An error occurred while executing the following cell: ------------------ # Multi-armed bandit problem for a linear Gaussian model # with linear reward function. # In this demo, we consider three arms: # 1. The first arm is an upward-trending arm with initial negative bias # 2. The second arm is a downward-trending arm with initial positive bias # 3. The third arm is a stationary arm with initial zero bias # !pip install -qq -Uq tfp-nightly[jax] > /dev/null # Author: Gerardo Durán-Martín (@gerdm) import jax import seaborn as sns import matplotlib.pyplot as plt try: import probml_utils as pml except ModuleNotFoundError: %pip install -qq git+https://github.com/probml/probml-utils.git import probml_utils as pml import jax.numpy as jnp import pandas as pd from jax import random from functools import partial from jax.nn import one_hot try: from tensorflow_probability.substrates import jax as tfp except ModuleNotFoundError: %pip install -qq tensorflow-probability from tensorflow_probability.substrates import jax as tfp tfd = tfp.distributions class NormalGammaBandit: def sample(self, key, params, state): key_sigma, key_w = random.split(key, 2) sigma2_samp = tfd.InverseGamma(concentration=params["a"], scale=params["b"]).sample(seed=key_sigma) cov_matrix_samples = sigma2_samp[:, None, None] * params["Sigma"] w_samp = tfd.MultivariateNormalFullCovariance(loc=params["mu"], covariance_matrix=cov_matrix_samples).sample( seed=key_w ) return sigma2_samp, w_samp def predict_rewards(self, params_sample, state): sigma2_samp, w_samp = params_sample predicted_reward = jnp.einsum("m,km->k", state, w_samp) return predicted_reward def update(self, action, params, state, reward): """ Update the parameters of the model for the chosen arm """ mu_k = params["mu"][action] Sigma_k = params["Sigma"][action] Lambda_k = jnp.linalg.inv(Sigma_k) a_k = params["a"][action] b_k = params["b"][action] # weight params Lambda_update = jnp.outer(state, state) + Lambda_k Sigma_update = jnp.linalg.inv(Lambda_update) mu_update = Sigma_update @ (Lambda_k @ mu_k + state * reward) # noise params a_update = a_k + 1 / 2 b_update = b_k + (reward**2 + mu_k.T @ Lambda_k @ mu_k - mu_update.T @ Lambda_update @ mu_update) / 2 # Update only the chosen action at time t params["mu"] = params["mu"].at[action].set(mu_update) params["Sigma"] = params["Sigma"].at[action].set(Sigma_update) params["a"] = params["a"].at[action].set(a_update) params["b"] = params["b"].at[action].set(b_update) params = {"mu": params["mu"], "Sigma": params["Sigma"], "a": params["a"], "b": params["b"]} return params def true_reward(key, action, state, true_params): """ Compute true reward as the linear combination of each set of weights and the observed state plus the noise from each arm """ w_k = true_params["w"][action] sigma_k = jnp.sqrt(true_params["sigma2"][action]) reward = w_k @ state + random.normal(key) * sigma_k return reward def thompson_sampling_step(model_params, state, model, environment): """ Contextual implementation of the Thompson sampling algorithm. This implementation considers a single step Parameters ---------- model_params: dict environment: function key: jax.random.PRNGKey moidel: instance of a Bandit model """ key, context = state key_sample, key_reward = random.split(key) # Sample an choose an action params = model.sample(key_sample, model_params, context) pred_rewards = model.predict_rewards(params, context) action = pred_rewards.argmax() # environment reward reward = environment(key_reward, action, context) model_params = model.update(action, model_params, context, reward) arm_reward = one_hot(action, K) * reward return model_params, (model_params, arm_reward) plt.rcParams["axes.spines.top"] = False plt.rcParams["axes.spines.right"] = False # 1. Specify underlying dynamics (unknown) W = jnp.array([[-5.0, 2.0, 0.5], [0.0, 0.0, 0.0], [5.0, -1.5, -1.0]]) sigmas = jnp.ones(3) K, M = W.shape N = 500 T = 4 x = jnp.linspace(0, T, N) X = jnp.c_[jnp.ones(N), x, x**2] true_params = {"w": W, "sigma2": sigmas**2} # 2. Sample one instance of the multi-armed bandit process # this is only for plotting, it will not be used fo training key = random.PRNGKey(314) noise = random.multivariate_normal(key, mean=jnp.zeros(K), cov=jnp.eye(K) * sigmas, shape=(N,)) Y = jnp.einsum("nm,km->nk", X, W) + noise # 3. Configure the model parameters that will be used # during Thompson sampling eta = 2.0 lmbda = 5.0 init_params = { "mu": jnp.zeros((K, M)), "Sigma": lmbda * jnp.eye(M) * jnp.ones((K, 1, 1)), "a": eta * jnp.ones(K), "b": eta * jnp.ones(K), } environment = partial(true_reward, true_params=true_params) thompson_partial = partial(thompson_sampling_step, model=NormalGammaBandit(), environment=environment) thompson_vmap = jax.vmap(lambda key: jax.lax.scan(thompson_partial, init_params, (random.split(key, N), X))) # 4. Do Thompson sampling nsamples = 100 key = random.PRNGKey(3141) keys = random.split(key, nsamples) posteriors_samples, (_, hist_reward_samples) = thompson_vmap(keys) # 5. Plotting # 5.1 Example dataset plt.plot(x, Y) plt.axhline(y=0, c="black") plt.legend([f"arm{i}" for i in range(K)]) pml.savefig("bandit-lingauss-true-reward.pdf") # 5.2 Plot heatmap of chosen arm and given reward ix = 0 map_reward = hist_reward_samples[ix] map_reward = map_reward.at[map_reward == 0].set(jnp.nan) labels = [f"arm{i}" for i in range(K)] map_reward_df = pd.DataFrame(map_reward, index=[f"{t:0.2f}" for t in x], columns=labels) fig, ax = plt.subplots(figsize=(4, 5)) sns.heatmap(map_reward_df, cmap="viridis", ax=ax, xticklabels=labels) plt.ylabel("time") pml.savefig("bandit-lingauss-heatmap.pdf") # 5.3 Plot cumulative reward per arm fig, ax = plt.subplots() plt.plot(x, hist_reward_samples[ix].cumsum(axis=0)) plt.legend(labels, loc="upper left") plt.ylabel("cumulative reward") plt.xlabel("time") pml.savefig("bandit-lingauss-cumulative-reward.pdf") # 5.4 Plot regret fig, ax = plt.subplots() expected_hist_reward = hist_reward_samples.mean(axis=0) optimal_reward = jnp.einsum("nm,km->nk", X, true_params["w"]).max(axis=1) regret = optimal_reward - expected_hist_reward.max(axis=1) cumulative_regret = regret.cumsum() # plt.plot(x, cumulative_regret) plt.plot(x, cumulative_regret, label="observed") scale_factor = 20 # empirical plt.plot(x, scale_factor * jnp.sqrt(x), label="c $\sqrt{t}$") plt.title("Cumulative regret") plt.ylabel("$L_T$") plt.xlabel("time") plt.legend() pml.savefig("bandit-lingauss-cumulative-regret.pdf") plt.show() ------------------ --------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_5111/2794284974.py in 178 179 fig, ax = plt.subplots(figsize=(4, 5)) --> 180 sns.heatmap(map_reward_df, cmap="viridis", ax=ax, xticklabels=labels) 181 plt.ylabel("time") 182 pml.savefig("bandit-lingauss-heatmap.pdf") ~/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/matrix.py in heatmap(data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, linewidths, linecolor, cbar, cbar_kws, cbar_ax, square, xticklabels, yticklabels, mask, ax, **kwargs) 457 if square: 458 ax.set_aspect("equal") --> 459 plotter.plot(ax, cbar_ax, kwargs) 460 return ax 461 ~/miniconda3/envs/py37/lib/python3.7/site-packages/seaborn/matrix.py in plot(self, ax, cax, kws) 304 305 # Draw the heatmap --> 306 mesh = ax.pcolormesh(self.plot_data, cmap=self.cmap, **kws) 307 308 # Set the axis limits ~/miniconda3/envs/py37/lib/python3.7/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs) 1412 def inner(ax, *args, data=None, **kwargs): 1413 if data is None: -> 1414 return func(ax, *map(sanitize_sequence, args), **kwargs) 1415 1416 bound = new_sig.bind(ax, *args, **kwargs) ~/miniconda3/envs/py37/lib/python3.7/site-packages/matplotlib/axes/_axes.py in pcolormesh(self, alpha, norm, cmap, vmin, vmax, shading, antialiased, *args, **kwargs) 6072 collection = mcoll.QuadMesh( 6073 coords, antialiased=antialiased, shading=shading, -> 6074 array=C, cmap=cmap, norm=norm, alpha=alpha, **kwargs) 6075 collection._scale_norm(norm, vmin, vmax) 6076 self._pcolor_grid_deprecation_helper() ~/miniconda3/envs/py37/lib/python3.7/site-packages/matplotlib/collections.py in __init__(self, *args, **kwargs) 2013 # super init delayed after own init because array kwarg requires 2014 # self._coordinates and self._shading -> 2015 super().__init__(**kwargs) 2016 self.mouseover = False 2017 ~/miniconda3/envs/py37/lib/python3.7/site-packages/matplotlib/collections.py in __init__(self, edgecolors, facecolors, linewidths, linestyles, capstyle, joinstyle, antialiaseds, offsets, transOffset, norm, cmap, pickradius, hatch, urls, zorder, **kwargs) 215 216 self._path_effects = None --> 217 self.update(kwargs) 218 self._paths = None 219 ~/miniconda3/envs/py37/lib/python3.7/site-packages/matplotlib/artist.py in update(self, props) 1067 raise AttributeError(f"{type(self).__name__!r} object " 1068 f"has no property {k!r}") -> 1069 ret.append(func(v)) 1070 if ret: 1071 self.pchanged() ~/miniconda3/envs/py37/lib/python3.7/site-packages/matplotlib/collections.py in set_array(self, A) 2075 f"X ({width}) and/or Y ({height})") 2076 -> 2077 return super().set_array(A) 2078 2079 def get_datalim(self, transData): ~/miniconda3/envs/py37/lib/python3.7/site-packages/matplotlib/cm.py in set_array(self, A) 475 A = cbook.safe_masked_invalid(A, copy=True) 476 if not np.can_cast(A.dtype, float, "same_kind"): --> 477 raise TypeError(f"Image data of dtype {A.dtype} cannot be " 478 "converted to float") 479 TypeError: Image data of dtype object cannot be converted to float TypeError: Image data of dtype object cannot be converted to float