3

I am trying to use JAX on another SO question to evaluate JAX applicability and performance on the code (There are useful information on that about what the code does). For this purpose, I have modified the code by jax.numpy (jnp) equivalent methods (Substituting NumPy related codes with their equivalent jnp codes were not as easy as I thought due to my little experience by JAX, and may be it could be written better). Finally, I checked the results with the ex-code (optimized algorithm) and the results were the same, but it takes 7.5 seconds by JAX, which took 0.10 seconds by the ex-one for a sample case (using Colab). I think this long runtime may be related to for loop in the code, which might be substituted by JAX related modules e.g. fori-loop or vectorization and …; but I don’t know what changes, and how, must be done to make this code satisfying in terms of performance and speed (using JAX).

import numpy as np
from scipy.spatial import cKDTree, distance
import jax
from jax import numpy as jnp
jax.config.update("jax_enable_x64", True)


# ---------------------------- input data ----------------------------
""" For testing by prepared files:
radii = np.load('a.npy')
poss = np.load('b.npy')
"""

rnd = np.random.RandomState(70)
data_volume = 1000

radii = rnd.uniform(0.0005, 0.122, data_volume)
dia_max = 2 * radii.max()

x = rnd.uniform(-1.02, 1.02, (data_volume, 1))
y = rnd.uniform(-3.52, 3.52, (data_volume, 1))
z = rnd.uniform(-1.02, -0.575, (data_volume, 1))
poss = np.hstack((x, y, z))
# --------------------------------------------------------------------


