3

Hi why can't I vectorize the condition function to apply for a list of boolean? or is there something else going on here?

DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda x: 1
f2 = lambda y: 0
cond = lambda dk: jax.lax.cond(dk,f1,f2)
vcond = jax.vmap(cond)
vcond(DK)

I was expecting it to give me an array.

Kapil
  • 81
  • 5

2 Answers2

2

There are two issues here: first, lax.cond requires a scalar predicate, and you are vmapping a 2D input, which ends up effectively passing a 1D predicate. You can fix this by using two levels of vmap:

vcond = jax.vmap(jax.vmap(cond))

Second, you've set up your f1 and f2 to take an argument x, but it doesn't appear you have any value x to pass to them (if you did, you could pass it as jax.lax.cond(dk,f1,f2,x)). If no x argument is needed, you can redefine the functions to take no arguments. The resulting code looks like this:

import jax.numpy as jnp
import jax

DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda: 1
f2 = lambda: 0
cond = lambda dk: jax.lax.cond(dk,f1,f2)
vcond = jax.vmap(jax.vmap(cond))
result = vcond(DK)
print(result)
# [[1]
#  [1]
#  [0]
#  [1]]

Note that you can avoid a lot of this complexity by replacing your code with jnp.where. For example:

result = jnp.where(DK, 1, 0)

This will lower to essentially the same XLA select operation as your original code.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • 1
    I find that `jax.nump.where`/`jax.numpy.select` (`jax.lax.select`) and `jax.numpy.piecewise` (`jax.lax.switch`) cover most reasons to use a conditional. Sometimes these methods are substantially faster than `at[mask].set()`, other times essentially the same. – DavidJ Dec 15 '22 at 10:42
  • I see. Do you have a blog/tutorial where these are interchangeably used for common conditional cases? – Kapil Dec 16 '22 at 20:51
1

Try this:

import jax
import jax.numpy as jnp
DK = jnp.array([[True],[True],[False],[True]])
f1 = lambda x: 1
f2 = lambda y: 0

# New_Version
cond = lambda dk: jax.lax.cond(dk, lambda x: f1(x), lambda x: f2(x), dk)
# ----------------------------^^^ check condition
# ------------------------------------------------------------------^^^ pass to f1 or f2 base the result of condition

# Old_Version
cond = lambda dk: jax.lax.cond(dk,
                               dk, lambda x: f1(x),
                               dk, lambda x: f2(x))

vcond = jax.vmap(jax.vmap(cond))
vcond(DK)

Output:

DeviceArray([[1],
             [1],
             [0],
             [1]], dtype=int32, weak_type=True)
I'mahdi
  • 23,382
  • 5
  • 22
  • 30
  • 1
    `jax.lax.cond(dk, dk, lambda x: f1(x), dk, lambda x: f2(x))` why is there a extra dk before lambda func? the documentation for jax.lax.cond only has ( predicate,fun1,fun2). Amazingly it worked! – Kapil Dec 09 '22 at 23:25
  • 1
    `jax.lax.cond(dk, x, f1, x, f2)` is an older deprecated signature for `lax.cond`, which still works but will raise a warning. The new call signature is `jax.lax.cond(dk, f1, f2, x)` where `dk` is the predicate, and `x` is the argument that will be passed to either `f1` or `f2`. – jakevdp Dec 10 '22 at 01:07
  • @jakevdp, thanks a lot, edited the answer and add a new version. – I'mahdi Dec 10 '22 at 04:05
  • @Kapil, You're welcome. Thanks `jakevdp` you got the answer, also I edited the answer for the new version. – I'mahdi Dec 10 '22 at 04:14