{ "cells": [ { "cell_type": "markdown", "metadata": { "toc": "true" }, "source": [ "# Table of Contents\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# *Do we even need a smart learning algorithm? Is UCB useless?*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This short notebook demonstrates that \"smart\" Multi-Armed Bandits learning algorithms, like UCB, are indeed needed to learn the distribution of arms, even in the simplest case.\n", "\n", "We will use an example of a small Single-Player simulation, and compare the `UCB` algorithm with a naive \"max empirical reward\" algorithm.\n", "The goal is to illustrate that introducing an exploration term (the confidence width), like what is done in UCB and similar algorithms, really helps learning and improves performance.\n", "\n", "----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Notations for the arms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To remind the usual notations, there is a fixed number $K \\geq 1$ of levers, or \"arms\", and a player has to select one lever at each discrete times $t \\geq 1, t \\in \\mathbb{N}$, ie $k = A(t)$. Selecting an arm $k$ at time $t$ will yield a (random) *reward*, $r_k(t)$, and the goal of the player is to maximize its cumulative reward $R_T = \\sum_{t = 1}^T r_{A(t)}(t)$.\n", "\n", "Each arm is associated with a distribution $\\nu_k$, for $k = 1,\\dots,K$, and the usual restriction is to consider one-dimensional exponential family (it includes Gaussian, Exponential and Bernoulli distributions), ie distributions parametered by their means, $\\mu_k$.\n", "So the arm $k$, $r_k(t) \\sim \\nu_k$, are iid, and assumed bounded in $[a,b] = [0,1]$.\n", "\n", "For instance, arms can follow Bernoulli distributions, of means $\\mu_1,\\dots,\\mu_K \\in [0,1]$: $r_k(t) \\sim \\mathrm{Bern}(\\mu_k)$, ie $\\mathbb{P}(r_k(t) = 1) = \\mu_k$.\n", "\n", "Let $N_k(t) = \\sum_{\\tau=1}^t \\mathbb{1}(A(t) = k)$ be the number of times arm $k$ was selected up-to time $t \\geq 1$.\n", "The empirical mean of arm $k$ is then defined as $\\hat{\\mu_k}(t) := \\frac{\\sum_{\\tau=1}^t \\mathbb{1}(A(t) = k) r_k(t) }{N_k(t)}$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Importing the algorithms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, be sure to be in the main folder, and import `Evaluator` from `Environment` package:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Info: Using the Jupyter notebook version of the tqdm() decorator, tqdm_notebook() ...\n" ] } ], "source": [ "# Local imports\n", "from SMPyBandits.Environment import Evaluator, tqdm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also need arms, for instance `Bernoulli`-distributed arm:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Import arms\n", "from SMPyBandits.Arms import Bernoulli" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And finally we need some single-player Reinforcement Learning algorithms.\n", "I focus here on the `UCB` index policy, and the base class `IndexPolicy` will be used to easily define another algorithm." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Import algorithms\n", "from SMPyBandits.Policies import UCB, UCBalpha, EmpiricalMeans\n", "from SMPyBandits.Policies.IndexPolicy import IndexPolicy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The `UCB` algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we can check the documentation of the `UCB` class, implementing the **Upper-Confidence Bounds algorithm**." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "code_folding": [ 0 ] }, "outputs": [], "source": [ "# Just improving the ?? in Jupyter. Thanks to https://nbviewer.jupyter.org/gist/minrk/7715212\n", "from __future__ import print_function\n", "from IPython.core import page\n", "def myprint(s):\n", " try:\n", " print(s['text/plain'])\n", " except (KeyError, TypeError):\n", " print(s)\n", "page.page = myprint" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[0;31mInit signature:\u001b[0m \u001b[0mUCB\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnbArms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlower\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mamplitude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mDocstring:\u001b[0m \n", "The UCB policy for bounded bandits.\n", "\n", "- Reference: [Lai & Robbins, 1985].\n", "\u001b[0;31mInit docstring:\u001b[0m\n", "New generic index policy.\n", "\n", "- nbArms: the number of arms,\n", "- lower, amplitude: lower value and known amplitude of the rewards.\n", "\u001b[0;31mFile:\u001b[0m /tmp/SMPyBandits/notebooks/venv3/lib/python3.6/site-packages/SMPyBandits/Policies/UCB.py\n", "\u001b[0;31mType:\u001b[0m type\n", "\n" ] } ], "source": [ "UCB?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us quickly have a look to the code of the `UCB` policy imported above." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[0;31mInit signature:\u001b[0m \u001b[0mUCB\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnbArms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlower\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mamplitude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mSource:\u001b[0m \n", "\u001b[0;32mclass\u001b[0m \u001b[0mUCB\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mIndexPolicy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\" The UCB policy for bounded bandits.\u001b[0m\n", "\u001b[0;34m\u001b[0m\n", "\u001b[0;34m - Reference: [Lai & Robbins, 1985].\u001b[0m\n", "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcomputeIndex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34mr\"\"\" Compute the current index, at time t and after :math:`N_k(t)` pulls of arm k:\u001b[0m\n", "\u001b[0;34m\u001b[0m\n", "\u001b[0;34m .. math:: I_k(t) = \\frac{X_k(t)}{N_k(t)} + \\sqrt{\\frac{2 \\log(t)}{N_k(t)}}.\u001b[0m\n", "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0marm\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'+inf'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrewards\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0marm\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0marm\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0marm\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcomputeAllIndex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\" Compute the current indexes for all arms, in a vectorized manner.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mindexes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrewards\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mindexes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'+inf'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mindexes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFile:\u001b[0m /tmp/SMPyBandits/notebooks/venv3/lib/python3.6/site-packages/SMPyBandits/Policies/UCB.py\n", "\u001b[0;31mType:\u001b[0m type\n", "\n" ] } ], "source": [ "UCB??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This policy is defined by inheriting from `IndexPolicy`, which is a generic class already implementing all the methods (`choice()` to get $A(t) \\in \\{1,\\dots,K\\}$, etc).\n", "The only method defined in this class is the `computeIndex(arm)` method, which here uses a UCB index: the empirical mean plus a confidence width term (hence the name \"upper confidence bound\").\n", "\n", "For the classical `UCB` algorithm, with $\\alpha=4$, the index is computed in two parts:\n", "\n", "- the empirical mean: $\\hat{\\mu}_k(t) := \\frac{\\sum_{\\tau=1}^t \\mathbb{1}(A(t) = k) r_k(t) }{N_k(t)}$, computed as `rewards[k] / pulls[k]` in the code,\n", "- the upper confidence bound, $B_k(t) := \\sqrt{\\frac{\\alpha \\log(t)}{2 N_k(t)}}$, computed as `sqrt((2 * log(t)) / pulls[k]` in the code." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then the index $X_k(t) = \\hat{\\mu}_k(t) + B_k(t)$ is used to decide which arm to select at time $t+1$:\n", "$$ A(t+1) = \\arg\\max_k X_k(t). $$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The simple `UCB1` algorithm uses $\\alpha = 4$, but empirically $\\alpha = 1$ is known to work better." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The `EmpiricalMeans` algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can write a new bandit algorithm quite easily with my framework.\n", "For simple index-based policy, we simply need to write a `computeIndex(arm)` method, as presented above.\n", "\n", "The `EmpiricalMeans` algorithm will be simpler than `UCB`, as the decision will only be based on the empirical means $\\hat{\\mu}_k(t)$:\n", "$$ A(t+1) = \\arg\\max_k \\hat{\\mu}_k(t). $$" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[0;31mInit signature:\u001b[0m \u001b[0mEmpiricalMeans\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnbArms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlower\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mamplitude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mDocstring:\u001b[0m The naive Empirical Means policy for bounded bandits: like UCB but without a bias correction term. Note that it is equal to UCBalpha with alpha=0, only quicker.\n", "\u001b[0;31mInit docstring:\u001b[0m\n", "New generic index policy.\n", "\n", "- nbArms: the number of arms,\n", "- lower, amplitude: lower value and known amplitude of the rewards.\n", "\u001b[0;31mFile:\u001b[0m /tmp/SMPyBandits/notebooks/venv3/lib/python3.6/site-packages/SMPyBandits/Policies/EmpiricalMeans.py\n", "\u001b[0;31mType:\u001b[0m type\n", "\n" ] } ], "source": [ "EmpiricalMeans?" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[0;31mInit signature:\u001b[0m \u001b[0mEmpiricalMeans\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnbArms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlower\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mamplitude\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mSource:\u001b[0m \n", "\u001b[0;32mclass\u001b[0m \u001b[0mEmpiricalMeans\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mIndexPolicy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\" The naive Empirical Means policy for bounded bandits: like UCB but without a bias correction term. Note that it is equal to UCBalpha with alpha=0, only quicker.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcomputeIndex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34mr\"\"\" Compute the current index, at time t and after :math:`N_k(t)` pulls of arm k:\u001b[0m\n", "\u001b[0;34m\u001b[0m\n", "\u001b[0;34m .. math:: I_k(t) = \\frac{X_k(t)}{N_k(t)}.\u001b[0m\n", "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0marm\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'+inf'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrewards\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0marm\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0marm\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcomputeAllIndex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\" Compute the current indexes for all arms, in a vectorized manner.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mindexes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrewards\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mindexes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpulls\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'+inf'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mindexes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFile:\u001b[0m /tmp/SMPyBandits/notebooks/venv3/lib/python3.6/site-packages/SMPyBandits/Policies/EmpiricalMeans.py\n", "\u001b[0;31mType:\u001b[0m type\n", "\n" ] } ], "source": [ "EmpiricalMeans??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "----" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating some MAB problems" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Parameters for the simulation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- $T = 10000$ is the time horizon,\n", "- $N = 100$ is the number of repetitions,\n", "- `N_JOBS = 4` is the number of cores used to parallelize the code." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "HORIZON = 10000\n", "REPETITIONS = 100\n", "N_JOBS = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Some MAB problem with Bernoulli arms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We consider in this example $3$ problems, with `Bernoulli` arms, of different means." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "ENVIRONMENTS = [ # 1) Bernoulli arms\n", " { # A very easy problem, but it is used in a lot of articles\n", " \"arm_type\": Bernoulli,\n", " \"params\": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n", " },\n", " { # An other problem, best arm = last, with three groups: very bad arms (0.01, 0.02), middle arms (0.3 - 0.6) and very good arms (0.78, 0.8, 0.82)\n", " \"arm_type\": Bernoulli,\n", " \"params\": [0.01, 0.02, 0.3, 0.4, 0.5, 0.6, 0.795, 0.8, 0.805]\n", " },\n", " { # A very hard problem, as used in [Cappé et al, 2012]\n", " \"arm_type\": Bernoulli,\n", " \"params\": [0.01, 0.01, 0.01, 0.02, 0.02, 0.02, 0.05, 0.05, 0.1]\n", " },\n", " ]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Some RL algorithms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We simply want to compare the $\\mathrm{UCB}_1$ algorithm (`UCB`) against the `EmpiricalMeans` algorithm, defined above." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "POLICIES = [\n", " # --- UCB1 algorithm\n", " {\n", " \"archtype\": UCB,\n", " \"params\": {}\n", " },\n", " # --- UCB alpha algorithm with alpha=1/2\n", " {\n", " \"archtype\": UCBalpha,\n", " \"params\": {\n", " \"alpha\": 0.5\n", " }\n", " },\n", " # --- EmpiricalMeans algorithm\n", " {\n", " \"archtype\": EmpiricalMeans,\n", " \"params\": {}\n", " },\n", " ]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So the complete configuration for the problem will be this dictionary:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'horizon': 10000,\n", " 'repetitions': 100,\n", " 'n_jobs': 4,\n", " 'verbosity': 6,\n", " 'environment': [{'arm_type': SMPyBandits.Arms.Bernoulli.Bernoulli,\n", " 'params': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]},\n", " {'arm_type': SMPyBandits.Arms.Bernoulli.Bernoulli,\n", " 'params': [0.01, 0.02, 0.3, 0.4, 0.5, 0.6, 0.795, 0.8, 0.805]},\n", " {'arm_type': SMPyBandits.Arms.Bernoulli.Bernoulli,\n", " 'params': [0.01, 0.01, 0.01, 0.02, 0.02, 0.02, 0.05, 0.05, 0.1]}],\n", " 'policies': [{'archtype': SMPyBandits.Policies.UCB.UCB, 'params': {}},\n", " {'archtype': SMPyBandits.Policies.UCBalpha.UCBalpha,\n", " 'params': {'alpha': 0.5}},\n", " {'archtype': SMPyBandits.Policies.EmpiricalMeans.EmpiricalMeans,\n", " 'params': {}}]}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "configuration = {\n", " # --- Duration of the experiment\n", " \"horizon\": HORIZON,\n", " # --- Number of repetition of the experiment (to have an average)\n", " \"repetitions\": REPETITIONS,\n", " # --- Parameters for the use of joblib.Parallel\n", " \"n_jobs\": N_JOBS, # = nb of CPU cores\n", " \"verbosity\": 6, # Max joblib verbosity\n", " # --- Arms\n", " \"environment\": ENVIRONMENTS,\n", " # --- Algorithms\n", " \"policies\": POLICIES,\n", "}\n", "configuration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating the `Evaluator` object" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of policies in this comparison: 3\n", "Time horizon: 10000\n", "Number of repetitions: 100\n", "Sampling rate for plotting, delta_t_plot: 1\n", "Number of jobs for parallelization: 4\n", "Using this dictionary to create a new environment:\n", " {'arm_type':