{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "2021-06-17-recostep-tutorial-offline-replayer-eval-recogym.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyMo+rs0MFOnygJoYfsqfz0h", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "B7BGtRmWYACb" }, "source": [ "## Environment setup" ] }, { "cell_type": "code", "metadata": { "id": "lE0TYT5zVVjy" }, "source": [ "!pip install -q recogym" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "uSUfDrbZywvc" }, "source": [ "# Offline Replayer Evaluation I - Recogym\n", "> Running recogym for offline simulation and evaluation\n", "\n", "- toc: true\n", "- badges: true\n", "- comments: true\n", "- categories: [bandit]\n", "- image: " ] }, { "cell_type": "code", "metadata": { "id": "cUH6Vc1OU89n" }, "source": [ "import numpy as np\n", "from numpy.random.mtrand import RandomState\n", "from scipy.special import logsumexp\n", "import scipy\n", "import pandas as pd\n", "from scipy.stats.distributions import beta\n", "from copy import deepcopy\n", "\n", "from scipy.sparse import csr_matrix\n", "from scipy.sparse.linalg import svds\n", "\n", "from itertools import chain\n", "from sklearn.neighbors import NearestNeighbors\n", "from IPython.display import display, HTML\n", "\n", "from matplotlib.ticker import FormatStrFormatter\n", "\n", "import gym, recogym\n", "from recogym import env_1_args, Configuration\n", "from recogym.agents import OrganicUserEventCounterAgent, organic_user_count_args\n", "from recogym.agents.organic_count import OrganicCount, organic_count_args, to_categorical\n", "from recogym import Configuration\n", "from recogym.agents import Agent\n", "from recogym.envs.observation import Observation\n", "from recogym.agents import RandomAgent, random_args\n", "from recogym import verify_agents, verify_agents_IPS\n", "from recogym.evaluate_agent import plot_verify_agents, verify_agents_recall_at_k\n", "\n", "from recogym.envs.session import OrganicSessions\n", "from recogym.envs.context import DefaultContext\n", "from recogym.envs.observation import Observation\n", "\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "P = 2000 # Number of Products\n", "U = 2000 # Number of Users" ], "execution_count": 19, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "vk3eavl0VN_n" }, "source": [ "# You can overwrite environment arguments here\n", "env_1_args['random_seed'] = 42\n", "env_1_args['num_products']= P\n", "env_1_args['phi_var']=0.0\n", "env_1_args['number_of_flips']=P//2\n", "env_1_args['sigma_mu_organic'] = 0.1\n", "env_1_args['sigma_omega']=0.05" ], "execution_count": 4, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "8ShuzevnVI7b" }, "source": [ "# Initialize the gym for the first time by calling .make() and .init_gym()\n", "env = gym.make('reco-gym-v1')\n", "env.init_gym(env_1_args)" ], "execution_count": 5, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "p3ZfkuZcVMZT" }, "source": [ "# env.reset()" ], "execution_count": 6, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 703 }, "id": "zdnvCn74VnMn", "outputId": "b57ab8a8-9fc6-4982-8995-3a41398bfc95" }, "source": [ "# Generate RecSys logs for U users\n", "reco_log = env.generate_logs(U)\n", "reco_log.head(20)" ], "execution_count": 7, "outputs": [ { "output_type": "stream", "text": [ "Organic Users: 0it [00:00, ?it/s]\n", "Users: 100%|██████████| 2000/2000 [02:15<00:00, 14.73it/s]\n" ], "name": "stderr" }, { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tuzvacpsps-a
00.00organic116<NA>NaNNaNNone
11.00bandit<NA>11230.00.0005()
22.00bandit<NA>13320.00.0005()
33.00bandit<NA>8050.00.0005()
44.00bandit<NA>11840.00.0005()
50.01organic1205<NA>NaNNaNNone
61.01organic1137<NA>NaNNaNNone
72.01organic1337<NA>NaNNaNNone
83.01organic972<NA>NaNNaNNone
94.01organic841<NA>NaNNaNNone
105.01organic140<NA>NaNNaNNone
116.01bandit<NA>13670.00.0005()
127.01bandit<NA>10860.00.0005()
138.01bandit<NA>16980.00.0005()
149.01bandit<NA>15130.00.0005()
1510.01bandit<NA>1340.00.0005()
1611.01bandit<NA>10560.00.0005()
1712.01bandit<NA>17510.00.0005()
1813.01bandit<NA>7250.00.0005()
1914.01bandit<NA>6120.00.0005()
\n", "
" ], "text/plain": [ " t u z v a c ps ps-a\n", "0 0.0 0 organic 116 NaN NaN None\n", "1 1.0 0 bandit 1123 0.0 0.0005 ()\n", "2 2.0 0 bandit 1332 0.0 0.0005 ()\n", "3 3.0 0 bandit 805 0.0 0.0005 ()\n", "4 4.0 0 bandit 1184 0.0 0.0005 ()\n", "5 0.0 1 organic 1205 NaN NaN None\n", "6 1.0 1 organic 1137 NaN NaN None\n", "7 2.0 1 organic 1337 NaN NaN None\n", "8 3.0 1 organic 972 NaN NaN None\n", "9 4.0 1 organic 841 NaN NaN None\n", "10 5.0 1 organic 140 NaN NaN None\n", "11 6.0 1 bandit 1367 0.0 0.0005 ()\n", "12 7.0 1 bandit 1086 0.0 0.0005 ()\n", "13 8.0 1 bandit 1698 0.0 0.0005 ()\n", "14 9.0 1 bandit 1513 0.0 0.0005 ()\n", "15 10.0 1 bandit 134 0.0 0.0005 ()\n", "16 11.0 1 bandit 1056 0.0 0.0005 ()\n", "17 12.0 1 bandit 1751 0.0 0.0005 ()\n", "18 13.0 1 bandit 725 0.0 0.0005 ()\n", "19 14.0 1 bandit 612 0.0 0.0005 ()" ] }, "metadata": { "tags": [] }, "execution_count": 7 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xbB6tEBPVuDZ", "outputId": "22d85507-4855-4552-902b-9d8b5afe4e0e" }, "source": [ "n_events = reco_log.shape[0]\n", "n_organic = reco_log.loc[reco_log['z'] == 'organic'].shape[0]\n", "print('Training on {0} organic and {1} bandit events'.format(n_organic, n_events - n_organic))" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "Training on 43753 organic and 159967 bandit events\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "gX5Gve7eYG_r" }, "source": [ "## Defining evaluation methods" ] }, { "cell_type": "markdown", "metadata": { "id": "485rufJ-X8Ju" }, "source": [ "### Traditional evaluation" ] }, { "cell_type": "code", "metadata": { "id": "44GbKe76WZAV" }, "source": [ "def leave_one_out(reco_log, agent, last = False, N = 1, folds = 10):\n", " # 1. Extract all organic events\n", " reco_log = reco_log.loc[reco_log['z'] == 'organic']\n", " \n", " # 2. For every user sequence - randomly sample out an item\n", " hits = []\n", " for _ in range(folds):\n", " user_id = 0\n", " history = []\n", " session = OrganicSessions()\n", " agent.reset()\n", " for row in reco_log.itertuples():\n", " # If we have a new user\n", " if row.u != user_id:\n", " if last:\n", " # Sample out last item\n", " index = len(history) - 1\n", " else:\n", " # Sample out a random item from the history\n", " index = np.random.choice(len(history),\n", " replace = False)\n", " test = history[index]\n", " train = history[:index] + history[index + 1:]\n", "\n", " # 3. Recreate the user sequence without these items - Let the agent observe the incomplete sequence\n", " for t, v in list(train):\n", " session.next(DefaultContext(t, user_id), int(v))\n", "\n", " # 4. Generate a top-N set of recommendations by letting the agent act\n", " # TODO - For now only works for N = 1\n", " try:\n", " prob_a = agent.act(Observation(DefaultContext(t + 1, user_id), session), 0, False)['ps-a']\n", " except:\n", " prob_a = [1 / P] * P\n", "\n", " # 5. Compute metrics checking whether the sampled test item is in the top-N\n", " try:\n", " hits.append(np.argmax(prob_a) == int(test[1]))\n", " except:\n", " hits.append(0)\n", "\n", " # Reset variables\n", " user_id = row.u\n", " history = []\n", " session = OrganicSessions()\n", " agent.reset()\n", "\n", " # Save the organic interaction to the running average for the session\n", " history.append((row.t,row.v))\n", " \n", " # Error analysis\n", " mean_hits = np.mean(hits)\n", " serr_hits = np.std(hits) / np.sqrt(len(hits))\n", " low_bound = mean_hits - 1.96 * serr_hits\n", " upp_bound = mean_hits + 1.96 * serr_hits\n", " \n", " return mean_hits, low_bound, upp_bound\n", "\n", "def verify_agents_traditional(reco_log, agents, last = False, N = 1, folds = 10):\n", " # Placeholder DataFrame for result\n", " stat = {\n", " 'Agent': [],\n", " '0.025': [],\n", " '0.500' : [],\n", " '0.975': [],\n", " }\n", "\n", " # For every agent\n", " for agent_id in agents:\n", " # Compute HR@k\n", " mean, low, upp = leave_one_out(reco_log, agents[agent_id], last = last, N = N, folds = folds)\n", " stat['Agent'].append(agent_id)\n", " stat['0.025'].append(low)\n", " stat['0.500'].append(mean)\n", " stat['0.975'].append(upp)\n", " return pd.DataFrame().from_dict(stat)" ], "execution_count": 22, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Ka70n5JcYPYE" }, "source": [ "### Counterfactual evaluation" ] }, { "cell_type": "code", "metadata": { "id": "_HoADS4uX54y" }, "source": [ "def compute_ips_weights(agent, reco_log):\n", " # Placeholder for return values\n", " rewards = [] # Labels for actions\n", " t_props = [] # Treatment propensities\n", " l_props = [] # Logging propensities\n", " \n", " # For every logged interaction\n", " user_id = 0\n", " session = OrganicSessions()\n", " agent.reset()\n", " for row in reco_log.itertuples():\n", " # If we have a new user\n", " if row.u != user_id:\n", " # Reset\n", " session = OrganicSessions()\n", " agent.reset()\n", " user_id = row.u\n", " \n", " # If we have an organic event\n", " if row.z == 'organic':\n", " session.next(DefaultContext(row.t, row.u), int(row.v)) \n", " \n", " else:\n", " prob_a = agent.act(Observation(DefaultContext(row.t, row.u), session), 0, False)['ps-a']\n", " rewards.append(row.c)\n", " try:\n", " t_props.append(prob_a[int(row.a)])\n", " except:\n", " t_props.append(0)\n", " l_props.append(row.ps)\n", " session = OrganicSessions()\n", " \n", " return np.asarray(rewards), np.asarray(t_props), np.asarray(l_props)\n", "\n", "def verify_agents_counterfactual(reco_log, agents, cap = 3):\n", " # Placeholder DataFrame for results\n", " IPS_stat = {\n", " 'Agent': [],\n", " '0.025': [],\n", " '0.500' : [],\n", " '0.975': [],\n", " }\n", " CIPS_stat = {\n", " 'Agent': [],\n", " '0.025': [],\n", " '0.500' : [],\n", " '0.975': [],\n", " }\n", " SNIPS_stat = {\n", " 'Agent': [],\n", " '0.025': [],\n", " '0.500' : [],\n", " '0.975': [],\n", " }\n", "\n", " # For every agent\n", " for agent_id in agents:\n", " # Get the rewards and propensities\n", " rewards, t_props, l_props = compute_ips_weights(agents[agent_id], reco_log)\n", " \n", " # Compute the sample weights - propensity ratios\n", " p_ratio = t_props / l_props\n", "\n", " # Effective sample size for E_t estimate (from A. Owen)\n", " n_e = len(rewards) * (np.mean(p_ratio) ** 2) / (p_ratio ** 2).mean()\n", " n_e = 0 if np.isnan(n_e) else n_e\n", " print(\"Effective sample size for agent {} is {}\".format(str(agent_id), n_e))\n", " \n", " # Critical value from t-distribution as we have unknown variance\n", " alpha = .00125\n", " cv = scipy.stats.t.ppf(1 - alpha, df = int(n_e) - 1)\n", " \n", " ###############\n", " # VANILLA IPS #\n", " ###############\n", " # Expected reward for pi_t\n", " E_t = np.mean(rewards * p_ratio)\n", "\n", " # Variance of the estimate\n", " var = ((rewards * p_ratio - E_t) ** 2).mean()\n", " stddev = np.sqrt(var)\n", " \n", " # C.I. assuming unknown variance - use t-distribution and effective sample size\n", " min_bound = E_t - cv * stddev / np.sqrt(int(n_e))\n", " max_bound = E_t + cv * stddev / np.sqrt(int(n_e))\n", " \n", " # Store result\n", " IPS_stat['Agent'].append(agent_id)\n", " IPS_stat['0.025'].append(min_bound)\n", " IPS_stat['0.500'].append(E_t)\n", " IPS_stat['0.975'].append(max_bound)\n", " \n", " ############## \n", " # CAPPED IPS #\n", " ##############\n", " # Cap ratios\n", " p_ratio_capped = np.clip(p_ratio, a_min = None, a_max = cap)\n", " \n", " # Expected reward for pi_t\n", " E_t_capped = np.mean(rewards * p_ratio_capped)\n", "\n", " # Variance of the estimate\n", " var_capped = ((rewards * p_ratio_capped - E_t_capped) ** 2).mean()\n", " stddev_capped = np.sqrt(var_capped) \n", " \n", " # C.I. assuming unknown variance - use t-distribution and effective sample size\n", " min_bound_capped = E_t_capped - cv * stddev_capped / np.sqrt(int(n_e))\n", " max_bound_capped = E_t_capped + cv * stddev_capped / np.sqrt(int(n_e))\n", " \n", " # Store result\n", " CIPS_stat['Agent'].append(agent_id)\n", " CIPS_stat['0.025'].append(min_bound_capped)\n", " CIPS_stat['0.500'].append(E_t_capped)\n", " CIPS_stat['0.975'].append(max_bound_capped)\n", " \n", " ##############\n", " # NORMED IPS #\n", " ##############\n", " # Expected reward for pi_t\n", " E_t_normed = np.sum(rewards * p_ratio) / np.sum(p_ratio)\n", "\n", " # Variance of the estimate\n", " var_normed = np.sum(((rewards - E_t_normed) ** 2) * (p_ratio ** 2)) / (p_ratio.sum() ** 2) \n", " stddev_normed = np.sqrt(var_normed)\n", "\n", " # C.I. assuming unknown variance - use t-distribution and effective sample size\n", " min_bound_normed = E_t_normed - cv * stddev_normed / np.sqrt(int(n_e))\n", " max_bound_normed = E_t_normed + cv * stddev_normed / np.sqrt(int(n_e))\n", "\n", " # Store result\n", " SNIPS_stat['Agent'].append(agent_id)\n", " SNIPS_stat['0.025'].append(min_bound_normed)\n", " SNIPS_stat['0.500'].append(E_t_normed)\n", " SNIPS_stat['0.975'].append(max_bound_normed)\n", " \n", " return pd.DataFrame().from_dict(IPS_stat), pd.DataFrame().from_dict(CIPS_stat), pd.DataFrame().from_dict(SNIPS_stat)" ], "execution_count": 31, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "bvr-NmAdYUSg" }, "source": [ "## Creating agents" ] }, { "cell_type": "markdown", "metadata": { "id": "MTAIbCveYWTy" }, "source": [ "### SVD agent" ] }, { "cell_type": "code", "metadata": { "id": "6xcs7TqWYSbI" }, "source": [ "class SVDAgent(Agent):\n", " def __init__(self, config, U = U, P = P, K = 5):\n", " super(SVDAgent, self).__init__(config)\n", " self.rng = RandomState(self.config.random_seed)\n", " assert(P >= K)\n", " self.K = K\n", " self.R = csr_matrix((U,P))\n", " self.V = np.zeros((P,K))\n", " self.user_history = np.zeros(P)\n", " \n", " def train(self, reco_log, U = U, P = P):\n", " # Extract all organic user logs\n", " reco_log = reco_log.loc[reco_log['z'] == 'organic']\n", " \n", " # Generate ratings matrix for training, row-based for efficient row (user) retrieval\n", " self.R = csr_matrix((np.ones(len(reco_log)),\n", " (reco_log['u'],reco_log['v'])),\n", " (U,P))\n", "\n", " # Singular Value Decomposition\n", " _, _, self.V = svds(self.R, k = self.K)\n", " \n", " def observe(self, observation):\n", " for session in observation.sessions():\n", " self.user_history[session['v']] += 1\n", "\n", " def act(self, observation, reward, done):\n", " \"\"\"Act method returns an Action based on current observation and past history\"\"\"\n", " self.observe(observation)\n", " scores = self.user_history.dot(self.V.T).dot(self.V)\n", " action = np.argmax(scores)\n", " prob = np.zeros_like(scores)\n", " prob[action] = 1.0\n", "\n", " return {\n", " **super().act(observation, reward, done),\n", " **{\n", " 'a': action,\n", " 'ps': prob[action],\n", " 'ps-a': prob,\n", " },\n", " }\n", "\n", " def reset(self):\n", " self.user_history = np.zeros(P)" ], "execution_count": 13, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "CwOGMwwyYjTQ" }, "source": [ "### Item-KNN agent" ] }, { "cell_type": "code", "metadata": { "id": "YHIXb-KHYejQ" }, "source": [ "class itemkNNAgent(Agent):\n", " def __init__(self, config, U = U, P = P, k = 5, greedy = False, alpha = 1):\n", " super(itemkNNAgent, self).__init__(config)\n", " self.rng = RandomState(self.config.random_seed)\n", " self.k = min(P,k)\n", " self.greedy = greedy\n", " self.alpha = alpha\n", " self.Rt = csr_matrix((P,U))\n", " self.user_history = np.zeros(P)\n", " self.S = np.eye(P)\n", " \n", " def train(self, reco_log, U = U, P = P):\n", " # Extract all organic user logs\n", " reco_log = reco_log.loc[reco_log['z'] == 'organic']\n", " \n", " # Generate ratings matrix for training, row-based for efficient row (user) retrieval\n", " self.R_t = csr_matrix((np.ones(len(reco_log)),\n", " (reco_log['v'],reco_log['u'])),\n", " (P,U))\n", "\n", " # Set up nearest neighbours module\n", " nn = NearestNeighbors(n_neighbors = self.k,\n", " metric = 'cosine')\n", "\n", " # Initialise placeholder for distances and indices\n", " distances = []\n", " indices = []\n", "\n", " # Dirty fix for multiprocessing backend being unable to pickle large objects\n", " nn.fit(self.R_t)\n", " distances, indices = nn.kneighbors(self.R_t, return_distance = True)\n", "\n", " # Precompute similarity matrix S\n", " data = list(chain.from_iterable(1.0 - distances))\n", " rows = list(chain.from_iterable([i] * self.k for i in range(P)))\n", " cols = list(chain.from_iterable(indices))\n", " \n", " # (P,P)-matrix with cosine similarities between items\n", " self.S = csr_matrix((data,(rows, cols)), (P,P)).todense()\n", " \n", " def observe(self, observation):\n", " for session in observation.sessions():\n", " self.user_history[session['v']] += 1\n", "\n", " def act(self, observation, reward, done):\n", " \"\"\"Act method returns an Action based on current observation and past history\"\"\"\n", " self.observe(observation)\n", " scores = self.user_history.dot(self.S).A1\n", " \n", " if self.greedy:\n", " action = np.argmax(scores)\n", " prob = np.zeros_like(scores)\n", " prob[action] = 1.0\n", " else:\n", " scores **= self.alpha\n", " prob = scores / np.sum(scores)\n", " action = self.rng.choice(self.S.shape[0], p = prob)\n", "\n", " return {\n", " **super().act(observation, reward, done),\n", " **{\n", " 'a': action,\n", " 'ps': prob[action],\n", " 'ps-a': prob,\n", " },\n", " }\n", "\n", " def reset(self):\n", " self.user_history = np.zeros(P)" ], "execution_count": 15, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "qK82zI6qYspo" }, "source": [ "### User-KNN agent" ] }, { "cell_type": "code", "metadata": { "id": "uoRDZ8mNYpLo" }, "source": [ "class userkNNAgent(Agent):\n", " def __init__(self, config, U = U, P = P, k = 5, greedy = False, alpha = 1):\n", " super(userkNNAgent, self).__init__(config)\n", " self.rng = RandomState(self.config.random_seed)\n", " self.k = min(P,k)\n", " self.greedy = greedy\n", " self.alpha = alpha\n", " self.U = U\n", " self.P = P\n", " self.R = csr_matrix((U,P))\n", " self.user_history = np.zeros(P)\n", " self.nn = NearestNeighbors(n_neighbors = self.k, metric = 'cosine')\n", " \n", " def train(self, reco_log, U = U, P = P):\n", " # Extract all organic user logs\n", " reco_log = reco_log.loc[reco_log['z'] == 'organic']\n", " \n", " # Generate ratings matrix for training, row-based for efficient row (user) retrieval\n", " self.R = csr_matrix((np.ones(len(reco_log)),\n", " (reco_log['u'],reco_log['v'])),\n", " (U,P))\n", "\n", " # Fit nearest neighbours\n", " self.nn.fit(self.R)\n", " \n", " def observe(self, observation):\n", " for session in observation.sessions():\n", " self.user_history[session['v']] += 1\n", "\n", " def act(self, observation, reward, done):\n", " \"\"\"Act method returns an Action based on current observation and past history\"\"\"\n", " self.observe(observation)\n", " \n", " # Get neighbouring users based on user history\n", " distances, indices = self.nn.kneighbors(self.user_history.reshape(1,-1))\n", " scores = np.add.reduce([dist * self.R[idx,:] for dist, idx in zip(distances,indices)])\n", " \n", " if self.greedy:\n", " action = np.argmax(scores)\n", " prob = np.zeros_like(scores)\n", " prob[action] = 1.0\n", " else:\n", " scores **= self.alpha\n", " prob = scores / np.sum(scores)\n", " action = self.rng.choice(self.P, p = prob)\n", "\n", " return {\n", " **super().act(observation, reward, done),\n", " **{\n", " 'a': action,\n", " 'ps': prob[action],\n", " 'ps-a': prob,\n", " },\n", " }\n", "\n", " def reset(self):\n", " self.user_history = np.zeros(P)" ], "execution_count": 16, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "CcVL6ih6Y0p8" }, "source": [ "### Agent initializations" ] }, { "cell_type": "code", "metadata": { "id": "wYX3_5fYYumd" }, "source": [ "# SVD Agent\n", "SVD_agent = SVDAgent(Configuration(env_1_args), U, P, 30)\n", "SVD_agent.train(reco_log)\n", "\n", "# item-kNN Agent\n", "itemkNN_agent = itemkNNAgent(Configuration(env_1_args), U, P, 500, greedy = True)\n", "itemkNN_agent.train(reco_log)\n", "\n", "# user-kNN Agent\n", "userkNN_agent = userkNNAgent(Configuration(env_1_args), U, P, 20, greedy = True)\n", "userkNN_agent.train(reco_log)\n", "\n", "# Generalised Popularity agent\n", "GPOP_agent = OrganicCount(Configuration({\n", " **env_1_args,\n", " 'select_randomly': True,\n", "}))\n", "\n", "# Generalised Popularity agent\n", "GPOP_agent_greedy = OrganicCount(Configuration({\n", " **env_1_args,\n", " 'select_randomly': False,\n", "}))\n", "\n", "# Peronalised Popularity agent\n", "PPOP_agent = OrganicUserEventCounterAgent(Configuration({\n", " **organic_user_count_args,\n", " **env_1_args,\n", " 'select_randomly': True,\n", "}))\n", "\n", "# Peronalised Popularity agent\n", "PPOP_agent_greedy = OrganicUserEventCounterAgent(Configuration({\n", " **organic_user_count_args,\n", " **env_1_args,\n", " 'select_randomly': False,\n", "}))\n", "\n", "# Random Agent\n", "random_args['num_products'] = P\n", "RAND_agent = RandomAgent(Configuration({**env_1_args, **random_args,}))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "qRqbHqMJY9vL" }, "source": [ "## Offline evaluation" ] }, { "cell_type": "markdown", "metadata": { "id": "E9r-zhlAZZ9M" }, "source": [ "### Generating test logs" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "pnzTANe1Y3lW", "outputId": "bb01cede-b583-4960-c01c-c0fea14cbc28" }, "source": [ "%%time\n", "# Placeholder for agents\n", "agents = {\n", " ' Random': RAND_agent,\n", " ' Popular': GPOP_agent_greedy,\n", " ' User-pop': PPOP_agent,\n", " ' SVD': SVD_agent,\n", " ' User-kNN': userkNN_agent,\n", " 'Item-kNN': itemkNN_agent,\n", "}\n", "agent_ids = sorted(list(agents.keys()))#['SVD','GPOP','PPOP','RAND']\n", "# Generate new logs, to be used for offline testing\n", "n_test_users = 5000 # U\n", "test_log = env.generate_logs(n_test_users)\n", "n_events = test_log.shape[0]\n", "n_organic = test_log.loc[test_log['z'] == 'organic'].shape[0]\n", "print('Testing on {0} organic and {1} bandit events'.format(n_organic, n_events - n_organic))" ], "execution_count": 18, "outputs": [ { "output_type": "stream", "text": [ "Organic Users: 0it [00:00, ?it/s]\n", "Users: 100%|██████████| 5000/5000 [05:38<00:00, 14.77it/s]\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Testing on 109631 organic and 403487 bandit events\n", "CPU times: user 6min 4s, sys: 4min 47s, total: 10min 51s\n", "Wall time: 5min 40s\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "igubir8vaB0m" }, "source": [ "### (Util) helper function to plot barchart" ] }, { "cell_type": "code", "metadata": { "id": "zCpbb9K-ZB-b" }, "source": [ "def plot_barchart(result, title, xlabel, col = 'tab:red', figname = 'fig.eps', size = (6,2), fontsize = 12):\n", " fig, axes = plt.subplots(figsize = size)\n", " plt.title(title, size = fontsize)\n", " n_agents = len(result)\n", " yticks = np.arange(n_agents)\n", " mean = result['0.500']\n", " lower = result['0.500'] - result['0.025']\n", " upper = result['0.975'] - result['0.500']\n", " plt.barh(yticks,\n", " mean,\n", " height = .25,\n", " xerr = (lower, upper),\n", " align = 'center',\n", " color = col,)\n", " plt.yticks(yticks, result['Agent'], size = fontsize)\n", " plt.xticks(size = fontsize)\n", " plt.xlabel(xlabel, size = fontsize)\n", " plt.xlim(.0,None)\n", " plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.2f'))\n", " plt.savefig(figname, bbox_inches = 'tight')\n", " plt.show()" ], "execution_count": 20, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "w9r6FSeDZdyM" }, "source": [ "### Leave-one-out evaluation" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 479 }, "id": "zJSfQ5RDZpab", "outputId": "57ae38fc-2143-4be7-9c89-8fef3257f9a1" }, "source": [ "%%time\n", "result_LOO = verify_agents_traditional(test_log, deepcopy(agents))\n", "display(result_LOO)\n", "plot_barchart(result_LOO, 'Evaluate on Organic Feedback', 'HR@1', 'tab:red', 'traditional_eval.eps')" ], "execution_count": 23, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/recogym/agents/organic_user_count.py:51: RuntimeWarning: invalid value encountered in true_divide\n", " action_proba = features / np.sum(features)\n" ], "name": "stderr" }, { "output_type": "display_data", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Agent0.0250.5000.975
0Random0.0000000.0000000.000000
1Popular0.0000000.0000000.000000
2User-pop0.0000000.0000000.000000
3SVD0.0745600.0768950.079231
4User-kNN0.0799270.0823360.084746
5Item-kNN0.0744610.0767950.079130
\n", "
" ], "text/plain": [ " Agent 0.025 0.500 0.975\n", "0 Random 0.000000 0.000000 0.000000\n", "1 Popular 0.000000 0.000000 0.000000\n", "2 User-pop 0.000000 0.000000 0.000000\n", "3 SVD 0.074560 0.076895 0.079231\n", "4 User-kNN 0.079927 0.082336 0.084746\n", "5 Item-kNN 0.074461 0.076795 0.079130" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } }, { "output_type": "stream", "text": [ "CPU times: user 6min 45s, sys: 43 s, total: 7min 28s\n", "Wall time: 4min 43s\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "KMEG7NLGZ5Kp" }, "source": [ "### IPS Estimators" ] }, { "cell_type": "code", "metadata": { "id": "RuQ5m3goeHnR" }, "source": [ "# Generate new logs, to be used for offline testing\n", "test_log_ppop = env.generate_logs(n_test_users, agent = deepcopy(PPOP_agent))" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 204 }, "id": "Gjj31s8ge6qZ", "outputId": "27bb9173-54e6-4cd7-d68a-deaeb939857f" }, "source": [ "test_log_ppop.head()" ], "execution_count": 26, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tuzvacpsps-a
00.00organic220<NA>NaNNaNNone
11.00organic867<NA>NaNNaNNone
22.00organic969<NA>NaNNaNNone
33.00organic730<NA>NaNNaNNone
44.00bandit<NA>9690.00.25()
\n", "
" ], "text/plain": [ " t u z v a c ps ps-a\n", "0 0.0 0 organic 220 NaN NaN None\n", "1 1.0 0 organic 867 NaN NaN None\n", "2 2.0 0 organic 969 NaN NaN None\n", "3 3.0 0 organic 730 NaN NaN None\n", "4 4.0 0 bandit 969 0.0 0.25 ()" ] }, "metadata": { "tags": [] }, "execution_count": 26 } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 975 }, "id": "SdKK38bSZ6Bl", "outputId": "e15dbd9e-cf50-44ff-ae9d-5f8ae7ec26b7" }, "source": [ "%%time\n", "cap = 15\n", "result_IPS, result_CIPS, result_SNIPS = verify_agents_counterfactual(test_log_ppop, deepcopy(agents), cap = cap)\n", "display(result_IPS)\n", "plot_barchart(result_IPS, 'IPS', 'CTR', 'tab:blue', 'bandit_eval_noclip.eps')\n", "display(result_CIPS)\n", "plot_barchart(result_CIPS, 'Clipped IPS', 'CTR', 'tab:blue', 'bandit_eval_clip{0}.eps'.format(cap))" ], "execution_count": 32, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:65: RuntimeWarning: invalid value encountered in double_scalars\n", "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:120: RuntimeWarning: invalid value encountered in double_scalars\n" ], "name": "stderr" }, { "output_type": "stream", "text": [ "Effective sample size for agent Random is 0\n", "Effective sample size for agent Popular is 0\n", "Effective sample size for agent User-pop is 0\n", "Effective sample size for agent SVD is 21243.67667925834\n", "Effective sample size for agent User-kNN is 33971.880634000554\n", "Effective sample size for agent Item-kNN is 39627.482103973896\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Agent0.0250.5000.975
0RandomNaN0.000000NaN
1PopularNaN0.000000NaN
2User-popNaN0.000000NaN
3SVD0.0053230.0130570.020791
4User-kNN0.0109220.0175130.024104
5Item-kNN0.0123140.0182500.024186
\n", "
" ], "text/plain": [ " Agent 0.025 0.500 0.975\n", "0 Random NaN 0.000000 NaN\n", "1 Popular NaN 0.000000 NaN\n", "2 User-pop NaN 0.000000 NaN\n", "3 SVD 0.005323 0.013057 0.020791\n", "4 User-kNN 0.010922 0.017513 0.024104\n", "5 Item-kNN 0.012314 0.018250 0.024186" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } }, { "output_type": "display_data", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Agent0.0250.5000.975
0RandomNaN0.000000NaN
1PopularNaN0.000000NaN
2User-popNaN0.000000NaN
3SVD0.0055900.0121600.018731
4User-kNN0.0108010.0167350.022670
5Item-kNN0.0121840.0175990.023014
\n", "
" ], "text/plain": [ " Agent 0.025 0.500 0.975\n", "0 Random NaN 0.000000 NaN\n", "1 Popular NaN 0.000000 NaN\n", "2 User-pop NaN 0.000000 NaN\n", "3 SVD 0.005590 0.012160 0.018731\n", "4 User-kNN 0.010801 0.016735 0.022670\n", "5 Item-kNN 0.012184 0.017599 0.023014" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaUAAACwCAYAAACvvnEyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXVUlEQVR4nO3debxVZb3H8c9XBbyIggppokDlkLN1MdMs8ZV11exqpV7TUMwcsvGaJl0nJDGyLCu1Qst5zrHBHF9YmkOQI5XlAKKIAiqTA9Pv/vE8Rxf77H04h7PP2Wsfvu/Xa7/cez1rPftZz6v48ay9WF9FBGZmZmWwWqMHYGZm1sJFyczMSsNFyczMSsNFyczMSsNFyczMSsNFyczMSsNFyawBJI2RdHl+P0TSAkmrN2AcUyXt0d3fa1aLi5JZF5F0sKRJueC8KOlWSbtW7hcRz0VEv4hY2ohx1iLpYkln5PfDJEU+lwW5mI0u7LuvpEckzZM0W9Ldkt7TuNFbs1qj0QMw64kkHQeMBo4BbgMWAXsC+wL3NnBonTUgIpZI2hm4S9IjwFPApcBngbuBfsAngVIVWWsOXimZ1Zmk/sBY4CsRcUNELIyIxRHx24g4ocr+LauQNfLniZK+J+mhvPK4WdJ6FfseJWlGXoEdX+hrNUmjJT0taY6ka1uOze0jJU3LbSet7DlGxP3AFGAbYAfg2Yi4K5L5EXF9RDy3sv3bqstFyaz+dgbWBG7sRB+HAl8E3g0sAX5a0b47sBlpRXJi4XehrwH7AbsBGwGvAucBSNoK+DkwMretD2zc0YEp+QiwNfAw8Dfg/ZJ+LGl3Sf062qdZCxcls/pbH5gdEUs60cdlEfFERCwETgEOrLgR4vS8AnscuAj4fN5+DHBSRDwfEW8BY4D98ypsf+B3EfGn3HYKsKyD45oNvAJcCIzOq6NngBHAYOBaYHb+PcrFyTrMvymZ1d8cYKCkNTpRmKYX3k8DegED22jfNr8fCtwoqVhslgIbkFZHbx8XEQslzenguAZWO6eIeAA4EEDSjsA1wEnAdzrYv63ivFIyq7/7gbdIl9FW1iaF90OAxaRVSq32Gfn9dGCviBhQeK0ZES8ALxaPk9SXtKqrq4j4K3AD6fcmsw5xUTKrs4iYC5wKnCdpP0l9JfWStJeks9rZzRckbZULx1jgNxW3jJ+S+90aOJy0MgH4BTBO0lAASYMk7ZvbfgPsI2lXSb1zv53+MyD3d6Skd+XP7wf+G3igs33bqsdFyawLRMTZwHHAycAs0grmq8BN7eziMuBiYCbppomvV7TfQ7oV+y7ghxFxe97+E+AW4HZJ80mFYac8pinAV4ArSaumV4HnO352rbxGKkKPS1oA/JF0k0d7C7DZ2+SQP7NykTQRuDwiLqzSNgx4FujVyRspzErJKyUzMysNFyUzMysNX74zM7PS8ErJzMxKw0XJzMxKw0906ISBAwfGsGHDGj0MM7OmMnny5NkRMaham4tSJwwbNoxJkyY1ehhmZk1F0rRabU19+a7ykf9mZtbc6lKUWiKVJY2SVIoAs0LB+kPF9ssljcnvR+R9zq/Y515Jo7pvtGZmBk2+UmqnnSTt0kb7QmBk/pfyZmbWQPW87LUl8AOgV37+1ZKIGCCpDzCO9Fj7PqRnYv1vRLwhaQRwOSnA7HjSI/a/TIqOPof0qP4fRsSZ7RmApM8BZwP7AAvy5rPy9+9e47DX8phOIz3Yst0ef2Euw0b/viOHmFlJTR3/qUYPwahvUfoHKWDsSxGxa2H7eOB9pMjkxaSHQZ7KOzkrG5IeODkYGAVcANwB/CfpkfyTJF0VEc+29eWSDiflt+wREU8VVj7nA1+XtEdE3Fnj8HHAvySNj4gn233GZiU388rRjR5C0xjxwA8aPYSmMXHixC7ru0sv30kScBRpZfRKRMwHzgQOKuy2GBgXEYuBq0mro59ExPz8VOO/A9uv4Ku+CZwAjIiIpyra3iAVnTNqHRwRM0mP/B/bjnM6StIkSZOWvj53RbubmVkHdPVda4OAvsDkVJ8AEFCMdZ5TyIl5I//3pUL7G0A/gHxZsMVWhfcnAGMjotZj+C8ETpD06TbG+n3gaUltFsCImABMAOjz7s38jCYrtQ0PHt/oITSNib58Vwr1LkqVf0jPJhWVrXPyZec6j+hX/Fy4RPdJ4I+SZkbE9VWOWyTpdOC7wJQafc+RdE7ex8zMGqDeReklYGNJvSNiUUQsk3QB8GNJX42IlyUNBraJiNvq+L1TgD2B2yQtjohbquxzGTA67/fvGv38CHiGtJpboW0H92eS/3ZlZlY39f5N6W5SgZgpaXbediIpIfMBSfOAO4Et6vy9RMSjpLvuLpC0V5X2paQbLNZro495pLv1au5jZmZdx9EVnTB8+PDwY4bMzDpG0uSIGF6tbVX4x7NmZtYkXJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0ekRRyumxmzZ6HGZm1jn1ikNviR5fo2L7xZJqRkZ0hzyuxyWtVth2hqSL8/sVxqabmVn3aLqVkqTVV7xXKxuxfIZTNSuKTTczsy7W1XlKb8uX137FOwm0d0XE/+S29wM/I6XNzgJOiYhrc9vFpPiLocBuwL6kh7rW+p5dgauAkRExMW8+Czhd0rURsaTGoSuKTW/Fcehm5eV48+bUbUWJlFN0O+kP/d7AcABJa5Hiz08F9gK2Be6Q9ERE/D0fezCwN+kp4L1rfYGkPUlx6p+LiIcKTTcAB5Li1i+scXh7YtPNupwjzOvD8eb105Xx55W68/LdYtJqZ6OIeDMi7s3b9wGmRsRFEbEkIh4GrgcOKBx7c0TcFxHLIuLNGv0fAPwS2KuiIEEKHzwFOEVSraK2wth0cBy6mVlXqtdKqeWSWK/C+5bPi/P7b5NWSw9JehU4OyJ+TSpUO0l6rWJclxU+T295I2lKPgZSAfpzfv9N4NKIeKLaACPiD5KeB45u4zxWGJvuOHTrao4wrw/HmzenehWlF0nFZxjwj8L295AuzRERM4Ej4e3ffe6U9CdSwbknIj7RRv9v/+EfEVvX2OcA4FeSno+In9TY5yTS701XVf2SdsSmm5lZ16lLUYqIpZKuB8ZJOhKYB+wPbAXcCiDpAOD+iHgeeJVUaJYBvwPGSxoJXJ273AFYEBH/oP1mAB8HJkpaFBE/rzLOiZKeAA4Dflujn/bEpgOOQzczq7d6/qZ0LPAK8BjwMvBV4FMR8VJu3xF4UNIC4BbgGxHxTETMBz5JumV7BjAT+D7Qp6MDiIjnSIVptKQv1djtZNqORF9hbLqZmXUNx6F3guPQzcw6znHoZmbWFFyUzMysNFyUzMysNFyUzMysNFyUzMysNFyUzMysNFyUzMysNFyUzMysNFyUzMysNHpEUZJ0hKR/Spov6SVJf5C0tqTR+aGvlfsPlLRI0jaSRklaKmlBfj0r6SJJmzfiXMzMVmVNX5Qk7QacCXw+ItYGtgSuyc2XA7tIek/FYQcBjxdiLu6PiH5Af2APUrbSZEnbdPkJmJnZ27ozebar7EgqKg8DRMQrwCW5bb6ku4GRwNjCMYcCl1Z2lB/G+jRwrKQhwBjS086rchy6lZ0jwa3Z9ISi9CDw3ZyDdDswKSLeKrRfQiouYwEkbUGKxljR/1tvAL5X99Ga4767kSPBu093Rob3ZE1/+S4nz34W+CDwe2COpB9JWj3vciOwgaRd8udDgVsjYtYKup5BlfgKx6GbmXWdnrBSIiJuBW6VtBqwO3Ad8CTwy4h4XdJ1wKGS7gcOAb7Vjm4Hk/KhKr/Lceid5Ljv7uNIcGs2PaIotYiIZcBd+Xek4k0KlwA3kS7JrU3t1NmizwB/bmsHJ8+amdVX0xclSfsC/wHcBrxGuvFhN+Cbhd3+nNsmAFdHxKIafa0ODAGOA0YAO3fZwM3MrJWm/00JeBU4Evg3MI90G/gPIuKKlh0ixeteCgylyl13wM45pn0eMBFYB9gxIh7v2qGbmVmR49A7wXHoZmYd5zh0MzNrCi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGnUtSjla/N4q26dK2qOe32VmZj1PU66UJDX9g2TNzKy1bi9KkvaW9HdJ8yW9IOn4Qts+kh6R9Jqkv0jartA2VdKJkh4DFlYWJkkjJD0v6f8kzc77H1Jo7y/pUkmzJE2TdHLOX2pZ4d0n6VxJcyX9U9LHu2E6zMysoBErpV8BR0fE2qTMo7sBJH0A+DVwNLA+8EvgFkl9Csd+nhRjPiAillTpe0NgICmg7zBgQo4/B/gZ0B94Lyna4lDg8MKxOwFP5+NPA26Q1Cp51szMuk4jitJiYCtJ60TEqxHxt7z9KFJS7IMRsTQiLgHeAj5cOPanETE9It5oo/9TIuKtiLiHFI9+YM5JOgj4TkTMj4ipwNnAyMJxLwPnRMTiiLiGlFzbKsGvGIc+a9aKEtXNzKwj6l2UlgC9qmzvRSpGAJ8D9gamSbpHUkuQ3lDgW/nS3WuSXgM2ATYq9DMdQNIQSQtaXoX2VyNiYeHztHz8wDyGaRVtgwufX4jlczxajl1OREyIiOERMXzQoEHV5sDMzFZSvYvSc8AQSWrZIKkv8C5yQYiIv0bEvnnbTcC1edfpwLiIGFB49Y2Iqwr9R+7juYjo1/IqtK8raa3C5yHADGA2qSgOrWh7ofB5cHHchWPNzKyb1LsoPQi8CYyWtGYuEOOBSaSVUW9Jh0jqHxGLSUmvy/KxFwDHSNpJyVqSPiVp7Q6O4fT8PR8F9gGui4ilpOI3TtLakoaSIs8vLxz3LuDrknpJOgDYEvjDyk2DmZmtjLoWpYh4i/Q7zAjgeeAZ0iWwAwuXxkYCUyXNA44BDsnHTiLFmp9Lijh/ChjVwSHMzMfOAK4AjomIf+a2rwEL85juBa4k3VjR4kFgM9Kqahywf0TM6eD3m5lZJ/SYOHRJI4DLI2LjlTh2FPCliNi1I8c5Dt3MrOMch25mZk3BRcnMzEqjxxSliJi4Mpfu8rEXd/TSnZmZ1V+PKUpmZtb8XJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0VpmiJCkkbdrocZiZWW2NCPkbIWlZfsL3fElPSjp8xUeamVlP16iV0oz8dO91gBOBCyRt1aCxtMnR62Zm3aehl+8iuYn0ENWtJPWRdI6kGfl1TkvybDvizidK+lLh8yhJ91b73vz08YclzZM0XdKYQtuwfKnvCEnPkZNxzcys6zW0KElaTdJngAHA48BJpKTZHYDtgQ8BJxcOaSvuvCMWkuLQB5Ceav5lSftV7LMbKb7iv1aifzMzWwmNKkob5WTZ2cBpwMiIeJIUYzE2Il6OiFnA6SwfWQ5V4s47+uX5kUSPR8SyiHgMuIpUhIrGRMTCyuh1x6GbmXWdRv1eMqPGc+o2onVkeTGSvFbceYdI2okUPrgN0BvoA1xXsdv0asdGxARgAqToio5+t5mZ1Va2W8Jn0DqyvBhJXivuHNIlub6Ftg3b+J4rgVuATSKiP/ALQBX7uOCYmXWzshWlq4CTJQ2SNBA4leUjy6FK3Hne/gjwWUl9879HOqKN71kbeCUi3pT0IeDg+p6GmZmtjLLd7nwG6Tbxx/Ln6/K2FsW489dZPu78x8COwEv5+CuAPWp8z7HA2ZLOBe4BriXd9GBmZg3UNHHonYk77yqOQzcz6zjHoZuZWVNwUTIzs9Io229KNUXERKA0l+7MzKz+vFIyM7PScFEyM7PScFEyM7PScFEyM7PScFEyM7PS6FFFyZHnZmbNrUuLUg7aW5qjz+dJelTSPl35nWZm1ry6Y6V0f44+HwCcD1wtyc+ZMzOzVrrt8l1ELAMuA9YCNgOQ9D5Jd0uakyPOrygWrBx5frykxyTNlXSNpDUL7SdIejFHp3+x+H2S+ku6VNIsSdMknSxptdw2StJ9kn4s6TVJz0jaJW+fLullSYd1y8SYmdnbuq0oSVodOBxYzDtBfgK+Rwrq2xLYBBhTceiBwJ7Ae4DtgFG5vz2B44FPkIpc5RPBfwb0B95LSpU9NH9/i51ITxNfn5SvdDXpKeObAl8AzpXUb6VP2MzMOqw7itKHc/T5m8APgS9ExMsAEfFURNyR481nAT+idSz5TyNiRkS8AvwW2CFvPxC4KCKeyGm0Y1oOyAXwIOA7ETE/IqYCZ7N8tPqzEXFRRCwFriEVxLF5LLcDi0gFajmOQzcz6zrdUZQeiIgBwLqktNePtjRI2kDS1ZJekDSPFOg3sOL4mYX3rwMtq5eNWD6yvBijPhDoReto9cGFzy8V3r8BEBGV21qtlCJiQkQMj4jhgwYNqmw2M7NO6M7flBYAXwZGSvpA3nwmKXZ824hYh3TZrDKWvJYXSaubFkMK72eTLhNWRqu/sBJDNzOzbtKt/04pX4K7kBRzDimWfAEwV9Jg4IQOdHctMErSVpL6AqcVvmdpbh8naW1JQ4HjaB2tbmZmJdKIfzx7DrC3pO2A04EPAnOB3wM3tLeTiLg193U38FT+b9HXgIXAM8C9pJsZft3ZwZuZWddpmjj0MnIcuplZxzkO3czMmoJXSp0gaT7wZKPHUUIDSTeb2Ds8J615TqpbFeZlaERUvX25aeLQS+rJWkvQVZmkSZ6X5XlOWvOcVLeqz4sv35mZWWm4KJmZWWm4KHXOhEYPoKQ8L615TlrznFS3Ss+Lb3QwM7PS8ErJzMxKw0XJzMxKw0WpgqT1JN0oaWEOBzy4xn6S9P0cUDgnv1ehfQdJkyW9nv+7Q7V+mkEd52SCpCclLZM0qttOoIt4XlrznLRWjzmRtLmkm5VCS1+RdJukLbr3TLqHi1Jr55GylDYADgF+LmnrKvsdBewHbE8KH/w0cDSApN7AzaQHwK4LXALcnLc3o07PSfYocCzwty4dbffxvLTmOWmtHnMygBT9s0Xu5yHSnzE9T0T4lV+kqPZFwOaFbZcB46vs+xfgqMLnI0jZUQCfJMVkqND+HLBno8+xUXNSsd+9wKhGn5vnxXPSjHOS29Yjxf6s3+hzrPfLK6XlbQ4siYh/FbY9ClT7W83Wua3aflsDj0X+X0/2WI1+yq5ec9LTeF5a85y01lVz8jFgZkTMqcsoS8RFaXn9gHkV2+aScp+q7Tu3Yr9++RpwZVtb/ZRdveakp/G8tOY5aa3ucyJpY9IlwePqOM7ScFFa3gJgnYpt6wDz27HvOsCCvDrqSD9lV6856Wk8L615Tlqr65xIGgTcDpwfEVfVeayl4KK0vH8Ba0jarLBte2BKlX2n5LZq+00Btqv4G852Nfopu3rNSU/jeWnNc9Ja3eZE0rqkgnRLRIzrgrGWQ6N/1CrbC7gauIr0A+VHSEvoravsdwzwD2AwsBHpfzzH5LbewDTgG0Af4Kv5c+9Gn1+j5qQwL2sC9wFH5verNfr8PC+ek7LPCWnV9BBwbqPPp8vnq9EDKNuLdFfLTaQo9eeAg/P2j5KW0i37CTgLeCW/zmL5u+0+AEwG3iDd1vqBRp9bCeZkIumOoeJrRKPPz/PiOSn7nACH5TlYSLrM1/Ia0ujzq/fLz74zM7PS8G9KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZk1G0sGSJklaIOlFSbdKujN/XiBpkaTFhc+3ShomKQrbpkoa3ehzMau0RqMHYGbtJ+k4YDTpkTS3kbJ69gQ+FhF75H3GAJtGxBcKxw3LbwdExBJJw4F7JE2OiDu67wzM2uaiZNYkJPUHxgKHR8QNhabf5le7RcQkSVOAHQAXJSsNX74zax47kx5MemNnO5L0YWAb4KnO9mVWT14pmTWP9YHZEbGkE33MltSHVNzOJj0o1Kw0vFIyax5zgIGSOvOXyYGkhNNvASOAXnUYl1nduCiZNY/7gbeA/TrTSUQsjYgfAW8Cx9ZjYGb14qJk1iQiYi5wKnCepP0k9ZXUS9Jeks5aiS7HA9+WtGZ9R2q28lyUzJpIRJwNHAecDMwCppOSjVfmt6HfA6+Skl3NSsEhf2ZmVhpeKZmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWn8P6JaP70kYUMaAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } }, { "output_type": "stream", "text": [ "CPU times: user 45min 36s, sys: 2min 15s, total: 47min 51s\n", "Wall time: 29min 38s\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "NqDH9cq_b4vh" }, "source": [ "### A/B tests" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DsK98lKNb6RB", "outputId": "d4596fb8-1674-4c67-e0ec-13dbc52e67dd" }, "source": [ "%%time\n", "result_AB = verify_agents(env, n_test_users, deepcopy(agents))\n", "display(result_AB)\n", "plot_barchart(result_AB, 'A/B-test', 'CTR', 'tab:green', 'ABtest_eval.eps')" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Organic Users: 0it [00:00, ?it/s]\n", "Users: 100%|██████████| 5000/5000 [05:39<00:00, 14.71it/s]\n", "Organic Users: 0it [00:00, ?it/s]\n", "Users: 100%|██████████| 5000/5000 [05:35<00:00, 14.92it/s]\n", "Organic Users: 0it [00:00, ?it/s]\n", "Users: 100%|██████████| 5000/5000 [06:49<00:00, 12.20it/s]\n", "Organic Users: 0it [00:00, ?it/s]\n", "Users: 100%|██████████| 5000/5000 [06:39<00:00, 12.53it/s]\n", "Organic Users: 0it [00:00, ?it/s]\n", "Users: 83%|████████▎ | 4163/5000 [18:04<03:51, 3.62it/s]" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "id": "ieFcpaa-b8Tt" }, "source": [ "def combine_barchart(resultAB, resultCIPS, title, xlabel, figname = 'fig.eps', size = (6,2), fontsize = 12):\n", " fig, axes = plt.subplots(figsize = size)\n", " plt.title(title, size = fontsize)\n", " n_agents = len(resultAB)\n", " \n", " for i, (name, colour, result) in enumerate([('A/B-test', 'tab:green', result_AB),('CIPS', 'tab:blue', result_CIPS)]):\n", " mean = result['0.500']\n", " lower = result['0.500'] - result['0.025']\n", " upper = result['0.975'] - result['0.500']\n", " height = .25\n", " yticks = [a + i * height for a in range(n_agents)]\n", " plt.barh(yticks,\n", " mean,\n", " height = height,\n", " xerr = (lower, upper),\n", " align = 'edge',\n", " label = name,\n", " color = colour)\n", " plt.yticks(yticks, result['Agent'], size = fontsize)\n", " plt.xticks(size = fontsize)\n", " plt.xlabel(xlabel, size = fontsize)\n", " plt.legend(loc = 'lower right')\n", " plt.xlim(.0,None)\n", " plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.3f'))\n", " plt.savefig(figname, bbox_inches = 'tight')\n", " plt.show()\n", "combine_barchart(result_AB, result_CIPS, 'Evaluate on Bandit Feedback', 'CTR', 'ABtest_CIPS.eps')" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Q4NB72lgb-Ek" }, "source": [ "plot_barchart(result_LOO, 'Evaluate on Organic Feedback', 'HR@1', 'tab:red', 'traditional_eval.eps')" ], "execution_count": null, "outputs": [] } ] }