{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyNkfccbNHhNy78YPugx3bFf"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":["import os\n","import gc\n","import glob\n","import torch\n","import base64\n","import pygame\n","import argparse\n","import numpy as np\n","import torch.nn as nn\n","import gymnasium as gym\n","import torch.optim as optim\n","import matplotlib.pyplot as plt\n","import torch.nn.functional as F\n","\n","\n","from pathlib import Path\n","from datetime import datetime\n","from collections import deque\n","from IPython.display import HTML\n","from collections import namedtuple\n","from IPython.display import clear_output\n","from torch.distributions import Categorical\n","from torch.utils.tensorboard import SummaryWriter\n","from gym.wrappers.record_video import RecordVideo"],"metadata":{"id":"hQ90ciAqdQxk","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1743707343447,"user_tz":-210,"elapsed":6367,"user":{"displayName":"Nima Shirzady","userId":"04764659690504916110"}},"outputId":"6f23ad5c-f2d5-456d-d3e2-9093a07a7f15"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["pygame 2.6.1 (SDL 2.28.4, Python 3.11.11)\n","Hello from the pygame community. https://www.pygame.org/contribute.html\n"]}]},{"cell_type":"code","source":["def show_video(video_path):\n"," \"\"\"Display a recorded video in Colab.\"\"\"\n"," video_file = glob.glob(video_path + \"/*.mp4\")[0] # Get the first recorded video\n"," video_url = f\"data:video/mp4;base64,{base64.b64encode(open(video_file, 'rb').read()).decode()}\"\n"," return HTML(f'')\n"],"metadata":{"id":"8cdjpJiKab9f"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","gc.collect()\n","torch.cuda.empty_cache()\n","os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Used for debugging; CUDA related errors shown immediately.\n","# Seed everything for reproducible results\n","seed = 2024\n","np.random.seed(seed)\n","np.random.default_rng(seed)\n","os.environ['PYTHONHASHSEED'] = str(seed)\n","torch.manual_seed(seed)\n","if torch.cuda.is_available():\n"," torch.cuda.manual_seed(seed)\n"," torch.backends.cudnn.deterministic = True\n"," torch.backends.cudnn.benchmark = False\n"],"metadata":{"id":"RzOmjR7MGLWz"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class ReplayMemory:\n"," def __init__(self, capacity):\n"," \"\"\"\n"," Experience Replay Memory defined by deques to store transitions/agent experiences\n","\n"," \"\"\"\n","\n"," self.capacity = capacity\n","\n"," self.states = deque(maxlen=capacity)\n"," self.actions = deque(maxlen=capacity)\n"," self.next_states = deque(maxlen=capacity)\n"," self.rewards = deque(maxlen=capacity)\n"," self.dones = deque(maxlen=capacity)\n","\n","\n"," def store(self, state, action, next_state, reward, done):\n"," \"\"\"\n"," Append (store) the transitions to their respective deques\n"," \"\"\"\n","\n"," self.states.append(state)\n"," self.actions.append(action)\n"," self.next_states.append(next_state)\n"," self.rewards.append(reward)\n"," self.dones.append(done)\n","\n","\n"," def sample(self, batch_size):\n"," \"\"\"\n"," Randomly sample transitions from memory, then convert sampled transitions\n"," to tensors and move to device (CPU or GPU).\n"," \"\"\"\n","\n"," indices = np.random.choice(len(self), size=batch_size, replace=False)\n","\n"," states = torch.stack([torch.as_tensor(self.states[i], dtype=torch.float32, device=device) for i in indices]).to(device)\n"," actions = torch.as_tensor([self.actions[i] for i in indices], dtype=torch.long, device=device)\n"," next_states = torch.stack([torch.as_tensor(self.next_states[i], dtype=torch.float32, device=device) for i in indices]).to(device)\n"," rewards = torch.as_tensor([self.rewards[i] for i in indices], dtype=torch.float32, device=device)\n"," dones = torch.as_tensor([self.dones[i] for i in indices], dtype=torch.bool, device=device)\n","\n"," return states, actions, next_states, rewards, dones\n","\n","\n"," def __len__(self):\n"," \"\"\"\n"," To check how many samples are stored in the memory. self.dones deque\n"," represents the length of the entire memory.\n"," \"\"\"\n","\n"," return len(self.dones)\n"],"metadata":{"id":"C0692mQcGU77"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class DQN_Network(nn.Module):\n"," \"\"\"\n"," The Deep Q-Network (DQN) model for reinforcement learning.\n"," This network consists of Fully Connected (FC) layers with ReLU activation functions.\n"," \"\"\"\n","\n"," def __init__(self, num_actions, input_dim):\n"," \"\"\"\n"," Initialize the DQN network.\n","\n"," Parameters:\n"," num_actions (int): The number of possible actions in the environment.\n"," input_dim (int): The dimensionality of the input state space.\n"," \"\"\"\n","\n"," super(DQN_Network, self).__init__()\n","\n"," self.FC = nn.Sequential(\n"," nn.Linear(input_dim, 12),\n"," nn.ReLU(inplace=True),\n"," nn.Linear(12, 8),\n"," nn.ReLU(inplace=True),\n"," nn.Linear(8, num_actions)\n"," )\n","\n"," # Initialize FC layer weights using He initialization\n"," for layer in [self.FC]:\n"," for module in layer:\n"," if isinstance(module, nn.Linear):\n"," nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')\n","\n","\n"," def forward(self, x):\n"," \"\"\"\n"," Forward pass of the network to find the Q-values of the actions.\n","\n"," Parameters:\n"," x (torch.Tensor): Input tensor representing the state.\n","\n"," Returns:\n"," Q (torch.Tensor): Tensor containing Q-values for each action.\n"," \"\"\"\n","\n"," Q = self.FC(x)\n"," return Q\n"],"metadata":{"id":"FCYb-z21Gdy7"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ITicWIe9T_eK"},"outputs":[],"source":["class DQN_Agent:\n"," \"\"\"\n"," DQN Agent Class. This class defines some key elements of the DQN algorithm,\n"," such as the learning method, hard update, and action selection based on the\n"," Q-value of actions or the epsilon-greedy policy.\n"," \"\"\"\n","\n"," def __init__(self, env, epsilon_max, epsilon_min, epsilon_decay,\n"," clip_grad_norm, learning_rate, discount, memory_capacity):\n","\n"," # To save the history of network loss\n"," self.loss_history = []\n"," self.running_loss = 0\n"," self.learned_counts = 0\n","\n"," # RL hyperparameters\n"," self.epsilon_max = epsilon_max\n"," self.epsilon_min = epsilon_min\n"," self.epsilon_decay = epsilon_decay\n"," self.discount = discount\n","\n"," self.action_space = env.action_space\n"," self.action_space.seed(seed) # Set the seed to get reproducible results when sampling the action space\n"," self.observation_space = env.observation_space\n"," self.replay_memory = ReplayMemory(memory_capacity)\n","\n"," # Initiate the network models\n"," self.main_network = DQN_Network(num_actions=self.action_space.n, input_dim=self.observation_space.n).to(device)\n"," self.target_network = DQN_Network(num_actions=self.action_space.n, input_dim=self.observation_space.n).to(device).eval()\n"," self.target_network.load_state_dict(self.main_network.state_dict())\n","\n"," self.clip_grad_norm = clip_grad_norm # For clipping exploding gradients caused by high reward value\n"," self.critertion = nn.MSELoss()\n"," self.optimizer = optim.Adam(self.main_network.parameters(), lr=learning_rate)\n","\n","\n"," def select_action(self, state):\n"," \"\"\"\n"," Selects an action using epsilon-greedy strategy OR based on the Q-values.\n","\n"," Parameters:\n"," state (torch.Tensor): Input tensor representing the state.\n","\n"," Returns:\n"," action (int): The selected action.\n"," \"\"\"\n","\n"," # Exploration: epsilon-greedy\n"," if np.random.random() < self.epsilon_max:\n"," return self.action_space.sample()\n","\n"," # Exploitation: the action is selected based on the Q-values.\n"," with torch.no_grad():\n"," Q_values = self.main_network(state)\n"," action = torch.argmax(Q_values).item()\n","\n"," return action\n","\n","\n"," def learn(self, batch_size, done):\n"," \"\"\"\n"," Train the main network using a batch of experiences sampled from the replay memory.\n","\n"," Parameters:\n"," batch_size (int): The number of experiences to sample from the replay memory.\n"," done (bool): Indicates whether the episode is done or not. If done,\n"," calculate the loss of the episode and append it in a list for plot.\n"," \"\"\"\n","\n"," # Sample a batch of experiences from the replay memory\n"," states, actions, next_states, rewards, dones = self.replay_memory.sample(batch_size)\n","\n","\n"," actions = actions.unsqueeze(1)\n"," rewards = rewards.unsqueeze(1)\n"," dones = dones.unsqueeze(1)\n","\n","\n"," predicted_q = self.main_network(states) # forward pass through the main network to find the Q-values of the states\n"," predicted_q = predicted_q.gather(dim=1, index=actions) # selecting the Q-values of the actions that were actually taken\n","\n"," # Compute the maximum Q-value for the next states using the target network\n"," with torch.no_grad():\n"," next_target_q_value = self.target_network(next_states).max(dim=1, keepdim=True)[0] # not argmax (cause we want the maxmimum q-value, not the action that maximize it)\n","\n","\n"," next_target_q_value[dones] = 0 # Set the Q-value for terminal states to zero\n"," y_js = rewards + (self.discount * next_target_q_value) # Compute the target Q-values\n"," loss = self.critertion(predicted_q, y_js) # Compute the loss\n","\n"," # Update the running loss and learned counts for logging and plotting\n"," self.running_loss += loss.item()\n"," self.learned_counts += 1\n","\n"," if done:\n"," episode_loss = self.running_loss / self.learned_counts # The average loss for the episode\n"," self.loss_history.append(episode_loss) # Append the episode loss to the loss history for plotting\n"," # Reset the running loss and learned counts\n"," self.running_loss = 0\n"," self.learned_counts = 0\n","\n"," self.optimizer.zero_grad() # Zero the gradients\n"," loss.backward() # Perform backward pass and update the gradients\n","\n"," # Clip the gradients to prevent exploding gradients\n"," torch.nn.utils.clip_grad_norm_(self.main_network.parameters(), self.clip_grad_norm)\n","\n"," self.optimizer.step() # Update the parameters of the main network using the optimizer\n","\n","\n"," def hard_update(self):\n"," \"\"\"\n"," Navie update: Update the target network parameters by directly copying\n"," the parameters from the main network.\n"," \"\"\"\n","\n"," self.target_network.load_state_dict(self.main_network.state_dict())\n","\n","\n"," def update_epsilon(self):\n"," \"\"\"\n"," Update the value of epsilon for epsilon-greedy exploration.\n","\n"," This method decreases epsilon over time according to a decay factor, ensuring\n"," that the agent becomes less exploratory and more exploitative as training progresses.\n"," \"\"\"\n","\n"," self.epsilon_max = max(self.epsilon_min, self.epsilon_max * self.epsilon_decay)\n","\n","\n"," def save(self, path):\n"," \"\"\"\n"," Save the parameters of the main network to a file with .pth extention.\n","\n"," \"\"\"\n"," torch.save(self.main_network.state_dict(), path)\n","\n","\n","\n"]},{"cell_type":"code","source":["class Model_TrainTest:\n"," def __init__(self, hyperparams):\n","\n"," # Define RL Hyperparameters\n"," self.train_mode = hyperparams[\"train_mode\"]\n"," self.RL_load_path = hyperparams[\"RL_load_path\"]\n"," self.save_path = hyperparams[\"save_path\"]\n"," self.save_interval = hyperparams[\"save_interval\"]\n","\n"," self.clip_grad_norm = hyperparams[\"clip_grad_norm\"]\n"," self.learning_rate = hyperparams[\"learning_rate\"]\n"," self.discount_factor = hyperparams[\"discount_factor\"]\n"," self.batch_size = hyperparams[\"batch_size\"]\n"," self.update_frequency = hyperparams[\"update_frequency\"]\n"," self.max_episodes = hyperparams[\"max_episodes\"]\n"," self.max_steps = hyperparams[\"max_steps\"]\n"," self.render = hyperparams[\"render\"]\n","\n"," self.epsilon_max = hyperparams[\"epsilon_max\"]\n"," self.epsilon_min = hyperparams[\"epsilon_min\"]\n"," self.epsilon_decay = hyperparams[\"epsilon_decay\"]\n","\n"," self.memory_capacity = hyperparams[\"memory_capacity\"]\n","\n"," self.num_states = hyperparams[\"num_states\"]\n"," self.map_size = hyperparams[\"map_size\"]\n"," self.render_fps = hyperparams[\"render_fps\"]\n","\n"," # Define Env\n"," self.env = gym.make('FrozenLake-v1', map_name=f\"{self.map_size}x{self.map_size}\",\n"," is_slippery=False, max_episode_steps=self.max_steps,\n"," render_mode=\"rgb_array\" if self.render else None)\n"," self.env.metadata['render_fps'] = self.render_fps # For max frame rate make it 0\n"," print(self.render)\n"," if self.render:\n"," video_folder = \"vid\"\n"," os.makedirs(video_folder, exist_ok=True)\n"," self.env = RecordVideo(self.env, video_folder)\n","\n"," # Define the agent class\n"," self.agent = DQN_Agent(env = self.env,\n"," epsilon_max = self.epsilon_max,\n"," epsilon_min = self.epsilon_min,\n"," epsilon_decay = self.epsilon_decay,\n"," clip_grad_norm = self.clip_grad_norm,\n"," learning_rate = self.learning_rate,\n"," discount = self.discount_factor,\n"," memory_capacity = self.memory_capacity)\n","\n","\n"," def state_preprocess(self, state:int, num_states:int):\n"," \"\"\"\n"," Convert an state to a tensor and basically it encodes the state into\n"," an onehot vector. For example, the return can be something like tensor([0,0,1,0,0])\n"," which could mean agent is at state 2 from total of 5 states.\n","\n"," \"\"\"\n"," onehot_vector = torch.zeros(num_states, dtype=torch.float32, device=device)\n"," onehot_vector[state] = 1\n"," return onehot_vector\n","\n","\n"," def train(self):\n"," \"\"\"\n"," Reinforcement learning training loop.\n"," \"\"\"\n","\n"," total_steps = 0\n"," self.reward_history = []\n","\n"," # Training loop over episodes\n"," for episode in range(1, self.max_episodes+1):\n"," state, _ = self.env.reset(seed=seed)\n"," state = self.state_preprocess(state, num_states=self.num_states)\n"," done = False\n"," truncation = False\n"," step_size = 0\n"," episode_reward = 0\n","\n"," while not done and not truncation:\n"," action = self.agent.select_action(state)\n"," next_state, reward, done, truncation, _ = self.env.step(action)\n"," next_state = self.state_preprocess(next_state, num_states=self.num_states)\n","\n"," self.agent.replay_memory.store(state, action, next_state, reward, done)\n","\n"," if len(self.agent.replay_memory) > self.batch_size and sum(self.reward_history) > 0:\n"," self.agent.learn(self.batch_size, (done or truncation))\n","\n"," # Update target-network weights\n"," if total_steps % self.update_frequency == 0:\n"," self.agent.hard_update()\n","\n"," state = next_state\n"," episode_reward += reward\n"," step_size +=1\n"," if step_size > 200:\n"," truncation = True\n","\n"," # Appends for tracking history\n"," self.reward_history.append(episode_reward) # episode reward\n"," total_steps += step_size\n","\n"," # Decay epsilon at the end of each episode\n"," self.agent.update_epsilon()\n","\n"," #-- based on interval\n"," if episode % self.save_interval == 0:\n"," self.agent.save(self.save_path + '_' + f'{episode}' + '.pth')\n"," if episode != self.max_episodes:\n"," self.plot_training(episode)\n"," #print('\\n~~~~~~Interval Save: Model saved.\\n')\n","\n"," result = (f\"Episode: {episode}, \"\n"," f\"Total Steps: {total_steps}, \"\n"," f\"Ep Step: {step_size}, \"\n"," f\"Raw Reward: {episode_reward:.2f}, \"\n"," f\"Epsilon: {self.agent.epsilon_max:.2f}\")\n"," #print(result)\n"," self.plot_training(episode)\n","\n","\n"," def test(self, max_episodes):\n"," \"\"\"\n"," Reinforcement learning policy evaluation.\n"," \"\"\"\n","\n"," # Load the weights of the test_network\n"," self.agent.main_network.load_state_dict(torch.load(self.RL_load_path))\n"," self.agent.main_network.eval()\n","\n"," # Testing loop over episodes\n"," for episode in range(1, max_episodes+1):\n"," state, _ = self.env.reset(seed=seed)\n"," done = False\n"," truncation = False\n"," step_size = 0\n"," episode_reward = 0\n","\n"," while not done and not truncation:\n"," state = self.state_preprocess(state, num_states=self.num_states)\n"," action = self.agent.select_action(state)\n"," next_state, reward, done, _ = self.env.step(action)\n","\n"," state = next_state\n"," episode_reward += reward\n"," step_size += 1\n","\n"," # Print log\n"," result = (f\"Episode: {episode}, \"\n"," f\"Steps: {step_size:}, \"\n"," f\"Reward: {episode_reward:.2f}, \")\n","\n"," #pygame.quit() # close the rendering window\n","\n","\n"," def plot_training(self, episode):\n"," clear_output(wait=True)\n"," # Calculate the Simple Moving Average (SMA) with a window size of 50\n"," sma = np.convolve(self.reward_history, np.ones(50)/50, mode='valid')\n","\n"," plt.figure()\n"," plt.title(\"Rewards\")\n"," plt.plot(self.reward_history, label='Raw Reward', color='#F6CE3B', alpha=1)\n"," plt.plot(sma, label='SMA 50', color='#385DAA')\n"," plt.xlabel(\"Episode\")\n"," plt.ylabel(\"Rewards\")\n"," plt.legend()\n","\n"," # Only save as file if last episode\n"," if episode == self.max_episodes:\n"," plt.savefig('./reward_plot.png', format='png', dpi=600, bbox_inches='tight')\n"," plt.tight_layout()\n"," plt.grid(True)\n"," plt.show()\n"," plt.clf()\n"," plt.close()\n"],"metadata":{"id":"6403NkUVGu7z"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: Set Hyperparameters\n","def generate_parameters(train_mode):\n"," render = not train_mode\n"," map_size = 8 # 4x4 or 8x8\n"," RL_hyperparams = {\n"," \"train_mode\" : train_mode,\n"," \"RL_load_path\" : f'./final_weights' + '_' + '600' + '.pth',\n"," \"save_path\" : f'./final_weights',\n"," \"save_interval\" : 50,\n","\n"," \"clip_grad_norm\" : 1,\n"," \"learning_rate\" : 3e-4,\n"," \"discount_factor\" : 0.97,\n"," \"batch_size\" : 128,\n"," \"update_frequency\" : 4,\n"," \"max_episodes\" : 600 if train_mode else 5,\n"," \"max_steps\" : 200,\n"," \"render\" : render,\n","\n"," \"epsilon_max\" : 0.3 if train_mode else -1,\n"," \"epsilon_min\" : 0.01,\n"," \"epsilon_decay\" : 0.9998,\n","\n"," \"memory_capacity\" : 1000 if train_mode else 0,\n","\n"," \"map_size\" : map_size,\n"," \"num_states\" : map_size ** 2,\n"," \"render_fps\" : 6,\n"," }\n"," return RL_hyperparams"],"metadata":{"id":"eI1XZ3JIG1fM"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["RL_hyperparams = generate_parameters(train_mode=True)\n","DRL = Model_TrainTest(RL_hyperparams)\n","DRL.train()"],"metadata":{"id":"wi9srHSzMnB0","colab":{"base_uri":"https://localhost:8080/","height":487},"executionInfo":{"status":"ok","timestamp":1743707082710,"user_tz":-210,"elapsed":57323,"user":{"displayName":"Nima Shirzady","userId":"04764659690504916110"}},"outputId":"af34ab15-9537-4a36-f55e-8219851270a5"},"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/plain":[""],"image/png":"\n"},"metadata":{}}]},{"cell_type":"code","source":["RL_hyperparams = generate_parameters(train_mode=False)\n","DRL = Model_TrainTest(RL_hyperparams)\n","DRL.test(max_episodes=1)\n","clear_output(wait=True)\n","# Show the recordied video\n","show_video('vid')"],"metadata":{"id":"Kbr4XPHsJEe5","colab":{"base_uri":"https://localhost:8080/","height":370},"executionInfo":{"status":"ok","timestamp":1743707090846,"user_tz":-210,"elapsed":2152,"user":{"displayName":"Nima Shirzady","userId":"04764659690504916110"}},"outputId":"13921b3b-ab6e-4b86-f518-44f9a7a0f568"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""],"text/html":[""]},"metadata":{},"execution_count":13}]},{"cell_type":"code","source":["EPS = 1e-12\n","SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"],"metadata":{"id":"ok3JjDgDIC9F"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["class Policy(nn.Module):\n"," \"\"\"\n"," Implement both policy network and the value network in one model\n"," \"\"\"\n"," def __init__(self, hid_dim=16):\n"," super(Policy, self).__init__()\n"," #Extract the dimensionality of state and action spaces\n"," self.discrete = isinstance(env.action_space, gym.spaces.Discrete)\n"," self.observation_dim = env.observation_space.n\n"," self.action_dim = env.action_space.n if self.discrete else env.action_space.shape[0]\n"," self.hid_dim = args.hid_dim\n"," self.double()\n","\n"," self.input_layer = nn.Sequential(nn.Linear(self.observation_dim, self.hid_dim), nn.ReLU())\n"," self.p_layer1 = nn.Sequential(nn.Linear(self.hid_dim, self.hid_dim), nn.ReLU())\n"," self.p_layer2 = nn.Linear(self.hid_dim, self.action_dim)\n","\n"," self.v_layers = nn.ModuleList(\n"," [nn.Sequential(nn.Linear(self.hid_dim, self.hid_dim), nn.ReLU())\n"," for _ in range(2)])\n"," self.v_output_layer = nn.Linear(self.hid_dim, 1)\n","\n"," # action & reward memory\n"," self.saved_actions = []\n"," self.rewards = []\n","\n"," def forward(self, state):\n"," x = self.input_layer(state)\n"," out = self.p_layer1(x)\n"," out = self.p_layer2(out)\n"," action_prob = F.softmax(out, dim=-1)\n","\n"," for layer in self.v_layers:\n"," value = layer(x)\n"," state_value = self.v_output_layer(value)\n","\n"," return action_prob, state_value\n","\n","\n"," def select_action(self, state):\n"," b = state\n"," state = torch.zeros(1, self.observation_dim)\n"," state[0, b] = 1.0\n"," state = state.float().to(device)\n","\n"," action_prob, state_value = self.forward(state)\n"," dist = Categorical(action_prob) # convert to a distribution\n"," action = dist.sample() # choose action from the distribution\n","\n"," self.saved_actions.append(SavedAction(dist.log_prob(action), state_value)) # save to action buffer\n","\n"," return action.item()\n","\n"," def saved_rewards(self, reward):\n"," self.rewards.append(reward)\n","\n","\n"," def calculate_loss(self, gamma=0.999):\n"," saved_actions = self.saved_actions # list of actions\n"," rewards = self.rewards # list of rewards\n"," policy_losses = []\n"," state_value_list = []\n"," returns = []\n"," adv_list = []\n","\n"," for t in range(len(rewards)-1, -1, -1): # calculate disounted returns in each time step\n"," disc_returns = (returns[0] if len(returns)> 0 else 0)\n"," G_t = gamma * disc_returns + rewards[t]\n"," returns.insert(0, G_t) # insert in the beginning of the list\n"," state_value = saved_actions[t][1]\n"," state_value_list.append(state_value)\n"," adv_list.insert(0, G_t - state_value)\n","\n"," adv_list = torch.tensor(adv_list)\n"," adv_list = (adv_list - adv_list.mean()) / (adv_list.std() + EPS) # for stability\n","\n"," for step in range(len(saved_actions)):\n"," log_prob = saved_actions[step][0]\n"," adv = adv_list[step]\n"," policy_losses.append(adv * log_prob)\n","\n"," value_loss = F.mse_loss(torch.tensor(state_value_list), torch.tensor(returns))\n"," policy_loss = torch.stack(policy_losses, dim=0).sum()\n"," loss = -policy_loss + value_loss\n","\n"," return loss\n","\n"," def clear_memory(self):\n"," # reset rewards and action buffer\n"," del self.rewards[:]\n"," del self.saved_actions[:]\n"],"metadata":{"id":"s7PzaKayIix5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def train(args):\n"," model = Policy(hid_dim=args.hid_dim).to(device)\n"," optimizer = optim.Adam(model.parameters(), lr=args.lr)\n"," scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.9)\n","\n"," ewma_reward = 0 # EWMA reward for tracking the learning progress\n"," reward_history = []\n"," loss_history = []\n","\n"," for episode in range(args.episodes):\n"," # reset environment and episode reward\n"," state = env.reset()\n"," ep_reward = 0\n"," t = 0\n","\n"," steps = 200\n"," for t in range(steps):\n"," action = model.select_action(state=state)\n"," state, reward, done, _ = env.step(action)\n"," model.saved_rewards(reward)\n"," ep_reward += reward\n"," if done: break\n","\n"," loss = model.calculate_loss(gamma=args.gamma)\n"," loss_history.append(loss.item())\n","\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n"," scheduler.step()\n","\n"," model.clear_memory()\n","\n"," reward_history.append(ep_reward)\n"," # update EWMA reward and log the results\n"," ewma_reward = 0.05 * ep_reward + (1 - 0.05) * ewma_reward\n"," if (episode) % 50 == 0:\n"," #print(f\"Episode {episode+1}\\tlength: {t+1}\\treward: {ep_reward}\\t ewma reward: {ewma_reward}\")\n"," clear_output(wait=True)\n"," plt.figure()\n"," plt.title(\"Rewards\")\n"," plt.xlabel(\"Episode\")\n"," plt.ylabel(\"Rewards\")\n"," plt.legend()\n"," sma = np.convolve(reward_history, np.ones(50)/50, mode='valid')\n"," plt.title(\"Rewards\")\n"," plt.plot(reward_history, label='Raw Reward', color='#F6CE3B', alpha=1)\n"," plt.plot(sma, label='SMA 50', color='#385DAA')\n"," plt.tight_layout()\n"," plt.grid(True)\n"," plt.show()\n"," plt.clf()\n"," plt.close()\n","\n","\n"," if ewma_reward > env.spec.reward_threshold or episode == args.episodes-1:\n"," if not os.path.isdir(\"./models\"):\n"," os.mkdir(\"./models\")\n"," torch.save(model.state_dict(), f\"./models/{args.env}_baseline.pth\")\n"," break\n","\n"],"metadata":{"id":"gpDowsN7In9K"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def test(model_name):\n"," model = Policy(hid_dim=args.hid_dim).to(device)\n"," model.load_state_dict(torch.load(f\"./models/{model_name}\"))\n","\n"," max_episode_len = 10000\n"," video_folder = \"vid1\"\n"," os.makedirs(video_folder, exist_ok=True)\n"," newenv = RecordVideo(env, video_folder)\n","\n"," state = newenv.reset()\n"," running_reward = 0\n"," for t in range(max_episode_len+1):\n"," action = model.select_action(state)\n"," state, reward, done, info = newenv.step(action)\n"," running_reward += reward\n"," if done:\n"," break\n"," print(f\"Testing: Reward: {running_reward}\")\n"," env.close()\n"],"metadata":{"id":"7T4rbjqxIqzu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def generate_args(lr, gamma, hidden_dims, step_size):\n"," parser = argparse.ArgumentParser(\"REINFORCE algorithm using baseline\")\n"," parser.add_argument(\"--env\", type=str, default=\"LunarLander-v2\", help=\"Name of the environment\")\n"," parser.add_argument(\"--seed\", type=int, default=10, help=\"Random seed\")\n"," parser.add_argument(\"--lr\", type=float, default=lr, help=\"Learning rate\")\n"," parser.add_argument(\"--step_size\", type=int, default=step_size, help=\"Step size for lr scheduler\")\n"," parser.add_argument(\"--episodes\", type=int, default=1000, help=\"Number of episodes for training\")\n"," parser.add_argument(\"--gamma\", type=float, default=gamma, help=\"Discount factor\")\n"," parser.add_argument(\"--hid_dim\", type=int, default=hidden_dims, help=\"Hidden dimension of the policy network\")\n","\n"," args = parser.parse_args(args=[])\n"," return args"],"metadata":{"id":"_HjDshhmN44Q"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import gym\n","# TODO: Set Hyperparameters\n","lr = 2e-4 # Learning Rate\n","hidden_dims = 64 # Hidden Layer Dimensions\n","gamma = 0.998 # Learning Rate Decay Factor\n","max_step_size = 200 ### max step size for an episode\n","\n","args = generate_args(lr=lr, hidden_dims=hidden_dims, gamma=gamma, step_size=max_step_size)\n","random_seed = args.seed\n","env = gym.make('FrozenLake-v1', map_name=\"8x8\", is_slippery=False, max_episode_steps=args.step_size, render_mode=\"rgb_array\")\n","\n","env.seed(random_seed)\n","torch.manual_seed(random_seed)\n","\n","train(args)\n","test(f'{args.env}_baseline.pth')"],"metadata":{"id":"FKtxlQ2RlokX","colab":{"base_uri":"https://localhost:8080/","height":559},"executionInfo":{"status":"ok","timestamp":1743708511776,"user_tz":-210,"elapsed":169023,"user":{"displayName":"Nima Shirzady","userId":"04764659690504916110"}},"outputId":"1b5ea886-0ceb-40fc-9ce6-9ca2e4161de0"},"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/plain":[""],"image/png":"\n"},"metadata":{}},{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.11/dist-packages/gym/wrappers/record_video.py:78: UserWarning: \u001b[33mWARN: Overwriting existing videos at /content/vid1 folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)\u001b[0m\n"," logger.warn(\n"]},{"output_type":"stream","name":"stdout","text":["Testing: Reward: 0.0\n"]}]},{"cell_type":"code","source":["test(f'{args.env}_baseline.pth')\n","clear_output(wait=True)\n","# Show the recordied video\n","show_video('vid1')"],"metadata":{"id":"8YSyRZUYLk4j","colab":{"base_uri":"https://localhost:8080/","height":370},"executionInfo":{"status":"ok","timestamp":1743708558304,"user_tz":-210,"elapsed":1127,"user":{"displayName":"Nima Shirzady","userId":"04764659690504916110"}},"outputId":"6aad83bd-4146-4481-8fc6-4fdfa7b3ef1f"},"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":[""],"text/html":[""]},"metadata":{},"execution_count":12}]}]}