3

I am working on an active machine learning pipeline with the tensorflow object detection api. My goal is to dynamically change the path within the network's .config file.

The standard configuration looks like this:

    train_input_reader: {
       tf_record_input_reader {
       input_path: "/PATH_TO_CONFIGURE/train.record"
       }
       label_map_path: "/PATH_TO_CONFIGURE/label_map.pbtxt"
    }

"PATH_TO_CONFIGURE" should be dynamically replaced from within my jupyter notebook cell.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
Sh0rtey
  • 113
  • 1
  • 8

2 Answers2

6

The object detection API configuration files have protobuf format. Here's roughly how you can read them, edit and save.

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2

pipeline = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          

with tf.gfile.GFile('config path', "r") as f:                                                                                                                                                                                                                     
    proto_str = f.read()                                                                                                                                                                                                                                          
    text_format.Merge(proto_str, pipeline)

pipeline.train_input_reader.tf_record_input_reader.input_path[:] = ['your new entry'] # it's a repeated field 
pipeline.train_input_reader.label_map_path = 'your new entry'

config_text = text_format.MessageToString(pipeline)                                                                                                                                                                                                        
with tf.gfile.Open('config path', "wb") as f:                                                                                                                                                                                                                       
    f.write(config_text)

You will have to adjust the code, but the general philosophy should be clear. I'd recommend to refactor it into function and call for Jupyter.

Lukasz Tracewski
  • 10,794
  • 3
  • 34
  • 53
  • 1
    It works for the label_map_path variable. When i try to edit the input_path it throws an error though: "AttributeError: Assignment not allowed to repeated field "input_path" in protocol message object." – Sh0rtey Nov 20 '19 at 17:47
  • 2
    Can you try `pipeline.train_input_reader.tf_record_input_reader.input_path[:] = ..` (not the colon)? You might have to delete it first if it's a composite, see https://stackoverflow.com/questions/23726335/how-to-assign-to-repeated-field – Lukasz Tracewski Nov 20 '19 at 21:21
  • that was the missing puzzle piece! Careful though: the string needs to be written in brackets: pipeline.train_input_reader.tf_record_input_reader.input_path[:] = ['PATH'] Otherwise it gets split into multiple lines – Sh0rtey Nov 20 '19 at 21:51
  • Thanks @Sh0rtey , I will edit the answer to reflect that. – Lukasz Tracewski Nov 20 '19 at 21:52
3

Here is what worked for me for TensorFlow 2 (the API has slightly changed from tf.gfile.GFile to tf.io.gfile.GFile):

import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import pipeline_pb2

def read_config():
    pipeline = pipeline_pb2.TrainEvalPipelineConfig()                                                                                                                                                                                                          
    with tf.io.gfile.GFile('pipeline.config', "r") as f:                                                                                                                                                                                                                     
        proto_str = f.read()                                                                                                                                                                                                                                          
        text_format.Merge(proto_str, pipeline)
    return pipeline

def write_config(pipeline):
    config_text = text_format.MessageToString(pipeline)                                                                                                                                                                                                        
    with tf.io.gfile.GFile('pipeline.config', "wb") as f:                                                                                                                                                                                                                       
        f.write(config_text)

def modify_config(pipeline):
    pipeline.model.ssd.num_classes = 1
    pipeline.train_config.fine_tune_checkpoint_type = 'detection'

    pipeline.train_input_reader.label_map_path = 'label_map.pbtxt'
    pipeline.train_input_reader.tf_record_input_reader.input_path[0] = 'train.record'

    pipeline.eval_input_reader[0].label_map_path = 'label_map.pbtxt'
    pipeline.eval_input_reader[0].tf_record_input_reader.input_path[0] = 'test.record'

    return pipeline


def setup_pipeline():
    pipeline = read_config()
    pipeline = modify_config(pipeline)
    write_config(pipeline)

setup_pipeline()
Oleksii Trekhleb
  • 2,543
  • 20
  • 22