{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "2021-06-17-recostep-tutorial-offline-replayer-eval-recogym.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMo+rs0MFOnygJoYfsqfz0h",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B7BGtRmWYACb"
},
"source": [
"## Environment setup"
]
},
{
"cell_type": "code",
"metadata": {
"id": "lE0TYT5zVVjy"
},
"source": [
"!pip install -q recogym"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "uSUfDrbZywvc"
},
"source": [
"# Offline Replayer Evaluation I - Recogym\n",
"> Running recogym for offline simulation and evaluation\n",
"\n",
"- toc: true\n",
"- badges: true\n",
"- comments: true\n",
"- categories: [bandit]\n",
"- image: "
]
},
{
"cell_type": "code",
"metadata": {
"id": "cUH6Vc1OU89n"
},
"source": [
"import numpy as np\n",
"from numpy.random.mtrand import RandomState\n",
"from scipy.special import logsumexp\n",
"import scipy\n",
"import pandas as pd\n",
"from scipy.stats.distributions import beta\n",
"from copy import deepcopy\n",
"\n",
"from scipy.sparse import csr_matrix\n",
"from scipy.sparse.linalg import svds\n",
"\n",
"from itertools import chain\n",
"from sklearn.neighbors import NearestNeighbors\n",
"from IPython.display import display, HTML\n",
"\n",
"from matplotlib.ticker import FormatStrFormatter\n",
"\n",
"import gym, recogym\n",
"from recogym import env_1_args, Configuration\n",
"from recogym.agents import OrganicUserEventCounterAgent, organic_user_count_args\n",
"from recogym.agents.organic_count import OrganicCount, organic_count_args, to_categorical\n",
"from recogym import Configuration\n",
"from recogym.agents import Agent\n",
"from recogym.envs.observation import Observation\n",
"from recogym.agents import RandomAgent, random_args\n",
"from recogym import verify_agents, verify_agents_IPS\n",
"from recogym.evaluate_agent import plot_verify_agents, verify_agents_recall_at_k\n",
"\n",
"from recogym.envs.session import OrganicSessions\n",
"from recogym.envs.context import DefaultContext\n",
"from recogym.envs.observation import Observation\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"P = 2000 # Number of Products\n",
"U = 2000 # Number of Users"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vk3eavl0VN_n"
},
"source": [
"# You can overwrite environment arguments here\n",
"env_1_args['random_seed'] = 42\n",
"env_1_args['num_products']= P\n",
"env_1_args['phi_var']=0.0\n",
"env_1_args['number_of_flips']=P//2\n",
"env_1_args['sigma_mu_organic'] = 0.1\n",
"env_1_args['sigma_omega']=0.05"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "8ShuzevnVI7b"
},
"source": [
"# Initialize the gym for the first time by calling .make() and .init_gym()\n",
"env = gym.make('reco-gym-v1')\n",
"env.init_gym(env_1_args)"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "p3ZfkuZcVMZT"
},
"source": [
"# env.reset()"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 703
},
"id": "zdnvCn74VnMn",
"outputId": "b57ab8a8-9fc6-4982-8995-3a41398bfc95"
},
"source": [
"# Generate RecSys logs for U users\n",
"reco_log = env.generate_logs(U)\n",
"reco_log.head(20)"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Organic Users: 0it [00:00, ?it/s]\n",
"Users: 100%|██████████| 2000/2000 [02:15<00:00, 14.73it/s]\n"
],
"name": "stderr"
},
{
"output_type": "execute_result",
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" t | \n",
" u | \n",
" z | \n",
" v | \n",
" a | \n",
" c | \n",
" ps | \n",
" ps-a | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
" 0 | \n",
" organic | \n",
" 116 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 1 | \n",
" 1.0 | \n",
" 0 | \n",
" bandit | \n",
" <NA> | \n",
" 1123 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 2 | \n",
" 2.0 | \n",
" 0 | \n",
" bandit | \n",
" <NA> | \n",
" 1332 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 3 | \n",
" 3.0 | \n",
" 0 | \n",
" bandit | \n",
" <NA> | \n",
" 805 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 4 | \n",
" 4.0 | \n",
" 0 | \n",
" bandit | \n",
" <NA> | \n",
" 1184 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 5 | \n",
" 0.0 | \n",
" 1 | \n",
" organic | \n",
" 1205 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 6 | \n",
" 1.0 | \n",
" 1 | \n",
" organic | \n",
" 1137 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 7 | \n",
" 2.0 | \n",
" 1 | \n",
" organic | \n",
" 1337 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 8 | \n",
" 3.0 | \n",
" 1 | \n",
" organic | \n",
" 972 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 9 | \n",
" 4.0 | \n",
" 1 | \n",
" organic | \n",
" 841 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 10 | \n",
" 5.0 | \n",
" 1 | \n",
" organic | \n",
" 140 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 11 | \n",
" 6.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 1367 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 12 | \n",
" 7.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 1086 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 13 | \n",
" 8.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 1698 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 14 | \n",
" 9.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 1513 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 15 | \n",
" 10.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 134 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 16 | \n",
" 11.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 1056 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 17 | \n",
" 12.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 1751 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 18 | \n",
" 13.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 725 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
" 19 | \n",
" 14.0 | \n",
" 1 | \n",
" bandit | \n",
" <NA> | \n",
" 612 | \n",
" 0.0 | \n",
" 0.0005 | \n",
" () | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" t u z v a c ps ps-a\n",
"0 0.0 0 organic 116 NaN NaN None\n",
"1 1.0 0 bandit 1123 0.0 0.0005 ()\n",
"2 2.0 0 bandit 1332 0.0 0.0005 ()\n",
"3 3.0 0 bandit 805 0.0 0.0005 ()\n",
"4 4.0 0 bandit 1184 0.0 0.0005 ()\n",
"5 0.0 1 organic 1205 NaN NaN None\n",
"6 1.0 1 organic 1137 NaN NaN None\n",
"7 2.0 1 organic 1337 NaN NaN None\n",
"8 3.0 1 organic 972 NaN NaN None\n",
"9 4.0 1 organic 841 NaN NaN None\n",
"10 5.0 1 organic 140 NaN NaN None\n",
"11 6.0 1 bandit 1367 0.0 0.0005 ()\n",
"12 7.0 1 bandit 1086 0.0 0.0005 ()\n",
"13 8.0 1 bandit 1698 0.0 0.0005 ()\n",
"14 9.0 1 bandit 1513 0.0 0.0005 ()\n",
"15 10.0 1 bandit 134 0.0 0.0005 ()\n",
"16 11.0 1 bandit 1056 0.0 0.0005 ()\n",
"17 12.0 1 bandit 1751 0.0 0.0005 ()\n",
"18 13.0 1 bandit 725 0.0 0.0005 ()\n",
"19 14.0 1 bandit 612 0.0 0.0005 ()"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xbB6tEBPVuDZ",
"outputId": "22d85507-4855-4552-902b-9d8b5afe4e0e"
},
"source": [
"n_events = reco_log.shape[0]\n",
"n_organic = reco_log.loc[reco_log['z'] == 'organic'].shape[0]\n",
"print('Training on {0} organic and {1} bandit events'.format(n_organic, n_events - n_organic))"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"Training on 43753 organic and 159967 bandit events\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gX5Gve7eYG_r"
},
"source": [
"## Defining evaluation methods"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "485rufJ-X8Ju"
},
"source": [
"### Traditional evaluation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "44GbKe76WZAV"
},
"source": [
"def leave_one_out(reco_log, agent, last = False, N = 1, folds = 10):\n",
" # 1. Extract all organic events\n",
" reco_log = reco_log.loc[reco_log['z'] == 'organic']\n",
" \n",
" # 2. For every user sequence - randomly sample out an item\n",
" hits = []\n",
" for _ in range(folds):\n",
" user_id = 0\n",
" history = []\n",
" session = OrganicSessions()\n",
" agent.reset()\n",
" for row in reco_log.itertuples():\n",
" # If we have a new user\n",
" if row.u != user_id:\n",
" if last:\n",
" # Sample out last item\n",
" index = len(history) - 1\n",
" else:\n",
" # Sample out a random item from the history\n",
" index = np.random.choice(len(history),\n",
" replace = False)\n",
" test = history[index]\n",
" train = history[:index] + history[index + 1:]\n",
"\n",
" # 3. Recreate the user sequence without these items - Let the agent observe the incomplete sequence\n",
" for t, v in list(train):\n",
" session.next(DefaultContext(t, user_id), int(v))\n",
"\n",
" # 4. Generate a top-N set of recommendations by letting the agent act\n",
" # TODO - For now only works for N = 1\n",
" try:\n",
" prob_a = agent.act(Observation(DefaultContext(t + 1, user_id), session), 0, False)['ps-a']\n",
" except:\n",
" prob_a = [1 / P] * P\n",
"\n",
" # 5. Compute metrics checking whether the sampled test item is in the top-N\n",
" try:\n",
" hits.append(np.argmax(prob_a) == int(test[1]))\n",
" except:\n",
" hits.append(0)\n",
"\n",
" # Reset variables\n",
" user_id = row.u\n",
" history = []\n",
" session = OrganicSessions()\n",
" agent.reset()\n",
"\n",
" # Save the organic interaction to the running average for the session\n",
" history.append((row.t,row.v))\n",
" \n",
" # Error analysis\n",
" mean_hits = np.mean(hits)\n",
" serr_hits = np.std(hits) / np.sqrt(len(hits))\n",
" low_bound = mean_hits - 1.96 * serr_hits\n",
" upp_bound = mean_hits + 1.96 * serr_hits\n",
" \n",
" return mean_hits, low_bound, upp_bound\n",
"\n",
"def verify_agents_traditional(reco_log, agents, last = False, N = 1, folds = 10):\n",
" # Placeholder DataFrame for result\n",
" stat = {\n",
" 'Agent': [],\n",
" '0.025': [],\n",
" '0.500' : [],\n",
" '0.975': [],\n",
" }\n",
"\n",
" # For every agent\n",
" for agent_id in agents:\n",
" # Compute HR@k\n",
" mean, low, upp = leave_one_out(reco_log, agents[agent_id], last = last, N = N, folds = folds)\n",
" stat['Agent'].append(agent_id)\n",
" stat['0.025'].append(low)\n",
" stat['0.500'].append(mean)\n",
" stat['0.975'].append(upp)\n",
" return pd.DataFrame().from_dict(stat)"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ka70n5JcYPYE"
},
"source": [
"### Counterfactual evaluation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_HoADS4uX54y"
},
"source": [
"def compute_ips_weights(agent, reco_log):\n",
" # Placeholder for return values\n",
" rewards = [] # Labels for actions\n",
" t_props = [] # Treatment propensities\n",
" l_props = [] # Logging propensities\n",
" \n",
" # For every logged interaction\n",
" user_id = 0\n",
" session = OrganicSessions()\n",
" agent.reset()\n",
" for row in reco_log.itertuples():\n",
" # If we have a new user\n",
" if row.u != user_id:\n",
" # Reset\n",
" session = OrganicSessions()\n",
" agent.reset()\n",
" user_id = row.u\n",
" \n",
" # If we have an organic event\n",
" if row.z == 'organic':\n",
" session.next(DefaultContext(row.t, row.u), int(row.v)) \n",
" \n",
" else:\n",
" prob_a = agent.act(Observation(DefaultContext(row.t, row.u), session), 0, False)['ps-a']\n",
" rewards.append(row.c)\n",
" try:\n",
" t_props.append(prob_a[int(row.a)])\n",
" except:\n",
" t_props.append(0)\n",
" l_props.append(row.ps)\n",
" session = OrganicSessions()\n",
" \n",
" return np.asarray(rewards), np.asarray(t_props), np.asarray(l_props)\n",
"\n",
"def verify_agents_counterfactual(reco_log, agents, cap = 3):\n",
" # Placeholder DataFrame for results\n",
" IPS_stat = {\n",
" 'Agent': [],\n",
" '0.025': [],\n",
" '0.500' : [],\n",
" '0.975': [],\n",
" }\n",
" CIPS_stat = {\n",
" 'Agent': [],\n",
" '0.025': [],\n",
" '0.500' : [],\n",
" '0.975': [],\n",
" }\n",
" SNIPS_stat = {\n",
" 'Agent': [],\n",
" '0.025': [],\n",
" '0.500' : [],\n",
" '0.975': [],\n",
" }\n",
"\n",
" # For every agent\n",
" for agent_id in agents:\n",
" # Get the rewards and propensities\n",
" rewards, t_props, l_props = compute_ips_weights(agents[agent_id], reco_log)\n",
" \n",
" # Compute the sample weights - propensity ratios\n",
" p_ratio = t_props / l_props\n",
"\n",
" # Effective sample size for E_t estimate (from A. Owen)\n",
" n_e = len(rewards) * (np.mean(p_ratio) ** 2) / (p_ratio ** 2).mean()\n",
" n_e = 0 if np.isnan(n_e) else n_e\n",
" print(\"Effective sample size for agent {} is {}\".format(str(agent_id), n_e))\n",
" \n",
" # Critical value from t-distribution as we have unknown variance\n",
" alpha = .00125\n",
" cv = scipy.stats.t.ppf(1 - alpha, df = int(n_e) - 1)\n",
" \n",
" ###############\n",
" # VANILLA IPS #\n",
" ###############\n",
" # Expected reward for pi_t\n",
" E_t = np.mean(rewards * p_ratio)\n",
"\n",
" # Variance of the estimate\n",
" var = ((rewards * p_ratio - E_t) ** 2).mean()\n",
" stddev = np.sqrt(var)\n",
" \n",
" # C.I. assuming unknown variance - use t-distribution and effective sample size\n",
" min_bound = E_t - cv * stddev / np.sqrt(int(n_e))\n",
" max_bound = E_t + cv * stddev / np.sqrt(int(n_e))\n",
" \n",
" # Store result\n",
" IPS_stat['Agent'].append(agent_id)\n",
" IPS_stat['0.025'].append(min_bound)\n",
" IPS_stat['0.500'].append(E_t)\n",
" IPS_stat['0.975'].append(max_bound)\n",
" \n",
" ############## \n",
" # CAPPED IPS #\n",
" ##############\n",
" # Cap ratios\n",
" p_ratio_capped = np.clip(p_ratio, a_min = None, a_max = cap)\n",
" \n",
" # Expected reward for pi_t\n",
" E_t_capped = np.mean(rewards * p_ratio_capped)\n",
"\n",
" # Variance of the estimate\n",
" var_capped = ((rewards * p_ratio_capped - E_t_capped) ** 2).mean()\n",
" stddev_capped = np.sqrt(var_capped) \n",
" \n",
" # C.I. assuming unknown variance - use t-distribution and effective sample size\n",
" min_bound_capped = E_t_capped - cv * stddev_capped / np.sqrt(int(n_e))\n",
" max_bound_capped = E_t_capped + cv * stddev_capped / np.sqrt(int(n_e))\n",
" \n",
" # Store result\n",
" CIPS_stat['Agent'].append(agent_id)\n",
" CIPS_stat['0.025'].append(min_bound_capped)\n",
" CIPS_stat['0.500'].append(E_t_capped)\n",
" CIPS_stat['0.975'].append(max_bound_capped)\n",
" \n",
" ##############\n",
" # NORMED IPS #\n",
" ##############\n",
" # Expected reward for pi_t\n",
" E_t_normed = np.sum(rewards * p_ratio) / np.sum(p_ratio)\n",
"\n",
" # Variance of the estimate\n",
" var_normed = np.sum(((rewards - E_t_normed) ** 2) * (p_ratio ** 2)) / (p_ratio.sum() ** 2) \n",
" stddev_normed = np.sqrt(var_normed)\n",
"\n",
" # C.I. assuming unknown variance - use t-distribution and effective sample size\n",
" min_bound_normed = E_t_normed - cv * stddev_normed / np.sqrt(int(n_e))\n",
" max_bound_normed = E_t_normed + cv * stddev_normed / np.sqrt(int(n_e))\n",
"\n",
" # Store result\n",
" SNIPS_stat['Agent'].append(agent_id)\n",
" SNIPS_stat['0.025'].append(min_bound_normed)\n",
" SNIPS_stat['0.500'].append(E_t_normed)\n",
" SNIPS_stat['0.975'].append(max_bound_normed)\n",
" \n",
" return pd.DataFrame().from_dict(IPS_stat), pd.DataFrame().from_dict(CIPS_stat), pd.DataFrame().from_dict(SNIPS_stat)"
],
"execution_count": 31,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "bvr-NmAdYUSg"
},
"source": [
"## Creating agents"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MTAIbCveYWTy"
},
"source": [
"### SVD agent"
]
},
{
"cell_type": "code",
"metadata": {
"id": "6xcs7TqWYSbI"
},
"source": [
"class SVDAgent(Agent):\n",
" def __init__(self, config, U = U, P = P, K = 5):\n",
" super(SVDAgent, self).__init__(config)\n",
" self.rng = RandomState(self.config.random_seed)\n",
" assert(P >= K)\n",
" self.K = K\n",
" self.R = csr_matrix((U,P))\n",
" self.V = np.zeros((P,K))\n",
" self.user_history = np.zeros(P)\n",
" \n",
" def train(self, reco_log, U = U, P = P):\n",
" # Extract all organic user logs\n",
" reco_log = reco_log.loc[reco_log['z'] == 'organic']\n",
" \n",
" # Generate ratings matrix for training, row-based for efficient row (user) retrieval\n",
" self.R = csr_matrix((np.ones(len(reco_log)),\n",
" (reco_log['u'],reco_log['v'])),\n",
" (U,P))\n",
"\n",
" # Singular Value Decomposition\n",
" _, _, self.V = svds(self.R, k = self.K)\n",
" \n",
" def observe(self, observation):\n",
" for session in observation.sessions():\n",
" self.user_history[session['v']] += 1\n",
"\n",
" def act(self, observation, reward, done):\n",
" \"\"\"Act method returns an Action based on current observation and past history\"\"\"\n",
" self.observe(observation)\n",
" scores = self.user_history.dot(self.V.T).dot(self.V)\n",
" action = np.argmax(scores)\n",
" prob = np.zeros_like(scores)\n",
" prob[action] = 1.0\n",
"\n",
" return {\n",
" **super().act(observation, reward, done),\n",
" **{\n",
" 'a': action,\n",
" 'ps': prob[action],\n",
" 'ps-a': prob,\n",
" },\n",
" }\n",
"\n",
" def reset(self):\n",
" self.user_history = np.zeros(P)"
],
"execution_count": 13,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "CwOGMwwyYjTQ"
},
"source": [
"### Item-KNN agent"
]
},
{
"cell_type": "code",
"metadata": {
"id": "YHIXb-KHYejQ"
},
"source": [
"class itemkNNAgent(Agent):\n",
" def __init__(self, config, U = U, P = P, k = 5, greedy = False, alpha = 1):\n",
" super(itemkNNAgent, self).__init__(config)\n",
" self.rng = RandomState(self.config.random_seed)\n",
" self.k = min(P,k)\n",
" self.greedy = greedy\n",
" self.alpha = alpha\n",
" self.Rt = csr_matrix((P,U))\n",
" self.user_history = np.zeros(P)\n",
" self.S = np.eye(P)\n",
" \n",
" def train(self, reco_log, U = U, P = P):\n",
" # Extract all organic user logs\n",
" reco_log = reco_log.loc[reco_log['z'] == 'organic']\n",
" \n",
" # Generate ratings matrix for training, row-based for efficient row (user) retrieval\n",
" self.R_t = csr_matrix((np.ones(len(reco_log)),\n",
" (reco_log['v'],reco_log['u'])),\n",
" (P,U))\n",
"\n",
" # Set up nearest neighbours module\n",
" nn = NearestNeighbors(n_neighbors = self.k,\n",
" metric = 'cosine')\n",
"\n",
" # Initialise placeholder for distances and indices\n",
" distances = []\n",
" indices = []\n",
"\n",
" # Dirty fix for multiprocessing backend being unable to pickle large objects\n",
" nn.fit(self.R_t)\n",
" distances, indices = nn.kneighbors(self.R_t, return_distance = True)\n",
"\n",
" # Precompute similarity matrix S\n",
" data = list(chain.from_iterable(1.0 - distances))\n",
" rows = list(chain.from_iterable([i] * self.k for i in range(P)))\n",
" cols = list(chain.from_iterable(indices))\n",
" \n",
" # (P,P)-matrix with cosine similarities between items\n",
" self.S = csr_matrix((data,(rows, cols)), (P,P)).todense()\n",
" \n",
" def observe(self, observation):\n",
" for session in observation.sessions():\n",
" self.user_history[session['v']] += 1\n",
"\n",
" def act(self, observation, reward, done):\n",
" \"\"\"Act method returns an Action based on current observation and past history\"\"\"\n",
" self.observe(observation)\n",
" scores = self.user_history.dot(self.S).A1\n",
" \n",
" if self.greedy:\n",
" action = np.argmax(scores)\n",
" prob = np.zeros_like(scores)\n",
" prob[action] = 1.0\n",
" else:\n",
" scores **= self.alpha\n",
" prob = scores / np.sum(scores)\n",
" action = self.rng.choice(self.S.shape[0], p = prob)\n",
"\n",
" return {\n",
" **super().act(observation, reward, done),\n",
" **{\n",
" 'a': action,\n",
" 'ps': prob[action],\n",
" 'ps-a': prob,\n",
" },\n",
" }\n",
"\n",
" def reset(self):\n",
" self.user_history = np.zeros(P)"
],
"execution_count": 15,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "qK82zI6qYspo"
},
"source": [
"### User-KNN agent"
]
},
{
"cell_type": "code",
"metadata": {
"id": "uoRDZ8mNYpLo"
},
"source": [
"class userkNNAgent(Agent):\n",
" def __init__(self, config, U = U, P = P, k = 5, greedy = False, alpha = 1):\n",
" super(userkNNAgent, self).__init__(config)\n",
" self.rng = RandomState(self.config.random_seed)\n",
" self.k = min(P,k)\n",
" self.greedy = greedy\n",
" self.alpha = alpha\n",
" self.U = U\n",
" self.P = P\n",
" self.R = csr_matrix((U,P))\n",
" self.user_history = np.zeros(P)\n",
" self.nn = NearestNeighbors(n_neighbors = self.k, metric = 'cosine')\n",
" \n",
" def train(self, reco_log, U = U, P = P):\n",
" # Extract all organic user logs\n",
" reco_log = reco_log.loc[reco_log['z'] == 'organic']\n",
" \n",
" # Generate ratings matrix for training, row-based for efficient row (user) retrieval\n",
" self.R = csr_matrix((np.ones(len(reco_log)),\n",
" (reco_log['u'],reco_log['v'])),\n",
" (U,P))\n",
"\n",
" # Fit nearest neighbours\n",
" self.nn.fit(self.R)\n",
" \n",
" def observe(self, observation):\n",
" for session in observation.sessions():\n",
" self.user_history[session['v']] += 1\n",
"\n",
" def act(self, observation, reward, done):\n",
" \"\"\"Act method returns an Action based on current observation and past history\"\"\"\n",
" self.observe(observation)\n",
" \n",
" # Get neighbouring users based on user history\n",
" distances, indices = self.nn.kneighbors(self.user_history.reshape(1,-1))\n",
" scores = np.add.reduce([dist * self.R[idx,:] for dist, idx in zip(distances,indices)])\n",
" \n",
" if self.greedy:\n",
" action = np.argmax(scores)\n",
" prob = np.zeros_like(scores)\n",
" prob[action] = 1.0\n",
" else:\n",
" scores **= self.alpha\n",
" prob = scores / np.sum(scores)\n",
" action = self.rng.choice(self.P, p = prob)\n",
"\n",
" return {\n",
" **super().act(observation, reward, done),\n",
" **{\n",
" 'a': action,\n",
" 'ps': prob[action],\n",
" 'ps-a': prob,\n",
" },\n",
" }\n",
"\n",
" def reset(self):\n",
" self.user_history = np.zeros(P)"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "CcVL6ih6Y0p8"
},
"source": [
"### Agent initializations"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wYX3_5fYYumd"
},
"source": [
"# SVD Agent\n",
"SVD_agent = SVDAgent(Configuration(env_1_args), U, P, 30)\n",
"SVD_agent.train(reco_log)\n",
"\n",
"# item-kNN Agent\n",
"itemkNN_agent = itemkNNAgent(Configuration(env_1_args), U, P, 500, greedy = True)\n",
"itemkNN_agent.train(reco_log)\n",
"\n",
"# user-kNN Agent\n",
"userkNN_agent = userkNNAgent(Configuration(env_1_args), U, P, 20, greedy = True)\n",
"userkNN_agent.train(reco_log)\n",
"\n",
"# Generalised Popularity agent\n",
"GPOP_agent = OrganicCount(Configuration({\n",
" **env_1_args,\n",
" 'select_randomly': True,\n",
"}))\n",
"\n",
"# Generalised Popularity agent\n",
"GPOP_agent_greedy = OrganicCount(Configuration({\n",
" **env_1_args,\n",
" 'select_randomly': False,\n",
"}))\n",
"\n",
"# Peronalised Popularity agent\n",
"PPOP_agent = OrganicUserEventCounterAgent(Configuration({\n",
" **organic_user_count_args,\n",
" **env_1_args,\n",
" 'select_randomly': True,\n",
"}))\n",
"\n",
"# Peronalised Popularity agent\n",
"PPOP_agent_greedy = OrganicUserEventCounterAgent(Configuration({\n",
" **organic_user_count_args,\n",
" **env_1_args,\n",
" 'select_randomly': False,\n",
"}))\n",
"\n",
"# Random Agent\n",
"random_args['num_products'] = P\n",
"RAND_agent = RandomAgent(Configuration({**env_1_args, **random_args,}))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "qRqbHqMJY9vL"
},
"source": [
"## Offline evaluation"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E9r-zhlAZZ9M"
},
"source": [
"### Generating test logs"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pnzTANe1Y3lW",
"outputId": "bb01cede-b583-4960-c01c-c0fea14cbc28"
},
"source": [
"%%time\n",
"# Placeholder for agents\n",
"agents = {\n",
" ' Random': RAND_agent,\n",
" ' Popular': GPOP_agent_greedy,\n",
" ' User-pop': PPOP_agent,\n",
" ' SVD': SVD_agent,\n",
" ' User-kNN': userkNN_agent,\n",
" 'Item-kNN': itemkNN_agent,\n",
"}\n",
"agent_ids = sorted(list(agents.keys()))#['SVD','GPOP','PPOP','RAND']\n",
"# Generate new logs, to be used for offline testing\n",
"n_test_users = 5000 # U\n",
"test_log = env.generate_logs(n_test_users)\n",
"n_events = test_log.shape[0]\n",
"n_organic = test_log.loc[test_log['z'] == 'organic'].shape[0]\n",
"print('Testing on {0} organic and {1} bandit events'.format(n_organic, n_events - n_organic))"
],
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"text": [
"Organic Users: 0it [00:00, ?it/s]\n",
"Users: 100%|██████████| 5000/5000 [05:38<00:00, 14.77it/s]\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Testing on 109631 organic and 403487 bandit events\n",
"CPU times: user 6min 4s, sys: 4min 47s, total: 10min 51s\n",
"Wall time: 5min 40s\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "igubir8vaB0m"
},
"source": [
"### (Util) helper function to plot barchart"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zCpbb9K-ZB-b"
},
"source": [
"def plot_barchart(result, title, xlabel, col = 'tab:red', figname = 'fig.eps', size = (6,2), fontsize = 12):\n",
" fig, axes = plt.subplots(figsize = size)\n",
" plt.title(title, size = fontsize)\n",
" n_agents = len(result)\n",
" yticks = np.arange(n_agents)\n",
" mean = result['0.500']\n",
" lower = result['0.500'] - result['0.025']\n",
" upper = result['0.975'] - result['0.500']\n",
" plt.barh(yticks,\n",
" mean,\n",
" height = .25,\n",
" xerr = (lower, upper),\n",
" align = 'center',\n",
" color = col,)\n",
" plt.yticks(yticks, result['Agent'], size = fontsize)\n",
" plt.xticks(size = fontsize)\n",
" plt.xlabel(xlabel, size = fontsize)\n",
" plt.xlim(.0,None)\n",
" plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.2f'))\n",
" plt.savefig(figname, bbox_inches = 'tight')\n",
" plt.show()"
],
"execution_count": 20,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "w9r6FSeDZdyM"
},
"source": [
"### Leave-one-out evaluation"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 479
},
"id": "zJSfQ5RDZpab",
"outputId": "57ae38fc-2143-4be7-9c89-8fef3257f9a1"
},
"source": [
"%%time\n",
"result_LOO = verify_agents_traditional(test_log, deepcopy(agents))\n",
"display(result_LOO)\n",
"plot_barchart(result_LOO, 'Evaluate on Organic Feedback', 'HR@1', 'tab:red', 'traditional_eval.eps')"
],
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/recogym/agents/organic_user_count.py:51: RuntimeWarning: invalid value encountered in true_divide\n",
" action_proba = features / np.sum(features)\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Agent | \n",
" 0.025 | \n",
" 0.500 | \n",
" 0.975 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Random | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 1 | \n",
" Popular | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 2 | \n",
" User-pop | \n",
" 0.000000 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
" 3 | \n",
" SVD | \n",
" 0.074560 | \n",
" 0.076895 | \n",
" 0.079231 | \n",
"
\n",
" \n",
" 4 | \n",
" User-kNN | \n",
" 0.079927 | \n",
" 0.082336 | \n",
" 0.084746 | \n",
"
\n",
" \n",
" 5 | \n",
" Item-kNN | \n",
" 0.074461 | \n",
" 0.076795 | \n",
" 0.079130 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Agent 0.025 0.500 0.975\n",
"0 Random 0.000000 0.000000 0.000000\n",
"1 Popular 0.000000 0.000000 0.000000\n",
"2 User-pop 0.000000 0.000000 0.000000\n",
"3 SVD 0.074560 0.076895 0.079231\n",
"4 User-kNN 0.079927 0.082336 0.084746\n",
"5 Item-kNN 0.074461 0.076795 0.079130"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaUAAACwCAYAAACvvnEyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAeBUlEQVR4nO3debgcVZnH8e+PEAhLFiABTCAJi0AIQnBAFBGiLALig7IpOGAYVh0UZ5QBBSQgUVwQVGQQRHYCYVgGFwQdDAoiGFwCCChgICEEshCyECDLO3+c01i3033XvrfrXn+f56kn3XVOVb1V96bfrlN161VEYGZmVgZrNDsAMzOzCiclMzMrDSclMzMrDSclMzMrDSclMzMrDSclMzMrDScl6xMkTZV0fLPj6OskXSbp7GbHUUvxd0DSBEn3d+c2rHs4KVmPkjRD0jJJSwrTJc2Oq0LSeEmzmh1HhaTtJd0p6VVJiyX9StLuzYonIk6OiK90dLk6P/fh3RGj9W5OStYMH46I9QvTKc0OqIwkbQU8ADwKbAEMB24H7pH0njrLrNlzEXZY9c99drMDsvJxUrJSkLS2pIWSdijMG5a/XW8saQNJP5E0V9Ir+fVmddY1UdL1hfejJUXlA1vSsZKeyGcez0o6Kc9fD7gLGF78Ni9pDUlnSHpG0nxJUyRt2Mq+nCDpaUkL8lnO8EJbSDpZ0t/y/n5fkuqsaiLwYEScGRELImJxRHwXuA74etW+HSfpeeBeSf0kXShpnqS/SzqlPfuf28ZLmiXp85JelvSipGML7VdLOr/w/mBJf5K0KB+f/esdlzrHarCkK/N2XpB0vqR+hfZ/y7G+IuluSaMKbftKejKfRV4CVB9HSboktz8pae9CQ91j0N79kvQ2SdMlndaRfbbWOSlZKUTEG8BtwJGF2UcA90XEy6Tf1auAUcBIYBnQ2WG/l4GDgEHAscBFkt4ZEUuBA4DZVd/mPwN8BNiLdLbyCvD9WiuW9AHgazn2twHPATdVdTsI2BXYMff7YJ049wVuqTF/CvBeSesU5u0FjMnrOiHvxzjgnTn2Nve/0L4pMBgYARwHfF/SBjX29V3AtcBpwBBgT2BGnX2p52pgBbA1sDOwH1C5LnQw8CXgEGAY8Btgcm4bSvp9OQsYCjwDvLdq3bvl+UOBc4DbCl8m6h6D9uyXpC2A+4BLIuKbHdxna01EePLUYxPpP/cSYGFhOiG37QM8U+j7AHBMnfWMA14pvJ8KHJ9fTwSuL7SNBgJYs8667gBOza/HA7Oq2p8A9i68fxuwvNb6gCuBbxTer5/7js7vA9ij0D4FOKNOXCuA/WvM3y6vZ0Rh37YstN8LnFR4v08H939ZsS/pA/zd+fXVwPn59Q+Aizr5c78D2AR4A1in0O9I4Ff59V3AcYW2NYDXSF9MjgF+V2gTMKvwOzABmA2o0Odh4Oh2HIO6+5V/z76d9+fIZv9/6otTmcefre/6SET8ssb8XwHrStoNeImUeG4HkLQucBGwP1D51j5QUr+IWNmRjUs6gPTNeRvSB926pOs29YwCbpe0qjBvJelD9YWqvsOBP1TeRMQSSfNJCWRGnj2n0P81UuKqZR4pAVZ7G7CKdMa2cZ43syqG4vvi6/bs//yIWNGOGDcHflYn9lpa/NzzGUl/4MXCCOYahXhHAd+RdGExfNKxbLGPERGSWuwn8ELkTJI9l5dr6xi0tV+fAJ4G/qe1nbXO8fCdlUZOLlNI35aPBH4SEYtz8+eBbYHdImIQaUgFVr+OALCU9CFTsWnlhaS1gVuBbwGbRMQQ0gdQZT21Hps/EzggIoYUpgERUZ2QIH07L173WA/YiNWTV3v8Eji8xvwjSNeaXivMK8b9IlC83rZ5IZ629r8jZgJbdWK54vJvAEMLx3VQRIwttJ9UddzXiYjfkvaxuF8qvs9GVF2vGwnMbscxaGu/JpK+MNxYvP5ljeGkZGVzI/Ax0rfRGwvzB5KGlRbm6wLntLKOPwF7ShopaTDwxULbWsDawFxgRf7GvF+h/SVgo7xcxWXApMpFdqUbMA6us+3JwLGSxuUPv68CD0XEjNZ2uo5zgd0lTZK0oaSBkj5DGro6vZXlpgCnShohaUhV37b2vyOuJO3r3ko3g4yQtF17F46IF4F7gAslDcrr2ErSXrnLZcAXJY2Ft26KqCTpnwJjJR2Sb+D4LIUvH9nGwGcl9c/LjSEln7aOQVv7tZz0ZWE94FpJ/hxtIB9Ma4Yfq+Xfq9xeaYiIh0hnOsNJ1xQqLgbWIX1D/R3w83orj4hfADcD04FHgJ8U2haTPsCmkIa/jgLuLLQ/SUoszyrdHTcc+E7uc4+kxXn7u9XZ9i+Bs0nfxF8kfeP+eDuOSa11/Q3YA9iJNPT3InAo8MGIeKCVRa8gfdhPB/5I+iBeAaxsa/87GN/D5JsEgFdJF/5HtbrQ6o4hJYm/5Hj+hzxkGRG3k+4yvEnSIuAx0g0cRMQ8UmK4AJgPvJ10DbLooTx/HjAJOCwi5rfjd6DN/YqIN0k3YGwC/MiJqXHUcsjVzPqafCZwWUR0NGGY9Thnd7M+RtI6kg6UtKakEaShztvbWs6sDHymZNbH5DsV7yPdOr6MdP3l1IhY1NTAzNrBScnMzErDw3dmZlYaTkpmZlYafqJDFwwdOjRGjx7d7DDMzHqVRx55ZF5EDKvV5qTUBaNHj2batGnNDsPMrFeR9Fy9tl49fKeqkgRmZta7NSQpKVWV3EfdVIK4MwoJ62dV86+XNDG/Hp/7XFrV535JE3ouWjMzg15+ptROu6n18tFLgaMlje6ZcMzMrJ5GDnuNAb4J9Je0BFgREUPyQyknkZ5svDbpL8v/IyKWSRoPXA98F/gCqRzAp4A3Sc86Gwp8KyK+2p4AJB0KXEgq3rUkz/5G3v776yy2MMd0Dul5V+32+mOP88R2YzqyiJlZw4x58olmh9BwjUxKTwAnk4ps7VGYfwHpoZTjSE/XvRH4Mv94cvOmwABSjZQJpIdJ/gL4F9Kj5qdJmhwRf29t40olm88E9omIpwtnPpeSnhS8T50aPpCS1l8lXRART7V7j83MutEnn697PwAA644f3+Y6pk6d2phgeki3Dt/lWiYnks6MFuSn836Vlk9NXg5MiojlpLLRQ4HvRMTiiHic9PTgndrY1OdIpYvHR8TTVW3LSEnn/HoLR8Qc0mPyz2vHPp0oaZqkaQtWrmiru5mZdUB337U2jFRs7ZFCrS0BxcJY8wuVQ5flf18qtC8jV73Mw4IV2xdenwacFxGz6sTxQ+A0SR9uJdavA89IajUBRsTlwOUAOwxYx89oMrNuc83I1h/sPqaXnQW1R6OTUvWH9DxSUhlbp0pnx1Ye0aIkc2GIbj/g55LmRMStNZZ7U9K5wFeAx+use76ki3MfMzNrgkYnpZeAzSStFRFvRsQqSVcAF0k6JSJezo/S3yEi7m7gdh8H9gfulrQ8ImoVLbsOOCP3+1ud9XwbeJZ2loYesMNYxviPZ83MGqbR15TuJSWIOZLm5XmnA08Dv8vVI38JbNvg7RIRfybddXdFLmpW3b6SdIPFhq2sYxHpbr26fczMrPu4dEUX7LLLLuHHDJmZdYykRyJil1pt/wx/PGtmZr2Ek5KZmZWGk5KZmZWGk5KZmZWGk5KZmZWGk5KZmZWGk5KZmZWGk5KZmZVGn0hKuXrs1s2Ow8zMuqZR5dArpcfXrJp/taS6JSN6Qo7rUUlrFOadL+nq/LrNsulmZtYzet2ZkqR+bfdazXBa1nCqpa2y6WZm1s26u57SW/Lw2pX8owLt/0XEx3LbdsD3SNVm5wJnR8SU3HY1qfzFKGAv4GDSQ13rbWcPYDJwdERMzbO/AZwraUpE1KvM11bZ9NW4HLqZ9TXNLrHeY0mJVKfoHtKH/lrALgCS1iOVP/8ycADwDuAXkh6LiL/kZY8CDiQ9BXytehuQtD+pnPqhEfFwoek24AhSufUf1lm8PWXTzcx6rbbKq0PzS6z35PDdctLZzvCIeD0i7s/zDwJmRMRVEbEiIv4I3AocXlj2fyPigYhYFRGv11n/4cAPgAOqEhKk4oNnA2dLqpfU2iybDi6HbmbWnRp1plT5dO5feF15vzy//i/S2dLDkl4BLoyIH5ES1W6SFlbFdV3h/czKC0mP52UgJaDf5NefA66NiMdqBRgRP5M0Cziplf1os2y6y6GbWW/VVnl1aH6J9UYlpRdJyWc0UByQ3II0NEdEzAFOgLeu+/xS0q9JCee+iNi3lfW/9eEfEWPr9DkcuFLSrIj4Tp0+Z5KuN02uuZF2lE03M7Pu05CkFBErJd0KTJJ0ArAIOAzYHrgLQNLhwIMRMQt4hZRoVgE/AS6QdDRwU17lOGBJRHTkittsYG9gqqQ3I+K/a8Q5VdJjwCeBH9dZT3vKpgMuh25m1miNvKb0aWABMB14GTgF+FBEvJTbdwUekrQEuBM4NSKejYjFwH6kW7ZnA3OArwNrdzSAiHielJjOkHR8nW5n0XpJ9DbLppuZWfdwOfQucDl0M7OOczl0MzPrFZyUzMysNJyUzMysNJyUzMysNJyUzMysNJyUzMysNJyUzMysNJyUzMysNJyUzMysNPpEUpJ0nKQnJS2W9JKkn0kaKOmM/NDX6v5DJb0paQdJEyStlLQkT3+XdJWkbZqxL2Zm/8x6fVKStBfwVeDIiBgIjAFuzs3XA7tL2qJqsY8DjxbKXDwYEesDg4F9SLWVHpG0Q7fvgJmZvaUnK892l11JSeWPABGxALgmty2WdC9wNHBeYZljgGurV5QfxvoM8GlJI4GJpKed1+Ry6GbWbM0uX95ofSEpPQR8JddBugeYFhFvFNqvISWX8wAkbUsqjfGhNtZ7G/C1hkdrZtZOvaF8eaP1+uG7XHn2EOCdwE+B+ZK+Lalf7nI7sImk3fP7Y4C7ImJuG6ueTY3yFS6HbmbWffrCmRIRcRdwl6Q1gPcDtwBPAT+IiNck3QIcI+lB4BPA59ux2hGk+lDV23I5dDPrEb2hfHmj9YmkVBERq4D/y9eRijcpXAPcQRqSG0j9qrNFHwV+01oHV541M2usXp+UJB0MrAPcDSwk3fiwF/C5Qrff5LbLgZsi4s066+oHjAT+ExgPvKfbAjczs9X0+mtKwCvACcDfgEWk28C/GRE3VDpEKq97LTCKGnfdAe/JZdoXAVOBQcCuEfFo94ZuZmZFLofeBS6HbmbWcS6HbmZmvYKTkpmZlYaTkpmZlYaTkpmZlYaTkpmZlYaTkpmZlYaTkpmZlYaTkpmZlYaTkpmZlUZDk1IuLX5/jfkzJO3TyG2ZmVnf0yvPlCT1+gfJmpnZ6no8KUk6UNJfJC2W9IKkLxTaDpL0J0kLJf1W0o6FthmSTpc0HVhanZgkjZc0S9KXJM3L/T9RaB8s6VpJcyU9J+msXH+pcob3gKRLJL0q6UlJe/fA4TAzs4JmnCldCZwUEQNJNY/uBZC0M/Aj4CRgI+AHwJ2S1i4seySpjPmQiKhV9nVTYCipQN8ngctz+XOA7wGDgS1JpS2OAY4tLLsb8Exe/hzgNkmrVZ41M7Pu04yktBzYXtKgiHglIv6Q559IqhT7UESsjIhrgDeAdxeW/W5EzIyIZa2s/+yIeCMi7iOVRz8i10n6OPDFiFgcETOAC4GjC8u9DFwcEcsj4mZS5doPVa+8WA597ty2KqqbmVlHNDoprQD615jfn5SMAA4FDgSek3SfpEohvVHA5/PQ3UJJC4HNgeGF9cwEkDRS0pLKVGh/JSKWFt4/l5cfmmN4rqptROH9C9Gyjkdl2RYi4vKI2CUidhk2bFitY2BmZp3U6KT0PDBSkiozJK0LbExOCBHx+4g4OM+7A5iSu84EJkXEkMK0bkRMLqw/8jqej4j1K1OhfQNJ6xXejwRmA/NISXFUVdsLhfcjinEXljUzsx7S6KT0EPA6cIakATlBXABMI50ZrSXpE5IGR8RyUqXXVXnZK4CTJe2mZD1JH5I0sIMxnJu38z7gIOCWiFhJSn6TJA2UNIpU8vz6wnIbA5+V1F/S4cAY4GedOwxmZtYZDU1KEfEG6TrMeGAW8CxpCOyIwtDY0cAMSYuAk4FP5GWnkcqaX0Iqcf40MKGDIczJy84GbgBOjognc9tngKU5pvuBG0k3VlQ8BLyddFY1CTgsIuZ3cPtmZtYFfaYcuqTxwPURsVknlp0AHB8Re3RkOZdDNzPrOJdDNzOzXsFJyczMSqPPJKWImNqZobu87NUdHbozM7PG6zNJyczMej8nJTMzKw0nJTMzKw0nJTMzKw0nJTMzK41/mqQkKSRt3ew4zMysvmYU+RsvaVV+wvdiSU9JOrbtJc3MrK9r1pnS7Px070HA6cAVkrZvUiytcul1M7Oe09Thu0juID1EdXtJa0u6WNLsPF1cqTzbjnLnUyUdX3g/QdL9tbabnz7+R0mLJM2UNLHQNjoP9R0n6XlyZVwzM+t+TU1KktaQ9FFgCPAocCap0uw4YCfgXcBZhUVaK3feEUtJ5dCHkJ5q/ilJH6nqsxepfMUHO7F+MzPrhGYlpeG5suw84Bzg6Ih4ilTG4ryIeDki5gLn0rJkOdQod97RjedHEj0aEasiYjowmZSEiiZGxNLq0usuh25m1n2adb1kdp3n1A1n9ZLlxZLk9cqdd4ik3UjFB3cA1gLWBm6p6jaz1rIRcTlwOaTSFR3dtpmZ1Ve2W8Jns3rJ8mJJ8nrlziENya1baNu0le3cCNwJbB4Rg4HLAFX1ccIxM+thZUtKk4GzJA2TNBT4Mi1LlkONcud5/p+AQyStm/8e6bhWtjMQWBARr0t6F3BUY3fDzMw6o2y3O59Puk18en5/S55XUSx3/hoty51fBOwKvJSXvwHYp852Pg1cKOkS4D5gCummBzMza6JeUw69K+XOu4vLoZuZdZzLoZuZWa/gpGRmZqVRtmtKdUXEVKA0Q3dmZtZ4PlMyM7PScFIyM7PScFIyM7PScFIyM7PScFIyM7PS6FNJySXPzcx6t25NSrnQ3spc+nyRpD9LOqg7t2lmZr1XT5wpPZhLnw8BLgVukuTnzJmZ2Wp6bPguIlYB1wHrAW8HkLSVpHslzc8lzm8oJqxc8vwLkqZLelXSzZIGFNpPk/RiLp3+b8XtSRos6VpJcyU9J+ksSWvktgmSHpB0kaSFkp6VtHueP1PSy5I+2SMHxszM3tJjSUlSP+BYYDn/KOQn4GukQn1jgM2BiVWLHgHsD2wB7AhMyOvbH/gCsC8pyVU/Efx7wGBgS1JV2WPy9it2Iz1NfCNSfaWbSE8Z3xr4V+ASSet3eofNzKzDeiIpvTuXPn8d+BbwrxHxMkBEPB0Rv8jlzecC32b1suTfjYjZEbEA+DEwLs8/ArgqIh7L1WgnVhbICfDjwBcjYnFEzAAupGVp9b9HxFURsRK4mZQQz8ux3AO8SUpQLbgcuplZ9+mJpPS7iBgCbECq9vq+SoOkTSTdJOkFSYtIBf2GVi0/p/D6NaBy9jKcliXLi2XUhwL9Wb20+ojC+5cKr5cBRET1vNXOlCLi8ojYJSJ2GTZsWHWzmZl1QU9eU1oCfAo4WtLOefZXSWXH3xERg0jDZtVlyet5kXR2UzGy8HoeaZiwurT6C50I3czMekiP/p1SHoL7IanMOaSy5EuAVyWNAE7rwOqmABMkbS9pXeCcwnZW5vZJkgZKGgX8J6uXVjczsxJpxh/PXgwcKGlH4FzgncCrwE+B29q7koi4K6/rXuDp/G/RZ4ClwLPA/aSbGX7U1eDNzKz79Jpy6GXkcuhmZh3ncuhmZtYr+EypCyQtBp5qdhw1DCXd7FE2ZYyrjDGB4+qIMsYE5YyrLDGNioiaty/3mnLoJfVUvVPQZpI0zXG1TxljAsfVEWWMCcoZVxljqubhOzMzKw0nJTMzKw0npa65vNkB1OG42q+MMYHj6ogyxgTljKuMMbXgGx3MzKw0fKZkZmal4aRkZmal4aRURdKGkm6XtDQXBzyqTj9J+nouUDg/v1ahfZykRyS9lv8dV2s9TYjrcklPSVolaUKzY5K0jaT/zcUYF0i6W9K2JYhraC4EOT8XgnxQ0nubHVdVv2MkhaTjmx1TjmOppCV5+mFnY2pwXP0kna9UCHSxpD+qk5WvG/R79b7CMapMIenQzsTUqLhy+wck/UHSIqXCpyd2NqYuiQhPhQmYTKqvtD6wB+m5fGNr9DuJ9Iezm5FKYvwFODm3rUUqlfEfwNrAZ/P7tZoZV27/d2BvYBowoQTH6l3AccCGpHIjXwGeLEFcA4BtSV/cBHwEWACs2eyfYe6zAfAk8BhwfLNjIj3tf+sy/T/M7eeTnos5Kv8cdwAGNPvnV+g7HlgMrNfk3/f+ebmT8nHalfSw7J0a9TNt9/709AbLPJFKtb8JbFOYdx1wQY2+vwVOLLw/jlQ7CmA/UpkMFdqfB/ZvZlxV/e6nC0mpO2LKbRvmD7iNyhIXKTF9OMe1cRniAi4DPg1MpZNJqZEx0cCk1MD/hxvkD9atyhJTjb5XkYqVNvtYbZJ/husW2n8PHNmIn2lHJg/ftbQNsCIi/lqY92dgbI2+Y3NbrX5jgemRf7LZ9Drr6cm4Gqm7YtoTmBMR88sQl6TppKrJdwI/jFw1uZlxSXoXsAspMXVFo3+Gv5Y0R9JtkkaXIK53ACuAw3Jcf5X0702O6S2S1gMOA67pZEwNiytSgdPJwLF5yPM9pLPL+7sQW6c4KbW0PrCoat6rpLpPtfq+WtVv/TxGW93W2np6Mq5GanhMkjYDvk+qfVWKuCJiR2AQcBRd+w/akLgk9QMuBU6JiFVdiKdhMeX3ewGjge2A2cBPJHX2MWaNimszYDDpg3sLUgKYKGnfJsZUdAjpOXT3dSKe7ohrMqnW3RvAb4AzI2ImPcxJqaUlpA+gokGkMd+2+g4CluSzo46spyfjaqSGxiRpGHAPcGlETC5LXAAR8XqO6QxJOzU5rk+TzsJ/18k4uiMmIuLXEfFmRCwETiUlgTFNjmtZnndeRCyLiOnATcCBTYyp6JPAtV38v9mQuCRtRzo2x5CuiY8F/kvSh7oQW6c4KbX0V2BNSW8vzNsJeLxG38dzW61+jwM7Vn0z2rHOenoyrkZqWEySNiAlpDsjYlJZ4qqhP7Blk+PaG/hoHo6aA+wOXCjpkibGVEuQLph3RqPiml6IhRqvmxETAJI2J93kcG0n42l0XDsAf42IuyNiVUQ8RSq8ekAX4+u4nr6IVfaJ9G1hMukC4nupfyfLycATpLtYhpN+uNV3351KuvvuFLp+912X4yrENgB4ADghv16jicdqEPAwcEnJfobvJt3JtBawDnA66dvn8CbHNQTYtDD9ljTcObiJMY0FxgH9SENEF5Pu8upfgt/3XwM/yP8PxwAvA3s3M6bc50vAr0v0+74V6UzqA6QvE1uRKnqf2IgYO7Q/Pb3Bsk+ku7/uIJVSfx44Ks9/H+lUt9JPwDdItwkvyK+Ld9vtDDxCGkL4A7BzSeKaSvq2WJzGNysm0hBG5HUsKUwjm3msSNdI/kxKRAtI4/57luFnWLXOqXTtlvBGHKsPkJLQUtKH/h3A28twrEgfwD/Pv1PPAic1O6bc50nguK4co244VkeQ/sRgMTAL+Dqd/MLalcnPvjMzs9LwNSUzMysNJyUzMysNJyUzMysNJyUzMysNJyUzMysNJyUzMysNJyUzMysNJyUzMysNJyWzEpI0Q9I+VfMmSLq/0L4sVy6dI+lqSevXWM+2kq6U9Helyr6PSjpX0sCqfu+X9CtJr0qa0a07Z9YKJyWz3uvDEbE+6blzOwNfLDZKOgS4i/SYq/cCGwEHkR7r9JCkkYXuS4EfAaf1QNxmdXW23omZlUREzJF0Nyk5AekMifRssz0jYlah+3OkmkK/J1U93Tuv42Hg4eqzM7Oe5jMls14uF0c8gPRU54ozgLMjYpakI/Lw3XOSzpR0RUT8FFglaYemBG1Wh5OSWXndIWlhZSJVnK1uXwzMJD2d+5xC23jgNkkb5uUOI1Vg3YZUFwrgT6RKsWal4aRkVl4fiYghlYlUdba6fSApAW0HDC20KSLeALYGno2IR/L7mwt9Ngde6L7wzTrOScmsl4uI+4CrgW8VZq+StBZpSG9LSe+UtDapZk4/SR8DRgO/7+FwzVrlpGTWN1wM7CupUu76t6S78xaQzrBuJZUHn0W6E28/4OCIWAEgaQ1JA0hDe5I0ICc1sx7lu+/M+oCImCvpWuDLwKHABaRrSg9ExBRgSqWvpLNJw3urCqvYE/hV4f0yUrXd8d0du1mRK8+a9VGSjgLOIyWqu0hlrv8FmAhcHxE3NC86s9qclMz6MEnjgNOB9wHrAX8BLouI65oamFkdTkpmZlYavtHBzMxKw0nJzMxKw0nJzMxKw0nJzMxKw0nJzMxKw0nJzMxK4/8BERXuZmwAjHMAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"CPU times: user 6min 45s, sys: 43 s, total: 7min 28s\n",
"Wall time: 4min 43s\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KMEG7NLGZ5Kp"
},
"source": [
"### IPS Estimators"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RuQ5m3goeHnR"
},
"source": [
"# Generate new logs, to be used for offline testing\n",
"test_log_ppop = env.generate_logs(n_test_users, agent = deepcopy(PPOP_agent))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"id": "Gjj31s8ge6qZ",
"outputId": "27bb9173-54e6-4cd7-d68a-deaeb939857f"
},
"source": [
"test_log_ppop.head()"
],
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" t | \n",
" u | \n",
" z | \n",
" v | \n",
" a | \n",
" c | \n",
" ps | \n",
" ps-a | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.0 | \n",
" 0 | \n",
" organic | \n",
" 220 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 1 | \n",
" 1.0 | \n",
" 0 | \n",
" organic | \n",
" 867 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 2 | \n",
" 2.0 | \n",
" 0 | \n",
" organic | \n",
" 969 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 3 | \n",
" 3.0 | \n",
" 0 | \n",
" organic | \n",
" 730 | \n",
" <NA> | \n",
" NaN | \n",
" NaN | \n",
" None | \n",
"
\n",
" \n",
" 4 | \n",
" 4.0 | \n",
" 0 | \n",
" bandit | \n",
" <NA> | \n",
" 969 | \n",
" 0.0 | \n",
" 0.25 | \n",
" () | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" t u z v a c ps ps-a\n",
"0 0.0 0 organic 220 NaN NaN None\n",
"1 1.0 0 organic 867 NaN NaN None\n",
"2 2.0 0 organic 969 NaN NaN None\n",
"3 3.0 0 organic 730 NaN NaN None\n",
"4 4.0 0 bandit 969 0.0 0.25 ()"
]
},
"metadata": {
"tags": []
},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 975
},
"id": "SdKK38bSZ6Bl",
"outputId": "e15dbd9e-cf50-44ff-ae9d-5f8ae7ec26b7"
},
"source": [
"%%time\n",
"cap = 15\n",
"result_IPS, result_CIPS, result_SNIPS = verify_agents_counterfactual(test_log_ppop, deepcopy(agents), cap = cap)\n",
"display(result_IPS)\n",
"plot_barchart(result_IPS, 'IPS', 'CTR', 'tab:blue', 'bandit_eval_noclip.eps')\n",
"display(result_CIPS)\n",
"plot_barchart(result_CIPS, 'Clipped IPS', 'CTR', 'tab:blue', 'bandit_eval_clip{0}.eps'.format(cap))"
],
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:65: RuntimeWarning: invalid value encountered in double_scalars\n",
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:120: RuntimeWarning: invalid value encountered in double_scalars\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Effective sample size for agent Random is 0\n",
"Effective sample size for agent Popular is 0\n",
"Effective sample size for agent User-pop is 0\n",
"Effective sample size for agent SVD is 21243.67667925834\n",
"Effective sample size for agent User-kNN is 33971.880634000554\n",
"Effective sample size for agent Item-kNN is 39627.482103973896\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Agent | \n",
" 0.025 | \n",
" 0.500 | \n",
" 0.975 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Random | \n",
" NaN | \n",
" 0.000000 | \n",
" NaN | \n",
"
\n",
" \n",
" 1 | \n",
" Popular | \n",
" NaN | \n",
" 0.000000 | \n",
" NaN | \n",
"
\n",
" \n",
" 2 | \n",
" User-pop | \n",
" NaN | \n",
" 0.000000 | \n",
" NaN | \n",
"
\n",
" \n",
" 3 | \n",
" SVD | \n",
" 0.005323 | \n",
" 0.013057 | \n",
" 0.020791 | \n",
"
\n",
" \n",
" 4 | \n",
" User-kNN | \n",
" 0.010922 | \n",
" 0.017513 | \n",
" 0.024104 | \n",
"
\n",
" \n",
" 5 | \n",
" Item-kNN | \n",
" 0.012314 | \n",
" 0.018250 | \n",
" 0.024186 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Agent 0.025 0.500 0.975\n",
"0 Random NaN 0.000000 NaN\n",
"1 Popular NaN 0.000000 NaN\n",
"2 User-pop NaN 0.000000 NaN\n",
"3 SVD 0.005323 0.013057 0.020791\n",
"4 User-kNN 0.010922 0.017513 0.024104\n",
"5 Item-kNN 0.012314 0.018250 0.024186"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAa0AAACwCAYAAAC8aTHGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXIklEQVR4nO3de5hddX3v8feHW5BbQBNFohAtSAVE7BkKUi3xKbUo+Ii1cBQMhAoRedR6UEpUUKBiEYtSpbQGb9xvj6h4gIMoJ1Qs4NmpCCKiiAmQGEi45MYlIfmcP35rdGUzM8nM7Jk9a/J5Pc9+svf6XdZv/TKZb35rrb2+sk1EREQTbNLtAURERGyoBK2IiGiMBK2IiGiMBK2IiGiMBK2IiGiMBK2IiGiMBK2IiGiMBK2IhpI0T9JBkmZIWiNphaRlku6SdGit3icl/a4qf0TSVd0cd8RwJGhFjA+3294G2B74OnC1pB0kHQNMBw6qynuAH3VxnBHDkqAVMY7YXgt8A3gR8CfAvsBNtn9blS+yPbuLQ4wYls26PYCI6BxJmwHHASuA3wB3AF+WtAD4v8DPbK/p4hAjhiUrrYjxYX9JTwGLgPcC77K91PalwIeBvwFuBR6TdEoXxxkxLFlpRYwPd9h+U18Fti8DLpO0OXBY9f4u2zeN6ggjOiArrYiNhO3Vtq8B7gb26vZ4IoYiK62IcUzSDGAx8J/ASsppwj2BO7s4rIghS9CKGN+WAZ8ELgU2BeYDH7R9W1dHFTFEShLIiIhoilzTioiIxkjQioiIxkjQioiIxkjQioiIxkjQioiIxsgt78MwadIkT506tdvDiIholLlz5y6xPXkobRO0hmHq1Km0Wq1uDyMiolEkzR9q20afHpQ0VZKrJ1tHRMQ415Gg1ZZBdUx8074W0G5o236ppNOr99OqOhe01bmtevxNRESMIY1eaW2g/SQdMED5SmC6pKmjM5yIiBiqTp5Wey3wBWBzSSuA521vL2kCcBZwBDAB+A7wv2w/I2ka5ZloXwY+DqwBPgisAs4DJgH/YvtzGzIASe8GzgUOpSTBAzin2v9b+mn2VDWmzwDHDuaA71mwlKmzrh9Mk4gYQ+adfUi3hxCD1MmgdR9wAnBcW16fsylpv/cBVgOXA58GPlGV7whsCUwBZgAXAjcD/wPYGWhJusL27wbauaRjgU8BB9l+oLZyugD4iKSDbP+wn+ZnAb+WdLbt+zf4iCPGiEWXz+r2EBpp2h1f6PYQGmnOnDld2/eInh6UJGAmZWX1hO3lwOeA99SqrQbOsr0auJKyuvpX28tt3wv8Enj9enb1UeBkYJrtB9rKnqEEpc/219j2IuA/gDM34JhmSmpJaq15eun6qkdERAeN9F13k4GtgLklfgEgSoqEXo/bXlO9f6b689Fa+TPANgDVacdee9TenwycafuRfsbxNeBkSe8YYKyfB34racAAaXs2MBtgwst3yyPyY0zY8cizuz2ERpqT04ON0+mg1f5LfAkl6Oxpe8GwO7e3qX+unQJ8K/B/JC2y/e0+2q2SdAbwT8C9/fT9uKTzqjoRETEGdTpoPQq8QtIWtlfZXivpQuBLkj5k+zFJU4C9bN/Uwf3eCxwM3CRpte3r+qhzCTCrqvebfvr5IvAgZTW4Xq+bMpFW/qcWETFqOn1N6xZKAFkkaUm17RTgAeAOScuAHwK7d3i/2P455a7BCyW9rY/yNZQbQF48QB/LKHcb9lsnIiK6J5mLh6Gnp8d5jFNExOBImmu7ZyhtN4YvF0dExDiRoBUREY2RoBUREY2RoBUREY2RoBUREY2RoBUREY2RoBUREY2RoBUREY0xLoJWlX14126PIyIiRlZHglYttf1mbdu/JanflCCjoRrXPZI2qW37rKRvVe97x35DW7tLJZ0+uqONiIiBNG6lJWnT9dd6gZ1YN4dXX/aTdMAQ+o6IiFEy0vm0/qA6ffd1/pjB+Ee2/2dV9qfAVyjZihcDp9m+uir7FiW9yS7AgcA7KQ/d7W8/bwKuAKbbnlNtPgc4Q9LVtp/vp+k5lGSRb9nQY7pnwVKmzrp+Q6tHxCiZl+wL49aoBS1KnqofUILCFkAPgKStgZspT2B/G/A64GZJv7D9y6rtkcDbKU9x36K/HUg6GLgQeLftn9aKrgWOAGZQEkL25QLgI5IOst1vUIzolEWXz+r2EMataXd8odtDGLfmzJnT1f2P5unB1ZTV0k62n7V9W7X9UGCe7W/aft72z4BvA4fX2n7P9k9sr7X9bD/9Hw58FXhbW8CCkpzyNOA0Sf0FvWcoK60Br8FJmimpJam15umlA1WNiIgO69RKq/eU2+a1972fV1fv/5Gy2vqppCeBc21/gxLI9pP0VNu4Lql9frj3jaR7qzZQAtSPq/cfBS62/Yu+Bmj7BkmPAB8Y4Di+Bpws6R39VbA9G5gNMOHluyWvSwzZjkee3e0hjFtzcnpw3OpU0Po9JThNBe6rbX8V5dQfthcBx8Mfrjv9UNJ/UgLSrbb/eoD+/xAcbO/ZT53Dga9LesT2v/ZT51OU611X9LkTe5WkMyjB9d4BxhMREV3QkaBle42kbwNnSToeWAb8HbAHcCOApMOB220/AjxJCURrgf8NnC1pOnBl1eU+wArb97HhFgJ/BcyRtMr2v/cxzjmSfgEcA3y/n34uAWYBBwO/GWiHr5sykVb+RxcRMWo6eU3rROAJ4G7gMeBDwCG2H63K9wXulLQCuA74B9sP2l4OvJVyS/pCYBHweWDCYAdg+yFK4Jol6bh+qp0KvHiAPtZQbgrpt05ERHSH7FyWGaqenh63Wq1uDyMiolEkzbXdM5S2jftycUREbLwStCIiojEStCIiojEStCIiojEStCIiojEStCIiojEStCIiojEStCIiojEStCIiojHGRdCS9H5Jv5K0XNKjkm6QtK2kWdVDedvrT5K0StJekmZIWiNpRfX6naRvSnpNN44lIiL61/igJelA4HPAe21vC7wWuKoqvhQ4QNKr2pq9B7inlsbkdtvbABOBgyi5teZK2mvEDyAiIjbYaGYuHin7UoLOzwBsPwFcVJUtl3QLMB04s9bmaODi9o6qh+X+FjhR0s7A6ZSn1ffpngVLmTrr+k4cQ8SoSBr6aLrxELTuBP6pyoP1A6Bl+7la+UWU4HMmgKTdKalP1vev91rgnzs+2o1EUsmPTUlDPzZ1O4V9kzT+9GCVufhvgT8Drgcel/RFSZtWVb4DvEzSAdXno4EbbS9eT9cL6SM9iaSZklqSWmueXtqZg4iIiA0yHlZa2L4RuFHSJsBbgGuA+4Gv2n5a0jXA0ZJuB44CPrYB3U6h5Adr39dsYDbAhJfvlrwu/Ugq+bEpaeij6cZF0Opley3wo+o6Vv0miouA71JO+W1L/1mL694F/HigCslcHBExuhoftCS9E3gRcBPwFOXGjAOBj9aq/bgqmw1caXtVP31tCuwMnARMA944YgOPiIhBa/w1LeBJ4HjgN8Ayym3uX7B9WW8Fl/TMFwO70Mddg8AbJa2o2s8BtgP2tX3PyA49IiIGQ+X3eQxFT0+PW61Wt4cREdEokuba7hlK2/Gw0oqIiI1EglZERDRGglZERDRGglZERDRGglZERDRGglZERDRGglZERDRGglZERDRGglZERDRGR4NWlbr+tj62z5N0UCf3FRERG59GrrQkNf5BvxERMXijHrQkvV3SLyUtl7RA0sdrZYdKukvSU5L+S9LetbJ5kk6RdDewsj1wSZom6RFJn5S0pKp/VK18oqSLJS2WNF/SqVX+rd4V4k8knS9pqaRfSfqrUZiOiIgYhG6stL4OfMD2tpScV7cASHoD8A3gA8BLgK8C10maUGv7XuAQYHvbz/fR947AJEoCx2OA2ZJ2r8q+AkwEXk1JXXI0cGyt7X7Ab6v2nwGulfSCzMUREdE93Qhaq4E9JG1n+0nb/11tn0nJNHyn7TW2LwKeA/avtf2y7YdtPzNA/6fZfs72rcD1wBFVnqz3AJ+wvdz2POBcYHqt3WPAebZX276Kkvn4BRkeJc2U1JLUWrx48dBmICIihqTTQet5YPM+tm9OCVYA7wbeDsyXdKuk3kSLuwAfq04NPiXpKeCVwE61fh4GkLSzpBW9r1r5k7ZX1j7Pr9pPqsYwv61sSu3zAq+bp6W37Tpsz7bdY7tn8uTJfc1BRESMkE4HrYeAnSWpd4OkrYCXUgUM2//P9jurbd8Frq6qPgycZXv72msr21fU+nfVx0O2t+l91cp3kLR17fPOwEJgCSVo7tJWtqD2eUp93LW2ERExRnQ6aN0JPAvMkrRlFUDOBlqUldUWko6SNNH2akqm4LVV2wuBEyTtp2JrSYdI2naQYzij2s+bgUOBa2yvoQTHsyRtK2kX4CRKluNeLwU+ImlzSYcDrwVuGNo0RETESOho0LL9HOU60DTgEeBByim2I2qn3qYD8yQtA04AjqratoDjgfOBJ4EHgBmDHMKiqu1C4DLgBNu/qso+DKysxnQbcDnlxo9edwK7UVZlZwF/Z/vxQe4/IiJGkNa9jNNckqYBl9p+xRDazgCOs/2mwbTr6elxq9Ua7O4iIjZqkuba7hlK20Z+uTgiIjZOCVoREdEY4yZo2Z4zlFODVdtvDfbUYEREjL5xE7QiImL8S9CKiIjGSNCKiIjGSNCKiIjGSNCKiIjG2GiCliRL2rXb44iIiKHrRhLIaZLWVk9oXy7pfknHrr9lRERs7Lq10lpYPZ19O+AU4EJJe3RpLANqz5AcERHd09XTgy6+S3nI7R6SJkg6T9LC6nVeb+biaoX2iKRPSloiaZ6ko3r7kjRH0nG1zzMk3dbXfqunx/9M0jJJD0s6vVY2tTqV+H5JD1FlVo6IiO7ratCStImkdwHbA/cAn6JkKt4HeD3w58CptSY7UhI6TgGOAWZL2n0Iu14JHF3t9xDgg5IOa6tzICU9yd8Mof+IiBgB3QpaO1WZiZcAnwGm276fkqbkTNuP2V4MnEFJZVJ3mu3nbN8KXA8cMdidV498usf2Wtt3A1dQglTd6bZX2n6mvlHSTEktSa3FixcPdtcRETEM3bpes7Cf5wTuRJXhuNKe8v5J2ysHKN8gkvajJKfcC9gCmABc01bt4b7a2p4NzIaSmmSw+46IiKEba7e8LwR2qX1uT3m/Q5UNua/ylcBWtbIdB9jP5cB1wCttTwT+A1BbnQSkiIgxZqwFrSuAUyVNljQJ+DRwaVudMyRtIenNwKH8cYV0F/C3kraqvo/1/gH2sy3whO1nJf05cGRnDyMiIkbCWLud+7OU2+Dvrj5fU23rtYhyp+FC4GngBNu/qsq+BOwLPFq1vww4qJ/9nAicK+l84FbgaspNGRERMYbJbsZZMEnTgEuHmjNrJPT09LjVanV7GBERjSJpru2eobQda6cHIyIi+pWgFRERjTHWrmn1y/YcYMycGoyIiNGXlVZERDRGglZERDRGglZERDRGglZERDRGglZERDTGuApaVR6sXbs9joiIGBkjGrSqRIxrJK2oEi7+XNKhI7nPiIgYv0ZjpXW77W0oz/a7ALhSUp7zFxERgzZqpwdtrwUuAbYGdgOQ9CeSbpH0uKQlki6rBzRJ8yR9XNLdkpZKukrSlrXykyX9XtJCSX9f35+kiZIulrRY0nxJp0rapCqbIeknkr4k6SlJD0o6oNr+sKTHJB0zKhMTEREbbNSClqRNgWOB1fwx0aOAf6Ykcnwt8Erg9LamRwAHA68C9gZmVP0dDHwc+GtKEGx/ovtXgInAqylZiY+u9t9rP8rT4F9Cya91JeUp8bsC7wPOl7TNkA84IiI6bjSC1v6SngKeBf4FeJ/txwBsP2D7ZtvP2V4MfJEXpr3/su2Ftp8Avg/sU20/Avim7V9U2YxP721QBcj3AJ+wvdz2POBcYHqt39/Z/qbtNcBVlIB5ZjWWHwCrKAFsHZJmSmpJai1evHhYExMREYMzGkHrDtvbAztQsgW/ubdA0sskXSlpgaRllISPk9raL6q9fxroXf3sBDxcK5tfez8J2Lxt23xgSu3zo7X3zwDYbt/2gpWW7dm2e2z3TJ48ub04IiJG0Ghe01oBfBCYLukN1ebPUdLav872dpTTcu1p7/vze8rqqNfOtfdLKKchd2krXzCEoUdExBgxqt/Tqk7xfQ34dLVpW2AFsFTSFODkQXR3NTBD0h6StgI+U9vPmqr8LEnbStoFOImykouIiIbqxpeLzwPeLmlv4Azgz4ClwPXAtRvaie0bq75uAR6o/qz7MLASeBC4jXKzxTeGO/iIiOge2e72GBqrp6fHrVar28OIiGgUSXNt9wyl7bh6jFNERIxvWWkNg6TlwP3dHscYMIly88vGLvNQZB6KzEPR1zzsYntIt19vNvzxbNTuH+oSdzyR1Mo8ZB56ZR6KzEPR6XnI6cGIiGiMBK2IiGiMBK3hmd3tAYwRmYci81BkHorMQ9HReciNGBER0RhZaUVERGMkaEVERGMkaLWR9GJJ35G0skoeeWQ/9STp81UCy8er96qV7yNprqSnqz/36aufsaqD8zBb0v2S1kqaMWoH0CGZhyLzUHRiHiS9RtL3VBLUPiHpJkm7j+6RDE+H5mGSSjLex1WS8d4u6S/Wt+8ErRf6N0ourZcBRwH/LmnPPurNBA4DXk9JTvkO4AMAkrYAvkd5QO8OwEXA96rtTTHseaj8HDgR+O8RHe3IyTwUmYeiE/OwPSVN0+5VPz+l/L5okk7Mwwrg74HJlN+Tnwe+L2ng7w/bzqt6AVtXfxGvqW27BDi7j7r/BcysfX4/JXcYwFspaVBUK38IOLjbxzia89BW7zZgRrePLfOQeRhL81CVvZiSoukl3T7GLv48bEIJaAZeOtD+s9Ja12uA523/urbt50Bf/4PYsyrrq96ewN2u/jYqd/fTz1jUqXlousxDkXkoRmoe/hJYZPvxjoxy5HV0HiTdTclsfx3wNVeZ7fuTxzitaxtgWdu2pZS8X33VXdpWb5vqfG172UD9jEUdmYe2oN1EmYci81B0fB4kvYJyqu2kDo91JHV0HmzvLWlL4F3Aei+hJGitawWwXdu27YDlG1B3O2CFbUsaTD9jUUfmYYTGNpoyD0XmoejoPEiaDPwAuMD2FR0e60jq+M+D7WeBKyTdJ+ku2/XV2TpyenBdvwY2k7RbbdvrgXv7qHtvVdZXvXuBvet3TVEuQvbVz1jUqXlousxDkXkoOjYPknagBKzrbJ81AmMdSSP587A58OoB997ti3pj7QVcCVxBudj4F5Tl7J591DsBuA+YAuxU/UWcUJVtAcwH/gGYAHyo+rxFt49vNOehNhdbAj8Bjq/eb9Lt48s8ZB66NQ+U1cZPgfO7fTxdnof9gTdVPxMvAk6hrNZ2GnDf3T74sfai3MnzXWAl5Y6/I6vtb6Ysa3vrCTgHeKJ6ncO6dwu+AZgLPEO5vfcN3T62Ls3DHModQfXXtG4fX+Yh89CteQCOqY57JeX0We9r524f3yjPw4GUGzOWV2W3An+5vn3n2YMREdEYuaYVERGNkaAVERGNkaAVERGNkaAVERGNkaAVERGNkaAVERGNkaAVERGNkaAV0TCSjpTUkrRC0u8l3Sjph9XnFZJWSVpd+3yjpKmSXNs2T9Ksbh9LxGDlgbkRDSLpJGAW5fE4N1HyGh1MeZLAQVWd04Fdbb+v1m5q9XZ7289L6gFulTTX9s2jdwQRw5OgFdEQkiYCZwLH2r62VvT96rXBbLck3QvsAyRoRWPk9GBEc7yR8oDZ7wy3I0n7A3sBDwy3r4jRlJVWRHO8BFhi+/lh9LFE0gRK8DuX8tDTiMbISiuiOR4HJkkazn82J1GyyX4MmEbJXxTRGAlaEc1xO/AccNhwOrG9xvYXgWeBEzsxsIjRkqAV0RC2lwKfBv5N0mGStpK0uaS3STpnCF2eDfyjpC07O9KIkZOgFdEgts8FTgJOBRYDD1MyYw/l2tT1wJOUDMIRjZAkkBER0RhZaUVERGMkaEVERGMkaEVERGMkaEVERGMkaEVERGMkaEVERGMkaEVERGMkaEVERGMkaEVERGP8f4ovmeAFMuqvAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Agent | \n",
" 0.025 | \n",
" 0.500 | \n",
" 0.975 | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Random | \n",
" NaN | \n",
" 0.000000 | \n",
" NaN | \n",
"
\n",
" \n",
" 1 | \n",
" Popular | \n",
" NaN | \n",
" 0.000000 | \n",
" NaN | \n",
"
\n",
" \n",
" 2 | \n",
" User-pop | \n",
" NaN | \n",
" 0.000000 | \n",
" NaN | \n",
"
\n",
" \n",
" 3 | \n",
" SVD | \n",
" 0.005590 | \n",
" 0.012160 | \n",
" 0.018731 | \n",
"
\n",
" \n",
" 4 | \n",
" User-kNN | \n",
" 0.010801 | \n",
" 0.016735 | \n",
" 0.022670 | \n",
"
\n",
" \n",
" 5 | \n",
" Item-kNN | \n",
" 0.012184 | \n",
" 0.017599 | \n",
" 0.023014 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Agent 0.025 0.500 0.975\n",
"0 Random NaN 0.000000 NaN\n",
"1 Popular NaN 0.000000 NaN\n",
"2 User-pop NaN 0.000000 NaN\n",
"3 SVD 0.005590 0.012160 0.018731\n",
"4 User-kNN 0.010801 0.016735 0.022670\n",
"5 Item-kNN 0.012184 0.017599 0.023014"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaUAAACwCAYAAACvvnEyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXVUlEQVR4nO3debxVZb3H8c9XBbyIggppokDlkLN1MdMs8ZV11exqpV7TUMwcsvGaJl0nJDGyLCu1Qst5zrHBHF9YmkOQI5XlAKKIAiqTA9Pv/vE8Rxf77H04h7PP2Wsfvu/Xa7/cez1rPftZz6v48ay9WF9FBGZmZmWwWqMHYGZm1sJFyczMSsNFyczMSsNFyczMSsNFyczMSsNFyczMSsNFyawBJI2RdHl+P0TSAkmrN2AcUyXt0d3fa1aLi5JZF5F0sKRJueC8KOlWSbtW7hcRz0VEv4hY2ohx1iLpYkln5PfDJEU+lwW5mI0u7LuvpEckzZM0W9Ldkt7TuNFbs1qj0QMw64kkHQeMBo4BbgMWAXsC+wL3NnBonTUgIpZI2hm4S9IjwFPApcBngbuBfsAngVIVWWsOXimZ1Zmk/sBY4CsRcUNELIyIxRHx24g4ocr+LauQNfLniZK+J+mhvPK4WdJ6FfseJWlGXoEdX+hrNUmjJT0taY6ka1uOze0jJU3LbSet7DlGxP3AFGAbYAfg2Yi4K5L5EXF9RDy3sv3bqstFyaz+dgbWBG7sRB+HAl8E3g0sAX5a0b47sBlpRXJi4XehrwH7AbsBGwGvAucBSNoK+DkwMretD2zc0YEp+QiwNfAw8Dfg/ZJ+LGl3Sf062qdZCxcls/pbH5gdEUs60cdlEfFERCwETgEOrLgR4vS8AnscuAj4fN5+DHBSRDwfEW8BY4D98ypsf+B3EfGn3HYKsKyD45oNvAJcCIzOq6NngBHAYOBaYHb+PcrFyTrMvymZ1d8cYKCkNTpRmKYX3k8DegED22jfNr8fCtwoqVhslgIbkFZHbx8XEQslzenguAZWO6eIeAA4EEDSjsA1wEnAdzrYv63ivFIyq7/7gbdIl9FW1iaF90OAxaRVSq32Gfn9dGCviBhQeK0ZES8ALxaPk9SXtKqrq4j4K3AD6fcmsw5xUTKrs4iYC5wKnCdpP0l9JfWStJeks9rZzRckbZULx1jgNxW3jJ+S+90aOJy0MgH4BTBO0lAASYMk7ZvbfgPsI2lXSb1zv53+MyD3d6Skd+XP7wf+G3igs33bqsdFyawLRMTZwHHAycAs0grmq8BN7eziMuBiYCbppomvV7TfQ7oV+y7ghxFxe97+E+AW4HZJ80mFYac8pinAV4ArSaumV4HnO352rbxGKkKPS1oA/JF0k0d7C7DZ2+SQP7NykTQRuDwiLqzSNgx4FujVyRspzErJKyUzMysNFyUzMysNX74zM7PS8ErJzMxKw0XJzMxKw0906ISBAwfGsGHDGj0MM7OmMnny5NkRMaham4tSJwwbNoxJkyY1ehhmZk1F0rRabU19+a7ykf9mZtbc6lKUWiKVJY2SVIoAs0LB+kPF9ssljcnvR+R9zq/Y515Jo7pvtGZmBk2+UmqnnSTt0kb7QmBk/pfyZmbWQPW87LUl8AOgV37+1ZKIGCCpDzCO9Fj7PqRnYv1vRLwhaQRwOSnA7HjSI/a/TIqOPof0qP4fRsSZ7RmApM8BZwP7AAvy5rPy9+9e47DX8phOIz3Yst0ef2Euw0b/viOHmFlJTR3/qUYPwahvUfoHKWDsSxGxa2H7eOB9pMjkxaSHQZ7KOzkrG5IeODkYGAVcANwB/CfpkfyTJF0VEc+29eWSDiflt+wREU8VVj7nA1+XtEdE3Fnj8HHAvySNj4gn233GZiU388rRjR5C0xjxwA8aPYSmMXHixC7ru0sv30kScBRpZfRKRMwHzgQOKuy2GBgXEYuBq0mro59ExPz8VOO/A9uv4Ku+CZwAjIiIpyra3iAVnTNqHRwRM0mP/B/bjnM6StIkSZOWvj53RbubmVkHdPVda4OAvsDkVJ8AEFCMdZ5TyIl5I//3pUL7G0A/gHxZsMVWhfcnAGMjotZj+C8ETpD06TbG+n3gaUltFsCImABMAOjz7s38jCYrtQ0PHt/oITSNib58Vwr1LkqVf0jPJhWVrXPyZec6j+hX/Fy4RPdJ4I+SZkbE9VWOWyTpdOC7wJQafc+RdE7ex8zMGqDeReklYGNJvSNiUUQsk3QB8GNJX42IlyUNBraJiNvq+L1TgD2B2yQtjohbquxzGTA67/fvGv38CHiGtJpboW0H92eS/3ZlZlY39f5N6W5SgZgpaXbediIpIfMBSfOAO4Et6vy9RMSjpLvuLpC0V5X2paQbLNZro495pLv1au5jZmZdx9EVnTB8+PDwY4bMzDpG0uSIGF6tbVX4x7NmZtYkXJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0ekRRyumxmzZ6HGZm1jn1ikNviR5fo2L7xZJqRkZ0hzyuxyWtVth2hqSL8/sVxqabmVn3aLqVkqTVV7xXKxuxfIZTNSuKTTczsy7W1XlKb8uX137FOwm0d0XE/+S29wM/I6XNzgJOiYhrc9vFpPiLocBuwL6kh7rW+p5dgauAkRExMW8+Czhd0rURsaTGoSuKTW/Fcehm5eV48+bUbUWJlFN0O+kP/d7AcABJa5Hiz08F9gK2Be6Q9ERE/D0fezCwN+kp4L1rfYGkPUlx6p+LiIcKTTcAB5Li1i+scXh7YtPNupwjzOvD8eb105Xx55W68/LdYtJqZ6OIeDMi7s3b9wGmRsRFEbEkIh4GrgcOKBx7c0TcFxHLIuLNGv0fAPwS2KuiIEEKHzwFOEVSraK2wth0cBy6mVlXqtdKqeWSWK/C+5bPi/P7b5NWSw9JehU4OyJ+TSpUO0l6rWJclxU+T295I2lKPgZSAfpzfv9N4NKIeKLaACPiD5KeB45u4zxWGJvuOHTrao4wrw/HmzenehWlF0nFZxjwj8L295AuzRERM4Ej4e3ffe6U9CdSwbknIj7RRv9v/+EfEVvX2OcA4FeSno+In9TY5yTS701XVf2SdsSmm5lZ16lLUYqIpZKuB8ZJOhKYB+wPbAXcCiDpAOD+iHgeeJVUaJYBvwPGSxoJXJ273AFYEBH/oP1mAB8HJkpaFBE/rzLOiZKeAA4Dflujn/bEpgOOQzczq7d6/qZ0LPAK8BjwMvBV4FMR8VJu3xF4UNIC4BbgGxHxTETMBz5JumV7BjAT+D7Qp6MDiIjnSIVptKQv1djtZNqORF9hbLqZmXUNx6F3guPQzcw6znHoZmbWFFyUzMysNFyUzMysNFyUzMysNFyUzMysNFyUzMysNFyUzMysNFyUzMysNFyUzMysNHpEUZJ0hKR/Spov6SVJf5C0tqTR+aGvlfsPlLRI0jaSRklaKmlBfj0r6SJJmzfiXMzMVmVNX5Qk7QacCXw+ItYGtgSuyc2XA7tIek/FYQcBjxdiLu6PiH5Af2APUrbSZEnbdPkJmJnZ27ozebar7EgqKg8DRMQrwCW5bb6ku4GRwNjCMYcCl1Z2lB/G+jRwrKQhwBjS086rchy6lZ0jwa3Z9ISi9CDw3ZyDdDswKSLeKrRfQiouYwEkbUGKxljR/1tvAL5X99Ga4767kSPBu093Rob3ZE1/+S4nz34W+CDwe2COpB9JWj3vciOwgaRd8udDgVsjYtYKup5BlfgKx6GbmXWdnrBSIiJuBW6VtBqwO3Ad8CTwy4h4XdJ1wKGS7gcOAb7Vjm4Hk/KhKr/Lceid5Ljv7uNIcGs2PaIotYiIZcBd+Xek4k0KlwA3kS7JrU3t1NmizwB/bmsHJ8+amdVX0xclSfsC/wHcBrxGuvFhN+Cbhd3+nNsmAFdHxKIafa0ODAGOA0YAO3fZwM3MrJWm/00JeBU4Evg3MI90G/gPIuKKlh0ixeteCgylyl13wM45pn0eMBFYB9gxIh7v2qGbmVmR49A7wXHoZmYd5zh0MzNrCi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGnUtSjla/N4q26dK2qOe32VmZj1PU66UJDX9g2TNzKy1bi9KkvaW9HdJ8yW9IOn4Qts+kh6R9Jqkv0jartA2VdKJkh4DFlYWJkkjJD0v6f8kzc77H1Jo7y/pUkmzJE2TdHLOX2pZ4d0n6VxJcyX9U9LHu2E6zMysoBErpV8BR0fE2qTMo7sBJH0A+DVwNLA+8EvgFkl9Csd+nhRjPiAillTpe0NgICmg7zBgQo4/B/gZ0B94Lyna4lDg8MKxOwFP5+NPA26Q1Cp51szMuk4jitJiYCtJ60TEqxHxt7z9KFJS7IMRsTQiLgHeAj5cOPanETE9It5oo/9TIuKtiLiHFI9+YM5JOgj4TkTMj4ipwNnAyMJxLwPnRMTiiLiGlFzbKsGvGIc+a9aKEtXNzKwj6l2UlgC9qmzvRSpGAJ8D9gamSbpHUkuQ3lDgW/nS3WuSXgM2ATYq9DMdQNIQSQtaXoX2VyNiYeHztHz8wDyGaRVtgwufX4jlczxajl1OREyIiOERMXzQoEHV5sDMzFZSvYvSc8AQSWrZIKkv8C5yQYiIv0bEvnnbTcC1edfpwLiIGFB49Y2Iqwr9R+7juYjo1/IqtK8raa3C5yHADGA2qSgOrWh7ofB5cHHchWPNzKyb1LsoPQi8CYyWtGYuEOOBSaSVUW9Jh0jqHxGLSUmvy/KxFwDHSNpJyVqSPiVp7Q6O4fT8PR8F9gGui4ilpOI3TtLakoaSIs8vLxz3LuDrknpJOgDYEvjDyk2DmZmtjLoWpYh4i/Q7zAjgeeAZ0iWwAwuXxkYCUyXNA44BDsnHTiLFmp9Lijh/ChjVwSHMzMfOAK4AjomIf+a2rwEL85juBa4k3VjR4kFgM9Kqahywf0TM6eD3m5lZJ/SYOHRJI4DLI2LjlTh2FPCliNi1I8c5Dt3MrOMch25mZk3BRcnMzEqjxxSliJi4Mpfu8rEXd/TSnZmZ1V+PKUpmZtb8XJTMzKw0XJTMzKw0XJTMzKw0XJTMzKw0VpmiJCkkbdrocZiZWW2NCPkbIWlZfsL3fElPSjp8xUeamVlP16iV0oz8dO91gBOBCyRt1aCxtMnR62Zm3aehl+8iuYn0ENWtJPWRdI6kGfl1TkvybDvizidK+lLh8yhJ91b73vz08YclzZM0XdKYQtuwfKnvCEnPkZNxzcys6zW0KElaTdJngAHA48BJpKTZHYDtgQ8BJxcOaSvuvCMWkuLQB5Ceav5lSftV7LMbKb7iv1aifzMzWwmNKkob5WTZ2cBpwMiIeJIUYzE2Il6OiFnA6SwfWQ5V4s47+uX5kUSPR8SyiHgMuIpUhIrGRMTCyuh1x6GbmXWdRv1eMqPGc+o2onVkeTGSvFbceYdI2okUPrgN0BvoA1xXsdv0asdGxARgAqToio5+t5mZ1Va2W8Jn0DqyvBhJXivuHNIlub6Ftg3b+J4rgVuATSKiP/ALQBX7uOCYmXWzshWlq4CTJQ2SNBA4leUjy6FK3Hne/gjwWUl9879HOqKN71kbeCUi3pT0IeDg+p6GmZmtjLLd7nwG6Tbxx/Ln6/K2FsW489dZPu78x8COwEv5+CuAPWp8z7HA2ZLOBe4BriXd9GBmZg3UNHHonYk77yqOQzcz6zjHoZuZWVNwUTIzs9Io229KNUXERKA0l+7MzKz+vFIyM7PScFEyM7PScFEyM7PScFEyM7PScFEyM7PS6FFFyZHnZmbNrUuLUg7aW5qjz+dJelTSPl35nWZm1ry6Y6V0f44+HwCcD1wtyc+ZMzOzVrrt8l1ELAMuA9YCNgOQ9D5Jd0uakyPOrygWrBx5frykxyTNlXSNpDUL7SdIejFHp3+x+H2S+ku6VNIsSdMknSxptdw2StJ9kn4s6TVJz0jaJW+fLullSYd1y8SYmdnbuq0oSVodOBxYzDtBfgK+Rwrq2xLYBBhTceiBwJ7Ae4DtgFG5vz2B44FPkIpc5RPBfwb0B95LSpU9NH9/i51ITxNfn5SvdDXpKeObAl8AzpXUb6VP2MzMOqw7itKHc/T5m8APgS9ExMsAEfFURNyR481nAT+idSz5TyNiRkS8AvwW2CFvPxC4KCKeyGm0Y1oOyAXwIOA7ETE/IqYCZ7N8tPqzEXFRRCwFriEVxLF5LLcDi0gFajmOQzcz6zrdUZQeiIgBwLqktNePtjRI2kDS1ZJekDSPFOg3sOL4mYX3rwMtq5eNWD6yvBijPhDoReto9cGFzy8V3r8BEBGV21qtlCJiQkQMj4jhgwYNqmw2M7NO6M7flBYAXwZGSvpA3nwmKXZ824hYh3TZrDKWvJYXSaubFkMK72eTLhNWRqu/sBJDNzOzbtKt/04pX4K7kBRzDimWfAEwV9Jg4IQOdHctMErSVpL6AqcVvmdpbh8naW1JQ4HjaB2tbmZmJdKIfzx7DrC3pO2A04EPAnOB3wM3tLeTiLg193U38FT+b9HXgIXAM8C9pJsZft3ZwZuZWddpmjj0MnIcuplZxzkO3czMmoJXSp0gaT7wZKPHUUIDSTeb2Ds8J615TqpbFeZlaERUvX25aeLQS+rJWkvQVZmkSZ6X5XlOWvOcVLeqz4sv35mZWWm4KJmZWWm4KHXOhEYPoKQ8L615TlrznFS3Ss+Lb3QwM7PS8ErJzMxKw0XJzMxKw0WpgqT1JN0oaWEOBzy4xn6S9P0cUDgnv1ehfQdJkyW9nv+7Q7V+mkEd52SCpCclLZM0qttOoIt4XlrznLRWjzmRtLmkm5VCS1+RdJukLbr3TLqHi1Jr55GylDYADgF+LmnrKvsdBewHbE8KH/w0cDSApN7AzaQHwK4LXALcnLc3o07PSfYocCzwty4dbffxvLTmOWmtHnMygBT9s0Xu5yHSnzE9T0T4lV+kqPZFwOaFbZcB46vs+xfgqMLnI0jZUQCfJMVkqND+HLBno8+xUXNSsd+9wKhGn5vnxXPSjHOS29Yjxf6s3+hzrPfLK6XlbQ4siYh/FbY9ClT7W83Wua3aflsDj0X+X0/2WI1+yq5ec9LTeF5a85y01lVz8jFgZkTMqcsoS8RFaXn9gHkV2+aScp+q7Tu3Yr9++RpwZVtb/ZRdveakp/G8tOY5aa3ucyJpY9IlwePqOM7ScFFa3gJgnYpt6wDz27HvOsCCvDrqSD9lV6856Wk8L615Tlqr65xIGgTcDpwfEVfVeayl4KK0vH8Ba0jarLBte2BKlX2n5LZq+00Btqv4G852Nfopu3rNSU/jeWnNc9Ja3eZE0rqkgnRLRIzrgrGWQ6N/1CrbC7gauIr0A+VHSEvoravsdwzwD2AwsBHpfzzH5LbewDTgG0Af4Kv5c+9Gn1+j5qQwL2sC9wFH5verNfr8PC+ek7LPCWnV9BBwbqPPp8vnq9EDKNuLdFfLTaQo9eeAg/P2j5KW0i37CTgLeCW/zmL5u+0+AEwG3iDd1vqBRp9bCeZkIumOoeJrRKPPz/PiOSn7nACH5TlYSLrM1/Ia0ujzq/fLz74zM7PS8G9KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZmZWGi5KZk1G0sGSJklaIOlFSbdKujN/XiBpkaTFhc+3ShomKQrbpkoa3ehzMau0RqMHYGbtJ+k4YDTpkTS3kbJ69gQ+FhF75H3GAJtGxBcKxw3LbwdExBJJw4F7JE2OiDu67wzM2uaiZNYkJPUHxgKHR8QNhabf5le7RcQkSVOAHQAXJSsNX74zax47kx5MemNnO5L0YWAb4KnO9mVWT14pmTWP9YHZEbGkE33MltSHVNzOJj0o1Kw0vFIyax5zgIGSOvOXyYGkhNNvASOAXnUYl1nduCiZNY/7gbeA/TrTSUQsjYgfAW8Cx9ZjYGb14qJk1iQiYi5wKnCepP0k9ZXUS9Jeks5aiS7HA9+WtGZ9R2q28lyUzJpIRJwNHAecDMwCppOSjVfmt6HfA6+Skl3NSsEhf2ZmVhpeKZmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWm4KJmZWWn8P6JaP70kYUMaAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"CPU times: user 45min 36s, sys: 2min 15s, total: 47min 51s\n",
"Wall time: 29min 38s\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NqDH9cq_b4vh"
},
"source": [
"### A/B tests"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DsK98lKNb6RB",
"outputId": "d4596fb8-1674-4c67-e0ec-13dbc52e67dd"
},
"source": [
"%%time\n",
"result_AB = verify_agents(env, n_test_users, deepcopy(agents))\n",
"display(result_AB)\n",
"plot_barchart(result_AB, 'A/B-test', 'CTR', 'tab:green', 'ABtest_eval.eps')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Organic Users: 0it [00:00, ?it/s]\n",
"Users: 100%|██████████| 5000/5000 [05:39<00:00, 14.71it/s]\n",
"Organic Users: 0it [00:00, ?it/s]\n",
"Users: 100%|██████████| 5000/5000 [05:35<00:00, 14.92it/s]\n",
"Organic Users: 0it [00:00, ?it/s]\n",
"Users: 100%|██████████| 5000/5000 [06:49<00:00, 12.20it/s]\n",
"Organic Users: 0it [00:00, ?it/s]\n",
"Users: 100%|██████████| 5000/5000 [06:39<00:00, 12.53it/s]\n",
"Organic Users: 0it [00:00, ?it/s]\n",
"Users: 83%|████████▎ | 4163/5000 [18:04<03:51, 3.62it/s]"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ieFcpaa-b8Tt"
},
"source": [
"def combine_barchart(resultAB, resultCIPS, title, xlabel, figname = 'fig.eps', size = (6,2), fontsize = 12):\n",
" fig, axes = plt.subplots(figsize = size)\n",
" plt.title(title, size = fontsize)\n",
" n_agents = len(resultAB)\n",
" \n",
" for i, (name, colour, result) in enumerate([('A/B-test', 'tab:green', result_AB),('CIPS', 'tab:blue', result_CIPS)]):\n",
" mean = result['0.500']\n",
" lower = result['0.500'] - result['0.025']\n",
" upper = result['0.975'] - result['0.500']\n",
" height = .25\n",
" yticks = [a + i * height for a in range(n_agents)]\n",
" plt.barh(yticks,\n",
" mean,\n",
" height = height,\n",
" xerr = (lower, upper),\n",
" align = 'edge',\n",
" label = name,\n",
" color = colour)\n",
" plt.yticks(yticks, result['Agent'], size = fontsize)\n",
" plt.xticks(size = fontsize)\n",
" plt.xlabel(xlabel, size = fontsize)\n",
" plt.legend(loc = 'lower right')\n",
" plt.xlim(.0,None)\n",
" plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.3f'))\n",
" plt.savefig(figname, bbox_inches = 'tight')\n",
" plt.show()\n",
"combine_barchart(result_AB, result_CIPS, 'Evaluate on Bandit Feedback', 'CTR', 'ABtest_CIPS.eps')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Q4NB72lgb-Ek"
},
"source": [
"plot_barchart(result_LOO, 'Evaluate on Organic Feedback', 'HR@1', 'tab:red', 'traditional_eval.eps')"
],
"execution_count": null,
"outputs": []
}
]
}