I am using JAX for auto differentiation. In this, I am trying to convert concrete tracer values to regular float values using astype(float)
but it still seems to return a concrete tracer value. However when I do astype(int)
it seems to correctly convert the concrete tracer into an int.
Here is an example:
import jax
from jax import numpy as jnp
def func(mu):
print((mu.astype(float)))
print(mu)
print(mu.astype(int))
return mu
f = jax.grad(func)
print(f(jnp.array(5.0)))
This returns the following:
Traced<ConcreteArray(5.0)>with<JVPTrace(level=2/0)>
with primal = Traced<ConcreteArray(5.0):JaxprTrace(level=1/0)>
tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
Traced<ConcreteArray(5.0)>with<JVPTrace(level=2/0)>
with primal = Traced<ConcreteArray(5.0):JaxprTrace(level=1/0)>
tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)>
5
1.0
Clearly print(mu)
and print(mu.astype(float))
seem to return the exact same thing while print(mu.astype(int))
returns the correct int value. Is this expected? This is a simple example but in my original function I would like to convert one of the input parameters into a regular float value.
Thanks in advance!