{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict wine quality" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use the latest versions of DataFrame and KotlinDL libraries from the [version repository](https://github.com/Kotlin/kotlin-jupyter-libraries).\n", "\n", "To run this notebook in Kotlin Notebook, please make sure \"Resolve multiplatform dependencies\" is turned OFF for this library" ] }, { "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:56:40.349956Z", "start_time": "2025-05-28T10:56:40.324636Z" } }, "cell_type": "code", "source": "%useLatestDescriptors", "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:56:46.678907Z", "start_time": "2025-05-28T10:56:40.353708Z" } }, "cell_type": "code", "source": "%use dataframe", "outputs": [], "execution_count": 2 }, { "cell_type": "code", "metadata": { "pycharm": { "is_executing": true }, "ExecuteTime": { "end_time": "2025-05-28T10:56:49.158217Z", "start_time": "2025-05-28T10:56:46.694225Z" } }, "source": "%use kotlin-dl", "outputs": [], "execution_count": 3 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Read the dataframe from CSV and print the first few lines of it" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:56:53.207927Z", "start_time": "2025-05-28T10:56:49.171422Z" } }, "source": [ "val rawDf = DataFrame.readCsv(fileOrUrl = \"winequality-red.csv\", delimiter = ';')\n", "rawDf.head()" ], "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidresidual sugarchloridesfree sulfur dioxidetotal sulfur dioxidedensitypHsulphatesalcoholquality
7.4000000.7000000.0000001.9000000.07600011.00000034.0000000.9978003.5100000.5600009.4000005
7.8000000.8800000.0000002.6000000.09800025.00000067.0000000.9968003.2000000.6800009.8000005
7.8000000.7600000.0400002.3000000.09200015.00000054.0000000.9970003.2600000.6500009.8000005
11.2000000.2800000.5600001.9000000.07500017.00000060.0000000.9980003.1600000.5800009.8000006
7.4000000.7000000.0000001.9000000.07600011.00000034.0000000.9978003.5100000.5600009.4000005
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"residual sugar\",\"chlorides\",\"free sulfur dioxide\",\"total sulfur dioxide\",\"density\",\"pH\",\"sulphates\",\"alcohol\",\"quality\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"}],\"nrow\":5,\"ncol\":12},\"kotlin_dataframe\":[{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"residual sugar\":1.9,\"chlorides\":0.076,\"free sulfur dioxide\":11.0,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"pH\":3.51,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.88,\"citric acid\":0.0,\"residual sugar\":2.6,\"chlorides\":0.098,\"free sulfur dioxide\":25.0,\"total sulfur dioxide\":67.0,\"density\":0.9968,\"pH\":3.2,\"sulphates\":0.68,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.76,\"citric acid\":0.04,\"residual sugar\":2.3,\"chlorides\":0.092,\"free sulfur dioxide\":15.0,\"total sulfur dioxide\":54.0,\"density\":0.997,\"pH\":3.26,\"sulphates\":0.65,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":11.2,\"volatile acidity\":0.28,\"citric acid\":0.56,\"residual sugar\":1.9,\"chlorides\":0.075,\"free sulfur dioxide\":17.0,\"total sulfur dioxide\":60.0,\"density\":0.998,\"pH\":3.16,\"sulphates\":0.58,\"alcohol\":9.8,\"quality\":6},{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"residual sugar\":1.9,\"chlorides\":0.076,\"free sulfur dioxide\":11.0,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"pH\":3.51,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5}]}" }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 4 }, { "metadata": {}, "cell_type": "markdown", "source": [ "_Note:_ For formatting, the DataFrame needs to be rendered as HTML.\n", "This means that when running in Kotlin Notebook, \"Render DataFrame tables natively\" needs to be turned off, or we need\n", "to explicitly turn the dataframe into HTML." ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:58:34.197613Z", "start_time": "2025-05-28T10:58:34.019830Z" } }, "source": [ "rawDf.corr()\n", " .format { colsOf() }.with { linearBg(value = it, from = -1.0 to red, to = 1.0 to green) }\n", " .toHtml()" ], "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 6 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Based on the correlation, we can remove some columns, they seem to be insignificant" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:58:45.713301Z", "start_time": "2025-05-28T10:58:45.136784Z" } }, "source": [ "val df = rawDf.remove { `free sulfur dioxide` and `residual sugar` and pH }\n", "df" ], "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
fixed acidityvolatile aciditycitric acidchloridestotal sulfur dioxidedensitysulphatesalcoholquality
7.4000000.7000000.0000000.07600034.0000000.9978000.5600009.4000005
7.8000000.8800000.0000000.09800067.0000000.9968000.6800009.8000005
7.8000000.7600000.0400000.09200054.0000000.9970000.6500009.8000005
11.2000000.2800000.5600000.07500060.0000000.9980000.5800009.8000006
7.4000000.7000000.0000000.07600034.0000000.9978000.5600009.4000005
7.4000000.6600000.0000000.07500040.0000000.9978000.5600009.4000005
7.9000000.6000000.0600000.06900059.0000000.9964000.4600009.4000005
7.3000000.6500000.0000000.06500021.0000000.9946000.47000010.0000007
7.8000000.5800000.0200000.07300018.0000000.9968000.5700009.5000007
7.5000000.5000000.3600000.071000102.0000000.9978000.80000010.5000005
6.7000000.5800000.0800000.09700065.0000000.9959000.5400009.2000005
7.5000000.5000000.3600000.071000102.0000000.9978000.80000010.5000005
5.6000000.6150000.0000000.08900059.0000000.9943000.5200009.9000005
7.8000000.6100000.2900000.11400029.0000000.9974001.5600009.1000005
8.9000000.6200000.1800000.176000145.0000000.9986000.8800009.2000005
8.9000000.6200000.1900000.170000148.0000000.9986000.9300009.2000005
8.5000000.2800000.5600000.092000103.0000000.9969000.75000010.5000007
8.1000000.5600000.2800000.36800056.0000000.9968001.2800009.3000005
7.4000000.5900000.0800000.08600029.0000000.9974000.5000009.0000004
7.9000000.3200000.5100000.34100056.0000000.9969001.0800009.2000006
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"chlorides\",\"total sulfur dioxide\",\"density\",\"sulphates\",\"alcohol\",\"quality\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double?\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double?\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"}],\"nrow\":1599,\"ncol\":9},\"kotlin_dataframe\":[{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"chlorides\":0.076,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.88,\"citric acid\":0.0,\"chlorides\":0.098,\"total sulfur dioxide\":67.0,\"density\":0.9968,\"sulphates\":0.68,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.76,\"citric acid\":0.04,\"chlorides\":0.092,\"total sulfur dioxide\":54.0,\"density\":0.997,\"sulphates\":0.65,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":11.2,\"volatile acidity\":0.28,\"citric acid\":0.56,\"chlorides\":0.075,\"total sulfur dioxide\":60.0,\"density\":0.998,\"sulphates\":0.58,\"alcohol\":9.8,\"quality\":6},{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"chlorides\":0.076,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.4,\"volatile acidity\":0.66,\"citric acid\":0.0,\"chlorides\":0.075,\"total sulfur dioxide\":40.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.9,\"volatile acidity\":0.6,\"citric acid\":0.06,\"chlorides\":0.069,\"total sulfur dioxide\":59.0,\"density\":0.9964,\"sulphates\":0.46,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.3,\"volatile acidity\":0.65,\"citric acid\":0.0,\"chlorides\":0.065,\"total sulfur dioxide\":21.0,\"density\":0.9946,\"sulphates\":0.47,\"alcohol\":10.0,\"quality\":7},{\"fixed acidity\":7.8,\"volatile acidity\":0.58,\"citric acid\":0.02,\"chlorides\":0.073,\"total sulfur dioxide\":18.0,\"density\":0.9968,\"sulphates\":0.57,\"alcohol\":9.5,\"quality\":7},{\"fixed acidity\":7.5,\"volatile acidity\":0.5,\"citric acid\":0.36,\"chlorides\":0.071,\"total sulfur dioxide\":102.0,\"density\":0.9978,\"sulphates\":0.8,\"alcohol\":10.5,\"quality\":5},{\"fixed acidity\":6.7,\"volatile acidity\":0.58,\"citric acid\":0.08,\"chlorides\":0.097,\"total sulfur dioxide\":65.0,\"density\":0.9959,\"sulphates\":0.54,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":7.5,\"volatile acidity\":0.5,\"citric acid\":0.36,\"chlorides\":0.071,\"total sulfur dioxide\":102.0,\"density\":0.9978,\"sulphates\":0.8,\"alcohol\":10.5,\"quality\":5},{\"fixed acidity\":5.6,\"volatile acidity\":0.615,\"citric acid\":0.0,\"chlorides\":0.089,\"total sulfur dioxide\":59.0,\"density\":0.9943,\"sulphates\":0.52,\"alcohol\":9.9,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.61,\"citric acid\":0.29,\"chlorides\":0.114,\"total sulfur dioxide\":29.0,\"density\":0.9974,\"sulphates\":1.56,\"alcohol\":9.1,\"quality\":5},{\"fixed acidity\":8.9,\"volatile acidity\":0.62,\"citric acid\":0.18,\"chlorides\":0.176,\"total sulfur dioxide\":145.0,\"density\":0.9986,\"sulphates\":0.88,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":8.9,\"volatile acidity\":0.62,\"citric acid\":0.19,\"chlorides\":0.17,\"total sulfur dioxide\":148.0,\"density\":0.9986,\"sulphates\":0.93,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":8.5,\"volatile acidity\":0.28,\"citric acid\":0.56,\"chlorides\":0.092,\"total sulfur dioxide\":103.0,\"density\":0.9969,\"sulphates\":0.75,\"alcohol\":10.5,\"quality\":7},{\"fixed acidity\":8.1,\"volatile acidity\":0.56,\"citric acid\":0.28,\"chlorides\":0.368,\"total sulfur dioxide\":56.0,\"density\":0.9968,\"sulphates\":1.28,\"alcohol\":9.3,\"quality\":5},{\"fixed acidity\":7.4,\"volatile acidity\":0.59,\"citric acid\":0.08,\"chlorides\":0.086,\"total sulfur dioxide\":29.0,\"density\":0.9974,\"sulphates\":0.5,\"alcohol\":9.0,\"quality\":4},{\"fixed acidity\":7.9,\"volatile acidity\":0.32,\"citric acid\":0.51,\"chlorides\":0.341,\"total sulfur dioxide\":56.0,\"density\":0.9969,\"sulphates\":1.08,\"alcohol\":9.2,\"quality\":6}]}" }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 7 }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict wine quality: first approach" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:58:56.621291Z", "start_time": "2025-05-28T10:58:56.116886Z" } }, "source": [ "// Simple converter function between DataFrame and KotlinDL data representations\n", "fun DataFrame.toOnHeapDataset(labelColumnName: String): OnHeapDataset {\n", " return OnHeapDataset.create(\n", " dataframe = this,\n", " yColumn = labelColumnName\n", " )\n", "}\n", "\n", "fun OnHeapDataset.Companion.create(\n", " dataframe: DataFrame,\n", " yColumn: String\n", "): OnHeapDataset {\n", " fun extractX(): Array =\n", " dataframe.remove(yColumn).rows()\n", " .map { (it.values() as List).toFloatArray() }.toTypedArray()\n", "\n", " fun extractY(): FloatArray =\n", " dataframe.get { yColumn() }.toList().toFloatArray()\n", "\n", " return create(\n", " ::extractX,\n", " ::extractY\n", " )\n", "}" ], "outputs": [], "execution_count": 8 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:04.592184Z", "start_time": "2025-05-28T10:59:04.415519Z" } }, "source": [ "val (train, test) = df.convert { colsOf() }.toFloat()\n", " .toOnHeapDataset(labelColumnName = \"quality\")\n", " .split(0.8)" ], "outputs": [], "execution_count": 9 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define simple neural network with only 2 dense layers" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:17.233420Z", "start_time": "2025-05-28T10:59:15.925180Z" } }, "source": [ "val inputNeurons = train.x[0].size.toLong()\n", "\n", "val model = Sequential.of(\n", " Input(\n", " inputNeurons,\n", " ),\n", " Dense(\n", " outputSize = (inputNeurons * 10).toInt(),\n", " activation = Activations.Tanh,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal(),\n", " ),\n", " Dense(\n", " outputSize = (inputNeurons * 10).toInt(),\n", " activation = Activations.Tanh,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal(),\n", " ),\n", " Dense(\n", " outputSize = 1,\n", " activation = Activations.Linear,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal(),\n", " )\n", ")" ], "outputs": [], "execution_count": 10 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:19.520847Z", "start_time": "2025-05-28T10:59:19.323480Z" } }, "source": [ "model.compile(optimizer = Adam(), loss = Losses.MSE, metric = Metrics.MAE)" ], "outputs": [], "execution_count": 11 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:21.999158Z", "start_time": "2025-05-28T10:59:21.845761Z" } }, "source": [ "model.printSummary()" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==============================================================================\n", "Model type: Sequential\n", "______________________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "==============================================================================\n", "input_1(Input) [None, 8] 0 \n", "______________________________________________________________________________\n", "dense_2(Dense) [None, 80] 720 \n", "______________________________________________________________________________\n", "dense_3(Dense) [None, 80] 6480 \n", "______________________________________________________________________________\n", "dense_4(Dense) [None, 1] 81 \n", "______________________________________________________________________________\n", "==============================================================================\n", "Total trainable params: 7281\n", "Total frozen params: 0\n", "Total params: 7281\n", "______________________________________________________________________________\n" ] } ], "execution_count": 12 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train it!" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:34.097925Z", "start_time": "2025-05-28T10:59:28.769627Z" } }, "source": [ "val trainHist = model.fit(train, batchSize = 500, epochs=2000)" ], "outputs": [], "execution_count": 13 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:34.289579Z", "start_time": "2025-05-28T10:59:34.101283Z" } }, "source": [ "trainHist.epochHistory.toDataFrame().tail()" ], "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
epochIndexlossValuemetricValuesvalLossValuevalMetricValues
19960.334851[0.45112717151641846]NaN[NaN]
19970.334814[0.45109668374061584]NaN[NaN]
19980.334778[0.45106613636016846]NaN[NaN]
19990.334741[0.45103588700294495]NaN[NaN]
20000.334705[0.45100536942481995]NaN[NaN]
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"epochIndex\",\"lossValue\",\"metricValues\",\"valLossValue\",\"valMetricValues\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.collections.List\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.collections.List\"}],\"nrow\":5,\"ncol\":5},\"kotlin_dataframe\":[{\"epochIndex\":1996,\"lossValue\":0.334850937128067,\"metricValues\":[0.45112717151641846],\"valLossValue\":NaN,\"valMetricValues\":[NaN]},{\"epochIndex\":1997,\"lossValue\":0.3348143994808197,\"metricValues\":[0.45109668374061584],\"valLossValue\":NaN,\"valMetricValues\":[NaN]},{\"epochIndex\":1998,\"lossValue\":0.33477771282196045,\"metricValues\":[0.45106613636016846],\"valLossValue\":NaN,\"valMetricValues\":[NaN]},{\"epochIndex\":1999,\"lossValue\":0.3347410261631012,\"metricValues\":[0.45103588700294495],\"valLossValue\":NaN,\"valMetricValues\":[NaN]},{\"epochIndex\":2000,\"lossValue\":0.33470451831817627,\"metricValues\":[0.45100536942481995],\"valLossValue\":NaN,\"valMetricValues\":[NaN]}]}" }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 14 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check that our network predicts values more or less correctly:" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:41.417539Z", "start_time": "2025-05-28T10:59:41.312689Z" } }, "source": [ "model.predictSoftly(test.x[9])[0]" ], "outputs": [ { "data": { "text/plain": [ "5.24972" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 15 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:42.274260Z", "start_time": "2025-05-28T10:59:42.202425Z" } }, "source": [ "test.y[9]" ], "outputs": [ { "data": { "text/plain": [ "5.0" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 16 }, { "cell_type": "markdown", "metadata": {}, "source": [ "Close the model:" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T10:59:44.410731Z", "start_time": "2025-05-28T10:59:44.301119Z" } }, "source": [ "model.close()" ], "outputs": [], "execution_count": 17 }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predict wine quality: second approach" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:00:40.005511Z", "start_time": "2025-05-28T11:00:39.811245Z" } }, "source": [ "data class TrainTestSplitResult(\n", " val trainX: DataFrame,\n", " val trainY: DataFrame,\n", " val testX: DataFrame,\n", " val testY: DataFrame,\n", ")\n", "\n", "fun trainTestSplit(\n", " d: DataFrame,\n", " col: String,\n", " trainPart: Double,\n", "): TrainTestSplitResult {\n", " val n = d.count()\n", " val trainN = ceil(n * trainPart).toInt()\n", "\n", " val shuffledInd = (0.. DataFrame.toX(): Array =\n", " merge { colsOf() }.by { it.map { it.toFloat() }.toFloatArray() }.into(\"X\")\n", " .get { \"X\"() }\n", " .toList()\n", " .toTypedArray()" ], "outputs": [], "execution_count": 22 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:01:38.146951Z", "start_time": "2025-05-28T11:01:38.051890Z" } }, "source": [ "fun DataFrame.toY(): FloatArray = \n", " get { \"quality\"() }\n", " .asIterable()\n", " .map { it.toFloat() }\n", " .toFloatArray()" ], "outputs": [], "execution_count": 23 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:01:42.822059Z", "start_time": "2025-05-28T11:01:42.710503Z" } }, "source": [ "val trainXDL = trainX.toX()\n", "val trainYDL = trainY.toY()\n", "val testXDL = testX.toX()\n", "val testYDL = testY.toY()" ], "outputs": [], "execution_count": 24 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:01:44.236813Z", "start_time": "2025-05-28T11:01:44.058822Z" } }, "source": [ "val trainKotlinDLDataset = OnHeapDataset.create({ trainXDL }, { trainYDL })\n", "val testKotlinDLDataset = OnHeapDataset.create({ testXDL }, { testYDL })" ], "outputs": [], "execution_count": 25 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:01:48.178391Z", "start_time": "2025-05-28T11:01:47.942204Z" } }, "source": [ "val inputNeurons = train.x[0].size.toLong()\n", "\n", "val model2 = Sequential.of(\n", " Input(\n", " inputNeurons\n", " ),\n", " Dense(\n", " outputSize = (inputNeurons * 10).toInt(),\n", " activation = Activations.Tanh,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal()\n", " ),\n", " Dense(\n", " outputSize = (inputNeurons * 10).toInt(),\n", " activation = Activations.Tanh,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal()\n", " ),\n", " Dense(\n", " outputSize = 1,\n", " activation = Activations.Linear,\n", " kernelInitializer = HeNormal(),\n", " biasInitializer = HeNormal()\n", " )\n", ")\n", "model2.compile(optimizer = Adam(), loss = Losses.MSE, metric = Metrics.MAE)\n", "model2.printSummary()" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==============================================================================\n", "Model type: Sequential\n", "______________________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "==============================================================================\n", "input_1(Input) [None, 8] 0 \n", "______________________________________________________________________________\n", "dense_2(Dense) [None, 80] 720 \n", "______________________________________________________________________________\n", "dense_3(Dense) [None, 80] 6480 \n", "______________________________________________________________________________\n", "dense_4(Dense) [None, 1] 81 \n", "______________________________________________________________________________\n", "==============================================================================\n", "Total trainable params: 7281\n", "Total frozen params: 0\n", "Total params: 7281\n", "______________________________________________________________________________\n" ] } ], "execution_count": 26 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:01:56.732805Z", "start_time": "2025-05-28T11:01:50.419220Z" } }, "source": [ "val trainHist = model2.fit(train, batchSize = 500, epochs = 2000)\n", "trainHist.epochHistory.toDataFrame().tail()" ], "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
epochIndexlossValuemetricValuesvalLossValuevalMetricValues
19960.334532[0.4508610963821411]NaN[NaN]
19970.334495[0.45082950592041016]NaN[NaN]
19980.334458[0.45079800486564636]NaN[NaN]
19990.334421[0.4507667124271393]NaN[NaN]
20000.334384[0.45073509216308594]NaN[NaN]
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"epochIndex\",\"lossValue\",\"metricValues\",\"valLossValue\",\"valMetricValues\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.collections.List\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.collections.List\"}],\"nrow\":5,\"ncol\":5},\"kotlin_dataframe\":[{\"epochIndex\":1996,\"lossValue\":0.3345320522785187,\"metricValues\":[0.4508610963821411],\"valLossValue\":NaN,\"valMetricValues\":[NaN]},{\"epochIndex\":1997,\"lossValue\":0.33449509739875793,\"metricValues\":[0.45082950592041016],\"valLossValue\":NaN,\"valMetricValues\":[NaN]},{\"epochIndex\":1998,\"lossValue\":0.3344581127166748,\"metricValues\":[0.45079800486564636],\"valLossValue\":NaN,\"valMetricValues\":[NaN]},{\"epochIndex\":1999,\"lossValue\":0.33442115783691406,\"metricValues\":[0.4507667124271393],\"valLossValue\":NaN,\"valMetricValues\":[NaN]},{\"epochIndex\":2000,\"lossValue\":0.33438411355018616,\"metricValues\":[0.45073509216308594],\"valLossValue\":NaN,\"valMetricValues\":[NaN]}]}" }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 27 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:01:56.818019Z", "start_time": "2025-05-28T11:01:56.735440Z" } }, "source": [ "model2.predictSoftly(testXDL[9])[0]" ], "outputs": [ { "data": { "text/plain": [ "5.8768764" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 28 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:02:06.727475Z", "start_time": "2025-05-28T11:02:06.684985Z" } }, "source": [ "testYDL[9]" ], "outputs": [ { "data": { "text/plain": [ "5.0" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 29 }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also compare predicted and ground truth values to ensure predictions are correct" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:02:58.701570Z", "start_time": "2025-05-28T11:02:58.378637Z" } }, "source": [ "val predicted = testXDL.mapIndexed { i, _ ->\n", " round(model2.predictSoftly(testXDL[i])[0]).toInt()\n", "}.toColumn(\"predicted\")\n", "\n", "val ground_truth = testYDL.mapIndexed { i, _ ->\n", " testYDL[i].toInt()\n", "}.toColumn(\"ground_truth\")\n", "\n", "val predDf = dataFrameOf(predicted, ground_truth)" ], "outputs": [], "execution_count": 30 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:02:59.233583Z", "start_time": "2025-05-28T11:02:59.172927Z" } }, "source": [ "predDf.head()" ], "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
predictedground_truth
66
54
66
55
55
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"predicted\",\"ground_truth\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"}],\"nrow\":5,\"ncol\":2},\"kotlin_dataframe\":[{\"predicted\":6,\"ground_truth\":6},{\"predicted\":5,\"ground_truth\":4},{\"predicted\":6,\"ground_truth\":6},{\"predicted\":5,\"ground_truth\":5},{\"predicted\":5,\"ground_truth\":5}]}" }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 31 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:03:39.596244Z", "start_time": "2025-05-28T11:03:39.225342Z" } }, "source": [ "val inds = List(10) { it + 1 }\n", "val ctab = predDf\n", " .groupBy { ground_truth }.pivotCounts(inward = false) { predicted }\n", " .sortBy { ground_truth }\n", "\n", "ctab.format { drop(1) }.perRowCol { row, col ->\n", " val y = col.name().toInt()\n", " val x = row.ground_truth\n", " val k = 1.0 - abs(x - y) / 10.0\n", " background(RGBColor(50, (50 + k * 200).toInt().toShort(), 50))\n", "}.toHtml()" ], "outputs": [ { "data": { "text/html": [ " \n", " " ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 35 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:03:56.515424Z", "start_time": "2025-05-28T11:03:56.317184Z" } }, "source": [ "val predDf2 = predDf.add(\"avg_dev\") { abs(predicted - ground_truth) }" ], "outputs": [], "execution_count": 36 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:03:57.220978Z", "start_time": "2025-05-28T11:03:57.020582Z" } }, "source": [ "predDf2.avg_dev.cast().describe()" ], "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
nametypecountuniquenullstopfreqmeanstdminp25medianp75max
avg_devInt3193001960.4075240.53500700.0000000.0000001.0000002
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"name\",\"type\",\"count\",\"unique\",\"nulls\",\"top\",\"freq\",\"mean\",\"std\",\"min\",\"p25\",\"median\",\"p75\",\"max\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.String\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.String\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"}],\"nrow\":1,\"ncol\":14},\"kotlin_dataframe\":[{\"name\":\"avg_dev\",\"type\":\"Int\",\"count\":319,\"unique\":3,\"nulls\":0,\"top\":0,\"freq\":196,\"mean\":0.40752351097178685,\"std\":0.5350070344969485,\"min\":0,\"p25\":0.0,\"median\":0.0,\"p75\":1.0,\"max\":2}]}" }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 37 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:03:58.933406Z", "start_time": "2025-05-28T11:03:58.800454Z" } }, "source": [ "predDf2.sortBy { avg_dev }[(0.7 * (319 - 1)).toInt()]" ], "outputs": [ { "data": { "text/html": [ " \n", " \n", " \n", " \n", " \n", " \n", "
predictedground_truthavg_dev
651
\n", " \n", " \n", " " ], "application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"predicted\",\"ground_truth\",\"avg_dev\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"}],\"nrow\":1,\"ncol\":3},\"kotlin_dataframe\":[{\"predicted\":6,\"ground_truth\":5,\"avg_dev\":1}]}" }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 38 }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2025-05-28T11:04:00.646210Z", "start_time": "2025-05-28T11:04:00.585028Z" } }, "source": [ "model2.close()" ], "outputs": [], "execution_count": 39 }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": "" } ], "metadata": { "kernelspec": { "display_name": "Kotlin", "language": "kotlin", "name": "kotlin" }, "language_info": { "codemirror_mode": "text/x-kotlin", "file_extension": ".kt", "mimetype": "text/x-kotlin", "name": "kotlin", "nbconvert_exporter": "", "pygments_lexer": "kotlin", "version": "1.8.0-dev-707" }, "ktnbPluginMetadata": { "projectLibraries": [] } }, "nbformat": 4, "nbformat_minor": 1 }