{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Datasets\n", "\n", "A dataset is a list of (input, target) pairs that can be further split into training and testing lists.\n", "\n", "Let's make an example network to use as demonstration. This network will compute whether the number of 1's in a set of 5 bits is odd." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "input (InputLayer) (None, 5) 0 \n", "_________________________________________________________________\n", "hidden (Dense) (None, 10) 60 \n", "_________________________________________________________________\n", "output (Dense) (None, 1) 11 \n", "=================================================================\n", "Total params: 71\n", "Trainable params: 71\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "ConX, version 3.7.5\n" ] } ], "source": [ "import conx as cx\n", "\n", "net = cx.Network(\"Odd Network\")\n", "net.add(cx.Layer(\"input\", 5))\n", "net.add(cx.Layer(\"hidden\", 10, activation=\"relu\"))\n", "net.add(cx.Layer(\"output\", 1, activation=\"sigmoid\"))\n", "net.connect()\n", "net.compile(error=\"mse\", optimizer=\"adam\")\n", "net.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## As a list of (input, target) pairs\n", "\n", "The most straightforward method of adding input, target vectors to train on is to use a list of (input, target) pairs. First we define a function that takes a number and returns the bitwise representation of it:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def num2bin(i, bits=5):\n", " \"\"\"\n", " Take a number and turn it into a list of bits (most significant first).\n", " \"\"\"\n", " return [int(s) for s in ((\"0\" * bits) + bin(i)[2:])[-bits:]]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1, 0, 1, 1, 1]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num2bin(23)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we make a list of (input, target) pairs:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "patterns = []\n", "\n", "for i in range(2 ** 5):\n", " inputs = num2bin(i)\n", " targets = [int(sum(inputs) % 2 == 1.0)]\n", " patterns.append((inputs, targets))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pair set 5 looks like:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([0, 0, 1, 0, 1], [0])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "patterns[5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We set the network to use this dataset:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[([0, 0, 0, 0, 0], [0]),\n", " ([0, 0, 0, 0, 1], [1]),\n", " ([0, 0, 0, 1, 0], [1]),\n", " ([0, 0, 0, 1, 1], [0]),\n", " ([0, 0, 1, 0, 0], [1]),\n", " ([0, 0, 1, 0, 1], [0]),\n", " ([0, 0, 1, 1, 0], [0]),\n", " ([0, 0, 1, 1, 1], [1]),\n", " ([0, 1, 0, 0, 0], [1]),\n", " ([0, 1, 0, 0, 1], [0]),\n", " ([0, 1, 0, 1, 0], [0]),\n", " ([0, 1, 0, 1, 1], [1]),\n", " ([0, 1, 1, 0, 0], [0]),\n", " ([0, 1, 1, 0, 1], [1]),\n", " ([0, 1, 1, 1, 0], [1]),\n", " ([0, 1, 1, 1, 1], [0]),\n", " ([1, 0, 0, 0, 0], [1]),\n", " ([1, 0, 0, 0, 1], [0]),\n", " ([1, 0, 0, 1, 0], [0]),\n", " ([1, 0, 0, 1, 1], [1]),\n", " ([1, 0, 1, 0, 0], [0]),\n", " ([1, 0, 1, 0, 1], [1]),\n", " ([1, 0, 1, 1, 0], [1]),\n", " ([1, 0, 1, 1, 1], [0]),\n", " ([1, 1, 0, 0, 0], [0]),\n", " ([1, 1, 0, 0, 1], [1]),\n", " ([1, 1, 0, 1, 0], [1]),\n", " ([1, 1, 0, 1, 1], [0]),\n", " ([1, 1, 1, 0, 0], [1]),\n", " ([1, 1, 1, 0, 1], [0]),\n", " ([1, 1, 1, 1, 0], [0]),\n", " ([1, 1, 1, 1, 1], [1])]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "patterns" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "net.dataset.load(patterns)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset**: Dataset for Odd Network\n", "\n", "**Information**:\n", " * name : None\n", " * length : 32\n", "\n", "**Input Summary**:\n", " * shape : (5,)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (1,)\n", " * range : (0.0, 1.0)\n", "\n" ], "text/plain": [ "**Dataset**: Dataset for Odd Network\n", "\n", "**Information**:\n", " * name : None\n", " * length : 32\n", "\n", "**Input Summary**:\n", " * shape : (5,)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (1,)\n", " * range : (0.0, 1.0)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net.dataset.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset.add()\n", "\n", "You can use the default `dataset` and add one pattern at a time. Consider the task of training a network to determine if the number of inputs is even (0) or odd (1). We could add inputs one at a time:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "net.dataset.clear()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "net.dataset.append([0, 0, 0, 0, 1], [1])\n", "net.dataset.append([0, 0, 0, 1, 1], [0])\n", "net.dataset.append([0, 0, 1, 0, 0], [1])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "net.dataset.clear()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "for i in range(2 ** 5):\n", " inputs = num2bin(i)\n", " targets = [int(sum(inputs) % 2 == 1.0)]\n", " net.dataset.append(inputs, targets)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset**: Dataset for Odd Network\n", "\n", "**Information**:\n", " * name : None\n", " * length : 32\n", "\n", "**Input Summary**:\n", " * shape : (5,)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (1,)\n", " * range : (0.0, 1.0)\n", "\n" ], "text/plain": [ "**Dataset**: Dataset for Odd Network\n", "\n", "**Information**:\n", " * name : None\n", " * length : 32\n", "\n", "**Input Summary**:\n", " * shape : (5,)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (1,)\n", " * range : (0.0, 1.0)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net.dataset.info()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.0, 1.0, 1.0, 0.0, 1.0]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.dataset.inputs[13]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1.0]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.dataset.targets[13]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "net.reset()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "========================================================\n", " | Training | Training \n", "Epochs | Error | Accuracy \n", "------ | --------- | --------- \n", "# 3982 | 0.02909 | 0.75000 \n" ] } ], "source": [ "net.train(epochs=5000, accuracy=.75, tolerance=.2, report_rate=100, plot=True)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "========================================================\n", "Testing validation dataset with tolerance 0.2...\n", "# | inputs | targets | outputs | result\n", "---------------------------------------\n", "0 | [[0.00, 0.00, 0.00, 0.00, 0.00]] | [[0.00]] | [0.01] | correct\n", "1 | [[0.00, 0.00, 0.00, 0.00, 1.00]] | [[1.00]] | [0.83] | correct\n", "2 | [[0.00, 0.00, 0.00, 1.00, 0.00]] | [[1.00]] | [0.83] | correct\n", "3 | [[0.00, 0.00, 0.00, 1.00, 1.00]] | [[0.00]] | [0.04] | correct\n", "4 | [[0.00, 0.00, 1.00, 0.00, 0.00]] | [[1.00]] | [0.82] | correct\n", "5 | [[0.00, 0.00, 1.00, 0.00, 1.00]] | [[0.00]] | [0.09] | correct\n", "6 | [[0.00, 0.00, 1.00, 1.00, 0.00]] | [[0.00]] | [0.27] | X\n", "7 | [[0.00, 0.00, 1.00, 1.00, 1.00]] | [[1.00]] | [0.96] | correct\n", "8 | [[0.00, 1.00, 0.00, 0.00, 0.00]] | [[1.00]] | [0.93] | correct\n", "9 | [[0.00, 1.00, 0.00, 0.00, 1.00]] | [[0.00]] | [0.33] | X\n", "10 | [[0.00, 1.00, 0.00, 1.00, 0.00]] | [[0.00]] | [0.01] | correct\n", "11 | [[0.00, 1.00, 0.00, 1.00, 1.00]] | [[1.00]] | [0.84] | correct\n", "12 | [[0.00, 1.00, 1.00, 0.00, 0.00]] | [[0.00]] | [0.04] | correct\n", "13 | [[0.00, 1.00, 1.00, 0.00, 1.00]] | [[1.00]] | [0.80] | X\n", "14 | [[0.00, 1.00, 1.00, 1.00, 0.00]] | [[1.00]] | [0.91] | correct\n", "15 | [[0.00, 1.00, 1.00, 1.00, 1.00]] | [[0.00]] | [0.06] | correct\n", "16 | [[1.00, 0.00, 0.00, 0.00, 0.00]] | [[1.00]] | [0.93] | correct\n", "17 | [[1.00, 0.00, 0.00, 0.00, 1.00]] | [[0.00]] | [0.21] | X\n", "18 | [[1.00, 0.00, 0.00, 1.00, 0.00]] | [[0.00]] | [0.17] | correct\n", "19 | [[1.00, 0.00, 0.00, 1.00, 1.00]] | [[1.00]] | [0.83] | correct\n", "20 | [[1.00, 0.00, 1.00, 0.00, 0.00]] | [[0.00]] | [0.22] | X\n", "21 | [[1.00, 0.00, 1.00, 0.00, 1.00]] | [[1.00]] | [0.93] | correct\n", "22 | [[1.00, 0.00, 1.00, 1.00, 0.00]] | [[1.00]] | [0.71] | X\n", "23 | [[1.00, 0.00, 1.00, 1.00, 1.00]] | [[0.00]] | [0.06] | correct\n", "24 | [[1.00, 1.00, 0.00, 0.00, 0.00]] | [[0.00]] | [0.04] | correct\n", "25 | [[1.00, 1.00, 0.00, 0.00, 1.00]] | [[1.00]] | [0.62] | X\n", "26 | [[1.00, 1.00, 0.00, 1.00, 0.00]] | [[1.00]] | [0.95] | correct\n", "27 | [[1.00, 1.00, 0.00, 1.00, 1.00]] | [[0.00]] | [0.25] | X\n", "28 | [[1.00, 1.00, 1.00, 0.00, 0.00]] | [[1.00]] | [0.85] | correct\n", "29 | [[1.00, 1.00, 1.00, 0.00, 1.00]] | [[0.00]] | [0.20] | X\n", "30 | [[1.00, 1.00, 1.00, 1.00, 0.00]] | [[0.00]] | [0.16] | correct\n", "31 | [[1.00, 1.00, 1.00, 1.00, 1.00]] | [[1.00]] | [0.90] | correct\n", "Total count: 32\n", " correct: 23\n", " incorrect: 9\n", "Total percentage correct: 0.71875\n" ] } ], "source": [ "net.evaluate(tolerance=.2, show=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset inputs and targets\n", "\n", "Inputs and targets in the dataset are represented in the same format as given (as lists, or lists of lists). These formats are automattically converted into an internal format." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "ds = net.dataset" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1.0, 0.0, 0.0, 0.0, 1.0]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.inputs[17]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To see/access the internal format, use the underscore before inputs or targets. This is a numpy array. ConX is designed so that you need not have to use numpy for most network operations." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1., 0., 0., 0., 1.], dtype=float32)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds._inputs[0][17]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Built-in datasets" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['cifar10',\n", " 'cifar100',\n", " 'cmu_faces_full_size',\n", " 'cmu_faces_half_size',\n", " 'cmu_faces_quarter_size',\n", " 'colors',\n", " 'figure_ground_a',\n", " 'fingers',\n", " 'gridfonts',\n", " 'mnist',\n", " 'vmnist']" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cx.Dataset.datasets()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset**: MNIST\n", "\n", "\n", "Original source: http://yann.lecun.com/exdb/mnist/\n", "\n", "The MNIST dataset contains 70,000 images of handwritten digits (zero\n", "to nine) that have been size-normalized and centered in a square grid\n", "of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers\n", "representing grayscale intensities ranging from 0 (black) to 1\n", "(white). The target data consists of one-hot binary vectors of size\n", "10, corresponding to the digit classification categories zero through\n", "nine. Some example MNIST images are shown below:\n", "\n", "![MNIST Images](https://github.com/Calysto/conx/raw/master/data/mnist_images.png)\n", "\n", "**Information**:\n", " * name : MNIST\n", " * length : 70000\n", "\n", "**Input Summary**:\n", " * shape : (28, 28, 1)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (10,)\n", " * range : (0.0, 1.0)\n", "\n" ], "text/plain": [ "**Dataset**: MNIST\n", "\n", "\n", "Original source: http://yann.lecun.com/exdb/mnist/\n", "\n", "The MNIST dataset contains 70,000 images of handwritten digits (zero\n", "to nine) that have been size-normalized and centered in a square grid\n", "of pixels. Each image is a 28 × 28 × 1 array of floating-point numbers\n", "representing grayscale intensities ranging from 0 (black) to 1\n", "(white). The target data consists of one-hot binary vectors of size\n", "10, corresponding to the digit classification categories zero through\n", "nine. Some example MNIST images are shown below:\n", "\n", "![MNIST Images](https://github.com/Calysto/conx/raw/master/data/mnist_images.png)\n", "\n", "**Information**:\n", " * name : MNIST\n", " * length : 70000\n", "\n", "**Input Summary**:\n", " * shape : (28, 28, 1)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (10,)\n", " * range : (0.0, 1.0)\n" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds = cx.Dataset.get('mnist')\n", "ds" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset**: CIFAR-10\n", "\n", "\n", "Original source: https://www.cs.toronto.edu/~kriz/cifar.html\n", "\n", "The CIFAR-10 dataset consists of 60000 32x32 colour images in 10\n", "classes, with 6000 images per class.\n", "\n", "The classes are completely mutually exclusive. There is no overlap\n", "between automobiles and trucks. \"Automobile\" includes sedans, SUVs,\n", "things of that sort. \"Truck\" includes only big trucks. Neither\n", "includes pickup trucks.\n", "\n", "**Information**:\n", " * name : CIFAR-10\n", " * length : 60000\n", "\n", "**Input Summary**:\n", " * shape : (32, 32, 3)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (10,)\n", " * range : (0.0, 1.0)\n", "\n" ], "text/plain": [ "**Dataset**: CIFAR-10\n", "\n", "\n", "Original source: https://www.cs.toronto.edu/~kriz/cifar.html\n", "\n", "The CIFAR-10 dataset consists of 60000 32x32 colour images in 10\n", "classes, with 6000 images per class.\n", "\n", "The classes are completely mutually exclusive. There is no overlap\n", "between automobiles and trucks. \"Automobile\" includes sedans, SUVs,\n", "things of that sort. \"Truck\" includes only big trucks. Neither\n", "includes pickup trucks.\n", "\n", "**Information**:\n", " * name : CIFAR-10\n", " * length : 60000\n", "\n", "**Input Summary**:\n", " * shape : (32, 32, 3)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (10,)\n", " * range : (0.0, 1.0)\n" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds = cx.Dataset.get('cifar10')\n", "ds" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset**: Gridfonts\n", "\n", "\n", "This dataset originates from Douglas Hofstadter's research\n", "group:\n", "\n", "http://goosie.cogsci.indiana.edu/pub/gridfonts.data\n", "\n", "![Gridfont Grid](https://github.com/Calysto/conx/raw/master/data/grid.png)\n", "\n", "These data have been processed to make them neural\n", "network friendly:\n", "\n", "https://github.com/Calysto/conx/blob/master/data/gridfonts.py\n", "\n", "The dataset is composed of letters on a 25 row x 9 column\n", "grid. The inputs and targets are identical, and the labels\n", "contain a string identifying the letter.\n", "\n", "You can read a thesis using part of this dataset here:\n", "https://repository.brynmawr.edu/compsci_pubs/78/\n", "\n", "**Information**:\n", " * name : Gridfonts\n", " * length : 7462\n", "\n", "**Input Summary**:\n", " * shape : (25, 9)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (25, 9)\n", " * range : (0.0, 1.0)\n", "\n" ], "text/plain": [ "**Dataset**: Gridfonts\n", "\n", "\n", "This dataset originates from Douglas Hofstadter's research\n", "group:\n", "\n", "http://goosie.cogsci.indiana.edu/pub/gridfonts.data\n", "\n", "![Gridfont Grid](https://github.com/Calysto/conx/raw/master/data/grid.png)\n", "\n", "These data have been processed to make them neural\n", "network friendly:\n", "\n", "https://github.com/Calysto/conx/blob/master/data/gridfonts.py\n", "\n", "The dataset is composed of letters on a 25 row x 9 column\n", "grid. The inputs and targets are identical, and the labels\n", "contain a string identifying the letter.\n", "\n", "You can read a thesis using part of this dataset here:\n", "https://repository.brynmawr.edu/compsci_pubs/78/\n", "\n", "**Information**:\n", " * name : Gridfonts\n", " * length : 7462\n", "\n", "**Input Summary**:\n", " * shape : (25, 9)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (25, 9)\n", " * range : (0.0, 1.0)\n" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds = cx.Dataset.get(\"gridfonts\")\n", "ds" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset**: CIFAR-100\n", "\n", "\n", "Original source: https://www.cs.toronto.edu/~kriz/cifar.html\n", "\n", "This dataset is just like the CIFAR-10, except it has 100 classes\n", "containing 600 images each. The 100 classes in the CIFAR-100 are grouped\n", "into 20 superclasses. Each image comes with a \"fine\" label (the class\n", "to which it belongs) and a \"coarse\" label (the superclass to which it\n", "belongs). Here is the list of classes in the CIFAR-100:\n", "\n", "Superclass | Classes\n", "-------------------------------|-----------------------------------------------------\n", "aquatic mammals\t | beaver, dolphin, otter, seal, whale\n", "fish | aquarium fish, flatfish, ray, shark, trout\n", "flowers\t | orchids, poppies, roses, sunflowers, tulips\n", "food containers | bottles, bowls, cans, cups, plates\n", "fruit and vegetables | apples, mushrooms, oranges, pears, sweet peppers\n", "household electrical devices | clock, computer keyboard, lamp, telephone, television\n", "household furniture | bed, chair, couch, table, wardrobe\n", "insects\t | bee, beetle, butterfly, caterpillar, cockroach\n", "large carnivores | bear, leopard, lion, tiger, wolf\n", "large man-made outdoor things | bridge, castle, house, road, skyscraper\n", "large natural outdoor scenes | cloud, forest, mountain, plain, sea\n", "large omnivores and herbivores | camel, cattle, chimpanzee, elephant, kangaroo\n", "medium-sized mammals | fox, porcupine, possum, raccoon, skunk\n", "non-insect invertebrates | crab, lobster, snail, spider, worm\n", "people\t | baby, boy, girl, man, woman\n", "reptiles | crocodile, dinosaur, lizard, snake, turtle\n", "small mammals | hamster, mouse, rabbit, shrew, squirrel\n", "trees | maple, oak, palm, pine, willow\n", "vehicles 1 | bicycle, bus, motorcycle, pickup truck, train\n", "vehicles 2 | lawn-mower, rocket, streetcar, tank, tractor\n", "\n", "\n", "**Information**:\n", " * name : CIFAR-100\n", " * length : 60000\n", "\n", "**Input Summary**:\n", " * shape : (32, 32, 3)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (100,)\n", " * range : (0.0, 1.0)\n", "\n" ], "text/plain": [ "**Dataset**: CIFAR-100\n", "\n", "\n", "Original source: https://www.cs.toronto.edu/~kriz/cifar.html\n", "\n", "This dataset is just like the CIFAR-10, except it has 100 classes\n", "containing 600 images each. The 100 classes in the CIFAR-100 are grouped\n", "into 20 superclasses. Each image comes with a \"fine\" label (the class\n", "to which it belongs) and a \"coarse\" label (the superclass to which it\n", "belongs). Here is the list of classes in the CIFAR-100:\n", "\n", "Superclass | Classes\n", "-------------------------------|-----------------------------------------------------\n", "aquatic mammals\t | beaver, dolphin, otter, seal, whale\n", "fish | aquarium fish, flatfish, ray, shark, trout\n", "flowers\t | orchids, poppies, roses, sunflowers, tulips\n", "food containers | bottles, bowls, cans, cups, plates\n", "fruit and vegetables | apples, mushrooms, oranges, pears, sweet peppers\n", "household electrical devices | clock, computer keyboard, lamp, telephone, television\n", "household furniture | bed, chair, couch, table, wardrobe\n", "insects\t | bee, beetle, butterfly, caterpillar, cockroach\n", "large carnivores | bear, leopard, lion, tiger, wolf\n", "large man-made outdoor things | bridge, castle, house, road, skyscraper\n", "large natural outdoor scenes | cloud, forest, mountain, plain, sea\n", "large omnivores and herbivores | camel, cattle, chimpanzee, elephant, kangaroo\n", "medium-sized mammals | fox, porcupine, possum, raccoon, skunk\n", "non-insect invertebrates | crab, lobster, snail, spider, worm\n", "people\t | baby, boy, girl, man, woman\n", "reptiles | crocodile, dinosaur, lizard, snake, turtle\n", "small mammals | hamster, mouse, rabbit, shrew, squirrel\n", "trees | maple, oak, palm, pine, willow\n", "vehicles 1 | bicycle, bus, motorcycle, pickup truck, train\n", "vehicles 2 | lawn-mower, rocket, streetcar, tank, tractor\n", "\n", "\n", "**Information**:\n", " * name : CIFAR-100\n", " * length : 60000\n", "\n", "**Input Summary**:\n", " * shape : (32, 32, 3)\n", " * range : (0.0, 1.0)\n", "\n", "**Target Summary**:\n", " * shape : (100,)\n", " * range : (0.0, 1.0)\n" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds = cx.Dataset.get('cifar100')\n", "ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset Methods\n", "\n", "Class methods:\n", "\n", "* Dataset.datasets() - get a list of all known datasets\n", "* Dataset.get(name) - get a named dataset and return a ConX Dataset\n", "* summary() - display a summary of the dataset\n", "\n", "Instance methods:\n", "\n", "**General operations**:\n", "\n", "* datasets() - get a list of all known datasets\n", "* clear() - clear the current dataset of all data\n", "* copy(dataset) - get a copy of the dataset\n", "\n", "**Constructing datasets**:\n", "\n", "* get(name) - get a named dataset; overwrites previous dataset if any\n", "* load_direct(inputs=None, targets=None, labels=None) - loads internal dataset format directly\n", "* load(pairs=None, inputs=None, targets=None, labels=None) - load by fields\n", "* add(inputs, targets) - add a single [inputs] / [targets] pair\n", "* add_random(count, frange=(-1, 1)) - adds count random patterns to dataset; requires a network\n", "* add_by_function(width, frange, ifunction, tfunction) - adds to inputs with ifunction, and to targets with tfunction\n", "\n", "* slice(start=None, stop=None) - select the data between start and stop; clears split\n", "* shuffle() - shuffle the dataset; shuffles entire set; clears split\n", "* split(split=None) - split the dataset into train/test sets. split=0.1 saves 10% for testing. split amount can be fraction or integer\n", "* chop(amount) - chop this amount from end; amount can be fraction, or integer\n", "\n", "* set_targets_from_inputs(f=None, input_bank=0, target_bank=0) -\n", "* set_inputs_from_targets(f=None, input_bank=0, target_bank=0) -\n", "* set_targets_from_labels(num_classes=None, bank_index=0) -\n", "* rescale_inputs(bank_index, old_range, new_range, new_dtype) -\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dataset Examples\n", "\n", "Dataset.split() will divide the dataset between training and testing sets. You can provide split an integer (to divide at a specific point), or a floating-point value, to divide by a percentage." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "ds.split(20)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "ds.split(.5)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: dataset split reset to 0\n" ] } ], "source": [ "ds.slice(10)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "ds.shuffle()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "ds.chop(5)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "CIFAR-100:\n", "Patterns Shape Range \n", "=================================================================\n", "inputs (32, 32, 3) (0.0, 1.0) \n", "targets (100,) (0.0, 1.0) \n", "=================================================================\n", "Total patterns: 5\n", " Training patterns: 5\n", " Testing patterns: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "ds.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Additional operations\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "ds.set_targets_from_inputs()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "ds.set_inputs_from_targets()" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(32, 32, 3)]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.inputs.shape" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "ds.inputs.reshape(0, (32 * 32 * 3,))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(3072,)]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds.inputs.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Vector Operations\n", "\n", "Each dataset has the following virtual fields:\n", "\n", "* inputs - a complete list of all input vectors\n", "* targets - a complete list of all target vectors\n", "* labels - a complete list of all labels (if any)\n", "\n", "* train_inputs - a list of all input vectors for training\n", "* train_targets - a list of all target vectors for training\n", "* train_labels - a list of all labels (if any) for training\n", "\n", "* test_inputs - a list of all input vectors for testing\n", "* test_targets - a list of all target vectors for testing\n", "* test_labels - a list of all labels (if any) for testing\n", "\n", "You may perform standard list-based operations on these virtual arrays, including:\n", "\n", "* len(FIELD) - length\n", "* FIELD[num] - indexing\n", "* FIELD[START:END] - slice\n", "* FIELD[num, num, num, ...] - selection by index\n", "\n", "In addition, each field has the following methods:\n", "\n", "* get_shape(bank_index=None) - get the shape of a bank\n", "* filter_indices(function) - get a list of indices whose FIELD matches filter(FIELD[index])\n", "* filter(function) - get a list of FIELD[i] where FIELD[i] matches filter(FIELD[i])\n", "* reshape(bank_index, new_shape=None) - change the shape of a bank" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset direct manipulation\n", "\n", "You can also set the internal format directly, given that it is in the correct format:\n", "\n", "* use list of columns for multi-bank inputs or targets\n", "* use np.array(vectors) for single-bank inputs or targets" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "inputs = []\n", "targets = []\n", "\n", "for i in range(2 ** 5):\n", " v = num2bin(i)\n", " inputs.append(v)\n", " targets.append([int(sum(v) % 2 == 1.0)])\n", "\n", "net = cx.Network(\"Even-Odd\", 5, 2, 2, 1)\n", "net.compile(error=\"mse\", optimizer=\"adam\")\n", "net.dataset.load_direct([np.array(inputs)], [np.array(targets)])" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "========================================================\n", "Testing validation dataset with tolerance 0.2...\n", "# | inputs | targets | outputs | result\n", "---------------------------------------\n", "0 | [[0.00, 0.00, 0.00, 0.00, 0.00]] | [[0.00]] | [0.00] | correct\n", "1 | [[0.00, 0.00, 0.00, 0.00, 1.00]] | [[1.00]] | [0.42] | X\n", "2 | [[0.00, 0.00, 0.00, 1.00, 0.00]] | [[1.00]] | [-0.31] | X\n", "3 | [[0.00, 0.00, 0.00, 1.00, 1.00]] | [[0.00]] | [0.11] | correct\n", "4 | [[0.00, 0.00, 1.00, 0.00, 0.00]] | [[1.00]] | [-0.07] | X\n", "5 | [[0.00, 0.00, 1.00, 0.00, 1.00]] | [[0.00]] | [0.35] | X\n", "6 | [[0.00, 0.00, 1.00, 1.00, 0.00]] | [[0.00]] | [-0.37] | X\n", "7 | [[0.00, 0.00, 1.00, 1.00, 1.00]] | [[1.00]] | [0.05] | X\n", "8 | [[0.00, 1.00, 0.00, 0.00, 0.00]] | [[1.00]] | [0.60] | X\n", "9 | [[0.00, 1.00, 0.00, 0.00, 1.00]] | [[0.00]] | [1.02] | X\n", "10 | [[0.00, 1.00, 0.00, 1.00, 0.00]] | [[0.00]] | [0.30] | X\n", "11 | [[0.00, 1.00, 0.00, 1.00, 1.00]] | [[1.00]] | [0.71] | X\n", "12 | [[0.00, 1.00, 1.00, 0.00, 0.00]] | [[0.00]] | [0.54] | X\n", "13 | [[0.00, 1.00, 1.00, 0.00, 1.00]] | [[1.00]] | [0.95] | correct\n", "14 | [[0.00, 1.00, 1.00, 1.00, 0.00]] | [[1.00]] | [0.23] | X\n", "15 | [[0.00, 1.00, 1.00, 1.00, 1.00]] | [[0.00]] | [0.65] | X\n", "16 | [[1.00, 0.00, 0.00, 0.00, 0.00]] | [[1.00]] | [0.30] | X\n", "17 | [[1.00, 0.00, 0.00, 0.00, 1.00]] | [[0.00]] | [0.72] | X\n", "18 | [[1.00, 0.00, 0.00, 1.00, 0.00]] | [[0.00]] | [-0.01] | correct\n", "19 | [[1.00, 0.00, 0.00, 1.00, 1.00]] | [[1.00]] | [0.41] | X\n", "20 | [[1.00, 0.00, 1.00, 0.00, 0.00]] | [[0.00]] | [0.23] | X\n", "21 | [[1.00, 0.00, 1.00, 0.00, 1.00]] | [[1.00]] | [0.65] | X\n", "22 | [[1.00, 0.00, 1.00, 1.00, 0.00]] | [[1.00]] | [-0.07] | X\n", "23 | [[1.00, 0.00, 1.00, 1.00, 1.00]] | [[0.00]] | [0.34] | X\n", "24 | [[1.00, 1.00, 0.00, 0.00, 0.00]] | [[0.00]] | [0.90] | X\n", "25 | [[1.00, 1.00, 0.00, 0.00, 1.00]] | [[1.00]] | [1.32] | X\n", "26 | [[1.00, 1.00, 0.00, 1.00, 0.00]] | [[1.00]] | [0.59] | X\n", "27 | [[1.00, 1.00, 0.00, 1.00, 1.00]] | [[0.00]] | [1.01] | X\n", "28 | [[1.00, 1.00, 1.00, 0.00, 0.00]] | [[1.00]] | [0.83] | correct\n", "29 | [[1.00, 1.00, 1.00, 0.00, 1.00]] | [[0.00]] | [1.25] | X\n", "30 | [[1.00, 1.00, 1.00, 1.00, 0.00]] | [[0.00]] | [0.53] | X\n", "31 | [[1.00, 1.00, 1.00, 1.00, 1.00]] | [[1.00]] | [0.95] | correct\n", "Total count: 32\n", " correct: 6\n", " incorrect: 26\n", "Total percentage correct: 0.1875\n" ] } ], "source": [ "net.evaluate(tolerance=.2, show=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 2 }