{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### ML.NET Binary Classification\n", "Creates a binary classification model to predict the quality of wine using 11 physicochemical features. Uses the DataFrame API to read the raw data and prepare it." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### NuGet package installation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Installing package Microsoft.ML, version 1.4.0................done!" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Successfully added reference to package Microsoft.ML, version 1.4.0" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Installing package XPlot.Plotly, version 3.0.1........done!" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Successfully added reference to package XPlot.Plotly, version 3.0.1" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#r \"nuget:Microsoft.ML, 1.4.0\"\n", "#r \"nuget:XPlot.Plotly, 3.0.1\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Namespaces" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "using Microsoft.ML;\n", "using Microsoft.ML.Data;\n", "using Microsoft.ML.Trainers;\n", "using Microsoft.ML.Transforms;\n", "using XPlot.Plotly;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Input class definition" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "public class BinaryClassificationData\n", "{\n", " [LoadColumn(0)]\n", " public float FixedAcidity;\n", "\n", " [LoadColumn(1)]\n", " public float VolatileAcidity;\n", "\n", " [LoadColumn(2)]\n", " public float CitricAcid;\n", "\n", " [LoadColumn(3)]\n", " public float ResidualSugar;\n", "\n", " [LoadColumn(4)]\n", " public float Chlorides;\n", "\n", " [LoadColumn(5)]\n", " public float FreeSulfurDioxide;\n", "\n", " [LoadColumn(6)]\n", " public float TotalSulfurDioxide;\n", "\n", " [LoadColumn(7)]\n", " public float Density;\n", "\n", " [LoadColumn(8)]\n", " public float Ph;\n", "\n", " [LoadColumn(9)]\n", " public float Sulphates;\n", "\n", " [LoadColumn(10)]\n", " public float Alcohol;\n", "\n", " [LoadColumn(11)]\n", " public float Quality;\n", "}\n", "\n", "public class RichBinaryClassificationData: BinaryClassificationData\n", "{\n", " public bool Label => Quality > 5;\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Output class definition" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "public class BinaryClassificationPrediction\n", "{\n", " public bool Label;\n", "\n", " [ColumnName(\"PredictedLabel\")]\n", " public bool PredictedLabel;\n", "\n", " public int LabelAsNumber => PredictedLabel ? 1 : 0;\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Bring in the DataFrame" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Installing package Microsoft.Data.Analysis, version 0.2.0......done!" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Successfully added reference to package Microsoft.Data.Analysis, version 0.2.0" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#r \"nuget:Microsoft.Data.Analysis,0.2.0\"\n", "using Microsoft.Data.Analysis;\n", "using Microsoft.AspNetCore.Html;\n", "\n", "// Convenient custom formatter.\n", "Formatter.Register((df, writer) =>\n", "{\n", " var headers = new List();\n", " headers.Add(th(i(\"index\")));\n", " headers.AddRange(df.Columns.Select(c => (IHtmlContent) th(c.Name)));\n", " var rows = new List>();\n", " var take = 5;\n", " for (var i = 0; i < Math.Min(take, df.Rows.Count); i++)\n", " {\n", " var cells = new List();\n", " cells.Add(td(i));\n", " foreach (var obj in df.Rows[i])\n", " {\n", " cells.Add(td(obj));\n", " }\n", " rows.Add(cells);\n", " }\n", "\n", " var t = table(\n", " thead(\n", " headers),\n", " tbody(\n", " rows.Select(\n", " r => tr(r))));\n", "\n", " writer.Write(t);\n", "}, \"text/html\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read the raw data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
indexFixedAcidityVolatileAcidityCitricAcidResidualSugarChloridesFreeSulfurDioxideTotalSulfurDioxideDensityPhSulphatesAlcoholQuality
070.270.3620.70.045451701.00130.458.86
16.30.30.341.60.049141320.9943.30.499.56
28.10.280.46.90.0530970.99513.260.4410.16
37.20.230.328.50.058471860.99563.190.49.96
47.20.230.328.50.058471860.99563.190.49.96
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "var trainingData = DataFrame.LoadCsv(\n", " \"./WineQuality_White_Train.csv\",\n", " separator: ';',\n", " columnNames: new[]\n", " {\n", " \"FixedAcidity\",\n", " \"VolatileAcidity\",\n", " \"CitricAcid\",\n", " \"ResidualSugar\",\n", " \"Chlorides\",\n", " \"FreeSulfurDioxide\",\n", " \"TotalSulfurDioxide\",\n", " \"Density\",\n", " \"Ph\",\n", " \"Sulphates\",\n", " \"Alcohol\",\n", " \"Quality\"\n", " });\n", "\n", "display(trainingData);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare the data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
indexFixedAcidityVolatileAcidityCitricAcidResidualSugarChloridesFreeSulfurDioxideTotalSulfurDioxideDensityPhSulphatesAlcoholQualityLabel
070.270.3620.70.045451701.00130.458.86True
16.30.30.341.60.049141320.9943.30.499.56True
28.10.280.46.90.0530970.99513.260.4410.16True
37.20.230.328.50.058471860.99563.190.49.96True
47.20.230.328.50.058471860.99563.190.49.96True
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "// Create the Label column and add it to the data.\n", "var labelCol = trainingData[\"Quality\"].ElementwiseGreaterThanOrEqual(6);\n", "labelCol.SetName(\"Label\");\n", "trainingData.Columns.Add(labelCol);\n", "\n", "// This works, but we need the Quality column in later cells ...\n", "// trainingData.Columns.Remove(trainingData[\"Quality\"]);\n", "\n", "display(trainingData);" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "var mlContext = new MLContext(seed: null);\n", "\n", "// Define the pipeline.\n", "var pipeline =\n", " mlContext.Transforms.ReplaceMissingValues(\n", " outputColumnName: \"FixedAcidity\",\n", " replacementMode: MissingValueReplacingEstimator.ReplacementMode.Mean)\n", " .Append(mlContext.Transforms.Concatenate(\"Features\",\n", " new[]\n", " {\n", " \"FixedAcidity\",\n", " \"VolatileAcidity\",\n", " \"CitricAcid\",\n", " \"ResidualSugar\",\n", " \"Chlorides\",\n", " \"FreeSulfurDioxide\",\n", " \"TotalSulfurDioxide\",\n", " \"Density\",\n", " \"Ph\",\n", " \"Sulphates\",\n", " \"Alcohol\"\n", " }))\n", " .Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression());" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "var model = pipeline.Fit(trainingData);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate the model" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LogLossLogLossReductionEntropyAreaUnderRocCurveAccuracyPositivePrecisionPositiveRecallNegativePrecisionNegativeRecallF1ScoreAreaUnderPrecisionRecallCurveConfusionMatrix
0.74520452598977840.19740040333434760.9284885378533480.79096305668454760.73908603523104420.7645249487354750.87042801556420240.66397578203834510.48849294729027470.81404657933042210.8749940309174482{ Microsoft.ML.Data.ConfusionMatrix: PerClassPrecision: [ 0.764524948735475, 0.6639757820383451 ], PerClassRecall: [ 0.8704280155642024, 0.4884929472902747 ], Counts: [ [ 2237, 333 ], [ 689, 658 ] ], NumberOfClasses: 2 }
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "// Load the raw test data.\n", "var testData = mlContext.Data.LoadFromTextFile(\n", " \"./WineQuality_White_Test.csv\", \n", " separatorChar: ';',\n", " hasHeader: true);\n", " \n", "// Calculate the Label (IDataView to IEnumerable to IDataView). \n", "var stronglyTypedTestData = mlContext.Data.CreateEnumerable(trainingData, false);\n", "testData = mlContext.Data.LoadFromEnumerable(stronglyTypedTestData);\n", "\n", "// Score the test data and calculate the metrics.\n", "var scoredData = model.Transform(testData);\n", "var qualityMetrics = mlContext.BinaryClassification.Evaluate(scoredData);\n", "display(qualityMetrics);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize the quality metrics" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "string[] metricNames = \n", " { \n", " \"Log Loss\", \n", " \"Log Loss Reduction\", \n", " \"Entropy\", \n", " \"Area Under Curve\", \n", " \"Accuracy\",\n", " \"Positive Recall\", \n", " \"Negative Recall\",\n", " \"F1 Score\"\n", " };\n", "\n", "double[] metricValues = \n", " { \n", " qualityMetrics.LogLoss, \n", " qualityMetrics.LogLossReduction, \n", " qualityMetrics.Entropy, \n", " qualityMetrics.AreaUnderRocCurve, \n", " qualityMetrics.Accuracy,\n", " qualityMetrics.PositiveRecall, \n", " qualityMetrics.NegativeRecall,\n", " qualityMetrics.F1Score\n", " };\n", "\n", "var graph = new Graph.Bar()\n", "{\n", " x = metricValues,\n", " y = metricNames,\n", " orientation = \"h\",\n", " marker = new Graph.Marker { color = \"darkred\" }\n", "};\n", "\n", "var chart = Chart.Plot(graph);\n", "\n", "var layout = new Layout.Layout(){ title=\"Quality Metrics\" };\n", "chart.WithLayout(layout);\n", "\n", "display(chart);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Drawing the Confusion Matrix\n", "\n", "#### Default" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
PerClassPrecisionPerClassRecallCountsNumberOfClasses
[ 0.764524948735475, 0.6639757820383451 ][ 0.8704280155642024, 0.4884929472902747 ][ [ 2237, 333 ], [ 689, 658 ] ]2
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(qualityMetrics.ConfusionMatrix);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Custom Formatter for Binary Confusion Matrix" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "Formatter.Register((df, writer) =>\n", "{\n", " var rows = new List();\n", "\n", " var cells = new List();\n", " var n = df.Counts[0][0] + df.Counts[0][1] + df.Counts[1][0] + df.Counts[1][1];\n", " cells.Add(td[rowspan: 2, colspan: 2, style: \"text-align: center; background-color: transparent\"](\"n = \" + n));\n", " cells.Add(td[colspan: 2, style: \"border: 1px solid black; text-align: center; padding: 24px; background-color: lightsteelblue\"](b(\"Predicted\")));\n", " rows.Add(tr[style: \"background-color: transparent\"](cells));\n", " \n", " cells = new List();\n", " cells.Add(td[style:\"border: 1px solid black; padding: 24px; background-color: #E3EAF3\"](b(\"True\")));\n", " cells.Add(td[style:\"border: 1px solid black; padding: 24px; background-color: #E3EAF3\"](b(\"False\")));\n", " rows.Add(tr[style: \"background-color: transparent\"](cells));\n", " \n", " cells = new List();\n", " cells.Add(td[rowspan: 2, style:\"border: 1px solid black; text-align: center; padding: 24px; background-color: lightsteelblue\"](b(\"Actual\")));\n", " cells.Add(td[style:\"border: 1px solid black; text-align: center; padding: 24px; background-color: #E3EAF3\"](b(\"True\"))); \n", " cells.Add(td[style:\"border: 1px solid black; padding: 24px\"](df.Counts[0][0]));\n", " cells.Add(td[style:\"border: 1px solid black; padding: 24px\"](df.Counts[0][1]));\n", " rows.Add(tr[style: \"background-color: transparent\"](cells));\n", " \n", " cells = new List();\n", " cells.Add(td[style:\"border: 1px solid black; text-align: center; padding: 24px; background-color: #E3EAF3\"](b(\"False\")));\n", " cells.Add(td[style:\"border: 1px solid black; padding: 24px\"](df.Counts[1][0]));\n", " cells.Add(td[style:\"border: 1px solid black; padding: 24px\"](df.Counts[1][1]));\n", " rows.Add(tr(cells));\n", "\n", " var t = table(\n", " tbody(\n", " rows));\n", "\n", " writer.Write(t);\n", "}, \"text/html\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Tadaa" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
n = 3917Predicted
TrueFalse
ActualTrue2237333
False689658
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(qualityMetrics.ConfusionMatrix);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a prediction engine and use it on a random sample" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
LabelFixedAcidityVolatileAcidityCitricAcidResidualSugarChloridesFreeSulfurDioxideTotalSulfurDioxideDensityPhSulphatesAlcoholQuality
False7.10.370.6710.50.045491550.99753.160.448.75
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
LabelAsNumberLabelPredictedLabel
0FalseFalse
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "// Create prediction engine\n", "var predictionEngine = mlContext.Model.CreatePredictionEngine(model);\n", "\n", "// Get a random data sample\n", "var shuffledData = mlContext.Data.ShuffleRows(trainingData);\n", "var rawSample = mlContext.Data.TakeRows(shuffledData, 1);\n", "var sample = mlContext.Data.CreateEnumerable(rawSample, false).First();\n", "display(sample);\n", "\n", "// Predict quality of sample\n", "var prediction = predictionEngine.Predict(sample);\n", "display(prediction);" ] } ], "metadata": { "kernelspec": { "display_name": ".NET (C#)", "language": "C#", "name": ".net-csharp" }, "language_info": { "file_extension": ".cs", "mimetype": "text/x-csharp", "name": "C#", "pygments_lexer": "csharp", "version": "8.0" } }, "nbformat": 4, "nbformat_minor": 2 }