{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predict wine quality"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the latest versions of DataFrame and KotlinDL libraries from [version repository](https://github.com/Kotlin/kotlin-jupyter-libraries)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"pycharm": {
"is_executing": true
},
"ExecuteTime": {
"end_time": "2023-12-05T11:16:19.227175498Z",
"start_time": "2023-12-05T11:16:10.985704525Z"
}
},
"outputs": [],
"source": [
"%useLatestDescriptors\n",
"%use dataframe, kotlin-dl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Read the dataframe from CSV and print the first few lines of it"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:22.897712600Z",
"start_time": "2023-12-05T11:16:19.230308412Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
"
\n",
" \n",
" \n",
" \n",
" | fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality |
|---|
| 7.400000 | 0.700000 | 0.000000 | 1.900000 | 0.076000 | 11.000000 | 34.000000 | 0.997800 | 3.510000 | 0.560000 | 9.400000 | 5 |
| 7.800000 | 0.880000 | 0.000000 | 2.600000 | 0.098000 | 25.000000 | 67.000000 | 0.996800 | 3.200000 | 0.680000 | 9.800000 | 5 |
| 7.800000 | 0.760000 | 0.040000 | 2.300000 | 0.092000 | 15.000000 | 54.000000 | 0.997000 | 3.260000 | 0.650000 | 9.800000 | 5 |
| 11.200000 | 0.280000 | 0.560000 | 1.900000 | 0.075000 | 17.000000 | 60.000000 | 0.998000 | 3.160000 | 0.580000 | 9.800000 | 6 |
| 7.400000 | 0.700000 | 0.000000 | 1.900000 | 0.076000 | 11.000000 | 34.000000 | 0.997800 | 3.510000 | 0.560000 | 9.400000 | 5 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":5,\"ncol\":12,\"columns\":[\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"residual sugar\",\"chlorides\",\"free sulfur dioxide\",\"total sulfur dioxide\",\"density\",\"pH\",\"sulphates\",\"alcohol\",\"quality\"],\"kotlin_dataframe\":[{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"residual sugar\":1.9,\"chlorides\":0.076,\"free sulfur dioxide\":11.0,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"pH\":3.51,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.88,\"citric acid\":0.0,\"residual sugar\":2.6,\"chlorides\":0.098,\"free sulfur dioxide\":25.0,\"total sulfur dioxide\":67.0,\"density\":0.9968,\"pH\":3.2,\"sulphates\":0.68,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.76,\"citric acid\":0.04,\"residual sugar\":2.3,\"chlorides\":0.092,\"free sulfur dioxide\":15.0,\"total sulfur dioxide\":54.0,\"density\":0.997,\"pH\":3.26,\"sulphates\":0.65,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":11.2,\"volatile acidity\":0.28,\"citric acid\":0.56,\"residual sugar\":1.9,\"chlorides\":0.075,\"free sulfur dioxide\":17.0,\"total sulfur dioxide\":60.0,\"density\":0.998,\"pH\":3.16,\"sulphates\":0.58,\"alcohol\":9.8,\"quality\":6},{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"residual sugar\":1.9,\"chlorides\":0.076,\"free sulfur dioxide\":11.0,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"pH\":3.51,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 2
}
],
"source": [
"val raw_df = DataFrame.readCSV(fileOrUrl = \"winequality-red.csv\", delimiter = ';')\n",
"raw_df.head()"
]
},
{
"cell_type": "markdown",
"source": [
"Note: For formatting, the DataFrame needs to be rendered as HTML. This means that when running in Kotlin Notebook, \"Render DataFrame tables natively\" needs to be turned off."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:23.901440325Z",
"start_time": "2023-12-05T11:16:22.875220486Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | column | fixed acidity | volatile acidity | citric acid | residual sugar | chlorides | free sulfur dioxide | total sulfur dioxide | density | pH | sulphates | alcohol | quality |
|---|
| fixed acidity | 1.000000 | -0.256131 | 0.671703 | 0.114777 | 0.093705 | -0.153794 | -0.113181 | 0.668047 | -0.682978 | 0.183006 | -0.061668 | 0.124052 |
| volatile acidity | -0.256131 | 1.000000 | -0.552496 | 0.001918 | 0.061298 | -0.010504 | 0.076470 | 0.022026 | 0.234937 | -0.260987 | -0.202288 | -0.390558 |
| citric acid | 0.671703 | -0.552496 | 1.000000 | 0.143577 | 0.203823 | -0.060978 | 0.035533 | 0.364947 | -0.541904 | 0.312770 | 0.109903 | 0.226373 |
| residual sugar | 0.114777 | 0.001918 | 0.143577 | 1.000000 | 0.055610 | 0.187049 | 0.203028 | 0.355283 | -0.085652 | 0.005527 | 0.042075 | 0.013732 |
| chlorides | 0.093705 | 0.061298 | 0.203823 | 0.055610 | 1.000000 | 0.005562 | 0.047400 | 0.200632 | -0.265026 | 0.371260 | -0.221141 | -0.128907 |
| free sulfur dioxide | -0.153794 | -0.010504 | -0.060978 | 0.187049 | 0.005562 | 1.000000 | 0.667666 | -0.021946 | 0.070377 | 0.051658 | -0.069408 | -0.050656 |
| total sulfur dioxide | -0.113181 | 0.076470 | 0.035533 | 0.203028 | 0.047400 | 0.667666 | 1.000000 | 0.071269 | -0.066495 | 0.042947 | -0.205654 | -0.185100 |
| density | 0.668047 | 0.022026 | 0.364947 | 0.355283 | 0.200632 | -0.021946 | 0.071269 | 1.000000 | -0.341699 | 0.148506 | -0.496180 | -0.174919 |
| pH | -0.682978 | 0.234937 | -0.541904 | -0.085652 | -0.265026 | 0.070377 | -0.066495 | -0.341699 | 1.000000 | -0.196648 | 0.205633 | -0.057731 |
| sulphates | 0.183006 | -0.260987 | 0.312770 | 0.005527 | 0.371260 | 0.051658 | 0.042947 | 0.148506 | -0.196648 | 1.000000 | 0.093595 | 0.251397 |
| alcohol | -0.061668 | -0.202288 | 0.109903 | 0.042075 | -0.221141 | -0.069408 | -0.205654 | -0.496180 | 0.205633 | 0.093595 | 1.000000 | 0.476166 |
| quality | 0.124052 | -0.390558 | 0.226373 | 0.013732 | -0.128907 | -0.050656 | -0.185100 | -0.174919 | -0.057731 | 0.251397 | 0.476166 | 1.000000 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":12,\"ncol\":13,\"columns\":[\"column\",\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"residual sugar\",\"chlorides\",\"free sulfur dioxide\",\"total sulfur dioxide\",\"density\",\"pH\",\"sulphates\",\"alcohol\",\"quality\"],\"kotlin_dataframe\":[{\"column\":\"fixed acidity\",\"fixed acidity\":1.0,\"volatile acidity\":-0.2561308947703819,\"citric acid\":0.6717034347641041,\"residual sugar\":0.1147767244949209,\"chlorides\":0.09370518632130498,\"free sulfur dioxide\":-0.15379419286482485,\"total sulfur dioxide\":-0.11318144304548039,\"density\":0.6680472921189711,\"pH\":-0.6829781945685299,\"sulphates\":0.18300566393215348,\"alcohol\":-0.06166827062815111,\"quality\":0.1240516491132247},{\"column\":\"volatile acidity\",\"fixed acidity\":-0.2561308947703819,\"volatile acidity\":1.0,\"citric acid\":-0.5524956845595839,\"residual sugar\":0.001917881962790698,\"chlorides\":0.061297772476461614,\"free sulfur dioxide\":-0.010503827006591856,\"total sulfur dioxide\":0.07647000482092836,\"density\":0.022026232195215885,\"pH\":0.23493729440739436,\"sulphates\":-0.26098668528329055,\"alcohol\":-0.20228802715325686,\"quality\":-0.3905577802640094},{\"column\":\"citric acid\",\"fixed acidity\":0.6717034347641041,\"volatile acidity\":-0.5524956845595839,\"citric acid\":1.0,\"residual sugar\":0.14357716157031483,\"chlorides\":0.2038229138290425,\"free sulfur dioxide\":-0.06097812919230497,\"total sulfur dioxide\":0.035533023931161666,\"density\":0.364947175211252,\"pH\":-0.5419041447395132,\"sulphates\":0.31277004385441737,\"alcohol\":0.10990324664156755,\"quality\":0.2263725143180432},{\"column\":\"residual sugar\",\"fixed acidity\":0.1147767244949209,\"volatile acidity\":0.001917881962790698,\"citric acid\":0.14357716157031483,\"residual sugar\":1.0,\"chlorides\":0.05560953520353218,\"free sulfur dioxide\":0.18704899510428666,\"total sulfur dioxide\":0.2030278816971015,\"density\":0.35528337098337653,\"pH\":-0.08565242221887161,\"sulphates\":0.005527121339138363,\"alcohol\":0.04207543720973116,\"quality\":0.013731637340066346},{\"column\":\"chlorides\",\"fixed acidity\":0.09370518632130498,\"volatile acidity\":0.061297772476461614,\"citric acid\":0.2038229138290425,\"residual sugar\":0.05560953520353218,\"chlorides\":1.0,\"free sulfur dioxide\":0.005562147004781117,\"total sulfur dioxide\":0.04740046825907533,\"density\":0.200632326641512,\"pH\":-0.2650261311732279,\"sulphates\":0.371260481285427,\"alcohol\":-0.22114054478828302,\"quality\":-0.12890655993005312},{\"column\":\"free sulfur dioxide\",\"fixed acidity\":-0.15379419286482485,\"volatile acidity\":-0.010503827006591856,\"citric acid\":-0.06097812919230497,\"residual sugar\":0.18704899510428666,\"chlorides\":0.005562147004781117,\"free sulfur dioxide\":1.0,\"total sulfur dioxide\":0.6676664504810212,\"density\":-0.021945831163489242,\"pH\":0.07037749850494217,\"sulphates\":0.051657571842828584,\"alcohol\":-0.06940835356499997,\"quality\":-0.05065605724427643},{\"column\":\"total sulfur dioxide\",\"fixed acidity\":-0.11318144304548039,\"volatile acidity\":0.07647000482092836,\"citric acid\":0.035533023931161666,\"residual sugar\":0.2030278816971015,\"chlorides\":0.04740046825907533,\"free sulfur dioxide\":0.6676664504810212,\"total sulfur dioxide\":1.0,\"density\":0.07126947620310328,\"pH\":-0.06649455901285606,\"sulphates\":0.04294683623953844,\"alcohol\":-0.20565394374367177,\"quality\":-0.18510028892653843},{\"column\":\"density\",\"fixed acidity\":0.6680472921189711,\"volatile acidity\":0.022026232195215885,\"citric acid\":0.364947175211252,\"residual sugar\":0.35528337098337653,\"chlorides\":0.200632326641512,\"free sulfur dioxide\":-0.021945831163489242,\"total sulfur dioxide\":0.07126947620310328,\"density\":1.0,\"pH\":-0.3416993347850301,\"sulphates\":0.14850641172078524,\"alcohol\":-0.4961797702417023,\"quality\":-0.1749192277833492},{\"column\":\"pH\",\"fixed acidity\":-0.6829781945685299,\"volatile acidity\":0.23493729440739436,\"citric acid\":-0.5419041447395132,\"residual sugar\":-0.08565242221887161,\"chlorides\":-0.2650261311732279,\"free sulfur dioxide\":0.07037749850494217,\"total sulfur dioxide\":-0.06649455901285606,\"density\":-0.3416993347850301,\"pH\":1.0,\"sulphates\":-0.1966476023043703,\"alcohol\":0.20563250850549894,\"quality\":-0.0577313912053823},{\"column\":\"sulphates\",\"fixed acidity\":0.18300566393215348,\"volatile acidity\":-0.26098668528329055,\"citric acid\":0.31277004385441737,\"residual sugar\":0.005527121339138363,\"chlorides\":0.371260481285427,\"free sulfur dioxide\":0.051657571842828584,\"total sulfur dioxide\":0.04294683623953844,\"density\":0.14850641172078524,\"pH\":-0.1966476023043703,\"sulphates\":1.0,\"alcohol\":0.09359475041046762,\"quality\":0.25139707906926206},{\"column\":\"alcohol\",\"fixed acidity\":-0.06166827062815111,\"volatile acidity\":-0.20228802715325686,\"citric acid\":0.10990324664156755,\"residual sugar\":0.04207543720973116,\"chlorides\":-0.22114054478828302,\"free sulfur dioxide\":-0.06940835356499997,\"total sulfur dioxide\":-0.20565394374367177,\"density\":-0.4961797702417023,\"pH\":0.20563250850549894,\"sulphates\":0.09359475041046762,\"alcohol\":1.0,\"quality\":0.47616632400114156},{\"column\":\"quality\",\"fixed acidity\":0.1240516491132247,\"volatile acidity\":-0.3905577802640094,\"citric acid\":0.2263725143180432,\"residual sugar\":0.013731637340066346,\"chlorides\":-0.12890655993005312,\"free sulfur dioxide\":-0.05065605724427643,\"total sulfur dioxide\":-0.18510028892653843,\"density\":-0.1749192277833492,\"pH\":-0.0577313912053823,\"sulphates\":0.25139707906926206,\"alcohol\":0.47616632400114156,\"quality\":1.0}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 3
}
],
"source": [
"raw_df.corr().format { colsOf() }.with { \n",
" linearBg(value = it, from = -1.0 to red, to = 1.0 to green)\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Based on the correlation, we can remove some columns, they seem to be insignificant"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:25.803086434Z",
"start_time": "2023-12-05T11:16:23.862783060Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | fixed acidity | volatile acidity | citric acid | chlorides | total sulfur dioxide | density | sulphates | alcohol | quality |
|---|
| 7.400000 | 0.700000 | 0.000000 | 0.076000 | 34.000000 | 0.997800 | 0.560000 | 9.400000 | 5 |
| 7.800000 | 0.880000 | 0.000000 | 0.098000 | 67.000000 | 0.996800 | 0.680000 | 9.800000 | 5 |
| 7.800000 | 0.760000 | 0.040000 | 0.092000 | 54.000000 | 0.997000 | 0.650000 | 9.800000 | 5 |
| 11.200000 | 0.280000 | 0.560000 | 0.075000 | 60.000000 | 0.998000 | 0.580000 | 9.800000 | 6 |
| 7.400000 | 0.700000 | 0.000000 | 0.076000 | 34.000000 | 0.997800 | 0.560000 | 9.400000 | 5 |
| 7.400000 | 0.660000 | 0.000000 | 0.075000 | 40.000000 | 0.997800 | 0.560000 | 9.400000 | 5 |
| 7.900000 | 0.600000 | 0.060000 | 0.069000 | 59.000000 | 0.996400 | 0.460000 | 9.400000 | 5 |
| 7.300000 | 0.650000 | 0.000000 | 0.065000 | 21.000000 | 0.994600 | 0.470000 | 10.000000 | 7 |
| 7.800000 | 0.580000 | 0.020000 | 0.073000 | 18.000000 | 0.996800 | 0.570000 | 9.500000 | 7 |
| 7.500000 | 0.500000 | 0.360000 | 0.071000 | 102.000000 | 0.997800 | 0.800000 | 10.500000 | 5 |
| 6.700000 | 0.580000 | 0.080000 | 0.097000 | 65.000000 | 0.995900 | 0.540000 | 9.200000 | 5 |
| 7.500000 | 0.500000 | 0.360000 | 0.071000 | 102.000000 | 0.997800 | 0.800000 | 10.500000 | 5 |
| 5.600000 | 0.615000 | 0.000000 | 0.089000 | 59.000000 | 0.994300 | 0.520000 | 9.900000 | 5 |
| 7.800000 | 0.610000 | 0.290000 | 0.114000 | 29.000000 | 0.997400 | 1.560000 | 9.100000 | 5 |
| 8.900000 | 0.620000 | 0.180000 | 0.176000 | 145.000000 | 0.998600 | 0.880000 | 9.200000 | 5 |
| 8.900000 | 0.620000 | 0.190000 | 0.170000 | 148.000000 | 0.998600 | 0.930000 | 9.200000 | 5 |
| 8.500000 | 0.280000 | 0.560000 | 0.092000 | 103.000000 | 0.996900 | 0.750000 | 10.500000 | 7 |
| 8.100000 | 0.560000 | 0.280000 | 0.368000 | 56.000000 | 0.996800 | 1.280000 | 9.300000 | 5 |
| 7.400000 | 0.590000 | 0.080000 | 0.086000 | 29.000000 | 0.997400 | 0.500000 | 9.000000 | 4 |
| 7.900000 | 0.320000 | 0.510000 | 0.341000 | 56.000000 | 0.996900 | 1.080000 | 9.200000 | 6 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":1599,\"ncol\":9,\"columns\":[\"fixed acidity\",\"volatile acidity\",\"citric acid\",\"chlorides\",\"total sulfur dioxide\",\"density\",\"sulphates\",\"alcohol\",\"quality\"],\"kotlin_dataframe\":[{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"chlorides\":0.076,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.88,\"citric acid\":0.0,\"chlorides\":0.098,\"total sulfur dioxide\":67.0,\"density\":0.9968,\"sulphates\":0.68,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.76,\"citric acid\":0.04,\"chlorides\":0.092,\"total sulfur dioxide\":54.0,\"density\":0.997,\"sulphates\":0.65,\"alcohol\":9.8,\"quality\":5},{\"fixed acidity\":11.2,\"volatile acidity\":0.28,\"citric acid\":0.56,\"chlorides\":0.075,\"total sulfur dioxide\":60.0,\"density\":0.998,\"sulphates\":0.58,\"alcohol\":9.8,\"quality\":6},{\"fixed acidity\":7.4,\"volatile acidity\":0.7,\"citric acid\":0.0,\"chlorides\":0.076,\"total sulfur dioxide\":34.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.4,\"volatile acidity\":0.66,\"citric acid\":0.0,\"chlorides\":0.075,\"total sulfur dioxide\":40.0,\"density\":0.9978,\"sulphates\":0.56,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.9,\"volatile acidity\":0.6,\"citric acid\":0.06,\"chlorides\":0.069,\"total sulfur dioxide\":59.0,\"density\":0.9964,\"sulphates\":0.46,\"alcohol\":9.4,\"quality\":5},{\"fixed acidity\":7.3,\"volatile acidity\":0.65,\"citric acid\":0.0,\"chlorides\":0.065,\"total sulfur dioxide\":21.0,\"density\":0.9946,\"sulphates\":0.47,\"alcohol\":10.0,\"quality\":7},{\"fixed acidity\":7.8,\"volatile acidity\":0.58,\"citric acid\":0.02,\"chlorides\":0.073,\"total sulfur dioxide\":18.0,\"density\":0.9968,\"sulphates\":0.57,\"alcohol\":9.5,\"quality\":7},{\"fixed acidity\":7.5,\"volatile acidity\":0.5,\"citric acid\":0.36,\"chlorides\":0.071,\"total sulfur dioxide\":102.0,\"density\":0.9978,\"sulphates\":0.8,\"alcohol\":10.5,\"quality\":5},{\"fixed acidity\":6.7,\"volatile acidity\":0.58,\"citric acid\":0.08,\"chlorides\":0.097,\"total sulfur dioxide\":65.0,\"density\":0.9959,\"sulphates\":0.54,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":7.5,\"volatile acidity\":0.5,\"citric acid\":0.36,\"chlorides\":0.071,\"total sulfur dioxide\":102.0,\"density\":0.9978,\"sulphates\":0.8,\"alcohol\":10.5,\"quality\":5},{\"fixed acidity\":5.6,\"volatile acidity\":0.615,\"citric acid\":0.0,\"chlorides\":0.089,\"total sulfur dioxide\":59.0,\"density\":0.9943,\"sulphates\":0.52,\"alcohol\":9.9,\"quality\":5},{\"fixed acidity\":7.8,\"volatile acidity\":0.61,\"citric acid\":0.29,\"chlorides\":0.114,\"total sulfur dioxide\":29.0,\"density\":0.9974,\"sulphates\":1.56,\"alcohol\":9.1,\"quality\":5},{\"fixed acidity\":8.9,\"volatile acidity\":0.62,\"citric acid\":0.18,\"chlorides\":0.176,\"total sulfur dioxide\":145.0,\"density\":0.9986,\"sulphates\":0.88,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":8.9,\"volatile acidity\":0.62,\"citric acid\":0.19,\"chlorides\":0.17,\"total sulfur dioxide\":148.0,\"density\":0.9986,\"sulphates\":0.93,\"alcohol\":9.2,\"quality\":5},{\"fixed acidity\":8.5,\"volatile acidity\":0.28,\"citric acid\":0.56,\"chlorides\":0.092,\"total sulfur dioxide\":103.0,\"density\":0.9969,\"sulphates\":0.75,\"alcohol\":10.5,\"quality\":7},{\"fixed acidity\":8.1,\"volatile acidity\":0.56,\"citric acid\":0.28,\"chlorides\":0.368,\"total sulfur dioxide\":56.0,\"density\":0.9968,\"sulphates\":1.28,\"alcohol\":9.3,\"quality\":5},{\"fixed acidity\":7.4,\"volatile acidity\":0.59,\"citric acid\":0.08,\"chlorides\":0.086,\"total sulfur dioxide\":29.0,\"density\":0.9974,\"sulphates\":0.5,\"alcohol\":9.0,\"quality\":4},{\"fixed acidity\":7.9,\"volatile acidity\":0.32,\"citric acid\":0.51,\"chlorides\":0.341,\"total sulfur dioxide\":56.0,\"density\":0.9969,\"sulphates\":1.08,\"alcohol\":9.2,\"quality\":6}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 4
}
],
"source": [
"val df = raw_df.remove { `free sulfur dioxide` and `residual sugar` and pH }\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict wine quality: first approach"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:26.957258498Z",
"start_time": "2023-12-05T11:16:25.797377618Z"
}
},
"outputs": [],
"source": [
"// Simple converter function between DataFrame and KotlinDL data representations\n",
"fun DataFrame.toOnHeapDataset(labelColumnName: String): OnHeapDataset {\n",
" return OnHeapDataset.create(\n",
" dataframe = this,\n",
" yColumn = labelColumnName\n",
" )\n",
"}\n",
"\n",
"fun OnHeapDataset.Companion.create(\n",
" dataframe: DataFrame,\n",
" yColumn: String\n",
"): OnHeapDataset {\n",
" fun extractX(): Array =\n",
" dataframe.remove(yColumn).rows()\n",
" .map { (it.values() as List).toFloatArray() }.toTypedArray()\n",
"\n",
" fun extractY(): FloatArray =\n",
" dataframe.get { yColumn() }.toList().toFloatArray()\n",
"\n",
" return create(\n",
" ::extractX,\n",
" ::extractY\n",
" )\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:27.394202002Z",
"start_time": "2023-12-05T11:16:26.797989581Z"
}
},
"outputs": [],
"source": [
"val (train, test) = df.convert { colsOf() }.toFloat()\n",
" .toOnHeapDataset(labelColumnName = \"quality\")\n",
" .split(0.8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define simple neural network with only 2 dense layers"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:29.483880350Z",
"start_time": "2023-12-05T11:16:27.172120668Z"
}
},
"outputs": [],
"source": [
"val inputNeurons = train.x[0].size.toLong()\n",
"\n",
"val model = Sequential.of(\n",
" Input(\n",
" inputNeurons,\n",
" ),\n",
" Dense(\n",
" outputSize = (inputNeurons * 10).toInt(),\n",
" activation = Activations.Tanh,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal(),\n",
" ),\n",
" Dense(\n",
" outputSize = (inputNeurons * 10).toInt(),\n",
" activation = Activations.Tanh,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal(),\n",
" ),\n",
" Dense(\n",
" outputSize = 1,\n",
" activation = Activations.Linear,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal(),\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:29.682507911Z",
"start_time": "2023-12-05T11:16:29.432917424Z"
}
},
"outputs": [],
"source": [
"model.compile(optimizer = Adam(), loss = Losses.MSE, metric = Metrics.MAE)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:29.791145300Z",
"start_time": "2023-12-05T11:16:29.626841115Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==============================================================================\n",
"Model type: Sequential\n",
"______________________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"==============================================================================\n",
"input_1(Input) [None, 8] 0 \n",
"______________________________________________________________________________\n",
"dense_2(Dense) [None, 80] 720 \n",
"______________________________________________________________________________\n",
"dense_3(Dense) [None, 80] 6480 \n",
"______________________________________________________________________________\n",
"dense_4(Dense) [None, 1] 81 \n",
"______________________________________________________________________________\n",
"==============================================================================\n",
"Total trainable params: 7281\n",
"Total frozen params: 0\n",
"Total params: 7281\n",
"______________________________________________________________________________\n"
]
}
],
"source": [
"model.printSummary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train it!"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:38.582565633Z",
"start_time": "2023-12-05T11:16:29.756612239Z"
}
},
"outputs": [],
"source": [
"val trainHist = model.fit(train, batchSize = 500, epochs=2000)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:38.871304073Z",
"start_time": "2023-12-05T11:16:38.583131469Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | epochIndex | lossValue | metricValues | valLossValue | valMetricValues |
|---|
| 1996 | 0.334877 | [0.45114806294441223] | NaN | [NaN] |
| 1997 | 0.334841 | [0.45111772418022156] | NaN | [NaN] |
| 1998 | 0.334805 | [0.45108696818351746] | NaN | [NaN] |
| 1999 | 0.334768 | [0.45105621218681335] | NaN | [NaN] |
| 2000 | 0.334732 | [0.4510253965854645] | NaN | [NaN] |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":5,\"ncol\":5,\"columns\":[\"epochIndex\",\"lossValue\",\"metricValues\",\"valLossValue\",\"valMetricValues\"],\"kotlin_dataframe\":[{\"epochIndex\":1996,\"lossValue\":0.3348773419857025,\"metricValues\":[0.45114806294441223],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1997,\"lossValue\":0.3348410427570343,\"metricValues\":[0.45111772418022156],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1998,\"lossValue\":0.3348047435283661,\"metricValues\":[0.45108696818351746],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1999,\"lossValue\":0.3347683250904083,\"metricValues\":[0.45105621218681335],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":2000,\"lossValue\":0.33473196625709534,\"metricValues\":[0.4510253965854645],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 11
}
],
"source": [
"trainHist.epochHistory.toDataFrame().tail()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check that our network predicts values more or less correctly:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:39.115254360Z",
"start_time": "2023-12-05T11:16:38.837383297Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"5.2477317"
]
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 12
}
],
"source": [
"model.predictSoftly(test.x[9])[0]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:39.309250279Z",
"start_time": "2023-12-05T11:16:38.956283998Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"5.0"
]
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 13
}
],
"source": [
"test.y[9]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Close the model:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:39.367595540Z",
"start_time": "2023-12-05T11:16:39.021784659Z"
}
},
"outputs": [],
"source": [
"model.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict wine quality: second approach"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:39.858262608Z",
"start_time": "2023-12-05T11:16:39.092452824Z"
}
},
"outputs": [],
"source": [
"data class TrainTestSplitResult(\n",
" val trainX: DataFrame,\n",
" val trainY: DataFrame,\n",
" val testX: DataFrame,\n",
" val testY: DataFrame,\n",
")\n",
"\n",
"fun trainTestSplit(\n",
" d: DataFrame,\n",
" col: String,\n",
" trainPart: Double,\n",
"): TrainTestSplitResult {\n",
" val n = d.count()\n",
" val trainN = ceil(n * trainPart).toInt()\n",
"\n",
" val shuffledInd = (0 until n).shuffled()\n",
" val trainInd = shuffledInd.subList(0, trainN)\n",
" val testInd = shuffledInd.subList(trainN, n)\n",
"\n",
" val train = d[trainInd]\n",
" val test = d[testInd]\n",
"\n",
" val trainX = train.select { all().except(cols(col)) }\n",
" val trainY = train.select(col)\n",
"\n",
" val testX = test.select { all().except(cols(col)) }\n",
" val testY = test.select(col)\n",
"\n",
" return TrainTestSplitResult(trainX, trainY, testX, testY)\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create and then train the model as we did before"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:41.588501858Z",
"start_time": "2023-12-05T11:16:39.773619363Z"
}
},
"outputs": [],
"source": [
"val (trainX, trainY, testX, testY) =\n",
" trainTestSplit(df, \"quality\", 0.8)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:42.028324910Z",
"start_time": "2023-12-05T11:16:41.587739397Z"
}
},
"outputs": [],
"source": [
"fun DataFrame.toX(): Array =\n",
" merge { colsOf() }.by { it.map { it.toFloat() }.toFloatArray() }.into(\"X\")\n",
" .get { \"X\"() }\n",
" .toList()\n",
" .toTypedArray()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:42.332839608Z",
"start_time": "2023-12-05T11:16:41.996806378Z"
}
},
"outputs": [],
"source": [
"fun DataFrame.toY(): FloatArray = \n",
" get { \"quality\"() }\n",
" .asIterable()\n",
" .map { it.toFloat() }\n",
" .toFloatArray()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:42.527702465Z",
"start_time": "2023-12-05T11:16:42.217032792Z"
}
},
"outputs": [],
"source": [
"val trainXDL = trainX.toX()\n",
"val trainYDL = trainY.toY()\n",
"val testXDL = testX.toX()\n",
"val testYDL = testY.toY()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:42.774734665Z",
"start_time": "2023-12-05T11:16:42.439947039Z"
}
},
"outputs": [],
"source": [
"val trainKotlinDLDataset = OnHeapDataset.create({ trainXDL }, { trainYDL })\n",
"val testKotlinDLDataset = OnHeapDataset.create({ testXDL }, { testYDL })"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:43.061367684Z",
"start_time": "2023-12-05T11:16:42.605357371Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==============================================================================\n",
"Model type: Sequential\n",
"______________________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"==============================================================================\n",
"input_1(Input) [None, 8] 0 \n",
"______________________________________________________________________________\n",
"dense_2(Dense) [None, 80] 720 \n",
"______________________________________________________________________________\n",
"dense_3(Dense) [None, 80] 6480 \n",
"______________________________________________________________________________\n",
"dense_4(Dense) [None, 1] 81 \n",
"______________________________________________________________________________\n",
"==============================================================================\n",
"Total trainable params: 7281\n",
"Total frozen params: 0\n",
"Total params: 7281\n",
"______________________________________________________________________________\n"
]
}
],
"source": [
"val inputNeurons = train.x[0].size.toLong()\n",
"\n",
"val model2 = Sequential.of(\n",
" Input(\n",
" inputNeurons\n",
" ),\n",
" Dense(\n",
" outputSize = (inputNeurons * 10).toInt(),\n",
" activation = Activations.Tanh,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal()\n",
" ),\n",
" Dense(\n",
" outputSize = (inputNeurons * 10).toInt(),\n",
" activation = Activations.Tanh,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal()\n",
" ),\n",
" Dense(\n",
" outputSize = 1,\n",
" activation = Activations.Linear,\n",
" kernelInitializer = HeNormal(),\n",
" biasInitializer = HeNormal()\n",
" )\n",
")\n",
"model2.compile(optimizer = Adam(), loss = Losses.MSE, metric = Metrics.MAE)\n",
"model2.printSummary()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:50.819499613Z",
"start_time": "2023-12-05T11:16:42.942896113Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | epochIndex | lossValue | metricValues | valLossValue | valMetricValues |
|---|
| 1996 | 0.334773 | [0.45107388496398926] | NaN | [NaN] |
| 1997 | 0.334737 | [0.45104312896728516] | NaN | [NaN] |
| 1998 | 0.334700 | [0.45101237297058105] | NaN | [NaN] |
| 1999 | 0.334663 | [0.4509815275669098] | NaN | [NaN] |
| 2000 | 0.334626 | [0.45095062255859375] | NaN | [NaN] |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":5,\"ncol\":5,\"columns\":[\"epochIndex\",\"lossValue\",\"metricValues\",\"valLossValue\",\"valMetricValues\"],\"kotlin_dataframe\":[{\"epochIndex\":1996,\"lossValue\":0.3347732126712799,\"metricValues\":[0.45107388496398926],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1997,\"lossValue\":0.33473655581474304,\"metricValues\":[0.45104312896728516],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1998,\"lossValue\":0.3346997797489166,\"metricValues\":[0.45101237297058105],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":1999,\"lossValue\":0.33466312289237976,\"metricValues\":[0.4509815275669098],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]},{\"epochIndex\":2000,\"lossValue\":0.3346264064311981,\"metricValues\":[0.45095062255859375],\"valLossValue\":\"NaN\",\"valMetricValues\":[NaN]}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 22
}
],
"source": [
"val trainHist = model2.fit(train, batchSize = 500, epochs = 2000)\n",
"trainHist.epochHistory.toDataFrame().tail()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:50.980777499Z",
"start_time": "2023-12-05T11:16:50.812978010Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"6.6911993"
]
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 23
}
],
"source": [
"model2.predictSoftly(testXDL[9])[0]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:51.081838805Z",
"start_time": "2023-12-05T11:16:50.895468359Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"7.0"
]
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 24
}
],
"source": [
"testYDL[9]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also compare predicted and ground truth values to ensure predictions are correct"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:51.523573225Z",
"start_time": "2023-12-05T11:16:50.955044609Z"
}
},
"outputs": [],
"source": [
"val predicted = testXDL.mapIndexed { i, _ ->\n",
" round(model2.predictSoftly(testXDL[i])[0]).toInt()\n",
"}.toColumn(\"predicted\")\n",
"\n",
"val ground_truth = testYDL.mapIndexed { i, _ ->\n",
" testYDL[i].toInt()\n",
"}.toColumn(\"ground_truth\")\n",
"\n",
"val predDf = dataFrameOf(predicted, ground_truth)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:51.721871863Z",
"start_time": "2023-12-05T11:16:51.510838483Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | predicted | ground_truth |
|---|
| 5 | 5 |
| 5 | 5 |
| 6 | 5 |
| 5 | 5 |
| 6 | 6 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":5,\"ncol\":2,\"columns\":[\"predicted\",\"ground_truth\"],\"kotlin_dataframe\":[{\"predicted\":5,\"ground_truth\":5},{\"predicted\":5,\"ground_truth\":5},{\"predicted\":6,\"ground_truth\":5},{\"predicted\":5,\"ground_truth\":5},{\"predicted\":6,\"ground_truth\":6}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 26
}
],
"source": [
"predDf.head()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:52.792811636Z",
"start_time": "2023-12-05T11:16:51.613822929Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | ground_truth | 5 | 6 | 4 | 7 |
|---|
| 3 | 2 | 0 | 1 | 0 |
| 4 | 8 | 2 | 1 | 0 |
| 5 | 105 | 42 | 1 | 0 |
| 6 | 34 | 78 | 0 | 5 |
| 7 | 0 | 20 | 0 | 16 |
| 8 | 0 | 1 | 0 | 3 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":6,\"ncol\":5,\"columns\":[\"ground_truth\",\"5\",\"6\",\"4\",\"7\"],\"kotlin_dataframe\":[{\"ground_truth\":3,\"5\":2,\"6\":0,\"4\":1,\"7\":0},{\"ground_truth\":4,\"5\":8,\"6\":2,\"4\":1,\"7\":0},{\"ground_truth\":5,\"5\":105,\"6\":42,\"4\":1,\"7\":0},{\"ground_truth\":6,\"5\":34,\"6\":78,\"4\":0,\"7\":5},{\"ground_truth\":7,\"5\":0,\"6\":20,\"4\":0,\"7\":16},{\"ground_truth\":8,\"5\":0,\"6\":1,\"4\":0,\"7\":3}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 27
}
],
"source": [
"val inds = List(10) { it + 1 }\n",
"val ctab = predDf\n",
" .groupBy { ground_truth }.pivotCounts(inward = false) { predicted }\n",
" .sortBy { ground_truth }\n",
"\n",
"ctab.format { drop(1) }.perRowCol { row, col ->\n",
" val y = col.name().toInt()\n",
" val x = row.ground_truth\n",
" val k = 1.0 - abs(x - y) / 10.0\n",
" background(RGBColor(50, (50 + k * 200).toInt().toShort(), 50))\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:53.265516155Z",
"start_time": "2023-12-05T11:16:52.633490769Z"
}
},
"outputs": [],
"source": [
"val predDf2 = predDf.add(\"avg_dev\") { abs(predicted - ground_truth) }"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:53.466031186Z",
"start_time": "2023-12-05T11:16:53.040384705Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | name | type | count | unique | nulls | top | freq | mean | std | min | median | max |
|---|
| avg_dev | Int | 319 | 3 | 0 | 0 | 200 | 0.388715 | 0.519432 | 0 | 0 | 2 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":1,\"ncol\":12,\"columns\":[\"name\",\"type\",\"count\",\"unique\",\"nulls\",\"top\",\"freq\",\"mean\",\"std\",\"min\",\"median\",\"max\"],\"kotlin_dataframe\":[{\"name\":\"avg_dev\",\"type\":\"Int\",\"count\":319,\"unique\":3,\"nulls\":0,\"top\":0,\"freq\":200,\"mean\":0.3887147335423197,\"std\":0.5194317560418836,\"min\":0,\"median\":0,\"max\":2}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 29
}
],
"source": [
"predDf2.avg_dev.cast().describe()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:53.495338730Z",
"start_time": "2023-12-05T11:16:53.198655138Z"
}
},
"outputs": [
{
"data": {
"text/html": [
" \n",
" \n",
" \n",
" \n",
" \n",
" \n",
" | predicted | ground_truth | avg_dev |
|---|
| 6 | 5 | 1 |
\n",
" \n",
" \n",
" "
],
"application/kotlindataframe+json": "{\"nrow\":1,\"ncol\":3,\"columns\":[\"predicted\",\"ground_truth\",\"avg_dev\"],\"kotlin_dataframe\":[{\"predicted\":6,\"ground_truth\":5,\"avg_dev\":1}]}"
},
"metadata": {},
"output_type": "execute_result",
"execution_count": 30
}
],
"source": [
"predDf2.sortBy { avg_dev }[(0.7 * (319 - 1)).toInt()]"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"ExecuteTime": {
"end_time": "2023-12-05T11:16:53.539689733Z",
"start_time": "2023-12-05T11:16:53.340929615Z"
}
},
"outputs": [],
"source": [
"model2.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Kotlin",
"language": "kotlin",
"name": "kotlin"
},
"language_info": {
"codemirror_mode": "text/x-kotlin",
"file_extension": ".kt",
"mimetype": "text/x-kotlin",
"name": "kotlin",
"nbconvert_exporter": "",
"pygments_lexer": "kotlin",
"version": "1.8.0-dev-707"
},
"ktnbPluginMetadata": {
"projectLibraries": []
}
},
"nbformat": 4,
"nbformat_minor": 1
}