I defined a custom loss function for Deep Metric Learning in Keras, and I get abysmal performance in Colab with a GPU accelerator, notwithstanding the simplicity of the problem. What can I do to speed the training up?
Sample code: the code is written to be run in Colab/Jupyter, but I guess it could be run as a script too (I didn't try, but I don't see any obvious reason why it shouldn't work).
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train / 255.0).astype('float32')
X_test = (X_test / 255.0).astype('float32')
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)
tf.keras.backend.clear_session()
embedding_size = 3
model = keras.Sequential()
model.add(Flatten(input_shape = input_shape))
model.add(Dense(128, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(embedding_size))
model.summary()
margin = 0.2
def contrastive_loss(y_true, embeddings):
loss = 0.0
b = embeddings.shape[0]
for i in range(0,b):
yi = y_true[i]
xi = embeddings[i]
for j in range(i+1,b):
yj = y_true[j]
xj = embeddings[j]
yij = tf.minimum(1.0, tf.abs(tf.cast(yi-yj, dtype = tf.float32)))
distance = tf.norm(xi-xj)
loss = loss + (1-yij)*distance**2 + yij*tf.maximum(0.0, margin-distance)**2
loss = 0.5 * loss
return loss
model.compile(keras.optimizers.Adam(learning_rate=1e-4), loss=contrastive_loss)
history = model.fit(X_train, y_train, batch_size=32, epochs=10, shuffle=True)
print(X_test.shape)
emb = model.predict(X_test)
print(emb.shape, "emb")
print(y_test.shape, "y_test")
plt.scatter(emb[:,0], emb[:,1], c=y_test, cmap='tab10')
The training is very slow on Colab with GPU accelerator. Any hints on how I could speed it up? Right now I get > 200 ms/step, i.e., ≈ 380 s per epoch.
EDIT: initially I wrote a stripped-down version of the actual code, but, according to some comments, the way to speed the code up may depend on what I actually compute inside the loss function, so I now put the actual code. Sorry if that makes the code go slower...
EDIT2: I tried to implement the solution in Jeff's answer but I'm getting an error whose cause I don't understand. Here's the new code:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train / 255.0).astype('float32')
X_test = (X_test / 255.0).astype('float32')
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
input_shape = (28, 28, 1)
tf.keras.backend.clear_session()
embedding_size = 3
model = keras.Sequential()
model.add(Flatten(input_shape = input_shape))
model.add(Dense(128, activation='relu'))
model.add(Dense(128, activation='relu'))
model.add(Dense(embedding_size))
model.summary()
margin = 0.2
def contrastive_loss(y_true, embeddings):
# it's easier to work with a flattened array
flat_y_true = tf.reshape(y_true, (-1))
# matrix of absolute differences, clipped to 1 (if yi!=yj then yij =1)
yijs = abs(np.subtract.outer(flat_y_true, flat_y_true)).clip(max=1)
# we only need the upper triangular part of the matrix
yijs = yijs[np.triu_indices_from(yijs)]
# first compute row differences of the embeddings matrix, the compute norms for
# each row with axis=2
distances = np.linalg.norm(embeddings[:,None] - embeddings[None], axis=2)
# we only need the upper triangular part, again
distances = distances[np.triu_indices_from(distances)]
loss = ((1-yijs)*(distances**2) + (yijs*((margin-distances).clip(min=0)**2))).sum()*0.5
return loss
model.compile(keras.optimizers.Adam(learning_rate=1e-4), loss=contrastive_loss)
history = model.fit(X_train, y_train, batch_size=32, epochs=10, shuffle=True)
I get the following error:
Epoch 1/10
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-5-1ae713ccf38c> in <module>()
4
5 # When setting batch size, remember we are *quadratically* expanding it in our loss.
----> 6 history = model.fit(X_train, y_train, batch_size=32, epochs=10, shuffle=True)
10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
971 except Exception as e: # pylint:disable=broad-except
972 if hasattr(e, "ag_error_metadata"):
--> 973 raise e.ag_error_metadata.to_exception(e)
974 else:
975 raise
ValueError: in user code:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:806 train_function *
return step_function(self, iterator)
<ipython-input-4-6c3b8535eb43>:5 contrastive_loss *
flat_y_true = tf.reshape(y_true, (-1))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper **
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py:195 reshape
result = gen_array_ops.reshape(tensor, shape, name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py:8234 reshape
"Reshape", tensor=tensor, shape=shape, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:744 _apply_op_helper
attrs=attr_protos, op_def=op_def)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:593 _create_op_internal
compute_device)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:3485 _create_op_internal
op_def=op_def)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1975 __init__
control_input_ops, op_def)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1815 _create_c_op
raise ValueError(str(e))
ValueError: Shape must be rank 1 but is rank 0 for '{{node contrastive_loss/Reshape}} = Reshape[T=DT_UINT8, Tshape=DT_INT32](ExpandDims, contrastive_loss/Reshape/shape)' with input shapes: [32,1], [].