functools.partial
will partially evaluate a function, binding arguments to it for when it is called later. here's an example of it being used with a function:
from functools import partial
def f(x, y, z):
print(f"{x=} {y=} {z=}")
g = partial(f, 1, z=3)
g(2)
# x=1 y=2 z=3
and here is an example of it being used on a class constructor:
from typing import NamedTuple
class MyClass(NamedTuple):
a: int
b: int
c: int
make_class = partial(MyClass, 1, c=3)
print(make_class(b=2))
# MyClass(a=1, b=2, c=3)
The use in the flax example is conceptually the same: partial(f)
returns a function that when called, applies the bound arguments to the original callable, whether it is a function, a method, or a class constructor.
For example, the ResNet18
function created here:
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],
block_cls=ResNetBlock)
is a partially-evaluated ResNet
constructor, and the function is called in a test here:
@parameterized.product(
model=(models.ResNet18, models.ResNet18Local)
)
def test_resnet_18_v1_model(self, model):
"""Tests ResNet18 V1 model definition and output (variables)."""
rng = jax.random.PRNGKey(0)
model_def = model(num_classes=2, dtype=jnp.float32)
variables = model_def.init(
rng, jnp.ones((1, 64, 64, 3), jnp.float32))
self.assertLen(variables, 2)
self.assertLen(variables['params'], 11)
model
here is the partially evaluated function ResNet18
, and when it is called it returns the fully-instantiated ResNet
object with the parameters specified in the ResNet18
partial definition.