{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-21-gridworld-td.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T122762%20%7C%20Training%20RL%20Agent%20in%20Gridworld%20with%20Temporal%20Difference%20learning%20method.ipynb","timestamp":1644659399764}],"collapsed_sections":[],"authorship_tag":"ABX9TyObH3bb/z+IhSDM8ZLwKd4r"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"bmIy05Cqwt3X"},"source":["# Training RL Agent in Gridworld with Temporal Difference learning method"]},{"cell_type":"markdown","metadata":{"id":"CeesDf_Xupst"},"source":["## Imports"]},{"cell_type":"code","metadata":{"id":"qCl7x_aFuhxe"},"source":["import gym\n","import numpy as np\n","import matplotlib.pyplot as plt\n","import seaborn as sns"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"tR76MDYoujcE"},"source":["## Gridworld"]},{"cell_type":"code","metadata":{"id":"PTl1sTu6ujZe"},"source":["class GridworldV2Env(gym.Env):\n"," def __init__(self, step_cost=-0.2, max_ep_length=500, explore_start=False):\n"," self.index_to_coordinate_map = {\n"," \"0\": [0, 0],\n"," \"1\": [0, 1],\n"," \"2\": [0, 2],\n"," \"3\": [0, 3],\n"," \"4\": [1, 0],\n"," \"5\": [1, 1],\n"," \"6\": [1, 2],\n"," \"7\": [1, 3],\n"," \"8\": [2, 0],\n"," \"9\": [2, 1],\n"," \"10\": [2, 2],\n"," \"11\": [2, 3],\n"," }\n"," self.coordinate_to_index_map = {\n"," str(val): int(key) for key, val in self.index_to_coordinate_map.items()\n"," }\n"," self.map = np.zeros((3, 4))\n"," self.observation_space = gym.spaces.Discrete(1)\n"," self.distinct_states = [str(i) for i in range(12)]\n"," self.goal_coordinate = [0, 3]\n"," self.bomb_coordinate = [1, 3]\n"," self.wall_coordinate = [1, 1]\n"," self.goal_state = self.coordinate_to_index_map[str(self.goal_coordinate)] # 3\n"," self.bomb_state = self.coordinate_to_index_map[str(self.bomb_coordinate)] # 7\n"," self.map[self.goal_coordinate[0]][self.goal_coordinate[1]] = 1\n"," self.map[self.bomb_coordinate[0]][self.bomb_coordinate[1]] = -1\n"," self.map[self.wall_coordinate[0]][self.wall_coordinate[1]] = 2\n","\n"," self.exploring_starts = explore_start\n"," self.state = 8\n"," self.done = False\n"," self.max_ep_length = max_ep_length\n"," self.steps = 0\n"," self.step_cost = step_cost\n"," self.action_space = gym.spaces.Discrete(4)\n"," self.action_map = {\"UP\": 0, \"RIGHT\": 1, \"DOWN\": 2, \"LEFT\": 3}\n"," self.possible_actions = list(self.action_map.values())\n","\n"," def reset(self):\n"," self.done = False\n"," self.steps = 0\n"," self.map = np.zeros((3, 4))\n"," self.map[self.goal_coordinate[0]][self.goal_coordinate[1]] = 1\n"," self.map[self.bomb_coordinate[0]][self.bomb_coordinate[1]] = -1\n"," self.map[self.wall_coordinate[0]][self.wall_coordinate[1]] = 2\n","\n"," if self.exploring_starts:\n"," self.state = np.random.choice([0, 1, 2, 4, 6, 8, 9, 10, 11])\n"," else:\n"," self.state = 8\n"," return self.state\n","\n"," def get_next_state(self, current_position, action):\n","\n"," next_state = self.index_to_coordinate_map[str(current_position)].copy()\n","\n"," if action == 0 and next_state[0] != 0 and next_state != [2, 1]:\n"," # Move up\n"," next_state[0] -= 1\n"," elif action == 1 and next_state[1] != 3 and next_state != [1, 0]:\n"," # Move right\n"," next_state[1] += 1\n"," elif action == 2 and next_state[0] != 2 and next_state != [0, 1]:\n"," # Move down\n"," next_state[0] += 1\n"," elif action == 3 and next_state[1] != 0 and next_state != [1, 2]:\n"," # Move left\n"," next_state[1] -= 1\n"," else:\n"," pass\n"," return self.coordinate_to_index_map[str(next_state)]\n","\n"," def step(self, action):\n"," assert action in self.possible_actions, f\"Invalid action:{action}\"\n","\n"," current_position = self.state\n"," next_state = self.get_next_state(current_position, action)\n","\n"," self.steps += 1\n","\n"," if next_state == self.goal_state:\n"," reward = 1\n"," self.done = True\n","\n"," elif next_state == self.bomb_state:\n"," reward = -1\n"," self.done = True\n"," else:\n"," reward = self.step_cost\n","\n"," if self.steps == self.max_ep_length:\n"," self.done = True\n","\n"," self.state = next_state\n"," return next_state, reward, self.done"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"PMeLTRHPujWT"},"source":["## Visualization function"]},{"cell_type":"code","metadata":{"id":"4wfzk8Kiu28Q"},"source":["def visualize_grid_state_values(grid_state_values):\n"," \"\"\"Visualizes the state value function for the grid\"\"\"\n"," plt.figure(figsize=(10, 5))\n"," p = sns.heatmap(\n"," grid_state_values,\n"," cmap=\"Greens\",\n"," annot=True,\n"," fmt=\".1f\",\n"," annot_kws={\"size\": 16},\n"," square=True,\n"," )\n"," p.set_ylim(len(grid_state_values) + 0.01, -0.01)\n"," plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JrwdkXQwu3tM"},"source":["## Temporal Difference learning method"]},{"cell_type":"code","metadata":{"id":"2dlmJ4BAu76v"},"source":["def temporal_difference_learning(env, max_episodes):\n"," grid_state_values = np.zeros((len(env.distinct_states), 1))\n"," grid_state_values[env.goal_state] = 1\n"," grid_state_values[env.bomb_state] = -1\n"," # v: state-value function\n"," v = grid_state_values\n"," gamma = 0.99 # Discount factor\n"," alpha = 0.01 # learning rate\n","\n"," for episode in range(max_episodes):\n"," state = env.reset()\n"," done = False\n"," while not done:\n"," action = env.action_space.sample() # random policy\n"," next_state, reward, done = env.step(action)\n","\n"," # State-value function updates using TD(0)\n"," v[state] += alpha * (reward + gamma * v[next_state] - v[state])\n"," state = next_state\n"," visualize_grid_state_values(grid_state_values.reshape((3, 4)))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":324},"id":"6ecXS55wu-r6","executionInfo":{"status":"ok","timestamp":1638441777337,"user_tz":-330,"elapsed":1941,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"02df2eda-3cce-4977-888d-e0c681f7638f"},"source":["if __name__ == \"__main__\":\n"," max_episodes = 4000\n"," env = GridworldV2Env(step_cost=-0.1, max_ep_length=30)\n"," temporal_difference_learning(env, max_episodes)"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"6F1CvviNvf0U"},"source":["---"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"eY5_ri7ovC5d","executionInfo":{"status":"ok","timestamp":1638441822532,"user_tz":-330,"elapsed":3730,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"f3416f07-e18a-4b51-d436-2948f1f20495"},"source":["!pip install -q watermark\n","%reload_ext watermark\n","%watermark -a \"Sparsh A.\" -m -iv -u -t -d"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Author: Sparsh A.\n","\n","Last updated: 2021-12-02 10:43:44\n","\n","Compiler : GCC 7.5.0\n","OS : Linux\n","Release : 5.4.104+\n","Machine : x86_64\n","Processor : x86_64\n","CPU cores : 2\n","Architecture: 64bit\n","\n","gym : 0.17.3\n","matplotlib: 3.2.2\n","numpy : 1.19.5\n","seaborn : 0.11.2\n","IPython : 5.5.0\n","\n"]}]},{"cell_type":"markdown","metadata":{"id":"eQZmkNqovgsp"},"source":["---"]},{"cell_type":"markdown","metadata":{"id":"Jpw59lINviPG"},"source":["**END**"]}]}