An error occurred while executing the following cell: ------------------ from typing import Sequence import matplotlib.pyplot as plt %matplotlib inline %pip install -qq jax==0.3.14 import jax import jax.numpy as jnp try: import flax.linen as nn from flax.training import train_state except ModuleNotFoundError: %pip install -qq flax import flax.linen as nn from flax.training import train_state try: import optax except ModuleNotFoundError: %pip install -qq optax import optax import functools import scipy as sp import math try: import seaborn as sns except ModuleNotFoundError: %pip install -qq seaborn import seaborn as sns try: import probml_utils as pml from probml_utils import savefig, latexify except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml from probml_utils import savefig, latexify rng = jax.random.PRNGKey(0) ------------------ --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /tmp/ipykernel_5910/1147266451.py in 5 6 get_ipython().run_line_magic('pip', 'install -qq jax==0.3.14') ----> 7 import jax 8 import jax.numpy as jnp 9 ~/miniconda3/envs/py37/lib/python3.7/site-packages/jax/__init__.py in 33 # We want the exported object to be the class, so we first import the module 34 # to make sure a later import doesn't overwrite the class. ---> 35 from jax import config as _config_module 36 del _config_module 37 ~/miniconda3/envs/py37/lib/python3.7/site-packages/jax/config.py in 15 # TODO(phawkins): fix users of this alias and delete this file. 16 ---> 17 from jax._src.config import config ~/miniconda3/envs/py37/lib/python3.7/site-packages/jax/_src/config.py in 27 from absl import logging 28 ---> 29 from jax._src import lib 30 from jax._src.lib import jax_jit 31 from jax._src.lib import transfer_guard_lib ~/miniconda3/envs/py37/lib/python3.7/site-packages/jax/_src/lib/__init__.py in 95 jax_version=jax.version.__version__, 96 jaxlib_version=jaxlib.version.__version__, ---> 97 minimum_jaxlib_version=jax.version._minimum_jaxlib_version) 98 99 ~/miniconda3/envs/py37/lib/python3.7/site-packages/jax/_src/lib/__init__.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version) 87 f'incompatible with jax version {jax_version}. Please ' 88 'update your jax and/or jaxlib packages.') ---> 89 raise RuntimeError(msg) 90 91 return _jaxlib_version RuntimeError: jaxlib version 0.3.25 is newer than and incompatible with jax version 0.3.14. Please update your jax and/or jaxlib packages. RuntimeError: jaxlib version 0.3.25 is newer than and incompatible with jax version 0.3.14. Please update your jax and/or jaxlib packages.