{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "https://github.com/christopherjenness/NBA-prediction\n", "https://github.com/fastai/courses/blob/master/deeplearning1/nbs/utils.py" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPython 3.5.3\n", "IPython 5.1.0\n", "\n", "numpy 1.11.3\n", "pandas 0.19.2\n", "keras 1.2.0\n", "tensorflow 0.10.0rc0\n", "\n", "compiler : GCC 4.4.7 20120313 (Red Hat 4.4.7-1)\n", "system : Linux\n", "release : 4.4.0-72-generic\n", "machine : x86_64\n", "processor : x86_64\n", "CPU cores : 4\n", "interpreter: 64bit\n", "Git hash : 72c5182d31686b9c85995ce2b433353f418c2b9d\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -v -m -p numpy,pandas,keras,tensorflow -g" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import pandas as pd\n", "\n", "df = pd.read_pickle('./2016_scores.pkl')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
away_scoreaway_teamhome_scorehome_teamdifference
05Toronto Blue Jays3Tampa Bay Rays-2
11St. Louis Cardinals4Pittsburgh Pirates3
23New York Mets4Kansas City Royals1
32Seattle Mariners3Texas Rangers1
45Toronto Blue Jays3Tampa Bay Rays-2
\n", "
" ], "text/plain": [ " away_score away_team home_score home_team difference\n", "0 5 Toronto Blue Jays 3 Tampa Bay Rays -2\n", "1 1 St. Louis Cardinals 4 Pittsburgh Pirates 3\n", "2 3 New York Mets 4 Kansas City Royals 1\n", "3 2 Seattle Mariners 3 Texas Rangers 1\n", "4 5 Toronto Blue Jays 3 Tampa Bay Rays -2" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['difference'] = df['home_score'] - df['away_score']\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "list_of_teams = list(df['away_team'].unique()) " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "list_of_teams = df['away_team'].unique()\n", "n_teams = len(list_of_teams)\n", "\n", "teamid2idx = {o:i for i,o in enumerate(list_of_teams)}\n", "idx2teamid = {teamid2idx[key]:key for key in teamid2idx.keys()}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "df.away_team = df.away_team.apply(lambda x: teamid2idx[x])\n", "df.home_team = df.home_team.apply(lambda x: teamid2idx[x])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy as np\n", "\n", "msk = np.random.rand(len(df)) < 0.8\n", "trn = df[msk]\n", "val = df[~msk]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from keras.layers import Input, Embedding, merge\n", "from keras.layers.core import Flatten\n", "from keras.models import Model\n", "from keras.regularizers import l2\n", "from keras.optimizers import Adam" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def embedding_input(name, n_in, n_out, reg):\n", " inp = Input(shape=(1,), dtype='int64', name=name)\n", " return inp, Embedding(n_in, n_out, input_length=1, W_regularizer=l2(reg))(inp)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def create_bias(inp, n_in):\n", " x = Embedding(n_in, 1, input_length=1)(inp)\n", " return Flatten()(x)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "n_factors = 2\n", "\n", "away_in = Input(shape=(1,), dtype='int64', name='away_in')\n", "home_in = Input(shape=(1,), dtype='int64', name='home_in')\n", "\n", "embedding_layer = Embedding(n_teams, n_factors, input_length=1, W_regularizer=l2(1e-4))\n", "\n", "#away_in, a = embedding_input('away_in', n_teams, n_factors, 1e-4)\n", "#home_in, h = embedding_input('home_in', n_teams, n_factors, 1e-4)\n", "\n", "a = embedding_layer(away_in)\n", "h = embedding_layer(home_in)\n", "\n", "ab = create_bias(away_in, n_teams)\n", "hb = create_bias(home_in, n_teams)\n", "\n", "x = merge([a, h], mode='dot')\n", "x = Flatten()(x)\n", "x = merge([x, ab], mode='sum')\n", "x = merge([x, hb], mode='sum')\n", "model = Model([away_in, home_in], x)\n", "model.compile(Adam(0.001), loss='mse')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 1959 samples, validate on 504 samples\n", "Epoch 1/1\n", "1959/1959 [==============================] - 0s - loss: 18.3939 - val_loss: 19.2317\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit([trn.away_team, trn.home_team], trn.difference, batch_size=64, nb_epoch=1, \n", " validation_data=([val.away_team, val.home_team], val.difference))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 1959 samples, validate on 504 samples\n", "Epoch 1/6\n", "1959/1959 [==============================] - 0s - loss: 18.3673 - val_loss: 19.2271\n", "Epoch 2/6\n", "1959/1959 [==============================] - 0s - loss: 18.3447 - val_loss: 19.2208: 18.42\n", "Epoch 3/6\n", "1959/1959 [==============================] - 0s - loss: 18.3232 - val_loss: 19.2184\n", "Epoch 4/6\n", "1959/1959 [==============================] - 0s - loss: 18.3002 - val_loss: 19.2128\n", "Epoch 5/6\n", "1959/1959 [==============================] - 0s - loss: 18.2789 - val_loss: 19.2074\n", "Epoch 6/6\n", "1959/1959 [==============================] - 0s - loss: 18.2568 - val_loss: 19.2036\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.optimizer.lr=0.01\n", "model.fit([trn.away_team, trn.home_team], trn.difference, batch_size=64, nb_epoch=6, \n", " validation_data=([val.away_team, val.home_team], val.difference))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 1959 samples, validate on 504 samples\n", "Epoch 1/5\n", "1959/1959 [==============================] - 0s - loss: 18.2359 - val_loss: 19.2007\n", "Epoch 2/5\n", "1959/1959 [==============================] - 0s - loss: 18.2140 - val_loss: 19.1982\n", "Epoch 3/5\n", "1959/1959 [==============================] - 0s - loss: 18.1922 - val_loss: 19.1933\n", "Epoch 4/5\n", "1959/1959 [==============================] - 0s - loss: 18.1710 - val_loss: 19.1917\n", "Epoch 5/5\n", "1959/1959 [==============================] - 0s - loss: 18.1485 - val_loss: 19.1912\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.optimizer.lr=0.001\n", "model.fit([trn.away_team, trn.home_team], trn.difference, batch_size=64, nb_epoch=5, \n", " validation_data=([val.away_team, val.home_team], val.difference))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://github.com/fastai/courses/blob/master/deeplearning1/nbs/lesson4.ipynb" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(0.19213127, 'Philadelphia Phillies'),\n", " (0.13538371, 'Colorado Rockies'),\n", " (0.1154874, 'Cincinnati Reds'),\n", " (0.10293819, 'San Diego Padres'),\n", " (0.091762304, 'New York Yankees'),\n", " (0.086704448, 'Atlanta Braves'),\n", " (0.086652882, 'Chicago White Sox'),\n", " (0.073435932, 'Milwaukee Brewers'),\n", " (0.066562518, \"Arizona D'Backs\"),\n", " (0.054421984, 'Baltimore Orioles'),\n", " (0.050462101, 'Kansas City Royals'),\n", " (0.041441962, 'LA Angels of Anaheim'),\n", " (0.033019256, 'Houston Astros'),\n", " (0.021627087, 'Tampa Bay Rays'),\n", " (0.01519491, 'Miami Marlins')]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_away_bias = Model(away_in, ab)\n", "away_bias = get_away_bias.predict(df.away_team.unique())\n", "away_rating = [(b[0], idx2teamid[i]) for i,b in zip(df.away_team.unique(),away_bias)] \n", " \n", "sorted(away_rating, key=lambda x: x[0], reverse=True)[:15]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(-0.16621463, 'Washington Nationals'),\n", " (-0.1011001, 'Toronto Blue Jays'),\n", " (-0.094313346, 'Boston Red Sox'),\n", " (-0.082905471, 'St. Louis Cardinals'),\n", " (-0.071792163, 'Chicago Cubs'),\n", " (-0.07038597, 'New York Mets'),\n", " (-0.054257698, 'San Francisco Giants'),\n", " (-0.050478533, 'Texas Rangers'),\n", " (-0.028976118, 'Seattle Mariners'),\n", " (-0.028675666, 'Pittsburgh Pirates'),\n", " (-0.010899131, 'Oakland Athletics'),\n", " (0.0073902896, 'Cleveland Indians'),\n", " (0.0084861945, 'Detroit Tigers'),\n", " (0.014417341, 'Minnesota Twins'),\n", " (0.014768536, 'Los Angeles Dodgers')]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sorted(away_rating, key=lambda x: x[0])[:15]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "mse = []\n", "for i in range(100):\n", " shuff_scores = df.difference.values.copy()\n", " np.random.shuffle(shuff_scores)\n", " t_mse = np.mean((df.difference-shuff_scores)**2)\n", " mse.append(t_mse) " ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEKxJREFUeJzt3X+s3XV9x/HnawW2BdnQ9Yr86qoJkqDRztxV49CBP0hb\niExjXJtlojOpEjTTuC11S5z7j825ZQ4jqdKgiyKaWUdGRcG4oYmALSlQFEYlNbQwWnSCTDdSfe+P\n++1yvZ7Teznf055z/Twfycn9fj/fz/d83p/7vXn1e74953tSVUiS2vFLky5AknR8GfyS1BiDX5Ia\nY/BLUmMMfklqjMEvSY0x+CWpMQa/JDXG4Jekxpww6QIGWblyZa1evXrSZUjSsrFr167HqmpmKX2n\nMvhXr17Nzp07J12GJC0bSb671L5e6pGkxhj8ktQYg1+SGmPwS1JjDH5JaozBL0mNMfglqTEGvyQ1\nxuCXpMZM5Sd3pcWs3nLjRMbdd+XFExlXGifP+CWpMQa/JDXG4Jekxhj8ktQYg1+SGmPwS1JjDH5J\naozBL0mNMfglqTEGvyQ1xuCXpMYseq+eJNuAS4CDVfXCru164Nyuy6nAD6pqzYB99wE/BH4CHK6q\n2THVLUka0VJu0nYtcBXwySMNVfX7R5aTfAh4/Cj7X1hVj41aoCRpvBYN/qq6NcnqQduSBHgT8Krx\nliVJOlb6XuN/BfBoVT0wZHsBtyTZlWRzz7EkSWPQ9378m4DrjrL9/Ko6kOTZwM1J7quqWwd17P5h\n2AywatWqnmVJkoYZ+Yw/yQnAG4Drh/WpqgPdz4PAdmDtUfpurarZqpqdmZkZtSxJ0iL6XOp5DXBf\nVe0ftDHJyUlOObIMXATs6TGeJGkMFg3+JNcB3wDOTbI/ydu6TRtZcJknyRlJdnSrpwFfT3IXcAdw\nY1XdNL7SJUmjWMq7ejYNaX/LgLaHgQ3d8oPAi3vWJ0kaMz+5K0mN6fuuHqkpq7fcOJFx91158UTG\n1S8mz/glqTEGvyQ1xuCXpMYY/JLUGINfkhpj8EtSYwx+SWqMwS9JjTH4JakxBr8kNcbgl6TGGPyS\n1BiDX5IaY/BLUmMMfklqjMEvSY0x+CWpMUv5svVtSQ4m2TOv7QNJDiTZ3T02DNl3XZL7k+xNsmWc\nhUuSRrOUM/5rgXUD2v++qtZ0jx0LNyZZAXwEWA+cB2xKcl6fYiVJ/S0a/FV1K/D9EZ57LbC3qh6s\nqqeAzwCXjvA8kqQx6vNl6+9K8mZgJ/DeqvqvBdvPBB6at74feOmwJ0uyGdgMsGrVqh5ltccvAJf0\ndIz6n7sfBZ4HrAEeAT7Ut5Cq2lpVs1U1OzMz0/fpJElDjBT8VfVoVf2kqn4KfIy5yzoLHQDOnrd+\nVtcmSZqgkYI/yenzVl8P7BnQ7ZvAOUmem+QkYCNwwyjjSZLGZ9Fr/EmuAy4AVibZD/wlcEGSNUAB\n+4C3d33PAD5eVRuq6nCSdwJfAlYA26rq3mMyC0nSki0a/FW1aUDzNUP6PgxsmLe+A/i5t3pKkibH\nT+5KUmMMfklqjMEvSY0x+CWpMQa/JDWmzy0b1LhJ3SpCUj+e8UtSYwx+SWqMwS9JjTH4JakxBr8k\nNcbgl6TGGPyS1BiDX5IaY/BLUmMMfklqjMEvSY0x+CWpMYsGf5JtSQ4m2TOv7YNJ7ktyd5LtSU4d\nsu++JPck2Z1k5zgLlySNZiln/NcC6xa03Qy8sKpeBPwH8L6j7H9hVa2pqtnRSpQkjdOiwV9VtwLf\nX9D25ao63K3eBpx1DGqTJB0D47jG/0fAF4dsK+CWJLuSbB7DWJKknnp9EUuSvwAOA58a0uX8qjqQ\n5NnAzUnu615BDHquzcBmgFWrVvUpS5J0FCOf8Sd5C3AJ8AdVVYP6VNWB7udBYDuwdtjzVdXWqpqt\nqtmZmZlRy5IkLWKk4E+yDvgz4HVV9aMhfU5OcsqRZeAiYM+gvpKk42cpb+e8DvgGcG6S/UneBlwF\nnMLc5ZvdSa7u+p6RZEe362nA15PcBdwB3FhVNx2TWUiSlmzRa/xVtWlA8zVD+j4MbOiWHwRe3Ks6\nSdLY+cldSWqMwS9JjTH4JakxBr8kNcbgl6TGGPyS1BiDX5IaY/BLUmMMfklqjMEvSY0x+CWpMQa/\nJDXG4Jekxhj8ktQYg1+SGmPwS1JjDH5JaozBL0mNWcp37m5LcjDJnnltz0pyc5IHup/PHLLvuiT3\nJ9mbZMs4C5ckjWYpZ/zXAusWtG0BvlJV5wBf6dZ/RpIVwEeA9cB5wKYk5/WqVpLU26LBX1W3At9f\n0Hwp8Ilu+RPA7w3YdS2wt6oerKqngM90+0mSJmjUa/ynVdUj3fJ/AqcN6HMm8NC89f1dmyRpgnr/\n525VFVB9nyfJ5iQ7k+w8dOhQ36eTJA0xavA/muR0gO7nwQF9DgBnz1s/q2sbqKq2VtVsVc3OzMyM\nWJYkaTGjBv8NwGXd8mXAvwzo803gnCTPTXISsLHbT5I0QUt5O+d1wDeAc5PsT/I24ErgtUkeAF7T\nrZPkjCQ7AKrqMPBO4EvAt4HPVtW9x2YakqSlOmGxDlW1acimVw/o+zCwYd76DmDHyNVJksbOT+5K\nUmMMfklqjMEvSY0x+CWpMQa/JDXG4Jekxhj8ktQYg1+SGmPwS1JjDH5JaozBL0mNMfglqTEGvyQ1\nxuCXpMYY/JLUGINfkhpj8EtSYxb9Bi4tzeotN066BOmYmNTf9r4rL57IuC0Y+Yw/yblJds97PJHk\n3Qv6XJDk8Xl93t+/ZElSHyOf8VfV/cAagCQrgAPA9gFdv1ZVl4w6jiRpvMZ1jf/VwHeq6rtjej5J\n0jEyruDfCFw3ZNvLk9yd5ItJXjCm8SRJI+od/ElOAl4HfG7A5juBVVX1IuAfgS8c5Xk2J9mZZOeh\nQ4f6liVJGmIcZ/zrgTur6tGFG6rqiap6slveAZyYZOWgJ6mqrVU1W1WzMzMzYyhLkjTIOIJ/E0Mu\n8yR5TpJ0y2u78b43hjElSSPq9T7+JCcDrwXePq/tHQBVdTXwRuDyJIeBHwMbq6r6jClJ6qdX8FfV\nfwO/saDt6nnLVwFX9RlDkjRe3rJBkhpj8EtSYwx+SWqMwS9JjTH4JakxBr8kNcbgl6TGGPyS1BiD\nX5IaY/BLUmMMfklqzC/cl637pef6ReTftcbJM35JaozBL0mNMfglqTEGvyQ1xuCXpMYY/JLUGINf\nkhrTK/iT7EtyT5LdSXYO2J4kH06yN8ndSV7SZzxJUn/j+ADXhVX12JBt64FzusdLgY92PyVJE3Ks\nL/VcCnyy5twGnJrk9GM8piTpKPoGfwG3JNmVZPOA7WcCD81b39+1/Zwkm5PsTLLz0KFDPcuSJA3T\nN/jPr6o1zF3SuSLJK0d9oqraWlWzVTU7MzPTsyxJ0jC9gr+qDnQ/DwLbgbULuhwAzp63flbXJkma\nkJGDP8nJSU45sgxcBOxZ0O0G4M3du3teBjxeVY+MXK0kqbc+7+o5Ddie5MjzfLqqbkryDoCquhrY\nAWwA9gI/At7ar1xJUl8jB39VPQi8eED71fOWC7hi1DEkSePnJ3clqTEGvyQ1xuCXpMYY/JLUGINf\nkhozjpu0SdLYrd5y48TG3nflxRMb+3jwjF+SGmPwS1JjDH5JaozBL0mNMfglqTEGvyQ1xuCXpMYY\n/JLUGINfkhpj8EtSY7xlgyQtMKnbRRyvW0V4xi9JjenzZetnJ/lqkm8luTfJHw/oc0GSx5Ps7h7v\n71euJKmvPpd6DgPvrao7k5wC7Epyc1V9a0G/r1XVJT3GkSSN0chn/FX1SFXd2S3/EPg2cOa4CpMk\nHRtjucafZDXwW8DtAza/PMndSb6Y5AXjGE+SNLre7+pJ8gzgn4F3V9UTCzbfCayqqieTbAC+AJwz\n5Hk2A5sBVq1a1bcsSdIQvc74k5zIXOh/qqo+v3B7VT1RVU92yzuAE5OsHPRcVbW1qmaranZmZqZP\nWZKko+jzrp4A1wDfrqq/G9LnOV0/kqztxvveqGNKkvrrc6nnd4A/BO5Jsrtr+3NgFUBVXQ28Ebg8\nyWHgx8DGqqoeY0qSeho5+Kvq60AW6XMVcNWoY0iSxs9P7kpSYwx+SWqMwS9JjTH4JakxBr8kNcbg\nl6TGGPyS1BiDX5IaY/BLUmMMfklqjMEvSY0x+CWpMQa/JDXG4Jekxhj8ktQYg1+SGmPwS1JjDH5J\nakyv4E+yLsn9SfYm2TJge5J8uNt+d5KX9BlPktTfyMGfZAXwEWA9cB6wKcl5C7qtB87pHpuBj446\nniRpPPqc8a8F9lbVg1X1FPAZ4NIFfS4FPllzbgNOTXJ6jzElST31Cf4zgYfmre/v2p5uH0nScXTC\npAs4Islm5i4HATyZ5P5ueSXw2GSqGovlXj8s/zlY/+Qt9zkcl/rz1712/82lduwT/AeAs+etn9W1\nPd0+AFTVVmDrwvYkO6tqtkedE7Xc64flPwfrn7zlPoflXv9CfS71fBM4J8lzk5wEbARuWNDnBuDN\n3bt7XgY8XlWP9BhTktTTyGf8VXU4yTuBLwErgG1VdW+Sd3TbrwZ2ABuAvcCPgLf2L1mS1Eeva/xV\ntYO5cJ/fdvW85QKu6DMGAy7/LDPLvX5Y/nOw/slb7nNY7vX/jMxlsySpFd6yQZIaM5HgT/IrSe5I\ncleSe5P8Vdf+gSQHkuzuHhuG7L8vyT1dn53Ht/r/r2HgHLpt70pyX9f+N0P2P+rtLo61MdQ/0WNw\nlL+h6+f9/exLsnvI/hP9/Xc19J3DtB6DNUluO1JXkrVD9p/mY7DUOUw8i0ZSVcf9AQR4Rrd8InA7\n8DLgA8CfLGH/fcDKSdS+hDlcCNwC/HK37dkD9l0BfAd4HnAScBdw3nKpfxqOwbD6F/T5EPD+afz9\n953DNB8D4MvA+q59A/Bvy+0YLGUO03AMRn1M5Iy/5jzZrZ7YPZbVfzYcZQ6XA1dW1f92/Q4O2H0p\nt7s4pnrWP3GL/Q0lCfAm4LoBu0/89w+95zBxR6m/gF/r2n8deHjA7tN+DJYyh2VrYtf4k6zoXsIe\nBG6uqtu7Te/q7uS5Lckzh+xewC1JdnWf+J2IIXN4PvCKJLcn+fckvz1g16m4lUWP+mEKjsFR/oYA\nXgE8WlUPDNh1Kn7/0GsOML3H4N3AB5M8BPwt8L4Bu077MVjKHGAKjsEoJhb8VfWTqlrD3Kd51yZ5\nIXN373wesAZ4hLmXuYOc3+27HrgiySuPR80LDZnDCcCzmHu5+KfAZ7szt6nTs/6JH4Mh9R+xiSk9\nU56v5xym9RhcDrynqs4G3gNcc7zrejp6zmHix2AUE39XT1X9APgqsK6qHu0Owk+BjzH3cnDQPge6\nnweB7cP6HS/z58Dcmcvnu5eQdwA/Ze4+H/Mt+VYWx8MI9U/VMVhQP0lOAN4AXD9kl6n6/cNIc5jm\nY3AZ8Plu0+eG1DXtx2Apc5iqY/B0TOpdPTNJTu2WfxV4LXBffvaWza8H9gzY9+QkpxxZBi4a1O9Y\nGzYH4AvM/QcpSZ7P3H9cLby501Jud3FM9al/Go7BUeoHeA1wX1XtH7L7xH//0G8OU34MHgZ+t+v2\nKmDQpappPwaLzmEajsGoJnV3ztOBT2Tuy1x+CfhsVf1rkn9Ksoa562b7gLcDJDkD+HhVbQBOA7Z3\nVx9OAD5dVTdN0RxOArYl2QM8BVxWVTV/DjXkdhfLpX6m4xgMrL/btpEFl0im8PcPPebAFB+DJD8A\n/qF71fI/dHfdXU7HYClzYDqOwUj85K4kNWbi1/glSceXwS9JjTH4JakxBr8kNcbgl6TGGPyS1BiD\nX5IaY/BLUmP+D5AweouX4NqPAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "plt.hist(mse);" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "teams[1].away_team" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "18.3493827748383\n" ] } ], "source": [ "predictions = model.predict([df.away_team.values, df.home_team.values])\n", "t_mse = (df.difference-np.squeeze(predictions))**2\n", "print(np.mean(t_mse))" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "18.589783174166605\n" ] } ], "source": [ "predictions = model.predict([df.away_team.values, df.home_team.values])\n", "np.random.shuffle(predictions)\n", "t_mse = (df.difference-np.squeeze(predictions))**2\n", "print(np.mean(t_mse))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-0.06541088],\n", " [ 0.15904462],\n", " [ 0.03773554],\n", " [-0.12947126],\n", " [-0.05763809],\n", " [-0.09660609],\n", " [ 0.19942555],\n", " [ 0.28710777],\n", " [ 0.07351479],\n", " [-0.0760739 ]], dtype=float32)" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions[0:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda root]", "language": "python", "name": "conda-root-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.3" } }, "nbformat": 4, "nbformat_minor": 2 }