{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.1.0\n" ] } ], "source": [ "# Import Tensorflow and check the version\n", "import tensorflow as tf\n", "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(b'Hello world!', shape=(), dtype=string)\n" ] } ], "source": [ "a = tf.constant(\"Hello world!\")\n", "print(a)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Autograph and tf.function()\n", "@tf.function\n", "def f(x):\n", " return tf.add(x, 1.)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tf.Tensor(2.0, shape=(), dtype=float32)\n", "tf.Tensor([2. 2.], shape=(2,), dtype=float32)\n", "tf.Tensor([[4.]], shape=(1, 1), dtype=float32)\n" ] } ], "source": [ "scalar = tf.constant(1.0)\n", "vector = tf.constant([1.0, 1.0])\n", "matrix = tf.constant([[3.0]])\n", "\n", "print(f(scalar))\n", "print(f(vector))\n", "print(f(matrix))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Building a model\n", "n_input = 4\n", "n_output = 3\n", "n_hidden = 10\n", "\n", "# hyperparameter\n", "learning_rate = 0.01\n", "training_epochs = 2000\n", "display_steps = 200" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Getting data\n", "from sklearn. datasets import load_iris\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import OneHotEncoder\n", "\n", "iris_data = load_iris() # load the iris dataset\n", "\n", "x = iris_data.data\n", "y_ = iris_data.target.reshape(-1, 1) # Convert data to a single column" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# One Hot encode the class labels\n", "encoder = OneHotEncoder(sparse = False)\n", "y = encoder.fit_transform(y_)\n", "\n", "train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.20)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Build the model\n", "model = tf.keras.Sequential()\n", "\n", "model.add(tf.keras.layers.Dense(n_hidden, input_shape=(n_input,), activation='relu', name='fc1'))\n", "model.add(tf.keras.layers.Dense(n_output, activation='softmax', name='output'))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Neural Network Model Summary: \n", "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "fc1 (Dense) (None, 10) 50 \n", "_________________________________________________________________\n", "output (Dense) (None, 3) 33 \n", "=================================================================\n", "Total params: 83\n", "Trainable params: 83\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "# Adam optimizer with learning rate of 0.001\n", "optimizer = tf.keras.optimizers.Adam(lr=0.001)\n", "model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])\n", "\n", "print('Neural Network Model Summary: ')\n", "print(model.summary())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/200\n", "24/24 - 0s - loss: 4.0796 - accuracy: 0.3417\n", "Epoch 2/200\n", "24/24 - 0s - loss: 3.2345 - accuracy: 0.3417\n", "Epoch 3/200\n", "24/24 - 0s - loss: 2.6355 - accuracy: 0.3417\n", "Epoch 4/200\n", "24/24 - 0s - loss: 2.1942 - accuracy: 0.3417\n", "Epoch 5/200\n", "24/24 - 0s - loss: 1.8734 - accuracy: 0.3417\n", "Epoch 6/200\n", "24/24 - 0s - loss: 1.6620 - accuracy: 0.4500\n", "Epoch 7/200\n", "24/24 - 0s - loss: 1.4956 - accuracy: 0.6667\n", "Epoch 8/200\n", "24/24 - 0s - loss: 1.3713 - accuracy: 0.6833\n", "Epoch 9/200\n", "24/24 - 0s - loss: 1.2623 - accuracy: 0.6833\n", "Epoch 10/200\n", "24/24 - 0s - loss: 1.1800 - accuracy: 0.6833\n", "Epoch 11/200\n", "24/24 - 0s - loss: 1.1269 - accuracy: 0.6833\n", "Epoch 12/200\n", "24/24 - 0s - loss: 1.0675 - accuracy: 0.6833\n", "Epoch 13/200\n", "24/24 - 0s - loss: 1.0308 - accuracy: 0.6833\n", "Epoch 14/200\n", "24/24 - 0s - loss: 1.0018 - accuracy: 0.6833\n", "Epoch 15/200\n", "24/24 - 0s - loss: 0.9756 - accuracy: 0.6833\n", "Epoch 16/200\n", "24/24 - 0s - loss: 0.9544 - accuracy: 0.6833\n", "Epoch 17/200\n", "24/24 - 0s - loss: 0.9416 - accuracy: 0.6750\n", "Epoch 18/200\n", "24/24 - 0s - loss: 0.9235 - accuracy: 0.6750\n", "Epoch 19/200\n", "24/24 - 0s - loss: 0.9105 - accuracy: 0.6583\n", "Epoch 20/200\n", "24/24 - 0s - loss: 0.8986 - accuracy: 0.6417\n", "Epoch 21/200\n", "24/24 - 0s - loss: 0.8873 - accuracy: 0.6167\n", "Epoch 22/200\n", "24/24 - 0s - loss: 0.8777 - accuracy: 0.5167\n", "Epoch 23/200\n", "24/24 - 0s - loss: 0.8672 - accuracy: 0.5083\n", "Epoch 24/200\n", "24/24 - 0s - loss: 0.8563 - accuracy: 0.4500\n", "Epoch 25/200\n", "24/24 - 0s - loss: 0.8468 - accuracy: 0.4333\n", "Epoch 26/200\n", "24/24 - 0s - loss: 0.8384 - accuracy: 0.4250\n", "Epoch 27/200\n", "24/24 - 0s - loss: 0.8265 - accuracy: 0.4250\n", "Epoch 28/200\n", "24/24 - 0s - loss: 0.8180 - accuracy: 0.4000\n", "Epoch 29/200\n", "24/24 - 0s - loss: 0.8071 - accuracy: 0.4083\n", "Epoch 30/200\n", "24/24 - 0s - loss: 0.7983 - accuracy: 0.4167\n", "Epoch 31/200\n", "24/24 - 0s - loss: 0.7878 - accuracy: 0.4250\n", "Epoch 32/200\n", "24/24 - 0s - loss: 0.7796 - accuracy: 0.4000\n", "Epoch 33/200\n", "24/24 - 0s - loss: 0.7683 - accuracy: 0.4000\n", "Epoch 34/200\n", "24/24 - 0s - loss: 0.7599 - accuracy: 0.4250\n", "Epoch 35/200\n", "24/24 - 0s - loss: 0.7511 - accuracy: 0.4083\n", "Epoch 36/200\n", "24/24 - 0s - loss: 0.7402 - accuracy: 0.4250\n", "Epoch 37/200\n", "24/24 - 0s - loss: 0.7313 - accuracy: 0.4333\n", "Epoch 38/200\n", "24/24 - 0s - loss: 0.7223 - accuracy: 0.4333\n", "Epoch 39/200\n", "24/24 - 0s - loss: 0.7141 - accuracy: 0.4250\n", "Epoch 40/200\n", "24/24 - 0s - loss: 0.7044 - accuracy: 0.4250\n", "Epoch 41/200\n", "24/24 - 0s - loss: 0.6964 - accuracy: 0.4333\n", "Epoch 42/200\n", "24/24 - 0s - loss: 0.6895 - accuracy: 0.4250\n", "Epoch 43/200\n", "24/24 - 0s - loss: 0.6800 - accuracy: 0.4500\n", "Epoch 44/200\n", "24/24 - 0s - loss: 0.6727 - accuracy: 0.4583\n", "Epoch 45/200\n", "24/24 - 0s - loss: 0.6653 - accuracy: 0.4833\n", "Epoch 46/200\n", "24/24 - 0s - loss: 0.6574 - accuracy: 0.4500\n", "Epoch 47/200\n", "24/24 - 0s - loss: 0.6503 - accuracy: 0.4667\n", "Epoch 48/200\n", "24/24 - 0s - loss: 0.6433 - accuracy: 0.4917\n", "Epoch 49/200\n", "24/24 - 0s - loss: 0.6379 - accuracy: 0.4750\n", "Epoch 50/200\n", "24/24 - 0s - loss: 0.6299 - accuracy: 0.4417\n", "Epoch 51/200\n", "24/24 - 0s - loss: 0.6237 - accuracy: 0.5083\n", "Epoch 52/200\n", "24/24 - 0s - loss: 0.6173 - accuracy: 0.5500\n", "Epoch 53/200\n", "24/24 - 0s - loss: 0.6110 - accuracy: 0.5000\n", "Epoch 54/200\n", "24/24 - 0s - loss: 0.6070 - accuracy: 0.5417\n", "Epoch 55/200\n", "24/24 - 0s - loss: 0.6002 - accuracy: 0.5333\n", "Epoch 56/200\n", "24/24 - 0s - loss: 0.5945 - accuracy: 0.5167\n", "Epoch 57/200\n", "24/24 - 0s - loss: 0.5895 - accuracy: 0.5500\n", "Epoch 58/200\n", "24/24 - 0s - loss: 0.5847 - accuracy: 0.5417\n", "Epoch 59/200\n", "24/24 - 0s - loss: 0.5800 - accuracy: 0.5917\n", "Epoch 60/200\n", "24/24 - 0s - loss: 0.5753 - accuracy: 0.6083\n", "Epoch 61/200\n", "24/24 - 0s - loss: 0.5712 - accuracy: 0.5500\n", "Epoch 62/200\n", "24/24 - 0s - loss: 0.5666 - accuracy: 0.6000\n", "Epoch 63/200\n", "24/24 - 0s - loss: 0.5620 - accuracy: 0.6250\n", "Epoch 64/200\n", "24/24 - 0s - loss: 0.5583 - accuracy: 0.6083\n", "Epoch 65/200\n", "24/24 - 0s - loss: 0.5544 - accuracy: 0.6333\n", "Epoch 66/200\n", "24/24 - 0s - loss: 0.5499 - accuracy: 0.6417\n", "Epoch 67/200\n", "24/24 - 0s - loss: 0.5466 - accuracy: 0.6417\n", "Epoch 68/200\n", "24/24 - 0s - loss: 0.5410 - accuracy: 0.6333\n", "Epoch 69/200\n", "24/24 - 0s - loss: 0.5248 - accuracy: 0.6417\n", "Epoch 70/200\n", "24/24 - 0s - loss: 0.5063 - accuracy: 0.7583\n", "Epoch 71/200\n", "24/24 - 0s - loss: 0.5015 - accuracy: 0.7667\n", "Epoch 72/200\n", "24/24 - 0s - loss: 0.4920 - accuracy: 0.8000\n", "Epoch 73/200\n", "24/24 - 0s - loss: 0.4844 - accuracy: 0.8000\n", "Epoch 74/200\n", "24/24 - 0s - loss: 0.4777 - accuracy: 0.8333\n", "Epoch 75/200\n", "24/24 - 0s - loss: 0.4765 - accuracy: 0.8417\n", "Epoch 76/200\n", "24/24 - 0s - loss: 0.4639 - accuracy: 0.8667\n", "Epoch 77/200\n", "24/24 - 0s - loss: 0.4576 - accuracy: 0.8917\n", "Epoch 78/200\n", "24/24 - 0s - loss: 0.4523 - accuracy: 0.8833\n", "Epoch 79/200\n", "24/24 - 0s - loss: 0.4483 - accuracy: 0.8667\n", "Epoch 80/200\n", "24/24 - 0s - loss: 0.4409 - accuracy: 0.9000\n", "Epoch 81/200\n", "24/24 - 0s - loss: 0.4364 - accuracy: 0.8833\n", "Epoch 82/200\n", "24/24 - 0s - loss: 0.4321 - accuracy: 0.8750\n", "Epoch 83/200\n", "24/24 - 0s - loss: 0.4308 - accuracy: 0.9000\n", "Epoch 84/200\n", "24/24 - 0s - loss: 0.4206 - accuracy: 0.9000\n", "Epoch 85/200\n", "24/24 - 0s - loss: 0.4129 - accuracy: 0.9083\n", "Epoch 86/200\n", "24/24 - 0s - loss: 0.4113 - accuracy: 0.8917\n", "Epoch 87/200\n", "24/24 - 0s - loss: 0.4063 - accuracy: 0.9333\n", "Epoch 88/200\n", "24/24 - 0s - loss: 0.3997 - accuracy: 0.9250\n", "Epoch 89/200\n", "24/24 - 0s - loss: 0.3958 - accuracy: 0.9250\n", "Epoch 90/200\n", "24/24 - 0s - loss: 0.3939 - accuracy: 0.9000\n", "Epoch 91/200\n", "24/24 - 0s - loss: 0.3853 - accuracy: 0.9333\n", "Epoch 92/200\n", "24/24 - 0s - loss: 0.3813 - accuracy: 0.9250\n", "Epoch 93/200\n", "24/24 - 0s - loss: 0.3763 - accuracy: 0.9333\n", "Epoch 94/200\n", "24/24 - 0s - loss: 0.3711 - accuracy: 0.9417\n", "Epoch 95/200\n", "24/24 - 0s - loss: 0.3672 - accuracy: 0.9417\n", "Epoch 96/200\n", "24/24 - 0s - loss: 0.3646 - accuracy: 0.9417\n", "Epoch 97/200\n", "24/24 - 0s - loss: 0.3582 - accuracy: 0.9333\n", "Epoch 98/200\n", "24/24 - 0s - loss: 0.3581 - accuracy: 0.9417\n", "Epoch 99/200\n", "24/24 - 0s - loss: 0.3531 - accuracy: 0.9167\n", "Epoch 100/200\n", "24/24 - 0s - loss: 0.3469 - accuracy: 0.9417\n", "Epoch 101/200\n", "24/24 - 0s - loss: 0.3446 - accuracy: 0.9417\n", "Epoch 102/200\n", "24/24 - 0s - loss: 0.3441 - accuracy: 0.9083\n", "Epoch 103/200\n", "24/24 - 0s - loss: 0.3326 - accuracy: 0.9500\n", "Epoch 104/200\n", "24/24 - 0s - loss: 0.3305 - accuracy: 0.9417\n", "Epoch 105/200\n", "24/24 - 0s - loss: 0.3264 - accuracy: 0.9667\n", "Epoch 106/200\n", "24/24 - 0s - loss: 0.3233 - accuracy: 0.9500\n", "Epoch 107/200\n", "24/24 - 0s - loss: 0.3173 - accuracy: 0.9667\n", "Epoch 108/200\n", "24/24 - 0s - loss: 0.3159 - accuracy: 0.9583\n", "Epoch 109/200\n", "24/24 - 0s - loss: 0.3122 - accuracy: 0.9583\n", "Epoch 110/200\n", "24/24 - 0s - loss: 0.3064 - accuracy: 0.9583\n", "Epoch 111/200\n", "24/24 - 0s - loss: 0.3035 - accuracy: 0.9667\n", "Epoch 112/200\n", "24/24 - 0s - loss: 0.3005 - accuracy: 0.9583\n", "Epoch 113/200\n", "24/24 - 0s - loss: 0.2964 - accuracy: 0.9500\n", "Epoch 114/200\n", "24/24 - 0s - loss: 0.2955 - accuracy: 0.9667\n", "Epoch 115/200\n", "24/24 - 0s - loss: 0.2891 - accuracy: 0.9833\n", "Epoch 116/200\n", "24/24 - 0s - loss: 0.2853 - accuracy: 0.9583\n", "Epoch 117/200\n", "24/24 - 0s - loss: 0.2850 - accuracy: 0.9750\n", "Epoch 118/200\n", "24/24 - 0s - loss: 0.2804 - accuracy: 0.9667\n", "Epoch 119/200\n", "24/24 - 0s - loss: 0.2768 - accuracy: 0.9833\n", "Epoch 120/200\n", "24/24 - 0s - loss: 0.2737 - accuracy: 0.9750\n", "Epoch 121/200\n", "24/24 - 0s - loss: 0.2725 - accuracy: 0.9500\n", "Epoch 122/200\n", "24/24 - 0s - loss: 0.2700 - accuracy: 0.9583\n", "Epoch 123/200\n", "24/24 - 0s - loss: 0.2649 - accuracy: 0.9750\n", "Epoch 124/200\n", "24/24 - 0s - loss: 0.2614 - accuracy: 0.9750\n", "Epoch 125/200\n", "24/24 - 0s - loss: 0.2685 - accuracy: 0.9417\n", "Epoch 126/200\n", "24/24 - 0s - loss: 0.2606 - accuracy: 0.9500\n", "Epoch 127/200\n", "24/24 - 0s - loss: 0.2513 - accuracy: 0.9750\n", "Epoch 128/200\n", "24/24 - 0s - loss: 0.2522 - accuracy: 0.9583\n", "Epoch 129/200\n", "24/24 - 0s - loss: 0.2468 - accuracy: 0.9500\n", "Epoch 130/200\n", "24/24 - 0s - loss: 0.2465 - accuracy: 0.9750\n", "Epoch 131/200\n", "24/24 - 0s - loss: 0.2430 - accuracy: 0.9667\n", "Epoch 132/200\n", "24/24 - 0s - loss: 0.2388 - accuracy: 0.9583\n", "Epoch 133/200\n", "24/24 - 0s - loss: 0.2384 - accuracy: 0.9750\n", "Epoch 134/200\n", "24/24 - 0s - loss: 0.2349 - accuracy: 0.9667\n", "Epoch 135/200\n", "24/24 - 0s - loss: 0.2322 - accuracy: 0.9583\n", "Epoch 136/200\n", "24/24 - 0s - loss: 0.2289 - accuracy: 0.9667\n", "Epoch 137/200\n", "24/24 - 0s - loss: 0.2269 - accuracy: 0.9667\n", "Epoch 138/200\n", "24/24 - 0s - loss: 0.2256 - accuracy: 0.9667\n", "Epoch 139/200\n", "24/24 - 0s - loss: 0.2212 - accuracy: 0.9667\n", "Epoch 140/200\n", "24/24 - 0s - loss: 0.2184 - accuracy: 0.9833\n", "Epoch 141/200\n", "24/24 - 0s - loss: 0.2175 - accuracy: 0.9667\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 142/200\n", "24/24 - 0s - loss: 0.2156 - accuracy: 0.9667\n", "Epoch 143/200\n", "24/24 - 0s - loss: 0.2146 - accuracy: 0.9667\n", "Epoch 144/200\n", "24/24 - 0s - loss: 0.2091 - accuracy: 0.9667\n", "Epoch 145/200\n", "24/24 - 0s - loss: 0.2108 - accuracy: 0.9667\n", "Epoch 146/200\n", "24/24 - 0s - loss: 0.2057 - accuracy: 0.9750\n", "Epoch 147/200\n", "24/24 - 0s - loss: 0.2032 - accuracy: 0.9750\n", "Epoch 148/200\n", "24/24 - 0s - loss: 0.2012 - accuracy: 0.9833\n", "Epoch 149/200\n", "24/24 - 0s - loss: 0.2014 - accuracy: 0.9667\n", "Epoch 150/200\n", "24/24 - 0s - loss: 0.1973 - accuracy: 0.9667\n", "Epoch 151/200\n", "24/24 - 0s - loss: 0.1958 - accuracy: 0.9833\n", "Epoch 152/200\n", "24/24 - 0s - loss: 0.1936 - accuracy: 0.9750\n", "Epoch 153/200\n", "24/24 - 0s - loss: 0.1919 - accuracy: 0.9750\n", "Epoch 154/200\n", "24/24 - 0s - loss: 0.1898 - accuracy: 0.9833\n", "Epoch 155/200\n", "24/24 - 0s - loss: 0.1887 - accuracy: 0.9833\n", "Epoch 156/200\n", "24/24 - 0s - loss: 0.1890 - accuracy: 0.9583\n", "Epoch 157/200\n", "24/24 - 0s - loss: 0.1840 - accuracy: 0.9833\n", "Epoch 158/200\n", "24/24 - 0s - loss: 0.1886 - accuracy: 0.9750\n", "Epoch 159/200\n", "24/24 - 0s - loss: 0.1832 - accuracy: 0.9750\n", "Epoch 160/200\n", "24/24 - 0s - loss: 0.1791 - accuracy: 0.9833\n", "Epoch 161/200\n", "24/24 - 0s - loss: 0.1791 - accuracy: 0.9750\n", "Epoch 162/200\n", "24/24 - 0s - loss: 0.1763 - accuracy: 0.9750\n", "Epoch 163/200\n", "24/24 - 0s - loss: 0.1747 - accuracy: 0.9750\n", "Epoch 164/200\n", "24/24 - 0s - loss: 0.1740 - accuracy: 0.9833\n", "Epoch 165/200\n", "24/24 - 0s - loss: 0.1786 - accuracy: 0.9667\n", "Epoch 166/200\n", "24/24 - 0s - loss: 0.1705 - accuracy: 0.9750\n", "Epoch 167/200\n", "24/24 - 0s - loss: 0.1680 - accuracy: 0.9833\n", "Epoch 168/200\n", "24/24 - 0s - loss: 0.1666 - accuracy: 0.9750\n", "Epoch 169/200\n", "24/24 - 0s - loss: 0.1663 - accuracy: 0.9833\n", "Epoch 170/200\n", "24/24 - 0s - loss: 0.1653 - accuracy: 0.9833\n", "Epoch 171/200\n", "24/24 - 0s - loss: 0.1658 - accuracy: 0.9583\n", "Epoch 172/200\n", "24/24 - 0s - loss: 0.1626 - accuracy: 0.9750\n", "Epoch 173/200\n", "24/24 - 0s - loss: 0.1606 - accuracy: 0.9750\n", "Epoch 174/200\n", "24/24 - 0s - loss: 0.1590 - accuracy: 0.9833\n", "Epoch 175/200\n", "24/24 - 0s - loss: 0.1572 - accuracy: 0.9833\n", "Epoch 176/200\n", "24/24 - 0s - loss: 0.1574 - accuracy: 0.9750\n", "Epoch 177/200\n", "24/24 - 0s - loss: 0.1551 - accuracy: 0.9833\n", "Epoch 178/200\n", "24/24 - 0s - loss: 0.1552 - accuracy: 0.9750\n", "Epoch 179/200\n", "24/24 - 0s - loss: 0.1541 - accuracy: 0.9833\n", "Epoch 180/200\n", "24/24 - 0s - loss: 0.1530 - accuracy: 0.9833\n", "Epoch 181/200\n", "24/24 - 0s - loss: 0.1513 - accuracy: 0.9833\n", "Epoch 182/200\n", "24/24 - 0s - loss: 0.1486 - accuracy: 0.9833\n", "Epoch 183/200\n", "24/24 - 0s - loss: 0.1483 - accuracy: 0.9750\n", "Epoch 184/200\n", "24/24 - 0s - loss: 0.1458 - accuracy: 0.9833\n", "Epoch 185/200\n", "24/24 - 0s - loss: 0.1471 - accuracy: 0.9667\n", "Epoch 186/200\n", "24/24 - 0s - loss: 0.1476 - accuracy: 0.9750\n", "Epoch 187/200\n", "24/24 - 0s - loss: 0.1437 - accuracy: 0.9750\n", "Epoch 188/200\n", "24/24 - 0s - loss: 0.1418 - accuracy: 0.9833\n", "Epoch 189/200\n", "24/24 - 0s - loss: 0.1427 - accuracy: 0.9833\n", "Epoch 190/200\n", "24/24 - 0s - loss: 0.1418 - accuracy: 0.9833\n", "Epoch 191/200\n", "24/24 - 0s - loss: 0.1447 - accuracy: 0.9750\n", "Epoch 192/200\n", "24/24 - 0s - loss: 0.1388 - accuracy: 0.9750\n", "Epoch 193/200\n", "24/24 - 0s - loss: 0.1375 - accuracy: 0.9833\n", "Epoch 194/200\n", "24/24 - 0s - loss: 0.1362 - accuracy: 0.9750\n", "Epoch 195/200\n", "24/24 - 0s - loss: 0.1380 - accuracy: 0.9667\n", "Epoch 196/200\n", "24/24 - 0s - loss: 0.1355 - accuracy: 0.9750\n", "Epoch 197/200\n", "24/24 - 0s - loss: 0.1358 - accuracy: 0.9833\n", "Epoch 198/200\n", "24/24 - 0s - loss: 0.1329 - accuracy: 0.9833\n", "Epoch 199/200\n", "24/24 - 0s - loss: 0.1320 - accuracy: 0.9750\n", "Epoch 200/200\n", "24/24 - 0s - loss: 0.1322 - accuracy: 0.9917\n", "1/1 [==============================] - 0s 2ms/step - loss: 0.1019 - accuracy: 1.0000\n", "Final test set loss: 0.101947\n", "Final test set accuracy: 1.000000\n" ] } ], "source": [ "# Train the model\n", "model.fit(train_x, train_y, verbose=2, batch_size=5, epochs=200)\n", "\n", "# Test on unseen data\n", "results = model.evaluate(test_x, test_y)\n", "\n", "print('Final test set loss: {:4f}'.format(results[0]))\n", "print('Final test set accuracy: {:4f}'.format(results[1]))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['abstract_reasoning',\n", " 'aeslc',\n", " 'aflw2k3d',\n", " 'amazon_us_reviews',\n", " 'arc',\n", " 'bair_robot_pushing_small',\n", " 'beans',\n", " 'big_patent',\n", " 'bigearthnet',\n", " 'billsum',\n", " 'binarized_mnist',\n", " 'binary_alpha_digits',\n", " 'c4',\n", " 'caltech101',\n", " 'caltech_birds2010',\n", " 'caltech_birds2011',\n", " 'cars196',\n", " 'cassava',\n", " 'cats_vs_dogs',\n", " 'celeb_a',\n", " 'celeb_a_hq',\n", " 'cfq',\n", " 'chexpert',\n", " 'cifar10',\n", " 'cifar100',\n", " 'cifar10_1',\n", " 'cifar10_corrupted',\n", " 'citrus_leaves',\n", " 'cityscapes',\n", " 'civil_comments',\n", " 'clevr',\n", " 'cmaterdb',\n", " 'cnn_dailymail',\n", " 'coco',\n", " 'coil100',\n", " 'colorectal_histology',\n", " 'colorectal_histology_large',\n", " 'cos_e',\n", " 'curated_breast_imaging_ddsm',\n", " 'cycle_gan',\n", " 'deep_weeds',\n", " 'definite_pronoun_resolution',\n", " 'diabetic_retinopathy_detection',\n", " 'div2k',\n", " 'dmlab',\n", " 'downsampled_imagenet',\n", " 'dsprites',\n", " 'dtd',\n", " 'duke_ultrasound',\n", " 'dummy_dataset_shared_generator',\n", " 'dummy_mnist',\n", " 'emnist',\n", " 'eraser_multi_rc',\n", " 'esnli',\n", " 'eurosat',\n", " 'fashion_mnist',\n", " 'flic',\n", " 'flores',\n", " 'food101',\n", " 'gap',\n", " 'gigaword',\n", " 'glue',\n", " 'groove',\n", " 'higgs',\n", " 'horses_or_humans',\n", " 'i_naturalist2017',\n", " 'image_label_folder',\n", " 'imagenet2012',\n", " 'imagenet2012_corrupted',\n", " 'imagenet_resized',\n", " 'imagenette',\n", " 'imagewang',\n", " 'imdb_reviews',\n", " 'iris',\n", " 'kitti',\n", " 'kmnist',\n", " 'lfw',\n", " 'librispeech',\n", " 'librispeech_lm',\n", " 'libritts',\n", " 'lm1b',\n", " 'lost_and_found',\n", " 'lsun',\n", " 'malaria',\n", " 'math_dataset',\n", " 'mnist',\n", " 'mnist_corrupted',\n", " 'movie_rationales',\n", " 'moving_mnist',\n", " 'multi_news',\n", " 'multi_nli',\n", " 'multi_nli_mismatch',\n", " 'natural_questions',\n", " 'newsroom',\n", " 'nsynth',\n", " 'omniglot',\n", " 'open_images_v4',\n", " 'opinosis',\n", " 'oxford_flowers102',\n", " 'oxford_iiit_pet',\n", " 'para_crawl',\n", " 'patch_camelyon',\n", " 'pet_finder',\n", " 'places365_small',\n", " 'plant_leaves',\n", " 'plant_village',\n", " 'plantae_k',\n", " 'qa4mre',\n", " 'quickdraw_bitmap',\n", " 'reddit_tifu',\n", " 'resisc45',\n", " 'rock_paper_scissors',\n", " 'rock_you',\n", " 'scan',\n", " 'scene_parse150',\n", " 'scicite',\n", " 'scientific_papers',\n", " 'shapes3d',\n", " 'smallnorb',\n", " 'snli',\n", " 'so2sat',\n", " 'speech_commands',\n", " 'squad',\n", " 'stanford_dogs',\n", " 'stanford_online_products',\n", " 'starcraft_video',\n", " 'sun397',\n", " 'super_glue',\n", " 'svhn_cropped',\n", " 'ted_hrlr_translate',\n", " 'ted_multi_translate',\n", " 'tf_flowers',\n", " 'the300w_lp',\n", " 'tiny_shakespeare',\n", " 'titanic',\n", " 'trivia_qa',\n", " 'uc_merced',\n", " 'ucf101',\n", " 'vgg_face2',\n", " 'visual_domain_decathlon',\n", " 'voc',\n", " 'wider_face',\n", " 'wikihow',\n", " 'wikipedia',\n", " 'wmt14_translate',\n", " 'wmt15_translate',\n", " 'wmt16_translate',\n", " 'wmt17_translate',\n", " 'wmt18_translate',\n", " 'wmt19_translate',\n", " 'wmt_t2t_translate',\n", " 'wmt_translate',\n", " 'xnli',\n", " 'xsum',\n", " 'yelp_polarity_reviews']" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import tensorflow_datasets as tfds\n", "tfds.list_builders()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "iris = tfds.load(name=\"iris\", split=None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }