My code is as follows:
!pip install flax
init_params = TransporterNets().init(key, init_img, init_text, init_pix)['params']
print(f'Model parameters: {n_params(init_params):,}')
optim = flax.optim.Adam(lr=1e-4).create(init_params)
However it shows the following error:
AttributeError: module 'flax' has no attribute 'optim'
Even though I have seen documentation of optim
attribute in flax
module. How to fix it?