For example you set up a module and that has params. But if you want do regularize something in a loss what is the pattern?
import jax.numpy as jnp
import jax
def loss(params, x, y):
l = jnp.sum((y - mlp.apply(params, x)) ** 2)
w = hk.get_params(params, 'w') # does not work like this
l += jnp.sum(w ** w)
return l
There is some pattern missing in the examples.