{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# KNN exercise with NBA player data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "- NBA player statistics from 2014-2015 (partial season): [data](https://github.com/justmarkham/DAT4-students/blob/master/kerry/Final/NBA_players_2015.csv), [data dictionary](https://github.com/justmarkham/DAT-project-examples/blob/master/pdf/nba_paper.pdf)\n", "- **Goal:** Predict player position using assists, steals, blocks, turnovers, and personal fouls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 1: Read the data into Pandas" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# read the data into a DataFrame\n", "import pandas as pd\n", "url = 'https://raw.githubusercontent.com/justmarkham/DAT4-students/master/kerry/Final/NBA_players_2015.csv'\n", "nba = pd.read_csv(url, index_col=0)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Index([u'season_end', u'player', u'pos', u'age', u'bref_team_id', u'g', u'gs',\n", " u'mp', u'fg', u'fga', u'fg_', u'x3p', u'x3pa', u'x3p_', u'x2p', u'x2pa',\n", " u'x2p_', u'ft', u'fta', u'ft_', u'orb', u'drb', u'trb', u'ast', u'stl',\n", " u'blk', u'tov', u'pf', u'pts', u'G', u'MP', u'PER', u'TS%', u'3PAr',\n", " u'FTr', u'TRB%', u'AST%', u'STL%', u'BLK%', u'TOV%', u'USG%', u'OWS',\n", " u'DWS', u'WS', u'WS/48', u'OBPM', u'DBPM', u'BPM', u'VORP'],\n", " dtype='object')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# examine the columns\n", "nba.columns" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "G 200\n", "F 199\n", "C 79\n", "dtype: int64" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# examine the positions\n", "nba.pos.value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 2: Create X and y\n", "\n", "Use the following features: assists, steals, blocks, turnovers, personal fouls" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# map positions to numbers\n", "nba['pos_num'] = nba.pos.map({'C':0, 'F':1, 'G':2})" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# create feature matrix (X)\n", "feature_cols = ['ast', 'stl', 'blk', 'tov', 'pf']\n", "X = nba[feature_cols]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# alternative way to create X\n", "X = nba.loc[:, 'ast':'pf']" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# create response vector (y)\n", "y = nba.pos_num" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 3: Train a KNN model (K=5)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# import class\n", "from sklearn.neighbors import KNeighborsClassifier" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# instantiate with K=5\n", "knn = KNeighborsClassifier(n_neighbors=5)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n", " metric_params=None, n_neighbors=5, p=2, weights='uniform')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# fit with data\n", "knn.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 4: Predict player position and calculate predicted probability of each position\n", "\n", "Predict for a player with these statistics: 1 assist, 1 steal, 0 blocks, 1 turnover, 2 personal fouls" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# create a list to represent a player\n", "player = [1, 1, 0, 1, 2]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([2], dtype=int64)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# make a prediction\n", "knn.predict(player)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 0. , 0.2, 0.8]])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# calculate predicted probabilities\n", "knn.predict_proba(player)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Step 5: Repeat steps 3 and 4 using K=50" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([1], dtype=int64)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# repeat for K=50\n", "knn = KNeighborsClassifier(n_neighbors=50)\n", "knn.fit(X, y)\n", "knn.predict(player)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 0.06, 0.62, 0.32]])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# calculate predicted probabilities\n", "knn.predict_proba(player)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bonus: Explore the features to decide which ones are predictive" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# allow plots to appear in the notebook\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "# increase default figure and font sizes for easier viewing\n", "plt.rcParams['figure.figsize'] = (6, 4)\n", "plt.rcParams['font.size'] = 14" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", " | count | \n", "mean | \n", "std | \n", "min | \n", "25% | \n", "50% | \n", "75% | \n", "max | \n", "
---|---|---|---|---|---|---|---|---|
pos | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
C | \n", "79 | \n", "0.945570 | \n", "0.858263 | \n", "0 | \n", "0.40 | \n", "0.80 | \n", "1.15 | \n", "4.4 | \n", "
F | \n", "199 | \n", "1.173367 | \n", "1.086252 | \n", "0 | \n", "0.45 | \n", "0.90 | \n", "1.50 | \n", "7.3 | \n", "
G | \n", "200 | \n", "2.729000 | \n", "2.128287 | \n", "0 | \n", "1.10 | \n", "2.25 | \n", "3.80 | \n", "10.2 | \n", "