0

In dm-haiku, parameters of neural networks are defined in dictionaries where keys are module (and submodule) names. If you would like to traverse through the values, there are multiple ways of doing so as shown in this dm-haiku issue. However, the dictionary doesn't respect the ordering of the modules and makes it hard to parse submodules. For example, if I have 2 linear layers, each followed by a mlp layer, then using hk.data_structures.traverse(params) will (roughly) return:

['linear', 'linear_2', 'mlp/~/1', 'mlp/~/2'].

whereas I would like it to return:

['linear', 'mlp/~/1', 'linear_2', 'mlp/~/2'].

My reason for wanting this form is if creating an invertible neural network and wanting to reverse the order the params are called, isolating substituent parts for other purposes (e.g. transfer learning), or, in general, wanting more control of how and where to (re)use trained parameters.

To deal with this, I've resorted to regex the names and put them in the order that I want, then using hk.data_structures.filter(predicate, params) to filter by the sorted module names. Although, this is quite tedious if I have to remake a regex every time I want to do this.

I'm wondering if there is a way to convert a dm-haiku dictionary of params to something like a pytree with a hierarchy and ordering that makes this easier? I believe equinox handles parameters in this manner (and I'm going to look more into how that is done soon), but wanted to check to see if I'm overlooking a simple method to allow grouping, reversing, and other permutations of the params's dictionary?

VdZ
  • 95
  • 5

1 Answers1

1

According to source code https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/filtering.py#L42#L46 haiku use the sorted function of dict (haiku parameters are vanilla dict since 0.0.6) for hk.data_structures.traverse. Therefore you can't get the result you want without modifying the function itself. By the way, I don't get precisely what do you mean by "to reverse the order the params are called". All parameters are passed together in input and then the only thing that determines the order of use is the architecture of the function itself so you should manually invert the forward pass but you don't need to change something in params.

Valentin Goldité
  • 1,040
  • 4
  • 13
  • 1
    Thanks for the tip! I hacked together a way using regex since I couldn't find an answer in the docs... but I should have looked deeper into the source code for this utility function. – VdZ Aug 29 '22 at 20:52