4

Below is the code I am using. I commented out the line to convert my model to the TPU model. With GPU for the same amount of data it's taking 7 seconds for an epoch while using TPU it takes 90 secs.

    Inp = tf.keras.Input(name='input', shape=(input_dim,), dtype=tf.float32)
    x = tf.keras.layers.Dense(900, kernel_initializer='uniform',  activation='relu', input_dim=input_dim, name = 'Dense_01')(Inp)
    x = tf.keras.layers.Dropout(0.3, name = 'Dropout_02')(x)
    output = tf.keras.layers.Dense(stop_criteria, activation='softmax',name = 'Dense_02')(x)

    model = tf.keras.Model(inputs=[Inp], outputs=[output])
    opt = tf.train.AdamOptimizer(.001)
    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['acc'])

    '''tpu_model = tf.contrib.tpu.keras_to_tpu_model(model,
                                                  strategy=tf.contrib.tpu.TPUDistributionStrategy(
                                                      tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)))'''
    model.fit(X_tra, y_tra, epochs=5, batch_size=batch_size, shuffle=False,
              validation_split=0.1, verbose=2)

Here is the link to the notebook

mihirjoshi
  • 12,161
  • 7
  • 47
  • 78

1 Answers1

0

Have you tried the tpu_model.fit_generator method like in the example below? The other part looks fine. Also, one problem could be the use of Adam Optimizer. There was smth. about it, but I forgot where the link is. Try another optimizer and the code below and if a different optimizer worked, you know it must be smth. with the Adam Optimizer.

tf.keras.backend.clear_session()

training_model = lstm_model(seq_len=100, batch_size=128, stateful=False)

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
    training_model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))

tpu_model.fit_generator(
    training_generator(seq_len=100, batch_size=1024),
    steps_per_epoch=100,
    epochs=10,
)
tpu_model.save_weights('/tmp/bard.h5', overwrite=True)
jubueche
  • 763
  • 5
  • 24