What are main differences between flax and dm-haiku?
From theirs descriptions:
- Flax, a neural network library for JAX
- Haiku, a neural network library for JAX inspired by Sonnet
Question:
Which one jax-based library should I pick to implement, let's say DeepSpeech model (consists of CNN layers + LSTM layers + FC) and ctc-loss?
UPD.
Found the explanation about differences from the developer of dm-haiku:
Flax is a bit more batteries included, and comes with optimizers, mixed precision and some training loops (I am told these are decoupled and you can use as much or as little as you want). Haiku aims to just solve NN modules and state management, it leaves other parts of the problem to other libraries (e.g. optax for optimization).
Haiku is designed to be a port of Sonnet (a TF NN library) to JAX. So Haiku is a better choice if (like DeepMind) you have a significant amount of Sonnet+TF code that you might want to use in JAX and you want migrating that code (in either direction) to be as easy as possible.
I think otherwise it comes down to personal preference. Within Alphabet there are 100s of researchers using each library so I don't think you can go wrong either way. At DeepMind we have standardised on Haiku because it makes sense for us. I would suggest taking a look at the example code provided by both libraries and seeing which matches your preferences for structuring experiments. I think you'll find that moving code from one library to another is not very complicated if you change your mind in the future.
The original question is still relevant.