Questions tagged [flax]
40 questions
12
votes
1 answer
What is the main difference between flax (google) and dm-haiku (deepmind)?
What are main differences between flax and dm-haiku?
From theirs descriptions:
Flax, a neural network library for JAX
Haiku, a neural network library for JAX inspired by Sonnet
Question:
Which one jax-based library should I pick to implement,…

discort
- 678
- 1
- 8
- 26
6
votes
1 answer
AttributeError: module 'flax' has no attribute 'nn'
I'm trying to run RegNeRF, which requires flax. On installing the latest version of flax==0.6.0, I got an error stating flax has no attribute optim. This answer suggested to downgrade flax to 0.5.1. On doing that, now I'm getting the error…

Nagabhushan S N
- 6,407
- 8
- 44
- 87
4
votes
1 answer
Flax much slower than pure Jax for neural nentworks?
for a project I am trying to code up a very simple MLP example, but I noticed that the implementation in flax is about 20 times slower than the pure jax implementation. What am I doing wrong here?
import time
import jax.numpy as np
from jax import…

Luca Thiede
- 3,229
- 4
- 21
- 32
2
votes
2 answers
Jax - vmap over batch of dataclasses
In JAX, I am looking to vmap a function over a fixed length list of dataclasses, for example:
import jax, chex
from flax import struct
@struct.dataclass
class EnvParams:
max_steps: int = 500
random_respawn: bool = False
def foo(params:…

EmptyJackson
- 21
- 5
2
votes
1 answer
AttributeError: module 'flax' has no attribute 'optim'
My code is as follows:
!pip install flax
init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params']
print(f'Model parameters: {n_params(init_params):,}')
optim = flax.optim.Adam(lr=1e-4).create(init_params)
However it shows…
2
votes
0 answers
Pytorch equivalent of `register_buffer` in flax/jax
I'm searching a way to write the equivalent of the following Pytorch module in Flax but I haven't found a way to do it. The important thing is that the constant should be loadable and saveable upon checkpoint.
class SillyModule(nn.Module):
def…

ysig
- 447
- 4
- 18
2
votes
1 answer
Calculating the Hessian Vector Product of a Flax NN output wrt to the inputs
I am trying to get the second derivative of the output w.r.t the input of a neural network built using Flax. The network is structured as follows:
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax…

Vignesh Gopakumar
- 143
- 1
- 3
- 7
1
vote
1 answer
Computing the gradient of a batched function using JAX
I would need to compute the gradient of a batched function using JAX. The following is a minimal example of what I would like to do:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
x = jnp.expand_dims(jnp.linspace(-1, 1, 20),…

al_cc
- 93
- 1
- 6
1
vote
0 answers
How to convert .safetensors or .ckpt Files and Using in FlaxStableDiffusionImg2ImgPipeline?
I am trying to convert a .safetensors model to a diffusers model using the Python script found at https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py. The command I tried is python3…

Aero Wang
- 8,382
- 14
- 63
- 99
1
vote
1 answer
data_format in JAX/FLAX
I did not find any settings for data_format=channels_first or data_format=channels_last in FLAX modules ( which are based on JAX ).
On the contrary, TensorFlow does have that designation. Does the choice of data_format is irrelevant to JAX…

ujjwalnur
- 11
- 4
1
vote
1 answer
Passing JAX tracers to Huggingface CLIP transformer for calculating loss
I'm working on a vision task using JAX, and I'm facing an issue with passing intermediate JAX tracer objects as images to the CLIP model for calculating the loss. The CLIP model expects NumPy arrays as inputs, so the JAX tracer objects are not…

Kian
- 15
- 3
1
vote
1 answer
The exact meaning of n_jitted_steps=5
I have tried to run the code. Here, there is a command called n_jitted_steps=5, which according to the authors, can accumulate several steps. Since the code is rather complicated, it might be difficult to understand. However, I have tried the…

RanWang
- 310
- 2
- 12
1
vote
1 answer
No module named 'jax.experimental.global_device_array' when running the official Flax Example on Colab with V100
I have been trying to understand this official flax example, based on a Coalb pro+ account with V100. When I execute the command python main.py --workdir=./imagenet --config=configs/v100_x8.py , the returned error is
File…

RanWang
- 310
- 2
- 12
1
vote
1 answer
Fail to understand the usage of partial argument in Flax Resnet Official Example
I have been trying to understand this official example. However, I am very confused about the use of partial in two places.
For example, in line 94, we have the following:
conv = partial(self.conv, use_bias=False, dtype=self.dtype)
I am not sure…

RanWang
- 310
- 2
- 12
1
vote
0 answers
Why run_t5_mlm_flax.py does not produces model weight file etc?
I was trying to reproduce this Hugging Face tutorial on T5-like span masked-language-modeling.
I have the following code tokenizing_and_configing.py:
import datasets
from t5_tokenizer_model import SentencePieceUnigramTokenizer
from transformers…

littleworth
- 4,781
- 6
- 42
- 76