{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PPO\n", "\n", "The goal of this exercise is to use the `tianshou` library to apply PPO on the cartpole environment. `tianshou` is the latest and most up-to-date DRL library. It is based on pytorch for the deep networks and is the only library currently compatible with gymnasium, not gym.\n", "\n", "Github: \\\n", "Documentation: \n", "\n", "Install it in your virtual environment simply with:\n", "\n", "```bash\n", "pip install -U tianshou\n", "```\n", "\n", "It will also install pytorch, which becomes double use with tensorflow, but well, storage is cheap...\n", "\n", "Let's first import the usual stuff:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "\n", "if IN_COLAB:\n", " !pip install -U gymnasium pygame moviepy\n", " !pip install -U tianshou" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gym version: 0.28.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/vitay/.virtualenvs/tianshou/lib/python3.11/site-packages/pygame/pkgdata.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html\n", " from pkg_resources import resource_stream, resource_exists\n" ] } ], "source": [ "import numpy as np\n", "rng = np.random.default_rng()\n", "import matplotlib.pyplot as plt\n", "import os\n", "from IPython.display import clear_output\n", "from collections import deque\n", "\n", "import gymnasium as gym\n", "print(\"gym version:\", gym.__version__)\n", "\n", "import tianshou as ts\n", "\n", "import torch\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "import pygame\n", "from moviepy.editor import ImageSequenceClip, ipython_display\n", "\n", "\n", "class GymRecorder(object):\n", " \"\"\"\n", " Simple wrapper over moviepy to generate a .gif with the frames of a gym environment.\n", " \n", " The environment must have the render_mode `rgb_array_list`.\n", " \"\"\"\n", " def __init__(self, env):\n", " self.env = env\n", " self._frames = []\n", "\n", " def record(self, frames):\n", " \"To be called at the end of an episode.\"\n", " for frame in frames:\n", " self._frames.append(np.array(frame))\n", "\n", " def make_video(self, filename):\n", " \"Generates the gif video.\"\n", " directory = os.path.dirname(os.path.abspath(filename))\n", " if not os.path.exists(directory):\n", " os.mkdir(directory)\n", " self.clip = ImageSequenceClip(list(self._frames), fps=self.env.metadata[\"render_fps\"])\n", " self.clip.write_gif(filename, fps=self.env.metadata[\"render_fps\"], loop=0)\n", " del self._frames\n", " self._frames = []" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## tianshou\n", "\n", "### Structure\n", "\n", "``tianshou`` provides an implementation of most model-free algorithms seen in the course: DQN and its variants, A3C, DDPG, PPO and more. It also has several offline RL algorithms. You can see the list of algorithms here:\n", "\n", "\n", "\n", "``tianshou`` relies on several concepts, which are explained here:\n", "\n", "\n", "\n", "![](https://tianshou.readthedocs.io/en/latest/_images/concepts_arch2.png)\n", "\n", "* The **policy** is actually the DRL algorithm (DQN, PPO), not the mapping from states into actions used in the course. It relies on one (or more) neural networks called the **model**.\n", "* The interaction of the policy with the environment is done by the **collector**. By default, the collector used **distributed learning**, i.e. it uses parallel workers to interact with copies of the environment, thereby speeding up data collection. This is used even for algorithms which do not need distributed learning (DQN), as it is only beneficial. \n", "* The data collected by the collector is stored in a **buffer**, which can be an ERM for off-policy algorithms or a temporary buffer for on-policy ones.\n", "* The (distributed) data is stored in **batches**. How data circulates between the collector, the policy and the buffer during training is controlled by the **trainer**.\n", "\n", "Let's demonstrate this interaction with a dummy DQN network on Cartpole:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "env = gym.make('CartPole-v0')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Policy\n", "\n", "The first step is to create the neural network for the DQN network. It must have `env.observation_space.shape=4` input neurons and `env.action_space.n=2` discrete output neurons. Let's put a single hidden layer with 32 neurons for now:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "net = ts.utils.net.common.Net(\n", " env.observation_space.shape,\n", " env.action_space.n,\n", " hidden_sizes=[64, 64],\n", " device=device,\n", ").to(device)\n", "\n", "optim = torch.optim.Adam(net.parameters(), lr=0.001)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`optim` is the Adam optimizer in pytorch, modifying all parameters (weights and biases) of the value network. Check the doc of `Net()` if you want a more specific architecture.\n", "\n", "The output layer of the network is discrete, so that tianshou knows how to sample an action from the output (here the output neurons represent the Q-values, but it could be logits of a continuous policy). \n", "\n", "Now that we have the neural network, we can create the DQN policy object:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "policy = ts.policy.DQNPolicy(\n", " model=net, # value network\n", " optim=optim, # optimizer\n", " discount_factor=0.95, # gamma\n", " target_update_freq=1000, # how often to update the target network\n", " action_space=env.action_space, # action space\n", ")\n", "policy.set_eps(0.1) # epsilon-greedy action selection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the doc of `DQNPolicy` for additional parameters (e.g. to implement a double duelling DQN).\n", "\n", "We can now use the policy to interact with the environment as usual and visualize a trial with an untrained network:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Return: 10.0\n", "MoviePy - Building file videos/cartpole-before.gif with imageio.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluation mode\n", "policy.eval() \n", "\n", "# Create a recordable environment\n", "env = gym.make('CartPole-v0', render_mode=\"rgb_array_list\")\n", "recorder = GymRecorder(env)\n", "\n", "# Sample the initial state\n", "state, info = env.reset()\n", "\n", "# One episode:\n", "done = False\n", "return_episode = 0\n", "while not done:\n", "\n", " # Select an action from the learned policy\n", " action = policy.forward(ts.data.Batch(obs=[state], info=None)).act[0]\n", " \n", " # Sample a single transition\n", " next_state, reward, terminal, truncated, info = env.step(action)\n", "\n", " # End of the episode\n", " done = terminal or truncated\n", "\n", " # Update undiscounted return\n", " return_episode += reward\n", " \n", " # Go in the next state\n", " state = next_state\n", "\n", "print(\"Return:\", return_episode)\n", "\n", "recorder.record(env.render())\n", "video = \"videos/cartpole-before.gif\"\n", "recorder.make_video(video)\n", "ipython_display(video)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The action selection is done by calling `policy.forward()` on a batch of data containing only the current state:\n", "\n", "```python\n", "action = policy.forward(ts.data.Batch(obs=[state], info=None)).act[0]\n", "```\n", "\n", "### Collector\n", "\n", "As we have seen in the DQN exercise, using a neural network with a batch size of 1 is extremely inefficient and slow. It is much better to use **distributed learning** and parallel workers to collect data. That way, we can form a minibatch of states that can be processed efficiently by the NN. " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "collector = ts.data.Collector(\n", " policy=policy, \n", " env=ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]),\n", " exploration_noise=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`policy` is the exploration policy. The `exploration_noise` flag allows to switch exploration on and off. For a discrete DQN policy, this impacts the $\\epsilon$-greedy action selection scheme, but other algorithms might use another mechanism (softmax, Gaussian policies, Ornstein-Uhlenbeck, noisy parameters, etc).\n", "\n", "`ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])` means that we create 10 copies of the Cartpole environment which will be acted upon in parallel using the policy.\n", "\n", "Let's collect some data with the collector. We can either collect a fixed number of steps (over the parallel workers) or episodes. Let's start with 10 steps:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'n/ep': 0,\n", " 'n/st': 10,\n", " 'rews': array([], dtype=float64),\n", " 'lens': array([], dtype=int64),\n", " 'idxs': array([], dtype=int64),\n", " 'rew': 0,\n", " 'len': 0,\n", " 'rew_std': 0,\n", " 'len_std': 0}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "collector.collect(n_step=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's weird, we apparently did not receive any reward. Let's try to collect more steps." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'n/ep': 98,\n", " 'n/st': 1000,\n", " 'rews': array([ 8., 9., 9., 10., 10., 10., 11., 12., 12., 13., 8., 9., 10.,\n", " 10., 9., 11., 10., 9., 10., 17., 10., 10., 10., 11., 10., 10.,\n", " 10., 10., 11., 8., 10., 10., 10., 10., 8., 11., 9., 10., 14.,\n", " 9., 10., 9., 10., 10., 13., 9., 11., 11., 10., 10., 9., 9.,\n", " 13., 9., 9., 10., 9., 10., 8., 10., 10., 8., 9., 10., 10.,\n", " 11., 9., 12., 10., 10., 13., 11., 9., 10., 11., 10., 12., 10.,\n", " 8., 9., 9., 9., 10., 11., 10., 9., 10., 9., 9., 9., 12.,\n", " 10., 9., 9., 10., 10., 12., 9.]),\n", " 'lens': array([ 8, 9, 9, 10, 10, 10, 11, 12, 12, 13, 8, 9, 10, 10, 9, 11, 10,\n", " 9, 10, 17, 10, 10, 10, 11, 10, 10, 10, 10, 11, 8, 10, 10, 10, 10,\n", " 8, 11, 9, 10, 14, 9, 10, 9, 10, 10, 13, 9, 11, 11, 10, 10, 9,\n", " 9, 13, 9, 9, 10, 9, 10, 8, 10, 10, 8, 9, 10, 10, 11, 9, 12,\n", " 10, 10, 13, 11, 9, 10, 11, 10, 12, 10, 8, 9, 9, 9, 10, 11, 10,\n", " 9, 10, 9, 9, 9, 12, 10, 9, 9, 10, 10, 12, 9]),\n", " 'idxs': array([7, 3, 6, 1, 2, 5, 8, 0, 9, 4, 1, 6, 7, 3, 8, 5, 0, 4, 9, 2, 1, 6,\n", " 7, 3, 8, 5, 4, 9, 0, 6, 2, 1, 7, 3, 0, 8, 9, 4, 5, 2, 6, 7, 3, 0,\n", " 1, 4, 8, 9, 5, 2, 7, 3, 6, 1, 4, 0, 9, 8, 7, 5, 2, 3, 0, 1, 4, 6,\n", " 8, 9, 5, 2, 7, 3, 6, 0, 1, 8, 4, 9, 2, 5, 7, 3, 6, 0, 1, 4, 8, 9,\n", " 2, 7, 5, 3, 1, 8, 0, 4, 6, 9]),\n", " 'rew': 10.051020408163266,\n", " 'len': 10.051020408163266,\n", " 'rew_std': 1.395125983482348,\n", " 'len_std': 1.395125983482348}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "collector.collect(n_step=1000)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alright, returns are only reported at the end of an episode. With 1000 steps (in parallel over 10 workers, i.e. each of them did 100 steps), we collected around 100 episodes of length 9 or 10, i.e. the cartpole falls right away, as expected with a random policy.\n", "\n", "Can we collect complete episodes? Yes:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'n/ep': 10,\n", " 'n/st': 71,\n", " 'rews': array([10., 9., 11., 10., 8., 9., 10., 9., 10., 10.]),\n", " 'lens': array([10, 9, 11, 10, 8, 9, 10, 9, 10, 10]),\n", " 'idxs': array([2, 5, 7, 3, 9, 0, 1, 4, 8, 6]),\n", " 'rew': 9.6,\n", " 'len': 9.6,\n", " 'rew_std': 0.8,\n", " 'len_std': 0.8}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "collector.collect(n_episode=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "But where is the data, i.e. the collected transitions? Nowhere, because we forgot to create a buffer to store them. Let's fix that mistake." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "collector = ts.data.Collector(\n", " policy=policy, \n", " env=ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]),\n", " buffer=ts.data.VectorReplayBuffer(1000, 10),\n", " exploration_noise=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We preallocate an ERM of max 1000 transitions for each of the 10 workers. One could use a single replay buffer 10 times bigger, but tianshou requires it . Let's collect some episodes and look at the data stored in the first buffer:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ReplayBuffer(\n", " obs: array([[ 3.4645274e-02, -3.8563393e-02, 2.8588787e-02, 4.0672716e-02],\n", " [ 3.3874005e-02, 1.5613718e-01, 2.9402241e-02, -2.4285485e-01],\n", " [ 3.6996752e-02, 3.5082710e-01, 2.4545144e-02, -5.2612048e-01],\n", " [ 4.4013292e-02, 5.4559523e-01, 1.4022734e-02, -8.1096911e-01],\n", " [ 5.4925196e-02, 7.4052227e-01, -2.1966479e-03, -1.0992084e+00],\n", " [ 6.9735639e-02, 9.3567306e-01, -2.4180816e-02, -1.3925797e+00],\n", " [ 8.8449106e-02, 1.1310875e+00, -5.2032411e-02, -1.6927242e+00],\n", " [ 1.1107086e-01, 1.3267703e+00, -8.5886896e-02, -2.0011418e+00],\n", " [ 1.3760626e-01, 1.5226773e+00, -1.2590973e-01, -2.3191388e+00],\n", " [ 1.6805981e-01, 1.7187009e+00, -1.7229250e-01, -2.6477661e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],\n", " dtype=float32),\n", " rew: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),\n", " truncated: array([False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False]),\n", " act: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n", " done: array([False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False]),\n", " policy: Batch(),\n", " info: Batch(\n", " env_id: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n", " ),\n", " terminated: array([False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False]),\n", " obs_next: array([[ 3.3874005e-02, 1.5613718e-01, 2.9402241e-02, -2.4285485e-01],\n", " [ 3.6996752e-02, 3.5082710e-01, 2.4545144e-02, -5.2612048e-01],\n", " [ 4.4013292e-02, 5.4559523e-01, 1.4022734e-02, -8.1096911e-01],\n", " [ 5.4925196e-02, 7.4052227e-01, -2.1966479e-03, -1.0992084e+00],\n", " [ 6.9735639e-02, 9.3567306e-01, -2.4180816e-02, -1.3925797e+00],\n", " [ 8.8449106e-02, 1.1310875e+00, -5.2032411e-02, -1.6927242e+00],\n", " [ 1.1107086e-01, 1.3267703e+00, -8.5886896e-02, -2.0011418e+00],\n", " [ 1.3760626e-01, 1.5226773e+00, -1.2590973e-01, -2.3191388e+00],\n", " [ 1.6805981e-01, 1.7187009e+00, -1.7229250e-01, -2.6477661e+00],\n", " [ 2.0243382e-01, 1.9146510e+00, -2.2524783e-01, -2.9877436e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],\n", " dtype=float32),\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "collector.collect(n_episode=10)\n", "collector.buffer.buffers[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that each element of the transitions (s, a, r, s', done, terminated) is saved in a preallocated array of 1000 entries. The first replay buffer has only saved one short episode, so most of the data is zero.\n", "\n", "Sampling a minibatch of transitions is easy:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Batch(\n", " obs: array([[ 0.06973564, 0.93567306, -0.02418082, -1.3925797 ],\n", " [ 0.00501526, 0.438202 , -0.00610071, -0.5498225 ],\n", " [ 0.00501526, 0.438202 , -0.00610071, -0.5498225 ],\n", " [ 0.08791052, 1.4160656 , -0.12085497, -2.0735345 ],\n", " [-0.00449122, 0.5818429 , 0.01416115, -0.8354608 ],\n", " [ 0.1239062 , 1.754971 , -0.17513801, -2.6740465 ],\n", " [-0.04072677, 0.42806664, -0.00650083, -0.5649815 ],\n", " [-0.01800847, 0.19060245, -0.03275078, -0.3052972 ],\n", " [ 0.04785302, 0.36885726, -0.01701378, -0.5969896 ],\n", " [-0.01481533, -0.01626822, -0.01261962, -0.03139872]],\n", " dtype=float32),\n", " act: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),\n", " rew: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n", " terminated: array([False, False, False, False, False, True, False, False, False,\n", " False]),\n", " truncated: array([False, False, False, False, False, False, False, False, False,\n", " False]),\n", " done: array([False, False, False, False, False, True, False, False, False,\n", " False]),\n", " obs_next: array([[ 8.84491056e-02, 1.13108754e+00, -5.20324111e-02,\n", " -1.69272423e+00],\n", " [ 1.37792965e-02, 6.33409083e-01, -1.70971639e-02,\n", " -8.44421327e-01],\n", " [ 1.37792965e-02, 6.33409083e-01, -1.70971639e-02,\n", " -8.44421327e-01],\n", " [ 1.16231829e-01, 1.61218965e+00, -1.62325650e-01,\n", " -2.40101957e+00],\n", " [ 7.14564137e-03, 7.76768565e-01, -2.54806434e-03,\n", " -1.12365675e+00],\n", " [ 1.59005612e-01, 1.95090282e+00, -2.28618950e-01,\n", " -3.01467609e+00],\n", " [-3.21654342e-02, 6.23279154e-01, -1.78004559e-02,\n", " -8.59705389e-01],\n", " [-1.41964173e-02, 3.86175454e-01, -3.88567224e-02,\n", " -6.08126342e-01],\n", " [ 5.52301593e-02, 5.64213097e-01, -2.89535765e-02,\n", " -8.94982755e-01],\n", " [-1.51406955e-02, 1.79032415e-01, -1.32475952e-02,\n", " -3.28036398e-01]], dtype=float32),\n", " info: Batch(\n", " env_id: array([0, 1, 1, 1, 2, 2, 4, 6, 8, 9]),\n", " ),\n", " policy: Batch(),\n", " ),\n", " array([ 5, 104, 104, 109, 203, 209, 402, 601, 802, 900]))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "collector.buffer.sample(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The transitions come randomly from the workers, so we do not need to worry about it.\n", "\n", "We can reset the buffers with the following command:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "collector.reset_buffer(keep_statistics=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we print the buffer, data still seems to be there:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ReplayBuffer(\n", " obs: array([[ 3.4645274e-02, -3.8563393e-02, 2.8588787e-02, 4.0672716e-02],\n", " [ 3.3874005e-02, 1.5613718e-01, 2.9402241e-02, -2.4285485e-01],\n", " [ 3.6996752e-02, 3.5082710e-01, 2.4545144e-02, -5.2612048e-01],\n", " [ 4.4013292e-02, 5.4559523e-01, 1.4022734e-02, -8.1096911e-01],\n", " [ 5.4925196e-02, 7.4052227e-01, -2.1966479e-03, -1.0992084e+00],\n", " [ 6.9735639e-02, 9.3567306e-01, -2.4180816e-02, -1.3925797e+00],\n", " [ 8.8449106e-02, 1.1310875e+00, -5.2032411e-02, -1.6927242e+00],\n", " [ 1.1107086e-01, 1.3267703e+00, -8.5886896e-02, -2.0011418e+00],\n", " [ 1.3760626e-01, 1.5226773e+00, -1.2590973e-01, -2.3191388e+00],\n", " [ 1.6805981e-01, 1.7187009e+00, -1.7229250e-01, -2.6477661e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],\n", " dtype=float32),\n", " rew: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),\n", " truncated: array([False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False]),\n", " act: array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n", " done: array([False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False]),\n", " policy: Batch(),\n", " info: Batch(\n", " env_id: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n", " ),\n", " terminated: array([False, False, False, False, False, False, False, False, False,\n", " True, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False,\n", " False]),\n", " obs_next: array([[ 3.3874005e-02, 1.5613718e-01, 2.9402241e-02, -2.4285485e-01],\n", " [ 3.6996752e-02, 3.5082710e-01, 2.4545144e-02, -5.2612048e-01],\n", " [ 4.4013292e-02, 5.4559523e-01, 1.4022734e-02, -8.1096911e-01],\n", " [ 5.4925196e-02, 7.4052227e-01, -2.1966479e-03, -1.0992084e+00],\n", " [ 6.9735639e-02, 9.3567306e-01, -2.4180816e-02, -1.3925797e+00],\n", " [ 8.8449106e-02, 1.1310875e+00, -5.2032411e-02, -1.6927242e+00],\n", " [ 1.1107086e-01, 1.3267703e+00, -8.5886896e-02, -2.0011418e+00],\n", " [ 1.3760626e-01, 1.5226773e+00, -1.2590973e-01, -2.3191388e+00],\n", " [ 1.6805981e-01, 1.7187009e+00, -1.7229250e-01, -2.6477661e+00],\n", " [ 2.0243382e-01, 1.9146510e+00, -2.2524783e-01, -2.9877436e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],\n", " [ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00]],\n", " dtype=float32),\n", ")" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "collector.buffer.buffers[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But sampling returns an error:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "probabilities contain NaN", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[16], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mcollector\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbuffer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/.virtualenvs/tianshou/lib/python3.11/site-packages/tianshou/data/buffer/base.py:313\u001b[0m, in \u001b[0;36mReplayBuffer.sample\u001b[0;34m(self, batch_size)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msample\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch_size: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Batch, np\u001b[38;5;241m.\u001b[39mndarray]:\n\u001b[1;32m 307\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Get a random sample from buffer with size = batch_size.\u001b[39;00m\n\u001b[1;32m 308\u001b[0m \n\u001b[1;32m 309\u001b[0m \u001b[38;5;124;03m Return all the data in the buffer if batch_size is 0.\u001b[39;00m\n\u001b[1;32m 310\u001b[0m \n\u001b[1;32m 311\u001b[0m \u001b[38;5;124;03m :return: Sample data and its corresponding index inside the buffer.\u001b[39;00m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 313\u001b[0m indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample_indices\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m[indices], indices\n", "File \u001b[0;32m~/.virtualenvs/tianshou/lib/python3.11/site-packages/tianshou/data/buffer/manager.py:180\u001b[0m, in \u001b[0;36mReplayBufferManager.sample_indices\u001b[0;34m(self, batch_size)\u001b[0m\n\u001b[1;32m 178\u001b[0m sample_num \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuffer_num, \u001b[38;5;28mint\u001b[39m)\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 180\u001b[0m buffer_idx \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mchoice\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 181\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbuffer_num\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_lengths\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_lengths\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msum\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 182\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 183\u001b[0m sample_num \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mbincount(buffer_idx, minlength\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuffer_num)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;66;03m# avoid batch_size > 0 and sample_num == 0 -> get child's all data\u001b[39;00m\n", "File \u001b[0;32mmtrand.pyx:954\u001b[0m, in \u001b[0;36mnumpy.random.mtrand.RandomState.choice\u001b[0;34m()\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: probabilities contain NaN" ] } ], "source": [ "collector.buffer.sample(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interaction loop\n", "\n", "We do not even need to actually sample the buffer, because `policy.update()` takes the buffer and a batch size as input. The following code implements DQN on Cartpole, with an okayish choice of hyperparameters. The main interaction loop consists of:\n", "\n", "1. `collector.collect()`: Collect 100 samples using the 10 workers and store them in the ERM.\n", "2. `policy.update()`: Sample `repeat=10` minitaches of 64 transitions from the buffer and learn from them.\n", "3. `test_collector.collect()`: Test the performance by running 10 episodes without exploration.\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Model\n", "net = ts.utils.net.common.Net(\n", " env.observation_space.shape,\n", " env.action_space.n,\n", " hidden_sizes=[64, 64],\n", " device=device,\n", ").to(device)\n", "\n", "optim = torch.optim.Adam(net.parameters(), lr=0.001)\n", "\n", "# Policy\n", "policy = ts.policy.DQNPolicy(\n", " model=net, # value network\n", " optim=optim, # optimizer\n", " discount_factor=0.99, # gamma\n", " estimation_step=1, # n-step returns\n", " is_double=False, # double Q-learning\n", " target_update_freq=120, # how often to update the target network\n", " action_space=env.action_space, # action space\n", ")\n", "policy.set_eps(0.1) # epsilon-greedy action selection\n", "\n", "# Collector for training\n", "collector = ts.data.Collector(\n", " policy=policy, \n", " env=ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]),\n", " buffer=ts.data.VectorReplayBuffer(20000, 10),\n", " exploration_noise=True\n", ")\n", "# Collector for testing (without exploration). No need for a buffer\n", "test_collector = ts.data.Collector(\n", " policy=policy, \n", " env=ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]),\n", " exploration_noise=False\n", ")\n", "# Pre-fill the training buffer with random transitions\n", "collector.collect(n_step=1000, random=True)\n", "\n", "# Interaction\n", "returns = []\n", "for iteration in range(1000):\n", "\n", " # Training mode\n", " policy.train()\n", " \n", " # Collect transitions\n", " result = collector.collect(n_step=100)\n", "\n", " # Train DQN network on minibatch\n", " policy.update(\n", " buffer=collector.buffer,\n", " sample_size=0, # use the whole buffer\n", " batch_size=64,\n", " repeat=10,\n", " )\n", "\n", " # Test 10 episodes\n", " policy.eval()\n", " result = test_collector.collect(n_episode=10)\n", " mean_reward = result['rew']\n", " #print(iteration, \":\", mean_reward)\n", " returns.append(mean_reward)\n", "\n", "plt.figure()\n", "plt.plot(np.array(returns))\n", "plt.xlabel(\"Epochs\")\n", "plt.ylabel(\"Returns\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Return: 200.0\n", "MoviePy - Building file videos/cartpole-dqn.gif with imageio.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Evaluation mode\n", "policy.eval() \n", "\n", "# Create a recordable environment\n", "env = gym.make('CartPole-v0', render_mode=\"rgb_array_list\")\n", "recorder = GymRecorder(env)\n", "\n", "# Sample the initial state\n", "state, info = env.reset()\n", "\n", "# One episode:\n", "done = False\n", "return_episode = 0\n", "while not done:\n", " # Select an action from the learned policy\n", " action = policy.forward(ts.data.Batch(obs=[state], info=None)).act[0]\n", " # Sample a single transition\n", " next_state, reward, terminal, truncated, info = env.step(action)\n", " # End of the episode\n", " done = terminal or truncated\n", " # Update undiscounted return\n", " return_episode += reward\n", " # Go in the next state\n", " state = next_state\n", "print(\"Return:\", return_episode)\n", "\n", "recorder.record(env.render())\n", "video = \"videos/cartpole-dqn.gif\"\n", "recorder.make_video(video)\n", "ipython_display(video)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q:** Understand and run the code. Experiment with the hyperparameters and compare them to the previous exercise. In particular, what is `estimation_step=3` in the constructor of PPO? What is its influence?\n", "\n", "**Q:** Implement scheduling of the exploration parameter with the right hyperparameters." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**A:** The main difference is that much more data is collected by the 10 workers before the DQN is updated (10 times in a row instead of 1). Distributed learning can also be used in DQN-like algorithms ans speed up learning. Collecting 1000 transitions instead of 100 at each iteration is even better.\n", "\n", "`estimation_step` allows to use n-step returns instead of the vanilla 1-step return r + gamma * V(s'). Choosing a value of `n=3` (and double Q-learning) stabilizes learning a lot.\n", "\n", "There are many ways to implement exploration scheduling. Here is one that works OK. " ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Model\n", "net = ts.utils.net.common.Net(\n", " env.observation_space.shape,\n", " env.action_space.n,\n", " hidden_sizes=[128, 128],\n", " device=device,\n", ").to(device)\n", "\n", "optim = torch.optim.Adam(net.parameters(), lr=0.001)\n", "\n", "# Policy\n", "policy = ts.policy.DQNPolicy(\n", " model=net, # value network\n", " optim=optim, # optimizer\n", " discount_factor=0.99, # gamma\n", " estimation_step=3, # n-step returns\n", " is_double=True, # double Q-learning\n", " target_update_freq=120, # how often to update the target network\n", " action_space=env.action_space, # action space\n", ")\n", "policy.set_eps(0.1) # epsilon-greedy action selection\n", "\n", "# Collector for training\n", "collector = ts.data.Collector(\n", " policy=policy, \n", " env=ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]),\n", " buffer=ts.data.VectorReplayBuffer(20000, 10),\n", " exploration_noise=True\n", ")\n", "# Collector for testing (without exploration). No need for a buffer\n", "test_collector = ts.data.Collector(\n", " policy=policy, \n", " env=ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]),\n", " exploration_noise=False\n", ")\n", "# Pre-fill the training buffer with random transitions\n", "collector.collect(n_step=1000, random=True)\n", "\n", "# Interaction\n", "returns = []\n", "for iteration in range(1000):\n", "\n", " # Training mode\n", " policy.train()\n", "\n", " # Exploration schedule\n", " if iteration <= 100:\n", " eps = 0.5\n", " elif iteration <= 900:\n", " eps = 0.5 - (iteration - 100) / (900 - 100) * (0.5 - 0.05)\n", " else:\n", " eps = 0.05\n", " policy.set_eps(eps)\n", " \n", " # Collect transitions\n", " result = collector.collect(n_step=1000)\n", "\n", " # Train DQN network on minibatch\n", " policy.update(\n", " buffer=collector.buffer,\n", " sample_size=0, # use the whole buffer\n", " batch_size=64,\n", " repeat=10,\n", " )\n", "\n", " # Test 10 episodes\n", " policy.eval()\n", " result = test_collector.collect(n_episode=10)\n", " mean_reward = result['rew']\n", " #print(iteration, \":\", mean_reward, eps)\n", " returns.append(mean_reward)\n", "\n", "plt.figure()\n", "plt.plot(np.array(returns))\n", "plt.xlabel(\"Epochs\")\n", "plt.ylabel(\"Returns\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PPO\n", "\n", "Now that DQN works on Cartpole, let's use PPO and compare its performance to DQN.\n", "\n", "You will need to use the PPO policy, obviously:\n", "\n", "```python\n", "policy = ts.policy.PPOPolicy(\n", " actor=actor, \n", " critic=critic, \n", " optim=optim,\n", " dist_fn=torch.distributions.Categorical, \n", " action_space=env.action_space,\n", " discount_factor=0.99,\n", " max_grad_norm=0.5,\n", " eps_clip=0.2,\n", " gae_lambda=0.95,\n", " deterministic_eval=True,\n", " action_scaling=False,\n", ")\n", "```\n", "\n", "It has many more hyperparameters, which can be let at their default value (or not, depending on the time you have). Check the doc for their meaning. The important thing is that you now need an actor and a critic, not a single network.\n", "\n", "One way to do it is to use the actor/critic specifications provided by tianshou:\n", "\n", "```python\n", "features = ts.utils.net.common.Net(\n", " state_shape=env.observation_space.shape, \n", " hidden_sizes=[64, 64], \n", " device=device)\n", "\n", "actor = ts.utils.net.discrete.Actor(\n", " preprocess_net=features, \n", " action_shape=env.action_space.n, \n", " device=device).to(device)\n", "\n", "critic = ts.utils.net.discrete.Critic(\n", " preprocess_net=features, \n", " device=device).to(device)\n", "\n", "actor_critic = ts.utils.net.common.ActorCritic(actor=actor, critic=critic)\n", "\n", "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.001)\n", "```\n", "\n", "``features`` is the shared feature extractor between the actor and the critic. ``actor`` is the policy head (one neuron per discrete action), ``critic`` is a single output neuron for the value V(s). ``actor_critic`` is the combined two-headed network. \n", "\n", "When defining the PPO policy, `dist_fn=torch.distributions.Categorical` specifies how exploration is performed, here a softmax over the two actions left and right. \n", "\n", "**Q:** Implement PPO on Cartpole. You will need to find the right hyperparameters for the task. \n", "\n", "Remember that learning is **on-policy**, so the transition buffer must be emptied after training the network. Your interaction loop must therefore look like this:\n", "\n", "```python\n", "for iteration in range(N):\n", "\n", " # Collect enough on-policy steps\n", " result = collector.collect(...)\n", "\n", " # Update the PP0 network by learning the on-policy buffer\n", " policy.update(...)\n", "\n", " # Empty the buffer as we are on-policy\n", " collector.reset_buffer(keep_statistics=False)\n", "```\n", "\n", "Do not hesitate to collect many on-policy steps before training the network." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGwCAYAAACD0J42AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAzfUlEQVR4nO3df3xU1Z3/8fdMJpmEH0lINImRBPAnUClSUIiyCpICAfEH2F1cSqPykNVNVEgXFX+tbpfGurb+oCi2q7BWWb9rK1TTlhZBg6whQpBWKUXRKKyQoMYkJJDJJHO/f8BcMhBIZubemQl5PR+PPB7M3JvJmdN9eN577vmc4zAMwxAAAEAMcUa7AQAAAMcjoAAAgJhDQAEAADGHgAIAAGIOAQUAAMQcAgoAAIg5BBQAABBzXNFuQCh8Pp/27dun/v37y+FwRLs5AACgGwzD0MGDB5WdnS2n89RzJD0yoOzbt085OTnRbgYAAAjB3r17NXDgwFPe0yMDSv/+/SUd+YLJyclRbg0AAOiOxsZG5eTkmOP4qfTIgOJ/rJOcnExAAQCgh+nO8gwWyQIAgJhDQAEAADGHgAIAAGIOAQUAAMQcAgoAAIg5BBQAABBzCCgAACDmEFAAAEDMIaAAAICYQ0ABAAAxJ6iAUlpaqksuuUT9+/dXRkaGrrvuOu3atSvgnpaWFhUVFSk9PV39+vXTrFmzVFtbG3DPnj17NH36dPXp00cZGRlatGiR2trawv82AADgtBBUQCkvL1dRUZE2b96sdevWyev1avLkyWpubjbvWbhwod544w29+uqrKi8v1759+zRz5kzzent7u6ZPn67W1la9++67+q//+i+tXLlSDz30kHXfCgAA9GgOwzCMUH/5yy+/VEZGhsrLy3XFFVeooaFBZ555platWqUbbrhBkvS3v/1Nw4YNU0VFhcaNG6c//OEPuvrqq7Vv3z5lZmZKkpYvX6577rlHX375pRISErr8u42NjUpJSVFDQ0OPPiyw3WfI2+5TYnxct3/HMAzVH/IqJSleTufJD1tq9rTpm0OtVjQTANALuV1xOrO/29LPDGb8Dus044aGBklSWlqaJKmqqkper1f5+fnmPUOHDlVubq4ZUCoqKjRixAgznEjSlClTdPvtt2vHjh0aNWrUCX/H4/HI4/EEfMHTwT/+crOqv2rW24smqE9C5/9T/OX/6vXq1v/T3m8OaW/dIf3fN4flafNpcHofLfzuBbr629mK6xBUvm7yaHn5J3qx4nN52nyR+ioAgNPMFRecqRdvuTRqfz/kgOLz+bRgwQJdfvnluuiiiyRJNTU1SkhIUGpqasC9mZmZqqmpMe/pGE781/3XOlNaWqpHHnkk1KbGJMMwtOWzOvkM6bOvDml4dudJ8t7ffKC/7j8xkH329SHd9cp2LXtrt0q+e4HGnZOu/3ynWi/8b7UOtbZLkhJcTnV9oDUAACdKiIvuCBJyQCkqKtKHH36oTZs2WdmeTi1evFglJSXm68bGRuXk5Nj+d+3U3Nou39GHa/WneBRT09giSVo05UJdnJOqnAF9lJIUr5cqP9dz5Z/oo9om3fbSNrmcDrUd/cCLzk7WDydfqAkXnCmHg4gCAOh5QgooxcXFKisr08aNGzVw4EDz/aysLLW2tqq+vj5gFqW2tlZZWVnmPe+9917A5/mrfPz3HM/tdsvttvY5WLQ1Hvaa/647SUBp9xlmePn7MTkBzwKLJp6n748bpP9851O9sKlaza3tuiCzn0q+e6GmfCuTYAIA6NGCquIxDEPFxcVavXq1NmzYoCFDhgRcHz16tOLj47V+/XrzvV27dmnPnj3Ky8uTJOXl5emDDz7QgQMHzHvWrVun5ORkDR8+PJzv0qMcbDlWVv1Nc+cBpfGw15xlSe0Tf8L1lKR4/XDyhXrnnqv0yvxx+sNdV2jqRVmEEwBAjxfUDEpRUZFWrVql3/72t+rfv7+5ZiQlJUVJSUlKSUnRvHnzVFJSorS0NCUnJ+uOO+5QXl6exo0bJ0maPHmyhg8frrlz5+qxxx5TTU2NHnjgARUVFZ12sySn0tjSYQal2dvpPf6ZleREl+LjTp4l0/omaNw56dY2EACAKAoqoDz77LOSpAkTJgS8v2LFCt10002SpCeeeEJOp1OzZs2Sx+PRlClT9Mwzz5j3xsXFqaysTLfffrvy8vLUt29fFRYW6t/+7d/C+yY9TMdHPCcrB/bPrKT17br0GgCA00lQAaU7W6YkJiZq2bJlWrZs2UnvGTRokH7/+98H86dPO4EzKJ0HFP/7AwgoAIBehrN4oqTxcIc1KCebQTn6flofAgoAoHchoERJdx7x+NemMIMCAOhtCChR0vERzzcnWSRrzqAQUAAAvQwBJUo6PuLpcg0Kj3gAAL0MASVKOs6gHPa26/DR7ek7OlbFc+IeKAAAnM4IKFHSMaBIna9D8e+DwgwKAKC3IaBEScdHPFLnAYV9UAAAvRUBJUpOmEHpZKEs+6AAAHorAkqU+M/iST8aPo4/MNDb7lPj0XvYBwUA0NsQUKLAMAxzH5Tc9D6STjwwsP7QketOh5ScxCJZAEDvQkCJgsPedrUdPaZ4cHpfSSeWGvvXpKT2SVCck9OJAQC9CwElCvwLZF1Oh7JTEyVJ9cc94jm2BwqzJwCA3oeAEgX+BbLJSfFK6+uWJNUdOn7RLBU8AIDei4ASBf71J8mJLnMTtuPXoLAHCgCgNyOgREHHGRR/ADlhDQozKACAXoyAEgX+NSjJiccCyvEbtXGSMQCgNyOgRIF/BqV/osucITlZFQ97oAAAeiMCShQcW4MSb86QeNp8AQcGsossAKA3I6BEgX+H2OQkl/omxCkh7sj/DB13kzVnUDjJGADQCxFQoqDjDIrD4dCATip5ju2DwgwKAKD3IaBEwUFzBuVIMOmskocqHgBAb0ZAiYJjZcYuSTqhkqfF267mo+tRWIMCAOiNCChR0PERj6QTKnn8BwW6nA71d7ui0EIAAKKLgBIFjcc/4vGvQTkaTL5u9hx9P0EOBwcFAgB6HwJKFJwwg+J/xHN0BuWbo5u0sQcKAKC3IqBEmGEYJ65B8T/iOboGxTyHhxJjAEAvRUCJsBavT952Q5LU/7g1KMdmUKjgAQD0bgSUCPPPnjgdUt+EOElS6nFlxuyBAgDo7QgoEWauP0mKNxfAph1XZnxsF1kCCgCgdyKgRJi5/iTx2PqSjlU8hmEwgwIA6PUIKBHWePjYOTx+/pmS1jafDrW2M4MCAOj1CCgR1tkMSlJ8nNyuowcGNreq7miZMbvIAgB6KwJKhJmbtHUIKA6H41glz6HWY1U8POIBAPRSBJQIO7ZINnALe38lz9fNreyDAgDo9QgoEdbZIx5JSjsaRr745rBa23xH32MGBQDQOxFQIuzYItnAgOKv2Pn0y2ZJktvlVFJ8XGQbBwBAjCCgRJh/BqV/YuAjHv9sySdfNpmvOSgQANBbEVAi7PiDAv38Myj+gMIeKACA3oyAEmFmFU/S8WtQjgSSL+oPB7wGAKA3IqBE2EFzBuX4Kp4jgcU4co4ge6AAAHo1AkqEmVU8J5lBMV/3ocQYANB7EVAiyDCMLqt4zNfMoAAAejECigUMw9DeukMy/M9nTsLT5lNr+5E9To5/xHPCDAoBBQDQixFQLPDM25/o7x57S6//ed8p7/NX8DgdUt+EwIBywgwKVTwAgF6MgGKBTw4cKQ1+628HTnmfv4Knf2K8nM7APU6SEuICNmZjBgUA0JsRUCzgOfrY5v299ae879gCWVen1zuGEmZQAAC9GQHFAh7vkYDy+deHVHf0JOLO+B/x9Hd3XqGT2qFyhxkUAEBvRkCxgH/hqyT9+RSzKMc2aet6BiWVMmMAQC8WdEDZuHGjZsyYoezsbDkcDq1ZsybgelNTk4qLizVw4EAlJSVp+PDhWr58ecA9LS0tKioqUnp6uvr166dZs2aptrY2rC8STa1t7ea/T/WY52Tb3Pv5H+v0TYhTIgcFAgB6saADSnNzs0aOHKlly5Z1er2kpERr167VSy+9pJ07d2rBggUqLi7W66+/bt6zcOFCvfHGG3r11VdVXl6uffv2aebMmaF/iyjztB2bQdl+yhmUzjdp8/PPoKT14/EOAKB36/xZwykUFBSooKDgpNffffddFRYWasKECZKk+fPn67nnntN7772na665Rg0NDXr++ee1atUqXXXVVZKkFStWaNiwYdq8ebPGjRsX2jeJota2wEc8hmF0ehKxuUlbFzMoaSyQBQD0cpavQbnsssv0+uuv64svvpBhGHrrrbf00UcfafLkyZKkqqoqeb1e5efnm78zdOhQ5ebmqqKiotPP9Hg8amxsDPiJJR0DSsNhr6q/au70vi6reI7OnLCLLACgt7M8oCxdulTDhw/XwIEDlZCQoKlTp2rZsmW64oorJEk1NTVKSEhQampqwO9lZmaqpqam088sLS1VSkqK+ZOTk2N1s8Pif8ST4DrSnSd7zNPVGpTJwzN11dAMFV422PI2AgDQk9gSUDZv3qzXX39dVVVV+ulPf6qioiK9+eabIX/m4sWL1dDQYP7s3bvXwhaHzz+DcnFOqqRTBJSWzs/h8ctMTtQLN12iiRdmWN5GAAB6kqDXoJzK4cOHdd9992n16tWaPn26JOnb3/62tm/frscff1z5+fnKyspSa2ur6uvrA2ZRamtrlZWV1ennut1uud1uK5tqKX+Z8bghaXqvuq4bMyiWdjsAAKcdS2dQvF6vvF6vnM7Aj42Li5PPd2QQHz16tOLj47V+/Xrz+q5du7Rnzx7l5eVZ2ZyI8c+gXDokXZK0c3+jWrztJ9zXVRUPAAA4Iuj/V76pqUm7d+82X1dXV2v79u1KS0tTbm6urrzySi1atEhJSUkaNGiQysvL9eKLL+pnP/uZJCklJUXz5s1TSUmJ0tLSlJycrDvuuEN5eXk9soJHkjxH90E558y+OqNfgr5qatWOfY0aPWhAwH0HzbN4mEEBAOBUgh4pt27dqokTJ5qvS0pKJEmFhYVauXKlXnnlFS1evFhz5sxRXV2dBg0apCVLlui2224zf+eJJ56Q0+nUrFmz5PF4NGXKFD3zzDMWfJ3I8/kMedsNSZLb5dTFOal6c+cBbd9bf0JA6WqRLAAAOCLogDJhwgQZhnHS61lZWVqxYsUpPyMxMVHLli076WZvPUnHbe4TjgsoHbV4281qHx7xAABwapzFE6aOu8i6XXG6OOfIrMn2vd8E3Od/vONwSP3dPOIBAOBUCChh6rhJW3ycQ9/OSZHDIe2tO6yvmzzmNf8C2X5ul5zOE3eZBQAAxxBQwuR/xJPgcsrhcCg5MV7nntlPUuB+KKw/AQCg+wgoYfIcLSd2xx3rSv+Gbe/vqTff62qTNgAAcAwBJUz+GRR3/IkBpfMZFNafAADQFUbLMPnXoCR0MoNSWf21/ulXW3XZuWeoprFFEjMoAAB0BwElTMcfFChJQ7P668LM/tpVe1B/3FGrP+6oNa+xBgUAgK4RUMLkn0Fxu+LM91xxTv3uzvH64IsGvfvJ16r45Gtt+axOnjafLjo7OVpNBQCgxyCghKm1kxkU6UhIGZU7QKNyB6ho4nnytLWrtsGjgQOSotFMAAB6FAJKmPzn8BwfUI7ndsUpN71PJJoEAECPRxVPmDzmIx66EgAAqzCqhulkj3gAAEDoGFXDZO4kG0dXAgBgFUbVMHm8/o3a4rq4EwAAdBcBJUzMoAAAYD1G1TCxBgUAAOsxqobJX2ZMFQ8AANZhVA1TK2XGAABYjlE1TDziAQDAeoyqYfJ0cpoxAAAID6NqmMxHPPF0JQAAVmFUDZOHMmMAACzHqBom/0ZtCS42agMAwCoElDD5N2qjigcAAOswqoap9eg+KFTxAABgHUbVMFFmDACA9RhVw+RhozYAACzHqBomZlAAALAeo2qYWCQLAID1GFXD5C8zdlNmDACAZQgoYfLPoPCIBwAA6zCqhqmVs3gAALAco2qYPOyDAgCA5RhVw+DzGfK2G5JYJAsAgJUYVcPgX38iMYMCAICVGFXD4N+kTSKgAABgJUbVMLR2DCgskgUAwDKMqmHoWGLscDii3BoAAE4fBJQweLxHKnjczJ4AAGApRtYwmNvcx9ONAABYiZE1DGzSBgCAPRhZw8BJxgAA2IORNQz+MmMOCgQAwFoElDAwgwIAgD0YWcPgIaAAAGALRtYw+A8K5BweAACsxcgaBh7xAABgD0bWMJg7yVJmDACApYIeWTdu3KgZM2YoOztbDodDa9asOeGenTt36pprrlFKSor69u2rSy65RHv27DGvt7S0qKioSOnp6erXr59mzZql2trasL5INHi8zKAAAGCHoEfW5uZmjRw5UsuWLev0+ieffKLx48dr6NChevvtt/WXv/xFDz74oBITE817Fi5cqDfeeEOvvvqqysvLtW/fPs2cOTP0bxEl5k6ylBkDAGApV7C/UFBQoIKCgpNev//++zVt2jQ99thj5nvnnnuu+e+GhgY9//zzWrVqla666ipJ0ooVKzRs2DBt3rxZ48aNC7ZJUcMaFAAA7GHpyOrz+fS73/1OF1xwgaZMmaKMjAyNHTs24DFQVVWVvF6v8vPzzfeGDh2q3NxcVVRUdPq5Ho9HjY2NAT+xgCoeAADsYenIeuDAATU1NenRRx/V1KlT9ac//UnXX3+9Zs6cqfLycklSTU2NEhISlJqaGvC7mZmZqqmp6fRzS0tLlZKSYv7k5ORY2eyQtZo7yRJQAACwkuUzKJJ07bXXauHChbr44ot177336uqrr9by5ctD/tzFixeroaHB/Nm7d69VTQ4Lj3gAALBH0GtQTuWMM86Qy+XS8OHDA94fNmyYNm3aJEnKyspSa2ur6uvrA2ZRamtrlZWV1ennut1uud1uK5tqCQ+nGQMAYAtLR9aEhARdcskl2rVrV8D7H330kQYNGiRJGj16tOLj47V+/Xrz+q5du7Rnzx7l5eVZ2RzbmY944gkoAABYKegZlKamJu3evdt8XV1dre3btystLU25ublatGiR/uEf/kFXXHGFJk6cqLVr1+qNN97Q22+/LUlKSUnRvHnzVFJSorS0NCUnJ+uOO+5QXl5ej6rgkSQPG7UBAGCLoAPK1q1bNXHiRPN1SUmJJKmwsFArV67U9ddfr+XLl6u0tFR33nmnLrzwQv3mN7/R+PHjzd954okn5HQ6NWvWLHk8Hk2ZMkXPPPOMBV8nso6tQWEfFAAArOQwDMOIdiOC1djYqJSUFDU0NCg5OTlq7fjBC+9p40df6qffG6lZowdGrR0AAPQEwYzfPJsIQ+vRfVCo4gEAwFqMrGGgzBgAAHswsobBw0ZtAADYgpE1DMygAABgD0bWMBw7zZhuBADASoysYfB4/fugUGYMAICVCChhMGdQ2EkWAABLMbKGoZWzeAAAsAUjaxg87IMCAIAtGFlD5PMZ8rYf2YSXRbIAAFiLkTVE/vUnEjMoAABYjZE1RP5N2iQCCgAAVmNkDVFrx4DCIlkAACzFyBoi/yOeBJdTDocjyq0BAOD0QkAJkX8Gxc3sCQAAlmN0DZG/xJhN2gAAsB6ja4jYpA0AAPswuoaIk4wBALAPo2uI/GXGbhcHBQIAYDUCSoiYQQEAwD6MriHyEFAAALANo2uIzIMCWSQLAIDlGF1DZO6DQpkxAACWY3QNkbmTLDMoAABYjtE1RB4va1AAALALo2uI/DMolBkDAGA9AkqIKDMGAMA+jK4hMs/iIaAAAGA5RtcQmVU8BBQAACzH6BoiHvEAAGAfRtcQUWYMAIB9GF1D5C8zZqM2AACsx+gaIg8zKAAA2IbRNUTH1qCwDwoAAFYjoITIQxUPAAC2YXQNUav/NGMCCgAAlmN0DRFlxgAA2IfRNUQeAgoAALZhdA0RO8kCAGAfRtcQHTvNmC4EAMBqjK4h8m/UlhBHmTEAAFYjoITInEFhJ1kAACzH6Bois4qHnWQBALAco2uIPOyDAgCAbRhdQ+DzGfK2G5JYJAsAgB0YXUPgX38iMYMCAIAdGF1DQEABAMBejK4h8JcYSyySBQDADkGPrhs3btSMGTOUnZ0th8OhNWvWnPTe2267TQ6HQ08++WTA+3V1dZozZ46Sk5OVmpqqefPmqampKdimRI1/BiXB5ZTD4YhyawAAOP0EHVCam5s1cuRILVu27JT3rV69Wps3b1Z2dvYJ1+bMmaMdO3Zo3bp1Kisr08aNGzV//vxgmxI15jb3zJ4AAGALV7C/UFBQoIKCglPe88UXX+iOO+7QH//4R02fPj3g2s6dO7V27Vpt2bJFY8aMkSQtXbpU06ZN0+OPP95poPF4PPJ4PObrxsbGYJttKX+JMZu0AQBgD8tHWJ/Pp7lz52rRokX61re+dcL1iooKpaammuFEkvLz8+V0OlVZWdnpZ5aWliolJcX8ycnJsbrZQWGTNgAA7GX5CPuTn/xELpdLd955Z6fXa2pqlJGREfCey+VSWlqaampqOv2dxYsXq6GhwfzZu3ev1c0OihlQqOABAMAWQT/iOZWqqio99dRT2rZtm6WLR91ut9xut2WfFy4PAQUAAFtZOsK+8847OnDggHJzc+VyueRyufT555/rhz/8oQYPHixJysrK0oEDBwJ+r62tTXV1dcrKyrKyObYxF8m6OMkYAAA7WDqDMnfuXOXn5we8N2XKFM2dO1c333yzJCkvL0/19fWqqqrS6NGjJUkbNmyQz+fT2LFjrWyObZhBAQDAXkEHlKamJu3evdt8XV1dre3btystLU25ublKT08PuD8+Pl5ZWVm68MILJUnDhg3T1KlTdeutt2r58uXyer0qLi7W7NmzO63giUXmQYEskgUAwBZBj7Bbt27VqFGjNGrUKElSSUmJRo0apYceeqjbn/Hyyy9r6NChmjRpkqZNm6bx48frF7/4RbBNiRrzEQ9lxgAA2CLoGZQJEybIMIxu3//ZZ5+d8F5aWppWrVoV7J+OGeZOssygAABgC0bYEPjP4mENCgAA9mCEDYF/BoUqHgAA7EFACQEbtQEAYC9G2BAc2weF7gMAwA6MsCEwDwskoAAAYAtG2BDwiAcAAHsxwoaAMmMAAOzFCBsCf5kxG7UBAGAPRtgQeJhBAQDAVoywITi2BoV9UAAAsAMBJQScZgwAgL0YYUPQSpkxAAC2YoQNAWXGAADYixE2BDziAQDAXoywIWCrewAA7MUIG4JjpxnTfQAA2IERNgT+jdoS4igzBgDADgSUEJgzKOwkCwCALRhhQ2BW8bCTLAAAtmCEDQFlxgAA2IsRNkg+n8EiWQAAbMYIGyR/OJGYQQEAwC6MsEEioAAAYD9G2CD5S4wlFskCAGAXRtgg+WdQElxOORyOKLcGAIDTEwElSOY298yeAABgG0bZIHna2iWxSRsAAHZilA3S4dYjASUxnm3uAQCwCwElSC1HF8kSUAAAsA8BJUgt3iMzKEkEFAAAbENACdJhAgoAALYjoATJXIOSQEABAMAuBJQgtRyt4klkF1kAAGzDKBsk/wxKEjMoAADYhoASJBbJAgBgPwJKkCgzBgDAfgSUIPmreAgoAADYh4ASJMqMAQCwHwElSC3mIlm6DgAAuzDKBsksM2YGBQAA2xBQgsRhgQAA2I+AEiTWoAAAYD8CSpAoMwYAwH4ElCCxURsAAPYjoATJfMRDFQ8AALZhlA0Si2QBALAfASVILewkCwCA7QgoQfIvkmUNCgAA9gk6oGzcuFEzZsxQdna2HA6H1qxZY17zer265557NGLECPXt21fZ2dn6wQ9+oH379gV8Rl1dnebMmaPk5GSlpqZq3rx5ampqCvvL2K2t3afWdgIKAAB2CzqgNDc3a+TIkVq2bNkJ1w4dOqRt27bpwQcf1LZt2/Taa69p165duuaaawLumzNnjnbs2KF169aprKxMGzdu1Pz580P/FhHS0uYz/80jHgAA7OMK9hcKCgpUUFDQ6bWUlBStW7cu4L2f//znuvTSS7Vnzx7l5uZq586dWrt2rbZs2aIxY8ZIkpYuXapp06bp8ccfV3Z29gmf6/F45PF4zNeNjY3BNtsS/vUnkuR28XQMAAC72D7KNjQ0yOFwKDU1VZJUUVGh1NRUM5xIUn5+vpxOpyorKzv9jNLSUqWkpJg/OTk5dje7U8cqeJxyOh1RaQMAAL2BrQGlpaVF99xzj2688UYlJydLkmpqapSRkRFwn8vlUlpammpqajr9nMWLF6uhocH82bt3r53NPik2aQMAIDKCfsTTXV6vV3//938vwzD07LPPhvVZbrdbbrfbopaFjm3uAQCIDFsCij+cfP7559qwYYM5eyJJWVlZOnDgQMD9bW1tqqurU1ZWlh3NsQwHBQIAEBmWP+Lxh5OPP/5Yb775ptLT0wOu5+Xlqb6+XlVVVeZ7GzZskM/n09ixY61ujqUOs0kbAAAREfQMSlNTk3bv3m2+rq6u1vbt25WWlqazzjpLN9xwg7Zt26aysjK1t7eb60rS0tKUkJCgYcOGaerUqbr11lu1fPlyeb1eFRcXa/bs2Z1W8MSSjotkAQCAfYIOKFu3btXEiRPN1yUlJZKkwsJCPfzww3r99dclSRdffHHA77311luaMGGCJOnll19WcXGxJk2aJKfTqVmzZunpp58O8StEjqfNf1AgMygAANgp6IAyYcIEGYZx0uunuuaXlpamVatWBfuno84/g8IaFAAA7MWziiCwBgUAgMggoASBMmMAACKDgBIEyowBAIgMAkoQzJ1kWSQLAICtCChBMMuMOSgQAABbMdIGwT+DksgMCgAAtiKgBIE1KAAARAYBJQicZgwAQGQQUIJAmTEAAJFBQAkCG7UBABAZBJQgmFvds0gWAABbEVCCYFbxUGYMAICtGGmDwEZtAABEBgElCJQZAwAQGQSUILBIFgCAyCCgdJNhGJQZAwAQIQSUbvK0+cx/swYFAAB7EVC6yV9iLFHFAwCA3Rhpu8m//iQ+ziFXHN0GAICdGGm7qYUFsgAARAwBpZsoMQYAIHIIKN3EJm0AAEQOAaWbzBJjFwEFAAC7EVC6yV/Fk8gMCgAAtiOgdNOxNSh0GQAAdmO07Sa2uQcAIHIIKN3koYoHAICIIaB0E2XGAABEDgGlmw63Hq3iYZEsAAC2I6B0U0vb0TUolBkDAGA7Ako3+cuMkxLoMgAA7MZo200trEEBACBiCCjdRJkxAACRQ0DpJk4zBgAgcggo3XT46Fk8POIBAMB+BJRuamnlNGMAACKFgNJNZpkxZ/EAAGA7RttuMk8z5hEPAAC2I6B0E1vdAwAQOQSUbqKKBwCAyCGgdFMLVTwAAEQMAaUbDMM49oiHKh4AAGxHQOkGb7uhdp8hiUc8AABEAgGlG/wlxhJlxgAARAKjbTf4N2lzOqSEOLoMAAC7Mdp2Q8cSY4fDEeXWAABw+iOgdAMnGQMAEFkElG7wlxgTUAAAiIygA8rGjRs1Y8YMZWdny+FwaM2aNQHXDcPQQw89pLPOOktJSUnKz8/Xxx9/HHBPXV2d5syZo+TkZKWmpmrevHlqamoK64vY6TAHBQIAEFFBB5Tm5maNHDlSy5Yt6/T6Y489pqefflrLly9XZWWl+vbtqylTpqilpcW8Z86cOdqxY4fWrVunsrIybdy4UfPnzw/9W9ishW3uAQCIKFewv1BQUKCCgoJOrxmGoSeffFIPPPCArr32WknSiy++qMzMTK1Zs0azZ8/Wzp07tXbtWm3ZskVjxoyRJC1dulTTpk3T448/ruzs7BM+1+PxyOPxmK8bGxuDbXZYjq1B4YkYAACRYOmIW11drZqaGuXn55vvpaSkaOzYsaqoqJAkVVRUKDU11QwnkpSfny+n06nKyspOP7e0tFQpKSnmT05OjpXN7hLn8AAAEFmWBpSamhpJUmZmZsD7mZmZ5rWamhplZGQEXHe5XEpLSzPvOd7ixYvV0NBg/uzdu9fKZneJk4wBAIisoB/xRIPb7Zbb7Y7a3/cvkmUGBQCAyLB0BiUrK0uSVFtbG/B+bW2teS0rK0sHDhwIuN7W1qa6ujrznljjaeMkYwAAIsnSgDJkyBBlZWVp/fr15nuNjY2qrKxUXl6eJCkvL0/19fWqqqoy79mwYYN8Pp/Gjh1rZXMsQ5kxAACRFfQjnqamJu3evdt8XV1dre3btystLU25ublasGCB/v3f/13nn3++hgwZogcffFDZ2dm67rrrJEnDhg3T1KlTdeutt2r58uXyer0qLi7W7NmzO63giQXsJAsAQGQFHVC2bt2qiRMnmq9LSkokSYWFhVq5cqXuvvtuNTc3a/78+aqvr9f48eO1du1aJSYmmr/z8ssvq7i4WJMmTZLT6dSsWbP09NNPW/B17EGZMQAAkeUwDMOIdiOC1djYqJSUFDU0NCg5Odn2v1fyP9v12rYvtLhgqP7pynNt/3sAAJyOghm/mRLoBnMnWdagAAAQEQSUbjDLjF0EFAAAIoGA0g3macbMoAAAEBEElG5gJ1kAACKLgNINLVTxAAAQUYy43cAMCgAAkUVA6QZOMwYAILIIKN3AVvcAAEQWAaUbzCoeZlAAAIgIAkoX2n2GWts5zRgAgEgioHTBv/5EIqAAABApBJQuHO4QUNwuugsAgEhgxO2Cf4Gs2+WU0+mIcmsAAOgdCChd8LRRwQMAQKQRULpwuJUFsgAARBoBpQuH2aQNAICII6B0gV1kAQCIPAJKF46dw0NXAQAQKYy6XWAGBQCAyCOgdME8h4eAAgBAxBBQumDOoFBmDABAxBBQunDYS5kxAACRRkDpwrEyY7oKAIBIYdTtgsfLGhQAACKNgNKFwwQUAAAijoDSBfOwQAIKAAARQ0DpAjMoAABEHgGlC/4yY04zBgAgcggop/CX/6vXOx9/JUnKTHZHuTUAAPQeBJSTOHCwRf/0qyp52ny6amiGrrwgI9pNAgCg1yCgdMLT1q7bX9qm/Q0tOvfMvnpy9sWKczqi3SwAAHoNAspxDMPQQ2t2qOrzb9Q/0aVf/mCMkhPjo90sAAB6FQLKcV6s+Fz/b+teOR3S0htH6Zwz+0W7SQAA9DoElA7e3f2V/q3sr5KkewuGasKFrDsBACAaXNFuQCxxOh1KTnRpwoUZuvXvzol2cwAA6LUIKB2MOyddb9wxXmf0c8vhYFEsAADRQkA5zsABfaLdBAAAej3WoAAAgJhDQAEAADGHgAIAAGIOAQUAAMQcAgoAAIg5BBQAABBzCCgAACDmEFAAAEDMIaAAAICYQ0ABAAAxh4ACAABiDgEFAADEHAIKAACIOT3yNGPDMCRJjY2NUW4JAADoLv+47R/HT6VHBpSDBw9KknJycqLcEgAAEKyDBw8qJSXllPc4jO7EmBjj8/m0b98+9e/fXw6Hw9LPbmxsVE5Ojvbu3avk5GRLPxuB6OvIoa8jh76OHPo6cqzqa8MwdPDgQWVnZ8vpPPUqkx45g+J0OjVw4EBb/0ZycjL/Bx8h9HXk0NeRQ19HDn0dOVb0dVczJ34skgUAADGHgAIAAGIOAeU4brdb//qv/yq32x3tppz26OvIoa8jh76OHPo6cqLR1z1ykSwAADi9MYMCAABiDgEFAADEHAIKAACIOQQUAAAQcwgoHSxbtkyDBw9WYmKixo4dq/feey/aTerxSktLdckll6h///7KyMjQddddp127dgXc09LSoqKiIqWnp6tfv36aNWuWamtro9Ti08ejjz4qh8OhBQsWmO/R19b54osv9P3vf1/p6elKSkrSiBEjtHXrVvO6YRh66KGHdNZZZykpKUn5+fn6+OOPo9jinqm9vV0PPvighgwZoqSkJJ177rn60Y9+FHCWC30duo0bN2rGjBnKzs6Ww+HQmjVrAq53p2/r6uo0Z84cJScnKzU1VfPmzVNTU1P4jTNgGIZhvPLKK0ZCQoLxwgsvGDt27DBuvfVWIzU11aitrY1203q0KVOmGCtWrDA+/PBDY/v27ca0adOM3Nxco6mpybzntttuM3Jycoz169cbW7duNcaNG2dcdtllUWx1z/fee+8ZgwcPNr797W8bd911l/k+fW2Nuro6Y9CgQcZNN91kVFZWGp9++qnxxz/+0di9e7d5z6OPPmqkpKQYa9asMf785z8b11xzjTFkyBDj8OHDUWx5z7NkyRIjPT3dKCsrM6qrq41XX33V6Nevn/HUU0+Z99DXofv9739v3H///cZrr71mSDJWr14dcL07fTt16lRj5MiRxubNm4133nnHOO+884wbb7wx7LYRUI669NJLjaKiIvN1e3u7kZ2dbZSWlkaxVaefAwcOGJKM8vJywzAMo76+3oiPjzdeffVV856dO3cakoyKiopoNbNHO3jwoHH++ecb69atM6688kozoNDX1rnnnnuM8ePHn/S6z+czsrKyjP/4j/8w36uvrzfcbrfx3//935Fo4mlj+vTpxi233BLw3syZM405c+YYhkFfW+n4gNKdvv3rX/9qSDK2bNli3vOHP/zBcDgcxhdffBFWe3jEI6m1tVVVVVXKz88333M6ncrPz1dFRUUUW3b6aWhokCSlpaVJkqqqquT1egP6fujQocrNzaXvQ1RUVKTp06cH9KlEX1vp9ddf15gxY/S9731PGRkZGjVqlH75y1+a16urq1VTUxPQ1ykpKRo7dix9HaTLLrtM69ev10cffSRJ+vOf/6xNmzapoKBAEn1tp+70bUVFhVJTUzVmzBjznvz8fDmdTlVWVob193vkYYFW++qrr9Te3q7MzMyA9zMzM/W3v/0tSq06/fh8Pi1YsECXX365LrroIklSTU2NEhISlJqaGnBvZmamampqotDKnu2VV17Rtm3btGXLlhOu0dfW+fTTT/Xss8+qpKRE9913n7Zs2aI777xTCQkJKiwsNPuzs/+m0NfBuffee9XY2KihQ4cqLi5O7e3tWrJkiebMmSNJ9LWNutO3NTU1ysjICLjucrmUlpYWdv8TUBAxRUVF+vDDD7Vp06ZoN+W0tHfvXt11111at26dEhMTo92c05rP59OYMWP04x//WJI0atQoffjhh1q+fLkKCwuj3LrTy//8z//o5Zdf1qpVq/Stb31L27dv14IFC5SdnU1fn+Z4xCPpjDPOUFxc3AnVDLW1tcrKyopSq04vxcXFKisr01tvvaWBAwea72dlZam1tVX19fUB99P3wauqqtKBAwf0ne98Ry6XSy6XS+Xl5Xr66aflcrmUmZlJX1vkrLPO0vDhwwPeGzZsmPbs2SNJZn/y35TwLVq0SPfee69mz56tESNGaO7cuVq4cKFKS0sl0dd26k7fZmVl6cCBAwHX29raVFdXF3b/E1AkJSQkaPTo0Vq/fr35ns/n0/r165WXlxfFlvV8hmGouLhYq1ev1oYNGzRkyJCA66NHj1Z8fHxA3+/atUt79uyh74M0adIkffDBB9q+fbv5M2bMGM2ZM8f8N31tjcsvv/yEcvmPPvpIgwYNkiQNGTJEWVlZAX3d2NioyspK+jpIhw4dktMZOFTFxcXJ5/NJoq/t1J2+zcvLU319vaqqqsx7NmzYIJ/Pp7Fjx4bXgLCW2J5GXnnlFcPtdhsrV640/vrXvxrz5883UlNTjZqammg3rUe7/fbbjZSUFOPtt9829u/fb/4cOnTIvOe2224zcnNzjQ0bNhhbt2418vLyjLy8vCi2+vTRsYrHMOhrq7z33nuGy+UylixZYnz88cfGyy+/bPTp08d46aWXzHseffRRIzU11fjtb39r/OUvfzGuvfZaSl9DUFhYaJx99tlmmfFrr71mnHHGGcbdd99t3kNfh+7gwYPG+++/b7z//vuGJONnP/uZ8f777xuff/65YRjd69upU6cao0aNMiorK41NmzYZ559/PmXGVlu6dKmRm5trJCQkGJdeeqmxefPmaDepx5PU6c+KFSvMew4fPmz88z//szFgwACjT58+xvXXX2/s378/eo0+jRwfUOhr67zxxhvGRRddZLjdbmPo0KHGL37xi4DrPp/PePDBB43MzEzD7XYbkyZNMnbt2hWl1vZcjY2Nxl133WXk5uYaiYmJxjnnnGPcf//9hsfjMe+hr0P31ltvdfrf6MLCQsMwute3X3/9tXHjjTca/fr1M5KTk42bb77ZOHjwYNhtcxhGh+34AAAAYgBrUAAAQMwhoAAAgJhDQAEAADGHgAIAAGIOAQUAAMQcAgoAAIg5BBQAABBzCCgAACDmEFAAAEDMIaAACMpNN90kh8Mhh8Oh+Ph4DRkyRHfffbdaWlq69ftvv/22HA7HCacqA0BHrmg3AEDPM3XqVK1YsUJer1dVVVUqLCyUw+HQT37yk4i2w+v1Kj4+PqJ/E0BkMIMCIGhut1tZWVnKycnRddddp/z8fK1bt06S5PP5VFpaqiFDhigpKUkjR47Ur3/9a0nSZ599pokTJ0qSBgwYIIfDoZtuukmSNHjwYD355JMBf+fiiy/Www8/bL52OBx69tlndc0116hv375asmSJHn74YV188cX61a9+pcGDByslJUWzZ8/WwYMHzd/79a9/rREjRigpKUnp6enKz89Xc3OzfR0EIGwEFABh+fDDD/Xuu+8qISFBklRaWqoXX3xRy5cv144dO7Rw4UJ9//vfV3l5uXJycvSb3/xGkrRr1y7t379fTz31VFB/7+GHH9b111+vDz74QLfccosk6ZNPPtGaNWtUVlamsrIylZeX69FHH5Uk7d+/XzfeeKNuueUW7dy5U2+//bZmzpwpzkkFYhuPeAAEraysTP369VNbW5s8Ho+cTqd+/vOfy+Px6Mc//rHefPNN5eXlSZLOOeccbdq0Sc8995yuvPJKpaWlSZIyMjKUmpoa9N/+x3/8R918880B7/l8Pq1cuVL9+/eXJM2dO1fr16/XkiVLtH//frW1tWnmzJkaNGiQJGnEiBFhfHsAkUBAARC0iRMn6tlnn1Vzc7OeeOIJuVwuzZo1Szt27NChQ4f03e9+N+D+1tZWjRo1ypK/PWbMmBPeGzx4sBlOJOmss87SgQMHJEkjR47UpEmTNGLECE2ZMkWTJ0/WDTfcoAEDBljSHgD2IKAACFrfvn113nnnSZJeeOEFjRw5Us8//7wuuugiSdLvfvc7nX322QG/43a7T/mZTqfzhMcuXq+30799vOMXyjocDvl8PklSXFyc1q1bp3fffVd/+tOftHTpUt1///2qrKzUkCFDuvimAKKFNSgAwuJ0OnXffffpgQce0PDhw+V2u7Vnzx6dd955AT85OTmSZK5VaW9vD/icM888U/v37zdfNzY2qrq62pI2OhwOXX755XrkkUf0/vvvKyEhQatXr7bkswHYgxkUAGH73ve+p0WLFum5557Tv/zLv2jhwoXy+XwaP368Ghoa9L//+79KTk5WYWGhBg0aJIfDobKyMk2bNk1JSUnq16+frrrqKq1cuVIzZsxQamqqHnroIcXFxYXdtsrKSq1fv16TJ09WRkaGKisr9eWXX2rYsGEWfHMAdiGgAAiby+VScXGxHnvsMVVXV+vMM89UaWmpPv30U6Wmpuo73/mO7rvvPknS2WefrUceeUT33nuvbr75Zv3gBz/QypUrtXjxYlVXV+vqq69WSkqKfvSjH1kyg5KcnKyNGzfqySefVGNjowYNGqSf/vSnKigoCPuzAdjHYVBrBwAAYgxrUAAAQMwhoAAAgJhDQAEAADGHgAIAAGIOAQUAAMQcAgoAAIg5BBQAABBzCCgAACDmEFAAAEDMIaAAAICYQ0ABAAAx5/8DJW9hIVRGyIYAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Model\n", "features = ts.utils.net.common.Net(\n", " state_shape=env.observation_space.shape, \n", " hidden_sizes=[64, 64], \n", " device=device)\n", "\n", "actor = ts.utils.net.discrete.Actor(\n", " preprocess_net=features, \n", " action_shape=env.action_space.n, \n", " device=device).to(device)\n", "\n", "critic = ts.utils.net.discrete.Critic(\n", " preprocess_net=features, \n", " device=device).to(device)\n", "\n", "actor_critic = ts.utils.net.common.ActorCritic(actor=actor, critic=critic)\n", "\n", "optim = torch.optim.Adam(actor_critic.parameters(), lr=0.003)\n", "\n", "# Policy\n", "policy = ts.policy.PPOPolicy(\n", " actor=actor, \n", " critic=critic, \n", " optim=optim,\n", " dist_fn=torch.distributions.Categorical, \n", " action_space=env.action_space,\n", " deterministic_eval=True,\n", " action_scaling=False,\n", " discount_factor=0.99,\n", " max_grad_norm=0.5,\n", " eps_clip=0.2,\n", " gae_lambda=0.95,\n", ")\n", "\n", "# Collector\n", "collector = ts.data.Collector(\n", " policy=policy, \n", " env=ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]), \n", " buffer=ts.data.VectorReplayBuffer(20000, 10),\n", " exploration_noise=True\n", ")\n", "test_collector = ts.data.Collector(\n", " policy=policy, \n", " env=ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]),\n", " exploration_noise=False\n", ")\n", "\n", "# On-policy trainer\n", "returns = []\n", "for iteration in range(100):\n", "\n", " policy.train() \n", "\n", " # Collect enough on-policy steps\n", " result = collector.collect(n_step=3000)\n", "\n", " # Update the PP0 network by learning the on-policy buffer 10 times\n", " policy.update(\n", " buffer=collector.buffer,\n", " sample_size=0, # use the whole buffer\n", " batch_size=256,\n", " repeat=10,\n", " )\n", "\n", " # Empty the buffer as we are on-policy\n", " collector.reset_buffer(keep_statistics=False)\n", "\n", " # Test 10 episodes\n", " policy.eval()\n", " result = test_collector.collect(n_episode=10)\n", " mean_reward = result['rew']\n", " #print(iteration, \":\", mean_reward)\n", " returns.append(mean_reward)\n", "\n", "plt.figure()\n", "plt.plot(np.array(returns))\n", "plt.xlabel(\"Epochs\")\n", "plt.ylabel(\"Returns\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Return: 200.0\n", "MoviePy - Building file videos/cartpole-ppo.gif with imageio.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." ] } ], "source": [ "# Evaluation mode\n", "policy.eval() \n", "\n", "# Create a recordable environment\n", "env = gym.make('CartPole-v0', render_mode=\"rgb_array_list\")\n", "recorder = GymRecorder(env)\n", "\n", "# Sample the initial state\n", "state, info = env.reset()\n", "\n", "# One episode:\n", "done = False\n", "return_episode = 0\n", "while not done:\n", " # Select an action from the learned policy\n", " action = policy.forward(ts.data.Batch(obs=[state], info=None)).act[0].numpy()\n", " # Sample a single transition\n", " next_state, reward, terminal, truncated, info = env.step(action)\n", " # End of the episode\n", " done = terminal or truncated\n", " # Update undiscounted return\n", " return_episode += reward\n", " # Go in the next state\n", " state = next_state\n", "\n", "print(\"Return:\", return_episode)\n", "\n", "recorder.record(env.render())\n", "video = \"videos/cartpole-ppo.gif\"\n", "recorder.make_video(video)\n", "ipython_display(video)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Q:** How does it compare to DQN once the right hyperparameters are found? What is their influence? Play especially with `eps_clip`, the $\\epsilon$ threshold used to clip the IS weight in the PPO loss.\n", "\n", "**Q:** Apply PPO to more complex environments available in gymnasium. Beware that for continuous action spaces, you will need to use continuous actor/critic networks, i.e.:\n", "\n", "```python\n", "actor = ts.utils.net.continuous.Actor(...)\n", "critic = ts.utils.net.continuous.Critic(...)\n", "```\n", "\n", "instead of:\n", "\n", "```python\n", "actor = ts.utils.net.discrete.Actor(...)\n", "critic = ts.utils.net.discrete.Critic(...)\n", "```\n", "\n", "The `dist_fn` argument of `PPOPolicy` must also be set accordingly. Check the doc!" ] } ], "metadata": { "kernelspec": { "display_name": "tianshou", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 2 }