{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Intro"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Very basic CNN for recognizing hand-drawn images of various objects, to which Gaussian noise was added.
\n",
"One class of objects contain only noise.
\n",
"The dataset is no longer available but the CNN results can still be seen below
\n",
"tl;dr ~60% accuracy (disappointing). Noise reduction & feature extraction would probably help"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "d6f97c979edf53a5ac054750fa6484aa5b318ee3",
"colab": {},
"colab_type": "code",
"id": "vNt7fSoZWlyQ"
},
"outputs": [],
"source": [
"from sklearn import preprocessing\n",
"from sklearn.model_selection import train_test_split\n",
"import numpy as np \n",
"import pickle\n",
"import keras\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.utils.np_utils import to_categorical\n",
"from keras import models, layers\n",
"import os\n",
"import pandas as pd"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "dd7df30613f7ab2bcf548945d2b4e68657c46dd0",
"colab_type": "text",
"id": "Y9I0F6PJNBCs"
},
"source": [
"## 1. Data loading & preprocessing"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "5f946a240811336d11101c9d4ce7b464b85f9f28",
"colab_type": "text",
"id": "WQX_NWufQUio"
},
"source": [
"### Set seed for reproducibility"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "b3b10adca9fe738d1e58e419d785c2eb6c448a19",
"colab": {},
"colab_type": "code",
"id": "wj6dCMOKDFbA"
},
"outputs": [],
"source": [
"SEED = 123\n",
"np.random.seed(SEED)"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "86bafbf9544b79ed558cfee3d1c55bd12a606946",
"colab_type": "text",
"id": "Cw6tQyE3Qc9s"
},
"source": [
"### Load data"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "28291c6eb3b5ef980402077ed6e5e9ea1d2cf055",
"colab": {},
"colab_type": "code",
"id": "I_Iz0td90lMQ"
},
"outputs": [],
"source": [
"DIR = \"input\" # Location of input data"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "f57ad597b449cdd125949f50c81bfdc15b4ed3be",
"colab": {},
"colab_type": "code",
"id": "056xpFX4XAfo"
},
"outputs": [],
"source": [
"X_train = np.load(os.path.join(DIR, \"train_images.npy\"), encoding='latin1')\n",
"train_labels = np.genfromtxt(os.path.join(DIR, \"train_labels.csv\"), names=True, delimiter=',', dtype=[('Id', 'i8'), ('Category', 'S15')])\n",
"X_test = np.load(os.path.join(DIR, \"test_images.npy\"), encoding = 'latin1')\n",
"\n",
"X_train = np.array(tuple(x[1] for x in X_train))\n",
"X_test = np.array(tuple(x[1] for x in X_test))"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "815617a18fa250bafa9d0494145cfec3b98dce18",
"colab_type": "text",
"id": "xJpTJ4AqXM0s"
},
"source": [
"### Numerically-encode image labels"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "7b4aa597d5ff7a5e87a1d03abeb76942e7546062",
"colab": {},
"colab_type": "code",
"id": "fnNWfdDAXN0w"
},
"outputs": [],
"source": [
"# Numerical encoding\n",
"y = train_labels[:]['Category']\n",
"le = preprocessing.LabelEncoder()\n",
"le.fit(y)\n",
"y = le.transform(y)"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "afcfbad1573f6f5973d7994fc4ae046747e94023",
"colab_type": "text",
"id": "_jfpAOcWAxIY"
},
"source": [
"### Normalize data"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "5c675f98cf00e66b77a4e806642916a073c5d8f8",
"colab": {},
"colab_type": "code",
"id": "JK1-Zjx3AwIU"
},
"outputs": [],
"source": [
"# converts images to greyscale\n",
"X_train /= X_train.max()\n",
"X_test /= X_test.max()"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "87e543d7e075f600465d6a0880495425453d04a4",
"colab_type": "text",
"id": "OUvmY7oHDTEc"
},
"source": [
"### Split into training and validation sets"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "55e8052e81d553b693e35f83a716b28faa3635dc",
"colab": {},
"colab_type": "code",
"id": "-E_MJHuKDVlY"
},
"outputs": [],
"source": [
"X_train, X_valid, y_train, y_valid = train_test_split(X_train,\n",
" y,\n",
" random_state=SEED,\n",
" stratify=y/len(y)\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "90f146a26f099ab26148465d8337c5edc08e0a17",
"colab_type": "text",
"id": "SRqpP7BT6RC0"
},
"source": [
"### Reshape data for keras"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "8701036c065c956b55ad352ee989b6c5b0e2f700",
"colab": {},
"colab_type": "code",
"id": "NZ8IbbJIF4X4"
},
"outputs": [],
"source": [
"# Reshape to 4D tensor (last dimension is nb. of channels)\n",
"X_train = X_train.reshape(X_train.shape[0], 100, 100, 1)\n",
"X_valid = X_valid.reshape(X_valid.shape[0], 100, 100, 1)\n",
"X_test = X_test.reshape(X_test.shape[0], 100, 100, 1)\n",
"\n",
"# One-hot encode labels\n",
"n_classes = len(np.unique(y))\n",
"y_train = to_categorical(y_train, n_classes)\n",
"y_valid = to_categorical(y_valid, n_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "fee6f6c8bc1272d8aa221ec3e03364699d46c8ff",
"colab_type": "text",
"id": "hoyrXviWXChU"
},
"source": [
"### Augment data through transformations"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "10ed402c06d61694b9ca603b122b25769b0d1b43",
"colab": {},
"colab_type": "code",
"id": "dbyRs1olgKO8"
},
"outputs": [],
"source": [
"# Training batch size\n",
"BATCH_SIZE = 32"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "cf24b6b012702f22920da760ce18e2ba00f678d0",
"colab": {},
"colab_type": "code",
"id": "pxVYEpMkK_Lw"
},
"outputs": [],
"source": [
"# Load datagen if already saved\n",
"path_to_datagen = os.path.join(DIR, \"datagen\")\n",
"\n",
"try:\n",
" with open(path_to_datagen, \"rb\") as f:\n",
" datagen = pickle.load(f)\n",
" \n",
"except FileNotFoundError:\n",
" # Create datagen\n",
" datagen = ImageDataGenerator(rotation_range=30, # rotate images by [0,30] deg.\n",
" horizontal_flip=True,\n",
" )\n",
" datagen.fit(X_train,\n",
" seed=SEED\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "c32357a210996697c593ecc1081a64b1fb08a65f",
"colab": {},
"colab_type": "code",
"id": "tGNR_Mu4rNG0"
},
"outputs": [],
"source": [
"# Save datagen to file\n",
"path_to_datagen = os.path.join(DIR, \"datagen\")\n",
"with open(path_to_datagen, \"wb\") as f:\n",
" pickle.dump(datagen, f)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "82b8eda4c60322dd74a3d084320c00ce6d8c1d6a",
"colab": {},
"colab_type": "code",
"id": "tLhID8jIgPT8"
},
"outputs": [],
"source": [
"# Create a transformed generator of X_train\n",
"X_train_transformed = datagen.flow(X_train,\n",
" y_train,\n",
" batch_size=BATCH_SIZE,\n",
" seed=SEED\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "d4b2b1880306eea4d0b37db83da596f17671343f",
"colab_type": "text",
"id": "vJzd2-ddsWVs"
},
"source": [
"## 2. Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "53cf761fa51b0dd1d719e534753528043f9b3a6f",
"colab_type": "text",
"id": "Xl2y_o6V2jJc"
},
"source": [
"### CNN"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"_uuid": "756ef2d4dbdab8ed73c05869d8adaf38b4bc4275",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 2622
},
"colab_type": "code",
"id": "G0E_T-Aa2iSw",
"outputId": "89adf3b7-44e3-42bc-c570-fa26a6d27cf1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/75\n",
" - 26s - loss: 3.4000 - acc: 0.0466 - val_loss: 3.3884 - val_acc: 0.0572\n",
"Epoch 2/75\n",
" - 22s - loss: 3.3901 - acc: 0.0523 - val_loss: 3.3792 - val_acc: 0.0572\n",
"Epoch 3/75\n",
" - 22s - loss: 3.3823 - acc: 0.0559 - val_loss: 3.3660 - val_acc: 0.0564\n",
"Epoch 4/75\n",
" - 22s - loss: 3.3435 - acc: 0.0665 - val_loss: 3.2844 - val_acc: 0.0656\n",
"Epoch 5/75\n",
" - 22s - loss: 3.2781 - acc: 0.0815 - val_loss: 3.3462 - val_acc: 0.0668\n",
"Epoch 6/75\n",
" - 22s - loss: 3.2233 - acc: 0.0904 - val_loss: 3.1801 - val_acc: 0.1020\n",
"Epoch 7/75\n",
" - 22s - loss: 3.1760 - acc: 0.0967 - val_loss: 3.1188 - val_acc: 0.1208\n",
"Epoch 8/75\n",
" - 22s - loss: 3.1446 - acc: 0.1009 - val_loss: 3.0972 - val_acc: 0.1076\n",
"Epoch 9/75\n",
" - 22s - loss: 3.1117 - acc: 0.1110 - val_loss: 3.1106 - val_acc: 0.1064\n",
"Epoch 10/75\n",
" - 22s - loss: 3.0848 - acc: 0.1131 - val_loss: 3.0844 - val_acc: 0.1044\n",
"Epoch 11/75\n",
" - 22s - loss: 3.0702 - acc: 0.1222 - val_loss: 3.0531 - val_acc: 0.1324\n",
"Epoch 12/75\n",
" - 22s - loss: 3.0029 - acc: 0.1491 - val_loss: 2.8860 - val_acc: 0.1812\n",
"Epoch 13/75\n",
" - 22s - loss: 2.8905 - acc: 0.1706 - val_loss: 2.9766 - val_acc: 0.1352\n",
"Epoch 14/75\n",
" - 22s - loss: 2.8030 - acc: 0.1952 - val_loss: 2.7868 - val_acc: 0.1740\n",
"Epoch 15/75\n",
" - 22s - loss: 2.7497 - acc: 0.2042 - val_loss: 2.6058 - val_acc: 0.2284\n",
"Epoch 16/75\n",
" - 22s - loss: 2.6533 - acc: 0.2350 - val_loss: 2.5227 - val_acc: 0.2496\n",
"Epoch 17/75\n",
" - 22s - loss: 2.5674 - acc: 0.2492 - val_loss: 2.4525 - val_acc: 0.2632\n",
"Epoch 18/75\n",
" - 22s - loss: 2.5015 - acc: 0.2664 - val_loss: 2.3620 - val_acc: 0.2940\n",
"Epoch 19/75\n",
" - 22s - loss: 2.4456 - acc: 0.2780 - val_loss: 2.3450 - val_acc: 0.2964\n",
"Epoch 20/75\n",
" - 22s - loss: 2.3751 - acc: 0.3027 - val_loss: 2.3370 - val_acc: 0.3136\n",
"Epoch 21/75\n",
" - 22s - loss: 2.3572 - acc: 0.3093 - val_loss: 2.4342 - val_acc: 0.2752\n",
"Epoch 22/75\n",
" - 22s - loss: 2.2875 - acc: 0.3326 - val_loss: 2.2770 - val_acc: 0.3328\n",
"Epoch 23/75\n",
" - 22s - loss: 2.2602 - acc: 0.3372 - val_loss: 2.1578 - val_acc: 0.3600\n",
"Epoch 24/75\n",
" - 22s - loss: 2.2152 - acc: 0.3453 - val_loss: 2.4199 - val_acc: 0.2872\n",
"Epoch 25/75\n",
" - 22s - loss: 2.1619 - acc: 0.3664 - val_loss: 2.2176 - val_acc: 0.3432\n",
"Epoch 26/75\n",
" - 22s - loss: 2.1415 - acc: 0.3729 - val_loss: 2.1399 - val_acc: 0.3716\n",
"Epoch 27/75\n",
" - 22s - loss: 2.1027 - acc: 0.3872 - val_loss: 2.2814 - val_acc: 0.3300\n",
"Epoch 28/75\n",
" - 22s - loss: 2.0819 - acc: 0.3923 - val_loss: 1.9929 - val_acc: 0.4116\n",
"Epoch 29/75\n",
" - 22s - loss: 2.0310 - acc: 0.4072 - val_loss: 2.0946 - val_acc: 0.3884\n",
"Epoch 30/75\n",
" - 22s - loss: 2.0042 - acc: 0.4126 - val_loss: 2.0247 - val_acc: 0.4092\n",
"Epoch 31/75\n",
" - 22s - loss: 1.9849 - acc: 0.4216 - val_loss: 1.9591 - val_acc: 0.4432\n",
"Epoch 32/75\n",
" - 22s - loss: 1.9260 - acc: 0.4357 - val_loss: 1.9129 - val_acc: 0.4548\n",
"Epoch 33/75\n",
" - 22s - loss: 1.8921 - acc: 0.4393 - val_loss: 1.9010 - val_acc: 0.4500\n",
"Epoch 34/75\n",
" - 22s - loss: 1.8885 - acc: 0.4521 - val_loss: 1.8530 - val_acc: 0.4688\n",
"Epoch 35/75\n",
" - 22s - loss: 1.8620 - acc: 0.4527 - val_loss: 1.9407 - val_acc: 0.4548\n",
"Epoch 36/75\n",
" - 22s - loss: 1.8289 - acc: 0.4702 - val_loss: 2.0116 - val_acc: 0.4160\n",
"Epoch 37/75\n",
" - 22s - loss: 1.8066 - acc: 0.4765 - val_loss: 1.8501 - val_acc: 0.4856\n",
"Epoch 38/75\n",
" - 22s - loss: 1.7652 - acc: 0.4809 - val_loss: 1.9134 - val_acc: 0.4660\n",
"Epoch 39/75\n",
" - 22s - loss: 1.7429 - acc: 0.4897 - val_loss: 1.7785 - val_acc: 0.4920\n",
"Epoch 40/75\n",
" - 22s - loss: 1.7355 - acc: 0.5036 - val_loss: 1.7783 - val_acc: 0.4848\n",
"Epoch 41/75\n",
" - 22s - loss: 1.7108 - acc: 0.4957 - val_loss: 1.8340 - val_acc: 0.4696\n",
"Epoch 42/75\n",
" - 22s - loss: 1.7125 - acc: 0.5056 - val_loss: 1.7989 - val_acc: 0.4728\n",
"Epoch 43/75\n",
" - 22s - loss: 1.6727 - acc: 0.5118 - val_loss: 2.1937 - val_acc: 0.3636\n",
"Epoch 44/75\n",
" - 22s - loss: 1.6641 - acc: 0.5199 - val_loss: 1.7716 - val_acc: 0.4936\n",
"Epoch 45/75\n",
" - 22s - loss: 1.6244 - acc: 0.5279 - val_loss: 1.7167 - val_acc: 0.5192\n",
"Epoch 46/75\n",
" - 22s - loss: 1.6108 - acc: 0.5368 - val_loss: 1.7427 - val_acc: 0.5276\n",
"Epoch 47/75\n",
" - 22s - loss: 1.5840 - acc: 0.5369 - val_loss: 1.7339 - val_acc: 0.5172\n",
"Epoch 48/75\n",
" - 22s - loss: 1.5703 - acc: 0.5364 - val_loss: 1.8054 - val_acc: 0.4888\n",
"Epoch 49/75\n",
" - 22s - loss: 1.5661 - acc: 0.5410 - val_loss: 1.8507 - val_acc: 0.4984\n",
"Epoch 50/75\n",
" - 22s - loss: 1.5360 - acc: 0.5529 - val_loss: 1.6361 - val_acc: 0.5548\n",
"Epoch 51/75\n",
" - 22s - loss: 1.5206 - acc: 0.5584 - val_loss: 1.6115 - val_acc: 0.5492\n",
"Epoch 52/75\n",
" - 22s - loss: 1.5119 - acc: 0.5608 - val_loss: 1.6899 - val_acc: 0.5492\n",
"Epoch 53/75\n",
" - 22s - loss: 1.5150 - acc: 0.5635 - val_loss: 1.6747 - val_acc: 0.5476\n",
"Epoch 54/75\n",
" - 22s - loss: 1.4974 - acc: 0.5636 - val_loss: 1.5890 - val_acc: 0.5548\n",
"Epoch 55/75\n",
" - 22s - loss: 1.4784 - acc: 0.5718 - val_loss: 1.6427 - val_acc: 0.5472\n",
"Epoch 56/75\n",
" - 22s - loss: 1.4665 - acc: 0.5738 - val_loss: 1.6516 - val_acc: 0.5392\n",
"Epoch 57/75\n",
" - 22s - loss: 1.4500 - acc: 0.5793 - val_loss: 1.7005 - val_acc: 0.5352\n",
"Epoch 58/75\n",
" - 22s - loss: 1.4353 - acc: 0.5838 - val_loss: 1.6790 - val_acc: 0.5504\n",
"Epoch 59/75\n",
" - 22s - loss: 1.4281 - acc: 0.5857 - val_loss: 1.6089 - val_acc: 0.5472\n",
"Epoch 60/75\n",
" - 22s - loss: 1.4501 - acc: 0.5890 - val_loss: 1.6866 - val_acc: 0.5444\n",
"Epoch 61/75\n",
" - 22s - loss: 1.4025 - acc: 0.5885 - val_loss: 1.6510 - val_acc: 0.5336\n",
"Epoch 62/75\n",
" - 22s - loss: 1.3906 - acc: 0.6001 - val_loss: 1.6181 - val_acc: 0.5596\n",
"Epoch 63/75\n",
" - 22s - loss: 1.3800 - acc: 0.5954 - val_loss: 1.5918 - val_acc: 0.5600\n",
"Epoch 64/75\n",
" - 22s - loss: 1.3950 - acc: 0.6009 - val_loss: 1.6006 - val_acc: 0.5700\n",
"Epoch 65/75\n",
" - 22s - loss: 1.3817 - acc: 0.6016 - val_loss: 1.6468 - val_acc: 0.5716\n",
"Epoch 66/75\n",
" - 22s - loss: 1.3489 - acc: 0.6160 - val_loss: 1.7986 - val_acc: 0.5664\n",
"Epoch 67/75\n",
" - 22s - loss: 1.3506 - acc: 0.6128 - val_loss: 1.6805 - val_acc: 0.5588\n",
"Epoch 68/75\n",
" - 22s - loss: 1.3637 - acc: 0.6135 - val_loss: 1.6333 - val_acc: 0.5652\n",
"Epoch 69/75\n",
" - 22s - loss: 1.3152 - acc: 0.6181 - val_loss: 1.5735 - val_acc: 0.5900\n",
"Epoch 70/75\n",
" - 22s - loss: 1.3193 - acc: 0.6218 - val_loss: 1.5246 - val_acc: 0.5700\n",
"Epoch 71/75\n",
" - 22s - loss: 1.3229 - acc: 0.6225 - val_loss: 1.6264 - val_acc: 0.5720\n",
"Epoch 72/75\n",
" - 22s - loss: 1.2847 - acc: 0.6201 - val_loss: 1.5936 - val_acc: 0.5624\n",
"Epoch 73/75\n",
" - 22s - loss: 1.3131 - acc: 0.6262 - val_loss: 1.6841 - val_acc: 0.5512\n",
"Epoch 74/75\n",
" - 22s - loss: 1.3026 - acc: 0.6253 - val_loss: 1.5475 - val_acc: 0.5740\n",
"Epoch 75/75\n",
" - 22s - loss: 1.2590 - acc: 0.6329 - val_loss: 1.5790 - val_acc: 0.5704\n"
]
}
],
"source": [
"# Load CNN if already saved\n",
"path_to_cnn = os.path.join(DIR, \"cnn.h5\")\n",
"if os.path.isfile(path_to_cnn):\n",
" cnn = models.load_model(path_to_cnn)\n",
"\n",
"else:\n",
" \n",
" # Create CNN\n",
" # Architecture: ((2D convolution)*2->MaxPool->Dropout regularizer)*2->Flatten->Dropout->Probability vector\n",
"\n",
" cnn = models.Sequential()\n",
"\n",
" cnn.add(layers.Conv2D(filters=32,\n",
" kernel_size=(3, 3),\n",
" activation=\"relu\",\n",
" input_shape=(100, 100, 1),\n",
" padding=\"same\",\n",
" ))\n",
"\n",
" cnn.add(layers.Conv2D(filters=64,\n",
" kernel_size=(3, 3),\n",
" activation=\"relu\",\n",
" padding=\"same\",\n",
" ))\n",
"\n",
" cnn.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
" \n",
" cnn.add(layers.Dropout(0.25))\n",
"\n",
" cnn.add(layers.Conv2D(filters=64,\n",
" kernel_size=(3, 3),\n",
" activation=\"relu\",\n",
" padding=\"same\",\n",
" ))\n",
"\n",
" cnn.add(layers.Conv2D(filters=64,\n",
" kernel_size=(3, 3),\n",
" activation=\"relu\",\n",
" padding=\"same\",\n",
" ))\n",
" \n",
" cnn.add(layers.MaxPooling2D(pool_size=(2, 2)))\n",
" cnn.add(layers.Dropout(0.25))\n",
" \n",
" cnn.add(layers.Flatten())\n",
"\n",
" cnn.add(layers.Dense(256, activation=\"relu\"))\n",
"\n",
" cnn.add(layers.Dropout(0.5))\n",
"\n",
" cnn.add(layers.Dense(n_classes, activation=\"softmax\"))\n",
"\n",
"\n",
" cnn.compile(optimizer=keras.optimizers.Adadelta(),\n",
" loss=\"categorical_crossentropy\",\n",
" metrics=[\"accuracy\"])\n",
" # Fit CNN\n",
" losses = cnn.fit_generator(X_train_transformed,\n",
" epochs=75,\n",
" steps_per_epoch=X_train.shape[0] // BATCH_SIZE,\n",
" verbose=2, \n",
" validation_data=(X_valid, y_valid),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "8a7295de1d97cdc773aa58216751837e18fb72dc",
"colab": {},
"colab_type": "code",
"id": "nSqxmi2F5BU4"
},
"outputs": [],
"source": [
"cnn.save(os.path.join(DIR, \"cnn.h5\"))"
]
},
{
"cell_type": "markdown",
"metadata": {
"_uuid": "2731504a1df20287f591dd5736147349d7ad8859",
"colab_type": "text",
"id": "EpRcW6IrBvZE"
},
"source": [
"### Save test predicitions"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"_uuid": "298a6edbcefb143bd7ecb6311bf9b29069ab6b0e",
"colab": {},
"colab_type": "code",
"id": "NDxlSqIY_LKI"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"predictions = cnn.predict(X_test).argmax(axis=1)\n",
"predictions_string = le.inverse_transform(predictions).astype('U15')\n",
"df = pd.DataFrame(predictions_string)\n",
"df.index.name = \"Id\"\n",
"df.to_csv(os.path.join(DIR, \"submission.csv\"), header = [\"Category\"])"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"YNEEyLtrXYLQ",
"RlemX_aQ0D5Q"
],
"name": "CNN2.ipynb",
"provenance": [],
"toc_visible": true,
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 1
}