{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-19-ac.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T043789%20%7C%20Training%20RL%20Agent%20in%20CartPole%20Environment%20with%20Actor-Critic%20method.ipynb","timestamp":1644651718134}],"collapsed_sections":[],"authorship_tag":"ABX9TyMqmVBsoiHiDw0f6yiNuEgc"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"xFAAnaUK9y0n"},"source":["# Training RL Agent in CartPole Environment with Actor-Critic method"]},{"cell_type":"code","metadata":{"id":"7QLHLqC1_ivD"},"source":["import numpy as np\n","import tensorflow as tf\n","import gym\n","import tensorflow_probability as tfp"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vLAh4L1h_j0e"},"source":["class ActorCritic(tf.keras.Model):\n"," def __init__(self, action_dim):\n"," super().__init__()\n"," self.fc1 = tf.keras.layers.Dense(512, activation=\"relu\")\n"," self.fc2 = tf.keras.layers.Dense(128, activation=\"relu\")\n"," self.critic = tf.keras.layers.Dense(1, activation=None)\n"," self.actor = tf.keras.layers.Dense(action_dim, activation=None)\n","\n"," def call(self, input_data):\n"," x = self.fc1(input_data)\n"," x1 = self.fc2(x)\n"," actor = self.actor(x1)\n"," critic = self.critic(x1)\n"," return critic, actor"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"QReTNLPx_lWZ"},"source":["class Agent:\n"," def __init__(self, action_dim=4, gamma=0.99):\n"," \"\"\"Agent with a neural-network brain powered policy\n","\n"," Args:\n"," action_dim (int): Action dimension\n"," gamma (float) : Discount factor. Default=0.99\n"," \"\"\"\n","\n"," self.gamma = gamma\n"," self.opt = tf.keras.optimizers.Adam(learning_rate=1e-4)\n"," self.actor_critic = ActorCritic(action_dim)\n","\n"," def get_action(self, state):\n"," _, action_probabilities = self.actor_critic(np.array([state]))\n"," action_probabilities = tf.nn.softmax(action_probabilities)\n"," action_probabilities = action_probabilities.numpy()\n"," dist = tfp.distributions.Categorical(\n"," probs=action_probabilities, dtype=tf.float32\n"," )\n"," action = dist.sample()\n"," return int(action.numpy()[0])\n","\n"," def actor_loss(self, prob, action, td):\n"," prob = tf.nn.softmax(prob)\n"," dist = tfp.distributions.Categorical(probs=prob, dtype=tf.float32)\n"," log_prob = dist.log_prob(action)\n"," loss = -log_prob * td\n"," return loss\n","\n"," def learn(self, state, action, reward, next_state, done):\n"," state = np.array([state])\n"," next_state = np.array([next_state])\n","\n"," with tf.GradientTape() as tape:\n"," value, action_probabilities = self.actor_critic(state, training=True)\n"," value_next_st, _ = self.actor_critic(next_state, training=True)\n"," td = reward + self.gamma * value_next_st * (1 - int(done)) - value\n"," actor_loss = self.actor_loss(action_probabilities, action, td)\n"," critic_loss = td ** 2\n"," total_loss = actor_loss + critic_loss\n"," grads = tape.gradient(total_loss, self.actor_critic.trainable_variables)\n"," self.opt.apply_gradients(zip(grads, self.actor_critic.trainable_variables))\n"," return total_loss"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"pnpydnWc_orT"},"source":["def train(agent, env, episodes, render=True):\n"," \"\"\"Train `agent` in `env` for `episodes`\n","\n"," Args:\n"," agent (Agent): Agent to train\n"," env (gym.Env): Environment to train the agent\n"," episodes (int): Number of episodes to train\n"," render (bool): True=Enable/False=Disable rendering; Default=True\n"," \"\"\"\n","\n"," for episode in range(episodes):\n","\n"," done = False\n"," state = env.reset()\n"," total_reward = 0\n"," all_loss = []\n","\n"," while not done:\n"," action = agent.get_action(state)\n"," next_state, reward, done, _ = env.step(action)\n"," loss = agent.learn(state, action, reward, next_state, done)\n"," all_loss.append(loss)\n"," state = next_state\n"," total_reward += reward\n"," if render:\n"," env.render()\n"," if done:\n"," print(\"\\n\")\n"," print(f\"Episode#:{episode} ep_reward:{total_reward}\", end=\"\\r\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_jg1Vo-N_paC","executionInfo":{"status":"ok","timestamp":1638446075700,"user_tz":-330,"elapsed":3908,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"0425833e-6475-4ab0-cb46-a4d1312707fa"},"source":["if __name__ == \"__main__\":\n","\n"," env = gym.make(\"CartPole-v0\")\n"," agent = Agent(env.action_space.n)\n"," num_episodes = 10 # Increase number of episodes to train\n"," # Set render=True to visualize Agent's actions in the env\n"," train(agent, env, num_episodes, render=False)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Episode#:0 ep_reward:1.0\rEpisode#:0 ep_reward:2.0\rEpisode#:0 ep_reward:3.0\rEpisode#:0 ep_reward:4.0\rEpisode#:0 ep_reward:5.0\rEpisode#:0 ep_reward:6.0\rEpisode#:0 ep_reward:7.0\rEpisode#:0 ep_reward:8.0\rEpisode#:0 ep_reward:9.0\rEpisode#:0 ep_reward:10.0\r\n","\n","Episode#:1 ep_reward:28.0\n","\n","Episode#:2 ep_reward:10.0\n","\n","Episode#:3 ep_reward:17.0\n","\n","Episode#:4 ep_reward:32.0\n","\n","Episode#:5 ep_reward:32.0\n","\n","Episode#:6 ep_reward:15.0\n","\n","Episode#:7 ep_reward:37.0\n","\n","Episode#:8 ep_reward:10.0\n","\n","Episode#:9 ep_reward:21.0\n","\n"]}]},{"cell_type":"markdown","metadata":{"id":"XmPZbFhL_7fN"},"source":["---"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"xxO4q8Uo_sW-","executionInfo":{"status":"ok","timestamp":1638446126199,"user_tz":-330,"elapsed":3640,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"f8b10262-5828-438c-d30d-812dd45bc602"},"source":["!pip install -q watermark\n","%reload_ext watermark\n","%watermark -a \"Sparsh A.\" -m -iv -u -t -d"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Author: Sparsh A.\n","\n","Last updated: 2021-12-02 11:55:28\n","\n","Compiler : GCC 7.5.0\n","OS : Linux\n","Release : 5.4.104+\n","Machine : x86_64\n","Processor : x86_64\n","CPU cores : 2\n","Architecture: 64bit\n","\n","gym : 0.17.3\n","IPython : 5.5.0\n","tensorflow_probability: 0.15.0\n","tensorflow : 2.7.0\n","numpy : 1.19.5\n","\n"]}]},{"cell_type":"markdown","metadata":{"id":"tlR6JySQ_8ta"},"source":["---"]},{"cell_type":"markdown","metadata":{"id":"NZYSPhLS_8rO"},"source":["**END**"]}]}