{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\";\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\" \n", "import sys" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "# import ktrain and ktrain.vision modules\n", "import ktrain\n", "from ktrain import vision" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download a PNG version of the **MNIST** dataset from [here](https://s3.amazonaws.com/fast-ai-imageclas/mnist_png.tgz) and set DATADIR to the extracted folder." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "color_mode detected (grayscale) different than color_mode selected (rgb)\n", "Found 60000 images belonging to 10 classes.\n", "Found 60000 images belonging to 10 classes.\n", "Found 10000 images belonging to 10 classes.\n" ] } ], "source": [ "# load the data with some modest data augmentation\n", "# We load as RGB even though we have grayscale images\n", "# since some models only support RGB images.\n", "DATADIR = 'data/mnist_png'\n", "data_aug = vision.get_data_aug(featurewise_center=True, \n", " featurewise_std_normalization=True,\n", " rotation_range=15,\n", " zoom_range=0.1,\n", " width_shift_range=0.1,\n", " height_shift_range=0.1)\n", "(train_data, val_data, preproc) = vision.images_from_folder(\n", " datadir=DATADIR,\n", " data_aug = data_aug,\n", " train_test_names=['training', 'testing'], \n", " target_size=(32,32), color_mode='rgb')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Is Multi-Label? False\n", "wrn22 model created.\n" ] } ], "source": [ "# get a pre-canned 22-layer Wide Residual Network model\n", "model = vision.image_classifier('wrn22', train_data, val_data)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# get a Learner object\n", "learner = ktrain.get_learner(model=model, train_data=train_data, val_data=val_data, \n", " workers=8, use_multiprocessing=True, batch_size=64)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "simulating training for different learning rates... this may take a few moments...\n", "Epoch 1/5\n", "937/937 [==============================] - 65s 69ms/step - loss: 6.9835 - acc: 0.1659\n", "Epoch 2/5\n", "937/937 [==============================] - 59s 63ms/step - loss: 5.1547 - acc: 0.6753\n", "Epoch 3/5\n", "937/937 [==============================] - 59s 63ms/step - loss: 0.9958 - acc: 0.9507\n", "Epoch 4/5\n", "756/937 [=======================>......] - ETA: 11s - loss: 1.4162 - acc: 0.8885\n", "\n", "done.\n", "Please invoke the Learner.lr_plot() method to visually inspect the loss plot to help identify the maximal learning rate associated with falling loss.\n" ] } ], "source": [ "# find a good learning rate\n", "learner.lr_find()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learner.lr_plot()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "begin training using triangular learning rate policy with max lr of 0.002...\n", "Epoch 1/1\n", "936/937 [============================>.] - ETA: 0s - loss: 1.0951 - acc: 0.9207\n", "937/937 [==============================] - 67s 71ms/step - loss: 1.0942 - acc: 0.9207 - val_loss: 0.2455 - val_acc: 0.9932\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train WRN-22 model for a single epoch\n", "learner.autofit(2e-3, 1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# get a Predictor object that we can use to classify (potentially unlabeled) images\n", "predictor = ktrain.get_predictor(learner.model, preproc)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# let's see the class labels and their indices\n", "predictor.get_classes()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['7']" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# let's try to predict an image depicting a 7\n", "predictor.predict_filename('/home/amaiya/data/mnist_png/testing/7/7021.png')" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[9.99729097e-01, 1.02616505e-05, 4.12802947e-05, 1.04568608e-05,\n", " 5.07383811e-06, 3.03435208e-05, 1.10756089e-04, 2.12177038e-05,\n", " 1.56492970e-05, 2.57711799e-05]], dtype=float32)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# let's try predicting an image showing a 0 and return probabilities for all classes\n", "predictor.predict_filename('/home/amaiya/data/mnist_png/testing/0/101.png', return_proba=True)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1010 images belonging to 1 classes.\n" ] }, { "data": { "text/plain": [ "[('3/1020.png', '3'),\n", " ('3/1028.png', '3'),\n", " ('3/1042.png', '3'),\n", " ('3/1062.png', '3'),\n", " ('3/1066.png', '3'),\n", " ('3/1067.png', '3'),\n", " ('3/1069.png', '3'),\n", " ('3/1072.png', '3'),\n", " ('3/1092.png', '3'),\n", " ('3/1095.png', '3')]" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# let's predict all images showing a 3 in our validation set\n", "predictor.predict_folder('/home/amaiya/data/mnist_png/testing/3/')[:10]" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# let's save the predictor for possible later deployment in an application\n", "predictor.save('/tmp/mypredictor')" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# reload the predictor from a file\n", "predictor = ktrain.load_predictor('/tmp/mypredictor')" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['7']" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# let's use the reloaded predictor to verify it still works correctly\n", "predictor.predict_filename('/home/amaiya/data/mnist_png/testing/7/7021.png')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }