Recently, all my notebooks have been updated to python 3.10, and I can't run my old code. I was using jax 0.2.17 and numpyro 0.7.1, but these no longer work with my version, so I tried changing to python 3.9. I used the following code:
!sudo apt-get update -y
!sudo apt-get install python3.9
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1
!sudo update-alternatives --config python3
!apt-get install python3-pip
However, after downloading this and using !pip install jax==0.2.17
, when I check the jax
version using jax.__version__
, I get 0.4.8
, the version of jax
already installed in Google Colab. Are there any ways to fix this? I want to be able to use the old versions of jax
and numpyro
, and pip installs them, but when I import I appear to be importing from the 3.10 version instead. Thanks in advance!