{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Auto-encoder with Conv2D\n",
"\n",
"This notebook demonstrates using a Conv2D network in an autoencoding task with the MNIST dataset."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"ConX, version 3.7.5\n"
]
}
],
"source": [
"import conx as cx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we make a network. We will work with a 3 dimensional input from MNIST, but a flat target vector."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"net = cx.Network(\"Auto-Encoding with Conv\")\n",
"net.add(cx.Layer(\"input\", (28,28,1)),\n",
" cx.Conv2DLayer(\"Conv2D-1\", 16, (5,5), colormap=\"gray\", activation=\"relu\"),\n",
" cx.MaxPool2DLayer(\"maxpool1\", (2,2)),\n",
" cx.Conv2DLayer(\"Conv2D-2\", 132, (5,5), activation=\"relu\"),\n",
" cx.MaxPool2DLayer(\"maxpool2\", (2,2)),\n",
" cx.FlattenLayer(\"flatten\"))\n",
"net.add(cx.Layer(\"output\", 28 * 28, vshape=(28,28), activation='sigmoid'))\n",
"net.connect()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"net.compile(error=\"mse\", optimizer=\"adam\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We load the MNIST dataset and examine the shapes of the inputs and targets."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"net.get_dataset(\"mnist\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"**Dataset**: MNIST\n",
"\n",
"\n",
"Original source: http://yann.lecun.com/exdb/mnist/\n",
"\n",
"The MNIST dataset contains 70,000 images of handwritten digits (zero\n",
"to nine) that have been size-normalized and centered in a square grid\n",
"of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers\n",
"representing grayscale intensities ranging from 0 (black) to 1\n",
"(white). The target data consists of one-hot binary vectors of size\n",
"10, corresponding to the digit classification categories zero through\n",
"nine. Some example MNIST images are shown below:\n",
"\n",
"![MNIST Images](https://github.com/Calysto/conx/raw/master/data/mnist_images.png)\n",
"\n",
"**Information**:\n",
" * name : MNIST\n",
" * length : 70000\n",
"\n",
"**Input Summary**:\n",
" * shape : (28, 28, 1)\n",
" * range : (0.0, 1.0)\n",
"\n",
"**Target Summary**:\n",
" * shape : (10,)\n",
" * range : (0.0, 1.0)\n",
"\n"
],
"text/plain": [
"**Dataset**: MNIST\n",
"\n",
"\n",
"Original source: http://yann.lecun.com/exdb/mnist/\n",
"\n",
"The MNIST dataset contains 70,000 images of handwritten digits (zero\n",
"to nine) that have been size-normalized and centered in a square grid\n",
"of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers\n",
"representing grayscale intensities ranging from 0 (black) to 1\n",
"(white). The target data consists of one-hot binary vectors of size\n",
"10, corresponding to the digit classification categories zero through\n",
"nine. Some example MNIST images are shown below:\n",
"\n",
"![MNIST Images](https://github.com/Calysto/conx/raw/master/data/mnist_images.png)\n",
"\n",
"**Information**:\n",
" * name : MNIST\n",
" * length : 70000\n",
"\n",
"**Input Summary**:\n",
" * shape : (28, 28, 1)\n",
" * range : (0.0, 1.0)\n",
"\n",
"**Target Summary**:\n",
" * shape : (10,)\n",
" * range : (0.0, 1.0)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"net.dataset.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Because this is an auto-encoding task, we wish that the targets were the same as the inputs:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: network 'Auto-Encoding with Conv' target bank #0 has a multi-dimensional shape; is this correct?\n"
]
}
],
"source": [
"net.dataset.set_targets_from_inputs()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, that gives a warning. ConX does not allow targets to have a shape, so we flatten them:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"net.dataset.targets.reshape(28 * 28)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(784,)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"net.dataset.targets.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "35f9df111fa44e8ca24a2202d44df588",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Dashboard(children=(Accordion(children=(HBox(children=(VBox(children=(Select(description='Dataset:', index=1, …"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"net.dashboard()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just to test our design, we chop the majority of patterns, leaving only 100."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"net.dataset.chop(69900)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And save 10% for testing/validation:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"net.dataset.split(0.1)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
""
],
"text/plain": [
"Network display: