0

I am trying to save my configuration for a Keras model. I would like to be able to read the configuration from the file to be able to reproduce the training.

Before implementing a custom metric in a function I could just do it the way shown below without the mean_pred. Now I am running into the problem TypeError: Object of type 'function' is not JSON serializable.

Here I read that it is possible to get the function name as string by custom_metric_name = mean_pred.__name__. I would like to not only be able to save the name, but to be able to save a reference to the function if possible.

Perhaps I should as mentioned here also think about not just storing my configuration in the .py file but using ConfigObj. Unless this would solve my current problem I would implement this later.

Minimum working example of problem:

import keras.backend as K
import json

def mean_pred(y_true, y_pred):
    return K.mean(y_pred)

config = {'epochs':500,
          'loss':{'class':'categorical_crossentropy'},
          'optimizer':'Adam',
          'metrics':{'class':['accuracy', mean_pred]}
          }

# Do the training etc...

config_filename = 'config.txt'
with open(config_filename, 'w') as f:
    f.write(json.dumps(config))

Greatly appreciate help with this problem as well as other approaches to saving my configuration in the best way possible.

a-doering
  • 1,149
  • 10
  • 21

1 Answers1

0

To solve my problem I saved the name of the function as a string in the config file and then extracted the function from a dictionary to use it as metrics in the model. One could additionally use: 'class':['accuracy', mean_pred.__name__] to save the name of the function as a string in the config. This does also work for multiple custom functions and for more keys to metrics (eg. define metrics for 'reg' like 'class' when doing regression and classification).

import keras.backend as K
import json
from collections import defaultdict

def mean_pred(y_true, y_pred):
    return K.mean(y_pred)


config = {'epochs':500,
          'loss':{'class':'categorical_crossentropy'},
          'optimizer':'Adam',
          'metrics':{'class':['accuracy', 'mean_pred']}
          }


custom_metrics= {'mean_pred':mean_pred}

metrics = defaultdict(list)
for metric_type, metric_functions in config['metrics'].items():
    for function in metric_functions:
        if function in custom_metrics.keys():
            metrics[metric_type].append(custom_metrics[function])
        else:
            metrics[metric_type].append(function)

# Do the training, use metrics

config_filename = 'config.txt'
with open(config_filename, 'w') as f:
    f.write(json.dumps(config))
a-doering
  • 1,149
  • 10
  • 21