1

How would one access a training operation from a tf.keras.models.Model? Consider the following:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Flatten
from tensorflow.keras.models import Model
import numpy as np
from sys import exit as xit
# Make some dummy data
dummy_data_shape=(5,5)
def batch_generator(size):
    """ Makes some random data """
    def _gen():
        y_batch=np.random.randint(0,2, size=size)
        y_batch=np.expand_dims(y_batch,-1)
        y_expanded=np.expand_dims(y_batch,-1)
        x_batch=np.ones((size,*dummy_data_shape))*y_expanded
        yield x_batch,y_batch
    return _gen()

# Make some simple model
Y=tf.placeholder(tf.float32,[None,1])
X = Input(shape=dummy_data_shape)
layer_mod = Flatten()(X)
layer_mod = Dense(1)(layer_mod)

# Tie it all together and compile
out_model = Model(inputs=[X], outputs=[layer_mod])
out_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.metrics.binary_crossentropy
)

### How can I access a train_op from the out_model?
with tf.Session() as sess:
    data_iter=batch_generator(10)
    sess.run(tf.global_variables_initializer())
    x,y=next(data_iter)
    ## Here: How to access the operation that trains the model?
    train_op=out_model.train_op #<-- ?
    sess.run(train_op, feed_dict={X:x,Y:y})

What should the second to last line in the code above be for the model to train?

  • Maybe you can refer https://stackoverflow.com/questions/42685994/how-to-get-a-tensorflow-op-by-name – giser_yugang Apr 15 '19 at 13:01
  • @giser_yugang Yes, if could find any one operation that corresponds to the training operation in the model. Closest I can find are [these lines in the implementation](https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/engine/training.py#L1894), that only seem to execute on `train_on_batch`, which doesn't really do it for me :/ – Andreas Storvik Strauman Apr 16 '19 at 08:27

0 Answers0