I have a network that has a lot of items that need to be embedded.
However, in each training batch, only a very small portion of the items will actually be used.
If I use the normal tf.keras.layers.Embedding
layer, it will add all the items into the network parameter, thus consuming a lot of memory and decreasing speed in distributed training significantly since in each step all the unused grads are still synchronized.
What I want is, that in each training step only the actually used embedding weights are added into the graph and be computed and synchronized.
Pytorch
already provides this functionality with torch.nn.Embedding(sparse=True)
.
How can I implement this in Tensorflow 2?