{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST Visualizations\n",
"\n",
"In this notebook, we continue with the MNIST analysis after our [initial exploration](MNIST.ipynb).\n",
"\n",
"Let's again begin by reading in the MNIST dataset."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"ConX, version 3.6.5\n"
]
}
],
"source": [
"import conx as cx"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"**Dataset**: MNIST\n",
"\n",
"\n",
"Original source: http://yann.lecun.com/exdb/mnist/\n",
"\n",
"The MNIST dataset contains 70,000 images of handwritten digits (zero\n",
"to nine) that have been size-normalized and centered in a square grid\n",
"of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers\n",
"representing grayscale intensities ranging from 0 (black) to 1\n",
"(white). The target data consists of one-hot binary vectors of size\n",
"10, corresponding to the digit classification categories zero through\n",
"nine. Some example MNIST images are shown below:\n",
"\n",
"\n",
"\n",
"**Information**:\n",
" * name : MNIST\n",
" * length : 70000\n",
"\n",
"**Input Summary**:\n",
" * shape : (28, 28, 1)\n",
" * range : (0.0, 1.0)\n",
"\n",
"**Target Summary**:\n",
" * shape : (10,)\n",
" * range : (0.0, 1.0)\n",
"\n"
],
"text/plain": [
"**Dataset**: MNIST\n",
"\n",
"\n",
"Original source: http://yann.lecun.com/exdb/mnist/\n",
"\n",
"The MNIST dataset contains 70,000 images of handwritten digits (zero\n",
"to nine) that have been size-normalized and centered in a square grid\n",
"of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers\n",
"representing grayscale intensities ranging from 0 (black) to 1\n",
"(white). The target data consists of one-hot binary vectors of size\n",
"10, corresponding to the digit classification categories zero through\n",
"nine. Some example MNIST images are shown below:\n",
"\n",
"\n",
"\n",
"**Information**:\n",
" * name : MNIST\n",
" * length : 70000\n",
"\n",
"**Input Summary**:\n",
" * shape : (28, 28, 1)\n",
" * range : (0.0, 1.0)\n",
"\n",
"**Target Summary**:\n",
" * shape : (10,)\n",
" * range : (0.0, 1.0)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mnist = cx.Dataset.get('mnist')\n",
"mnist.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A Convolutional Network for MNIST Classification\n",
"\n",
"\n",
"Again, we build a CNN."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"cnn = cx.Network(\"MNIST_CNN_Visualize\")\n",
"\n",
"cnn.add(cx.Layer(\"input\", (28,28,1), colormap=\"gray\"),\n",
" cx.Conv2DLayer(\"conv2D_1\", 16, (5,5), activation=\"relu\", dropout=0.20),\n",
" cx.MaxPool2DLayer(\"maxpool1\", (2,2)),\n",
" cx.Conv2DLayer(\"conv2D_2\", 32, (5,5), activation=\"relu\", dropout=0.20),\n",
" cx.MaxPool2DLayer(\"maxpool2\", (2,2)),\n",
" cx.FlattenLayer(\"flat\"),\n",
" cx.Layer(\"hidden\", 30, activation='relu'),\n",
" cx.Layer(\"output\", 10, activation='softmax'))\n",
"\n",
"cnn.connect()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"MNIST:\n",
"Patterns Shape Range \n",
"=================================================================\n",
"inputs (28, 28, 1) (0.0, 1.0) \n",
"targets (10,) (0.0, 1.0) \n",
"=================================================================\n",
"Total patterns: 70000\n",
" Training patterns: 60000\n",
" Testing patterns: 10000\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"cnn.get_dataset(\"MNIST\")\n",
"cnn.dataset.split(10000)\n",
"cnn.dataset.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Again, we will try the RMSprop algorithm, which automatically adjusts the learning rate and momentum as training proceeds."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"cnn.compile(error='categorical_crossentropy', optimizer='RMSprop')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"input (InputLayer) (None, 28, 28, 1) 0 \n",
"_________________________________________________________________\n",
"conv2D_1 (Conv2D) (None, 24, 24, 16) 416 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 24, 24, 16) 0 \n",
"_________________________________________________________________\n",
"maxpool1 (MaxPooling2D) (None, 12, 12, 16) 0 \n",
"_________________________________________________________________\n",
"conv2D_2 (Conv2D) (None, 8, 8, 32) 12832 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 8, 8, 32) 0 \n",
"_________________________________________________________________\n",
"maxpool2 (MaxPooling2D) (None, 4, 4, 32) 0 \n",
"_________________________________________________________________\n",
"flat (Flatten) (None, 512) 0 \n",
"_________________________________________________________________\n",
"hidden (Dense) (None, 30) 15390 \n",
"_________________________________________________________________\n",
"output (Dense) (None, 10) 310 \n",
"=================================================================\n",
"Total params: 28,948\n",
"Trainable params: 28,948\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"cnn.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3ed204ec31fb4ef8ae7a5e3f8ab0147c",
"version_major": 2,
"version_minor": 0
},
"text/html": [
"
Failed to display Jupyter Widget of type Dashboard.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "Dashboard(children=(Accordion(children=(HBox(children=(VBox(children=(Select(description='Dataset:', index=1, options=('Test', 'Train'), rows=1, value='Train'), FloatSlider(value=0.5, continuous_update=False, description='Zoom', layout=Layout(width='65%'), max=1.0, style=SliderStyle(description_width='initial')), IntText(value=150, description='Horizontal space between banks:', style=DescriptionStyle(description_width='initial')), IntText(value=30, description='Vertical space between layers:', style=DescriptionStyle(description_width='initial')), HBox(children=(Checkbox(value=True, description='Show Targets', style=DescriptionStyle(description_width='initial')), Checkbox(value=True, description='Errors', style=DescriptionStyle(description_width='initial')))), Select(description='Features:', index=2, options=('', 'input', 'conv2D_1', 'maxpool1', 'conv2D_2', 'maxpool2'), rows=1, value='conv2D_1'), IntText(value=8, description='Feature columns:', style=DescriptionStyle(description_width='initial')), FloatText(value=4.0, description='Feature scale:', style=DescriptionStyle(description_width='initial'))), layout=Layout(width='100%')), VBox(children=(Select(description='Layer:', index=7, options=('input', 'conv2D_1', 'maxpool1', 'conv2D_2', 'maxpool2', 'flat', 'hidden', 'output'), rows=1, value='output'), Checkbox(value=True, description='Visible'), Select(description='Colormap:', options=('', 'Accent', 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Vega10', 'Vega10_r', 'Vega20', 'Vega20_r', 'Vega20b', 'Vega20b_r', 'Vega20c', 'Vega20c_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'seismic', 'seismic_r', 'spectral', 'spectral_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'viridis', 'viridis_r', 'winter', 'winter_r'), rows=1, value=''), HTML(value='