{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Embeddings\n",
    "\n",
    "https://www.youtube.com/watch?v=wSXGlvTR9UM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Activation, Embedding, Merge, Flatten\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Age</th>\n",
       "      <th>Education</th>\n",
       "      <th>H_education</th>\n",
       "      <th>num_child</th>\n",
       "      <th>Religion</th>\n",
       "      <th>Employ</th>\n",
       "      <th>H_occupation</th>\n",
       "      <th>living_standard</th>\n",
       "      <th>Media_exposure</th>\n",
       "      <th>contraceptive</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>24</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>45</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>10</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>43</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>42</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>36</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Age  Education  H_education  num_child  Religion  Employ  H_occupation  \\\n",
       "0   24          2            3          3         1       1             2   \n",
       "1   45          1            3         10         1       1             3   \n",
       "2   43          2            3          7         1       1             3   \n",
       "3   42          3            2          9         1       1             3   \n",
       "4   36          3            3          8         1       1             3   \n",
       "\n",
       "   living_standard  Media_exposure  contraceptive  \n",
       "0                3               0              1  \n",
       "1                4               0              1  \n",
       "2                4               0              1  \n",
       "3                3               0              1  \n",
       "4                2               0              1  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('data/cmc.data',header=None,names=['Age','Education','H_education',\n",
    "                                                     'num_child','Religion', 'Employ',\n",
    "                                                     'H_occupation','living_standard',\n",
    "                                                     'Media_exposure','contraceptive'])\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Age                False\n",
       "Education          False\n",
       "H_education        False\n",
       "num_child          False\n",
       "Religion           False\n",
       "Employ             False\n",
       "H_occupation       False\n",
       "living_standard    False\n",
       "Media_exposure     False\n",
       "contraceptive      False\n",
       "dtype: bool"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.isnull().any()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x12287b630>"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAE4RJREFUeJzt3X+sX3ddx/Hn23bA7MV2MLw2bbVLXEgmFVhvthKIuZdF\nUzayLnHgyNy6ZaaJDsQw4wqJGvwRyx+ATA3aMLIO0csywdVu0yxdr8gfG7Y418FAyihZb8oqW1e4\nbGqqb//4fsDL5bbf8/3e8+33ez8+H8nNPedzPud8P+/zuX3d0/P9cSMzkSTV60eGPQBJ0mAZ9JJU\nOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVa5R0EfEmoi4NyK+HBFPRsQbIuIVEfFQRHy1fL+g\n9I2IuCMijkTE4xFx6WBLkCSdTTR5Z2xE7AH+KTM/FhEvAX4UeB/wXGbuioidwAWZeXtEXAm8C7gS\nuBz4SGZefrbjX3jhhblx48a+Cvjud7/LqlWr+tp31FjL6KmlDrCWUbWUWg4dOvStzHxV146ZedYv\nYDXwdcovhXntXwHWluW1wFfK8l8A71is35m+Nm/enP06cOBA3/uOGmsZPbXUkWkto2optQAHs0uG\nZ2b3K/qIeB2wG/gS8FrgEPBuYDYz15Q+AZzMzDURsQ/YlZmfK9v2A7dn5sEFx90B7AAYHx/fPD09\n3fWX0mLm5uYYGxvra99RYy2jp5Y6wFpG1VJqmZqaOpSZE107dvtNAEwAp4HLy/pHgN8Hnl/Q72T5\nvg9407z2/cDE2R7DK/oOaxk9tdSRaS2j6lxc0Td5MvYYcCwzHy3r9wKXAs9ExFqA8v1E2T4LbJi3\n//rSJkkagq5Bn5nfBJ6OiFeXpivo3MbZC2wvbduB+8ryXuDG8uqbLcCpzDze7rAlSU2tbNjvXcAn\nyytungJupvNL4p6IuAX4BvD20vcBOq+4OQK8UPpKkoakUdBn5mN07tUvdMUifRO4dYnjkiS1xHfG\nSlLlDHpJqpxBL0mVa/pkrCRVa+PO+4f22HdtHfxHOXhFL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNe\nkipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWp\ncga9JFXOoJekyhn0klQ5g16SKmfQS1LlGgV9RByNiMMR8VhEHCxtr4iIhyLiq+X7BaU9IuKOiDgS\nEY9HxKWDLECSdHa9XNFPZebrMnOirO8E9mfmxcD+sg7wFuDi8rUD+Ghbg5Uk9W4pt262AXvK8h7g\nmnntd2fHI8CaiFi7hMeRJC1BZGb3ThFfB04CCfxFZu6OiOczc03ZHsDJzFwTEfuAXZn5ubJtP3B7\nZh5ccMwddK74GR8f3zw9Pd1XAXNzc4yNjfW176ixltFTSx1gLWdzePZUa8fq1UWrV/Rdy9TU1KF5\nd1nOaGXD470pM2cj4seBhyLiy/M3ZmZGRPffGD+4z25gN8DExEROTk72svv3zczM0O++o8ZaRk8t\ndYC1nM1NO+9v7Vi9umvrqoHPS6NbN5k5W76fAD4DXAY8871bMuX7idJ9Ftgwb/f1pU2SNARdgz4i\nVkXEy7+3DPwC8ASwF9heum0H7ivLe4Eby6tvtgCnMvN46yOXJDXS5NbNOPCZzm14VgJ/lZl/HxH/\nDNwTEbcA3wDeXvo/AFwJHAFeAG5ufdSSpMa6Bn1mPgW8dpH2Z4ErFmlP4NZWRidJWjLfGStJlTPo\nJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16S\nKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJaly\nBr0kVW5l044RsQI4CMxm5lsj4iJgGnglcAi4ITP/KyJeCtwNbAaeBX4pM4+2PnKpcht33t/q8W7b\ndJqbGhzz6K6rWn1cDV8vV/TvBp6ct/4B4MOZ+dPASeCW0n4LcLK0f7j0kyQNSaOgj4j1wFXAx8p6\nAG8G7i1d9gDXlOVtZZ2y/YrSX5I0BJGZ3TtF3Av8EfBy4DeBm4BHylU7EbEBeDAzXxMRTwBbM/NY\n2fY14PLM/NaCY+4AdgCMj49vnp6e7quAubk5xsbG+tp31FjL6BlmHYdnT7V6vPHz4ZkXu/fbtG51\nq487CG3PS9vnuhcXrV7Rdy1TU1OHMnOiW7+u9+gj4q3Aicw8FBGTfY1mEZm5G9gNMDExkZOT/R16\nZmaGfvcdNdYyeoZZR5P76b24bdNpPni4+9NyR6+fbPVxB6HteWn7XPfirq2rBv4z1uTJ2DcCV0fE\nlcDLgB8DPgKsiYiVmXkaWA/Mlv6zwAbgWESsBFbTeVJWkjQEXe/RZ+Z7M3N9Zm4ErgMezszrgQPA\ntaXbduC+sry3rFO2P5xN7g9JkgZiKa+jvx14T0QcofMSyztL+53AK0v7e4CdSxuiJGkpGr+OHiAz\nZ4CZsvwUcNkiff4DeFsLY5MktcB3xkpS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIq\nZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqlxPf0pQ2rjz/kb9\nbtt0mpsa9m3i6K6rWjuW9P+NV/SSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFWua9BHxMsi\n4vMR8a8R8cWIeH9pvygiHo2IIxHxqYh4SWl/aVk/UrZvHGwJkqSzaXJF/5/AmzPztcDrgK0RsQX4\nAPDhzPxp4CRwS+l/C3CytH+49JMkDUnXoM+OubJ6XvlK4M3AvaV9D3BNWd5W1inbr4iIaG3EkqSe\nNLpHHxErIuIx4ATwEPA14PnMPF26HAPWleV1wNMAZfsp4JVtDlqS1FxkZvPOEWuAzwC/DdxVbs8Q\nERuABzPzNRHxBLA1M4+VbV8DLs/Mby041g5gB8D4+Pjm6enpvgqYm5tjbGysr31HzXKo5fDsqUb9\nxs+HZ15s73E3rVvd3sF6MMw5aXqum2o6J8M6171oe17aPte9uGj1ir5rmZqaOpSZE9369fShZpn5\nfEQcAN4ArImIleWqfT0wW7rNAhuAYxGxElgNPLvIsXYDuwEmJiZycnKyl6F838zMDP3uO2qWQy1N\nP6jstk2n+eDh9j4z7+j1k60dqxfDnJM2PxQOms/JsM51L9qel7bPdS/u2rpq4D9jTV5186pyJU9E\nnA/8PPAkcAC4tnTbDtxXlveWdcr2h7OX/zZIklrV5JJrLbAnIlbQ+cVwT2bui4gvAdMR8QfAvwB3\nlv53Ap+IiCPAc8B1Axi3JKmhrkGfmY8Dr1+k/SngskXa/wN4WyujkyQtme+MlaTKGfSSVDmDXpIq\nZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIG\nvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIq1zXo\nI2JDRByIiC9FxBcj4t2l/RUR8VBEfLV8v6C0R0TcERFHIuLxiLh00EVIks6syRX9aeC2zLwE2ALc\nGhGXADuB/Zl5MbC/rAO8Bbi4fO0APtr6qCVJjXUN+sw8nplfKMvfAZ4E1gHbgD2l2x7gmrK8Dbg7\nOx4B1kTE2tZHLklqpKd79BGxEXg98CgwnpnHy6ZvAuNleR3w9LzdjpU2SdIQRGY26xgxBvwj8IeZ\n+emIeD4z18zbfjIzL4iIfcCuzPxcad8P3J6ZBxccbwedWzuMj49vnp6e7quAubk5xsbG+tp31CyH\nWg7PnmrUb/x8eObF9h5307rV7R2sB8Ock6bnuqmmczKsc92Ltuel7XPdi4tWr+i7lqmpqUOZOdGt\n38omB4uI84C/AT6ZmZ8uzc9ExNrMPF5uzZwo7bPAhnm7ry9tPyAzdwO7ASYmJnJycrLJUH7IzMwM\n/e47apZDLTftvL9Rv9s2neaDhxv9eDVy9PrJ1o7Vi2HOSdNz3VTTORnWue5F2/PS9rnuxV1bVw38\nZ6zJq24CuBN4MjM/NG/TXmB7Wd4O3Dev/cby6pstwKl5t3gkSedYk0uuNwI3AIcj4rHS9j5gF3BP\nRNwCfAN4e9n2AHAlcAR4Abi51RFLknrSNejLvfY4w+YrFumfwK1LHJckqSW+M1aSKmfQS1LlDHpJ\nqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlWvvA8OH5PDsqaF9lvTR\nXVcN5XElqRde0UtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEv\nSZUz6CWpcga9JFXOoJekynUN+oj4eESciIgn5rW9IiIeioivlu8XlPaIiDsi4khEPB4Rlw5y8JKk\n7ppc0d8FbF3QthPYn5kXA/vLOsBbgIvL1w7go+0MU5LUr65Bn5mfBZ5b0LwN2FOW9wDXzGu/Ozse\nAdZExNq2BitJ6l2/9+jHM/N4Wf4mMF6W1wFPz+t3rLRJkoYkMrN7p4iNwL7MfE1Zfz4z18zbfjIz\nL4iIfcCuzPxcad8P3J6ZBxc55g46t3cYHx/fPD093VcBJ547xTMv9rXrkm1at7rV483NzTE2Ntbq\nMdt2ePZUo37j59PqvLR9rpsa5pw0PddNNZ2TYZ3rXrQ9L22f615ctHpF37VMTU0dysyJbv36/Zux\nz0TE2sw8Xm7NnCjts8CGef3Wl7Yfkpm7gd0AExMTOTk52ddA/uST9/HBw8P507dHr59s9XgzMzP0\nex7OlaZ/n/e2TadbnZe2z3VTw5yTtv8WctM5Gda57kXb8zKsvzsNcNfWVQP/Gev31s1eYHtZ3g7c\nN6/9xvLqmy3AqXm3eCRJQ9D113tE/DUwCVwYEceA3wV2AfdExC3AN4C3l+4PAFcCR4AXgJsHMGZJ\nUg+6Bn1mvuMMm65YpG8Cty51UJKk9vjOWEmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQ\nS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0k\nVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcgMJ+ojYGhFfiYgjEbFzEI8hSWqm9aCP\niBXAnwFvAS4B3hERl7T9OJKkZgZxRX8ZcCQzn8rM/wKmgW0DeBxJUgODCPp1wNPz1o+VNknSEERm\ntnvAiGuBrZn5K2X9BuDyzHzngn47gB1l9dXAV/p8yAuBb/W576ixltFTSx1gLaNqKbX8VGa+qlun\nlX0e/GxmgQ3z1teXth+QmbuB3Ut9sIg4mJkTSz3OKLCW0VNLHWAto+pc1DKIWzf/DFwcERdFxEuA\n64C9A3gcSVIDrV/RZ+bpiHgn8A/ACuDjmfnFth9HktTMIG7dkJkPAA8M4tiLWPLtnxFiLaOnljrA\nWkbVwGtp/clYSdJo8SMQJKlyyyLoI+LjEXEiIp44w/aIiDvKRy48HhGXnusxNtWglsmIOBURj5Wv\n3znXY2wqIjZExIGI+FJEfDEi3r1In5Gfm4Z1LIt5iYiXRcTnI+JfSy3vX6TPSyPiU2VOHo2Ijed+\npN01rOWmiPj3efPyK8MYaxMRsSIi/iUi9i2ybbBzkpkj/wX8HHAp8MQZtl8JPAgEsAV4dNhjXkIt\nk8C+YY+zYS1rgUvL8suBfwMuWW5z07COZTEv5TyPleXzgEeBLQv6/Brw52X5OuBTwx73Emq5CfjT\nYY+1YT3vAf5qsZ+jQc/Jsriiz8zPAs+dpcs24O7seARYExFrz83oetOglmUjM49n5hfK8neAJ/nh\nd0GP/Nw0rGNZKOd5rqyeV74WPhG3DdhTlu8FroiIOEdDbKxhLctCRKwHrgI+doYuA52TZRH0DdT2\nsQtvKP9dfTAifmbYg2mi/Ffz9XSuuuZbVnNzljpgmcxLuUXwGHACeCgzzzgnmXkaOAW88tyOspkG\ntQD8YrkteG9EbFhk+yj4Y+C3gP85w/aBzkktQV+TL9B5W/NrgT8B/nbI4+kqIsaAvwF+IzO/Pezx\n9KtLHctmXjLzvzPzdXTelX5ZRLxm2GPqV4Na/g7YmJk/CzzE/10Vj4yIeCtwIjMPDWsMtQR9o49d\nWA4y89vf++9qdt6PcF5EXDjkYZ1RRJxHJxw/mZmfXqTLspibbnUst3kByMzngQPA1gWbvj8nEbES\nWA08e25H15sz1ZKZz2bmf5bVjwGbz/XYGngjcHVEHKXzab5vjoi/XNBnoHNSS9DvBW4sr/DYApzK\nzOPDHlQ/IuInvndvLiIuozNHI/mPsIzzTuDJzPzQGbqN/Nw0qWO5zEtEvCoi1pTl84GfB768oNte\nYHtZvhZ4OMuzgKOkSS0Lnu+5ms7zKyMlM9+bmeszcyOdJ1ofzsxfXtBtoHMykHfGti0i/prOqx4u\njIhjwO/SeWKGzPxzOu/CvRI4ArwA3DyckXbXoJZrgV+NiNPAi8B1o/iPsHgjcANwuNxHBXgf8JOw\nrOamSR3LZV7WAnui8weAfgS4JzP3RcTvAQczcy+dX2qfiIgjdF4YcN3whntWTWr59Yi4GjhNp5ab\nhjbaHp3LOfGdsZJUuVpu3UiSzsCgl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcv8LAfci\ngdNBX0MAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x122831a90>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df.Education.hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1473, 10)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x11a97b6a0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFJxJREFUeJzt3X+MndV95/H3tzYkXU/WAyGdtWxv7VWtXZHQJHhESRNV\nM0G7NU67ZqUUEaHGIEuWWrZK1f2BW6mt2u5K5A82TdBuulaJMJWbAdGwtgzJFjme7WYjnOKUYH4k\nmwlxikfEVrCZZgJt5fTbP+4hXIzH89w795eP3i/pap7nnHPv/T4Ph88899y515GZSJLq9WPDLkCS\n1F8GvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyq4ddAMBVV12VmzZt6uq+P/jB\nD1izZk1vC+oB6+qMdXVuVGuzrs6spK5jx459LzPfsezAzBz6bevWrdmtI0eOdH3ffrKuzlhX50a1\nNuvqzErqAp7IBhnr0o0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFVuJL4C\nYSWOzy9w255HhvLcJ+760FCeV5I64RW9JFXOoJekyhn0klQ5g16SKmfQS1LlGgV9RIxHxEMR8fWI\neC4i3hcRV0bEYxHxzfLzijI2IuJTETEXEU9FxLX9PQRJ0sU0vaL/JPCFzPxXwLuB54A9wOHM3AIc\nLvsANwJbym038OmeVixJ6siyQR8Ra4GfA+4FyMy/z8yXgR3AvjJsH3BT2d4B3F/+AZTHgfGIWNfz\nyiVJjUTrX6O6yICI9wB7gWdpXc0fAz4GzGfmeBkTwNnMHI+IQ8Bdmfml0ncYuDMznzjvcXfTuuJn\nYmJi68zMTFcHcPrMAqde7equK3bN+rVL9i0uLjI2NjbAapqxrs6Mal0wurVZV2dWUtf09PSxzJxc\nblyTT8auBq4Ffi0zj0bEJ3l9mQaAzMyIuPhvjPNk5l5av0CYnJzMqampTu7+I/fsP8Ddx4fzAd8T\nt04t2Tc7O0u3x9RP1tWZUa0LRrc26+rMIOpqskZ/EjiZmUfL/kO0gv/Ua0sy5efp0j8PbGy7/4bS\nJkkagmWDPjO/C7wQEf+yNN1AaxnnILCztO0EDpTtg8BHy1/fXA8sZOaLvS1bktRU0zWPXwP2R8Tl\nwPPA7bR+STwYEbuA7wA3l7GPAtuBOeCVMlaSNCSNgj4znwQutOB/wwXGJnDHCuuSJPWIn4yVpMoZ\n9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEv\nSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVbnWTQRFxAvg+\n8EPgXGZORsSVwAPAJuAEcHNmno2IAD4JbAdeAW7LzK/2vnRJ6o1Nex4Z2nPft21N35+jkyv66cx8\nT2ZOlv09wOHM3AIcLvsANwJbym038OleFStJ6txKlm52APvK9j7gprb2+7PlcWA8Itat4HkkSSsQ\nmbn8oIhvA2eBBP5nZu6NiJczc7z0B3A2M8cj4hBwV2Z+qfQdBu7MzCfOe8zdtK74mZiY2DozM9PV\nAZw+s8CpV7u664pds37tkn2Li4uMjY0NsJpmrKszo1oXjG5tl2Jdx+cXBlzN6zavXdX1+Zqenj7W\ntsqypEZr9MAHMnM+In4CeCwivt7emZkZEcv/xnjjffYCewEmJydzamqqk7v/yD37D3D38aaH0Vsn\nbp1asm92dpZuj6mfrKszo1oXjG5tl2Jdtw15jb7f56vR0k1mzpefp4GHgeuAU68tyZSfp8vweWBj\n2903lDZJ0hAsG/QRsSYi3vbaNvBvgKeBg8DOMmwncKBsHwQ+Gi3XAwuZ+WLPK5ckNdJkzWMCeLi1\nDM9q4E8z8wsR8ZfAgxGxC/gOcHMZ/yitP62co/Xnlbf3vGpJUmPLBn1mPg+8+wLtLwE3XKA9gTt6\nUp0kacX8ZKwkVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5\ng16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlWsc\n9BGxKiL+KiIOlf3NEXE0IuYi4oGIuLy0v6Xsz5X+Tf0pXZLURCdX9B8Dnmvb/zjwicz8KeAssKu0\n7wLOlvZPlHGSpCFpFPQRsQH4EPDHZT+ADwIPlSH7gJvK9o6yT+m/oYyXJA1B0yv6PwT+M/APZf/t\nwMuZea7snwTWl+31wAsApX+hjJckDUFk5sUHRPwCsD0zfzUipoD/CNwGPF6WZ4iIjcDnM/NdEfE0\nsC0zT5a+bwE/k5nfO+9xdwO7ASYmJrbOzMx0dQCnzyxw6tWu7rpi16xfu2Tf4uIiY2NjA6ymGevq\nzKjWBaNb26VY1/H5hQFX87rNa1d1fb6mp6ePZebkcuNWN3is9wP/NiK2A28F/inwSWA8IlaXq/YN\nwHwZPw9sBE5GxGpgLfDS+Q+amXuBvQCTk5M5NTXVoJQ3u2f/Ae4+3uQweu/ErVNL9s3OztLtMfWT\ndXVmVOuC0a3tUqzrtj2PDLaYNvdtW9P387Xs0k1m/mZmbsjMTcAtwBcz81bgCPDhMmwncKBsHyz7\nlP4v5nIvGyRJfbOSv6O/E/iNiJijtQZ/b2m/F3h7af8NYM/KSpQkrURHax6ZOQvMlu3ngesuMOZv\ngV/qQW2SpB7wk7GSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQ\nS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUueH8q9rSJeL4/MLQ/uHoE3d9aCjPq/p4RS9J\nlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqt2zQR8RbI+IrEfG1iHgmIn6vtG+OiKMRMRcRD0TE5aX9\nLWV/rvRv6u8hSJIupskV/d8BH8zMdwPvAbZFxPXAx4FPZOZPAWeBXWX8LuBsaf9EGSdJGpJlgz5b\nFsvuZeWWwAeBh0r7PuCmsr2j7FP6b4iI6FnFkqSONFqjj4hVEfEkcBp4DPgW8HJmnitDTgLry/Z6\n4AWA0r8AvL2XRUuSmovMbD44Yhx4GPht4L6yPENEbAQ+n5nvioingW2ZebL0fQv4mcz83nmPtRvY\nDTAxMbF1ZmamqwM4fWaBU692ddcVu2b92iX7FhcXGRsbG2A1zVhXZ0Z1fsHonrNLsa7j8wsDruZ1\nm9eu6vp8TU9PH8vMyeXGdfRdN5n5ckQcAd4HjEfE6nLVvgGYL8PmgY3AyYhYDawFXrrAY+0F9gJM\nTk7m1NRUJ6X8yD37D3D38eF8Zc+JW6eW7JudnaXbY+on6+rMqM4vGN1zdinWNazvMwK4b9uavp+v\nJn91845yJU9E/Djwr4HngCPAh8uwncCBsn2w7FP6v5idvGyQJPVUk0uVdcC+iFhF6xfDg5l5KCKe\nBWYi4r8AfwXcW8bfC/xJRMwBZ4Bb+lC3JKmhZYM+M58C3nuB9ueB6y7Q/rfAL/WkOknSivnJWEmq\nnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ\n9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuWWDfqI2BgR\nRyLi2Yh4JiI+VtqvjIjHIuKb5ecVpT0i4lMRMRcRT0XEtf0+CEnS0ppc0Z8D/kNmXg1cD9wREVcD\ne4DDmbkFOFz2AW4EtpTbbuDTPa9aktTYskGfmS9m5lfL9veB54D1wA5gXxm2D7ipbO8A7s+Wx4Hx\niFjX88olSY10tEYfEZuA9wJHgYnMfLF0fReYKNvrgRfa7naytEmShiAys9nAiDHg/wD/NTM/FxEv\nZ+Z4W//ZzLwiIg4Bd2Xml0r7YeDOzHzivMfbTWtph4mJia0zMzNdHcDpMwucerWru67YNevXLtm3\nuLjI2NjYAKtpxro6M6rzC0b3nF2KdR2fXxhwNa/bvHZV1+drenr6WGZOLjdudZMHi4jLgD8D9mfm\n50rzqYhYl5kvlqWZ06V9HtjYdvcNpe0NMnMvsBdgcnIyp6ammpTyJvfsP8DdxxsdRs+duHVqyb7Z\n2Vm6PaZ+sq7OjOr8gtE9Z5diXbfteWSwxbS5b9uavp+vJn91E8C9wHOZ+d/aug4CO8v2TuBAW/tH\ny1/fXA8stC3xSJIGrMmlyvuBXwaOR8STpe23gLuAByNiF/Ad4ObS9yiwHZgDXgFu72nFkqSOLBv0\nZa09lui+4QLjE7hjhXVJknrET8ZKUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQ\nS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0k\nVc6gl6TKGfSSVDmDXpIqt2zQR8RnIuJ0RDzd1nZlRDwWEd8sP68o7RERn4qIuYh4KiKu7WfxkqTl\nNbmivw/Ydl7bHuBwZm4BDpd9gBuBLeW2G/h0b8qUJHVr2aDPzL8AzpzXvAPYV7b3ATe1td+fLY8D\n4xGxrlfFSpI61+0a/URmvli2vwtMlO31wAtt406WNknSkERmLj8oYhNwKDPfVfZfzszxtv6zmXlF\nRBwC7srML5X2w8CdmfnEBR5zN63lHSYmJrbOzMx0dQCnzyxw6tWu7rpi16xfu2Tf4uIiY2NjA6ym\nGevqzKjOLxjdc3Yp1nV8fmHA1bxu89pVXZ+v6enpY5k5udy41V09OpyKiHWZ+WJZmjld2ueBjW3j\nNpS2N8nMvcBegMnJyZyamuqqkHv2H+Du490exsqcuHVqyb7Z2Vm6PaZ+sq7OjOr8gtE9Z5diXbft\neWSwxbS5b9uavp+vbpduDgI7y/ZO4EBb+0fLX99cDyy0LfFIkoZg2UuViPgsMAVcFREngd8F7gIe\njIhdwHeAm8vwR4HtwBzwCnB7H2qWJHVg2aDPzI8s0XXDBcYmcMdKi5Ik9Y6fjJWkyhn0klQ5g16S\nKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJaly\nBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFWuL0EfEdsi4hsRMRcRe/rx\nHJKkZnoe9BGxCvjvwI3A1cBHIuLqXj+PJKmZflzRXwfMZebzmfn3wAywow/PI0lqoB9Bvx54oW3/\nZGmTJA3B6mE9cUTsBnaX3cWI+EaXD3UV8L3eVNWZ+PhFu4dW1zKsqzOjOr/Ac9apkaxr+uMrqusn\nmwzqR9DPAxvb9jeUtjfIzL3A3pU+WUQ8kZmTK32cXrOuzlhX50a1NuvqzCDq6sfSzV8CWyJic0Rc\nDtwCHOzD80iSGuj5FX1mnouIfw/8b2AV8JnMfKbXzyNJaqYva/SZ+SjwaD8e+wJWvPzTJ9bVGevq\n3KjWZl2d6XtdkZn9fg5J0hD5FQiSVLmRDfqI+ExEnI6Ip5foj4j4VPmahaci4tq2vp0R8c1y2zng\num4t9RyPiC9HxLvb+k6U9icj4okB1zUVEQvluZ+MiN9p6+vbV1Y0qOs/tdX0dET8MCKuLH39PF8b\nI+JIRDwbEc9ExMcuMGbgc6xhXQOfYw3rGvgca1jXwOdYRLw1Ir4SEV8rdf3eBca8JSIeKOfkaERs\nauv7zdL+jYj4+RUXlJkjeQN+DrgWeHqJ/u3A54EArgeOlvYrgefLzyvK9hUDrOtnX3s+Wl8DcbSt\n7wRw1ZDO1xRw6ALtq4BvAf8CuBz4GnD1oOo6b+wvAl8c0PlaB1xbtt8G/P/zj3sYc6xhXQOfYw3r\nGvgca1LXMOZYmTNjZfsy4Chw/XljfhX4o7J9C/BA2b66nKO3AJvLuVu1knpG9oo+M/8COHORITuA\n+7PlcWA8ItYBPw88lplnMvMs8BiwbVB1ZeaXy/MCPE7rcwR91+B8LaWvX1nRYV0fAT7bq+e+mMx8\nMTO/Wra/DzzHmz/BPfA51qSuYcyxhudrKX2bY13UNZA5VubMYtm9rNzOf0N0B7CvbD8E3BARUdpn\nMvPvMvPbwBytc9i1kQ36Bpb6qoVR+gqGXbSuCF+TwJ9HxLFofTJ40N5XXkp+PiLeWdpG4nxFxD+h\nFZZ/1tY8kPNVXjK/l9ZVV7uhzrGL1NVu4HNsmbqGNseWO1+DnmMRsSoingRO07owWHJ+ZeY5YAF4\nO304X0P7CoTaRcQ0rf8JP9DW/IHMnI+InwAei4ivlyveQfgq8JOZuRgR24H/BWwZ0HM38YvA/8vM\n9qv/vp+viBij9T/+r2fm3/TysVeiSV3DmGPL1DW0Odbwv+NA51hm/hB4T0SMAw9HxLsy84LvVfXb\npXxFv9RXLTT6CoZ+ioifBv4Y2JGZL73Wnpnz5edp4GFW+HKsE5n5N6+9lMzW5xwui4irGIHzVdzC\neS+p+32+IuIyWuGwPzM/d4EhQ5ljDeoayhxbrq5hzbEm56sY+Bwrj/0ycIQ3L+/96LxExGpgLfAS\n/ThfvXwDotc3YBNLv7n4Id74RtlXSvuVwLdpvUl2Rdm+coB1/XNaa2o/e177GuBtbdtfBrYNsK5/\nxuufm7gO+Oty7lbTejNxM6+/UfbOQdVV+tfSWsdfM6jzVY79fuAPLzJm4HOsYV0Dn2MN6xr4HGtS\n1zDmGPAOYLxs/zjwf4FfOG/MHbzxzdgHy/Y7eeObsc+zwjdjR3bpJiI+S+td/Ksi4iTwu7Te0CAz\n/4jWJ2+305rwrwC3l74zEfEHtL5zB+D3840v1fpd1+/QWmf7H633VTiXrS8smqD18g1aE/9PM/ML\nA6zrw8CvRMQ54FXglmzNqr5+ZUWDugD+HfDnmfmDtrv29XwB7wd+GThe1lEBfotWiA5zjjWpaxhz\nrEldw5hjTeqCwc+xdcC+aP1DTD9GK8QPRcTvA09k5kHgXuBPImKO1i+hW0rNz0TEg8CzwDngjmwt\nA3XNT8ZKUuUu5TV6SVIDBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZX7RwDj7+S1+Ez4\nAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x106e9ef60>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df.contraceptive.hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Age                int64\n",
       "Education          int64\n",
       "H_education        int64\n",
       "num_child          int64\n",
       "Religion           int64\n",
       "Employ             int64\n",
       "H_occupation       int64\n",
       "living_standard    int64\n",
       "Media_exposure     int64\n",
       "contraceptive      int64\n",
       "dtype: object"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.dtypes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def one_hot_encoding(idx):\n",
    "    y = np.zeros((len(idx),max(idx)+1))\n",
    "    y[np.arange(len(idx)), idx] = 1\n",
    "    return y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "scaler = StandardScaler()\n",
    "df[['Age','num_child']] = scaler.fit_transform(df[['Age','num_child']]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x = df[['Age','num_child','Employ','Media_exposure']].values\n",
    "y = one_hot_encoding(df.contraceptive.values-1)\n",
    "\n",
    "liv_cats = df.living_standard.max()\n",
    "edu_cats = df.Education.max()\n",
    "\n",
    "liv = df.living_standard.values - 1\n",
    "liv_one_hot = one_hot_encoding(liv)\n",
    "edu = df.Education.values - 1\n",
    "edu_one_hot = one_hot_encoding(edu)\n",
    "\n",
    "train_x, test_x, train_liv, \\\n",
    "test_liv, train_edu, test_edu, train_y, test_y = train_test_split(x,liv_one_hot,edu_one_hot,y,test_size=0.1, random_state=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_x = np.hstack([train_x, train_edu, train_liv])\n",
    "test_x = np.hstack([test_x, test_edu, test_liv])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1325, 12)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1325, 4)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_edu.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1325, 4)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_liv.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1325, 12)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:2: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=12, input_dim=12)`\n",
      "  from ipykernel import kernelapp as app\n",
      "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:4: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n",
      "/Users/sachin/anaconda/lib/python3.5/site-packages/keras/models.py:826: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
      "  warnings.warn('The `nb_epoch` argument in `fit` '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "0s - loss: 1.1002 - acc: 0.3894\n",
      "Epoch 2/100\n",
      "0s - loss: 1.0557 - acc: 0.4158\n",
      "Epoch 3/100\n",
      "0s - loss: 1.0371 - acc: 0.4370\n",
      "Epoch 4/100\n",
      "0s - loss: 1.0242 - acc: 0.4649\n",
      "Epoch 5/100\n",
      "0s - loss: 1.0140 - acc: 0.4762\n",
      "Epoch 6/100\n",
      "0s - loss: 1.0059 - acc: 0.4891\n",
      "Epoch 7/100\n",
      "0s - loss: 0.9989 - acc: 0.4913\n",
      "Epoch 8/100\n",
      "0s - loss: 0.9930 - acc: 0.4936\n",
      "Epoch 9/100\n",
      "0s - loss: 0.9877 - acc: 0.5034\n",
      "Epoch 10/100\n",
      "0s - loss: 0.9827 - acc: 0.5072\n",
      "Epoch 11/100\n",
      "0s - loss: 0.9781 - acc: 0.5147\n",
      "Epoch 12/100\n",
      "0s - loss: 0.9742 - acc: 0.5155\n",
      "Epoch 13/100\n",
      "0s - loss: 0.9706 - acc: 0.5185\n",
      "Epoch 14/100\n",
      "0s - loss: 0.9671 - acc: 0.5253\n",
      "Epoch 15/100\n",
      "0s - loss: 0.9642 - acc: 0.5230\n",
      "Epoch 16/100\n",
      "0s - loss: 0.9613 - acc: 0.5260\n",
      "Epoch 17/100\n",
      "0s - loss: 0.9587 - acc: 0.5283\n",
      "Epoch 18/100\n",
      "0s - loss: 0.9563 - acc: 0.5260\n",
      "Epoch 19/100\n",
      "0s - loss: 0.9539 - acc: 0.5313\n",
      "Epoch 20/100\n",
      "0s - loss: 0.9518 - acc: 0.5313\n",
      "Epoch 21/100\n",
      "0s - loss: 0.9497 - acc: 0.5343\n",
      "Epoch 22/100\n",
      "0s - loss: 0.9479 - acc: 0.5381\n",
      "Epoch 23/100\n",
      "0s - loss: 0.9461 - acc: 0.5389\n",
      "Epoch 24/100\n",
      "0s - loss: 0.9443 - acc: 0.5449\n",
      "Epoch 25/100\n",
      "0s - loss: 0.9427 - acc: 0.5419\n",
      "Epoch 26/100\n",
      "0s - loss: 0.9413 - acc: 0.5426\n",
      "Epoch 27/100\n",
      "0s - loss: 0.9397 - acc: 0.5457\n",
      "Epoch 28/100\n",
      "0s - loss: 0.9383 - acc: 0.5449\n",
      "Epoch 29/100\n",
      "0s - loss: 0.9370 - acc: 0.5442\n",
      "Epoch 30/100\n",
      "0s - loss: 0.9357 - acc: 0.5442\n",
      "Epoch 31/100\n",
      "0s - loss: 0.9344 - acc: 0.5449\n",
      "Epoch 32/100\n",
      "0s - loss: 0.9332 - acc: 0.5487\n",
      "Epoch 33/100\n",
      "0s - loss: 0.9321 - acc: 0.5442\n",
      "Epoch 34/100\n",
      "0s - loss: 0.9309 - acc: 0.5472\n",
      "Epoch 35/100\n",
      "0s - loss: 0.9299 - acc: 0.5479\n",
      "Epoch 36/100\n",
      "0s - loss: 0.9289 - acc: 0.5532\n",
      "Epoch 37/100\n",
      "0s - loss: 0.9279 - acc: 0.5540\n",
      "Epoch 38/100\n",
      "0s - loss: 0.9269 - acc: 0.5517\n",
      "Epoch 39/100\n",
      "0s - loss: 0.9260 - acc: 0.5502\n",
      "Epoch 40/100\n",
      "0s - loss: 0.9251 - acc: 0.5525\n",
      "Epoch 41/100\n",
      "0s - loss: 0.9242 - acc: 0.5540\n",
      "Epoch 42/100\n",
      "0s - loss: 0.9234 - acc: 0.5540\n",
      "Epoch 43/100\n",
      "0s - loss: 0.9225 - acc: 0.5517\n",
      "Epoch 44/100\n",
      "0s - loss: 0.9217 - acc: 0.5540\n",
      "Epoch 45/100\n",
      "0s - loss: 0.9209 - acc: 0.5555\n",
      "Epoch 46/100\n",
      "0s - loss: 0.9202 - acc: 0.5570\n",
      "Epoch 47/100\n",
      "0s - loss: 0.9194 - acc: 0.5570\n",
      "Epoch 48/100\n",
      "0s - loss: 0.9186 - acc: 0.5585\n",
      "Epoch 49/100\n",
      "0s - loss: 0.9179 - acc: 0.5592\n",
      "Epoch 50/100\n",
      "0s - loss: 0.9171 - acc: 0.5600\n",
      "Epoch 51/100\n",
      "0s - loss: 0.9164 - acc: 0.5608\n",
      "Epoch 52/100\n",
      "0s - loss: 0.9158 - acc: 0.5608\n",
      "Epoch 53/100\n",
      "0s - loss: 0.9150 - acc: 0.5600\n",
      "Epoch 54/100\n",
      "0s - loss: 0.9144 - acc: 0.5615\n",
      "Epoch 55/100\n",
      "0s - loss: 0.9137 - acc: 0.5585\n",
      "Epoch 56/100\n",
      "0s - loss: 0.9131 - acc: 0.5592\n",
      "Epoch 57/100\n",
      "0s - loss: 0.9126 - acc: 0.5570\n",
      "Epoch 58/100\n",
      "0s - loss: 0.9119 - acc: 0.5562\n",
      "Epoch 59/100\n",
      "0s - loss: 0.9114 - acc: 0.5600\n",
      "Epoch 60/100\n",
      "0s - loss: 0.9108 - acc: 0.5600\n",
      "Epoch 61/100\n",
      "0s - loss: 0.9102 - acc: 0.5608\n",
      "Epoch 62/100\n",
      "0s - loss: 0.9097 - acc: 0.5600\n",
      "Epoch 63/100\n",
      "0s - loss: 0.9092 - acc: 0.5608\n",
      "Epoch 64/100\n",
      "0s - loss: 0.9086 - acc: 0.5608\n",
      "Epoch 65/100\n",
      "0s - loss: 0.9081 - acc: 0.5630\n",
      "Epoch 66/100\n",
      "0s - loss: 0.9077 - acc: 0.5623\n",
      "Epoch 67/100\n",
      "0s - loss: 0.9072 - acc: 0.5608\n",
      "Epoch 68/100\n",
      "0s - loss: 0.9068 - acc: 0.5660\n",
      "Epoch 69/100\n",
      "0s - loss: 0.9064 - acc: 0.5645\n",
      "Epoch 70/100\n",
      "0s - loss: 0.9059 - acc: 0.5660\n",
      "Epoch 71/100\n",
      "0s - loss: 0.9055 - acc: 0.5645\n",
      "Epoch 72/100\n",
      "0s - loss: 0.9051 - acc: 0.5645\n",
      "Epoch 73/100\n",
      "0s - loss: 0.9047 - acc: 0.5660\n",
      "Epoch 74/100\n",
      "0s - loss: 0.9044 - acc: 0.5660\n",
      "Epoch 75/100\n",
      "0s - loss: 0.9040 - acc: 0.5645\n",
      "Epoch 76/100\n",
      "0s - loss: 0.9036 - acc: 0.5691\n",
      "Epoch 77/100\n",
      "0s - loss: 0.9032 - acc: 0.5706\n",
      "Epoch 78/100\n",
      "0s - loss: 0.9029 - acc: 0.5683\n",
      "Epoch 79/100\n",
      "0s - loss: 0.9025 - acc: 0.5698\n",
      "Epoch 80/100\n",
      "0s - loss: 0.9022 - acc: 0.5698\n",
      "Epoch 81/100\n",
      "0s - loss: 0.9018 - acc: 0.5668\n",
      "Epoch 82/100\n",
      "0s - loss: 0.9015 - acc: 0.5691\n",
      "Epoch 83/100\n",
      "0s - loss: 0.9012 - acc: 0.5713\n",
      "Epoch 84/100\n",
      "0s - loss: 0.9008 - acc: 0.5728\n",
      "Epoch 85/100\n",
      "0s - loss: 0.9005 - acc: 0.5691\n",
      "Epoch 86/100\n",
      "0s - loss: 0.9001 - acc: 0.5728\n",
      "Epoch 87/100\n",
      "0s - loss: 0.8999 - acc: 0.5683\n",
      "Epoch 88/100\n",
      "0s - loss: 0.8996 - acc: 0.5691\n",
      "Epoch 89/100\n",
      "0s - loss: 0.8993 - acc: 0.5713\n",
      "Epoch 90/100\n",
      "0s - loss: 0.8989 - acc: 0.5721\n",
      "Epoch 91/100\n",
      "0s - loss: 0.8987 - acc: 0.5691\n",
      "Epoch 92/100\n",
      "0s - loss: 0.8984 - acc: 0.5675\n",
      "Epoch 93/100\n",
      "0s - loss: 0.8981 - acc: 0.5691\n",
      "Epoch 94/100\n",
      "0s - loss: 0.8978 - acc: 0.5691\n",
      "Epoch 95/100\n",
      "0s - loss: 0.8975 - acc: 0.5698\n",
      "Epoch 96/100\n",
      "0s - loss: 0.8972 - acc: 0.5698\n",
      "Epoch 97/100\n",
      "0s - loss: 0.8970 - acc: 0.5713\n",
      "Epoch 98/100\n",
      "0s - loss: 0.8967 - acc: 0.5706\n",
      "Epoch 99/100\n",
      "0s - loss: 0.8964 - acc: 0.5691\n",
      "Epoch 100/100\n",
      "0s - loss: 0.8962 - acc: 0.5721\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1212fee10>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Dense(input_dim=train_x.shape[1],output_dim=12))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Dense(output_dim=3))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "model.compile(optimizer='adagrad', loss='categorical_crossentropy', metrics=['accuracy'])\n",
    "model.fit(train_x, train_y, nb_epoch=100, verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_1 (Dense)              (None, 12)                156       \n",
      "_________________________________________________________________\n",
      "activation_1 (Activation)    (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 3)                 39        \n",
      "_________________________________________________________________\n",
      "activation_2 (Activation)    (None, 3)                 0         \n",
      "=================================================================\n",
      "Total params: 195.0\n",
      "Trainable params: 195.0\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(12, 12)\n",
      "(12,)\n",
      "(12, 3)\n",
      "(3,)\n"
     ]
    }
   ],
   "source": [
    "for w in model.get_weights():\n",
    "    print(w.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "148/148 [==============================] - 0s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.85758495330810547, 0.587837815284729]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_x, test_y, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.39239073,  0.2203065 ,  0.38730279],\n",
       "       [ 0.82135081,  0.10566427,  0.07298491],\n",
       "       [ 0.25139767,  0.17943899,  0.56916332],\n",
       "       [ 0.3676849 ,  0.33580911,  0.29650599],\n",
       "       [ 0.75309271,  0.13132168,  0.11558564],\n",
       "       [ 0.16729502,  0.54943871,  0.28326628],\n",
       "       [ 0.18573713,  0.45595431,  0.35830855],\n",
       "       [ 0.8188768 ,  0.10733887,  0.0737843 ],\n",
       "       [ 0.73907691,  0.04200678,  0.21891631],\n",
       "       [ 0.64818466,  0.11329354,  0.2385218 ]], dtype=float32)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(test_x[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 3, 3, ..., 3, 1, 3])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "liv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_x, test_x, train_liv, \\\n",
    "test_liv, train_edu, test_edu, train_y, test_y = train_test_split(x,liv,edu,y,test_size=0.1, random_state=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:16: UserWarning: The `Merge` layer is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n",
      "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:18: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=12)`\n",
      "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:20: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n"
     ]
    }
   ],
   "source": [
    "# Input layer for religion\n",
    "encoder_liv = Sequential()\n",
    "encoder_liv.add(Embedding(liv_cats,4,input_length=1))\n",
    "encoder_liv.add(Flatten())\n",
    "\n",
    "# Input layer for religion\n",
    "encoder_edu = Sequential()\n",
    "encoder_edu.add(Embedding(edu_cats,4,input_length=1))\n",
    "encoder_edu.add(Flatten())\n",
    "\n",
    "# Input layer for triggers(x_b)\n",
    "dense_x = Sequential()\n",
    "dense_x.add(Dense(4, input_dim=x.shape[1]))\n",
    "\n",
    "model = Sequential()\n",
    "model.add(Merge([encoder_liv, encoder_edu, dense_x], mode='concat'))\n",
    "# model.add(Activation('relu'))\n",
    "model.add(Dense(output_dim=12))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Dense(output_dim=3))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "model.compile(optimizer='adagrad', loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sachin/anaconda/lib/python3.5/site-packages/keras/models.py:826: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
      "  warnings.warn('The `nb_epoch` argument in `fit` '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "0s - loss: 1.1477 - acc: 0.3442\n",
      "Epoch 2/100\n",
      "0s - loss: 1.0408 - acc: 0.4242\n",
      "Epoch 3/100\n",
      "0s - loss: 1.0129 - acc: 0.4679\n",
      "Epoch 4/100\n",
      "0s - loss: 0.9960 - acc: 0.4755\n",
      "Epoch 5/100\n",
      "0s - loss: 0.9841 - acc: 0.4906\n",
      "Epoch 6/100\n",
      "0s - loss: 0.9750 - acc: 0.4981\n",
      "Epoch 7/100\n",
      "0s - loss: 0.9672 - acc: 0.5049\n",
      "Epoch 8/100\n",
      "0s - loss: 0.9604 - acc: 0.5034\n",
      "Epoch 9/100\n",
      "0s - loss: 0.9540 - acc: 0.5140\n",
      "Epoch 10/100\n",
      "0s - loss: 0.9490 - acc: 0.5223\n",
      "Epoch 11/100\n",
      "0s - loss: 0.9446 - acc: 0.5260\n",
      "Epoch 12/100\n",
      "0s - loss: 0.9408 - acc: 0.5260\n",
      "Epoch 13/100\n",
      "0s - loss: 0.9377 - acc: 0.5253\n",
      "Epoch 14/100\n",
      "0s - loss: 0.9350 - acc: 0.5374\n",
      "Epoch 15/100\n",
      "0s - loss: 0.9326 - acc: 0.5358\n",
      "Epoch 16/100\n",
      "0s - loss: 0.9301 - acc: 0.5374\n",
      "Epoch 17/100\n",
      "0s - loss: 0.9283 - acc: 0.5336\n",
      "Epoch 18/100\n",
      "0s - loss: 0.9266 - acc: 0.5404\n",
      "Epoch 19/100\n",
      "0s - loss: 0.9249 - acc: 0.5404\n",
      "Epoch 20/100\n",
      "0s - loss: 0.9236 - acc: 0.5434\n",
      "Epoch 21/100\n",
      "0s - loss: 0.9223 - acc: 0.5419\n",
      "Epoch 22/100\n",
      "0s - loss: 0.9213 - acc: 0.5457\n",
      "Epoch 23/100\n",
      "0s - loss: 0.9201 - acc: 0.5457\n",
      "Epoch 24/100\n",
      "0s - loss: 0.9190 - acc: 0.5464\n",
      "Epoch 25/100\n",
      "0s - loss: 0.9181 - acc: 0.5479\n",
      "Epoch 26/100\n",
      "0s - loss: 0.9172 - acc: 0.5464\n",
      "Epoch 27/100\n",
      "0s - loss: 0.9165 - acc: 0.5472\n",
      "Epoch 28/100\n",
      "0s - loss: 0.9157 - acc: 0.5547\n",
      "Epoch 29/100\n",
      "0s - loss: 0.9151 - acc: 0.5509\n",
      "Epoch 30/100\n",
      "0s - loss: 0.9143 - acc: 0.5562\n",
      "Epoch 31/100\n",
      "0s - loss: 0.9137 - acc: 0.5577\n",
      "Epoch 32/100\n",
      "0s - loss: 0.9132 - acc: 0.5592\n",
      "Epoch 33/100\n",
      "0s - loss: 0.9125 - acc: 0.5540\n",
      "Epoch 34/100\n",
      "0s - loss: 0.9121 - acc: 0.5555\n",
      "Epoch 35/100\n",
      "0s - loss: 0.9115 - acc: 0.5555\n",
      "Epoch 36/100\n",
      "0s - loss: 0.9111 - acc: 0.5592\n",
      "Epoch 37/100\n",
      "0s - loss: 0.9105 - acc: 0.5600\n",
      "Epoch 38/100\n",
      "0s - loss: 0.9101 - acc: 0.5592\n",
      "Epoch 39/100\n",
      "0s - loss: 0.9096 - acc: 0.5630\n",
      "Epoch 40/100\n",
      "0s - loss: 0.9093 - acc: 0.5592\n",
      "Epoch 41/100\n",
      "0s - loss: 0.9088 - acc: 0.5600\n",
      "Epoch 42/100\n",
      "0s - loss: 0.9084 - acc: 0.5623\n",
      "Epoch 43/100\n",
      "0s - loss: 0.9081 - acc: 0.5600\n",
      "Epoch 44/100\n",
      "0s - loss: 0.9077 - acc: 0.5562\n",
      "Epoch 45/100\n",
      "0s - loss: 0.9075 - acc: 0.5585\n",
      "Epoch 46/100\n",
      "0s - loss: 0.9071 - acc: 0.5570\n",
      "Epoch 47/100\n",
      "0s - loss: 0.9068 - acc: 0.5585\n",
      "Epoch 48/100\n",
      "0s - loss: 0.9064 - acc: 0.5600\n",
      "Epoch 49/100\n",
      "0s - loss: 0.9061 - acc: 0.5623\n",
      "Epoch 50/100\n",
      "0s - loss: 0.9060 - acc: 0.5600\n",
      "Epoch 51/100\n",
      "0s - loss: 0.9056 - acc: 0.5592\n",
      "Epoch 52/100\n",
      "0s - loss: 0.9053 - acc: 0.5577\n",
      "Epoch 53/100\n",
      "0s - loss: 0.9050 - acc: 0.5577\n",
      "Epoch 54/100\n",
      "0s - loss: 0.9048 - acc: 0.5577\n",
      "Epoch 55/100\n",
      "0s - loss: 0.9045 - acc: 0.5585\n",
      "Epoch 56/100\n",
      "0s - loss: 0.9043 - acc: 0.5608\n",
      "Epoch 57/100\n",
      "0s - loss: 0.9040 - acc: 0.5547\n",
      "Epoch 58/100\n",
      "0s - loss: 0.9037 - acc: 0.5600\n",
      "Epoch 59/100\n",
      "0s - loss: 0.9035 - acc: 0.5608\n",
      "Epoch 60/100\n",
      "0s - loss: 0.9032 - acc: 0.5555\n",
      "Epoch 61/100\n",
      "0s - loss: 0.9030 - acc: 0.5600\n",
      "Epoch 62/100\n",
      "0s - loss: 0.9030 - acc: 0.5585\n",
      "Epoch 63/100\n",
      "0s - loss: 0.9026 - acc: 0.5608\n",
      "Epoch 64/100\n",
      "0s - loss: 0.9023 - acc: 0.5585\n",
      "Epoch 65/100\n",
      "0s - loss: 0.9021 - acc: 0.5592\n",
      "Epoch 66/100\n",
      "0s - loss: 0.9019 - acc: 0.5592\n",
      "Epoch 67/100\n",
      "0s - loss: 0.9017 - acc: 0.5555\n",
      "Epoch 68/100\n",
      "0s - loss: 0.9016 - acc: 0.5570\n",
      "Epoch 69/100\n",
      "0s - loss: 0.9012 - acc: 0.5577\n",
      "Epoch 70/100\n",
      "0s - loss: 0.9011 - acc: 0.5615\n",
      "Epoch 71/100\n",
      "0s - loss: 0.9009 - acc: 0.5585\n",
      "Epoch 72/100\n",
      "0s - loss: 0.9007 - acc: 0.5608\n",
      "Epoch 73/100\n",
      "0s - loss: 0.9005 - acc: 0.5585\n",
      "Epoch 74/100\n",
      "0s - loss: 0.9003 - acc: 0.5577\n",
      "Epoch 75/100\n",
      "0s - loss: 0.9001 - acc: 0.5562\n",
      "Epoch 76/100\n",
      "0s - loss: 0.8999 - acc: 0.5562\n",
      "Epoch 77/100\n",
      "0s - loss: 0.8998 - acc: 0.5562\n",
      "Epoch 78/100\n",
      "0s - loss: 0.8996 - acc: 0.5555\n",
      "Epoch 79/100\n",
      "0s - loss: 0.8993 - acc: 0.5592\n",
      "Epoch 80/100\n",
      "0s - loss: 0.8992 - acc: 0.5570\n",
      "Epoch 81/100\n",
      "0s - loss: 0.8990 - acc: 0.5577\n",
      "Epoch 82/100\n",
      "0s - loss: 0.8988 - acc: 0.5592\n",
      "Epoch 83/100\n",
      "0s - loss: 0.8987 - acc: 0.5585\n",
      "Epoch 84/100\n",
      "0s - loss: 0.8986 - acc: 0.5585\n",
      "Epoch 85/100\n",
      "0s - loss: 0.8984 - acc: 0.5623\n",
      "Epoch 86/100\n",
      "0s - loss: 0.8983 - acc: 0.5608\n",
      "Epoch 87/100\n",
      "0s - loss: 0.8982 - acc: 0.5608\n",
      "Epoch 88/100\n",
      "0s - loss: 0.8979 - acc: 0.5630\n",
      "Epoch 89/100\n",
      "0s - loss: 0.8979 - acc: 0.5623\n",
      "Epoch 90/100\n",
      "0s - loss: 0.8977 - acc: 0.5630\n",
      "Epoch 91/100\n",
      "0s - loss: 0.8974 - acc: 0.5630\n",
      "Epoch 92/100\n",
      "0s - loss: 0.8973 - acc: 0.5615\n",
      "Epoch 93/100\n",
      "0s - loss: 0.8972 - acc: 0.5608\n",
      "Epoch 94/100\n",
      "0s - loss: 0.8971 - acc: 0.5615\n",
      "Epoch 95/100\n",
      "0s - loss: 0.8970 - acc: 0.5623\n",
      "Epoch 96/100\n",
      "0s - loss: 0.8968 - acc: 0.5608\n",
      "Epoch 97/100\n",
      "0s - loss: 0.8967 - acc: 0.5615\n",
      "Epoch 98/100\n",
      "0s - loss: 0.8966 - acc: 0.5600\n",
      "Epoch 99/100\n",
      "0s - loss: 0.8964 - acc: 0.5608\n",
      "Epoch 100/100\n",
      "0s - loss: 0.8964 - acc: 0.5600\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x121d8ddd8>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit([train_liv[:,None], train_edu[:,None], train_x], train_y, nb_epoch=100, verbose=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "dense_3 (Dense)              (None, 4)                 20        \n",
      "=================================================================\n",
      "Total params: 20.0\n",
      "Trainable params: 20\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "dense_x.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (None, 1, 4)              16        \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 4)                 0         \n",
      "=================================================================\n",
      "Total params: 16.0\n",
      "Trainable params: 16.0\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "encoder_liv.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "merge_1 (Merge)              (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 12)                156       \n",
      "_________________________________________________________________\n",
      "activation_3 (Activation)    (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 3)                 39        \n",
      "_________________________________________________________________\n",
      "activation_4 (Activation)    (None, 3)                 0         \n",
      "=================================================================\n",
      "Total params: 247.0\n",
      "Trainable params: 247.0\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'list' object has no attribute 'shape'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-23-8e2f9d3764a7>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'shape'"
     ]
    }
   ],
   "source": [
    "for w in model.get_weights():\n",
    "    print(w.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[-0.10935115, -0.23031998,  0.23564951,  0.3424432 ],\n",
       "        [-0.06168354, -0.05825301,  0.111118  ,  0.16586818],\n",
       "        [-0.15200721, -0.0934339 , -0.10459773,  0.00161025],\n",
       "        [-0.04321436,  0.05898349, -0.01321368, -0.08363315]], dtype=float32)]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[array([[-0.10935115, -0.23031998,  0.23564951,  0.3424432 ],\n",
       "         [-0.06168354, -0.05825301,  0.111118  ,  0.16586818],\n",
       "         [-0.15200721, -0.0934339 , -0.10459773,  0.00161025],\n",
       "         [-0.04321436,  0.05898349, -0.01321368, -0.08363315]], dtype=float32)],\n",
       " [],\n",
       " [array([[ 0.3322866 ,  0.19930699, -0.24251796,  0.24181931],\n",
       "         [ 0.19793722,  0.12741166, -0.18478565,  0.16072831],\n",
       "         [ 0.00148587, -0.0822821 , -0.00326628,  0.04515119],\n",
       "         [-0.09646862, -0.2633568 ,  0.16113301, -0.19863592]], dtype=float32)],\n",
       " [],\n",
       " [array([[-0.59816033,  0.87062579, -0.2976734 ,  0.08506605],\n",
       "         [ 0.28670722,  0.01459732, -0.67885333,  0.91728342],\n",
       "         [-0.72324103, -0.32768196,  0.8207261 ,  0.69913387],\n",
       "         [-0.20462862,  0.6349501 , -0.34206402,  0.06973144]], dtype=float32),\n",
       "  array([ 0.08511487, -0.05479484,  0.09653779,  0.01787456], dtype=float32)],\n",
       " [array([[ 0.54720604,  0.20560427,  0.38046002,  0.06325047, -0.53751922,\n",
       "          -0.06238003, -0.07022222,  0.33957773,  0.19087823, -0.3943564 ,\n",
       "           0.07500011,  0.28335726],\n",
       "         [ 0.66748869,  0.43769655,  0.27608281, -0.03321984, -0.43980169,\n",
       "          -0.39402112,  0.41173798,  0.4726226 , -0.10022564, -0.39347824,\n",
       "          -0.65959972, -0.18946621],\n",
       "         [-0.23607142,  0.28477719, -0.28498721,  0.41613233,  0.15728261,\n",
       "          -0.17061651, -0.1701867 , -0.03094537, -0.45825756,  0.35427347,\n",
       "          -0.20373616,  0.4309096 ],\n",
       "         [-0.68063968,  0.42256722,  0.08898511,  0.44825551,  0.745767  ,\n",
       "          -0.15242647, -0.27827993, -0.52321166, -0.14772928,  0.31725135,\n",
       "           0.41540977,  0.20260219],\n",
       "         [ 0.11326507,  0.29595992, -0.66411144,  0.03979665,  0.78415918,\n",
       "          -0.12158706, -0.40152088, -0.05005921, -0.10276611,  0.52777141,\n",
       "           0.22086678,  0.61187786],\n",
       "         [-0.81348175, -0.54370284, -0.06375849,  0.55344445, -0.18702853,\n",
       "          -0.17707211, -0.28065667, -0.51284409, -0.58049625,  0.61143988,\n",
       "          -0.22997195,  0.21264738],\n",
       "         [ 0.52765691,  0.48952475,  0.47147927, -0.46761283, -0.35750964,\n",
       "           0.13695255,  0.52752954,  0.71626884,  0.65563965, -0.7301867 ,\n",
       "          -0.139073  ,  0.04272588],\n",
       "         [-0.51113981, -0.36798111, -0.55281842, -0.30524713,  0.39200947,\n",
       "           0.44922075, -0.4600637 , -0.09046047, -0.42538816,  0.43586993,\n",
       "          -0.00330608,  0.40157044],\n",
       "         [ 0.07566521,  0.77746761, -0.05694147, -0.00362857, -0.10556948,\n",
       "           0.01500559, -0.02078186, -0.28349164, -0.32414779,  0.51029384,\n",
       "          -0.64134562, -0.24186224],\n",
       "         [ 0.08324727,  0.33779496,  0.23709755, -0.3381888 ,  0.16273189,\n",
       "           0.47784749,  0.04162871,  0.30976474, -0.31543013, -0.28823802,\n",
       "           0.62342113,  0.42441621],\n",
       "         [-0.28952643,  0.35541049, -0.15565668,  0.2239645 ,  0.46108943,\n",
       "           0.58152092,  0.39327046,  0.1154022 ,  0.09159836,  0.26722667,\n",
       "          -0.09292367,  0.15957126],\n",
       "         [-0.01987584, -0.06251051,  0.0218032 , -0.24661671, -0.08240297,\n",
       "          -0.84416002,  0.33008894, -0.28735343,  0.23902059,  0.50154585,\n",
       "           0.0804465 ,  0.17921968]], dtype=float32),\n",
       "  array([ -1.18326025e-04,   1.35194603e-02,  -4.79373112e-02,\n",
       "          -2.96827797e-02,   1.12170853e-01,   5.46757542e-02,\n",
       "           4.18406017e-02,  -1.61119565e-01,   3.99559699e-02,\n",
       "           1.43326208e-01,   5.45555726e-03,  -3.65096107e-02], dtype=float32)],\n",
       " [],\n",
       " [array([[-0.60570943,  0.19913501,  0.43104902],\n",
       "         [-0.37961754, -0.03606199,  0.33234817],\n",
       "         [-0.05652641,  0.38592601,  0.12417495],\n",
       "         [ 0.49130049, -0.45519802, -0.04951563],\n",
       "         [ 0.11304783, -0.82271665,  0.11088244],\n",
       "         [ 0.85395575, -0.00952558, -0.67157847],\n",
       "         [-0.13419652,  0.45367789,  0.18071729],\n",
       "         [-0.14250514, -0.29016131, -0.04574362],\n",
       "         [-0.43520772,  0.43990734,  0.22246923],\n",
       "         [-0.57091314, -0.7859264 , -0.24046065],\n",
       "         [ 0.02203327, -0.20768185, -0.85033274],\n",
       "         [ 0.2845436 , -0.80705798, -0.51982474]], dtype=float32),\n",
       "  array([-0.08315417, -0.05141847,  0.17288126], dtype=float32)],\n",
       " []]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = model.get_weights()\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "148/148 [==============================] - 0s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.86288201808929443, 0.60810810327529907]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate([test_liv[:,None], test_edu[:,None], test_x],test_y, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.29773653,  0.26741281,  0.43485063],\n",
       "       [ 0.77942479,  0.11813645,  0.10243875],\n",
       "       [ 0.18719749,  0.21930861,  0.59349388],\n",
       "       [ 0.37331343,  0.36969444,  0.25699213],\n",
       "       [ 0.69447452,  0.15733764,  0.14818783]], dtype=float32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p = model.predict([test_liv[:,None], test_edu[:,None], test_x], batch_size=256)\n",
    "p[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "merge_1 (Merge)              (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 12)                156       \n",
      "_________________________________________________________________\n",
      "activation_3 (Activation)    (None, 12)                0         \n",
      "_________________________________________________________________\n",
      "dense_5 (Dense)              (None, 3)                 39        \n",
      "_________________________________________________________________\n",
      "activation_4 (Activation)    (None, 3)                 0         \n",
      "=================================================================\n",
      "Total params: 247.0\n",
      "Trainable params: 247.0\n",
      "Non-trainable params: 0.0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:4: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n",
      "/Users/sachin/anaconda/lib/python3.5/site-packages/keras/models.py:826: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n",
      "  warnings.warn('The `nb_epoch` argument in `fit` '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6348 - acc: 0.6531     \n",
      "Epoch 2/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6224 - acc: 0.6644     \n",
      "Epoch 3/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6165 - acc: 0.6662     \n",
      "Epoch 4/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6129 - acc: 0.6694     \n",
      "Epoch 5/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6105 - acc: 0.6699     \n",
      "Epoch 6/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6086 - acc: 0.6699     \n",
      "Epoch 7/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6073 - acc: 0.6737     \n",
      "Epoch 8/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6062 - acc: 0.6704     \n",
      "Epoch 9/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6053 - acc: 0.6707     \n",
      "Epoch 10/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6045 - acc: 0.6707     \n",
      "Epoch 11/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6038 - acc: 0.6709     \n",
      "Epoch 12/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6032 - acc: 0.6722     \n",
      "Epoch 13/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6026 - acc: 0.6727     \n",
      "Epoch 14/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6020 - acc: 0.6719     \n",
      "Epoch 15/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6015 - acc: 0.6717     \n",
      "Epoch 16/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6010 - acc: 0.6735     \n",
      "Epoch 17/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6006 - acc: 0.6740     \n",
      "Epoch 18/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.6002 - acc: 0.6730     \n",
      "Epoch 19/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5998 - acc: 0.6752     \n",
      "Epoch 20/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5994 - acc: 0.6762     \n",
      "Epoch 21/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5990 - acc: 0.6770     \n",
      "Epoch 22/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5987 - acc: 0.6770     \n",
      "Epoch 23/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5983 - acc: 0.6777     \n",
      "Epoch 24/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5979 - acc: 0.6770     \n",
      "Epoch 25/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5976 - acc: 0.6785     \n",
      "Epoch 26/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5973 - acc: 0.6790     \n",
      "Epoch 27/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5970 - acc: 0.6792     \n",
      "Epoch 28/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5966 - acc: 0.6787     \n",
      "Epoch 29/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5963 - acc: 0.6800     \n",
      "Epoch 30/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5960 - acc: 0.6795     \n",
      "Epoch 31/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5957 - acc: 0.6790     \n",
      "Epoch 32/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5954 - acc: 0.6790     \n",
      "Epoch 33/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5951 - acc: 0.6790     \n",
      "Epoch 34/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5949 - acc: 0.6810     \n",
      "Epoch 35/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5946 - acc: 0.6792     \n",
      "Epoch 36/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5943 - acc: 0.6800     \n",
      "Epoch 37/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5941 - acc: 0.6818     \n",
      "Epoch 38/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5938 - acc: 0.6803     \n",
      "Epoch 39/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5935 - acc: 0.6813     \n",
      "Epoch 40/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5933 - acc: 0.6808     \n",
      "Epoch 41/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5930 - acc: 0.6813     \n",
      "Epoch 42/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5928 - acc: 0.6820     \n",
      "Epoch 43/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5926 - acc: 0.6835     \n",
      "Epoch 44/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5923 - acc: 0.6843     \n",
      "Epoch 45/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5921 - acc: 0.6840     \n",
      "Epoch 46/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5919 - acc: 0.6835     \n",
      "Epoch 47/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5917 - acc: 0.6848     \n",
      "Epoch 48/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5915 - acc: 0.6840     \n",
      "Epoch 49/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5913 - acc: 0.6848     \n",
      "Epoch 50/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5911 - acc: 0.6843     \n",
      "Epoch 51/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5909 - acc: 0.6853     \n",
      "Epoch 52/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5907 - acc: 0.6850     \n",
      "Epoch 53/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5905 - acc: 0.6850     \n",
      "Epoch 54/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5903 - acc: 0.6855     \n",
      "Epoch 55/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5901 - acc: 0.6853     \n",
      "Epoch 56/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5899 - acc: 0.6863     \n",
      "Epoch 57/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5897 - acc: 0.6858     \n",
      "Epoch 58/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5895 - acc: 0.6858     \n",
      "Epoch 59/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5894 - acc: 0.6870     \n",
      "Epoch 60/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5891 - acc: 0.6868     \n",
      "Epoch 61/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5890 - acc: 0.6865     \n",
      "Epoch 62/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5888 - acc: 0.6875     \n",
      "Epoch 63/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5886 - acc: 0.6881     \n",
      "Epoch 64/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5884 - acc: 0.6881     \n",
      "Epoch 65/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5882 - acc: 0.6883     \n",
      "Epoch 66/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5881 - acc: 0.6891     \n",
      "Epoch 67/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5879 - acc: 0.6893     \n",
      "Epoch 68/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5877 - acc: 0.6896     \n",
      "Epoch 69/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5875 - acc: 0.6891     \n",
      "Epoch 70/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5874 - acc: 0.6891     \n",
      "Epoch 71/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5872 - acc: 0.6896     \n",
      "Epoch 72/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5871 - acc: 0.6893     \n",
      "Epoch 73/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5869 - acc: 0.6893     \n",
      "Epoch 74/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5867 - acc: 0.6901     \n",
      "Epoch 75/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5866 - acc: 0.6901     \n",
      "Epoch 76/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5864 - acc: 0.6901     \n",
      "Epoch 77/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5863 - acc: 0.6901     \n",
      "Epoch 78/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5861 - acc: 0.6901     \n",
      "Epoch 79/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5859 - acc: 0.6903     \n",
      "Epoch 80/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5858 - acc: 0.6901     \n",
      "Epoch 81/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5857 - acc: 0.6901     \n",
      "Epoch 82/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5855 - acc: 0.6901     \n",
      "Epoch 83/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5854 - acc: 0.6896     \n",
      "Epoch 84/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5852 - acc: 0.6903     \n",
      "Epoch 85/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5851 - acc: 0.6901     \n",
      "Epoch 86/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5850 - acc: 0.6903     \n",
      "Epoch 87/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5849 - acc: 0.6903     \n",
      "Epoch 88/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5847 - acc: 0.6906     \n",
      "Epoch 89/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5846 - acc: 0.6906     \n",
      "Epoch 90/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5845 - acc: 0.6906     \n",
      "Epoch 91/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5844 - acc: 0.6911     \n",
      "Epoch 92/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5842 - acc: 0.6918     \n",
      "Epoch 93/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5841 - acc: 0.6911     \n",
      "Epoch 94/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5840 - acc: 0.6916     \n",
      "Epoch 95/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5838 - acc: 0.6906     \n",
      "Epoch 96/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5837 - acc: 0.6916     \n",
      "Epoch 97/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5836 - acc: 0.6908     \n",
      "Epoch 98/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5835 - acc: 0.6913     \n",
      "Epoch 99/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5833 - acc: 0.6906     \n",
      "Epoch 100/100\n",
      "1325/1325 [==============================] - 0s - loss: 0.5832 - acc: 0.6911     \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x1227b34a8>"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Dense(4, input_dim=train_x.shape[1]))\n",
    "model.add(Activation('relu'))\n",
    "model.add(Dense(output_dim=3))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "model.compile(optimizer='adagrad', loss='binary_crossentropy', metrics=['accuracy'])\n",
    "model.fit(train_x, train_y, nb_epoch=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "148/148 [==============================] - 0s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.56608289480209351, 0.70945948362350464]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(test_x,test_y,batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.fit?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  },
  "latex_envs": {
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 0
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}