0

I am trying to calculate the FLOPS of my model which is a tf.keras model.

As a workaround I am dealing with my model as being a pure tensorflow one, since I am not aware of a way to calculate FLOPS directly in a keras model.

The problem I am facing is (apparently) that at some layers the shape is considered undefined and I am getting an error.

import tensorflow as tf
import numpy as np

model = tf.keras.applications.ResNet50(
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    input_shape=None,
    pooling=None,
    classes=1000)
nparams = np.sum([np.prod(v.get_shape().as_list()) for v in tf.compat.v1.trainable_variables()])
options = tf.profiler.ProfileOptionBuilder.float_operation()
options['output'] = 'none'
flops = tf.profiler.profile(tf.get_default_graph(), options=options).total_float_ops
flops = flops // 2

111 ops no flops stats due to incomplete shapes.

On the other hand if I check the summary of the previous model I cannot seem to find any undefined shapes in layers except for the batch size. And I think I cannot explicitly define the batch size.

model.summary()
Model: "resnet50"

input_1 (InputLayer) [(None, 224, 224, 3) 0
...

The problem is that as I get it the returned FLOPS are not accurate. So, how can I get the actual FLOPS of my model?

My tensorflow is 1.15, Keras is 2.2.5 and Keras-Applications is 1.0.8

Eypros
  • 5,370
  • 6
  • 42
  • 75

1 Answers1

0

After a research I made I finally managed to find a solution for it. Some observations regarding this:

1) As it seems the issue here is that this None in shape for the profiler is enough to cause these errors. The model should be called using a hard-coded shape like:

ResNet50(include_top=True, weights="imagenet", input_tensor=tf.placeholder('float32', shape=(1, 32, 32, 3)), input_shape=None, pooling=None, classes=1000)

The solution seem to be valid only for tensorflow < 2. A workaround to use it in tf 2.0+ is this:

def get_flops(model_h5_path):
    session = tf.compat.v1.Session()
    graph = tf.compat.v1.get_default_graph()
        

    with graph.as_default():
        with session.as_default():
            model = tf.keras.models.load_model(model_h5_path)

            run_meta = tf.compat.v1.RunMetadata()
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        
            # We use the Keras session graph in the call to the profiler.
            flops = tf.compat.v1.profiler.profile(graph=graph,
                                                  run_meta=run_meta, cmd='op', options=opts)
        
            return flops.total_float_ops

Taken from here.

3) The actual solution works only for frozen models. The good news is that this is how it is measured by all work at the first place (by an inference frozen model to be exact). So, a working solution is:

import keras.backend as K
from keras.applications.resnet50 import ResNet50


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(
            session, input_graph_def, output_names, freeze_var_names)
        return frozen_graph


def load_pb(pb):
    with tf.gfile.GFile(pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


run_meta = tf.RunMetadata()
with tf.Session(graph=tf.Graph()) as sess:
    K.set_session(sess)
    net = ResNet50(include_top=True, weights="imagenet", input_tensor=tf.placeholder('float32', shape=(1, 32, 32, 3)), input_shape=None, pooling=None, classes=1000)

    frozen_graph = freeze_session(K.get_session(), output_names=[out.op.name for out in net.outputs])
    with tf.gfile.GFile('graph.pb', "wb") as f:
        f.write(frozen_graph.SerializeToString())

    g2 = load_pb('./graph.pb')
    with g2.as_default():
        flops = tf.profiler.profile(g2, options=tf.profiler.ProfileOptionBuilder.float_operation())
        print('FLOP after freezing {} MFLOPS'.format(float(flops.total_float_ops//2) * 1e-6))

And finally:

FLOP after freezing 80.87084 MFLOPS

which manages to calculate the FLOPS of the frozen model and creates a by-product of a saved pb model (graph.pb) on the disk (which can be deleted afterwards of course).

The solution heavily borrows code from these answers (just to be fair).

Eypros
  • 5,370
  • 6
  • 42
  • 75