{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MNIST Visualizations\n", "\n", "In this notebook, we continue with the MNIST analysis after our [initial exploration](MNIST.ipynb).\n", "\n", "Let's again begin by reading in the MNIST dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n", "ConX, version 3.6.5\n" ] } ], "source": [ "import conx as cx" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset**: MNIST\n", "\n", "\n", "Original source: http://yann.lecun.com/exdb/mnist/\n", "\n", "The MNIST dataset contains 70,000 images of handwritten digits (zero\n", "to nine) that have been size-normalized and centered in a square grid\n", "of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers\n", "representing grayscale intensities ranging from 0 (black) to 1\n", "(white). The target data consists of one-hot binary vectors of size\n", "10, corresponding to the digit classification categories zero through\n", "nine. Some example MNIST images are shown below:\n", "\n", "![MNIST Images](https://github.com/Calysto/conx/raw/master/data/mnist_images.png)\n", "\n", "**Information**:\n", " * name : MNIST\n", " * length : 70000\n", "\n", "**Input Summary**:\n", " * shape : (28, 28, 1)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (10,)\n", " * range : (0.0, 1.0)\n", "\n" ], "text/plain": [ "**Dataset**: MNIST\n", "\n", "\n", "Original source: http://yann.lecun.com/exdb/mnist/\n", "\n", "The MNIST dataset contains 70,000 images of handwritten digits (zero\n", "to nine) that have been size-normalized and centered in a square grid\n", "of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers\n", "representing grayscale intensities ranging from 0 (black) to 1\n", "(white). The target data consists of one-hot binary vectors of size\n", "10, corresponding to the digit classification categories zero through\n", "nine. Some example MNIST images are shown below:\n", "\n", "![MNIST Images](https://github.com/Calysto/conx/raw/master/data/mnist_images.png)\n", "\n", "**Information**:\n", " * name : MNIST\n", " * length : 70000\n", "\n", "**Input Summary**:\n", " * shape : (28, 28, 1)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (10,)\n", " * range : (0.0, 1.0)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "mnist = cx.Dataset.get('mnist')\n", "mnist.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A Convolutional Network for MNIST Classification\n", "\n", "\n", "Again, we build a CNN." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "cnn = cx.Network(\"MNIST_CNN_Visualize\")\n", "\n", "cnn.add(cx.Layer(\"input\", (28,28,1), colormap=\"gray\"),\n", " cx.Conv2DLayer(\"conv2D_1\", 16, (5,5), activation=\"relu\", dropout=0.20),\n", " cx.MaxPool2DLayer(\"maxpool1\", (2,2)),\n", " cx.Conv2DLayer(\"conv2D_2\", 32, (5,5), activation=\"relu\", dropout=0.20),\n", " cx.MaxPool2DLayer(\"maxpool2\", (2,2)),\n", " cx.FlattenLayer(\"flat\"),\n", " cx.Layer(\"hidden\", 30, activation='relu'),\n", " cx.Layer(\"output\", 10, activation='softmax'))\n", "\n", "cnn.connect()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "MNIST:\n", "Patterns Shape Range \n", "=================================================================\n", "inputs (28, 28, 1) (0.0, 1.0) \n", "targets (10,) (0.0, 1.0) \n", "=================================================================\n", "Total patterns: 70000\n", " Training patterns: 60000\n", " Testing patterns: 10000\n", "_________________________________________________________________\n" ] } ], "source": [ "cnn.get_dataset(\"MNIST\")\n", "cnn.dataset.split(10000)\n", "cnn.dataset.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, we will try the RMSprop algorithm, which automatically adjusts the learning rate and momentum as training proceeds." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "cnn.compile(error='categorical_crossentropy', optimizer='RMSprop')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "input (InputLayer) (None, 28, 28, 1) 0 \n", "_________________________________________________________________\n", "conv2D_1 (Conv2D) (None, 24, 24, 16) 416 \n", "_________________________________________________________________\n", "dropout_1 (Dropout) (None, 24, 24, 16) 0 \n", "_________________________________________________________________\n", "maxpool1 (MaxPooling2D) (None, 12, 12, 16) 0 \n", "_________________________________________________________________\n", "conv2D_2 (Conv2D) (None, 8, 8, 32) 12832 \n", "_________________________________________________________________\n", "dropout_2 (Dropout) (None, 8, 8, 32) 0 \n", "_________________________________________________________________\n", "maxpool2 (MaxPooling2D) (None, 4, 4, 32) 0 \n", "_________________________________________________________________\n", "flat (Flatten) (None, 512) 0 \n", "_________________________________________________________________\n", "hidden (Dense) (None, 30) 15390 \n", "_________________________________________________________________\n", "output (Dense) (None, 10) 310 \n", "=================================================================\n", "Total params: 28,948\n", "Trainable params: 28,948\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "cnn.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3ed204ec31fb4ef8ae7a5e3f8ab0147c", "version_major": 2, "version_minor": 0 }, "text/html": [ "
Failed to display Jupyter Widget of type Dashboard
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "Dashboard(children=(Accordion(children=(HBox(children=(VBox(children=(Select(description='Dataset:', index=1, options=('Test', 'Train'), rows=1, value='Train'), FloatSlider(value=0.5, continuous_update=False, description='Zoom', layout=Layout(width='65%'), max=1.0, style=SliderStyle(description_width='initial')), IntText(value=150, description='Horizontal space between banks:', style=DescriptionStyle(description_width='initial')), IntText(value=30, description='Vertical space between layers:', style=DescriptionStyle(description_width='initial')), HBox(children=(Checkbox(value=True, description='Show Targets', style=DescriptionStyle(description_width='initial')), Checkbox(value=True, description='Errors', style=DescriptionStyle(description_width='initial')))), Select(description='Features:', index=2, options=('', 'input', 'conv2D_1', 'maxpool1', 'conv2D_2', 'maxpool2'), rows=1, value='conv2D_1'), IntText(value=8, description='Feature columns:', style=DescriptionStyle(description_width='initial')), FloatText(value=4.0, description='Feature scale:', style=DescriptionStyle(description_width='initial'))), layout=Layout(width='100%')), VBox(children=(Select(description='Layer:', index=7, options=('input', 'conv2D_1', 'maxpool1', 'conv2D_2', 'maxpool2', 'flat', 'hidden', 'output'), rows=1, value='output'), Checkbox(value=True, description='Visible'), Select(description='Colormap:', options=('', 'Accent', 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Vega10', 'Vega10_r', 'Vega20', 'Vega20_r', 'Vega20b', 'Vega20b_r', 'Vega20c', 'Vega20c_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'seismic', 'seismic_r', 'spectral', 'spectral_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'viridis', 'viridis_r', 'winter', 'winter_r'), rows=1, value=''), HTML(value=''), FloatText(value=0.0, description='Leftmost color maps to:', style=DescriptionStyle(description_width='initial')), FloatText(value=1.0, description='Rightmost color maps to:', style=DescriptionStyle(description_width='initial')), IntText(value=0, description='Feature to show:', style=DescriptionStyle(description_width='initial')), HBox(children=(Checkbox(value=True, description='Rotate network', layout=Layout(width='52%'), style=DescriptionStyle(description_width='initial')), Button(icon='save', layout=Layout(width='10%'), style=ButtonStyle())))), layout=Layout(width='100%')))),), selected_index=None, _titles={'0': 'MNIST_CNN_Visualize'}), VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='Dataset index', layout=Layout(width='100%'), max=59999), Label(value='of 60000', layout=Layout(width='100px'))), layout=Layout(height='40px')), HBox(children=(Button(icon='fast-backward', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='backward', layout=Layout(width='100%'), style=ButtonStyle()), IntText(value=0, layout=Layout(width='100%')), Button(icon='forward', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='fast-forward', layout=Layout(width='100%'), style=ButtonStyle()), Button(description='Play', icon='play', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='refresh', layout=Layout(width='25%'), style=ButtonStyle())), layout=Layout(height='50px', width='100%'))), layout=Layout(width='100%')), HTML(value='', layout=Layout(justify_content='center', overflow_x='auto', overflow_y='auto', width='95%')), Output()))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "