{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"; " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collect MNIST Dataset as Arrays" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.datasets import mnist\n", "import numpy as np\n", "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "x_train = x_train.astype('float32')\n", "x_test = x_test.astype('float32')\n", "x_train = np.expand_dims(x_train, axis=3)\n", "x_test = np.expand_dims(x_test, axis=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 1: Preprocess Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import ktrain\n", "from ktrain import vision as vis" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "data_aug = vis.get_data_aug( rotation_range=15,\n", " zoom_range=0.1,\n", " width_shift_range=0.1,\n", " height_shift_range=0.1)\n", "classes = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "(trn, val, preproc) = vis.images_from_array(x_train, y_train, \n", " validation_data=None,\n", " val_pct=0.1,\n", " random_state=42,\n", " data_aug=data_aug,\n", " class_names=classes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 2: Load Model and Wrap in `Learner`" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "default_cnn model created.\n" ] } ], "source": [ "# Using a LeNet-style classifier\n", "model = vis.image_classifier('default_cnn', trn, val)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=128)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 3: Find Learning Rate" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Train for 421 steps\n", "Epoch 1/3\n", "421/421 [==============================] - 15s 35ms/step - loss: 2.9301 - accuracy: 0.1410\n", "Epoch 2/3\n", "421/421 [==============================] - 14s 33ms/step - loss: 0.8085 - accuracy: 0.7409\n", "Epoch 3/3\n", "197/421 [=============>................] - ETA: 7s - loss: 0.4978 - accuracy: 0.8852\n", "\n", "done.\n", "Visually inspect loss plot and select learning rate associated with falling loss\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_find(show_plot=True, max_epochs=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## STEP 4: Train Model\n", "\n", "We only train for three epochs for demonstration purposes." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using onecycle policy with max lr of 0.001...\n", "Train for 422 steps, validate for 188 steps\n", "Epoch 1/3\n", "422/422 [==============================] - 15s 35ms/step - loss: 0.8279 - accuracy: 0.7384 - val_loss: 0.0681 - val_accuracy: 0.9802\n", "Epoch 2/3\n", "422/422 [==============================] - 15s 35ms/step - loss: 0.1559 - accuracy: 0.9532 - val_loss: 0.0654 - val_accuracy: 0.9798\n", "Epoch 3/3\n", "422/422 [==============================] - 14s 34ms/step - loss: 0.0990 - accuracy: 0.9702 - val_loss: 0.0316 - val_accuracy: 0.9905\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.fit_onecycle(1e-3, 3)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " zero 1.00 0.99 0.99 624\n", " one 1.00 1.00 1.00 654\n", " two 0.99 0.99 0.99 572\n", " three 0.99 0.99 0.99 589\n", " four 0.99 0.99 0.99 580\n", " five 1.00 0.99 0.99 551\n", " six 0.99 0.99 0.99 580\n", " seven 0.99 0.99 0.99 633\n", " eight 0.98 0.98 0.98 585\n", " nine 0.99 0.99 0.99 632\n", "\n", " accuracy 0.99 6000\n", " macro avg 0.99 0.99 0.99 6000\n", "weighted avg 0.99 0.99 0.99 6000\n", "\n" ] }, { "data": { "text/plain": [ "array([[618, 0, 0, 0, 0, 0, 2, 0, 4, 0],\n", " [ 0, 652, 0, 1, 0, 0, 0, 0, 0, 1],\n", " [ 0, 1, 566, 0, 0, 0, 0, 1, 4, 0],\n", " [ 0, 0, 2, 586, 0, 0, 0, 1, 0, 0],\n", " [ 0, 0, 0, 0, 574, 0, 0, 1, 0, 5],\n", " [ 1, 0, 0, 2, 0, 545, 1, 0, 1, 1],\n", " [ 0, 0, 1, 1, 1, 0, 576, 0, 1, 0],\n", " [ 0, 1, 3, 0, 1, 0, 0, 626, 2, 0],\n", " [ 1, 0, 2, 0, 3, 1, 0, 1, 576, 1],\n", " [ 0, 0, 0, 0, 1, 1, 0, 5, 1, 624]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learner.validate(class_names=preproc.get_classes())" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.view_top_losses(n=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Predictions" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'seven'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictor.predict(x_test[0:1])[0]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(predictor.predict(x_test[0:1], return_proba=True)[0])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "predictor.save('/tmp/my_mnist')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "p = ktrain.load_predictor('/tmp/my_mnist')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'seven'" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p.predict(x_test[0:1])[0]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "predictions = p.predict(x_test)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PredictedActual
0seven7
1two2
2one1
3zero0
4four4
\n", "
" ], "text/plain": [ " Predicted Actual\n", "0 seven 7\n", "1 two 2\n", "2 one 1\n", "3 zero 0\n", "4 four 4" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.DataFrame(zip(predictions, y_test), columns=['Predicted', 'Actual'])\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }