{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-13-listwise-retail.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/P327463%20%7C%20List-wise%20Product%20Recommendations%20using%20RL%20methods%20on%20Retail%20dataset.ipynb","timestamp":1644610285298}],"collapsed_sections":[],"authorship_tag":"ABX9TyM7CW1JToK0uwxBLH2E9tH0"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# List-wise Product Recommendations using RL methods on Retail dataset"],"metadata":{"id":"9FbVBgGDaUV5"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"pQhWTElKuY6Z"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F \n","import torch.autograd\n","from torch.autograd import Variable\n","import torch.optim as optim\n","\n","import gym\n","from gym import spaces\n","\n","import pandas as pd\n","import numpy as np\n","from sklearn.preprocessing import OneHotEncoder\n","import random\n","from collections import deque\n","import os"]},{"cell_type":"markdown","source":["## Data"],"metadata":{"id":"ecbUGmt8uw7i"}},{"cell_type":"code","source":["!gdown --id 1h5DEIT-JYeR5e8D8BK6dny5zYCwth1rl"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"zkRdSXGGuw5K","executionInfo":{"status":"ok","timestamp":1639481833996,"user_tz":-330,"elapsed":4872,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"c19370d0-6516-49a2-84fa-169bb62c4d95"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Downloading...\n","From: https://drive.google.com/uc?id=1h5DEIT-JYeR5e8D8BK6dny5zYCwth1rl\n","To: /content/dataset.zip\n","100% 22.9M/22.9M [00:00<00:00, 62.9MB/s]\n"]}]},{"cell_type":"code","source":["!unzip dataset.zip"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Ev3QK2Zgu2O5","executionInfo":{"status":"ok","timestamp":1639481844856,"user_tz":-330,"elapsed":1594,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"e06b6cf4-718f-42c9-b32b-c177aefae9de"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Archive: dataset.zip\n"," creating: dataset/\n"," inflating: dataset/.DS_Store \n"," inflating: dataset/Data_Acc_Item.csv \n"," inflating: dataset/Item_inf.csv \n"," inflating: dataset/train_acc_inf.csv \n"]}]},{"cell_type":"code","source":["PARENT_PATH = 'weight'\n","ACTOR_PATH = 'weight/actor'\n","ACTOR_TARGET_PATH = 'weight/actor_target'\n","CRITIC_PATH = 'weight/critic'\n","CRITIC_TARGET_PATH = 'weight/critic_target'"],"metadata":{"id":"jJA8g_x9v9GK"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model"],"metadata":{"id":"XPEY3UuYu5pf"}},{"cell_type":"code","source":["class Critic(nn.Module):\n"," def __init__(self, state_size, action_size, hidden_size, action_sequence_length):\n"," super(Critic, self).__init__()\n"," self.encode_state = nn.LSTM(state_size,action_size,batch_first = True)\n"," hidden_stack = [nn.Linear((action_sequence_length + 1)*action_size, hidden_size),\n"," nn.ReLU(),]\n"," for i in range(3):\n"," hidden_stack.extend([nn.Linear(hidden_size, hidden_size), nn.ReLU()])\n"," self.hidden_layer = nn.Sequential(*hidden_stack)\n"," self.output_layer = nn.Linear(hidden_size, 1)\n","\n"," def forward(self, state, action):\n"," \"\"\"\n"," Params state and actions are torch tensors\n"," \"\"\"\n"," if not isinstance(state,torch.Tensor):\n"," state = torch.tensor(state)\n"," if not isinstance(action,torch.Tensor):\n"," action = torch.tensor(action)\n"," if (len(state.shape)==2) and (len(action.shape)==2):\n"," action = action.unsqueeze(0)\n"," state = state.unsqueeze(0)\n"," _,(encoded_state,__) = self.encode_state(state)\n"," encoded_state = encoded_state.squeeze(0)\n"," action = action.flatten(1)\n"," x = torch.cat([encoded_state,action],-1)\n"," x = self.hidden_layer(x)\n"," x = self.output_layer(x)\n"," if (len(state.shape)==2) and (len(action.shape)==2):\n"," x = x.squeeze(0)\n"," return x"],"metadata":{"id":"t2rvSOLbu8LM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class Actor(nn.Module):\n"," def __init__(self, input_size,input_sequence_length, output_sequence_length, output_size):\n"," super(Actor, self).__init__()\n"," self.weight_matrix = torch.nn.Parameter(torch.ones((1,input_sequence_length), requires_grad=True))\n"," self.Linear = nn.Linear(input_size, output_size)\n"," self.Activation = nn.Softmax(dim=-1)\n"," self.output_shape = (output_sequence_length,output_size)\n"," def forward(self, state):\n"," \"\"\"\n"," Param state is a torch tensor\n"," \"\"\"\n"," state = torch.FloatTensor(state)\n"," size = len(state.shape)\n"," if size==2:\n"," state = state.unsqueeze(0)\n"," state = self.weight_matrix.matmul(state)\n"," state = state.squeeze(1)\n"," action = []\n","# x = self.Linear(state)\n"," action.append(self.Activation(state))\n"," for i in range(self.output_shape[0]-1):\n"," indices = action[i].argmax(-1).unsqueeze(-1)\n"," action_i = action[i].scatter(-1,indices,0)\n"," action_i = action_i / action_i.sum(-1).unsqueeze(-1)\n"," action.append(action_i)\n"," action = torch.cat(action,-1).reshape((-1,self.output_shape[0],self.output_shape[1]))\n"," if size==2:\n"," action = action.squeeze(0)\n"," return action"],"metadata":{"id":"nUxrr7dcvAOu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class OUNoise(object):\n"," def __init__(self, action_space, mu=0.0, theta=0.1, max_sigma=0.5, min_sigma=0.0, decay_period=500):\n"," self.mu = mu\n"," self.theta = theta\n"," self.sigma = max_sigma\n"," self.max_sigma = max_sigma\n"," self.min_sigma = min_sigma\n"," self.decay_period = decay_period\n"," self.action_dim = action_space.shape\n"," self.low = action_space.low\n"," self.high = action_space.high\n"," self.reset()\n"," \n"," def reset(self):\n"," self.state = np.ones(self.action_dim) * self.mu\n"," \n"," def evolve_state(self):\n"," x = self.state\n"," dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim[0],self.action_dim[1])\n"," self.state = x + dx\n"," return self.state\n"," \n"," def get_action(self, action, t=0): \n"," ou_state = self.evolve_state()\n"," self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)\n"," action = np.clip(action + ou_state, self.low, self.high)\n"," action = torch.from_numpy(action)\n"," action = torch.nn.Softmax(dim=-1)(action).detach().numpy()\n"," return action"],"metadata":{"id":"nrOnx1GcvNmo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class ActionSpace(gym.Space):\n"," def __init__(self, n_reco, n_item):\n"," self.shape = (n_reco, n_item)\n"," self.dtype = np.int64\n"," self.low = 0\n"," self.high = 1\n"," super(ActionSpace, self).__init__(self.shape,self.dtype)\n"," def sample(self):\n"," sample = torch.zeros(self.shape,torch.int64)\n"," indices = torch.randint(0,n_item,(n_reco,1))\n"," sampe = sample.scatter_(1,indices,1)\n"," return sampe.numpy()"],"metadata":{"id":"AGQGxVTUvTFN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class StateSpace(gym.Space):\n"," def __init__(self, max_state, n_item):\n"," self.shape = (max_state, n_item)\n"," self.dtype = np.int64\n"," super(StateSpace, self).__init__(self.shape,self.dtype)"],"metadata":{"id":"IqT3xUflvoFc"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Memory buffer"],"metadata":{"id":"Q5d-EPhuvjkv"}},{"cell_type":"code","source":["class Memory:\n"," def __init__(self, max_size):\n"," self.buffer = deque(maxlen=max_size)\n"," \n"," def push(self, state, action, reward, next_state, done):\n"," experience = (state, action, np.array([reward]), next_state, done)\n"," self.buffer.append(experience)\n","\n"," def sample(self, batch_size):\n"," state_batch = []\n"," action_batch = []\n"," reward_batch = []\n"," next_state_batch = []\n"," done_batch = []\n","\n"," batch = random.sample(self.buffer, batch_size)\n","\n"," for experience in batch:\n"," state, action, reward, next_state, done = experience\n"," state_batch.append(state)\n"," action_batch.append(action)\n"," reward_batch.append(reward)\n"," next_state_batch.append(next_state)\n"," done_batch.append(done)\n"," \n"," return state_batch, action_batch, reward_batch, next_state_batch, done_batch\n","\n"," def __len__(self):\n"," return len(self.buffer)"],"metadata":{"id":"dz_Lj1THvlKj"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Agent"],"metadata":{"id":"wq50vLsxwD10"}},{"cell_type":"code","source":["class DDPGagent:\n"," def __init__(self, env, hidden_size=576, \n"," actor_learning_rate=1e-4, \n"," critic_learning_rate=1e-3, \n"," gamma=0.99, tau=1e-2, \n"," max_memory_size=50000):\n"," # Params\n"," self.size_states = env.observation_space.shape\n"," self.size_actions = env.action_space.shape\n"," self.gamma = gamma\n"," self.tau = tau\n","\n"," # Networks\n"," self.actor = Actor(self.size_states[1],self.size_actions[0], hidden_size, self.size_actions[1])\n"," self.actor_target = Actor(self.size_states[1],self.size_actions[0], hidden_size, self.size_actions[1])\n"," self.critic = Critic(self.size_states[1] ,self.size_actions[1] , hidden_size, self.size_actions[0])\n"," self.critic_target = Critic(self.size_states[1] ,self.size_actions[1] , hidden_size, self.size_actions[0])\n","\n"," self.load_()\n"," \n"," # Training\n"," self.memory = Memory(max_memory_size) \n"," self.critic_criterion = nn.MSELoss()\n"," self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_learning_rate)\n"," self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_learning_rate)\n"," \n"," for p in self.actor_target.parameters():\n"," p.requires_grad = False\n"," for p in self.critic_target.parameters():\n"," p.requires_grad = False\n"," \n"," def from_probability_distribution_to_action(self,action):\n"," if not isinstance(action,torch.Tensor):\n"," action = torch.FloatTensor(action)\n"," indices = torch.max(action,-1).indices.unsqueeze(-1)\n"," action = action.zero_().scatter_(-1,indices,1).numpy()\n"," return action\n"," \n"," def get_action(self, state):\n"," if not isinstance(state,torch.Tensor):\n"," state = torch.FloatTensor(state)\n"," with torch.no_grad():\n"," action = self.actor.forward(state)\n"," action = action.detach().numpy()\n"," return action\n"," \n"," def update(self, batch_size):\n"," states, actions, rewards, next_states, _ = self.memory.sample(batch_size)\n"," states = torch.FloatTensor(states)\n"," actions = torch.FloatTensor(actions)\n"," rewards = torch.FloatTensor(rewards)\n"," next_states = torch.FloatTensor(next_states)\n"," \n"," # Critic loss \n"," Qvals = self.critic.forward(states, actions)\n"," next_actions = self.actor_target.forward(next_states)\n"," next_actions = self.from_probability_distribution_to_action(next_actions)\n"," next_Q = self.critic_target.forward(next_states, next_actions)\n"," Qprime = rewards + self.gamma * next_Q\n"," critic_loss = self.critic_criterion(Qvals, Qprime)\n","\n"," # Actor loss\n"," policy_loss = -self.critic.forward(states, self.actor.forward(states)).mean()\n"," \n"," # update networks\n"," self.actor_optimizer.zero_grad()\n"," policy_loss.backward()\n"," self.actor_optimizer.step()\n","\n"," self.critic_optimizer.zero_grad()\n"," critic_loss.backward() \n"," self.critic_optimizer.step()\n","\n"," # update target networks\n"," with torch.no_grad():\n"," for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):\n"," target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))\n"," \n"," for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):\n"," target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))\n"," def save_(self):\n"," if not os.path.exists(PARENT_PATH):\n"," os.mkdir(PARENT_PATH)\n"," torch.save(self.actor.state_dict(), ACTOR_PATH)\n"," torch.save(self.actor_target.state_dict(), ACTOR_TARGET_PATH)\n"," torch.save(self.critic.state_dict(), CRITIC_PATH)\n"," torch.save(self.critic_target.state_dict(), CRITIC_TARGET_PATH)\n"," def load_(self):\n"," try:\n"," self.actor.load_state_dict(torch.load(ACTOR_PATH))\n"," self.actor_target.load_state_dict(torch.load(ACTOR_TARGET_PATH))\n"," self.critic.load_state_dict(torch.load(CRITIC_PATH))\n"," self.critic_target.load_state_dict(torch.load(CRITIC_TARGET_PATH))\n"," except Exception:\n"," for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):\n"," target_param.data.copy_(param.data)\n","\n"," for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):\n"," target_param.data.copy_(param.data)\n"," print(self.actor.eval(), self.critic.eval())"],"metadata":{"id":"mW0xdq3CvytJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Environment"],"metadata":{"id":"oCC7LoW0vlfZ"}},{"cell_type":"code","source":["class MyEnv(gym.Env):\n","\n"," def __init__(self,\n"," history_data: pd.DataFrame,\n"," item_data: pd.DataFrame,\n"," user_data: pd.DataFrame,\n"," dim_action: int = 3,\n"," max_lag: int = 20,\n"," ): \n"," \n"," super(MyEnv, self).__init__()\n"," self.history_data = history_data\n"," self.item_data = item_data\n"," self.user_data = user_data\n"," self.dim_action = dim_action\n"," self.max_lag = max_lag\n"," self.list_item = item_data.ID.tolist()\n"," self.n_item = len(self.list_item)\n"," self.encode = OneHotEncoder(handle_unknown='ignore')\n"," self.encode.fit(np.array(self.list_item).reshape(-1,1))\n"," self.action_space = ActionSpace(self.dim_action, self.n_item)\n"," self.observation_space = StateSpace(self.max_lag, self.n_item)\n"," self.idx_current = 0\n"," \n"," def step(self, action):\n"," action = np.array(action)\n"," _current_itemID = self.history_data.iloc[self.idx_current].ItemID\n"," _current_AcountID = self.history_data.iloc[self.idx_current].AccountID\n"," _temp = self.history_data.iloc[:self.idx_current + 1]\n"," current_frame = _temp[_temp.AccountID == _current_AcountID]\n"," if (len(current_frame) < self.max_lag):\n"," first_state = obs = np.zeros((self.max_lag - len(current_frame),self.n_item))\n"," str_obs = current_frame.ItemID.to_numpy().reshape(-1,1)\n"," last_state = self.encode.transform(str_obs).toarray()\n"," obs = np.concatenate([first_state, last_state],0)\n"," else:\n"," str_obs = current_frame[-self.max_lag:].ItemID.to_numpy().reshape(-1,1)\n"," obs = self.encode.transform(str_obs).toarray()\n"," \n"," _encode_current_itemID = self.encode.transform([[_current_itemID]]).toarray().reshape(-1)\n"," reward = 0\n"," for i in range(self.dim_action):\n"," if (action[i]==_encode_current_itemID).all():\n"," reward = self.dim_action - i\n"," break\n"," if (np.sum(action,1) > 1).any():\n"," reward = reward - 10\n"," done = False\n"," return obs, reward, done, {}\n"," def get_observation(self, reset = False):\n"," if reset:\n"," self.idx_current = np.random.randint(len(self.history_data))\n"," else:\n"," if (self.idx_current+1) == len(self.history_data):\n"," self.idx_current = 0\n"," else:\n"," self.idx_current = self.idx_current + 1\n"," _current_AcountID = self.history_data.iloc[self.idx_current].AccountID\n"," _temp = self.history_data.iloc[:self.idx_current]\n"," recent_past_frame = _temp[_temp.AccountID == _current_AcountID]\n"," \n"," first_state = obs = np.zeros((len(recent_past_frame),self.n_item))\n"," if (len(recent_past_frame) < self.max_lag):\n"," first_state = obs = np.zeros(( self.max_lag - len(recent_past_frame),self.n_item))\n"," str_obs = recent_past_frame.ItemID.to_numpy().reshape(-1,1)\n"," if len(str_obs) !=0:\n"," last_state = self.encode.transform(str_obs).toarray()\n"," obs = np.concatenate([first_state, last_state],0)\n"," else:\n"," str_obs = recent_past_frame[-self.max_lag:].ItemID.to_numpy().reshape(-1,1)\n"," obs = self.encode.transform(str_obs).toarray()\n"," return obs\n"," \n"," def render(self, mode='human', close=False):\n"," # Render the environment to the screen\n"," raise Exception()"],"metadata":{"id":"HLQkXPdXvycs"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Training"],"metadata":{"id":"ih6WzjaFwUSw"}},{"cell_type":"code","source":["rating = pd.read_csv('dataset/Data_Acc_Item.csv')\n","item = pd.read_csv('dataset/Item_inf.csv',index_col = 'Unnamed: 0')\n","user = pd.read_csv('dataset/train_acc_inf.csv')\n","\n","env = MyEnv(rating,item,user)\n","agent = DDPGagent(env)\n","noise = OUNoise(env.action_space)\n","batch_size = 100\n","rewards = []\n","avg_rewards = []"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vtMSwuuLwMCC","executionInfo":{"status":"ok","timestamp":1639482215218,"user_tz":-330,"elapsed":2892,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"9f6c74a4-35d0-4e57-fb43-c04c85da043a"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Actor(\n"," (Linear): Linear(in_features=576, out_features=576, bias=True)\n"," (Activation): Softmax(dim=-1)\n",") Critic(\n"," (encode_state): LSTM(576, 576, batch_first=True)\n"," (hidden_layer): Sequential(\n"," (0): Linear(in_features=2304, out_features=576, bias=True)\n"," (1): ReLU()\n"," (2): Linear(in_features=576, out_features=576, bias=True)\n"," (3): ReLU()\n"," (4): Linear(in_features=576, out_features=576, bias=True)\n"," (5): ReLU()\n"," (6): Linear(in_features=576, out_features=576, bias=True)\n"," (7): ReLU()\n"," )\n"," (output_layer): Linear(in_features=576, out_features=1, bias=True)\n",")\n"]}]},{"cell_type":"code","source":["for episode in range(20):\n"," state = env.get_observation(reset = True)\n"," noise.reset()\n"," episode_reward = 0\n"," \n"," for step in range(500):\n"," action = agent.get_action(state)\n"," action = noise.get_action(action, step)\n"," action = agent.from_probability_distribution_to_action(action)\n"," new_state, reward, done, _ = env.step(action) \n"," agent.memory.push(state, action, reward, new_state, done)\n"," \n"," if len(agent.memory) > batch_size:\n"," agent.update(batch_size) \n"," \n"," state = env.get_observation()\n"," episode_reward += reward\n"," print('step {} in episode {} : reward is {}'.format(step, episode, reward))\n","\n"," rewards.append(episode_reward)\n"," avg_rewards.append(np.mean(rewards[-10:]))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":502},"id":"giGBoeJfwTFJ","executionInfo":{"status":"error","timestamp":1639482223753,"user_tz":-330,"elapsed":872,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"e39f8e6d-ddae-442a-9045-da164d0a89f4"},"execution_count":null,"outputs":[{"output_type":"error","ename":"RuntimeError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m500\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnoise\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_probability_distribution_to_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m\u001b[0m in \u001b[0;36mget_action\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight_matrix\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (576x20 and 3x1)"]}]},{"cell_type":"code","source":["plt.plot(rewards)\n","plt.plot(avg_rewards)\n","plt.plot()\n","plt.xlabel('Episode')\n","plt.ylabel('Reward')\n","plt.show()"],"metadata":{"id":"AeNUU95lwWWJ"},"execution_count":null,"outputs":[]}]}