2 approaches of counting operations in Tensorflow provides different answers. It is not clear how it is counted. Any one can explain this?
I tried below 2 approaches (get_flops, get_flops_tfv2_1) . I didn't expect 2 approaches of profiling to provide different numbers. get_flops_tfv2_1 provides output as 1440 (close to theoretical operations) get_flops provides output as 120
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv1D, GlobalAveragePooling1D
def model_gen(win_size, n_feats):
input1 = Input(shape=(win_size,n_feats), name='input')
x = Conv1D(filters=120,kernel_size=1,dilation_rate=1, strides=1, padding='same')(input1)
x13=GlobalAveragePooling1D()(x)
model = Model(inputs=input1, outputs=x13)
return model
def get_flops(model_h5_path): # https://stackoverflow.com/a/59862883
tf.compat.v1.disable_eager_execution()
session = tf.compat.v1.Session()
graph = tf.compat.v1.get_default_graph()
with graph.as_default():
with session.as_default():
model = tf.keras.models.load_model(model_h5_path)
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.compat.v1.profiler.profile(graph=graph, run_meta=run_meta, cmd='op', options=opts)
tf.compat.v1.reset_default_graph()
return flops.total_float_ops
def get_flops_tfv2_1(model_h5_path): # https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-1065141109
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder
from tensorflow.python.profiler.model_analyzer import profile
model = tf.keras.models.load_model(model_h5_path)
forward_pass = tf.function(model.call,
input_signature=[tf.TensorSpec(shape=(1,) + model.input_shape[1:])])
graph_info = profile(forward_pass.get_concrete_function().graph,
options=ProfileOptionBuilder.float_operation())
# The //2 is necessary since `profile` counts multiply and accumulate
# as two flops, here we report the total number of multiply accumulate ops
flops = graph_info.total_float_ops // 2
return flops
if __name__ == '__main__':
dur = 6
model_file = 'tmp.h5'
n_feats=1
model=model_gen(dur, n_feats)
model.save(model_file)
ops = get_flops(model_file)
# ops = get_flops_tfv2_1(model_file)
print(f'#Operations = {ops}')