In Flax, we typically initialize a model by passing in a random vector and let the library figure the correct shape for the parameters via shape inference. For example, this is what the tutorial did
def create_train_state(rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
cnn = CNN()
params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']
tx = optax.sgd(learning_rate, momentum)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
It is worth noting that the concrete value of jnp.ones([1, 28, 28, 1])
does not matter, as shape inference only relies on its shape. I can replace it with jnp.zeros([1, 28, 28, 1])
or jnp.random(jax.random.PRNGKey(42), [1, 28, 28, 1])
, and it will give me the exactly same result.
My question is, can I use jnp.empty([1, 28, 28, 1])
instead? I want to use jnp.empty
to clarify that we don't care about the value (and it could also be faster but the speedup is negligible). However, there is something called trap representation in C, and it looks like reading from jnp.empty
without overwriting it first could trigger undefined behavior. Since Numpy is a light wrapper around C, should I worry about that?
Bonus question: let's forget about Jax and focus on vanilla Numpy. It is safe to read from np.empty([...])
? Again, I don't care about the value, but I do care about not getting a segfault.