{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "4a34e4a6-2e42-407c-b8c8-872ab304be66", "metadata": {}, "outputs": [], "source": [ "import operator\n", "from functools import reduce\n", "from typing import List\n", "\n", "import arviz as az\n", "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import numpyro\n", "import numpyro.distributions as dist\n", "from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs\n", "from numpyro.infer.util import Predictive\n", "\n", "rng_key = jax.random.PRNGKey(2)" ] }, { "cell_type": "code", "execution_count": 2, "id": "f65d6260-c78c-4a69-9538-c3446ce87200", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2\n", "%load_ext watermark" ] }, { "cell_type": "code", "execution_count": 3, "id": "26551ebc-7208-4ebd-be9a-5bc18117eb8b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python implementation: CPython\n", "Python version : 3.8.11\n", "IPython version : 7.18.1\n", "\n", "arviz : 0.11.2\n", "jax : 0.2.19\n", "matplotlib: 3.4.3\n", "numpy : 1.20.3\n", "pandas : 1.3.2\n", "numpyro : 0.7.2\n", "\n", "Compiler : GCC 7.5.0\n", "OS : Linux\n", "Release : 4.19.193-1-MANJARO\n", "Machine : x86_64\n", "Processor : \n", "CPU cores : 4\n", "Architecture: 64bit\n", "\n" ] } ], "source": [ "%watermark -v -m -p arviz,jax,matplotlib,numpy,pandas,numpyro" ] }, { "cell_type": "code", "execution_count": 4, "id": "30fb2e78-4584-455d-b602-f72e9495c2eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Git hash: fea84b871dfddca7ae6efdf221a994d7b439e749\n", "\n", "Git branch: master\n", "\n" ] } ], "source": [ "%watermark -gb" ] }, { "cell_type": "code", "execution_count": 5, "id": "6305d599-e33a-4b6b-9b35-68ee1ae8e53d", "metadata": {}, "outputs": [], "source": [ "expected = pd.DataFrame(\n", " [\n", " (False, False, False, 0.101, 0.101),\n", " (True, False, False, 0.802, 0.034),\n", " (False, True, False, 0.034, 0.802),\n", " (True, True, False, 0.561, 0.561),\n", " (False, False, True, 0.148, 0.148),\n", " (True, False, True, 0.862, 0.326),\n", " (False, True, True, 0.326, 0.862),\n", " (True, True, True, 0.946, 0.946),\n", " ],\n", " columns=[\"IsCorrect1\", \"IsCorrect2\", \"IsCorrect2\", \"P(csharp)\", \"P(sql)\"],\n", ")" ] }, { "cell_type": "markdown", "id": "e5a8a1de-8dc7-411c-8908-04c66b4aac68", "metadata": {}, "source": [ "# Purpose\n", "- Reproducing [`fritzo`'s answer](https://forum.pyro.ai/t/model-based-machine-learning-book-chapter-2-skills-example-in-pyro-tensor-dimension-issue/464/12?u=bdatko) to [Chapter 2 MBML Learning skills](https://mbmlbook.com/LearningSkills.html)\n", "\n", "The twist:\n", "1. we are using `numpyro.__version__ == 1.7.1` instead of `pyro.__version__ == 0.3`\n", "1. assume a fixed guessing probability (work on building one the first iterations of the model from the book)\n", "2. reporduce the results for just three questions, two skills using model form [**Figure 2.17**](https://mbmlbook.com/LearningSkills_Moving_to_real_data.html) with [**Table 2.4**](https://mbmlbook.com/LearningSkills_Testing_out_the_model.html), reproduced below\n", "\n", "| | IsCorrect1 | IsCorrect2 | IsCorrect2 | P(csharp) | P(sql) |\n", "|---:|:-------------|:-------------|:-------------|------------:|---------:|\n", "| 0 | False | False | False | 0.101 | 0.101 |\n", "| 1 | True | False | False | 0.802 | 0.034 |\n", "| 2 | False | True | False | 0.034 | 0.802 |\n", "| 3 | True | True | False | 0.561 | 0.561 |\n", "| 4 | False | False | True | 0.148 | 0.148 |\n", "| 5 | True | False | True | 0.862 | 0.326 |\n", "| 6 | False | True | True | 0.326 | 0.862 |\n", "| 7 | True | True | True | 0.946 | 0.946 |\n", "\n", "The table above can be used to check our model, and to get us ready for the *real data*. Lets view each permutation as a data record, resulting in a table of 3 responses from 8 people, where each question either needs `skill_01`, `skill_02`, or `skill_01` and `skill_02`. The toy data is shown below:" ] }, { "cell_type": "code", "execution_count": 6, "id": "f23d59c6-5e66-40e4-8b92-335d8659a8bb", "metadata": {}, "outputs": [], "source": [ "responses_check = jnp.array([[0., 1., 0., 1., 0., 1., 0., 1.], [0., 0., 1., 1., 0., 0., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.]])\n", "skills_needed_check = [[0], [1], [0, 1]]" ] }, { "cell_type": "markdown", "id": "3bd538ce-e14e-4c89-ae87-197f67969c11", "metadata": {}, "source": [ "- I have been playing around with various model and inference engines\n", "- trying out iterations based on the discussion on the [Pyro forum](https://forum.pyro.ai/t/numpyro-chapter-2-mbml/3184?u=bdatko)" ] }, { "cell_type": "markdown", "id": "4ada8c1d-c0e2-46c7-b685-80887fe974cc", "metadata": {}, "source": [ "#### model_00\n", "* trying out the two for loops over skills\n", "* beta priors for skills" ] }, { "cell_type": "code", "execution_count": 7, "id": "310fdfdb-d803-4598-bb28-ef6f436f0069", "metadata": {}, "outputs": [], "source": [ "def model_00(\n", " graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2\n", "):\n", " n_questions, n_participants = graded_responses.shape\n", " n_skills = max(map(max, skills_needed)) + 1\n", " \n", " participants_plate = numpyro.plate(\"participants_plate\", n_participants)\n", " \n", " with participants_plate:\n", " with numpyro.plate(\"skills_plate\", n_skills):\n", " theta = numpyro.sample(\"theta\", dist.Beta(1,1))\n", " \n", " skills = []\n", " \n", " for s in range(n_skills):\n", " skills.append([])\n", " for p in range(n_participants):\n", " sample = numpyro.sample(\"skill_{}_{}\".format(s,p), dist.Bernoulli(theta[s,p]))\n", " skills[s].append(sample.squeeze())\n", " \n", "\n", " for q in range(n_questions):\n", " has_skills = reduce(operator.mul, [jnp.array(skills[i]) for i in skills_needed[q]])\n", " for p in range(n_participants):\n", " prob_correct = has_skills[p] * (1 - prob_mistake) + (1 - has_skills[p]) * prob_guess\n", " isCorrect = numpyro.sample(\"isCorrect_{}_{}\".format(q,p), dist.Bernoulli(prob_correct), obs=graded_responses[q,p],)" ] }, { "cell_type": "code", "execution_count": 8, "id": "e72de02f-49e8-4320-8ec2-61082afaed85", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|██████████| 1200/1200 [01:30<00:00, 13.31it/s, 7 steps of size 4.32e-01. acc. prob=0.90]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " skill_0_0 0.09 0.29 0.00 0.00 0.00 856.76 1.00\n", " skill_0_1 0.80 0.40 1.00 0.00 1.00 902.07 1.00\n", " skill_0_2 0.03 0.17 0.00 0.00 0.00 917.94 1.00\n", " skill_0_3 0.56 0.50 1.00 0.00 1.00 672.03 1.00\n", " skill_0_4 0.14 0.35 0.00 0.00 1.00 647.06 1.00\n", " skill_0_5 0.87 0.34 1.00 0.00 1.00 893.91 1.00\n", " skill_0_6 0.32 0.47 0.00 0.00 1.00 916.99 1.00\n", " skill_0_7 0.94 0.24 1.00 1.00 1.00 1029.96 1.00\n", " skill_1_0 0.10 0.30 0.00 0.00 1.00 900.22 1.00\n", " skill_1_1 0.03 0.18 0.00 0.00 0.00 758.48 1.00\n", " skill_1_2 0.80 0.40 1.00 0.00 1.00 604.73 1.00\n", " skill_1_3 0.54 0.50 1.00 0.00 1.00 622.10 1.00\n", " skill_1_4 0.15 0.36 0.00 0.00 1.00 964.93 1.00\n", " skill_1_5 0.34 0.47 0.00 0.00 1.00 979.84 1.00\n", " skill_1_6 0.86 0.35 1.00 0.00 1.00 821.40 1.00\n", " skill_1_7 0.94 0.23 1.00 1.00 1.00 1048.92 1.00\n", " theta[0,0] 0.37 0.25 0.34 0.00 0.75 1164.94 1.00\n", " theta[0,1] 0.60 0.28 0.64 0.19 1.00 918.45 1.00\n", " theta[0,2] 0.34 0.23 0.30 0.00 0.70 1364.64 1.00\n", " theta[0,3] 0.52 0.29 0.53 0.11 1.00 729.96 1.00\n", " theta[0,4] 0.37 0.26 0.33 0.00 0.79 723.06 1.00\n", " theta[0,5] 0.63 0.26 0.67 0.24 1.00 1217.48 1.00\n", " theta[0,6] 0.44 0.28 0.40 0.01 0.85 854.28 1.00\n", " theta[0,7] 0.63 0.24 0.65 0.26 0.98 751.65 1.01\n", " theta[1,0] 0.37 0.26 0.34 0.00 0.76 915.23 1.00\n", " theta[1,1] 0.36 0.25 0.32 0.00 0.73 827.01 1.00\n", " theta[1,2] 0.61 0.28 0.67 0.18 1.00 530.27 1.00\n", " theta[1,3] 0.50 0.29 0.50 0.06 0.95 799.04 1.00\n", " theta[1,4] 0.39 0.27 0.36 0.00 0.81 1397.22 1.00\n", " theta[1,5] 0.45 0.29 0.41 0.02 0.90 843.83 1.00\n", " theta[1,6] 0.62 0.26 0.67 0.23 1.00 1060.55 1.00\n", " theta[1,7] 0.66 0.24 0.70 0.30 1.00 1182.22 1.00\n", "\n" ] } ], "source": [ "nuts_kernel = NUTS(model_00)\n", "\n", "kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)\n", "\n", "mcmc = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=1)\n", "mcmc.run(rng_key, responses_check, skills_needed_check)\n", "mcmc.print_summary()" ] }, { "cell_type": "code", "execution_count": 9, "id": "10300a42-5cd4-4544-9d8a-21736e27772f", "metadata": {}, "outputs": [], "source": [ "expected[\"model_00 P(csharp)\"] = [mcmc.get_samples()[key].mean() for key in list(mcmc.get_samples().keys())[:8]]\n", "expected[\"model_00 P(sql)\"] = [mcmc.get_samples()[key].mean() for key in list(mcmc.get_samples().keys())[8:-1]]" ] }, { "cell_type": "markdown", "id": "a3075f1a-3893-4ae2-a1cc-7f51666f5299", "metadata": {}, "source": [ "* below the code results in an AssertionError\n", "* trying using `infer_discrete` without NUTS and MCMC\n", "* probably b/c of the beta priors? I am not sure though\n", "\n", "```python\n", "predictive = Predictive(\n", " model_00,\n", " num_samples=1000,\n", " infer_discrete=True,\n", ")\n", "discrete_samples = predictive(rng_key, responses_check, skills_needed_check)\n", "```\n", "\n", "\n", "```python\n", "AssertionError Traceback (most recent call last)\n", " in \n", " 4 infer_discrete=True,\n", " 5 )\n", "----> 6 discrete_samples = predictive(rng_key, responses_check, skills_needed_check)\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in __call__(self, rng_key, *args, **kwargs)\n", " 892 )\n", " 893 model = substitute(self.model, self.params)\n", "--> 894 return _predictive(\n", " 895 rng_key,\n", " 896 model,\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)\n", " 737 rng_key = rng_key.reshape(batch_shape + (2,))\n", " 738 chunk_size = num_samples if parallel else 1\n", "--> 739 return soft_vmap(\n", " 740 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size\n", " 741 )\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size)\n", " 403 fn = vmap(fn)\n", " 404 \n", "--> 405 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)\n", " 406 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)\n", " 407 ys = tree_map(\n", "\n", " [... skipping hidden 15 frame]\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in single_prediction(val)\n", " 702 model_trace = prototype_trace\n", " 703 temperature = 1\n", "--> 704 pred_samples = _sample_posterior(\n", " 705 config_enumerate(condition(model, samples)),\n", " 706 first_available_dim,\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/contrib/funsor/discrete.py in _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs)\n", " 60 with funsor.adjoint.AdjointTape() as tape:\n", " 61 with block(), enum(first_available_dim=first_available_dim):\n", "---> 62 log_prob, model_tr, log_measures = _enum_log_density(\n", " 63 model, args, kwargs, {}, sum_op, prod_op\n", " 64 )\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)\n", " 157 model = substitute(model, data=params)\n", " 158 with plate_to_enum_plate():\n", "--> 159 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)\n", " 160 log_factors = []\n", " 161 time_to_factors = defaultdict(list) # log prob factors\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)\n", " 163 :return: `OrderedDict` containing the execution trace.\n", " 164 \"\"\"\n", "--> 165 self(*args, **kwargs)\n", " 166 return self.trace\n", " 167 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", " in model_00(graded_responses, skills_needed, prob_mistake, prob_guess)\n", " 9 with participants_plate:\n", " 10 with numpyro.plate(\"skills_plate\", n_skills):\n", "---> 11 theta = numpyro.sample(\"theta\", dist.Beta(1,1))\n", " 12 \n", " 13 skills = []\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)\n", " 157 \n", " 158 # ...and use apply_stack to send it to the Messengers\n", "--> 159 msg = apply_stack(initial_msg)\n", " 160 return msg[\"value\"]\n", " 161 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in apply_stack(msg)\n", " 29 if msg[\"value\"] is None:\n", " 30 if msg[\"type\"] == \"sample\":\n", "---> 31 msg[\"value\"], msg[\"intermediates\"] = msg[\"fn\"](\n", " 32 *msg[\"args\"], sample_intermediates=True, **msg[\"kwargs\"]\n", " 33 )\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/distribution.py in __call__(self, *args, **kwargs)\n", " 300 sample_intermediates = kwargs.pop(\"sample_intermediates\", False)\n", " 301 if sample_intermediates:\n", "--> 302 return self.sample_with_intermediates(key, *args, **kwargs)\n", " 303 return self.sample(key, *args, **kwargs)\n", " 304 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/distribution.py in sample_with_intermediates(self, key, sample_shape)\n", " 573 \n", " 574 def sample_with_intermediates(self, key, sample_shape=()):\n", "--> 575 return self._sample(self.base_dist.sample_with_intermediates, key, sample_shape)\n", " 576 \n", " 577 def sample(self, key, sample_shape=()):\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/distribution.py in _sample(self, sample_fn, key, sample_shape)\n", " 532 batch_shape = expanded_sizes + interstitial_sizes\n", " 533 # shape = sample_shape + expanded_sizes + interstitial_sizes + base_dist.shape()\n", "--> 534 samples, intermediates = sample_fn(key, sample_shape=sample_shape + batch_shape)\n", " 535 \n", " 536 interstitial_dims = tuple(self._interstitial_sizes.keys())\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/distribution.py in sample_with_intermediates(self, key, sample_shape)\n", " 259 :rtype: numpy.ndarray\n", " 260 \"\"\"\n", "--> 261 return self.sample(key, sample_shape=sample_shape), []\n", " 262 \n", " 263 def log_prob(self, value):\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/continuous.py in sample(self, key, sample_shape)\n", " 79 \n", " 80 def sample(self, key, sample_shape=()):\n", "---> 81 assert is_prng_key(key)\n", " 82 return self._dirichlet.sample(key, sample_shape)[..., 0]\n", " 83 \n", "\n", "AssertionError: \n", "```" ] }, { "cell_type": "markdown", "id": "d3d8470e-a381-42da-9a1c-153825a910cf", "metadata": {}, "source": [ "#### model_01\n", "* trying out the two for loops over skills, suggested [here](https://forum.pyro.ai/t/numpyro-chapter-2-mbml/3184/2?u=bdatko) and again [here](https://forum.pyro.ai/t/numpyro-chapter-2-mbml/3184/6?u=bdatko)\n", "* removing beta priors for skills, more like the book" ] }, { "cell_type": "code", "execution_count": 10, "id": "d79fa480-61c4-484c-9fd8-ccd1b83891c5", "metadata": {}, "outputs": [], "source": [ "def model_01a(\n", " graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2\n", "):\n", " n_questions, n_participants = graded_responses.shape\n", " n_skills = max(map(max, skills_needed)) + 1\n", " \n", " participants_plate = numpyro.plate(\"participants_plate\", n_participants)\n", " \n", " skills = []\n", " \n", " for s in range(n_skills):\n", " skills.append([])\n", " for p in range(n_participants):\n", " sample = numpyro.sample(\"skill_{}_{}\".format(s,p), dist.Bernoulli(0.5))\n", " skills[s].append(sample.squeeze())\n", " \n", "\n", " for q in range(n_questions):\n", " has_skills = reduce(operator.mul, [jnp.array(skills[i]) for i in skills_needed[q]])\n", " for p in range(n_participants):\n", " prob_correct = has_skills[p] * (1 - prob_mistake) + (1 - has_skills[p]) * prob_guess\n", " isCorrect = numpyro.sample(\"isCorrect_{}_{}\".format(q,p), dist.Bernoulli(prob_correct), obs=graded_responses[q,p],)" ] }, { "cell_type": "code", "execution_count": 11, "id": "f953f845-eb33-4951-9b6a-324019631330", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|██████████| 1200/1200 [01:10<00:00, 17.00it/s, 1 steps of size 1.19e+37. acc. prob=1.00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", " skill_0_0 0.10 0.30 0.00 0.00 0.00 1084.04 1.00\n", " skill_0_1 0.80 0.40 1.00 0.00 1.00 1625.57 1.00\n", " skill_0_2 0.03 0.18 0.00 0.00 0.00 1029.71 1.00\n", " skill_0_3 0.55 0.50 1.00 0.00 1.00 1126.00 1.00\n", " skill_0_4 0.15 0.36 0.00 0.00 1.00 1129.83 1.00\n", " skill_0_5 0.86 0.35 1.00 0.00 1.00 1472.38 1.00\n", " skill_0_6 0.33 0.47 0.00 0.00 1.00 3153.84 1.00\n", " skill_0_7 0.93 0.26 1.00 1.00 1.00 953.76 1.00\n", " skill_1_0 0.09 0.29 0.00 0.00 0.00 1250.97 1.00\n", " skill_1_1 0.03 0.17 0.00 0.00 0.00 992.63 1.00\n", " skill_1_2 0.80 0.40 1.00 0.00 1.00 1217.17 1.00\n", " skill_1_3 0.55 0.50 1.00 0.00 1.00 1138.18 1.00\n", " skill_1_4 0.15 0.36 0.00 0.00 1.00 1289.65 1.00\n", " skill_1_5 0.33 0.47 0.00 0.00 1.00 2924.36 1.00\n", " skill_1_6 0.86 0.34 1.00 0.00 1.00 1191.89 1.00\n", " skill_1_7 0.93 0.25 1.00 1.00 1.00 1174.43 1.00\n", "\n" ] } ], "source": [ "nuts_kernel = NUTS(model_01a)\n", "\n", "kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)\n", "\n", "mcmc = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=1)\n", "mcmc.run(rng_key, responses_check, skills_needed_check)\n", "mcmc.print_summary()" ] }, { "cell_type": "code", "execution_count": 12, "id": "b2e92a35-3c0b-4ca9-ba34-9a261249bb7b", "metadata": {}, "outputs": [], "source": [ "expected[\"model_01a P(csharp)\"] = [mcmc.get_samples()[key].mean() for key in list(mcmc.get_samples().keys())[:8]]\n", "expected[\"model_01a P(sql)\"] = [mcmc.get_samples()[key].mean() for key in list(mcmc.get_samples().keys())[8:]]" ] }, { "cell_type": "code", "execution_count": 13, "id": "8711f0d2-78b2-4fdc-98a2-a618b82fa41b", "metadata": {}, "outputs": [], "source": [ "def model_01b(\n", " graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2\n", "):\n", " n_questions, n_participants = graded_responses.shape\n", " n_skills = max(map(max, skills_needed)) + 1\n", " \n", " participants_plate = numpyro.plate(\"participants_plate\", n_participants)\n", " \n", " skills = []\n", " \n", " for s in range(n_skills):\n", " skills.append([])\n", " for p in range(n_participants):\n", " sample = numpyro.sample(\"skill_{}_{}\".format(s,p), dist.Bernoulli(0.5))\n", " skills[s].append(sample.squeeze())\n", " \n", "\n", " for q in range(n_questions):\n", " has_skills = reduce(operator.mul, [jnp.array(skills[i]) for i in skills_needed[q]])\n", " for p in range(n_participants):\n", " prob_correct = has_skills[p] * (1 - prob_mistake) + (1 - has_skills[p]) * prob_guess\n", " isCorrect = numpyro.sample(\"isCorrect_{}_{}\".format(q,p), dist.Bernoulli(prob_correct), obs=graded_responses[q,p],)" ] }, { "cell_type": "code", "execution_count": 14, "id": "cafa0404-4592-43ae-9c8b-15675b825400", "metadata": {}, "outputs": [], "source": [ "predictive = Predictive(\n", " model_01b,\n", " num_samples=3000,\n", " infer_discrete=True,\n", ")\n", "discrete_samples = predictive(rng_key, responses_check, skills_needed_check)" ] }, { "cell_type": "code", "execution_count": 15, "id": "4f3ffa23-1ce1-42e1-82c9-2009f609b49e", "metadata": {}, "outputs": [], "source": [ "expected[\"model_01b P(csharp)\"] = [discrete_samples[key].mean() for key in list(discrete_samples.keys())[24:32]]\n", "expected[\"model_01b P(sql)\"] = [discrete_samples[key].mean() for key in list(discrete_samples.keys())[32:]]" ] }, { "cell_type": "markdown", "id": "b73273ab-4656-4bc2-8f65-488987f1ebda", "metadata": {}, "source": [ "#### model_02\n", "* trying not to use the doulbe for loop, so slow\n", "* beta priors for skills\n", "* this model is very similar to the original post on the forum from [`fritzo`'s answer](https://forum.pyro.ai/t/model-based-machine-learning-book-chapter-2-skills-example-in-pyro-tensor-dimension-issue/464/12?u=bdatko) " ] }, { "cell_type": "code", "execution_count": 16, "id": "01a9a7ce-7c2c-485a-9342-061f42f03eed", "metadata": {}, "outputs": [], "source": [ "def model_02(\n", " graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2\n", "):\n", " n_questions, n_participants = graded_responses.shape\n", " n_skills = max(map(max, skills_needed)) + 1\n", " \n", " participants_plate = numpyro.plate(\"participants_plate\", n_participants)\n", " \n", " with participants_plate:\n", " with numpyro.plate(\"skills_plate\", n_skills):\n", " theta = numpyro.sample(\"theta\", dist.Beta(1,1))\n", " \n", " with participants_plate:\n", " skills = []\n", " for s in range(n_skills):\n", " skills.append(numpyro.sample(\"skill_{}\".format(s), dist.Bernoulli(theta[s])))\n", "\n", " for q in range(n_questions):\n", " has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])\n", " prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess\n", " isCorrect = numpyro.sample(\n", " \"isCorrect_{}\".format(q),\n", " dist.Bernoulli(prob_correct).to_event(1),\n", " obs=graded_responses[q],\n", " )" ] }, { "cell_type": "code", "execution_count": 17, "id": "cb067b12-4fba-4302-b0f5-059a4bf16a37", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|██████████| 1200/1200 [00:11<00:00, 108.14it/s, 7 steps of size 4.12e-01. acc. prob=0.91] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", "skill_0[0] 0.09 0.28 0.00 0.00 0.00 916.08 1.00\n", "skill_0[1] 0.81 0.40 1.00 0.00 1.00 811.18 1.00\n", "skill_0[2] 0.03 0.17 0.00 0.00 0.00 771.79 1.00\n", "skill_0[3] 0.58 0.49 1.00 0.00 1.00 447.39 1.00\n", "skill_0[4] 0.13 0.34 0.00 0.00 1.00 593.39 1.00\n", "skill_0[5] 0.86 0.35 1.00 0.00 1.00 734.09 1.00\n", "skill_0[6] 0.33 0.47 0.00 0.00 1.00 973.74 1.00\n", "skill_0[7] 0.94 0.23 1.00 1.00 1.00 1052.75 1.00\n", "skill_1[0] 0.10 0.30 0.00 0.00 1.00 950.95 1.00\n", "skill_1[1] 0.04 0.18 0.00 0.00 0.00 724.80 1.00\n", "skill_1[2] 0.79 0.41 1.00 0.00 1.00 861.81 1.00\n", "skill_1[3] 0.53 0.50 1.00 0.00 1.00 676.18 1.00\n", "skill_1[4] 0.15 0.36 0.00 0.00 1.00 738.70 1.00\n", "skill_1[5] 0.33 0.47 0.00 0.00 1.00 1071.68 1.00\n", "skill_1[6] 0.86 0.35 1.00 0.00 1.00 980.21 1.00\n", "skill_1[7] 0.95 0.23 1.00 1.00 1.00 959.35 1.00\n", "theta[0,0] 0.36 0.25 0.32 0.00 0.74 1065.44 1.00\n", "theta[0,1] 0.61 0.27 0.65 0.19 1.00 851.69 1.00\n", "theta[0,2] 0.35 0.24 0.31 0.00 0.70 1030.36 1.00\n", "theta[0,3] 0.53 0.30 0.54 0.09 0.99 331.72 1.00\n", "theta[0,4] 0.37 0.25 0.34 0.00 0.77 589.42 1.00\n", "theta[0,5] 0.62 0.27 0.67 0.20 1.00 660.04 1.00\n", "theta[0,6] 0.44 0.29 0.43 0.00 0.86 828.42 1.00\n", "theta[0,7] 0.63 0.24 0.66 0.25 0.98 826.83 1.00\n", "theta[1,0] 0.37 0.26 0.33 0.00 0.76 906.31 1.00\n", "theta[1,1] 0.36 0.25 0.33 0.00 0.73 825.56 1.00\n", "theta[1,2] 0.61 0.28 0.64 0.18 1.00 951.56 1.00\n", "theta[1,3] 0.50 0.29 0.51 0.11 1.00 872.60 1.00\n", "theta[1,4] 0.38 0.26 0.35 0.00 0.77 1007.59 1.00\n", "theta[1,5] 0.45 0.29 0.43 0.00 0.86 802.19 1.00\n", "theta[1,6] 0.62 0.26 0.66 0.22 1.00 1071.47 1.00\n", "theta[1,7] 0.66 0.24 0.71 0.30 1.00 1275.90 1.00\n", "\n" ] } ], "source": [ "nuts_kernel = NUTS(model_02)\n", "\n", "kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)\n", "\n", "mcmc = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=1)\n", "mcmc.run(rng_key, responses_check, skills_needed_check)\n", "mcmc.print_summary()" ] }, { "cell_type": "code", "execution_count": 18, "id": "b3962c2a-9c60-40d5-9137-c8c015343dce", "metadata": {}, "outputs": [], "source": [ "expected[\"model_02 P(csharp)\"] = mcmc.get_samples(group_by_chain=False)[\"skill_0\"].mean(0)\n", "expected[\"model_02 P(sql)\"] = mcmc.get_samples(group_by_chain=False)[\"skill_1\"].mean(0)" ] }, { "cell_type": "markdown", "id": "7e3ba1d9-67d7-465f-81e5-6f69fed13287", "metadata": {}, "source": [ "* below the code results in an AssertionError\n", "* trying using `infer_discrete` without NUTS and MCMC\n", "* probably b/c of the beta priors? I am not sure though\n", "\n", "```python\n", "predictive = Predictive(\n", " model_02,\n", " num_samples=3000,\n", " infer_discrete=True,\n", ")\n", "discrete_samples = predictive(rng_key, responses_check, skills_needed_check)\n", "```\n", "\n", "\n", "```python\n", "AssertionError Traceback (most recent call last)\n", " in \n", " 4 infer_discrete=True,\n", " 5 )\n", "----> 6 discrete_samples = predictive(rng_key, responses_check, skills_needed_check)\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in __call__(self, rng_key, *args, **kwargs)\n", " 892 )\n", " 893 model = substitute(self.model, self.params)\n", "--> 894 return _predictive(\n", " 895 rng_key,\n", " 896 model,\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)\n", " 737 rng_key = rng_key.reshape(batch_shape + (2,))\n", " 738 chunk_size = num_samples if parallel else 1\n", "--> 739 return soft_vmap(\n", " 740 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size\n", " 741 )\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size)\n", " 403 fn = vmap(fn)\n", " 404 \n", "--> 405 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)\n", " 406 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)\n", " 407 ys = tree_map(\n", "\n", " [... skipping hidden 15 frame]\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in single_prediction(val)\n", " 702 model_trace = prototype_trace\n", " 703 temperature = 1\n", "--> 704 pred_samples = _sample_posterior(\n", " 705 config_enumerate(condition(model, samples)),\n", " 706 first_available_dim,\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/contrib/funsor/discrete.py in _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs)\n", " 60 with funsor.adjoint.AdjointTape() as tape:\n", " 61 with block(), enum(first_available_dim=first_available_dim):\n", "---> 62 log_prob, model_tr, log_measures = _enum_log_density(\n", " 63 model, args, kwargs, {}, sum_op, prod_op\n", " 64 )\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)\n", " 157 model = substitute(model, data=params)\n", " 158 with plate_to_enum_plate():\n", "--> 159 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)\n", " 160 log_factors = []\n", " 161 time_to_factors = defaultdict(list) # log prob factors\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)\n", " 163 :return: `OrderedDict` containing the execution trace.\n", " 164 \"\"\"\n", "--> 165 self(*args, **kwargs)\n", " 166 return self.trace\n", " 167 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)\n", " 85 return self\n", " 86 with self:\n", "---> 87 return self.fn(*args, **kwargs)\n", " 88 \n", " 89 \n", "\n", " in model_02(graded_responses, skills_needed, prob_mistake, prob_guess)\n", " 9 with participants_plate:\n", " 10 with numpyro.plate(\"skills_plate\", n_skills):\n", "---> 11 theta = numpyro.sample(\"theta\", dist.Beta(1,1))\n", " 12 \n", " 13 with participants_plate:\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)\n", " 157 \n", " 158 # ...and use apply_stack to send it to the Messengers\n", "--> 159 msg = apply_stack(initial_msg)\n", " 160 return msg[\"value\"]\n", " 161 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in apply_stack(msg)\n", " 29 if msg[\"value\"] is None:\n", " 30 if msg[\"type\"] == \"sample\":\n", "---> 31 msg[\"value\"], msg[\"intermediates\"] = msg[\"fn\"](\n", " 32 *msg[\"args\"], sample_intermediates=True, **msg[\"kwargs\"]\n", " 33 )\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/distribution.py in __call__(self, *args, **kwargs)\n", " 300 sample_intermediates = kwargs.pop(\"sample_intermediates\", False)\n", " 301 if sample_intermediates:\n", "--> 302 return self.sample_with_intermediates(key, *args, **kwargs)\n", " 303 return self.sample(key, *args, **kwargs)\n", " 304 \n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/distribution.py in sample_with_intermediates(self, key, sample_shape)\n", " 573 \n", " 574 def sample_with_intermediates(self, key, sample_shape=()):\n", "--> 575 return self._sample(self.base_dist.sample_with_intermediates, key, sample_shape)\n", " 576 \n", " 577 def sample(self, key, sample_shape=()):\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/distribution.py in _sample(self, sample_fn, key, sample_shape)\n", " 532 batch_shape = expanded_sizes + interstitial_sizes\n", " 533 # shape = sample_shape + expanded_sizes + interstitial_sizes + base_dist.shape()\n", "--> 534 samples, intermediates = sample_fn(key, sample_shape=sample_shape + batch_shape)\n", " 535 \n", " 536 interstitial_dims = tuple(self._interstitial_sizes.keys())\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/distribution.py in sample_with_intermediates(self, key, sample_shape)\n", " 259 :rtype: numpy.ndarray\n", " 260 \"\"\"\n", "--> 261 return self.sample(key, sample_shape=sample_shape), []\n", " 262 \n", " 263 def log_prob(self, value):\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/distributions/continuous.py in sample(self, key, sample_shape)\n", " 79 \n", " 80 def sample(self, key, sample_shape=()):\n", "---> 81 assert is_prng_key(key)\n", " 82 return self._dirichlet.sample(key, sample_shape)[..., 0]\n", " 83 \n", "\n", "AssertionError: \n", "```" ] }, { "cell_type": "markdown", "id": "33c33b23-154d-482d-b1a9-5838b8628817", "metadata": {}, "source": [ "#### model_03\n", "* trying not to use the doulbe for loop, so slow\n", "* removing beta priors for skills, most similar to the book\n", "* this model is very similar to the original post on the forum from [`fritzo`'s answer](https://forum.pyro.ai/t/model-based-machine-learning-book-chapter-2-skills-example-in-pyro-tensor-dimension-issue/464/12?u=bdatko) " ] }, { "cell_type": "code", "execution_count": 19, "id": "8cc9706c-e84d-439d-b47b-16d42efdaa35", "metadata": {}, "outputs": [], "source": [ "def model_03(\n", " graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2\n", "):\n", " n_questions, n_participants = graded_responses.shape\n", " n_skills = max(map(max, skills_needed)) + 1\n", " \n", " participants_plate = numpyro.plate(\"participants_plate\", n_participants)\n", " \n", " with participants_plate:\n", " skills = []\n", " for s in range(n_skills):\n", " skills.append(numpyro.sample(\"skill_{}\".format(s), dist.Bernoulli(0.5)))\n", "\n", " for q in range(n_questions):\n", " has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])\n", " prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess\n", " isCorrect = numpyro.sample(\n", " \"isCorrect_{}\".format(q),\n", " dist.Bernoulli(prob_correct).to_event(1),\n", " obs=graded_responses[q],\n", " )" ] }, { "cell_type": "code", "execution_count": 20, "id": "23bdddb1-4f8c-40d6-98d1-a3f9fdbc1d69", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "sample: 100%|██████████| 1200/1200 [00:09<00:00, 126.70it/s, 1 steps of size 1.19e+37. acc. prob=1.00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.0% 95.0% n_eff r_hat\n", "skill_0[0] 0.10 0.30 0.00 0.00 0.00 1084.04 1.00\n", "skill_0[1] 0.80 0.40 1.00 0.00 1.00 1625.57 1.00\n", "skill_0[2] 0.03 0.18 0.00 0.00 0.00 1029.71 1.00\n", "skill_0[3] 0.55 0.50 1.00 0.00 1.00 1126.00 1.00\n", "skill_0[4] 0.15 0.36 0.00 0.00 1.00 1129.83 1.00\n", "skill_0[5] 0.86 0.35 1.00 0.00 1.00 1472.38 1.00\n", "skill_0[6] 0.33 0.47 0.00 0.00 1.00 3153.84 1.00\n", "skill_0[7] 0.93 0.26 1.00 1.00 1.00 953.76 1.00\n", "skill_1[0] 0.09 0.29 0.00 0.00 0.00 1250.97 1.00\n", "skill_1[1] 0.03 0.17 0.00 0.00 0.00 992.63 1.00\n", "skill_1[2] 0.80 0.40 1.00 0.00 1.00 1217.17 1.00\n", "skill_1[3] 0.55 0.50 1.00 0.00 1.00 1138.18 1.00\n", "skill_1[4] 0.15 0.36 0.00 0.00 1.00 1289.65 1.00\n", "skill_1[5] 0.33 0.47 0.00 0.00 1.00 2924.36 1.00\n", "skill_1[6] 0.86 0.34 1.00 0.00 1.00 1191.89 1.00\n", "skill_1[7] 0.93 0.25 1.00 1.00 1.00 1174.43 1.00\n", "\n" ] } ], "source": [ "nuts_kernel = NUTS(model_03)\n", "\n", "kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)\n", "\n", "mcmc = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=1)\n", "mcmc.run(rng_key, responses_check, skills_needed_check)\n", "mcmc.print_summary()" ] }, { "cell_type": "code", "execution_count": 21, "id": "7257c583-33fb-4361-98bc-f753b588ceb6", "metadata": {}, "outputs": [], "source": [ "expected[\"model_03 P(csharp)\"] = mcmc.get_samples(group_by_chain=False)[\"skill_0\"].mean(0)\n", "expected[\"model_03 P(sql)\"] = mcmc.get_samples(group_by_chain=False)[\"skill_1\"].mean(0)" ] }, { "cell_type": "markdown", "id": "14c8ffde-2a07-472a-8086-e43a5978e60f", "metadata": {}, "source": [ "#### model_04\n", "* trying using SVI as suggested [here](https://forum.pyro.ai/t/numpyro-chapter-2-mbml/3184/2?u=bdatko) and again [here](https://forum.pyro.ai/t/numpyro-chapter-2-mbml/3184/5?u=bdatko)\n", "* removed beta priors for skills, most similar to the book\n", "* this model is very similar to the original post on the forum from [`fritzo`'s answer](https://forum.pyro.ai/t/model-based-machine-learning-book-chapter-2-skills-example-in-pyro-tensor-dimension-issue/464/12?u=bdatko) " ] }, { "cell_type": "code", "execution_count": 22, "id": "21dfd601-9d44-4f19-83e0-f1c54acb123e", "metadata": {}, "outputs": [], "source": [ "from numpyro.infer import SVI, TraceGraph_ELBO" ] }, { "cell_type": "code", "execution_count": 23, "id": "21ccdef4-5b66-4f2f-9e43-ea71cd558b04", "metadata": {}, "outputs": [], "source": [ "def model_04(\n", " graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2\n", "):\n", " n_questions, n_participants = graded_responses.shape\n", " n_skills = max(map(max, skills_needed)) + 1\n", "\n", " with numpyro.plate(\"participants_plate\", n_participants):\n", " with numpyro.plate(\"skills_plate\", n_skills):\n", " skills = numpyro.sample(\n", " \"skills\", dist.Bernoulli(0.5), infer={\"enumerate\": \"parallel\"}\n", " )\n", "\n", " for q in range(n_questions):\n", " has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])\n", " prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess\n", " isCorrect = numpyro.sample(\n", " \"isCorrect_{}\".format(q),\n", " dist.Bernoulli(prob_correct).to_event(1),\n", " obs=graded_responses[q],\n", " )\n", "\n", "\n", "def guide_04(\n", " graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2\n", "):\n", " _, n_participants = graded_responses.shape\n", " n_skills = max(map(max, skills_needed)) + 1\n", "\n", " skill_p = numpyro.param(\n", " \"skill_p\",\n", " 0.5 * jnp.ones((n_skills, n_participants)),\n", " constraint=dist.constraints.unit_interval,\n", " )\n", "\n", " with numpyro.plate(\"participants_plate\", n_participants):\n", " with numpyro.plate(\"skills_plate\", n_skills):\n", " skills = numpyro.sample(\"skills\", dist.Bernoulli(skill_p))\n", "\n", " return skills, skill_p" ] }, { "cell_type": "code", "execution_count": 24, "id": "7d0527fe-72eb-400a-bf08-67a7e28a9689", "metadata": {}, "outputs": [], "source": [ "optimizer = numpyro.optim.Adam(step_size=0.05)\n", "\n", "svi = SVI(model_04, guide_04, optimizer, loss=TraceGraph_ELBO())" ] }, { "cell_type": "code", "execution_count": 25, "id": "0a4a2b2b-e66c-41ee-a271-836ec050fa72", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 10000/10000 [00:03<00:00, 3144.86it/s, init loss: 21.2843, avg. loss [9501-10000]: 19.4706]\n" ] } ], "source": [ "svi_result = svi.run(rng_key, 10_000, responses_check, skills_needed_check)" ] }, { "cell_type": "code", "execution_count": 26, "id": "588de0fc-3bee-4d7f-b690-7e3dfc475764", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.02325671, 0.7117256 , 0.13351652, 0.4025505 , 0.04318797,\n", " 0.92429334, 0.04476254, 0.97155374],\n", " [0.05118924, 0.01298811, 0.8495267 , 0.76294196, 0.07470088,\n", " 0.8607131 , 0.90841603, 0.959564 ]], dtype=float32)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# params, state, losses\n", "svi_result.params[\"skill_p\"]" ] }, { "cell_type": "code", "execution_count": 27, "id": "a59fe6be-c312-4438-b76b-4432d42b7032", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAD4CAYAAAD2FnFTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAuBElEQVR4nO3dd3wUZf4H8M83BUJC6L2GGgggvQvSpah4nnpiPxT17GeFH9azHOep51lOz3aeiqCingpIR3oRFAQEJPTQQpHQSdnn98eW7G5mdmd2Z3d2N5/368WLZHbKM9nd+T79EaUUiIiI/CXZnQAiIopNDBBERKSJAYKIiDQxQBARkSYGCCIi0pRidwKsUqtWLZWVlWV3MoiI4sratWuPKKVqa72WMAEiKysLa9assTsZRERxRUR2673GKiYiItLEAEFERJoYIIiISBMDBBERaWKAICIiTQwQRESkiQGCiIg0MUD4WZZ7BDuPnLY7GUREtkuYgXJWue7dVQCAXZNG2ZwSIiJ7sQRBRESaGCCIiEgTAwSAaWvzsDz3iN3JICKKKWyDAPDQ5+sBsN2BiMgbSxABXPT3hXjg03V2J4OIyBYMEAHsPnoGX/60z+5kEBHZggGCiIg0MUAQEZEmBggvZwqL7U4CEVHMYIDwkvPEbLuTQEQUMxggDChxKExetRtFJQ67k0JEFDUMEAZ8+sNeTPxqI95dstPupBARRQ0DhI78k+c8P/92phAAcPxsoV3JISKKOgYIHYNfWuT5+fUFuTamhIjIHgwQOk6eK+3RdLaoxMaUEBHZgwGCiIg0MUCYoexOABFR9DBAEBGRJgYIIiLSxAARhvV7jyPvtzN2J4OIKCK4YJAZ4vvr6DeWAeBCQ0SUmFiCMION1ERUjsR0gBCR5iLynohMszstRETlTdAAISJpIrJaRNaLyCYReTrUi4nI+yKSLyIbNV4bLiJbRSRXRMYDgFJqh1LqllCvF00b9xXgpTlb7U4GEZFljJQgzgMYpJTqCKATgOEi0st7BxGpIyKZfttaapzrAwDD/TeKSDKANwCMAJADYIyI5Bi5gVgx+o1leG1BLrYfPoUPlnFSPyKKf0EDhHI65fo11fXPvzb+IgBfi0gaAIjIOACvapxrMYBjGpfpASDXVWIoBDAVwGgjNyAil4rI2wUFBUZ2jxiHcv5JrnxzOZ769hecL+b0HEQU3wy1QYhIsoisA5APYK5SapX360qpzwHMAjBVRK4DMBbA1SbS0RDAXq/f8wA0FJGaIvIWgM4iMkHrQKXUt0qp26pWrWricqHRa6Pef/wsXPEBBWeLIp4OIqJoMNTNVSlVAqCTiFQD8JWItFdKbfTb5wURmQrgTQAtvEodRojGNqWUOgrgDhPnscXj/yv9U7iDiGKPJyKKc6Z6MSmljgP4HtrtCP0AtAfwFYAnTaYjD0Bjr98bAdhv8hwRpxXF/DEwEFGiMNKLqbar5AARqQRgCIAtfvt0BvAOnO0GfwRQQ0SeNZGOHwC0EpFmIlIBwDUAvjFxfMwRI9GEiCiGGSlB1AewUER+hvNBPlcpNd1vn3QAVymltiulHABuArDb/0QiMgXACgDZIpInIrcAgFKqGMDdAGYD2AzgM6XUplBvKlJOnS+Gw2GsiMCSBBHFu6BtEEqpnwF0DrLPMr/fi+AsUfjvNybAOWYCmBksPXaavGoPamZUwAPDsn22MxYQUSKK6ZHUsWj6zwfsTgIRUVSU+wBx/9SfTB+TNX4GHp32c1jXPV9cgqzxMzB5VZmaOCKimFDuA8T/1oXWWerTNaXDNvb9dtbQMUNeXoTxXzgDS8EZ53iJV+ZtC+n6RESRVu4DhFk7jpwus23roZOGjs3NP4WpP+wNvqOGq99ageGvLA7pWCKiUHA9iDixepfWDCVERJHDEkSEXPfuKmSNn6H52sGCczh04nyUU0REZA5LEBGydvdvuq/1+uv8KKaEiCg0LEHYjAPqiChWMUDEgNfmb8MMjq8gohjDKqYIKy5xICU5cBx+ae6vAICcBgOQkiRoXCM9GkkjIgqIJYgIaznxO3R4cjYOFAQfKzHwxe/R74WFps5fWOzAjsNmZlYnIjKGASIKTp4vxqyNByNy7m7PzsWglxbh8En2iiIiazFAREluvnYu/8ip8B7sJ84Vu/7nSnZEZC0GiCiZvGpPSMftPXbG0H5cfoKIrMYAEcOW5R5BvxcW4ut1++xOChGVQwwQMWzzgRMAgPV7C4LuK1zCjogsxgARg9bu5rxLRGQ/BogY9Ps3V/j8zsIBEdmBASKGOVzzcCzffjTovowhRPFDKYWzhSV2JyMoBogYVnDW2XXV3RYRCEsZRPHjvaU70faJWTh04pzdSQmIAYKIKMpmbHDOvZZncDVKuzBAxDCrZno9WHAOJQ5OG0sUe2L7e8kAkSBEpxXiyKnz6PXX+fjrzM1RThER6XF/W2N9un8GiBh1vrjEVLtC3m/aI66PnykEACzYmm9FshLC0VPn46KBkBKXe9xSjMcHBohYlf3YLBw7bXx+pflb8jHope+xdNsRz7YFWw5h5Q7nmIpYz6lEU9dn5+HyN5bZnQwqx+KlTwkDRAw7fLJsD4dth05q7rv32BnsOHwaT36z0bNt7Adr8Nj/nL87GCF8bNX5OxJFw9kiZwk21r+WDBAxTKtdeeg/FmvueyjIdN8MEESxY9N+Z9d1FePfS64oF8PMfHjW7z3u+bnj03OQU7+Kz+sOh1WpIiKrxHrnQpYgYti5IvNPdQXnALsVO3xHX5c4VMznVoiMOF9cgtPni+1ORkCnzxdrpvH4mUJs91oB8uNVu6OZLNMYIGKY/0M+HAdPnMM7S3ZYdj4iu1z62lK0e3K23ckIqN2Ts9HuydnIeWIWlFLYfOAEZm86iFGvLsXglxZ59jtw3HegnFIKx04X+mzbuK8A//fVBk8G7+t1+5A1fkbYi40ZwQCRYAJ13/xoZWznVoiM+PVQ/KzBfqawBB8s34UR/1yC2z9ai33HA4+c/mzNXnR5Zi42HziBlTuOorDYgRveW4VPVu3Bb2ecvRo/dn2PjUzBEy4GiARzoEB/bpe9x5wfzoMF59Dx6Tm6PaKIKDTr9h7HM9N/8dnm3T7o78S5YpwvLs3ULXZ1U3/gs/W45u2VmPTdljJjJdw1xTe8t9qKJAfEAFEOzdp4AAVnizw5kfLuvaU7seeo9kDDyat26w5CJPJ3xb+W4b2lO322/W/dft39c/NP+T7oXQ9/d+lgW35pJs6OsRMMEHHsN7+6SjLvxLkiPDP9F4x5Z6XmaxO/2ojr311lQ8qovFi985inFKE0xlYH6luywsBSAOFggIhDd33yI3YfPY35Wzh9RriUq6PYyXO+o9afn7kZFzw1BwA8db9EwYTaTzD7sVlB99GaemfMOysj2hbBcRBxaMbPB/Bz3nFPmwKFTivHBgBvLy7t8VVU4sDGfQVo37BqtJJF5OHuveRQ8GmvcDtQcBZt/cY9WYUliDjF4GDM+eIS7D0WvA1BAsyMeKawBJe8thQHA3QAIALCmzrjwr8twMwNB3Vf//On65D92KwyWZqxH6wJ/aJBMECUQ7E+etNKj077Gf1eWIgzhdoDq9xf6IKzRcgaPwMnzulXJwV6zW5Z42dgwpc/252MiDsV4wPkwqG3eJD767ro18MAgBNno/c5ZIAoh15dsA0AcPxsEVZsP+ozstOIZblH8MOuY4b3f27GL7js9aWmrmGV711fqvMGR6XnB1gC8uFpsf0AnrJ6r91JiLj2MT5ALhJOnvMNioUl0Zs3h20Q5dBxV6Pr1+v242tXF7xdk0YZPv46V68eo8e8s2Rn8J0izIpCU6D+7LFk474CzNp4EH8e2hrJSfEysXTo5v1yCAcKzqJfq9ooLHFgzqaDuHtQK7uTZQmtKqtozpjDAEEAgPum/oSJo9qiTmaa3UmxlNnH45CXF+P6Xk0ikpZoueQ1Z2nNoRQeGd4m6tcvLnHgyW824a6BLdGgWqWIX+/WD0vr4JOTBCUOhXH9m6NiSnLErx1px8+W7cq+x0CbmlVYxVTObMgr0Nz+9br9eHH21oDHrtl1zLNCHQCs3HEUhwJUycQDrczYxyv3RD0dkbB422Fbrrts+1FMXrUH47/cEPVru9de11uCN97YPb8mA0Q5c9kb+m0BgT6MSilc+dYKT/USAFzz9kr0fH4+luUeQdb4GT7BI1Zt3KcdIBORQPCv73MN9eIKlcOhcOP7q7E8t3Qlw2jNGnyuKPGXjTWz7HAkMECUMwGDgOv/rPEz8MbCXM193AudeHvz++0AgA1x8PB1V78kEr0JGjfsK8ALs7bijx/8ELFrF5wtwuJfD+POT37Ekm2H0f+FhThfHJ1G1CEvLwq+kx+HQ2HWxoM4fqYQ+SfPoftz8zD3l0OWpCcSgdjuiQnZBkGa/j57K+4a2NLuZFjm77O3WNLL53xxSczVbb8y79eAr5+xqGvoF2vz0KVpdTSrlVHmteNnijxzCrkflEcjPB21XrdQQH8A5IcrduGpb30n0xv34RqsnDAY9aqG1/7W74WFYR2vpTBKwVYPSxDk4V+6GPvBD7hnyk+Wnf/2j9ZEvR+7ewCcVnA4V1RieiK+5RGe+yYU3uMzSjQGuey3aIDfg5+vx/BXfJe81aoCcX+ONu0/Uaa6qbDYgYlfbUC+xnrrVtIrKR/QaTN76PP1YV0v0DT70RCp6jYGCNK1YEs+vl2/HwVnivDFj/vCPt/sTYfw1U9lz1PiULju3ZVYvv2IxlGRc8+Un3DZ68tMHaPVzlJwtgi/nS5EUYnD9lX7It0OdL7YYaoqxT9ezdt8CJNX7cHT3/yifUCE/XuR9qJZxRpr8m47dFIz4Gp5ZX7gUlykmRmXZAYDBHkolF3NCgAe+Gxd2DmsQI6eOo9luUdx39R1lp63uMSheT9u8zebr3uevv5AmW0dn56Dzs/MRauJ3+HvQXqCmTVn08GgVUjRjklaM9/qOeU3yKvY9cANJTOgdYxD5wH+r4W5WLglX7fXnj//Xk+5+Scx9B+L8Y+5xh78eoEn3jFAkMex04Xo8szcMtsPBunKujTXxJc9wNNMKeD3by5H9mPfGT9fAP+cv82S83jbeeR0wNen/mDtaObbPlqLV+YZv481u3+z9PpajnvNbtv7rwvKvO5d7TRp1haf1+51VVkamSF3++FTGP36UszffAivzd+Ga98pO+36fZ+u0zz21QW5+OMHP+DS15fibGEJlFKmSleHTjjbT37c4/v33HnkdLnqCcdGavL4fmt4/eat6Hu+1sIH3OYDgVfMC2VOqh1HTqO4xIGUZO28VaASSzTc/tHaiF/j1Pli7DxyGs1qZeBskLrvcOrGX577K9bnFeCW/+pPRvftev3FeNx2HzuN5blH8Zfp+tVaxQ6H5vvqn58Z+OL3AMzNPBDPWIIgy+j1HLHDhC9/xrwQqpCMuOPjtVgYYC2OE+eKcMsHPwSc1ynaRr26pMxSmOEY+OL3+OqnvKD7OcKp/7Lo4/SfpbuwIMjaKT/s+g3DvBrgzWR11kSo/j8WMEBQUBEfkGThYCClFNbtPR7Rievmbc4POLbgw+W7MH9LvmcsSf6Jc4ZGnC/LPRKwUdThUJrdHo08gzftP1FmKcxAXpu/Dd9vDfxQ/fOnwdulzMSH/JPnMHPDAXy/NR8frthl/MAgSgwmYsfh09h60Pw67Ve+tcL0MfGCVUwU1PbDgevdzSi7ALvCpn3WrYg1bW1e1GZdVUphxY6y3V53+61v3eP5+QCAN6/rghEd6muea8m2w7jhvdV4cGhr3DPYOdGcf2Ce+L+NmLJ6T5nqjUMR6DL6kqtx1n2tdSYmKnx2xmbPz2YKAT2em29ib+Omrc1DnxY1De178SuLdauP3NNtlycsQZBlDp80PzDqo5W7Qxrpq1eqyTU5dXk4npuxWbPh9PO1zqoXBd8VwP40+UdPukscCgVeDbXuRtGdR0uDsf8U6VNWl84R9cv+E57crpm2o99OF+Khz9cHLR14m/vLIVz+hrnuwG5G5+qKVDdNtyILpsi+6f3VZbaZCZzxiAGCLPPAZ8a7wv645zc88fXGoA3J3pRS2HLwBFbvPIY2j8/CMq/eU+7BYj/tPm74fOF6N0iVzYcrdpdZazg33xnAnp+5GR3/MifgwEHvaRaK/R5wI19dgov9Bq0Z0fmZuZi2Ng83/8dYUM4aPwOTV+02fR23Xw8Ze3+vinA1TahNIe52Nb33KdTAGS8YIMhSK7YfDTiQ6tdDJ7EhrwBXv7UCH67YjRKvAUpHgkzN8N/luzD8lSV4zbXgkXviwKXbjuCCp+Zg6bYjMb/imPtBNeNn53iK+ZsPobA4+AC7lhOt6fobinB6tx0/U4QNeQX4Ym0ehvrNnXR1FOvuTXf/9WsX8x/PUV6wDYIs5T2ISqsu9+OVe/Dxyj1IcS1kc+SU8W6hq13VEEu2lZYc1u09jn+4BpKt3nkUSTGe5XH36nGv43Pf1HW4D+vw9ysvAGD/NNW/7D+Bs0XWPgw3HzyBRzTahdzvp5GuqtHkXX3pjtvfbSw7QLI8YIAgWwXqfni2sARzfjmI0Z0aAvANDG7eRfxih0JyjEeI0W8swzd39/XMEeVmtCoGADYfKG3UD6XdJ5CRry6x9HwAgrZUWznflxWKHapMoH76W3umBrEbAwTFrL9M34Qpq/civUIKhubUDZq3/vXQSSTHwToxl72+DBVTfAOZ97KsSikUleg/VUf8s/Qh3v25edYn0GLT1gYfLxFLjMynNebtlZo92BJNTAcIEWkOYCKAqkqpK+1OD5nz0pytqJlRIeTj3WMZxn24Bld0aRh0/3mbjffMsVugNRPeWbIDz8/covt6vFkdoIfSEptWvQukw1Nz8J+buwfcpzwEB8BAI7WINBaRhSKyWUQ2ich9oV5MRN4XkXwR2ajx2nAR2SoiuSIyHgCUUjuUUreEej2y12sLcsvMvR+qL3/chxPlpKHwi7Xhz5xrRolD2bY6m3sNiVgz8Svncqmrdh7D24u325wa+xipsC0G8KBSqi2AXgDuEpEc7x1EpI6IZPpt01pt5gMAw/03ikgygDcAjACQA2CM/zUosRQHmQgpa/yMKKUktuw+ehpbTbRHhKPgbBHumvwj2j85G20enxX8gHLEew2NRCrNmRU0QCilDiilfnT9fBLAZgD+5f2LAHwtImkAICLjALyqca7FALTKmz0A5LpKDIUApgIYbeQGRORSEXm7oKD8zLBIiSsas7G6rdh+FDM2HAg64R6VX6a6fIhIFoDOAHyGjyqlPgcwC8BUEbkOwFgAV5s4dUMA3pPn5AFoKCI1ReQtAJ1FZILWgUqpb5VSt1WtWtXE5Yjojo99Z34NNg6Fyh/DjdQiUhnAFwDuV0qVmTxHKfWCiEwF8CaAFkopM3MeaPU9UUqpowDuMHEeIgpRt2djv0cURZehEoSIpMIZHCYrpb7U2acfgPYAvgLwpMl05AFo7PV7IwCxNXqGiKicMdKLSQC8B2CzUuplnX06A3gHznaDPwKoISLPmkjHDwBaiUgzEakA4BoA35g4noio3IrUCHwjJYi+AG4AMEhE1rn+jfTbJx3AVUqp7UopB4CbAJSZ4UtEpgBYASBbRPJE5BYAUEoVA7gbwGw4G8E/U0ptCvmuiIgobEHbIJRSSxFkSRel1DK/34vgLFH47zcmwDlmApgZLD1ERBQdsT1xDRER2YYBgoiINDFAEBGRJgYIIiLSxABBRESaGCCIiOKcI9RFt4NggCAiinORCQ8MEEREcc/IKnihYIAgIopzLEEQEZEmliCIiEiTQ3+J87AwQBARxTlWMUVIyzqV7U4CEVFY2M01Qga0rm13EoiIwhKh+MAAQUQU/1iCiIgbe2fZnQQiorA4WIKIjCY107HhqWF2J4OIKGQNq1WKyHnLfYAgIop3Kcn2rUlNRETlEAMEEVGcYy+mCBKJTPGMyqcKyfpfq0jVFRNFAgOEl4wKyXYnwbBalSvYnQTSEyC/EakBTYnmk1t72p0EAgNE3Ppd54Z2J4F01MmsqPtakxrpUUxJ/OrTshZqVdb/O5KvaumpETkvAwSRxbJqZui+9vDF2ZZdZ9IVHSw7V6Lo06Imbrmwmd3JiLpG1SOT8WCAiAPP/a693Unw0aRGOv7QrbHmawOyrZ+6pEpaiuXn1HJZxwYRv0a3rBqWnWt4+3q4sGUty86XCCpXTMHjl+RgcJs6diclITBAxIFAOVI7XNqxPto3rBK16w3I5pddi0CgDE6x0D2reoRT45RqaX988+01w9rVAwC0rMtJOK3AABFDsutm2p0EQ1Ug9wxqBej0/IpEf7BnY6wEpaVNvdL3LthDcmSHepZfv26VwPX1ofTUa9fAfCZAIvIJMO7Kro1svX6iYYCwyehOZasz/tC9Mcb0aGLo+Eh1ze3RLHgVSFpqsucx0D2rOkZ1qO95LRJ9dFKSjN/rbf2bh3ydUP+kzWv7lvD+duUFga8T5CG69NGBhq/tPteLV3UMsp++jo2raW5/5Q+dDKfDDukR7nX45KU5uq/dWk7aORggAKSlOP8M9w9pbXNKgBa1jVUnJUUoQLSpl4ldk0YZ3r9lHftLPd7+b2TbkI8NtQfqzHv7Iccrt10nM83wsclJggZVffdvVD0dHRpWBQA8e3ng0lPdKs5j0ysEbqdxf1xaaax/klWzbAPnkLZ10CqEEq3RKi9D59I5Vd+WNQEA9w9pFdb5nwoQAIL5fTkpqTBAAEhJTsKuSaMwLozcp1mddXJtFVPL5or8Q8HNfbJw96CW1icKQGZa4O5y7u50pfHJ91tcHoccpqUm4/nfBe9RNGVcL5/fL2xZC3P/3B8pGgPrPhnXE/++oSuu79XUs23+gxf57iTAX0a3w8tXd0TXpsbaGFp7VYUNzK6NRQ8PQCWNz1zvFtqN32mpkX9kdHPdi96YkVauTElykn5a3CUr76o/f52bBP6bBcowtK1fxVRGKl4xQNjkpj5ZmtuNPGCfuqwdKleMTs8ef4kyzqu/xQtFpWk8ZHW53uQ/dG+M5rUra1ZtZaal4uJ2vm0VLWqXzf1nVEzBFV1M5Ga93r///LEHmtbMwISRbXFBI2eJpU29TCwfPwhj+2YZP6cX/+qzt2/oqrlflybVdM9xn6tk4E7qU5fm4J/XdCq9husSKsCH8YouznFCI9rX190n3nwyLvqDBxkg/Nzevzlu8Mq1mbHmsSGG99VrQ3B/5K/t6dsWcWPv0NJk1Idje2DRwwPCPs9jl5QW2+tXNV7VAjhHz3p3TxyWUxd1MiuiYkrodc0VUrQ/4t0N5LjvGtgi5OuakRxidaGZwwK1e1StlIqHhjk7J9TOrIgG1SqF3sbld9iwdtoN8l/e2Vf3FO6/h8O1yMHlnRuiZkZpI7yRhvDWdZ1Vpc0MVtmaUbVSZAalBeOuTvTnHTytxgDhZ8LItujToqbndzMNpFaO/PS/qqkcqgEPDvVtb+nfujaaanSnrej3gHXn2txfUqVK651fv7azTy43WMOpvz4ta+GtG7p6Gh9/17khVk8cgmQT74G/R3R6ZfVrXRuNa/jOi+T/THz44jYhX1eL3jPXSMcAt3du7GZRaswLVF3jlhrGe+XPnVkSMd6d1593avx76IUSA398fKipTgRW0iswje4UuVkVGCBiiAJQK8M5x5JP7lsCN0q/d5P5h8agtuGNLXA3ovb1Gqjln7ML5VGRmpyE/q2sqf4J9ABQSlnerbh6emrAXlSeakG/L3qwAXrfPzQA3959IQBgaE5dTL/nQtzcJwuZJqoZ63l9nu4d1FJ3QsFAVYhTb+ul/6JLDQNzhAXN8XqqkJz/Jwl82knE73Wj7hro224XSpfcGhkVgrbTJRIGCA12Tu46vH09vHldF9xxUWn1Rt0qaQHT1KyW+WJ0qP3V3Y2mHRpVxfonh+HSAA+3BkFmLq2REdkJB3c8P9Lz86gLIl8X/dMTw3R7Ub17Yze0dwVVN6Ofs6xaGejQqPTY9g2r4qnL2pmqBhrslSF4YFg2fn1uhOm0VEuvEPSh7P36xe3qAgBy6vuOp/DP8ep1V3U3UouITyO8VV/PtvUDZxAGh5mJsoLRXo2RwgARRLSDhYhgRIf6Pj1bvKttvOs/QwkMpdcp/TlQTloBWPTwAHw4tge2Pz/Sp5juTsvQHOeDINuvCiKrVgaWjx+k2UtmTI8mYc9I27x2Bp65vL1nlPAQ1xd63gP9Me2O3j4P0Hp+9bfuijJvQ9rWDSs9gQzJKXturYetd8bArLF9w++bb7QqR68mqalXl9mrujqnY+nYuKr2zgDaN6yCFeMH+2zzrr50X0srGCoofDKuJ968rovu+QN9f7V6j3nTqnL195+buwfdJxQ9mtVwjjG6wDsDFv0eIgwQYQh1bv9QGrncX5ZQB4Ld1r851j0xVPO1QPXaAucXpX/r2khOEs0v6u86N8KWZ4ajpUYfe2eDp++22pkV8dcrOqCSTt99d8mjit/f6Sq/vuepSUm4oVdTT/XbWNfgpZZ1MsvMeeT/MNZ6OF/asQG2PjscKycMDlid0tNEm4Emv7+HdxtLOLNyPmGgX79V4xR8H1yl+rasVWY2W71Sx9rHhmDaHX3KLJfZqLrz/W/tmi7Dv3rVu4qpT4taGNFBv3RopqQcSoAd2KaO5T3iAOCz23vj8zv64M4B0ekooYcBwoQB2bWxa9IodGyknyMKZskjA8PqLeT9XQnUzc9fo+qVUC29gqe47z5Pdt1MNNEYKGVWuI3o13n12np0RDbeuLaLT/vG1meH42+/DzxC2QoVU5JRr2oaejWvqfn6RRF4GHT3CmbhFlhXThiM1RMHl9lu1RQY7lLrpCs64Ms7+2D1/5XN/bur0oKVvmtWrugcle+13/1DWqGxa0r0D/7YA5Nv7Vnms5XkCqhGSvd6+8z9c3+f37c9NwLjR7QJaR2KN67tbPoYf5d2bKA5Fbz73p3dgqNf980AoSGSff0b10hHtXTtqhUzD3yzru/pbDv49PZeWPTwAE/X0bpBuqJaXcVWr0oa/nJZuzLbvR/IFVOSy7QZVExJ9jwYQkmjCLDhqWE+1WlGjuvg1W5wY++m+O/YHpZO2Q04H3jhTBHirV7VNFTX+XwFEiyAuEd7f3xrT7x3UzdkVExBlybVUUen6yVQ+j0y8xnybpeqnlHBJ5OQnCTo0qQa7hrYEmN6NPEZRKjH3QvPvyec/yjx1OQkVEhJQh+N2XEDre8BBB9cakTb+plY/Ih276jtz4/EtDv6wI4qJntGW8U7G1uxvWOImUZK98M1My3V84F+5Q+ddKfnrpCShMJiB9o3CL205Oad5pVeOc6XrroAQ15eHPb5jcpMS0V6RXMlnc/v6I3CEgcKix2o5qry6pZVA7smjULW+BmRSGbYtD4V7l5MbeoFnoBPL4/irsapVbkiBgdpqwnn2xGo+nW7V6eDvxpcC2Ngdh08NKw1buidFVJ6Nj19Mc4VlaDrs/NCOt6oQHnDcLp5h4slCA1WPf/H9NBeMyEUWmkKt8RxeeeGuqWZZq4GOiP12v5evKojLteYjNBfyzqZeHS4c6xBsB5P1jP2t0tLTUaVtFTUqlwxaKNmLOvatDq+vqtvma6eblZ95rU/p8aPt3pNjqQkwd2DWukGnl7NA7cnZVRM0Zz+Jhj/8UPxiiWIIJxF79AexF2b1sCU1XsN799OI7funmkz0JdMRHDngBZYueMoftxz3GQqtbkbM/VGIgdyZddGhqddvr1/c/RvXUvz3q3kqe7w+z0c393Xr0xDeiy6upvzvdCbtRUAujSpjq5Nq+OxUdoZgkjmYd2N0G3qZUZslmI979/cHUdPFYZ9nowKyThdWGLqmBd+fwEe+eJnw/vbMc1NYoS5CPLu9eF5gwy+U2Y+6q+O6YzeLXwbRpc+OhBT/OZf0StJPDK8Db68s2/I04SUPafrehFuGEtKElPBoV+rsnXEnRo7u7kGqyu2Wtv6VXR7sr1zYzc8M7psW0u0VUtPxQtXBh/RXqlCMr74Ux+fWWkBYMKINp7XIyUtNRnv3dQNH4fQQByu9AopnkZxPUbmPVs+fjBWTijbOSCQq7uX1jBEsv0xHAwQYbAys6M1JqBR9fSg0zj7e/qydmUaUUOZ19/dAFjP5HxKkfbOjd3w/s2+3XIfGtYa393Xz9DU4+6uuJEeDTs0p66pem8rHxBW5sLHXtgMDwxtrVs1pWX8iDbo3KSaJ8NjJDmD29a1dKqaaKuanhpz3xUrMEAEESgHHew7Ha3SsvcDISlJ8Ce/wVahNHLd1CcLuyaNsnRisreu157Z04y01OQy7RUpyUloW9/Y6md/Gd0eH47tgex6majp6jHz8MXZEZ3wzAyrPzPhni41OQn3Dm5lqhtzq7qZ+OrOvshw5bxjNHNsi091xtfE6t+IbRAWuKBRVd26W6uEM8CpU4C652jq31p7jYFIcleNuKdzSEtN9gxsevySHLRrUAXX92oa9bpvPVY/KGL0uVMupSYLejaviSWPDMT+42ftTo4hLEFoMvewGNevuakZOTWF8E3Ocg1a0lvERQTYNWlU0DrWSLNylTGzru7WGA9fnK25wFJGxRTc0Dsr6sGhpWvqFO82k1gJUJE0pkeTiE1NEcvcPfXcGtdIR0/XuB/v6e3NmH7PhWGnywgGiBCUTkMcvWtqVXW9OqYz3r+5G+pX9a1yidVcox0L2qcmJ+GugS0tny49HPcOboWpt/XyPCSAyDVSRqPq4rnflS6LqvUOuzMog9vUwcAQH4jxYsq4Xph5bz+8cW3p/FC3BFi/2r2KpX8HFS3ut7JlncplJn6MFFYxmRDRgBDCuaukpWJQG/1BS4mfJ41PyUmiO5WHFZ+xaL/v1/Vsit1Hz+DtxTs0X7/johbIaVAFAyIwTUmscT/ovXuDBSpB92peE7nPjTA1xiaa7y9LECHw70sfjZKEmWoad3ICTcVNiSspSTDHNddQtEq57llcG1Yv2+03OUkwMLtOuahG8+bf9qdXgo7lAZgsQQRj4DMdajG+UmoyzhaZG1xjpJomKUnw4+NDkZmWmG+ve+zBrf3Cn946VlhdFVTb1WU0Wr1jru3RBC1qVw5/ptsoiNb34sNbeiDvmHWN0e5vfjQDSmI+QcKkt8yo/5ctWIYo2OuLHh6AQS8twqnzxSZSZ0ykF+OxU2ZaKnZNGmV3MmKaHeuY6FWbRcOCBy/CgYJzQfdbPXFwSO1R0++5EEUlDlPHVElLRU6DVM/a2g8Max3kiMBa1qmMuwa2wDXdmwTf2SIMEBqCNaRZlSmrUyUN7RtWwcodxwwfY2ePIIqcclb7YrnmtSujee2y65H4q5MZ2mC2cBqFk5LEkgyNiFi+TnowsVv5ZSM7Z0/UY0cPIIp/sTqFA8UHliDiRDyXHPiMCo5/o9jRvHZGyKtFJhoGiCAC5dvtqBaI55IEq1GCs/r9LW89h6yw4MEBdichZjBAECUwVjHFl6m39YqpDiYMEEHUzKiA/Tq9I1Jd3c38F1UnX96LzJM291oNbeoHn5HWiHguaZZndvYE08JG6iA+vb03RrSv57Pt8Uty0KJ2Bibf2hPj+jXDxe20RzNb+SVt7pp3SWth81j3xZ/64E8DWujOGUXOQY1LHx2Ifq0Sf7QxWad6emSnrWcJQsfMe/sh/+Q5NK6Rjqu7N8Z3Gw96XuueVQPzXfWUEyM8i6vb1d0ao0XtyujatHpUrmeldg2qRnzFuETQqLp1wb+iKxiPuqC+Zeek2PLW9V3RvqGxae5DxSydjpwGVTAgO7yJxUSA18Z0DriP+4EfrH+2iKBbVg02OpIhaanJ+PHxoXhmdPvgO1NcGt6+nqWZCi0MEAa0cq1CFsrcRsGOeWBoNuY90N+z0hmRVWpkVIjpeX4o9rGKyYBG1dOx/fmRYQ+g05qnJjlJDC2VSUQUbcxeGBSLo6uJiCKJAYKIiDQxQBARkSYGCCIi0sQAQUREmhggouCf13QCANw3pJW9CSEiMoHdXKNgdKeGGNmhvmfuJiKieMAnVpQwOBBRvOFTK4I4LQYRxTMGCCIi0sQAQUREmhggiIhIEwMEERFpYoAgIiJNDBAR1Lh6JbuTQEQUMg6Ui5DZ9/dHdj2u80BE8YsliAhhcCCieMcAQUREmhggiIhIEwMEERFpYoAgIiJNDBBERKSJ3VwtNvnWnthz7IzdySAiChsDhMX6tqyFvnYngojIAqxiIiIiTQwQRESkiQGCiIg0MUAQEZEmBggiItLEAEFERJoYIIiISBMDBBERaWKAICIiTRxJTRRBz17eHh0aVrU7GUQhYYAgiqDrezW1OwlEIWMVExERaWKAICIiTQwQRESkiQGCiIg0MUAQEZEmBggiItLEAEFERJoYIIiISJMopexOgyVE5DCA3SEeXgvAEQuTEw94z+UD7znxhXu/TZVStbVeSJgAEQ4RWaOU6mZ3OqKJ91w+8J4TXyTvl1VMRESkiQGCiIg0MUA4vW13AmzAey4feM+JL2L3yzYIIiLSxBIEERFpYoAgIiJN5T5AiMhwEdkqIrkiMt7u9IRKRBqLyEIR2Swim0TkPtf2GiIyV0S2uf6v7nXMBNd9bxWRi722dxWRDa7XXhURseOejBKRZBH5SUSmu35P6HsWkWoiMk1Etrje796JfM8i8mfXZ3qjiEwRkbREvF8ReV9E8kVko9c2y+5TRCqKyKeu7atEJCtoopRS5fYfgGQA2wE0B1ABwHoAOXanK8R7qQ+gi+vnTAC/AsgB8AKA8a7t4wH8zfVzjut+KwJo5vo7JLteWw2gNwAB8B2AEXbfX5B7fwDAJwCmu35P6HsG8F8At7p+rgCgWqLeM4CGAHYCqOT6/TMANyfi/QLoD6ALgI1e2yy7TwB3AnjL9fM1AD4Nmia7/yg2vyG9Acz2+n0CgAl2p8uie/sawFAAWwHUd22rD2Cr1r0CmO36e9QHsMVr+xgA/7b7fgLcZyMA8wEMQmmASNh7BlDF9cAUv+0Jec+uALEXQA04l0ieDmBYAt9vll+AsOw+3fu4fk6Bc/S1BEpPea9icn/43PJc2+Kaq+jYGcAqAHWVUgcAwPV/Hdduevfe0PWz//ZY9QqARwA4vLYl8j03B3AYwH9c1WrvikgGEvSelVL7ALwIYA+AAwAKlFJzkKD3q8HK+/Qco5QqBlAAoGagi5f3AKFVBxnX/X5FpDKALwDcr5Q6EWhXjW0qwPaYIyKXAMhXSq01eojGtri6Zzhzfl0AvKmU6gzgNJxVD3ri+p5dde6j4axGaQAgQ0SuD3SIxra4uV8TQrlP03+D8h4g8gA09vq9EYD9NqUlbCKSCmdwmKyU+tK1+ZCI1He9Xh9Avmu73r3nuX723x6L+gK4TER2AZgKYJCIfIzEvuc8AHlKqVWu36fBGTAS9Z6HANiplDqslCoC8CWAPkjc+/Vn5X16jhGRFABVARwLdPHyHiB+ANBKRJqJSAU4G26+sTlNIXH1VHgPwGal1MteL30D4CbXzzfB2Tbh3n6Nq2dDMwCtAKx2FWNPikgv1zlv9DompiilJiilGimlsuB87xYopa5HYt/zQQB7RSTbtWkwgF+QuPe8B0AvEUl3pXMwgM1I3Pv1Z+V9ep/rSji/L4FLUXY3ytj9D8BIOHv8bAcw0e70hHEfF8JZXPwZwDrXv5Fw1jHOB7DN9X8Nr2Mmuu57K7x6dADoBmCj67XXEaQhKxb+ARiA0kbqhL5nAJ0ArHG91/8DUD2R7xnA0wC2uNL6EZw9dxLufgFMgbOdpQjO3P4tVt4ngDQAnwPIhbOnU/NgaeJUG0REpKm8VzEREZEOBggiItLEAEFERJoYIIiISBMDBBERaWKAICIiTQwQRESk6f8BGtVbYuQ6EnoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.semilogy(np.array(svi_result.losses))" ] }, { "cell_type": "code", "execution_count": 28, "id": "da701767-6f29-4004-8bbc-8e95fce39128", "metadata": {}, "outputs": [], "source": [ "expected[\"model_04 skill_01 P(csharp)\"] = np.array(svi_result.params[\"skill_p\"][0])\n", "expected[\"model_04 skill_02 P(sql)\"] = np.array(svi_result.params[\"skill_p\"][1])" ] }, { "cell_type": "markdown", "id": "68df5128-939b-415e-b58f-ef8b0c5ecc14", "metadata": {}, "source": [ "#### Final Result" ] }, { "cell_type": "code", "execution_count": 29, "id": "a4391f2e-5455-4f98-8f72-ec5f00b85e06", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IsCorrect1IsCorrect2IsCorrect2P(csharp)P(sql)model_00 P(csharp)model_00 P(sql)model_01a P(csharp)model_01a P(sql)model_01b P(csharp)model_01b P(sql)model_02 P(csharp)model_02 P(sql)model_03 P(csharp)model_03 P(sql)model_04 skill_01 P(csharp)model_04 skill_02 P(sql)
0FalseFalseFalse0.1010.1010.090.1030.0990.0910.00133333330.493666680.0860.1010.0990.0910.0232570.051189
1TrueFalseFalse0.8020.0340.8040.0320.7950.0310.497333350.491333340.8050.0350.7950.0310.7117260.012988
2FalseTrueFalse0.0340.8020.0280.7950.0340.8040.503666640.4950.0300.7900.0340.8040.1335170.849527
3TrueTrueFalse0.5610.5610.5550.5380.5530.5470.503333330.501333360.5810.5340.5530.5470.4025500.762942
4FalseFalseTrue0.1480.1480.1380.1510.150.1510.5090.4890.1330.1500.1500.1510.0431880.074701
5TrueFalseTrue0.8620.3260.8670.3390.8620.3250.4940.4990.8580.3280.8620.3250.9242930.860713
6FalseTrueTrue0.3260.8620.3210.8560.330.8640.511333350.516666650.3270.8560.3300.8640.0447630.908416
7TrueTrueTrue0.9460.9460.940.9430.930.9310.4960.496666670.9420.9460.9300.9310.9715540.959564
\n", "
" ], "text/plain": [ " IsCorrect1 IsCorrect2 IsCorrect2 P(csharp) P(sql) model_00 P(csharp) \\\n", "0 False False False 0.101 0.101 0.09 \n", "1 True False False 0.802 0.034 0.804 \n", "2 False True False 0.034 0.802 0.028 \n", "3 True True False 0.561 0.561 0.555 \n", "4 False False True 0.148 0.148 0.138 \n", "5 True False True 0.862 0.326 0.867 \n", "6 False True True 0.326 0.862 0.321 \n", "7 True True True 0.946 0.946 0.94 \n", "\n", " model_00 P(sql) model_01a P(csharp) model_01a P(sql) model_01b P(csharp) \\\n", "0 0.103 0.099 0.091 0.0013333333 \n", "1 0.032 0.795 0.031 0.49733335 \n", "2 0.795 0.034 0.804 0.50366664 \n", "3 0.538 0.553 0.547 0.50333333 \n", "4 0.151 0.15 0.151 0.509 \n", "5 0.339 0.862 0.325 0.494 \n", "6 0.856 0.33 0.864 0.51133335 \n", "7 0.943 0.93 0.931 0.496 \n", "\n", " model_01b P(sql) model_02 P(csharp) model_02 P(sql) model_03 P(csharp) \\\n", "0 0.49366668 0.086 0.101 0.099 \n", "1 0.49133334 0.805 0.035 0.795 \n", "2 0.495 0.030 0.790 0.034 \n", "3 0.50133336 0.581 0.534 0.553 \n", "4 0.489 0.133 0.150 0.150 \n", "5 0.499 0.858 0.328 0.862 \n", "6 0.51666665 0.327 0.856 0.330 \n", "7 0.49666667 0.942 0.946 0.930 \n", "\n", " model_03 P(sql) model_04 skill_01 P(csharp) model_04 skill_02 P(sql) \n", "0 0.091 0.023257 0.051189 \n", "1 0.031 0.711726 0.012988 \n", "2 0.804 0.133517 0.849527 \n", "3 0.547 0.402550 0.762942 \n", "4 0.151 0.043188 0.074701 \n", "5 0.325 0.924293 0.860713 \n", "6 0.864 0.044763 0.908416 \n", "7 0.931 0.971554 0.959564 " ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "expected" ] }, { "cell_type": "markdown", "id": "59c28b09-3726-49c7-bbba-92861dfe8b88", "metadata": {}, "source": [ "#### model_05\n", "* trying explict config_enumerate\n", "* same as `model_03`" ] }, { "cell_type": "code", "execution_count": 30, "id": "44f2c934-c2d3-494d-8afb-f13612f63b12", "metadata": {}, "outputs": [], "source": [ "from numpyro.contrib.funsor import config_enumerate" ] }, { "cell_type": "code", "execution_count": 31, "id": "19725234-7f12-498a-8122-d4549b2b896e", "metadata": {}, "outputs": [], "source": [ "@config_enumerate\n", "def model_05(\n", " graded_responses, skills_needed: List[List[int]], prob_mistake=0.1, prob_guess=0.2\n", "):\n", " n_questions, n_participants = graded_responses.shape\n", " n_skills = max(map(max, skills_needed)) + 1\n", " \n", " participants_plate = numpyro.plate(\"participants_plate\", n_participants)\n", " \n", " with participants_plate:\n", " skills = []\n", " for s in range(n_skills):\n", " skills.append(numpyro.sample(\"skill_{}\".format(s), dist.Bernoulli(0.5), infer={\"enumerate\": \"parallel\"}))\n", "\n", " for q in range(n_questions):\n", " has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])\n", " prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess\n", " isCorrect = numpyro.sample(\n", " \"isCorrect_{}\".format(q),\n", " dist.Bernoulli(prob_correct).to_event(1),\n", " obs=graded_responses[q],\n", " )" ] }, { "cell_type": "markdown", "id": "4d45ee14-f2e6-40d7-ab3d-87c1e2188496", "metadata": {}, "source": [ "* trying out `@config_enumeration`, shouldn't work\n", "\n", "```python\n", "nuts_kernel = NUTS(model_05)\n", "\n", "kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)\n", "\n", "mcmc = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=1)\n", "mcmc.run(rng_key, responses_check, skills_needed_check)\n", "mcmc.print_summary()\n", "\n", "```\n", "\n", "```python\n", "AssertionError Traceback (most recent call last)\n", " in \n", " 4 \n", " 5 mcmc = MCMC(kernel, num_warmup=200, num_samples=1000, num_chains=1)\n", "----> 6 mcmc.run(rng_key, responses_check, skills_needed_check)\n", " 7 mcmc.print_summary()\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)\n", " 564 map_args = (rng_key, init_state, init_params)\n", " 565 if self.num_chains == 1:\n", "--> 566 states_flat, last_state = partial_map_fn(map_args)\n", " 567 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)\n", " 568 else:\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)\n", " 353 rng_key, init_state, init_params = init\n", " 354 if init_state is None:\n", "--> 355 init_state = self.sampler.init(\n", " 356 rng_key,\n", " 357 self.num_warmup,\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/hmc_gibbs.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)\n", " 439 and site[\"infer\"].get(\"enumerate\", \"\") != \"parallel\"\n", " 440 ]\n", "--> 441 assert (\n", " 442 self._gibbs_sites\n", " 443 ), \"Cannot detect any discrete latent variables in the model.\"\n", "\n", "AssertionError: Cannot detect any discrete latent variables in the model.\n", "```" ] }, { "cell_type": "markdown", "id": "d3f3079b-a022-49f6-9d29-b065c93f3b3b", "metadata": {}, "source": [ "* trying again with Predictive\n", "\n", "```python\n", "predictive = Predictive(\n", " model_05,\n", " num_samples=3000,\n", " infer_discrete=True,\n", ")\n", "discrete_samples = predictive(rng_key, responses_check, skills_needed_check)\n", "```\n", "\n", "```python\n", "ValueError Traceback (most recent call last)\n", " in \n", " 4 infer_discrete=True,\n", " 5 )\n", "----> 6 discrete_samples = predictive(rng_key, responses_check, skills_needed_check)\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in __call__(self, rng_key, *args, **kwargs)\n", " 892 )\n", " 893 model = substitute(self.model, self.params)\n", "--> 894 return _predictive(\n", " 895 rng_key,\n", " 896 model,\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)\n", " 737 rng_key = rng_key.reshape(batch_shape + (2,))\n", " 738 chunk_size = num_samples if parallel else 1\n", "--> 739 return soft_vmap(\n", " 740 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size\n", " 741 )\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size)\n", " 403 fn = vmap(fn)\n", " 404 \n", "--> 405 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)\n", " 406 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)\n", " 407 ys = tree_map(\n", "\n", " [... skipping hidden 15 frame]\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/infer/util.py in single_prediction(val)\n", " 702 model_trace = prototype_trace\n", " 703 temperature = 1\n", "--> 704 pred_samples = _sample_posterior(\n", " 705 config_enumerate(condition(model, samples)),\n", " 706 first_available_dim,\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/contrib/funsor/discrete.py in _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs)\n", " 60 with funsor.adjoint.AdjointTape() as tape:\n", " 61 with block(), enum(first_available_dim=first_available_dim):\n", "---> 62 log_prob, model_tr, log_measures = _enum_log_density(\n", " 63 model, args, kwargs, {}, sum_op, prod_op\n", " 64 )\n", "\n", "~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)\n", " 238 result = funsor.optimizer.apply_optimizer(lazy_result)\n", " 239 if len(result.inputs) > 0:\n", "--> 240 raise ValueError(\n", " 241 \"Expected the joint log density is a scalar, but got {}. \"\n", " 242 \"There seems to be something wrong at the following sites: {}.\".format(\n", "\n", "ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_1'}.\n", "```" ] } ], "metadata": { "jupytext": { "formats": "ipynb,py" }, "kernelspec": { "display_name": "Python [conda env:numpyro_play]", "language": "python", "name": "conda-env-numpyro_play-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.11" } }, "nbformat": 4, "nbformat_minor": 5 }