{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-23-recsim.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T219174%20%7C%20Recsim%20Catalyst.ipynb","timestamp":1644663600909}],"collapsed_sections":[],"authorship_tag":"ABX9TyOvUSUvC4gSatWU0WTCUGZo"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"oeyCsKy3R112"},"source":["# Recsim Catalyst\n","\n","> We will create a recommender bot by neural networks, and use RL methods to train it."]},{"cell_type":"markdown","metadata":{"id":"XK8c_5VaiGKl"},"source":["## Abstract\n","\n","*We propose RecSim, a configurable platform for authoring simulation environments for recommender systems (RSs) that naturally supports sequential interaction with users. RecSim allows the creation of new environments that reflect particular aspects of user behavior and item structure at a level of abstraction well-suited to pushing the limits of current reinforcement learning (RL) and RS techniques in sequential interactive recommendation problems. Environments can be easily configured that vary assumptions about: user preferences and item familiarity; user latent state and its dynamics; and choice models and other user response behavior. We outline how RecSim offers value to RL and RS researchers and practitioners, and how it can serve as a vehicle for academic-industrial collaboration.*\n","\n","https://arxiv.org/abs/1909.04847\n","\n","[GitHub](https://github.com/google-research/recsim), [Video](https://youtu.be/T6ZLpi65Bsc), [Medium](https://medium.com/dataseries/googles-recsim-is-an-open-source-simulation-framework-for-recommender-systems-9a802377acc2)"]},{"cell_type":"markdown","metadata":{"id":"TTmaDZWqiZCu"},"source":["RecSim is a configurable platform for authoring simulation environments for recommender systems (RSs) that naturally supports sequential interaction with users. RecSim allows the creation of new environments that reflect particular aspects of user behavior and item structure at a level of abstraction well-suited to pushing the limits of current reinforcement learning (RL) and RS techniques in sequential interactive recommendation problems. Environments can be easily configured that vary assumptions about: user preferences and item familiarity; user latent state and its dynamics; and choice models and other user response behavior. We outline how RecSim offers value to RL and RS researchers and practitioners, and how it can serve as a vehicle for academic-industrial collaboration. For a detailed description of the RecSim architecture please read Ie et al. Please cite the paper if you use the code from this repository in your work.\n","\n","RecSim simulates a recommender agent’s interaction with an environment where the agent interacts by doing some recommendations to users. Both the user and the subject of recommendations are simulated. The simulations are done based on popularity, interests, demographics, frequency and other traits. When an RL agent recommends something to a user, then depending on the user’s acceptance, few traits are scored high. This still sounds like a typical recommendation system. However, with RecSim, a developer can author these traits. The features in a user choice model can be made more customised as the agent gets rewarded for making the right recommendation.\n","\n","![](https://github.com/recohut/nbs/blob/main/raw/_images/T219174_1.png?raw=1)\n","\n","*Green and blue boxes show the environment. We need to implement special classes, User and Document. Our bot(\"Agent\") have to choose from several documents the most relevant for the user. The user can move to the offered document if he accepts it, to random document overwise or stay on the current document.*\n","\n","Green and blue boxes show the environment. We need to implement special classes, User and Document. Our bot(\"Agent\") have to choose from several documents the most relevant for the user. The user can move to the offered document if he accepts it, to random document overwise or stay on the current document.\n","\n","Recsim is a configurable simulation platform for recommender systems make by Google, which utilized the document and user database directly. We can break Recsim into two parts,\n","\n","- The environment consists of a user model, a document (item) model and a user-choice model. The user model samples users from a prior distribution of observable and latent user features; the document model samples items from a prior over observable and latent document features; and the user-choice model determines the user’s response, which is dependent on observable document features, observable and latent user features.\n","- The SlateQ Simulation Environment, which uses the SlateQ Algorithm to return a slate of items back to the simulation environment.\n","\n","Unlike virtual Taobao, Recsim has a concrete representation of items, and the actions returned by the reinforcement learning agent can be directly associated with items. However, the user model and item model of Recsim are too simple, and without sufficient data support, the prior probability distribution for generating simulated users and virtual items is difficult to be accurate."]},{"cell_type":"markdown","metadata":{"id":"KpLT9UqfUsup"},"source":["## Setup"]},{"cell_type":"code","metadata":{"id":"02lAdPChRvMu"},"source":["!pip install -Uq catalyst gym recsim"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wW4ZjWQGRytT"},"source":["from collections import deque, namedtuple\n","import random\n","import numpy as np\n","import gym\n","\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.utils.data import DataLoader\n","\n","from catalyst import dl, utils\n","\n","from gym import spaces\n","\n","from recsim import document, user\n","from recsim.choice_model import AbstractChoiceModel\n","from recsim.simulator import recsim_gym, environment"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"ZnhIeHteUw9K"},"source":["## Params"]},{"cell_type":"code","metadata":{"id":"C_7AvLRvS08O"},"source":["device = utils.get_device()\n","utils.set_global_seed(42)\n","\n","DOC_NUM = 10\n","EMB_SIZE = 4\n","P_EXIT_ACCEPTED = 0.1\n","P_EXIT_NOT_ACCEPTED = 0.2\n","\n","# let's define a matrix W for simulation of users' respose\n","# (based on the section 7.3 of the paper https://arxiv.org/pdf/1512.07679.pdf)\n","# W_ij defines the probability that a user will accept recommendation j\n","# given that he is consuming item i at the moment\n","\n","W = (np.ones((DOC_NUM, DOC_NUM)) - np.eye(DOC_NUM)) * \\\n"," np.random.uniform(0.0, P_EXIT_NOT_ACCEPTED, (DOC_NUM, DOC_NUM)) + \\\n"," np.diag(np.random.uniform(1.0 - P_EXIT_ACCEPTED, 1.0, DOC_NUM))\n","W = W[:, np.random.permutation(DOC_NUM)]"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"l9MY2EHQUyvV"},"source":["## Document Model"]},{"cell_type":"code","metadata":{"id":"m51NN5sFS4Yb"},"source":["class Document(document.AbstractDocument):\n","\n"," def __init__(self, doc_id):\n"," super().__init__(doc_id)\n","\n"," def create_observation(self):\n"," return (self._doc_id,)\n","\n"," @staticmethod\n"," def observation_space():\n"," return spaces.Discrete(DOC_NUM)\n","\n"," def __str__(self):\n"," return \"Document #{}\".format(self._doc_id)\n","\n","\n","class DocumentSampler(document.AbstractDocumentSampler):\n","\n"," def __init__(self, doc_ctor=Document):\n"," super().__init__(doc_ctor)\n"," self._doc_count = 0\n","\n"," def sample_document(self):\n"," doc = self._doc_ctor(self._doc_count % DOC_NUM)\n"," self._doc_count += 1\n"," return doc"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"AonnK8VNU2Zk"},"source":["## User Model"]},{"cell_type":"code","metadata":{"id":"w9Le47H-TNa1"},"source":["class UserState(user.AbstractUserState):\n"," def __init__(self, user_id, current, active_session=True):\n"," self.user_id = user_id\n"," self.current = current\n"," self.active_session = active_session\n","\n"," def create_observation(self):\n"," return (self.current,)\n","\n"," def __str__(self):\n"," return \"User #{}\".format(self.user_id)\n","\n"," @staticmethod\n"," def observation_space():\n"," return spaces.Discrete(DOC_NUM)\n","\n"," def score_document(self, doc_obs):\n"," return W[self.current, doc_obs[0]]\n","\n","\n","class StaticUserSampler(user.AbstractUserSampler):\n"," def __init__(self, user_ctor=UserState):\n"," super().__init__(user_ctor)\n"," self.user_count = 0\n","\n"," def sample_user(self):\n"," self.user_count += 1\n"," sampled_user = self._user_ctor(\n"," self.user_count, np.random.randint(DOC_NUM))\n"," return sampled_user\n","\n","\n","class Response(user.AbstractResponse):\n"," def __init__(self, accept=False):\n"," self.accept = accept\n","\n"," def create_observation(self):\n"," return (int(self.accept),)\n","\n"," @classmethod\n"," def response_space(cls):\n"," return spaces.Discrete(2)\n","\n","\n","class UserChoiceModel(AbstractChoiceModel):\n"," def __init__(self):\n"," super().__init__()\n"," self._score_no_click = P_EXIT_ACCEPTED\n","\n"," def score_documents(self, user_state, doc_obs):\n"," if len(doc_obs) != 1:\n"," raise ValueError(\n"," \"Expecting single document, but got: {}\".format(doc_obs))\n"," self._scores = np.array(\n"," [user_state.score_document(doc) for doc in doc_obs])\n","\n"," def choose_item(self):\n"," if np.random.random() < self.scores[0]:\n"," return 0\n","\n","\n","class UserModel(user.AbstractUserModel):\n"," def __init__(self):\n"," super().__init__(Response, StaticUserSampler(), 1)\n"," self.choice_model = UserChoiceModel()\n","\n"," def simulate_response(self, slate_documents):\n"," if len(slate_documents) != 1:\n"," raise ValueError(\"Expecting single document, but got: {}\".format(\n"," slate_documents))\n","\n"," responses = [self._response_model_ctor() for _ in slate_documents]\n","\n"," self.choice_model.score_documents(\n"," self._user_state,\n"," [doc.create_observation() for doc in slate_documents]\n"," )\n"," selected_index = self.choice_model.choose_item()\n","\n"," if selected_index is not None:\n"," responses[selected_index].accept = True\n","\n"," return responses\n","\n"," def update_state(self, slate_documents, responses):\n"," if len(slate_documents) != 1:\n"," raise ValueError(\n"," f\"Expecting single document, but got: {slate_documents}\"\n"," )\n","\n"," response = responses[0]\n"," doc = slate_documents[0]\n"," if response.accept:\n"," self._user_state.current = doc.doc_id()\n"," self._user_state.active_session = bool(\n"," np.random.binomial(1, 1 - P_EXIT_ACCEPTED))\n"," else:\n"," self._user_state.current = np.random.choice(DOC_NUM)\n"," self._user_state.active_session = bool(\n"," np.random.binomial(1, 1 - P_EXIT_NOT_ACCEPTED))\n","\n"," def is_terminal(self):\n"," \"\"\"Returns a boolean indicating if the session is over.\"\"\"\n"," return not self._user_state.active_session\n","\n","\n","def clicked_reward(responses):\n"," reward = 0.0\n"," for response in responses:\n"," if response.accept:\n"," reward += 1\n"," return reward"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zncpiikvUogu"},"source":["## RecSim Environment"]},{"cell_type":"code","metadata":{"id":"5EN5ZP-MTR_a"},"source":["def make_env():\n"," env = recsim_gym.RecSimGymEnv(\n"," environment.Environment(\n"," UserModel(), \n"," DocumentSampler(), \n"," DOC_NUM, \n"," 1, \n"," resample_documents=False\n"," ),\n"," clicked_reward\n"," )\n"," return env"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"OuX-2DMQUi8D"},"source":["## Actor-Critic Policy"]},{"cell_type":"markdown","metadata":{"id":"bEV0WC7bUKoy"},"source":["The actor is a simple NN, that generate embedding action vector based on current state. The critic model is more complicated. In our implementation, we need action embeddings. Our actions is a picking a document. So, we just need a embedding vector for each document. They can be trained as well as a critic model. And we have to implement choosing process by choosing top-k variants and calculate q-value on them."]},{"cell_type":"code","metadata":{"id":"9R-YJVeaTY3-"},"source":["from catalyst.contrib.nn import Normalize\n","\n","\n","inner_fn = utils.get_optimal_inner_init(nn.ReLU)\n","outer_fn = utils.outer_init\n","\n","\n","class ActorModel(nn.Module):\n"," def __init__(self, hidden=64, doc_num=10, doc_emb_size=4):\n"," super().__init__()\n"," \n"," self.actor = nn.Sequential(\n"," nn.Linear(doc_num, hidden),\n"," nn.ReLU(),\n"," nn.Linear(hidden, hidden),\n"," nn.ReLU(),\n"," )\n"," self.head = nn.Sequential(\n"," nn.Linear(hidden, doc_emb_size),\n"," Normalize()\n"," )\n"," \n"," self.actor.apply(inner_fn)\n"," self.head.apply(outer_fn)\n"," \n"," self.doc_num = doc_num\n"," self.doc_emb_size = doc_emb_size\n"," \n"," def forward(self, states):\n"," return self.head(self.actor(states))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"o3Sj7TEDUasp"},"source":["class CriticModel(nn.Module):\n"," def __init__(self, hidden=64, doc_num=10, doc_emb_size=4):\n"," super().__init__()\n"," \n"," self.critic = nn.Sequential(\n"," nn.Linear(doc_num + doc_emb_size, hidden),\n"," nn.ReLU(),\n"," nn.Linear(hidden, hidden),\n"," nn.ReLU(),\n"," )\n"," \n"," self.head = nn.Linear(hidden, 1)\n"," \n"," self.critic.apply(inner_fn)\n"," self.head.apply(outer_fn)\n"," \n"," self.doc_embs = nn.Sequential(\n"," nn.Embedding(doc_num, doc_emb_size),\n"," Normalize()\n"," )\n"," \n"," self.doc_num = doc_num\n"," self.doc_emb_size = doc_emb_size\n"," \n"," def _generate_input(self, states, proto_actions):\n"," return torch.cat([states, proto_actions], 1)\n"," \n"," def forward(self, states, proto_actions):\n"," inputs = self._generate_input(states, proto_actions)\n"," return self.head(self.critic(inputs))\n"," \n"," def get_topk(self, states, proto_actions, top_k=1):\n"," # Instead of kNN algorithm we can calculate distance across all of the objects.\n"," dist = torch.cdist(proto_actions, self.doc_embs[0].weight)\n"," indexes = torch.topk(dist, k=top_k, largest=False)[1]\n"," return torch.cat([self.doc_embs(index).unsqueeze(0) for index in indexes]), indexes\n"," \n"," def get_best(self, states, proto_actions, top_k=1):\n"," doc_embs, indexes = self.get_topk(states, proto_actions, top_k)\n"," top_k = doc_embs.size(1)\n"," best_values = torch.empty(states.size(0)).to(states.device)\n"," best_indexes = torch.empty(states.size(0)).to(states.device)\n"," for num, (state, actions, idx) in enumerate(zip(states, doc_embs, indexes)):\n"," new_states = state.repeat(top_k, 1)\n"," # for each pair of state and action we use critic to calculate values\n"," values = self(new_states, actions)\n"," best = values.max(0)[1].item()\n"," best_values[num] = values[best]\n"," best_indexes[num] = idx[best]\n"," return best_indexes, best_values"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"zBpfW5cnUc6i"},"source":["## Training"]},{"cell_type":"code","metadata":{"id":"UVfQwQHaU6Bj"},"source":["import numpy as np\n","from collections import deque, namedtuple\n","\n","Transition = namedtuple(\n"," 'Transition', \n"," field_names=[\n"," 'state', \n"," 'action', \n"," 'reward',\n"," 'done', \n"," 'next_state'\n"," ]\n",")\n","\n","class ReplayBuffer:\n"," def __init__(self, capacity: int):\n"," self.buffer = deque(maxlen=capacity)\n"," \n"," def append(self, transition: Transition):\n"," self.buffer.append(transition)\n"," \n"," def sample(self, batch_size: int):\n"," indices = np.random.choice(\n"," len(self.buffer), \n"," batch_size, \n"," replace=batch_size > len(self.buffer)\n"," )\n"," states, actions, rewards, dones, next_states = \\\n"," zip(*[self.buffer[idx] for idx in indices])\n"," return (\n"," np.array(states, dtype=np.float32), \n"," np.array(actions, dtype=np.int64), \n"," np.array(rewards, dtype=np.float32),\n"," np.array(dones, dtype=np.bool), \n"," np.array(next_states, dtype=np.float32)\n"," )\n"," \n"," def __len__(self):\n"," return len(self.buffer)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"HYp1A4U8U8jN"},"source":["from torch.utils.data.dataset import IterableDataset\n","\n","\n","class ReplayDataset(IterableDataset):\n","\n"," def __init__(self, buffer: ReplayBuffer, epoch_size: int = int(1e3)):\n"," self.buffer = buffer\n"," self.epoch_size = epoch_size\n","\n"," def __iter__(self):\n"," states, actions, rewards, dones, next_states = \\\n"," self.buffer.sample(self.epoch_size)\n"," for i in range(len(dones)):\n"," yield states[i], actions[i], rewards[i], dones[i], next_states[i]\n"," \n"," def __len__(self):\n"," return self.epoch_size"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ex7Fop1ZU_LX"},"source":["def extract_state(env, state):\n"," user_space = env.observation_space.spaces[\"user\"]\n"," return spaces.flatten(user_space, state[\"user\"])\n","\n","def get_action(env, actor, critic, state, top_k=10, epsilon=None):\n"," # Our framework is created by PG process and it must be trained with \n"," # a noise added to the actor's output.\n"," # But in our framework it's better to sample action from the enviroment.\n"," state = torch.tensor(state, dtype=torch.float32).to(device).unsqueeze(0)\n"," if epsilon is None or random.random() < epsilon:\n"," proto_action = actor(state)\n"," action = critic.get_best(state, proto_action, top_k)[0]\n"," action = action.detach().cpu().numpy().astype(int)\n"," else:\n"," action = env.action_space.sample()\n"," return action\n","\n","\n","def generate_session(\n"," env, \n"," actor,\n"," critic,\n"," replay_buffer=None,\n"," epsilon=None,\n"," top_k=10\n","):\n"," total_reward = 0\n"," s = env.reset()\n"," s = extract_state(env, s)\n","\n"," for t in range(1000):\n"," a = get_action(env, actor, critic, epsilon=epsilon, state=s, top_k=top_k)\n"," next_s, r, done, _ = env.step(a)\n"," next_s = extract_state(env, next_s)\n","\n"," if replay_buffer is not None:\n"," transition = Transition(s, a, r, done, next_s)\n"," replay_buffer.append(transition)\n","\n"," total_reward += r\n"," s = next_s\n"," if done:\n"," break\n","\n"," return total_reward\n","\n","def generate_sessions(\n"," env, \n"," actor,\n"," critic,\n"," replay_buffer=None,\n"," num_sessions=100,\n"," epsilon=None,\n"," top_k=10\n","):\n"," sessions_reward = 0\n"," for i_episone in range(num_sessions):\n"," reward = generate_session(\n"," env=env, \n"," actor=actor,\n"," critic=critic,\n"," epsilon=epsilon,\n"," replay_buffer=replay_buffer,\n"," top_k=top_k\n"," )\n"," sessions_reward += reward\n"," sessions_reward /= num_sessions\n"," return sessions_reward\n","\n","def soft_update(target, source, tau):\n"," \"\"\"Updates the target data with smoothing by ``tau``\"\"\"\n"," for target_param, param in zip(target.parameters(), source.parameters()):\n"," target_param.data.copy_(\n"," target_param.data * (1.0 - tau) + param.data * tau\n"," )"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"0hIDu0ZmVc9_"},"source":["It's a standart GameCallback!"]},{"cell_type":"code","metadata":{"id":"2LhjNdyNVLhW"},"source":["class RecSimCallback(dl.Callback):\n"," def __init__(self, order=0, session_period=1):\n"," super().__init__(order=0)\n"," self.session_period = session_period\n"," \n"," def on_stage_start(self, runner: dl.IRunner):\n"," generate_sessions(\n"," env=runner.env, \n"," actor=runner.model[\"origin_actor\"],\n"," critic=runner.model[\"origin_critic\"],\n"," replay_buffer=runner.replay_buffer,\n"," top_k=runner.k,\n"," epsilon=runner.epsilon,\n"," )\n"," \n"," def on_batch_end(self, runner: dl.IRunner):\n"," if runner.global_batch_step % self.session_period == 0:\n"," session_reward = generate_session(\n"," env=runner.env, \n"," actor=runner.model[\"origin_actor\"],\n"," critic=runner.model[\"origin_critic\"],\n"," replay_buffer=runner.replay_buffer,\n"," top_k=runner.k,\n"," epsilon=runner.epsilon,\n"," )\n"," runner.batch_metrics.update({\"s_reward\": session_reward})\n"," \n"," def on_epoch_end(self, runner: dl.IRunner):\n"," valid_reward = generate_sessions(\n"," env=runner.env, \n"," actor=runner.model[\"origin_actor\"],\n"," critic=runner.model[\"origin_critic\"],\n"," top_k=runner.k,\n"," epsilon=None\n"," )\n"," runner.epoch_metrics[\"_epoch_\"][\"train_v_reward\"] = valid_reward"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"_aTp1p-BVaL5"},"source":["class CustomRunner(dl.Runner):\n"," \n"," def __init__(self, *, env, replay_buffer, gamma, tau, epsilon=0.2, tau_period=1, k=5, **kwargs):\n"," super().__init__(**kwargs)\n"," self.env = env\n"," self.replay_buffer = replay_buffer\n"," self.gamma = gamma\n"," self.tau = tau\n"," self.tau_period = tau_period\n"," self.epsilon = epsilon\n"," self.k = k\n"," \n"," def on_stage_start(self, runner: dl.IRunner):\n"," super().on_stage_start(runner)\n"," soft_update(self.model[\"origin_actor\"], self.model[\"target_actor\"], 1.0)\n"," soft_update(self.model[\"origin_critic\"], self.model[\"target_critic\"], 1.0)\n","\n"," def handle_batch(self, batch):\n"," # model train/valid step\n"," states, actions, rewards, dones, next_states = batch\n"," \n"," proto_actions = self.model[\"origin_actor\"](states)\n"," policy_loss = (-self.model[\"origin_critic\"](states, proto_actions)).mean()\n"," \n"," with torch.no_grad():\n"," target_proto_actions = self.model[\"target_actor\"](next_states)\n"," target_values = self.model[\"target_critic\"].get_best(next_states, target_proto_actions, self.k)[1].detach()\n","\n"," dones = dones * 1.0\n"," expected_values = target_values * self.gamma * (1 - dones) + rewards\n"," actions = self.model[\"origin_critic\"].doc_embs(actions.squeeze())\n"," values = self.model[\"origin_critic\"](states, actions).squeeze()\n"," \n"," value_loss = self.criterion(\n"," values,\n"," expected_values\n"," )\n"," \n"," self.batch_metrics.update(\n"," {\n"," \"critic_loss\": value_loss, \n"," \"actor_loss\": policy_loss,\n"," }\n"," )\n","\n"," if self.is_train_loader:\n"," self.optimizer[\"actor\"].zero_grad()\n"," policy_loss.backward()\n"," self.optimizer[\"actor\"].step()\n"," \n"," self.optimizer[\"critic\"].zero_grad()\n"," value_loss.backward()\n"," self.optimizer[\"critic\"].step()\n"," \n"," if self.global_batch_step % self.tau_period == 0:\n"," soft_update(self.model[\"target_critic\"], self.model[\"origin_critic\"], self.tau)\n"," soft_update(self.model[\"target_actor\"], self.model[\"origin_actor\"], self.tau)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Q8LcvEBBVWOp"},"source":["Let's train our model and check the results."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"twCqmv1VVXTj","executionInfo":{"status":"ok","timestamp":1634626657205,"user_tz":-330,"elapsed":89358,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"ed5cdf22-5ad9-45bf-dafb-0ab7de0351b1"},"source":["utils.set_global_seed(42)\n","\n","env = make_env()\n","replay_buffer = ReplayBuffer(int(1e5))\n","gamma = 0.99\n","tau = 0.001\n","tau_period = 1\n","session_period = 1\n","epoch_size = int(1e4)\n","\n","\n","models = {\n"," \"origin_actor\": ActorModel(doc_num=DOC_NUM, doc_emb_size=EMB_SIZE),\n"," \"origin_critic\": CriticModel(doc_num=DOC_NUM, doc_emb_size=EMB_SIZE),\n"," \"target_actor\": ActorModel(doc_num=DOC_NUM, doc_emb_size=EMB_SIZE),\n"," \"target_critic\": CriticModel(doc_num=DOC_NUM, doc_emb_size=EMB_SIZE),\n","}\n","with torch.no_grad():\n"," models[\"origin_critic\"].doc_embs[0].weight.copy_(models[\"target_critic\"].doc_embs[0].weight)\n","\n","utils.set_requires_grad(models[\"target_actor\"], requires_grad=False)\n","utils.set_requires_grad(models[\"target_critic\"], requires_grad=False)\n","\n","criterion = torch.nn.MSELoss()\n","optimizer = {\n"," \"actor\": torch.optim.Adam(models[\"origin_actor\"].parameters(), lr=1e-3),\n"," \"critic\": torch.optim.Adam(models[\"origin_critic\"].parameters(), lr=1e-3),\n","}\n","\n","loaders = {\n"," \"train\": DataLoader(\n"," ReplayDataset(replay_buffer, epoch_size=epoch_size), \n"," batch_size=32,\n"," ),\n","}\n","\n","\n","runner = CustomRunner(\n"," env=env, \n"," replay_buffer=replay_buffer, \n"," gamma=gamma, \n"," tau=tau,\n"," tau_period=tau_period\n",")\n","\n","runner.train(\n"," model=models,\n"," criterion=criterion,\n"," optimizer=optimizer,\n"," loaders=loaders,\n"," logdir=\"./logs_rl\",\n"," valid_loader=\"_epoch_\",\n"," valid_metric=\"train_v_reward\",\n"," minimize_valid_metric=False,\n"," load_best_on_end=True,\n"," num_epochs=20,\n"," verbose=False,\n"," callbacks=[RecSimCallback(order=0, session_period=session_period)]\n",")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/catalyst/core/runner.py:624: UserWarning: No ``ICriterionCallback/CriterionCallback`` were found while runner.criterion is not None.Do you compute the loss during ``runner.handle_batch``?\n"," \"No ``ICriterionCallback/CriterionCallback`` were found \"\n"]},{"output_type":"stream","name":"stdout","text":["train (1/20) \n","* Epoch (1/20) train_v_reward: 4.45\n","train (2/20) \n","* Epoch (2/20) train_v_reward: 5.48\n","train (3/20) \n","* Epoch (3/20) train_v_reward: 4.35\n","train (4/20) \n","* Epoch (4/20) train_v_reward: 4.73\n","train (5/20) \n","* Epoch (5/20) train_v_reward: 5.52\n","train (6/20) \n","* Epoch (6/20) train_v_reward: 4.36\n","train (7/20) \n","* Epoch (7/20) train_v_reward: 4.88\n","train (8/20) \n","* Epoch (8/20) train_v_reward: 4.33\n","train (9/20) \n","* Epoch (9/20) train_v_reward: 5.01\n","train (10/20) \n","* Epoch (10/20) train_v_reward: 4.78\n","train (11/20) \n","* Epoch (11/20) train_v_reward: 4.9\n","train (12/20) \n","* Epoch (12/20) train_v_reward: 5.41\n","train (13/20) \n","* Epoch (13/20) train_v_reward: 4.63\n","train (14/20) \n","* Epoch (14/20) train_v_reward: 4.86\n","train (15/20) \n","* Epoch (15/20) train_v_reward: 4.38\n","train (16/20) \n","* Epoch (16/20) train_v_reward: 4.15\n","train (17/20) \n","* Epoch (17/20) train_v_reward: 4.45\n","train (18/20) \n","* Epoch (18/20) train_v_reward: 4.27\n","train (19/20) \n","* Epoch (19/20) train_v_reward: 5.27\n","train (20/20) \n","* Epoch (20/20) train_v_reward: 5.14\n","Top best models:\n","logs_rl/checkpoints/train.5.pth\t5.5200\n"]}]},{"cell_type":"markdown","metadata":{"id":"j0UI-CknVfsW"},"source":["In our case, we can compare RL bot results with the optimal recommender agent. The agent can be built by the relation matrix W. We need to chose an index with the maximum value in the column."]},{"cell_type":"code","metadata":{"id":"qFpNh2QiVnUK"},"source":["from recsim.agent import AbstractEpisodicRecommenderAgent\n","\n","class OptimalRecommender(AbstractEpisodicRecommenderAgent):\n","\n"," def __init__(self, environment, W):\n"," super().__init__(environment.action_space)\n"," self._observation_space = environment.observation_space\n"," self._W = W\n","\n"," def step(self, reward, observation):\n"," return [self._W[observation[\"user\"], :].argmax()]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"4XUNIDjhVrHJ"},"source":["def run_agent(\n"," env, \n"," agent, \n"," num_steps: int = int(1e4), \n"," log_every: int = int(1e3)\n","):\n"," reward_history = []\n"," step, episode = 1, 1\n","\n"," observation = env.reset()\n"," while step < num_steps:\n"," action = agent.begin_episode(observation)\n"," episode_reward = 0\n"," while True:\n"," observation, reward, done, info = env.step(action)\n"," episode_reward += reward\n","\n"," if step % log_every == 0:\n"," print(step, np.mean(reward_history[-50:]))\n"," step += 1\n"," if done:\n"," break\n"," else:\n"," action = agent.step(reward, observation)\n","\n"," agent.end_episode(reward, observation)\n"," reward_history.append(episode_reward)\n","\n"," return reward_history"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"A4vqLdJ6Vsrk","executionInfo":{"status":"ok","timestamp":1634626675640,"user_tz":-330,"elapsed":743,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"e003bff5-c17a-4b5e-9275-7e6acb296b94"},"source":["env = make_env()\n","agent = OptimalRecommender(env, W)\n","\n","reward_history = run_agent(env, agent)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["1000 8.26\n","2000 7.22\n","3000 8.18\n","4000 9.14\n","5000 8.22\n","6000 8.22\n","7000 10.76\n","8000 9.44\n","9000 8.2\n","10000 9.5\n"]}]}]}