{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# AlphaZero\n", "\n", "This notebook is based on the paper:\n", "\n", "* [Mastering the game of Go without Human Knowledge](https://deepmind.com/research/publications/mastering-game-go-without-human-knowledge/)\n", "\n", "with additional insight from:\n", "\n", "* https://applied-data.science/blog/how-to-build-your-own-alphazero-ai-using-python-and-keras/\n", "* https://github.com/AppliedDataSciencePartners/DeepReinforcementLearning\n", "* https://github.com/junxiaosong/AlphaZero_Gomoku/blob/master/mcts_alphaZero.py\n", "\n", "\n", "This code use the new [conx](http://conx.readthedocs.io/en/latest/) layer that sits on top of Keras. Conx is designed to be simpler than Keras, more intuitive, and integrated visualizations.\n", "\n", "Currently this code requires the TensorFlow backend, as it has a function specific to TF." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Game\n", "\n", "First, let's look at a specific game. We can use many, but for this demonstration we'll pick ConnectFour. There is a good code base of different games and a game engine in the [code](https://github.com/Calysto/aima3/) based on [Artificial Intelligence: A Modern Approach](http://aima.cs.berkeley.edu/).\n", "\n", "If you would like to install aima3, you can use something like this in a cell:\n", "\n", "```bash\n", "! pip install aima3 -U --user\n", "```\n", "\n", "aima3 has other games that you can play as well as ConnectFour, including TicTacToe. \n", "aima3 has many AI algorithms wrapped up to play games. You can see more details about the game engine and ConnectFour here:\n", "\n", "* https://github.com/Calysto/aima3/blob/master/notebooks/games.ipynb\n", "* https://github.com/Calysto/aima3/blob/master/notebooks/connect_four.ipynb\n", "\n", "and other resources in that repository.\n", "\n", "We import some of these that will be useful in our AlphaZero exploration:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from aima3.games import (ConnectFour, RandomPlayer, \n", " MCTSPlayer, QueryPlayer, Player,\n", " MiniMaxPlayer, AlphaBetaPlayer,\n", " AlphaBetaCutoffPlayer)\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make a game:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "game = ConnectFour()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and play a game between two random players:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random-2 is thinking...\n", "Random-2 makes action (1, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X . . . . . . \n", "Random-1 is thinking...\n", "Random-1 makes action (3, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X . O . . . . \n", "Random-2 is thinking...\n", "Random-2 makes action (5, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X . O . X . . \n", "Random-1 is thinking...\n", "Random-1 makes action (4, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X . O O X . . \n", "Random-2 is thinking...\n", "Random-2 makes action (3, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . . . . \n", "X . O O X . . \n", "Random-1 is thinking...\n", "Random-1 makes action (7, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . . . . \n", "X . O O X . O \n", "Random-2 is thinking...\n", "Random-2 makes action (7, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . . . X \n", "X . O O X . O \n", "Random-1 is thinking...\n", "Random-1 makes action (5, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . O . X \n", "X . O O X . O \n", "Random-2 is thinking...\n", "Random-2 makes action (4, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X X O . X \n", "X . O O X . O \n", "Random-1 is thinking...\n", "Random-1 makes action (1, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "O . X X O . X \n", "X . O O X . O \n", "Random-2 is thinking...\n", "Random-2 makes action (4, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . X . . . \n", "O . X X O . X \n", "X . O O X . O \n", "Random-1 is thinking...\n", "Random-1 makes action (3, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . O X . . . \n", "O . X X O . X \n", "X . O O X . O \n", "Random-2 is thinking...\n", "Random-2 makes action (1, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X . O X . . . \n", "O . X X O . X \n", "X . O O X . O \n", "Random-1 is thinking...\n", "Random-1 makes action (3, 4):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . O . . . . \n", "X . O X . . . \n", "O . X X O . X \n", "X . O O X . O \n", "Random-2 is thinking...\n", "Random-2 makes action (1, 4):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X . O . . . . \n", "X . O X . . . \n", "O . X X O . X \n", "X . O O X . O \n", "Random-1 is thinking...\n", "Random-1 makes action (1, 5):\n", ". . . . . . . \n", ". . . . . . . \n", "O . . . . . . \n", "X . O . . . . \n", "X . O X . . . \n", "O . X X O . X \n", "X . O O X . O \n", "Random-2 is thinking...\n", "Random-2 makes action (2, 1):\n", ". . . . . . . \n", ". . . . . . . \n", "O . . . . . . \n", "X . O . . . . \n", "X . O X . . . \n", "O . X X O . X \n", "X X O O X . O \n", "Random-1 is thinking...\n", "Random-1 makes action (6, 1):\n", ". . . . . . . \n", ". . . . . . . \n", "O . . . . . . \n", "X . O . . . . \n", "X . O X . . . \n", "O . X X O . X \n", "X X O O X O O \n", "Random-2 is thinking...\n", "Random-2 makes action (1, 6):\n", ". . . . . . . \n", "X . . . . . . \n", "O . . . . . . \n", "X . O . . . . \n", "X . O X . . . \n", "O . X X O . X \n", "X X O O X O O \n", "Random-1 is thinking...\n", "Random-1 makes action (3, 5):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . . \n", "X . O . . . . \n", "X . O X . . . \n", "O . X X O . X \n", "X X O O X O O \n", "Random-2 is thinking...\n", "Random-2 makes action (2, 2):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . . \n", "X . O . . . . \n", "X . O X . . . \n", "O X X X O . X \n", "X X O O X O O \n", "Random-1 is thinking...\n", "Random-1 makes action (4, 4):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . . \n", "X . O O . . . \n", "X . O X . . . \n", "O X X X O . X \n", "X X O O X O O \n", "Random-2 is thinking...\n", "Random-2 makes action (5, 3):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . . \n", "X . O O . . . \n", "X . O X X . . \n", "O X X X O . X \n", "X X O O X O O \n", "Random-1 is thinking...\n", "Random-1 makes action (5, 4):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . . \n", "X . O O O . . \n", "X . O X X . . \n", "O X X X O . X \n", "X X O O X O O \n", "Random-2 is thinking...\n", "Random-2 makes action (7, 3):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . . \n", "X . O O O . . \n", "X . O X X . X \n", "O X X X O . X \n", "X X O O X O O \n", "Random-1 is thinking...\n", "Random-1 makes action (6, 2):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . . \n", "X . O O O . . \n", "X . O X X . X \n", "O X X X O O X \n", "X X O O X O O \n", "Random-2 is thinking...\n", "Random-2 makes action (7, 4):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . . \n", "X . O O O . X \n", "X . O X X . X \n", "O X X X O O X \n", "X X O O X O O \n", "Random-1 is thinking...\n", "Random-1 makes action (7, 5):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . O \n", "X . O O O . X \n", "X . O X X . X \n", "O X X X O O X \n", "X X O O X O O \n", "Random-2 is thinking...\n", "Random-2 makes action (2, 3):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . . . O \n", "X . O O O . X \n", "X X O X X . X \n", "O X X X O O X \n", "X X O O X O O \n", "Random-1 is thinking...\n", "Random-1 makes action (5, 5):\n", ". . . . . . . \n", "X . . . . . . \n", "O . O . O . O \n", "X . O O O . X \n", "X X O X X . X \n", "O X X X O O X \n", "X X O O X O O \n", "Random-2 is thinking...\n", "Random-2 makes action (3, 6):\n", ". . . . . . . \n", "X . X . . . . \n", "O . O . O . O \n", "X . O O O . X \n", "X X O X X . X \n", "O X X X O O X \n", "X X O O X O O \n", "Random-1 is thinking...\n", "Random-1 makes action (2, 4):\n", ". . . . . . . \n", "X . X . . . . \n", "O . O . O . O \n", "X O O O O . X \n", "X X O X X . X \n", "O X X X O O X \n", "X X O O X O O \n", "***** Random-1 wins!\n" ] }, { "data": { "text/plain": [ "['Random-1']" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "game.play_game(RandomPlayer(\"Random-1\"), RandomPlayer(\"Random-2\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also play a match (a bunch of games) or even a tournament between a bunch of players.\n", "\n", "```python\n", "p1 = RandomPlayer(\"Random-1\")\n", "p2 = MiniMax(\"MiniMax-1\")\n", "p3 = AlphaBetaCutoff(\"ABCutoff-1\")\n", "\n", "game.play_matches(10, p1, p2)\n", "\n", "game.play_tournament(1, p1, p2, p3)\n", "```\n", "\n", "Can you beat RandomPlayer? Hope so!\n", "\n", "Can you beat MiniMax? No! But it takes too long.\n", "\n", "Humans enter their commands by (column, row) where column starts at 1 from left, and row starts at 1 from bottom." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [], "source": [ "# game.play_game(AlphaBetaCutoffPlayer(\"AlphaBetaCutoff\"), HumanPlayer(\"Your Name Here\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Network\n", "\n", "Net, we are going to build the same kind of network described in the AlphaZero paper.\n", "\n", "Make sure to set your Keras backend to TensorFlow for now, as we have a function that is written at that level." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n", "/usr/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n", " return f(*args, **kwds)\n", "conx, version 3.5.14\n" ] } ], "source": [ "import conx as cx\n", "from aima3.games import Game\n", "from keras import regularizers" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "## NEED TO REWRITE THIS FUNCTION IN KERAS:\n", "\n", "import tensorflow as tf\n", "\n", "def softmax_cross_entropy_with_logits(y_true, y_pred):\n", " p = y_pred\n", " pi = y_true\n", " zero = tf.zeros(shape = tf.shape(pi), dtype=tf.float32)\n", " where = tf.equal(pi, zero)\n", " negatives = tf.fill(tf.shape(pi), -100.0) \n", " p = tf.where(where, negatives, p)\n", " loss = tf.nn.softmax_cross_entropy_with_logits(labels = pi, logits = p)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Representations\n", "\n", "The state board is the most important bits of information. How to represent it? Possible ideas:\n", "\n", "* a vector of 42 values\n", "* a 6x7 matrix\n", "\n", "We decided to represent the state of the board as 2 6x7 matrices: one for representing the current player's pieces, and the other for the opponent pieces.\n", "\n", "We also need to represent actions. Possible ideas:\n", "\n", "* 7 outputs, each representing a column to drop a piece into\n", "* two outputs, one representing row, and the other column\n", "* 6x7 matrix, each representing the position on the grid\n", "* 42 outputs, each representing the position on the grid\n", "\n", "We decided to represent them as the final option: 42 outputs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The network architecture in AlphaZero is quite large, and has repeating blocks of layers. To help in the construction of the network, we define some functions" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def add_conv_block(net, input_layer):\n", " cname = net.add(cx.Conv2DLayer(\"conv2d-%d\", \n", " filters=75, \n", " kernel_size=(4,4), \n", " padding='same', \n", " use_bias=False,\n", " activation='linear', \n", " kernel_regularizer=regularizers.l2(0.0001)))\n", " bname = net.add(cx.BatchNormalizationLayer(\"batch-norm-%d\", axis=1))\n", " lname = net.add(cx.LeakyReLULayer(\"leaky-relu-%d\"))\n", " net.connect(input_layer, cname)\n", " net.connect(cname, bname)\n", " net.connect(bname, lname)\n", " return lname\n", "\n", "def add_residual_block(net, input_layer):\n", " prev_layer = add_conv_block(net, input_layer)\n", " cname = net.add(cx.Conv2DLayer(\"conv2d-%d\",\n", " filters=75,\n", " kernel_size=(4,4),\n", " padding='same',\n", " use_bias=False,\n", " activation='linear',\n", " kernel_regularizer=regularizers.l2(0.0001)))\n", " bname = net.add(cx.BatchNormalizationLayer(\"batch-norm-%d\", axis=1))\n", " aname = net.add(cx.AddLayer(\"add-%d\"))\n", " lname = net.add(cx.LeakyReLULayer(\"leaky-relu-%d\"))\n", " net.connect(prev_layer, cname)\n", " net.connect(cname, bname)\n", " net.connect(input_layer, aname)\n", " net.connect(bname, aname)\n", " net.connect(aname, lname)\n", " return lname\n", "\n", "def add_value_block(net, input_layer):\n", " l1 = net.add(cx.Conv2DLayer(\"conv2d-%d\",\n", " filters=1,\n", " kernel_size=(1,1),\n", " padding='same',\n", " use_bias=False,\n", " activation='linear',\n", " kernel_regularizer=regularizers.l2(0.0001)))\n", " l2 = net.add(cx.BatchNormalizationLayer(\"batch-norm-%d\", axis=1))\n", " l3 = net.add(cx.LeakyReLULayer(\"leaky-relu-%d\"))\n", " l4 = net.add(cx.FlattenLayer(\"flatten-%d\"))\n", " l5 = net.add(cx.Layer(\"dense-%d\",\n", " 20,\n", " use_bias=False,\n", " activation='linear',\n", " kernel_regularizer=regularizers.l2(0.0001)))\n", " l6 = net.add(cx.LeakyReLULayer(\"leaky-relu-%d\"))\n", " l7 = net.add(cx.Layer('value_head',\n", " 1,\n", " use_bias=False,\n", " activation='tanh',\n", " kernel_regularizer=regularizers.l2(0.0001)))\n", " net.connect(input_layer, l1)\n", " net.connect(l1, l2)\n", " net.connect(l2, l3)\n", " net.connect(l3, l4)\n", " net.connect(l4, l5)\n", " net.connect(l5, l6)\n", " net.connect(l6, l7)\n", " return l7\n", "\n", "def add_policy_block(net, input_layer):\n", " l1 = net.add(cx.Conv2DLayer(\"conv2d-%d\",\n", " filters=2,\n", " kernel_size=(1,1),\n", " padding='same',\n", " use_bias=False,\n", " activation='linear',\n", " kernel_regularizer = regularizers.l2(0.0001)))\n", " l2 = net.add(cx.BatchNormalizationLayer(\"batch-norm-%d\", axis=1))\n", " l3 = net.add(cx.LeakyReLULayer(\"leaky-relu-%d\"))\n", " l4 = net.add(cx.FlattenLayer(\"flatten-%d\"))\n", " l5 = net.add(cx.Layer('policy_head',\n", " 42,\n", " use_bias=False,\n", " activation='linear',\n", " kernel_regularizer=regularizers.l2(0.0001)))\n", " net.connect(input_layer, l1)\n", " net.connect(l1, l2)\n", " net.connect(l2, l3)\n", " net.connect(l3, l4)\n", " net.connect(l4, l5)\n", " return l5" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def make_network(game, residuals=5):\n", " net = cx.Network(\"AlphaZero Network\")\n", " net.add(cx.Layer(\"main_input\", (game.v, game.h, 2)))\n", " out_layer = add_conv_block(net, \"main_input\")\n", " for i in range(residuals):\n", " out_layer = add_residual_block(net, out_layer)\n", " add_policy_block(net, out_layer)\n", " add_value_block(net, out_layer)\n", " net.compile(loss={'value_head': 'mean_squared_error', \n", " 'policy_head': softmax_cross_entropy_with_logits},\n", " optimizer=cx.SGD(lr=0.1, momentum=0.9),\n", " loss_weights={'value_head': 0.5, \n", " 'policy_head': 0.5})\n", " for layer in net.layers:\n", " if layer.kind() == \"hidden\":\n", " layer.visible = False\n", " return net" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "game = ConnectFour()\n", "net = make_network(game)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", "main_input (InputLayer) (None, 6, 7, 2) 0 \n", "__________________________________________________________________________________________________\n", "conv2d-1 (Conv2D) (None, 6, 7, 75) 2400 main_input[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-1 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-1[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-1 (LeakyReLU) (None, 6, 7, 75) 0 batch-norm-1[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-2 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-1[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-2 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-2[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-2 (LeakyReLU) (None, 6, 7, 75) 0 batch-norm-2[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-3 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-2[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-3 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-3[0][0] \n", "__________________________________________________________________________________________________\n", "add-1 (Add) (None, 6, 7, 75) 0 leaky-relu-1[0][0] \n", " batch-norm-3[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-3 (LeakyReLU) (None, 6, 7, 75) 0 add-1[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-4 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-3[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-4 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-4[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-4 (LeakyReLU) (None, 6, 7, 75) 0 batch-norm-4[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-5 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-4[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-5 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-5[0][0] \n", "__________________________________________________________________________________________________\n", "add-2 (Add) (None, 6, 7, 75) 0 leaky-relu-3[0][0] \n", " batch-norm-5[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-5 (LeakyReLU) (None, 6, 7, 75) 0 add-2[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-6 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-5[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-6 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-6[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-6 (LeakyReLU) (None, 6, 7, 75) 0 batch-norm-6[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-7 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-6[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-7 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-7[0][0] \n", "__________________________________________________________________________________________________\n", "add-3 (Add) (None, 6, 7, 75) 0 leaky-relu-5[0][0] \n", " batch-norm-7[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-7 (LeakyReLU) (None, 6, 7, 75) 0 add-3[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-8 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-7[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-8 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-8[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-8 (LeakyReLU) (None, 6, 7, 75) 0 batch-norm-8[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-9 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-8[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-9 (BatchNormalizatio (None, 6, 7, 75) 24 conv2d-9[0][0] \n", "__________________________________________________________________________________________________\n", "add-4 (Add) (None, 6, 7, 75) 0 leaky-relu-7[0][0] \n", " batch-norm-9[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-9 (LeakyReLU) (None, 6, 7, 75) 0 add-4[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-10 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-9[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-10 (BatchNormalizati (None, 6, 7, 75) 24 conv2d-10[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-10 (LeakyReLU) (None, 6, 7, 75) 0 batch-norm-10[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-11 (Conv2D) (None, 6, 7, 75) 90000 leaky-relu-10[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-11 (BatchNormalizati (None, 6, 7, 75) 24 conv2d-11[0][0] \n", "__________________________________________________________________________________________________\n", "add-5 (Add) (None, 6, 7, 75) 0 leaky-relu-9[0][0] \n", " batch-norm-11[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-11 (LeakyReLU) (None, 6, 7, 75) 0 add-5[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-13 (Conv2D) (None, 6, 7, 1) 75 leaky-relu-11[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-13 (BatchNormalizati (None, 6, 7, 1) 24 conv2d-13[0][0] \n", "__________________________________________________________________________________________________\n", "conv2d-12 (Conv2D) (None, 6, 7, 2) 150 leaky-relu-11[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-13 (LeakyReLU) (None, 6, 7, 1) 0 batch-norm-13[0][0] \n", "__________________________________________________________________________________________________\n", "batch-norm-12 (BatchNormalizati (None, 6, 7, 2) 24 conv2d-12[0][0] \n", "__________________________________________________________________________________________________\n", "flatten-2 (Flatten) (None, 42) 0 leaky-relu-13[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-12 (LeakyReLU) (None, 6, 7, 2) 0 batch-norm-12[0][0] \n", "__________________________________________________________________________________________________\n", "dense-1 (Dense) (None, 20) 840 flatten-2[0][0] \n", "__________________________________________________________________________________________________\n", "flatten-1 (Flatten) (None, 84) 0 leaky-relu-12[0][0] \n", "__________________________________________________________________________________________________\n", "leaky-relu-14 (LeakyReLU) (None, 20) 0 dense-1[0][0] \n", "__________________________________________________________________________________________________\n", "policy_head (Dense) (None, 42) 3528 flatten-1[0][0] \n", "__________________________________________________________________________________________________\n", "value_head (Dense) (None, 1) 20 leaky-relu-14[0][0] \n", "==================================================================================================\n", "Total params: 907,325\n", "Trainable params: 907,169\n", "Non-trainable params: 156\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "net.model.summary()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "51" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(net.layers)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " AlphaZero NetworkLayer: policy_head (output)\n", " shape = (42,)\n", " Keras class = Dense\n", " use_bias = False\n", " activation = linear\n", " kernel_regularizer = <keras.regularizers.L1L2 object at 0x7ffb46395080>policy_headLayer: value_head (output)\n", " shape = (1,)\n", " Keras class = Dense\n", " use_bias = False\n", " activation = tanh\n", " kernel_regularizer = <keras.regularizers.L1L2 object at 0x7ffb46395da0>value_headWeights from flatten-1 to policy_head\n", " policy_head/kernel:0 has shape (84, 42)Weights from flatten-1 to policy_head\n", " policy_head/kernel:0 has shape (84, 42)Layer: main_input (input)\n", " shape = (6, 7, 2)\n", " Keras class = Inputmain_input20" ], "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.render()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Connecting the Network to the Game" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we need a mapping from game (x,y) moves to a position in a list of actions and probabilities." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def make_mappings(game):\n", " \"\"\"\n", " Get a mapping from game's (x,y) to array position.\n", " \"\"\"\n", " move2pos = {}\n", " pos2move = []\n", " position = 0\n", " for y in range(game.v, 0, -1):\n", " for x in range(1, game.h + 1):\n", " move2pos[(x,y)] = position\n", " pos2move.append((x,y))\n", " position += 1\n", " return move2pos, pos2move" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use the connectFour game, defined above:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "move2pos, pos2move = make_mappings(game)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "36" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "move2pos[(2,1)]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 1)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pos2move[35]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Need a method of converting a list of state moves into an array:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def state2array(game, state):\n", " array = []\n", " to_move = game.to_move(state)\n", " for y in range(game.v, 0, -1):\n", " for x in range(1, game.h + 1):\n", " item = state.board.get((x, y), 0)\n", " if item != 0:\n", " item = 1 if item == to_move else -1\n", " array.append(item)\n", " return array" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(42,)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cx.shape(state2array(game, game.initial))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So, state2array returns a list of 42 numbers, where:\n", "\n", "* 0 represents an empty place\n", "* 1 represents one of my pieces\n", "* -1 represents one of my opponent's pieces\n", "\n", "Note that \"my\" and \"my opponent\" may swap back and forth depending on perspective (ie, whose turn it is, as determined by game.to_move(state))." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def state2inputs(game, state):\n", " board = np.array(state2array(game, state)) # 1 is my pieces, -1 other\n", " currentplayer_position = np.zeros(len(board), dtype=np.int)\n", " currentplayer_position[board==1] = 1\n", " other_position = np.zeros(len(board), dtype=np.int)\n", " other_position[board==-1] = 1\n", " position = np.array(list(zip(currentplayer_position,other_position)))\n", " inputs = position.reshape((game.v, game.h, 2))\n", " return inputs.tolist()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to convert the state's board into a form for the neural network:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],\n", " [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],\n", " [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],\n", " [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],\n", " [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],\n", " [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state2inputs(game, game.initial)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can check to see if this is correct by propagating the activations to the first layer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initial board state has no pieces on the board:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

Feature 0

Feature 1
" ], "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state = game.initial\n", "net.propagate_to_features(\"main_input\", state2inputs(game, state))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we make a move to (1,1). But note that after the move, it is now the other player's move. So the first move is seen on the opponent's board (the right side, feature #1):" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

Feature 0

Feature 1
" ], "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state = game.result(game.initial, (1,1))\n", "net.propagate_to_features(\"main_input\", state2inputs(game, state))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, the second player moves to (2,1). Now we are back to the original perspective, and so the right-hand board is on the left, because that is now the current player's perspective." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

Feature 0

Feature 1
" ], "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "state = game.result(state, (3,1))\n", "net.propagate_to_features(\"main_input\", state2inputs(game, state))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we are ready to connect the game to the network. We define a function `get_predictions` that takes a game and state, and propagates it through the network returning a (value, probabilities, allowedActions). The probabilities are the pi list from the AlphaZero paper." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def get_predictions(net, game, state):\n", " \"\"\"\n", " Given a state, give output of network on preferred\n", " actions. state.allowedActions removes impossible\n", " actions.\n", "\n", " Returns (value, probabilties, allowedActions)\n", " \"\"\"\n", " board = np.array(state2array(game, state)) # 1 is my pieces, -1 other\n", " inputs = state2inputs(game, state)\n", " preds = net.propagate(inputs, visualize=True)\n", " value = preds[1][0]\n", " logits = np.array(preds[0])\n", " allowedActions = np.array([move2pos[act] for act in game.actions(state)])\n", " mask = np.ones(len(board), dtype=bool)\n", " mask[allowedActions] = False\n", " logits[mask] = -100\n", " #SOFTMAX\n", " odds = np.exp(logits)\n", " probs = odds / np.sum(odds) ###put this just before the for?\n", " return (value, probs.tolist(), allowedActions.tolist())" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "value, probs, acts = get_predictions(net, game, state)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " AlphaZero NetworkLayer: policy_head (output)\n", " shape = (42,)\n", " Keras class = Dense\n", " use_bias = False\n", " activation = linear\n", " kernel_regularizer = <keras.regularizers.L1L2 object at 0x7ffb46395080>policy_headLayer: value_head (output)\n", " shape = (1,)\n", " Keras class = Dense\n", " use_bias = False\n", " activation = tanh\n", " kernel_regularizer = <keras.regularizers.L1L2 object at 0x7ffb46395da0>value_headWeights from flatten-1 to policy_head\n", " policy_head/kernel:0 has shape (84, 42)Weights from flatten-1 to policy_head\n", " policy_head/kernel:0 has shape (84, 42)Layer: main_input (input)\n", " shape = (6, 7, 2)\n", " Keras class = Inputmain_input20" ], "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.snapshot(state2inputs(game, state))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing Game and Network Integration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we turn the predictions into a move, and we can play a game with the network." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "class NNPlayer(Player):\n", "\n", " def set_game(self, game):\n", " \"\"\"\n", " Get a mapping from game's (x,y) to array position.\n", " \"\"\"\n", " self.net = make_network(game)\n", " self.game = game\n", " self.move2pos = {}\n", " self.pos2move = []\n", " position = 0\n", " for y in range(self.game.v, 0, -1):\n", " for x in range(1, self.game.h + 1):\n", " self.move2pos[(x,y)] = position\n", " self.pos2move.append((x,y))\n", " position += 1\n", "\n", " def get_predictions(self, state):\n", " \"\"\"\n", " Given a state, give output of network on preferred\n", " actions. state.allowedActions removes impossible\n", " actions.\n", "\n", " Returns (value, probabilties, allowedActions)\n", " \"\"\"\n", " board = np.array(self.state2array(state)) # 1 is my pieces, -1 other\n", " inputs = self.state2inputs(state)\n", " preds = self.net.propagate(inputs)\n", " value = preds[1][0]\n", " logits = np.array(preds[0])\n", " allowedActions = np.array([self.move2pos[act] for act in self.game.actions(state)])\n", " mask = np.ones(len(board), dtype=bool)\n", " mask[allowedActions] = False\n", " logits[mask] = -100\n", " #SOFTMAX\n", " odds = np.exp(logits)\n", " probs = odds / np.sum(odds) \n", " return (value, probs.tolist(), allowedActions.tolist())\n", " \n", " def get_action(self, state, turn):\n", " value, probabilities, moves = self.get_predictions(state)\n", " probs = np.array(probabilities)[moves]\n", " pos = cx.choice(moves, probs)\n", " return self.pos2move[pos]\n", "\n", " def state2inputs(self, state):\n", " board = np.array(self.state2array(state)) # 1 is my pieces, -1 other\n", " currentplayer_position = np.zeros(len(board), dtype=np.int)\n", " currentplayer_position[board==1] = 1\n", " other_position = np.zeros(len(board), dtype=np.int)\n", " other_position[board==-1] = 1\n", " position = np.array(list(zip(currentplayer_position,other_position)))\n", " inputs = position.reshape((self.game.v, self.game.h, 2))\n", " return inputs\n", "\n", " def state2array(self, state):\n", " array = []\n", " to_move = self.game.to_move(state)\n", " for y in range(self.game.v, 0, -1):\n", " for x in range(1, self.game.h + 1):\n", " item = state.board.get((x, y), 0)\n", " if item != 0:\n", " item = 1 if item == to_move else -1\n", " array.append(item)\n", " return array" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "p1 = RandomPlayer(\"Random\")\n", "p2 = NNPlayer(\"NNPlayer\")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "p2.set_game(game)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2, 1)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p2.get_action(state, 2)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NNPlayer is thinking...\n", "NNPlayer makes action (2, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". X . . . . . \n", "Random is thinking...\n", "Random makes action (4, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". X . O . . . \n", "NNPlayer is thinking...\n", "NNPlayer makes action (5, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". X . O X . . \n", "Random is thinking...\n", "Random makes action (4, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . O . . . \n", ". X . O X . . \n", "NNPlayer is thinking...\n", "NNPlayer makes action (5, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . O X . . \n", ". X . O X . . \n", "Random is thinking...\n", "Random makes action (5, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . O . . \n", ". . . O X . . \n", ". X . O X . . \n", "NNPlayer is thinking...\n", "NNPlayer makes action (3, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . O . . \n", ". . . O X . . \n", ". X X O X . . \n", "Random is thinking...\n", "Random makes action (2, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . O . . \n", ". O . O X . . \n", ". X X O X . . \n", "NNPlayer is thinking...\n", "NNPlayer makes action (3, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . O . . \n", ". O X O X . . \n", ". X X O X . . \n", "Random is thinking...\n", "Random makes action (7, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . O . . \n", ". O X O X . . \n", ". X X O X . O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (3, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . O . . \n", ". O X O X . . \n", ". X X O X . O \n", "Random is thinking...\n", "Random makes action (6, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . O . . \n", ". O X O X . . \n", ". X X O X O O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (1, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . O . . \n", ". O X O X . . \n", "X X X O X O O \n", "Random is thinking...\n", "Random makes action (7, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . O . . \n", ". O X O X . O \n", "X X X O X O O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (2, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". X X . O . . \n", ". O X O X . O \n", "X X X O X O O \n", "Random is thinking...\n", "Random makes action (3, 4):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . O . . . . \n", ". X X . O . . \n", ". O X O X . O \n", "X X X O X O O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (1, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . O . . . . \n", ". X X . O . . \n", "X O X O X . O \n", "X X X O X O O \n", "Random is thinking...\n", "Random makes action (5, 4):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . O . O . . \n", ". X X . O . . \n", "X O X O X . O \n", "X X X O X O O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (3, 5):\n", ". . . . . . . \n", ". . . . . . . \n", ". . X . . . . \n", ". . O . O . . \n", ". X X . O . . \n", "X O X O X . O \n", "X X X O X O O \n", "Random is thinking...\n", "Random makes action (5, 5):\n", ". . . . . . . \n", ". . . . . . . \n", ". . X . O . . \n", ". . O . O . . \n", ". X X . O . . \n", "X O X O X . O \n", "X X X O X O O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (1, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . X . O . . \n", ". . O . O . . \n", "X X X . O . . \n", "X O X O X . O \n", "X X X O X O O \n", "Random is thinking...\n", "Random makes action (3, 6):\n", ". . . . . . . \n", ". . O . . . . \n", ". . X . O . . \n", ". . O . O . . \n", "X X X . O . . \n", "X O X O X . O \n", "X X X O X O O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (5, 6):\n", ". . . . . . . \n", ". . O . X . . \n", ". . X . O . . \n", ". . O . O . . \n", "X X X . O . . \n", "X O X O X . O \n", "X X X O X O O \n", "Random is thinking...\n", "Random makes action (6, 2):\n", ". . . . . . . \n", ". . O . X . . \n", ". . X . O . . \n", ". . O . O . . \n", "X X X . O . . \n", "X O X O X O O \n", "X X X O X O O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (2, 4):\n", ". . . . . . . \n", ". . O . X . . \n", ". . X . O . . \n", ". X O . O . . \n", "X X X . O . . \n", "X O X O X O O \n", "X X X O X O O \n", "Random is thinking...\n", "Random makes action (4, 3):\n", ". . . . . . . \n", ". . O . X . . \n", ". . X . O . . \n", ". X O . O . . \n", "X X X O O . . \n", "X O X O X O O \n", "X X X O X O O \n", "NNPlayer is thinking...\n", "NNPlayer makes action (1, 4):\n", ". . . . . . . \n", ". . O . X . . \n", ". . X . O . . \n", "X X O . O . . \n", "X X X O O . . \n", "X O X O X O O \n", "X X X O X O O \n", "***** NNPlayer wins!\n" ] }, { "data": { "text/plain": [ "['NNPlayer']" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "game.play_game(p1, p2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training The Network\n", "\n", "Now we are ready to train the network. The training is a clever use of Monte Carlo Tree Search, combined with playing against itself.\n", "\n", "There is a [Monte Carlo Tree Search player](https://github.com/Calysto/aima3/blob/master/notebooks/monte_carlo_tree_search.ipynb) in aima3 that we will use. We set the policy to come from predictions from the neural network." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "class AlphaZeroMCTSPlayer(MCTSPlayer):\n", " \"\"\"\n", " A Monte Carlo Tree Search with policy function from\n", " neural network. Network will be set later to self.nnplayer.\n", " \"\"\"\n", " def policy(self, game, state):\n", " # these moves are positions:\n", " value, probs_all, moves = self.nnplayer.get_predictions(state)\n", " if len(moves) == 0:\n", " result = [], value\n", " else:\n", " probs = np.array(probs_all)[moves]\n", " moves = [self.nnplayer.pos2move[pos] for pos in moves]\n", " # we need to return probs and moves for game\n", " result = [(act, prob) for (act, prob) in list(zip(moves, probs))], value\n", " return result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main AlphaZeroPlayer needs to be able to play in one of two modes:\n", "\n", "* self_play: it plays against itself (using two different MCTS, as this version requires it). The network provides policy evaulation for each state is it looks ahead.\n", "* regular play: moves come directly from the network" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "class AlphaZeroPlayer(NNPlayer):\n", " ## Load weights if continuing\n", " def __init__(self, name, n_playout=40, *args, **kwargs):\n", " super().__init__(name, *args, **kwargs)\n", " self.mcts_players = [AlphaZeroMCTSPlayer(\"MCTS-1\", n_playout=n_playout),\n", " AlphaZeroMCTSPlayer(\"MCTS-2\", n_playout=n_playout)]\n", " \n", " def set_game(self, game):\n", " super().set_game(game)\n", " self.mcts_players[0].set_game(game)\n", " self.mcts_players[1].set_game(game)\n", " self.mcts_players[0].nnplayer = self\n", " self.mcts_players[1].nnplayer = self\n", " self.data = [[], []]\n", " self.cache = {}\n", " \n", " def get_action(self, state, turn, self_play):\n", " if self_play:\n", " ## Only way to determine which is which?\n", " if turn in self.cache:\n", " player_num = 1\n", " else:\n", " player_num = 0\n", " self.cache[turn] = True\n", " ## now use the policy to get some probs:\n", " move, pi = self.mcts_players[player_num].get_action(state, round(turn), return_prob=True)\n", " ## save the state and probs:\n", " self.data[player_num].append((self.state2inputs(state), self.move_probs2all_probs(pi)))\n", " return move\n", " else:\n", " # play the network, were're in the playoffs!\n", " return super().get_action(state, round(turn))\n", "\n", " def move_probs2all_probs(self, move_probs):\n", " all_probs = np.zeros(len(self.state2array(game.initial)))\n", " for move in move_probs:\n", " all_probs[self.move2pos[move]] = move_probs[move]\n", " return all_probs.tolist()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now set up the game to play in one of the two modes. \n", "\n", "One complication when playing itself: the system isn't sure which one it is, and we want to separate the two plays! To keep track, we cache the turn; if we see the same turn again, then we know it is the second." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "class AlphaZeroGame(ConnectFour):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " self.memory = []\n", " \n", " def play_game(self, *players, flip_coin=False, verbose=1, **kwargs):\n", " results = super().play_game(*players, flip_coin=flip_coin, verbose=verbose, **kwargs)\n", " if \"self_play\" in kwargs and kwargs[\"self_play\"]:\n", " ## Do not allow flipping coins when self play:\n", " ## Assumes that player1 == player2 when self-playing\n", " assert flip_coin is False, \"no coin_flip when self-playing\"\n", " ## value is in terms of player 0\n", " value = self.final_utility\n", " for state, probs in players[0].data[0]:\n", " self.memory.append([state, [probs, [value]]])\n", " # also data from opponent, so flip value:\n", " value = -value\n", " for state, probs in players[1].data[1]:\n", " self.memory.append([state, [probs, [value]]])\n", " return results" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "game = AlphaZeroGame()\n", "best_player = AlphaZeroPlayer(\"best_player\")\n", "current_player = AlphaZeroPlayer(\"current_player\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some basic tests to make sure things are going in the right place:" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(5, 1)\n", "(4, 1)\n", "(3, 1)\n" ] } ], "source": [ "current_player.set_game(game)\n", "assert current_player.data == [[], []]\n", "print(current_player.get_action(game.initial, 1, self_play=False))\n", "assert current_player.data == [[], []]\n", "print(current_player.get_action(game.initial, 1, self_play=True))\n", "assert current_player.data[0] != []\n", "print(current_player.get_action(game.initial, 1, self_play=True))\n", "assert current_player.data[1] != []" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample just for testing:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tournament to begin with 2 matches...\n", "best_player is thinking...\n", "best_player makes action (3, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . . . . \n", "best_player is thinking...\n", "best_player makes action (7, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X . . . O \n", "best_player is thinking...\n", "best_player makes action (2, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". X X . . . O \n", "best_player is thinking...\n", "best_player makes action (6, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". X X . . O O \n", "best_player is thinking...\n", "best_player makes action (1, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X X X . . O O \n", "best_player is thinking...\n", "best_player makes action (4, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X X X O . O O \n", "best_player is thinking...\n", "best_player makes action (5, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X X X O X O O \n", "best_player is thinking...\n", "best_player makes action (4, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . O . . . \n", "X X X O X O O \n", "best_player is thinking...\n", "best_player makes action (3, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . X O . . . \n", "X X X O X O O \n", "best_player is thinking...\n", "best_player makes action (4, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . O . . . \n", ". . X O . . . \n", "X X X O X O O \n", "best_player is thinking...\n", "best_player makes action (7, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . O . . . \n", ". . X O . . X \n", "X X X O X O O \n", "best_player is thinking...\n", "best_player makes action (4, 4):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . O . . . \n", ". . . O . . . \n", ". . X O . . X \n", "X X X O X O O \n", "***** best_player wins!\n", "best_player is thinking...\n", "best_player makes action (7, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . X \n", "best_player is thinking...\n", "best_player makes action (2, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . . . . X \n", "best_player is thinking...\n", "best_player makes action (4, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . X . . X \n", "best_player is thinking...\n", "best_player makes action (2, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . . . . . \n", ". O . X . . X \n", "best_player is thinking...\n", "best_player makes action (1, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . . . . . \n", "X O . X . . X \n", "best_player is thinking...\n", "best_player makes action (7, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . . . . O \n", "X O . X . . X \n", "best_player is thinking...\n", "best_player makes action (4, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . X . . O \n", "X O . X . . X \n", "best_player is thinking...\n", "best_player makes action (3, 1):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . X . . O \n", "X O O X . . X \n", "best_player is thinking...\n", "best_player makes action (1, 2):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X O . X . . O \n", "X O O X . . X \n", "best_player is thinking...\n", "best_player makes action (2, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . . . . . \n", "X O . X . . O \n", "X O O X . . X \n", "best_player is thinking...\n", "best_player makes action (1, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X O . . . . . \n", "X O . X . . O \n", "X O O X . . X \n", "best_player is thinking...\n", "best_player makes action (4, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X O . O . . . \n", "X O . X . . O \n", "X O O X . . X \n", "best_player is thinking...\n", "best_player makes action (7, 3):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", "X O . O . . X \n", "X O . X . . O \n", "X O O X . . X \n", "best_player is thinking...\n", "best_player makes action (2, 4):\n", ". . . . . . . \n", ". . . . . . . \n", ". . . . . . . \n", ". O . . . . . \n", "X O . O . . X \n", "X O . X . . O \n", "X O O X . . X \n", "***** best_player wins!\n" ] }, { "data": { "text/plain": [ "{'DRAW': 0, 'best_player': 2}" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "game.play_tournament(1, best_player, best_player, verbose=1, mode=\"ordered\", self_play=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Did we collect some history?" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "26" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(game.memory)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ok, we are ready to learn!" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "config = dict(\n", " MINIMUM_MEMORY_SIZE_BEFORE_TRAINING = 1000, # min size of memory\n", " TRAINING_EPOCHS_PER_CYCLE = 500, # training on current network\n", " CYCLES = 1, # number of cycles to run\n", " SELF_PLAY_MATCHES = 1, # matches to test yo' self per self-play round\n", " TOURNAMENT_MATCHES = 2, # plays each player as first mover per match, so * 2\n", " BEST_SWAP_PERCENT = 1.0, # you must be this much better than best\n", ")" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "def alphazero_train(config):\n", " ## Uses global game, best_player, and current_player\n", " for cycle in range(config[\"CYCLES\"]):\n", " print(\"Epoch #%s...\" % cycle)\n", " # self-play, collect data:\n", " print(\"Self-play matches begin...\")\n", " while len(game.memory) < config[\"MINIMUM_MEMORY_SIZE_BEFORE_TRAINING\"]:\n", " results = game.play_tournament(config[\"SELF_PLAY_MATCHES\"], \n", " best_player, best_player, \n", " mode=\"ordered\", self_play=True)\n", " print(\"Memory size is %s\" % len(game.memory))\n", " print(\"Enough to train!\")\n", " current_player.net.dataset.clear()\n", " current_player.net.dataset.load(game.memory)\n", " print(\"Training on \", len(current_player.net.dataset.inputs), \"patterns...\")\n", " current_player.net.train(config[\"TRAINING_EPOCHS_PER_CYCLE\"], \n", " batch_size=len(game.memory),\n", " plot=True)\n", " ## save dataset every once in a while\n", " ## now see which net is better:\n", " print(\"Playing best vs current to see who wins the title...\")\n", " results = game.play_tournament(config[\"TOURNAMENT_MATCHES\"], \n", " best_player, current_player, \n", " mode=\"one-each\", self_play=False)\n", " if results[\"current_player\"] > results[\"best_player\"] * config[\"BEST_SWAP_PERCENT\"]:\n", " print(\"current won! swapping weights\")\n", " # give the better weights to the best_player\n", " best_player.net.set_weights(\n", " current_player.net.get_weights())\n", " game.memory = []\n", " else:\n", " print(\"best won!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch #0...\n", "Self-play matches begin...\n", "Memory size is 171\n", "Memory size is 208\n", "Memory size is 247\n", "Memory size is 284\n", "Memory size is 310\n", "Memory size is 349\n", "Memory size is 387\n", "Memory size is 425\n", "Memory size is 455\n", "Memory size is 481\n", "Memory size is 541\n", "Memory size is 571\n", "Memory size is 608\n", "Memory size is 657\n", "Memory size is 703\n", "Memory size is 744\n", "Memory size is 788\n", "Memory size is 833\n", "Memory size is 878\n", "Memory size is 920\n", "Memory size is 979\n", "Memory size is 1028\n", "Enough to train!\n", "Training on 1028 patterns...\n", "Training...\n", " | Training | policy | value \n", "Epochs | Error | head acc | head acc \n", "------ | --------- | --------- | --------- \n", "# 801 | 0.39601 | 0.00000 | 0.92708 \n", "# 802 | 1.03875 | 0.00000 | 0.27140 \n", "# 803 | 0.87153 | 0.00000 | 0.15953 \n", "# 804 | 0.79308 | 0.00000 | 0.18191 \n", "# 805 | 0.73491 | 0.00000 | 0.29475 \n", "# 806 | 0.73332 | 0.00000 | 0.35019 \n", "# 807 | 0.70390 | 0.00000 | 0.36868 \n", "# 808 | 0.67224 | 0.00000 | 0.40467 \n", "# 809 | 0.64956 | 0.00000 | 0.46790 \n", "# 810 | 0.65823 | 0.00000 | 0.49222 \n", "# 811 | 0.64335 | 0.00000 | 0.49611 \n", "# 812 | 0.68785 | 0.00000 | 0.47374 \n", "# 813 | 0.64451 | 0.00000 | 0.50486 \n", "# 814 | 0.62052 | 0.00000 | 0.51946 \n", "# 815 | 0.65647 | 0.00000 | 0.47179 \n", "# 816 | 0.71668 | 0.00000 | 0.45039 \n", "# 817 | 0.69026 | 0.00000 | 0.45525 \n", "# 818 | 0.62992 | 0.00000 | 0.53794 \n", "# 819 | 0.59203 | 0.00000 | 0.61479 \n", "# 820 | 0.62125 | 0.00000 | 0.60019 \n", "# 821 | 0.59042 | 0.00000 | 0.62257 \n", "# 822 | 0.56548 | 0.00000 | 0.64008 \n", "# 823 | 0.57956 | 0.00000 | 0.68288 \n", "# 824 | 0.60517 | 0.00000 | 0.60798 \n", "# 825 | 0.56350 | 0.00000 | 0.65953 \n", "# 826 | 0.54747 | 0.00000 | 0.71304 \n", "# 827 | 0.54154 | 0.00000 | 0.73054 \n", "# 828 | 0.57241 | 0.00000 | 0.69553 \n", "# 829 | 0.56825 | 0.00000 | 0.70914 \n", "# 830 | 0.55524 | 0.00000 | 0.69747 \n", "# 831 | 0.54866 | 0.00000 | 0.71109 \n", "# 832 | 0.54924 | 0.00000 | 0.71693 \n", "# 833 | 0.54617 | 0.00000 | 0.73249 \n", "# 834 | 0.53730 | 0.00000 | 0.73638 \n", "# 835 | 0.53897 | 0.00000 | 0.74125 \n", "# 836 | 0.54266 | 0.00000 | 0.72860 \n", "# 837 | 0.52672 | 0.00000 | 0.75486 \n", "# 838 | 0.53487 | 0.00000 | 0.77626 \n", "# 839 | 0.53721 | 0.00000 | 0.73054 \n", "# 840 | 0.56242 | 0.00000 | 0.70623 \n", "# 841 | 0.54416 | 0.00000 | 0.71693 \n", "# 842 | 0.53821 | 0.00000 | 0.73054 \n", "# 843 | 0.52338 | 0.00000 | 0.76654 \n", "# 844 | 0.52326 | 0.00000 | 0.77335 \n", "# 845 | 0.53004 | 0.00000 | 0.75973 \n", "# 846 | 0.52644 | 0.00000 | 0.77918 \n", "# 847 | 0.52099 | 0.00000 | 0.78113 \n", "# 848 | 0.52136 | 0.00000 | 0.78599 \n", "# 849 | 0.51461 | 0.00000 | 0.79961 \n", "# 850 | 0.52613 | 0.00000 | 0.75389 \n", "# 851 | 0.52171 | 0.00000 | 0.78210 \n", "# 852 | 0.52806 | 0.00000 | 0.76556 \n", "# 853 | 0.52643 | 0.00000 | 0.78502 \n", "# 854 | 0.52034 | 0.00000 | 0.78599 \n", "# 855 | 0.51269 | 0.00000 | 0.79183 \n", "# 856 | 0.51479 | 0.00000 | 0.78210 \n", "# 857 | 0.50916 | 0.00000 | 0.79572 \n", "# 858 | 0.50813 | 0.00000 | 0.79475 \n", "# 859 | 0.50480 | 0.00000 | 0.81615 \n", "# 860 | 0.50278 | 0.00000 | 0.82879 \n", "# 861 | 0.49981 | 0.00000 | 0.82296 \n", "# 862 | 0.51360 | 0.00000 | 0.80447 \n", "# 863 | 0.50122 | 0.00000 | 0.79377 \n", "# 864 | 0.49991 | 0.00000 | 0.81712 \n", "# 865 | 0.49977 | 0.00000 | 0.82101 \n", "# 866 | 0.49591 | 0.00000 | 0.81323 \n", "# 867 | 0.49709 | 0.00000 | 0.83074 \n", "# 868 | 0.49621 | 0.00000 | 0.82685 \n", "# 869 | 0.49390 | 0.00000 | 0.81809 \n", "# 870 | 0.49332 | 0.00000 | 0.82685 \n", "# 871 | 0.49415 | 0.00000 | 0.82588 \n", "# 872 | 0.49587 | 0.00000 | 0.82198 \n", "# 873 | 0.50590 | 0.00000 | 0.76167 \n", "# 874 | 0.50653 | 0.00000 | 0.79280 \n", "# 875 | 0.50713 | 0.00000 | 0.79669 \n", "# 876 | 0.50457 | 0.00000 | 0.78210 \n", "# 877 | 0.50043 | 0.00000 | 0.80447 \n", "# 878 | 0.51643 | 0.00000 | 0.74805 \n", "# 879 | 0.51950 | 0.00000 | 0.74903 \n", "# 880 | 0.51771 | 0.00000 | 0.77140 \n", "# 881 | 0.49895 | 0.00000 | 0.77140 \n", "# 882 | 0.49755 | 0.00000 | 0.81420 \n", "# 883 | 0.49531 | 0.00000 | 0.79669 \n", "# 884 | 0.51883 | 0.00000 | 0.77335 \n", "# 885 | 0.51188 | 0.00000 | 0.76265 \n", "# 886 | 0.50649 | 0.00000 | 0.78307 \n", "# 887 | 0.50355 | 0.00000 | 0.78405 \n", "# 888 | 0.49438 | 0.00000 | 0.79864 \n", "# 889 | 0.49338 | 0.00000 | 0.82101 \n", "# 890 | 0.49197 | 0.00000 | 0.82393 \n", "# 891 | 0.48883 | 0.00000 | 0.82490 \n", "# 892 | 0.49004 | 0.00000 | 0.83658 \n", "# 893 | 0.48731 | 0.00000 | 0.82198 \n", "# 894 | 0.48745 | 0.00000 | 0.83560 \n", "# 895 | 0.49023 | 0.00000 | 0.81907 \n", "# 896 | 0.48879 | 0.00000 | 0.82393 \n", "# 897 | 0.48995 | 0.00000 | 0.82198 \n", "# 898 | 0.48738 | 0.00000 | 0.81809 \n", "# 899 | 0.48518 | 0.00000 | 0.81615 \n", "# 900 | 0.48521 | 0.00000 | 0.83852 \n", "# 901 | 0.48788 | 0.00000 | 0.82101 \n", "# 902 | 0.48487 | 0.00000 | 0.83268 \n", "# 903 | 0.48504 | 0.00000 | 0.82296 \n", "# 904 | 0.48348 | 0.00000 | 0.82782 \n", "# 905 | 0.48291 | 0.00000 | 0.83074 \n", "# 906 | 0.48198 | 0.00000 | 0.83074 \n", "# 907 | 0.48150 | 0.00000 | 0.83560 \n", "# 908 | 0.47950 | 0.00000 | 0.83463 \n", "# 909 | 0.47981 | 0.00000 | 0.83852 \n", "# 910 | 0.48167 | 0.00000 | 0.84047 \n", "# 911 | 0.48813 | 0.00000 | 0.75486 \n", "# 912 | 0.48913 | 0.00000 | 0.82101 \n", "# 913 | 0.48447 | 0.00000 | 0.80058 \n", "# 914 | 0.48232 | 0.00000 | 0.83658 \n", "# 915 | 0.47927 | 0.00000 | 0.83658 \n", "# 916 | 0.47946 | 0.00000 | 0.83463 \n", "# 917 | 0.48028 | 0.00000 | 0.83463 \n", "# 918 | 0.47864 | 0.00000 | 0.82490 \n", "# 919 | 0.47678 | 0.00000 | 0.83949 \n", "# 920 | 0.47827 | 0.00000 | 0.83755 \n", "# 921 | 0.47766 | 0.00000 | 0.84047 \n", "# 922 | 0.47743 | 0.00000 | 0.82296 \n", "# 923 | 0.47477 | 0.00000 | 0.83755 \n", "# 924 | 0.47580 | 0.00000 | 0.83755 \n", "# 925 | 0.47772 | 0.00000 | 0.84241 \n", "# 926 | 0.47633 | 0.00000 | 0.82879 \n", "# 927 | 0.47586 | 0.00000 | 0.83658 \n", "# 928 | 0.47410 | 0.00000 | 0.84922 \n", "# 929 | 0.47460 | 0.00000 | 0.83755 \n", "# 930 | 0.47401 | 0.00000 | 0.83463 \n", "# 931 | 0.47358 | 0.00000 | 0.83463 \n", "# 932 | 0.48433 | 0.00000 | 0.77529 \n", "# 933 | 0.48837 | 0.00000 | 0.80156 \n", "# 934 | 0.49228 | 0.00000 | 0.75000 \n", "# 935 | 0.53406 | 0.00000 | 0.72957 \n", "# 936 | 0.59349 | 0.00000 | 0.65370 \n", "# 937 | 0.54461 | 0.00000 | 0.65661 \n", "# 938 | 0.51751 | 0.00000 | 0.73541 \n", "# 939 | 0.51946 | 0.00000 | 0.76654 \n", "# 940 | 0.51198 | 0.00000 | 0.74611 \n", "# 941 | 0.49071 | 0.00000 | 0.79572 \n", "# 942 | 0.48741 | 0.00000 | 0.80837 \n", "# 943 | 0.48940 | 0.00000 | 0.83171 \n", "# 944 | 0.49434 | 0.00000 | 0.78307 \n", "# 945 | 0.50810 | 0.00000 | 0.79086 \n", "# 946 | 0.50190 | 0.00000 | 0.77529 \n", "# 947 | 0.50138 | 0.00000 | 0.79280 \n", "# 948 | 0.49911 | 0.00000 | 0.79572 \n", "# 949 | 0.50314 | 0.00000 | 0.80739 \n", "# 950 | 0.49738 | 0.00000 | 0.80058 \n", "# 951 | 0.53510 | 0.00000 | 0.74222 \n", "# 952 | 0.53245 | 0.00000 | 0.75973 \n", "# 953 | 0.50113 | 0.00000 | 0.79280 \n", "# 954 | 0.49678 | 0.00000 | 0.80642 \n", "# 955 | 0.49968 | 0.00000 | 0.82004 \n", "# 956 | 0.49353 | 0.00000 | 0.79864 \n", "# 957 | 0.49409 | 0.00000 | 0.82393 \n", "# 958 | 0.48850 | 0.00000 | 0.81712 \n", "# 959 | 0.48732 | 0.00000 | 0.82198 \n", "# 960 | 0.48583 | 0.00000 | 0.83463 \n", "# 961 | 0.48717 | 0.00000 | 0.83560 \n", "# 962 | 0.49166 | 0.00000 | 0.78988 \n", "# 963 | 0.48718 | 0.00000 | 0.82101 \n", "# 964 | 0.48449 | 0.00000 | 0.84436 \n", "# 965 | 0.48572 | 0.00000 | 0.84241 \n", "# 966 | 0.48336 | 0.00000 | 0.83366 \n", "# 967 | 0.48527 | 0.00000 | 0.82198 \n", "# 968 | 0.48378 | 0.00000 | 0.83658 \n", "# 969 | 0.48560 | 0.00000 | 0.83463 \n", "# 970 | 0.48140 | 0.00000 | 0.83560 \n", "# 971 | 0.47973 | 0.00000 | 0.84047 \n", "# 972 | 0.48026 | 0.00000 | 0.84339 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "# 973 | 0.52743 | 0.00000 | 0.71109 \n", "# 974 | 0.51828 | 0.00000 | 0.73152 \n", "# 975 | 0.51169 | 0.00000 | 0.75681 \n", "# 976 | 0.50040 | 0.00000 | 0.77626 \n", "# 977 | 0.49987 | 0.00000 | 0.80545 \n", "# 978 | 0.49475 | 0.00000 | 0.79961 \n", "# 979 | 0.48799 | 0.00000 | 0.82685 \n", "# 980 | 0.48720 | 0.00000 | 0.83171 \n", "# 981 | 0.48790 | 0.00000 | 0.84241 \n", "# 982 | 0.48720 | 0.00000 | 0.82296 \n", "# 983 | 0.48248 | 0.00000 | 0.83366 \n", "# 984 | 0.48260 | 0.00000 | 0.85019 \n", "# 985 | 0.48272 | 0.00000 | 0.83560 \n", "# 986 | 0.48424 | 0.00000 | 0.82101 \n", "# 987 | 0.48143 | 0.00000 | 0.84047 \n", "# 988 | 0.48048 | 0.00000 | 0.82977 \n", "# 989 | 0.47887 | 0.00000 | 0.84922 \n", "# 990 | 0.48022 | 0.00000 | 0.84241 \n", "# 991 | 0.47867 | 0.00000 | 0.84533 \n", "# 992 | 0.47878 | 0.00000 | 0.84241 \n", "# 993 | 0.47876 | 0.00000 | 0.83560 \n", "# 994 | 0.47767 | 0.00000 | 0.84144 \n", "# 995 | 0.47786 | 0.00000 | 0.85214 \n", "# 996 | 0.47649 | 0.00000 | 0.84533 \n", "# 997 | 0.47572 | 0.00000 | 0.84728 \n", "# 998 | 0.47814 | 0.00000 | 0.83560 \n", "# 999 | 0.47693 | 0.00000 | 0.82588 \n", "# 1000 | 0.47429 | 0.00000 | 0.82977 \n", "# 1001 | 0.47584 | 0.00000 | 0.85798 \n", "# 1002 | 0.47488 | 0.00000 | 0.84825 \n", "# 1003 | 0.47582 | 0.00000 | 0.83560 \n", "# 1004 | 0.47530 | 0.00000 | 0.83366 \n", "# 1005 | 0.47396 | 0.00000 | 0.83852 \n", "# 1006 | 0.47526 | 0.00000 | 0.82879 \n", "# 1007 | 0.47325 | 0.00000 | 0.83755 \n", "# 1008 | 0.47301 | 0.00000 | 0.83949 \n", "# 1009 | 0.47212 | 0.00000 | 0.85117 \n", "# 1010 | 0.47124 | 0.00000 | 0.84728 \n", "# 1011 | 0.47168 | 0.00000 | 0.85700 \n", "# 1012 | 0.49259 | 0.00000 | 0.79864 \n", "# 1013 | 0.49331 | 0.00000 | 0.78891 \n", "# 1014 | 0.48602 | 0.00000 | 0.79280 \n", "# 1015 | 0.51709 | 0.00000 | 0.73444 \n", "# 1016 | 0.50436 | 0.00000 | 0.76070 \n", "# 1017 | 0.49104 | 0.00000 | 0.78016 \n", "# 1018 | 0.48956 | 0.00000 | 0.80253 \n", "# 1019 | 0.54473 | 0.00000 | 0.76459 \n", "# 1020 | 0.51186 | 0.00000 | 0.74514 \n" ] } ], "source": [ "alphazero_train(config)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "133" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(game.memory)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's train best_player some more:" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Interrupted! Cleaning up...\n", "========================================================================\n", " | Training | policy | value \n", "Epochs | Error | head acc | head acc \n", "------ | --------- | --------- | --------- \n", "# 801 | 0.39601 | 0.00000 | 0.92708 \n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mcurrent_player\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mcurrent_player\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgame\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mcurrent_player\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreport_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/.local/lib/python3.6/site-packages/conx/network.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, epochs, accuracy, error, batch_size, report_rate, verbose, kverbose, shuffle, tolerance, class_weight, sample_weight, use_validation_to_stop, plot, record, callbacks, save)\u001b[0m\n\u001b[1;32m 1262\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Saved!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1263\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minterrupted\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1264\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1266\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepoch_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "best_player.net.train(1000, report_rate=5, plot=True)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "best_player.net[\"policy_head\"].vshape = (6,7)\n", "best_player.net.config[\"show_targets\"] = True" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cacd226ab591480f845322f74401034b", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type Dashboard.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "Dashboard(children=(Accordion(children=(HBox(children=(VBox(children=(Select(description='Dataset:', index=1, options=('Test', 'Train'), rows=1, value='Train'), FloatSlider(value=1.0, continuous_update=False, description='Zoom', max=3.0, min=0.5), IntText(value=150, description='Horizontal space between banks:', style=DescriptionStyle(description_width='initial')), IntText(value=30, description='Vertical space between layers:', style=DescriptionStyle(description_width='initial')), HBox(children=(Checkbox(value=True, description='Show Targets', style=DescriptionStyle(description_width='initial')), Checkbox(value=False, description='Errors', style=DescriptionStyle(description_width='initial')))), Select(description='Features:', options=('', 'main_input', 'conv2d-1', 'batch-norm-1', 'leaky-relu-1', 'conv2d-2', 'batch-norm-2', 'leaky-relu-2', 'conv2d-3', 'batch-norm-3', 'add-1', 'leaky-relu-3', 'conv2d-4', 'batch-norm-4', 'leaky-relu-4', 'conv2d-5', 'batch-norm-5', 'add-2', 'leaky-relu-5', 'conv2d-6', 'batch-norm-6', 'leaky-relu-6', 'conv2d-7', 'batch-norm-7', 'add-3', 'leaky-relu-7', 'conv2d-8', 'batch-norm-8', 'leaky-relu-8', 'conv2d-9', 'batch-norm-9', 'add-4', 'leaky-relu-9', 'conv2d-10', 'batch-norm-10', 'leaky-relu-10', 'conv2d-11', 'batch-norm-11', 'add-5', 'leaky-relu-11', 'conv2d-12', 'batch-norm-12', 'leaky-relu-12', 'conv2d-13', 'batch-norm-13', 'leaky-relu-13'), rows=1, value=''), IntText(value=3, description='Feature columns:', style=DescriptionStyle(description_width='initial')), FloatText(value=2.0, description='Feature scale:', style=DescriptionStyle(description_width='initial'))), layout=Layout(width='100%')), VBox(children=(Select(description='Layer:', index=50, options=('main_input', 'conv2d-1', 'batch-norm-1', 'leaky-relu-1', 'conv2d-2', 'batch-norm-2', 'leaky-relu-2', 'conv2d-3', 'batch-norm-3', 'add-1', 'leaky-relu-3', 'conv2d-4', 'batch-norm-4', 'leaky-relu-4', 'conv2d-5', 'batch-norm-5', 'add-2', 'leaky-relu-5', 'conv2d-6', 'batch-norm-6', 'leaky-relu-6', 'conv2d-7', 'batch-norm-7', 'add-3', 'leaky-relu-7', 'conv2d-8', 'batch-norm-8', 'leaky-relu-8', 'conv2d-9', 'batch-norm-9', 'add-4', 'leaky-relu-9', 'conv2d-10', 'batch-norm-10', 'leaky-relu-10', 'conv2d-11', 'batch-norm-11', 'add-5', 'leaky-relu-11', 'conv2d-12', 'batch-norm-12', 'leaky-relu-12', 'flatten-1', 'policy_head', 'conv2d-13', 'batch-norm-13', 'leaky-relu-13', 'flatten-2', 'dense-1', 'leaky-relu-14', 'value_head'), rows=1, value='value_head'), Checkbox(value=True, description='Visible'), Select(description='Colormap:', options=('', 'Accent', 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Vega10', 'Vega10_r', 'Vega20', 'Vega20_r', 'Vega20b', 'Vega20b_r', 'Vega20c', 'Vega20c_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'seismic', 'seismic_r', 'spectral', 'spectral_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'viridis', 'viridis_r', 'winter', 'winter_r'), rows=1, value=''), HTML(value=''), FloatText(value=-1.0, description='Leftmost color maps to:', style=DescriptionStyle(description_width='initial')), FloatText(value=1.0, description='Rightmost color maps to:', style=DescriptionStyle(description_width='initial')), IntText(value=0, description='Feature to show:', style=DescriptionStyle(description_width='initial'))), layout=Layout(width='100%')))),), selected_index=None, _titles={'0': 'AlphaZero Network'}), VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='Dataset index', layout=Layout(width='100%'), max=132), Label(value='of 133', layout=Layout(width='100px'))), layout=Layout(height='40px')), HBox(children=(Button(icon='fast-backward', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='backward', layout=Layout(width='100%'), style=ButtonStyle()), IntText(value=0, layout=Layout(width='100%')), Button(icon='forward', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='fast-forward', layout=Layout(width='100%'), style=ButtonStyle()), Button(description='Play', icon='play', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='refresh', layout=Layout(width='25%'), style=ButtonStyle())), layout=Layout(height='50px', width='100%'))), layout=Layout(width='100%')), HTML(value='

', layout=Layout(justify_content='center', overflow_x='auto', overflow_y='auto', width='95%')), Output()))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "best_player.net.dashboard()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Now, you can play the best player to see how it does:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p1 = QueryPlayer(\"Your Name\")\n", "p2 = NNPlayer(\"Trained AlphaZero\")\n", "p2.net = best_player.net\n", "connect4 = ConnectFour()\n", "connect4.play_game(p1, p2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "* Play against itself, at just the right level. Evolution-style.\n", "* Uses search in training." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": { "0741e46b9b0f4bf5844bd3d85c4232d9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "height": "40px" } }, "0776543edcaf412e9f635d7a22655c6b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "IntTextModel", "state": { "description": "Horizontal space between banks:", "layout": "IPY_MODEL_fbb8a16d0747495e99f5bfc2808109ff", "step": 1, "style": "IPY_MODEL_a1badee473d84e34b941b2c9f758497e", "value": 150 } }, "0923567ad61746768ba409b22a19e778": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "FloatTextModel", "state": { "description": "Rightmost color maps to:", "layout": "IPY_MODEL_fd93bee2d6344272b7af327ceedbe041", "step": null, "style": "IPY_MODEL_92fcc589ae3f43e3835b39bebe0eb5af", "value": 1 } }, "0ad2f90079b940afbae181f44e2fa95c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "0cca1cfa7b60414780471f0f4a6d5928": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "0dcae20a89eb4f3f9f72e1749488e972": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "height": "50px", "width": "100%" } }, "0dffa73cbed44de797de9757880b194d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "SliderStyleModel", "state": { "description_width": "" } }, "1061517f16174b528ea3e598d9a05390": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "VBoxModel", "state": { "children": [ "IPY_MODEL_5b1f003557944ddbaf762229351297a2", "IPY_MODEL_6d3e564d860847f9b3ded1f0a666141d", "IPY_MODEL_0776543edcaf412e9f635d7a22655c6b", "IPY_MODEL_929c52b13ba3416e980762d30075433b", "IPY_MODEL_af765efc36444d1a83f795dce70759ca", "IPY_MODEL_6c3867c8e825497eba374bbee12868bb", "IPY_MODEL_b36f0ddf249b43749b2bc484734a1797", "IPY_MODEL_47eedfb230484e189827e89986730a27" ], "layout": "IPY_MODEL_6ab6f6dff1c74d81a479cdd6db7fd23d" } }, "12a4410c8917409fbe22407c92bd5851": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "12bd2b38882e43008400b0a17656f423": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "179ed4ed33034da29acb906a1656f109": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "187ca2da185c48cd8396c48f75ac4eaf": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "1a3cf9f38b1d4d11a9f75692b93fc318": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "1dd255503b57462b81cfb7c78003a688": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "IntSliderModel", "state": { "continuous_update": false, "description": "Dataset index", "layout": "IPY_MODEL_179ed4ed33034da29acb906a1656f109", "max": 3, "style": "IPY_MODEL_d9da6ff4050f4edb99e198d819516bda", "value": 1 } }, "1dfa04d6b8a84c56887183f68e4335fd": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "1fa1279ef6b94fb5a45eb551c19091e9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "1fc06217b9474474ab88eb96e110cb72": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "217042ec394e443cb966ee7f50e7d70e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "22231be8fd7745b9bc764cef17bc1b6a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "HBoxModel", "state": { "children": [ "IPY_MODEL_1061517f16174b528ea3e598d9a05390", "IPY_MODEL_7a68e61b8bb54e0eac2f7b5aa775dfb5" ], "layout": "IPY_MODEL_8a191438cd5a4f90bf89cfc1de5d4dd3" } }, "271ce6ca358f4348a9a676a9ec313fda": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "28192858eb7d4f7ca186e8f46c242e3d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "2b3adfdf59714146bfc616daea045c05": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonModel", "state": { "icon": "backward", "layout": "IPY_MODEL_0ad2f90079b940afbae181f44e2fa95c", "style": "IPY_MODEL_507e754ab8334b738975df70ab1cb124" } }, "2df28eb027f643ac89e58be0acbc5519": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "3432f3b78cfb4f89a48ed6a2baf5db45": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "35796f507edf4facbac6578ec38f9b25": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "justify_content": "center", "overflow_x": "auto", "overflow_y": "auto", "width": "95%" } }, "3bcb4ce374a54d49addef5e241765ad6": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "3be538730bc442a0b3408132b13e0a21": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "FloatTextModel", "state": { "description": "Leftmost color maps to:", "layout": "IPY_MODEL_12bd2b38882e43008400b0a17656f423", "step": null, "style": "IPY_MODEL_c9fdaf2f01a5424cb0bcf548fcd5fba0", "value": -1 } }, "3e33934bd3054262adbdad542c4fea63": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_35796f507edf4facbac6578ec38f9b25", "style": "IPY_MODEL_28192858eb7d4f7ca186e8f46c242e3d", "value": "\n\n \n \n \n\n

\n \n \n \n \n Residual CNNLayer: policy_head (output)\n shape = (42,)\n Keras class = Dense\n use_bias = False\n activation = linear\n kernel_regularizer = <keras.regularizers.L1L2 object at 0x7fe39b432278>policy_headLayer: value_head (output)\n shape = (1,)\n Keras class = Dense\n use_bias = False\n activation = tanh\n kernel_regularizer = <keras.regularizers.L1L2 object at 0x7fe3c0b73c18>value_headWeights from flatten-1 to policy_head\n policy_head_1/kernel:0 has shape (84, 42)Weights from flatten-1 to policy_head\n policy_head_1/kernel:0 has shape (84, 42)Layer: main_input (input)\n shape = (6, 7, 2)\n Keras class = Inputmain_input20

main_input features


Feature 0

Feature 1
" } }, "41c3c7e072f04c6abf41f9133fc592fc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "42fb14e627004bf69ce83a88647494e5": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "46b3ca31424346b8881e81ea3b528b6c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonModel", "state": { "icon": "fast-forward", "layout": "IPY_MODEL_9823a1c0a34f40f5903189fc734bd864", "style": "IPY_MODEL_71cfc547f6e548719d1136aa62a82b4d" } }, "47eedfb230484e189827e89986730a27": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "FloatTextModel", "state": { "description": "Feature scale:", "layout": "IPY_MODEL_eff0474a98d94315b818f72b70906a8e", "step": null, "style": "IPY_MODEL_12a4410c8917409fbe22407c92bd5851", "value": 2 } }, "4b5e2f4caa21408ab667a52dc4e97681": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "HBoxModel", "state": { "children": [ "IPY_MODEL_7e15b8ba58e34c98be93831feff33791", "IPY_MODEL_2b3adfdf59714146bfc616daea045c05", "IPY_MODEL_7d758c7a5da947fbb1a9c3671ac8057a", "IPY_MODEL_c2657c5043af450787a756d5ff70a797", "IPY_MODEL_46b3ca31424346b8881e81ea3b528b6c", "IPY_MODEL_adde82d52fef48b5870b3966a2401a13", "IPY_MODEL_89765308b2544059b1cd31c85f5f1aea" ], "layout": "IPY_MODEL_0dcae20a89eb4f3f9f72e1749488e972" } }, "507e754ab8334b738975df70ab1cb124": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonStyleModel", "state": {} }, "545de5cb599c4f80b51a57eae1ccda28": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "VBoxModel", "state": { "children": [ "IPY_MODEL_8b78b24619ac45d49a1ccbd3c6801a5f", "IPY_MODEL_66e7a0fb8d3a4b5b96a090468ee0630f", "IPY_MODEL_3e33934bd3054262adbdad542c4fea63", "IPY_MODEL_dba1ab8cf7604081977cb1904bbd6e03" ], "layout": "IPY_MODEL_81b70d642cee4b15806a0ed6bde80999" } }, "555b06db5a3a42d6b6facd771ac3fff7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonStyleModel", "state": {} }, "56f6ff9e50674984b90124f8428aefd8": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "SelectModel", "state": { "_options_labels": [ "", "Accent", "Accent_r", "Blues", "Blues_r", "BrBG", "BrBG_r", "BuGn", "BuGn_r", "BuPu", "BuPu_r", "CMRmap", "CMRmap_r", "Dark2", "Dark2_r", "GnBu", "GnBu_r", "Greens", "Greens_r", "Greys", "Greys_r", "OrRd", "OrRd_r", "Oranges", "Oranges_r", "PRGn", "PRGn_r", "Paired", "Paired_r", "Pastel1", "Pastel1_r", "Pastel2", "Pastel2_r", "PiYG", "PiYG_r", "PuBu", "PuBuGn", "PuBuGn_r", "PuBu_r", "PuOr", "PuOr_r", "PuRd", "PuRd_r", "Purples", "Purples_r", "RdBu", "RdBu_r", "RdGy", "RdGy_r", "RdPu", "RdPu_r", "RdYlBu", "RdYlBu_r", "RdYlGn", "RdYlGn_r", "Reds", "Reds_r", "Set1", "Set1_r", "Set2", "Set2_r", "Set3", "Set3_r", "Spectral", "Spectral_r", "Vega10", "Vega10_r", "Vega20", "Vega20_r", "Vega20b", "Vega20b_r", "Vega20c", "Vega20c_r", "Wistia", "Wistia_r", "YlGn", "YlGnBu", "YlGnBu_r", "YlGn_r", "YlOrBr", "YlOrBr_r", "YlOrRd", "YlOrRd_r", "afmhot", "afmhot_r", "autumn", "autumn_r", "binary", "binary_r", "bone", "bone_r", "brg", "brg_r", "bwr", "bwr_r", "cool", "cool_r", "coolwarm", "coolwarm_r", "copper", "copper_r", "cubehelix", "cubehelix_r", "flag", "flag_r", "gist_earth", "gist_earth_r", "gist_gray", "gist_gray_r", "gist_heat", "gist_heat_r", "gist_ncar", "gist_ncar_r", "gist_rainbow", "gist_rainbow_r", "gist_stern", "gist_stern_r", "gist_yarg", "gist_yarg_r", "gnuplot", "gnuplot2", "gnuplot2_r", "gnuplot_r", "gray", "gray_r", "hot", "hot_r", "hsv", "hsv_r", "inferno", "inferno_r", "jet", "jet_r", "magma", "magma_r", "nipy_spectral", "nipy_spectral_r", "ocean", "ocean_r", "pink", "pink_r", "plasma", "plasma_r", "prism", "prism_r", "rainbow", "rainbow_r", "seismic", "seismic_r", "spectral", "spectral_r", "spring", "spring_r", "summer", "summer_r", "tab10", "tab10_r", "tab20", "tab20_r", "tab20b", "tab20b_r", "tab20c", "tab20c_r", "terrain", "terrain_r", "viridis", "viridis_r", "winter", "winter_r" ], "description": "Colormap:", "index": 0, "layout": "IPY_MODEL_fbb8a16d0747495e99f5bfc2808109ff", "rows": 1, "style": "IPY_MODEL_c65cae16013b437aa57b924afe681ef4" } }, "5b1f003557944ddbaf762229351297a2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "SelectModel", "state": { "_options_labels": [ "Test", "Train" ], "description": "Dataset:", "index": 1, "layout": "IPY_MODEL_c263abd70a4d48c4bcfb66f823a490a2", "rows": 1, "style": "IPY_MODEL_837fdb67e2164e4585e29fcedec8ac0b" } }, "5e253d6e7f8c4ed0923d32aec9cb1e3e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "25%" } }, "5e6bbae775a34f77836b19a2499c8f34": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "IntTextModel", "state": { "description": "Feature to show:", "layout": "IPY_MODEL_f0f23f5f0b31443ca2fa8bc3fadaa94c", "step": 1, "style": "IPY_MODEL_9835ea8988394379a41b196038121392" } }, "5fd493c4923a48cbae37d1b27bb1bebd": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "63b13f0a739e4d328e0f42d31cc13753": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "655fb18010474bc7adfc93ea9f88f195": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "CheckboxModel", "state": { "description": "Visible", "disabled": false, "layout": "IPY_MODEL_fbb8a16d0747495e99f5bfc2808109ff", "style": "IPY_MODEL_72f1453f5ff44eaf83235ec62ee65a97", "value": true } }, "66e7a0fb8d3a4b5b96a090468ee0630f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "VBoxModel", "state": { "children": [ "IPY_MODEL_b1759105ed0e4029ad8b4b853649b20a", "IPY_MODEL_4b5e2f4caa21408ab667a52dc4e97681" ], "layout": "IPY_MODEL_271ce6ca358f4348a9a676a9ec313fda" } }, "6ab6f6dff1c74d81a479cdd6db7fd23d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "6c3867c8e825497eba374bbee12868bb": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "SelectModel", "state": { "_options_labels": [ "", "main_input", "conv2d-1", "batch-norm-1", "leaky-relu-1", "conv2d-2", "batch-norm-2", "leaky-relu-2", "conv2d-3", "batch-norm-3", "add-1", "leaky-relu-3", "conv2d-4", "batch-norm-4", "leaky-relu-4", "conv2d-5", "batch-norm-5", "add-2", "leaky-relu-5", "conv2d-6", "batch-norm-6", "leaky-relu-6", "conv2d-7", "batch-norm-7", "add-3", "leaky-relu-7", "conv2d-8", "batch-norm-8", "leaky-relu-8", "conv2d-9", "batch-norm-9", "add-4", "leaky-relu-9", "conv2d-10", "batch-norm-10", "leaky-relu-10", "conv2d-11", "batch-norm-11", "add-5", "leaky-relu-11", "conv2d-12", "batch-norm-12", "leaky-relu-12", "conv2d-13", "batch-norm-13", "leaky-relu-13" ], "description": "Features:", "index": 1, "layout": "IPY_MODEL_63b13f0a739e4d328e0f42d31cc13753", "rows": 1, "style": "IPY_MODEL_ca36085589ea4aae83167a05606f2ee9" } }, "6d3e564d860847f9b3ded1f0a666141d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "FloatSliderModel", "state": { "continuous_update": false, "description": "Zoom", "layout": "IPY_MODEL_76884212b09440088f9ffee6e5c484a4", "max": 3, "min": 0.5, "step": 0.1, "style": "IPY_MODEL_0dffa73cbed44de797de9757880b194d", "value": 1 } }, "6f31104f50c74f428b95d3f2afc9cee7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonStyleModel", "state": {} }, "705fd9c4bc8e41fe9efe90be5711d1af": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "71cfc547f6e548719d1136aa62a82b4d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonStyleModel", "state": {} }, "72f1453f5ff44eaf83235ec62ee65a97": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "733ce58e1f1e44bf814c6651058c7195": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "LabelModel", "state": { "layout": "IPY_MODEL_cdd164f9e1e345edb568290864583612", "style": "IPY_MODEL_217042ec394e443cb966ee7f50e7d70e", "value": "of 4" } }, "76884212b09440088f9ffee6e5c484a4": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "7a68e61b8bb54e0eac2f7b5aa775dfb5": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "VBoxModel", "state": { "children": [ "IPY_MODEL_d920c397643e496cb085df4e13aa21d7", "IPY_MODEL_655fb18010474bc7adfc93ea9f88f195", "IPY_MODEL_56f6ff9e50674984b90124f8428aefd8", "IPY_MODEL_9be5400ae11f4df280a22d357c7d8dde", "IPY_MODEL_3be538730bc442a0b3408132b13e0a21", "IPY_MODEL_0923567ad61746768ba409b22a19e778", "IPY_MODEL_5e6bbae775a34f77836b19a2499c8f34" ], "layout": "IPY_MODEL_2df28eb027f643ac89e58be0acbc5519" } }, "7d758c7a5da947fbb1a9c3671ac8057a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "IntTextModel", "state": { "layout": "IPY_MODEL_1fa1279ef6b94fb5a45eb551c19091e9", "step": 1, "style": "IPY_MODEL_c70bdb047b4f4f8d88e5c7f1ee0ea5ea", "value": 1 } }, "7e15b8ba58e34c98be93831feff33791": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonModel", "state": { "icon": "fast-backward", "layout": "IPY_MODEL_5fd493c4923a48cbae37d1b27bb1bebd", "style": "IPY_MODEL_9aa099ae3a1f4b42a0d45d1bc4e36958" } }, "81b70d642cee4b15806a0ed6bde80999": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "837fdb67e2164e4585e29fcedec8ac0b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "89765308b2544059b1cd31c85f5f1aea": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonModel", "state": { "icon": "refresh", "layout": "IPY_MODEL_5e253d6e7f8c4ed0923d32aec9cb1e3e", "style": "IPY_MODEL_555b06db5a3a42d6b6facd771ac3fff7" } }, "8a191438cd5a4f90bf89cfc1de5d4dd3": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "8b78b24619ac45d49a1ccbd3c6801a5f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "AccordionModel", "state": { "_titles": { "0": "Residual CNN" }, "children": [ "IPY_MODEL_22231be8fd7745b9bc764cef17bc1b6a" ], "layout": "IPY_MODEL_3bcb4ce374a54d49addef5e241765ad6" } }, "929c52b13ba3416e980762d30075433b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "IntTextModel", "state": { "description": "Vertical space between layers:", "layout": "IPY_MODEL_fbb8a16d0747495e99f5bfc2808109ff", "step": 1, "style": "IPY_MODEL_42fb14e627004bf69ce83a88647494e5", "value": 30 } }, "92fcc589ae3f43e3835b39bebe0eb5af": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "9823a1c0a34f40f5903189fc734bd864": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "9835ea8988394379a41b196038121392": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "9aa099ae3a1f4b42a0d45d1bc4e36958": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonStyleModel", "state": {} }, "9be5400ae11f4df280a22d357c7d8dde": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "HTMLModel", "state": { "layout": "IPY_MODEL_187ca2da185c48cd8396c48f75ac4eaf", "style": "IPY_MODEL_1fc06217b9474474ab88eb96e110cb72", "value": "" } }, "a1badee473d84e34b941b2c9f758497e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "a528fe0f71264a5a96bdb4c65e9ee777": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100%" } }, "adde82d52fef48b5870b3966a2401a13": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonModel", "state": { "description": "Play", "icon": "play", "layout": "IPY_MODEL_705fd9c4bc8e41fe9efe90be5711d1af", "style": "IPY_MODEL_6f31104f50c74f428b95d3f2afc9cee7" } }, "af765efc36444d1a83f795dce70759ca": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "HBoxModel", "state": { "children": [ "IPY_MODEL_c8ca60c826ea46f2baa7fc2e2c4e61c4", "IPY_MODEL_ca38c8c68b444c9097213485ef6dda23" ], "layout": "IPY_MODEL_0cca1cfa7b60414780471f0f4a6d5928" } }, "b1759105ed0e4029ad8b4b853649b20a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "HBoxModel", "state": { "children": [ "IPY_MODEL_1dd255503b57462b81cfb7c78003a688", "IPY_MODEL_733ce58e1f1e44bf814c6651058c7195" ], "layout": "IPY_MODEL_0741e46b9b0f4bf5844bd3d85c4232d9" } }, "b36f0ddf249b43749b2bc484734a1797": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "IntTextModel", "state": { "description": "Feature columns:", "layout": "IPY_MODEL_c71d3120be1a46b6be38508bfc832385", "step": 1, "style": "IPY_MODEL_1a3cf9f38b1d4d11a9f75692b93fc318", "value": 3 } }, "c263abd70a4d48c4bcfb66f823a490a2": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "c2657c5043af450787a756d5ff70a797": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonModel", "state": { "icon": "forward", "layout": "IPY_MODEL_a528fe0f71264a5a96bdb4c65e9ee777", "style": "IPY_MODEL_df4d3667610c4c2c9673115f2a3802e6" } }, "c65cae16013b437aa57b924afe681ef4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "c70bdb047b4f4f8d88e5c7f1ee0ea5ea": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "c71d3120be1a46b6be38508bfc832385": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "c896b0d8bfac44ee9db8a93e0d532a73": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "c8ca60c826ea46f2baa7fc2e2c4e61c4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "CheckboxModel", "state": { "description": "Show Targets", "disabled": false, "layout": "IPY_MODEL_fbb8a16d0747495e99f5bfc2808109ff", "style": "IPY_MODEL_c896b0d8bfac44ee9db8a93e0d532a73", "value": false } }, "c9fdaf2f01a5424cb0bcf548fcd5fba0": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "initial" } }, "ca36085589ea4aae83167a05606f2ee9": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "DescriptionStyleModel", "state": { "description_width": "" } }, "ca38c8c68b444c9097213485ef6dda23": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "CheckboxModel", "state": { "description": "Errors", "disabled": false, "layout": "IPY_MODEL_fbb8a16d0747495e99f5bfc2808109ff", "style": "IPY_MODEL_3432f3b78cfb4f89a48ed6a2baf5db45", "value": false } }, "cdd164f9e1e345edb568290864583612": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": { "width": "100px" } }, "d301fe44fae94261be077dffdad67bd4": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "d920c397643e496cb085df4e13aa21d7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "SelectModel", "state": { "_options_labels": [ "main_input", "conv2d-1", "batch-norm-1", "leaky-relu-1", "conv2d-2", "batch-norm-2", "leaky-relu-2", "conv2d-3", "batch-norm-3", "add-1", "leaky-relu-3", "conv2d-4", "batch-norm-4", "leaky-relu-4", "conv2d-5", "batch-norm-5", "add-2", "leaky-relu-5", "conv2d-6", "batch-norm-6", "leaky-relu-6", "conv2d-7", "batch-norm-7", "add-3", "leaky-relu-7", "conv2d-8", "batch-norm-8", "leaky-relu-8", "conv2d-9", "batch-norm-9", "add-4", "leaky-relu-9", "conv2d-10", "batch-norm-10", "leaky-relu-10", "conv2d-11", "batch-norm-11", "add-5", "leaky-relu-11", "conv2d-12", "batch-norm-12", "leaky-relu-12", "flatten-1", "policy_head", "conv2d-13", "batch-norm-13", "leaky-relu-13", "flatten-2", "dense-1", "leaky-relu-14", "value_head" ], "description": "Layer:", "index": 50, "layout": "IPY_MODEL_1dfa04d6b8a84c56887183f68e4335fd", "rows": 1, "style": "IPY_MODEL_41c3c7e072f04c6abf41f9133fc592fc" } }, "d9da6ff4050f4edb99e198d819516bda": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "SliderStyleModel", "state": { "description_width": "" } }, "dba1ab8cf7604081977cb1904bbd6e03": { "model_module": "@jupyter-widgets/output", "model_module_version": "1.0.0", "model_name": "OutputModel", "state": { "layout": "IPY_MODEL_d301fe44fae94261be077dffdad67bd4" } }, "df4d3667610c4c2c9673115f2a3802e6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.1.0", "model_name": "ButtonStyleModel", "state": {} }, "eff0474a98d94315b818f72b70906a8e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "f0f23f5f0b31443ca2fa8bc3fadaa94c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "fbb8a16d0747495e99f5bfc2808109ff": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} }, "fd93bee2d6344272b7af327ceedbe041": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.0.0", "model_name": "LayoutModel", "state": {} } }, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 2 }