An error occurred while executing the following cell: ------------------ """ Fits Bernoulli mixture model for mnist digits using em algorithm Author: Meduri Venkata Shivaditya, Aleyna Kara(@karalleyna) """ from jax.random import PRNGKey, randint try: import tensorflow as tf except ModuleNotFoundError: %pip install -qq tensorflow import tensorflow as tf try: from probml_utils.mix_bernoulli_lib import BMM except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git from probml_utils.mix_bernoulli_lib import BMM from probml_utils.mix_bernoulli_em_mnist import mnist_data def main(): n_obs = 1000 observations = mnist_data(n_obs) # subsample the MNIST dataset n_vars = len(observations[0]) K, num_of_iters = 12, 10 n_row, n_col = 3, 4 bmm = BMM(K, n_vars) _ = bmm.fit_em(observations, num_of_iters=num_of_iters) bmm.plot(n_row, n_col, "bmm_em_mnist") if __name__ == "__main__": main() ------------------ --------------------------------------------------------------------------- ImportError Traceback (most recent call last) /tmp/ipykernel_5885/205051138.py in 13 import tensorflow as tf 14 try: ---> 15 from probml_utils.mix_bernoulli_lib import BMM 16 except ModuleNotFoundError: 17 get_ipython().run_line_magic('pip', 'install -qq git+https://github.com/probml/probml-utils.git') ~/miniconda3/envs/py37/lib/python3.7/site-packages/probml_utils/mix_bernoulli_lib.py in 8 from jax.scipy.special import expit, logit 9 from jax.nn import softmax ---> 10 from jax.experimental import optimizers 11 12 import distrax ImportError: cannot import name 'optimizers' from 'jax.experimental' (/github/home/miniconda3/envs/py37/lib/python3.7/site-packages/jax/experimental/__init__.py) ImportError: cannot import name 'optimizers' from 'jax.experimental' (/github/home/miniconda3/envs/py37/lib/python3.7/site-packages/jax/experimental/__init__.py)