{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CIFAR10 CNN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train a simple deep CNN on the CIFAR10 small images dataset.\n", "\n", "Some constants we'll use:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "batch_size = 32\n", "num_classes = 10\n", "epochs = 200\n", "data_augmentation = True\n", "num_predictions = 20" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using Theano backend.\n", "conx, version 3.5.9\n" ] } ], "source": [ "from conx import *" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "net = Network(\"CIRAR10\")\n", "net.add(ImageLayer(\"input\", (32, 32), 3)) # depends on K.image_data_format(), right?\n", "net.add(Conv2DLayer(\"conv1\", 32, (3, 3), padding='same', activation='relu'))\n", "net.add(Conv2DLayer(\"conv2\", 32, (3, 3), activation='relu'))\n", "net.add(MaxPool2DLayer(\"pool1\", pool_size=(2, 2), dropout=0.25))\n", "net.add(Conv2DLayer(\"conv3\", 64, (3, 3), padding='same', activation='relu'))\n", "net.add(Conv2DLayer(\"conv4\", 64, (3, 3), activation='relu'))\n", "net.add(MaxPool2DLayer(\"pool2\", pool_size=(2, 2), dropout=0.25))\n", "net.add(FlattenLayer(\"flatten\"))\n", "net.add(Layer(\"hidden1\", 512, activation='relu', vshape=(16, 32), dropout=0.5))\n", "net.add(Layer(\"output\", num_classes, activation='softmax'))\n", "net.connect()\n", "\n", "# initiate RMSprop optimizer\n", "opt = RMSprop(lr=0.0001, decay=1e-6)\n", "\n", "net.compile(loss='categorical_crossentropy',\n", " optimizer=opt)\n", "\n", "# Let's train the model using RMSprop\n", "net.compile(loss='categorical_crossentropy',\n", " optimizer=opt,\n", " metrics=['accuracy'])\n", "model = net.model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "net.dataset.get(\"cifar10\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset name**: CIFAR-10\n", "\n", "\n", "Original source: https://www.cs.toronto.edu/~kriz/cifar.html\n", "\n", "The CIFAR-10 dataset consists of 60000 32x32 colour images in 10\n", "classes, with 6000 images per class.\n", "\n", "The classes are completely mutually exclusive. There is no overlap\n", "between automobiles and trucks. \"Automobile\" includes sedans, SUVs,\n", "things of that sort. \"Truck\" includes only big trucks. Neither\n", "includes pickup trucks.\n", "\n", "**Dataset Split**:\n", " * training : 60000\n", " * testing : 0\n", " * total : 60000\n", "\n", "**Input Summary**:\n", " * shape : [(32, 32, 3)]\n", " * range : [(0.0, 1.0)]\n", "\n", "**Target Summary**:\n", " * shape : [(10,)]\n", " * range : [(0.0, 1.0)]\n", "\n" ], "text/plain": [ "**Dataset name**: CIFAR-10\n", "\n", "\n", "Original source: https://www.cs.toronto.edu/~kriz/cifar.html\n", "\n", "The CIFAR-10 dataset consists of 60000 32x32 colour images in 10\n", "classes, with 6000 images per class.\n", "\n", "The classes are completely mutually exclusive. There is no overlap\n", "between automobiles and trucks. \"Automobile\" includes sedans, SUVs,\n", "things of that sort. \"Truck\" includes only big trucks. Neither\n", "includes pickup trucks.\n", "\n", "**Dataset Split**:\n", " * training : 60000\n", " * testing : 0\n", " * total : 60000\n", "\n", "**Input Summary**:\n", " * shape : [(32, 32, 3)]\n", " * range : [(0.0, 1.0)]\n", "\n", "**Target Summary**:\n", " * shape : [(10,)]\n", " * range : [(0.0, 1.0)]\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net.dataset.summary()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(60000, 32, 32, 3)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.dataset._inputs[0].shape" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "**Dataset name**: CIFAR-10\n", "\n", "\n", "Original source: https://www.cs.toronto.edu/~kriz/cifar.html\n", "\n", "The CIFAR-10 dataset consists of 60000 32x32 colour images in 10\n", "classes, with 6000 images per class.\n", "\n", "The classes are completely mutually exclusive. There is no overlap\n", "between automobiles and trucks. \"Automobile\" includes sedans, SUVs,\n", "things of that sort. \"Truck\" includes only big trucks. Neither\n", "includes pickup trucks.\n", "\n", "**Dataset Split**:\n", " * training : 10000\n", " * testing : 50000\n", " * total : 60000\n", "\n", "**Input Summary**:\n", " * shape : [(32, 32, 3)]\n", " * range : [(0.0, 1.0)]\n", "\n", "**Target Summary**:\n", " * shape : [(10,)]\n", " * range : [(0.0, 1.0)]\n", "\n" ], "text/plain": [ "**Dataset name**: CIFAR-10\n", "\n", "\n", "Original source: https://www.cs.toronto.edu/~kriz/cifar.html\n", "\n", "The CIFAR-10 dataset consists of 60000 32x32 colour images in 10\n", "classes, with 6000 images per class.\n", "\n", "The classes are completely mutually exclusive. There is no overlap\n", "between automobiles and trucks. \"Automobile\" includes sedans, SUVs,\n", "things of that sort. \"Truck\" includes only big trucks. Neither\n", "includes pickup trucks.\n", "\n", "**Dataset Split**:\n", " * training : 10000\n", " * testing : 50000\n", " * total : 60000\n", "\n", "**Input Summary**:\n", " * shape : [(32, 32, 3)]\n", " * range : [(0.0, 1.0)]\n", "\n", "**Target Summary**:\n", " * shape : [(10,)]\n", " * range : [(0.0, 1.0)]\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net.dataset.split(50000)\n", "net.dataset.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Examine Input as Image" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(32, 32, 3)]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.dataset.inputs.shape" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(10,)]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.dataset.targets.shape" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKAAAACgCAIAAAAErfB6AAAQbUlEQVR4nO2dW6/kRhWFy3bZbfe9+5zuc585zExmmESZCRBQNNyVF3iJeOPfkT+AEEIICQkEQiI8kCgISAjDZCaTc5tz6/bpdtvtO6/1+YEfUNR6W7JdLnu31Utr79plfft73xcKwnCm0pZdqXTs1Sq9tdFW6WTcUenmsKtSz3FVKluBSoUjVTabhyrNCtx3NByo1C5zlaZpqtL1eq1SP/BVWopSpXESqXQw7GOSNU7O0kyljsADOo6j0l4Xb6PTwbtyXcwq4ci1ZWMaNt5VYxpFbeFcYaA1TIA1hwmw5jAB1hzyk08/UXl4daXSMf77hbUBvln2cDSYqnRVQa9FJYRSbXkqjddQCnECoZSX0HpXDnSELzFyUeBkh5Kk1WrxvitcW2Ea1npDpTZkk8ip5gKJlxNR+8zKQqXtNkSWZUOgWRSkwsZ3GK8hKosc1JF4QPMFaw4TYM1hAqw5TIA1hwwkBIvAP7S4TVV1uAULaToZqzRoCAcLIycpHKV1DoVS82QvoM9FJ6uucO1gDEOtyHGy52KoEmaUcDw8cJphknmBWbV5suxgZJ9HCwvyza4h/QqBkakaRbeDJ4pWMWcFVWXz2uXiBkeFgdYwAdYcJsCawwRYc0jfgsPS68H3ub83UulGAC/HrSBJohm8m7LCryeJcSMbRpboM7coKVjCmyWOYo5i3IMkWS6gbjJ6VQltoJpip8ssXp4lmHOJG7s0xUpmLSWFU5riqOfi+e0KLyeN5ioVNAFbNNSKCvLtZgUFar5gzWECrDlMgDWHCbDmkKMWhENA4TCgWTPpI41VVrCF6BEJR1IMMOeVVpQkFE6Svk+ZQuzUDoa6uAhxco6JLGPYQHEJJdgNWHWV4lpHYBq2BbHjtFhItYLkbLsYWda4ds30aJJDZFUCJ4cRRg5jvLqI6nWd4+WYL1hzmABrDhNgzWECrDnkZAil0HOhjHwf1Hbw5x8wqZcXUCgVTaK6hqxo1LKXGYRDVdNvojKqJWygZQavqiwx55j1XAXpcoUbncwwlMu6/36EJ8pfoX4tuYGau7V5T6XT6b5KrR6Seun8WqVRhGncLCGyrm4gOb84wlAllxCYL1hzmABrDhNgzWECrDnk7gQJsr4HW6TbhpyxqH0EDReL9lOaQHTY1FwbPZR3dTrQeosb6JdBH67Qkim/lyc4OUohsjxMSuy1aZm5FCzXoUrTmulROlmDPur+n7z+tkoXZ5CcdcxrN+EJpjFmFUX48FouTj7Yxn2n0y2Vni+gyMwXrDlMgDWHCbDmMAHWHHLcgxsls1ClLRd//u0Wqp/SBGInZ2HRcIh6rpr5sqzEbyvPmWtjw4PTS9QZff4S3s3lEvdl9kzcZh3ZT777lkr3d3Cjn3/0XKV/efZKpY21h9LGEy3DS0wjwpx7PS4YLLlA0sdRjwZi28LRggsVbx3s4kYz1K+ZL1hzmABrDhNgzWECrDnkdIwuBcmMPohFh4XVQEmGf3tpMU/H2qjGTynJIViGI3hVGUu9nx+fqnS2oEnE7KHDiq2+j5OnEhrEn0EKvdbfVunZGEOdhxcqTWM8wsdPn6rUZieJvMPirwHsp0brq8EAYrZXsZ6LqdU6W6j0kNak+YI1hwmw5jAB1hwmwJpDjjYnKh91YWzZ7N8ULrDqLV+hdaddNmqyoDJqmmLdLvKDuQD913MIllWKGiXfR3W+77F2n/0PRg6U4EfPzlVaZLg2HUBkTUaYlSUglPICgjTmUsQV84NZgWlY1JhMpQqXbRlqNuhyuUigYLuumvrUfMGawwRYc5gAaw4TYM0hRaMNJst/Gmgxq9UWME0kfy42lxPm1FytADVZV69gMMVXUHN32BSVHbeET1X14O4epsGzC/b5XFA2SgeJyJ6HB9wY3VXp3dduqfTFl39V6WdPT1TqSUqhGvq0KFitTm/O9TDnij0bGgsMLLaHN1+w5jAB1hwmwJrDBFhzyEbfKCtPeAL8l9UKmamM3QIKmy3PY+imBeneAWRFXeDo7U0Ih7u7UBnxGkf37j9WqVdDVc1v8IDBEOlRcQ2T6GB7R6XhCg7ana++ptL+qE36EPe9xBPNbyDfXMo3u4Y3l7M3BkWVKNnvodGMtFH7Zr5gzWECrDlMgDWHCbDmkKXF+iYWVTf+sQMfycQuu4CeXkKgvThGIbh02Yj9HGVW63Oc/NoUqurdH0DdfH6CDXt6e8h4bm4g5XdxifzgcEh1U7HinIm5i0u4UdIPVXoZnqn05AzmlOvi5Qz7UEpJwh72El+aReFUUXPZ7I5v0TFkttB8wbrDBFhzmABrDhNgzSGHbLVeSIisiG0wa9ay3yzhzrz8EnImiiA6Ah8/prMXMMW2fCTI9vZuq3S4+xWVuktaO0xi7j/+Fg6+glAKCqi5UuABV2woutOGfMvYY8vq4NXtd7jKbwitt7zGQsWLczTGyrl+cJ0htyi4jLHDJqgZN7Vu5BbNF6w5TIA1hwmw5jAB1hxyGeLfXmZIcrks8BHs4S4dLieMoLlGPXhGQ3bCSuYQWdNdZPH2Hn1fpf88Rpn402egT3awhWIY4ujWXSQTbYHuXVkKzTVkq6/FBV5OwGV9O2Pet0TKz32EDhYJba8//+ZXKj0+wjQcr1EZx10g6VXljVI47hdtvmDNYQKsOUyANYcJsOZo7LAnStoija39bJZolezZMGen0sWCGbEU2mdnAAn2zR/+UKX7D95R6S/e/5lKt2khOVzWd/L8c5x853WV+htoxN6pWXA/Q5OGoIJQythe9WoJOpzAbtvYPlRpEmFlos2GDqUHB62RLsy5FNFiZ32rBm3W0AsDrWECrDlMgDWHCbDmkGxiLkr6II16H1YOiZrNSC0m8cYbKErabkOgff3t+yp9+ASqan4Brdcq4JHd2ccWNRVvvD1Fjq9Ys1Vp2NjdB0fzBAqlFFBzn58cq/Qf//xQpU/ewcgb2/DmFkvINxZsic1DSM6qUWaVUUZRrt5chipNlxjafMGawwRYc5gAaw4TYM0hK9oiSQrB4tEzkhJpLMfGv/29bfg+foBfz+HtA5U+/g6sq50Hj1T6t7+8r9JbBxh5+403MckJOivINppDxGvotWQB6+r89Eil83PIqDKHVxX0kPHc5NY4R6cfq3RrB50kipgWYYKqK2uFThJlzQ2xqYSDFov1t9mUotUwHw20hgmw5jAB1hwmwJpDutxPeM4UWMlmCUEbqwsdFmRPaV0dnYUqvfv1H6l0/01QISCj8iV6Jwy40eHk/lsqXUnURn3yMfpVpQmGWiwwq6uTL1XqcCdq38fL2fsKdNOj+8g8Fg7cKNcZgnpw/eSajUxfojq/oXwLfoYRS+HaG7jvFqvbzBesOUyANYcJsOYwAdYcMk24aWALssLiJnquzQYP7PcQdHHyez99T6VPfvyuSvub3Nb4+b9U6vBGIZcxXn7xb5WeLiFJ/vjLX6q0G3DhXgpHaXsL8q3PYv0Xx/C5Ms5qvHuo0vtvfkOlgnXwsxAeWaPV1zxhP/gaUVgnsBcj9tWouQL04RCzMF+w5jAB1hwmwJrDBFhzyKrm/i5syWRxA76iZhEW01h+C9Xcb30DoqPFXvKf/g3JtfkpqtVTtmlfztEY6+jZpyqNavhrbolruxLSr+9DRk1GEFln59wRmhVq8RIC7egFXDAhPsGsIm6SKPGuitZUpdcFXl0QIC/Z5ibegYR8W8ZYp1lwm27zBWsOE2DNYQKsOUyANYcU3O2mKqC5JAu0S6axMi423Bog5ffbX/1apeMtaJDpDkq0spgN0V3oiC53V5ZsGdqhfNuecsvrJcqdAgcjX19eqTRniXmPzVczdv76z8cofD/7DPstpgV757uYc9l4hH1IP9FBFOwWZKNPGTUSmOTDN7DI0XzBmsMEWHOYAGsOE2DNIasKeSuPvo8vuWKwsXExq5AqtpG6uoIrFF2CBjn8l4otuMYjCKXhLhcMlqgaPznFyLXgvjI2Um+N5YQOu4B2fIhK+njCaXAaeWUGnWjzxS5iaL2sBQnW28UTrYJQpcsKmmu9wme50b+j0s2pqcn6f4IJsOYwAdYcJsCaQ9oWnB2/BVukplfVCaBBOr1NlcY5DJeNHpq4Sw6V3aA9fGXj5NiFnNnagjtTZRAdDx6ho8MHf/g9blSjlN/lnjRJhKP9HiwzT0KgOewVEbF4/cUZZFQY4nlTC/X3k/v4tPaGtMxqvI35FSbprSkM9+jcxdyDRxhoDRNgzWECrDlMgDWH9Nj7Kk5hqTisYKqYa4u5m7TD3QlbHkulXAzlsdHCoI+jr7jhYLwHGTU9wLK+kwuk/N745rdVGl1ik8TnT5G1XEWhSqWDJxoMoLksplbPTjDyly/pZLXwRP0t6NPJmCNTr1kzXDuac5HjFKsp94d4Oc8+ha9nvmDNYQKsOUyANYcJsOaQWxPEOL/GRjIJd+tbwY0RtQ3TRNL36ffhsHisnEpWSBcGLq4VGeiHH3yg0jsPIMGOjykrmNNss6uUQ50YBJAzqwgiK0lACxasdQMM9eRraK/q0xQrHDbLZweu5Agiy16i8H3a7qn0a/ffwNEh1ml+dPYCQwkDrWECrDlMgDWHCbDmkLcOkJkaWPh7f3YELXB+Ca8qY5eCbhfKaMVa9rJC1bjD39bsEuJuGUGSrHMM5dSgvS4K7s9fYSniMfd8rmpIsK0JlKBVoaxsHiID2OrgeYcDaB/PwROlrKEX7OO6SnFyFjEDWOHovQPsNb3LXvJHx5Cc15cImfmCNYcJsOYwAdYcJsCaQ/ZHNJj4Fz2ackvoDnJeV+fILa5ZKiU9WDlZo1VEDg2Ss5b9JoG66dAzWsfQTcka6cKMI5ekdY0nihasyeoHpMhpJo29C68xyW4XplhjRyKrgD71JG7UgrQVnodJHt47xDRiDPWnP6Ghxd+fYv8e8wVrDhNgzWECrDlMgDWHlGxq7vdhbI273LuQm8G4AZKJC5YOiRLXBj46Q5UsbS/TUKVeG0O5ErNyHGi9tMZQGbdTrmldNfZqrDPoNbbYEi7tJ+FB64VziKyEiysHw0bbCbwNm08Uc1XA+RV6bM3p6y1XMPJ+98fPcC2EoPmCdYcJsOYwAdYcJsCaQ0ZMVAkHmxV2O1AdbgCJ0qEBMxhw8d0iIUVWK+IiuHzNBlUeMmI+67kKVudL1u57/NG6LbhCloXDbaY42e5BFGxp7wUsOhtC681mUEZLSr/+GE8Us7zrP18gW/rZP9BpfotV8lv73F3axo02mcQ0X7DmMAHWHCbAmsMEWHPI45fgaQjd1JtAZfgBzRoIMjEeQ4NEK3gqYQg6v2aXAogM4VRQRhU3kilLljtV/6tpgcU6eIfV+QntthqPK1yWaBUxqr1KZg9L2l4hm0M0KrRmVKBfPMPzh9dYYZCtcPH2ACVaD29jR0UObL5g3WECrDlMgDWHCbDmkKWLXle597ZK0wqekV2g+skfQL8MJxBoo8ZOfzEMl3CGoqTwCqoqWUEKlQUUmajxu6zYI3TN3Rg9j6lGdltdrnFtwn0AXe451LNhElU2FkjmOebc6nDPIfawH3oY+Y4YqvTNxyjvevDosUoP76GDxbfegZo7PsUCA/MFaw4TYM1hAqw5TIA1x38Bk+bX8S2vl/AAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image = array2image(net.dataset.inputs[0], scale=5.0)\n", "image" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3f5f41f3fadf444aa175904939c7dbf7", "version_major": 2, "version_minor": 0 }, "text/html": [ "

Failed to display Jupyter Widget of type Dashboard.

\n", "

\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "

\n", "

\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "

\n" ], "text/plain": [ "Dashboard(children=(Accordion(children=(HBox(children=(VBox(children=(Select(description='Dataset:', index=1, options=('Test', 'Train'), rows=1, value='Train'), FloatSlider(value=1.0, continuous_update=False, description='Zoom', max=3.0, min=0.5), IntText(value=150, description='Horizontal space between banks:', style=DescriptionStyle(description_width='initial')), IntText(value=30, description='Vertical space between layers:', style=DescriptionStyle(description_width='initial')), HBox(children=(Checkbox(value=False, description='Show Targets', style=DescriptionStyle(description_width='initial')), Checkbox(value=False, description='Errors', style=DescriptionStyle(description_width='initial')))), Select(description='Features:', options=('', 'input', 'conv1', 'conv2', 'pool1', 'conv3', 'conv4', 'pool2'), rows=1, value=''), IntText(value=3, description='Feature columns:', style=DescriptionStyle(description_width='initial')), FloatText(value=2.0, description='Feature scale:', style=DescriptionStyle(description_width='initial'))), layout=Layout(width='100%')), VBox(children=(Select(description='Layer:', index=9, options=('input', 'conv1', 'conv2', 'pool1', 'conv3', 'conv4', 'pool2', 'flatten', 'hidden1', 'output'), rows=1, value='output'), Checkbox(value=True, description='Visible'), Select(description='Colormap:', options=('', 'Accent', 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Vega10', 'Vega10_r', 'Vega20', 'Vega20_r', 'Vega20b', 'Vega20b_r', 'Vega20c', 'Vega20c_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'seismic', 'seismic_r', 'spectral', 'spectral_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'viridis', 'viridis_r', 'winter', 'winter_r'), rows=1, value=''), HTML(value=''), FloatText(value=-1.0, description='Leftmost color maps to:', style=DescriptionStyle(description_width='initial')), FloatText(value=1.0, description='Rightmost color maps to:', style=DescriptionStyle(description_width='initial')), IntText(value=0, description='Feature to show:', style=DescriptionStyle(description_width='initial'))), layout=Layout(width='100%')))),), selected_index=None, _titles={'0': 'CIRAR10'}), VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='Dataset index', layout=Layout(width='100%'), max=9999), Label(value='of 10000', layout=Layout(width='100px'))), layout=Layout(height='40px')), HBox(children=(Button(icon='fast-backward', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='backward', layout=Layout(width='100%'), style=ButtonStyle()), IntText(value=0, layout=Layout(width='100%')), Button(icon='forward', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='fast-forward', layout=Layout(width='100%'), style=ButtonStyle()), Button(description='Play', icon='play', layout=Layout(width='100%'), style=ButtonStyle()), Button(icon='refresh', layout=Layout(width='25%'), style=ButtonStyle())), layout=Layout(height='50px', width='100%'))), layout=Layout(width='100%')), HTML(value='

', layout=Layout(justify_content='center', overflow_x='auto', overflow_y='auto', width='95%')), Output()))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", "require(['base/js/namespace'], function(Jupyter) {\n", " Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {\n", " comm.on_msg(function(msg) {\n", " var data = msg[\"content\"][\"data\"];\n", " var images = document.getElementsByClassName(data[\"class\"]);\n", " for (var i = 0; i < images.length; i++) {\n", " if (data[\"href\"]) {\n", " images[i].setAttributeNS(null, \"href\", data[\"href\"]);\n", " }\n", " if (data[\"src\"]) {\n", " images[i].setAttributeNS(null, \"src\", data[\"src\"]);\n", " }\n", " }\n", " });\n", " });\n", "});\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net.dashboard()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.0949893519282341,\n", " 0.09791158884763718,\n", " 0.10104776173830032,\n", " 0.10612580180168152,\n", " 0.09653763473033905,\n", " 0.10584577172994614,\n", " 0.0927908718585968,\n", " 0.10790029913187027,\n", " 0.0942075327038765,\n", " 0.10264338552951813]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.propagate(net.dataset.inputs[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Let Keras Take over from here" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "import keras\n", "from keras.datasets import cifar10\n", "from keras.preprocessing.image import ImageDataGenerator" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "========================================================================\n", " | Training | Training | Validate | Validate \n", "Epochs | Error | Accuracy | Error | Accuracy \n", "------ | --------- | --------- | --------- | --------- \n", "# 1 | 2.14204 | 0.20052 | 1.97837 | 0.29421 \n" ] } ], "source": [ "net.train(plot=True)" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1, 50000, 32, 32, 3), (1, 50000, 10))" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shape(x_test), shape(y_test)" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "model = net.model" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "(x_train, y_train), (x_test, y_test) = net.dataset._split_data()" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1, 10000, 32, 32, 3), (1, 10000, 10), (1, 50000, 32, 32, 3), (1, 50000, 10))" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shape(x_train), shape(y_train), shape(x_test), shape(y_test)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using real-time data augmentation.\n", "Epoch 1/200\n", "313/312 [==============================] - 219s 700ms/step - loss: 1.6980 - acc: 0.3732 - val_loss: 1.5751 - val_acc: 0.4305\n", "Epoch 2/200\n", "313/312 [==============================] - 225s 718ms/step - loss: 1.6487 - acc: 0.3901 - val_loss: 1.5533 - val_acc: 0.4324\n", "Epoch 3/200\n", "311/312 [============================>.] - ETA: 0s - loss: 1.6029 - acc: 0.4136" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m validation_data=(x_test, y_test))\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 85\u001b[0m warnings.warn('Update your `' + object_name +\n\u001b[1;32m 86\u001b[0m '` call to the Keras 2 API: ' + signature, stacklevel=2)\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_original_function\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 2142\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2143\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mval_sample_weights\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2144\u001b[0;31m verbose=0)\n\u001b[0m\u001b[1;32m 2145\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_outs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2146\u001b[0m \u001b[0mval_outs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mval_outs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(self, x, y, batch_size, verbose, sample_weight, steps)\u001b[0m\n\u001b[1;32m 1725\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1726\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1727\u001b[0;31m steps=steps)\n\u001b[0m\u001b[1;32m 1728\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1729\u001b[0m def predict(self, x,\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36m_test_loop\u001b[0;34m(self, f, ins, batch_size, verbose, steps)\u001b[0m\n\u001b[1;32m 1368\u001b[0m \u001b[0mins_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_slice_arrays\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_ids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1369\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1370\u001b[0;31m \u001b[0mbatch_outs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1371\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_outs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1372\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_index\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/keras/backend/theano_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 1221\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1222\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1223\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/theano/compile/function_module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 882\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 883\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 884\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0moutput_subset\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 885\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_subset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_subset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 886\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/theano/ifelse.py\u001b[0m in \u001b[0;36mthunk\u001b[0;34m()\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 246\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0mthunk\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 247\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcompute_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcond\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 248\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "if not data_augmentation:\n", " print('Not using data augmentation.')\n", " model.fit(x_train, y_train,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " validation_data=(x_test, y_test),\n", " shuffle=True)\n", "else:\n", " print('Using real-time data augmentation.')\n", " # This will do preprocessing and realtime data augmentation:\n", " datagen = ImageDataGenerator(\n", " featurewise_center=False, # set input mean to 0 over the dataset\n", " samplewise_center=False, # set each sample mean to 0\n", " featurewise_std_normalization=False, # divide inputs by std of the dataset\n", " samplewise_std_normalization=False, # divide each input by its std\n", " zca_whitening=False, # apply ZCA whitening\n", " rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)\n", " width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)\n", " height_shift_range=0.1, # randomly shift images vertically (fraction of total height)\n", " horizontal_flip=True, # randomly flip images\n", " vertical_flip=False) # randomly flip images\n", "\n", " # Compute quantities required for feature-wise normalization\n", " # (std, mean, and principal components if ZCA whitening is applied).\n", " datagen.fit(x_train[0])\n", "\n", " # Fit the model on the batches generated by datagen.flow().\n", " model.fit_generator(datagen.flow(x_train[0], y_train[0],\n", " batch_size=batch_size),\n", " steps_per_epoch=x_train[0].shape[0] // batch_size,\n", " epochs=epochs,\n", " validation_data=(x_test, y_test))" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model Accuracy = 0.43\n", "Actual Label = [ 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.07781513 0.11049026 0.0632093 0.02424265 0.10382453 0.04574762\n", " 0.03441789 0.42630658 0.02449505 0.08945103]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] vs. Predicted Label = [ 0.02727769 0.19264257 0.02373797 0.01587937 0.0090649 0.00988456\n", " 0.06715904 0.01502883 0.06517704 0.57414806]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] vs. Predicted Label = [ 0.05388747 0.72588551 0.00831101 0.00158087 0.00303652 0.00182245\n", " 0.00224504 0.00250599 0.10607026 0.0946549 ]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] vs. Predicted Label = [ 0.01729291 0.19587232 0.07844362 0.10607366 0.05843127 0.09728301\n", " 0.20552364 0.0685864 0.01614707 0.15634608]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] vs. Predicted Label = [ 0.04557527 0.47399527 0.04536494 0.04711241 0.02142892 0.03853729\n", " 0.02306257 0.03526938 0.04544082 0.22421315]\n", "Actual Label = [ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.01227825 0.0398778 0.04749325 0.05364216 0.03648441 0.07180805\n", " 0.13807741 0.3137106 0.00615273 0.28047535]\n", "Actual Label = [ 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.00148747 0.00222605 0.03411757 0.24753405 0.05446577 0.16512007\n", " 0.44495872 0.04444287 0.00107558 0.00457185]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] vs. Predicted Label = [ 0.00741599 0.02212191 0.05528059 0.25858176 0.04116985 0.24418072\n", " 0.209802 0.12378619 0.00374054 0.03392046]\n", "Actual Label = [ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.00456545 0.00249086 0.09983079 0.03194337 0.23732299 0.03509691\n", " 0.14823847 0.42950818 0.00175047 0.00925251]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] vs. Predicted Label = [ 0.01474431 0.00438144 0.08705295 0.17581116 0.1401801 0.10807402\n", " 0.40283573 0.04501324 0.00754052 0.01436656]\n", "Actual Label = [ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.20857221 0.03858071 0.11660673 0.02005922 0.05146967 0.01053343\n", " 0.02061583 0.02458823 0.28485668 0.22411728]\n", "Actual Label = [ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.04369825 0.00953077 0.28183886 0.03723752 0.39792955 0.04845857\n", " 0.05786296 0.08070038 0.03028672 0.01245641]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] vs. Predicted Label = [ 0.03174659 0.14564584 0.01506849 0.00770176 0.00636712 0.00667131\n", " 0.01012393 0.01783629 0.16324712 0.59559155]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] vs. Predicted Label = [ 0.01913743 0.42621526 0.00817715 0.00716725 0.0026362 0.00949341\n", " 0.01235665 0.01647761 0.10527103 0.39306799]\n", "Actual Label = [ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] vs. Predicted Label = [ 0.28236881 0.04135361 0.1647729 0.03696084 0.15157856 0.03136211\n", " 0.02248296 0.04669645 0.17347305 0.04895073]\n", "Actual Label = [ 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.05246869 0.45963278 0.00994109 0.00715726 0.0059837 0.0033786\n", " 0.01398534 0.00770203 0.06684003 0.37291047]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] vs. Predicted Label = [ 0.04716396 0.04682264 0.11449952 0.13412473 0.13372231 0.10707654\n", " 0.17844412 0.09313755 0.0549487 0.09005994]\n", "Actual Label = [ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.03275985 0.08420437 0.18667556 0.0925509 0.07995293 0.05029331\n", " 0.30199629 0.04829691 0.01449488 0.108775 ]\n", "Actual Label = [ 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.00613854 0.00442893 0.12136693 0.05807019 0.19370349 0.09075771\n", " 0.36420116 0.14638041 0.00331027 0.0116424 ]\n", "Actual Label = [ 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.] vs. Predicted Label = [ 0.01641119 0.00571625 0.2858237 0.06443242 0.29150918 0.04093195\n", " 0.22316463 0.05377265 0.00830103 0.00993697]\n", "Actual Label = [ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] vs. Predicted Label = [ 0.08868346 0.36420235 0.01048361 0.00159288 0.00550383 0.00113587\n", " 0.00297236 0.00778512 0.1336575 0.38398302]\n" ] } ], "source": [ "# Evaluate model with test data set and share sample prediction results\n", "evaluation = model.evaluate_generator(datagen.flow(x_test[0], y_test[0],\n", " batch_size=batch_size),\n", " steps=x_test[0].shape[0] // batch_size)\n", "\n", "print('Model Accuracy = %.2f' % (evaluation[1]))\n", "\n", "predict_gen = model.predict_generator(datagen.flow(x_test[0], y_test[0],\n", " batch_size=batch_size),\n", " steps=x_test[0].shape[0] // batch_size)\n", "\n", "for predict_index, predicted_y in enumerate(predict_gen):\n", " #actual_label = labels['label_names'][np.argmax(y_test[predict_index])]\n", " #predicted_label = labels['label_names'][np.argmax(predicted_y)]\n", " #print('Actual Label = %s vs. Predicted Label = %s' % (actual_label,\n", " # predicted_label))\n", " print('Actual Label = %s vs. Predicted Label = %s' % (y_test[0][predict_index],\n", " predicted_y))\n", " if predict_index == num_predictions:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }