I don't think JAX isn't more incapable of doing this than TensorFlow. Nothing forbid you to do this in JAX:
new_array = my_array[mask]
However, mask
should be indices (integers) and not booleans. This way, JAX is aware of the shape of new_array
(the same as mask
). In that sens, I'm pretty sure that tf.boolean_mask
is not differentiable i.e. it will raise an error if you try to compute its gradient at some point.
More generally, if you need to mask an array, whatever library you are using, there are two approaches:
- if you know in advance what indices need to be selected and you need to provide these indices such that the library can compute the shape before compilation;
- if you can't define these indices, for whatever reason, then you need to design your code in order to avoid the prevent the padding to affect your result.
Examples for each situation
Let say you're writing a simple embedding layer in JAX. The input
is a batch of token indices corresponding to several sentences. To get word embeddings corresponding to these indices, I will simply write word_embeddings = embeddings[input]
. Since I don't know the length of the sentences in advance, I need to pad all token sequences to the same length beforehand, such that input
is of shape (number_of_sentences, sentence_max_length)
. Now, JAX will compile the masking operation every time this shape changes. To minimize the number of compilations, you can provide the same number of sentences (also called batch size) and you can set the sentence_max_length
to the maximum sentence length in the entire corpus. This way, there will be only one compilation during training. Of course, you need to reserve one row in word_embeddings
that corresponds to the pad index. But still, the masking works.
Later in the model, let say you want to express each word of each sentence as a weighted average of all other words in the sentence (like a self-attention mechanism). The weights are computed in parallel for the entire batch and are stored in the matrix A
of dimension (number_of_sentences, sentence_max_length, sentence_max_length)
. The weighted averages are computed with the formula A @ word_embeddings
. Now, you need to make sure the pad tokens don't affect this previous formula. To do so, you can zero out the entries of A corresponding to the pad indices to remove their influence in the averaging. If the pad token index is 0, you would do:
mask = jnp.array(input > 0, dtype=jnp.float32)
A = A * mask[:, jnp.newaxis, :]
weighted_mean = A @ word_embeddings
So here we used a boolean mask, but the masking is somehow differentiable since we multiply the mask with another matrix instead of using it as an index. Note that we should proceed the same way to remove the rows of weighted_mean
that also correspond to pad tokens.