Questions tagged [jax]

JAX allows to write auto-differentiable functions. It provides a NumPy and native Python compatible interface built on composable function transformations. Further optimization happens by automatic vectorization and running code on GPUs/TPUs.

Documentation: https://jax.readthedocs.io

Project repo: https://github.com/google/jax

496 questions
13
votes
5 answers

Not able to install jaxlib

I am trying to install jaxlib on my windows 10 by the following command which I found on the documentation.. pip install jaxlib It shows the following error Collecting jaxlib Could not find a version that satisfies the requirement jaxlib (from…
user13782377
12
votes
1 answer

What is the main difference between flax (google) and dm-haiku (deepmind)?

What are main differences between flax and dm-haiku? From theirs descriptions: Flax, a neural network library for JAX Haiku, a neural network library for JAX inspired by Sonnet Question: Which one jax-based library should I pick to implement,…
discort
  • 678
  • 1
  • 8
  • 26
8
votes
4 answers

importing jax fails on mac with m1 chip

For python 3.8.8 and using the new mac air (with the m1 chip), in jupyter notebooks and in python terminal, import jax raises this error Python 3.8.8 (default, Apr 13 2021, 12:59:45) [Clang 10.0.0 ] :: Anaconda, Inc. on darwin Type "help",…
dcxst
  • 173
  • 3
  • 8
8
votes
1 answer

Non-hashable static arguments are not supported in Jax when using vmap

This is related to this question. After some work, I managed to change it down to the last error. The code looks like this now. import jax.numpy as jnp from jax import grad, jit, value_and_grad from jax import vmap, pmap from jax import…
RanWang
  • 310
  • 2
  • 12
7
votes
2 answers

JAX Apply function only on slice of array under jit

I am using JAX, and I want to perform an operation like @jax.jit def fun(x, index): x[:index] = other_fun(x[:index]) return x This cannot be performed under jit. Is there a way of doing this with jax.ops or jax.lax? I thought of using…
Federico Taschin
  • 2,027
  • 3
  • 15
  • 28
6
votes
1 answer

Test jax.pmap before deploying on multi-device hardware

My question is fairly simple: I am coding on a single-device small laptop and I am using jax.pmap because my code will run on multiple TPUs. I would like to "fake" having multiple devices to test my code and try different things. Is there any way to…
Valentin Macé
  • 1,150
  • 1
  • 10
  • 25
6
votes
1 answer

AttributeError: module 'flax' has no attribute 'nn'

I'm trying to run RegNeRF, which requires flax. On installing the latest version of flax==0.6.0, I got an error stating flax has no attribute optim. This answer suggested to downgrade flax to 0.5.1. On doing that, now I'm getting the error…
Nagabhushan S N
  • 6,407
  • 8
  • 44
  • 87
6
votes
1 answer

JAX pmap with multi-core CPU

What is the correct method for using multiple CPU cores with jax.pmap? The following example creates an environment variable for SPMD on CPU core backends, tests that JAX recognises the devices, and attempts a device lock. import…
DavidJ
  • 326
  • 2
  • 10
6
votes
2 answers

How to install trax, jax, jaxlib on M1 Mac on macOS 12?

trax New to trax, I'm trying to run it locally (macOS 12.1, Apple Silicon ARM M1 processor, 8GB RAM, Anaconda), but I'm running into some issues. In an environment with python 3.8.5, I installed trax running pip3 install trax==1.3.9 inside an…
arturomp
  • 28,790
  • 10
  • 43
  • 72
6
votes
1 answer

JAX: Converting Concrete Tracer values to regular float values doesnt work

I am using JAX for auto differentiation. In this, I am trying to convert concrete tracer values to regular float values using astype(float) but it still seems to return a concrete tracer value. However when I do astype(int) it seems to correctly…
6
votes
2 answers

Jax, jit and dynamic shapes: a regression from Tensorflow?

The documentation for JAX says, Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time. Now I am somewhat surprised because tensorflow has operations like tf.boolean_mask that does what JAX seems…
user209974
  • 1,737
  • 2
  • 15
  • 31
5
votes
1 answer

Modify an array from indexes contained in another array

I have an array of the shape (2,10) such as: arr = jnp.ones(shape=(2,10)) * 2 or [[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.] [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]] and another array, for example [2,4]. I want the second array to tell from which index the elements…
Valentin Macé
  • 1,150
  • 1
  • 10
  • 25
5
votes
0 answers

Cannot install Jaxlib on Apple M1, on Docker

I am trying to install Jax, Jaxlib and Chex on Docker, on a Apple M1 Pro machine. The Docker image's base OS is Debian. Doing uname -m gives aarch64 inside the container and arm64 in my local terminal. I am able to install Jax with no issues via…
Asier R.
  • 443
  • 1
  • 5
  • 11
5
votes
1 answer

in_axes keyword in JAX's vmap

I'm trying to understand JAX's auto-vectorization capabilities using vmap and implemented a minimal working example based on JAX's documentation. I don't understand how in_axes is used correctly. In the example below I can set in_axes=(None, 0) or…
Gilfoyle
  • 3,282
  • 3
  • 47
  • 83
5
votes
1 answer

Jax - Debugging NaN-values

Nice evening everyone, i spent the last 6 hours trying to debug seemingly randomly occuring NaN-values in Jax. I have narrowed down that the NaNs initially stem from either the loss-function or its gradient. A minimal-notebook that reproduces the…
Simon B
  • 199
  • 1
  • 9
1
2 3
33 34