{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## The pair of dice problem\n", "\n", "Copyright 2018 Allen Downey\n", "\n", "MIT License: https://opensource.org/licenses/MIT\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Configure Jupyter so figures appear in the notebook\n", "%matplotlib inline\n", "\n", "# Configure Jupyter to display the assigned value after an assignment\n", "%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'\n", "\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import thinkplot\n", "from thinkbayes2 import Pmf, Suite\n", "\n", "from fractions import Fraction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The BayesTable class\n", "\n", "Here's the class that represents a Bayesian table." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class BayesTable(pd.DataFrame):\n", " def __init__(self, hypo, prior=1, **options):\n", " columns = ['hypo', 'prior', 'likelihood', 'unnorm', 'posterior']\n", " super().__init__(columns=columns, **options)\n", " self.hypo = hypo\n", " self.prior = prior\n", " \n", " def mult(self):\n", " self.unnorm = self.prior * self.likelihood\n", " \n", " def norm(self):\n", " nc = np.sum(self.unnorm)\n", " self.posterior = self.unnorm / nc\n", " return nc\n", " \n", " def update(self):\n", " self.mult()\n", " return self.norm()\n", " \n", " def reset(self):\n", " return BayesTable(self.hypo, self.posterior)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The pair of dice problem\n", "\n", "Suppose I have a box that contains one each of 4-sided, 6-sided, 8-sided, and 12-sided dice. I choose two dice at random and roll them without letting you see the die or the outcome. I report that the sum of the dice is 3.\n", "\n", "1) What is the posterior probability that I rolled each possible pair of the dice?\n", "\n", "\n", "2) If I roll the same dice again, what is the probability that the sum of the dice is 11?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Solution**\n", "\n", "I'll start by making a list of possible pairs of dice." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(4, 6), (4, 8), (4, 12), (6, 8), (6, 12), (8, 12)]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sides = [4, 6, 8, 12]\n", "hypo = []\n", "for die1 in sides:\n", " for die2 in sides:\n", " if die2 > die1:\n", " hypo.append((die1, die2))\n", " \n", "hypo" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a `BayesTable` that represents the hypotheses." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
hypopriorlikelihoodunnormposterior
0(4, 6)1NaNNaNNaN
1(4, 8)1NaNNaNNaN
2(4, 12)1NaNNaNNaN
3(6, 8)1NaNNaNNaN
4(6, 12)1NaNNaNNaN
5(8, 12)1NaNNaNNaN
\n", "
" ], "text/plain": [ " hypo prior likelihood unnorm posterior\n", "0 (4, 6) 1 NaN NaN NaN\n", "1 (4, 8) 1 NaN NaN NaN\n", "2 (4, 12) 1 NaN NaN NaN\n", "3 (6, 8) 1 NaN NaN NaN\n", "4 (6, 12) 1 NaN NaN NaN\n", "5 (8, 12) 1 NaN NaN NaN" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "table = BayesTable(hypo)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we didn't specify prior probabilities, the default value is equal priors for all hypotheses. They don't have to be normalized, because we have to normalize the posteriors anyway.\n", "\n", "Now we can specify the likelihoods: if the first die has `n1` sides and the second die has `n2` sides, the probability of getting a sum of 3 is\n", "\n", "`2 / n1 / n2`\n", "\n", "The factor of `2` is there because there are two ways the sum can be 3, either the first die is `1` and the second is `2`, or the other way around.\n", "\n", "So the likelihoods are:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
hypopriorlikelihoodunnormposterior
0(4, 6)10.0833333NaNNaN
1(4, 8)10.0625NaNNaN
2(4, 12)10.0416667NaNNaN
3(6, 8)10.0416667NaNNaN
4(6, 12)10.0277778NaNNaN
5(8, 12)10.0208333NaNNaN
\n", "
" ], "text/plain": [ " hypo prior likelihood unnorm posterior\n", "0 (4, 6) 1 0.0833333 NaN NaN\n", "1 (4, 8) 1 0.0625 NaN NaN\n", "2 (4, 12) 1 0.0416667 NaN NaN\n", "3 (6, 8) 1 0.0416667 NaN NaN\n", "4 (6, 12) 1 0.0277778 NaN NaN\n", "5 (8, 12) 1 0.0208333 NaN NaN" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "for i, row in table.iterrows():\n", " n1, n2 = row.hypo\n", " table.loc[i, 'likelihood'] = 2 / n1 / n2\n", " \n", "table" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can use `update` to compute the posterior probabilities:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
hypopriorlikelihoodunnormposterior
0(4, 6)10.08333330.08333330.3
1(4, 8)10.06250.06250.225
2(4, 12)10.04166670.04166670.15
3(6, 8)10.04166670.04166670.15
4(6, 12)10.02777780.02777780.1
5(8, 12)10.02083330.02083330.075
\n", "
" ], "text/plain": [ " hypo prior likelihood unnorm posterior\n", "0 (4, 6) 1 0.0833333 0.0833333 0.3\n", "1 (4, 8) 1 0.0625 0.0625 0.225\n", "2 (4, 12) 1 0.0416667 0.0416667 0.15\n", "3 (6, 8) 1 0.0416667 0.0416667 0.15\n", "4 (6, 12) 1 0.0277778 0.0277778 0.1\n", "5 (8, 12) 1 0.0208333 0.0208333 0.075" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "table.update()\n", "table" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Part two\n", "\n", "The second part of the problem asks for the (posterior predictive) probability of getting a total of 11 if we roll the same dice again.\n", "\n", "For this, it will be useful to write a more general function that computes the probability of getting a total, `k`, given `n1` and `n2`.\n", "\n", "Here's an example with the `4` and `6` sided dice:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEwhJREFUeJzt3X+QXfV53/H3J1LAMR5jBzaZRBKVPChp5MZNnEVxmppmTGuLaYvSqdRKThPI0FE7idK0SZriDsWOwh+l+UE6A+1YDcQEYgQlTqupt5EZ00lnMi7VAi54UdSsFYrWcsu6YFKcIVjm6R/3MHO5urBnVxddOd/3a0az53zPc8557h3t5549955zU1VIktrwDdNuQJJ07hj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIasn3YDoy699NLavHnztNuQpK8rjzzyyJeqamaluvMu9Ddv3sz8/Py025CkrytJ/lefOk/vSFJDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQ867K3L1Z9P+m++d2r5vu3HvedcHTK+X0T7UFo/0Jakhhr4kNcTQl6SG9Ar9JDuSHE+ymOSGMcuvTPJoktNJdo0suyzJp5IcS/Jkks2TaV2StForhn6SdcDtwNXANmBvkm0jZU8D1wEfH7OJ3wR+qaq+C9gOPHM2DUuS1q7Pp3e2A4tVdQIgySFgJ/DkKwVV9VS37OXhFbsXh/VV9WBX98Jk2pYkrUWf0zsbgJND80vdWB/fAXw5ySeSPJbkl7q/HCRJU9An9DNmrHpufz3wXuDngCuAdzA4DfTqHST7kswnmV9eXu65aUnSavUJ/SVg09D8RuBUz+0vAY9V1YmqOg38B+Ddo0VVdbCqZqtqdmZmxa94lCStUZ/QPwpsTbIlyQXAHuBwz+0fBd6e5JUkfx9D7wVIks6tFUO/O0LfDxwBjgH3V9VCkgNJrgFIckWSJWA38NEkC926X2NwaufTSZ5gcKro370xD0WStJJe996pqjlgbmTspqHpowxO+4xb90HgXWfRoyRpQrwiV5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUkF6hn2RHkuNJFpPcMGb5lUkeTXI6ya4xy9+a5AtJbptE05KktVkx9JOsA24Hrga2AXuTbBspexq4Dvj4a2zmF4HfW3ubkqRJ6HOkvx1YrKoTVfUScAjYOVxQVU9V1ePAy6MrJ/k+4FuBT02gX0nSWegT+huAk0PzS93YipJ8A/ArwD9dfWuSpEnrE/oZM1Y9t/8TwFxVnXy9oiT7kswnmV9eXu65aUnSaq3vUbMEbBqa3wic6rn9HwDem+QngLcAFyR5oape9WZwVR0EDgLMzs72fUGRJK1Sn9A/CmxNsgX4ArAH+GCfjVfVj7wyneQ6YHY08CVJ586Kp3eq6jSwHzgCHAPur6qFJAeSXAOQ5IokS8Bu4KNJFt7IpiVJa9PnSJ+qmgPmRsZuGpo+yuC0z+tt42PAx1bdoSRpYrwiV5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSG97rKpr1/7b753avu+7ca9U9u3+vP/SFs80pekhhj6ktSQXqGfZEeS40kWk5zxdYdJrkzyaJLTSXYNjX9Pks8kWUjyeJK/O8nmJUmrs2LoJ1kH3A5cDWwD9ibZNlL2NHAd8PGR8T8Bfqyq3gnsAH4tydvOtmlJ0tr0eSN3O7BYVScAkhwCdgJPvlJQVU91y14eXrGq/ufQ9KkkzwAzwJfPunNJ0qr1Ob2zATg5NL/Uja1Kku3ABcDnV7uuJGky+oR+xozVanaS5NuAu4Efr6qXxyzfl2Q+yfzy8vJqNi1JWoU+ob8EbBqa3wic6ruDJG8FPgncWFX/bVxNVR2sqtmqmp2Zmem7aUnSKvUJ/aPA1iRbklwA7AEO99l4V/87wG9W1b9fe5uSpElYMfSr6jSwHzgCHAPur6qFJAeSXAOQ5IokS8Bu4KNJFrrV/w5wJXBdks92/77nDXkkkqQV9boNQ1XNAXMjYzcNTR9lcNpndL17gHvOskdJ0oR4Ra4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSG9Qj/JjiTHkywmuWHM8iuTPJrkdJJdI8uuTfKH3b9rJ9W4JGn1Vgz9JOuA24GrgW3A3iTbRsqeBq4DPj6y7jcDHwa+H9gOfDjJ28++bUnSWvQ50t8OLFbViap6CTgE7BwuqKqnqupx4OWRdT8APFhVz1bVc8CDwI4J9C1JWoM+ob8BODk0v9SN9XE260qSJqxP6GfMWPXcfq91k+xLMp9kfnl5ueemJUmr1Sf0l4BNQ/MbgVM9t99r3ao6WFWzVTU7MzPTc9OSpNXqE/pHga1JtiS5ANgDHO65/SPA+5O8vXsD9/3dmCRpClYM/ao6DexnENbHgPuraiHJgSTXACS5IskSsBv4aJKFbt1ngV9k8MJxFDjQjUmSpmB9n6KqmgPmRsZuGpo+yuDUzbh17wTuPIseJUkT4hW5ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNaTX/fS1evtvvndq+77txr1T27e0Vv7OnBu9jvST7EhyPMlikhvGLL8wyX3d8oeTbO7GvzHJXUmeSHIsyYcm274kaTVWDP0k64DbgauBbcDeJNtGyq4Hnquqy4FbgVu68d3AhVX13cD3Af/glRcESdK51+dIfzuwWFUnquol4BCwc6RmJ3BXN/0AcFWSAAVclGQ98E3AS8AfT6RzSdKq9Qn9DcDJofmlbmxsTfdF6s8DlzB4AfgK8EXgaeCXx30xepJ9SeaTzC8vL6/6QUiS+ukT+hkzVj1rtgNfA74d2AL8bJJ3nFFYdbCqZqtqdmZmpkdLkqS16BP6S8CmofmNwKnXqulO5VwMPAt8EPjdqvpqVT0D/D4we7ZNS5LWpk/oHwW2JtmS5AJgD3B4pOYwcG03vQt4qKqKwSmd92XgIuA9wB9MpnVJ0mqtGPrdOfr9wBHgGHB/VS0kOZDkmq7sDuCSJIvAzwCvfKzzduAtwOcYvHj8RlU9PuHHIEnqqdfFWVU1B8yNjN00NP0ig49njq73wrhxSdJ0eBsGSWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDeoV+kh1JjidZTHLDmOUXJrmvW/5wks1Dy96V5DNJFpI8keRNk2tfkrQaK4Z+knUMvvbwamAbsDfJtpGy64Hnqupy4Fbglm7d9cA9wD+sqncCPwR8dWLdS5JWpc+R/nZgsapOVNVLwCFg50jNTuCubvoB4KokAd4PPF5V/wOgqv5vVX1tMq1LklarT+hvAE4OzS91Y2Nrui9Sfx64BPgOoJIcSfJokp8ft4Mk+5LMJ5lfXl5e7WOQJPXUJ/QzZqx61qwH/jLwI93Pv5XkqjMKqw5W1WxVzc7MzPRoSZK0Fn1CfwnYNDS/ETj1WjXdefyLgWe78d+rqi9V1Z8Ac8C7z7ZpSdLa9An9o8DWJFuSXADsAQ6P1BwGru2mdwEPVVUBR4B3JXlz92LwV4AnJ9O6JGm11q9UUFWnk+xnEODrgDuraiHJAWC+qg4DdwB3J1lkcIS/p1v3uSS/yuCFo4C5qvrkG/RYJEkrWDH0AapqjsGpmeGxm4amXwR2v8a69zD42KYkacq8IleSGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SG9Ar9JDuSHE+ymOSGMcsvTHJft/zhJJtHll+W5IUkPzeZtiVJa7Fi6CdZB9wOXA1sA/Ym2TZSdj3wXFVdDtwK3DKy/FbgP599u5Kks9Hnm7O2A4tVdQIgySFgJ6/+rtudwEe66QeA25KkqirJDwMngK9MrOvXsf/me8/Fbs5w2417p7JfSZM1rQyBc5MjfU7vbABODs0vdWNja6rqNPA8cEmSi4B/BvzC2bcqSTpbfUI/Y8aqZ80vALdW1Quvu4NkX5L5JPPLy8s9WpIkrUWf0ztLwKah+Y3AqdeoWUqyHrgYeBb4fmBXkn8FvA14OcmLVXXb8MpVdRA4CDA7Ozv6giJJmpA+oX8U2JpkC/AFYA/wwZGaw8C1wGeAXcBDVVXAe18pSPIR4IXRwJcknTsrhn5VnU6yHzgCrAPurKqFJAeA+ao6DNwB3J1kkcER/p43smlJ0tr0OdKnquaAuZGxm4amXwR2r7CNj6yhP0nSBHlFriQ1xNCXpIYY+pLUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SGGPqS1BBDX5IaYuhLUkMMfUlqiKEvSQ0x9CWpIb1CP8mOJMeTLCa5YczyC5Pc1y1/OMnmbvyvJXkkyRPdz/dNtn1J0mqsGPpJ1gG3A1cD24C9SbaNlF0PPFdVlwO3Ard0418C/mZVfTeD79C9e1KNS5JWr8+R/nZgsapOVNVLwCFg50jNTuCubvoB4KokqarHqupUN74AvCnJhZNoXJK0en1CfwNwcmh+qRsbW1NVp4HngUtGav428FhV/enaWpUkna0+X4yeMWO1mpok72Rwyuf9Y3eQ7AP2AVx22WU9WpIkrUWfI/0lYNPQ/Ebg1GvVJFkPXAw8281vBH4H+LGq+vy4HVTVwaqararZmZmZ1T0CSVJvfUL/KLA1yZYkFwB7gMMjNYcZvFELsAt4qKoqyduATwIfqqrfn1TTkqS1WTH0u3P0+4EjwDHg/qpaSHIgyTVd2R3AJUkWgZ8BXvlY537gcuBfJPls9+9bJv4oJEm99DmnT1XNAXMjYzcNTb8I7B6z3s3AzWfZoyRpQrwiV5IaYuhLUkMMfUlqiKEvSQ0x9CWpIYa+JDXE0Jekhhj6ktQQQ1+SGmLoS1JDDH1JaoihL0kNMfQlqSGGviQ1xNCXpIYY+pLUkF6hn2RHkuNJFpPcMGb5hUnu65Y/nGTz0LIPdePHk3xgcq1LklZrxdBPsg64Hbga2AbsTbJtpOx64Lmquhy4FbilW3cbg+/UfSewA/g33fYkSVPQ50h/O7BYVSeq6iXgELBzpGYncFc3/QBwVZJ044eq6k+r6o+AxW57kqQp6BP6G4CTQ/NL3djYmu6L1J8HLum5riTpHElVvX5Bshv4QFX9/W7+R4HtVfVTQzULXc1SN/95Bkf0B4DPVNU93fgdwFxV/fbIPvYB+7rZ7wSOT+CxrcWlwJemtO/zkc/HmXxOXs3n40zTek7+XFXNrFS0vseGloBNQ/MbgVOvUbOUZD1wMfBsz3WpqoPAwR69vKGSzFfV7LT7OF/4fJzJ5+TVfD7OdL4/J31O7xwFtibZkuQCBm/MHh6pOQxc203vAh6qwZ8Qh4E93ad7tgBbgf8+mdYlSau14pF+VZ1Osh84AqwD7qyqhSQHgPmqOgzcAdydZJHBEf6ebt2FJPcDTwKngZ+sqq+9QY9FkrSCFc/ptyTJvu5Uk/D5GMfn5NV8Ps50vj8nhr4kNcTbMEhSQwx9IMmmJP8lybEkC0l+eto9nQ+SrEvyWJL/NO1epi3J25I8kOQPuv8nPzDtnqYtyT/pfl8+l+TeJG+adk/nUpI7kzyT5HNDY9+c5MEkf9j9fPs0exzH0B84DfxsVX0X8B7gJ8fcaqJFPw0cm3YT54l/DfxuVf154C/S+POSZAPwj4DZqvoLDD7ksWe6XZ1zH2Nwe5lhNwCfrqqtwKe7+fOKoQ9U1Rer6tFu+v8x+IVu+srhJBuBvw78+rR7mbYkbwWuZPApNarqpar68nS7Oi+sB76puzbnzYy5BufPsqr6rww+rThs+JY0dwE/fE6b6sHQH9HdIfR7gYen28nU/Rrw88DL027kPPAOYBn4je50168nuWjaTU1TVX0B+GXgaeCLwPNV9anpdnVe+Naq+iIMDiaBb5lyP2cw9IckeQvw28A/rqo/nnY/05LkbwDPVNUj0+7lPLEeeDfwb6vqe4GvcB7+2X4udeeqdwJbgG8HLkry96bblfow9DtJvpFB4P9WVX1i2v1M2Q8C1yR5isFdVd+X5J7ptjRVS8BSVb3y198DDF4EWvZXgT+qquWq+irwCeAvTbmn88H/SfJtAN3PZ6bczxkMfaC7DfQdwLGq+tVp9zNtVfWhqtpYVZsZvDn3UFU1exRXVf8bOJnkO7uhqxhcZd6yp4H3JHlz9/tzFY2/ud0ZviXNtcB/nGIvY/W54VoLfhD4UeCJJJ/txv55Vc1NsSedX34K+K3u/lMngB+fcj9TVVUPJ3kAeJTBp98e4zy4aeK5lORe4IeAS5MsAR8G/iVwf5LrGbww7p5eh+N5Ra4kNcTTO5LUEENfkhpi6EtSQwx9SWqIoS9JDTH0Jakhhr4kNcTQl6SG/H/vbU+kWX1jogAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "n1, n2 = 4, 6\n", "d1 = Pmf(range(1, n1+1))\n", "d2 = Pmf(range(1, n2+1))\n", "total = d1 + d2\n", "thinkplot.Hist(total)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here's the general function:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def prob_total(k, n1, n2):\n", " d1 = Pmf(range(1, n1+1))\n", " d2 = Pmf(range(1, n2+1))\n", " total = d1 + d2\n", " return total[k]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To check the results, I'll compare them to the likelihoods in the previous table:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4 6 0.08333333333333333 True\n", "4 8 0.0625 True\n", "4 12 0.041666666666666664 True\n", "6 8 0.041666666666666664 True\n", "6 12 0.027777777777777776 True\n", "8 12 0.020833333333333332 True\n" ] } ], "source": [ "for i, row in table.iterrows():\n", " n1, n2 = row.hypo\n", " p = prob_total(3, n1, n2)\n", " print(n1, n2, p, p == row.likelihood)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can answer the second part of the question using the law of total probability. The chance of getting `11` on the second roll is the\n", "\n", "$\\sum_{n1, n2} P(n1, n2 ~|~ D) \\cdot P(11 ~|~ n1, n2)$\n", "\n", "The first term is the posterior probability, which we can read from the table; the second term is `prob_total(11, n1, n2)`.\n", "\n", "Here's how we compute the total probability:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.05364583333333333" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "total = 0\n", "for i, row in table.iterrows():\n", " n1, n2 = row.hypo\n", " p = prob_total(11, n1, n2)\n", " total += row.posterior * p \n", " \n", "total" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This calculation is similar to the first step of the update, so we can also compute it by\n", "\n", "1) Creating a new table with the posteriors from `table`.\n", "\n", "2) Adding the likelihood of getting a total of `11` on the next roll.\n", "\n", "3) Computing the normalizing constant." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
hypopriorlikelihoodunnormposterior
0(4, 6)0.30NaNNaN
1(4, 8)0.2250.0625NaNNaN
2(4, 12)0.150.0833333NaNNaN
3(6, 8)0.150.0833333NaNNaN
4(6, 12)0.10.0833333NaNNaN
5(8, 12)0.0750.0833333NaNNaN
\n", "
" ], "text/plain": [ " hypo prior likelihood unnorm posterior\n", "0 (4, 6) 0.3 0 NaN NaN\n", "1 (4, 8) 0.225 0.0625 NaN NaN\n", "2 (4, 12) 0.15 0.0833333 NaN NaN\n", "3 (6, 8) 0.15 0.0833333 NaN NaN\n", "4 (6, 12) 0.1 0.0833333 NaN NaN\n", "5 (8, 12) 0.075 0.0833333 NaN NaN" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "table2 = table.reset()\n", "for i, row in table2.iterrows():\n", " n1, n2 = row.hypo\n", " table2.loc[i, 'likelihood'] = prob_total(11, n1, n2)\n", " \n", "table2" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.05364583333333333" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "table2.update()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
hypopriorlikelihoodunnormposterior
0(4, 6)0.3000
1(4, 8)0.2250.06250.01406250.262136
2(4, 12)0.150.08333330.01250.23301
3(6, 8)0.150.08333330.01250.23301
4(6, 12)0.10.08333330.008333330.15534
5(8, 12)0.0750.08333330.006250.116505
\n", "
" ], "text/plain": [ " hypo prior likelihood unnorm posterior\n", "0 (4, 6) 0.3 0 0 0\n", "1 (4, 8) 0.225 0.0625 0.0140625 0.262136\n", "2 (4, 12) 0.15 0.0833333 0.0125 0.23301\n", "3 (6, 8) 0.15 0.0833333 0.0125 0.23301\n", "4 (6, 12) 0.1 0.0833333 0.00833333 0.15534\n", "5 (8, 12) 0.075 0.0833333 0.00625 0.116505" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "table2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Using a Suite" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can solve this problem more concisely, and more efficiently, using a `Suite`.\n", "\n", "First, I'll create `Pmf` object for each die." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "dice = {}\n", "for n in sides:\n", " dice[n] = Pmf(range(1, n+1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And a `Pmf` object for the sum of each pair of dice." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "pairs = {}\n", "for n1 in sides:\n", " for n2 in sides:\n", " if n2 > n1:\n", " pairs[n1, n2] = dice[n1] + dice[n2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a `Dice` class that implements `Likelihood` by looking up the data, `k`, in the `Pmf` that corresponds to `hypo`: " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "class Dice(Suite):\n", " \n", " def Likelihood(self, data, hypo):\n", " \"\"\"Likelihood of the data given the hypothesis.\n", " \n", " data: total of two dice\n", " hypo: pair of sides\n", " \n", " return: probability\n", " \"\"\"\n", " return pairs[hypo][data]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's the prior:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 6) 0.16666666666666666\n", "(4, 8) 0.16666666666666666\n", "(4, 12) 0.16666666666666666\n", "(6, 8) 0.16666666666666666\n", "(6, 12) 0.16666666666666666\n", "(8, 12) 0.16666666666666666\n" ] } ], "source": [ "suite = Dice(pairs.keys())\n", "suite.Print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And the posterior:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 6) 0.3\n", "(4, 8) 0.225\n", "(4, 12) 0.15\n", "(6, 8) 0.15\n", "(6, 12) 0.1\n", "(8, 12) 0.075\n" ] } ], "source": [ "suite.Update(3)\n", "suite.Print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And the posterior probability of getting `11` on the next roll." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.05364583333333333" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "suite.Update(11)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }