{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Object Detection & classification",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zyBhfd53HCFk",
        "outputId": "60ee2486-4823-4817-f1e1-adf2c7eb8c42"
      },
      "source": [
        "!apt-get install libcairo2-dev libjpeg-dev libgif-dev\n",
        "!pip install pycairo\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib\n",
        "%matplotlib inline"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Reading package lists... Done\n",
            "Building dependency tree       \n",
            "Reading state information... Done\n",
            "libjpeg-dev is already the newest version (8c-2ubuntu8).\n",
            "libcairo2-dev is already the newest version (1.15.10-2ubuntu0.1).\n",
            "libgif-dev is already the newest version (5.1.4-2ubuntu0.1).\n",
            "0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.\n",
            "Requirement already satisfied: pycairo in /usr/local/lib/python3.7/dist-packages (1.20.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sFI6vDGWHKk1"
      },
      "source": [
        "import cairo\n",
        "num_imgs = 1000\n",
        "\n",
        "img_size = 32\n",
        "min_object_size = 4\n",
        "max_object_size = 16\n",
        "num_objects = 2\n",
        "\n",
        "bboxes = np.zeros((num_imgs, num_objects, 4))\n",
        "imgs = np.zeros((num_imgs, img_size, img_size, 4), dtype=np.uint8)  # format: BGRA\n",
        "shapes = np.zeros((num_imgs, num_objects), dtype=int)\n",
        "num_shapes = 3\n",
        "shape_labels = ['rectangle', 'circle', 'triangle']\n",
        "colors = np.zeros((num_imgs, num_objects), dtype=int)\n",
        "num_colors = 3\n",
        "color_labels = ['r', 'g', 'b']\n",
        "\n",
        "for i_img in range(num_imgs):\n",
        "    surface = cairo.ImageSurface.create_for_data(imgs[i_img], cairo.FORMAT_ARGB32, img_size, img_size)\n",
        "    cr = cairo.Context(surface)\n",
        "\n",
        "    # Fill background white.\n",
        "    cr.set_source_rgb(1, 1, 1)\n",
        "    cr.paint()\n",
        "    \n",
        "    # TODO: Try no overlap here.\n",
        "    # Draw random shapes.\n",
        "    for i_object in range(num_objects):\n",
        "        shape = np.random.randint(num_shapes)\n",
        "        shapes[i_img, i_object] = shape\n",
        "        if shape == 0:  # rectangle\n",
        "            w, h = np.random.randint(min_object_size, max_object_size, size=2)\n",
        "            x = np.random.randint(0, img_size - w)\n",
        "            y = np.random.randint(0, img_size - h)\n",
        "            bboxes[i_img, i_object] = [x, y, w, h]\n",
        "            cr.rectangle(x, y, w, h)            \n",
        "        elif shape == 1:  # circle   \n",
        "            r = 0.5 * np.random.randint(min_object_size, max_object_size)\n",
        "            x = np.random.randint(r, img_size - r)\n",
        "            y = np.random.randint(r, img_size - r)\n",
        "            bboxes[i_img, i_object] = [x - r, y - r, 2 * r, 2 * r]\n",
        "            cr.arc(x, y, r, 0, 2*np.pi)\n",
        "        elif shape == 2:  # triangle\n",
        "            w, h = np.random.randint(min_object_size, max_object_size, size=2)\n",
        "            x = np.random.randint(0, img_size - w)\n",
        "            y = np.random.randint(0, img_size - h)\n",
        "            bboxes[i_img, i_object] = [x, y, w, h]\n",
        "            cr.move_to(x, y)\n",
        "            cr.line_to(x+w, y)\n",
        "            cr.line_to(x+w, y+h)\n",
        "            cr.line_to(x, y)\n",
        "            cr.close_path()\n",
        "        \n",
        "        # TODO: Introduce some variation to the colors by adding a small random offset to the rgb values.\n",
        "        color = np.random.randint(num_colors)\n",
        "        colors[i_img, i_object] = color\n",
        "        max_offset = 0.3\n",
        "        r_offset, g_offset, b_offset = max_offset * 2. * (np.random.rand(3) - 0.5)\n",
        "        if color == 0:\n",
        "            cr.set_source_rgb(1-max_offset+r_offset, 0+g_offset, 0+b_offset)\n",
        "        elif color == 1:\n",
        "            cr.set_source_rgb(0+r_offset, 1-max_offset+g_offset, 0+b_offset)\n",
        "        elif color == 2:\n",
        "            cr.set_source_rgb(0+r_offset, 0-max_offset+g_offset, 1+b_offset)\n",
        "        cr.fill()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1b22TdRRWW4F"
      },
      "source": [
        "imgs = imgs[..., 2::-1]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "elXj1uS7WZjK",
        "outputId": "20b97507-b299-4232-eaf6-985cc2c30730"
      },
      "source": [
        "imgs.shape"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1000, 32, 32, 3)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 265
        },
        "id": "D8xyIRNzWp64",
        "outputId": "8d5fbe48-3304-4b65-9d8d-42ecbd97d2c6"
      },
      "source": [
        "i = 3\n",
        "plt.imshow(imgs[i], interpolation='none', origin='lower', extent=[0, img_size, 0, img_size])\n",
        "for bbox, shape, color in zip(bboxes[i], shapes[i], colors[i]):\n",
        "    plt.gca().add_patch(matplotlib.patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], ec='k', fc='none'))\n",
        "    plt.annotate(shape_labels[shape], (bbox[0], bbox[1] + bbox[3] + 0.7), color=color_labels[color], clip_on=False)\n",
        "# surface.write_to_png(\"circle.png\")"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAARI0lEQVR4nO3df4xU5X7H8ffXvVwgQFyVlVCVXS41Kv64C45bEX9QULTGn81tvWpuiRr3tsFEEktCuLVXkxrXWjXGVCsWA2moXL1IsXrbQs0aa23BBVZlIQUxu9bNCmt1FSxoF7794zlkl2WGHWfOzCzzfF7JZM4858w5X074zHPOM2f2mLsjItXvpEoXICLlobCLREJhF4mEwi4SCYVdJBI/KOfGJk6c6A0NDeXcpEhUOjs7+fzzzy3bvLKGvaGhgba2tnJuUiQqmUwm5zwdxotEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEYtiwm9kYM9tkZu+bWYeZPZy0TzWzjWb2kZn9ysx+WPpyRaRQ+fTs3wJz3f3HQCNwnZldCjwGPOXuvw18CdxTujJFpFjDht2D/cnLUcnDgbnAr5P2lcAtJalQRFKR1zm7mdWYWTuwF9gA7Ab63L0/WeRT4Iwc7202szYza+vt7U2jZhEpQF5hd/dD7t4InAk0AefmuwF3X+buGXfP1NXVFVimiBTre43Gu3sf0ArMAmrN7MhPZM8EulOuTURSlM9ofJ2Z1SbTY4FrgB2E0P8kWWwBsK5URYpI8fL54xWTgZVmVkP4cHjZ3V83s+3AajP7C2ArsLyEdYpIkYYNu7t/AMzI0v4x4fxdRE4AuoJOJBIKu0gkFHapXn198Oyzuedfdln623zrLbjhhvTXmwKFXapXrrD3J9eCvftueeupMIVdqteSJbB7NzQ2wiWXwBVXwE03wfTpYf748eF5/36YNw9mzoQLL4R1ybfInZ1w3nlw771w/vkwfz4cOBDmvfceXHRRWPfixXDBBcdu/5tv4O67oakJZswYWG+FKOxSvVpaYNo0aG+Hxx+HLVvg6adh586jlxszBtauDfNbW+GBB+DIrcx37YKFC6GjA2prYc2a0H7XXfD882HdNTXZt//IIzB3LmzaFNa7eHH4AKgQhV3i0dQEU6ce2+4OS5eGnvrqq6G7G/bsCfOmTg29N8DFF4fevq8P9u2DWbNC+x13ZN/e+vXhA6exEebMgYMH4ZNP0v5X5a2sd4QRqahx47K3r1oFvb2weTOMGgUNDSGYAKNHDyxXUzNwGJ8P93AkcM45BZecJvXsUr0mTAg98HC++gpOPz0EvbUVurqOv3xtbVj3xo3h9erV2Ze79lp45pmBU4KtW/OvvQTUs0v1Ou00mD07DJ6NHQuTJmVf7s474cYbw+BcJgPn5vGjzuXLw8DdSSfBVVfByScfu8yDD8KiReH04PDhcErw+uvF/ZuKYH7kU6cMMpmM68aOUhX27x8YzW9pgZ6eMPhXYZlMhra2tsrfxVWkarzxBjz6aPjOvr4eVqyodEXDUthFCnHbbeFxAtEAnUgkFHapWg0NDZjZCftoaGhIdX/oMF6qVldXF+UcgC7Ud/+Q/cc6o29dmOp21LOLREJhF4mEwi4SCYVdJBIKu0gkNBovUgb9G/8p57xvX/hFWWpQzy4SCYVdJBIKu0gk8rnX21lm1mpm282sw8zuT9ofMrNuM2tPHteXvlwRKVQ+A3T9wAPuvsXMJgCbzWxDMu8pd/+r0pUnImnJ515vPUBPMr3PzHYAZ5S6MBFJ1/f66s3MGgg3edwIzAbuM7M/AtoIvf+XWd7TDDQDTJkypchyRUauQzu35Jx3sOWu3G88fKgE1Rwr7wE6MxsPrAEWufvXwHPANKCR0PM/ke197r7M3TPunqmrq0uhZBEpRF5hN7NRhKCvcvdXAdx9j7sfcvfDwAvo9s0iI1o+o/EGLAd2uPuTg9onD1rsVmBb+uWJSFryOWefDfwM+NDM2pO2pcDtZtYIONAJ/LwkFYpIKvIZjX8HyPanaX+TfjkiUiq6gk4kEvrVm8j3cPizzpzzDvzyD3LO84P/W4Jqvh/17CKRUNhFIqGwi0RCYReJhMIuEgmNxosM4V9/kXPegT/7/dzv6+stRTmpUc8uEgmFXSQSCrtIJBR2kUgo7CKRUNhFIqGv3iRO3x3IOevAw7flnHe4+6NSVFMW6tlFIqGwi0RCYReJhMIuEgmFXSQSCrtIJPTVm1Q3P5y1+cDjzTnfcmj7xlJVU1Hq2UUiobCLREJhF4lEPvd6O8vMWs1su5l1mNn9SfupZrbBzHYlz6eUvlwRKVQ+PXs/4d7r04FLgYVmNh1YArzp7mcDbyavRWSEGjbs7t7j7luS6X3ADuAM4GZgZbLYSuCWUhUpIsX7Xl+9mVkDMAPYCExy955k1mfApBzvaQaaAaZMmVJonSIF+XbZ0qzt/e+sK3MllZf3AJ2ZjQfWAIvc/evB89zdCbduPoa7L3P3jLtn6urqiipWRAqXV9jNbBQh6Kvc/dWkeY+ZTU7mTwb2lqZEEUlDPqPxBiwHdrj7k4NmvQYsSKYXAPEdF4mcQPI5Z58N/Az40Mzak7alQAvwspndA3QBf1iaEkUkDcOG3d3fASzH7HnpliMipaIr6EQioV+9SdWqr69nzB8/VukyClZfX5/q+hR2qVqdnZ2VLmFE0WG8SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSKhsItEQmEXiUQ+93p70cz2mtm2QW0PmVm3mbUnj+tLW6aIFCufnn0FcF2W9qfcvTF5/CbdskQkbcOG3d3fBr4oQy0iUkLFnLPfZ2YfJIf5p+RayMyazazNzNp6e3uL2JyIFKPQsD8HTAMagR7giVwLuvsyd8+4e6aurq7AzYlIsQoKu7vvcfdD7n4YeAFoSrcsEUlbQWE3s8mDXt4KbMu1rIiMDMPexdXMXgLmABPN7FPgl8AcM2sEHOgEfl7CGkUkBcOG3d1vz9K8vAS1iEgJ6Qo6kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgMG/bklsx7zWzboLZTzWyDme1KnnPesllERoZ8evYVwHVD2pYAb7r72cCbyWsRGcGGDbu7vw18MaT5ZmBlMr0SuCXlukQkZYWes09y955k+jNgUq4FzazZzNrMrK23t7fAzYlIsYoeoHN3J9y6Odf8Ze6ecfdMXV1dsZsTkQIVGvY9ZjYZIHnem15JIlIKhYb9NWBBMr0AWJdOOSJSKvl89fYS8B/AOWb2qZndA7QA15jZLuDq5LWIjGA/GG4Bd789x6x5KdciIiWkK+hEIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIlG1Ye/rg2efzT3/ssvS3+Zbb8ENN6S/XpE0RBf2/v7w/O675a1HpNKqNuxLlsDu3dDYCJdcAldcATfdBNOnh/njx4fn/fth3jyYORMuvBDWJb/f6+yE886De++F88+H+fPhwIEw77334KKLwroXL4YLLjh2+998A3ffDU1NMGPGwHpFKqVqw97SAtOmQXs7PP44bNkCTz8NO3cevdyYMbB2bZjf2goPPACe/CmOXbtg4ULo6IDaWlizJrTfdRc8/3xYd01N9u0/8gjMnQubNoX1Ll4cPgBEKqVqwz5UUxNMnXpsuzssXRp66quvhu5u2LMnzJs6NfTeABdfHHr7vj7Ytw9mzQrtd9yRfXvr14cPnMZGmDMHDh6ETz5J+18lkr9hf+JaLcaNy96+ahX09sLmzTBqFDQ0hGACjB49sFxNzcBhfD7cw5HAOecUXLJIqqq2Z58wIfTAw/nqKzj99BD01lbo6jr+8rW1Yd0bN4bXq1dnX+7aa+GZZwZOCbZuzb92kVKo2p79tNNg9uwweDZ2LEzK8fdv77wTbrwxDM5lMnDuucOve/nyMHB30klw1VVw8snHLvPgg7BoUTg9OHw4nBK8/npx/yaRYph7zj8Mm7pMJuNtbW1l216p7N8/MJrf0gI9PWHwT6TSMpkMbW1tlm1e1fbspfTGG/Doo+E7+/p6WLGi0hWJDE9hL8Btt4WHyImkagfoRORoVduzNzQ00DXc0LoMq76+ns7OzkqXISmo2rB3dXVRzsHHkaKrI/vFAG+/8j853/P2K0Pv2zngb7f/uOiaZGTQYbxIJBR2kUgUdRhvZp3APuAQ0O/umTSKEpH0pXHO/rvu/nkK6xGREtJhvEgkiu3ZHVhvZg487+7Lhi5gZs1AM8CUKVOK3Fx16fn425zz/vMfv8w579+OM3re8e95/PpHolRsz365u88Efg9YaGZXDl3A3Ze5e8bdM3V1dUVuTkQKVVTY3b07ed4LrAWa0ihKRNJXcNjNbJyZTTgyDcwHtqVVmIikq5hz9knAWjM7sp6/d/d/TqUqEUldwWF3948BXUspcoLQV28ikajaH8KUQq6vyvQ1mZwI1LOLREJhF4mEwi4SCYVdJBIKu0gkFHaRSFT1V2/6qkxkgHp2kUgo7CKRUNhFIqGwi0RCYReJRNWOxtfX1/Nb08ZUuowTXn19faVLkJRUbdh1fzKRo+kwXiQSCrtIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEgmFXSQSRYXdzK4zs/8ys4/MbElaRYlI+oq5sWMN8NeE2zVPB243s+lpFSYi6SqmZ28CPnL3j939O2A1cHM6ZYlI2or5IcwZwH8Pev0p8DtDFzKzZqA5efmtmY2E2zpPBD6vdBGojqFUx9EKqSPnzxRL/qs3d18GLAMwszZ3z5R6m8NRHaojxjqKOYzvBs4a9PrMpE1ERqBiwv4ecLaZTTWzHwI/BV5LpywRSVvBh/Hu3m9m9wH/AtQAL7p7xzBvW1bo9lKmOo6mOo5WlXWYu6e5PhEZoXQFnUgkFHaRSJQl7CPpsloz6zSzD82s3czayrjdF81s7+DrDMzsVDPbYGa7kudTKlTHQ2bWneyTdjO7vsQ1nGVmrWa23cw6zOz+pL2s++M4dZR7f4wxs01m9n5Sx8NJ+1Qz25jk5lfJQHjh3L2kD8Lg3W7gR8APgfeB6aXe7nHq6QQmVmC7VwIzgW2D2v4SWJJMLwEeq1AdDwF/WsZ9MRmYmUxPAHYSLrku6/44Th3l3h8GjE+mRwEbgUuBl4GfJu1/A/xJMdspR8+uy2oBd38bGHqL2JuBlcn0SuCWCtVRVu7e4+5bkul9wA7CFZll3R/HqaOsPNifvByVPByYC/w6aS96f5Qj7Nkuqy37Dh3EgfVmtjm5lLeSJrl7TzL9GTCpgrXcZ2YfJIf5JT+dOMLMGoAZhN6sYvtjSB1Q5v1hZjVm1g7sBTYQjob73L0/WaTo3MQ4QHe5u88k/FpvoZldWemCIHy6Ez6IKuE5YBrQCPQAT5Rjo2Y2HlgDLHL3rwfPK+f+yFJH2feHux9y90bClahNwLlpb6McYR9Rl9W6e3fyvBdYS9ixlbLHzCYDJM97K1GEu+9J/rMdBl6gDPvEzEYRArbK3V9Nmsu+P7LVUYn9cYS79wGtwCyg1syOXPhWdG7KEfYRc1mtmY0zswlHpoH5QCV/hfcasCCZXgCsq0QRRwKWuJUS7xMzM2A5sMPdnxw0q6z7I1cdFdgfdWZWm0yPBa4hjB+0Aj9JFit+f5RptPF6wkjnbuAX5RrlzFLHjwjfBrwPdJSzFuAlwiHh/xHOv+4BTgPeBHYB/wqcWqE6/g74EPiAELjJJa7hcsIh+gdAe/K4vtz74zh1lHt/XARsTba3DfjzQf9fNwEfAa8Ao4vZji6XFYlEjAN0IlFS2EUiobCLREJhF4mEwi4SCYVdJBIKu0gk/h9HlHmjFp3pYwAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TilaovYy5gRI",
        "outputId": "5ca01987-ea9d-4143-b083-8af98f1ab83a"
      },
      "source": [
        "X = (imgs - 128.) / 255.\n",
        "X.shape, np.mean(X), np.std(X)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "((1000, 32, 32, 3), 0.40640019786560416, 0.2643851084254905)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WjPWoR6P5m7m",
        "outputId": "d6f4ad87-0871-42a0-a1ca-b2c815f4fab9"
      },
      "source": [
        "colors_onehot = np.zeros((num_imgs, num_objects, num_colors))\n",
        "for i_img in range(num_imgs):\n",
        "    for i_object in range(num_objects):\n",
        "        colors_onehot[i_img, i_object, colors[i_img, i_object]] = 1\n",
        "\n",
        "shapes_onehot = np.zeros((num_imgs, num_objects, num_shapes))\n",
        "for i_img in range(num_imgs):\n",
        "    for i_object in range(num_objects):\n",
        "        shapes_onehot[i_img, i_object, shapes[i_img, i_object]] = 1\n",
        "        \n",
        "y = np.concatenate([bboxes / img_size, shapes_onehot, colors_onehot], axis=-1).reshape(num_imgs, -1)\n",
        "y.shape, np.all(np.argmax(colors_onehot, axis=-1) == colors)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "((1000, 20), True)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_O6XUxO86b60",
        "outputId": "9bc521d7-6685-4127-ae52-d5621f41f332"
      },
      "source": [
        "y[0]"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([0.4375 , 0.15625, 0.40625, 0.3125 , 0.     , 0.     , 1.     ,\n",
              "       0.     , 0.     , 1.     , 0.65625, 0.53125, 0.15625, 0.25   ,\n",
              "       1.     , 0.     , 0.     , 1.     , 0.     , 0.     ])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h3SWVFsY6pDq"
      },
      "source": [
        "i = int(0.8 * num_imgs)\n",
        "train_X = X[:i]\n",
        "test_X = X[i:]\n",
        "train_y = y[:i]\n",
        "test_y = y[i:]\n",
        "test_imgs = imgs[i:]\n",
        "test_bboxes = bboxes[i:]\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gVoFNlek6uxM"
      },
      "source": [
        "from keras.models import Sequential\n",
        "from keras.layers import Dense, Activation, Dropout, Conv2D, Convolution2D, MaxPooling2D, Flatten\n",
        "\n",
        "filter_size = 3\n",
        "pool_size = 2\n",
        "\n",
        "model = Sequential([\n",
        "        Conv2D(32,5, input_shape=X.shape[1:],  activation='relu'), \n",
        "        MaxPooling2D(pool_size=(pool_size, pool_size)), \n",
        "        Conv2D(64, filter_size, activation='relu'), \n",
        "        MaxPooling2D(pool_size=(pool_size, pool_size)), \n",
        "        Conv2D(128, filter_size, activation='relu'),\n",
        "        Conv2D(128, filter_size, activation='relu'), \n",
        "        Flatten(), \n",
        "        Dropout(0.4), \n",
        "        Dense(256, activation='relu'), \n",
        "        Dropout(0.4), \n",
        "        Dense(y.shape[-1])\n",
        "    ]) \n",
        "\n",
        "model.compile('adadelta', 'mse')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ERIpAdm18Ggc"
      },
      "source": [
        "# Flip bboxes during training.\n",
        "# Note: The validation loss is always quite big here because we don't flip the bounding boxes for the validation data. \n",
        "def IOU(bbox1, bbox2):\n",
        "    '''Calculate overlap between two bounding boxes [x, y, w, h] as the area of intersection over the area of unity'''\n",
        "    x1, y1, w1, h1 = bbox1[0], bbox1[1], bbox1[2], bbox1[3]  # TODO: Check if its more performant if tensor elements are accessed directly below.\n",
        "    x2, y2, w2, h2 = bbox2[0], bbox2[1], bbox2[2], bbox2[3]\n",
        "\n",
        "    w_I = min(x1 + w1, x2 + w2) - max(x1, x2)\n",
        "    h_I = min(y1 + h1, y2 + h2) - max(y1, y2)\n",
        "    if w_I <= 0 or h_I <= 0:  # no overlap\n",
        "        return 0\n",
        "    I = w_I * h_I\n",
        "\n",
        "    U = w1 * h1 + w2 * h2 - I\n",
        "\n",
        "    return I / U\n",
        "\n",
        "def dist(bbox1, bbox2):\n",
        "    return np.sqrt(np.sum(np.square(bbox1[:2] - bbox2[:2])))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ULFtqFOw8Ops"
      },
      "source": [
        "num_epochs_flipping = 1\n",
        "num_epochs_no_flipping = 0  # has no significant effect\n",
        "\n",
        "flipped_train_y = np.array(train_y)\n",
        "flipped = np.zeros((len(train_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "ious_epoch = np.zeros((len(train_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "dists_epoch = np.zeros((len(train_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "mses_epoch = np.zeros((len(train_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "acc_shapes_epoch = np.zeros((len(train_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "acc_colors_epoch = np.zeros((len(train_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "\n",
        "flipped_test_y = np.array(test_y)\n",
        "flipped_test = np.zeros((len(test_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "ious_test_epoch = np.zeros((len(test_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "dists_test_epoch = np.zeros((len(test_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "mses_test_epoch = np.zeros((len(test_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "acc_shapes_test_epoch = np.zeros((len(test_y), num_epochs_flipping + num_epochs_no_flipping))\n",
        "acc_colors_test_epoch = np.zeros((len(test_y), num_epochs_flipping + num_epochs_no_flipping))\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rr-QEHUK8TzH",
        "outputId": "482d76c2-02a3-496b-beb3-a9aa5b58aacd"
      },
      "source": [
        "# TODO: Calculate ious directly for all samples (using slices of the array pred_y for x, y, w, h).\n",
        "for epoch in range(num_epochs_flipping):\n",
        "    print('Epoch', epoch)\n",
        "    model.fit(train_X, flipped_train_y, epochs=1, validation_data=(test_X, test_y), verbose=2)\n",
        "    pred_y = model.predict(train_X)\n",
        "\n",
        "    for sample, (pred, exp) in enumerate(zip(pred_y, flipped_train_y)):\n",
        "        \n",
        "        # TODO: Make this simpler.\n",
        "        pred = pred.reshape(num_objects, -1)\n",
        "        exp = exp.reshape(num_objects, -1)\n",
        "        \n",
        "        pred_bboxes = pred[:, :4]\n",
        "        exp_bboxes = exp[:, :4]\n",
        "        \n",
        "        ious = np.zeros((num_objects, num_objects))\n",
        "        dists = np.zeros((num_objects, num_objects))\n",
        "        mses = np.zeros((num_objects, num_objects))\n",
        "        for i, exp_bbox in enumerate(exp_bboxes):\n",
        "            for j, pred_bbox in enumerate(pred_bboxes):\n",
        "                ious[i, j] = IOU(exp_bbox, pred_bbox)\n",
        "                dists[i, j] = dist(exp_bbox, pred_bbox)\n",
        "                mses[i, j] = np.mean(np.square(exp_bbox - pred_bbox))\n",
        "                \n",
        "        new_order = np.zeros(num_objects, dtype=int)\n",
        "        \n",
        "        for i in range(num_objects):\n",
        "            # Find pred and exp bbox with maximum iou and assign them to each other (i.e. switch the positions of the exp bboxes in y).\n",
        "            ind_exp_bbox, ind_pred_bbox = np.unravel_index(ious.argmax(), ious.shape)\n",
        "            ious_epoch[sample, epoch] += ious[ind_exp_bbox, ind_pred_bbox]\n",
        "            dists_epoch[sample, epoch] += dists[ind_exp_bbox, ind_pred_bbox]\n",
        "            mses_epoch[sample, epoch] += mses[ind_exp_bbox, ind_pred_bbox]\n",
        "            ious[ind_exp_bbox] = -1  # set iou of assigned bboxes to -1, so they don't get assigned again\n",
        "            ious[:, ind_pred_bbox] = -1\n",
        "            new_order[ind_pred_bbox] = ind_exp_bbox\n",
        "        \n",
        "        flipped_train_y[sample] = exp[new_order].flatten()\n",
        "        \n",
        "        flipped[sample, epoch] = 1. - np.mean(new_order == np.arange(num_objects, dtype=int))#np.array_equal(new_order, np.arange(num_objects, dtype=int))  # TODO: Change this to reflect the number of flips.\n",
        "        ious_epoch[sample, epoch] /= num_objects\n",
        "        dists_epoch[sample, epoch] /= num_objects\n",
        "        mses_epoch[sample, epoch] /= num_objects\n",
        "        \n",
        "        acc_shapes_epoch[sample, epoch] = np.mean(np.argmax(pred[:, 4:4+num_shapes], axis=-1) == np.argmax(exp[:, 4:4+num_shapes], axis=-1))\n",
        "        acc_colors_epoch[sample, epoch] = np.mean(np.argmax(pred[:, 4+num_shapes:4+num_shapes+num_colors], axis=-1) == np.argmax(exp[:, 4+num_shapes:4+num_shapes+num_colors], axis=-1))\n",
        "\n",
        "    \n",
        "    # Calculate metrics on test data. \n",
        "    pred_test_y = model.predict(test_X)\n",
        "    # TODO: Make this simpler.\n",
        "    for sample, (pred, exp) in enumerate(zip(pred_test_y, flipped_test_y)):\n",
        "        \n",
        "        # TODO: Make this simpler.\n",
        "        pred = pred.reshape(num_objects, -1)\n",
        "        exp = exp.reshape(num_objects, -1)\n",
        "        \n",
        "        pred_bboxes = pred[:, :4]\n",
        "        exp_bboxes = exp[:, :4]\n",
        "        \n",
        "        ious = np.zeros((num_objects, num_objects))\n",
        "        dists = np.zeros((num_objects, num_objects))\n",
        "        mses = np.zeros((num_objects, num_objects))\n",
        "        for i, exp_bbox in enumerate(exp_bboxes):\n",
        "            for j, pred_bbox in enumerate(pred_bboxes):\n",
        "                ious[i, j] = IOU(exp_bbox, pred_bbox)\n",
        "                dists[i, j] = dist(exp_bbox, pred_bbox)\n",
        "                mses[i, j] = np.mean(np.square(exp_bbox - pred_bbox))\n",
        "                \n",
        "        new_order = np.zeros(num_objects, dtype=int)\n",
        "        \n",
        "        for i in range(num_objects):\n",
        "            # Find pred and exp bbox with maximum iou and assign them to each other (i.e. switch the positions of the exp bboxes in y).\n",
        "            ind_exp_bbox, ind_pred_bbox = np.unravel_index(mses.argmin(), mses.shape)\n",
        "            ious_test_epoch[sample, epoch] += ious[ind_exp_bbox, ind_pred_bbox]\n",
        "            dists_test_epoch[sample, epoch] += dists[ind_exp_bbox, ind_pred_bbox]\n",
        "            mses_test_epoch[sample, epoch] += mses[ind_exp_bbox, ind_pred_bbox]\n",
        "            mses[ind_exp_bbox] = 1000000#-1  # set iou of assigned bboxes to -1, so they don't get assigned again\n",
        "            mses[:, ind_pred_bbox] = 10000000#-1\n",
        "            new_order[ind_pred_bbox] = ind_exp_bbox\n",
        "        \n",
        "        flipped_test_y[sample] = exp[new_order].flatten()\n",
        "        \n",
        "        flipped_test[sample, epoch] = 1. - np.mean(new_order == np.arange(num_objects, dtype=int))#np.array_equal(new_order, np.arange(num_objects, dtype=int))  # TODO: Change this to reflect the number of flips.\n",
        "        ious_test_epoch[sample, epoch] /= num_objects\n",
        "        dists_test_epoch[sample, epoch] /= num_objects\n",
        "        mses_test_epoch[sample, epoch] /= num_objects\n",
        "        \n",
        "        acc_shapes_test_epoch[sample, epoch] = np.mean(np.argmax(pred[:, 4:4+num_shapes], axis=-1) == np.argmax(exp[:, 4:4+num_shapes], axis=-1))\n",
        "        acc_colors_test_epoch[sample, epoch] = np.mean(np.argmax(pred[:, 4+num_shapes:4+num_shapes+num_colors], axis=-1) == np.argmax(exp[:, 4+num_shapes:4+num_shapes+num_colors], axis=-1))\n",
        "       \n",
        "            \n",
        "    print('Flipped {} % of all elements'.format(np.mean(flipped[:, epoch]) * 100.))\n",
        "    print('Mean IOU: {}'.format(np.mean(ious_epoch[:, epoch])))\n",
        "    print('Mean dist: {}'.format(np.mean(dists_epoch[:, epoch])))\n",
        "    print('Mean mse: {}'.format(np.mean(mses_epoch[:, epoch])))\n",
        "    print('Accuracy shapes: {}'.format(np.mean(acc_shapes_epoch[:, epoch])))\n",
        "    print('Accuracy colors: {}'.format(np.mean(acc_colors_epoch[:, epoch])))\n",
        "    \n",
        "    print('--------------- TEST ----------------')\n",
        "    print('Flipped {} % of all elements'.format(np.mean(flipped_test[:, epoch]) * 100.))\n",
        "    print('Mean IOU: {}'.format(np.mean(ious_test_epoch[:, epoch])))\n",
        "    print('Mean dist: {}'.format(np.mean(dists_test_epoch[:, epoch])))\n",
        "    print('Mean mse: {}'.format(np.mean(mses_test_epoch[:, epoch])))\n",
        "    print('Accuracy shapes: {}'.format(np.mean(acc_shapes_test_epoch[:, epoch])))\n",
        "    print('Accuracy colors: {}'.format(np.mean(acc_colors_test_epoch[:, epoch])))\n",
        "    print()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0\n",
            "25/25 - 33s - loss: 0.2424 - val_loss: 0.2394\n",
            "Flipped 0.0 % of all elements\n",
            "Mean IOU: 0.0\n",
            "Mean dist: 0.5303613788074256\n",
            "Mean mse: 0.1333769995181736\n",
            "Accuracy shapes: 0.344375\n",
            "Accuracy colors: 0.31\n",
            "--------------- TEST ----------------\n",
            "Flipped 49.0 % of all elements\n",
            "Mean IOU: 0.0\n",
            "Mean dist: 0.5323391392244374\n",
            "Mean mse: 0.13168510682584525\n",
            "Accuracy shapes: 0.31\n",
            "Accuracy colors: 0.325\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 949
        },
        "id": "avZAedkX_NfQ",
        "outputId": "e8ab82f5-0c90-47b7-85e7-192995c4b962"
      },
      "source": [
        "# model.layers\n",
        "weights = model.layers[0].get_weights()[0]\n",
        "weights = weights.transpose(3, 0, 1, 2)\n",
        "print(weights.shape)\n",
        "# plt.imshow(weights[0] * 255. + 128., interpolation='none', origin='lower')\n",
        "print(np.mean(weights[0]), np.std(weights[0]), np.min(weights[0]), np.max(weights[0]))\n",
        "adj_weights = (weights * 255.) + 128.\n",
        "print(np.mean(adj_weights[0]), np.std(adj_weights[0]), np.min(adj_weights[0]), np.max(adj_weights[0]))\n",
        "plt.figure(figsize=(16, 8))\n",
        "for i in range(24):\n",
        "    plt.subplot(4, 6, i+1)\n",
        "    plt.imshow(adj_weights[i, :, :], interpolation='none', origin='lower', cmap='Greys')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "(32, 5, 5, 3)\n",
            "0.0071782973 0.045646656 -0.08047159 0.08199458\n",
            "129.83046 11.639896 107.479744 148.90862\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "\n",
            "text/plain": [
              "<Figure size 1152x576 with 24 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 300
        },
        "id": "gs3CoEUQAAER",
        "outputId": "6dce11dc-c057-4eda-86dd-4518f66699cc"
      },
      "source": [
        "plt.pcolor(flipped[:1000], cmap='Greys', vmax=1.)\n",
        "# plt.axvline(num_epochs_flipping, c='r')\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Training sample')"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Text(0, 0.5, 'Training sample')"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAEKCAYAAADuEgmxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXq0lEQVR4nO3de7RedX3n8fdHAoiggBAzTAKFqVGHsVXxaLG6OhWqS2gXwalFGFsjK9PMsqhYHUc6M6s61q6lM44X2i40FjVYFZFqiZZamYg6zggahCIXHSNySeQSFQKKoOB3/nh+mTzEXH7nJPucJ8n7tdaznt/+7cv5nr2SfLJvv52qQpKkHo+a6wIkSbsPQ0OS1M3QkCR1MzQkSd0MDUlSN0NDktRt0NBI8sdJrk9yXZKPJXl0kmOSXJlkbZKPJ9mvLbt/m17b5h89ZG2SpOkbLDSSLAReA0xV1VOBfYDTgbcD76qqJwJ3A8vaKsuAu1v/u9pykqQJMvTpqXnAAUnmAY8BbgdOAC5u81cCp7b2kjZNm39ikgxcnyRpGuYNteGqWp/kHcCtwE+AzwFXAfdU1UNtsXXAwtZeCNzW1n0oyUbgMOD749tNshxYDnDggQc+8ylPecpQv4Ik7ZGuuuqq71fV/JmsO1hoJDmU0dHDMcA9wCeAF+3sdqtqBbACYGpqqtasWbOzm5SkvUqSW2a67pCnp34L+G5VbaiqnwGfBJ4LHNJOVwEsAta39nrgSIA2/2DgBwPWJ0mapiFD41bg+CSPadcmTgRuAC4HXtKWWQpc0tqr2jRt/ufL0RQlaaIMFhpVdSWjC9pfB77RftYK4I3A65KsZXTN4vy2yvnAYa3/dcA5Q9UmSZqZ7M7/mfeahiRNX5KrqmpqJuv6RLgkqZuhIUnqZmhIkroZGpKkboaGJKmboSFJ6mZoSJK6GRqSpG6GhiSpm6EhSepmaEiSuhkakqRuhoYkqZuhIUnqZmhIkroZGpKkboaGJKmboSFJ6jZYaCR5cpJrxj73JnltkscnuSzJt9v3oW35JDk3ydok1yY5bqjaJEkzM1hoVNW3qurpVfV04JnA/cCngHOA1VW1GFjdpgFOAha3z3LgvKFqkyTNzGydnjoR+E5V3QIsAVa2/pXAqa29BLigRq4ADklyxCzVJ0nqMFuhcTrwsdZeUFW3t/YdwILWXgjcNrbOutYnSZoQg4dGkv2AU4BPbDmvqgqoaW5veZI1SdZs2LBhF1UpSeoxG0caJwFfr6o72/Sdm047te+7Wv964Mix9Ra1vkeoqhVVNVVVU/Pnzx+wbEnSlmYjNM5g86kpgFXA0tZeClwy1v/ydhfV8cDGsdNYkqQJMG/IjSc5EHgB8O/Hut8GXJRkGXALcFrrvxQ4GVjL6E6rM4esTZI0fYOGRlX9GDhsi74fMLqbastlCzhryHokSTvHJ8IlSd0MDUlSN0NDktTN0JAkdTM0JEndDA1JUjdDQ5LUzdCQJHUzNCRJ3QwNSVI3Q0OS1M3QkCR1MzQkSd0MDUlSN0NDktTN0JAkdTM0JEndDA1JUjdDQ5LUbdDQSHJIkouTfDPJjUmek+TxSS5L8u32fWhbNknOTbI2ybVJjhuyNknS9A19pPEe4LNV9RTgacCNwDnA6qpaDKxu0wAnAYvbZzlw3sC1SZKmabDQSHIw8BvA+QBV9dOqugdYAqxsi60ETm3tJcAFNXIFcEiSI4aqT5I0fUMeaRwDbAA+mOTqJH+d5EBgQVXd3pa5A1jQ2guB28bWX9f6HiHJ8iRrkqzZsGHDgOVLkrY0ZGjMA44DzquqZwA/ZvOpKACqqoCazkarakVVTVXV1Pz583dZsZKkHRsyNNYB66rqyjZ9MaMQuXPTaaf2fVebvx44cmz9Ra1PkjQhBguNqroDuC3Jk1vXicANwCpgaetbClzS2quAl7e7qI4HNo6dxpIkTYB5A2//1cBHkuwH3AScySioLkqyDLgFOK0teylwMrAWuL8tK0maIIOGRlVdA0xtZdaJW1m2gLOGrEeStHN8IlyS1M3QkCR1MzQkSd0MDUlSN0NDktTN0JAkdTM0JEndDA1JUjdDQ5LUzdCQJHXrCo0kz0tyZmvPT3LMsGVJkibRDkMjyZuANwJ/0rr2Bf5myKIkSZOp50jjxcApjF6iRFV9D3jskEVJkiZTT2j8dPwNe+2VrZKkvVBPaFyU5H3AIUn+EPifwPuHLUuSNIl2+D6NqnpHkhcA9wJPBv60qi4bvDJJ0sTpeglTCwmDQpL2ctsMjST30a5jbDmL0Yv2HjdYVZKkibTN0Kiqnb5DKsnNwH3Aw8BDVTWV5PHAx4GjgZuB06rq7iQB3sPoPeH3A6+oqq/vbA2SpF2n9+G+45K8Jsmrkzxjmj/j+VX19Kra9K7wc4DVVbUYWN2mAU4CFrfPcuC8af4cSdLAeh7u+1NgJXAYcDjwoST/ZSd+5pK2Pdr3qWP9F9TIFYzu1jpiJ36OJGkX67kQ/jLgaVX1AECStwHXAG/tWLeAzyUp4H1VtQJYUFW3t/l3AAtaeyFw29i661rf7WN9JFnO6EiEo446qqMESdKu0hMa3wMeDTzQpvcH1ndu/3lVtT7JE4DLknxzfGZVVQuUbi14VgBMTU1Na11J0s7pCY2NwPVJLmN05PAC4KtJzgWoqtdsa8WqWt++70ryKeDZwJ1Jjqiq29vpp7va4uuBI8dWX0R/OEmSZkFPaHyqfTb5Qs+G23Ajj6qq+1r7hcBbgFXAUuBt7fuStsoq4FVJLgR+Ddg4dhpLkjQBep4IX7mjZbZhAfCp0Z20zAM+WlWfTfI1RkOTLANuAU5ry1/K6HbbtYxuuT1zhj9XkjSQHYZGkt8B/gz4pbZ818N9VXUT8LSt9P8AOHEr/QWc1Ve2JGku9Jyeejfwb4BvtH/YJUl7qZ6H+24DrjMwJEk9Rxr/Ebg0yReBBzd1VtU7B6tKkjSRekLjz4EfMXpWY79hy5EkTbKe0PjnVfXUwSuRJE28nmsalyZ54eCVSJImXk9ovBL4bJKfJLk3yX1J7h26MEnS5Ol5uG+n36shSdozdL3uNcmhjN5z8ehNfVX1paGKkiRNpp4nwv8dcDajAQSvAY4HvgKcMGxpkqRJ03NN42zgWcAtVfV84BnAPYNWJUmaSD2h8cDYC5j2r6pvAk8etixJ0iTquaaxLskhwN8xepHS3YxGp5Uk7WV67p56cWu+OcnlwMHAZwetSpI0kXZ4eirJLyfZf9MkcDTwmCGLkiRNpp5rGn8LPJzkiYzezX0k8NFBq5IkTaSe0Ph5VT0EvBj4i6p6A3DEsGVJkiZRT2j8LMkZjN7n/ZnWt+9wJUmSJlVPaJwJPAf486r6bpJjgA/3/oAk+yS5Osln2vQxSa5MsjbJx5Ps1/r3b9Nr2/yjp//rSJKGtMPQqKobquo1VfWxNv3dqnr7NH7G2cCNY9NvB95VVU8E7gaWtf5lwN2t/11tOUnSBOk50pixJIuA3wb+uk2H0fAjF7dFVgKntvaSNk2bf2JbXpI0IQYNDeDdjF4X+/M2fRhwT7uwDrAOWNjaCxm9j5w2f2Nb/hGSLE+yJsmaDRs2DFm7JGkLg4VGkt8B7qqqq3bldqtqRVVNVdXU/Pnzd+WmJUk70DPK7aeB2qJ7I7AGeN+mcam24rnAKUlOZjSk+uOA9wCHJJnXjiYWAevb8usZPQOyLsk8Rk+e/2Cav48kaUA9Rxo3AT8C3t8+9wL3AU9q01tVVX9SVYuq6mjgdODzVfUy4HLgJW2xpcAlrb2qTdPmf76qtgwrSdIc6hmw8Ner6llj059O8rWqelaS62fwM98IXJjkrcDVwPmt/3zgw0nWAj9kFDSSpAnSExoHJTmqqm4FSHIUcFCb99OeH1JVXwC+0No3Ac/eyjIPAL/Xsz1J0tzoCY3XA19O8h1GAxYeA/xRkgPZfIusJGkv0DM0+qVJFgNPaV3fGrv4/e7BKpMkTZyeIw2AZzIaEn0e8LQkVNUFg1UlSZpIPbfcfhj4ZeAa4OHWXYChIUl7mZ4jjSngWG9/lST1PKdxHfDPhi5EkjT5eo40DgduSPJV4MFNnVV1ymBVSZImUk9ovHnoIiRJu4eeW26/OBuFSJIm3zZDI8mXq+p5Se7jkQMWBqiqetzg1UmSJso2Q6Oqnte+Hzt75UiSJlnXw31J9gEWjC+/aSwqSdLeo+fhvlcDbwLuZPMb+Ar41QHrkiRNoJ4jjbOBJ1eVL0SSpL1cz8N9tzF6U58kaS/Xc6RxE/CFJH/PIx/ue+dgVUmSJlJPaNzaPvu1jyRpL9XzcN9/nY1CJEmTb3sP9727ql6b5NM88uE+YMdjTyV5NPAlYP/2cy6uqjclOQa4EDgMuAr4g6r6aZL9GQ23/kzgB8BLq+rmmf1akqQhbO9I48Pt+x0z3PaDwAlV9aMk+zJ6Zew/AK8D3lVVFyZ5L7AMOK99311VT0xyOvB24KUz/NmSpAFs74nwq9r3jMaeau/f+FGb3Ld9CjgB+LetfyWjARHPA5aweXDEi4G/TBLf4yFJk2OHt9wmWZzk4iQ3JLlp06dn40n2SXINcBdwGfAd4J6qeqgtsg5Y2NoLGd3eS5u/kdEprC23uTzJmiRrNmzY0FOGJGkX6XlO44OMjgQeAp7P6LrD3/RsvKoerqqnA4uAZwNPmWGd49tcUVVTVTU1f/78nd2cJGkaekLjgKpaDaSqbqmqNwO/PZ0fUlX3AJcDzwEOSbLptNgiYH1rrweOBGjzD2Z0QVySNCF6QuPBJI8Cvp3kVUleDBy0o5WSzE9ySGsfALwAuJFReLykLbYUuKS1V7Vp2vzPez1DkiZL79hTjwFeA/wZo1NUS7e7xsgRwMo2Qu6jgIuq6jNJbgAuTPJW4Grg/Lb8+cCHk6wFfgicPq3fRJI0uO2GRvsH/6VV9R8Y3Ql1Zu+Gq+pa4Blb6b+J0fWNLfsfAH6vd/uSpNm3zdNTSeZV1cPA82axHknSBNvekcZXgeOAq5OsAj4B/HjTzKr65MC1SZImTM81jUczuovpBEYP56V9GxqStJfZXmg8IcnrgOvYHBabeFeTJO2Fthca+zC6tTZbmWdoSNJeaHuhcXtVvWXWKpEkTbztPdy3tSMMSdJebHuhceKsVSFJ2i1sMzSq6oezWYgkafL1jD0lSRJgaEiSpsHQkCR1MzQkSd0MDUlSN0NDktTN0JAkdTM0JEndDA1JUrfBQiPJkUkuT3JDkuuTnN36H5/ksiTfbt+Htv4kOTfJ2iTXJjluqNokSTMz5JHGQ8Drq+pY4HjgrCTHAucAq6tqMbC6TQOcBCxun+XAeQPWJkmagcFCo6pur6qvt/Z9wI3AQmAJsLItthI4tbWXABfUyBXAIUmOGKo+SdL0zco1jSRHA88ArgQWVNXtbdYdwILWXgjcNrbauta35baWJ1mTZM2GDRsGq1mS9IsGD40kBwF/C7y2qu4dn1dVxTTfAlhVK6pqqqqm5s+fvwsrlSTtyKChkWRfRoHxkar6ZOu+c9Npp/Z9V+tfDxw5tvqi1idJmhBD3j0V4Hzgxqp659isVcDS1l4KXDLW//J2F9XxwMax01iSpAmwvXeE76znAn8AfCPJNa3vPwFvAy5Ksgy4BTitzbsUOBlYC9wPnDlgbZKkGRgsNKrqy2z7PeO/8CrZdn3jrKHqkSTtPJ8IlyR1MzQkSd0MDUlSN0NDktTN0JAkdTM0JEndDA1JUjdDQ5LUzdCQJHUzNCRJ3QwNSVI3Q0OS1M3QkCR1MzQkSd0MDUlSN0NDktTN0JAkdTM0JEndBguNJB9IcleS68b6Hp/ksiTfbt+Htv4kOTfJ2iTXJjluqLokSTM35JHGh4AXbdF3DrC6qhYDq9s0wEnA4vZZDpw3YF2SpBkaLDSq6kvAD7foXgKsbO2VwKlj/RfUyBXAIUmOGKo2SdLMzPY1jQVVdXtr3wEsaO2FwG1jy61rfZKkCTJnF8KrqoCa7npJlidZk2TNhg0bBqhMkrQtsx0ad2467dS+72r964Ejx5Zb1Pp+QVWtqKqpqpqaP3/+oMVKkh5ptkNjFbC0tZcCl4z1v7zdRXU8sHHsNJYkaULMG2rDST4G/CZweJJ1wJuAtwEXJVkG3AKc1ha/FDgZWAvcD5w5VF2SpJkbLDSq6oxtzDpxK8sWcNZQtUiSdg2fCJckdTM0JEndDA1JUjdDQ5LUzdCQJHUzNCRJ3QwNSVI3Q0OS1M3QkCR1MzQkSd0MDUlSN0NDktTN0JAkdTM0JEndDA1JUjdDQ5LUzdCQJHUzNCRJ3QwNSVK3iQqNJC9K8q0ka5OcM9f1SJIeaWJCI8k+wF8BJwHHAmckOXZuq5IkjZuY0ACeDaytqpuq6qfAhcCSOa5JkjRm3lwXMGYhcNvY9Drg17ZcKMlyYHmbfDDJdbNQ2+7gcOD7c13EhHBfbOa+2Mx9sdmTZ7riJIVGl6paAawASLKmqqbmuKSJ4L7YzH2xmftiM/fFZknWzHTdSTo9tR44cmx6UeuTJE2ISQqNrwGLkxyTZD/gdGDVHNckSRozMaenquqhJK8C/hHYB/hAVV2/g9VWDF/ZbsN9sZn7YjP3xWbui81mvC9SVbuyEEnSHmySTk9JkiacoSFJ6rZbhMaOhhdJsn+Sj7f5VyY5evarnB0d++J1SW5Icm2S1Ul+aS7qnA29w84k+d0klWSPvd2yZ18kOa392bg+yUdnu8bZ0vF35Kgklye5uv09OXku6hxakg8kuWtbz7Jl5Ny2n65NclzXhqtqoj+MLop/B/gXwH7APwHHbrHMHwHvbe3TgY/Pdd1zuC+eDzymtV+5N++LttxjgS8BVwBTc133HP65WAxcDRzapp8w13XP4b5YAbyytY8Fbp7rugfaF78BHAdct435JwP/AAQ4HriyZ7u7w5FGz/AiS4CVrX0xcGKSzGKNs2WH+6KqLq+q+9vkFYyed9kT9Q4782fA24EHZrO4WdazL/4Q+Kuquhugqu6a5RpnS8++KOBxrX0w8L1ZrG/WVNWXgB9uZ5ElwAU1cgVwSJIjdrTd3SE0tja8yMJtLVNVDwEbgcNmpbrZ1bMvxi1j9D+JPdEO90U73D6yqv5+NgubAz1/Lp4EPCnJ/05yRZIXzVp1s6tnX7wZ+P0k64BLgVfPTmkTZ7r/ngAT9JyGdq0kvw9MAf96rmuZC0keBbwTeMUclzIp5jE6RfWbjI4+v5TkV6rqnjmtam6cAXyoqv5HkucAH07y1Kr6+VwXtjvYHY40eoYX+f/LJJnH6JDzB7NS3ezqGmolyW8B/xk4paoenKXaZtuO9sVjgacCX0hyM6Nztqv20IvhPX8u1gGrqupnVfVd4P8yCpE9Tc++WAZcBFBVXwEezWgww73NjIZu2h1Co2d4kVXA0tZ+CfD5ald69jA73BdJngG8j1Fg7KnnrWEH+6KqNlbV4VV1dFUdzej6zilVNeOB2iZYz9+Rv2N0lEGSwxmdrrppNoucJT374lbgRIAk/5JRaGyY1Sonwyrg5e0uquOBjVV1+45WmvjTU7WN4UWSvAVYU1WrgPMZHWKuZXTh5/S5q3g4nfvivwMHAZ9o9wLcWlWnzFnRA+ncF3uFzn3xj8ALk9wAPAy8oar2uKPxzn3xeuD9Sf6Y0UXxV+yJ/8lM8jFG/1E4vF2/eROwL0BVvZfR9ZyTgbXA/cCZXdvdA/eVJGkgu8PpKUnShDA0JEndDA1JUjdDQ5LUzdCQJHUzNKTtSPJwkmvGPtscTXcG2z56WyOQSpNq4p/TkObYT6rq6XNdhDQpPNKQZiDJzUn+W5JvJPlqkie2/qOTfH7sfSZHtf4FST6V5J/a59fbpvZJ8v72jovPJTlgzn4pqYOhIW3fAVucnnrp2LyNVfUrwF8C7259fwGsrKpfBT4CnNv6zwW+WFVPY/SOg+tb/2JGQ5b/K+Ae4HcH/n2kneIT4dJ2JPlRVR20lf6bgROq6qYk+wJ3VNVhSb4PHFFVP2v9t1fV4Uk2AIvGB5DM6A2Tl1XV4jb9RmDfqnrr8L+ZNDMeaUgzV9toT8f4KMQP43VGTThDQ5q5l459f6W1/w+bB8x8GfC/Wns1o9fvkmSfJAfPVpHSruT/aqTtOyDJNWPTn62qTbfdHprkWkZHC2e0vlcDH0zyBkbDbW8aOfRsYEWSZYyOKF4J7HAYamnSeE1DmoF2TWOqqr4/17VIs8nTU5Kkbh5pSJK6eaQhSepmaEiSuhkakqRuhoYkqZuhIUnq9v8Aa20lzV+bDIoAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3G7WkFCiADlB",
        "outputId": "161e76b1-f215-4cdb-ce88-7eb8564648e1"
      },
      "source": [
        "\n",
        "pred_y = model.predict(test_X)\n",
        "pred_y = pred_y.reshape(len(pred_y), num_objects, -1)\n",
        "pred_bboxes = pred_y[..., :4] * img_size\n",
        "pred_shapes = np.argmax(pred_y[..., 4:4+num_shapes], axis=-1).astype(int)  # take max from probabilities\n",
        "# print pred_y[..., 4+num_shapes:4+num_shapes+num_colors].shape\n",
        "# print np.argmax(pred_y[..., 5:8], axis=-1).shape\n",
        "pred_colors = np.argmax(pred_y[..., 4+num_shapes:4+num_shapes+num_colors], axis=-1).astype(int)\n",
        "pred_bboxes.shape, pred_shapes.shape, pred_colors.shape"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "((200, 2, 4), (200, 2), (200, 2))"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "p0NOPbY8AKBS",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 479
        },
        "outputId": "39250029-6816-45f6-d05d-3e2f7f9d174d"
      },
      "source": [
        "plt.figure(figsize=(16, 8))\n",
        "for i_subplot in range(1, 9):\n",
        "    plt.subplot(2, 4, i_subplot)\n",
        "    i = np.random.randint(len(test_X))\n",
        "    plt.imshow(test_imgs[i], interpolation='none', origin='lower', extent=[0, img_size, 0, img_size])\n",
        "    for bbox, shape, color in zip(pred_bboxes[i], pred_shapes[i], pred_colors[i]):\n",
        "        plt.gca().add_patch(matplotlib.patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], ec='k', fc='none'))\n",
        "        plt.annotate(shape_labels[shape], (bbox[0], bbox[1] + bbox[3] + 0.7), color=color_labels[color], clip_on=False, bbox={'fc': 'w', 'ec': 'none', 'pad': 1, 'alpha': 0.6})"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "\n",
            "text/plain": [
              "<Figure size 1152x576 with 8 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    }
  ]
}