{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predict wine quality"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the latest versions of DataFrame and KotlinDL libraries from [version repository](https://github.com/Kotlin/kotlin-jupyter-libraries)."
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-07T12:08:53.598805Z",
"start_time": "2025-03-07T12:08:50.871926Z"
}
},
"cell_type": "code",
"source": "%use dataframe",
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"metadata": {
"pycharm": {
"is_executing": true
},
"ExecuteTime": {
"end_time": "2025-03-07T12:09:23.739653Z",
"start_time": "2025-03-07T12:08:53.625650Z"
}
},
"source": "%use kotlin-dl",
"outputs": [],
"execution_count": 2
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Read the dataframe from CSV and print the first few lines of it"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-07T12:09:24.312493Z",
"start_time": "2025-03-07T12:09:23.762745Z"
}
},
"source": [
"val rawDf = DataFrame.readCsv(fileOrUrl = \"winequality-red.csv\", delimiter = ';')\n",
"rawDf.head()"
],
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
"
\n",
" \n",
" \n",
" \n",
" | fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality |
|---|
| 7,400000 | 0,700000 | 0,000000 | 1,900000 | 0,076000 | 11,000000 | 34,000000 | 0,997800 | 3,510000 | 0,560000 | 9,400000 | 5 |
| 7,800000 | 0,880000 | 0,000000 | 2,600000 | 0,098000 | 25,000000 | 67,000000 | 0,996800 | 3,200000 | 0,680000 | 9,800000 | 5 |
| 7,800000 | 0,760000 | 0,040000 | 2,300000 | 0,092000 | 15,000000 | 54,000000 | 0,997000 | 3,260000 | 0,650000 | 9,800000 | 5 |
| 11,200000 | 0,280000 | 0,560000 | 1,900000 | 0,075000 | 17,000000 | 60,000000 | 0,998000 | 3,160000 | 0,580000 | 9,800000 | 6 |
| 7,400000 | 0,700000 | 0,000000 | 1,900000 | 0,076000 | 11,000000 | 34,000000 | 0,997800 | 3,510000 | 0,560000 | 9,400000 | 5 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"residual sugar\",\"chlorides\",\"free sulfur dioxide\",\"total sulfur dioxide\",\"density\",\"pH\",\"sulphates\",\"alcohol\",\"quality\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"}],\"nrow\":5,\"ncol\":12},\"kotlin_dataframe\":[{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"residual sugar\":1.9,\"chlorides\":0.076,\"free sulfur dioxide\":11.0,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"pH\":3.51,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.88,\"citric acid\":0.0,\"residual sugar\":2.6,\"chlorides\":0.098,\"free sulfur dioxide\":25.0,\"total sulfur dioxide\":67.0,\"density\":0.9968,\"pH\":3.2,\"sulphates\":0.68,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.76,\"citric acid\":0.04,\"residual sugar\":2.3,\"chlorides\":0.092,\"free sulfur dioxide\":15.0,\"total sulfur dioxide\":54.0,\"density\":0.997,\"pH\":3.26,\"sulphates\":0.65,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":11.2,\"volatile acidity\":0.28,\"citric acid\":0.56,\"residual sugar\":1.9,\"chlorides\":0.075,\"free sulfur dioxide\":17.0,\"total sulfur dioxide\":60.0,\"density\":0.998,\"pH\":3.16,\"sulphates\":0.58,\"alcohol\":9.8,\"quality\":6},{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"residual sugar\":1.9,\"chlorides\":0.076,\"free sulfur dioxide\":11.0,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"pH\":3.51,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5}]}"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"cell_type": "markdown",
"source": [
"Note: For formatting, the DataFrame needs to be rendered as HTML. This means that when running in Kotlin Notebook, \"Render DataFrame tables natively\" needs to be turned off."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-07T12:09:24.547956Z",
"start_time": "2025-03-07T12:09:24.355637Z"
}
},
"source": [
"rawDf.corr().format { colsOf() }.with { \n",
" linearBg(value = it, from = -1.0 to red, to = 1.0 to green)\n",
"}"
],
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | column | fixed acidity | volatile acidity | residual sugar | chlorides | density | pH | sulphates | alcohol | quality |
|---|
| fixed acidity | 1,000000 | -0,256131 | 0,114777 | 0,093705 | 0,668047 | -0,682978 | 0,183006 | -0,061668 | 0,124052 |
| volatile acidity | -0,256131 | 1,000000 | 0,001918 | 0,061298 | 0,022026 | 0,234937 | -0,260987 | -0,202288 | -0,390558 |
| residual sugar | 0,114777 | 0,001918 | 1,000000 | 0,055610 | 0,355283 | -0,085652 | 0,005527 | 0,042075 | 0,013732 |
| chlorides | 0,093705 | 0,061298 | 0,055610 | 1,000000 | 0,200632 | -0,265026 | 0,371260 | -0,221141 | -0,128907 |
| density | 0,668047 | 0,022026 | 0,355283 | 0,200632 | 1,000000 | -0,341699 | 0,148506 | -0,496180 | -0,174919 |
| pH | -0,682978 | 0,234937 | -0,085652 | -0,265026 | -0,341699 | 1,000000 | -0,196648 | 0,205633 | -0,057731 |
| sulphates | 0,183006 | -0,260987 | 0,005527 | 0,371260 | 0,148506 | -0,196648 | 1,000000 | 0,093595 | 0,251397 |
| alcohol | -0,061668 | -0,202288 | 0,042075 | -0,221141 | -0,496180 | 0,205633 | 0,093595 | 1,000000 | 0,476166 |
| quality | 0,124052 | -0,390558 | 0,013732 | -0,128907 | -0,174919 | -0,057731 | 0,251397 | 0,476166 | 1,000000 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"column\",\"fixed acidity\",\"volatile acidity\",\"residual sugar\",\"chlorides\",\"density\",\"pH\",\"sulphates\",\"alcohol\",\"quality\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.String\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"}],\"nrow\":9,\"ncol\":10},\"kotlin_dataframe\":[{\"column\":\"fixed acidity\",\"fixed acidity\":1.0,\"volatile acidity\":-0.2561308947703819,\"residual sugar\":0.1147767244949209,\"chlorides\":0.09370518632130498,\"density\":0.6680472921189711,\"pH\":-0.6829781945685299,\"sulphates\":0.18300566393215348,\"alcohol\":-0.06166827062815111,\"quality\":0.1240516491132247},{\"column\":\"volatile acidity\",\"fixed acidity\":-0.2561308947703819,\"volatile acidity\":1.0,\"residual sugar\":0.001917881962790698,\"chlorides\":0.061297772476461614,\"density\":0.022026232195215885,\"pH\":0.23493729440739436,\"sulphates\":-0.26098668528329055,\"alcohol\":-0.20228802715325686,\"quality\":-0.3905577802640094},{\"column\":\"residual sugar\",\"fixed acidity\":0.1147767244949209,\"volatile acidity\":0.001917881962790698,\"residual sugar\":1.0,\"chlorides\":0.05560953520353218,\"density\":0.35528337098337653,\"pH\":-0.08565242221887161,\"sulphates\":0.005527121339138363,\"alcohol\":0.04207543720973116,\"quality\":0.013731637340066346},{\"column\":\"chlorides\",\"fixed acidity\":0.09370518632130498,\"volatile acidity\":0.061297772476461614,\"residual sugar\":0.05560953520353218,\"chlorides\":1.0,\"density\":0.200632326641512,\"pH\":-0.2650261311732279,\"sulphates\":0.371260481285427,\"alcohol\":-0.22114054478828302,\"quality\":-0.12890655993005312},{\"column\":\"density\",\"fixed acidity\":0.6680472921189711,\"volatile acidity\":0.022026232195215885,\"residual sugar\":0.35528337098337653,\"chlorides\":0.200632326641512,\"density\":1.0,\"pH\":-0.3416993347850301,\"sulphates\":0.14850641172078524,\"alcohol\":-0.4961797702417023,\"quality\":-0.1749192277833492},{\"column\":\"pH\",\"fixed acidity\":-0.6829781945685299,\"volatile acidity\":0.23493729440739436,\"residual sugar\":-0.08565242221887161,\"chlorides\":-0.2650261311732279,\"density\":-0.3416993347850301,\"pH\":1.0,\"sulphates\":-0.1966476023043703,\"alcohol\":0.20563250850549894,\"quality\":-0.0577313912053823},{\"column\":\"sulphates\",\"fixed acidity\":0.18300566393215348,\"volatile acidity\":-0.26098668528329055,\"residual sugar\":0.005527121339138363,\"chlorides\":0.371260481285427,\"density\":0.14850641172078524,\"pH\":-0.1966476023043703,\"sulphates\":1.0,\"alcohol\":0.09359475041046762,\"quality\":0.25139707906926206},{\"column\":\"alcohol\",\"fixed acidity\":-0.06166827062815111,\"volatile acidity\":-0.20228802715325686,\"residual sugar\":0.04207543720973116,\"chlorides\":-0.22114054478828302,\"density\":-0.4961797702417023,\"pH\":0.20563250850549894,\"sulphates\":0.09359475041046762,\"alcohol\":1.0,\"quality\":0.47616632400114156},{\"column\":\"quality\",\"fixed acidity\":0.1240516491132247,\"volatile acidity\":-0.3905577802640094,\"residual sugar\":0.013731637340066346,\"chlorides\":-0.12890655993005312,\"density\":-0.1749192277833492,\"pH\":-0.0577313912053823,\"sulphates\":0.25139707906926206,\"alcohol\":0.47616632400114156,\"quality\":1.0}]}"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the correlation, we can remove some columns, they seem to be insignificant"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-07T12:09:24.855470Z",
"start_time": "2025-03-07T12:09:24.639458Z"
}
},
"source": [
"val df = rawDf.remove { `free sulfur dioxide` and `residual sugar` and pH }\n",
"df"
],
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | fixed acidity | volatile acidity | citric acid | chlorides | total sulfur dioxide | density | sulphates | alcohol | quality |
|---|
| 7,400000 | 0,700000 | 0,000000 | 0,076000 | 34,000000 | 0,997800 | 0,560000 | 9,400000 | 5 |
| 7,800000 | 0,880000 | 0,000000 | 0,098000 | 67,000000 | 0,996800 | 0,680000 | 9,800000 | 5 |
| 7,800000 | 0,760000 | 0,040000 | 0,092000 | 54,000000 | 0,997000 | 0,650000 | 9,800000 | 5 |
| 11,200000 | 0,280000 | 0,560000 | 0,075000 | 60,000000 | 0,998000 | 0,580000 | 9,800000 | 6 |
| 7,400000 | 0,700000 | 0,000000 | 0,076000 | 34,000000 | 0,997800 | 0,560000 | 9,400000 | 5 |
| 7,400000 | 0,660000 | 0,000000 | 0,075000 | 40,000000 | 0,997800 | 0,560000 | 9,400000 | 5 |
| 7,900000 | 0,600000 | 0,060000 | 0,069000 | 59,000000 | 0,996400 | 0,460000 | 9,400000 | 5 |
| 7,300000 | 0,650000 | 0,000000 | 0,065000 | 21,000000 | 0,994600 | 0,470000 | 10,000000 | 7 |
| 7,800000 | 0,580000 | 0,020000 | 0,073000 | 18,000000 | 0,996800 | 0,570000 | 9,500000 | 7 |
| 7,500000 | 0,500000 | 0,360000 | 0,071000 | 102,000000 | 0,997800 | 0,800000 | 10,500000 | 5 |
| 6,700000 | 0,580000 | 0,080000 | 0,097000 | 65,000000 | 0,995900 | 0,540000 | 9,200000 | 5 |
| 7,500000 | 0,500000 | 0,360000 | 0,071000 | 102,000000 | 0,997800 | 0,800000 | 10,500000 | 5 |
| 5,600000 | 0,615000 | 0,000000 | 0,089000 | 59,000000 | 0,994300 | 0,520000 | 9,900000 | 5 |
| 7,800000 | 0,610000 | 0,290000 | 0,114000 | 29,000000 | 0,997400 | 1,560000 | 9,100000 | 5 |
| 8,900000 | 0,620000 | 0,180000 | 0,176000 | 145,000000 | 0,998600 | 0,880000 | 9,200000 | 5 |
| 8,900000 | 0,620000 | 0,190000 | 0,170000 | 148,000000 | 0,998600 | 0,930000 | 9,200000 | 5 |
| 8,500000 | 0,280000 | 0,560000 | 0,092000 | 103,000000 | 0,996900 | 0,750000 | 10,500000 | 7 |
| 8,100000 | 0,560000 | 0,280000 | 0,368000 | 56,000000 | 0,996800 | 1,280000 | 9,300000 | 5 |
| 7,400000 | 0,590000 | 0,080000 | 0,086000 | 29,000000 | 0,997400 | 0,500000 | 9,000000 | 4 |
| 7,900000 | 0,320000 | 0,510000 | 0,341000 | 56,000000 | 0,996900 | 1,080000 | 9,200000 | 6 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"$version\":\"2.1.1\",\"metadata\":{\"columns\":[\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"chlorides\",\"total sulfur dioxide\",\"density\",\"sulphates\",\"alcohol\",\"quality\"],\"types\":[{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double?\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double?\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Double\"},{\"kind\":\"ValueColumn\",\"type\":\"kotlin.Int\"}],\"nrow\":1599,\"ncol\":9},\"kotlin_dataframe\":[{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"chlorides\":0.076,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.88,\"citric acid\":0.0,\"chlorides\":0.098,\"total sulfur dioxide\":67.0,\"density\":0.9968,\"sulphates\":0.68,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.76,\"citric acid\":0.04,\"chlorides\":0.092,\"total sulfur dioxide\":54.0,\"density\":0.997,\"sulphates\":0.65,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":11.2,\"volatile acidity\":0.28,\"citric acid\":0.56,\"chlorides\":0.075,\"total sulfur dioxide\":60.0,\"density\":0.998,\"sulphates\":0.58,\"alcohol\":9.8,\"quality\":6},{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"chlorides\":0.076,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.4,\"volatile acidity\":0.66,\"citric acid\":0.0,\"chlorides\":0.075,\"total sulfur dioxide\":40.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.9,\"volatile acidity\":0.6,\"citric acid\":0.06,\"chlorides\":0.069,\"total sulfur dioxide\":59.0,\"density\":0.9964,\"sulphates\":0.46,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.3,\"volatile acidity\":0.65,\"citric acid\":0.0,\"chlorides\":0.065,\"total sulfur dioxide\":21.0,\"density\":0.9946,\"sulphates\":0.47,\"alcohol\":10.0,\"quality\":7},{\"fixed acidity\":7.8,\"volatile acidity\":0.58,\"citric acid\":0.02,\"chlorides\":0.073,\"total sulfur dioxide\":18.0,\"density\":0.9968,\"sulphates\":0.57,\"alcohol\":9.5,\"quality\":7},{\"fixed acidity\":7.5,\"volatile acidity\":0.5,\"citric acid\":0.36,\"chlorides\":0.071,\"total sulfur dioxide\":102.0,\"density\":0.9978,\"sulphates\":0.8,\"alcohol\":10.5,\"quality\":5},{\"fixed acidity\":6.7,\"volatile acidity\":0.58,\"citric acid\":0.08,\"chlorides\":0.097,\"total sulfur dioxide\":65.0,\"density\":0.9959,\"sulphates\":0.54,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":7.5,\"volatile acidity\":0.5,\"citric acid\":0.36,\"chlorides\":0.071,\"total sulfur dioxide\":102.0,\"density\":0.9978,\"sulphates\":0.8,\"alcohol\":10.5,\"quality\":5},{\"fixed acidity\":5.6,\"volatile acidity\":0.615,\"citric acid\":0.0,\"chlorides\":0.089,\"total sulfur dioxide\":59.0,\"density\":0.9943,\"sulphates\":0.52,\"alcohol\":9.9,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.61,\"citric acid\":0.29,\"chlorides\":0.114,\"total sulfur dioxide\":29.0,\"density\":0.9974,\"sulphates\":1.56,\"alcohol\":9.1,\"quality\":5},{\"fixed acidity\":8.9,\"volatile acidity\":0.62,\"citric acid\":0.18,\"chlorides\":0.176,\"total sulfur dioxide\":145.0,\"density\":0.9986,\"sulphates\":0.88,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":8.9,\"volatile acidity\":0.62,\"citric acid\":0.19,\"chlorides\":0.17,\"total sulfur dioxide\":148.0,\"density\":0.9986,\"sulphates\":0.93,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":8.5,\"volatile acidity\":0.28,\"citric acid\":0.56,\"chlorides\":0.092,\"total sulfur dioxide\":103.0,\"density\":0.9969,\"sulphates\":0.75,\"alcohol\":10.5,\"quality\":7},{\"fixed acidity\":8.1,\"volatile acidity\":0.56,\"citric acid\":0.28,\"chlorides\":0.368,\"total sulfur dioxide\":56.0,\"density\":0.9968,\"sulphates\":1.28,\"alcohol\":9.3,\"quality\":5},{\"fixed acidity\":7.4,\"volatile acidity\":0.59,\"citric acid\":0.08,\"chlorides\":0.086,\"total sulfur dioxide\":29.0,\"density\":0.9974,\"sulphates\":0.5,\"alcohol\":9.0,\"quality\":4},{\"fixed acidity\":7.9,\"volatile acidity\":0.32,\"citric acid\":0.51,\"chlorides\":0.341,\"total sulfur dioxide\":56.0,\"density\":0.9969,\"sulphates\":1.08,\"alcohol\":9.2,\"quality\":6}]}"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 5
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict wine quality: first approach"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-07T12:09:25.019813Z",
"start_time": "2025-03-07T12:09:24.874927Z"
}
},
"source": [
"// Simple converter function between DataFrame and KotlinDL data representations\n",
"fun DataFrame.toOnHeapDataset(labelColumnName: String): OnHeapDataset {\n",
" return OnHeapDataset.create(\n",
" dataframe = this,\n",
" yColumn = labelColumnName\n",
" )\n",
"}\n",
"\n",
"fun OnHeapDataset.Companion.create(\n",
" dataframe: DataFrame,\n",
" yColumn: String\n",
"): OnHeapDataset {\n",
" fun extractX(): Array =\n",
" dataframe.remove(yColumn).rows()\n",
" .map { (it.values() as List).toFloatArray() }.toTypedArray()\n",
"\n",
" fun extractY(): FloatArray =\n",
" dataframe.get { yColumn() }.toList().toFloatArray()\n",
"\n",
" return create(\n",
" ::extractX,\n",
" ::extractY\n",
" )\n",
"}"
],
"outputs": [],
"execution_count": 6
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-07T12:09:25.179183Z",
"start_time": "2025-03-07T12:09:25.122890Z"
}
},
"source": [
"val (train, test) = df.convert { colsOf() }.toFloat()\n",
" .toOnHeapDataset(labelColumnName = \"quality\")\n",
" .split(0.8)"
],
"outputs": [],
"execution_count": 7
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define simple neural network with only 2 dense layers"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-07T12:09:25.351518Z",
"start_time": "2025-03-07T12:09:25.235016Z"
}
},
"source": [
"val inputNeurons = train.x[0].size.toLong()\n",
"\n",
"val model = Sequential.of(\n",
" Input(\n",
" inputNeurons,\n",
" ),\n",
" Dense(\n",
" outputSize = (inputNeurons * 10).toInt(),\n",
" activation = Activations.Tanh,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal(),\n",
" ),\n",
" Dense(\n",
" outputSize = (inputNeurons * 10).toInt(),\n",
" activation = Activations.Tanh,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal(),\n",
" ),\n",
" Dense(\n",
" outputSize = 1,\n",
" activation = Activations.Linear,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal(),\n",
" )\n",
")"
],
"outputs": [
{
"ename": "java.lang.UnsatisfiedLinkError",
"evalue": "Cannot find TensorFlow native library for OS: darwin, architecture: aarch64. See https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java/README.md for possible solutions (such as building the library from source). Additional information on attempts to find the native library can be obtained by adding org.tensorflow.NativeLibrary.DEBUG=1 to the system properties of the JVM.",
"output_type": "error",
"traceback": [
"java.lang.UnsatisfiedLinkError: Cannot find TensorFlow native library for OS: darwin, architecture: aarch64. See https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java/README.md for possible solutions (such as building the library from source). Additional information on attempts to find the native library can be obtained by adding org.tensorflow.NativeLibrary.DEBUG=1 to the system properties of the JVM.",
"\tat org.tensorflow.NativeLibrary.load(NativeLibrary.java:77)",
"\tat org.tensorflow.TensorFlow.init(TensorFlow.java:67)",
"\tat org.tensorflow.TensorFlow.(TensorFlow.java:82)",
"\tat org.tensorflow.Graph.(Graph.java:479)",
"\tat org.jetbrains.kotlinx.dl.api.core.GraphTrainableModel.(GraphTrainableModel.kt:113)",
"\tat org.jetbrains.kotlinx.dl.api.core.Sequential.(Sequential.kt:26)",
"\tat org.jetbrains.kotlinx.dl.api.core.Sequential$Companion.of(Sequential.kt:45)",
"\tat org.jetbrains.kotlinx.dl.api.core.Sequential$Companion.of$default(Sequential.kt:39)",
"\tat Line_25_jupyter.(Line_25.jupyter.kts:3) at Cell In[8], line 3",
"\tat java.base/jdk.internal.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)",
"\tat java.base/jdk.internal.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)",
"\tat java.base/jdk.internal.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)",
"\tat java.base/java.lang.reflect.Constructor.newInstance(Constructor.java:490)",
"\tat kotlin.script.experimental.jvm.BasicJvmScriptEvaluator.evalWithConfigAndOtherScriptsResults(BasicJvmScriptEvaluator.kt:122)",
"\tat kotlin.script.experimental.jvm.BasicJvmScriptEvaluator.invoke$suspendImpl(BasicJvmScriptEvaluator.kt:48)",
"\tat kotlin.script.experimental.jvm.BasicJvmScriptEvaluator.invoke(BasicJvmScriptEvaluator.kt)",
"\tat kotlin.script.experimental.jvm.BasicJvmReplEvaluator.eval(BasicJvmReplEvaluator.kt:49)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.InternalEvaluatorImpl$eval$resultWithDiagnostics$1.invokeSuspend(InternalEvaluatorImpl.kt:137)",
"\tat kotlin.coroutines.jvm.internal.BaseContinuationImpl.resumeWith(ContinuationImpl.kt:33)",
"\tat kotlinx.coroutines.DispatchedTask.run(DispatchedTask.kt:104)",
"\tat kotlinx.coroutines.EventLoopImplBase.processNextEvent(EventLoop.common.kt:277)",
"\tat kotlinx.coroutines.BlockingCoroutine.joinBlocking(Builders.kt:95)",
"\tat kotlinx.coroutines.BuildersKt__BuildersKt.runBlocking(Builders.kt:69)",
"\tat kotlinx.coroutines.BuildersKt.runBlocking(Unknown Source)",
"\tat kotlinx.coroutines.BuildersKt__BuildersKt.runBlocking$default(Builders.kt:48)",
"\tat kotlinx.coroutines.BuildersKt.runBlocking$default(Unknown Source)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.InternalEvaluatorImpl.eval(InternalEvaluatorImpl.kt:137)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.CellExecutorImpl$execute$1$result$1.invoke(CellExecutorImpl.kt:80)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.CellExecutorImpl$execute$1$result$1.invoke(CellExecutorImpl.kt:78)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.ReplForJupyterImpl.withHost(ReplForJupyterImpl.kt:774)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.CellExecutorImpl.execute-L4Nmkdk(CellExecutorImpl.kt:78)",
"\tat org.jetbrains.kotlinx.jupyter.repl.execution.CellExecutor$DefaultImpls.execute-L4Nmkdk$default(CellExecutor.kt:13)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.ReplForJupyterImpl.evaluateUserCode-wNURfNM(ReplForJupyterImpl.kt:596)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.ReplForJupyterImpl.evalExImpl(ReplForJupyterImpl.kt:454)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.ReplForJupyterImpl.access$evalExImpl(ReplForJupyterImpl.kt:141)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.ReplForJupyterImpl$evalEx$1.invoke(ReplForJupyterImpl.kt:447)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.ReplForJupyterImpl$evalEx$1.invoke(ReplForJupyterImpl.kt:446)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.ReplForJupyterImpl.withEvalContext(ReplForJupyterImpl.kt:427)",
"\tat org.jetbrains.kotlinx.jupyter.repl.impl.ReplForJupyterImpl.evalEx(ReplForJupyterImpl.kt:446)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor$processExecuteRequest$1$response$1$1.invoke(IdeCompatibleMessageRequestProcessor.kt:171)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor$processExecuteRequest$1$response$1$1.invoke(IdeCompatibleMessageRequestProcessor.kt:170)",
"\tat org.jetbrains.kotlinx.jupyter.streams.BlockingSubstitutionEngine.withDataSubstitution(SubstitutionEngine.kt:70)",
"\tat org.jetbrains.kotlinx.jupyter.streams.StreamSubstitutionManager.withSubstitutedStreams(StreamSubstitutionManager.kt:118)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor.withForkedIn(IdeCompatibleMessageRequestProcessor.kt:347)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor.access$withForkedIn(IdeCompatibleMessageRequestProcessor.kt:67)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor$evalWithIO$1$1.invoke(IdeCompatibleMessageRequestProcessor.kt:361)",
"\tat org.jetbrains.kotlinx.jupyter.streams.BlockingSubstitutionEngine.withDataSubstitution(SubstitutionEngine.kt:70)",
"\tat org.jetbrains.kotlinx.jupyter.streams.StreamSubstitutionManager.withSubstitutedStreams(StreamSubstitutionManager.kt:118)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor.withForkedErr(IdeCompatibleMessageRequestProcessor.kt:336)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor.access$withForkedErr(IdeCompatibleMessageRequestProcessor.kt:67)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor$evalWithIO$1.invoke(IdeCompatibleMessageRequestProcessor.kt:360)",
"\tat org.jetbrains.kotlinx.jupyter.streams.BlockingSubstitutionEngine.withDataSubstitution(SubstitutionEngine.kt:70)",
"\tat org.jetbrains.kotlinx.jupyter.streams.StreamSubstitutionManager.withSubstitutedStreams(StreamSubstitutionManager.kt:118)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor.withForkedOut(IdeCompatibleMessageRequestProcessor.kt:328)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor.evalWithIO(IdeCompatibleMessageRequestProcessor.kt:359)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor$processExecuteRequest$1$response$1.invoke(IdeCompatibleMessageRequestProcessor.kt:170)",
"\tat org.jetbrains.kotlinx.jupyter.messaging.IdeCompatibleMessageRequestProcessor$processExecuteRequest$1$response$1.invoke(IdeCompatibleMessageRequestProcessor.kt:169)",
"\tat org.jetbrains.kotlinx.jupyter.execution.JupyterExecutorImpl$Task.execute(JupyterExecutorImpl.kt:41)",
"\tat org.jetbrains.kotlinx.jupyter.execution.JupyterExecutorImpl$executorThread$1.invoke(JupyterExecutorImpl.kt:81)",
"\tat org.jetbrains.kotlinx.jupyter.execution.JupyterExecutorImpl$executorThread$1.invoke(JupyterExecutorImpl.kt:79)",
"\tat kotlin.concurrent.ThreadsKt$thread$thread$1.run(Thread.kt:30)",
"",
"java.lang.UnsatisfiedLinkError: Cannot find TensorFlow native library for OS: darwin, architecture: aarch64. See https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java/README.md for possible solutions (such as building the library from source). Additional information on attempts to find the native library can be obtained by adding org.tensorflow.NativeLibrary.DEBUG=1 to the system properties of the JVM.",
"at Cell In[8], line 3",
""
]
}
],
"execution_count": 8
},
{
"cell_type": "code",
"metadata": {},
"source": [
"model.compile(optimizer = Adam(), loss = Losses.MSE, metric = Metrics.MAE)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"model.printSummary()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train it!"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val trainHist = model.fit(train, batchSize = 500, epochs=2000)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"trainHist.epochHistory.toDataFrame().tail()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check that our network predicts values more or less correctly:"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"model.predictSoftly(test.x[9])[0]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"test.y[9]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Close the model:"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"model.close()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict wine quality: second approach"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"data class TrainTestSplitResult(\n",
" val trainX: DataFrame,\n",
" val trainY: DataFrame,\n",
" val testX: DataFrame,\n",
" val testY: DataFrame,\n",
")\n",
"\n",
"fun trainTestSplit(\n",
" d: DataFrame,\n",
" col: String,\n",
" trainPart: Double,\n",
"): TrainTestSplitResult {\n",
" val n = d.count()\n",
" val trainN = ceil(n * trainPart).toInt()\n",
"\n",
" val shuffledInd = (0 until n).shuffled()\n",
" val trainInd = shuffledInd.subList(0, trainN)\n",
" val testInd = shuffledInd.subList(trainN, n)\n",
"\n",
" val train = d[trainInd]\n",
" val test = d[testInd]\n",
"\n",
" val trainX = train.select { all().except(cols(col)) }\n",
" val trainY = train.select(col)\n",
"\n",
" val testX = test.select { all().except(cols(col)) }\n",
" val testY = test.select(col)\n",
"\n",
" return TrainTestSplitResult(trainX, trainY, testX, testY)\n",
"}"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create and then train the model as we did before"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val (trainX, trainY, testX, testY) =\n",
" trainTestSplit(df, \"quality\", 0.8)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fun DataFrame.toX(): Array =\n",
" merge { colsOf() }.by { it.map { it.toFloat() }.toFloatArray() }.into(\"X\")\n",
" .get { \"X\"() }\n",
" .toList()\n",
" .toTypedArray()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"fun DataFrame.toY(): FloatArray = \n",
" get { \"quality\"() }\n",
" .asIterable()\n",
" .map { it.toFloat() }\n",
" .toFloatArray()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val trainXDL = trainX.toX()\n",
"val trainYDL = trainY.toY()\n",
"val testXDL = testX.toX()\n",
"val testYDL = testY.toY()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val trainKotlinDLDataset = OnHeapDataset.create({ trainXDL }, { trainYDL })\n",
"val testKotlinDLDataset = OnHeapDataset.create({ testXDL }, { testYDL })"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val inputNeurons = train.x[0].size.toLong()\n",
"\n",
"val model2 = Sequential.of(\n",
" Input(\n",
" inputNeurons\n",
" ),\n",
" Dense(\n",
" outputSize = (inputNeurons * 10).toInt(),\n",
" activation = Activations.Tanh,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal()\n",
" ),\n",
" Dense(\n",
" outputSize = (inputNeurons * 10).toInt(),\n",
" activation = Activations.Tanh,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal()\n",
" ),\n",
" Dense(\n",
" outputSize = 1,\n",
" activation = Activations.Linear,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal()\n",
" )\n",
")\n",
"model2.compile(optimizer = Adam(), loss = Losses.MSE, metric = Metrics.MAE)\n",
"model2.printSummary()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val trainHist = model2.fit(train, batchSize = 500, epochs = 2000)\n",
"trainHist.epochHistory.toDataFrame().tail()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"model2.predictSoftly(testXDL[9])[0]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"testYDL[9]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also compare predicted and ground truth values to ensure predictions are correct"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val predicted = testXDL.mapIndexed { i, _ ->\n",
" round(model2.predictSoftly(testXDL[i])[0]).toInt()\n",
"}.toColumn(\"predicted\")\n",
"\n",
"val ground_truth = testYDL.mapIndexed { i, _ ->\n",
" testYDL[i].toInt()\n",
"}.toColumn(\"ground_truth\")\n",
"\n",
"val predDf = dataFrameOf(predicted, ground_truth)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"predDf.head()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val inds = List(10) { it + 1 }\n",
"val ctab = predDf\n",
" .groupBy { ground_truth }.pivotCounts(inward = false) { predicted }\n",
" .sortBy { ground_truth }\n",
"\n",
"ctab.format { drop(1) }.perRowCol { row, col ->\n",
" val y = col.name().toInt()\n",
" val x = row.ground_truth\n",
" val k = 1.0 - abs(x - y) / 10.0\n",
" background(RGBColor(50, (50 + k * 200).toInt().toShort(), 50))\n",
"}"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"val predDf2 = predDf.add(\"avg_dev\") { abs(predicted - ground_truth) }"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"predDf2.avg_dev.cast().describe()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"predDf2.sortBy { avg_dev }[(0.7 * (319 - 1)).toInt()]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"metadata": {},
"source": [
"model2.close()"
],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Kotlin",
"language": "kotlin",
"name": "kotlin"
},
"language_info": {
"codemirror_mode": "text/x-kotlin",
"file_extension": ".kt",
"mimetype": "text/x-kotlin",
"name": "kotlin",
"nbconvert_exporter": "",
"pygments_lexer": "kotlin",
"version": "1.8.0-dev-707"
},
"ktnbPluginMetadata": {
"projectLibraries": []
}
},
"nbformat": 4,
"nbformat_minor": 1
}