1

I have a network that has a lot of items that need to be embedded.

However, in each training batch, only a very small portion of the items will actually be used.

If I use the normal tf.keras.layers.Embedding layer, it will add all the items into the network parameter, thus consuming a lot of memory and decreasing speed in distributed training significantly since in each step all the unused grads are still synchronized.

What I want is, that in each training step only the actually used embedding weights are added into the graph and be computed and synchronized.

Pytorch already provides this functionality with torch.nn.Embedding(sparse=True).

How can I implement this in Tensorflow 2?

  • Given https://stackoverflow.com/questions/47868341/why-are-embeddings-in-pytorch-implemented-as-sparse-layers, the key point seems to be pytorch returns Sparse gradient for sparse embedding? how to implement this in tensorflow? – Visitor Lord Jul 24 '20 at 08:36

2 Answers2

1

My bad... checking tf.GradientTape() tells me that gradient of tf.gather is already a sparse tensor, so this needs no bother.

0

I am facing the same problem.I solve it by https://github.com/tensorflow/tensorflow/issues/33880.

class SparseEmbedding(tf.keras.layers.Embedding):
def __init__(self, *args, combiner=None, mask_zero=None, **kwargs):
    assert mask_zero is None, "Cannot use sparse embedding with mask zero!"
    self._combiner = combiner
    super().__init__(*args, **kwargs)

def compute_mask(self, inputs, mask=None):
    return None

def call(self, inputs):
    dtype = backend.dtype(inputs)
    if dtype != 'int32' and dtype != 'int64':
        raise TypeError(f"""
        Failed to process inputs! Expected dtype to be integers!
        
        expected: int32 or int64
        received: {dtype}
        name: {inputs.name}
        """)
    out = embedding_ops.safe_embedding_lookup_sparse_v2(
        embedding_weights=self.embeddings,
        sparse_ids=inputs,
        combiner=self._combiner)
    if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
        # Instead of casting the variable as in most layers, cast the output, as
        # this is mathematically equivalent but is faster.
        out = math_ops.cast(out, self._dtype_policy.compute_dtype)
    return out

def get_config(self):
    base_config = super().get_config()
    config = {"combiner": self._combiner}
    return dict(list(base_config.items()) + list(config.items()))