{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xDSYr1OFgFL2"
   },
   "source": [
    "# Making a RNN learn Addition and Subtraction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "OW8z_KHAgFL3"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.layers import TimeDistributed, Dense, Dropout, SimpleRNN, RepeatVector\n",
    "from tensorflow.keras.callbacks import EarlyStopping, LambdaCallback\n",
    "\n",
    "from termcolor import colored"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F5ywCLDKgFL5"
   },
   "source": [
    "## Generating Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "hP-OyQrivBNl",
    "outputId": "141b545a-ea9d-4da9-de2c-504bd2726de6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12"
      ]
     },
     "execution_count": 2,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_chars = \"1234567890+-\"\n",
    "num_features = len(all_chars)\n",
    "num_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "Z5LZ2kFBgFL9"
   },
   "outputs": [],
   "source": [
    "# Tokenizing the characters\n",
    "char_to_idx = {c:i for i,c in enumerate(all_chars)}\n",
    "idx_to_char = {i:c for i,c in enumerate(all_chars)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BCB2D3ZXg98R",
    "outputId": "f8f96054-a739-4f4a-9f50-0515cd280c2c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5, '7': 6, '8': 7, '9': 8, '0': 9, '+': 10, '-': 11}\n",
      "{0: '1', 1: '2', 2: '3', 3: '4', 4: '5', 5: '6', 6: '7', 7: '8', 8: '9', 9: '0', 10: '+', 11: '-'}\n"
     ]
    }
   ],
   "source": [
    "print(char_to_idx)\n",
    "print(idx_to_char)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "CESgYbgagFL9"
   },
   "outputs": [],
   "source": [
    "def random_operation(num1, num2):\n",
    "    \"\"\"Applies random operation from `+, -` and return result and operator\"\"\"\n",
    "    \n",
    "    operator = np.random.choice(['+', '-'])\n",
    "    result = 0\n",
    "    if operator == '+':\n",
    "        result = num1 + num2\n",
    "    else:\n",
    "        result = num1 - num2\n",
    "\n",
    "    return result, operator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "B1doFIX7gFL9",
    "outputId": "afe23ab3-78f2-43a3-ca27-ff2ee341ab3e"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, '+')"
      ]
     },
     "execution_count": 6,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random_operation(5,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "L2b1Xg9hvBNr",
    "outputId": "5cef2d06-e48c-457a-8686-5d9be028bcac"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('90-27', '63')"
      ]
     },
     "execution_count": 7,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def generate_data():\n",
    "    \"\"\"Generates sample(str) - results(str) pair \"\"\"\n",
    "    first_num = np.random.randint(0,100)\n",
    "    second_num = np.random.randint(0,100)\n",
    "    res, opr = random_operation(first_num, second_num)\n",
    "    sample = str(first_num) + opr + str(second_num)\n",
    "    label = str(res)\n",
    "    return (sample, label)\n",
    "\n",
    "# Test\n",
    "generate_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SkOdd4NFgFL-"
   },
   "source": [
    "## Creating the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "iZKl0LtdvBNy",
    "outputId": "53de59b0-f362-4da4-b9e0-a19f0442173c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "simple_rnn (SimpleRNN)       (None, 128)               18048     \n",
      "_________________________________________________________________\n",
      "repeat_vector (RepeatVector) (None, 5, 128)            0         \n",
      "_________________________________________________________________\n",
      "simple_rnn_1 (SimpleRNN)     (None, 5, 128)            32896     \n",
      "_________________________________________________________________\n",
      "time_distributed (TimeDistri (None, 5, 12)             1548      \n",
      "=================================================================\n",
      "Total params: 52,492\n",
      "Trainable params: 52,492\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "hidden_units = 128\n",
    "max_timesteps = 5\n",
    "\n",
    "model = Sequential([\n",
    "    # Encoder\n",
    "    SimpleRNN(hidden_units, input_shape=(None, num_features)),\n",
    "    RepeatVector(max_timesteps),\n",
    "    # Decoder\n",
    "    SimpleRNN(hidden_units, return_sequences=True),\n",
    "    TimeDistributed(Dense(num_features, activation=\"softmax\"))\n",
    "])\n",
    "\n",
    "model.compile('adam', loss='categorical_crossentropy', metrics=['accuracy'])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "goIaTKykgFL_"
   },
   "source": [
    "## Vectorize and De-Vectorize Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "Ci29YaA6vBN1"
   },
   "outputs": [],
   "source": [
    " def vectorize_sample(sample, label):\n",
    "    x = np.zeros((max_timesteps, num_features))\n",
    "    y = np.zeros((max_timesteps, num_features))\n",
    "    \n",
    "    diff_x = max_timesteps - len(sample)\n",
    "    diff_y = max_timesteps - len(label)\n",
    "    \n",
    "    for i, c in enumerate(sample):\n",
    "        x[diff_x + i, char_to_idx[c]] = 1\n",
    "    for i in range(diff_x):\n",
    "        x[i, char_to_idx['0']] = 1\n",
    "    \n",
    "    for i, c in enumerate(label):\n",
    "        y[diff_y + i, char_to_idx[c]] = 1\n",
    "    for i in range(diff_y):\n",
    "        y[i, char_to_idx['0']] = 1\n",
    "    \n",
    "    return (x,y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KhL2lywJgFL_",
    "outputId": "7c9130ed-08b7-430e-bb7d-a963bb9b0d51"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "78-26 = 52\n"
     ]
    }
   ],
   "source": [
    "s1,l1 = generate_data()\n",
    "print(s1 +' = ' + l1)\n",
    "s1v, l1v = vectorize_sample(s1, l1)\n",
    "# print(s1v)\n",
    "# print(l1v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "_nARKq-bvBN9"
   },
   "outputs": [],
   "source": [
    "def devectorize_sample(sample):\n",
    "    sample_char_list = [idx_to_char[i] for i in np.argmax(sample, axis=1)]\n",
    "    sample = ''.join(sample_char_list).lstrip('0')\n",
    "    return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "id": "KCGUxNwNvBOA",
    "outputId": "16cbe44e-9bfb-4510-9e5f-d7202088fe4e"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'52'"
      ]
     },
     "execution_count": 12,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "devectorize_sample(l1v)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "K8kMTEwigFL_"
   },
   "source": [
    "## Creating Trainset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "T5rdrhiVvBOI"
   },
   "outputs": [],
   "source": [
    "def create_dataset(num_sample=1000):\n",
    "    x = np.zeros((num_sample, max_timesteps, num_features))\n",
    "    y = np.zeros((num_sample, max_timesteps, num_features))\n",
    "    \n",
    "    for i in range(num_sample):\n",
    "        s, l = generate_data()\n",
    "        sv, lv = vectorize_sample(s,l)\n",
    "        x[i] = sv\n",
    "        y[i] = lv\n",
    "    \n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "cRUvtd0IvBOO",
    "outputId": "be88c787-9b92-4c30-eb5f-e9b65d353be9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 5, 12) (10000, 5, 12)\n"
     ]
    }
   ],
   "source": [
    "x_train, y_train = create_dataset(10000)\n",
    "print(x_train.shape, y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Van9N7xrgFMA",
    "outputId": "2251f065-db06-4cdc-e0c2-62f89886915b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('4-9', '-5')"
      ]
     },
     "execution_count": 15,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "devectorize_sample(x_train[0]), devectorize_sample(y_train[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_TGUW30XgFMA"
   },
   "source": [
    "## Training the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "S8HWRdiOvBOS",
    "outputId": "260d4b0d-aaa0-4c0a-abe3-f89ebe18c52c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "->0.60->0.63->0.63->0.65->0.67->0.68->0.70->0.71->0.72->0.73->0.74->0.76->0.76->0.77->0.79->0.80->0.82->0.84->0.85->0.87->0.88->0.89->0.90->0.91->0.92->0.93->0.94->0.94->0.94->0.95->0.96->0.95->0.96->0.96->0.97->0.97->0.97->0.97->0.97->0.97->0.97->0.97->0.97->0.98->0.98->0.96->0.97->0.98->0.98->0.98->0.98->0.98->0.98->0.98->0.98->0.98->0.98->0.98->0.99->0.99->0.98->0.99->0.98->0.98->0.99->0.99->0.99->0.98->0.98->0.97->0.96"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f57da6a4390>"
      ]
     },
     "execution_count": 16,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lb_cb = LambdaCallback(\n",
    "    on_epoch_end=lambda e, l: print('->{:.2f}'.format(l['val_accuracy']), end='')\n",
    ")\n",
    "es_cb = EarlyStopping(monitor='val_loss', patience=5)\n",
    "\n",
    "model.fit(x_train, y_train, epochs=100, batch_size=256, validation_split=0.1,\n",
    "          callbacks=[lb_cb, es_cb], verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tlNtu_9ZvBOY",
    "outputId": "9e0c27ae-eca2-45a5-a84f-624542cd6a3e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32mSample: 89-19 Actual: 70 Predicted: 70\u001b[0m\n",
      "\u001b[32mSample: 3+32 Actual: 35 Predicted: 35\u001b[0m\n",
      "\u001b[31mSample: 54+74 Actual: 128 Predicted: 127\u001b[0m\n",
      "\u001b[32mSample: 55+72 Actual: 127 Predicted: 127\u001b[0m\n",
      "\u001b[32mSample: 9+93 Actual: 102 Predicted: 102\u001b[0m\n",
      "\u001b[32mSample: 76-39 Actual: 37 Predicted: 37\u001b[0m\n",
      "\u001b[32mSample: 56-59 Actual: -3 Predicted: -3\u001b[0m\n",
      "\u001b[32mSample: 37-29 Actual: 8 Predicted: 8\u001b[0m\n",
      "\u001b[32mSample: 78+85 Actual: 163 Predicted: 163\u001b[0m\n",
      "\u001b[32mSample: 52+86 Actual: 138 Predicted: 138\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "x_test, y_test = create_dataset(10)\n",
    "y_pred = model.predict(x_test)\n",
    "\n",
    "for i, pred in enumerate(y_pred):\n",
    "    y = devectorize_sample(y_test[i])\n",
    "    y_hat = devectorize_sample(pred)\n",
    "    col = 'green'\n",
    "    if y!=y_hat:\n",
    "        col = 'red'\n",
    "    out = 'Sample: ' + devectorize_sample(x_test[i]) + ' Actual: ' + y + ' Predicted: ' + y_hat\n",
    "    print(colored(out, col))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "rGYe0TRzgFMB",
    "outputId": "121f4d62-5c1a-40b4-8534-4c8c122630a5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy :: 0.9648000001907349\n"
     ]
    }
   ],
   "source": [
    "x_test, y_test = create_dataset(1000)\n",
    "_, acc = model.evaluate(x_test, y_test, verbose=0)\n",
    "print('Test Accuracy ::', acc)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Simple_RNN.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "tf2",
   "language": "python",
   "name": "tf2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}