{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Bayesian Statistics Made Simple\n", "===\n", "\n", "Code and exercises from my workshop on Bayesian statistics in Python.\n", "\n", "Copyright 2018 Allen Downey\n", "\n", "MIT License: https://opensource.org/licenses/MIT" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# If we're running on Colab, install empiricaldist\n", "# https://pypi.org/project/empiricaldist/\n", "\n", "import sys\n", "IN_COLAB = 'google.colab' in sys.modules\n", "\n", "if IN_COLAB:\n", " !pip install empiricaldist" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "\n", "import seaborn as sns\n", "sns.set_style('white')\n", "sns.set_context('talk')\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "from empiricaldist import Pmf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The Bayesian bandit problem\n", "\n", "Suppose you have several \"one-armed bandit\" slot machines, and reason to think that they have different probabilities of paying off.\n", "\n", "Each time you play a machine, you either win or lose, and you can use the outcome to update your belief about the probability of winning.\n", "\n", "Then, to decide which machine to play next, you can use the \"Bayesian bandit\" strategy, explained below.\n", "\n", "First, let's see how to do the update." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The prior\n", "\n", "If we know nothing about the probability of wining, we can start with a uniform prior." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def decorate_bandit(title):\n", " \"\"\"Labels the axes.\n", " \n", " title: string\n", " \"\"\"\n", " plt.xlabel('Probability of winning')\n", " plt.ylabel('PMF')\n", " plt.title(title)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "bandit = Pmf.from_seq(range(101))\n", "bandit.plot()\n", "decorate_bandit('Prior distribution')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The likelihood function\n", "\n", "The likelihood function that computes the probability of an outcome (W or L) for a hypothetical value of x, the probability of winning (from 0 to 1)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def likelihood_bandit(data, hypo):\n", " \"\"\"Likelihood function for Bayesian bandit\n", " \n", " data: string, either 'W' or 'L'\n", " hypo: probability of winning (0-100)\n", " \n", " returns: float probability\n", " \"\"\"\n", " x = hypo / 100\n", " if data == 'W':\n", " return x\n", " else:\n", " return 1-x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise 1:** Suppose you play a machine 10 times and win once. What is the posterior distribution of $x$?" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multiple bandits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now suppose we have several bandits and we want to decide which one to play.\n", "\n", "For this example, we have 4 machines with these probabilities:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "actual_probs = [0.10, 0.20, 0.30, 0.40]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function `play` simulates playing one machine once and returns `W` or `L`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from random import random\n", "from collections import Counter\n", "\n", "# count how many times we've played each machine\n", "counter = Counter()\n", "\n", "def flip(p):\n", " \"\"\"Return True with probability p.\"\"\"\n", " return random() < p\n", "\n", "def play(i):\n", " \"\"\"Play machine i.\n", " \n", " returns: string 'W' or 'L'\n", " \"\"\"\n", " counter[i] += 1\n", " p = actual_probs[i]\n", " if flip(p):\n", " return 'W'\n", " else:\n", " return 'L'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a test, playing machine 3 twenty times:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "for i in range(20):\n", " result = play(3)\n", " print(result, end=' ')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now I'll make 4 `Pmf` objects to represent our beliefs about the 4 machines." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "prior = range(101)\n", "beliefs = [Pmf.from_seq(prior) for i in range(4)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function displays the four posterior distributions" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "options = dict(xticklabels='invisible', yticklabels='invisible')\n", "\n", "def plot(beliefs, **options):\n", " sns.set_context('paper')\n", " for i, b in enumerate(beliefs):\n", " plt.subplot(2, 2, i+1)\n", " b.plot(label='Machine %s' % i)\n", " plt.gca().set_yticklabels([])\n", " plt.legend()\n", " \n", " plt.tight_layout()\n", " sns.set_context('talk')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": true }, "outputs": [], "source": [ "plot(beliefs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following function updates our beliefs about one of the machines based on one outcome." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def update(beliefs, i, outcome):\n", " \"\"\"Update beliefs about machine i, given outcome.\n", " \n", " beliefs: list of Pmf\n", " i: index into beliefs\n", " outcome: string 'W' or 'L'\n", " \"\"\"\n", " beliefs[i].update(likelihood_bandit, outcome)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise 2:** Write a nested loop that plays each machine 10 times; then plot the posterior distributions. \n", "\n", "Hint: call `play` and then `update`." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After playing each machine 10 times, we can summarize `beliefs` by printing the posterior mean and credible interval:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "for i, b in enumerate(beliefs):\n", " print(b.mean(), b.credible_interval(0.9))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bayesian Bandits\n", "\n", "To get more information, we could play each machine 100 times, but while we are gathering data, we are not making good use of it. The kernel of the Bayesian Bandits algorithm is that it collects and uses data at the same time. In other words, it balances exploration and exploitation.\n", "\n", "The following function chooses among the machines so that the probability of choosing each machine is proportional to its \"probability of superiority\".\n", "\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def choose(beliefs):\n", " \"\"\"Use the Bayesian bandit strategy to choose a machine.\n", " \n", " Draws a sample from each distributions.\n", " \n", " returns: index of the machine that yielded the highest value\n", " \"\"\"\n", " ps = [b.choice() for b in beliefs]\n", " return np.argmax(ps)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function chooses one value from the posterior distribution of each machine and then uses `argmax` to find the index of the machine that chose the highest value.\n", "\n", "Here's an example." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "choose(beliefs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise 3:** Putting it all together, fill in the following function to choose a machine, play once, and update `beliefs`:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def choose_play_update(beliefs, verbose=False):\n", " \"\"\"Chose a machine, play it, and update beliefs.\n", " \n", " beliefs: list of Pmf objects\n", " verbose: Boolean, whether to print results\n", " \"\"\"\n", " # choose a machine\n", " machine = ____\n", " \n", " # play it\n", " outcome = ____\n", " \n", " # update beliefs\n", " update(____)\n", " \n", " if verbose:\n", " print(i, outcome, beliefs[machine].mean())" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's an example" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "choose_play_update(beliefs, verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Trying it out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start again with a fresh set of machines and an empty `Counter`." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "beliefs = [Pmf.from_seq(prior) for i in range(4)]\n", "counter = Counter()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we run the bandit algorithm 100 times, we can see how `beliefs` gets updated:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "num_plays = 100\n", "\n", "for i in range(num_plays):\n", " choose_play_update(beliefs)\n", " \n", "plot(beliefs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can summarize `beliefs` by printing the posterior mean and credible interval:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "for i, b in enumerate(beliefs):\n", " print(b.mean(), b.credible_interval(0.9))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The credible intervals usually contain the true values (10, 20, 30, and 40).\n", "\n", "The estimates are still rough, especially for the lower-probability machines. But that's a feature, not a bug: the goal is to play the high-probability machines most often. Making the estimates more precise is a means to that end, but not an end itself.\n", "\n", "Let's see how many times each machine got played. If things go according to plan, the machines with higher probabilities should get played more often." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "for machine, count in sorted(counter.items()):\n", " print(machine, count)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "**Exercise 4:** Go back and run this section again with a different value of `num_play` and see how it does." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 1 }