{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Embeddings\n", "\n", "https://www.youtube.com/watch?v=wSXGlvTR9UM" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "from keras.models import Sequential\n", "from keras.layers import Dense, Activation, Embedding, Merge, Flatten\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style>\n", " .dataframe thead tr:only-child th {\n", " text-align: right;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: left;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Age</th>\n", " <th>Education</th>\n", " <th>H_education</th>\n", " <th>num_child</th>\n", " <th>Religion</th>\n", " <th>Employ</th>\n", " <th>H_occupation</th>\n", " <th>living_standard</th>\n", " <th>Media_exposure</th>\n", " <th>contraceptive</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>24</td>\n", " <td>2</td>\n", " <td>3</td>\n", " <td>3</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>2</td>\n", " <td>3</td>\n", " <td>0</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>45</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>10</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>4</td>\n", " <td>0</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>43</td>\n", " <td>2</td>\n", " <td>3</td>\n", " <td>7</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>4</td>\n", " <td>0</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>42</td>\n", " <td>3</td>\n", " <td>2</td>\n", " <td>9</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>3</td>\n", " <td>0</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>36</td>\n", " <td>3</td>\n", " <td>3</td>\n", " <td>8</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>2</td>\n", " <td>0</td>\n", " <td>1</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Age Education H_education num_child Religion Employ H_occupation \\\n", "0 24 2 3 3 1 1 2 \n", "1 45 1 3 10 1 1 3 \n", "2 43 2 3 7 1 1 3 \n", "3 42 3 2 9 1 1 3 \n", "4 36 3 3 8 1 1 3 \n", "\n", " living_standard Media_exposure contraceptive \n", "0 3 0 1 \n", "1 4 0 1 \n", "2 4 0 1 \n", "3 3 0 1 \n", "4 2 0 1 " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('data/cmc.data',header=None,names=['Age','Education','H_education',\n", " 'num_child','Religion', 'Employ',\n", " 'H_occupation','living_standard',\n", " 'Media_exposure','contraceptive'])\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Age False\n", "Education False\n", "H_education False\n", "num_child False\n", "Religion False\n", "Employ False\n", "H_occupation False\n", "living_standard False\n", "Media_exposure False\n", "contraceptive False\n", "dtype: bool" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.isnull().any()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x12287b630>" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAE4RJREFUeJzt3X+sX3ddx/Hn23bA7MV2MLw2bbVLXEgmFVhvthKIuZdF\nUzayLnHgyNy6ZaaJDsQw4wqJGvwRyx+ATA3aMLIO0csywdVu0yxdr8gfG7Y418FAyihZb8oqW1e4\nbGqqb//4fsDL5bbf8/3e8+33ez8+H8nNPedzPud8P+/zuX3d0/P9cSMzkSTV60eGPQBJ0mAZ9JJU\nOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVa5R0EfEmoi4NyK+HBFPRsQbIuIVEfFQRHy1fL+g\n9I2IuCMijkTE4xFx6WBLkCSdTTR5Z2xE7AH+KTM/FhEvAX4UeB/wXGbuioidwAWZeXtEXAm8C7gS\nuBz4SGZefrbjX3jhhblx48a+Cvjud7/LqlWr+tp31FjL6KmlDrCWUbWUWg4dOvStzHxV146ZedYv\nYDXwdcovhXntXwHWluW1wFfK8l8A71is35m+Nm/enP06cOBA3/uOGmsZPbXUkWkto2optQAHs0uG\nZ2b3K/qIeB2wG/gS8FrgEPBuYDYz15Q+AZzMzDURsQ/YlZmfK9v2A7dn5sEFx90B7AAYHx/fPD09\n3fWX0mLm5uYYGxvra99RYy2jp5Y6wFpG1VJqmZqaOpSZE107dvtNAEwAp4HLy/pHgN8Hnl/Q72T5\nvg9407z2/cDE2R7DK/oOaxk9tdSRaS2j6lxc0Td5MvYYcCwzHy3r9wKXAs9ExFqA8v1E2T4LbJi3\n//rSJkkagq5Bn5nfBJ6OiFeXpivo3MbZC2wvbduB+8ryXuDG8uqbLcCpzDze7rAlSU2tbNjvXcAn\nyytungJupvNL4p6IuAX4BvD20vcBOq+4OQK8UPpKkoakUdBn5mN07tUvdMUifRO4dYnjkiS1xHfG\nSlLlDHpJqpxBL0mVa/pkrCRVa+PO+4f22HdtHfxHOXhFL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNe\nkipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWp\ncga9JFXOoJekyhn0klQ5g16SKmfQS1LlGgV9RByNiMMR8VhEHCxtr4iIhyLiq+X7BaU9IuKOiDgS\nEY9HxKWDLECSdHa9XNFPZebrMnOirO8E9mfmxcD+sg7wFuDi8rUD+Ghbg5Uk9W4pt262AXvK8h7g\nmnntd2fHI8CaiFi7hMeRJC1BZGb3ThFfB04CCfxFZu6OiOczc03ZHsDJzFwTEfuAXZn5ubJtP3B7\nZh5ccMwddK74GR8f3zw9Pd1XAXNzc4yNjfW176ixltFTSx1gLWdzePZUa8fq1UWrV/Rdy9TU1KF5\nd1nOaGXD470pM2cj4seBhyLiy/M3ZmZGRPffGD+4z25gN8DExEROTk72svv3zczM0O++o8ZaRk8t\ndYC1nM1NO+9v7Vi9umvrqoHPS6NbN5k5W76fAD4DXAY8871bMuX7idJ9Ftgwb/f1pU2SNARdgz4i\nVkXEy7+3DPwC8ASwF9heum0H7ivLe4Eby6tvtgCnMvN46yOXJDXS5NbNOPCZzm14VgJ/lZl/HxH/\nDNwTEbcA3wDeXvo/AFwJHAFeAG5ufdSSpMa6Bn1mPgW8dpH2Z4ErFmlP4NZWRidJWjLfGStJlTPo\nJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16S\nKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJaly\nBr0kVW5l044RsQI4CMxm5lsj4iJgGnglcAi4ITP/KyJeCtwNbAaeBX4pM4+2PnKpcht33t/q8W7b\ndJqbGhzz6K6rWn1cDV8vV/TvBp6ct/4B4MOZ+dPASeCW0n4LcLK0f7j0kyQNSaOgj4j1wFXAx8p6\nAG8G7i1d9gDXlOVtZZ2y/YrSX5I0BJGZ3TtF3Av8EfBy4DeBm4BHylU7EbEBeDAzXxMRTwBbM/NY\n2fY14PLM/NaCY+4AdgCMj49vnp6e7quAubk5xsbG+tp31FjL6BlmHYdnT7V6vPHz4ZkXu/fbtG51\nq487CG3PS9vnuhcXrV7Rdy1TU1OHMnOiW7+u9+gj4q3Aicw8FBGTfY1mEZm5G9gNMDExkZOT/R16\nZmaGfvcdNdYyeoZZR5P76b24bdNpPni4+9NyR6+fbPVxB6HteWn7XPfirq2rBv4z1uTJ2DcCV0fE\nlcDLgB8DPgKsiYiVmXkaWA/Mlv6zwAbgWESsBFbTeVJWkjQEXe/RZ+Z7M3N9Zm4ErgMezszrgQPA\ntaXbduC+sry3rFO2P5xN7g9JkgZiKa+jvx14T0QcofMSyztL+53AK0v7e4CdSxuiJGkpGr+OHiAz\nZ4CZsvwUcNkiff4DeFsLY5MktcB3xkpS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIq\nZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqlxPf0pQ2rjz/kb9\nbtt0mpsa9m3i6K6rWjuW9P+NV/SSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFWua9BHxMsi\n4vMR8a8R8cWIeH9pvygiHo2IIxHxqYh4SWl/aVk/UrZvHGwJkqSzaXJF/5/AmzPztcDrgK0RsQX4\nAPDhzPxp4CRwS+l/C3CytH+49JMkDUnXoM+OubJ6XvlK4M3AvaV9D3BNWd5W1inbr4iIaG3EkqSe\nNLpHHxErIuIx4ATwEPA14PnMPF26HAPWleV1wNMAZfsp4JVtDlqS1FxkZvPOEWuAzwC/DdxVbs8Q\nERuABzPzNRHxBLA1M4+VbV8DLs/Mby041g5gB8D4+Pjm6enpvgqYm5tjbGysr31HzXKo5fDsqUb9\nxs+HZ15s73E3rVvd3sF6MMw5aXqum2o6J8M6171oe17aPte9uGj1ir5rmZqaOpSZE9369fShZpn5\nfEQcAN4ArImIleWqfT0wW7rNAhuAYxGxElgNPLvIsXYDuwEmJiZycnKyl6F838zMDP3uO2qWQy1N\nP6jstk2n+eDh9j4z7+j1k60dqxfDnJM2PxQOms/JsM51L9qel7bPdS/u2rpq4D9jTV5186pyJU9E\nnA/8PPAkcAC4tnTbDtxXlveWdcr2h7OX/zZIklrV5JJrLbAnIlbQ+cVwT2bui4gvAdMR8QfAvwB3\nlv53Ap+IiCPAc8B1Axi3JKmhrkGfmY8Dr1+k/SngskXa/wN4WyujkyQtme+MlaTKGfSSVDmDXpIq\nZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIG\nvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIq1zXo\nI2JDRByIiC9FxBcj4t2l/RUR8VBEfLV8v6C0R0TcERFHIuLxiLh00EVIks6syRX9aeC2zLwE2ALc\nGhGXADuB/Zl5MbC/rAO8Bbi4fO0APtr6qCVJjXUN+sw8nplfKMvfAZ4E1gHbgD2l2x7gmrK8Dbg7\nOx4B1kTE2tZHLklqpKd79BGxEXg98CgwnpnHy6ZvAuNleR3w9LzdjpU2SdIQRGY26xgxBvwj8IeZ\n+emIeD4z18zbfjIzL4iIfcCuzPxcad8P3J6ZBxccbwedWzuMj49vnp6e7quAubk5xsbG+tp31CyH\nWg7PnmrUb/x8eObF9h5307rV7R2sB8Ock6bnuqmmczKsc92Ltuel7XPdi4tWr+i7lqmpqUOZOdGt\n38omB4uI84C/AT6ZmZ8uzc9ExNrMPF5uzZwo7bPAhnm7ry9tPyAzdwO7ASYmJnJycrLJUH7IzMwM\n/e47apZDLTftvL9Rv9s2neaDhxv9eDVy9PrJ1o7Vi2HOSdNz3VTTORnWue5F2/PS9rnuxV1bVw38\nZ6zJq24CuBN4MjM/NG/TXmB7Wd4O3Dev/cby6pstwKl5t3gkSedYk0uuNwI3AIcj4rHS9j5gF3BP\nRNwCfAN4e9n2AHAlcAR4Abi51RFLknrSNejLvfY4w+YrFumfwK1LHJckqSW+M1aSKmfQS1LlDHpJ\nqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlWvvA8OH5PDsqaF9lvTR\nXVcN5XElqRde0UtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEv\nSZUz6CWpcga9JFXOoJekynUN+oj4eESciIgn5rW9IiIeioivlu8XlPaIiDsi4khEPB4Rlw5y8JKk\n7ppc0d8FbF3QthPYn5kXA/vLOsBbgIvL1w7go+0MU5LUr65Bn5mfBZ5b0LwN2FOW9wDXzGu/Ozse\nAdZExNq2BitJ6l2/9+jHM/N4Wf4mMF6W1wFPz+t3rLRJkoYkMrN7p4iNwL7MfE1Zfz4z18zbfjIz\nL4iIfcCuzPxcad8P3J6ZBxc55g46t3cYHx/fPD093VcBJ547xTMv9rXrkm1at7rV483NzTE2Ntbq\nMdt2ePZUo37j59PqvLR9rpsa5pw0PddNNZ2TYZ3rXrQ9L22f615ctHpF37VMTU0dysyJbv36/Zux\nz0TE2sw8Xm7NnCjts8CGef3Wl7Yfkpm7gd0AExMTOTk52ddA/uST9/HBw8P507dHr59s9XgzMzP0\nex7OlaZ/n/e2TadbnZe2z3VTw5yTtv8WctM5Gda57kXb8zKsvzsNcNfWVQP/Gev31s1eYHtZ3g7c\nN6/9xvLqmy3AqXm3eCRJQ9D113tE/DUwCVwYEceA3wV2AfdExC3AN4C3l+4PAFcCR4AXgJsHMGZJ\nUg+6Bn1mvuMMm65YpG8Cty51UJKk9vjOWEmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQ\nS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0k\nVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcgMJ+ojYGhFfiYgjEbFzEI8hSWqm9aCP\niBXAnwFvAS4B3hERl7T9OJKkZgZxRX8ZcCQzn8rM/wKmgW0DeBxJUgODCPp1wNPz1o+VNknSEERm\ntnvAiGuBrZn5K2X9BuDyzHzngn47gB1l9dXAV/p8yAuBb/W576ixltFTSx1gLaNqKbX8VGa+qlun\nlX0e/GxmgQ3z1teXth+QmbuB3Ut9sIg4mJkTSz3OKLCW0VNLHWAto+pc1DKIWzf/DFwcERdFxEuA\n64C9A3gcSVIDrV/RZ+bpiHgn8A/ACuDjmfnFth9HktTMIG7dkJkPAA8M4tiLWPLtnxFiLaOnljrA\nWkbVwGtp/clYSdJo8SMQJKlyyyLoI+LjEXEiIp44w/aIiDvKRy48HhGXnusxNtWglsmIOBURj5Wv\n3znXY2wqIjZExIGI+FJEfDEi3r1In5Gfm4Z1LIt5iYiXRcTnI+JfSy3vX6TPSyPiU2VOHo2Ijed+\npN01rOWmiPj3efPyK8MYaxMRsSIi/iUi9i2ybbBzkpkj/wX8HHAp8MQZtl8JPAgEsAV4dNhjXkIt\nk8C+YY+zYS1rgUvL8suBfwMuWW5z07COZTEv5TyPleXzgEeBLQv6/Brw52X5OuBTwx73Emq5CfjT\nYY+1YT3vAf5qsZ+jQc/Jsriiz8zPAs+dpcs24O7seARYExFrz83oetOglmUjM49n5hfK8neAJ/nh\nd0GP/Nw0rGNZKOd5rqyeV74WPhG3DdhTlu8FroiIOEdDbKxhLctCRKwHrgI+doYuA52TZRH0DdT2\nsQtvKP9dfTAifmbYg2mi/Ffz9XSuuuZbVnNzljpgmcxLuUXwGHACeCgzzzgnmXkaOAW88tyOspkG\ntQD8YrkteG9EbFhk+yj4Y+C3gP85w/aBzkktQV+TL9B5W/NrgT8B/nbI4+kqIsaAvwF+IzO/Pezx\n9KtLHctmXjLzvzPzdXTelX5ZRLxm2GPqV4Na/g7YmJk/CzzE/10Vj4yIeCtwIjMPDWsMtQR9o49d\nWA4y89vf++9qdt6PcF5EXDjkYZ1RRJxHJxw/mZmfXqTLspibbnUst3kByMzngQPA1gWbvj8nEbES\nWA08e25H15sz1ZKZz2bmf5bVjwGbz/XYGngjcHVEHKXzab5vjoi/XNBnoHNSS9DvBW4sr/DYApzK\nzOPDHlQ/IuInvndvLiIuozNHI/mPsIzzTuDJzPzQGbqN/Nw0qWO5zEtEvCoi1pTl84GfB768oNte\nYHtZvhZ4OMuzgKOkSS0Lnu+5ms7zKyMlM9+bmeszcyOdJ1ofzsxfXtBtoHMykHfGti0i/prOqx4u\njIhjwO/SeWKGzPxzOu/CvRI4ArwA3DyckXbXoJZrgV+NiNPAi8B1o/iPsHgjcANwuNxHBXgf8JOw\nrOamSR3LZV7WAnui8weAfgS4JzP3RcTvAQczcy+dX2qfiIgjdF4YcN3whntWTWr59Yi4GjhNp5ab\nhjbaHp3LOfGdsZJUuVpu3UiSzsCgl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcv8LAfci\ngdNBX0MAAAAASUVORK5CYII=\n", "text/plain": [ "<matplotlib.figure.Figure at 0x122831a90>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df.Education.hist()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1473, 10)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.axes._subplots.AxesSubplot at 0x11a97b6a0>" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFJxJREFUeJzt3X+MndV95/H3tzYkXU/WAyGdtWxv7VWtXZHQJHhESRNV\nM0G7NU67ZqUUEaHGIEuWWrZK1f2BW6mt2u5K5A82TdBuulaJMJWbAdGwtgzJFjme7WYjnOKUYH4k\nmwlxikfEVrCZZgJt5fTbP+4hXIzH89w795eP3i/pap7nnHPv/T4Ph88899y515GZSJLq9WPDLkCS\n1F8GvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyq4ddAMBVV12VmzZt6uq+P/jB\nD1izZk1vC+oB6+qMdXVuVGuzrs6spK5jx459LzPfsezAzBz6bevWrdmtI0eOdH3ffrKuzlhX50a1\nNuvqzErqAp7IBhnr0o0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFVuJL4C\nYSWOzy9w255HhvLcJ+760FCeV5I64RW9JFXOoJekyhn0klQ5g16SKmfQS1LlGgV9RIxHxEMR8fWI\neC4i3hcRV0bEYxHxzfLzijI2IuJTETEXEU9FxLX9PQRJ0sU0vaL/JPCFzPxXwLuB54A9wOHM3AIc\nLvsANwJbym038OmeVixJ6siyQR8Ra4GfA+4FyMy/z8yXgR3AvjJsH3BT2d4B3F/+AZTHgfGIWNfz\nyiVJjUTrX6O6yICI9wB7gWdpXc0fAz4GzGfmeBkTwNnMHI+IQ8Bdmfml0ncYuDMznzjvcXfTuuJn\nYmJi68zMTFcHcPrMAqde7equK3bN+rVL9i0uLjI2NjbAapqxrs6Mal0wurVZV2dWUtf09PSxzJxc\nblyTT8auBq4Ffi0zj0bEJ3l9mQaAzMyIuPhvjPNk5l5av0CYnJzMqampTu7+I/fsP8Ddx4fzAd8T\nt04t2Tc7O0u3x9RP1tWZUa0LRrc26+rMIOpqskZ/EjiZmUfL/kO0gv/Ua0sy5efp0j8PbGy7/4bS\nJkkagmWDPjO/C7wQEf+yNN1AaxnnILCztO0EDpTtg8BHy1/fXA8sZOaLvS1bktRU0zWPXwP2R8Tl\nwPPA7bR+STwYEbuA7wA3l7GPAtuBOeCVMlaSNCSNgj4znwQutOB/wwXGJnDHCuuSJPWIn4yVpMoZ\n9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEv\nSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVbnWTQRFxAvg+\n8EPgXGZORsSVwAPAJuAEcHNmno2IAD4JbAdeAW7LzK/2vnRJ6o1Nex4Z2nPft21N35+jkyv66cx8\nT2ZOlv09wOHM3AIcLvsANwJbym038OleFStJ6txKlm52APvK9j7gprb2+7PlcWA8Itat4HkkSSsQ\nmbn8oIhvA2eBBP5nZu6NiJczc7z0B3A2M8cj4hBwV2Z+qfQdBu7MzCfOe8zdtK74mZiY2DozM9PV\nAZw+s8CpV7u664pds37tkn2Li4uMjY0NsJpmrKszo1oXjG5tl2Jdx+cXBlzN6zavXdX1+Zqenj7W\ntsqypEZr9MAHMnM+In4CeCwivt7emZkZEcv/xnjjffYCewEmJydzamqqk7v/yD37D3D38aaH0Vsn\nbp1asm92dpZuj6mfrKszo1oXjG5tl2Jdtw15jb7f56vR0k1mzpefp4GHgeuAU68tyZSfp8vweWBj\n2903lDZJ0hAsG/QRsSYi3vbaNvBvgKeBg8DOMmwncKBsHwQ+Gi3XAwuZ+WLPK5ckNdJkzWMCeLi1\nDM9q4E8z8wsR8ZfAgxGxC/gOcHMZ/yitP62co/Xnlbf3vGpJUmPLBn1mPg+8+wLtLwE3XKA9gTt6\nUp0kacX8ZKwkVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5\ng16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlWsc\n9BGxKiL+KiIOlf3NEXE0IuYi4oGIuLy0v6Xsz5X+Tf0pXZLURCdX9B8Dnmvb/zjwicz8KeAssKu0\n7wLOlvZPlHGSpCFpFPQRsQH4EPDHZT+ADwIPlSH7gJvK9o6yT+m/oYyXJA1B0yv6PwT+M/APZf/t\nwMuZea7snwTWl+31wAsApX+hjJckDUFk5sUHRPwCsD0zfzUipoD/CNwGPF6WZ4iIjcDnM/NdEfE0\nsC0zT5a+bwE/k5nfO+9xdwO7ASYmJrbOzMx0dQCnzyxw6tWu7rpi16xfu2Tf4uIiY2NjA6ymGevq\nzKjWBaNb26VY1/H5hQFX87rNa1d1fb6mp6ePZebkcuNWN3is9wP/NiK2A28F/inwSWA8IlaXq/YN\nwHwZPw9sBE5GxGpgLfDS+Q+amXuBvQCTk5M5NTXVoJQ3u2f/Ae4+3uQweu/ErVNL9s3OztLtMfWT\ndXVmVOuC0a3tUqzrtj2PDLaYNvdtW9P387Xs0k1m/mZmbsjMTcAtwBcz81bgCPDhMmwncKBsHyz7\nlP4v5nIvGyRJfbOSv6O/E/iNiJijtQZ/b2m/F3h7af8NYM/KSpQkrURHax6ZOQvMlu3ngesuMOZv\ngV/qQW2SpB7wk7GSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQ\nS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUueH8q9rSJeL4/MLQ/uHoE3d9aCjPq/p4RS9J\nlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqt2zQR8RbI+IrEfG1iHgmIn6vtG+OiKMRMRcRD0TE5aX9\nLWV/rvRv6u8hSJIupskV/d8BH8zMdwPvAbZFxPXAx4FPZOZPAWeBXWX8LuBsaf9EGSdJGpJlgz5b\nFsvuZeWWwAeBh0r7PuCmsr2j7FP6b4iI6FnFkqSONFqjj4hVEfEkcBp4DPgW8HJmnitDTgLry/Z6\n4AWA0r8AvL2XRUuSmovMbD44Yhx4GPht4L6yPENEbAQ+n5nvioingW2ZebL0fQv4mcz83nmPtRvY\nDTAxMbF1ZmamqwM4fWaBU692ddcVu2b92iX7FhcXGRsbG2A1zVhXZ0Z1fsHonrNLsa7j8wsDruZ1\nm9eu6vp8TU9PH8vMyeXGdfRdN5n5ckQcAd4HjEfE6nLVvgGYL8PmgY3AyYhYDawFXrrAY+0F9gJM\nTk7m1NRUJ6X8yD37D3D38eF8Zc+JW6eW7JudnaXbY+on6+rMqM4vGN1zdinWNazvMwK4b9uavp+v\nJn91845yJU9E/Djwr4HngCPAh8uwncCBsn2w7FP6v5idvGyQJPVUk0uVdcC+iFhF6xfDg5l5KCKe\nBWYi4r8AfwXcW8bfC/xJRMwBZ4Bb+lC3JKmhZYM+M58C3nuB9ueB6y7Q/rfAL/WkOknSivnJWEmq\nnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ\n9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuWWDfqI2BgR\nRyLi2Yh4JiI+VtqvjIjHIuKb5ecVpT0i4lMRMRcRT0XEtf0+CEnS0ppc0Z8D/kNmXg1cD9wREVcD\ne4DDmbkFOFz2AW4EtpTbbuDTPa9aktTYskGfmS9m5lfL9veB54D1wA5gXxm2D7ipbO8A7s+Wx4Hx\niFjX88olSY10tEYfEZuA9wJHgYnMfLF0fReYKNvrgRfa7naytEmShiAys9nAiDHg/wD/NTM/FxEv\nZ+Z4W//ZzLwiIg4Bd2Xml0r7YeDOzHzivMfbTWtph4mJia0zMzNdHcDpMwucerWru67YNevXLtm3\nuLjI2NjYAKtpxro6M6rzC0b3nF2KdR2fXxhwNa/bvHZV1+drenr6WGZOLjdudZMHi4jLgD8D9mfm\n50rzqYhYl5kvlqWZ06V9HtjYdvcNpe0NMnMvsBdgcnIyp6ammpTyJvfsP8DdxxsdRs+duHVqyb7Z\n2Vm6PaZ+sq7OjOr8gtE9Z5diXbfteWSwxbS5b9uavp+vJn91E8C9wHOZ+d/aug4CO8v2TuBAW/tH\ny1/fXA8stC3xSJIGrMmlyvuBXwaOR8STpe23gLuAByNiF/Ad4ObS9yiwHZgDXgFu72nFkqSOLBv0\nZa09lui+4QLjE7hjhXVJknrET8ZKUuUMekmqnEEvSZUz6CWpcga9JFXOoJekyhn0klQ5g16SKmfQ\nS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJalyBr0k\nVc6gl6TKGfSSVDmDXpIqt2zQR8RnIuJ0RDzd1nZlRDwWEd8sP68o7RERn4qIuYh4KiKu7WfxkqTl\nNbmivw/Ydl7bHuBwZm4BDpd9gBuBLeW2G/h0b8qUJHVr2aDPzL8AzpzXvAPYV7b3ATe1td+fLY8D\n4xGxrlfFSpI61+0a/URmvli2vwtMlO31wAtt406WNknSkERmLj8oYhNwKDPfVfZfzszxtv6zmXlF\nRBwC7srML5X2w8CdmfnEBR5zN63lHSYmJrbOzMx0dQCnzyxw6tWu7rpi16xfu2Tf4uIiY2NjA6ym\nGevqzKjOLxjdc3Yp1nV8fmHA1bxu89pVXZ+v6enpY5k5udy41V09OpyKiHWZ+WJZmjld2ueBjW3j\nNpS2N8nMvcBegMnJyZyamuqqkHv2H+Du490exsqcuHVqyb7Z2Vm6PaZ+sq7OjOr8gtE9Z5diXbft\neWSwxbS5b9uavp+vbpduDgI7y/ZO4EBb+0fLX99cDyy0LfFIkoZg2UuViPgsMAVcFREngd8F7gIe\njIhdwHeAm8vwR4HtwBzwCnB7H2qWJHVg2aDPzI8s0XXDBcYmcMdKi5Ik9Y6fjJWkyhn0klQ5g16S\nKmfQS1LlDHpJqpxBL0mVM+glqXIGvSRVzqCXpMoZ9JJUOYNekipn0EtS5Qx6SaqcQS9JlTPoJaly\nBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZUz6CWpcga9JFWuL0EfEdsi4hsRMRcRe/rx\nHJKkZnoe9BGxCvjvwI3A1cBHIuLqXj+PJKmZflzRXwfMZebzmfn3wAywow/PI0lqoB9Bvx54oW3/\nZGmTJA3B6mE9cUTsBnaX3cWI+EaXD3UV8L3eVNWZ+PhFu4dW1zKsqzOjOr/Ac9apkaxr+uMrqusn\nmwzqR9DPAxvb9jeUtjfIzL3A3pU+WUQ8kZmTK32cXrOuzlhX50a1NuvqzCDq6sfSzV8CWyJic0Rc\nDtwCHOzD80iSGuj5FX1mnouIfw/8b2AV8JnMfKbXzyNJaqYva/SZ+SjwaD8e+wJWvPzTJ9bVGevq\n3KjWZl2d6XtdkZn9fg5J0hD5FQiSVLmRDfqI+ExEnI6Ip5foj4j4VPmahaci4tq2vp0R8c1y2zng\num4t9RyPiC9HxLvb+k6U9icj4okB1zUVEQvluZ+MiN9p6+vbV1Y0qOs/tdX0dET8MCKuLH39PF8b\nI+JIRDwbEc9ExMcuMGbgc6xhXQOfYw3rGvgca1jXwOdYRLw1Ir4SEV8rdf3eBca8JSIeKOfkaERs\nauv7zdL+jYj4+RUXlJkjeQN+DrgWeHqJ/u3A54EArgeOlvYrgefLzyvK9hUDrOtnX3s+Wl8DcbSt\n7wRw1ZDO1xRw6ALtq4BvAf8CuBz4GnD1oOo6b+wvAl8c0PlaB1xbtt8G/P/zj3sYc6xhXQOfYw3r\nGvgca1LXMOZYmTNjZfsy4Chw/XljfhX4o7J9C/BA2b66nKO3AJvLuVu1knpG9oo+M/8COHORITuA\n+7PlcWA8ItYBPw88lplnMvMs8BiwbVB1ZeaXy/MCPE7rcwR91+B8LaWvX1nRYV0fAT7bq+e+mMx8\nMTO/Wra/DzzHmz/BPfA51qSuYcyxhudrKX2bY13UNZA5VubMYtm9rNzOf0N0B7CvbD8E3BARUdpn\nMvPvMvPbwBytc9i1kQ36Bpb6qoVR+gqGXbSuCF+TwJ9HxLFofTJ40N5XXkp+PiLeWdpG4nxFxD+h\nFZZ/1tY8kPNVXjK/l9ZVV7uhzrGL1NVu4HNsmbqGNseWO1+DnmMRsSoingRO07owWHJ+ZeY5YAF4\nO304X0P7CoTaRcQ0rf8JP9DW/IHMnI+InwAei4ivlyveQfgq8JOZuRgR24H/BWwZ0HM38YvA/8vM\n9qv/vp+viBij9T/+r2fm3/TysVeiSV3DmGPL1DW0Odbwv+NA51hm/hB4T0SMAw9HxLsy84LvVfXb\npXxFv9RXLTT6CoZ+ioifBv4Y2JGZL73Wnpnz5edp4GFW+HKsE5n5N6+9lMzW5xwui4irGIHzVdzC\neS+p+32+IuIyWuGwPzM/d4EhQ5ljDeoayhxbrq5hzbEm56sY+Bwrj/0ycIQ3L+/96LxExGpgLfAS\n/ThfvXwDotc3YBNLv7n4Id74RtlXSvuVwLdpvUl2Rdm+coB1/XNaa2o/e177GuBtbdtfBrYNsK5/\nxuufm7gO+Oty7lbTejNxM6+/UfbOQdVV+tfSWsdfM6jzVY79fuAPLzJm4HOsYV0Dn2MN6xr4HGtS\n1zDmGPAOYLxs/zjwf4FfOG/MHbzxzdgHy/Y7eeObsc+zwjdjR3bpJiI+S+td/Ksi4iTwu7Te0CAz\n/4jWJ2+305rwrwC3l74zEfEHtL5zB+D3840v1fpd1+/QWmf7H633VTiXrS8smqD18g1aE/9PM/ML\nA6zrw8CvRMQ54FXglmzNqr5+ZUWDugD+HfDnmfmDtrv29XwB7wd+GThe1lEBfotWiA5zjjWpaxhz\nrEldw5hjTeqCwc+xdcC+aP1DTD9GK8QPRcTvA09k5kHgXuBPImKO1i+hW0rNz0TEg8CzwDngjmwt\nA3XNT8ZKUuUu5TV6SVIDBr0kVc6gl6TKGfSSVDmDXpIqZ9BLUuUMekmqnEEvSZX7RwDj7+S1+Ez4\nAAAAAElFTkSuQmCC\n", "text/plain": [ "<matplotlib.figure.Figure at 0x106e9ef60>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "df.contraceptive.hist()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Age int64\n", "Education int64\n", "H_education int64\n", "num_child int64\n", "Religion int64\n", "Employ int64\n", "H_occupation int64\n", "living_standard int64\n", "Media_exposure int64\n", "contraceptive int64\n", "dtype: object" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.dtypes" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def one_hot_encoding(idx):\n", " y = np.zeros((len(idx),max(idx)+1))\n", " y[np.arange(len(idx)), idx] = 1\n", " return y" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true }, "outputs": [], "source": [ "scaler = StandardScaler()\n", "df[['Age','num_child']] = scaler.fit_transform(df[['Age','num_child']]) " ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "collapsed": true }, "outputs": [], "source": [ "x = df[['Age','num_child','Employ','Media_exposure']].values\n", "y = one_hot_encoding(df.contraceptive.values-1)\n", "\n", "liv_cats = df.living_standard.max()\n", "edu_cats = df.Education.max()\n", "\n", "liv = df.living_standard.values - 1\n", "liv_one_hot = one_hot_encoding(liv)\n", "edu = df.Education.values - 1\n", "edu_one_hot = one_hot_encoding(edu)\n", "\n", "train_x, test_x, train_liv, \\\n", "test_liv, train_edu, test_edu, train_y, test_y = train_test_split(x,liv_one_hot,edu_one_hot,y,test_size=0.1, random_state=1)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_x = np.hstack([train_x, train_edu, train_liv])\n", "test_x = np.hstack([test_x, test_edu, test_liv])" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1325, 12)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_x.shape" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1325, 4)" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_edu.shape" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1325, 4)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_liv.shape" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1325, 12)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_x.shape" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:2: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=12, input_dim=12)`\n", " from ipykernel import kernelapp as app\n", "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:4: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n", "/Users/sachin/anaconda/lib/python3.5/site-packages/keras/models.py:826: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n", " warnings.warn('The `nb_epoch` argument in `fit` '\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "0s - loss: 1.1002 - acc: 0.3894\n", "Epoch 2/100\n", "0s - loss: 1.0557 - acc: 0.4158\n", "Epoch 3/100\n", "0s - loss: 1.0371 - acc: 0.4370\n", "Epoch 4/100\n", "0s - loss: 1.0242 - acc: 0.4649\n", "Epoch 5/100\n", "0s - loss: 1.0140 - acc: 0.4762\n", "Epoch 6/100\n", "0s - loss: 1.0059 - acc: 0.4891\n", "Epoch 7/100\n", "0s - loss: 0.9989 - acc: 0.4913\n", "Epoch 8/100\n", "0s - loss: 0.9930 - acc: 0.4936\n", "Epoch 9/100\n", "0s - loss: 0.9877 - acc: 0.5034\n", "Epoch 10/100\n", "0s - loss: 0.9827 - acc: 0.5072\n", "Epoch 11/100\n", "0s - loss: 0.9781 - acc: 0.5147\n", "Epoch 12/100\n", "0s - loss: 0.9742 - acc: 0.5155\n", "Epoch 13/100\n", "0s - loss: 0.9706 - acc: 0.5185\n", "Epoch 14/100\n", "0s - loss: 0.9671 - acc: 0.5253\n", "Epoch 15/100\n", "0s - loss: 0.9642 - acc: 0.5230\n", "Epoch 16/100\n", "0s - loss: 0.9613 - acc: 0.5260\n", "Epoch 17/100\n", "0s - loss: 0.9587 - acc: 0.5283\n", "Epoch 18/100\n", "0s - loss: 0.9563 - acc: 0.5260\n", "Epoch 19/100\n", "0s - loss: 0.9539 - acc: 0.5313\n", "Epoch 20/100\n", "0s - loss: 0.9518 - acc: 0.5313\n", "Epoch 21/100\n", "0s - loss: 0.9497 - acc: 0.5343\n", "Epoch 22/100\n", "0s - loss: 0.9479 - acc: 0.5381\n", "Epoch 23/100\n", "0s - loss: 0.9461 - acc: 0.5389\n", "Epoch 24/100\n", "0s - loss: 0.9443 - acc: 0.5449\n", "Epoch 25/100\n", "0s - loss: 0.9427 - acc: 0.5419\n", "Epoch 26/100\n", "0s - loss: 0.9413 - acc: 0.5426\n", "Epoch 27/100\n", "0s - loss: 0.9397 - acc: 0.5457\n", "Epoch 28/100\n", "0s - loss: 0.9383 - acc: 0.5449\n", "Epoch 29/100\n", "0s - loss: 0.9370 - acc: 0.5442\n", "Epoch 30/100\n", "0s - loss: 0.9357 - acc: 0.5442\n", "Epoch 31/100\n", "0s - loss: 0.9344 - acc: 0.5449\n", "Epoch 32/100\n", "0s - loss: 0.9332 - acc: 0.5487\n", "Epoch 33/100\n", "0s - loss: 0.9321 - acc: 0.5442\n", "Epoch 34/100\n", "0s - loss: 0.9309 - acc: 0.5472\n", "Epoch 35/100\n", "0s - loss: 0.9299 - acc: 0.5479\n", "Epoch 36/100\n", "0s - loss: 0.9289 - acc: 0.5532\n", "Epoch 37/100\n", "0s - loss: 0.9279 - acc: 0.5540\n", "Epoch 38/100\n", "0s - loss: 0.9269 - acc: 0.5517\n", "Epoch 39/100\n", "0s - loss: 0.9260 - acc: 0.5502\n", "Epoch 40/100\n", "0s - loss: 0.9251 - acc: 0.5525\n", "Epoch 41/100\n", "0s - loss: 0.9242 - acc: 0.5540\n", "Epoch 42/100\n", "0s - loss: 0.9234 - acc: 0.5540\n", "Epoch 43/100\n", "0s - loss: 0.9225 - acc: 0.5517\n", "Epoch 44/100\n", "0s - loss: 0.9217 - acc: 0.5540\n", "Epoch 45/100\n", "0s - loss: 0.9209 - acc: 0.5555\n", "Epoch 46/100\n", "0s - loss: 0.9202 - acc: 0.5570\n", "Epoch 47/100\n", "0s - loss: 0.9194 - acc: 0.5570\n", "Epoch 48/100\n", "0s - loss: 0.9186 - acc: 0.5585\n", "Epoch 49/100\n", "0s - loss: 0.9179 - acc: 0.5592\n", "Epoch 50/100\n", "0s - loss: 0.9171 - acc: 0.5600\n", "Epoch 51/100\n", "0s - loss: 0.9164 - acc: 0.5608\n", "Epoch 52/100\n", "0s - loss: 0.9158 - acc: 0.5608\n", "Epoch 53/100\n", "0s - loss: 0.9150 - acc: 0.5600\n", "Epoch 54/100\n", "0s - loss: 0.9144 - acc: 0.5615\n", "Epoch 55/100\n", "0s - loss: 0.9137 - acc: 0.5585\n", "Epoch 56/100\n", "0s - loss: 0.9131 - acc: 0.5592\n", "Epoch 57/100\n", "0s - loss: 0.9126 - acc: 0.5570\n", "Epoch 58/100\n", "0s - loss: 0.9119 - acc: 0.5562\n", "Epoch 59/100\n", "0s - loss: 0.9114 - acc: 0.5600\n", "Epoch 60/100\n", "0s - loss: 0.9108 - acc: 0.5600\n", "Epoch 61/100\n", "0s - loss: 0.9102 - acc: 0.5608\n", "Epoch 62/100\n", "0s - loss: 0.9097 - acc: 0.5600\n", "Epoch 63/100\n", "0s - loss: 0.9092 - acc: 0.5608\n", "Epoch 64/100\n", "0s - loss: 0.9086 - acc: 0.5608\n", "Epoch 65/100\n", "0s - loss: 0.9081 - acc: 0.5630\n", "Epoch 66/100\n", "0s - loss: 0.9077 - acc: 0.5623\n", "Epoch 67/100\n", "0s - loss: 0.9072 - acc: 0.5608\n", "Epoch 68/100\n", "0s - loss: 0.9068 - acc: 0.5660\n", "Epoch 69/100\n", "0s - loss: 0.9064 - acc: 0.5645\n", "Epoch 70/100\n", "0s - loss: 0.9059 - acc: 0.5660\n", "Epoch 71/100\n", "0s - loss: 0.9055 - acc: 0.5645\n", "Epoch 72/100\n", "0s - loss: 0.9051 - acc: 0.5645\n", "Epoch 73/100\n", "0s - loss: 0.9047 - acc: 0.5660\n", "Epoch 74/100\n", "0s - loss: 0.9044 - acc: 0.5660\n", "Epoch 75/100\n", "0s - loss: 0.9040 - acc: 0.5645\n", "Epoch 76/100\n", "0s - loss: 0.9036 - acc: 0.5691\n", "Epoch 77/100\n", "0s - loss: 0.9032 - acc: 0.5706\n", "Epoch 78/100\n", "0s - loss: 0.9029 - acc: 0.5683\n", "Epoch 79/100\n", "0s - loss: 0.9025 - acc: 0.5698\n", "Epoch 80/100\n", "0s - loss: 0.9022 - acc: 0.5698\n", "Epoch 81/100\n", "0s - loss: 0.9018 - acc: 0.5668\n", "Epoch 82/100\n", "0s - loss: 0.9015 - acc: 0.5691\n", "Epoch 83/100\n", "0s - loss: 0.9012 - acc: 0.5713\n", "Epoch 84/100\n", "0s - loss: 0.9008 - acc: 0.5728\n", "Epoch 85/100\n", "0s - loss: 0.9005 - acc: 0.5691\n", "Epoch 86/100\n", "0s - loss: 0.9001 - acc: 0.5728\n", "Epoch 87/100\n", "0s - loss: 0.8999 - acc: 0.5683\n", "Epoch 88/100\n", "0s - loss: 0.8996 - acc: 0.5691\n", "Epoch 89/100\n", "0s - loss: 0.8993 - acc: 0.5713\n", "Epoch 90/100\n", "0s - loss: 0.8989 - acc: 0.5721\n", "Epoch 91/100\n", "0s - loss: 0.8987 - acc: 0.5691\n", "Epoch 92/100\n", "0s - loss: 0.8984 - acc: 0.5675\n", "Epoch 93/100\n", "0s - loss: 0.8981 - acc: 0.5691\n", "Epoch 94/100\n", "0s - loss: 0.8978 - acc: 0.5691\n", "Epoch 95/100\n", "0s - loss: 0.8975 - acc: 0.5698\n", "Epoch 96/100\n", "0s - loss: 0.8972 - acc: 0.5698\n", "Epoch 97/100\n", "0s - loss: 0.8970 - acc: 0.5713\n", "Epoch 98/100\n", "0s - loss: 0.8967 - acc: 0.5706\n", "Epoch 99/100\n", "0s - loss: 0.8964 - acc: 0.5691\n", "Epoch 100/100\n", "0s - loss: 0.8962 - acc: 0.5721\n" ] }, { "data": { "text/plain": [ "<keras.callbacks.History at 0x1212fee10>" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Sequential()\n", "model.add(Dense(input_dim=train_x.shape[1],output_dim=12))\n", "model.add(Activation('relu'))\n", "model.add(Dense(output_dim=3))\n", "model.add(Activation('softmax'))\n", "\n", "model.compile(optimizer='adagrad', loss='categorical_crossentropy', metrics=['accuracy'])\n", "model.fit(train_x, train_y, nb_epoch=100, verbose=2)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "dense_1 (Dense) (None, 12) 156 \n", "_________________________________________________________________\n", "activation_1 (Activation) (None, 12) 0 \n", "_________________________________________________________________\n", "dense_2 (Dense) (None, 3) 39 \n", "_________________________________________________________________\n", "activation_2 (Activation) (None, 3) 0 \n", "=================================================================\n", "Total params: 195.0\n", "Trainable params: 195.0\n", "Non-trainable params: 0.0\n", "_________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(12, 12)\n", "(12,)\n", "(12, 3)\n", "(3,)\n" ] } ], "source": [ "for w in model.get_weights():\n", " print(w.shape)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "148/148 [==============================] - 0s\n" ] }, { "data": { "text/plain": [ "[0.85758495330810547, 0.587837815284729]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(test_x, test_y, batch_size=256)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.39239073, 0.2203065 , 0.38730279],\n", " [ 0.82135081, 0.10566427, 0.07298491],\n", " [ 0.25139767, 0.17943899, 0.56916332],\n", " [ 0.3676849 , 0.33580911, 0.29650599],\n", " [ 0.75309271, 0.13132168, 0.11558564],\n", " [ 0.16729502, 0.54943871, 0.28326628],\n", " [ 0.18573713, 0.45595431, 0.35830855],\n", " [ 0.8188768 , 0.10733887, 0.0737843 ],\n", " [ 0.73907691, 0.04200678, 0.21891631],\n", " [ 0.64818466, 0.11329354, 0.2385218 ]], dtype=float32)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict(test_x[:10])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([2, 3, 3, ..., 3, 1, 3])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "liv" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_x, test_x, train_liv, \\\n", "test_liv, train_edu, test_edu, train_y, test_y = train_test_split(x,liv,edu,y,test_size=0.1, random_state=1)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:16: UserWarning: The `Merge` layer is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n", "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:18: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=12)`\n", "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:20: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n" ] } ], "source": [ "# Input layer for religion\n", "encoder_liv = Sequential()\n", "encoder_liv.add(Embedding(liv_cats,4,input_length=1))\n", "encoder_liv.add(Flatten())\n", "\n", "# Input layer for religion\n", "encoder_edu = Sequential()\n", "encoder_edu.add(Embedding(edu_cats,4,input_length=1))\n", "encoder_edu.add(Flatten())\n", "\n", "# Input layer for triggers(x_b)\n", "dense_x = Sequential()\n", "dense_x.add(Dense(4, input_dim=x.shape[1]))\n", "\n", "model = Sequential()\n", "model.add(Merge([encoder_liv, encoder_edu, dense_x], mode='concat'))\n", "# model.add(Activation('relu'))\n", "model.add(Dense(output_dim=12))\n", "model.add(Activation('relu'))\n", "model.add(Dense(output_dim=3))\n", "model.add(Activation('softmax'))\n", "\n", "model.compile(optimizer='adagrad', loss='categorical_crossentropy', metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/sachin/anaconda/lib/python3.5/site-packages/keras/models.py:826: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n", " warnings.warn('The `nb_epoch` argument in `fit` '\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "0s - loss: 1.1477 - acc: 0.3442\n", "Epoch 2/100\n", "0s - loss: 1.0408 - acc: 0.4242\n", "Epoch 3/100\n", "0s - loss: 1.0129 - acc: 0.4679\n", "Epoch 4/100\n", "0s - loss: 0.9960 - acc: 0.4755\n", "Epoch 5/100\n", "0s - loss: 0.9841 - acc: 0.4906\n", "Epoch 6/100\n", "0s - loss: 0.9750 - acc: 0.4981\n", "Epoch 7/100\n", "0s - loss: 0.9672 - acc: 0.5049\n", "Epoch 8/100\n", "0s - loss: 0.9604 - acc: 0.5034\n", "Epoch 9/100\n", "0s - loss: 0.9540 - acc: 0.5140\n", "Epoch 10/100\n", "0s - loss: 0.9490 - acc: 0.5223\n", "Epoch 11/100\n", "0s - loss: 0.9446 - acc: 0.5260\n", "Epoch 12/100\n", "0s - loss: 0.9408 - acc: 0.5260\n", "Epoch 13/100\n", "0s - loss: 0.9377 - acc: 0.5253\n", "Epoch 14/100\n", "0s - loss: 0.9350 - acc: 0.5374\n", "Epoch 15/100\n", "0s - loss: 0.9326 - acc: 0.5358\n", "Epoch 16/100\n", "0s - loss: 0.9301 - acc: 0.5374\n", "Epoch 17/100\n", "0s - loss: 0.9283 - acc: 0.5336\n", "Epoch 18/100\n", "0s - loss: 0.9266 - acc: 0.5404\n", "Epoch 19/100\n", "0s - loss: 0.9249 - acc: 0.5404\n", "Epoch 20/100\n", "0s - loss: 0.9236 - acc: 0.5434\n", "Epoch 21/100\n", "0s - loss: 0.9223 - acc: 0.5419\n", "Epoch 22/100\n", "0s - loss: 0.9213 - acc: 0.5457\n", "Epoch 23/100\n", "0s - loss: 0.9201 - acc: 0.5457\n", "Epoch 24/100\n", "0s - loss: 0.9190 - acc: 0.5464\n", "Epoch 25/100\n", "0s - loss: 0.9181 - acc: 0.5479\n", "Epoch 26/100\n", "0s - loss: 0.9172 - acc: 0.5464\n", "Epoch 27/100\n", "0s - loss: 0.9165 - acc: 0.5472\n", "Epoch 28/100\n", "0s - loss: 0.9157 - acc: 0.5547\n", "Epoch 29/100\n", "0s - loss: 0.9151 - acc: 0.5509\n", "Epoch 30/100\n", "0s - loss: 0.9143 - acc: 0.5562\n", "Epoch 31/100\n", "0s - loss: 0.9137 - acc: 0.5577\n", "Epoch 32/100\n", "0s - loss: 0.9132 - acc: 0.5592\n", "Epoch 33/100\n", "0s - loss: 0.9125 - acc: 0.5540\n", "Epoch 34/100\n", "0s - loss: 0.9121 - acc: 0.5555\n", "Epoch 35/100\n", "0s - loss: 0.9115 - acc: 0.5555\n", "Epoch 36/100\n", "0s - loss: 0.9111 - acc: 0.5592\n", "Epoch 37/100\n", "0s - loss: 0.9105 - acc: 0.5600\n", "Epoch 38/100\n", "0s - loss: 0.9101 - acc: 0.5592\n", "Epoch 39/100\n", "0s - loss: 0.9096 - acc: 0.5630\n", "Epoch 40/100\n", "0s - loss: 0.9093 - acc: 0.5592\n", "Epoch 41/100\n", "0s - loss: 0.9088 - acc: 0.5600\n", "Epoch 42/100\n", "0s - loss: 0.9084 - acc: 0.5623\n", "Epoch 43/100\n", "0s - loss: 0.9081 - acc: 0.5600\n", "Epoch 44/100\n", "0s - loss: 0.9077 - acc: 0.5562\n", "Epoch 45/100\n", "0s - loss: 0.9075 - acc: 0.5585\n", "Epoch 46/100\n", "0s - loss: 0.9071 - acc: 0.5570\n", "Epoch 47/100\n", "0s - loss: 0.9068 - acc: 0.5585\n", "Epoch 48/100\n", "0s - loss: 0.9064 - acc: 0.5600\n", "Epoch 49/100\n", "0s - loss: 0.9061 - acc: 0.5623\n", "Epoch 50/100\n", "0s - loss: 0.9060 - acc: 0.5600\n", "Epoch 51/100\n", "0s - loss: 0.9056 - acc: 0.5592\n", "Epoch 52/100\n", "0s - loss: 0.9053 - acc: 0.5577\n", "Epoch 53/100\n", "0s - loss: 0.9050 - acc: 0.5577\n", "Epoch 54/100\n", "0s - loss: 0.9048 - acc: 0.5577\n", "Epoch 55/100\n", "0s - loss: 0.9045 - acc: 0.5585\n", "Epoch 56/100\n", "0s - loss: 0.9043 - acc: 0.5608\n", "Epoch 57/100\n", "0s - loss: 0.9040 - acc: 0.5547\n", "Epoch 58/100\n", "0s - loss: 0.9037 - acc: 0.5600\n", "Epoch 59/100\n", "0s - loss: 0.9035 - acc: 0.5608\n", "Epoch 60/100\n", "0s - loss: 0.9032 - acc: 0.5555\n", "Epoch 61/100\n", "0s - loss: 0.9030 - acc: 0.5600\n", "Epoch 62/100\n", "0s - loss: 0.9030 - acc: 0.5585\n", "Epoch 63/100\n", "0s - loss: 0.9026 - acc: 0.5608\n", "Epoch 64/100\n", "0s - loss: 0.9023 - acc: 0.5585\n", "Epoch 65/100\n", "0s - loss: 0.9021 - acc: 0.5592\n", "Epoch 66/100\n", "0s - loss: 0.9019 - acc: 0.5592\n", "Epoch 67/100\n", "0s - loss: 0.9017 - acc: 0.5555\n", "Epoch 68/100\n", "0s - loss: 0.9016 - acc: 0.5570\n", "Epoch 69/100\n", "0s - loss: 0.9012 - acc: 0.5577\n", "Epoch 70/100\n", "0s - loss: 0.9011 - acc: 0.5615\n", "Epoch 71/100\n", "0s - loss: 0.9009 - acc: 0.5585\n", "Epoch 72/100\n", "0s - loss: 0.9007 - acc: 0.5608\n", "Epoch 73/100\n", "0s - loss: 0.9005 - acc: 0.5585\n", "Epoch 74/100\n", "0s - loss: 0.9003 - acc: 0.5577\n", "Epoch 75/100\n", "0s - loss: 0.9001 - acc: 0.5562\n", "Epoch 76/100\n", "0s - loss: 0.8999 - acc: 0.5562\n", "Epoch 77/100\n", "0s - loss: 0.8998 - acc: 0.5562\n", "Epoch 78/100\n", "0s - loss: 0.8996 - acc: 0.5555\n", "Epoch 79/100\n", "0s - loss: 0.8993 - acc: 0.5592\n", "Epoch 80/100\n", "0s - loss: 0.8992 - acc: 0.5570\n", "Epoch 81/100\n", "0s - loss: 0.8990 - acc: 0.5577\n", "Epoch 82/100\n", "0s - loss: 0.8988 - acc: 0.5592\n", "Epoch 83/100\n", "0s - loss: 0.8987 - acc: 0.5585\n", "Epoch 84/100\n", "0s - loss: 0.8986 - acc: 0.5585\n", "Epoch 85/100\n", "0s - loss: 0.8984 - acc: 0.5623\n", "Epoch 86/100\n", "0s - loss: 0.8983 - acc: 0.5608\n", "Epoch 87/100\n", "0s - loss: 0.8982 - acc: 0.5608\n", "Epoch 88/100\n", "0s - loss: 0.8979 - acc: 0.5630\n", "Epoch 89/100\n", "0s - loss: 0.8979 - acc: 0.5623\n", "Epoch 90/100\n", "0s - loss: 0.8977 - acc: 0.5630\n", "Epoch 91/100\n", "0s - loss: 0.8974 - acc: 0.5630\n", "Epoch 92/100\n", "0s - loss: 0.8973 - acc: 0.5615\n", "Epoch 93/100\n", "0s - loss: 0.8972 - acc: 0.5608\n", "Epoch 94/100\n", "0s - loss: 0.8971 - acc: 0.5615\n", "Epoch 95/100\n", "0s - loss: 0.8970 - acc: 0.5623\n", "Epoch 96/100\n", "0s - loss: 0.8968 - acc: 0.5608\n", "Epoch 97/100\n", "0s - loss: 0.8967 - acc: 0.5615\n", "Epoch 98/100\n", "0s - loss: 0.8966 - acc: 0.5600\n", "Epoch 99/100\n", "0s - loss: 0.8964 - acc: 0.5608\n", "Epoch 100/100\n", "0s - loss: 0.8964 - acc: 0.5600\n" ] }, { "data": { "text/plain": [ "<keras.callbacks.History at 0x121d8ddd8>" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit([train_liv[:,None], train_edu[:,None], train_x], train_y, nb_epoch=100, verbose=2)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "dense_3 (Dense) (None, 4) 20 \n", "=================================================================\n", "Total params: 20.0\n", "Trainable params: 20\n", "Non-trainable params: 0.0\n", "_________________________________________________________________\n" ] } ], "source": [ "dense_x.summary()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "embedding_1 (Embedding) (None, 1, 4) 16 \n", "_________________________________________________________________\n", "flatten_1 (Flatten) (None, 4) 0 \n", "=================================================================\n", "Total params: 16.0\n", "Trainable params: 16.0\n", "Non-trainable params: 0.0\n", "_________________________________________________________________\n" ] } ], "source": [ "encoder_liv.summary()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "merge_1 (Merge) (None, 12) 0 \n", "_________________________________________________________________\n", "dense_4 (Dense) (None, 12) 156 \n", "_________________________________________________________________\n", "activation_3 (Activation) (None, 12) 0 \n", "_________________________________________________________________\n", "dense_5 (Dense) (None, 3) 39 \n", "_________________________________________________________________\n", "activation_4 (Activation) (None, 3) 0 \n", "=================================================================\n", "Total params: 247.0\n", "Trainable params: 247.0\n", "Non-trainable params: 0.0\n", "_________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "ename": "AttributeError", "evalue": "'list' object has no attribute 'shape'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m<ipython-input-23-8e2f9d3764a7>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'shape'" ] } ], "source": [ "for w in model.get_weights():\n", " print(w.shape)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[array([[-0.10935115, -0.23031998, 0.23564951, 0.3424432 ],\n", " [-0.06168354, -0.05825301, 0.111118 , 0.16586818],\n", " [-0.15200721, -0.0934339 , -0.10459773, 0.00161025],\n", " [-0.04321436, 0.05898349, -0.01321368, -0.08363315]], dtype=float32)]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[array([[-0.10935115, -0.23031998, 0.23564951, 0.3424432 ],\n", " [-0.06168354, -0.05825301, 0.111118 , 0.16586818],\n", " [-0.15200721, -0.0934339 , -0.10459773, 0.00161025],\n", " [-0.04321436, 0.05898349, -0.01321368, -0.08363315]], dtype=float32)],\n", " [],\n", " [array([[ 0.3322866 , 0.19930699, -0.24251796, 0.24181931],\n", " [ 0.19793722, 0.12741166, -0.18478565, 0.16072831],\n", " [ 0.00148587, -0.0822821 , -0.00326628, 0.04515119],\n", " [-0.09646862, -0.2633568 , 0.16113301, -0.19863592]], dtype=float32)],\n", " [],\n", " [array([[-0.59816033, 0.87062579, -0.2976734 , 0.08506605],\n", " [ 0.28670722, 0.01459732, -0.67885333, 0.91728342],\n", " [-0.72324103, -0.32768196, 0.8207261 , 0.69913387],\n", " [-0.20462862, 0.6349501 , -0.34206402, 0.06973144]], dtype=float32),\n", " array([ 0.08511487, -0.05479484, 0.09653779, 0.01787456], dtype=float32)],\n", " [array([[ 0.54720604, 0.20560427, 0.38046002, 0.06325047, -0.53751922,\n", " -0.06238003, -0.07022222, 0.33957773, 0.19087823, -0.3943564 ,\n", " 0.07500011, 0.28335726],\n", " [ 0.66748869, 0.43769655, 0.27608281, -0.03321984, -0.43980169,\n", " -0.39402112, 0.41173798, 0.4726226 , -0.10022564, -0.39347824,\n", " -0.65959972, -0.18946621],\n", " [-0.23607142, 0.28477719, -0.28498721, 0.41613233, 0.15728261,\n", " -0.17061651, -0.1701867 , -0.03094537, -0.45825756, 0.35427347,\n", " -0.20373616, 0.4309096 ],\n", " [-0.68063968, 0.42256722, 0.08898511, 0.44825551, 0.745767 ,\n", " -0.15242647, -0.27827993, -0.52321166, -0.14772928, 0.31725135,\n", " 0.41540977, 0.20260219],\n", " [ 0.11326507, 0.29595992, -0.66411144, 0.03979665, 0.78415918,\n", " -0.12158706, -0.40152088, -0.05005921, -0.10276611, 0.52777141,\n", " 0.22086678, 0.61187786],\n", " [-0.81348175, -0.54370284, -0.06375849, 0.55344445, -0.18702853,\n", " -0.17707211, -0.28065667, -0.51284409, -0.58049625, 0.61143988,\n", " -0.22997195, 0.21264738],\n", " [ 0.52765691, 0.48952475, 0.47147927, -0.46761283, -0.35750964,\n", " 0.13695255, 0.52752954, 0.71626884, 0.65563965, -0.7301867 ,\n", " -0.139073 , 0.04272588],\n", " [-0.51113981, -0.36798111, -0.55281842, -0.30524713, 0.39200947,\n", " 0.44922075, -0.4600637 , -0.09046047, -0.42538816, 0.43586993,\n", " -0.00330608, 0.40157044],\n", " [ 0.07566521, 0.77746761, -0.05694147, -0.00362857, -0.10556948,\n", " 0.01500559, -0.02078186, -0.28349164, -0.32414779, 0.51029384,\n", " -0.64134562, -0.24186224],\n", " [ 0.08324727, 0.33779496, 0.23709755, -0.3381888 , 0.16273189,\n", " 0.47784749, 0.04162871, 0.30976474, -0.31543013, -0.28823802,\n", " 0.62342113, 0.42441621],\n", " [-0.28952643, 0.35541049, -0.15565668, 0.2239645 , 0.46108943,\n", " 0.58152092, 0.39327046, 0.1154022 , 0.09159836, 0.26722667,\n", " -0.09292367, 0.15957126],\n", " [-0.01987584, -0.06251051, 0.0218032 , -0.24661671, -0.08240297,\n", " -0.84416002, 0.33008894, -0.28735343, 0.23902059, 0.50154585,\n", " 0.0804465 , 0.17921968]], dtype=float32),\n", " array([ -1.18326025e-04, 1.35194603e-02, -4.79373112e-02,\n", " -2.96827797e-02, 1.12170853e-01, 5.46757542e-02,\n", " 4.18406017e-02, -1.61119565e-01, 3.99559699e-02,\n", " 1.43326208e-01, 5.45555726e-03, -3.65096107e-02], dtype=float32)],\n", " [],\n", " [array([[-0.60570943, 0.19913501, 0.43104902],\n", " [-0.37961754, -0.03606199, 0.33234817],\n", " [-0.05652641, 0.38592601, 0.12417495],\n", " [ 0.49130049, -0.45519802, -0.04951563],\n", " [ 0.11304783, -0.82271665, 0.11088244],\n", " [ 0.85395575, -0.00952558, -0.67157847],\n", " [-0.13419652, 0.45367789, 0.18071729],\n", " [-0.14250514, -0.29016131, -0.04574362],\n", " [-0.43520772, 0.43990734, 0.22246923],\n", " [-0.57091314, -0.7859264 , -0.24046065],\n", " [ 0.02203327, -0.20768185, -0.85033274],\n", " [ 0.2845436 , -0.80705798, -0.51982474]], dtype=float32),\n", " array([-0.08315417, -0.05141847, 0.17288126], dtype=float32)],\n", " []]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = model.get_weights()\n", "a" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "148/148 [==============================] - 0s\n" ] }, { "data": { "text/plain": [ "[0.86288201808929443, 0.60810810327529907]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate([test_liv[:,None], test_edu[:,None], test_x],test_y, batch_size=256)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.29773653, 0.26741281, 0.43485063],\n", " [ 0.77942479, 0.11813645, 0.10243875],\n", " [ 0.18719749, 0.21930861, 0.59349388],\n", " [ 0.37331343, 0.36969444, 0.25699213],\n", " [ 0.69447452, 0.15733764, 0.14818783]], dtype=float32)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p = model.predict([test_liv[:,None], test_edu[:,None], test_x], batch_size=256)\n", "p[:5]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "merge_1 (Merge) (None, 12) 0 \n", "_________________________________________________________________\n", "dense_4 (Dense) (None, 12) 156 \n", "_________________________________________________________________\n", "activation_3 (Activation) (None, 12) 0 \n", "_________________________________________________________________\n", "dense_5 (Dense) (None, 3) 39 \n", "_________________________________________________________________\n", "activation_4 (Activation) (None, 3) 0 \n", "=================================================================\n", "Total params: 247.0\n", "Trainable params: 247.0\n", "Non-trainable params: 0.0\n", "_________________________________________________________________\n" ] } ], "source": [ "model.summary()" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/sachin/anaconda/lib/python3.5/site-packages/ipykernel/__main__.py:4: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(units=3)`\n", "/Users/sachin/anaconda/lib/python3.5/site-packages/keras/models.py:826: UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.\n", " warnings.warn('The `nb_epoch` argument in `fit` '\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "1325/1325 [==============================] - 0s - loss: 0.6348 - acc: 0.6531 \n", "Epoch 2/100\n", "1325/1325 [==============================] - 0s - loss: 0.6224 - acc: 0.6644 \n", "Epoch 3/100\n", "1325/1325 [==============================] - 0s - loss: 0.6165 - acc: 0.6662 \n", "Epoch 4/100\n", "1325/1325 [==============================] - 0s - loss: 0.6129 - acc: 0.6694 \n", "Epoch 5/100\n", "1325/1325 [==============================] - 0s - loss: 0.6105 - acc: 0.6699 \n", "Epoch 6/100\n", "1325/1325 [==============================] - 0s - loss: 0.6086 - acc: 0.6699 \n", "Epoch 7/100\n", "1325/1325 [==============================] - 0s - loss: 0.6073 - acc: 0.6737 \n", "Epoch 8/100\n", "1325/1325 [==============================] - 0s - loss: 0.6062 - acc: 0.6704 \n", "Epoch 9/100\n", "1325/1325 [==============================] - 0s - loss: 0.6053 - acc: 0.6707 \n", "Epoch 10/100\n", "1325/1325 [==============================] - 0s - loss: 0.6045 - acc: 0.6707 \n", "Epoch 11/100\n", "1325/1325 [==============================] - 0s - loss: 0.6038 - acc: 0.6709 \n", "Epoch 12/100\n", "1325/1325 [==============================] - 0s - loss: 0.6032 - acc: 0.6722 \n", "Epoch 13/100\n", "1325/1325 [==============================] - 0s - loss: 0.6026 - acc: 0.6727 \n", "Epoch 14/100\n", "1325/1325 [==============================] - 0s - loss: 0.6020 - acc: 0.6719 \n", "Epoch 15/100\n", "1325/1325 [==============================] - 0s - loss: 0.6015 - acc: 0.6717 \n", "Epoch 16/100\n", "1325/1325 [==============================] - 0s - loss: 0.6010 - acc: 0.6735 \n", "Epoch 17/100\n", "1325/1325 [==============================] - 0s - loss: 0.6006 - acc: 0.6740 \n", "Epoch 18/100\n", "1325/1325 [==============================] - 0s - loss: 0.6002 - acc: 0.6730 \n", "Epoch 19/100\n", "1325/1325 [==============================] - 0s - loss: 0.5998 - acc: 0.6752 \n", "Epoch 20/100\n", "1325/1325 [==============================] - 0s - loss: 0.5994 - acc: 0.6762 \n", "Epoch 21/100\n", "1325/1325 [==============================] - 0s - loss: 0.5990 - acc: 0.6770 \n", "Epoch 22/100\n", "1325/1325 [==============================] - 0s - loss: 0.5987 - acc: 0.6770 \n", "Epoch 23/100\n", "1325/1325 [==============================] - 0s - loss: 0.5983 - acc: 0.6777 \n", "Epoch 24/100\n", "1325/1325 [==============================] - 0s - loss: 0.5979 - acc: 0.6770 \n", "Epoch 25/100\n", "1325/1325 [==============================] - 0s - loss: 0.5976 - acc: 0.6785 \n", "Epoch 26/100\n", "1325/1325 [==============================] - 0s - loss: 0.5973 - acc: 0.6790 \n", "Epoch 27/100\n", "1325/1325 [==============================] - 0s - loss: 0.5970 - acc: 0.6792 \n", "Epoch 28/100\n", "1325/1325 [==============================] - 0s - loss: 0.5966 - acc: 0.6787 \n", "Epoch 29/100\n", "1325/1325 [==============================] - 0s - loss: 0.5963 - acc: 0.6800 \n", "Epoch 30/100\n", "1325/1325 [==============================] - 0s - loss: 0.5960 - acc: 0.6795 \n", "Epoch 31/100\n", "1325/1325 [==============================] - 0s - loss: 0.5957 - acc: 0.6790 \n", "Epoch 32/100\n", "1325/1325 [==============================] - 0s - loss: 0.5954 - acc: 0.6790 \n", "Epoch 33/100\n", "1325/1325 [==============================] - 0s - loss: 0.5951 - acc: 0.6790 \n", "Epoch 34/100\n", "1325/1325 [==============================] - 0s - loss: 0.5949 - acc: 0.6810 \n", "Epoch 35/100\n", "1325/1325 [==============================] - 0s - loss: 0.5946 - acc: 0.6792 \n", "Epoch 36/100\n", "1325/1325 [==============================] - 0s - loss: 0.5943 - acc: 0.6800 \n", "Epoch 37/100\n", "1325/1325 [==============================] - 0s - loss: 0.5941 - acc: 0.6818 \n", "Epoch 38/100\n", "1325/1325 [==============================] - 0s - loss: 0.5938 - acc: 0.6803 \n", "Epoch 39/100\n", "1325/1325 [==============================] - 0s - loss: 0.5935 - acc: 0.6813 \n", "Epoch 40/100\n", "1325/1325 [==============================] - 0s - loss: 0.5933 - acc: 0.6808 \n", "Epoch 41/100\n", "1325/1325 [==============================] - 0s - loss: 0.5930 - acc: 0.6813 \n", "Epoch 42/100\n", "1325/1325 [==============================] - 0s - loss: 0.5928 - acc: 0.6820 \n", "Epoch 43/100\n", "1325/1325 [==============================] - 0s - loss: 0.5926 - acc: 0.6835 \n", "Epoch 44/100\n", "1325/1325 [==============================] - 0s - loss: 0.5923 - acc: 0.6843 \n", "Epoch 45/100\n", "1325/1325 [==============================] - 0s - loss: 0.5921 - acc: 0.6840 \n", "Epoch 46/100\n", "1325/1325 [==============================] - 0s - loss: 0.5919 - acc: 0.6835 \n", "Epoch 47/100\n", "1325/1325 [==============================] - 0s - loss: 0.5917 - acc: 0.6848 \n", "Epoch 48/100\n", "1325/1325 [==============================] - 0s - loss: 0.5915 - acc: 0.6840 \n", "Epoch 49/100\n", "1325/1325 [==============================] - 0s - loss: 0.5913 - acc: 0.6848 \n", "Epoch 50/100\n", "1325/1325 [==============================] - 0s - loss: 0.5911 - acc: 0.6843 \n", "Epoch 51/100\n", "1325/1325 [==============================] - 0s - loss: 0.5909 - acc: 0.6853 \n", "Epoch 52/100\n", "1325/1325 [==============================] - 0s - loss: 0.5907 - acc: 0.6850 \n", "Epoch 53/100\n", "1325/1325 [==============================] - 0s - loss: 0.5905 - acc: 0.6850 \n", "Epoch 54/100\n", "1325/1325 [==============================] - 0s - loss: 0.5903 - acc: 0.6855 \n", "Epoch 55/100\n", "1325/1325 [==============================] - 0s - loss: 0.5901 - acc: 0.6853 \n", "Epoch 56/100\n", "1325/1325 [==============================] - 0s - loss: 0.5899 - acc: 0.6863 \n", "Epoch 57/100\n", "1325/1325 [==============================] - 0s - loss: 0.5897 - acc: 0.6858 \n", "Epoch 58/100\n", "1325/1325 [==============================] - 0s - loss: 0.5895 - acc: 0.6858 \n", "Epoch 59/100\n", "1325/1325 [==============================] - 0s - loss: 0.5894 - acc: 0.6870 \n", "Epoch 60/100\n", "1325/1325 [==============================] - 0s - loss: 0.5891 - acc: 0.6868 \n", "Epoch 61/100\n", "1325/1325 [==============================] - 0s - loss: 0.5890 - acc: 0.6865 \n", "Epoch 62/100\n", "1325/1325 [==============================] - 0s - loss: 0.5888 - acc: 0.6875 \n", "Epoch 63/100\n", "1325/1325 [==============================] - 0s - loss: 0.5886 - acc: 0.6881 \n", "Epoch 64/100\n", "1325/1325 [==============================] - 0s - loss: 0.5884 - acc: 0.6881 \n", "Epoch 65/100\n", "1325/1325 [==============================] - 0s - loss: 0.5882 - acc: 0.6883 \n", "Epoch 66/100\n", "1325/1325 [==============================] - 0s - loss: 0.5881 - acc: 0.6891 \n", "Epoch 67/100\n", "1325/1325 [==============================] - 0s - loss: 0.5879 - acc: 0.6893 \n", "Epoch 68/100\n", "1325/1325 [==============================] - 0s - loss: 0.5877 - acc: 0.6896 \n", "Epoch 69/100\n", "1325/1325 [==============================] - 0s - loss: 0.5875 - acc: 0.6891 \n", "Epoch 70/100\n", "1325/1325 [==============================] - 0s - loss: 0.5874 - acc: 0.6891 \n", "Epoch 71/100\n", "1325/1325 [==============================] - 0s - loss: 0.5872 - acc: 0.6896 \n", "Epoch 72/100\n", "1325/1325 [==============================] - 0s - loss: 0.5871 - acc: 0.6893 \n", "Epoch 73/100\n", "1325/1325 [==============================] - 0s - loss: 0.5869 - acc: 0.6893 \n", "Epoch 74/100\n", "1325/1325 [==============================] - 0s - loss: 0.5867 - acc: 0.6901 \n", "Epoch 75/100\n", "1325/1325 [==============================] - 0s - loss: 0.5866 - acc: 0.6901 \n", "Epoch 76/100\n", "1325/1325 [==============================] - 0s - loss: 0.5864 - acc: 0.6901 \n", "Epoch 77/100\n", "1325/1325 [==============================] - 0s - loss: 0.5863 - acc: 0.6901 \n", "Epoch 78/100\n", "1325/1325 [==============================] - 0s - loss: 0.5861 - acc: 0.6901 \n", "Epoch 79/100\n", "1325/1325 [==============================] - 0s - loss: 0.5859 - acc: 0.6903 \n", "Epoch 80/100\n", "1325/1325 [==============================] - 0s - loss: 0.5858 - acc: 0.6901 \n", "Epoch 81/100\n", "1325/1325 [==============================] - 0s - loss: 0.5857 - acc: 0.6901 \n", "Epoch 82/100\n", "1325/1325 [==============================] - 0s - loss: 0.5855 - acc: 0.6901 \n", "Epoch 83/100\n", "1325/1325 [==============================] - 0s - loss: 0.5854 - acc: 0.6896 \n", "Epoch 84/100\n", "1325/1325 [==============================] - 0s - loss: 0.5852 - acc: 0.6903 \n", "Epoch 85/100\n", "1325/1325 [==============================] - 0s - loss: 0.5851 - acc: 0.6901 \n", "Epoch 86/100\n", "1325/1325 [==============================] - 0s - loss: 0.5850 - acc: 0.6903 \n", "Epoch 87/100\n", "1325/1325 [==============================] - 0s - loss: 0.5849 - acc: 0.6903 \n", "Epoch 88/100\n", "1325/1325 [==============================] - 0s - loss: 0.5847 - acc: 0.6906 \n", "Epoch 89/100\n", "1325/1325 [==============================] - 0s - loss: 0.5846 - acc: 0.6906 \n", "Epoch 90/100\n", "1325/1325 [==============================] - 0s - loss: 0.5845 - acc: 0.6906 \n", "Epoch 91/100\n", "1325/1325 [==============================] - 0s - loss: 0.5844 - acc: 0.6911 \n", "Epoch 92/100\n", "1325/1325 [==============================] - 0s - loss: 0.5842 - acc: 0.6918 \n", "Epoch 93/100\n", "1325/1325 [==============================] - 0s - loss: 0.5841 - acc: 0.6911 \n", "Epoch 94/100\n", "1325/1325 [==============================] - 0s - loss: 0.5840 - acc: 0.6916 \n", "Epoch 95/100\n", "1325/1325 [==============================] - 0s - loss: 0.5838 - acc: 0.6906 \n", "Epoch 96/100\n", "1325/1325 [==============================] - 0s - loss: 0.5837 - acc: 0.6916 \n", "Epoch 97/100\n", "1325/1325 [==============================] - 0s - loss: 0.5836 - acc: 0.6908 \n", "Epoch 98/100\n", "1325/1325 [==============================] - 0s - loss: 0.5835 - acc: 0.6913 \n", "Epoch 99/100\n", "1325/1325 [==============================] - 0s - loss: 0.5833 - acc: 0.6906 \n", "Epoch 100/100\n", "1325/1325 [==============================] - 0s - loss: 0.5832 - acc: 0.6911 \n" ] }, { "data": { "text/plain": [ "<keras.callbacks.History at 0x1227b34a8>" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Sequential()\n", "model.add(Dense(4, input_dim=train_x.shape[1]))\n", "model.add(Activation('relu'))\n", "model.add(Dense(output_dim=3))\n", "model.add(Activation('softmax'))\n", "\n", "model.compile(optimizer='adagrad', loss='binary_crossentropy', metrics=['accuracy'])\n", "model.fit(train_x, train_y, nb_epoch=100)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "148/148 [==============================] - 0s\n" ] }, { "data": { "text/plain": [ "[0.56608289480209351, 0.70945948362350464]" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(test_x,test_y,batch_size=256)" ] }, { "cell_type": "code", "execution_count": 95, "metadata": { "collapsed": true }, "outputs": [], "source": [ "model.fit?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" }, "latex_envs": { "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 0 } }, "nbformat": 4, "nbformat_minor": 1 }