Chris gave a valid answer for simple codes, but there's a way to do it without having to modify the dataclass. I ran into the exact same problem, and other part of my code depends on overloaded operators in the dataclass, so I could not modify the data structures easily.
The solution is using pytree's and tree_map(). These are JAX data structures fornested list/dict of traced arrays. First, you'll need to modify your class into a pytree. This should require very little effort.
Since lists of pytrees are also pytrees, jax.tree_util.tree_map would work without you having to modify your data_class.
Here is a minimum working example:
import jax
from jax import jit, vmap, tree_util
from functools import partial # for JAX jit with static params
class MyContainer:
""" For JAX use """
def _tree_flatten(self):
children = (self.a,) # arrays / dynamic values
aux_data = {'a_stat': self.a_stat} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
"""
A container with a traced and a static member.
the * operator is overloaded as demonstration.
"""
def __init__(self, a:int):
self.a = a
self.a_stat = a*100
def __mul__(self, other):
return(MyContainer(self.a*other.a))
# Registering the datatype with JAX
tree_util.register_pytree_node(
MyContainer,
MyContainer._tree_flatten,
MyContainer._tree_unflatten)
X_list = [MyContainer(3),MyContainer(4),MyContainer(5)]
Y_list = [MyContainer(1),MyContainer(10),MyContainer(100)]
# A simple callable adds the traced var a to the static var a_stat
def simple_callable(my_container):
return(MyContainer(my_container.a+my_container.a_stat))
# Note that tree_map will try to traverse into class members as well.
# To stop it from doing that, we add is_leaf to stop it from looking
# deeper when the item is a MyContainer.
test_simple_list = jax.tree_util.tree_map(
simple_callable,
[MyContainer(3),MyContainer(4),MyContainer(5)],
is_leaf=lambda n: isinstance(n, MyContainer)
)
# see if it works
for i in range(len(X_list)):
print('simple_callable', test_simple_list[i].a)
# This also works for callables containing such list of dataclasses
# However, to do indexing, you need a list for the indices.
# this list will be automatically handled as a pytree.
tree_ind = list(range(len(X_list)))
def callcables_containing_dataclass(i):
return(X_list[i]*Y_list[i])
test_callable_list = jax.tree_util.tree_map(callcables_containing_dataclass, tree_ind)
# seeing if it works
for i in range(len(X_list)):
print('callable with dataclass', test_callable_list[i].a)
# jitting works
@jit
def test():
return(
test_simple_list = jax.tree_util.tree_map(
simple_callable,
[MyContainer(3),MyContainer(4),MyContainer(5)],
is_leaf=lambda n: isinstance(n, MyContainer)
),
jax.tree_util.tree_map(callcables_containing_dataclass, tree_ind
)
test()