I am trying to create custom parameters for a Haiku hk.Module.
I know that hk.get_parameter()
does create new parameters/returns existing parameters, however, I also want to modify the values of those parameters before they are returned when calling the .init()
function. Modifying the parameters after the return from the .init()
is straightforward as mentioned in the following StackOverflow thread: How does one get a parameter from a params (pytree) in haiku? (jax framework) I have tried using hk.custom_getter()
, however, it only modifies the return value to the user (when calling hk.get_parameter()
); the original parameters remain unmodified.
Is there a way to achieve this?
Here is a code snippet of what I am trying to achieve:
class HaikuModule(hk.Module):
def __call__(self):
w = hk.get_parameter("w", shape=[5,5], init=jnp.zeros) # create parameter w intialized to 0s
# modify the value of w
???
def forward_fn():
model = HaikuModule()
return model()
fn = hk.transform(forward_fn)
rng= jax.random.PRNGKey(10)
params = fn.init(rng)
# params should return modified params, not 0s(initialized value)