{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict wine quality" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use the latest versions of DataFrame and KotlinDL libraries from [version repository](https://github.com/Kotlin/kotlin-jupyter-libraries)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "pycharm": { "is_executing": true }, "ExecuteTime": { "end_time": "2023-12-05T11:16:19.227175498Z", "start_time": "2023-12-05T11:16:10.985704525Z" } }, "outputs": [], "source": [ "%useLatestDescriptors\n", "%use dataframe, kotlin-dl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read the dataframe from CSV and print the first few lines of it" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:22.897712600Z", "start_time": "2023-12-05T11:16:19.230308412Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
7.4000000.7000000.0000001.9000000.07600011.00000034.0000000.9978003.5100000.5600009.4000005
7.8000000.8800000.0000002.6000000.09800025.00000067.0000000.9968003.2000000.6800009.8000005
7.8000000.7600000.0400002.3000000.09200015.00000054.0000000.9970003.2600000.6500009.8000005
11.2000000.2800000.5600001.9000000.07500017.00000060.0000000.9980003.1600000.5800009.8000006
7.4000000.7000000.0000001.9000000.07600011.00000034.0000000.9978003.5100000.5600009.4000005
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":5,\"ncol\":12,\"columns\":[\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"residual sugar\",\"chlorides\",\"free sulfur dioxide\",\"total sulfur dioxide\",\"density\",\"pH\",\"sulphates\",\"alcohol\",\"quality\"],\"kotlin_dataframe\":[{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"residual sugar\":1.9,\"chlorides\":0.076,\"free sulfur dioxide\":11.0,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"pH\":3.51,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.88,\"citric acid\":0.0,\"residual sugar\":2.6,\"chlorides\":0.098,\"free sulfur dioxide\":25.0,\"total sulfur dioxide\":67.0,\"density\":0.9968,\"pH\":3.2,\"sulphates\":0.68,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.76,\"citric acid\":0.04,\"residual sugar\":2.3,\"chlorides\":0.092,\"free sulfur dioxide\":15.0,\"total sulfur dioxide\":54.0,\"density\":0.997,\"pH\":3.26,\"sulphates\":0.65,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":11.2,\"volatile acidity\":0.28,\"citric acid\":0.56,\"residual sugar\":1.9,\"chlorides\":0.075,\"free sulfur dioxide\":17.0,\"total sulfur dioxide\":60.0,\"density\":0.998,\"pH\":3.16,\"sulphates\":0.58,\"alcohol\":9.8,\"quality\":6},{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"residual sugar\":1.9,\"chlorides\":0.076,\"free sulfur dioxide\":11.0,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"pH\":3.51,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 2 } ], "source": [ "val raw_df = DataFrame.readCSV(fileOrUrl = \"winequality-red.csv\", delimiter = ';')\n", "raw_df.head()" ] }, { "cell_type": "markdown", "source": [ "Note: For formatting, the DataFrame needs to be rendered as HTML. This means that when running in Kotlin Notebook, \"Render DataFrame tables natively\" needs to be turned off." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:23.901440325Z", "start_time": "2023-12-05T11:16:22.875220486Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
columnfixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
fixed acidity1.000000-0.2561310.6717030.1147770.093705-0.153794-0.1131810.668047-0.6829780.183006-0.0616680.124052
volatile acidity-0.2561311.000000-0.5524960.0019180.061298-0.0105040.0764700.0220260.234937-0.260987-0.202288-0.390558
citric acid0.671703-0.5524961.0000000.1435770.203823-0.0609780.0355330.364947-0.5419040.3127700.1099030.226373
residual sugar0.1147770.0019180.1435771.0000000.0556100.1870490.2030280.355283-0.0856520.0055270.0420750.013732
chlorides0.0937050.0612980.2038230.0556101.0000000.0055620.0474000.200632-0.2650260.371260-0.221141-0.128907
free sulfur dioxide-0.153794-0.010504-0.0609780.1870490.0055621.0000000.667666-0.0219460.0703770.051658-0.069408-0.050656
total sulfur dioxide-0.1131810.0764700.0355330.2030280.0474000.6676661.0000000.071269-0.0664950.042947-0.205654-0.185100
density0.6680470.0220260.3649470.3552830.200632-0.0219460.0712691.000000-0.3416990.148506-0.496180-0.174919
pH-0.6829780.234937-0.541904-0.085652-0.2650260.070377-0.066495-0.3416991.000000-0.1966480.205633-0.057731
sulphates0.183006-0.2609870.3127700.0055270.3712600.0516580.0429470.148506-0.1966481.0000000.0935950.251397
alcohol-0.061668-0.2022880.1099030.042075-0.221141-0.069408-0.205654-0.4961800.2056330.0935951.0000000.476166
quality0.124052-0.3905580.2263730.013732-0.128907-0.050656-0.185100-0.174919-0.0577310.2513970.4761661.000000
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":12,\"ncol\":13,\"columns\":[\"column\",\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"residual sugar\",\"chlorides\",\"free sulfur dioxide\",\"total sulfur dioxide\",\"density\",\"pH\",\"sulphates\",\"alcohol\",\"quality\"],\"kotlin_dataframe\":[{\"column\":\"fixed acidity\",\"fixed acidity\":1.0,\"volatile acidity\":-0.2561308947703819,\"citric acid\":0.6717034347641041,\"residual sugar\":0.1147767244949209,\"chlorides\":0.09370518632130498,\"free sulfur dioxide\":-0.15379419286482485,\"total sulfur dioxide\":-0.11318144304548039,\"density\":0.6680472921189711,\"pH\":-0.6829781945685299,\"sulphates\":0.18300566393215348,\"alcohol\":-0.06166827062815111,\"quality\":0.1240516491132247},{\"column\":\"volatile acidity\",\"fixed acidity\":-0.2561308947703819,\"volatile acidity\":1.0,\"citric acid\":-0.5524956845595839,\"residual sugar\":0.001917881962790698,\"chlorides\":0.061297772476461614,\"free sulfur dioxide\":-0.010503827006591856,\"total sulfur dioxide\":0.07647000482092836,\"density\":0.022026232195215885,\"pH\":0.23493729440739436,\"sulphates\":-0.26098668528329055,\"alcohol\":-0.20228802715325686,\"quality\":-0.3905577802640094},{\"column\":\"citric acid\",\"fixed acidity\":0.6717034347641041,\"volatile acidity\":-0.5524956845595839,\"citric acid\":1.0,\"residual sugar\":0.14357716157031483,\"chlorides\":0.2038229138290425,\"free sulfur dioxide\":-0.06097812919230497,\"total sulfur dioxide\":0.035533023931161666,\"density\":0.364947175211252,\"pH\":-0.5419041447395132,\"sulphates\":0.31277004385441737,\"alcohol\":0.10990324664156755,\"quality\":0.2263725143180432},{\"column\":\"residual sugar\",\"fixed acidity\":0.1147767244949209,\"volatile acidity\":0.001917881962790698,\"citric acid\":0.14357716157031483,\"residual sugar\":1.0,\"chlorides\":0.05560953520353218,\"free sulfur dioxide\":0.18704899510428666,\"total sulfur dioxide\":0.2030278816971015,\"density\":0.35528337098337653,\"pH\":-0.08565242221887161,\"sulphates\":0.005527121339138363,\"alcohol\":0.04207543720973116,\"quality\":0.013731637340066346},{\"column\":\"chlorides\",\"fixed acidity\":0.09370518632130498,\"volatile acidity\":0.061297772476461614,\"citric acid\":0.2038229138290425,\"residual sugar\":0.05560953520353218,\"chlorides\":1.0,\"free sulfur dioxide\":0.005562147004781117,\"total sulfur dioxide\":0.04740046825907533,\"density\":0.200632326641512,\"pH\":-0.2650261311732279,\"sulphates\":0.371260481285427,\"alcohol\":-0.22114054478828302,\"quality\":-0.12890655993005312},{\"column\":\"free sulfur dioxide\",\"fixed acidity\":-0.15379419286482485,\"volatile acidity\":-0.010503827006591856,\"citric acid\":-0.06097812919230497,\"residual sugar\":0.18704899510428666,\"chlorides\":0.005562147004781117,\"free sulfur dioxide\":1.0,\"total sulfur dioxide\":0.6676664504810212,\"density\":-0.021945831163489242,\"pH\":0.07037749850494217,\"sulphates\":0.051657571842828584,\"alcohol\":-0.06940835356499997,\"quality\":-0.05065605724427643},{\"column\":\"total sulfur dioxide\",\"fixed acidity\":-0.11318144304548039,\"volatile acidity\":0.07647000482092836,\"citric acid\":0.035533023931161666,\"residual sugar\":0.2030278816971015,\"chlorides\":0.04740046825907533,\"free sulfur dioxide\":0.6676664504810212,\"total sulfur dioxide\":1.0,\"density\":0.07126947620310328,\"pH\":-0.06649455901285606,\"sulphates\":0.04294683623953844,\"alcohol\":-0.20565394374367177,\"quality\":-0.18510028892653843},{\"column\":\"density\",\"fixed acidity\":0.6680472921189711,\"volatile acidity\":0.022026232195215885,\"citric acid\":0.364947175211252,\"residual sugar\":0.35528337098337653,\"chlorides\":0.200632326641512,\"free sulfur dioxide\":-0.021945831163489242,\"total sulfur dioxide\":0.07126947620310328,\"density\":1.0,\"pH\":-0.3416993347850301,\"sulphates\":0.14850641172078524,\"alcohol\":-0.4961797702417023,\"quality\":-0.1749192277833492},{\"column\":\"pH\",\"fixed acidity\":-0.6829781945685299,\"volatile acidity\":0.23493729440739436,\"citric acid\":-0.5419041447395132,\"residual sugar\":-0.08565242221887161,\"chlorides\":-0.2650261311732279,\"free sulfur dioxide\":0.07037749850494217,\"total sulfur dioxide\":-0.06649455901285606,\"density\":-0.3416993347850301,\"pH\":1.0,\"sulphates\":-0.1966476023043703,\"alcohol\":0.20563250850549894,\"quality\":-0.0577313912053823},{\"column\":\"sulphates\",\"fixed acidity\":0.18300566393215348,\"volatile acidity\":-0.26098668528329055,\"citric acid\":0.31277004385441737,\"residual sugar\":0.005527121339138363,\"chlorides\":0.371260481285427,\"free sulfur dioxide\":0.051657571842828584,\"total sulfur dioxide\":0.04294683623953844,\"density\":0.14850641172078524,\"pH\":-0.1966476023043703,\"sulphates\":1.0,\"alcohol\":0.09359475041046762,\"quality\":0.25139707906926206},{\"column\":\"alcohol\",\"fixed acidity\":-0.06166827062815111,\"volatile acidity\":-0.20228802715325686,\"citric acid\":0.10990324664156755,\"residual sugar\":0.04207543720973116,\"chlorides\":-0.22114054478828302,\"free sulfur dioxide\":-0.06940835356499997,\"total sulfur dioxide\":-0.20565394374367177,\"density\":-0.4961797702417023,\"pH\":0.20563250850549894,\"sulphates\":0.09359475041046762,\"alcohol\":1.0,\"quality\":0.47616632400114156},{\"column\":\"quality\",\"fixed acidity\":0.1240516491132247,\"volatile acidity\":-0.3905577802640094,\"citric acid\":0.2263725143180432,\"residual sugar\":0.013731637340066346,\"chlorides\":-0.12890655993005312,\"free sulfur dioxide\":-0.05065605724427643,\"total sulfur dioxide\":-0.18510028892653843,\"density\":-0.1749192277833492,\"pH\":-0.0577313912053823,\"sulphates\":0.25139707906926206,\"alcohol\":0.47616632400114156,\"quality\":1.0}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 3 } ], "source": [ "raw_df.corr().format { colsOf() }.with { \n", " linearBg(value = it, from = -1.0 to red, to = 1.0 to green)\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the correlation, we can remove some columns, they seem to be insignificant" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:25.803086434Z", "start_time": "2023-12-05T11:16:23.862783060Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidchloridestotal sulfur dioxidedensitysulphatesalcoholquality
7.4000000.7000000.0000000.07600034.0000000.9978000.5600009.4000005
7.8000000.8800000.0000000.09800067.0000000.9968000.6800009.8000005
7.8000000.7600000.0400000.09200054.0000000.9970000.6500009.8000005
11.2000000.2800000.5600000.07500060.0000000.9980000.5800009.8000006
7.4000000.7000000.0000000.07600034.0000000.9978000.5600009.4000005
7.4000000.6600000.0000000.07500040.0000000.9978000.5600009.4000005
7.9000000.6000000.0600000.06900059.0000000.9964000.4600009.4000005
7.3000000.6500000.0000000.06500021.0000000.9946000.47000010.0000007
7.8000000.5800000.0200000.07300018.0000000.9968000.5700009.5000007
7.5000000.5000000.3600000.071000102.0000000.9978000.80000010.5000005
6.7000000.5800000.0800000.09700065.0000000.9959000.5400009.2000005
7.5000000.5000000.3600000.071000102.0000000.9978000.80000010.5000005
5.6000000.6150000.0000000.08900059.0000000.9943000.5200009.9000005
7.8000000.6100000.2900000.11400029.0000000.9974001.5600009.1000005
8.9000000.6200000.1800000.176000145.0000000.9986000.8800009.2000005
8.9000000.6200000.1900000.170000148.0000000.9986000.9300009.2000005
8.5000000.2800000.5600000.092000103.0000000.9969000.75000010.5000007
8.1000000.5600000.2800000.36800056.0000000.9968001.2800009.3000005
7.4000000.5900000.0800000.08600029.0000000.9974000.5000009.0000004
7.9000000.3200000.5100000.34100056.0000000.9969001.0800009.2000006
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":1599,\"ncol\":9,\"columns\":[\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"chlorides\",\"total sulfur dioxide\",\"density\",\"sulphates\",\"alcohol\",\"quality\"],\"kotlin_dataframe\":[{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"chlorides\":0.076,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.88,\"citric acid\":0.0,\"chlorides\":0.098,\"total sulfur dioxide\":67.0,\"density\":0.9968,\"sulphates\":0.68,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.76,\"citric acid\":0.04,\"chlorides\":0.092,\"total sulfur dioxide\":54.0,\"density\":0.997,\"sulphates\":0.65,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":11.2,\"volatile acidity\":0.28,\"citric acid\":0.56,\"chlorides\":0.075,\"total sulfur dioxide\":60.0,\"density\":0.998,\"sulphates\":0.58,\"alcohol\":9.8,\"quality\":6},{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"chlorides\":0.076,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.4,\"volatile acidity\":0.66,\"citric acid\":0.0,\"chlorides\":0.075,\"total sulfur dioxide\":40.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.9,\"volatile acidity\":0.6,\"citric acid\":0.06,\"chlorides\":0.069,\"total sulfur dioxide\":59.0,\"density\":0.9964,\"sulphates\":0.46,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.3,\"volatile acidity\":0.65,\"citric acid\":0.0,\"chlorides\":0.065,\"total sulfur dioxide\":21.0,\"density\":0.9946,\"sulphates\":0.47,\"alcohol\":10.0,\"quality\":7},{\"fixed acidity\":7.8,\"volatile acidity\":0.58,\"citric acid\":0.02,\"chlorides\":0.073,\"total sulfur dioxide\":18.0,\"density\":0.9968,\"sulphates\":0.57,\"alcohol\":9.5,\"quality\":7},{\"fixed acidity\":7.5,\"volatile acidity\":0.5,\"citric acid\":0.36,\"chlorides\":0.071,\"total sulfur dioxide\":102.0,\"density\":0.9978,\"sulphates\":0.8,\"alcohol\":10.5,\"quality\":5},{\"fixed acidity\":6.7,\"volatile acidity\":0.58,\"citric acid\":0.08,\"chlorides\":0.097,\"total sulfur dioxide\":65.0,\"density\":0.9959,\"sulphates\":0.54,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":7.5,\"volatile acidity\":0.5,\"citric acid\":0.36,\"chlorides\":0.071,\"total sulfur dioxide\":102.0,\"density\":0.9978,\"sulphates\":0.8,\"alcohol\":10.5,\"quality\":5},{\"fixed acidity\":5.6,\"volatile acidity\":0.615,\"citric acid\":0.0,\"chlorides\":0.089,\"total sulfur dioxide\":59.0,\"density\":0.9943,\"sulphates\":0.52,\"alcohol\":9.9,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.61,\"citric acid\":0.29,\"chlorides\":0.114,\"total sulfur dioxide\":29.0,\"density\":0.9974,\"sulphates\":1.56,\"alcohol\":9.1,\"quality\":5},{\"fixed acidity\":8.9,\"volatile acidity\":0.62,\"citric acid\":0.18,\"chlorides\":0.176,\"total sulfur dioxide\":145.0,\"density\":0.9986,\"sulphates\":0.88,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":8.9,\"volatile acidity\":0.62,\"citric acid\":0.19,\"chlorides\":0.17,\"total sulfur dioxide\":148.0,\"density\":0.9986,\"sulphates\":0.93,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":8.5,\"volatile acidity\":0.28,\"citric acid\":0.56,\"chlorides\":0.092,\"total sulfur dioxide\":103.0,\"density\":0.9969,\"sulphates\":0.75,\"alcohol\":10.5,\"quality\":7},{\"fixed acidity\":8.1,\"volatile acidity\":0.56,\"citric acid\":0.28,\"chlorides\":0.368,\"total sulfur dioxide\":56.0,\"density\":0.9968,\"sulphates\":1.28,\"alcohol\":9.3,\"quality\":5},{\"fixed acidity\":7.4,\"volatile acidity\":0.59,\"citric acid\":0.08,\"chlorides\":0.086,\"total sulfur dioxide\":29.0,\"density\":0.9974,\"sulphates\":0.5,\"alcohol\":9.0,\"quality\":4},{\"fixed acidity\":7.9,\"volatile acidity\":0.32,\"citric acid\":0.51,\"chlorides\":0.341,\"total sulfur dioxide\":56.0,\"density\":0.9969,\"sulphates\":1.08,\"alcohol\":9.2,\"quality\":6}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 4 } ], "source": [ "val df = raw_df.remove { `free sulfur dioxide` and `residual sugar` and pH }\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict wine quality: first approach" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:26.957258498Z", "start_time": "2023-12-05T11:16:25.797377618Z" } }, "outputs": [], "source": [ "// Simple converter function between DataFrame and KotlinDL data representations\n", "fun DataFrame.toOnHeapDataset(labelColumnName: String): OnHeapDataset {\n", " return OnHeapDataset.create(\n", " dataframe = this,\n", " yColumn = labelColumnName\n", " )\n", "}\n", "\n", "fun OnHeapDataset.Companion.create(\n", " dataframe: DataFrame,\n", " yColumn: String\n", "): OnHeapDataset {\n", " fun extractX(): Array =\n", " dataframe.remove(yColumn).rows()\n", " .map { (it.values() as List).toFloatArray() }.toTypedArray()\n", "\n", " fun extractY(): FloatArray =\n", " dataframe.get { yColumn() }.toList().toFloatArray()\n", "\n", " return create(\n", " ::extractX,\n", " ::extractY\n", " )\n", "}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:27.394202002Z", "start_time": "2023-12-05T11:16:26.797989581Z" } }, "outputs": [], "source": [ "val (train, test) = df.convert { colsOf() }.toFloat()\n", " .toOnHeapDataset(labelColumnName = \"quality\")\n", " .split(0.8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define simple neural network with only 2 dense layers" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:29.483880350Z", "start_time": "2023-12-05T11:16:27.172120668Z" } }, "outputs": [], "source": [ "val inputNeurons = train.x[0].size.toLong()\n", "\n", "val model = Sequential.of(\n", " Input(\n", " inputNeurons,\n", " ),\n", " Dense(\n", " outputSize = (inputNeurons * 10).toInt(),\n", " activation = Activations.Tanh,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal(),\n", " ),\n", " Dense(\n", " outputSize = (inputNeurons * 10).toInt(),\n", " activation = Activations.Tanh,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal(),\n", " ),\n", " Dense(\n", " outputSize = 1,\n", " activation = Activations.Linear,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal(),\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:29.682507911Z", "start_time": "2023-12-05T11:16:29.432917424Z" } }, "outputs": [], "source": [ "model.compile(optimizer = Adam(), loss = Losses.MSE, metric = Metrics.MAE)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:29.791145300Z", "start_time": "2023-12-05T11:16:29.626841115Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==============================================================================\n", "Model type: Sequential\n", "______________________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "==============================================================================\n", "input_1(Input) [None, 8] 0 \n", "______________________________________________________________________________\n", "dense_2(Dense) [None, 80] 720 \n", "______________________________________________________________________________\n", "dense_3(Dense) [None, 80] 6480 \n", "______________________________________________________________________________\n", "dense_4(Dense) [None, 1] 81 \n", "______________________________________________________________________________\n", "==============================================================================\n", "Total trainable params: 7281\n", "Total frozen params: 0\n", "Total params: 7281\n", "______________________________________________________________________________\n" ] } ], "source": [ "model.printSummary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train it!" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:38.582565633Z", "start_time": "2023-12-05T11:16:29.756612239Z" } }, "outputs": [], "source": [ "val trainHist = model.fit(train, batchSize = 500, epochs=2000)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:38.871304073Z", "start_time": "2023-12-05T11:16:38.583131469Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
epochIndexlossValuemetricValuesvalLossValuevalMetricValues
19960.334877[0.45114806294441223]NaN[NaN]
19970.334841[0.45111772418022156]NaN[NaN]
19980.334805[0.45108696818351746]NaN[NaN]
19990.334768[0.45105621218681335]NaN[NaN]
20000.334732[0.4510253965854645]NaN[NaN]
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":5,\"ncol\":5,\"columns\":[\"epochIndex\",\"lossValue\",\"metricValues\",\"valLossValue\",\"valMetricValues\"],\"kotlin_dataframe\":[{\"epochIndex\":1996,\"lossValue\":0.3348773419857025,\"metricValues\":[0.45114806294441223],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1997,\"lossValue\":0.3348410427570343,\"metricValues\":[0.45111772418022156],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1998,\"lossValue\":0.3348047435283661,\"metricValues\":[0.45108696818351746],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1999,\"lossValue\":0.3347683250904083,\"metricValues\":[0.45105621218681335],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":2000,\"lossValue\":0.33473196625709534,\"metricValues\":[0.4510253965854645],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 11 } ], "source": [ "trainHist.epochHistory.toDataFrame().tail()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check that our network predicts values more or less correctly:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:39.115254360Z", "start_time": "2023-12-05T11:16:38.837383297Z" } }, "outputs": [ { "data": { "text/plain": [ "5.2477317" ] }, "metadata": {}, "output_type": "execute_result", "execution_count": 12 } ], "source": [ "model.predictSoftly(test.x[9])[0]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:39.309250279Z", "start_time": "2023-12-05T11:16:38.956283998Z" } }, "outputs": [ { "data": { "text/plain": [ "5.0" ] }, "metadata": {}, "output_type": "execute_result", "execution_count": 13 } ], "source": [ "test.y[9]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Close the model:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:39.367595540Z", "start_time": "2023-12-05T11:16:39.021784659Z" } }, "outputs": [], "source": [ "model.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict wine quality: second approach" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:39.858262608Z", "start_time": "2023-12-05T11:16:39.092452824Z" } }, "outputs": [], "source": [ "data class TrainTestSplitResult(\n", " val trainX: DataFrame,\n", " val trainY: DataFrame,\n", " val testX: DataFrame,\n", " val testY: DataFrame,\n", ")\n", "\n", "fun trainTestSplit(\n", " d: DataFrame,\n", " col: String,\n", " trainPart: Double,\n", "): TrainTestSplitResult {\n", " val n = d.count()\n", " val trainN = ceil(n * trainPart).toInt()\n", "\n", " val shuffledInd = (0 until n).shuffled()\n", " val trainInd = shuffledInd.subList(0, trainN)\n", " val testInd = shuffledInd.subList(trainN, n)\n", "\n", " val train = d[trainInd]\n", " val test = d[testInd]\n", "\n", " val trainX = train.select { all().except(cols(col)) }\n", " val trainY = train.select(col)\n", "\n", " val testX = test.select { all().except(cols(col)) }\n", " val testY = test.select(col)\n", "\n", " return TrainTestSplitResult(trainX, trainY, testX, testY)\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's create and then train the model as we did before" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:41.588501858Z", "start_time": "2023-12-05T11:16:39.773619363Z" } }, "outputs": [], "source": [ "val (trainX, trainY, testX, testY) =\n", " trainTestSplit(df, \"quality\", 0.8)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:42.028324910Z", "start_time": "2023-12-05T11:16:41.587739397Z" } }, "outputs": [], "source": [ "fun DataFrame.toX(): Array =\n", " merge { colsOf() }.by { it.map { it.toFloat() }.toFloatArray() }.into(\"X\")\n", " .get { \"X\"() }\n", " .toList()\n", " .toTypedArray()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:42.332839608Z", "start_time": "2023-12-05T11:16:41.996806378Z" } }, "outputs": [], "source": [ "fun DataFrame.toY(): FloatArray = \n", " get { \"quality\"() }\n", " .asIterable()\n", " .map { it.toFloat() }\n", " .toFloatArray()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:42.527702465Z", "start_time": "2023-12-05T11:16:42.217032792Z" } }, "outputs": [], "source": [ "val trainXDL = trainX.toX()\n", "val trainYDL = trainY.toY()\n", "val testXDL = testX.toX()\n", "val testYDL = testY.toY()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:42.774734665Z", "start_time": "2023-12-05T11:16:42.439947039Z" } }, "outputs": [], "source": [ "val trainKotlinDLDataset = OnHeapDataset.create({ trainXDL }, { trainYDL })\n", "val testKotlinDLDataset = OnHeapDataset.create({ testXDL }, { testYDL })" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:43.061367684Z", "start_time": "2023-12-05T11:16:42.605357371Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==============================================================================\n", "Model type: Sequential\n", "______________________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "==============================================================================\n", "input_1(Input) [None, 8] 0 \n", "______________________________________________________________________________\n", "dense_2(Dense) [None, 80] 720 \n", "______________________________________________________________________________\n", "dense_3(Dense) [None, 80] 6480 \n", "______________________________________________________________________________\n", "dense_4(Dense) [None, 1] 81 \n", "______________________________________________________________________________\n", "==============================================================================\n", "Total trainable params: 7281\n", "Total frozen params: 0\n", "Total params: 7281\n", "______________________________________________________________________________\n" ] } ], "source": [ "val inputNeurons = train.x[0].size.toLong()\n", "\n", "val model2 = Sequential.of(\n", " Input(\n", " inputNeurons\n", " ),\n", " Dense(\n", " outputSize = (inputNeurons * 10).toInt(),\n", " activation = Activations.Tanh,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal()\n", " ),\n", " Dense(\n", " outputSize = (inputNeurons * 10).toInt(),\n", " activation = Activations.Tanh,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal()\n", " ),\n", " Dense(\n", " outputSize = 1,\n", " activation = Activations.Linear,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal()\n", " )\n", ")\n", "model2.compile(optimizer = Adam(), loss = Losses.MSE, metric = Metrics.MAE)\n", "model2.printSummary()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:50.819499613Z", "start_time": "2023-12-05T11:16:42.942896113Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
epochIndexlossValuemetricValuesvalLossValuevalMetricValues
19960.334773[0.45107388496398926]NaN[NaN]
19970.334737[0.45104312896728516]NaN[NaN]
19980.334700[0.45101237297058105]NaN[NaN]
19990.334663[0.4509815275669098]NaN[NaN]
20000.334626[0.45095062255859375]NaN[NaN]
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":5,\"ncol\":5,\"columns\":[\"epochIndex\",\"lossValue\",\"metricValues\",\"valLossValue\",\"valMetricValues\"],\"kotlin_dataframe\":[{\"epochIndex\":1996,\"lossValue\":0.3347732126712799,\"metricValues\":[0.45107388496398926],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1997,\"lossValue\":0.33473655581474304,\"metricValues\":[0.45104312896728516],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1998,\"lossValue\":0.3346997797489166,\"metricValues\":[0.45101237297058105],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1999,\"lossValue\":0.33466312289237976,\"metricValues\":[0.4509815275669098],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":2000,\"lossValue\":0.3346264064311981,\"metricValues\":[0.45095062255859375],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 22 } ], "source": [ "val trainHist = model2.fit(train, batchSize = 500, epochs = 2000)\n", "trainHist.epochHistory.toDataFrame().tail()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:50.980777499Z", "start_time": "2023-12-05T11:16:50.812978010Z" } }, "outputs": [ { "data": { "text/plain": [ "6.6911993" ] }, "metadata": {}, "output_type": "execute_result", "execution_count": 23 } ], "source": [ "model2.predictSoftly(testXDL[9])[0]" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:51.081838805Z", "start_time": "2023-12-05T11:16:50.895468359Z" } }, "outputs": [ { "data": { "text/plain": [ "7.0" ] }, "metadata": {}, "output_type": "execute_result", "execution_count": 24 } ], "source": [ "testYDL[9]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also compare predicted and ground truth values to ensure predictions are correct" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:51.523573225Z", "start_time": "2023-12-05T11:16:50.955044609Z" } }, "outputs": [], "source": [ "val predicted = testXDL.mapIndexed { i, _ ->\n", " round(model2.predictSoftly(testXDL[i])[0]).toInt()\n", "}.toColumn(\"predicted\")\n", "\n", "val ground_truth = testYDL.mapIndexed { i, _ ->\n", " testYDL[i].toInt()\n", "}.toColumn(\"ground_truth\")\n", "\n", "val predDf = dataFrameOf(predicted, ground_truth)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:51.721871863Z", "start_time": "2023-12-05T11:16:51.510838483Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
predictedground_truth
55
55
65
55
66
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":5,\"ncol\":2,\"columns\":[\"predicted\",\"ground_truth\"],\"kotlin_dataframe\":[{\"predicted\":5,\"ground_truth\":5},{\"predicted\":5,\"ground_truth\":5},{\"predicted\":6,\"ground_truth\":5},{\"predicted\":5,\"ground_truth\":5},{\"predicted\":6,\"ground_truth\":6}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 26 } ], "source": [ "predDf.head()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:52.792811636Z", "start_time": "2023-12-05T11:16:51.613822929Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
ground_truth5647
32010
48210
51054210
6347805
7020016
80103
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":6,\"ncol\":5,\"columns\":[\"ground_truth\",\"5\",\"6\",\"4\",\"7\"],\"kotlin_dataframe\":[{\"ground_truth\":3,\"5\":2,\"6\":0,\"4\":1,\"7\":0},{\"ground_truth\":4,\"5\":8,\"6\":2,\"4\":1,\"7\":0},{\"ground_truth\":5,\"5\":105,\"6\":42,\"4\":1,\"7\":0},{\"ground_truth\":6,\"5\":34,\"6\":78,\"4\":0,\"7\":5},{\"ground_truth\":7,\"5\":0,\"6\":20,\"4\":0,\"7\":16},{\"ground_truth\":8,\"5\":0,\"6\":1,\"4\":0,\"7\":3}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 27 } ], "source": [ "val inds = List(10) { it + 1 }\n", "val ctab = predDf\n", " .groupBy { ground_truth }.pivotCounts(inward = false) { predicted }\n", " .sortBy { ground_truth }\n", "\n", "ctab.format { drop(1) }.perRowCol { row, col ->\n", " val y = col.name().toInt()\n", " val x = row.ground_truth\n", " val k = 1.0 - abs(x - y) / 10.0\n", " background(RGBColor(50, (50 + k * 200).toInt().toShort(), 50))\n", "}" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:53.265516155Z", "start_time": "2023-12-05T11:16:52.633490769Z" } }, "outputs": [], "source": [ "val predDf2 = predDf.add(\"avg_dev\") { abs(predicted - ground_truth) }" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:53.466031186Z", "start_time": "2023-12-05T11:16:53.040384705Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
nametypecountuniquenullstopfreqmeanstdminmedianmax
avg_devInt3193002000.3887150.519432002
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":1,\"ncol\":12,\"columns\":[\"name\",\"type\",\"count\",\"unique\",\"nulls\",\"top\",\"freq\",\"mean\",\"std\",\"min\",\"median\",\"max\"],\"kotlin_dataframe\":[{\"name\":\"avg_dev\",\"type\":\"Int\",\"count\":319,\"unique\":3,\"nulls\":0,\"top\":0,\"freq\":200,\"mean\":0.3887147335423197,\"std\":0.5194317560418836,\"min\":0,\"median\":0,\"max\":2}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 29 } ], "source": [ "predDf2.avg_dev.cast().describe()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:53.495338730Z", "start_time": "2023-12-05T11:16:53.198655138Z" } }, "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
predictedground_truthavg_dev
651
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"nrow\":1,\"ncol\":3,\"columns\":[\"predicted\",\"ground_truth\",\"avg_dev\"],\"kotlin_dataframe\":[{\"predicted\":6,\"ground_truth\":5,\"avg_dev\":1}]}" }, "metadata": {}, "output_type": "execute_result", "execution_count": 30 } ], "source": [ "predDf2.sortBy { avg_dev }[(0.7 * (319 - 1)).toInt()]" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "ExecuteTime": { "end_time": "2023-12-05T11:16:53.539689733Z", "start_time": "2023-12-05T11:16:53.340929615Z" } }, "outputs": [], "source": [ "model2.close()" ] } ], "metadata": { "kernelspec": { "display_name": "Kotlin", "language": "kotlin", "name": "kotlin" }, "language_info": { "codemirror_mode": "text/x-kotlin", "file_extension": ".kt", "mimetype": "text/x-kotlin", "name": "kotlin", "nbconvert_exporter": "", "pygments_lexer": "kotlin", "version": "1.8.0-dev-707" }, "ktnbPluginMetadata": { "projectLibraries": [] } }, "nbformat": 4, "nbformat_minor": 1 }