6

The documentation for JAX says,

Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

Now I am somewhat surprised because tensorflow has operations like tf.boolean_mask that does what JAX seems incapable of doing when compiled.

  1. Why is there such a regression from Tensorflow? I was under the assumption that the underlying XLA representation was shared between the two frameworks, but I may be mistaken. I don't recall Tensorflow ever having troubles with dynamic shapes, and functions such as tf.boolean_mask have been around forever.
  2. Can we expect this gap to close in the future? If not, why makes it impossible to do in JAX' jit what Tensorflow (among others) enables?

EDIT

The gradient passes through tf.boolean_mask (obviously not on mask values, which are discrete); case in point here using TF1-style graphs where values are unknown, so TF cannot rely on them:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

x1 = tf.placeholder(tf.float32, (3,))
x2 = tf.placeholder(tf.float32, (3,))
y = tf.boolean_mask(x1, x2 > 0)
print(y.shape)  # prints "(?,)"
dydx1, dydx2 = tf.gradients(y, [x1, x2])
assert dydx1 is not None and dydx2 is None
user209974
  • 1,737
  • 2
  • 15
  • 31
  • This question is a bit too subjective for StackOverflow. You may have more luck asking about this at https://github.com/google/jax/discussions – jakevdp Mar 19 '21 at 17:36
  • Hi @jakevdp, I don't think the question is subjective as it relates to capacities of jit compilation of operators on dynamic shapes in JAX and TF. I agree the title of my question doesn't reflect that. – user209974 Mar 19 '21 at 18:00
  • 1
    OK, let me rephrase: you're asking things about JAX's design and roadmap; such questions are often closed as off-topic by StackOverflow moderators, and the people who can answer such questions are more active on JAX's github discussions than they are here. – jakevdp Mar 19 '21 at 18:41
  • Oh, I see what you mean. Fair enough. – user209974 Mar 20 '21 at 22:05

2 Answers2

4

Currently, you can't (as discussed here)

This is not a limitation of JAX jit vs TensorFlow, but a limitation of XLA or rather how the two compile.

JAX uses simply XLA to compile the function. XLA needs to know the static shape. That's an inherent design choice within XLA.

TensorFlow uses the function: this creates a graph which can have shapes that are not statically known. This is not as efficient as using XLA, but still fine. However, tf.function offers an option jit_compile, which will compile the graph inside the function with XLA. While this offers often a decent speedup (for free), it comes with restrictions: shapes need to be statically known (surprise, surprise,...)

This is overall not too surprising behavior: computations in computers are in general faster (given a decent optimizer went over it) the more is previously known as more parameters (memory layout,...) can be optimally scheduled. The less is known, the slower the code (on this end is normal Python).

Mayou36
  • 4,613
  • 2
  • 17
  • 20
0

I don't think JAX isn't more incapable of doing this than TensorFlow. Nothing forbid you to do this in JAX:

new_array = my_array[mask]

However, mask should be indices (integers) and not booleans. This way, JAX is aware of the shape of new_array (the same as mask). In that sens, I'm pretty sure that tf.boolean_mask is not differentiable i.e. it will raise an error if you try to compute its gradient at some point.

More generally, if you need to mask an array, whatever library you are using, there are two approaches:

  1. if you know in advance what indices need to be selected and you need to provide these indices such that the library can compute the shape before compilation;
  2. if you can't define these indices, for whatever reason, then you need to design your code in order to avoid the prevent the padding to affect your result.

Examples for each situation

  1. Let say you're writing a simple embedding layer in JAX. The input is a batch of token indices corresponding to several sentences. To get word embeddings corresponding to these indices, I will simply write word_embeddings = embeddings[input]. Since I don't know the length of the sentences in advance, I need to pad all token sequences to the same length beforehand, such that input is of shape (number_of_sentences, sentence_max_length). Now, JAX will compile the masking operation every time this shape changes. To minimize the number of compilations, you can provide the same number of sentences (also called batch size) and you can set the sentence_max_length to the maximum sentence length in the entire corpus. This way, there will be only one compilation during training. Of course, you need to reserve one row in word_embeddings that corresponds to the pad index. But still, the masking works.

  2. Later in the model, let say you want to express each word of each sentence as a weighted average of all other words in the sentence (like a self-attention mechanism). The weights are computed in parallel for the entire batch and are stored in the matrix A of dimension (number_of_sentences, sentence_max_length, sentence_max_length). The weighted averages are computed with the formula A @ word_embeddings. Now, you need to make sure the pad tokens don't affect this previous formula. To do so, you can zero out the entries of A corresponding to the pad indices to remove their influence in the averaging. If the pad token index is 0, you would do:

    mask = jnp.array(input > 0, dtype=jnp.float32)
    A = A * mask[:, jnp.newaxis, :]
    weighted_mean = A @ word_embeddings 

So here we used a boolean mask, but the masking is somehow differentiable since we multiply the mask with another matrix instead of using it as an index. Note that we should proceed the same way to remove the rows of weighted_mean that also correspond to pad tokens.

Robin
  • 1,531
  • 1
  • 15
  • 35
  • 1
    Thanks for your answer. Maybe I didn't understand your comment, but gradients *do* pass through `tf.boolean_mask` (obviously not through the mask). I edited my answer to provide a small illustration of it. – user209974 May 05 '21 at 20:36
  • jax.numpy.where or np.asarray(condition).nonzero() might be the closest operations to perform this with JAX, but the shapes are needed to jit them. After all it's not so much a question of gradient but a question of returning well shaped arrays. What happens if you use a mask = [[True, True], [False, True]] on a matrix X = [[1,2], [3,4]] ? – Robin May 06 '21 at 08:21
  • The answer of TF to this (the only possible answer indeed) is to flatten the common dimensions of the array and the mask. You example would have shape `(?,)` during build, and `(3,)` at run-time. – user209974 May 06 '21 at 08:27
  • It is less capable than *tf.function without xla* because it needs to have static known shapes. When you write "mask" in your example, I think you mean "indices", a mask *is* with booleans. And, as written, "This way, JAX is aware of the shape": TF doesn't need to be with the function. It's all not about the gradients btw! It's about the compilation. So the argument of gradients (which I think as well is not true) does not really apply here. – Mayou36 Nov 30 '21 at 18:30