# @jax.jit
def ends_gap(poss, dia_max):
    particle_corsp_overlaps = jnp.array([], dtype=np.float64)

    # kdtree = cKDTree(poss)                                                                                              # Using SciPy

    for particle_idx in range(len(poss)):

        cur_point = poss[particle_idx]
        # nears_i_ind = jnp.array(kdtree.query_ball_point(cur_point, r=dia_max, return_sorted=True), dtype=np.int64)      # Using SciPy
        
        # Using NumPy
        unshared_idx = jnp.delete(jnp.arange(len(poss)), particle_idx)
        poss_without = poss[unshared_idx]
        dist_max = radii[particle_idx] + radii.max()

        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dist_max
        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dist_max
        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dist_max
        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dist_max
        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dist_max
        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dist_max

        nears_i_ind = jnp.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]

        # assert len(nears_i_ind) > 0

        # if len(nears_i_ind) <= 1:
        #     continue

        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]

        # dist_i = distance.cdist(poss[tuple(nears_i_ind[None, :])], cur_point[None, :]).squeeze()                        # Using SciPy
        dist_i = jnp.linalg.norm(poss[tuple(nears_i_ind[None, :])] - cur_point[None, :], axis=-1)                     # Using NumPy
        contact_check = dist_i - (radii[tuple(nears_i_ind[None, :])] + radii[particle_idx])

        connected = contact_check[contact_check <= 0]
        particle_corsp_overlaps = jnp.concatenate((particle_corsp_overlaps, connected))

        contacts_ind = jnp.where(contact_check <= 0)[0]
        contacts_sec_ind = jnp.array(nears_i_ind)[contacts_ind]
        sphere_olps_ind = jnp.sort(contacts_sec_ind)

        ends_ind_mod_temp = jnp.array([jnp.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
        if particle_idx > 0:   # ---> these 4-lines perhaps be better to be substituted by just one-line list appending as "ends_ind.append(ends_ind_mod_temp)"
            ends_ind = jnp.concatenate((ends_ind, ends_ind_mod_temp))
        else:
            ends_ind = jnp.array(ends_ind_mod_temp, dtype=np.int64)

    ends_ind_org = ends_ind
    ends_ind, ends_ind_idx = jnp.unique(jnp.sort(ends_ind_org), axis=0, return_index=True)
    gap = jnp.array(particle_corsp_overlaps)[ends_ind_idx]

    return gap, ends_ind, ends_ind_idx, ends_ind_org

I have tried to use @jax.jit on this code, but it shows errors: TracerArrayConversionError or ConcretizationTypeError on COLAB TPU:

Using SciPy:

TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(float64[1000,3])>with<DynamicJaxprTrace(level=0/1)> While tracing the function ends_gap at :1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'poss'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Using NumPy:

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. While tracing the function ends_gap at :1 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'poss' and 'dia_max'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

I would be appreciated for any help to speed up this code by passing these problems using JAX (and jax.jit if possible). How to utilize JAX to have the best performances on both CPU and GPU or TPU?


Prepared sample test data:
a.npy = Radii data
b.npy = Poss data

Updates


The main aim of this issue is how to modify the code for gaining the best performance of that using JAX library

I have commented the SciPy related lines on the code based on jakevdp answer and uncomment the equivalent NumPy related sections.
For getting better answer, I'm numbering some important subjects:

  1. Is scikit-learn BallTree related methods compatible with JAX?? This methods can be a good alternative for SciPy cKDTree in terms of memory usage (for probable vectorizations).
  2. How to best handle the loop section in the code, using fori_loop or by putting code lines of the loop inside a function and then vectorizing, jitting or …??
  • I had problem preparing the code for using fori_loop. What has been done for using fori_loop can be understood from the following code line, where particle_corsp_overlaps was the input of the defined function (this function just contains the loop section). It will be useful to show how to do that if using fori_loop is recommended.
particle_corsp_overlaps, ends_ind = jax.lax.fori_loop(0, len(poss), jax_loop, particle_corsp_overlaps)
  • I put the NumPy section in a function for jitting by @jax.jit to check its capability to improve performance (I don't know how much it can help). It got an error ConcretizationTypeError (--> Shape depends on Traced Value) relating to poss. So, I tried to use @partial(jax.jit, static_argnums=0) decorator by importing partial from functools, but now I am getting the following error; how to solve it if this way is recommended e.g. for:
@partial(jax.jit, static_argnums=0)
def ends_gap(poss):

    for particle_idx in range(len(poss)):

        cur_point = poss[particle_idx]

        unshared_idx = jnp.delete(jnp.arange(len(poss)), particle_idx)
        poss_without = poss[unshared_idx]
        dist_max = radii[particle_idx] + radii.max()

        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dist_max
        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dist_max
        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dist_max
        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dist_max
        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dist_max
        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dist_max

        nears_i_ind = jnp.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]

        dist_i = jnp.linalg.norm(poss[tuple(nears_i_ind[None, :])] - cur_point[None, :], axis=-1) 

ValueError: Non-hashable static arguments are not supported. An error occured during a call to 'nearest_neighbors_jax' while trying to hash an object of type <class 'jaxlib.xla_extension.DeviceArray'>, [[ 8.42519143e-01 1.37693422e+00 -7.97775882e-01] [-3.31436445e-01 -1.67346250e+00 -8.61069684e-01] [-1.57500126e-01 -1.17502591e+00 -7.48879998e-01]]. The error was: TypeError: unhashable type: 'DeviceArray'

I did not put the total loop body into the function due to stuck in this short defined function. Creating a function with all the loop body, which can be jitted or …, is of interest if possible.

  1. Can 4-lines ends_ind related if-else statement be written in just one line using jax methods to avoid probable problems with if during jitting or …?
Ali_Sh
  • 2,667
  • 3
  • 43
  • 66
  • In general when using JAX, if you are looping over large arrays using Python `for` loops or `fori_loop`, your implementation will not be particularly performant. I'd suggest working towards implementing your function using vectorized operations. It would be easier to give suggestions about this if you included a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) of what you're doing: the code you provided has syntax errors when defining the inputs. – jakevdp Feb 14 '22 at 20:27
  • @jakevdp Did you use `np.load` to load the data (I added it to the codes)? Is there any syntax error when using it? What? – Ali_Sh Feb 15 '22 at 07:44
  • Yes, for example `np.load(radii data)` results in a syntax error when you attempt to execute it. Perhaps for the sake of the reproducible example, you could change it to something like `np.random.rand(shape)` where `shape` is replaced with the shape of the actual dataset? – jakevdp Feb 15 '22 at 13:20
  • @jakevdp Just for being sure, file names are `a.npy` and `b.npy` which could be loaded by `np.load('a.npy')` and `np.load('b.npy')`. ***Radii*** are in shape `(n_spheres, )` and ***poss*** shape is `(n_spheres, 3)`. I will prepare data as you want, but I must ensure that the created data (which are spheres) have overlaps from the beginning. – Ali_Sh Feb 15 '22 at 13:54
  • 1
    That's helpful information – up until now, there has been no way to know the shape or number of dimensions in your arrays, or to know that there are requirements as to the content. Do you see how your question was nearly impossible to answer before you provided this information? That would be the benefit of using a minimal reproducible example from the start. – jakevdp Feb 15 '22 at 16:08
  • @jakevdp, I edited the question and added an input section on the code, too. I'm really curious to see how it could be handled efficiently by JAX. Which of the scipy method or the numpy method will be handled better (in terms of performance) by JAX (or jitting)? – Ali_Sh Mar 07 '22 at 23:27
  • What do you mean by "the numpy method"? I only see one function in your question. – jakevdp Mar 08 '22 at 18:26
  • @jakevdp, I utilized both of scipy and numpy for determining `nears_i_ind ` and `dist_i` in the code (I showed them by comments at the end of the related lines, please scroll right). I commented Numpy related code line for `dist_i` and multi-line commented Numpy section by triplet-single quotes for determining `nears_i_ind`, in the body of one function. – Ali_Sh Mar 08 '22 at 18:48

2 Answers2

3

JAX cannot be used to optimize general numpy/scipy code, however it can be used to optimize/compile code written in JAX.

Your example revolves around the use of scipy's cKDTree. This is not implemented in JAX, and so it cannot be optimized or compiled in JAX, and using it within a jitted function will lead to the error you're seeing. If you want to use a KD tree with JAX, you'll have to find one implemented in JAX. I don't know of any such code.

As for why the code becomes slower when you replace np with jnp here, it's because you're really only using JAX as an alternate array container. Every time you pass a JAX array to a cKDTree call, it has to be converted to a numpy array, and then the result has to be converted back to a JAX array. This extra movement of data adds overhead to each call, making the result slower. This is not because JAX itself is slow, it's because you're not really using JAX as anything but a way of temporarily storing your data before converting it back to numpy.

Generally this kind of overhead can be reduced by wrapping the function in jax.jit, but as mentioned before, this is not compatible with non-jax code like scipy.spatial.cKDTree.

I suspect your best course of action would be to avoid using JAX and just use numpy/scipy and the cKDTree. I don't know of any JAX-compatible implementation of tree-based neighbor search algorithms, and full brute force approaches would not be competitive with cKDTree for large arrays.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Scipy related error was `TracerArrayConversionError`, not `ConcretizationTypeError`, *when we use jax.jit*. What if we remove all Scipy related code lines and use the equivalent prepared (commented) NumPy lines to find nearest neighbors, due to incompatibility of Scipy cKDTree with JAX? `ConcretizationTypeError` error is appearing when utilizing NumPy for finding nearest neighbors of each point as it mentioned. Why jax.jit have problem with this code now? How to solve the problem, if possible? Please tell me to separate the two methods (SciPy and NumPy), if they are confusing on one function. – Ali_Sh Mar 08 '22 at 20:45
  • You could certainly JIT-compile a function that only contains `jax.numpy` functions. But `scipy.cKDTree` appears to be at the core of the algorithm in your question, so you'd have to find some replacement for it. – jakevdp Mar 08 '22 at 21:40
  • Regarding the `ConcretizationTypeError`, search for "Shape depends on Traced Value" in the link given in the error message for more information. – jakevdp Mar 08 '22 at 21:43
  • Since my question could not be answered as requested based on the answers on the JAX opened discussion ([+9813](https://github.com/google/jax/discussions/9813)), the bounty was awarded to make up for your time and specialized answers. Please update/complete your answer based on your answers on the JAX discussion, if you find this useful based on your experiences, for other users. Thanks. – Ali_Sh Mar 12 '22 at 00:11
1

I looked into this earlier this year. I had an existing numba implementation and wanted to port it to jax. I started (repo here) but abandoned the project when I realized that jax's jit performance is currently woeful compared to numba for these types of algorithms with loops and index updates. I believe it may be related to this issue, but I could certainly be wrong.

For the moment, if you want to execute KDTree operations inside a jitted function you can use jax.experimental.host_callback.call to wrap an existing implementation. It won't speed up the external function, but jax's jit may improve other aspects of the jitted code.

DomJack
  • 4,098
  • 1
  • 17
  • 32