{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-23-rl-sac.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/T219631%20%7C%20Training%20Stock%20Trading%20RL%20Agent%20using%20SAC%20and%20Deploying%20%20as%20a%20Service.ipynb","timestamp":1644663599248}],"collapsed_sections":[],"authorship_tag":"ABX9TyOXzB1pylTORYaR1gwZK7Zg"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"4TKk8aKOc-zR"},"source":["# Training Stock Trading RL Agent using SAC and Deploying as a Service"]},{"cell_type":"code","metadata":{"id":"Mi3_RmfZFKCA"},"source":["import functools\n","import random\n","from collections import deque\n","\n","import numpy as np\n","import tensorflow as tf\n","import tensorflow_probability as tfp\n","from tensorflow.keras.layers import Concatenate, Dense, Input\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.optimizers import Adam\n","\n","import functools\n","from collections import deque"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"lIB0PD0OFO6Q"},"source":["tf.keras.backend.set_floatx(\"float64\")"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"8seA5Hb_IMDQ"},"source":["## Implementing the RL agent’s runtime components"]},{"cell_type":"code","metadata":{"id":"mruDbPR2FO3v"},"source":["def actor(state_shape, action_shape, units=(512, 256, 64)):\n"," state_shape_flattened = functools.reduce(lambda x, y: x * y, state_shape)\n"," state = Input(shape=state_shape_flattened)\n"," x = Dense(units[0], name=\"L0\", activation=\"relu\")(state)\n"," for index in range(1, len(units)):\n"," x = Dense(units[index], name=\"L{}\".format(index), activation=\"relu\")(x)\n","\n"," actions_mean = Dense(action_shape[0], name=\"Out_mean\")(x)\n"," actions_std = Dense(action_shape[0], name=\"Out_std\")(x)\n","\n"," model = Model(inputs=state, outputs=[actions_mean, actions_std])\n","\n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"f9n4iEqjFO1W"},"source":["def critic(state_shape, action_shape, units=(512, 256, 64)):\n"," state_shape_flattened = functools.reduce(lambda x, y: x * y, state_shape)\n"," inputs = [Input(shape=state_shape_flattened), Input(shape=action_shape)]\n"," concat = Concatenate(axis=-1)(inputs)\n"," x = Dense(units[0], name=\"Hidden0\", activation=\"relu\")(concat)\n"," for index in range(1, len(units)):\n"," x = Dense(units[index], name=\"Hidden{}\".format(index), activation=\"relu\")(x)\n","\n"," output = Dense(1, name=\"Out_QVal\")(x)\n"," model = Model(inputs=inputs, outputs=output)\n","\n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"4knlOHw_FU9e"},"source":["def update_target_weights(model, target_model, tau=0.005):\n"," weights = model.get_weights()\n"," target_weights = target_model.get_weights()\n"," for i in range(len(target_weights)): # set tau% of target model to be new weights\n"," target_weights[i] = weights[i] * tau + target_weights[i] * (1 - tau)\n"," target_model.set_weights(target_weights)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"tuhIY539FXyt"},"source":["# class SAC(object):\n","# def __init__(\n","# self,\n","# observation_shape,\n","# action_space,\n","# lr_actor=3e-5,\n","# lr_critic=3e-4,\n","# actor_units=(64, 64),\n","# critic_units=(64, 64),\n","# auto_alpha=True,\n","# alpha=0.2,\n","# tau=0.005,\n","# gamma=0.99,\n","# batch_size=128,\n","# memory_cap=100000,\n","# ):\n","# self.state_shape = observation_shape # shape of observations\n","# self.action_shape = action_space.shape # number of actions\n","# self.action_bound = (action_space.high - action_space.low) / 2\n","# self.action_shift = (action_space.high + action_space.low) / 2\n","# self.memory = deque(maxlen=int(memory_cap))\n","\n","# # Define and initialize actor network\n","# self.actor = actor(self.state_shape, self.action_shape, actor_units)\n","# self.actor_optimizer = Adam(learning_rate=lr_actor)\n","# self.log_std_min = -20\n","# self.log_std_max = 2\n","# print(self.actor.summary())\n","\n","# # Define and initialize critic networks\n","# self.critic_1 = critic(self.state_shape, self.action_shape, critic_units)\n","# self.critic_target_1 = critic(self.state_shape, self.action_shape, critic_units)\n","# self.critic_optimizer_1 = Adam(learning_rate=lr_critic)\n","# update_target_weights(self.critic_1, self.critic_target_1, tau=1.0)\n","\n","# self.critic_2 = critic(self.state_shape, self.action_shape, critic_units)\n","# self.critic_target_2 = critic(self.state_shape, self.action_shape, critic_units)\n","# self.critic_optimizer_2 = Adam(learning_rate=lr_critic)\n","# update_target_weights(self.critic_2, self.critic_target_2, tau=1.0)\n","\n","# print(self.critic_1.summary())\n","\n","# # Define and initialize temperature alpha and target entropy\n","# self.auto_alpha = auto_alpha\n","# if auto_alpha:\n","# self.target_entropy = -np.prod(self.action_shape)\n","# self.log_alpha = tf.Variable(0.0, dtype=tf.float64)\n","# self.alpha = tf.Variable(0.0, dtype=tf.float64)\n","# self.alpha.assign(tf.exp(self.log_alpha))\n","# self.alpha_optimizer = Adam(learning_rate=lr_actor)\n","# else:\n","# self.alpha = tf.Variable(alpha, dtype=tf.float64)\n","\n","# # Set hyperparameters\n","# self.gamma = gamma # discount factor\n","# self.tau = tau # target model update\n","# self.batch_size = batch_size\n","\n","# # Tensorboard\n","# self.summaries = {}\n","\n","# def process_actions(self, mean, log_std, test=False, eps=1e-6):\n","# std = tf.math.exp(log_std)\n","# raw_actions = mean\n","\n","# if not test:\n","# raw_actions += tf.random.normal(shape=mean.shape, dtype=tf.float64) * std\n","\n","# log_prob_u = tfp.distributions.Normal(loc=mean, scale=std).log_prob(raw_actions)\n","# actions = tf.math.tanh(raw_actions)\n","\n","# log_prob = tf.reduce_sum(log_prob_u - tf.math.log(1 - actions ** 2 + eps))\n","\n","# actions = actions * self.action_bound + self.action_shift\n","\n","# return actions, log_prob\n","\n","# def act(self, state, test=False, use_random=False):\n","# state = state.reshape(-1) # Flatten state\n","# state = np.expand_dims(state, axis=0).astype(np.float64)\n","\n","# if use_random:\n","# a = tf.random.uniform(\n","# shape=(1, self.action_shape[0]), minval=-1, maxval=1, dtype=tf.float64\n","# )\n","# else:\n","# means, log_stds = self.actor.predict(state)\n","# log_stds = tf.clip_by_value(log_stds, self.log_std_min, self.log_std_max)\n","\n","# a, log_prob = self.process_actions(means, log_stds, test=test)\n","\n","# q1 = self.critic_1.predict([state, a])[0][0]\n","# q2 = self.critic_2.predict([state, a])[0][0]\n","# self.summaries[\"q_min\"] = tf.math.minimum(q1, q2)\n","# self.summaries[\"q_mean\"] = np.mean([q1, q2])\n","\n","# return a\n","\n","# def load_actor(self, a_fn):\n","# self.actor.load_weights(a_fn)\n","# print(self.actor.summary())\n","\n","# def load_critic(self, c_fn):\n","# self.critic_1.load_weights(c_fn)\n","# self.critic_target_1.load_weights(c_fn)\n","# self.critic_2.load_weights(c_fn)\n","# self.critic_target_2.load_weights(c_fn)\n","# print(self.critic_1.summary())"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"TOHIrXBRGDxo"},"source":["we implemented the essential runtime components for the SAC agent. The runtime components include the actor and critic model definitions, a mechanism to load weights from previously trained agent models, and an agent interface to generate actions given states using the actor’s prediction and to process the prediction to generate an executable action.\n","\n","The runtime components for other actor-critic-based RL agent algorithms, such as A2C, A3C, and DDPG, as well as their extensions and variants, will be very similar, if not the same."]},{"cell_type":"markdown","metadata":{"id":"POMp4qjeSrKe"},"source":["SAC Agent Base"]},{"cell_type":"code","metadata":{"id":"vEJqAWaYSpmM"},"source":["class SAC(object):\n"," def __init__(\n"," self,\n"," observation_shape,\n"," action_space,\n"," lr_actor=3e-5,\n"," lr_critic=3e-4,\n"," actor_units=(64, 64),\n"," critic_units=(64, 64),\n"," auto_alpha=True,\n"," alpha=0.2,\n"," tau=0.005,\n"," gamma=0.99,\n"," batch_size=128,\n"," memory_cap=100000,\n"," ):\n"," self.state_shape = observation_shape # shape of observations\n"," self.action_shape = action_space.shape # number of actions\n"," self.action_bound = (action_space.high - action_space.low) / 2\n"," self.action_shift = (action_space.high + action_space.low) / 2\n"," self.memory = deque(maxlen=int(memory_cap))\n","\n"," # Define and initialize actor network\n"," self.actor = actor(self.state_shape, self.action_shape, actor_units)\n"," self.actor_optimizer = Adam(learning_rate=lr_actor)\n"," self.log_std_min = -20\n"," self.log_std_max = 2\n"," print(self.actor.summary())\n","\n"," # Define and initialize critic networks\n"," self.critic_1 = critic(self.state_shape, self.action_shape, critic_units)\n"," self.critic_target_1 = critic(self.state_shape, self.action_shape, critic_units)\n"," self.critic_optimizer_1 = Adam(learning_rate=lr_critic)\n"," update_target_weights(self.critic_1, self.critic_target_1, tau=1.0)\n","\n"," self.critic_2 = critic(self.state_shape, self.action_shape, critic_units)\n"," self.critic_target_2 = critic(self.state_shape, self.action_shape, critic_units)\n"," self.critic_optimizer_2 = Adam(learning_rate=lr_critic)\n"," update_target_weights(self.critic_2, self.critic_target_2, tau=1.0)\n","\n"," print(self.critic_1.summary())\n","\n"," # Define and initialize temperature alpha and target entropy\n"," self.auto_alpha = auto_alpha\n"," if auto_alpha:\n"," self.target_entropy = -np.prod(self.action_shape)\n"," self.log_alpha = tf.Variable(0.0, dtype=tf.float64)\n"," self.alpha = tf.Variable(0.0, dtype=tf.float64)\n"," self.alpha.assign(tf.exp(self.log_alpha))\n"," self.alpha_optimizer = Adam(learning_rate=lr_actor)\n"," else:\n"," self.alpha = tf.Variable(alpha, dtype=tf.float64)\n","\n"," # Set hyperparameters\n"," self.gamma = gamma # discount factor\n"," self.tau = tau # target model update\n"," self.batch_size = batch_size\n","\n"," # Tensorboard\n"," self.summaries = {}\n","\n"," def process_actions(self, mean, log_std, test=False, eps=1e-6):\n"," std = tf.math.exp(log_std)\n"," raw_actions = mean\n","\n"," if not test:\n"," raw_actions += tf.random.normal(shape=mean.shape, dtype=tf.float64) * std\n","\n"," log_prob_u = tfp.distributions.Normal(loc=mean, scale=std).log_prob(raw_actions)\n"," actions = tf.math.tanh(raw_actions)\n","\n"," log_prob = tf.reduce_sum(log_prob_u - tf.math.log(1 - actions ** 2 + eps))\n","\n"," actions = actions * self.action_bound + self.action_shift\n","\n"," return actions, log_prob\n","\n"," def act(self, state, test=False, use_random=False):\n"," state = state.reshape(-1) # Flatten state\n"," state = np.expand_dims(state, axis=0).astype(np.float64)\n","\n"," if use_random and len(self.memory) > self.batch_size:\n"," a = tf.random.uniform(\n"," shape=(1, self.action_shape[0]), minval=-1, maxval=1, dtype=tf.float64\n"," )\n"," else:\n"," means, log_stds = self.actor.predict(state)\n"," log_stds = tf.clip_by_value(log_stds, self.log_std_min, self.log_std_max)\n","\n"," a, log_prob = self.process_actions(means, log_stds, test=test)\n","\n"," q1 = self.critic_1.predict([state, a])[0][0]\n"," q2 = self.critic_2.predict([state, a])[0][0]\n"," self.summaries[\"q_min\"] = tf.math.minimum(q1, q2)\n"," self.summaries[\"q_mean\"] = np.mean([q1, q2])\n","\n"," return a\n","\n"," def save_model(self, a_fn, c_fn):\n"," self.actor.save(a_fn)\n"," self.critic_1.save(c_fn)\n","\n"," def load_actor(self, a_fn):\n"," self.actor.load_weights(a_fn)\n"," print(self.actor.summary())\n","\n"," def load_critic(self, c_fn):\n"," self.critic_1.load_weights(c_fn)\n"," self.critic_target_1.load_weights(c_fn)\n"," self.critic_2.load_weights(c_fn)\n"," self.critic_target_2.load_weights(c_fn)\n"," print(self.critic_1.summary())\n","\n"," def remember(self, state, action, reward, next_state, done):\n"," state = state.reshape(-1) # Flatten state\n"," state = np.expand_dims(state, axis=0)\n"," next_state = next_state.reshape(-1) # Flatten next-state\n"," next_state = np.expand_dims(next_state, axis=0)\n"," self.memory.append([state, action, reward, next_state, done])\n","\n"," def replay(self):\n"," if len(self.memory) < self.batch_size:\n"," return\n","\n"," samples = random.sample(self.memory, self.batch_size)\n"," s = np.array(samples).T\n"," states, actions, rewards, next_states, dones = [\n"," np.vstack(s[i, :]).astype(np.float) for i in range(5)\n"," ]\n","\n"," with tf.GradientTape(persistent=True) as tape:\n"," # next state action log probs\n"," means, log_stds = self.actor(next_states)\n"," log_stds = tf.clip_by_value(log_stds, self.log_std_min, self.log_std_max)\n"," next_actions, log_probs = self.process_actions(means, log_stds)\n","\n"," # critics loss\n"," current_q_1 = self.critic_1([states, actions])\n"," current_q_2 = self.critic_2([states, actions])\n"," next_q_1 = self.critic_target_1([next_states, next_actions])\n"," next_q_2 = self.critic_target_2([next_states, next_actions])\n"," next_q_min = tf.math.minimum(next_q_1, next_q_2)\n"," state_values = next_q_min - self.alpha * log_probs\n"," target_qs = tf.stop_gradient(\n"," rewards + state_values * self.gamma * (1.0 - dones)\n"," )\n"," critic_loss_1 = tf.reduce_mean(\n"," 0.5 * tf.math.square(current_q_1 - target_qs)\n"," )\n"," critic_loss_2 = tf.reduce_mean(\n"," 0.5 * tf.math.square(current_q_2 - target_qs)\n"," )\n","\n"," # current state action log probs\n"," means, log_stds = self.actor(states)\n"," log_stds = tf.clip_by_value(log_stds, self.log_std_min, self.log_std_max)\n"," actions, log_probs = self.process_actions(means, log_stds)\n","\n"," # actor loss\n"," current_q_1 = self.critic_1([states, actions])\n"," current_q_2 = self.critic_2([states, actions])\n"," current_q_min = tf.math.minimum(current_q_1, current_q_2)\n"," actor_loss = tf.reduce_mean(self.alpha * log_probs - current_q_min)\n","\n"," # temperature loss\n"," if self.auto_alpha:\n"," alpha_loss = -tf.reduce_mean(\n"," (self.log_alpha * tf.stop_gradient(log_probs + self.target_entropy))\n"," )\n","\n"," critic_grad = tape.gradient(\n"," critic_loss_1, self.critic_1.trainable_variables\n"," ) # compute actor gradient\n"," self.critic_optimizer_1.apply_gradients(\n"," zip(critic_grad, self.critic_1.trainable_variables)\n"," )\n","\n"," critic_grad = tape.gradient(\n"," critic_loss_2, self.critic_2.trainable_variables\n"," ) # compute actor gradient\n"," self.critic_optimizer_2.apply_gradients(\n"," zip(critic_grad, self.critic_2.trainable_variables)\n"," )\n","\n"," actor_grad = tape.gradient(\n"," actor_loss, self.actor.trainable_variables\n"," ) # compute actor gradient\n"," self.actor_optimizer.apply_gradients(\n"," zip(actor_grad, self.actor.trainable_variables)\n"," )\n","\n"," # tensorboard info\n"," self.summaries[\"q1_loss\"] = critic_loss_1\n"," self.summaries[\"q2_loss\"] = critic_loss_2\n"," self.summaries[\"actor_loss\"] = actor_loss\n","\n"," if self.auto_alpha:\n"," # optimize temperature\n"," alpha_grad = tape.gradient(alpha_loss, [self.log_alpha])\n"," self.alpha_optimizer.apply_gradients(zip(alpha_grad, [self.log_alpha]))\n"," self.alpha.assign(tf.exp(self.log_alpha))\n"," # tensorboard info\n"," self.summaries[\"alpha_loss\"] = alpha_loss\n","\n"," def train(self, cur_state, action, reward, next_state, done):\n"," self.remember(cur_state, action, reward, next_state, done) # add to memory\n"," self.replay() # train models through memory replay\n"," update_target_weights(\n"," self.critic_1, self.critic_target_1, tau=self.tau\n"," ) # iterates target model\n"," update_target_weights(self.critic_2, self.critic_target_2, tau=self.tau)\n","\n"," def update_memory(self, xp_store):\n"," for (cur_state, action, reward, next_state, done) in zip(\n"," xp_store[\"cur_states\"],\n"," xp_store[\"actions\"],\n"," xp_store[\"rewards\"],\n"," xp_store[\"next_states\"],\n"," xp_store[\"dones\"],\n"," ):\n"," self.remember(cur_state, action, reward, next_state, done) # add to memory\n","\n"," def train_with_distributed_replay_memory(self, new_experiences):\n"," self.update_memory(new_experiences)\n"," self.replay() # train models through memory replay\n"," update_target_weights(\n"," self.critic_1, self.critic_target_1, tau=self.tau\n"," ) # iterates target model\n"," update_target_weights(self.critic_2, self.critic_target_2, tau=self.tau)\n","\n"," def log_status(self, summary_writer, episode_num, reward):\n"," \"\"\"Write training stats using TF `summary_writer`\"\"\"\n"," with summary_writer.as_default():\n"," if len(self.memory) > self.batch_size:\n"," tf.summary.scalar(\n"," \"Loss/actor_loss\", self.summaries[\"actor_loss\"], step=episode_num\n"," )\n"," tf.summary.scalar(\n"," \"Loss/q1_loss\", self.summaries[\"q1_loss\"], step=episode_num\n"," )\n"," tf.summary.scalar(\n"," \"Loss/q2_loss\", self.summaries[\"q2_loss\"], step=episode_num\n"," )\n"," if self.auto_alpha:\n"," tf.summary.scalar(\n"," \"Loss/alpha_loss\",\n"," self.summaries[\"alpha_loss\"],\n"," step=episode_num,\n"," )\n","\n"," tf.summary.scalar(\"Stats/alpha\", self.alpha, step=episode_num)\n"," if self.auto_alpha:\n"," tf.summary.scalar(\"Stats/log_alpha\", self.log_alpha, step=episode_num)\n"," tf.summary.scalar(\"Stats/q_min\", self.summaries[\"q_min\"], step=episode_num)\n"," tf.summary.scalar(\n"," \"Stats/q_mean\", self.summaries[\"q_mean\"], step=episode_num\n"," )\n"," tf.summary.scalar(\"Main/step_reward\", reward, step=episode_num)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"vH8PiR47GPEc"},"source":["## Building RL environment simulators as a service"]},{"cell_type":"markdown","metadata":{"id":"sPkSUD8SIZOg"},"source":["Our implementation will contain two core modules – the tradegym server and the tradegym client, which are built based on the OpenAI Gym HTTP API. The recipe will focus on the customizations and the core components of the HTTP service interface. We will first define a minimum set of custom environments exposed as part of the tradegym library and then build the server and client modules."]},{"cell_type":"code","metadata":{"id":"EAyj6BrDJPh0"},"source":["!pip install -U flask"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"-NwnKpqOIZvs"},"source":["!wget -q --show-progress https://github.com/RecoHut-Projects/drl-recsys/raw/S990517/tools/tradegym.zip\n","!unzip tradegym.zip"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XoGuSTIfMkS6","executionInfo":{"status":"ok","timestamp":1638516648019,"user_tz":-330,"elapsed":617,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"3a666787-75a5-4f4e-8880-568ba57065b9"},"source":["%%writefile tradegym_server.py\n","import argparse\n","import json\n","import logging\n","import os\n","import sys\n","import uuid\n","\n","import numpy as np\n","import six\n","from flask import Flask, jsonify, request\n","\n","import gym\n","import tradegym\n","\n","\n","logger = logging.getLogger(\"tradegym\")\n","logger.setLevel(logging.ERROR)\n","\n","\n","########## Container for environments ##########\n","class Envs(object):\n"," \"\"\"\n"," Container and manager for the environments instantiated\n"," on this server.\n"," When a new environment is created, such as with\n"," envs.create('CartPole-v0'), it is stored under a short\n"," identifier (such as '3c657dbc'). Future API calls make\n"," use of this instance_id to identify which environment\n"," should be manipulated.\n"," \"\"\"\n","\n"," def __init__(self):\n"," self.envs = {}\n"," self.id_len = 8\n","\n"," def _lookup_env(self, instance_id):\n"," try:\n"," return self.envs[instance_id]\n"," except KeyError:\n"," raise InvalidUsage(\"Instance_id {} unknown\".format(instance_id))\n","\n"," def _remove_env(self, instance_id):\n"," try:\n"," del self.envs[instance_id]\n"," except KeyError:\n"," raise InvalidUsage(\"Instance_id {} unknown\".format(instance_id))\n","\n"," def create(self, env_id, seed=None):\n"," try:\n"," env = gym.make(env_id)\n"," if seed:\n"," env.seed(seed)\n"," except gym.error.Error:\n"," raise InvalidUsage(\n"," \"Attempted to look up malformed environment ID '{}'\".format(env_id)\n"," )\n","\n"," instance_id = str(uuid.uuid4().hex)[: self.id_len]\n"," self.envs[instance_id] = env\n"," return instance_id\n","\n"," def list_all(self):\n"," return dict(\n"," [(instance_id, env.spec.id) for (instance_id, env) in self.envs.items()]\n"," )\n","\n"," def reset(self, instance_id):\n"," env = self._lookup_env(instance_id)\n"," obs = env.reset()\n"," return env.observation_space.to_jsonable(obs)\n","\n"," def step(self, instance_id, action, render):\n"," env = self._lookup_env(instance_id)\n"," if isinstance(action, six.integer_types):\n"," nice_action = action\n"," else:\n"," nice_action = np.array(action)\n"," if render:\n"," env.render()\n"," [observation, reward, done, info] = env.step(nice_action)\n"," obs_jsonable = env.observation_space.to_jsonable(observation)\n"," return [obs_jsonable, reward, done, info]\n","\n"," def get_action_space_contains(self, instance_id, x):\n"," env = self._lookup_env(instance_id)\n"," return env.action_space.contains(int(x))\n","\n"," def get_action_space_info(self, instance_id):\n"," env = self._lookup_env(instance_id)\n"," return self._get_space_properties(env.action_space)\n","\n"," def get_action_space_sample(self, instance_id):\n"," env = self._lookup_env(instance_id)\n"," action = env.action_space.sample()\n"," if isinstance(action, (list, tuple)) or (\"numpy\" in str(type(action))):\n"," try:\n"," action = action.tolist()\n"," except TypeError:\n"," print(type(action))\n"," print(\"TypeError\")\n"," return action\n","\n"," def get_observation_space_contains(self, instance_id, j):\n"," env = self._lookup_env(instance_id)\n"," info = self._get_space_properties(env.observation_space)\n"," for key, value in j.items():\n"," # Convert both values to json for comparibility\n"," if json.dumps(info[key]) != json.dumps(value):\n"," print(\n"," 'Values for \"{}\" do not match. Passed \"{}\", Observed \"{}\".'.format(\n"," key, value, info[key]\n"," )\n"," )\n"," return False\n"," return True\n","\n"," def get_observation_space_info(self, instance_id):\n"," env = self._lookup_env(instance_id)\n"," return self._get_space_properties(env.observation_space)\n","\n"," def _get_space_properties(self, space):\n"," info = {}\n"," info[\"name\"] = space.__class__.__name__\n"," if info[\"name\"] == \"Discrete\":\n"," info[\"n\"] = space.n\n"," elif info[\"name\"] == \"Box\":\n"," info[\"shape\"] = space.shape\n"," # It's not JSON compliant to have Infinity, -Infinity, NaN.\n"," # Many newer JSON parsers allow it, but many don't. Notably python json\n"," # module can read and write such floats. So we only here fix \"export version\",\n"," # also make it flat.\n"," info[\"low\"] = [\n"," (x if x != -np.inf else -1e100) for x in np.array(space.low).flatten()\n"," ]\n"," info[\"high\"] = [\n"," (x if x != +np.inf else +1e100) for x in np.array(space.high).flatten()\n"," ]\n"," elif info[\"name\"] == \"HighLow\":\n"," info[\"num_rows\"] = space.num_rows\n"," info[\"matrix\"] = [\n"," ((float(x) if x != -np.inf else -1e100) if x != +np.inf else +1e100)\n"," for x in np.array(space.matrix).flatten()\n"," ]\n"," return info\n","\n"," def monitor_start(self, instance_id, directory, force, resume, video_callable):\n"," env = self._lookup_env(instance_id)\n"," if video_callable == False:\n"," v_c = lambda count: False\n"," else:\n"," v_c = lambda count: count % video_callable == 0\n"," self.envs[instance_id] = gym.wrappers.Monitor(\n"," env, directory, force=force, resume=resume, video_callable=v_c\n"," )\n","\n"," def monitor_close(self, instance_id):\n"," env = self._lookup_env(instance_id)\n"," env.close()\n","\n"," def env_close(self, instance_id):\n"," env = self._lookup_env(instance_id)\n"," env.close()\n"," self._remove_env(instance_id)\n","\n","\n","app = Flask(__name__)\n","app.config[\"JSONIFY_PRETTYPRINT_REGULAR\"] = False\n","envs = Envs()\n","\n","\n","class InvalidUsage(Exception):\n"," status_code = 400\n","\n"," def __init__(self, message, status_code=None, payload=None):\n"," Exception.__init__(self)\n"," self.message = message\n"," if status_code is not None:\n"," self.status_code = status_code\n"," self.payload = payload\n","\n"," def to_dict(self):\n"," rv = dict(self.payload or ())\n"," rv[\"message\"] = self.message\n"," return rv\n","\n","\n","def get_required_param(json, param):\n"," if json is None:\n"," logger.info(\"Request is not a valid json\")\n"," raise InvalidUsage(\"Request is not a valid json\")\n"," value = json.get(param, None)\n"," if (value is None) or (value == \"\") or (value == []):\n"," logger.info(\n"," \"A required request parameter '{}' had value {}\".format(param, value)\n"," )\n"," raise InvalidUsage(\n"," \"A required request parameter '{}' was not provided\".format(param)\n"," )\n"," return value\n","\n","\n","def get_optional_param(json, param, default):\n"," if json is None:\n"," logger.info(\"Request is not a valid json\")\n"," raise InvalidUsage(\"Request is not a valid json\")\n"," value = json.get(param, None)\n"," if (value is None) or (value == \"\") or (value == []):\n"," logger.info(\n"," \"An optional request parameter '{}' had value {} and was replaced with default value {}\".format(\n"," param, value, default\n"," )\n"," )\n"," value = default\n"," return value\n","\n","\n","@app.errorhandler(InvalidUsage)\n","def handle_invalid_usage(error):\n"," response = jsonify(error.to_dict())\n"," response.status_code = error.status_code\n"," return response\n","\n","\n","########## API route definitions ##########\n","@app.route(\"/v1/envs/\", methods=[\"POST\"])\n","def env_create():\n"," \"\"\"\n"," Create an instance of the specified environment\n"," Parameters:\n"," - env_id: gym environment ID string, such as 'CartPole-v0'\n"," - seed: set the seed for this env's random number generator(s).\n"," Returns:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the created environment instance. The instance_id is\n"," used in future API calls to identify the environment to be\n"," manipulated\n"," \"\"\"\n"," env_id = get_required_param(request.get_json(), \"env_id\")\n"," seed = get_optional_param(request.get_json(), \"seed\", None)\n"," instance_id = envs.create(env_id, seed)\n"," return jsonify(instance_id=instance_id)\n","\n","\n","@app.route(\"/v1/envs/\", methods=[\"GET\"])\n","def env_list_all():\n"," \"\"\"\n"," List all environments running on the server\n"," Returns:\n"," - envs: dict mapping instance_id to env_id\n"," (e.g. {'3c657dbc': 'CartPole-v0'}) for every env\n"," on the server\n"," \"\"\"\n"," all_envs = envs.list_all()\n"," return jsonify(all_envs=all_envs)\n","\n","\n","@app.route(\"/v1/envs//reset/\", methods=[\"POST\"])\n","def env_reset(instance_id):\n"," \"\"\"\n"," Reset the state of the environment and return an initial\n"," observation.\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," Returns:\n"," - observation: the initial observation of the space\n"," \"\"\"\n"," observation = envs.reset(instance_id)\n"," if np.isscalar(observation):\n"," observation = observation.item()\n"," return jsonify(observation=observation)\n","\n","\n","@app.route(\"/v1/envs//step/\", methods=[\"POST\"])\n","def env_step(instance_id):\n"," \"\"\"\n"," Run one timestep of the environment's dynamics.\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," - action: an action to take in the environment\n"," Returns:\n"," - observation: agent's observation of the current\n"," environment\n"," - reward: amount of reward returned after previous action\n"," - done: whether the episode has ended\n"," - info: a dict containing auxiliary diagnostic information\n"," \"\"\"\n"," json = request.get_json()\n"," action = get_required_param(json, \"action\")\n"," render = get_optional_param(json, \"render\", False)\n"," [obs_jsonable, reward, done, info] = envs.step(instance_id, action, render)\n"," return jsonify(observation=obs_jsonable, reward=reward, done=done, info=info)\n","\n","\n","@app.route(\"/v1/envs//action_space/\", methods=[\"GET\"])\n","def env_action_space_info(instance_id):\n"," \"\"\"\n"," Get information (name and dimensions/bounds) of the env's\n"," action_space\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," Returns:\n"," - info: a dict containing 'name' (such as 'Discrete'), and\n"," additional dimensional info (such as 'n') which varies from\n"," space to space\n"," \"\"\"\n"," info = envs.get_action_space_info(instance_id)\n"," return jsonify(info=info)\n","\n","\n","@app.route(\"/v1/envs//action_space/sample\", methods=[\"GET\"])\n","def env_action_space_sample(instance_id):\n"," \"\"\"\n"," Get a sample from the env's action_space\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," Returns:\n"," - action: a randomly sampled element belonging to the action_space\n"," \"\"\"\n"," action = envs.get_action_space_sample(instance_id)\n"," return jsonify(action=action)\n","\n","\n","@app.route(\"/v1/envs//action_space/contains/\", methods=[\"GET\"])\n","def env_action_space_contains(instance_id, x):\n"," \"\"\"\n"," Assess that value is a member of the env's action_space\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," - x: the value to be checked as member\n"," Returns:\n"," - member: whether the value passed as parameter belongs to the action_space\n"," \"\"\"\n","\n"," member = envs.get_action_space_contains(instance_id, x)\n"," return jsonify(member=member)\n","\n","\n","@app.route(\"/v1/envs//observation_space/contains\", methods=[\"POST\"])\n","def env_observation_space_contains(instance_id):\n"," \"\"\"\n"," Assess that the parameters are members of the env's observation_space\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," Returns:\n"," - member: whether all the values passed belong to the observation_space\n"," \"\"\"\n"," j = request.get_json()\n"," member = envs.get_observation_space_contains(instance_id, j)\n"," return jsonify(member=member)\n","\n","\n","@app.route(\"/v1/envs//observation_space/\", methods=[\"GET\"])\n","def env_observation_space_info(instance_id):\n"," \"\"\"\n"," Get information (name and dimensions/bounds) of the env's\n"," observation_space\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," Returns:\n"," - info: a dict containing 'name' (such as 'Discrete'),\n"," and additional dimensional info (such as 'n') which\n"," varies from space to space\n"," \"\"\"\n"," info = envs.get_observation_space_info(instance_id)\n"," return jsonify(info=info)\n","\n","\n","@app.route(\"/v1/envs//monitor/start/\", methods=[\"POST\"])\n","def env_monitor_start(instance_id):\n"," \"\"\"\n"," Start monitoring.\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," - force (default=False): Clear out existing training\n"," data from this directory (by deleting every file\n"," prefixed with \"openaigym.\")\n"," - resume (default=False): Retain the training data\n"," already in this directory, which will be merged with\n"," our new data\n"," \"\"\"\n"," j = request.get_json()\n","\n"," directory = get_required_param(j, \"directory\")\n"," force = get_optional_param(j, \"force\", False)\n"," resume = get_optional_param(j, \"resume\", False)\n"," video_callable = get_optional_param(j, \"video_callable\", False)\n"," envs.monitor_start(instance_id, directory, force, resume, video_callable)\n"," return (\"\", 204)\n","\n","\n","@app.route(\"/v1/envs//monitor/close/\", methods=[\"POST\"])\n","def env_monitor_close(instance_id):\n"," \"\"\"\n"," Flush all monitor data to disk.\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," \"\"\"\n"," envs.monitor_close(instance_id)\n"," return (\"\", 204)\n","\n","\n","@app.route(\"/v1/envs//close/\", methods=[\"POST\"])\n","def env_close(instance_id):\n"," \"\"\"\n"," Manually close an environment\n"," Parameters:\n"," - instance_id: a short identifier (such as '3c657dbc')\n"," for the environment instance\n"," \"\"\"\n"," envs.env_close(instance_id)\n"," return (\"\", 204)\n","\n","\n","@app.route(\"/v1/upload/\", methods=[\"POST\"])\n","def upload():\n"," \"\"\"\n"," Upload the results of training (as automatically recorded by\n"," your env's monitor) to OpenAI Gym.\n"," Parameters:\n"," - training_dir: A directory containing the results of a\n"," training run.\n"," - api_key: Your OpenAI API key\n"," - algorithm_id (default=None): An arbitrary string\n"," indicating the paricular version of the algorithm\n"," (including choices of parameters) you are running.\n"," \"\"\"\n"," j = request.get_json()\n"," training_dir = get_required_param(j, \"training_dir\")\n"," api_key = get_required_param(j, \"api_key\")\n"," algorithm_id = get_optional_param(j, \"algorithm_id\", None)\n","\n"," try:\n"," gym.upload(\n"," training_dir,\n"," algorithm_id,\n"," writeup=None,\n"," api_key=api_key,\n"," ignore_open_monitors=False,\n"," )\n"," return (\"\", 204)\n"," except gym.error.AuthenticationError:\n"," raise InvalidUsage(\"You must provide an OpenAI Gym API key\")\n","\n","\n","@app.route(\"/v1/shutdown/\", methods=[\"POST\"])\n","def shutdown():\n"," \"\"\" Request a server shutdown - currently used by the integration tests to repeatedly create and destroy fresh copies of the server running in a separate thread\"\"\"\n"," f = request.environ.get(\"werkzeug.server.shutdown\")\n"," f()\n"," return \"Server shutting down\"\n","\n","\n","if __name__ == \"__main__\":\n"," parser = argparse.ArgumentParser(description=\"Start a Gym HTTP API server\")\n"," parser.add_argument(\n"," \"-l\", \"--listen\", help=\"interface to listen to\", default=\"0.0.0.0\"\n"," )\n"," parser.add_argument(\"-p\", \"--port\", default=6666, type=int, help=\"port to bind to\")\n","\n"," args = parser.parse_args()\n"," print(\"Server starting at: \" + \"http://{}:{}\".format(args.listen, args.port))\n"," app.run(host=args.listen, port=args.port, debug=True)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Writing tradegym_server.py\n"]}]},{"cell_type":"code","metadata":{"id":"Q8K9EYdCIk0a"},"source":["!sudo nohup python tradegym_server.py > log.txt 2>&1 &"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"DWOW3gWdIoEJ","executionInfo":{"status":"ok","timestamp":1638516672900,"user_tz":-330,"elapsed":627,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"23027dd4-daf0-4a03-e417-f1af6272ff69"},"source":["!head log.txt"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":[" * Running on http://0.0.0.0:6666/ (Press CTRL+C to quit)\n"," * Restarting with stat\n"," * Debugger is active!\n"," * Debugger PIN: 323-023-100\n"]}]},{"cell_type":"markdown","metadata":{"id":"ZqZOCOK3QMjG"},"source":["Client"]},{"cell_type":"code","metadata":{"id":"GObnTQTYQIUy"},"source":["!pip install mplfinance"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9nRRmT-JQTBl"},"source":["import warnings\n","warnings.filterwarnings('ignore')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"8kkdOzdvP-bd"},"source":["import os\n","import sys\n","import requests\n","import json\n","import logging\n","import six.moves.urllib.parse as urlparse\n","\n","import gym\n","import tradegym"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Afb6Qt4OIvII"},"source":["# host_ip = \"0.0.0.0\"\n","# host_port = 6666\n","# endpoint = \"v1/act\"\n","# env = gym.make(\"StockTradingContinuousEnv-v0\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"8yXMsA3ZIwgh"},"source":["# Create an App-level child logger\n","logger = logging.getLogger(\"TFRL-tradegym-client\")\n","# Set handler for this logger to handle messages\n","logger.addHandler(logging.StreamHandler())\n","# Set logging-level for this logger's handler\n","logger.setLevel(logging.DEBUG)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"yE3Ug10AI5eI"},"source":["class Client(object):\n"," \"\"\"\n"," Gym client to interface with gym_http_server\n"," \"\"\"\n","\n"," def __init__(self, remote_base):\n"," self.remote_base = remote_base\n"," self.session = requests.Session()\n"," self.session.headers.update({\"Content-type\": \"application/json\"})\n","\n"," def _parse_server_error_or_raise_for_status(self, resp):\n"," j = {}\n"," try:\n"," j = resp.json()\n"," except:\n"," # Most likely json parse failed because of network error, not server error (server\n"," # sends its errors in json). Don't let parse exception go up, but rather raise default\n"," # error.\n"," resp.raise_for_status()\n"," if (\n"," resp.status_code != 200 and \"message\" in j\n"," ): # descriptive message from server side\n"," raise ServerError(message=j[\"message\"], status_code=resp.status_code)\n"," resp.raise_for_status()\n"," return j\n","\n"," def _post_request(self, route, data):\n"," url = urlparse.urljoin(self.remote_base, route)\n"," # logger.info(\"POST {}\\n{}\".format(url, json.dumps(data)))\n"," resp = self.session.post(\n"," urlparse.urljoin(self.remote_base, route), data=json.dumps(data)\n"," )\n"," return self._parse_server_error_or_raise_for_status(resp)\n","\n"," def _get_request(self, route):\n"," url = urlparse.urljoin(self.remote_base, route)\n"," # logger.info(\"GET {}\".format(url))\n"," resp = self.session.get(url)\n"," return self._parse_server_error_or_raise_for_status(resp)\n","\n"," def env_create(self, env_id):\n"," route = \"/v1/envs/\"\n"," data = {\"env_id\": env_id}\n"," resp = self._post_request(route, data)\n"," instance_id = resp[\"instance_id\"]\n"," return instance_id\n","\n"," def env_list_all(self):\n"," route = \"/v1/envs/\"\n"," resp = self._get_request(route)\n"," all_envs = resp[\"all_envs\"]\n"," return all_envs\n","\n"," def env_reset(self, instance_id):\n"," route = \"/v1/envs/{}/reset/\".format(instance_id)\n"," resp = self._post_request(route, None)\n"," observation = resp[\"observation\"]\n"," return observation\n","\n"," def env_step(self, instance_id, action, render=False):\n"," route = \"/v1/envs/{}/step/\".format(instance_id)\n"," data = {\"action\": action, \"render\": render}\n"," resp = self._post_request(route, data)\n"," observation = resp[\"observation\"]\n"," reward = resp[\"reward\"]\n"," done = resp[\"done\"]\n"," info = resp[\"info\"]\n"," return [observation, reward, done, info]\n","\n"," def env_action_space_info(self, instance_id):\n"," route = \"/v1/envs/{}/action_space/\".format(instance_id)\n"," resp = self._get_request(route)\n"," info = resp[\"info\"]\n"," return info\n","\n"," def env_action_space_sample(self, instance_id):\n"," route = \"/v1/envs/{}/action_space/sample\".format(instance_id)\n"," resp = self._get_request(route)\n"," action = resp[\"action\"]\n"," return action\n","\n"," def env_action_space_contains(self, instance_id, x):\n"," route = \"/v1/envs/{}/action_space/contains/{}\".format(instance_id, x)\n"," resp = self._get_request(route)\n"," member = resp[\"member\"]\n"," return member\n","\n"," def env_observation_space_info(self, instance_id):\n"," route = \"/v1/envs/{}/observation_space/\".format(instance_id)\n"," resp = self._get_request(route)\n"," info = resp[\"info\"]\n"," return info\n","\n"," def env_observation_space_contains(self, instance_id, params):\n"," route = \"/v1/envs/{}/observation_space/contains\".format(instance_id)\n"," resp = self._post_request(route, params)\n"," member = resp[\"member\"]\n"," return member\n","\n"," def env_monitor_start(\n"," self, instance_id, directory, force=False, resume=False, video_callable=False\n"," ):\n"," route = \"/v1/envs/{}/monitor/start/\".format(instance_id)\n"," data = {\n"," \"directory\": directory,\n"," \"force\": force,\n"," \"resume\": resume,\n"," \"video_callable\": video_callable,\n"," }\n"," self._post_request(route, data)\n","\n"," def env_monitor_close(self, instance_id):\n"," route = \"/v1/envs/{}/monitor/close/\".format(instance_id)\n"," self._post_request(route, None)\n","\n"," def env_close(self, instance_id):\n"," route = \"/v1/envs/{}/close/\".format(instance_id)\n"," self._post_request(route, None)\n","\n"," def upload(self, training_dir, algorithm_id=None, api_key=None):\n"," if not api_key:\n"," api_key = os.environ.get(\"OPENAI_GYM_API_KEY\")\n","\n"," route = \"/v1/upload/\"\n"," data = {\n"," \"training_dir\": training_dir,\n"," \"algorithm_id\": algorithm_id,\n"," \"api_key\": api_key,\n"," }\n"," self._post_request(route, data)\n","\n"," def shutdown_server(self):\n"," route = \"/v1/shutdown/\"\n"," self._post_request(route, None)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MdHg1t6BOOzi"},"source":["class ServerError(Exception):\n"," def __init__(self, message, status_code=None):\n"," Exception.__init__(self)\n"," self.message = message\n"," if status_code is not None:\n"," self.status_code = status_code"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ga_fvsUyONaR","executionInfo":{"status":"ok","timestamp":1638517822470,"user_tz":-330,"elapsed":1339,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"808b2f8d-a5dc-411a-c53e-4cf505ef9c65"},"source":["if __name__ == \"__main__\":\n"," remote_base = \"http://0.0.0.0:6666\"\n"," client = Client(remote_base)\n","\n"," # Create environment\n"," env_id = \"StockTradingContinuousEnv-v0\"\n"," # env_id = \"CartPole-v0\"\n"," instance_id = client.env_create(env_id)\n","\n"," # Check properties\n"," all_envs = client.env_list_all()\n"," logger.info(f\"all_envs:{all_envs}\")\n"," action_info = client.env_action_space_info(instance_id)\n"," logger.info(f\"action_info:{action_info}\")\n"," obs_info = client.env_observation_space_info(instance_id)\n"," # logger.info(f\"obs_info:{obs_info}\")\n","\n"," # Run a single step\n"," init_obs = client.env_reset(instance_id)\n"," [observation, reward, done, info] = client.env_step(instance_id, 1, True)\n"," logger.info(f\"reward:{reward} done:{done} info:{info}\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["all_envs:{'208b8156': 'StockTradingContinuousEnv-v0'}\n","action_info:{'high': [1.0], 'low': [-1.0], 'name': 'Box', 'shape': [1]}\n","reward:0.0 done:False info:{}\n"]}]},{"cell_type":"code","metadata":{"id":"kMX09dT9MMF1"},"source":["# !kill -9 $(lsof -t -i:6666)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"U9i6mgMDLauI"},"source":["## Training Deep RL agents using remote simulators"]},{"cell_type":"code","metadata":{"id":"lyOmMj2ASwvO"},"source":["import datetime\n","import os\n","import sys\n","import logging\n","\n","import gym.spaces\n","import numpy as np\n","import tensorflow as tf"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"k2YXu28ySBuN","executionInfo":{"status":"ok","timestamp":1638518748902,"user_tz":-330,"elapsed":109431,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"4adf83a0-0e44-4638-dcb2-61da7914ec1b"},"source":["# Create an App-level child logger\n","logger = logging.getLogger(\"TFRL-training-with-sim-server\")\n","# Set handler for this logger to handle messages\n","logger.addHandler(logging.StreamHandler())\n","# Set logging-level for this logger's handler\n","logger.setLevel(logging.DEBUG)\n","\n","current_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n","train_log_dir = os.path.join(\"logs\", \"TFRL-SAC\", current_time)\n","summary_writer = tf.summary.create_file_writer(train_log_dir)\n","\n","\n","if __name__ == \"__main__\":\n","\n"," # Set up client to connect to sim server\n"," sim_service_address = \"http://0.0.0.0:6666\"\n"," client = Client(sim_service_address)\n","\n"," # Set up training environment\n"," env_id = \"StockTradingContinuousEnv-v0\"\n"," instance_id = client.env_create(env_id)\n","\n"," # Set up agent\n"," observation_space_info = client.env_observation_space_info(instance_id)\n"," observation_shape = observation_space_info.get(\"shape\")\n"," action_space_info = client.env_action_space_info(instance_id)\n"," action_space = gym.spaces.Box(\n"," np.array(action_space_info.get(\"low\")),\n"," np.array(action_space_info.get(\"high\")),\n"," action_space_info.get(\"shape\"),\n"," )\n"," agent = SAC(observation_shape, action_space)\n","\n"," # Configure training\n"," max_epochs = 500 # 30000\n"," random_epochs = 0.6 * max_epochs\n"," max_steps = 100\n"," save_freq = 100 # 500\n"," reward = 0\n"," done = False\n","\n"," done, use_random, episode, steps, epoch, episode_reward = (\n"," False,\n"," True,\n"," 0,\n"," 0,\n"," 0,\n"," 0,\n"," )\n"," cur_state = client.env_reset(instance_id)\n","\n"," # Start training\n"," while epoch < max_epochs:\n"," if steps > max_steps:\n"," done = True\n","\n"," if done:\n"," episode += 1\n"," logger.info(\n"," f\"episode:{episode} cumulative_reward:{episode_reward} steps:{steps} epochs:{epoch}\"\n"," )\n"," with summary_writer.as_default():\n"," tf.summary.scalar(\"Main/episode_reward\", episode_reward, step=episode)\n"," tf.summary.scalar(\"Main/episode_steps\", steps, step=episode)\n"," summary_writer.flush()\n","\n"," done, cur_state, steps, episode_reward = (\n"," False,\n"," client.env_reset(instance_id),\n"," 0,\n"," 0,\n"," )\n"," if episode % save_freq == 0:\n"," agent.save_model(\n"," f\"sac_actor_episode{episode}_{env_id}.h5\",\n"," f\"sac_critic_episode{episode}_{env_id}.h5\",\n"," )\n","\n"," if epoch > random_epochs:\n"," use_random = False\n","\n"," action = agent.act(np.array(cur_state), use_random=use_random)\n"," next_state, reward, done, _ = client.env_step(\n"," instance_id, action.numpy().tolist()\n"," )\n"," agent.train(np.array(cur_state), action, reward, np.array(next_state), done)\n","\n"," cur_state = next_state\n"," episode_reward += reward\n"," steps += 1\n"," epoch += 1\n","\n"," # Update Tensorboard with Agent's training status\n"," agent.log_status(summary_writer, epoch, reward)\n"," summary_writer.flush()\n","\n"," agent.save_model(\n"," f\"sac_actor_final_episode_{env_id}.h5\", f\"sac_critic_final_episode_{env_id}.h5\"\n"," )"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Model: \"model_5\"\n","__________________________________________________________________________________________________\n"," Layer (type) Output Shape Param # Connected to \n","==================================================================================================\n"," input_10 (InputLayer) [(None, 186)] 0 [] \n"," \n"," L0 (Dense) (None, 64) 11968 ['input_10[0][0]'] \n"," \n"," L1 (Dense) (None, 64) 4160 ['L0[0][0]'] \n"," \n"," Out_mean (Dense) (None, 1) 65 ['L1[0][0]'] \n"," \n"," Out_std (Dense) (None, 1) 65 ['L1[0][0]'] \n"," \n","==================================================================================================\n","Total params: 16,258\n","Trainable params: 16,258\n","Non-trainable params: 0\n","__________________________________________________________________________________________________\n","None\n","Model: \"model_6\"\n","__________________________________________________________________________________________________\n"," Layer (type) Output Shape Param # Connected to \n","==================================================================================================\n"," input_11 (InputLayer) [(None, 186)] 0 [] \n"," \n"," input_12 (InputLayer) [(None, 1)] 0 [] \n"," \n"," concatenate_4 (Concatenate) (None, 187) 0 ['input_11[0][0]', \n"," 'input_12[0][0]'] \n"," \n"," Hidden0 (Dense) (None, 64) 12032 ['concatenate_4[0][0]'] \n"," \n"," Hidden1 (Dense) (None, 64) 4160 ['Hidden0[0][0]'] \n"," \n"," Out_QVal (Dense) (None, 1) 65 ['Hidden1[0][0]'] \n"," \n","==================================================================================================\n","Total params: 16,257\n","Trainable params: 16,257\n","Non-trainable params: 0\n","__________________________________________________________________________________________________\n","None\n"]},{"output_type":"stream","name":"stderr","text":["episode:1 cumulative_reward:233.1740702673518 steps:9 epochs:9\n","episode:1 cumulative_reward:233.1740702673518 steps:9 epochs:9\n","episode:2 cumulative_reward:261.92559452597607 steps:9 epochs:18\n","episode:2 cumulative_reward:261.92559452597607 steps:9 epochs:18\n","episode:3 cumulative_reward:321.24403826611274 steps:9 epochs:27\n","episode:3 cumulative_reward:321.24403826611274 steps:9 epochs:27\n","episode:4 cumulative_reward:256.7701639306034 steps:9 epochs:36\n","episode:4 cumulative_reward:256.7701639306034 steps:9 epochs:36\n","episode:5 cumulative_reward:243.89784571937196 steps:9 epochs:45\n","episode:5 cumulative_reward:243.89784571937196 steps:9 epochs:45\n","episode:6 cumulative_reward:230.29517490639364 steps:9 epochs:54\n","episode:6 cumulative_reward:230.29517490639364 steps:9 epochs:54\n","episode:7 cumulative_reward:211.36595799415477 steps:9 epochs:63\n","episode:7 cumulative_reward:211.36595799415477 steps:9 epochs:63\n","episode:8 cumulative_reward:302.6277125060683 steps:9 epochs:72\n","episode:8 cumulative_reward:302.6277125060683 steps:9 epochs:72\n","episode:9 cumulative_reward:316.0823152105854 steps:9 epochs:81\n","episode:9 cumulative_reward:316.0823152105854 steps:9 epochs:81\n","episode:10 cumulative_reward:293.7166451140265 steps:9 epochs:90\n","episode:10 cumulative_reward:293.7166451140265 steps:9 epochs:90\n","episode:11 cumulative_reward:222.07366889740342 steps:9 epochs:99\n","episode:11 cumulative_reward:222.07366889740342 steps:9 epochs:99\n","episode:12 cumulative_reward:240.87059492654691 steps:9 epochs:108\n","episode:12 cumulative_reward:240.87059492654691 steps:9 epochs:108\n","episode:13 cumulative_reward:328.90013981451773 steps:9 epochs:117\n","episode:13 cumulative_reward:328.90013981451773 steps:9 epochs:117\n","episode:14 cumulative_reward:156.0423163113037 steps:9 epochs:126\n","episode:14 cumulative_reward:156.0423163113037 steps:9 epochs:126\n","episode:15 cumulative_reward:236.0697199301511 steps:9 epochs:135\n","episode:15 cumulative_reward:236.0697199301511 steps:9 epochs:135\n","episode:16 cumulative_reward:233.79245189225708 steps:9 epochs:144\n","episode:16 cumulative_reward:233.79245189225708 steps:9 epochs:144\n","episode:17 cumulative_reward:26.059310669382285 steps:9 epochs:153\n","episode:17 cumulative_reward:26.059310669382285 steps:9 epochs:153\n","episode:18 cumulative_reward:81.53442025301308 steps:9 epochs:162\n","episode:18 cumulative_reward:81.53442025301308 steps:9 epochs:162\n","episode:19 cumulative_reward:206.21227849492084 steps:9 epochs:171\n","episode:19 cumulative_reward:206.21227849492084 steps:9 epochs:171\n","episode:20 cumulative_reward:209.02516909434826 steps:9 epochs:180\n","episode:20 cumulative_reward:209.02516909434826 steps:9 epochs:180\n","episode:21 cumulative_reward:95.10058260351684 steps:9 epochs:189\n","episode:21 cumulative_reward:95.10058260351684 steps:9 epochs:189\n","episode:22 cumulative_reward:194.6398110285793 steps:9 epochs:198\n","episode:22 cumulative_reward:194.6398110285793 steps:9 epochs:198\n","episode:23 cumulative_reward:75.56473537750503 steps:9 epochs:207\n","episode:23 cumulative_reward:75.56473537750503 steps:9 epochs:207\n","episode:24 cumulative_reward:95.47650417557884 steps:9 epochs:216\n","episode:24 cumulative_reward:95.47650417557884 steps:9 epochs:216\n","episode:25 cumulative_reward:303.2763836178467 steps:9 epochs:225\n","episode:25 cumulative_reward:303.2763836178467 steps:9 epochs:225\n","episode:26 cumulative_reward:181.11959752307575 steps:9 epochs:234\n","episode:26 cumulative_reward:181.11959752307575 steps:9 epochs:234\n","episode:27 cumulative_reward:31.754469251230944 steps:9 epochs:243\n","episode:27 cumulative_reward:31.754469251230944 steps:9 epochs:243\n","episode:28 cumulative_reward:117.503465579195 steps:9 epochs:252\n","episode:28 cumulative_reward:117.503465579195 steps:9 epochs:252\n","episode:29 cumulative_reward:99.90602867259929 steps:9 epochs:261\n","episode:29 cumulative_reward:99.90602867259929 steps:9 epochs:261\n","episode:30 cumulative_reward:93.7334058460674 steps:9 epochs:270\n","episode:30 cumulative_reward:93.7334058460674 steps:9 epochs:270\n","episode:31 cumulative_reward:34.46928390982646 steps:9 epochs:279\n","episode:31 cumulative_reward:34.46928390982646 steps:9 epochs:279\n","episode:32 cumulative_reward:180.10973321057725 steps:9 epochs:288\n","episode:32 cumulative_reward:180.10973321057725 steps:9 epochs:288\n","episode:33 cumulative_reward:148.11403567440448 steps:9 epochs:297\n","episode:33 cumulative_reward:148.11403567440448 steps:9 epochs:297\n","episode:34 cumulative_reward:174.93217712407602 steps:9 epochs:306\n","episode:34 cumulative_reward:174.93217712407602 steps:9 epochs:306\n","episode:35 cumulative_reward:180.9962155749114 steps:9 epochs:315\n","episode:35 cumulative_reward:180.9962155749114 steps:9 epochs:315\n","episode:36 cumulative_reward:225.85023074597552 steps:9 epochs:324\n","episode:36 cumulative_reward:225.85023074597552 steps:9 epochs:324\n","episode:37 cumulative_reward:206.50479409381308 steps:9 epochs:333\n","episode:37 cumulative_reward:206.50479409381308 steps:9 epochs:333\n","episode:38 cumulative_reward:288.47638902850963 steps:9 epochs:342\n","episode:38 cumulative_reward:288.47638902850963 steps:9 epochs:342\n","episode:39 cumulative_reward:283.9976062844444 steps:9 epochs:351\n","episode:39 cumulative_reward:283.9976062844444 steps:9 epochs:351\n","episode:40 cumulative_reward:215.61304171572306 steps:9 epochs:360\n","episode:40 cumulative_reward:215.61304171572306 steps:9 epochs:360\n","episode:41 cumulative_reward:309.9738444209918 steps:9 epochs:369\n","episode:41 cumulative_reward:309.9738444209918 steps:9 epochs:369\n","episode:42 cumulative_reward:257.2864536465588 steps:9 epochs:378\n","episode:42 cumulative_reward:257.2864536465588 steps:9 epochs:378\n","episode:43 cumulative_reward:140.890514619103 steps:9 epochs:387\n","episode:43 cumulative_reward:140.890514619103 steps:9 epochs:387\n","episode:44 cumulative_reward:271.4749288865371 steps:9 epochs:396\n","episode:44 cumulative_reward:271.4749288865371 steps:9 epochs:396\n","episode:45 cumulative_reward:297.9361340009875 steps:9 epochs:405\n","episode:45 cumulative_reward:297.9361340009875 steps:9 epochs:405\n","episode:46 cumulative_reward:271.53708554801676 steps:9 epochs:414\n","episode:46 cumulative_reward:271.53708554801676 steps:9 epochs:414\n","episode:47 cumulative_reward:273.282989472988 steps:9 epochs:423\n","episode:47 cumulative_reward:273.282989472988 steps:9 epochs:423\n","episode:48 cumulative_reward:253.60024494308152 steps:9 epochs:432\n","episode:48 cumulative_reward:253.60024494308152 steps:9 epochs:432\n","episode:49 cumulative_reward:290.06738997472064 steps:9 epochs:441\n","episode:49 cumulative_reward:290.06738997472064 steps:9 epochs:441\n","episode:50 cumulative_reward:174.05891005797082 steps:9 epochs:450\n","episode:50 cumulative_reward:174.05891005797082 steps:9 epochs:450\n","episode:51 cumulative_reward:216.02808332903987 steps:9 epochs:459\n","episode:51 cumulative_reward:216.02808332903987 steps:9 epochs:459\n","episode:52 cumulative_reward:341.3635398833105 steps:9 epochs:468\n","episode:52 cumulative_reward:341.3635398833105 steps:9 epochs:468\n","episode:53 cumulative_reward:189.98273007161868 steps:9 epochs:477\n","episode:53 cumulative_reward:189.98273007161868 steps:9 epochs:477\n","episode:54 cumulative_reward:191.61824861133084 steps:9 epochs:486\n","episode:54 cumulative_reward:191.61824861133084 steps:9 epochs:486\n","episode:55 cumulative_reward:199.3696430663797 steps:9 epochs:495\n","episode:55 cumulative_reward:199.3696430663797 steps:9 epochs:495\n"]},{"output_type":"stream","name":"stdout","text":["WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n","WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n"]}]},{"cell_type":"markdown","metadata":{"id":"vsP4OKmBUX0P"},"source":["## Evaluating deep RL agents"]},{"cell_type":"markdown","metadata":{"id":"hGaHZJUOTu8B"},"source":["Let’s assume that you have trained the SAC agent in one of the trading environments using the training script (previous recipe) and that you have several versions of the trained agent models, each with different policy network architectures or hyperparameters or your own tweaks and customizations to improve its performance. When you want to deploy an agent, you want to make sure that you pick the best performing agent, don’t you?"]},{"cell_type":"markdown","metadata":{"id":"73tLkCZAUV_e"},"source":["We will build a lean script to evaluate a given pre-trained agent model locally so that you can get a quantitative performance assessment and compare several trained models before choosing the right agent model for deployment. Specifically, we will use the tradegym module and the sac_agent_runtime module that we built earlier in this chapter to evaluate the agent models that we train."]},{"cell_type":"code","metadata":{"id":"kLMILht6Uw0P"},"source":["import os\n","import sys\n","\n","from argparse import ArgumentParser\n","import imageio\n","import gym\n","\n","import tradegym\n","\n","parser = ArgumentParser(prog=\"TFRL-Evaluating-RL-Agents\")\n","parser.add_argument(\"--agent\", default=\"SAC\", help=\"Name of Agent. Default=SAC\")\n","parser.add_argument(\n"," \"--env\",\n"," default=\"StockTradingContinuousEnv-v0\",\n"," help=\"Name of Gym env. Default=StockTradingContinuousEnv-v0\",\n",")\n","parser.add_argument(\n"," \"--num-episodes\",\n"," default=10,\n"," help=\"Number of episodes to evaluate the agent. Default=100\",\n",")\n","parser.add_argument(\n"," \"--trained-models-dir\",\n"," default=\"/content\",\n"," help=\"Directory contained trained models.\",\n",")\n","parser.add_argument(\n"," \"--model-version\",\n"," default=\"final_episode_StockTradingContinuousEnv-v0\",\n"," help=\"Trained model version\",\n",")\n","parser.add_argument(\n"," \"--render\",\n"," type=bool,\n"," help=\"Render environment and write to file? (True/False). Default=True\",\n",")\n","args = parser.parse_args([])\n","\n","\n","if __name__ == \"__main__\":\n"," # Create an instance of the evaluation environment\n"," env = gym.make(args.env)\n"," if args.agent != \"SAC\":\n"," print(f\"Unsupported Agent: {args.agent}. Using SAC Agent\")\n"," args.agent = \"SAC\"\n"," # Create an instance of the Soft Actor-Critic Agent\n"," agent = SAC(env.observation_space.shape, env.action_space)\n"," # Load trained Agent model/brain\n"," model_version = args.model_version\n"," agent.load_actor(\n"," os.path.join(args.trained_models_dir, f\"sac_actor_{model_version}.h5\")\n"," )\n"," agent.load_critic(\n"," os.path.join(args.trained_models_dir, f\"sac_critic_{model_version}.h5\")\n"," )\n"," print(f\"Loaded {args.agent} agent with trained model version:{model_version}\")\n"," render = args.render\n"," # Evaluate/Test/Rollout Agent with trained model/brain\n"," video = imageio.get_writer(\"/content/agent_eval_video.mp4\", fps=30)\n"," avg_reward = 0\n"," for i in range(args.num_episodes):\n"," cur_state, done, rewards = env.reset(), False, 0\n"," while not done:\n"," action = agent.act(cur_state, test=True)\n"," next_state, reward, done, _ = env.step(action[0])\n"," cur_state = next_state\n"," rewards += reward\n"," if render:\n"," video.append_data(env.render(mode=\"rgb_array\"))\n"," print(f\"Episode#:{i} cumulative_reward:{rewards}\")\n"," avg_reward += rewards\n"," avg_reward /= args.num_episodes\n"," video.close()\n"," print(f\"Average rewards over {args.num_episodes} episodes: {avg_reward}\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"R3FOPry1VBCc"},"source":["%load_ext tensorboard"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"RYVrmJidWRhE"},"source":["%tensorboard --logdir /content/logs/TFRL-SAC"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K3mvMj_nXE_p"},"source":["![image.png]()"]},{"cell_type":"markdown","metadata":{"id":"c0vq9jqtYGRq"},"source":["## Packaging deep RL agents for cloud deployments"]},{"cell_type":"code","metadata":{"id":"vbasjLj1YTHc"},"source":["%%writefile sac_runtime_components.py\n","import functools\n","from collections import deque\n","\n","import numpy as np\n","import tensorflow as tf\n","import tensorflow_probability as tfp\n","from tensorflow.keras.layers import Concatenate, Dense, Input\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.optimizers import Adam\n","\n","tf.keras.backend.set_floatx(\"float64\")\n","\n","\n","def actor(state_shape, action_shape, units=(512, 256, 64)):\n"," state_shape_flattened = functools.reduce(lambda x, y: x * y, state_shape)\n"," state = Input(shape=state_shape_flattened)\n"," x = Dense(units[0], name=\"L0\", activation=\"relu\")(state)\n"," for index in range(1, len(units)):\n"," x = Dense(units[index], name=\"L{}\".format(index), activation=\"relu\")(x)\n","\n"," actions_mean = Dense(action_shape[0], name=\"Out_mean\")(x)\n"," actions_std = Dense(action_shape[0], name=\"Out_std\")(x)\n","\n"," model = Model(inputs=state, outputs=[actions_mean, actions_std])\n","\n"," return model\n","\n","\n","def critic(state_shape, action_shape, units=(512, 256, 64)):\n"," state_shape_flattened = functools.reduce(lambda x, y: x * y, state_shape)\n"," inputs = [Input(shape=state_shape_flattened), Input(shape=action_shape)]\n"," concat = Concatenate(axis=-1)(inputs)\n"," x = Dense(units[0], name=\"Hidden0\", activation=\"relu\")(concat)\n"," for index in range(1, len(units)):\n"," x = Dense(units[index], name=\"Hidden{}\".format(index), activation=\"relu\")(x)\n","\n"," output = Dense(1, name=\"Out_QVal\")(x)\n"," model = Model(inputs=inputs, outputs=output)\n","\n"," return model\n","\n","\n","def update_target_weights(model, target_model, tau=0.005):\n"," weights = model.get_weights()\n"," target_weights = target_model.get_weights()\n"," for i in range(len(target_weights)): # set tau% of target model to be new weights\n"," target_weights[i] = weights[i] * tau + target_weights[i] * (1 - tau)\n"," target_model.set_weights(target_weights)\n","\n","\n","class SAC(object):\n"," def __init__(\n"," self,\n"," observation_shape,\n"," action_space,\n"," lr_actor=3e-5,\n"," lr_critic=3e-4,\n"," actor_units=(64, 64),\n"," critic_units=(64, 64),\n"," auto_alpha=True,\n"," alpha=0.2,\n"," tau=0.005,\n"," gamma=0.99,\n"," batch_size=128,\n"," memory_cap=100000,\n"," ):\n"," self.state_shape = observation_shape # shape of observations\n"," self.action_shape = action_space.shape # number of actions\n"," self.action_bound = (action_space.high - action_space.low) / 2\n"," self.action_shift = (action_space.high + action_space.low) / 2\n"," self.memory = deque(maxlen=int(memory_cap))\n","\n"," # Define and initialize actor network\n"," self.actor = actor(self.state_shape, self.action_shape, actor_units)\n"," self.actor_optimizer = Adam(learning_rate=lr_actor)\n"," self.log_std_min = -20\n"," self.log_std_max = 2\n"," print(self.actor.summary())\n","\n"," # Define and initialize critic networks\n"," self.critic_1 = critic(self.state_shape, self.action_shape, critic_units)\n"," self.critic_target_1 = critic(self.state_shape, self.action_shape, critic_units)\n"," self.critic_optimizer_1 = Adam(learning_rate=lr_critic)\n"," update_target_weights(self.critic_1, self.critic_target_1, tau=1.0)\n","\n"," self.critic_2 = critic(self.state_shape, self.action_shape, critic_units)\n"," self.critic_target_2 = critic(self.state_shape, self.action_shape, critic_units)\n"," self.critic_optimizer_2 = Adam(learning_rate=lr_critic)\n"," update_target_weights(self.critic_2, self.critic_target_2, tau=1.0)\n","\n"," print(self.critic_1.summary())\n","\n"," # Define and initialize temperature alpha and target entropy\n"," self.auto_alpha = auto_alpha\n"," if auto_alpha:\n"," self.target_entropy = -np.prod(self.action_shape)\n"," self.log_alpha = tf.Variable(0.0, dtype=tf.float64)\n"," self.alpha = tf.Variable(0.0, dtype=tf.float64)\n"," self.alpha.assign(tf.exp(self.log_alpha))\n"," self.alpha_optimizer = Adam(learning_rate=lr_actor)\n"," else:\n"," self.alpha = tf.Variable(alpha, dtype=tf.float64)\n","\n"," # Set hyperparameters\n"," self.gamma = gamma # discount factor\n"," self.tau = tau # target model update\n"," self.batch_size = batch_size\n","\n"," # Tensorboard\n"," self.summaries = {}\n","\n"," def process_actions(self, mean, log_std, test=False, eps=1e-6):\n"," std = tf.math.exp(log_std)\n"," raw_actions = mean\n","\n"," if not test:\n"," raw_actions += tf.random.normal(shape=mean.shape, dtype=tf.float64) * std\n","\n"," log_prob_u = tfp.distributions.Normal(loc=mean, scale=std).log_prob(raw_actions)\n"," actions = tf.math.tanh(raw_actions)\n","\n"," log_prob = tf.reduce_sum(log_prob_u - tf.math.log(1 - actions ** 2 + eps))\n","\n"," actions = actions * self.action_bound + self.action_shift\n","\n"," return actions, log_prob\n","\n"," def act(self, state, test=False, use_random=False):\n"," state = state.reshape(-1) # Flatten state\n"," state = np.expand_dims(state, axis=0).astype(np.float64)\n","\n"," if use_random:\n"," a = tf.random.uniform(\n"," shape=(1, self.action_shape[0]), minval=-1, maxval=1, dtype=tf.float64\n"," )\n"," else:\n"," means, log_stds = self.actor.predict(state)\n"," log_stds = tf.clip_by_value(log_stds, self.log_std_min, self.log_std_max)\n","\n"," a, log_prob = self.process_actions(means, log_stds, test=test)\n","\n"," q1 = self.critic_1.predict([state, a])[0][0]\n"," q2 = self.critic_2.predict([state, a])[0][0]\n"," self.summaries[\"q_min\"] = tf.math.minimum(q1, q2)\n"," self.summaries[\"q_mean\"] = np.mean([q1, q2])\n","\n"," return a\n","\n"," def load_actor(self, a_fn):\n"," self.actor.load_weights(a_fn)\n"," print(self.actor.summary())\n","\n"," def load_critic(self, c_fn):\n"," self.critic_1.load_weights(c_fn)\n"," self.critic_target_1.load_weights(c_fn)\n"," self.critic_2.load_weights(c_fn)\n"," self.critic_target_2.load_weights(c_fn)\n"," print(self.critic_1.summary())"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"T6YFCS2rYG9B","executionInfo":{"status":"ok","timestamp":1638519872409,"user_tz":-330,"elapsed":498,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"7b6bce9d-fe95-4c80-87d1-8cf8d9d058cc"},"source":["%%writefile trading_agent.py\n","import os\n","import sys\n","from argparse import ArgumentParser\n","\n","import gym.spaces\n","from flask import Flask, request\n","import numpy as np\n","\n","from sac_runtime_components import SAC\n","\n","\n","parser = ArgumentParser(\n"," prog=\"TFRL-Packaging-RL-Agents-For-Cloud-Deployments\"\n",")\n","\n","parser.add_argument(\"--agent\", default=\"SAC\", help=\"Name of Agent. Default=SAC\")\n","parser.add_argument(\n"," \"--host-ip\",\n"," default=\"0.0.0.0\",\n"," help=\"IP Address of the host server where Agent service is run. Default=127.0.0.1\",\n",")\n","parser.add_argument(\n"," \"--host-port\",\n"," default=\"5555\",\n"," help=\"Port on the host server to use for Agent service. Default=5555\",\n",")\n","parser.add_argument(\n"," \"--trained-models-dir\",\n"," default=\"/content\",\n"," help=\"Directory contained trained models. Default=content\",\n",")\n","parser.add_argument(\n"," \"--config\",\n"," default=\"runtime_config.json\",\n"," help=\"Runtime config parameters for the Agent. Default=runtime_config.json\",\n",")\n","parser.add_argument(\n"," \"--observation-shape\",\n"," default=(6, 31),\n"," help=\"Shape of observations. Default=(6, 31)\",\n",")\n","parser.add_argument(\n"," \"--action-space-low\", default=[-1], help=\"Low value of action space. Default=[-1]\"\n",")\n","parser.add_argument(\n"," \"--action-space-high\", default=[1], help=\"High value of action space. Default=[1]\"\n",")\n","parser.add_argument(\n"," \"--action-shape\", default=(1,), help=\"Shape of actions. Default=(1,)\"\n",")\n","parser.add_argument(\n"," \"--model-version\",\n"," default=\"final_episode_StockTradingContinuousEnv-v0\",\n"," help=\"Trained model version\",\n",")\n","args = parser.parse_args()\n","\n","\n","if __name__ == \"__main__\":\n"," if args.agent != \"SAC\":\n"," print(f\"Unsupported Agent: {args.agent}. Using SAC Agent\")\n"," args.agent = \"SAC\"\n"," # Set Agent's runtime configs\n"," observation_shape = args.observation_shape\n"," action_space = gym.spaces.Box(\n"," np.array(args.action_space_low),\n"," np.array(args.action_space_high),\n"," args.action_shape,\n"," )\n","\n"," # Create an instance of the Agent\n"," agent = SAC(observation_shape, action_space)\n"," # Load trained Agent model/brain\n"," model_version = args.model_version\n"," agent.load_actor(\n"," os.path.join(args.trained_models_dir, f\"sac_actor_{model_version}.h5\")\n"," )\n"," agent.load_critic(\n"," os.path.join(args.trained_models_dir, f\"sac_critic_{model_version}.h5\")\n"," )\n"," print(f\"Loaded {args.agent} agent with trained model version:{model_version}\")\n","\n"," # Setup Agent (http) service\n"," app = Flask(__name__)\n","\n"," @app.route(\"/v1/act\", methods=[\"POST\"])\n"," def get_action():\n"," data = request.get_json()\n"," action = agent.act(np.array(data.get(\"observation\")), test=True)\n"," return {\"action\": action.numpy().tolist()}\n","\n"," # Launch/Run the Agent (http) service\n"," app.run(host=args.host_ip, port=args.host_port, debug=True)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Writing trading_agent.py\n"]}]},{"cell_type":"code","metadata":{"id":"X8VGTqOYZRkG"},"source":["!sudo nohup python trading_agent.py > log_trading_agent.txt 2>&1 &"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6uy4LrGHZRkG","executionInfo":{"status":"ok","timestamp":1638520234396,"user_tz":-330,"elapsed":436,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"6097a6d5-2864-4d1d-e28d-db4f587174a4"},"source":["!head log_trading_agent.txt"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["2021-12-03 08:29:23.726443: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n","Model: \"model\"\n","__________________________________________________________________________________________________\n"," Layer (type) Output Shape Param # Connected to \n","==================================================================================================\n"," input_1 (InputLayer) [(None, 186)] 0 [] \n"," \n"," L0 (Dense) (None, 64) 11968 ['input_1[0][0]'] \n"," \n"," L1 (Dense) (None, 64) 4160 ['L0[0][0]'] \n"]}]},{"cell_type":"markdown","metadata":{"id":"sflc0BYJZWgq"},"source":["## Simple test for the deployed Trading Bot-as-a-Service"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ANaCrnOGauZk","executionInfo":{"status":"ok","timestamp":1638520309028,"user_tz":-330,"elapsed":950,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"a2e88932-0e52-4239-b60c-9ef130567c19"},"source":["import os\n","import sys\n","\n","import gym\n","import requests\n","\n","import tradegym # Register tradegym envs with OpenAI Gym registry\n","\n","host_ip = \"127.0.0.1\"\n","host_port = 5555\n","endpoint = \"v1/act\"\n","env = gym.make(\"StockTradingContinuousEnv-v0\")\n","\n","post_data = {\"observation\": env.observation_space.sample().tolist()}\n","res = requests.post(f\"http://{host_ip}:{host_port}/{endpoint}\", json=post_data)\n","if res.ok:\n"," print(f\"Received Agent action:{res.json()}\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Received Agent action:{'action': [[0.2692909986143388]]}\n"]}]},{"cell_type":"markdown","metadata":{"id":"ZZZGrJ1lbsox"},"source":["---"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"7vSyzqg4bsoy","executionInfo":{"status":"ok","timestamp":1638520570638,"user_tz":-330,"elapsed":7178,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"bc9bfb31-31ee-488a-8f26-441e61f2f903"},"source":["!apt-get -qq install tree\n","!rm -r sample_data"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Selecting previously unselected package tree.\n","(Reading database ... 155222 files and directories currently installed.)\n","Preparing to unpack .../tree_1.7.0-5_amd64.deb ...\n","Unpacking tree (1.7.0-5) ...\n","Setting up tree (1.7.0-5) ...\n","Processing triggers for man-db (2.8.3-2ubuntu0.1) ...\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"h_e9ZLORbsoy","executionInfo":{"status":"ok","timestamp":1638520570639,"user_tz":-330,"elapsed":22,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"f5358163-3101-41fb-972a-60d145ca9fb4"},"source":["!tree -h --du ."],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":[".\n","├── [924K] logs\n","│   └── [920K] TFRL-SAC\n","│   ├── [623K] 20211203-080023\n","│   │   └── [619K] events.out.tfevents.1638518423.fe9c99dc08cc.645.0.v2\n","│   └── [293K] 20211203-080403\n","│   └── [289K] events.out.tfevents.1638518643.fe9c99dc08cc.645.1.v2\n","├── [8.1K] log_trading_agent.txt\n","├── [141K] log.txt\n","├── [8.9K] __pycache__\n","│   └── [4.9K] sac_runtime_components.cpython-37.pyc\n","├── [148K] sac_actor_final_episode_StockTradingContinuousEnv-v0.h5\n","├── [146K] sac_critic_final_episode_StockTradingContinuousEnv-v0.h5\n","├── [5.8K] sac_runtime_components.py\n","├── [ 83K] tradegym\n","│   ├── [7.3K] crypto_trading_env.py\n","│   ├── [ 41K] data\n","│   │   ├── [ 19K] MSFT.csv\n","│   │   └── [ 18K] TSLA.csv\n","│   ├── [ 775] __init__.py\n","│   ├── [ 14K] __pycache__\n","│   │   ├── [ 685] __init__.cpython-37.pyc\n","│   │   ├── [4.3K] stock_trading_continuous_env.cpython-37.pyc\n","│   │   └── [5.4K] trading_utils.cpython-37.pyc\n","│   ├── [7.1K] stock_trading_continuous_env.py\n","│   └── [8.2K] trading_utils.py\n","├── [ 16K] tradegym_server.py\n","├── [ 20K] tradegym.zip\n","└── [2.7K] trading_agent.py\n","\n"," 1.5M used in 8 directories, 20 files\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"z777BEW9bsoz","executionInfo":{"status":"ok","timestamp":1638520664655,"user_tz":-330,"elapsed":6813,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"835c2920-60e8-4d3e-f4fa-1a72851c7d58"},"source":["!pip install -q watermark\n","%reload_ext watermark\n","%watermark -a \"Sparsh A.\" -m -iv -u -t -d -p numpy,tensorflow,flask"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Author: Sparsh A.\n","\n","Last updated: 2021-12-03 08:37:44\n","\n","numpy : 1.19.5\n","tensorflow: 2.7.0\n","flask : 2.0.2\n","\n","Compiler : GCC 7.5.0\n","OS : Linux\n","Release : 5.4.104+\n","Machine : x86_64\n","Processor : x86_64\n","CPU cores : 2\n","Architecture: 64bit\n","\n","gym : 0.17.3\n","sys : 3.7.12 (default, Sep 10 2021, 00:21:48) \n","[GCC 7.5.0]\n","IPython : 5.5.0\n","requests: 2.23.0\n","\n"]}]},{"cell_type":"markdown","metadata":{"id":"cXAMfLf9bsoz"},"source":["---"]},{"cell_type":"markdown","metadata":{"id":"qYeBkCjDbso0"},"source":["**END**"]}]}