{"nbformat":4,"nbformat_minor":0,"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.6.1"},"colab":{"name":"2-Advanced-MLP.ipynb","provenance":[],"collapsed_sections":[]}},"cells":[{"cell_type":"markdown","metadata":{"id":"ZbIvihxRm80P"},"source":["# Advanced MLP\n","- Advanced techniques for training neural networks\n"," - Weight Initialization\n"," - Nonlinearity (Activation function)\n"," - Optimizers\n"," - Batch Normalization\n"," - Dropout (Regularization)\n"," - Model Ensemble"]},{"cell_type":"code","metadata":{"id":"W12bh526m80Q"},"source":["import matplotlib.pyplot as plt\n","\n","from sklearn.model_selection import train_test_split\n","from tensorflow.keras.datasets import mnist\n","from tensorflow.keras.models import Sequential\n","from tensorflow.keras.utils import to_categorical"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"1m-pC5sOm80U"},"source":["## Load Dataset\n","- MNIST dataset\n","- source: http://yann.lecun.com/exdb/mnist/"]},{"cell_type":"code","metadata":{"id":"fH78vu3rm80V"},"source":["(X_train, y_train), (X_test, y_test) = mnist.load_data()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"RroxZgNjm80Y","executionInfo":{"status":"ok","timestamp":1604684633389,"user_tz":420,"elapsed":1426,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"2f1f1c0b-1218-41c7-cfb3-69b15968d66b","colab":{"base_uri":"https://localhost:8080/","height":282}},"source":["plt.imshow(X_train[0]) # show first number in the dataset\n","plt.show()\n","print('Label: ', y_train[0])"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAOZ0lEQVR4nO3dbYxc5XnG8euKbezamMQbB9chLjjgFAg0Jl0ZEBZQobgOqgSoCsSKIkJpnSY4Ca0rQWlV3IpWbpUQUUqRTHExFS+BBIQ/0CTUQpCowWWhBgwEDMY0NmaNWYENIX5Z3/2w42iBnWeXmTMv3vv/k1Yzc+45c24NXD5nznNmHkeEAIx/H+p0AwDag7ADSRB2IAnCDiRB2IEkJrZzY4d5ckzRtHZuEkjlV3pbe2OPR6o1FXbbiyVdJ2mCpH+LiJWl50/RNJ3qc5rZJICC9bGubq3hw3jbEyTdIOnzkk6UtMT2iY2+HoDWauYz+wJJL0TE5ojYK+lOSedV0xaAqjUT9qMk/WLY4621Ze9ie6ntPtt9+7Snic0BaEbLz8ZHxKqI6I2I3kma3OrNAaijmbBvkzRn2ONP1JYB6ELNhP1RSfNsz7V9mKQvSlpbTVsAqtbw0FtE7Le9TNKPNDT0tjoinq6sMwCVamqcPSLul3R/Rb0AaCEulwWSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiCJpmZxRffzxPJ/4gkfm9nS7T/3F8fUrQ1OPVBc9+hjdxTrU7/uYv3Vaw+rW3u893vFdXcOvl2sn3r38mL9uD9/pFjvhKbCbnuLpN2SBiXtj4jeKpoCUL0q9uy/FxE7K3gdAC3EZ3YgiWbDHpJ+bPsx20tHeoLtpbb7bPft054mNwegUc0exi+MiG22j5T0gO2fR8TDw58QEaskrZKkI9wTTW4PQIOa2rNHxLba7Q5J90paUEVTAKrXcNhtT7M9/eB9SYskbayqMQDVauYwfpake20ffJ3bI+KHlXQ1zkw4YV6xHpMnFeuvnPWRYv2d0+qPCfd8uDxe/JPPlMebO+k/fzm9WP/Hf1lcrK8/+fa6tZf2vVNcd2X/54r1j//k0PtE2nDYI2KzpM9U2AuAFmLoDUiCsANJEHYgCcIOJEHYgST4imsFBs/+bLF+7S03FOufmlT/q5jj2b4YLNb/5vqvFOsT3y4Pf51+97K6tenb9hfXnbyzPDQ3tW99sd6N2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs1dg8nOvFOuP/WpOsf6pSf1VtlOp5dtPK9Y3v1X+Kepbjv1+3dqbB8rj5LP++b+L9VY69L7AOjr27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQhCPaN6J4hHviVJ/Ttu11i4FLTi/Wdy0u/9zzhCcPL9af+Pr1H7ing67Z+TvF+qNnlcfRB994s1iP0+v/APGWbxZX1dwlT5SfgPdZH+u0KwZGnMuaPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4exeYMPOjxfrg6wPF+ku31x8rf/rM1cV1F/zDN4r1I2/o3HfK8cE1Nc5ue7XtHbY3DlvWY/sB25tqtzOqbBhA9cZyGH+LpPfOen+lpHURMU/SutpjAF1s1LBHxMOS3nsceZ6kNbX7aySdX3FfACrW6G/QzYqI7bX7r0qaVe+JtpdKWipJUzS1wc0BaFbTZ+Nj6Axf3bN8EbEqInojoneSJje7OQANajTs/bZnS1Ltdkd1LQFohUbDvlbSxbX7F0u6r5p2ALTKqJ/Zbd8h6WxJM21vlXS1pJWS7rJ9qaSXJV3YyibHu8Gdrze1/r5djc/v/ukvPVOsv3bjhPILHCjPsY7uMWrYI2JJnRJXxwCHEC6XBZIg7EAShB1IgrADSRB2IAmmbB4HTrji+bq1S04uD5r8+9HrivWzvnBZsT79e48U6+ge7NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2ceB0rTJr3/thOK6/7f2nWL9ymtuLdb/8sILivX43w/Xrc35+58V11Ubf+Y8A/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEUzYnN/BHpxfrt1397WJ97sQpDW/707cuK9bn3bS9WN+/eUvD2x6vmpqyGcD4QNiBJAg7kARhB5Ig7EAShB1IgrADSTDOjqI4Y36xfsTKrcX6HZ/8UcPbPv7BPy7Wf/tv63+PX5IGN21ueNuHqqbG2W2vtr3D9sZhy1bY3mZ7Q+3v3CobBlC9sRzG3yJp8QjLvxsR82t/91fbFoCqjRr2iHhY0kAbegHQQs2coFtm+8naYf6Mek+yvdR2n+2+fdrTxOYANKPRsN8o6VhJ8yVtl/Sdek+MiFUR0RsRvZM0ucHNAWhWQ2GPiP6IGIyIA5JukrSg2rYAVK2hsNuePezhBZI21nsugO4w6ji77TsknS1ppqR+SVfXHs+XFJK2SPpqRJS/fCzG2cejCbOOLNZfuei4urX1V1xXXPdDo+yLvvTSomL9zYWvF+vjUWmcfdRJIiJiyQiLb266KwBtxeWyQBKEHUiCsANJEHYgCcIOJMFXXNExd20tT9k81YcV67+MvcX6H3zj8vqvfe/64rqHKn5KGgBhB7Ig7EAShB1IgrADSRB2IAnCDiQx6rfekNuBheWfkn7xC+Upm0+av6VubbRx9NFcP3BKsT71vr6mXn+8Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwzj7OufekYv35b5bHum86Y02xfuaU8nfKm7En9hXrjwzMLb/AgVF/3TwV9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7IeAiXOPLtZfvOTjdWsrLrqzuO4fHr6zoZ6qcFV/b7H+0HWnFesz1pR/dx7vNuqe3fYc2w/afsb207a/VVveY/sB25tqtzNa3y6ARo3lMH6/pOURcaKk0yRdZvtESVdKWhcR8yStqz0G0KVGDXtEbI+Ix2v3d0t6VtJRks6TdPBayjWSzm9VkwCa94E+s9s+RtIpktZLmhURBy8+flXSrDrrLJW0VJKmaGqjfQJo0pjPxts+XNIPJF0eEbuG12JodsgRZ4iMiFUR0RsRvZM0ualmATRuTGG3PUlDQb8tIu6pLe63PbtWny1pR2taBFCFUQ/jbVvSzZKejYhrh5XWSrpY0sra7X0t6XAcmHjMbxXrb/7u7GL9or/7YbH+px+5p1hvpeXby8NjP/vX+sNrPbf8T3HdGQcYWqvSWD6znyHpy5Kesr2htuwqDYX8LtuXSnpZ0oWtaRFAFUYNe0T8VNKIk7tLOqfadgC0CpfLAkkQdiAJwg4kQdiBJAg7kARfcR2jibN/s25tYPW04rpfm/tQsb5ken9DPVVh2baFxfrjN5anbJ75/Y3Fes9uxsq7BXt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUgizTj73t8v/2zx3j8bKNavOu7+urVFv/F2Qz1VpX/wnbq1M9cuL657/F//vFjveaM8Tn6gWEU3Yc8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkGWffcn7537XnT767Zdu+4Y1ji/XrHlpUrHuw3o/7Djn+mpfq1ub1ry+uO1isYjxhzw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTgiyk+w50i6VdIsSSFpVURcZ3uFpD+R9FrtqVdFRP0vfUs6wj1xqpn4FWiV9bFOu2JgxAszxnJRzX5JyyPicdvTJT1m+4Fa7bsR8e2qGgXQOmOZn327pO21+7ttPyvpqFY3BqBaH+gzu+1jJJ0i6eA1mMtsP2l7te0ZddZZarvPdt8+7WmqWQCNG3PYbR8u6QeSLo+IXZJulHSspPka2vN/Z6T1ImJVRPRGRO8kTa6gZQCNGFPYbU/SUNBvi4h7JCki+iNiMCIOSLpJ0oLWtQmgWaOG3bYl3Szp2Yi4dtjy2cOedoGk8nSeADpqLGfjz5D0ZUlP2d5QW3aVpCW252toOG6LpK+2pEMAlRjL2fifShpp3K44pg6gu3AFHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IIlRf0q60o3Zr0l6ediimZJ2tq2BD6Zbe+vWviR6a1SVvR0dER8bqdDWsL9v43ZfRPR2rIGCbu2tW/uS6K1R7eqNw3ggCcIOJNHpsK/q8PZLurW3bu1LordGtaW3jn5mB9A+nd6zA2gTwg4k0ZGw215s+znbL9i+shM91GN7i+2nbG+w3dfhXlbb3mF747BlPbYfsL2pdjviHHsd6m2F7W21926D7XM71Nsc2w/afsb207a/VVve0feu0Fdb3re2f2a3PUHS85I+J2mrpEclLYmIZ9raSB22t0jqjYiOX4Bh+0xJb0m6NSJOqi37J0kDEbGy9g/ljIi4okt6WyHprU5P412brWj28GnGJZ0v6Svq4HtX6OtCteF968SefYGkFyJic0TslXSnpPM60EfXi4iHJQ28Z/F5ktbU7q/R0P8sbVent64QEdsj4vHa/d2SDk4z3tH3rtBXW3Qi7EdJ+sWwx1vVXfO9h6Qf237M9tJONzOCWRGxvXb/VUmzOtnMCEadxrud3jPNeNe8d41Mf94sTtC938KI+Kykz0u6rHa42pVi6DNYN42djmka73YZYZrxX+vke9fo9OfN6kTYt0maM+zxJ2rLukJEbKvd7pB0r7pvKur+gzPo1m53dLifX+umabxHmmZcXfDedXL6806E/VFJ82zPtX2YpC9KWtuBPt7H9rTaiRPZniZpkbpvKuq1ki6u3b9Y0n0d7OVdumUa73rTjKvD713Hpz+PiLb/STpXQ2fkX5T0V53ooU5fn5T0RO3v6U73JukODR3W7dPQuY1LJX1U0jpJmyT9l6SeLurtPyQ9JelJDQVrdod6W6ihQ/QnJW2o/Z3b6feu0Fdb3jculwWS4AQdkARhB5Ig7EAShB1IgrADSRB2IAnCDiTx/65XcTNOWsh5AAAAAElFTkSuQmCC\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}},{"output_type":"stream","text":["Label: 5\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"bGgToRjIm80c","executionInfo":{"status":"ok","timestamp":1604684633391,"user_tz":420,"elapsed":1420,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"c0e35f90-7ffb-4a1c-97a1-4a04ff47ab59","colab":{"base_uri":"https://localhost:8080/","height":282}},"source":["plt.imshow(X_test[0]) # show first number in the dataset\n","plt.show()\n","print('Label: ', y_test[0])"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANiklEQVR4nO3df4wc9XnH8c8n/kV8QGtDcF3j4ISQqE4aSHWBRNDKESUFImSiJBRLtVyJ5lALElRRW0QVBalVSlEIok0aySluHESgaQBhJTSNa6W1UKljg4yxgdaEmsau8QFOaxPAP/DTP24cHXD7vWNndmft5/2SVrs7z87Oo/F9PLMzO/t1RAjA8e9tbTcAoD8IO5AEYQeSIOxAEoQdSGJ6Pxc207PiBA31c5FAKq/qZzoYBzxRrVbYbV8s6XZJ0yT9bUTcXHr9CRrSeb6wziIBFGyIdR1rXe/G254m6auSLpG0WNIy24u7fT8AvVXnM/u5kp6OiGci4qCkeyQtbaYtAE2rE/YFkn4y7vnOatrr2B6xvcn2pkM6UGNxAOro+dH4iFgZEcMRMTxDs3q9OAAd1An7LkkLxz0/vZoGYADVCftGSWfZfpftmZKulLSmmbYANK3rU28Rcdj2tZL+SWOn3lZFxLbGOgPQqFrn2SPiQUkPNtQLgB7i67JAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJGoN2Wx7h6T9kl6TdDgihptoCkDzaoW98rGIeKGB9wHQQ+zGA0nUDXtI+oHtR2yPTPQC2yO2N9nedEgHai4OQLfq7sZfEBG7bJ8maa3tpyJi/fgXRMRKSSsl6WTPjZrLA9ClWlv2iNhV3Y9Kul/SuU00BaB5XYfd9pDtk44+lvRxSVubagxAs+rsxs+TdL/to+/zrYj4fiNdAWhc12GPiGcknd1gLwB6iFNvQBKEHUiCsANJEHYgCcIOJNHEhTApvPjZj3asvXP508V5nxqdV6wfPDCjWF9wd7k+e+dLHWtHNj9RnBd5sGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4zz5Ff/xH3+pY+9TQT8szn1lz4UvK5R2HX+5Yu/35j9Vc+LHrR6NndKwN3foLxXmnr3uk6XZax5YdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRP8GaTnZc+M8X9i35TXpZ58+r2PthQ+W/8+c82R5Hf/0V1ysz/zg/xbrt3zgvo61i97+SnHe7718YrH+idmdr5Wv65U4WKxvODBUrC854VDXy37P964u1t87srHr927ThlinfbF3wj8otuxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATXs0/R0Hc2FGr13vvkerPrr39pScfan5+/qLzsfy3/5v0tS97TRUdTM/2VI8X60Jbdxfop6+8t1n91Zuff25+9o/xb/MejSbfstlfZHrW9ddy0ubbX2t5e3c/pbZsA6prKbvw3JF38hmk3SFoXEWdJWlc9BzDAJg17RKyXtPcNk5dKWl09Xi3p8ob7AtCwbj+zz4uIox+onpPUcTAz2yOSRiTpBM3ucnEA6qp9ND7GrqTpeKVHRKyMiOGIGJ6hWXUXB6BL3YZ9j+35klTdjzbXEoBe6DbsayStqB6vkPRAM+0A6JVJP7Pbvltjv1x+qu2dkr4g6WZJ37Z9laRnJV3RyyZRdvi5PR1rQ/d2rknSa5O899B3Xuyio2bs+b2PFuvvn1n+8/3S3vd1rC36u2eK8x4uVo9Nk4Y9IpZ1KB2bv0IBJMXXZYEkCDuQBGEHkiDsQBKEHUiCS1zRmulnLCzWv3LjV4r1GZ5WrP/D7b/ZsXbK7oeL8x6P2LIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKcZ0drnvrDBcX6h2eVh7LedrA8HPXcJ15+yz0dz9iyA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASnGdHTx34xIc71h799G2TzF0eQej3r7uuWH/7v/1okvfPhS07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBeXb01H9f0nl7cqLL59GX/ddFxfrs7z9WrEexms+kW3bbq2yP2t46btpNtnfZ3lzdLu1tmwDqmspu/DckXTzB9Nsi4pzq9mCzbQFo2qRhj4j1kvb2oRcAPVTnAN21trdUu/lzOr3I9ojtTbY3HdKBGosDUEe3Yf+apDMlnSNpt6RbO70wIlZGxHBEDM+Y5MIGAL3TVdgjYk9EvBYRRyR9XdK5zbYFoGldhd32/HFPPylpa6fXAhgMk55nt323pCWSTrW9U9IXJC2xfY7GTmXukHR1D3vEAHvbSScV68t//aGOtX1HXi3OO/rFdxfrsw5sLNbxepOGPSKWTTD5jh70AqCH+LoskARhB5Ig7EAShB1IgrADSXCJK2rZftP7i/Xvnvo3HWtLt3+qOO+sBzm11iS27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZUfR/v/ORYn3Lb/9Vsf7jw4c61l76y9OL887S7mIdbw1bdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgvPsyU1f8MvF+vWf//tifZbLf0JXPra8Y+0d/8j16v3Elh1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuA8+3HO08v/xGd/d2ex/pkTXyzW79p/WrE+7/OdtydHinOiaZNu2W0vtP1D20/Y3mb7umr6XNtrbW+v7uf0vl0A3ZrKbvxhSZ+LiMWSPiLpGtuLJd0gaV1EnCVpXfUcwICaNOwRsTsiHq0e75f0pKQFkpZKWl29bLWky3vVJID63tJndtuLJH1I0gZJ8yLi6I+EPSdpXod5RiSNSNIJmt1tnwBqmvLReNsnSrpX0vURsW98LSJCUkw0X0SsjIjhiBieoVm1mgXQvSmF3fYMjQX9roi4r5q8x/b8qj5f0mhvWgTQhEl3421b0h2SnoyIL48rrZG0QtLN1f0DPekQ9Zz9vmL5z067s9bbf/WLnynWf/Gxh2u9P5ozlc/s50taLulx25uraTdqLOTftn2VpGclXdGbFgE0YdKwR8RDktyhfGGz7QDoFb4uCyRB2IEkCDuQBGEHkiDsQBJc4nocmLb4vR1rI/fU+/rD4lXXFOuL7vz3Wu+P/mHLDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ79OPDUH3T+Yd/LZu/rWJuK0//lYPkFMeEPFGEAsWUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQ4z34MePWyc4v1dZfdWqgy5BbGsGUHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSSmMj77QknflDRPUkhaGRG3275J0mclPV+99MaIeLBXjWb2P+dPK9bfOb37c+l37T+tWJ+xr3w9O1ezHzum8qWaw5I+FxGP2j5J0iO211a12yLiS71rD0BTpjI++25Ju6vH+20/KWlBrxsD0Ky39Jnd9iJJH5K0oZp0re0ttlfZnvC3kWyP2N5ke9MhHajVLIDuTTnstk+UdK+k6yNin6SvSTpT0jka2/JP+AXtiFgZEcMRMTxDsxpoGUA3phR22zM0FvS7IuI+SYqIPRHxWkQckfR1SeWrNQC0atKw27akOyQ9GRFfHjd9/riXfVLS1ubbA9CUqRyNP1/SckmP295cTbtR0jLb52js7MsOSVf3pEPU8hcvLi7WH/6tRcV67H68wW7QpqkcjX9IkicocU4dOIbwDTogCcIOJEHYgSQIO5AEYQeSIOxAEo4+Drl7sufGeb6wb8sDstkQ67Qv9k50qpwtO5AFYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dfz7Lafl/TsuEmnSnqhbw28NYPa26D2JdFbt5rs7YyIeMdEhb6G/U0LtzdFxHBrDRQMam+D2pdEb93qV2/sxgNJEHYgibbDvrLl5ZcMam+D2pdEb93qS2+tfmYH0D9tb9kB9AlhB5JoJey2L7b9H7aftn1DGz10YnuH7cdtb7a9qeVeVtketb113LS5ttfa3l7dTzjGXku93WR7V7XuNtu+tKXeFtr+oe0nbG+zfV01vdV1V+irL+ut75/ZbU+T9J+SLpK0U9JGScsi4om+NtKB7R2ShiOi9S9g2P4NSS9J+mZEfKCadoukvRFxc/Uf5ZyI+JMB6e0mSS+1PYx3NVrR/PHDjEu6XNLvqsV1V+jrCvVhvbWxZT9X0tMR8UxEHJR0j6SlLfQx8CJivaS9b5i8VNLq6vFqjf2x9F2H3gZCROyOiEerx/slHR1mvNV1V+irL9oI+wJJPxn3fKcGa7z3kPQD24/YHmm7mQnMi4jd1ePnJM1rs5kJTDqMdz+9YZjxgVl33Qx/XhcH6N7sgoj4NUmXSLqm2l0dSDH2GWyQzp1OaRjvfplgmPGfa3PddTv8eV1thH2XpIXjnp9eTRsIEbGruh+VdL8GbyjqPUdH0K3uR1vu5+cGaRjviYYZ1wCsuzaHP28j7BslnWX7XbZnSrpS0poW+ngT20PVgRPZHpL0cQ3eUNRrJK2oHq+Q9ECLvbzOoAzj3WmYcbW87lof/jwi+n6TdKnGjsj/WNKfttFDh77eLemx6rat7d4k3a2x3bpDGju2cZWkUyStk7Rd0j9LmjtAvd0p6XFJWzQWrPkt9XaBxnbRt0jaXN0ubXvdFfrqy3rj67JAEhygA5Ig7EAShB1IgrADSRB2IAnCDiRB2IEk/h9BCfQTVPflJQAAAABJRU5ErkJggg==\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}},{"output_type":"stream","text":["Label: 7\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"cKDggu2Gm80f"},"source":["# reshaping X data: (n, 28, 28) => (n, 784)\n","X_train = X_train.reshape((X_train.shape[0], -1))\n","X_test = X_test.reshape((X_test.shape[0], -1))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MUFEC2K_m80k"},"source":["# converting y data into categorical (one-hot encoding)\n","y_train = to_categorical(y_train)\n","y_test = to_categorical(y_test)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Y1D9nxJ_m80m","executionInfo":{"status":"ok","timestamp":1604684633546,"user_tz":420,"elapsed":1561,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"cd5049e9-e452-4751-faa9-7d723eca60ee","colab":{"base_uri":"https://localhost:8080/"}},"source":["print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["(60000, 784) (10000, 784) (60000, 10) (10000, 10)\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"X7X7pN-Sm80p"},"source":["## Basic MLP model\n","- Naive MLP model without any alterations"]},{"cell_type":"code","metadata":{"id":"OPUXkSYmm80p"},"source":["from tensorflow.keras.models import Sequential\n","from tensorflow.keras.layers import Activation, Dense\n","from tensorflow.keras import optimizers"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"b32x9pvIm80r"},"source":["model = Sequential()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"XPnz8u9ym80u"},"source":["model.add(Dense(50, input_shape = (784, )))\n","model.add(Activation('sigmoid'))\n","model.add(Dense(50))\n","model.add(Activation('sigmoid'))\n","model.add(Dense(50))\n","model.add(Activation('sigmoid'))\n","model.add(Dense(50))\n","model.add(Activation('sigmoid'))\n","model.add(Dense(10))\n","model.add(Activation('softmax'))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ua5rjdGYm80y"},"source":["sgd = optimizers.SGD(lr = 0.001)\n","model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Lj23wJ2lm800","executionInfo":{"status":"ok","timestamp":1604684711031,"user_tz":420,"elapsed":79023,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"ac601aeb-351d-4689-ae21-fb2bfbe6a27b","colab":{"base_uri":"https://localhost:8080/"}},"source":["history = model.fit(X_train, y_train, batch_size = 256, validation_split = 0.3, epochs = 100, verbose = 1)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Epoch 1/100\n","165/165 [==============================] - 1s 7ms/step - loss: 2.4838 - accuracy: 0.0995 - val_loss: 2.4486 - val_accuracy: 0.0966\n","Epoch 2/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.4176 - accuracy: 0.0995 - val_loss: 2.3965 - val_accuracy: 0.0966\n","Epoch 3/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3761 - accuracy: 0.0995 - val_loss: 2.3634 - val_accuracy: 0.0966\n","Epoch 4/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3496 - accuracy: 0.0995 - val_loss: 2.3422 - val_accuracy: 0.0966\n","Epoch 5/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3327 - accuracy: 0.0995 - val_loss: 2.3283 - val_accuracy: 0.0966\n","Epoch 6/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3217 - accuracy: 0.0997 - val_loss: 2.3193 - val_accuracy: 0.0970\n","Epoch 7/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3145 - accuracy: 0.1019 - val_loss: 2.3135 - val_accuracy: 0.0989\n","Epoch 8/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3099 - accuracy: 0.0827 - val_loss: 2.3096 - val_accuracy: 0.1059\n","Epoch 9/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3070 - accuracy: 0.1140 - val_loss: 2.3072 - val_accuracy: 0.1079\n","Epoch 10/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.3050 - accuracy: 0.1143 - val_loss: 2.3055 - val_accuracy: 0.1079\n","Epoch 11/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3038 - accuracy: 0.1143 - val_loss: 2.3044 - val_accuracy: 0.1079\n","Epoch 12/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3029 - accuracy: 0.1143 - val_loss: 2.3036 - val_accuracy: 0.1079\n","Epoch 13/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3023 - accuracy: 0.1143 - val_loss: 2.3031 - val_accuracy: 0.1079\n","Epoch 14/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3019 - accuracy: 0.1143 - val_loss: 2.3028 - val_accuracy: 0.1079\n","Epoch 15/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.3016 - accuracy: 0.1143 - val_loss: 2.3025 - val_accuracy: 0.1079\n","Epoch 16/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3014 - accuracy: 0.1143 - val_loss: 2.3023 - val_accuracy: 0.1079\n","Epoch 17/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3013 - accuracy: 0.1143 - val_loss: 2.3021 - val_accuracy: 0.1079\n","Epoch 18/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3011 - accuracy: 0.1143 - val_loss: 2.3020 - val_accuracy: 0.1079\n","Epoch 19/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3010 - accuracy: 0.1143 - val_loss: 2.3018 - val_accuracy: 0.1079\n","Epoch 20/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.3009 - accuracy: 0.1143 - val_loss: 2.3017 - val_accuracy: 0.1079\n","Epoch 21/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3008 - accuracy: 0.1143 - val_loss: 2.3017 - val_accuracy: 0.1079\n","Epoch 22/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3007 - accuracy: 0.1143 - val_loss: 2.3016 - val_accuracy: 0.1079\n","Epoch 23/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3007 - accuracy: 0.1143 - val_loss: 2.3015 - val_accuracy: 0.1079\n","Epoch 24/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3006 - accuracy: 0.1143 - val_loss: 2.3014 - val_accuracy: 0.1079\n","Epoch 25/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3005 - accuracy: 0.1143 - val_loss: 2.3013 - val_accuracy: 0.1079\n","Epoch 26/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3004 - accuracy: 0.1143 - val_loss: 2.3012 - val_accuracy: 0.1079\n","Epoch 27/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.3003 - accuracy: 0.1143 - val_loss: 2.3011 - val_accuracy: 0.1079\n","Epoch 28/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3002 - accuracy: 0.1143 - val_loss: 2.3010 - val_accuracy: 0.1079\n","Epoch 29/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3001 - accuracy: 0.1143 - val_loss: 2.3009 - val_accuracy: 0.1079\n","Epoch 30/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3001 - accuracy: 0.1143 - val_loss: 2.3008 - val_accuracy: 0.1079\n","Epoch 31/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.3000 - accuracy: 0.1143 - val_loss: 2.3007 - val_accuracy: 0.1079\n","Epoch 32/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2999 - accuracy: 0.1143 - val_loss: 2.3007 - val_accuracy: 0.1079\n","Epoch 33/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2998 - accuracy: 0.1143 - val_loss: 2.3006 - val_accuracy: 0.1079\n","Epoch 34/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2997 - accuracy: 0.1143 - val_loss: 2.3005 - val_accuracy: 0.1079\n","Epoch 35/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2996 - accuracy: 0.1143 - val_loss: 2.3004 - val_accuracy: 0.1079\n","Epoch 36/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2995 - accuracy: 0.1143 - val_loss: 2.3003 - val_accuracy: 0.1079\n","Epoch 37/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2995 - accuracy: 0.1143 - val_loss: 2.3002 - val_accuracy: 0.1079\n","Epoch 38/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2994 - accuracy: 0.1143 - val_loss: 2.3002 - val_accuracy: 0.1079\n","Epoch 39/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2993 - accuracy: 0.1143 - val_loss: 2.3001 - val_accuracy: 0.1079\n","Epoch 40/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2992 - accuracy: 0.1143 - val_loss: 2.3000 - val_accuracy: 0.1079\n","Epoch 41/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2991 - accuracy: 0.1143 - val_loss: 2.2999 - val_accuracy: 0.1079\n","Epoch 42/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2991 - accuracy: 0.1143 - val_loss: 2.2998 - val_accuracy: 0.1079\n","Epoch 43/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2990 - accuracy: 0.1143 - val_loss: 2.2998 - val_accuracy: 0.1079\n","Epoch 44/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2989 - accuracy: 0.1143 - val_loss: 2.2997 - val_accuracy: 0.1079\n","Epoch 45/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2988 - accuracy: 0.1143 - val_loss: 2.2996 - val_accuracy: 0.1079\n","Epoch 46/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2987 - accuracy: 0.1143 - val_loss: 2.2995 - val_accuracy: 0.1079\n","Epoch 47/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2987 - accuracy: 0.1143 - val_loss: 2.2994 - val_accuracy: 0.1079\n","Epoch 48/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2986 - accuracy: 0.1143 - val_loss: 2.2994 - val_accuracy: 0.1079\n","Epoch 49/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2985 - accuracy: 0.1143 - val_loss: 2.2993 - val_accuracy: 0.1079\n","Epoch 50/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2984 - accuracy: 0.1143 - val_loss: 2.2992 - val_accuracy: 0.1079\n","Epoch 51/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2983 - accuracy: 0.1143 - val_loss: 2.2991 - val_accuracy: 0.1079\n","Epoch 52/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2983 - accuracy: 0.1143 - val_loss: 2.2991 - val_accuracy: 0.1079\n","Epoch 53/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2982 - accuracy: 0.1143 - val_loss: 2.2990 - val_accuracy: 0.1079\n","Epoch 54/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2981 - accuracy: 0.1143 - val_loss: 2.2989 - val_accuracy: 0.1079\n","Epoch 55/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2980 - accuracy: 0.1143 - val_loss: 2.2988 - val_accuracy: 0.1079\n","Epoch 56/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2979 - accuracy: 0.1143 - val_loss: 2.2988 - val_accuracy: 0.1079\n","Epoch 57/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2979 - accuracy: 0.1143 - val_loss: 2.2987 - val_accuracy: 0.1079\n","Epoch 58/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2978 - accuracy: 0.1143 - val_loss: 2.2986 - val_accuracy: 0.1079\n","Epoch 59/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2977 - accuracy: 0.1143 - val_loss: 2.2985 - val_accuracy: 0.1079\n","Epoch 60/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2976 - accuracy: 0.1143 - val_loss: 2.2984 - val_accuracy: 0.1079\n","Epoch 61/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2975 - accuracy: 0.1143 - val_loss: 2.2983 - val_accuracy: 0.1079\n","Epoch 62/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2975 - accuracy: 0.1143 - val_loss: 2.2982 - val_accuracy: 0.1079\n","Epoch 63/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2974 - accuracy: 0.1143 - val_loss: 2.2981 - val_accuracy: 0.1079\n","Epoch 64/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2973 - accuracy: 0.1143 - val_loss: 2.2981 - val_accuracy: 0.1079\n","Epoch 65/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2972 - accuracy: 0.1143 - val_loss: 2.2980 - val_accuracy: 0.1079\n","Epoch 66/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2971 - accuracy: 0.1143 - val_loss: 2.2979 - val_accuracy: 0.1079\n","Epoch 67/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2971 - accuracy: 0.1143 - val_loss: 2.2978 - val_accuracy: 0.1079\n","Epoch 68/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2970 - accuracy: 0.1143 - val_loss: 2.2977 - val_accuracy: 0.1079\n","Epoch 69/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2969 - accuracy: 0.1143 - val_loss: 2.2976 - val_accuracy: 0.1079\n","Epoch 70/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2968 - accuracy: 0.1143 - val_loss: 2.2976 - val_accuracy: 0.1079\n","Epoch 71/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2967 - accuracy: 0.1143 - val_loss: 2.2975 - val_accuracy: 0.1079\n","Epoch 72/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2967 - accuracy: 0.1143 - val_loss: 2.2974 - val_accuracy: 0.1079\n","Epoch 73/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2966 - accuracy: 0.1143 - val_loss: 2.2973 - val_accuracy: 0.1079\n","Epoch 74/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2965 - accuracy: 0.1143 - val_loss: 2.2972 - val_accuracy: 0.1079\n","Epoch 75/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2964 - accuracy: 0.1143 - val_loss: 2.2972 - val_accuracy: 0.1079\n","Epoch 76/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2963 - accuracy: 0.1143 - val_loss: 2.2971 - val_accuracy: 0.1079\n","Epoch 77/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2963 - accuracy: 0.1143 - val_loss: 2.2970 - val_accuracy: 0.1079\n","Epoch 78/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2962 - accuracy: 0.1143 - val_loss: 2.2969 - val_accuracy: 0.1079\n","Epoch 79/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2961 - accuracy: 0.1143 - val_loss: 2.2969 - val_accuracy: 0.1079\n","Epoch 80/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2960 - accuracy: 0.1143 - val_loss: 2.2968 - val_accuracy: 0.1079\n","Epoch 81/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2959 - accuracy: 0.1143 - val_loss: 2.2967 - val_accuracy: 0.1079\n","Epoch 82/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2958 - accuracy: 0.1143 - val_loss: 2.2966 - val_accuracy: 0.1079\n","Epoch 83/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2958 - accuracy: 0.1143 - val_loss: 2.2965 - val_accuracy: 0.1079\n","Epoch 84/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2957 - accuracy: 0.1143 - val_loss: 2.2964 - val_accuracy: 0.1079\n","Epoch 85/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2956 - accuracy: 0.1143 - val_loss: 2.2963 - val_accuracy: 0.1079\n","Epoch 86/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2955 - accuracy: 0.1143 - val_loss: 2.2963 - val_accuracy: 0.1079\n","Epoch 87/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2954 - accuracy: 0.1143 - val_loss: 2.2962 - val_accuracy: 0.1079\n","Epoch 88/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2953 - accuracy: 0.1143 - val_loss: 2.2961 - val_accuracy: 0.1079\n","Epoch 89/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2952 - accuracy: 0.1143 - val_loss: 2.2960 - val_accuracy: 0.1079\n","Epoch 90/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2952 - accuracy: 0.1143 - val_loss: 2.2959 - val_accuracy: 0.1079\n","Epoch 91/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2951 - accuracy: 0.1143 - val_loss: 2.2958 - val_accuracy: 0.1079\n","Epoch 92/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2950 - accuracy: 0.1143 - val_loss: 2.2958 - val_accuracy: 0.1079\n","Epoch 93/100\n","165/165 [==============================] - 1s 4ms/step - loss: 2.2949 - accuracy: 0.1143 - val_loss: 2.2957 - val_accuracy: 0.1079\n","Epoch 94/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2948 - accuracy: 0.1143 - val_loss: 2.2956 - val_accuracy: 0.1079\n","Epoch 95/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2947 - accuracy: 0.1143 - val_loss: 2.2955 - val_accuracy: 0.1079\n","Epoch 96/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2946 - accuracy: 0.1143 - val_loss: 2.2954 - val_accuracy: 0.1079\n","Epoch 97/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2945 - accuracy: 0.1143 - val_loss: 2.2953 - val_accuracy: 0.1079\n","Epoch 98/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2945 - accuracy: 0.1143 - val_loss: 2.2952 - val_accuracy: 0.1079\n","Epoch 99/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2944 - accuracy: 0.1143 - val_loss: 2.2951 - val_accuracy: 0.1079\n","Epoch 100/100\n","165/165 [==============================] - 1s 5ms/step - loss: 2.2943 - accuracy: 0.1143 - val_loss: 2.2950 - val_accuracy: 0.1079\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"w4egvBDNm803","executionInfo":{"status":"ok","timestamp":1604684711175,"user_tz":420,"elapsed":79159,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"7cf5d450-074c-4d5f-8530-bcade9e6deaf","colab":{"base_uri":"https://localhost:8080/","height":265}},"source":["plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"8g6__3Onm805"},"source":["Training and validation accuracy seems to improve after around 60 epochs"]},{"cell_type":"code","metadata":{"id":"TwlF_-OSm805","executionInfo":{"status":"ok","timestamp":1604684711638,"user_tz":420,"elapsed":79615,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"9386ef36-2632-4ab1-d50d-70896eed329e","colab":{"base_uri":"https://localhost:8080/"}},"source":["results = model.evaluate(X_test, y_test)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["313/313 [==============================] - 0s 1ms/step - loss: 2.2944 - accuracy: 0.1135\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"DSeFp3jfm808","executionInfo":{"status":"ok","timestamp":1604684711640,"user_tz":420,"elapsed":79611,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"4648dd6d-4283-4d19-812c-a28e24876e78","colab":{"base_uri":"https://localhost:8080/"}},"source":["print('Test accuracy: ', results[1])"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Test accuracy: 0.11349999904632568\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"aQLGUyPsm80-"},"source":["## 1. Weight Initialization\n","- Changing weight initialization scheme can sometimes improve training of the model by preventing vanishing gradient problem up to some degree\n","- He normal or Xavier normal initialization schemes are SOTA at the moment\n","- Doc: https://keras.io/initializers/"]},{"cell_type":"code","metadata":{"id":"0kOFoXgtm80-"},"source":["# from now on, create a function to generate (return) models\n","def mlp_model():\n"," model = Sequential()\n"," \n"," model.add(Dense(50, input_shape = (784, ), kernel_initializer='he_normal')) # use he_normal initializer\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50, kernel_initializer='he_normal')) # use he_normal initializer\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50, kernel_initializer='he_normal')) # use he_normal initializer\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50, kernel_initializer='he_normal')) # use he_normal initializer\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(10, kernel_initializer='he_normal')) # use he_normal initializer\n"," model.add(Activation('softmax'))\n"," \n"," sgd = optimizers.SGD(lr = 0.001)\n"," model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])\n"," \n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"scrolled":true,"id":"T_4QYzNGm81A","executionInfo":{"status":"ok","timestamp":1604684947977,"user_tz":420,"elapsed":315938,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"e05cac1e-6c6e-431e-b9a1-86090ed6eb74","colab":{"base_uri":"https://localhost:8080/"}},"source":["model = mlp_model()\n","history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 1)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Epoch 1/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.3656 - accuracy: 0.1143 - val_loss: 2.3063 - val_accuracy: 0.1079\n","Epoch 2/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.3001 - accuracy: 0.1143 - val_loss: 2.2984 - val_accuracy: 0.1079\n","Epoch 3/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2958 - accuracy: 0.1143 - val_loss: 2.2955 - val_accuracy: 0.1079\n","Epoch 4/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2929 - accuracy: 0.1143 - val_loss: 2.2927 - val_accuracy: 0.1079\n","Epoch 5/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2902 - accuracy: 0.1143 - val_loss: 2.2900 - val_accuracy: 0.1079\n","Epoch 6/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2876 - accuracy: 0.1143 - val_loss: 2.2872 - val_accuracy: 0.1079\n","Epoch 7/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2849 - accuracy: 0.1143 - val_loss: 2.2844 - val_accuracy: 0.1079\n","Epoch 8/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2821 - accuracy: 0.1143 - val_loss: 2.2815 - val_accuracy: 0.1079\n","Epoch 9/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2791 - accuracy: 0.1143 - val_loss: 2.2785 - val_accuracy: 0.1080\n","Epoch 10/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2760 - accuracy: 0.1176 - val_loss: 2.2755 - val_accuracy: 0.1084\n","Epoch 11/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2731 - accuracy: 0.1170 - val_loss: 2.2724 - val_accuracy: 0.1083\n","Epoch 12/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2700 - accuracy: 0.1160 - val_loss: 2.2692 - val_accuracy: 0.1257\n","Epoch 13/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2668 - accuracy: 0.1228 - val_loss: 2.2659 - val_accuracy: 0.1266\n","Epoch 14/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2634 - accuracy: 0.1285 - val_loss: 2.2626 - val_accuracy: 0.1390\n","Epoch 15/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2599 - accuracy: 0.1586 - val_loss: 2.2588 - val_accuracy: 0.1309\n","Epoch 16/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2561 - accuracy: 0.1680 - val_loss: 2.2549 - val_accuracy: 0.1638\n","Epoch 17/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2520 - accuracy: 0.1872 - val_loss: 2.2507 - val_accuracy: 0.2122\n","Epoch 18/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2476 - accuracy: 0.2439 - val_loss: 2.2460 - val_accuracy: 0.1938\n","Epoch 19/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2429 - accuracy: 0.2310 - val_loss: 2.2410 - val_accuracy: 0.2599\n","Epoch 20/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2377 - accuracy: 0.2752 - val_loss: 2.2356 - val_accuracy: 0.2886\n","Epoch 21/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2321 - accuracy: 0.2968 - val_loss: 2.2296 - val_accuracy: 0.3260\n","Epoch 22/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2260 - accuracy: 0.3393 - val_loss: 2.2234 - val_accuracy: 0.3254\n","Epoch 23/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2194 - accuracy: 0.3507 - val_loss: 2.2164 - val_accuracy: 0.3616\n","Epoch 24/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2121 - accuracy: 0.3837 - val_loss: 2.2088 - val_accuracy: 0.3750\n","Epoch 25/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.2041 - accuracy: 0.3929 - val_loss: 2.2003 - val_accuracy: 0.4189\n","Epoch 26/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.1953 - accuracy: 0.4090 - val_loss: 2.1911 - val_accuracy: 0.4568\n","Epoch 27/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.1858 - accuracy: 0.4468 - val_loss: 2.1812 - val_accuracy: 0.4503\n","Epoch 28/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.1753 - accuracy: 0.4504 - val_loss: 2.1702 - val_accuracy: 0.4673\n","Epoch 29/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.1637 - accuracy: 0.4713 - val_loss: 2.1582 - val_accuracy: 0.4663\n","Epoch 30/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.1507 - accuracy: 0.4746 - val_loss: 2.1445 - val_accuracy: 0.4826\n","Epoch 31/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.1362 - accuracy: 0.4836 - val_loss: 2.1293 - val_accuracy: 0.5059\n","Epoch 32/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.1204 - accuracy: 0.4987 - val_loss: 2.1129 - val_accuracy: 0.5090\n","Epoch 33/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.1032 - accuracy: 0.5108 - val_loss: 2.0953 - val_accuracy: 0.5013\n","Epoch 34/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.0844 - accuracy: 0.5070 - val_loss: 2.0756 - val_accuracy: 0.5133\n","Epoch 35/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.0637 - accuracy: 0.5155 - val_loss: 2.0542 - val_accuracy: 0.5234\n","Epoch 36/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.0411 - accuracy: 0.5282 - val_loss: 2.0305 - val_accuracy: 0.5218\n","Epoch 37/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 2.0164 - accuracy: 0.5228 - val_loss: 2.0052 - val_accuracy: 0.5376\n","Epoch 38/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.9896 - accuracy: 0.5318 - val_loss: 1.9774 - val_accuracy: 0.5453\n","Epoch 39/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.9607 - accuracy: 0.5416 - val_loss: 1.9474 - val_accuracy: 0.5428\n","Epoch 40/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.9295 - accuracy: 0.5428 - val_loss: 1.9155 - val_accuracy: 0.5483\n","Epoch 41/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.8963 - accuracy: 0.5511 - val_loss: 1.8813 - val_accuracy: 0.5532\n","Epoch 42/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.8610 - accuracy: 0.5515 - val_loss: 1.8453 - val_accuracy: 0.5653\n","Epoch 43/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.8240 - accuracy: 0.5595 - val_loss: 1.8076 - val_accuracy: 0.5727\n","Epoch 44/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.7854 - accuracy: 0.5665 - val_loss: 1.7686 - val_accuracy: 0.5749\n","Epoch 45/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.7452 - accuracy: 0.5699 - val_loss: 1.7278 - val_accuracy: 0.5812\n","Epoch 46/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.7038 - accuracy: 0.5786 - val_loss: 1.6858 - val_accuracy: 0.5847\n","Epoch 47/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.6616 - accuracy: 0.5838 - val_loss: 1.6430 - val_accuracy: 0.5901\n","Epoch 48/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.6185 - accuracy: 0.5881 - val_loss: 1.5999 - val_accuracy: 0.5997\n","Epoch 49/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.5749 - accuracy: 0.5949 - val_loss: 1.5558 - val_accuracy: 0.6107\n","Epoch 50/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.5313 - accuracy: 0.6037 - val_loss: 1.5124 - val_accuracy: 0.6128\n","Epoch 51/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.4881 - accuracy: 0.6118 - val_loss: 1.4686 - val_accuracy: 0.6141\n","Epoch 52/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.4447 - accuracy: 0.6153 - val_loss: 1.4259 - val_accuracy: 0.6274\n","Epoch 53/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.4026 - accuracy: 0.6263 - val_loss: 1.3840 - val_accuracy: 0.6333\n","Epoch 54/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.3612 - accuracy: 0.6324 - val_loss: 1.3434 - val_accuracy: 0.6470\n","Epoch 55/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.3206 - accuracy: 0.6457 - val_loss: 1.3039 - val_accuracy: 0.6506\n","Epoch 56/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.2821 - accuracy: 0.6517 - val_loss: 1.2672 - val_accuracy: 0.6568\n","Epoch 57/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.2454 - accuracy: 0.6589 - val_loss: 1.2306 - val_accuracy: 0.6644\n","Epoch 58/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.2104 - accuracy: 0.6664 - val_loss: 1.1962 - val_accuracy: 0.6756\n","Epoch 59/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.1771 - accuracy: 0.6752 - val_loss: 1.1638 - val_accuracy: 0.6823\n","Epoch 60/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.1453 - accuracy: 0.6855 - val_loss: 1.1338 - val_accuracy: 0.6892\n","Epoch 61/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.1154 - accuracy: 0.6917 - val_loss: 1.1042 - val_accuracy: 0.7019\n","Epoch 62/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.0874 - accuracy: 0.7005 - val_loss: 1.0775 - val_accuracy: 0.7031\n","Epoch 63/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.0615 - accuracy: 0.7063 - val_loss: 1.0526 - val_accuracy: 0.7127\n","Epoch 64/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.0366 - accuracy: 0.7143 - val_loss: 1.0286 - val_accuracy: 0.7214\n","Epoch 65/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.0128 - accuracy: 0.7223 - val_loss: 1.0067 - val_accuracy: 0.7251\n","Epoch 66/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.9912 - accuracy: 0.7271 - val_loss: 0.9850 - val_accuracy: 0.7341\n","Epoch 67/100\n","1313/1313 [==============================] - 3s 2ms/step - loss: 0.9702 - accuracy: 0.7352 - val_loss: 0.9644 - val_accuracy: 0.7395\n","Epoch 68/100\n","1313/1313 [==============================] - 3s 2ms/step - loss: 0.9500 - accuracy: 0.7442 - val_loss: 0.9455 - val_accuracy: 0.7449\n","Epoch 69/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.9313 - accuracy: 0.7496 - val_loss: 0.9282 - val_accuracy: 0.7540\n","Epoch 70/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.9141 - accuracy: 0.7563 - val_loss: 0.9115 - val_accuracy: 0.7555\n","Epoch 71/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.8965 - accuracy: 0.7618 - val_loss: 0.8964 - val_accuracy: 0.7647\n","Epoch 72/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.8794 - accuracy: 0.7692 - val_loss: 0.8789 - val_accuracy: 0.7761\n","Epoch 73/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.8636 - accuracy: 0.7744 - val_loss: 0.8655 - val_accuracy: 0.7739\n","Epoch 74/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.8482 - accuracy: 0.7787 - val_loss: 0.8504 - val_accuracy: 0.7805\n","Epoch 75/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.8342 - accuracy: 0.7845 - val_loss: 0.8375 - val_accuracy: 0.7871\n","Epoch 76/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.8214 - accuracy: 0.7887 - val_loss: 0.8230 - val_accuracy: 0.7866\n","Epoch 77/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.8069 - accuracy: 0.7934 - val_loss: 0.8115 - val_accuracy: 0.7932\n","Epoch 78/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7945 - accuracy: 0.7967 - val_loss: 0.8010 - val_accuracy: 0.7973\n","Epoch 79/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7816 - accuracy: 0.8010 - val_loss: 0.7875 - val_accuracy: 0.8012\n","Epoch 80/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7692 - accuracy: 0.8051 - val_loss: 0.7758 - val_accuracy: 0.8077\n","Epoch 81/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7577 - accuracy: 0.8087 - val_loss: 0.7662 - val_accuracy: 0.8058\n","Epoch 82/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7457 - accuracy: 0.8120 - val_loss: 0.7550 - val_accuracy: 0.8089\n","Epoch 83/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7345 - accuracy: 0.8165 - val_loss: 0.7434 - val_accuracy: 0.8163\n","Epoch 84/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7244 - accuracy: 0.8196 - val_loss: 0.7346 - val_accuracy: 0.8154\n","Epoch 85/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7130 - accuracy: 0.8231 - val_loss: 0.7243 - val_accuracy: 0.8217\n","Epoch 86/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.7017 - accuracy: 0.8275 - val_loss: 0.7152 - val_accuracy: 0.8234\n","Epoch 87/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6929 - accuracy: 0.8316 - val_loss: 0.7069 - val_accuracy: 0.8229\n","Epoch 88/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6835 - accuracy: 0.8341 - val_loss: 0.6979 - val_accuracy: 0.8276\n","Epoch 89/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6726 - accuracy: 0.8356 - val_loss: 0.6890 - val_accuracy: 0.8304\n","Epoch 90/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6645 - accuracy: 0.8394 - val_loss: 0.6823 - val_accuracy: 0.8307\n","Epoch 91/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6549 - accuracy: 0.8445 - val_loss: 0.6707 - val_accuracy: 0.8353\n","Epoch 92/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6459 - accuracy: 0.8457 - val_loss: 0.6651 - val_accuracy: 0.8388\n","Epoch 93/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6375 - accuracy: 0.8502 - val_loss: 0.6558 - val_accuracy: 0.8389\n","Epoch 94/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6291 - accuracy: 0.8533 - val_loss: 0.6479 - val_accuracy: 0.8426\n","Epoch 95/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6208 - accuracy: 0.8541 - val_loss: 0.6406 - val_accuracy: 0.8453\n","Epoch 96/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.6119 - accuracy: 0.8574 - val_loss: 0.6317 - val_accuracy: 0.8479\n","Epoch 97/100\n","1313/1313 [==============================] - 3s 2ms/step - loss: 0.6037 - accuracy: 0.8609 - val_loss: 0.6270 - val_accuracy: 0.8495\n","Epoch 98/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.5967 - accuracy: 0.8618 - val_loss: 0.6185 - val_accuracy: 0.8507\n","Epoch 99/100\n","1313/1313 [==============================] - 3s 2ms/step - loss: 0.5882 - accuracy: 0.8646 - val_loss: 0.6135 - val_accuracy: 0.8517\n","Epoch 100/100\n","1313/1313 [==============================] - 3s 2ms/step - loss: 0.5817 - accuracy: 0.8655 - val_loss: 0.6051 - val_accuracy: 0.8533\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"k-ttTMOwm81C","executionInfo":{"status":"ok","timestamp":1604684947982,"user_tz":420,"elapsed":315935,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"2d83304a-8ada-4d26-9c66-e830a4a27f47","colab":{"base_uri":"https://localhost:8080/","height":265}},"source":["plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"2hwzerhum81E"},"source":["Training and validation accuracy seems to improve after around 60 epochs"]},{"cell_type":"code","metadata":{"id":"mzo9TrI5m81E","executionInfo":{"status":"ok","timestamp":1604684948276,"user_tz":420,"elapsed":316223,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"fb2d1917-506d-417d-ee2d-3e98540bf119","colab":{"base_uri":"https://localhost:8080/"}},"source":["results = model.evaluate(X_test, y_test)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["313/313 [==============================] - 0s 1ms/step - loss: 0.5939 - accuracy: 0.8625\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"_JDbnxfRm81G","executionInfo":{"status":"ok","timestamp":1604684948277,"user_tz":420,"elapsed":316219,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"3d9794e1-46d3-4574-8adb-1249943d0555","colab":{"base_uri":"https://localhost:8080/"}},"source":["print('Test accuracy: ', results[1])"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Test accuracy: 0.862500011920929\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"OMfsNCzcm81I"},"source":["## 2. Nonlinearity (Activation function)\n","- Sigmoid functions suffer from gradient vanishing problem, making training slower\n","- There are many choices apart from sigmoid and tanh; try many of them!\n"," - **'relu'** (rectified linear unit) is one of the most popular ones\n"," - **'selu'** (scaled exponential linear unit) is one of the most recent ones\n","- Doc: https://keras.io/activations/"]},{"cell_type":"markdown","metadata":{"id":"fO6HhMaLm81I"},"source":["\n","
**Sigmoid Activation Function**
\n","\n","
**Relu Activation Function**
"]},{"cell_type":"code","metadata":{"id":"phc_0Rgdm81J"},"source":["def mlp_model():\n"," model = Sequential()\n"," \n"," model.add(Dense(50, input_shape = (784, )))\n"," model.add(Activation('relu')) # use relu\n"," model.add(Dense(50))\n"," model.add(Activation('relu')) # use relu\n"," model.add(Dense(50))\n"," model.add(Activation('relu')) # use relu\n"," model.add(Dense(50))\n"," model.add(Activation('relu')) # use relu\n"," model.add(Dense(10))\n"," model.add(Activation('softmax'))\n"," \n"," sgd = optimizers.SGD(lr = 0.001)\n"," model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])\n"," \n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"yeEYkRCwm81L","executionInfo":{"status":"ok","timestamp":1604685184126,"user_tz":420,"elapsed":552058,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"37cbe657-7e64-42d7-be4e-7422a71f5ab0","colab":{"base_uri":"https://localhost:8080/"}},"source":["model = mlp_model()\n","history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 1)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Epoch 1/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 1.1128 - accuracy: 0.7314 - val_loss: 0.5504 - val_accuracy: 0.8466\n","Epoch 2/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.4489 - accuracy: 0.8700 - val_loss: 0.4168 - val_accuracy: 0.8827\n","Epoch 3/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.3434 - accuracy: 0.8988 - val_loss: 0.3719 - val_accuracy: 0.8918\n","Epoch 4/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.2895 - accuracy: 0.9141 - val_loss: 0.3231 - val_accuracy: 0.9103\n","Epoch 5/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.2528 - accuracy: 0.9248 - val_loss: 0.2951 - val_accuracy: 0.9153\n","Epoch 6/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.2266 - accuracy: 0.9321 - val_loss: 0.2940 - val_accuracy: 0.9170\n","Epoch 7/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.2094 - accuracy: 0.9368 - val_loss: 0.2848 - val_accuracy: 0.9194\n","Epoch 8/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1935 - accuracy: 0.9423 - val_loss: 0.2735 - val_accuracy: 0.9215\n","Epoch 9/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1820 - accuracy: 0.9458 - val_loss: 0.2632 - val_accuracy: 0.9277\n","Epoch 10/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1702 - accuracy: 0.9486 - val_loss: 0.2506 - val_accuracy: 0.9304\n","Epoch 11/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1605 - accuracy: 0.9518 - val_loss: 0.2467 - val_accuracy: 0.9319\n","Epoch 12/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1527 - accuracy: 0.9535 - val_loss: 0.2390 - val_accuracy: 0.9327\n","Epoch 13/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1447 - accuracy: 0.9558 - val_loss: 0.2948 - val_accuracy: 0.9189\n","Epoch 14/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1380 - accuracy: 0.9576 - val_loss: 0.2315 - val_accuracy: 0.9362\n","Epoch 15/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1324 - accuracy: 0.9604 - val_loss: 0.2351 - val_accuracy: 0.9369\n","Epoch 16/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1266 - accuracy: 0.9624 - val_loss: 0.2313 - val_accuracy: 0.9373\n","Epoch 17/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1217 - accuracy: 0.9626 - val_loss: 0.2308 - val_accuracy: 0.9385\n","Epoch 18/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1172 - accuracy: 0.9646 - val_loss: 0.2286 - val_accuracy: 0.9398\n","Epoch 19/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1130 - accuracy: 0.9659 - val_loss: 0.2267 - val_accuracy: 0.9398\n","Epoch 20/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1095 - accuracy: 0.9664 - val_loss: 0.2253 - val_accuracy: 0.9408\n","Epoch 21/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1049 - accuracy: 0.9676 - val_loss: 0.2281 - val_accuracy: 0.9411\n","Epoch 22/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.1017 - accuracy: 0.9695 - val_loss: 0.2252 - val_accuracy: 0.9417\n","Epoch 23/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0984 - accuracy: 0.9697 - val_loss: 0.2267 - val_accuracy: 0.9412\n","Epoch 24/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0946 - accuracy: 0.9718 - val_loss: 0.2310 - val_accuracy: 0.9422\n","Epoch 25/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0921 - accuracy: 0.9720 - val_loss: 0.2272 - val_accuracy: 0.9419\n","Epoch 26/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0891 - accuracy: 0.9733 - val_loss: 0.2260 - val_accuracy: 0.9438\n","Epoch 27/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0860 - accuracy: 0.9749 - val_loss: 0.2251 - val_accuracy: 0.9429\n","Epoch 28/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0823 - accuracy: 0.9757 - val_loss: 0.2363 - val_accuracy: 0.9436\n","Epoch 29/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0800 - accuracy: 0.9759 - val_loss: 0.2431 - val_accuracy: 0.9423\n","Epoch 30/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0776 - accuracy: 0.9765 - val_loss: 0.2289 - val_accuracy: 0.9447\n","Epoch 31/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0759 - accuracy: 0.9773 - val_loss: 0.2300 - val_accuracy: 0.9434\n","Epoch 32/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0735 - accuracy: 0.9784 - val_loss: 0.2301 - val_accuracy: 0.9450\n","Epoch 33/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0715 - accuracy: 0.9792 - val_loss: 0.2318 - val_accuracy: 0.9441\n","Epoch 34/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0692 - accuracy: 0.9794 - val_loss: 0.2324 - val_accuracy: 0.9446\n","Epoch 35/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0676 - accuracy: 0.9804 - val_loss: 0.2306 - val_accuracy: 0.9457\n","Epoch 36/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0651 - accuracy: 0.9807 - val_loss: 0.2379 - val_accuracy: 0.9436\n","Epoch 37/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0639 - accuracy: 0.9815 - val_loss: 0.2304 - val_accuracy: 0.9456\n","Epoch 38/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0613 - accuracy: 0.9825 - val_loss: 0.2417 - val_accuracy: 0.9455\n","Epoch 39/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0600 - accuracy: 0.9825 - val_loss: 0.2418 - val_accuracy: 0.9459\n","Epoch 40/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0588 - accuracy: 0.9832 - val_loss: 0.2460 - val_accuracy: 0.9445\n","Epoch 41/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0572 - accuracy: 0.9828 - val_loss: 0.2414 - val_accuracy: 0.9452\n","Epoch 42/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0552 - accuracy: 0.9837 - val_loss: 0.2456 - val_accuracy: 0.9441\n","Epoch 43/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0544 - accuracy: 0.9844 - val_loss: 0.2437 - val_accuracy: 0.9461\n","Epoch 44/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0532 - accuracy: 0.9843 - val_loss: 0.2438 - val_accuracy: 0.9466\n","Epoch 45/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0511 - accuracy: 0.9855 - val_loss: 0.2429 - val_accuracy: 0.9460\n","Epoch 46/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0492 - accuracy: 0.9862 - val_loss: 0.2568 - val_accuracy: 0.9448\n","Epoch 47/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0485 - accuracy: 0.9865 - val_loss: 0.2469 - val_accuracy: 0.9459\n","Epoch 48/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0469 - accuracy: 0.9863 - val_loss: 0.2509 - val_accuracy: 0.9459\n","Epoch 49/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0467 - accuracy: 0.9863 - val_loss: 0.2481 - val_accuracy: 0.9472\n","Epoch 50/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0447 - accuracy: 0.9870 - val_loss: 0.2511 - val_accuracy: 0.9464\n","Epoch 51/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0442 - accuracy: 0.9868 - val_loss: 0.2573 - val_accuracy: 0.9456\n","Epoch 52/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0425 - accuracy: 0.9875 - val_loss: 0.2563 - val_accuracy: 0.9455\n","Epoch 53/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0419 - accuracy: 0.9881 - val_loss: 0.2541 - val_accuracy: 0.9458\n","Epoch 54/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0404 - accuracy: 0.9886 - val_loss: 0.2624 - val_accuracy: 0.9472\n","Epoch 55/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0396 - accuracy: 0.9885 - val_loss: 0.2640 - val_accuracy: 0.9445\n","Epoch 56/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0388 - accuracy: 0.9888 - val_loss: 0.2718 - val_accuracy: 0.9448\n","Epoch 57/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0375 - accuracy: 0.9895 - val_loss: 0.2643 - val_accuracy: 0.9468\n","Epoch 58/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0361 - accuracy: 0.9897 - val_loss: 0.2728 - val_accuracy: 0.9439\n","Epoch 59/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0360 - accuracy: 0.9897 - val_loss: 0.2633 - val_accuracy: 0.9461\n","Epoch 60/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0347 - accuracy: 0.9905 - val_loss: 0.2640 - val_accuracy: 0.9468\n","Epoch 61/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0336 - accuracy: 0.9906 - val_loss: 0.2716 - val_accuracy: 0.9468\n","Epoch 62/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0330 - accuracy: 0.9908 - val_loss: 0.2772 - val_accuracy: 0.9449\n","Epoch 63/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0325 - accuracy: 0.9907 - val_loss: 0.2772 - val_accuracy: 0.9453\n","Epoch 64/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0320 - accuracy: 0.9911 - val_loss: 0.2770 - val_accuracy: 0.9471\n","Epoch 65/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0309 - accuracy: 0.9917 - val_loss: 0.2720 - val_accuracy: 0.9482\n","Epoch 66/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0302 - accuracy: 0.9918 - val_loss: 0.2828 - val_accuracy: 0.9473\n","Epoch 67/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0288 - accuracy: 0.9926 - val_loss: 0.2858 - val_accuracy: 0.9460\n","Epoch 68/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0287 - accuracy: 0.9923 - val_loss: 0.2855 - val_accuracy: 0.9459\n","Epoch 69/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0277 - accuracy: 0.9925 - val_loss: 0.2984 - val_accuracy: 0.9451\n","Epoch 70/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0269 - accuracy: 0.9929 - val_loss: 0.2891 - val_accuracy: 0.9447\n","Epoch 71/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0264 - accuracy: 0.9929 - val_loss: 0.2865 - val_accuracy: 0.9466\n","Epoch 72/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0251 - accuracy: 0.9933 - val_loss: 0.2896 - val_accuracy: 0.9472\n","Epoch 73/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0245 - accuracy: 0.9934 - val_loss: 0.2961 - val_accuracy: 0.9467\n","Epoch 74/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0248 - accuracy: 0.9934 - val_loss: 0.2961 - val_accuracy: 0.9462\n","Epoch 75/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0243 - accuracy: 0.9937 - val_loss: 0.2974 - val_accuracy: 0.9469\n","Epoch 76/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0226 - accuracy: 0.9943 - val_loss: 0.2977 - val_accuracy: 0.9474\n","Epoch 77/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0223 - accuracy: 0.9946 - val_loss: 0.3038 - val_accuracy: 0.9480\n","Epoch 78/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0220 - accuracy: 0.9940 - val_loss: 0.3051 - val_accuracy: 0.9472\n","Epoch 79/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0216 - accuracy: 0.9944 - val_loss: 0.3116 - val_accuracy: 0.9462\n","Epoch 80/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0207 - accuracy: 0.9950 - val_loss: 0.3021 - val_accuracy: 0.9489\n","Epoch 81/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0203 - accuracy: 0.9945 - val_loss: 0.3098 - val_accuracy: 0.9484\n","Epoch 82/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0196 - accuracy: 0.9952 - val_loss: 0.3139 - val_accuracy: 0.9468\n","Epoch 83/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0193 - accuracy: 0.9953 - val_loss: 0.3157 - val_accuracy: 0.9466\n","Epoch 84/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0188 - accuracy: 0.9954 - val_loss: 0.3180 - val_accuracy: 0.9458\n","Epoch 85/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0191 - accuracy: 0.9953 - val_loss: 0.3196 - val_accuracy: 0.9485\n","Epoch 86/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0182 - accuracy: 0.9953 - val_loss: 0.3267 - val_accuracy: 0.9452\n","Epoch 87/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0176 - accuracy: 0.9960 - val_loss: 0.3224 - val_accuracy: 0.9472\n","Epoch 88/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0177 - accuracy: 0.9957 - val_loss: 0.3195 - val_accuracy: 0.9476\n","Epoch 89/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0167 - accuracy: 0.9961 - val_loss: 0.3250 - val_accuracy: 0.9476\n","Epoch 90/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0163 - accuracy: 0.9964 - val_loss: 0.3281 - val_accuracy: 0.9472\n","Epoch 91/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0155 - accuracy: 0.9966 - val_loss: 0.3259 - val_accuracy: 0.9470\n","Epoch 92/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0154 - accuracy: 0.9964 - val_loss: 0.3387 - val_accuracy: 0.9447\n","Epoch 93/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0152 - accuracy: 0.9968 - val_loss: 0.3323 - val_accuracy: 0.9466\n","Epoch 94/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0147 - accuracy: 0.9968 - val_loss: 0.3295 - val_accuracy: 0.9470\n","Epoch 95/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0147 - accuracy: 0.9970 - val_loss: 0.3338 - val_accuracy: 0.9472\n","Epoch 96/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0143 - accuracy: 0.9970 - val_loss: 0.3401 - val_accuracy: 0.9477\n","Epoch 97/100\n","1313/1313 [==============================] - 3s 2ms/step - loss: 0.0143 - accuracy: 0.9968 - val_loss: 0.3434 - val_accuracy: 0.9459\n","Epoch 98/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0136 - accuracy: 0.9971 - val_loss: 0.3385 - val_accuracy: 0.9469\n","Epoch 99/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0130 - accuracy: 0.9975 - val_loss: 0.3472 - val_accuracy: 0.9472\n","Epoch 100/100\n","1313/1313 [==============================] - 2s 2ms/step - loss: 0.0127 - accuracy: 0.9977 - val_loss: 0.3448 - val_accuracy: 0.9450\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"Fig-BD1Qm81O","executionInfo":{"status":"ok","timestamp":1604685184396,"user_tz":420,"elapsed":552321,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"93c42cd3-cfb0-4708-d293-145d66ba0e2f","colab":{"base_uri":"https://localhost:8080/","height":265}},"source":["plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"TutiUGFIm81Q"},"source":["Training and validation accuracy improve instantaneously, but reach a plateau after around 30 epochs"]},{"cell_type":"code","metadata":{"id":"koMrvUyJm81R","executionInfo":{"status":"ok","timestamp":1604685184602,"user_tz":420,"elapsed":552520,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"d28b3257-f557-4ca2-8644-96cb7ec19ed0","colab":{"base_uri":"https://localhost:8080/"}},"source":["results = model.evaluate(X_test, y_test)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["313/313 [==============================] - 0s 1ms/step - loss: 0.3152 - accuracy: 0.9488\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"p7YpK2Pvm81T","executionInfo":{"status":"ok","timestamp":1604685184603,"user_tz":420,"elapsed":552515,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"76391ac6-eaf0-48c5-ccba-13c2769ad449","colab":{"base_uri":"https://localhost:8080/"}},"source":["print('Test accuracy: ', results[1])"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Test accuracy: 0.9488000273704529\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"fcqGvLFsm81V"},"source":["## 3. Optimizers\n","- Many variants of SGD are proposed and employed nowadays\n","- One of the most popular ones are Adam (Adaptive Moment Estimation)\n","- Doc: https://keras.io/optimizers/"]},{"cell_type":"markdown","metadata":{"id":"94XG6lbrm81V"},"source":["\n","
**Relative convergence speed of different optimizers**

"]},{"cell_type":"code","metadata":{"id":"Po9y2epXm81W"},"source":["def mlp_model():\n"," model = Sequential()\n"," \n"," model.add(Dense(50, input_shape = (784, )))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(10))\n"," model.add(Activation('softmax'))\n"," \n"," adam = optimizers.Adam(lr = 0.001) # use Adam optimizer\n"," model.compile(optimizer = adam, loss = 'categorical_crossentropy', metrics = ['accuracy'])\n"," \n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Fzn5iX2Um81Y"},"source":["model = mlp_model()\n","history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"1greMMf5m81a","executionInfo":{"status":"ok","timestamp":1604685416949,"user_tz":420,"elapsed":784848,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"0f66a1bb-5b29-47c8-bf25-75f64fe96045","colab":{"base_uri":"https://localhost:8080/","height":265}},"source":["plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"xCEHlMxrm81c"},"source":["Training and validation accuracy improve instantaneously, but reach plateau after around 50 epochs"]},{"cell_type":"code","metadata":{"id":"Zm-4ybkLm81d","executionInfo":{"status":"ok","timestamp":1604685417478,"user_tz":420,"elapsed":785370,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"cc6df5d1-bcfd-4656-8bc0-7588a149ab61","colab":{"base_uri":"https://localhost:8080/"}},"source":["results = model.evaluate(X_test, y_test)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["313/313 [==============================] - 0s 1ms/step - loss: 0.1801 - accuracy: 0.9465\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"DBHaJT5Um81f","executionInfo":{"status":"ok","timestamp":1604685417479,"user_tz":420,"elapsed":785364,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"f41904bb-c268-462e-b010-d33004890a07","colab":{"base_uri":"https://localhost:8080/"}},"source":["print('Test accuracy: ', results[1])"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Test accuracy: 0.9465000033378601\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"DNuN5ygtm81g"},"source":["## 4. Batch Normalization\n","- Batch Normalization, one of the methods to prevent the \"internal covariance shift\" problem, has proven to be highly effective\n","- Normalize each mini-batch before nonlinearity\n","- Doc: https://keras.io/optimizers/"]},{"cell_type":"markdown","metadata":{"id":"aTk5ndUCm81h"},"source":["\n","\n","
Batch normalization layer is usually inserted after dense/convolution and before nonlinearity"]},{"cell_type":"code","metadata":{"id":"wXD9H1wfm81h"},"source":["from keras.layers import BatchNormalization"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"mZEOsHbVm81k"},"source":["def mlp_model():\n"," model = Sequential()\n"," \n"," model.add(Dense(50, input_shape = (784, )))\n"," model.add(BatchNormalization()) # Add Batchnorm layer before Activation\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(BatchNormalization()) # Add Batchnorm layer before Activation\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(BatchNormalization()) # Add Batchnorm layer before Activation\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(BatchNormalization()) # Add Batchnorm layer before Activation\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(10))\n"," model.add(Activation('softmax'))\n"," \n"," sgd = optimizers.SGD(lr = 0.001)\n"," model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])\n"," \n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ziG6hzqZm81m"},"source":["model = mlp_model()\n","history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"4cVWf4x8m81o","executionInfo":{"status":"ok","timestamp":1604685663470,"user_tz":420,"elapsed":1031344,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"b9ace8eb-0ab3-4096-b772-49d788799a0f","colab":{"base_uri":"https://localhost:8080/","height":265}},"source":["plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"IvyC1A4Nm81q"},"source":["Training and validation accuracy improve consistently, but reach plateau after around 60 epochs"]},{"cell_type":"code","metadata":{"id":"89wGTNv0m81r","executionInfo":{"status":"ok","timestamp":1604685663946,"user_tz":420,"elapsed":1031814,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"8a58f27e-1b76-41a4-9574-7220cb03b7c7","colab":{"base_uri":"https://localhost:8080/"}},"source":["results = model.evaluate(X_test, y_test)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["313/313 [==============================] - 0s 1ms/step - loss: 0.1866 - accuracy: 0.9481\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"2OlzhJIrm81t","executionInfo":{"status":"ok","timestamp":1604685663947,"user_tz":420,"elapsed":1031809,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"e00f0778-d942-46d6-b8a6-9a14d100e740","colab":{"base_uri":"https://localhost:8080/"}},"source":["print('Test accuracy: ', results[1])"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Test accuracy: 0.9480999708175659\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"M_n7L3yrm81w"},"source":["## 5. Dropout (Regularization)\n","- Dropout is one of powerful ways to prevent overfitting\n","- The idea is simple. It is disconnecting some (randomly selected) neurons in each layer\n","- The probability of each neuron to be disconnected, namely 'Dropout rate', has to be designated\n","- Doc: https://keras.io/layers/core/#dropout"]},{"cell_type":"markdown","metadata":{"id":"cgmuU0_Em81w"},"source":[""]},{"cell_type":"code","metadata":{"id":"idwYUgbqm81x"},"source":["from keras.layers import Dropout"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"H6sbqxj0m81z"},"source":["def mlp_model():\n"," model = Sequential()\n"," \n"," model.add(Dense(50, input_shape = (784, )))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dropout(0.2)) # Dropout layer after Activation\n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid'))\n"," model.add(Dropout(0.2)) # Dropout layer after Activation\n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dropout(0.2)) # Dropout layer after Activation\n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dropout(0.2)) # Dropout layer after Activation\n"," model.add(Dense(10))\n"," model.add(Activation('softmax'))\n"," \n"," sgd = optimizers.SGD(lr = 0.001)\n"," model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])\n"," \n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"JWT0Rcwom811"},"source":["model = mlp_model()\n","history = model.fit(X_train, y_train, validation_split = 0.3, epochs = 100, verbose = 0)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"28sHxARXm813","executionInfo":{"status":"ok","timestamp":1604685863059,"user_tz":420,"elapsed":1230909,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"7361997b-5c42-4e85-9e4d-43446e488a72","colab":{"base_uri":"https://localhost:8080/","height":268}},"source":["plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.legend(['train acc', 'valid acc', 'train loss', 'valid loss'], loc = 'upper left')\n","plt.show()"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["
"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"R6YP4RHVm816"},"source":["Validation results does not improve since it did not show signs of overfitting, yet.\n","
Hence, the key takeaway message is that apply dropout when you see a signal of overfitting."]},{"cell_type":"code","metadata":{"id":"aAPAlxKEm816","executionInfo":{"status":"ok","timestamp":1604685863413,"user_tz":420,"elapsed":1231257,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"f85e06a9-e19c-40f6-9afd-4883e4523561","colab":{"base_uri":"https://localhost:8080/"}},"source":["results = model.evaluate(X_test, y_test)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["313/313 [==============================] - 0s 1ms/step - loss: 1.6782 - accuracy: 0.4227\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"47lVW1hXm818","executionInfo":{"status":"ok","timestamp":1604685863414,"user_tz":420,"elapsed":1231250,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"2c690eb9-22c7-4e51-c3d3-5aa2b061b505","colab":{"base_uri":"https://localhost:8080/"}},"source":["print('Test accuracy: ', results[1])"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Test accuracy: 0.4226999878883362\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"YtjeI1Erm81-"},"source":["## 6. Model Ensemble\n","- Model ensemble is a reliable and promising way to boost performance of the model\n","- Usually create 8 to 10 independent networks and merge their results\n","- Here, we resort to scikit-learn API, **VotingClassifier**\n","- Doc: http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html"]},{"cell_type":"markdown","metadata":{"id":"TSAYvxCmm81-"},"source":[""]},{"cell_type":"code","metadata":{"id":"Pb8HFb6Tm81_"},"source":["import numpy as np\n","\n","from tensorflow.keras.wrappers.scikit_learn import KerasClassifier\n","from sklearn.ensemble import VotingClassifier\n","from sklearn.metrics import accuracy_score"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kKsR6al3m82B"},"source":["y_train = np.argmax(y_train, axis = 1)\n","y_test = np.argmax(y_test, axis = 1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"nRxuL1cwm82D"},"source":["def mlp_model():\n"," model = Sequential()\n"," \n"," model.add(Dense(50, input_shape = (784, )))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(50))\n"," model.add(Activation('sigmoid')) \n"," model.add(Dense(10))\n"," model.add(Activation('softmax'))\n"," \n"," sgd = optimizers.SGD(lr = 0.001)\n"," model.compile(optimizer = sgd, loss = 'categorical_crossentropy', metrics = ['accuracy'])\n"," \n"," return model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NsC1_PiCm82E"},"source":["model1 = KerasClassifier(build_fn = mlp_model, epochs = 100, verbose = 0)\n","model2 = KerasClassifier(build_fn = mlp_model, epochs = 100, verbose = 0)\n","model3 = KerasClassifier(build_fn = mlp_model, epochs = 100, verbose = 0)\n","model1._estimator_type = \"classifier\"\n","model2._estimator_type = \"classifier\"\n","model3._estimator_type = \"classifier\""],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"2pBnptVLm82H"},"source":["ensemble_clf = VotingClassifier(estimators = [('model1', model1), ('model2', model2), ('model3', model3)]\n"," , voting = 'soft')"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vOx2UEAjm82I","executionInfo":{"status":"ok","timestamp":1604686456466,"user_tz":420,"elapsed":1824279,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"8ff8b858-c620-4bed-fe25-806aedebd854","colab":{"base_uri":"https://localhost:8080/"}},"source":["ensemble_clf.fit(X_train, y_train)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["VotingClassifier(estimators=[('model1',\n"," ),\n"," ('model2',\n"," ),\n"," ('model3',\n"," )],\n"," flatten_transform=True, n_jobs=None, voting='soft',\n"," weights=None)"]},"metadata":{"tags":[]},"execution_count":131}]},{"cell_type":"code","metadata":{"id":"uikYSUBmm82L","executionInfo":{"status":"ok","timestamp":1604686457386,"user_tz":420,"elapsed":1825192,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"44fa4445-330f-47aa-ff35-3190b964558a","colab":{"base_uri":"https://localhost:8080/"}},"source":["y_pred = ensemble_clf.predict(X_test)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/wrappers/scikit_learn.py:264: Sequential.predict_proba (from tensorflow.python.keras.engine.sequential) is deprecated and will be removed after 2021-01-01.\n","Instructions for updating:\n","Please use `model.predict()` instead.\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"1_5lAEucm82O","executionInfo":{"status":"ok","timestamp":1604686457388,"user_tz":420,"elapsed":1825189,"user":{"displayName":"Buomsoo Kim","photoUrl":"","userId":"18268696804115368229"}},"outputId":"cca4d7a3-600e-47e7-d55f-d9620245d139","colab":{"base_uri":"https://localhost:8080/"}},"source":["print('Test accuracy:', accuracy_score(y_pred, y_test))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Test accuracy: 0.9002\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"UeDuUWLPm82Q"},"source":["## Summary\n","\n","Below table is a summary of evaluation results so far. It turns out that all methods improve the test performance over the MNIST dataset. Why don't we try them out altogether?\n","\n","|Model | Baseline | Weight initialization | Activation function | Optimizer | Batchnormalization | Regularization | Ensemble |\n","|----------------|-------------|------------|-------------|-------------|------------|-----------|------------|\n","|Test Accuracy | 0.1134 | 0.8625 | 0.9488 | 0.9465 | 0.9480 | 0.4226 | 0.9002 |\n","\n","
\n"]},{"cell_type":"code","metadata":{"id":"MuUkwGu4T0si"},"source":[""],"execution_count":null,"outputs":[]}]}