2

In JAX, I am looking to vmap a function over a fixed length list of dataclasses, for example:

import jax, chex
from flax import struct

@struct.dataclass
class EnvParams:
    max_steps: int = 500
    random_respawn: bool = False

def foo(params: EnvParams):
    ...

param_list = jnp.Array([EnvParams(max_steps=500), EnvParams(max_steps=600)])
jax.vmap(foo)(param_list)

The example above fails since is not possible to create a jnp.Array of custom objects, and JAX doesn't allow vmapping over Python Lists. The only remaining option I see is to transform the dataclass to represent a batch of parameters, as so:

@struct.dataclass
class EnvParamBatch:
    max_steps: jnp.Array = jnp.array([500, 600])
    random_respawn: jnp.Array = jnp.array([False, True])

def bar(params):
    ...

jax.vmap(bar)(EnvParamBatch())

It would be preferable to use a container of structs (with each representing a single parameter set), so I'm wondering if there are any alternative approaches to this?

N.B. I am aware of this answer, however it's not precisely the same question and there may now be better solutions.

  • 2
    JAX's `vmap` cannot operate on array-of-structs, but can operate on struct-of-arrays, so your second solution is the approach you should use with JAX. I'd add an answer, but it seems you've already answered your question! – jakevdp Sep 19 '22 at 12:51

2 Answers2

1

Your second solution is correct. I agree it is awkward to use a struct of arrays but it is usually the best option in JAX (so arrays can be stored in GPU / TPU memory instead of CPU). Here's an example:

import typing
import jax
import jax.numpy as jnp

class EnvParams(typing.NamedTuple):
    max_steps: int = 500
    random_respawn: bool = False

param_array = EnvParams(
    max_steps=jnp.array([500, 600]),
    random_respawn=jnp.array([False, False]))
vmap_param_array = jax.vmap(lambda x: x)(param_array)

However if you really must use lists it is possible. Here is an example:

def list_to_array(list):
    cls = type(list[0])
    return cls(**{k: jnp.array([getattr(v, k) for v in list]) for k in cls._fields})

def array_to_list(array):
    cls = type(array)
    size = len(getattr(array, cls._fields[0]))
    return [cls(**{k: v(getattr(array, k)[i]) for k, v in cls._field_types.items()}) for i in range(size)]

param_list = [EnvParams(max_steps=500), EnvParams(max_steps=600)]
param_array = list_to_array(param_list)
vmap_param_array = jax.vmap(lambda x: x)(param_array)
vmap_param_list = array_to_list(vmap_param_array)
Chris Flesher
  • 987
  • 1
  • 10
  • 13
1

Chris gave a valid answer for simple codes, but there's a way to do it without having to modify the dataclass. I ran into the exact same problem, and other part of my code depends on overloaded operators in the dataclass, so I could not modify the data structures easily.

The solution is using pytree's and tree_map(). These are JAX data structures fornested list/dict of traced arrays. First, you'll need to modify your class into a pytree. This should require very little effort.

Since lists of pytrees are also pytrees, jax.tree_util.tree_map would work without you having to modify your data_class.

Here is a minimum working example:

import jax
from jax import jit, vmap, tree_util
from functools import partial # for JAX jit with static params

class MyContainer:
    """ For JAX use """
    def _tree_flatten(self):
        children = (self.a,)  # arrays / dynamic values
        aux_data = {'a_stat': self.a_stat}  # static values
        return (children, aux_data)
    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)
    
    """
    A container with a traced and a static member.
    the * operator is overloaded as demonstration.
    """
    def __init__(self, a:int):
        self.a = a
        self.a_stat = a*100
    
    def __mul__(self, other):
        return(MyContainer(self.a*other.a))
    
# Registering the datatype with JAX
tree_util.register_pytree_node(
    MyContainer,
    MyContainer._tree_flatten,
    MyContainer._tree_unflatten)


X_list = [MyContainer(3),MyContainer(4),MyContainer(5)]
Y_list = [MyContainer(1),MyContainer(10),MyContainer(100)]

# A simple callable adds the traced var a to the static var a_stat
def simple_callable(my_container):
    return(MyContainer(my_container.a+my_container.a_stat))

# Note that tree_map will try to traverse into class members as well. 
# To stop it from doing that, we add is_leaf to stop it from looking 
# deeper when the item is a MyContainer. 
test_simple_list = jax.tree_util.tree_map(
    simple_callable, 
    [MyContainer(3),MyContainer(4),MyContainer(5)], 
    is_leaf=lambda n: isinstance(n, MyContainer)
)

# see if it works
for i in range(len(X_list)):
    print('simple_callable', test_simple_list[i].a)
    

# This also works for callables containing such list of dataclasses
# However, to do indexing, you need a list for the indices.
# this list will be automatically handled as a pytree.
tree_ind = list(range(len(X_list))) 
def callcables_containing_dataclass(i):
    return(X_list[i]*Y_list[i])

test_callable_list = jax.tree_util.tree_map(callcables_containing_dataclass, tree_ind)

# seeing if it works
for i in range(len(X_list)):
    print('callable with dataclass', test_callable_list[i].a)

# jitting works
@jit
def test():
    return(
        test_simple_list = jax.tree_util.tree_map(
            simple_callable, 
            [MyContainer(3),MyContainer(4),MyContainer(5)], 
            is_leaf=lambda n: isinstance(n, MyContainer)
        ),
        jax.tree_util.tree_map(callcables_containing_dataclass, tree_ind
    )

test()