{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# Replay Memory" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "\n", "\n", "Now we build the experience replay buffer which is used for storing all the agent's\n", "experience. We sample minibatch of experience from the replay buffer for training the\n", "network." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class ReplayMemoryFast:\n", "\n", "\n", " # first we define init method and initialize buffer size\n", " def __init__(self, memory_size, minibatch_size):\n", "\n", " # max number of samples to store\n", " self.memory_size = memory_size\n", "\n", " # mini batch size\n", " self.minibatch_size = minibatch_size\n", "\n", " self.experience = [None]*self.memory_size \n", " self.current_index = 0\n", " self.size = 0\n", "\n", "\n", " # next we define the function called store for storing the experiences\n", " def store(self, observation, action, reward, newobservation, is_terminal):\n", "\n", " # store the experience as a tuple (current state, action, reward, next state, is it a terminal state)\n", " self.experience[self.current_index] = (observation, action, reward, newobservation, is_terminal)\n", " self.current_index += 1\n", "\n", " self.size = min(self.size+1, self.memory_size)\n", " \n", " # if the index is greater than memory then we flush the index by subtrating it with memory size\n", "\n", " if self.current_index >= self.memory_size:\n", " self.current_index -= self.memory_size\n", "\n", "\n", "\n", " # we define a function called sample for sampling the minibatch of experience\n", "\n", " def sample(self):\n", " if self.size < self.minibatch_size:\n", " return []\n", "\n", " # first we randomly sample some indices\n", " samples_index = np.floor(np.random.random((self.minibatch_size,))*self.size)\n", "\n", " # select the experience from the sampled index\n", " samples = [self.experience[int(i)] for i in samples_index]\n", "\n", " return samples" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:anaconda]", "language": "python", "name": "conda-env-anaconda-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" } }, "nbformat": 4, "nbformat_minor": 2 }