In dm-haiku
, parameters of neural networks are defined in dictionaries where keys are module (and submodule) names. If you would like to traverse through the values, there are multiple ways of doing so as shown in this dm-haiku
issue. However, the dictionary doesn't respect the ordering of the modules and makes it hard to parse submodules. For example, if I have 2 linear
layers, each followed by a mlp
layer, then using hk.data_structures.traverse(params)
will (roughly) return:
['linear', 'linear_2', 'mlp/~/1', 'mlp/~/2'].
whereas I would like it to return:
['linear', 'mlp/~/1', 'linear_2', 'mlp/~/2'].
My reason for wanting this form is if creating an invertible neural network and wanting to reverse the order the params
are called, isolating substituent parts for other purposes (e.g. transfer learning), or, in general, wanting more control of how and where to (re)use trained parameters.
To deal with this, I've resorted to regex the names and put them in the order that I want, then using hk.data_structures.filter(predicate, params)
to filter by the sorted module names. Although, this is quite tedious if I have to remake a regex every time I want to do this.
I'm wondering if there is a way to convert a dm-haiku
dictionary of params to something like a pytree
with a hierarchy and ordering that makes this easier? I believe equinox
handles parameters in this manner (and I'm going to look more into how that is done soon), but wanted to check to see if I'm overlooking a simple method to allow grouping, reversing, and other permutations of the params
's dictionary?