2

For example you set up a module and that has params. But if you want do regularize something in a loss what is the pattern?

import jax.numpy as jnp
import jax
def loss(params, x, y):
   l = jnp.sum((y - mlp.apply(params, x)) ** 2)
   w = hk.get_params(params, 'w') # does not work like this
   l += jnp.sum(w ** w)
   return l

There is some pattern missing in the examples.

mathtick
  • 6,487
  • 13
  • 56
  • 101
  • It looks like there is a `params.keys()` that you can look at and just access the items in the dict that way. I'm not sure if this is the "official" pattern for stuff like this. – mathtick Sep 02 '21 at 14:14

1 Answers1

1

params is essentially a read-only dictionary, so you can get the value of a parameter by treating it as a dictionary:

print(params['w'])

If you want to update the parameters, you cannot do it in-place, but have to first convert it to a mutable dictionary:

params_mutable = hk.data_structures.to_mutable_dict(params)
params_mutable['w'] = 3.14
params_new = hk.data_structures.to_immutable_dict(params_mutable)
jakevdp
  • 77,104
  • 11
  • 125
  • 160