{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%use deeplearning4j\n",
"%use krangl@2fcf74dfbbe382f1803d1ab9e4739439e1f5671b\n",
"%use lets-plot"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"val iris_data = \"sepal-length,sepal-width,petal-length,petal-width,species\\n5.1,3.5,1.4,0.2,Iris-setosa\\n4.9,3.0,1.4,0.2,Iris-setosa\\n4.7,3.2,1.3,0.2,Iris-setosa\\n4.6,3.1,1.5,0.2,Iris-setosa\\n5.0,3.6,1.4,0.2,Iris-setosa\\n5.4,3.9,1.7,0.4,Iris-setosa\\n4.6,3.4,1.4,0.3,Iris-setosa\\n5.0,3.4,1.5,0.2,Iris-setosa\\n4.4,2.9,1.4,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n5.4,3.7,1.5,0.2,Iris-setosa\\n4.8,3.4,1.6,0.2,Iris-setosa\\n4.8,3.0,1.4,0.1,Iris-setosa\\n4.3,3.0,1.1,0.1,Iris-setosa\\n5.8,4.0,1.2,0.2,Iris-setosa\\n5.7,4.4,1.5,0.4,Iris-setosa\\n5.4,3.9,1.3,0.4,Iris-setosa\\n5.1,3.5,1.4,0.3,Iris-setosa\\n5.7,3.8,1.7,0.3,Iris-setosa\\n5.1,3.8,1.5,0.3,Iris-setosa\\n5.4,3.4,1.7,0.2,Iris-setosa\\n5.1,3.7,1.5,0.4,Iris-setosa\\n4.6,3.6,1.0,0.2,Iris-setosa\\n5.1,3.3,1.7,0.5,Iris-setosa\\n4.8,3.4,1.9,0.2,Iris-setosa\\n5.0,3.0,1.6,0.2,Iris-setosa\\n5.0,3.4,1.6,0.4,Iris-setosa\\n5.2,3.5,1.5,0.2,Iris-setosa\\n5.2,3.4,1.4,0.2,Iris-setosa\\n4.7,3.2,1.6,0.2,Iris-setosa\\n4.8,3.1,1.6,0.2,Iris-setosa\\n5.4,3.4,1.5,0.4,Iris-setosa\\n5.2,4.1,1.5,0.1,Iris-setosa\\n5.5,4.2,1.4,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n5.0,3.2,1.2,0.2,Iris-setosa\\n5.5,3.5,1.3,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n4.4,3.0,1.3,0.2,Iris-setosa\\n5.1,3.4,1.5,0.2,Iris-setosa\\n5.0,3.5,1.3,0.3,Iris-setosa\\n4.5,2.3,1.3,0.3,Iris-setosa\\n4.4,3.2,1.3,0.2,Iris-setosa\\n5.0,3.5,1.6,0.6,Iris-setosa\\n5.1,3.8,1.9,0.4,Iris-setosa\\n4.8,3.0,1.4,0.3,Iris-setosa\\n5.1,3.8,1.6,0.2,Iris-setosa\\n4.6,3.2,1.4,0.2,Iris-setosa\\n5.3,3.7,1.5,0.2,Iris-setosa\\n5.0,3.3,1.4,0.2,Iris-setosa\\n7.0,3.2,4.7,1.4,Iris-versicolor\\n6.4,3.2,4.5,1.5,Iris-versicolor\\n6.9,3.1,4.9,1.5,Iris-versicolor\\n5.5,2.3,4.0,1.3,Iris-versicolor\\n6.5,2.8,4.6,1.5,Iris-versicolor\\n5.7,2.8,4.5,1.3,Iris-versicolor\\n6.3,3.3,4.7,1.6,Iris-versicolor\\n4.9,2.4,3.3,1.0,Iris-versicolor\\n6.6,2.9,4.6,1.3,Iris-versicolor\\n5.2,2.7,3.9,1.4,Iris-versicolor\\n5.0,2.0,3.5,1.0,Iris-versicolor\\n5.9,3.0,4.2,1.5,Iris-versicolor\\n6.0,2.2,4.0,1.0,Iris-versicolor\\n6.1,2.9,4.7,1.4,Iris-versicolor\\n5.6,2.9,3.6,1.3,Iris-versicolor\\n6.7,3.1,4.4,1.4,Iris-versicolor\\n5.6,3.0,4.5,1.5,Iris-versicolor\\n5.8,2.7,4.1,1.0,Iris-versicolor\\n6.2,2.2,4.5,1.5,Iris-versicolor\\n5.6,2.5,3.9,1.1,Iris-versicolor\\n5.9,3.2,4.8,1.8,Iris-versicolor\\n6.1,2.8,4.0,1.3,Iris-versicolor\\n6.3,2.5,4.9,1.5,Iris-versicolor\\n6.1,2.8,4.7,1.2,Iris-versicolor\\n6.4,2.9,4.3,1.3,Iris-versicolor\\n6.6,3.0,4.4,1.4,Iris-versicolor\\n6.8,2.8,4.8,1.4,Iris-versicolor\\n6.7,3.0,5.0,1.7,Iris-versicolor\\n6.0,2.9,4.5,1.5,Iris-versicolor\\n5.7,2.6,3.5,1.0,Iris-versicolor\\n5.5,2.4,3.8,1.1,Iris-versicolor\\n5.5,2.4,3.7,1.0,Iris-versicolor\\n5.8,2.7,3.9,1.2,Iris-versicolor\\n6.0,2.7,5.1,1.6,Iris-versicolor\\n5.4,3.0,4.5,1.5,Iris-versicolor\\n6.0,3.4,4.5,1.6,Iris-versicolor\\n6.7,3.1,4.7,1.5,Iris-versicolor\\n6.3,2.3,4.4,1.3,Iris-versicolor\\n5.6,3.0,4.1,1.3,Iris-versicolor\\n5.5,2.5,4.0,1.3,Iris-versicolor\\n5.5,2.6,4.4,1.2,Iris-versicolor\\n6.1,3.0,4.6,1.4,Iris-versicolor\\n5.8,2.6,4.0,1.2,Iris-versicolor\\n5.0,2.3,3.3,1.0,Iris-versicolor\\n5.6,2.7,4.2,1.3,Iris-versicolor\\n5.7,3.0,4.2,1.2,Iris-versicolor\\n5.7,2.9,4.2,1.3,Iris-versicolor\\n6.2,2.9,4.3,1.3,Iris-versicolor\\n5.1,2.5,3.0,1.1,Iris-versicolor\\n5.7,2.8,4.1,1.3,Iris-versicolor\\n6.3,3.3,6.0,2.5,Iris-virginica\\n5.8,2.7,5.1,1.9,Iris-virginica\\n7.1,3.0,5.9,2.1,Iris-virginica\\n6.3,2.9,5.6,1.8,Iris-virginica\\n6.5,3.0,5.8,2.2,Iris-virginica\\n7.6,3.0,6.6,2.1,Iris-virginica\\n4.9,2.5,4.5,1.7,Iris-virginica\\n7.3,2.9,6.3,1.8,Iris-virginica\\n6.7,2.5,5.8,1.8,Iris-virginica\\n7.2,3.6,6.1,2.5,Iris-virginica\\n6.5,3.2,5.1,2.0,Iris-virginica\\n6.4,2.7,5.3,1.9,Iris-virginica\\n6.8,3.0,5.5,2.1,Iris-virginica\\n5.7,2.5,5.0,2.0,Iris-virginica\\n5.8,2.8,5.1,2.4,Iris-virginica\\n6.4,3.2,5.3,2.3,Iris-virginica\\n6.5,3.0,5.5,1.8,Iris-virginica\\n7.7,3.8,6.7,2.2,Iris-virginica\\n7.7,2.6,6.9,2.3,Iris-virginica\\n6.0,2.2,5.0,1.5,Iris-virginica\\n6.9,3.2,5.7,2.3,Iris-virginica\\n5.6,2.8,4.9,2.0,Iris-virginica\\n7.7,2.8,6.7,2.0,Iris-virginica\\n6.3,2.7,4.9,1.8,Iris-virginica\\n6.7,3.3,5.7,2.1,Iris-virginica\\n7.2,3.2,6.0,1.8,Iris-virginica\\n6.2,2.8,4.8,1.8,Iris-virginica\\n6.1,3.0,4.9,1.8,Iris-virginica\\n6.4,2.8,5.6,2.1,Iris-virginica\\n7.2,3.0,5.8,1.6,Iris-virginica\\n7.4,2.8,6.1,1.9,Iris-virginica\\n7.9,3.8,6.4,2.0,Iris-virginica\\n6.4,2.8,5.6,2.2,Iris-virginica\\n6.3,2.8,5.1,1.5,Iris-virginica\\n6.1,2.6,5.6,1.4,Iris-virginica\\n7.7,3.0,6.1,2.3,Iris-virginica\\n6.3,3.4,5.6,2.4,Iris-virginica\\n6.4,3.1,5.5,1.8,Iris-virginica\\n6.0,3.0,4.8,1.8,Iris-virginica\\n6.9,3.1,5.4,2.1,Iris-virginica\\n6.7,3.1,5.6,2.4,Iris-virginica\\n6.9,3.1,5.1,2.3,Iris-virginica\\n5.8,2.7,5.1,1.9,Iris-virginica\\n6.8,3.2,5.9,2.3,Iris-virginica\\n6.7,3.3,5.7,2.5,Iris-virginica\\n6.7,3.0,5.2,2.3,Iris-virginica\\n6.3,2.5,5.0,1.9,Iris-virginica\\n6.5,3.0,5.2,2.0,Iris-virginica\\n6.2,3.4,5.4,2.3,Iris-virginica\\n5.9,3.0,5.1,1.8,Iris-virginica\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"| sepal-length | sepal-width | petal-length | petal-width | species |
|---|
| 5.1 | 3.3 | 1.7 | 0.5 | Iris-setosa |
| 5.8 | 2.7 | 5.1 | 1.9 | Iris-virginica |
| 5.6 | 2.8 | 4.9 | 2.0 | Iris-virginica |
| 4.8 | 3.0 | 1.4 | 0.3 | Iris-setosa |
| 7.7 | 2.6 | 6.9 | 2.3 | Iris-virginica |
"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import java.util.*\n",
"import java.io.StringReader\n",
"\n",
"val iris = DataFrame.readDelim(StringReader(iris_data)).shuffle()\n",
"iris.head()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" \n",
" "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"val points = geomPoint(\n",
" data = mapOf(\n",
" \"x\" to iris[\"sepal-length\"].values().toList(),\n",
" \"y\" to iris[\"sepal-width\"].values().toList(),\n",
" \"color\" to iris[\"species\"].values().toList()\n",
" ), alpha=1.0)\n",
"{\n",
" x = \"x\" \n",
" y = \"y\"\n",
" color = \"color\"\n",
"}\n",
"\n",
"ggplot() + points"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"| sepal-length | sepal-width | petal-length | petal-width |
|---|
| 5.1 | 3.3 | 1.7 | 0.5 |
| 5.8 | 2.7 | 5.1 | 1.9 |
| 5.6 | 2.8 | 4.9 | 2.0 |
| 4.8 | 3.0 | 1.4 | 0.3 |
| 7.7 | 2.6 | 6.9 | 2.3 |
"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"val irisWithoutLabel = iris.remove(\"species\")\n",
"irisWithoutLabel.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[5.1, 3.3, 1.7, 0.5]\n",
"[5.8, 2.7, 5.1, 1.9]\n",
"[5.6, 2.8, 4.9, 2.0]\n",
"[4.8, 3.0, 1.4, 0.3]\n",
"[7.7, 2.6, 6.9, 2.3]\n",
"[5.6, 2.9, 3.6, 1.3]\n",
"[6.9, 3.1, 5.4, 2.1]\n",
"[5.9, 3.0, 4.2, 1.5]\n",
"[4.9, 3.1, 1.5, 0.1]\n",
"[6.8, 2.8, 4.8, 1.4]\n",
"[6.0, 2.2, 5.0, 1.5]\n",
"[6.0, 3.4, 4.5, 1.6]\n",
"[5.4, 3.9, 1.3, 0.4]\n",
"[5.7, 3.0, 4.2, 1.2]\n",
"[7.2, 3.0, 5.8, 1.6]\n",
"[6.0, 2.7, 5.1, 1.6]\n",
"[6.4, 3.2, 5.3, 2.3]\n",
"[5.7, 2.8, 4.1, 1.3]\n",
"[5.7, 2.5, 5.0, 2.0]\n",
"[6.2, 2.8, 4.8, 1.8]\n",
"[5.0, 3.5, 1.3, 0.3]\n",
"[5.7, 4.4, 1.5, 0.4]\n",
"[6.3, 2.5, 5.0, 1.9]\n",
"[7.7, 3.0, 6.1, 2.3]\n",
"[4.8, 3.0, 1.4, 0.1]\n",
"[5.8, 2.7, 3.9, 1.2]\n",
"[5.1, 2.5, 3.0, 1.1]\n",
"[6.4, 2.8, 5.6, 2.1]\n",
"[5.3, 3.7, 1.5, 0.2]\n",
"[4.6, 3.4, 1.4, 0.3]\n",
"[7.6, 3.0, 6.6, 2.1]\n",
"[4.5, 2.3, 1.3, 0.3]\n",
"[5.6, 2.7, 4.2, 1.3]\n",
"[5.7, 2.6, 3.5, 1.0]\n",
"[6.7, 3.0, 5.0, 1.7]\n",
"[6.5, 3.0, 5.8, 2.2]\n",
"[5.0, 2.3, 3.3, 1.0]\n",
"[6.1, 3.0, 4.9, 1.8]\n",
"[6.5, 3.0, 5.2, 2.0]\n",
"[6.2, 3.4, 5.4, 2.3]\n",
"[4.4, 2.9, 1.4, 0.2]\n",
"[5.2, 3.5, 1.5, 0.2]\n",
"[7.2, 3.6, 6.1, 2.5]\n",
"[5.5, 4.2, 1.4, 0.2]\n",
"[6.4, 2.9, 4.3, 1.3]\n",
"[4.9, 3.0, 1.4, 0.2]\n",
"[6.3, 2.5, 4.9, 1.5]\n",
"[5.5, 2.4, 3.7, 1.0]\n",
"[4.7, 3.2, 1.6, 0.2]\n",
"[6.3, 2.7, 4.9, 1.8]\n",
"[6.3, 2.3, 4.4, 1.3]\n",
"[7.1, 3.0, 5.9, 2.1]\n",
"[5.0, 3.5, 1.6, 0.6]\n",
"[6.8, 3.0, 5.5, 2.1]\n",
"[4.8, 3.4, 1.9, 0.2]\n",
"[6.7, 3.1, 5.6, 2.4]\n",
"[5.8, 2.6, 4.0, 1.2]\n",
"[5.0, 3.2, 1.2, 0.2]\n",
"[6.7, 3.3, 5.7, 2.5]\n",
"[5.1, 3.5, 1.4, 0.2]\n",
"[6.4, 2.7, 5.3, 1.9]\n",
"[7.0, 3.2, 4.7, 1.4]\n",
"[6.1, 2.8, 4.7, 1.2]\n",
"[5.4, 3.4, 1.7, 0.2]\n",
"[4.9, 2.4, 3.3, 1.0]\n",
"[5.2, 3.4, 1.4, 0.2]\n",
"[6.5, 2.8, 4.6, 1.5]\n",
"[5.4, 3.0, 4.5, 1.5]\n",
"[7.3, 2.9, 6.3, 1.8]\n",
"[5.2, 2.7, 3.9, 1.4]\n",
"[5.4, 3.9, 1.7, 0.4]\n",
"[6.2, 2.2, 4.5, 1.5]\n",
"[5.1, 3.5, 1.4, 0.3]\n",
"[4.8, 3.4, 1.6, 0.2]\n",
"[7.7, 3.8, 6.7, 2.2]\n",
"[5.6, 3.0, 4.5, 1.5]\n",
"[6.3, 3.4, 5.6, 2.4]\n",
"[5.8, 2.8, 5.1, 2.4]\n",
"[5.5, 2.3, 4.0, 1.3]\n",
"[4.9, 2.5, 4.5, 1.7]\n",
"[6.0, 2.2, 4.0, 1.0]\n",
"[5.0, 2.0, 3.5, 1.0]\n",
"[5.9, 3.2, 4.8, 1.8]\n",
"[5.4, 3.4, 1.5, 0.4]\n",
"[6.9, 3.1, 4.9, 1.5]\n",
"[4.9, 3.1, 1.5, 0.1]\n",
"[5.2, 4.1, 1.5, 0.1]\n",
"[5.1, 3.8, 1.5, 0.3]\n",
"[5.1, 3.8, 1.6, 0.2]\n",
"[6.7, 3.1, 4.7, 1.5]\n",
"[5.9, 3.0, 5.1, 1.8]\n",
"[5.8, 4.0, 1.2, 0.2]\n",
"[4.3, 3.0, 1.1, 0.1]\n",
"[6.7, 2.5, 5.8, 1.8]\n",
"[6.3, 3.3, 6.0, 2.5]\n",
"[5.6, 2.5, 3.9, 1.1]\n",
"[4.4, 3.2, 1.3, 0.2]\n",
"[4.6, 3.1, 1.5, 0.2]\n",
"[5.5, 2.6, 4.4, 1.2]\n",
"[6.9, 3.1, 5.1, 2.3]\n",
"[6.0, 2.9, 4.5, 1.5]\n",
"[7.2, 3.2, 6.0, 1.8]\n",
"[6.1, 2.8, 4.0, 1.3]\n",
"[5.7, 2.9, 4.2, 1.3]\n",
"[5.8, 2.7, 4.1, 1.0]\n",
"[4.8, 3.1, 1.6, 0.2]\n",
"[6.9, 3.2, 5.7, 2.3]\n",
"[5.5, 2.4, 3.8, 1.1]\n",
"[5.0, 3.4, 1.5, 0.2]\n",
"[4.6, 3.2, 1.4, 0.2]\n",
"[4.9, 3.1, 1.5, 0.1]\n",
"[6.0, 3.0, 4.8, 1.8]\n",
"[6.3, 2.9, 5.6, 1.8]\n",
"[6.6, 3.0, 4.4, 1.4]\n",
"[7.9, 3.8, 6.4, 2.0]\n",
"[5.6, 3.0, 4.1, 1.3]\n",
"[5.7, 3.8, 1.7, 0.3]\n",
"[5.0, 3.4, 1.6, 0.4]\n",
"[5.7, 2.8, 4.5, 1.3]\n",
"[6.7, 3.3, 5.7, 2.1]\n",
"[6.7, 3.1, 4.4, 1.4]\n",
"[6.7, 3.0, 5.2, 2.3]\n",
"[5.5, 2.5, 4.0, 1.3]\n",
"[5.0, 3.3, 1.4, 0.2]\n",
"[4.4, 3.0, 1.3, 0.2]\n",
"[6.6, 2.9, 4.6, 1.3]\n",
"[7.4, 2.8, 6.1, 1.9]\n",
"[6.5, 3.0, 5.5, 1.8]\n",
"[6.3, 2.8, 5.1, 1.5]\n",
"[6.4, 3.2, 4.5, 1.5]\n",
"[6.1, 2.9, 4.7, 1.4]\n",
"[4.6, 3.6, 1.0, 0.2]\n",
"[5.4, 3.7, 1.5, 0.2]\n",
"[5.5, 3.5, 1.3, 0.2]\n",
"[6.1, 3.0, 4.6, 1.4]\n",
"[5.8, 2.7, 5.1, 1.9]\n",
"[6.8, 3.2, 5.9, 2.3]\n",
"[6.4, 3.1, 5.5, 1.8]\n",
"[7.7, 2.8, 6.7, 2.0]\n",
"[5.0, 3.0, 1.6, 0.2]\n",
"[6.2, 2.9, 4.3, 1.3]\n",
"[5.1, 3.4, 1.5, 0.2]\n",
"[6.5, 3.2, 5.1, 2.0]\n",
"[5.1, 3.7, 1.5, 0.4]\n",
"[6.4, 2.8, 5.6, 2.2]\n",
"[5.1, 3.8, 1.9, 0.4]\n",
"[5.0, 3.6, 1.4, 0.2]\n",
"[4.7, 3.2, 1.3, 0.2]\n",
"[6.3, 3.3, 4.7, 1.6]\n",
"[6.1, 2.6, 5.6, 1.4]]\n"
]
}
],
"source": [
"//Convert the iris data into 150x4 matrix\n",
"val row = 150\n",
"val col = 4\n",
"\n",
"val irisMatrix = Array(row) { DoubleArray(col) }\n",
"var i = 0\n",
"for (r in 0 until row) {\n",
" for (c in 0 until col) {\n",
" irisMatrix[r][c] = irisWithoutLabel[c][r] as Double\n",
" }\n",
"}\n",
"println(Arrays.deepToString(irisMatrix).replace(\"], \", \"]\\n\"))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 0.0, 1.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[1.0, 0.0, 0.0]\n",
"[0.0, 1.0, 0.0]\n",
"[0.0, 0.0, 1.0]]\n"
]
}
],
"source": [
"//Now do the same for the label data\n",
"val irisLabel = iris.select(\"species\")[0]\n",
"\n",
"val rowLabel = 150\n",
"val colLabel = 3\n",
"\n",
"val twodimLabel = Array(rowLabel) { DoubleArray(colLabel) }\n",
"for (r in 0 until rowLabel) {\n",
" when (irisLabel[r]) {\n",
" \"Iris-setosa\" -> twodimLabel[r][0] = 1.0\n",
" \"Iris-versicolor\" -> twodimLabel[r][1] = 1.0\n",
" \"Iris-virginica\" -> twodimLabel[r][2] = 1.0\n",
" }\n",
"}\n",
"println(Arrays.deepToString(twodimLabel).replace(\"], \", \"]\\n\"))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"//Convert the data matrices into training INDArrays\n",
"val dataIn = Nd4j.create(irisMatrix)\n",
"val dataOut = Nd4j.create(twodimLabel)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import org.nd4j.linalg.lossfunctions.LossFunctions\n",
"\n",
"val seed: Long = 6\n",
"\n",
"val conf = NeuralNetConfiguration.Builder()\n",
" .seed(seed) //include a random seed for reproducibility\n",
" // use stochastic gradient descent as an optimization algorithm\n",
" .updater(Nadam()) //specify the rate of change of the learning rate.\n",
" .l2(1e-4)\n",
" .list()\n",
" .layer(DenseLayer.Builder()\n",
" .nIn(4)\n",
" .nOut(3)\n",
" .activation(Activation.TANH)\n",
" .weightInit(WeightInit.XAVIER)\n",
" .build())\n",
" .layer(org.deeplearning4j.nn.conf.layers.DenseLayer.Builder()\n",
" .nIn(3)\n",
" .nOut(3)\n",
" .build())\n",
" .layer(OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)\n",
" .nIn(3)\n",
" .nOut(3)\n",
" .activation(Activation.SOFTMAX)\n",
" .weightInit(WeightInit.XAVIER)\n",
" .build())\n",
" .build()\n",
"\n",
"val model = MultiLayerNetwork(conf)\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score \n",
"\n",
"========================Evaluation Metrics========================\n",
" # of classes: 3\n",
" Accuracy: 1,0000\n",
" Precision: 1,0000\n",
" Recall: 1,0000\n",
" F1 Score: 1,0000\n",
"Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)\n",
"\n",
"\n",
"=========================Confusion Matrix=========================\n",
" 0 1 2\n",
"-------\n",
" 5 0 0 | 0 = 0\n",
" 0 3 0 | 1 = 1\n",
" 0 0 7 | 2 = 2\n",
"\n",
"Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times\n",
"==================================================================\n"
]
}
],
"source": [
"import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization\n",
"import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize\n",
"\n",
"//Create a data set from the INDArrays and shuffle it \n",
"val fullDataSet = DataSet(dataIn, dataOut)\n",
"fullDataSet.shuffle(seed)\n",
"\n",
"val splitedSet = fullDataSet.splitTestAndTrain(0.90)\n",
"val trainingData = splitedSet.train;\n",
"val testData = splitedSet.test;\n",
"\n",
"//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):\n",
"val normalizer: DataNormalization = NormalizerStandardize()\n",
"normalizer.fit(trainingData) //Collect the statistics (mean/stdev) from the training data. This does not modify the input data\n",
"normalizer.transform(trainingData) //Apply normalization to the training data\n",
"normalizer.transform(testData) //Apply normalization to the test data. This is using statistics calculated from the *training* set\n",
"\n",
"// train the network\n",
"model.setListeners(ScoreIterationListener(100))\n",
"for (l in 0..2000) {\n",
" model.fit(trainingData)\n",
"}\n",
"\n",
"// evaluate the network\n",
"val eval = Evaluation()\n",
"val output: INDArray = model.output(testData.features)\n",
"eval.eval(testData.labels, output)\n",
"println(\"Score \" + eval.stats())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Kotlin",
"language": "kotlin",
"name": "kotlin"
},
"language_info": {
"codemirror_mode": "text/x-kotlin",
"file_extension": ".kt",
"mimetype": "text/x-kotlin",
"name": "kotlin",
"nbconvert_exporter": "",
"pygments_lexer": "kotlin",
"version": "1.5.20-dev-4184"
}
},
"nbformat": 4,
"nbformat_minor": 2
}