{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-14-mab.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/P920826%20%7C%20Multi-Armed%20Bandits.ipynb","timestamp":1644613989516}],"collapsed_sections":[],"authorship_tag":"ABX9TyPgqv1n0wEAKJ0oCISigkms"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Multi-Armed Bandits"],"metadata":{"id":"KR6cZY4I9aOu"}},{"cell_type":"markdown","metadata":{"id":"JL13uxO2M40G"},"source":["## Agents"]},{"cell_type":"markdown","metadata":{"id":"MHlgXpT9NRee"},"source":["### Base"]},{"cell_type":"code","metadata":{"id":"c1E-ptqGNRbF"},"source":["\n","import numpy as np\n","from collections import OrderedDict"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-w_5Of8fOAKd"},"source":["\n","class Agent(object):\n"," def __init__(self, name):\n"," self.name = name\n"," self.history = OrderedDict()\n","\n"," def init(self):\n"," pass\n","\n"," def step(self, t):\n"," action = self._step(t)\n"," self.history[t] = (action, np.nan)\n"," return action\n","\n"," def _step(self, t):\n"," raise NotImplementedError\n","\n"," def get_reward(self, reward, t):\n"," assert t in self.history, \"time t when the action was taken doesn't exist in history\"\n"," action = self.history[t][0]\n"," self.history[t] = (action, reward)\n"," self._get_reward(action, reward, t)\n","\n"," def _get_reward(self, action, reward, t):\n"," raise NotImplementedError\n","\n"," def reset(self):\n"," self.history = OrderedDict()\n"," self.init()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"lS1tFETvNRYz"},"source":["### Epsilon Greedy"]},{"cell_type":"code","metadata":{"id":"ufQof1W-NRUj"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"JIfMj5SROFrJ"},"source":["\n","class EpsilonGreedy(Agent):\n"," def __init__(self, n_actions, epsilon=None):\n"," name = r\"$\\epsilon$-greedy with $\\epsilon$={}\".format(epsilon) if epsilon is not None else r\"Optimal $\\epsilon$-greedy algorithm\"\n"," super(EpsilonGreedy, self).__init__(name)\n"," self.epsilon = epsilon\n"," self.n_actions = n_actions\n"," self.count_actions = None\n"," self.exp_reward = None\n"," self.init()\n","\n"," def init(self):\n"," self.count_actions = np.zeros(self.n_actions)\n"," self.exp_reward = np.zeros(self.n_actions)\n","\n"," def _step(self, t):\n"," if self.epsilon is None:\n"," epsilon = (t + 1) ** (-1/3) * (self.n_actions * np.log(t + 1)) ** (1/3)\n"," else:\n"," epsilon = self.epsilon\n"," valid_actions = np.arange(self.n_actions)\n"," if np.random.random() > epsilon:\n"," r = self.exp_reward\n"," valid_actions = valid_actions[r == r.max()]\n"," action = np.random.choice(valid_actions)\n"," self.count_actions[action] += 1\n"," return action\n","\n"," def _get_reward(self, action, reward, t):\n"," self.exp_reward[action] += (reward - self.exp_reward[action]) / self.count_actions[action]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OAyZI8S_NRSY"},"source":["### Explore-Exploit"]},{"cell_type":"code","metadata":{"cellView":"form","id":"28tOfZiONRQh"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"2U2TMoDROKjB"},"source":["\n","class ExploreExploit(Agent):\n"," def __init__(self, n_actions, n_explore):\n"," super(ExploreExploit, self).__init__(\"Explore and exploit algorithm $N={}$\".format(n_explore))\n"," self.n_explore = n_explore\n"," self.n_actions = n_actions\n"," self.count_actions = None\n"," self.sum_reward = None\n"," self.chosen_action = None\n"," self.init()\n","\n"," def init(self):\n"," self.count_actions = np.zeros(self.n_actions, dtype=np.int)\n"," self.sum_reward = np.zeros(self.n_actions)\n","\n"," def _step(self, t):\n"," count = self.count_actions.sum()\n"," if self.count_actions.min() < self.n_explore:\n"," action = self.count_actions.argmin()\n"," elif count == self.n_explore * self.n_actions:\n"," action = np.random.choice(np.arange(self.n_actions)[self.sum_reward == self.sum_reward.max()])\n"," self.chosen_action = action\n"," else:\n"," action = self.chosen_action\n"," self.count_actions[action] += 1\n"," return action\n","\n"," def _get_reward(self, action, reward, t):\n"," self.sum_reward[action] += reward\n","\n","\n","class ExploreExploitOptimal(ExploreExploit):\n"," def __init__(self, n_actions, n_steps):\n"," n_explore = int((n_steps / n_actions * np.sqrt(2 * np.log(n_steps))) ** (2/3))\n"," super(ExploreExploitOptimal, self).__init__(n_actions, n_explore)\n"," self.name = r\"Optimal explore and exploit algorithm ($N = {}$)\".format(n_explore)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"M-QO1m-CNROU"},"source":["### Successive Elimination"]},{"cell_type":"code","metadata":{"cellView":"form","id":"CWmNixf1Nat4"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"6FjtYgRVOUDu"},"source":["\n","class SuccessiveElimination(Agent):\n"," def __init__(self, n_actions, n_steps):\n"," super(SuccessiveElimination, self).__init__(\"Successive elimination\")\n"," self.n_steps = n_steps\n"," self.n_actions = n_actions\n"," self.count_actions = None\n"," self.exp_reward = None\n"," self.active_actions = None\n"," self.round_actions = None\n"," self.init()\n","\n"," def init(self):\n"," self.count_actions = np.zeros(self.n_actions)\n"," self.exp_reward = np.zeros(self.n_actions)\n"," self.active_actions = np.ones(self.n_actions)\n"," self.round_actions = list(np.arange(self.n_actions))\n","\n"," def calc_conf_radius(self):\n"," return np.sqrt(2 * np.log(self.n_steps) / self.count_actions)\n","\n"," def ucb(self):\n"," return self.exp_reward + self.calc_conf_radius()\n","\n"," def lcb(self):\n"," return self.exp_reward - self.calc_conf_radius()\n","\n"," def _step(self, t):\n"," if self.round_actions:\n"," action = self.round_actions.pop()\n"," elif self.active_actions.sum() == 1:\n"," action = self.active_actions.argmax()\n"," else:\n"," lcb, ucb = self.lcb(), self.ucb()\n"," lcb_max = lcb[self.active_actions.astype(bool)].max()\n"," stay_active = ucb >= lcb_max\n"," self.active_actions *= stay_active\n"," self.round_actions = list(np.arange(self.n_actions)[self.active_actions.astype(bool)])\n"," action = self.round_actions.pop()\n"," self.count_actions[action] += 1\n"," return action\n","\n"," def _get_reward(self, action, reward, t):\n"," self.exp_reward[action] += (reward - self.exp_reward[action]) / self.count_actions[action]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PH35rbInNarl"},"source":["### UCB1"]},{"cell_type":"code","metadata":{"cellView":"form","id":"DhZiSH-kNhT0"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"xt2Bkm8EPdoU"},"source":["\n","class UCB1(Agent):\n"," def __init__(self, n_actions, n_steps=None):\n"," super(UCB1, self).__init__(\"UCB1\" if n_steps is None else \"UCB1 with fixed total #steps\")\n"," self.n_steps = n_steps\n"," self.n_actions = n_actions\n"," self.count_actions = None\n"," self.exp_reward = None\n"," self.round_actions = None\n"," self.init()\n","\n"," def init(self):\n"," self.count_actions = np.zeros(self.n_actions)\n"," self.exp_reward = np.zeros(self.n_actions)\n"," self.round_actions = list(np.arange(self.n_actions))\n","\n"," def calc_conf_radius(self):\n"," T = self.count_actions.sum() if self.n_steps is None else self.n_steps\n"," return np.sqrt(2 * np.log(T) / self.count_actions)\n","\n"," def ucb(self):\n"," return self.exp_reward + self.calc_conf_radius()\n","\n"," def _step(self, t):\n"," if self.round_actions:\n"," action = self.round_actions.pop()\n"," else:\n"," action = self.ucb().argmax()\n"," self.count_actions[action] += 1\n"," return action\n","\n"," def _get_reward(self, action, reward, t):\n"," self.exp_reward[action] += (reward - self.exp_reward[action]) / self.count_actions[action]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PSJLN6NkNhCe"},"source":["### UCB2"]},{"cell_type":"code","metadata":{"cellView":"form","id":"57YFo814Ng_o"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"myR0nk0iPkeC"},"source":["\n","class UCB2(Agent):\n"," def __init__(self, n_actions, alpha):\n"," super(UCB2, self).__init__(r\"UCB2 where $\\alpha$={}\".format(alpha))\n"," self.n_actions = n_actions\n"," self.alpha = alpha\n"," self.count_actions = None\n"," self.exp_reward = None\n"," self.count_r = None\n"," self.selected_action = None\n"," self.init()\n","\n"," def init(self):\n"," self.count_actions = np.zeros(self.n_actions)\n"," self.exp_reward = np.zeros(self.n_actions)\n"," self.count_r = np.zeros(self.n_actions)\n"," self.selected_action = None\n","\n"," def tau(self, r):\n"," return np.ceil((1.0 + self.alpha) ** r)\n","\n"," def ucb2(self):\n"," tau = self.tau(self.count_r)\n"," radius = np.sqrt((1 + self.alpha) * np.log(np.e * self.count_actions.sum() / tau) / (2 * tau))\n"," return self.exp_reward + radius\n","\n"," def _step(self, t):\n"," if self.count_actions.min() == 0:\n"," action = self.count_actions.argmin()\n"," else:\n"," while True:\n"," if self.selected_action is None:\n"," action = self.ucb2().argmax()\n"," r_a = self.count_r[action]\n"," n_times = self.tau(r_a + 1) - self.tau(r_a)\n"," self.selected_action = (action, n_times)\n"," action, n_remaining = self.selected_action\n"," if n_remaining == 0:\n"," self.selected_action = None\n"," self.count_r[action] += 1\n"," else:\n"," break\n"," self.selected_action = (action, n_remaining - 1)\n"," self.count_actions[action] += 1\n"," return action\n","\n"," def _get_reward(self, action, reward, t):\n"," self.exp_reward[action] += (reward - self.exp_reward[action]) / self.count_actions[action]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"jgWtcf2dNEh3"},"source":["## Bandits"]},{"cell_type":"markdown","metadata":{"id":"HyMfc-H_Pu7b"},"source":["### Base"]},{"cell_type":"code","metadata":{"cellView":"form","id":"W3f7BbovPtR4"},"source":["\n","class Bandit(object):\n"," def __init__(self, n_actions):\n"," self.t = 0\n"," self.n_actions = n_actions\n","\n"," def init(self):\n"," pass\n","\n"," def step(self):\n"," self.t += 1\n"," self._step()\n","\n"," def _step(self):\n"," pass\n","\n"," def best_expectation(self):\n"," raise NotImplementedError\n","\n"," def get_reward(self, action):\n"," raise NotImplementedError\n","\n"," def reset(self):\n"," self.t = 0\n"," self.init()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"249K9FEcP2oX"},"source":["### Independent"]},{"cell_type":"code","metadata":{"cellView":"form","id":"ZJbIs80YPtOe"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"NAr0nheJPtMr"},"source":["\n","class IndependentBandit(Bandit):\n"," def __init__(self, arms):\n"," super(IndependentBandit, self).__init__(len(arms))\n"," self.arms = arms\n","\n"," def best_expectation(self):\n"," return np.max([arm.get_expectation(self.t) for arm in self.arms])\n","\n"," def get_reward(self, action):\n"," arm = self.arms[action]\n"," return arm.get_reward(self.t), self.best_expectation() - arm.get_expectation(self.t)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gk0VpiIbPtK8"},"source":["## Arms"]},{"cell_type":"markdown","metadata":{"id":"HZLkZ7E9S1OW"},"source":["Types of bandits are:\n","\n","- bernoulli (default): bandit arms have bernoulli distributed rewards\n","- normal: bandit arms have Gaussian distributed rewards\n","- bernoulli periodic: success probability of the bernoulli distribution oscillates as a sinusoid."]},{"cell_type":"markdown","metadata":{"id":"3iJWMowzPtIu"},"source":["### Base"]},{"cell_type":"code","metadata":{"cellView":"form","id":"-mY9IFa2QL4H"},"source":["\n","class Arm(object):\n"," def __init__(self):\n"," pass\n","\n"," def get_expectation(self, t):\n"," raise NotImplementedError\n","\n"," def get_reward(self, t):\n"," raise NotImplementedError"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NPzkMf27QGA5"},"source":["### Bernoulli"]},{"cell_type":"code","metadata":{"cellView":"form","id":"0ylXypdDQVlO"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"b3hSvZiZQKwy"},"source":["\n","class BernoulliArm(Arm):\n"," def __init__(self, p_success):\n"," super(BernoulliArm, self).__init__()\n"," self.p_success = p_success\n","\n"," def get_expectation(self, t):\n"," return self.p_success\n","\n"," def get_reward(self, t):\n"," return np.random.binomial(1, self.p_success)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4JF5e13rQHGK"},"source":["### Bernoulli Periodic"]},{"cell_type":"code","metadata":{"cellView":"form","id":"OsZRM8PiQKHq"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"pOcnSJOVQa0A"},"source":["\n","class BernoulliPeriodicArm(Arm):\n"," def __init__(self, p_min, p_max, period, offset=0):\n"," super(BernoulliPeriodicArm, self).__init__()\n"," assert 0 < p_min < p_max < 1, \"wrong initialisation of probability\"\n"," self.p_min = p_min\n"," self.p_max = p_max\n"," self.offset = offset\n"," self.period = period\n","\n"," def p_success(self, t):\n"," y = np.sin(2 * np.pi * (t + self.offset) / self.period)\n"," return (self.p_max - self.p_min) / 2 * y + (self.p_max + self.p_min) / 2\n","\n"," def get_expectation(self, t):\n"," return self.p_success(t)\n","\n"," def get_reward(self, t):\n"," return np.random.binomial(1, self.p_success(t))"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"GU4RYgE0QHid"},"source":["### Normal"]},{"cell_type":"code","metadata":{"cellView":"form","id":"WvEkvePSPtF4"},"source":["\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"d7fLmQuTQfNW"},"source":["\n","class NormalArm(Arm):\n"," def __init__(self, mean, var):\n"," super(NormalArm, self).__init__()\n"," self.mean = mean\n"," self.var = var\n","\n"," def get_expectation(self, t):\n"," return self.mean\n","\n"," def get_reward(self, t):\n"," return np.random.normal(self.mean, self.var)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AJXsZY1wNJKk"},"source":["## Environments"]},{"cell_type":"code","metadata":{"cellView":"form","id":"2Dl0XpsyQjqm"},"source":["\n","import numpy as np\n","from collections import OrderedDict"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"9MauCBMiQnqz"},"source":["\n","class Environment(object):\n"," def __init__(self, bandit, agent, delay=0):\n"," self.bandit = bandit\n"," self.agent = agent\n"," self.delay = delay\n"," self.t = 0\n"," self.actions = OrderedDict()\n"," self.rewards = OrderedDict()\n"," self.regrets = OrderedDict()\n","\n"," def init(self):\n"," pass\n","\n"," def claim_reward(self, action, delay=0):\n"," reward, regret = self.bandit.get_reward(action)\n"," self.rewards[self.t + delay] = reward\n"," self.regrets[self.t + delay] = regret\n","\n"," def step(self):\n"," action = self._step()\n"," reward = self.rewards[self.t] if self.t in self.rewards else None\n"," regret = self.regrets[self.t] if self.t in self.regrets else None\n"," if reward is not None:\n"," self.agent.get_reward(reward, self.t)\n"," self.t += 1\n"," self.bandit.step()\n"," return action, reward, regret\n","\n"," def _step(self):\n"," action = self.agent.step(self.t)\n"," self.claim_reward(action, delay=self.delay)\n"," return action\n","\n"," def run(self, n_steps):\n"," self.reset()\n"," actions, rewards, _rewards, cum_rewards, cum_rewards_mean, regrets, _regrets, cum_regrets = [], [], [], [], [], [], [], []\n"," for i in range(n_steps):\n"," action, reward, regret = self.step()\n"," actions.append(action)\n"," rewards.append(reward if reward is not None else 0)\n"," if reward is not None:\n"," _rewards.append(reward)\n"," cum_rewards.append(np.sum(_rewards) if _rewards else 0)\n"," cum_rewards_mean.append(np.mean(_rewards) if _rewards else 0)\n"," regrets.append(regret if regret is not None else 0)\n"," if regret is not None:\n"," _regrets.append(regret)\n"," cum_regrets.append(np.sum(_regrets) if _regrets else 0)\n","\n"," return actions, rewards, cum_rewards, cum_rewards_mean, cum_regrets\n","\n"," def reset(self):\n"," self.bandit.reset()\n"," self.agent.reset()\n"," self.t = 0\n"," self.actions = []\n"," self.init()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4gQrA45cNLXN"},"source":["## Simulations"]},{"cell_type":"markdown","metadata":{"id":"_nxse79DNNfC"},"source":["### Experiment"]},{"cell_type":"code","metadata":{"cellView":"form","id":"MFeE3HngQtuk"},"source":["\n","import pickle\n","from datetime import datetime\n","import numpy as np"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"Ku4vOfUxQzZs"},"source":["\n","class Experiment(object):\n"," def __init__(self, envs, bandit, name, args, save=True):\n"," self.name = name\n"," self.envs = envs\n"," self.bandit = bandit\n"," self.type_bandit = args.bandit\n"," self.eval_regrets = args.regrets\n"," self.n_regret_eval = args.n_regret_eval\n"," self.n_steps = args.n_steps\n"," self.n_runs = args.n_runs\n"," self.labels = [env.agent.name for env in envs]\n"," self.results = None\n"," self.entries = [\"actions\", \"rewards\", \"cum_rewards\", \"cum_mean\", \"regrets\", \"final_regrets\"]\n"," self.run()\n"," if save:\n"," self.save()\n","\n"," def run(self):\n"," n_envs = len(self.envs)\n"," n_actions = self.bandit.n_actions\n"," n_steps = self.n_steps\n"," n_runs = self.n_runs\n","\n"," n_reg = self.n_regret_eval\n"," Ts = np.linspace(0, n_steps, n_reg + 1).astype(int)\n","\n"," rewards_runs = np.zeros((n_envs, n_steps))\n"," cum_rewards_runs = np.zeros((n_envs, n_steps))\n"," cum_rewards_mean_runs = np.zeros((n_envs, n_steps))\n"," regrets_runs = np.zeros((n_envs, n_steps))\n"," cum_actions_count_runs = np.zeros((n_envs, n_steps, n_actions))\n"," final_regrets = None\n","\n"," for i_env, env in enumerate(self.envs):\n"," for i_run in range(self.n_runs):\n"," actions, rewards, cum_rewards, cum_rewards_mean, regrets = env.run(n_steps=self.n_steps)\n","\n"," rewards_runs[i_env] += np.array(rewards)\n"," cum_rewards_runs[i_env] += np.array(cum_rewards)\n"," cum_rewards_mean_runs[i_env] += np.array(cum_rewards_mean)\n"," regrets_runs[i_env] += np.array(regrets)\n"," cum_actions_count_runs[i_env, np.arange(n_steps), np.array(actions)] += 1.0\n","\n"," rewards_runs /= n_runs\n"," cum_rewards_runs /= n_runs\n"," cum_rewards_mean_runs /= n_runs\n"," regrets_runs /= n_runs\n","\n"," if self.eval_regrets:\n"," final_regrets = np.zeros((n_envs, len(Ts)))\n"," final_regrets[:, -1] = regrets_runs[:, -1]\n"," for i, T in enumerate(Ts):\n"," if i == 0 or T == n_steps:\n"," continue\n"," final_regrets_sum = np.zeros(n_envs)\n"," for i_env, env in enumerate(self.envs):\n"," for i_run in range(self.n_runs):\n"," _, _, _, _, regrets = env.run(n_steps=T)\n"," final_regrets_sum[i_env] += regrets[-1]\n"," final_regrets[:, i] = final_regrets_sum / n_runs\n","\n"," self.results = {\n"," \"actions\": cum_actions_count_runs,\n"," \"rewards\": rewards_runs,\n"," \"cum_rewards\": cum_rewards_runs,\n"," \"cum_mean\": cum_rewards_mean_runs,\n"," \"regrets\": regrets_runs,\n"," \"final_regrets\": final_regrets\n"," }\n","\n"," def get_results(self):\n"," if self.results is None:\n"," return [None for _ in self.entries]\n"," return [self.results[e] for e in self.entries]\n","\n"," def save(self):\n"," now = datetime.now()\n"," current_time = now.strftime(\"%y%m%d_%H%M%S\")\n"," filename = \"{}_{}_{}_{}_steps_{}_runs\".format(current_time, self.type_bandit, self.name, self.n_steps, self.n_runs)\n"," filepath = \"data/\" + filename + \".p\"\n","\n"," with open(filepath, 'wb') as handle:\n"," pickle.dump(self, handle)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"IvPY1A1aQtrx"},"source":["### IID"]},{"cell_type":"code","metadata":{"cellView":"form","id":"wm9t2t6PQ7sa"},"source":["\n","import numpy as np\n","from datetime import datetime\n","import pickle"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"zFo2PM8vQ7o7"},"source":["\n","# from env.base import Environment\n","# from agent.eps_greedy import EpsilonGreedy\n","# from agent.exp_exp import ExploreExploit, ExploreExploitOptimal\n","# from agent.succ_elim import SuccessiveElimination\n","# from agent.ucb1 import UCB1\n","# from agent.ucb2 import UCB2\n","# from sim.experiment import Experiment\n","# from sim.utils import plot, plot_regrets"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"DMi39h6nQto7"},"source":["\n","def get_envs(bandit, agents, args):\n"," envs = [Environment(bandit, agent, delay=args.delay) for agent in agents]\n"," return envs\n","\n","\n","def run_epsilon_on_iid(bandit, epsilons, args):\n"," agents = [EpsilonGreedy(bandit.n_actions, epsilon) for epsilon in epsilons]\n"," envs = get_envs(bandit, agents, args)\n"," experiment = Experiment(envs, bandit, \"eps-greedy\", args)\n"," plot(experiment)\n","\n","\n","def run_exp_exp_on_iid(bandit, n_explores, args):\n"," agents = [ExploreExploit(bandit.n_actions, n_explore) for n_explore in n_explores]\n"," envs = get_envs(bandit, agents, args)\n"," experiment = Experiment(envs, bandit, \"exp-exp\", args)\n"," plot(experiment)\n","\n","\n","def run_exp_exp_opt_on_iid(bandit, args):\n"," return run_regret_experiment(bandit, ExploreExploitOptimal, args, \"exp-exp-opt\")\n","\n","\n","def run_regret_experiment(bandit, Agent, args, name):\n"," n_runs = args.n_runs\n"," n_steps = args.n_steps\n"," Ts = np.linspace(0, n_steps, args.n_regret_eval + 1).astype(int)\n"," final_regrets = np.zeros(len(Ts))\n"," agent = None\n"," env = None\n"," for i, T in enumerate(Ts):\n"," if i == 0:\n"," continue\n"," agent = Agent(bandit.n_actions, T)\n"," env = Environment(bandit, agent, delay=args.delay)\n"," regrets_sum = 0\n"," for _ in range(args.n_runs):\n"," _, _, _, _, regrets = env.run(n_steps=T)\n"," regrets_sum += regrets[-1]\n"," final_regrets[i] = regrets_sum / n_runs\n","\n"," now = datetime.now()\n"," current_time = now.strftime(\"%y%m%d_%H%M%S\")\n"," filename = \"{}_{}_{}_{}_steps_{}_runs\".format(current_time, args.bandit, name + \"-regrets\", n_steps, n_runs)\n"," filepath = \"data/\" + filename + \".p\"\n","\n"," with open(filepath, 'wb') as handle:\n"," pickle.dump((Ts, final_regrets, agent.name), handle)\n","\n"," experiment = Experiment([env], bandit, name, args)\n","\n"," plot_regrets([Ts], [final_regrets], [agent.name])\n"," plot(experiment)\n","\n"," return Ts, final_regrets, agent.name\n","\n","\n","def run_succ_elim_on_iid(bandit, args):\n"," return run_regret_experiment(bandit, SuccessiveElimination, args, \"succ-elim\")\n","\n","\n","def run_ucb1_fixed_steps_on_iid(bandit, args):\n"," agents = [UCB1(bandit.n_actions, n_steps=args.n_steps)]\n"," envs = get_envs(bandit, agents, args)\n"," experiment = Experiment(envs, bandit, \"ucb1-fixed\", args)\n"," plot(experiment)\n","\n","\n","def run_ucb1_on_iid(bandit, args):\n"," agents = [UCB1(bandit.n_actions)]\n"," envs = get_envs(bandit, agents, args)\n"," experiment = Experiment(envs, bandit, \"ucb1\", args)\n"," plot(experiment)\n","\n","\n","def run_ucb2_on_iid(bandit, alphas, args):\n"," agents = [UCB2(bandit.n_actions, alpha) for alpha in alphas]\n"," envs = get_envs(bandit, agents, args)\n"," experiment = Experiment(envs, bandit, \"ucb2\", args)\n"," plot(experiment)\n","\n","\n","def run_all_on_iid(bandit, args):\n"," agents = [\n"," ExploreExploitOptimal(bandit.n_actions, args.n_steps),\n"," EpsilonGreedy(bandit.n_actions),\n"," SuccessiveElimination(bandit.n_actions, args.n_steps),\n"," UCB1(bandit.n_actions),\n"," UCB2(bandit.n_actions, 0.01)\n"," ]\n"," envs = get_envs(bandit, agents, args)\n"," experiment = Experiment(envs, bandit, \"all\", args)\n"," plot(experiment)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"MnGmBa_cQtlR"},"source":["### Utils"]},{"cell_type":"code","metadata":{"cellView":"form","id":"EpCORFwoRFTq"},"source":["\n","import numpy as np\n","import matplotlib.pyplot as plt\n","import os, pickle"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"BJitSmZZRDL7"},"source":["\n","def load_data(filename):\n"," data_path = os.path.join(os.path.dirname(__file__)[:-3], filename[2:])\n","\n"," with open(data_path, 'rb') as handle:\n"," experiment = pickle.load(handle)\n","\n"," return experiment\n","\n","\n","def plot_from_data(filename):\n"," if \"regret\" in filename:\n"," Ts, regrets, labels = load_data(filename)\n"," plot_regrets([Ts], [regrets], [labels])\n"," plt.show()\n"," else:\n"," experiment = load_data(filename)\n"," plot(experiment)\n","\n","\n","def plot_regrets(Ts_arr, final_regrets, labels):\n"," plt.figure()\n"," for Ts, final_regrets_item, label in zip(Ts_arr, final_regrets, labels):\n"," plt.loglog(Ts, final_regrets_item, label=label)\n"," plt.legend()\n"," plt.xlabel(r\"Number of total time steps $T$\")\n"," plt.ylabel(r\"$Regret(T)$\")\n","\n","\n","def plot(experiment):\n"," n_steps = experiment.n_steps\n"," n_actions = experiment.bandit.n_actions\n"," labels = experiment.labels\n","\n"," actions, rewards, cum_rewards, cum_rewards_mean, regrets, final_regrets = experiment.get_results()\n","\n"," if final_regrets is not None:\n"," Ts = np.linspace(0, n_steps, final_regrets.shape[1]).astype(int)\n"," Ts_arr = [Ts for _ in range(len(labels))]\n"," plot_regrets(Ts_arr, final_regrets, labels)\n","\n"," plt.figure()\n"," for cum_rewards_mean_item, label in zip(cum_rewards_mean, labels):\n"," plt.plot(np.arange(n_steps), cum_rewards_mean_item, label=label)\n"," plt.legend()\n"," plt.xlabel(r\"Number of time steps $t$\")\n"," plt.ylabel(r\"$\\overline{Reward}(t)$\")\n","\n"," for actions_item, label in zip(actions, labels):\n"," plt.figure()\n"," bottom_sum = np.zeros(n_steps)\n"," for action in range(n_actions):\n"," plt.title(label)\n"," count = actions_item[:, action]\n"," plt.fill_between(np.arange(n_steps), bottom_sum, count + bottom_sum, label=str(action))\n"," bottom_sum += count\n"," plt.legend()\n"," plt.xlabel(\"Number of time steps\")\n"," plt.ylabel(\"Action chosen for each time step\")\n","\n"," plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"nbt-dDjyRRrj"},"source":["## Main"]},{"cell_type":"code","metadata":{"cellView":"form","id":"RXx7c32CRSZ-"},"source":["\n","import argparse\n","\n","\n","def parse_command():\n"," parser = argparse.ArgumentParser()\n"," _add_experiment_parser(parser)\n"," args = parser.parse_args(args={})\n","\n"," return args\n","\n","\n","def _add_experiment_parser(parser):\n"," o_parser = parser.add_argument_group(title='Experiment types')\n"," o_parser.add_argument('--plot', default=\"\",\n"," help=\"experiment data to plot\")\n"," o_parser.add_argument('--exp', type=int, default=0,\n"," help=\"experiment to run\")\n"," o_parser.add_argument('--n_runs', type=int, default=200,\n"," help=\"number of runs\")\n"," o_parser.add_argument('--n_steps', type=int, default=1000,\n"," help=\"number of steps\")\n"," o_parser.add_argument('--regrets', type=bool, default=True,\n"," help=\"plot regret against the number of rounds\")\n"," o_parser.add_argument('--n_regret_eval', type=int, default=10,\n"," help=\"number of experiments to evaluate regrets\")\n"," o_parser.add_argument('--bandit', type=str, default=\"bernoulli\",\n"," help=\"type of bandit\")\n"," o_parser.add_argument('--delay', type=int, default=0,\n"," help=\"delay of reward\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"cellView":"form","id":"AQS-_y7aRYvf"},"source":["\n","# from sim.iid import *\n","# from cli import parse_command\n","# from sim.utils import plot_from_data\n","# from bandit.independent import IndependentBandit\n","# from bandit.arm.normal import NormalArm\n","# from bandit.arm.bernoulli import BernoulliArm\n","# from bandit.arm.bernoulli_periodic import BernoulliPeriodicArm"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"-77SYM2bR5hB","executionInfo":{"status":"ok","timestamp":1632229412022,"user_tz":-330,"elapsed":1267035,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"2691bfb6-e163-4b25-b731-bea7c0e0af2d"},"source":["\n","MEANS = [0.0, 0.1, 0.2, 0.3, 0.4]\n","VARS = [3.0, 2.4, 1.8, 1.2, 0.6]\n","P_SUCCESSES = [0.4, 0.45, 0.5, 0.55, 0.6]\n","\n","!mkdir data\n","\n","def get_normal_bandit(means, vars):\n"," arms = [NormalArm(mean, var) for mean, var in zip(means, vars)]\n"," return IndependentBandit(arms)\n","\n","\n","def get_bernoulli_bandit(ps):\n"," arms = [BernoulliArm(p) for p in ps]\n"," return IndependentBandit(arms)\n","\n","\n","def get_periodic_bandit(p_min, p_max, period, n):\n"," arms = [BernoulliPeriodicArm(p_min, p_max, period, period * i / n) for i in range(n)]\n"," return IndependentBandit(arms)\n","\n","\n","def main(args):\n","\n"," bandit = get_bernoulli_bandit(P_SUCCESSES)\n"," if args.bandit == \"normal\":\n"," bandit = get_normal_bandit(MEANS, VARS)\n"," elif args.bandit == \"periodic\":\n"," bandit = get_periodic_bandit(0.3, 0.7, 100, 5)\n","\n"," if args.plot != \"\":\n"," plot_from_data(args.plot)\n","\n"," elif args.exp == 0:\n"," print(\"Explore-exploit algorithm\")\n"," n_explores = [0, 5, 10, 50, 100, 150]\n"," run_exp_exp_on_iid(bandit, n_explores, args)\n","\n"," elif args.exp == 1:\n"," print(\"Optimal explore-exploit algorithm\")\n"," run_exp_exp_opt_on_iid(bandit, args)\n","\n"," elif args.exp == 2:\n"," print(\"Epsilon-greedy algorithm\")\n"," epsilons = [1e-1, 1e-2, 1e-3, 1e-4, 0.0, None]\n"," run_epsilon_on_iid(bandit, epsilons, args)\n","\n"," elif args.exp == 3:\n"," print(\"Successive elimination algorithm\")\n"," run_succ_elim_on_iid(bandit, args)\n","\n"," elif args.exp == 4:\n"," print(\"UCB1 algorithm\")\n"," run_ucb1_on_iid(bandit, args)\n","\n"," elif args.exp == 5:\n"," print(\"UCB2 algorithm\")\n"," alphas = [0.001, 0.003, 0.01, 0.03, 0.1, 0.3]\n"," run_ucb2_on_iid(bandit, alphas, args)\n","\n"," elif args.exp == 6:\n"," print(\"All algorithms\")\n"," run_all_on_iid(bandit, args)\n","\n","\n","if __name__ == '__main__':\n"," args = parse_command()\n"," main(args)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Explore-exploit algorithm\n"]},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}},{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]}]}