An error occurred while executing the following cell: ------------------ def sample_from_posterior(key, data, save_name): kernel = kernels.ExpSquared(scale=1.0) gp = GaussianProcess(kernel, data["x_train"], diag=1e-8 + data["noise"]) pred_gp = gp.condition(data["y_train"], data["x_test"]).gp samples = pred_gp.sample(key, shape=(data["n_samples"],)) mean = pred_gp.mean std = pred_gp.variance**0.5 plt.figure() plt.plot(data["x_test"], samples.T, color=c_0) plt.scatter(data["x_train"], data["y_train"], color=c_1, zorder=10, s=5) plt.fill_between(data["x_test"].flatten(), mean - 2 * std, mean + 2 * std, alpha=0.2) plt.xlabel("$x$") plt.ylabel("$f$") # plt.legend(["samples"], loc="lower left") sns.despine() if len(save_name) > 0: savefig(save_name) key = jax.random.PRNGKey(0) x_train = jnp.array([[-4], [-3], [-2], [-1], [1]]) y_train = jnp.sin(x_train).flatten() x_test = jnp.arange(-5, 5.2, 0.2) data = {"noise": 0.0, "n_samples": 3, "x_train": x_train, "y_train": y_train, "x_test": x_test} sample_from_posterior(key, data, "gprDemoNoiseFreePost") data = {"noise": 0.3, "n_samples": 3, "x_train": x_train, "y_train": y_train, "x_test": x_test} sample_from_posterior(key, data, "gprDemoNoisyPost") ------------------ --------------------------------------------------------------------------- ValueError Traceback (most recent call last) /tmp/ipykernel_4161/2076173755.py in 26 27 data = {"noise": 0.0, "n_samples": 3, "x_train": x_train, "y_train": y_train, "x_test": x_test} ---> 28 sample_from_posterior(key, data, "gprDemoNoiseFreePost") 29 30 data = {"noise": 0.3, "n_samples": 3, "x_train": x_train, "y_train": y_train, "x_test": x_test} /tmp/ipykernel_4161/2076173755.py in sample_from_posterior(key, data, save_name) 2 kernel = kernels.ExpSquared(scale=1.0) 3 gp = GaussianProcess(kernel, data["x_train"], diag=1e-8 + data["noise"]) ----> 4 pred_gp = gp.condition(data["y_train"], data["x_test"]).gp 5 samples = pred_gp.sample(key, shape=(data["n_samples"],)) 6 mean = pred_gp.mean ~/miniconda3/envs/py37/lib/python3.7/site-packages/tinygp/gp.py in condition(self, y, X_test, diag, noise, include_mean, kernel) 181 if not jax.tree_util.tree_reduce(lambda a, b: a and b, matches): 182 raise ValueError( --> 183 "`X_test` must have the same tree structure as the input `X`, " 184 "and all but the leading dimension must have matching sizes" 185 ) ValueError: `X_test` must have the same tree structure as the input `X`, and all but the leading dimension must have matching sizes ValueError: `X_test` must have the same tree structure as the input `X`, and all but the leading dimension must have matching sizes