Questions tagged [dm-haiku]

4 questions
12
votes
1 answer

What is the main difference between flax (google) and dm-haiku (deepmind)?

What are main differences between flax and dm-haiku? From theirs descriptions: Flax, a neural network library for JAX Haiku, a neural network library for JAX inspired by Sonnet Question: Which one jax-based library should I pick to implement,…
discort
  • 678
  • 1
  • 8
  • 26
1
vote
1 answer

Migration from haiku: Alternative to Haiku's PRNGSequence?

I am writing a Markov chain Monte Carlo simulation in JAX which involves a large series of sampling steps. I currently rely on haiku's PRNGSequence to do the pseudo random number generator key bookkeeping: import haiku as hk def step(key,…
Hylke
  • 75
  • 6
1
vote
1 answer

AttributeError: module 'jax.tree_util' has no attribute 'PyTreeDef' when import haiku

I'm trying to use the Jax library with haiku on python3.6 at conda env, I met the below error and am stuck. I have tried to update my Jax version but nothing changed how can I fix it? Traceback (most recent call last): File "train.py", line 14, in…
Manyerror
  • 25
  • 3
1
vote
0 answers

Mixed Precision Training using Jax

I'm trying to understand how did Haiku achieve 2x speedup when training ResNet50 on ImageNet https://github.com/deepmind/dm-haiku/tree/main/examples/imagenet using the Deepmind JMP lib https://github.com/deepmind/jmp, and how to replicate this with…
Jankins21
  • 11
  • 1