{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%use deeplearning4j\n", "%use krangl@2fcf74dfbbe382f1803d1ab9e4739439e1f5671b\n", "%use lets-plot" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "val iris_data = \"sepal-length,sepal-width,petal-length,petal-width,species\\n5.1,3.5,1.4,0.2,Iris-setosa\\n4.9,3.0,1.4,0.2,Iris-setosa\\n4.7,3.2,1.3,0.2,Iris-setosa\\n4.6,3.1,1.5,0.2,Iris-setosa\\n5.0,3.6,1.4,0.2,Iris-setosa\\n5.4,3.9,1.7,0.4,Iris-setosa\\n4.6,3.4,1.4,0.3,Iris-setosa\\n5.0,3.4,1.5,0.2,Iris-setosa\\n4.4,2.9,1.4,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n5.4,3.7,1.5,0.2,Iris-setosa\\n4.8,3.4,1.6,0.2,Iris-setosa\\n4.8,3.0,1.4,0.1,Iris-setosa\\n4.3,3.0,1.1,0.1,Iris-setosa\\n5.8,4.0,1.2,0.2,Iris-setosa\\n5.7,4.4,1.5,0.4,Iris-setosa\\n5.4,3.9,1.3,0.4,Iris-setosa\\n5.1,3.5,1.4,0.3,Iris-setosa\\n5.7,3.8,1.7,0.3,Iris-setosa\\n5.1,3.8,1.5,0.3,Iris-setosa\\n5.4,3.4,1.7,0.2,Iris-setosa\\n5.1,3.7,1.5,0.4,Iris-setosa\\n4.6,3.6,1.0,0.2,Iris-setosa\\n5.1,3.3,1.7,0.5,Iris-setosa\\n4.8,3.4,1.9,0.2,Iris-setosa\\n5.0,3.0,1.6,0.2,Iris-setosa\\n5.0,3.4,1.6,0.4,Iris-setosa\\n5.2,3.5,1.5,0.2,Iris-setosa\\n5.2,3.4,1.4,0.2,Iris-setosa\\n4.7,3.2,1.6,0.2,Iris-setosa\\n4.8,3.1,1.6,0.2,Iris-setosa\\n5.4,3.4,1.5,0.4,Iris-setosa\\n5.2,4.1,1.5,0.1,Iris-setosa\\n5.5,4.2,1.4,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n5.0,3.2,1.2,0.2,Iris-setosa\\n5.5,3.5,1.3,0.2,Iris-setosa\\n4.9,3.1,1.5,0.1,Iris-setosa\\n4.4,3.0,1.3,0.2,Iris-setosa\\n5.1,3.4,1.5,0.2,Iris-setosa\\n5.0,3.5,1.3,0.3,Iris-setosa\\n4.5,2.3,1.3,0.3,Iris-setosa\\n4.4,3.2,1.3,0.2,Iris-setosa\\n5.0,3.5,1.6,0.6,Iris-setosa\\n5.1,3.8,1.9,0.4,Iris-setosa\\n4.8,3.0,1.4,0.3,Iris-setosa\\n5.1,3.8,1.6,0.2,Iris-setosa\\n4.6,3.2,1.4,0.2,Iris-setosa\\n5.3,3.7,1.5,0.2,Iris-setosa\\n5.0,3.3,1.4,0.2,Iris-setosa\\n7.0,3.2,4.7,1.4,Iris-versicolor\\n6.4,3.2,4.5,1.5,Iris-versicolor\\n6.9,3.1,4.9,1.5,Iris-versicolor\\n5.5,2.3,4.0,1.3,Iris-versicolor\\n6.5,2.8,4.6,1.5,Iris-versicolor\\n5.7,2.8,4.5,1.3,Iris-versicolor\\n6.3,3.3,4.7,1.6,Iris-versicolor\\n4.9,2.4,3.3,1.0,Iris-versicolor\\n6.6,2.9,4.6,1.3,Iris-versicolor\\n5.2,2.7,3.9,1.4,Iris-versicolor\\n5.0,2.0,3.5,1.0,Iris-versicolor\\n5.9,3.0,4.2,1.5,Iris-versicolor\\n6.0,2.2,4.0,1.0,Iris-versicolor\\n6.1,2.9,4.7,1.4,Iris-versicolor\\n5.6,2.9,3.6,1.3,Iris-versicolor\\n6.7,3.1,4.4,1.4,Iris-versicolor\\n5.6,3.0,4.5,1.5,Iris-versicolor\\n5.8,2.7,4.1,1.0,Iris-versicolor\\n6.2,2.2,4.5,1.5,Iris-versicolor\\n5.6,2.5,3.9,1.1,Iris-versicolor\\n5.9,3.2,4.8,1.8,Iris-versicolor\\n6.1,2.8,4.0,1.3,Iris-versicolor\\n6.3,2.5,4.9,1.5,Iris-versicolor\\n6.1,2.8,4.7,1.2,Iris-versicolor\\n6.4,2.9,4.3,1.3,Iris-versicolor\\n6.6,3.0,4.4,1.4,Iris-versicolor\\n6.8,2.8,4.8,1.4,Iris-versicolor\\n6.7,3.0,5.0,1.7,Iris-versicolor\\n6.0,2.9,4.5,1.5,Iris-versicolor\\n5.7,2.6,3.5,1.0,Iris-versicolor\\n5.5,2.4,3.8,1.1,Iris-versicolor\\n5.5,2.4,3.7,1.0,Iris-versicolor\\n5.8,2.7,3.9,1.2,Iris-versicolor\\n6.0,2.7,5.1,1.6,Iris-versicolor\\n5.4,3.0,4.5,1.5,Iris-versicolor\\n6.0,3.4,4.5,1.6,Iris-versicolor\\n6.7,3.1,4.7,1.5,Iris-versicolor\\n6.3,2.3,4.4,1.3,Iris-versicolor\\n5.6,3.0,4.1,1.3,Iris-versicolor\\n5.5,2.5,4.0,1.3,Iris-versicolor\\n5.5,2.6,4.4,1.2,Iris-versicolor\\n6.1,3.0,4.6,1.4,Iris-versicolor\\n5.8,2.6,4.0,1.2,Iris-versicolor\\n5.0,2.3,3.3,1.0,Iris-versicolor\\n5.6,2.7,4.2,1.3,Iris-versicolor\\n5.7,3.0,4.2,1.2,Iris-versicolor\\n5.7,2.9,4.2,1.3,Iris-versicolor\\n6.2,2.9,4.3,1.3,Iris-versicolor\\n5.1,2.5,3.0,1.1,Iris-versicolor\\n5.7,2.8,4.1,1.3,Iris-versicolor\\n6.3,3.3,6.0,2.5,Iris-virginica\\n5.8,2.7,5.1,1.9,Iris-virginica\\n7.1,3.0,5.9,2.1,Iris-virginica\\n6.3,2.9,5.6,1.8,Iris-virginica\\n6.5,3.0,5.8,2.2,Iris-virginica\\n7.6,3.0,6.6,2.1,Iris-virginica\\n4.9,2.5,4.5,1.7,Iris-virginica\\n7.3,2.9,6.3,1.8,Iris-virginica\\n6.7,2.5,5.8,1.8,Iris-virginica\\n7.2,3.6,6.1,2.5,Iris-virginica\\n6.5,3.2,5.1,2.0,Iris-virginica\\n6.4,2.7,5.3,1.9,Iris-virginica\\n6.8,3.0,5.5,2.1,Iris-virginica\\n5.7,2.5,5.0,2.0,Iris-virginica\\n5.8,2.8,5.1,2.4,Iris-virginica\\n6.4,3.2,5.3,2.3,Iris-virginica\\n6.5,3.0,5.5,1.8,Iris-virginica\\n7.7,3.8,6.7,2.2,Iris-virginica\\n7.7,2.6,6.9,2.3,Iris-virginica\\n6.0,2.2,5.0,1.5,Iris-virginica\\n6.9,3.2,5.7,2.3,Iris-virginica\\n5.6,2.8,4.9,2.0,Iris-virginica\\n7.7,2.8,6.7,2.0,Iris-virginica\\n6.3,2.7,4.9,1.8,Iris-virginica\\n6.7,3.3,5.7,2.1,Iris-virginica\\n7.2,3.2,6.0,1.8,Iris-virginica\\n6.2,2.8,4.8,1.8,Iris-virginica\\n6.1,3.0,4.9,1.8,Iris-virginica\\n6.4,2.8,5.6,2.1,Iris-virginica\\n7.2,3.0,5.8,1.6,Iris-virginica\\n7.4,2.8,6.1,1.9,Iris-virginica\\n7.9,3.8,6.4,2.0,Iris-virginica\\n6.4,2.8,5.6,2.2,Iris-virginica\\n6.3,2.8,5.1,1.5,Iris-virginica\\n6.1,2.6,5.6,1.4,Iris-virginica\\n7.7,3.0,6.1,2.3,Iris-virginica\\n6.3,3.4,5.6,2.4,Iris-virginica\\n6.4,3.1,5.5,1.8,Iris-virginica\\n6.0,3.0,4.8,1.8,Iris-virginica\\n6.9,3.1,5.4,2.1,Iris-virginica\\n6.7,3.1,5.6,2.4,Iris-virginica\\n6.9,3.1,5.1,2.3,Iris-virginica\\n5.8,2.7,5.1,1.9,Iris-virginica\\n6.8,3.2,5.9,2.3,Iris-virginica\\n6.7,3.3,5.7,2.5,Iris-virginica\\n6.7,3.0,5.2,2.3,Iris-virginica\\n6.3,2.5,5.0,1.9,Iris-virginica\\n6.5,3.0,5.2,2.0,Iris-virginica\\n6.2,3.4,5.4,2.3,Iris-virginica\\n5.9,3.0,5.1,1.8,Iris-virginica\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
sepal-lengthsepal-widthpetal-lengthpetal-widthspecies
5.13.31.70.5Iris-setosa
5.82.75.11.9Iris-virginica
5.62.84.92.0Iris-virginica
4.83.01.40.3Iris-setosa
7.72.66.92.3Iris-virginica
" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import java.util.*\n", "import java.io.StringReader\n", "\n", "val iris = DataFrame.readDelim(StringReader(iris_data)).shuffle()\n", "iris.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val points = geomPoint(\n", " data = mapOf(\n", " \"x\" to iris[\"sepal-length\"].values().toList(),\n", " \"y\" to iris[\"sepal-width\"].values().toList(),\n", " \"color\" to iris[\"species\"].values().toList()\n", " ), alpha=1.0)\n", "{\n", " x = \"x\" \n", " y = \"y\"\n", " color = \"color\"\n", "}\n", "\n", "ggplot() + points" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
sepal-lengthsepal-widthpetal-lengthpetal-width
5.13.31.70.5
5.82.75.11.9
5.62.84.92.0
4.83.01.40.3
7.72.66.92.3
" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val irisWithoutLabel = iris.remove(\"species\")\n", "irisWithoutLabel.head()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[5.1, 3.3, 1.7, 0.5]\n", "[5.8, 2.7, 5.1, 1.9]\n", "[5.6, 2.8, 4.9, 2.0]\n", "[4.8, 3.0, 1.4, 0.3]\n", "[7.7, 2.6, 6.9, 2.3]\n", "[5.6, 2.9, 3.6, 1.3]\n", "[6.9, 3.1, 5.4, 2.1]\n", "[5.9, 3.0, 4.2, 1.5]\n", "[4.9, 3.1, 1.5, 0.1]\n", "[6.8, 2.8, 4.8, 1.4]\n", "[6.0, 2.2, 5.0, 1.5]\n", "[6.0, 3.4, 4.5, 1.6]\n", "[5.4, 3.9, 1.3, 0.4]\n", "[5.7, 3.0, 4.2, 1.2]\n", "[7.2, 3.0, 5.8, 1.6]\n", "[6.0, 2.7, 5.1, 1.6]\n", "[6.4, 3.2, 5.3, 2.3]\n", "[5.7, 2.8, 4.1, 1.3]\n", "[5.7, 2.5, 5.0, 2.0]\n", "[6.2, 2.8, 4.8, 1.8]\n", "[5.0, 3.5, 1.3, 0.3]\n", "[5.7, 4.4, 1.5, 0.4]\n", "[6.3, 2.5, 5.0, 1.9]\n", "[7.7, 3.0, 6.1, 2.3]\n", "[4.8, 3.0, 1.4, 0.1]\n", "[5.8, 2.7, 3.9, 1.2]\n", "[5.1, 2.5, 3.0, 1.1]\n", "[6.4, 2.8, 5.6, 2.1]\n", "[5.3, 3.7, 1.5, 0.2]\n", "[4.6, 3.4, 1.4, 0.3]\n", "[7.6, 3.0, 6.6, 2.1]\n", "[4.5, 2.3, 1.3, 0.3]\n", "[5.6, 2.7, 4.2, 1.3]\n", "[5.7, 2.6, 3.5, 1.0]\n", "[6.7, 3.0, 5.0, 1.7]\n", "[6.5, 3.0, 5.8, 2.2]\n", "[5.0, 2.3, 3.3, 1.0]\n", "[6.1, 3.0, 4.9, 1.8]\n", "[6.5, 3.0, 5.2, 2.0]\n", "[6.2, 3.4, 5.4, 2.3]\n", "[4.4, 2.9, 1.4, 0.2]\n", "[5.2, 3.5, 1.5, 0.2]\n", "[7.2, 3.6, 6.1, 2.5]\n", "[5.5, 4.2, 1.4, 0.2]\n", "[6.4, 2.9, 4.3, 1.3]\n", "[4.9, 3.0, 1.4, 0.2]\n", "[6.3, 2.5, 4.9, 1.5]\n", "[5.5, 2.4, 3.7, 1.0]\n", "[4.7, 3.2, 1.6, 0.2]\n", "[6.3, 2.7, 4.9, 1.8]\n", "[6.3, 2.3, 4.4, 1.3]\n", "[7.1, 3.0, 5.9, 2.1]\n", "[5.0, 3.5, 1.6, 0.6]\n", "[6.8, 3.0, 5.5, 2.1]\n", "[4.8, 3.4, 1.9, 0.2]\n", "[6.7, 3.1, 5.6, 2.4]\n", "[5.8, 2.6, 4.0, 1.2]\n", "[5.0, 3.2, 1.2, 0.2]\n", "[6.7, 3.3, 5.7, 2.5]\n", "[5.1, 3.5, 1.4, 0.2]\n", "[6.4, 2.7, 5.3, 1.9]\n", "[7.0, 3.2, 4.7, 1.4]\n", "[6.1, 2.8, 4.7, 1.2]\n", "[5.4, 3.4, 1.7, 0.2]\n", "[4.9, 2.4, 3.3, 1.0]\n", "[5.2, 3.4, 1.4, 0.2]\n", "[6.5, 2.8, 4.6, 1.5]\n", "[5.4, 3.0, 4.5, 1.5]\n", "[7.3, 2.9, 6.3, 1.8]\n", "[5.2, 2.7, 3.9, 1.4]\n", "[5.4, 3.9, 1.7, 0.4]\n", "[6.2, 2.2, 4.5, 1.5]\n", "[5.1, 3.5, 1.4, 0.3]\n", "[4.8, 3.4, 1.6, 0.2]\n", "[7.7, 3.8, 6.7, 2.2]\n", "[5.6, 3.0, 4.5, 1.5]\n", "[6.3, 3.4, 5.6, 2.4]\n", "[5.8, 2.8, 5.1, 2.4]\n", "[5.5, 2.3, 4.0, 1.3]\n", "[4.9, 2.5, 4.5, 1.7]\n", "[6.0, 2.2, 4.0, 1.0]\n", "[5.0, 2.0, 3.5, 1.0]\n", "[5.9, 3.2, 4.8, 1.8]\n", "[5.4, 3.4, 1.5, 0.4]\n", "[6.9, 3.1, 4.9, 1.5]\n", "[4.9, 3.1, 1.5, 0.1]\n", "[5.2, 4.1, 1.5, 0.1]\n", "[5.1, 3.8, 1.5, 0.3]\n", "[5.1, 3.8, 1.6, 0.2]\n", "[6.7, 3.1, 4.7, 1.5]\n", "[5.9, 3.0, 5.1, 1.8]\n", "[5.8, 4.0, 1.2, 0.2]\n", "[4.3, 3.0, 1.1, 0.1]\n", "[6.7, 2.5, 5.8, 1.8]\n", "[6.3, 3.3, 6.0, 2.5]\n", "[5.6, 2.5, 3.9, 1.1]\n", "[4.4, 3.2, 1.3, 0.2]\n", "[4.6, 3.1, 1.5, 0.2]\n", "[5.5, 2.6, 4.4, 1.2]\n", "[6.9, 3.1, 5.1, 2.3]\n", "[6.0, 2.9, 4.5, 1.5]\n", "[7.2, 3.2, 6.0, 1.8]\n", "[6.1, 2.8, 4.0, 1.3]\n", "[5.7, 2.9, 4.2, 1.3]\n", "[5.8, 2.7, 4.1, 1.0]\n", "[4.8, 3.1, 1.6, 0.2]\n", "[6.9, 3.2, 5.7, 2.3]\n", "[5.5, 2.4, 3.8, 1.1]\n", "[5.0, 3.4, 1.5, 0.2]\n", "[4.6, 3.2, 1.4, 0.2]\n", "[4.9, 3.1, 1.5, 0.1]\n", "[6.0, 3.0, 4.8, 1.8]\n", "[6.3, 2.9, 5.6, 1.8]\n", "[6.6, 3.0, 4.4, 1.4]\n", "[7.9, 3.8, 6.4, 2.0]\n", "[5.6, 3.0, 4.1, 1.3]\n", "[5.7, 3.8, 1.7, 0.3]\n", "[5.0, 3.4, 1.6, 0.4]\n", "[5.7, 2.8, 4.5, 1.3]\n", "[6.7, 3.3, 5.7, 2.1]\n", "[6.7, 3.1, 4.4, 1.4]\n", "[6.7, 3.0, 5.2, 2.3]\n", "[5.5, 2.5, 4.0, 1.3]\n", "[5.0, 3.3, 1.4, 0.2]\n", "[4.4, 3.0, 1.3, 0.2]\n", "[6.6, 2.9, 4.6, 1.3]\n", "[7.4, 2.8, 6.1, 1.9]\n", "[6.5, 3.0, 5.5, 1.8]\n", "[6.3, 2.8, 5.1, 1.5]\n", "[6.4, 3.2, 4.5, 1.5]\n", "[6.1, 2.9, 4.7, 1.4]\n", "[4.6, 3.6, 1.0, 0.2]\n", "[5.4, 3.7, 1.5, 0.2]\n", "[5.5, 3.5, 1.3, 0.2]\n", "[6.1, 3.0, 4.6, 1.4]\n", "[5.8, 2.7, 5.1, 1.9]\n", "[6.8, 3.2, 5.9, 2.3]\n", "[6.4, 3.1, 5.5, 1.8]\n", "[7.7, 2.8, 6.7, 2.0]\n", "[5.0, 3.0, 1.6, 0.2]\n", "[6.2, 2.9, 4.3, 1.3]\n", "[5.1, 3.4, 1.5, 0.2]\n", "[6.5, 3.2, 5.1, 2.0]\n", "[5.1, 3.7, 1.5, 0.4]\n", "[6.4, 2.8, 5.6, 2.2]\n", "[5.1, 3.8, 1.9, 0.4]\n", "[5.0, 3.6, 1.4, 0.2]\n", "[4.7, 3.2, 1.3, 0.2]\n", "[6.3, 3.3, 4.7, 1.6]\n", "[6.1, 2.6, 5.6, 1.4]]\n" ] } ], "source": [ "//Convert the iris data into 150x4 matrix\n", "val row = 150\n", "val col = 4\n", "\n", "val irisMatrix = Array(row) { DoubleArray(col) }\n", "var i = 0\n", "for (r in 0 until row) {\n", " for (c in 0 until col) {\n", " irisMatrix[r][c] = irisWithoutLabel[c][r] as Double\n", " }\n", "}\n", "println(Arrays.deepToString(irisMatrix).replace(\"], \", \"]\\n\"))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 0.0, 1.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[1.0, 0.0, 0.0]\n", "[0.0, 1.0, 0.0]\n", "[0.0, 0.0, 1.0]]\n" ] } ], "source": [ "//Now do the same for the label data\n", "val irisLabel = iris.select(\"species\")[0]\n", "\n", "val rowLabel = 150\n", "val colLabel = 3\n", "\n", "val twodimLabel = Array(rowLabel) { DoubleArray(colLabel) }\n", "for (r in 0 until rowLabel) {\n", " when (irisLabel[r]) {\n", " \"Iris-setosa\" -> twodimLabel[r][0] = 1.0\n", " \"Iris-versicolor\" -> twodimLabel[r][1] = 1.0\n", " \"Iris-virginica\" -> twodimLabel[r][2] = 1.0\n", " }\n", "}\n", "println(Arrays.deepToString(twodimLabel).replace(\"], \", \"]\\n\"))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "//Convert the data matrices into training INDArrays\n", "val dataIn = Nd4j.create(irisMatrix)\n", "val dataOut = Nd4j.create(twodimLabel)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import org.nd4j.linalg.lossfunctions.LossFunctions\n", "\n", "val seed: Long = 6\n", "\n", "val conf = NeuralNetConfiguration.Builder()\n", " .seed(seed) //include a random seed for reproducibility\n", " // use stochastic gradient descent as an optimization algorithm\n", " .updater(Nadam()) //specify the rate of change of the learning rate.\n", " .l2(1e-4)\n", " .list()\n", " .layer(DenseLayer.Builder()\n", " .nIn(4)\n", " .nOut(3)\n", " .activation(Activation.TANH)\n", " .weightInit(WeightInit.XAVIER)\n", " .build())\n", " .layer(org.deeplearning4j.nn.conf.layers.DenseLayer.Builder()\n", " .nIn(3)\n", " .nOut(3)\n", " .build())\n", " .layer(OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)\n", " .nIn(3)\n", " .nOut(3)\n", " .activation(Activation.SOFTMAX)\n", " .weightInit(WeightInit.XAVIER)\n", " .build())\n", " .build()\n", "\n", "val model = MultiLayerNetwork(conf)\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Score \n", "\n", "========================Evaluation Metrics========================\n", " # of classes: 3\n", " Accuracy: 1,0000\n", " Precision: 1,0000\n", " Recall: 1,0000\n", " F1 Score: 1,0000\n", "Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)\n", "\n", "\n", "=========================Confusion Matrix=========================\n", " 0 1 2\n", "-------\n", " 5 0 0 | 0 = 0\n", " 0 3 0 | 1 = 1\n", " 0 0 7 | 2 = 2\n", "\n", "Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times\n", "==================================================================\n" ] } ], "source": [ "import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization\n", "import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize\n", "\n", "//Create a data set from the INDArrays and shuffle it \n", "val fullDataSet = DataSet(dataIn, dataOut)\n", "fullDataSet.shuffle(seed)\n", "\n", "val splitedSet = fullDataSet.splitTestAndTrain(0.90)\n", "val trainingData = splitedSet.train;\n", "val testData = splitedSet.test;\n", "\n", "//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):\n", "val normalizer: DataNormalization = NormalizerStandardize()\n", "normalizer.fit(trainingData) //Collect the statistics (mean/stdev) from the training data. This does not modify the input data\n", "normalizer.transform(trainingData) //Apply normalization to the training data\n", "normalizer.transform(testData) //Apply normalization to the test data. This is using statistics calculated from the *training* set\n", "\n", "// train the network\n", "model.setListeners(ScoreIterationListener(100))\n", "for (l in 0..2000) {\n", " model.fit(trainingData)\n", "}\n", "\n", "// evaluate the network\n", "val eval = Evaluation()\n", "val output: INDArray = model.output(testData.features)\n", "eval.eval(testData.labels, output)\n", "println(\"Score \" + eval.stats())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Kotlin", "language": "kotlin", "name": "kotlin" }, "language_info": { "codemirror_mode": "text/x-kotlin", "file_extension": ".kt", "mimetype": "text/x-kotlin", "name": "kotlin", "nbconvert_exporter": "", "pygments_lexer": "kotlin", "version": "1.5.20-dev-4184" } }, "nbformat": 4, "nbformat_minor": 2 }