{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "colab": { "name": "2021-06-12-vowpalwabbit-changing-context-part2.ipynb", "provenance": [], "collapsed_sections": [] } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "RnXUndQf1Zqt" }, "source": [ "# Contextual bandit with changing context part 2\n", "> Customizing the context and changing it midway to see how fast the agent can adapt to the new context and start recommending better products as per the context\n", "\n", "- toc: true\n", "- badges: true\n", "- comments: true\n", "- categories: [contextual bandit]\n", "- image: " ] }, { "cell_type": "code", "metadata": { "id": "NDf-kJiZ1WfJ" }, "source": [ "mapping_users = {\n", " 'Alex':'usera',\n", " 'Ben':'userb',\n", " 'Cindy': 'userc'\n", "}\n", " \n", "mapping_context1 = {\n", " 'Morning':'ctx11',\n", " 'Evening':'ctx12',\n", "}\n", "\n", "mapping_context2 = {\n", " 'Summer':'ctx21',\n", " 'Winter':'ctx22'\n", "}\n", "\n", "mapping_items = {\n", " 'Politics':'item1',\n", " 'Economics':'item2',\n", " 'Technology':'item3',\n", " 'Movies':'item4',\n", " 'Business':'item5',\n", " 'History':'item6'\n", "}\n", "\n", "# {v:k for k,v in mappings.items()}" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ywUBRGWp1WfO" }, "source": [ "from vowpalwabbit import pyvw\n", "import random\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "from itertools import product" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "bLp9p35a1WfT" }, "source": [ "users = list(mapping_users.values())\n", "items = list(mapping_items.values())\n", "context1 = list(mapping_context1.values())\n", "context2 = list(mapping_context2.values())" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "bJPAR-pJ1WfW" }, "source": [ "context = pd.DataFrame(list(product(users, context1, context2, items)), columns=['users', 'context1', 'context2', 'items'])\n", "context['reward'] = np.random.choice([0,1],len(context))\n", "context['cost'] = context['reward']*-1\n", "contextdf = context.copy()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "EklmWXvw1WfY", "outputId": "13a9bf54-8780-40ac-b8ee-6a1f7a7b8d63" }, "source": [ "contextdf" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
userscontext1context2itemsrewardcost
0useractx11ctx21item11-1
1useractx11ctx21item21-1
2useractx11ctx21item31-1
3useractx11ctx21item400
4useractx11ctx21item500
.....................
67usercctx12ctx22item200
68usercctx12ctx22item300
69usercctx12ctx22item41-1
70usercctx12ctx22item51-1
71usercctx12ctx22item600
\n", "

72 rows × 6 columns

\n", "
" ], "text/plain": [ " users context1 context2 items reward cost\n", "0 usera ctx11 ctx21 item1 1 -1\n", "1 usera ctx11 ctx21 item2 1 -1\n", "2 usera ctx11 ctx21 item3 1 -1\n", "3 usera ctx11 ctx21 item4 0 0\n", "4 usera ctx11 ctx21 item5 0 0\n", ".. ... ... ... ... ... ...\n", "67 userc ctx12 ctx22 item2 0 0\n", "68 userc ctx12 ctx22 item3 0 0\n", "69 userc ctx12 ctx22 item4 1 -1\n", "70 userc ctx12 ctx22 item5 1 -1\n", "71 userc ctx12 ctx22 item6 0 0\n", "\n", "[72 rows x 6 columns]" ] }, "metadata": { "tags": [] }, "execution_count": 56 } ] }, { "cell_type": "code", "metadata": { "id": "FAOCxBU31Wfb" }, "source": [ "import numpy as np\n", "import scipy\n", "import scipy.stats as stats\n", "from vowpalwabbit import pyvw\n", "import random\n", "import pandas as pd\n", "from itertools import product" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ojr3s0vC1Wfc" }, "source": [ "# This function modifies (context, action, cost, probability) to VW friendly format\n", "def to_vw_example_format(context, actions, cb_label=None):\n", " if cb_label is not None:\n", " chosen_action, cost, prob = cb_label\n", " example_string = \"\"\n", " example_string += \"shared |User users={} context1={} context2={}\\n\".format(context[\"user\"], context[\"context1\"], context[\"context2\"])\n", " for action in actions:\n", " if cb_label is not None and action == chosen_action:\n", " example_string += \"0:{}:{} \".format(cost, prob)\n", " example_string += \"|Action items={} \\n\".format(action)\n", " #Strip the last newline\n", " return example_string[:-1]\n", "def sample_custom_pmf(pmf):\n", " total = sum(pmf)\n", " scale = 1 / total\n", " pmf = [x * scale for x in pmf]\n", " draw = random.random()\n", " sum_prob = 0.0\n", " for index, prob in enumerate(pmf):\n", " sum_prob += prob\n", " if(sum_prob > draw):\n", " return index, prob\n", "def get_action(vw, context, actions):\n", " vw_text_example = to_vw_example_format(context, actions)\n", " pmf = vw.predict(vw_text_example)\n", " chosen_action_index, prob = sample_custom_pmf(pmf)\n", " return actions[chosen_action_index], prob\n", "def choose_user(users):\n", " return random.choice(users)\n", "def choose_context1(context1):\n", " return random.choice(context1)\n", "def choose_context2(context2):\n", " return random.choice(context2)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "S2WXd8Xd1Wff" }, "source": [ "class VWCSimulation():\n", " def __init__(self, vw, ictxt, n=100000):\n", " self.vw = vw\n", " self.users = ictxt['users'].unique().tolist()\n", " self.contexts1 = ictxt['context1'].unique().tolist()\n", " self.contexts2 = ictxt['context2'].unique().tolist()\n", " self.actions = ictxt['items'].unique().tolist()\n", " self.contextdf = ictxt.copy()\n", " self.contextdf['cost'] = self.contextdf['reward']*-1\n", " \n", " def get_cost(self, context, action):\n", " return self.contextdf.loc[(self.contextdf['users']==context['user']) & \\\n", " (self.contextdf.context1==context['context1']) & \\\n", " (self.contextdf.context2==context['context2']) & \\\n", " (self.contextdf['items']==action), \\\n", " 'cost'].values[0]\n", " \n", " def update_context(self, new_ctxt):\n", " self.contextdf = new_ctxt.copy()\n", " self.contextdf['cost'] = self.contextdf['reward']*-1\n", " \n", " def step(self):\n", " user = choose_user(self.users)\n", " context1 = choose_context1(self.contexts1)\n", " context2 = choose_context2(self.contexts2)\n", " context = {'user': user, 'context1': context1, 'context2': context2}\n", " action, prob = get_action(self.vw, context, self.actions)\n", " cost = self.get_cost(context, action)\n", " vw_format = self.vw.parse(to_vw_example_format(context, self.actions, (action, cost, prob)), pyvw.vw.lContextualBandit)\n", " self.vw.learn(vw_format)\n", " self.vw.finish_example(vw_format)\n", " return (context['user'], context['context1'], context['context2'], action, cost, prob)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "z_33SgaP1Wfi", "outputId": "bc8f39ed-e666-40d8-f69d-58960bbba94c" }, "source": [ "context = pd.DataFrame(list(product(users, context1, context2, items)), columns=['users', 'context1', 'context2', 'items'])\n", "context['reward'] = np.random.choice([0,1],len(context),p=[0.8,0.2])\n", "contextdf = context.copy()\n", "contextdf.reward.value_counts()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0 61\n", "1 11\n", "Name: reward, dtype: int64" ] }, "metadata": { "tags": [] }, "execution_count": 136 } ] }, { "cell_type": "code", "metadata": { "id": "hLYqC-B21Wfj" }, "source": [ "vw = pyvw.vw(\"--cb_explore_adf -q UA --quiet --epsilon 0.2\")\n", "vws = VWCSimulation(vw, contextdf)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "iN7_K3-l1Wfl", "outputId": "382ec01f-22ba-45c2-ba2a-d181ac41db11" }, "source": [ "vws.step()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "('usera', 'ctx11', 'ctx22', 'item4', 0, 0.16666666666666666)" ] }, "metadata": { "tags": [] }, "execution_count": 138 } ] }, { "cell_type": "code", "metadata": { "id": "GnEvCkub1Wfm" }, "source": [ "_temp = []\n", "for i in range(5000):\n", " _temp.append(vws.step())" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "IykFBA931Wfn" }, "source": [ "x = pd.DataFrame.from_records(_temp, columns=['user','context1','context2','item','cost','prob'])" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "-36YdGnH1Wfn", "outputId": "3adfdb67-ec48-4a0e-bfd9-861dd5be98d2" }, "source": [ "xx = x.copy()\n", "xx['ccost'] = xx['cost'].cumsum()\n", "xx = xx.fillna(0)\n", "xx = xx.rename_axis('iter').reset_index()\n", "xx['ctr'] = -1*xx['ccost']/xx['iter']\n", "xx.sample(10)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
iterusercontext1context2itemcostprobccostctr
31993199useractx11ctx22item5-10.833333-17810.556736
700700userbctx11ctx22item4-10.833333-3430.490000
36603660useractx12ctx22item500.833333-20350.556011
41234123usercctx12ctx22item5-10.833333-23020.558331
44484448userbctx11ctx22item4-10.833333-24800.557554
485485useractx11ctx21item2-10.833333-2190.451546
32803280userbctx11ctx21item100.033333-18220.555488
16791679userbctx12ctx21item6-10.833333-9010.536629
34893489usercctx11ctx21item3-10.833333-19320.553740
103103useractx11ctx22item200.833333-240.233010
\n", "
" ], "text/plain": [ " iter user context1 context2 item cost prob ccost ctr\n", "3199 3199 usera ctx11 ctx22 item5 -1 0.833333 -1781 0.556736\n", "700 700 userb ctx11 ctx22 item4 -1 0.833333 -343 0.490000\n", "3660 3660 usera ctx12 ctx22 item5 0 0.833333 -2035 0.556011\n", "4123 4123 userc ctx12 ctx22 item5 -1 0.833333 -2302 0.558331\n", "4448 4448 userb ctx11 ctx22 item4 -1 0.833333 -2480 0.557554\n", "485 485 usera ctx11 ctx21 item2 -1 0.833333 -219 0.451546\n", "3280 3280 userb ctx11 ctx21 item1 0 0.033333 -1822 0.555488\n", "1679 1679 userb ctx12 ctx21 item6 -1 0.833333 -901 0.536629\n", "3489 3489 userc ctx11 ctx21 item3 -1 0.833333 -1932 0.553740\n", "103 103 usera ctx11 ctx22 item2 0 0.833333 -24 0.233010" ] }, "metadata": { "tags": [] }, "execution_count": 141 } ] }, { "cell_type": "code", "metadata": { "id": "X1Aa6gjX1Wfo", "outputId": "5b62b160-8a68-4b61-f1f2-28f524737614" }, "source": [ "xx['ccost'].plot()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 142 }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "L4MXTt071Wfo", "outputId": "4d5ae04a-a638-4b3a-8f98-eef32ede93f4" }, "source": [ "xx['ctr'].plot()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 143 }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAc60lEQVR4nO3de3xdZZ3v8c8vO/dbmzTpNb1CuRS5lVBBFJCbBZTLy5kjKF4YPcjhoONxvKDDOCjqeDnjUQ4wlUEcmDMj44ygPVjtMIwICkhTKEgpLaUtbeglKSVN0iY72Xv/5o/spjvJbrObJllZa3/fr1deWetZK3v9nl6+XX3Ws9Yyd0dERMKvIOgCRERkdCjQRUQiQoEuIhIRCnQRkYhQoIuIRERhUAeuq6vzefPmBXV4EZFQWr169W53r8+2LbBAnzdvHk1NTUEdXkQklMzs9UNt05CLiEhEKNBFRCJCgS4iEhEKdBGRiFCgi4hEhAJdRCQiFOgiIhER2Dx0EYm+9u5eigoKMIOuniSVpYXsiyeoKi2ivauXyeVFxBMp2rt6aemIs7erl71dvcypLWfW5DJqKopHvSZ3x8x4a18Pb7R1MbWqhO7eFPt7E8ycXEZVSSEvNu9ld2echppydrV3A1BbUUxvMkV1WRFVpYVUlRRhBvviCSpKCiktig05VirldPUmWbu9nZLCAl5t6WR3Z5zTZk/mrAVTRr1vCnSRkOlJpPhp0zb29yQ4tWEyi+fWUBQrwN3Z35OkomRs/lonkil+s76VlWt3srsz3t++t6uXmZPKOG32ZHa2d/PWvh42tnbSGU+wZfc+AFIjfO1CXWUxc2rLOWnmJGbVlPHOY+s4aWY1ZnbIn3F3OuIJXt3VSTyRpL0rwYZdHax+/S1KCgt4etOblBTGeGt/D8mRFjZIdWkh5x8/lfbuXjq6E7zW2knb/t5D7v+Jd84fk0C3oF5w0djY6LpTVCQ33b1J1mxr44/Ne/nHZ15n6579h9y3rrKEE2dUcUx9JYvn1lBRHGNadSmv7Oxgfl05x0+vprKkkO7eJHv29bBjbxc9CWfH3i52tcfZumcfr+zsoL2rl+qyItz7jv/Kzg4AYgVGXWUxhuE4k8uK2bx7Hz3JVH8Nx06tZMakUhbUVVBSFCPem8TMKIoZxYUFtO3vZUplCbv2dlNcWMD0SaVUlxVRXhSjuLCA3Z1xmt/qYld7Ny80t7FtT1f/Z1eXFjK/vpICg7KiGDMmldEZ78UdUu6s2bZ3wD84B1SWFDKprIiTZlaTTDkNNWW8bdYkdqZrqCotYmNLJ4lUimOnVtJQU8aOvd3MqS1nf0+Szu4EiVSK7t4Ubft7iSeSJN0pK4qxZlsbv3t1Nw01ZVSVFtFQU8bUqhJKi2IUxQpYOK2S4lgBDTXlzJlSTnVp4WH/UTocM1vt7o1ZtynQRSaORDLFj363mWc2vUnS+85+u3uSrN/V0b/P7NoybnvfSUyfVMrj61vZ2NJJyp2a8mLcnRea9/JaSycd8cSI65hSUczMyWV0xhPECozKkkJOnFHF6XNquPK0mZQUDhxeSKacrXv2U1dZTFVp0YiPm4270xlPsGdfD8vXbOfZLXuIJ1K82RmnpT3O/t4kU6tKACiMGSdMr+aUWZOYXVtOZUkhNRXF1FYUM7+uYlTrylbnSEP6SCjQRSYod+ep197k3ic38eSru0mkhwAKC4yiWAFzasuJFRgnzqjm5FnVnDCjmsa5NRTGhp/P0NLRzZqtbfQkU+xo66auqpjOeJKXt+9lXzxJTXkRk8uLaagpozBmxAoKOHF6FQvqK4kVjH0wjZbxCtKJ4nCBrjF0kTHU3ZukpLAAM+P1N/exc283p82ZzObd+1i3o50Hnn6d57e2UVxYwHHTqphUVsSHz57LZSfPOOpjT60q5ZKTpo9CLya2fArz4SjQRcbA8he28+mfPD/sfmVFMT7/nuP5wJmzqassGYfKJMoU6BIJ7k48kco6dSwXvckUuzvjbN69j2nVfRfzMs/8tu3Zz92Pv8ZPnt0KwOTyIhbUVeDAjrZuJpcXcfKsSTy+oZXu3iQd3QPHr+sqiwFjd2ecM+bWcNXps5hRXUrjvBoml4/+1DzJTwp0ydm/Nm3j8//2IgCfvnAhZy2o5R3H1AF9gXfud3/DlIoSHv1f51JTUUx3b5L/vXI9v3ppJ2+0dXHBCVP5u+sWD7mglovWjjhlxTH2xRP8+PdbuHbJbOZOqcDd6e5NceHfPs72vd2H/PmqkkKmVpdwxamz+OR5C/jminU88PQhHysN9IXw/p4k+3uSQ7a17e/lua1t/es727v7Z4EAXLtkDrddsWhEfRUZKV0UlZx89f+v5ce/35J125nzali15a0j/swl82q57/oz2djSyR82vckrOzvY1NrJC817j7LaI1dYYBTGjO7eVNbtdZUlXHXaTG5697FUlxYST6R4fmsbC+or6E2m6E068UTfePmxU6vGuXrJJ5rlIkespb2b9u4EZvDQc83c9ZvXAPjUBcfS0Z3goeeaaR80rHDSzGrOmFsz4Mz35FmTuOPa09kXT/DJf1zNG21dHK2zFtTyzKY9A9ruuPZ0rjh1Zv96KuXs7eqlpqKYlo5udu2N01BTxofu/QMv72jnY++Yx80XHEvKnalVpVmPM/jvhi6+yUSgQJcjsrGlg4u+98SQ9qZbLxpy4e7+p7Zw/1NbePimc5hU3jf/eLhpZMmU09md4Mxv/Ef/zShXnDqT+qoSyopivPfUGRxbX8nO9m4mlxfTtGUP5x1XP+AzO+MJ1r6xl8Z5taGaYidytBTokrOW9m6WfPOxIe3Lbz6HUxomj39BIjKA5qHLYXV099IZT/DYuhZu/flLAHzukuO4+YKFAVcmIkdCgR4iR3tHXCrldPYkKC3se15GKuUs+PKKrPt+8rxjRnwcEQmGAj0kfvS7zdz+yMucc+wU7rx2cf9jRVMp5/ZfvsyfnjGbRTOr+/ePJ5Jcd+8fMDOe3bxnyOdNry5lZ/vQaX63Xn4in3jXgrHriIiMGY2hh0Ay5Rwz6Ez6hOlV/OLmczj+1l/3t11/zjy2t3Vx/TnzueaeZ3L+/Of+6mJKCgvG7LGrIjJ6dFE0xOKJ5IDQHonFcybTOK+W/9bYwLwpFXzv0Q3c/fhrLLtuMRedOC2nBz2JyMSgQA+pju5eTr7t3/vXX7l9KaVFMR5bt4uP3980oP3k21bSmxz4e/nkF97N7NrycatXRMbeUQe6mS0FfgDEgHvd/VuDtp8P/ALYnG56yN2/drjPVKAf3vce3cAdj73av77pm5dRkDHf+merm5lSWcz5x08F+i6YtnTEmVad/SYZEYmGo5q2aGYx4C7gYqAZWGVmy9395UG7Punu7z3qaoXGr//HgDeu3HLpCQPCHOD9ZzQMWDczhblInsvlKtgSYKO7bwIwsweBK4HBgS5Had4tvxyw/u7j6/nx9UsCqkZEwiaXQJ8FbMtYbwbenmW/s83sBWA78Dl3Xzt4BzO7AbgBYM6cOUdebYT98x+2Dli/72ONXHDCtICqEZEwymV6Q7Y7WQYPvD8HzHX3U4H/C/w82we5+z3u3ujujfX19UdUaJSlUs6XH/5j//rH3jFPYS4iRyyXM/RmYHbGegN9Z+H93L09Y3mFmd1tZnXuvnt0ygyntv093PRPz1FVWsgPP9xI2/4ezv3Ob3j7gilcsmgaf//kJjbs6hzwM1u+dXlA1YpI2OUS6KuAhWY2H3gDuAb4YOYOZjYd2OXubmZL6Dvzf3O0iw2b0772aP9y5vj4oy/v4tGXdw3Z/4+3XTIudYlINA075OLuCeBmYCWwDvipu681sxvN7Mb0bn8CvJQeQ78DuMaDmuAeoCdfbSWe6Hu7zUlfyf1moJryIv7zL86jqrRorEoTkTyQ073e7r4CWDGobVnG8p3AnaNbWni4O/O/dPCX55FPvZN96deW3f9nS1i9ZQ8tHXEeXLWNG85dwJcvO5FkyjEYMh1RRGSkdKfoKBg83TCTxsRFZDQd7sYiPcTjKN22fMjszH4bvn7pOFYiIvlOgX4U9uzr4R+e2gLAX1x8HFu+dTkfenvf/PopFcUUF+qXV0TGj56XehQeW3dwpsqnLux7u883rj6Zb1x9clAliUge0ynkCLV39/L5f3sRgLVffU/A1YiIKNBH7JSMx9rqxRAiMhEoiY5QKuX8dcaF0L9+36IAqxEROUiBfgSefLWVT/3kedr29/a3XX/O/AArEhE5SIGeo3giyYd/9OyAtnVfWxpQNSIiQynQc3Tvk5sHrG/+m8sw012eIjJx6KJoDvb3JPjuyvX963d+8HSFuYhMODpDz8F7vv9E/7Ju5ReRiUpn6MPoSaTYtqcLgFsvPzHgakREDk2BPoynNx18rPsn3rUgwEpERA5PgX4YLR3dfPS+vpktD95wVsDViIgcngJ9kDXb2vibX60DYMk3HutvP2vBlKBKEhHJiS6KDnLVXb8H4Ie/3dTf9tvPnx9QNSIiudMZeoar7/79kLbPXnwcc6dUBFCNiMiRUaCnfefXr/D81rYh7Z9OPxZXRGSiU6Cn3f34a/3LC+r6zsjvuPb0oMoRETliCnT6XvJ8wG8/fz4XLZoGHAx2EZEw0EVRoL0rAfTdODR3SgWfu+R43n38VN42a1LAlYmI5E5n6MCa5jYAqsuKACguLODsYzRNUUTCRYEO/TcPVZfqPywiEl55H+jJ1MHx8/ecND3ASkREjk7eB/q6He0A/OkZDXokroiEWt4H+mf+ZQ0Ak8uLgi1EROQo5RToZrbUzNab2UYzu+Uw+51pZkkz+5PRK3FsbWzpBODmC3QDkYiE27CBbmYx4C7gUmARcK2ZDXnVfXq/bwMrR7vIsfLoy7v6lyeV6QxdRMItlzP0JcBGd9/k7j3Ag8CVWfb7FPAzoGUU6xszm3fv478/0BR0GSIioyaXQJ8FbMtYb0639TOzWcDVwLLDfZCZ3WBmTWbW1NraeqS1jqqfrW7uX374pncEWImIyOjIJdCzTf3wQevfB77o7snDfZC73+Puje7eWF9fn2OJY2POlHIAvrj0BE6fUxNoLSIioyGXO2magdkZ6w3A9kH7NAIPpqf91QGXmVnC3X8+GkWOhT37egD4yNlzA65ERGR05BLoq4CFZjYfeAO4Bvhg5g7uPv/Aspn9A/DIRA5zgDc745QUFlBeHAu6FBGRUTFsoLt7wsxupm/2Sgy4z93XmtmN6e2HHTefiFo74vz9k5sBdDORiERGTg8vcfcVwIpBbVmD3N0/dvRlja0r7vxd0CWIiIy6vLxTdMfebgBOadDjcUUkOvIu0DMfxnXmvNoAKxERGV15F+jLX3ijf/nq02cdZk8RkXDJmweAJ1PObze00NWTAuBn/+MdeiORiERK3gT6MV8ecE2Xuekbi0REoiIvhly6e4fewFpbXhxAJSIiYycvAv2lN/YOaSso0PxzEYmWvBhy+caKdUDfO0Pv/OBiZtWUBVyRiMjoy4tAf35rGwA//HAjZx8zJdhiRETGSF4MubxrYR0Ab5+veeciEl15Eeg15cXMnVKucXMRibTID7nEE0mWvzD4ab8iItET+TP0JzbsDroEEZFxEflAv/2Rl4MuQURkXEQ+0Lfu2Q/ALz/9zoArEREZW5EP9AMWzagOugQRkTGVF4F+/vH1ejORiERepAO9pb3vRRaPr28NuBIRkbEX6UBfuXYnAH9+4cKAKxERGXuRDfTVr7/FX/1iLQAnavxcRPJAZAN9b1dP/3JRTOPnIhJ9kQ1042CIn3dcfYCViIiMj8gGemc80b9cGItsN0VE+kU26TIDXUQkH0Q20Ndu73tL0R9vuyTgSkRExkfknrYYTyS58G9/S/NbXQBUFEeuiyIiWeV0hm5mS81svZltNLNbsmy/0sxeNLM1ZtZkZoE9OGX9zo7+MAe9O1RE8sewp69mFgPuAi4GmoFVZrbc3TMfY/gYsNzd3cxOAX4KnDAWBQ/nsXUtQRxWRCRwuZyhLwE2uvsmd+8BHgSuzNzB3Tvd3dOrFYATkB889mpQhxYRCVQugT4L2Jax3pxuG8DMrjazV4BfAn+W7YPM7Ib0kExTa+vYP1+l6daLxvwYIiITRS6Bnm0QesgZuLs/7O4nAFcBt2f7IHe/x90b3b2xvn5sb/Z59ssXUldZMqbHEBGZSHIJ9GZgdsZ6A3DIl3S6+xPAMWZWd5S1HbFkqu/fmXlTyplaXTrehxcRCVQugb4KWGhm882sGLgGWJ65g5kda+kHjpvZYqAYeHO0ix3O5t2dAGx5c/94H1pEJHDDznJx94SZ3QysBGLAfe6+1sxuTG9fBrwf+IiZ9QJdwAcyLpKOG73EQkTyWU533bj7CmDFoLZlGcvfBr49uqUdue7eZNAliIgEJlK3/h8I9GXXnRFwJSIi4y9Sgd7VkwKgtqI44EpERMZfpAL9heY2AMqKYsEWIiISgEgF+ndXrgcgpue3iEgeilSgH6BAF5F8FKlAv+LUmQAcP70q4EpERMZfpAK9uLCAWZPLgi5DRCQQkQr0eCJFSWGkuiQikrNIpV+8N0mxAl1E8lSk0i+eSFGiKYsikqciFuhJDbmISN6KTPolU85b+3oV6CKStyKTfrc/8jLrd3VoDrqI5K3IBPoDT28BIN6bCrYQEZGA5PT43Inosz9dw69f2sm5C+v5zMULSb+siPW7OoItTEQkIKEN9IeeewOAX6/dyeMbWvrb9+zrCaokEZFARWLIpVvDLCIi0Qj0TJe+bXrQJYiIBCJygf5/PnBa0CWIiAQicoFeqjtFRSRPRS7QRUTyVaQC/TvvPyXoEkREAhOpQC/QXaIiksciFehdPYmgSxARCUwoA93ds7aftWDKOFciIjJxhDTQh7Z9YenxLJymd4mKSP7KKdDNbKmZrTezjWZ2S5btHzKzF9NfT5nZqaNf6kGD8/zrV72Nm84/diwPKSIy4Q0b6GYWA+4CLgUWAdea2aJBu20GznP3U4DbgXtGu9BMg4dcrjtr7lgeTkQkFHI5Q18CbHT3Te7eAzwIXJm5g7s/5e5vpVefARpGt8yB9ERFEZGhcgn0WcC2jPXmdNuhfBz4VbYNZnaDmTWZWVNra2vuVQ5y+R2/619ePGfyiD9HRCRKcgn0bJO7s04zMbN30xfoX8y23d3vcfdGd2+sr6/PvcrDeOimc0blc0REwi6X56E3A7Mz1huA7YN3MrNTgHuBS939zdEpT0REcpXLGfoqYKGZzTezYuAaYHnmDmY2B3gI+LC7bxj9MkVEZDjDnqG7e8LMbgZWAjHgPndfa2Y3prcvA74CTAHuNjOAhLs3jl3ZIiIyWE6voHP3FcCKQW3LMpY/AXxidEsTEZEjEco7RUVEZKhQB/oPrjkt6BJERCaMUAd6bUVx0CWIiEwYoQ50yzpFXkQkP4U60PU+CxGRg0Id6DpBFxE5KNSBXmBKdBGRAxToIiIREepAV56LiBwU6kDXRVERkYNCHei6KioiclCoA11n6CIiB4U80JXoIiIHhDrQleciIgeFO9A1hi4i0i/UgS4iIgeFOtCTnvVd1SIieSnUgZ5SoIuI9At3oKcU6CIiB4Q60JMKdBGRfqEOdOW5iMhBIQ90JbqIyAEKdBGRiAh1oM+bUhF0CSIiE0aoA312bXnQJYiITBihDnQRETkop0A3s6Vmtt7MNprZLVm2n2BmT5tZ3Mw+N/plDvW+U2eOx2FEREKjcLgdzCwG3AVcDDQDq8xsubu/nLHbHuDTwFVjUeRgsQJjTm3ZeBxKRCQ0cjlDXwJsdPdN7t4DPAhcmbmDu7e4+yqgdwxqzEpPWhQRGSiXQJ8FbMtYb063iYjIBJJLoGc7FR7RBHAzu8HMmsysqbW1dSQf0XdwzT8XERkil0BvBmZnrDcA20dyMHe/x90b3b2xvr5+JB/RT28rEhEZKJdAXwUsNLP5ZlYMXAMsH9uyRETkSA07y8XdE2Z2M7ASiAH3uftaM7sxvX2ZmU0HmoBqIGVmnwEWuXv7WBStARcRkaGGDXQAd18BrBjUtixjeSd9QzHjRiMuIiID6U5REZGICGWga5KLiMhQoQx0QNNcREQGCW+gi4jIAAp0EZGICG2ga8BFRGSg0Aa6iIgMFLpA13NcRESyC12gH6BJLiIiA4U20EVEZCAFuohIRIQu0DWELiKSXegC/QC9gk5EZKDQBrqIiAwUukDXiIuISHahC/QDNG1RRGSg0Aa6iIgMFLpA152iIiLZhS7QD9CIi4jIQKENdBERGSh0ga4BFxGR7EIX6AdolouIyEChDXQRERkodIGuSS4iItmFLtAPMI25iIgMENpAFxGRgUIX6K55LiIiWeUU6Ga21MzWm9lGM7sly3YzszvS2180s8WjX6qIiBzOsIFuZjHgLuBSYBFwrZktGrTbpcDC9NcNwN+Ncp0iIjKMXM7QlwAb3X2Tu/cADwJXDtrnSuAB7/MMMNnMZoxyrQA8sWH3WHysiEjo5RLos4BtGevN6bYj3Qczu8HMmsysqbW19UhrBaC2opj3nTqTixdNG9HPi4hEVWEO+2SbHzj4ymQu++Du9wD3ADQ2No7o6uYZc2s4Y27NSH5URCTScjlDbwZmZ6w3ANtHsI+IiIyhXAJ9FbDQzOabWTFwDbB80D7LgY+kZ7ucBex19x2jXKuIiBzGsEMu7p4ws5uBlUAMuM/d15rZjenty4AVwGXARmA/cP3YlSwiItnkMoaOu6+gL7Qz25ZlLDvwP0e3NBERORKhu1NURESyU6CLiESEAl1EJCIU6CIiEWEe0BsjzKwVeH2EP14H5NszANTn/KA+54ej6fNcd6/PtiGwQD8aZtbk7o1B1zGe1Of8oD7nh7Hqs4ZcREQiQoEuIhIRYQ30e4IuIADqc35Qn/PDmPQ5lGPoIiIyVFjP0EVEZBAFuohIRIQu0Id7YXWYmNl9ZtZiZi9ltNWa2aNm9mr6e03Gti+l+73ezN6T0X6Gmf0xve0OM8v2wpHAmdlsM/uNma0zs7Vm9ufp9ij3udTMnjWzF9J9/mq6PbJ9PsDMYmb2vJk9kl6PdJ/NbEu61jVm1pRuG98+u3tovuh7fO9rwAKgGHgBWBR0XUfRn3OBxcBLGW3fAW5JL98CfDu9vCjd3xJgfvrXIZbe9ixwNn1vjvoVcGnQfTtEf2cAi9PLVcCGdL+i3GcDKtPLRcAfgLOi3OeMvn8W+Gfgkaj/2U7XugWoG9Q2rn0O2xl6Li+sDg13fwLYM6j5SuD+9PL9wFUZ7Q+6e9zdN9P37Pkl6ZdxV7v70973p+GBjJ+ZUNx9h7s/l17uANbR9+7ZKPfZ3b0zvVqU/nIi3GcAM2sALgfuzWiOdJ8PYVz7HLZAz+ll1CE3zdNve0p/n5puP1TfZ6WXB7dPaGY2DzidvjPWSPc5PfSwBmgBHnX3yPcZ+D7wBSCV0Rb1Pjvw72a22sxuSLeNa59zesHFBJLTy6gj6lB9D92viZlVAj8DPuPu7YcZIoxEn909CZxmZpOBh83sbYfZPfR9NrP3Ai3uvtrMzs/lR7K0harPaee4+3Yzmwo8amavHGbfMelz2M7Q8+Fl1LvS/+0i/b0l3X6ovjenlwe3T0hmVkRfmP+Tuz+Ubo50nw9w9zbgcWAp0e7zOcAVZraFvmHRC8zs/xHtPuPu29PfW4CH6RsiHtc+hy3Qc3lhddgtBz6aXv4o8IuM9mvMrMTM5gMLgWfT/43rMLOz0lfDP5LxMxNKur4fAevc/XsZm6Lc5/r0mTlmVgZcBLxChPvs7l9y9wZ3n0ff39H/dPfriHCfzazCzKoOLAOXAC8x3n0O+srwCK4kX0bf7IjXgL8Mup6j7MtPgB1AL33/Mn8cmAI8Brya/l6bsf9fpvu9nowr30Bj+g/Pa8CdpO8AnmhfwDvp++/ji8Ca9NdlEe/zKcDz6T6/BHwl3R7ZPg/q//kcnOUS2T7TN/PuhfTX2gPZNN591q3/IiIREbYhFxEROQQFuohIRCjQRUQiQoEuIhIRCnQRkYhQoIuIRIQCXUQkIv4LtWqA4NNMHe8AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "M6WRP_0P1Wfp" }, "source": [ "tempdf1 = xx.copy()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ukACuqoc1Wfq", "outputId": "f6941465-fc4c-40ff-8a7c-84f8e104366e" }, "source": [ "context = pd.DataFrame(list(product(users, context1, context2, items)), columns=['users', 'context1', 'context2', 'items'])\n", "context['reward'] = 0\n", "X = context.copy()\n", "X.loc[(X['users']=='usera')&(X['items']=='item1'),'reward']=1\n", "X.loc[(X['users']=='userb')&(X['items']=='item2'),'reward']=1\n", "X.loc[(X['users']=='userc')&(X['items']=='item3'),'reward']=1\n", "X.reward.value_counts()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0 60\n", "1 12\n", "Name: reward, dtype: int64" ] }, "metadata": { "tags": [] }, "execution_count": 145 } ] }, { "cell_type": "code", "metadata": { "id": "NwVd7J4C1Wfq" }, "source": [ "vws.update_context(X)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "nqs6J8sy1Wfr" }, "source": [ "_temp = []\n", "for i in range(5000):\n", " _temp.append(vws.step())" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "JV3H9J4N1Wfr", "outputId": "5f675cad-76e7-42eb-b769-bd2b896cca7f" }, "source": [ "x = pd.DataFrame.from_records(_temp, columns=['user','context1','context2','item','cost','prob'])\n", "xx = x.copy()\n", "xx['ccost'] = xx['cost'].cumsum()\n", "xx = xx.fillna(0)\n", "xx = xx.rename_axis('iter').reset_index()\n", "xx['ctr'] = -1*xx['ccost']/xx['iter']\n", "xx.sample(10)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
iterusercontext1context2itemcostprobccostctr
354354usercctx12ctx21item200.833333-880.248588
33623362userbctx11ctx22item2-10.833333-24810.737954
485485useractx11ctx22item1-10.833333-1710.352577
33833383userbctx11ctx22item300.033333-24990.738693
28032803usercctx12ctx21item3-10.833333-20020.714235
34103410userbctx12ctx21item300.033333-25210.739296
267267useractx12ctx21item400.033333-540.202247
38483848userbctx12ctx21item2-10.833333-28880.750520
5454userbctx12ctx21item600.833333-130.240741
14471447usercctx11ctx21item500.033333-8910.615757
\n", "
" ], "text/plain": [ " iter user context1 context2 item cost prob ccost ctr\n", "354 354 userc ctx12 ctx21 item2 0 0.833333 -88 0.248588\n", "3362 3362 userb ctx11 ctx22 item2 -1 0.833333 -2481 0.737954\n", "485 485 usera ctx11 ctx22 item1 -1 0.833333 -171 0.352577\n", "3383 3383 userb ctx11 ctx22 item3 0 0.033333 -2499 0.738693\n", "2803 2803 userc ctx12 ctx21 item3 -1 0.833333 -2002 0.714235\n", "3410 3410 userb ctx12 ctx21 item3 0 0.033333 -2521 0.739296\n", "267 267 usera ctx12 ctx21 item4 0 0.033333 -54 0.202247\n", "3848 3848 userb ctx12 ctx21 item2 -1 0.833333 -2888 0.750520\n", "54 54 userb ctx12 ctx21 item6 0 0.833333 -13 0.240741\n", "1447 1447 userc ctx11 ctx21 item5 0 0.033333 -891 0.615757" ] }, "metadata": { "tags": [] }, "execution_count": 148 } ] }, { "cell_type": "code", "metadata": { "id": "6-2NfwBF1Wfs", "outputId": "5b89213a-e4f0-48a2-af31-cf79277a650a" }, "source": [ "tempdf2 = tempdf1.append(xx, ignore_index=True)\n", "tempdf2.sample(10)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
iterusercontext1context2itemcostprobccostctr
88373837useractx12ctx22item1-10.833333-28790.750326
82103210usercctx11ctx21item200.033333-23480.731464
32803280userbctx11ctx21item100.033333-18220.555488
10291029usercctx12ctx21item100.033333-5310.516035
41834183usercctx12ctx22item200.033333-23280.556538
27152715usercctx11ctx22item4-10.833333-14670.540331
63001300useractx12ctx21item1-10.833333-7760.596923
71112111useractx12ctx21item1-10.833333-14270.675983
40084008useractx11ctx22item5-10.833333-22390.558633
22362236usercctx12ctx21item2-10.833333-12030.538014
\n", "
" ], "text/plain": [ " iter user context1 context2 item cost prob ccost ctr\n", "8837 3837 usera ctx12 ctx22 item1 -1 0.833333 -2879 0.750326\n", "8210 3210 userc ctx11 ctx21 item2 0 0.033333 -2348 0.731464\n", "3280 3280 userb ctx11 ctx21 item1 0 0.033333 -1822 0.555488\n", "1029 1029 userc ctx12 ctx21 item1 0 0.033333 -531 0.516035\n", "4183 4183 userc ctx12 ctx22 item2 0 0.033333 -2328 0.556538\n", "2715 2715 userc ctx11 ctx22 item4 -1 0.833333 -1467 0.540331\n", "6300 1300 usera ctx12 ctx21 item1 -1 0.833333 -776 0.596923\n", "7111 2111 usera ctx12 ctx21 item1 -1 0.833333 -1427 0.675983\n", "4008 4008 usera ctx11 ctx22 item5 -1 0.833333 -2239 0.558633\n", "2236 2236 userc ctx12 ctx21 item2 -1 0.833333 -1203 0.538014" ] }, "metadata": { "tags": [] }, "execution_count": 149 } ] }, { "cell_type": "code", "metadata": { "id": "28SHrzVq1Wfs", "outputId": "16959443-416f-4fbe-f0e7-e331e324fc93" }, "source": [ "tempdf2['ccost'].plot()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 150 }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "2Nq8ZhSr1Wft", "outputId": "a9f12adf-13da-416b-e51c-dc66bdac8c59" }, "source": [ "tempdf2['ctr'].plot()" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 151 }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] } ] }