{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## DecisionTree Classification \n", "\n", "[https://github.com/arundhaj](https://github.com/arundhaj)\n", "\n", "using vertebrate data set" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Read the CSV file\n", "data = pd.read_csv('data/vertebrate.csv')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0 mammal\n", "1 reptile\n", "2 fish\n", "3 mammal\n", "4 amphibian\n", "5 reptile\n", "6 mammal\n", "7 bird\n", "8 mammal\n", "9 fish\n", "10 reptile\n", "11 bird\n", "12 mammal\n", "13 fish\n", "14 amphibian\n", "15 NaN\n", "Name: Class Label, dtype: object" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# List of all classes\n", "data['Class Label']" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array(['mammal', 'reptile', 'fish', 'amphibian', 'bird', nan], dtype=object)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# List of unique classes\n", "data['Class Label'].unique()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "Class Label\n", "amphibian 2\n", "bird 2\n", "fish 3\n", "mammal 5\n", "reptile 3\n", "dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Number of entries for each unique classes\n", "class_group = data.groupby('Class Label').apply(lambda x: len(x))\n", "class_group" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": [ "iVBORw0KGgoAAAANSUhEUgAAAWYAAAE4CAYAAABoolmRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\n", "AAALEgAACxIB0t1+/AAAFC5JREFUeJzt3X2UbXdd3/H3J/cSA8SIKIvHuFIwCEpCAiQEwXIDFbNa\n", "QamgFgEBm0pbSUqlrV2VcsUlWF0rZSFai2iCUIXFQ1CkPJNZJJASE3JDQoIFTGikCaUlPEQMAfLt\n", "H3tPZnJz78y5uTNnf8+c92utWZmzz54zn7sz5zO/+e2nVBWSpD6OmDqAJOmOLGZJasZilqRmLGZJ\n", "asZilqRmLGZJamb3LCsluQ74KvBt4JtVdep2hpKkZTZTMQMF7KmqL21nGEnSoU1lZNtSSJJuN2sx\n", "F/CBJJcmOXM7A0nSspt1KuPxVXVDkvsA70/yqaq6ECCJ53RL0l1QVQeciZipmKvqhvG/X0xyPnAq\n", "cOFmLz5PSfZW1d6pc3TgtlizU7bFMAA63DHQ3vHjsJK0eL8frg4/FxsNajedykhyjyTfOX5+T+Ap\n", "wJVbF0+StN4sI+b7AucnWV3/v1XV+7Y1lSQtsU2LuaquBU6aQ5bDtTJ1gEZWpg7QyMrUAfrYM3WA\n", "TlamDrCRHO71mJPUTphzkjrbmjnmrbAz5pg72Kg7PSVbkpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWp\n", "GYtZkpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWpGYtZ\n", "kpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWpGYtZkpqxmCWpGYtZkpqx\n", "mCWpGYtZkpqxmCWpGYtZkpqZqZiT7EpyeZJ3bncgSVp2s46YzwauBmobs0iSmKGYkzwI+IfA64Bs\n", "eyJJWnKzjJj/M/BvgNu2OYskCdi90ZNJfhz4P1V1eZI9G6y3d93Dlapa2ZJ0krRDjB26Z6Z1qw4+\n", "bZzkFcBzgG8BRwHHAG+rqueuW6eqyikOaRslqR67eILv962xUXduWMz7vcgTgZdU1VNnfXFJW8Ni\n", "3nk26s5DPY65w0+GJO1oM4+YD/oCjpilbeeIeefZyhGzJGmbWcyS1IzFLEnNWMyS1IzFLEnNWMyS\n", "1IzFLEnNWMyS1IzFLEnNWMyS1IzFLEnNWMyS1IzFLEnNWMyS1IzFLEnNWMyS1IzFLEnNWMyS1IzF\n", "LEnNWMyS1IzFLEnNWMyS1IzFLEnNWMyS1IzFLEnNWMyS1IzFLEnNWMyS1IzFLEnNWMyS1IzFLEnN\n", "WMyS1IzFLEnNWMyS1IzFLEnNbFrMSY5K8rEk+5JcneSV8wgmSctq92YrVNUtSU6vqq8n2Q1clOQJ\n", "VXXRHPJJ0tKZaSqjqr4+fnoksAv40rYlkqQlN1MxJzkiyT7gC8AFVXX19saSpOU164j5tqo6CXgQ\n", "8PeT7NnWVJK0xDadY16vqr6S5F3AY4CV1eVJ9q5bbaWqVpAOU5KaOsOqqsrUGTRY1J+LcUC7Z6Z1\n", "qzb+Nyb5XuBbVfXlJHcH3gv8WlV9cHy+/KHVdhjegB3eg5m8mN0W6xLskG2xUXfOMmK+P/D6JEcw\n", "TH28YbWUJUlbb9MR86Yv4IhZ22SnjIy2JIHbYi3BDtkWG3WnZ/5JUjMWsyQ1YzFLUjMWsyQ1YzFL\n", "UjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMW\n", "syQ1YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1\n", "YzFLUjMWsyQ1YzFLUjMWsyQ1YzFLUjMWsyQ1s2kxJzk2yQVJPpnkqiRnzSOYJC2rVNXGKyT3A+5X\n", "VfuSHA1cBvxkVV0zPl9Vle2PqmWTpGDjn8/5CFP/jLst1iXYIdtio+7cdMRcVTdW1b7x85uBa4AH\n", "3NUwkqSNHdIcc5LjgJOBj21HGEkS7J51xXEa463A2ePIef1ze9c9XKmqlUMJMfxp0kOPP9N6mHpb\n", "SDtJkj3AnpnW3WyOeXzBuwF/Aby7ql6133OHPce8U+aMtiSB22ItgdtiLYHbYi3BDtkWhzXHnCTA\n", "HwJX71/KkqStN8sc8+OBZwOnJ7l8/Dhjm3NJ0tLadI65qi7CE1EkaW4sXElqxmKWpGYsZklqxmKW\n", "pGYsZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGYs\n", "ZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGYsZklq\n", "xmKWpGYsZklqxmKWpGYsZklqxmKWpGYsZklqxmKWpGY2LeYkf5TkC0munEcgSVp2s4yYzwXO2O4g\n", "kqTBpsVcVRcCN80hiyQJ55glqZ3dW/EiSfaue7hSVStb8bqStFMk2QPsmWndqprlBY8D3llVJxzg\n", "uaqqHFLCA7wGbJ5j+4XD/bccdgK3xVoCt8VaArfFWoIdsi026k6nMiSpmVkOl/tT4KPAQ5Ncn+T5\n", "2x9LkpbXTFMZG76AUxlbm8BtsZbAbbGWwG2xlmCHbAunMiRpgVjMktSMxSxJzVjMktSMxSxJzVjM\n", "ktSMxSxJzVjMktSMxSxJzVjMktSMxSxJzVjMktSMxSxJzVjMktSMxSxJzVjMktSMxSxJzVjMktSM\n", "xSxJzVjMktSMxSxJzVjMktSMxSxJzVjMktSMxSxJzVjMktSMxSxJzVjMktSMxSxJzVjMktSMxSxJ\n", "zVjMktSMxSxJzVjMktSMxSxJzWxazEnOSPKpJJ9O8u/mEequWZk6QCMrUwdoZGXqAI2sTB2gkZWp\n", "A2xow2JOsgt4DXAG8IPAP0ny8HkEO3QrUwdoZGXqAI2sTB2gkZWpAzSyMnWADW02Yj4V+ExVXVdV\n", "3wTeBPzE9seSpOW1WTE/ELh+3eO/GZdJkrZJqurgTyY/BZxRVWeOj58NPLaqXrRunYO/gCTpoKoq\n", "B1q+e5Ov+zxw7LrHxzKMmjd9YUnSXbPZVMalwPFJjktyJPAzwJ9vfyxJWl4bjpir6ltJfgl4L7AL\n", "+MOqumYuySRpSW04xyxJmj/P/JOkZjbb+afGkvzOuocFrN8RW1V11pwjtZDkgcBxDNNvYdgWH540\n", "1BwlufdGz1fVl+aVpZMkPwJ8f1Wdm+Q+wNFVde3UuQ5kYYt5PJTvN4H7slZIVVXHTJdq7i4b//vD\n", "DGdmvplhWzwT+ORUoaaU5D8x7KS+Gvj2uqeWppiBjzP8oj6YvzevIF0k2Qs8GvgB4FzgSOCNwOMn\n", "jHVQCzvHnOSzwI+7MxKSfAx4wnh2JknuBlxUVY+dNtn8JfmfwAlV9Y2ps6iPJFcAJwOXVdXJ47JP\n", "VNWJ0yY7sIUdMQM3Wsq3uxdwDPD/xsffOS5bRp9lGA1ZzECS7waOB45aXbZM0zrrfKOqbkuGP66T\n", "3HPiPBta5GK+NMmbgXcAt47LqqrePmGmqfwm8PEkFzBMZTwR2DtpojlbN9/+dWBfkg+yVs5LOd+e\n", "5EzgLIYTwy4HTgMuBp40Za6JvCXJfwXuleSfAS8AXjdxpoNa5KmM88ZP7/APqKrnzz/NdJIcATwO\n", "+GvgsQzb45KqumHSYHOW5HnccQfo6ucFUFWvnybZdJJcBZwCXFxVJyV5GPDKqnr6xNEmkeQpwFPG\n", "h++tqvdPmWcjC1vMWpNkX1WdNHWObsajE46tqiumzjKFJJdW1WOS7ANOq6pbklxdVT84dTZtbGGn\n", "MpLcHfgFhqMR7s7ayOgFU+aayAeSPAN4Wy35b9okK8DTGH62LwO+mOQjVfXiSYNN4/pxjvkdwPuT\n", "3ARcN22k+UpyMwc/QqXtUVwLO2JO8lbgGuDngF8Dng1cs6RziTcD92A4POyWcXHbH7rttPrXQ5J/\n", "yjBaflmSK6vqhKmzTSnJHoYdxO+pqls3WV0TW9gRM8OB4s9I8hNV9fokfwJcNHWoKVTV0VNnaGRX\n", "kvsDPw386rhsMUcfW2AcMR8LfBX4GvAIhuOcl0KSY6rqqwc76abryTaLXMyrv/W/kuQE4EbgPhPm\n", "mbskD6+qa5I86kDPV9XSvAHXeTnDRbc+UlWXJHkI8OmJM00iya8Dz2PYMXzbuqdOnyTQNP4U+Ecc\n", "/KSblifbLPJUxpnA24ATgPOAo4GXVtXvT5lrnpL8QVWdOc6r3ul/ZFUt0xtQ+xlPtnmEUxeLZ2GL\n", "WWvGHaH/AngCQ0FfBPyXqvq7SYPNUZJ/W1W/NR7P7HVDgCTnAy+sqi9MnWVqST5YVU/ebFkXCzeV\n", "keQ5VfWGJL+8bvHtx6xW1TkTRZvSHzPMIb6aYTs8a1z2zClDzdmvAL/FcObfTawdw3z7scxL6BXA\n", "5ePxzOtPtnnahJnmahy03AO4z37zzMfQ+P6lC1fMDBsZhtOO17/hlvkN+EP7HZv6oSRXT5ZmGjcm\n", "eQDDGV172G/EPEmi6f0xw1mhV7E2x7xs2+IXgbOBB7B20S8YdoS+ZpJEM3AqYwdI8kbgd6vq4vHx\n", "acC/rKrnTJtsfpKcBfxz4MHA/97v6aqqB88/1bSS/GVVnTJ1jg6SnFVVr546x6wWtpjHve2vYjgd\n", "uYCPAi+uqr+eNNgcJbly/HQ3w+UMr2fYFt8H/FVVPXyqbFNJ8vtV9cKpc3SQ5ByGKYw/Z91FnZbx\n", "aJ0D7Ie5kGE/zC0bfuFEFrmYP8bwp8ibxkU/A7xomS51meS4DZ6uqvrcnKKoIY/WWZPkLQz7Yd7I\n", "2n6Y76qqlvthFrmY73Qt1SRXVNUjp8okqacDXSOk83VDFm7n37hnNcC7k/x7hgPIYRgxv3uyYFIz\n", "41l/z2W4zdbqe30pDx1kuCzu4/bbD3PZJl8zmYUbMSe5jgPvWV49XK7lmTzSvCW5mOH6y1cyHJWx\n", "+h5Zxkugfgp4KPvthwG+xbBNWt3JZOGKWdJskny8qg54uv6y2WR/DFV13VyCzGjhijnJk6rqQ+PN\n", "WA+0Y2MZ72Ai3UmSlzDs8Hondzwqo+WFe7abd8neXk8EPgQ8lQNPaVjM0uAW4LeB/8AdTzBZxmO6\n", "9+JdsiVNLcm1wClV9X+nzjI175I9J0mOAn6KO+9xfvlkoaRePg0szYWsNuFdsufkz4AvMxzy0vLs\n", "HWliq3cMv4AlvmN4hjb+C++SPQdJrqqqR0ydQ+pqvHP4/pbucLmxmK8EXgz82Li49V2yF3nE/NEk\n", "J1bVJ6YOInVUVedNnaGDqqoklwFfqaqXTJ1nFgs3Yl534Z5dwPHAtdzxz7SWk/nSvCV5KMM1mVfv\n", "JA/Le6W9vwK+H/gc8Lfj4rZ9sYgj5qfu93j1N0v2X1FacucCLwPOYbhG9fMZBjTL6Mc2X6WPhRsx\n", "r5fk0QyX8buN4eabS3c5Q+lgVs/8S3JlVZ2wftnU2bSxI6YOcFcl+Y8MN2G9N8Pdsc9N8tJJQ0m9\n", "3JJkF/CZJL+U5B8DrQ8T02BhR8zjHYBPXL3Q9Xgh7Cuq6qHTJpN6SHIKcA1wL+DXGe5z99tV9T8m\n", "DaZNLeIc86rPM+zQWD2G+Sjgb6aLI7X0BtZOwgrwWqDlDi+tWeQR858BpwDvGxf9KHAJQzkv3UH0\n", "0v7Gvypfwh1vxtruSmq6s0Uu5udt8PTSHUQv7S/JR6qq5UV6tLGFLWZJG0vyFIY7+3wAuHVcXF4a\n", "t7+FnWNO8lTg5dz5IkbHTBZK6uXnGS5zuZt1Uxl4adz2FnbEnOSzwNOBq6rqts3Wl5bNeLbbw2pR\n", "3+RLbGGPY2bYyfdJS1k6qI8ynI6tBbPII+bTGKYyLuCO82fnTJdK6mO8AelD8HoyC2dh55gZDpj/\n", "GsPxy0dOnEXq6IypA+iuWeQRs9djlrQjLfIc839PslBXjJKkWSzyiPlm4B4M88vfHBd7uJykhbew\n", "c8xVdXSSezNcLP+oqfNI0lZZ2GJOciZwFvAgYB9wGnAx8KQpc0nS4VrkOeazgVOBz1XV6cCjgK9M\n", "G0mSDt8iF/MtVfV3AEmOqqprGE4/laSFtrBTGcD1Sb4beAfw/iQ3AddNG0mSDt/CHpWxXpI9DHdn\n", "eE9V3brJ6pLU2o4oZknaSRZ5jlmSdiSLWZKasZglqRmLWXOX5H5J3pTkM0kuTfKuJMcnOS7Jldv0\n", "Pfcm+eVDWP/m7Xx9aSOLfLicFlCSAOcD51bVz47LTgTuy3Dzg+1yqHu5t3t96aAcMWveTgdurarX\n", "ri6oqk9U1UXrVxpHzx9Octn48bhx+f3H5ZcnuTLJ45MckeS88fEnkvyrWcMkOX8ctV81nua//rlz\n", "xuUfSPK947KHJHn3+DUfTuJJTdpyjpg1b48ALpthvS8AP1pV30hyPPAnwCnAsxiOV3/FOPq+J3Ay\n", "8ICqOgEgyXcdQp4XVNVNSe4OXJLkrVV10/i6f1lV/zrJS4GXAS8CXgv8YlV9Jsljgd8DnnwI30/a\n", "lMWseZv1T/4jgdckeSTwbYarCAJcAvxRkrsB76iqK8Yb8z44yauBdwHvO4Q8Zyf5yfHzY8fvcwnD\n", "XaXfPC5/I/D2JPcEfhh4y/A74fac0pZyKkPz9kng0TOs92LghvH+dI8BvgOgqi4EfgT4PHBekudU\n", "1ZeBRwIrwAuB180SZDxj9MnAaVV1EnA5B76EbBh+oRwB3FRVJ6/7+KFZvpd0KCxmzVVVfQj4jvXz\n", "uUlOTPKE/VY9Brhx/Py5wK5x3e8DvlhVr2Mo4Ecl+R5gV1W9HXgpw5UGZ3EMQ9HekuRhDJeOXXUE\n", "8Mzx82cBF1bV14BrkzxjzJJxx6W0pSxmTeHpwD8YD5e7CvgN4IbxudWpjt8Dfj7JPoarBq4evnY6\n", "sC/Jx4GfBl4FPBC4IMnlwBuAXznI9/3VJNePH/8LeA+wO8nVwCsZrue96m+BU8fD9/Yw3JEd4OeA\n", "XxhzXQU8bd3XeGSGtoTXypCkZhwxS1IzFrMkNWMxS1IzFrMkNWMxS1IzFrMkNWMxS1Iz/x9F+l34\n", "PYVbkgAAAABJRU5ErkJggg==\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot bar chart based on Class Label\n", "class_group.plot(kind='bar', grid=False)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.feature_extraction import DictVectorizer\n", "\n", "cols_to_retain = ['Body Temperature', 'Skin Cover', 'Gives Birth', 'Aquatic Creature', 'Aerial Creature', 'Has Legs', 'Hibernates']\n", "\n", "X_feature = data[cols_to_retain]\n", "X_dict = X_feature.T.to_dict().values()\n", "\n", "# turn list of dicts into a numpy array\n", "vect = DictVectorizer(sparse=False)\n", "X_vector = vect.fit_transform(X_dict)\n", "\n", "# print the features\n", "# vect.get_feature_names()\n", "\n", "# 0 to 14 is train set\n", "X_Train = X_vector[:-1]\n", "# 15th is test set\n", "X_Test = X_vector[-1:] \n", "\n", "# Used to vectorize the class label\n", "le = LabelEncoder()\n", "y_train = le.fit_transform(data['Class Label'][:-1])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn import tree\n", "\n", "clf = tree.DecisionTreeClassifier(criterion='entropy')\n", "clf = clf.fit(X_Train,y_train)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array(['reptile'], dtype=object)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Predict the test data, not seen earlier\n", "le.inverse_transform(clf.predict(X_Test))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# prediction with the same training set\n", "Train_predict = clf.predict(X_Train)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The model predicted the training set correctly\n", "(Train_predict == y_train).all()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy is: 1.0\n", " precision recall f1-score support\n", "\n", " 0 1.00 1.00 1.00 2\n", " 1 1.00 1.00 1.00 2\n", " 2 1.00 1.00 1.00 3\n", " 3 1.00 1.00 1.00 5\n", " 4 1.00 1.00 1.00 3\n", "\n", "avg / total 1.00 1.00 1.00 15\n", "\n" ] } ], "source": [ "# Metrics related to the DecisionTreeClassifier\n", "from sklearn.metrics import accuracy_score, classification_report\n", "\n", "print 'Accuracy is:', accuracy_score(y_train, Train_predict)\n", "print classification_report(y_train, Train_predict)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "\"\\nimport pydot\\nimport pyparsing\\n\\nimport StringIO\\ndot_data = StringIO.StringIO() \\ntree.export_graphviz(clf, out_file=dot_data) \\ngraph = pydot.graph_from_dot_data(dot_data.getvalue()) \\ngraph.write_png('data/vertebrate/tree.png') \\nfrom IPython.core.display import Image \\nImage(filename='data/vertebrate/tree.png')\\n\"" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'''\n", "import pydot\n", "import pyparsing\n", "\n", "import StringIO\n", "dot_data = StringIO.StringIO() \n", "tree.export_graphviz(clf, out_file=dot_data) \n", "graph = pydot.graph_from_dot_data(dot_data.getvalue()) \n", "graph.write_png('data/vertebrate/tree.png') \n", "from IPython.core.display import Image \n", "Image(filename='data/vertebrate/tree.png')\n", "'''" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }