{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using gpu device 0: GeForce GTX 965M (CNMeM is disabled, CuDNN 4007)\n", "/usr/local/lib/python3.4/dist-packages/theano/tensor/signal/downsample.py:5: UserWarning: downsample module has been moved to the pool module.\n", " warnings.warn(\"downsample module has been moved to the pool module.\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Couldn't import dot_parser, loading of dot files will not be possible.\n" ] }, { "data": { "text/plain": [ "'float32'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import theano\n", "import theano.tensor as T\n", "\n", "import lasagne\n", "floatX = theano.config.floatX\n", "floatX" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from IPython.display import HTML, display" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%load_ext Cython" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%%cython\n", "# cython: infer_types=True, annotation_typing=True\n", "## cython: infer_types.verbose=True \n", "from IPython.display import HTML, display\n", "import numpy as np\n", "\n", "floatX = np.float32\n", "\n", "binary6 = np.array([ list(map(int,bin(2**6+i)[:2:-1])) for i in range(2**6)], dtype=floatX)\n", "height = np.array([-1]*65, dtype=np.int32)\n", "for __i in range(6):\n", " height[2**__i]=__i\n", "\n", "cdef class Connect4:\n", " cdef public:\n", " long turn\n", " long long[2] data\n", " cpdef long get_col_row(self, col: long, row: long):\n", " pos = col * 7 + row\n", " mask = (1) << pos \n", " if self.data[1] & mask:\n", " return 2\n", " return bool(self.data[0] & mask)\n", " \n", " cpdef long is_end(self):\n", " cdef long long mask\n", " bitboard = self.data[1-self.turn%2]\n", " bound = (1)<<48 # 49 = 7*(6+1) \n", " # horizontal: 0x204081 = 1|(1<<7)|(1<<14)|(1<<21)\n", " # vertical: 0xf = 1|(1<<1)|(1<<2)|(1<<3)\n", " # up-right: 0x1010101 = 1|(1<<8)|(1<<16)|(1<<24)\n", " # down-right: 0x208208 = (1<<3)|(1<<9)|(1<<15)|(1<<21)\n", " for mask in [0x204081, 0xf, 0x1010101, 0x208208]:\n", " while mask < bound:\n", " if mask & bitboard == mask:\n", " return True\n", " mask <<= 1\n", " return False\n", " \n", " cpdef set_col_row(self, col:long, row:long, value:long):\n", " # assert value in [0,1,2]\n", " pos = col * 7 + row\n", " mask = (1) << pos\n", " neg_mask = ~mask \n", " if value == 1 or value ==2:\n", " self.data[value-1] |= mask\n", " self.data[2-value] &= neg_mask\n", " else:\n", " self.data[0] &= neg_mask\n", " self.data[1] &= neg_mask\n", " \n", " def __init__(self, data=None, turn=0):\n", " if data is not None:\n", " self.data = data[:]\n", " else:\n", " self.data = [0, 0]\n", " self.turn = turn\n", " \n", " cpdef remove(self, col:long):\n", " shift = col*7\n", " mask = (((self.data[0]|self.data[1]) >> shift) &0x3f) +1\n", " mask = (mask >> 1) << shift\n", " # print(shift, hex(mask), hex(self.data[0]), hex(self.data[1]))\n", " neg_mask = ~mask\n", " self.data[0] &= neg_mask\n", " self.data[1] &= neg_mask\n", " \n", " def _np_branch(self):\n", " c = self.turn%2 # who's turn\n", " base = np.zeros((2,7,6), dtype=floatX)\n", " pos = []\n", " moves = []\n", " red, yellow = self.data\n", " for i in range(7):\n", " mask = ((red|yellow) &0x3f) + 1\n", " p = height[mask]\n", " if p != -1:\n", " moves.append(i)\n", " pos.append(height[mask])\n", " base[c, i] = binary6[red&0x3f]\n", " base[1-c, i] = binary6[yellow&0x3f]\n", " red >>= 7\n", " yellow >>= 7\n", " boards = np.zeros( (len(moves), 2, 7, 6), dtype=floatX)\n", " for i in range(len(moves)):\n", " m = moves[i]\n", " p = pos[i]\n", " boards[i]=base\n", " boards[i, 0, m, p] = 1\n", " return moves, boards\n", " \n", " def _np_board(self):\n", " c = (self.turn-1)%2 # who played\n", " board = np.zeros((2, 7, 6), dtype=floatX)\n", " pos = []\n", " moves = []\n", " red, yellow = self.data\n", " for i in range(7):\n", " mask = ((red|yellow) &0x3f) + 1\n", " p = height[mask]\n", " if p != -1:\n", " moves.append(i)\n", " pos.append(height[mask])\n", " board[c, i] = binary6[red&0x3f]\n", " board[1-c, i] = binary6[yellow&0x3f]\n", " red >>= 7\n", " yellow >>= 7\n", " return board\n", " \n", " \n", " cpdef move(self, col:long, test=False):\n", " # assert 0<= col <7\n", " shift = col*7\n", " mask = (((self.data[0]|self.data[1]) >> shift) &0x3f) +1\n", " # print(\"mask=\", mask)\n", " if mask >= 64:\n", " return None\n", " if not test:\n", " self.data[self.turn%2] |= (mask<\"\n", " header = \"\"\"
\"\"\"\n", " header += \"\\n\".join(imgstr%('empty', pos(5-j), pos(i), 0) for i in range(7) for j in range(6))\n", " return header +\"\\n\".join(imgstr%('red_coin' if c==1 else 'yellow_coin', pos(5-j), pos(i), 2) for (i,j,c) in self.board_data()) +\"
\"\n", " \n", " def display(self):\n", " display(HTML(self._repr_html_()))\n", " \n", " def __repr__(self):\n", " row_str = lambda j: \"\".join(\".ox\"[self.get_col_row(i,j)] for i in range(7))\n", " return \"\\n\".join(row_str(j) for j in range(5,-1,-1))\n", "\n", "from random import randint\n", "def random_play(init_data=None, init_turn=0, display=False):\n", " game = Connect4(init_data, init_turn)\n", " while game.turn < 42 and not game.is_end():\n", " while game.move(randint(0,6)) is None:\n", " continue\n", " if display:\n", " game.display()\n", " if game.is_end():\n", " return game.turn\n", " return 0" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def MC_agent(_game, N=200):\n", " score = [-1.0*N]*7\n", " for i in range(7):\n", " game = Connect4(_game.data, _game.turn)\n", " if game.move(i):\n", " if game.is_end():\n", " return i\n", " s = 0\n", " for j in range(N):\n", " #print(\"move\", i, \"case\", j)\n", " r = random_play(game.data, game.turn)\n", " turn = (r-1)%2\n", " if r == 0:\n", " pass\n", " elif (r-1)%2 == _game.turn%2:\n", " s += 0.95** (r-_game.turn-1)\n", " else:\n", " s -= .95** (r-_game.turn-1)\n", " score[i] = s/N\n", " return max(zip(score, range(7)))[1]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def random_vs_MC(init_data=None, init_turn=0, display=False):\n", " game = Connect4(init_data, init_turn)\n", " while game.turn < 42 and not game.is_end():\n", " if game.turn%2 == 0:\n", " while game.move(randint(0,6)) is None:\n", " continue\n", " else:\n", " i = MC_agent(game)\n", " game.move(i)\n", " if display == 'all':\n", " game.display()\n", " if display:\n", " game.display()\n", " if game.is_end():\n", " return game.turn\n", " return 0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "input_var = T.tensor4('inputs')\n", "target_var = T.vector('targets')\n", "l_in = lasagne.layers.InputLayer(shape=(None, 2, 7, 6), input_var=input_var)\n", "#l_in_drop = lasagne.layers.DropoutLayer(l_in, p=0.2)\n", "l_hidden = lasagne.layers.DenseLayer(l_in, num_units=400, nonlinearity=lasagne.nonlinearities.rectify, W=lasagne.init.GlorotUniform())\n", "l_hidden2 = lasagne.layers.DenseLayer(l_hidden, num_units=200, nonlinearity=lasagne.nonlinearities.rectify, W=lasagne.init.GlorotUniform())\n", "l_hidden2_drop = lasagne.layers.DropoutLayer(l_hidden2, p=0.2)\n", "l_hidden3 = lasagne.layers.DenseLayer(l_hidden2_drop, num_units=40, nonlinearity=lasagne.nonlinearities.rectify, W=lasagne.init.GlorotUniform())\n", "l_out = lasagne.layers.DenseLayer(l_hidden3, num_units=1, nonlinearity=lasagne.nonlinearities.tanh, W=lasagne.init.GlorotUniform())\n", "\n", "prediction = lasagne.layers.get_output(l_out).flatten()\n", "V = theano.function([input_var], prediction)\n", "#loss = lasagne.objectives.binary_crossentropy(prediction, target_var)\n", "loss = lasagne.objectives.squared_error(prediction, target_var)\n", "\n", "loss = loss.mean()\n", "\n", "params = lasagne.layers.get_all_params(l_out, trainable=True)\n", "updates = lasagne.updates.adam(loss, params)\n", "train_fn = theano.function([input_var, target_var], loss, updates=updates)\n", "test_prediction = lasagne.layers.get_output(l_out, deterministic=True).flatten()\n", "test_V = theano.function([input_var], T.gt(test_prediction, 0.))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from random import random, randint\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def player_NN(game):\n", " moves, boards = game._np_branch()\n", " return moves[np.argmax(V(boards))]\n", "\n", "def player_random(game):\n", " while 1:\n", " r = randint(0,6)\n", " if game.move(r, test=True) is not None:\n", " return r\n", "\n", "def get_player_MC(N=100):\n", " def player(game):\n", " return MC_agent(game, N=N)\n", " return player\n", "\n", "def get_player_mixed(*settings):\n", " def player(game):\n", " r = random()\n", " for player, prop in settings:\n", " r-=prop\n", " if r<=0:\n", " return player(game)\n", " return player_random(game)\n", " return player\n", " \n", "\n", "def vs(player1, player2, display=False):\n", " game = Connect4()\n", " history = []\n", " while game.turn < 42 and not game.is_end():\n", " if game.turn%2 == 0:\n", " m = player1(game)\n", " else:\n", " m = player2(game)\n", " game.move(m)\n", " history.append(m)\n", " if display == 'all':\n", " game.display()\n", " if display:\n", " game.display()\n", " if game.is_end():\n", " return game.turn, history\n", " return 0, history\n", " " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def train_if(results):\n", " def train(r, history, old_histories, γ=0.95):\n", " _r = -1 if r ==0 else (r-1)%2\n", " if _r not in results:\n", " return\n", " game = Connect4()\n", " board_history = []\n", " for m in history:\n", " game.move(m)\n", " board_history.append(game._np_board())\n", " estimate_V = np.zeros(len(history), dtype=floatX)\n", " if _r==-1:\n", " r = 0.\n", " else:\n", " r = 1.\n", " for i in range(len(history)-1, -1, -1):\n", " estimate_V[i]=r\n", " r *= -γ\n", " old_histories.append( (board_history, estimate_V) )\n", " data0 = np.array([x for h in old_histories for x in h[0]], dtype=floatX)\n", " data1 = np.array([x for h in old_histories for x in h[1]], dtype=floatX)\n", " loss = train_fn(data0, data1) \n", " return train\n", "\n", "def vs_test(player1, player2, old_histories=[], ngames=1000, train=None):\n", " result = [0,0,0]\n", " for i in range(ngames):\n", " r, history = vs(player1, player2)\n", " if r == 0:\n", " result[0]+=1\n", " else:\n", " result[1 + (r-1)%2]+=1\n", " if train is not None:\n", " train(r, history, old_histories)\n", " old_histories= old_histories[-10:]\n", " return result\n", "\n", " " ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [], "source": [ "train2 = train=train_if([-1,0])\n", "train1 = train=train_if([-1,1])\n", "train_all = train=train_if([-1,0,1])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "time: 0.0003426074981689453\n", "inital result: mc100 vs nn: [0, 100, 0], nn vs mc100 [0, 0, 100]\n", "time: 11.120651960372925\n", "time: 61.95448708534241\n", "#5000 avgloss=0.406109, rand vs nn: [2, 307, 691], nn vs rand [1, 752, 247]\n", "time: 116.81879782676697\n", "#10000 avgloss=0.424629, rand vs nn: [2, 128, 870], nn vs rand [1, 870, 129]\n", "time: 171.35795402526855\n", "#15000 avgloss=0.412438, rand vs nn: [1, 164, 835], nn vs rand [1, 848, 151]\n", "time: 231.119473695755\n", "#20000 avgloss=0.397607, rand vs nn: [0, 40, 960], nn vs rand [0, 967, 33]\n", "time: 285.53651309013367\n", "#25000 avgloss=0.383559, rand vs nn: [1, 31, 968], nn vs rand [1, 959, 40]\n", "time: 340.74573159217834\n", "#30000 avgloss=0.394265, rand vs nn: [0, 48, 952], nn vs rand [0, 987, 13]\n", "time: 394.1741552352905\n", "#35000 avgloss=0.383031, rand vs nn: [1, 22, 977], nn vs rand [0, 998, 2]\n", "time: 447.6618604660034\n", "#40000 avgloss=0.384509, rand vs nn: [0, 12, 988], nn vs rand [0, 993, 7]\n", "time: 502.2312698364258\n", "#45000 avgloss=0.369094, rand vs nn: [0, 11, 989], nn vs rand [0, 997, 3]\n", "time: 560.1054155826569\n", "#50000 midterm avgloss=0.351583\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 45\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m%\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mN\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m==\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mN\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"#%d midterm avgloss=%f\"\u001b[0m\u001b[1;33m%\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtotal_loss\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0mN\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 47\u001b[1;33m print(\" mc100 vs nn: %s, nn vs mc100 %s\"%(vs_test(get_player_MC(100), player_NN, ngames=100, train=train2), \n\u001b[0m\u001b[0;32m 48\u001b[0m vs_test(player_NN, get_player_MC(100), ngames=100, train=train1)) )\n\u001b[0;32m 49\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m\u001b[0m in \u001b[0;36mvs_test\u001b[1;34m(player1, player2, old_histories, ngames, train)\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 27\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mngames\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 28\u001b[1;33m \u001b[0mr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhistory\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mvs\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mplayer1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mplayer2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 29\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mr\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m+=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m\u001b[0m in \u001b[0;36mvs\u001b[1;34m(player1, player2, display)\u001b[0m\n\u001b[0;32m 30\u001b[0m \u001b[1;32mwhile\u001b[0m \u001b[0mgame\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mturn\u001b[0m \u001b[1;33m<\u001b[0m \u001b[1;36m42\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mgame\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_end\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 31\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mgame\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mturn\u001b[0m\u001b[1;33m%\u001b[0m\u001b[1;36m2\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 32\u001b[1;33m \u001b[0mm\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mplayer1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgame\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 33\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 34\u001b[0m \u001b[0mm\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mplayer2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgame\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m\u001b[0m in \u001b[0;36mplayer\u001b[1;34m(game)\u001b[0m\n\u001b[0;32m 11\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mget_player_MC\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mN\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m100\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 12\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mplayer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgame\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 13\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mMC_agent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgame\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mN\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mN\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 14\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mplayer\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m\u001b[0m in \u001b[0;36mMC_agent\u001b[1;34m(_game, N)\u001b[0m\n\u001b[0;32m 9\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mN\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[1;31m#print(\"move\", i, \"case\", j)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 11\u001b[1;33m \u001b[0mr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mrandom_play\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgame\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgame\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mturn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 12\u001b[0m \u001b[0mturn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m%\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mr\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "import time\n", "import sys\n", "start_time=time.time()\n", "def run_game(V, verbose = False, ɛ=0.1, γ=0.95):\n", " game = Connect4()\n", " history=[]\n", " runtime_V=[]\n", " while game.turn < 42 and not game.is_end():\n", " s = 1 if game.turn%2 == 0 else -1\n", " moves, boards = game._np_branch()\n", " #print(boards)\n", " if random() < ɛ:\n", " idx = randint(0, len(moves)-1)\n", " values=[None]*len(moves)\n", " else:\n", " values = V(boards)\n", " idx = np.argmax(values)\n", " m = moves[idx]\n", " game.move(m)\n", " history.append(boards[idx])\n", " runtime_V.append(values[idx])\n", " if game.is_end():\n", " result = 1.\n", " else:\n", " result = 0. # Tie\n", " # train here\n", " #game.display()\n", " estimate_V = np.zeros(len(history), dtype=floatX)\n", " r = result\n", " for i in range(len(history)-1, -1, -1):\n", " estimate_V[i]=r\n", " r *= -γ\n", " loss = train_fn(np.array(history, dtype=floatX), estimate_V)\n", " return loss\n", "total_loss = 0\n", "N = 5000\n", "print(\"time:\",time.time()-start_time)\n", "print(\"inital result: mc100 vs nn: %s, nn vs mc100 %s\"%(vs_test(get_player_MC(100), player_NN, ngames=100), \n", " vs_test(player_NN, get_player_MC(100), ngames=100)))\n", "print(\"time:\",time.time()-start_time)\n", "for i in range(100*N):\n", " total_loss += run_game(V)\n", " if i%N==N-1:\n", " print(\"time:\",time.time()-start_time)\n", " if i%(10*N)==10*N-1:\n", " print(\"#%d midterm avgloss=%f\"%(i+1, total_loss/N))\n", " print(\" mc100 vs nn: %s, nn vs mc100 %s\"%(vs_test(get_player_MC(100), player_NN, ngames=100, train=train2), \n", " vs_test(player_NN, get_player_MC(100), ngames=100, train=train1)) )\n", " else:\n", " print(\"#%d avgloss=%f, rand vs nn: %s, nn vs rand %s\"%(i+1, total_loss/N, vs_test(player_random, player_NN, ngames=1000, train=train2), \n", " vs_test(player_NN, player_random, ngames=1000, train=train1)) )\n", " total_loss = 0\n", " sys.stdout.flush()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.4.3" } }, "nbformat": 4, "nbformat_minor": 0 }