{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Scikit Learn Machine Learning Library\n", "![ScikitLog](scikit-learn-logo-small.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Scikit](http://scikit-learn.org/stable/_images/sphx_glr_plot_classifier_comparison_001_carousel.png)\n", "http://scikit-learn.org/stable/index.html" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Preprocessing\n", "### Feature extraction and normalization.\n", "\n", "* Application: Transforming input data such as text for use with machine learning algorithms.\n", "* Modules: preprocessing, feature extraction." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model selection\n", "### Comparing, validating and choosing parameters and models.\n", "\n", "* Goal: Improved accuracy via parameter tuning\n", "* Modules: grid search, cross validation, metrics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Regression\n", "### Predicting a continuous-valued attribute associated with an object.\n", "\n", "* Applications: Drug response, Stock prices.\n", "* Algorithms: SVR, ridge regression, Lasso" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Classification\n", "### Identifying to which category an object belongs to.\n", "\n", "* Applications: Spam detection, Image recognition.\n", "* Algorithms: SVM, nearest neighbors, random forest" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Clustering\n", "### Automatic grouping of similar objects into sets.\n", "\n", "* Applications: Customer segmentation, Grouping experiment outcomes\n", "* Algorithms: k-Means, spectral clustering, mean-shift" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Dimensionality reduction\n", "*** Reducing the number of random variables to consider.\n", "\n", "* Applications: Visualization, Increased efficiency\n", "* Algorithms: PCA, feature selection, non-negative matrix factorization" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from sklearn import datasets" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "iris = datasets.load_iris()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sklearn.utils.Bunch" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(iris)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[5.1 3.5 1.4 0.2]\n", " [4.9 3. 1.4 0.2]\n", " [4.7 3.2 1.3 0.2]\n", " [4.6 3.1 1.5 0.2]\n", " [5. 3.6 1.4 0.2]\n", " [5.4 3.9 1.7 0.4]\n", " [4.6 3.4 1.4 0.3]\n", " [5. 3.4 1.5 0.2]\n", " [4.4 2.9 1.4 0.2]\n", " [4.9 3.1 1.5 0.1]\n", " [5.4 3.7 1.5 0.2]\n", " [4.8 3.4 1.6 0.2]\n", " [4.8 3. 1.4 0.1]\n", " [4.3 3. 1.1 0.1]\n", " [5.8 4. 1.2 0.2]\n", " [5.7 4.4 1.5 0.4]\n", " [5.4 3.9 1.3 0.4]\n", " [5.1 3.5 1.4 0.3]\n", " [5.7 3.8 1.7 0.3]\n", " [5.1 3.8 1.5 0.3]\n", " [5.4 3.4 1.7 0.2]\n", " [5.1 3.7 1.5 0.4]\n", " [4.6 3.6 1. 0.2]\n", " [5.1 3.3 1.7 0.5]\n", " [4.8 3.4 1.9 0.2]\n", " [5. 3. 1.6 0.2]\n", " [5. 3.4 1.6 0.4]\n", " [5.2 3.5 1.5 0.2]\n", " [5.2 3.4 1.4 0.2]\n", " [4.7 3.2 1.6 0.2]\n", " [4.8 3.1 1.6 0.2]\n", " [5.4 3.4 1.5 0.4]\n", " [5.2 4.1 1.5 0.1]\n", " [5.5 4.2 1.4 0.2]\n", " [4.9 3.1 1.5 0.1]\n", " [5. 3.2 1.2 0.2]\n", " [5.5 3.5 1.3 0.2]\n", " [4.9 3.1 1.5 0.1]\n", " [4.4 3. 1.3 0.2]\n", " [5.1 3.4 1.5 0.2]\n", " [5. 3.5 1.3 0.3]\n", " [4.5 2.3 1.3 0.3]\n", " [4.4 3.2 1.3 0.2]\n", " [5. 3.5 1.6 0.6]\n", " [5.1 3.8 1.9 0.4]\n", " [4.8 3. 1.4 0.3]\n", " [5.1 3.8 1.6 0.2]\n", " [4.6 3.2 1.4 0.2]\n", " [5.3 3.7 1.5 0.2]\n", " [5. 3.3 1.4 0.2]\n", " [7. 3.2 4.7 1.4]\n", " [6.4 3.2 4.5 1.5]\n", " [6.9 3.1 4.9 1.5]\n", " [5.5 2.3 4. 1.3]\n", " [6.5 2.8 4.6 1.5]\n", " [5.7 2.8 4.5 1.3]\n", " [6.3 3.3 4.7 1.6]\n", " [4.9 2.4 3.3 1. ]\n", " [6.6 2.9 4.6 1.3]\n", " [5.2 2.7 3.9 1.4]\n", " [5. 2. 3.5 1. ]\n", " [5.9 3. 4.2 1.5]\n", " [6. 2.2 4. 1. ]\n", " [6.1 2.9 4.7 1.4]\n", " [5.6 2.9 3.6 1.3]\n", " [6.7 3.1 4.4 1.4]\n", " [5.6 3. 4.5 1.5]\n", " [5.8 2.7 4.1 1. ]\n", " [6.2 2.2 4.5 1.5]\n", " [5.6 2.5 3.9 1.1]\n", " [5.9 3.2 4.8 1.8]\n", " [6.1 2.8 4. 1.3]\n", " [6.3 2.5 4.9 1.5]\n", " [6.1 2.8 4.7 1.2]\n", " [6.4 2.9 4.3 1.3]\n", " [6.6 3. 4.4 1.4]\n", " [6.8 2.8 4.8 1.4]\n", " [6.7 3. 5. 1.7]\n", " [6. 2.9 4.5 1.5]\n", " [5.7 2.6 3.5 1. ]\n", " [5.5 2.4 3.8 1.1]\n", " [5.5 2.4 3.7 1. ]\n", " [5.8 2.7 3.9 1.2]\n", " [6. 2.7 5.1 1.6]\n", " [5.4 3. 4.5 1.5]\n", " [6. 3.4 4.5 1.6]\n", " [6.7 3.1 4.7 1.5]\n", " [6.3 2.3 4.4 1.3]\n", " [5.6 3. 4.1 1.3]\n", " [5.5 2.5 4. 1.3]\n", " [5.5 2.6 4.4 1.2]\n", " [6.1 3. 4.6 1.4]\n", " [5.8 2.6 4. 1.2]\n", " [5. 2.3 3.3 1. ]\n", " [5.6 2.7 4.2 1.3]\n", " [5.7 3. 4.2 1.2]\n", " [5.7 2.9 4.2 1.3]\n", " [6.2 2.9 4.3 1.3]\n", " [5.1 2.5 3. 1.1]\n", " [5.7 2.8 4.1 1.3]\n", " [6.3 3.3 6. 2.5]\n", " [5.8 2.7 5.1 1.9]\n", " [7.1 3. 5.9 2.1]\n", " [6.3 2.9 5.6 1.8]\n", " [6.5 3. 5.8 2.2]\n", " [7.6 3. 6.6 2.1]\n", " [4.9 2.5 4.5 1.7]\n", " [7.3 2.9 6.3 1.8]\n", " [6.7 2.5 5.8 1.8]\n", " [7.2 3.6 6.1 2.5]\n", " [6.5 3.2 5.1 2. ]\n", " [6.4 2.7 5.3 1.9]\n", " [6.8 3. 5.5 2.1]\n", " [5.7 2.5 5. 2. ]\n", " [5.8 2.8 5.1 2.4]\n", " [6.4 3.2 5.3 2.3]\n", " [6.5 3. 5.5 1.8]\n", " [7.7 3.8 6.7 2.2]\n", " [7.7 2.6 6.9 2.3]\n", " [6. 2.2 5. 1.5]\n", " [6.9 3.2 5.7 2.3]\n", " [5.6 2.8 4.9 2. ]\n", " [7.7 2.8 6.7 2. ]\n", " [6.3 2.7 4.9 1.8]\n", " [6.7 3.3 5.7 2.1]\n", " [7.2 3.2 6. 1.8]\n", " [6.2 2.8 4.8 1.8]\n", " [6.1 3. 4.9 1.8]\n", " [6.4 2.8 5.6 2.1]\n", " [7.2 3. 5.8 1.6]\n", " [7.4 2.8 6.1 1.9]\n", " [7.9 3.8 6.4 2. ]\n", " [6.4 2.8 5.6 2.2]\n", " [6.3 2.8 5.1 1.5]\n", " [6.1 2.6 5.6 1.4]\n", " [7.7 3. 6.1 2.3]\n", " [6.3 3.4 5.6 2.4]\n", " [6.4 3.1 5.5 1.8]\n", " [6. 3. 4.8 1.8]\n", " [6.9 3.1 5.4 2.1]\n", " [6.7 3.1 5.6 2.4]\n", " [6.9 3.1 5.1 2.3]\n", " [5.8 2.7 5.1 1.9]\n", " [6.8 3.2 5.9 2.3]\n", " [6.7 3.3 5.7 2.5]\n", " [6.7 3. 5.2 2.3]\n", " [6.3 2.5 5. 1.9]\n", " [6.5 3. 5.2 2. ]\n", " [6.2 3.4 5.4 2.3]\n", " [5.9 3. 5.1 1.8]]\n" ] } ], "source": [ "print(iris.data)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "numpy.ndarray" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(iris.data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading from external datasets\n", "http://scikit-learn.org/stable/datasets/index.html#external-datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from sklearn import svm\n", "clf = svm.SVC()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n", " decision_function_shape='ovr', degree=3, gamma='auto', kernel='rbf',\n", " max_iter=-1, probability=False, random_state=None, shrinking=True,\n", " tol=0.001, verbose=False)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, y = iris.data, iris.target\n", "clf.fit(X, y) " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "150" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(iris.data)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Persistent Models using pickle\n", "import pickle\n", "s = pickle.dumps(clf)\n", "clf2 = pickle.loads(s)\n", "clf2.predict(X[0:1])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "sklearn.svm.classes.SVC" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(clf2)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Big data pickle\n", "In the specific case of the scikit, it may be more interesting to use joblib’s replacement of pickle (joblib.dump & joblib.load), which is more efficient on big data, but can only pickle to the disk and not to a string:\n", "\n", "\n", "` from sklearn.externals import joblib`\n", "\n", "` joblib.dump(clf, 'filename.pkl') `" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Full Scikit learn process" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. Define Problem.\n", "2. Prepare Data.\n", "3. Evaluate Algorithms.\n", "4. Improve Results.\n", "5. Present Results.\n", "\n", "from https://machinelearningmastery.com/machine-learning-in-python-step-by-step/" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Library setup/import\n", "import pandas as pd\n", "from pandas.plotting import scatter_matrix\n", "import matplotlib.pyplot as plt\n", "from sklearn import model_selection\n", "from sklearn.metrics import classification_report\n", "from sklearn.metrics import confusion_matrix\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", "from sklearn.naive_bayes import GaussianNB\n", "from sklearn.svm import SVC" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Load dataset\n", "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data\" #UCI ml archive\n", "names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']\n", "dataset = pd.read_csv(url, names=names)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(150, 5)\n" ] } ], "source": [ "# Summarizing data\n", "\n", "# shape\n", "print(dataset.shape)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal-lengthsepal-widthpetal-lengthpetal-widthclass
05.13.51.40.2Iris-setosa
14.93.01.40.2Iris-setosa
24.73.21.30.2Iris-setosa
34.63.11.50.2Iris-setosa
45.03.61.40.2Iris-setosa
\n", "
" ], "text/plain": [ " sepal-length sepal-width petal-length petal-width class\n", "0 5.1 3.5 1.4 0.2 Iris-setosa\n", "1 4.9 3.0 1.4 0.2 Iris-setosa\n", "2 4.7 3.2 1.3 0.2 Iris-setosa\n", "3 4.6 3.1 1.5 0.2 Iris-setosa\n", "4 5.0 3.6 1.4 0.2 Iris-setosa" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.head()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal-lengthsepal-widthpetal-lengthpetal-width
count150.000000150.000000150.000000150.000000
mean5.8433333.0540003.7586671.198667
std0.8280660.4335941.7644200.763161
min4.3000002.0000001.0000000.100000
25%5.1000002.8000001.6000000.300000
50%5.8000003.0000004.3500001.300000
75%6.4000003.3000005.1000001.800000
max7.9000004.4000006.9000002.500000
\n", "
" ], "text/plain": [ " sepal-length sepal-width petal-length petal-width\n", "count 150.000000 150.000000 150.000000 150.000000\n", "mean 5.843333 3.054000 3.758667 1.198667\n", "std 0.828066 0.433594 1.764420 0.763161\n", "min 4.300000 2.000000 1.000000 0.100000\n", "25% 5.100000 2.800000 1.600000 0.300000\n", "50% 5.800000 3.000000 4.350000 1.300000\n", "75% 6.400000 3.300000 5.100000 1.800000\n", "max 7.900000 4.400000 6.900000 2.500000" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Stat summary\n", "dataset.describe()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "class\n", "Iris-setosa 50\n", "Iris-versicolor 50\n", "Iris-virginica 50\n", "dtype: int64\n" ] } ], "source": [ "# Class Distribution\n", "print(dataset.groupby('class').size()) # notice aggregration similar to SQL" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualization\n", "# univariate plots, that is, plots of each individual variable.\n", "# box and whisker plots\n", "dataset.plot(kind='box', subplots=True, layout=(2,2), sharex=False, sharey=False)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# histograms\n", "dataset.hist()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Multivariate Plots" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# scatter plot matrix\n", "scatter_matrix(dataset)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# Note the diagonal grouping of some pairs of attributes. This suggests a high correlation and a predictable relationship." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Algorithm Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a Validation Dataset" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "# Split-out validation dataset\n", "# We will split the loaded dataset into two, \n", "# 80% of which we will use to train our models and \n", "# 20% that we will hold back as a validation dataset.\n", "array = dataset.values\n", "X = array[:,0:4] # our attributes\n", "Y = array[:,4] # our Class\n", "validation_size = 0.20\n", "seed = 7\n", "X_train, X_validation, Y_train, Y_validation = model_selection.train_test_split(X, Y, test_size=validation_size, random_state=seed)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# WE have now have training data in the X_train and Y_train for preparing models and a X_validation and Y_validation sets that we can use later." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test Harness\n", "We will use 10-fold cross validation to estimate accuracy." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# Test options and evaluation metric\n", "seed = 7 # just to have a consistent pseudo random numers, could be any constant\n", "scoring = 'accuracy'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Build Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Let’s evaluate 6 different algorithms:\n", "\n", "* Logistic Regression (LR)\n", "* Linear Discriminant Analysis (LDA)\n", "* K-Nearest Neighbors (KNN).\n", "* Classification and Regression Trees (CART).\n", "* Gaussian Naive Bayes (NB).\n", "* Support Vector Machines (SVM)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LR: 0.966667 (0.040825)\n", "LDA: 0.975000 (0.038188)\n", "KNN: 0.983333 (0.033333)\n", "CART: 0.966667 (0.040825)\n", "NB: 0.975000 (0.053359)\n", "SVM: 0.991667 (0.025000)\n" ] } ], "source": [ "# Spot Check Algorithms\n", "models = []\n", "models.append(('LR', LogisticRegression()))\n", "models.append(('LDA', LinearDiscriminantAnalysis()))\n", "models.append(('KNN', KNeighborsClassifier()))\n", "models.append(('CART', DecisionTreeClassifier()))\n", "models.append(('NB', GaussianNB()))\n", "models.append(('SVM', SVC()))\n", "# evaluate each model in turn\n", "results = []\n", "names = []\n", "for name, model in models:\n", "\tkfold = model_selection.KFold(n_splits=10, random_state=seed)\n", "\tcv_results = model_selection.cross_val_score(model, X_train, Y_train, cv=kfold, scoring=scoring)\n", "\tresults.append(cv_results)\n", "\tnames.append(name)\n", "\tmsg = \"%s: %f (%f)\" % (name, cv_results.mean(), cv_results.std())\n", "\tprint(msg)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# plot comparison\n", "# Compare Algorithms\n", "fig = plt.figure()\n", "fig.suptitle('Algorithm Comparison')\n", "ax = fig.add_subplot(111)\n", "plt.boxplot(results)\n", "ax.set_xticklabels(names)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Make Predictions\n", "\n", "The KNN algorithm was the most accurate model that we tested. Now we want to get an idea of the accuracy of the model on our validation set.\n", "\n", "This will give us an independent final check on the accuracy of the best model. It is valuable to keep a validation set just in case you made a slip during training, such as overfitting to the training set or a data leak. Both will result in an overly optimistic result.\n", "\n", "We can run the KNN model directly on the validation set and summarize the results as a final accuracy score, a confusion matrix and a classification report." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9\n", "[[ 7 0 0]\n", " [ 0 11 1]\n", " [ 0 2 9]]\n", " precision recall f1-score support\n", "\n", " Iris-setosa 1.00 1.00 1.00 7\n", "Iris-versicolor 0.85 0.92 0.88 12\n", " Iris-virginica 0.90 0.82 0.86 11\n", "\n", " avg / total 0.90 0.90 0.90 30\n", "\n" ] } ], "source": [ "# Make predictions on validation dataset\n", "knn = KNeighborsClassifier()\n", "knn.fit(X_train, Y_train)\n", "predictions = knn.predict(X_validation)\n", "print(accuracy_score(Y_validation, predictions))\n", "print(confusion_matrix(Y_validation, predictions))\n", "print(classification_report(Y_validation, predictions))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that the accuracy is 0.9 or 90%. The confusion matrix provides an indication of the three errors made. Finally, the classification report provides a breakdown of each class by precision, recall, f1-score and support showing excellent results (granted the validation dataset was small)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }