Questions tagged [google-jax]

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

12 questions
4
votes
0 answers

Saving Gradient in Backward Pass Google-JAX

I am using JAX to implement a simple neural network (NN) and I want to access and save the gradients from the backward pass for further analysis after the NN ran. I can access and look at the gradients temporarily with the python debugger (as long…
2
votes
0 answers

Is it possible to use objects with Google's Jax machine learning library

I am trying to write a DC Gan network using Google's Jax machine learning library. To do this, I created objects to serve as the discriminator and generator, however, as I was testing the discriminator, I got the error: TypeError: Argument…
1
vote
1 answer

Jax vmap, in_axes doesn't work if keyword argument is passed

The parameter in_axes in vmap seems to only work for positional arguments. But throws AssertionError (with no message) called with keyword argument. from jax import vmap import numpy as np def foo(a, b, c): return a * b + c foo = vmap(foo,…
Amith M
  • 63
  • 5
1
vote
1 answer

Jax.lax.scan with arguments?

I'm trying to speed up the execution of my code rewriting for loops into jax.lax.scan, but I ran into the issue that I need the scanFunction to handle parameters passed to the main function - but how to do it? Here I get NameError: name 'coefs' is…
pepazdepa
  • 117
  • 8
1
vote
1 answer

Rewriting for loop with jax.lax.scan

I'm having troubles understanding the JAX documentation. Can somebody give me a hint on how to rewrite simple code like this with jax.lax.scan? numbers = numpy.array( [ [3.0, 14.0], [15.0, -7.0], [16.0, -11.0] ]) evenNumbers = 0 for row in numbers: …
pepazdepa
  • 117
  • 8
1
vote
1 answer

JAX: Getting rid of zero-gradient

Is there a way how to modify this function (MyFunc) so that it gives the same result, but its derivative is not zero gradient? from jax import grad import jax.nn as nn import numpy as np def MyFunc(coefs): a = coefs[0] b = coefs[1] c =…
pepazdepa
  • 117
  • 8
1
vote
1 answer

JAX grad function: why am I getting a list of zeros instead of the gradients?

I'm trying to find global maximum of a Python function with many variables (500+). For this purpose I'm trying to use JAX grad() to compute the gradient function of this MyFunction. But I'm obviously doing something wrong - because each time I try…
1
vote
0 answers

Optimizing sampling with varying sample sizes in jax

I'm looking for ideas on how to optimize the sampling of a varying number of guests for a varying number of hosts. Let me clarify what I'm trying to do. Given a number of hosts "n_hosts", each one with a different number of possible guests,…
1
vote
0 answers

JAX/XLA slow compilation using conda

I'm getting into using Google JAX and the built-in jit and grad functionality. These aspects are working nicely on my machine, but when I increase the number of arguments I get the following notification: ******************************** Slow…
Drphoton
  • 164
  • 9
1
vote
1 answer

How to use grad convolution in google-jax?

Thanks for reading my question! I was just learning about custom grad functions in Jax, and I found the approach JAX took with defining custom functions is quite elegant. One thing troubles me though. I created a wrapper to make lax convolution look…
0
votes
1 answer

How to vectorize JAX functions using jit compilation and vmap auto-vectorization

How can I use jit and vmap in JAX to vectorize and speed up the following computation: @jit def distance(X, Y): """Compute distance between two matrices X and Y. Args: X (jax.numpy.ndarray): matrix of shape (n, m) Y…
0
votes
1 answer

Profiling JAX code: What is redzone_checker and why does it take so much time?

I have found this post but am still unclear on what the redzone_checker kernel is doing and why. Specifically, should it be taking > 90% of my application's runtime? TensorBoard reports that it is taking the vast majority of the runtime of my JAX…
emprice
  • 912
  • 11
  • 21