6

I am using JAX for auto differentiation. In this, I am trying to convert concrete tracer values to regular float values using astype(float) but it still seems to return a concrete tracer value. However when I do astype(int) it seems to correctly convert the concrete tracer into an int.

Here is an example:

import jax
from jax import numpy as jnp

def func(mu):

  print((mu.astype(float)))
  print(mu)
  print(mu.astype(int))

  return mu

f = jax.grad(func)

print(f(jnp.array(5.0)))

This returns the following:

Traced<ConcreteArray(5.0)>with<JVPTrace(level=2/0)>
  with primal = Traced<ConcreteArray(5.0):JaxprTrace(level=1/0)>
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(5.0)>with<JVPTrace(level=2/0)>
  with primal = Traced<ConcreteArray(5.0):JaxprTrace(level=1/0)>
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
5
1.0

Clearly print(mu) and print(mu.astype(float)) seem to return the exact same thing while print(mu.astype(int)) returns the correct int value. Is this expected? This is a simple example but in my original function I would like to convert one of the input parameters into a regular float value.

Thanks in advance!

jakevdp
  • 77,104
  • 11
  • 125
  • 160

1 Answers1

0

There is some information on this in JAX's FAQ: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array

In general, it is not safe to convert traced values to Python floats, because any operations you do on the float value will no longer be tracked by JAX's autodiff, and so autodiff may return incorrect results. This is why JAX allows you to call int on a JVP tracer (integer values cannot affect autodiff) while returning a traced value on float (float values do affect autodiff!).

If you would like to call a non-JAX function from within a transformed JAX function, your best bet is probably to use pure_callback along with custom_jvp; there is an example in the JAX docs at External Callbacks in JAX: pure_callback with custom_jvp.

Sorry for not answering the question earlier – these callback tools are relatively new and there was no good answer to your question when it was first asked!

jakevdp
  • 77,104
  • 11
  • 125
  • 160