{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Class Prediction Error Visualizer\n", "The `ClassPredictionError` visualizer is a ScoreVisualizer that takes a fitted scikit-learn classifier and a set of test X and y values and returns a stacked bar graph showing a color-coded break down of predicted classes compared to their actual classes. This visualizer provides a way to quickly understand how good your classifier is at predicting the right classes.\n", "\n", "Below is an example with the visualizer" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "from yellowbrick.classifier import ClassPredictionError\n", "\n", "from sklearn.datasets import make_classification\n", "from sklearn.model_selection import train_test_split as tts\n", "from sklearn.ensemble import RandomForestClassifier\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# We create our own classification dataset\n", "# The data set contains 5 classess and 1000 samples\n", "# We use RandomForest for modeling the data\n", "# I came up with arbitrary names for the classes\n", "\n", "X, y = make_classification(n_samples=1000, n_classes=5,\n", " n_informative=3, n_clusters_per_class=1)\n", "\n", "# Perform 80/20 training/test split\n", "X_train, X_test, y_train, y_test = tts(X, y, test_size=0.20,\n", " random_state=42)\n", "\n", "# Pass in model and classes to ClassPredictionError\n", "visualizer = ClassPredictionError(RandomForestClassifier(),\n", " classes=['apple', 'kiwi', 'pear',\n", " 'banana', 'orange'])\n", "# Fit the model\n", "visualizer.fit(X_train, y_train)\n", "\n", "# Use test data to create visualization\n", "visualizer.score(X_test, y_test)\n", "\n", "# Display visualization\n", "visualizer.show()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }