{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-24-apprentice-mountaincar.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T250391%20%7C%20Apprenticeship%20Learning%20in%20Mountaincar%20Environment.ipynb","timestamp":1644669408947},{"file_id":"1K1DpwKNrsmvK-CKDvfOPXt4UfktdSABH","timestamp":1636606729887}],"collapsed_sections":[],"toc_visible":true,"authorship_tag":"ABX9TyPVuNwWSQKASm9LSWCS0n7f"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"b7tF9-oSdIEv"},"source":["# Apprenticeship Learning in Mountaincar Environment"]},{"cell_type":"markdown","metadata":{"id":"pV5flWpeVM6M"},"source":["## Setup"]},{"cell_type":"markdown","metadata":{"id":"VdBAv8vWVM36"},"source":["### Installations"]},{"cell_type":"code","metadata":{"id":"y1GyS_RCTdvg"},"source":["!pip install gym pyvirtualdisplay > /dev/null 2>&1\n","!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1\n","\n","!apt-get update > /dev/null 2>&1\n","!apt-get install cmake > /dev/null 2>&1\n","!pip install --upgrade setuptools 2>&1\n","!pip install ez_setup > /dev/null 2>&1\n","!pip install gym[atari] > /dev/null 2>&1\n","\n","!wget http://www.atarimania.com/roms/Roms.rar\n","!mkdir /content/ROM/\n","!unrar e /content/Roms.rar /content/ROM/\n","!python -m atari_py.import_roms /content/ROM/"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"CfrWqvJJS0I8"},"source":["!pip install -q gym\n","!pip install -q pylab-sdk\n","!pip install -q readchar"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"7ZVxXGqdVLeL"},"source":["### Imports"]},{"cell_type":"code","metadata":{"id":"gXzsjx1rS4m5"},"source":["import gym\n","import matplotlib.pyplot as plt\n","import readchar\n","import numpy as np\n","import sys\n","import cvxpy as cp"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"G1lcBr_bVR6g"},"source":["### Gym render"]},{"cell_type":"code","metadata":{"id":"8FiXt334VV3l"},"source":["from gym.wrappers import Monitor\n","import glob\n","import io\n","import base64\n","from IPython.display import HTML\n","from pyvirtualdisplay import Display\n","from IPython import display as ipythondisplay"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"rsZZr8XaTdvh"},"source":["display = Display(visible=0, size=(1400, 900))\n","display.start()\n","\n","\"\"\"\n","Utility functions to enable video recording of gym environment \n","and displaying it.\n","To enable video, just do \"env = wrap_env(env)\"\"\n","\"\"\"\n","\n","def show_video():\n"," mp4list = glob.glob('video/*.mp4')\n"," if len(mp4list) > 0:\n"," mp4 = mp4list[0]\n"," video = io.open(mp4, 'r+b').read()\n"," encoded = base64.b64encode(video)\n"," ipythondisplay.display(HTML(data=''''''.format(encoded.decode('ascii'))))\n"," else: \n"," print(\"Could not find video\")\n"," \n","\n","def wrap_env(env):\n"," env = Monitor(env, './video', force=True)\n"," return env"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"XH0uQ0dGVJ-8"},"source":["### Params"]},{"cell_type":"code","metadata":{"id":"S7Ut5RnVTZwG"},"source":["# MACROS\n","Push_Left = 0\n","No_Push = 1\n","Push_Right = 2"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kYjjkkgWTYxO"},"source":["# Key mapping\n","arrow_keys = {\n"," '\\x1b[D': Push_Left,\n"," '\\x1b[B': No_Push,\n"," '\\x1b[C': Push_Right}"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MTzgzJ-SUOt8"},"source":["n_states = 400 # position - 20, velocity - 20\n","n_actions = 3\n","one_feature = 20 # number of state per one feature\n","feature_num = 4\n","q_table = np.zeros((n_states, n_actions)) # (400, 3)\n","\n","gamma = 0.99\n","q_learning_rate = 0.03\n","\n","N_idx = 20\n","F_idx = 4\n","GAMMA = 0.99\n","\n","np.random.seed(1)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Dr1OTNgjVHV4"},"source":["## Expert Demo"]},{"cell_type":"code","metadata":{"id":"r7hM2cStTyjf"},"source":["# env = wrap_env(gym.make(\"MountainCar-v0\"))\n","\n","# trajectories = []\n","# episode_step = 0\n","\n","# for episode in range(20): # n_trajectories : 20\n","# trajectory = []\n","# step = 0\n","\n","# env.reset()\n","# print(\"episode_step\", episode_step)\n","\n","# while True: \n","# env.render()\n","# print(\"step\", step)\n","\n","# key = readchar.readkey()\n","# if key not in arrow_keys.keys():\n","# break\n","\n","# action = arrow_keys[key]\n","# state, reward, done, _ = env.step(action)\n","\n","# if state[0] >= env.env.goal_position and step > 129: # trajectory_length : 130\n","# break\n","\n","# trajectory.append((state[0], state[1], action))\n","# step += 1\n","\n","# trajectory_numpy = np.array(trajectory, float)\n","# print(\"trajectory_numpy.shape\", trajectory_numpy.shape)\n","# episode_step += 1\n","# trajectories.append(trajectory)\n","\n","# np_trajectories = np.array(trajectories, float)\n","# print(\"np_trajectories.shape\", np_trajectories.shape)\n","\n","# np.save(\"expert_demo\", arr=np_trajectories)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"YFF2azikVZi9","executionInfo":{"status":"ok","timestamp":1636606887163,"user_tz":-330,"elapsed":1626,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"e6c76440-03de-4207-da96-975b15182c91"},"source":["!wget -q --show-progress https://github.com/reinforcement-learning-kr/lets-do-irl/raw/master/mountaincar/app/expert_demo/expert_demo.npy"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["\rexpert_demo.npy 0%[ ] 0 --.-KB/s \rexpert_demo.npy 100%[===================>] 62.62K --.-KB/s in 0.03s \n"]}]},{"cell_type":"markdown","metadata":{"id":"UzzcDLaaWB0m"},"source":["## Training"]},{"cell_type":"code","metadata":{"id":"Rn5btkVIYHT0"},"source":["def idx_state(env, state):\n"," env_low = env.observation_space.low\n"," env_high = env.observation_space.high\n"," env_distance = (env_high - env_low) / one_feature\n"," positioone_feature = int((state[0] - env_low[0]) / env_distance[0])\n"," velocity_idx = int((state[1] - env_low[1]) / env_distance[1])\n"," state_idx = positioone_feature + velocity_idx * one_feature\n"," return state_idx\n","\n","def update_q_table(state, action, reward, next_state):\n"," q_1 = q_table[state][action]\n"," q_2 = reward + gamma * max(q_table[next_state])\n"," q_table[state][action] += q_learning_rate * (q_2 - q_1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"0bk-oKZAYQhd"},"source":["class FeatureEstimate:\n"," def __init__(self, feature_num, env):\n"," self.env = env\n"," self.feature_num = feature_num\n"," self.feature = np.ones(self.feature_num)\n","\n"," def gaussian_function(self, x, mu):\n"," return np.exp(-np.power(x - mu, 2.) / (2 * np.power(1., 2.)))\n","\n"," def get_features(self, state):\n"," env_low = self.env.observation_space.low\n"," env_high = self.env.observation_space.high\n"," env_distance = (env_high - env_low) / (self.feature_num - 1)\n","\n"," for i in range(int(self.feature_num/2)):\n"," # position\n"," self.feature[i] = self.gaussian_function(state[0], \n"," env_low[0] + i * env_distance[0])\n"," # velocity\n"," self.feature[i+int(self.feature_num/2)] = self.gaussian_function(state[1], \n"," env_low[1] + i * env_distance[1])\n","\n"," return self.feature\n","\n","\n","def calc_feature_expectation(feature_num, gamma, q_table, demonstrations, env):\n"," feature_estimate = FeatureEstimate(feature_num, env)\n"," feature_expectations = np.zeros(feature_num)\n"," demo_num = len(demonstrations)\n"," \n"," for _ in range(demo_num):\n"," state = env.reset()\n"," demo_length = 0\n"," done = False\n"," \n"," while not done:\n"," demo_length += 1\n","\n"," state_idx = idx_state(env, state)\n"," action = np.argmax(q_table[state_idx])\n"," next_state, reward, done, _ = env.step(action)\n"," \n"," features = feature_estimate.get_features(next_state)\n"," feature_expectations += (gamma**(demo_length)) * np.array(features)\n","\n"," state = next_state\n"," \n"," feature_expectations = feature_expectations/ demo_num\n","\n"," return feature_expectations\n","\n","def expert_feature_expectation(feature_num, gamma, demonstrations, env):\n"," feature_estimate = FeatureEstimate(feature_num, env)\n"," feature_expectations = np.zeros(feature_num)\n"," \n"," for demo_num in range(len(demonstrations)):\n"," for demo_length in range(len(demonstrations[0])):\n"," state = demonstrations[demo_num][demo_length]\n"," features = feature_estimate.get_features(state)\n"," feature_expectations += (gamma**(demo_length)) * np.array(features)\n"," \n"," feature_expectations = feature_expectations / len(demonstrations)\n"," \n"," return feature_expectations\n","\n","\n","def QP_optimizer(feature_num, learner, expert):\n"," w = cp.Variable(feature_num)\n"," \n"," obj_func = cp.Minimize(cp.norm(w))\n"," constraints = [(expert-learner) * w >= 2] \n","\n"," prob = cp.Problem(obj_func, constraints)\n"," prob.solve()\n","\n"," if prob.status == \"optimal\":\n"," print(\"status:\", prob.status)\n"," print(\"optimal value\", prob.value)\n"," \n"," weights = np.squeeze(np.asarray(w.value))\n"," return weights, prob.status\n"," else:\n"," print(\"status:\", prob.status)\n"," \n"," weights = np.zeros(feature_num)\n"," return weights, prob.status\n","\n","\n","def add_feature_expectation(learner, temp_learner):\n"," # save new feature expectation to list after RL step\n"," learner = np.vstack([learner, temp_learner])\n"," return learner\n","\n","def subtract_feature_expectation(learner):\n"," # if status is infeasible, subtract first feature expectation\n"," learner = learner[1:][:]\n"," return learner"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mC_jI_fqTyEY","executionInfo":{"status":"ok","timestamp":1636608973194,"user_tz":-330,"elapsed":1901683,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"d6fb33e0-c2e9-4c59-d777-c0c1eb063f7f"},"source":["env = wrap_env(gym.make(\"MountainCar-v0\"))\n","\n","demonstrations = np.load(file=\"expert_demo.npy\")\n","\n","feature_estimate = FeatureEstimate(feature_num, env)\n","\n","learner = calc_feature_expectation(feature_num, gamma, q_table, demonstrations, env)\n","learner = np.matrix([learner])\n","\n","expert = expert_feature_expectation(feature_num, gamma, demonstrations, env)\n","expert = np.matrix([expert])\n","\n","w, status = QP_optimizer(feature_num, learner, expert)\n","\n","\n","episodes, scores = [], []\n","\n","for episode in range(60000):\n"," state = env.reset()\n"," score = 0\n","\n"," while True:\n"," state_idx = idx_state(env, state)\n"," action = np.argmax(q_table[state_idx])\n"," next_state, reward, done, _ = env.step(action)\n"," \n"," features = feature_estimate.get_features(state)\n"," irl_reward = np.dot(w, features)\n"," \n"," next_state_idx = idx_state(env, next_state)\n"," update_q_table(state_idx, action, irl_reward, next_state_idx)\n","\n"," score += reward\n"," state = next_state\n","\n"," if done:\n"," scores.append(score)\n"," episodes.append(episode)\n"," break\n","\n"," if episode % 1000 == 0:\n"," score_avg = np.mean(scores)\n"," print('{} episode score is {:.2f}'.format(episode, score_avg))\n"," # plt.plot(episodes, scores, 'b')\n"," # plt.savefig(\"./learning_curves/app_eps_60000.png\")\n"," np.save(\"app_q_table\", arr=q_table)\n","\n"," if episode % 5000 == 0:\n"," # optimize weight per 5000 episode\n"," status = \"infeasible\"\n"," temp_learner = calc_feature_expectation(feature_num, gamma, q_table, demonstrations, env)\n"," learner = add_feature_expectation(learner, temp_learner)\n"," \n"," while status==\"infeasible\":\n"," w, status = QP_optimizer(feature_num, learner, expert)\n"," if status==\"infeasible\":\n"," learner = subtract_feature_expectation(learner)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["status: optimal\n","optimal value 0.04285986897328936\n","0 episode score is -200.00\n","status: optimal\n","optimal value 0.04343140320473908\n","1000 episode score is -199.97\n","2000 episode score is -199.12\n","3000 episode score is -198.75\n","4000 episode score is -197.01\n","5000 episode score is -194.22\n","status: optimal\n","optimal value 0.05168352123017688\n","6000 episode score is -193.90\n","7000 episode score is -192.68\n","8000 episode score is -190.49\n","9000 episode score is -187.72\n","10000 episode score is -185.93\n","status: optimal\n","optimal value 0.06887200496494023\n","11000 episode score is -185.53\n","12000 episode score is -184.30\n","13000 episode score is -183.70\n","14000 episode score is -182.80\n","15000 episode score is -181.61\n","status: optimal\n","optimal value 0.06887200491350551\n","16000 episode score is -180.00\n","17000 episode score is -178.65\n","18000 episode score is -176.88\n","19000 episode score is -176.11\n","20000 episode score is -174.66\n","status: optimal\n","optimal value 0.08240263036799281\n","21000 episode score is -174.20\n","22000 episode score is -173.40\n","23000 episode score is -172.07\n","24000 episode score is -172.03\n","25000 episode score is -171.51\n","status: optimal\n","optimal value 0.10118726927984818\n","26000 episode score is -171.19\n","27000 episode score is -171.07\n","28000 episode score is -170.65\n","29000 episode score is -170.51\n","30000 episode score is -170.29\n","status: optimal\n","optimal value 0.10118726939128136\n","31000 episode score is -169.58\n","32000 episode score is -169.06\n","33000 episode score is -168.43\n","34000 episode score is -167.48\n","35000 episode score is -167.00\n","status: optimal\n","optimal value 0.10118726938402643\n","36000 episode score is -166.79\n","37000 episode score is -166.07\n","38000 episode score is -165.34\n","39000 episode score is -164.64\n","40000 episode score is -163.92\n","status: optimal\n","optimal value 0.12417663796193439\n","41000 episode score is -163.59\n","42000 episode score is -163.09\n","43000 episode score is -162.44\n","44000 episode score is -161.99\n","45000 episode score is -162.07\n","status: optimal\n","optimal value 0.12417663799999985\n","46000 episode score is -161.69\n","47000 episode score is -161.35\n","48000 episode score is -161.39\n","49000 episode score is -161.17\n","50000 episode score is -160.78\n","status: optimal\n","optimal value 0.12417663798361642\n","51000 episode score is -160.73\n","52000 episode score is -160.45\n","53000 episode score is -160.06\n","54000 episode score is -159.64\n","55000 episode score is -159.23\n","status: optimal\n","optimal value 0.1331090749714291\n","56000 episode score is -158.89\n","57000 episode score is -158.49\n","58000 episode score is -158.14\n","59000 episode score is -157.77\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":268},"id":"-SCtYi_odal3","executionInfo":{"status":"ok","timestamp":1636609020586,"user_tz":-330,"elapsed":648,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"0b1302c4-cdc3-415e-e76b-af4cfd6dd60b"},"source":["plt.plot(episodes, scores, 'b')\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"needs_background":"light"}}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":421},"id":"o0A-BMMjdT-q","executionInfo":{"status":"ok","timestamp":1636609023304,"user_tz":-330,"elapsed":11,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"800ff3f6-7cd3-4b59-d78f-caa8ef3a0b50"},"source":["show_video()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/html":[""],"text/plain":[""]},"metadata":{}}]},{"cell_type":"markdown","metadata":{"id":"bCWHIAAgTyg7"},"source":["## Test"]},{"cell_type":"code","metadata":{"id":"kkNGrXpMTye0"},"source":["def idx_to_state(env, state):\n"," env_low = env.observation_space.low\n"," env_high = env.observation_space.high\n"," env_distance = (env_high - env_low) / N_idx\n"," position_idx = int((state[0] - env_low[0]) / env_distance[0])\n"," velocity_idx = int((state[1] - env_low[1]) / env_distance[1])\n"," state_idx = position_idx + velocity_idx * N_idx\n"," return state_idx"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"OMy7a8KwWkjr","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1636609071483,"user_tz":-330,"elapsed":7610,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"6c400b6d-db79-4671-867e-7d7ba2ce2e4b"},"source":["print(\":: Testing APP-learning.\\n\")\n","\n","# Load the agent\n","n_states = N_idx**2 # position - 20, velocity - 20\n","n_actions = 3\n","q_table = np.load(file=\"app_q_table.npy\")\n","\n","# Create a new game instance.\n","env = wrap_env(gym.make(\"MountainCar-v0\"))\n","n_episode = 10 # test the agent 10times\n","scores = []\n","\n","for ep in range(n_episode):\n"," state = env.reset()\n"," score = 0\n","\n"," while True:\n"," # Render the play\n"," env.render()\n","\n"," state_idx = idx_to_state(env, state)\n","\n"," action = np.argmax(q_table[state_idx])\n","\n"," next_state, reward, done, _ = env.step(action)\n"," next_state_idx = idx_to_state(env, next_state)\n","\n"," score += reward\n"," state = next_state\n","\n"," if done:\n"," print('{} episode | score: {:.1f}'.format(ep + 1, score))\n"," \n"," break\n","\n","env.close()"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":[":: Testing APP-learning.\n","\n","1 episode | score: -97.0\n","2 episode | score: -133.0\n","3 episode | score: -98.0\n","4 episode | score: -148.0\n","5 episode | score: -135.0\n","6 episode | score: -146.0\n","7 episode | score: -131.0\n","8 episode | score: -132.0\n","9 episode | score: -158.0\n","10 episode | score: -98.0\n"]}]},{"cell_type":"code","metadata":{"id":"dajTkpC-Tdvi","colab":{"base_uri":"https://localhost:8080/","height":421},"executionInfo":{"status":"ok","timestamp":1636609075181,"user_tz":-330,"elapsed":850,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"364b2145-5eea-4ab0-faa4-8c1f7dd4e1a1"},"source":["show_video()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/html":[""],"text/plain":[""]},"metadata":{}}]}]}