{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercise 8.1 - Solution\n", "### Classification of magnetic phases using CNNs\n", "\n", "Imagine a 2-dimensional lattice arrangement of $n \\times n$ magnetic dipole moments (spins) that can be in one of two states ($+1$ or $−1$, Ising model).\n", "With interactions between spins being short ranged, each spin interacts only with its four neighbors.\n", "The probability to find a spin in one of the orientations is a function of temperature $T$ according to $p \\sim e^{−a/T},\\;a = \\mathrm{const.}$).\n", "\n", "At extremely low temperatures $T \\rightarrow 0$, neighboring spins have a very low probability of different orientations, so that a uniform overall state (ferromagnetic state) is adopted, characterized by $+1$ or $−1$.\n", "At very high temperatures $T \\rightarrow \\infty$, a paramagnetic phase with random spin alignment results, yielding $50\\%$ of $+1$ and $0%$ of $−1$ orientations.\n", "Below a critical temperature $0 < T < T_c$, stable ferromagnetic domains emerge, with both orientations being equally probable in the absence of an external magnetic field.\n", "The spin-spin correlation function diverges at $T_c$, whereas the correlation decays for $T > T_c$.\n", "\n", "The data for this task contain the $n \\times n$ dipole orientations on the lattice for different temperatures $T$.\n", "Classify the two magnetic phases (paramagnetic/ferromagnetic) using a convolutional neural network!" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "keras 2.4.0\n" ] } ], "source": [ "from tensorflow import keras\n", "import numpy as np\n", "callbacks = keras.callbacks\n", "layers = keras.layers\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load and prepare dataset\n", "See https://doi.org/10.1038/nphys4035 for more information" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import gdown\n", "url = \"https://drive.google.com/u/0/uc?export=download&confirm=HgGH&id=1Ihxt1hb3Kyv0IrjHlsYb9x9QY7l7n2Sl\"\n", "output = 'ising_data.npz'\n", "gdown.download(url, output, quiet=True)\n", "\n", "f = np.load(output)\n", "n_train = 20000\n", "\n", "x_train, x_test = f[\"C\"][:n_train], f[\"C\"][n_train:]\n", "T_train, T_test = f[\"T\"][:n_train], f[\"T\"][n_train:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "for i,j in enumerate(np.random.choice(n_train, 6)):\n", " plt.subplot(2,3,i+1)\n", " image = x_train[j]\n", " plot = plt.imshow(image)\n", " plt.title(\"T: %.2f\" % T_train[j])\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'frequency')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEGCAYAAACKB4k+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAT+UlEQVR4nO3df7Ad5X3f8ffHCLCDMeKHoihCtmitxEM7BuM7VLY7HtfEiSEuYlxMYZwgM2TUprTFk860StomTZo/7DYxNWmKRzVOhevaptguKoEkjCCTaWfAvmCMsbHLDTWVFIEuv4RtGrtivv3jPFqOL1fSkXT3HOne92vmzHn22efsfncW8bn74+xJVSFJEsCrJl2AJOnYYShIkjqGgiSpYyhIkjqGgiSps2zSBRyNs846q9auXTvpMiTpuPLAAw88XVUr5pt3XIfC2rVrmZ6ennQZknRcSfLEgeZ5+kiS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1Dmuv9Es6ZXWbv7Diaz3Ox/5+YmsVwvLIwVJUscjBUnHPY+OFk5vRwpJfjrJQ0OvF5J8OMkZSe5O8lh7P72NT5Ibk8wkeTjJBX3VJkmaX2+hUFXfrqrzq+p84K3Ai8CXgM3A9qpaB2xv0wAXA+vaaxNwU1+1SZLmN65rChcBf15VTwAbgK2tfytwWWtvAG6pgfuA5UlWjak+SRLjC4Urgc+29sqq2t3aTwIrW3s1sGPoMztb349IsinJdJLp2dnZvuqVpCWp91BIchJwKfBf586rqgLqcJZXVVuqaqqqplasmPeHgyRJR2gcdx9dDDxYVU+16aeSrKqq3e300J7WvwtYM/S5s1ufJB2TJnXXE/R359M4Th9dxcunjgC2ARtbeyNw+1D/1e0upPXA3qHTTJKkMej1SCHJKcB7gL831P0R4NYk1wJPAFe0/juBS4AZBncqXdNnbUvRYvyr5lC8f106PL2GQlV9HzhzTt8zDO5Gmju2gOv6rEeSdHA+5kKS1Fmyj7lYiqdSJOlQPFKQJHUMBUlSx1CQJHUMBUlSx1CQJHUMBUlSx1CQJHUMBUlSx1CQJHUMBUlSx1CQJHUMBUlSx1CQJHUMBUlSx1CQJHUMBUlSx1CQJHV6DYUky5PcluRbSR5N8rYkZyS5O8lj7f30NjZJbkwyk+ThJBf0WZsk6ZX6PlL4OPBHVfUm4DzgUWAzsL2q1gHb2zTAxcC69toE3NRzbZKkOXoLhSSnAe8Ebgaoqh9W1fPABmBrG7YVuKy1NwC31MB9wPIkq/qqT5L0Sn0eKZwDzAJ/kOSrST6Z5BRgZVXtbmOeBFa29mpgx9Dnd7a+H5FkU5LpJNOzs7M9li9JS0+fobAMuAC4qareAnyfl08VAVBVBdThLLSqtlTVVFVNrVixYsGKlST1Gwo7gZ1VdX+bvo1BSDy1/7RQe9/T5u8C1gx9/uzWJ0kak95CoaqeBHYk+enWdRHwTWAbsLH1bQRub+1twNXtLqT1wN6h00ySpDFY1vPy/xHwmSQnAY8D1zAIoluTXAs8AVzRxt4JXALMAC+2sZKkMeo1FKrqIWBqnlkXzTO2gOv6rEeSdHB+o1mS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1DEUJEmdXkMhyXeSfD3JQ0mmW98ZSe5O8lh7P731J8mNSWaSPJzkgj5rkyS90jiOFP5WVZ1fVVNtejOwvarWAdvbNMDFwLr22gTcNIbaJElDJnH6aAOwtbW3ApcN9d9SA/cBy5OsmkB9krRk9R0KBfxJkgeSbGp9K6tqd2s/Caxs7dXAjqHP7mx9PyLJpiTTSaZnZ2f7qluSlqRlPS//b1bVriQ/Dtyd5FvDM6uqktThLLCqtgBbAKampg7rs5Kkg+v1SKGqdrX3PcCXgAuBp/afFmrve9rwXcCaoY+f3fokSWPSWygkOSXJqfvbwM8CjwDbgI1t2Ebg9tbeBlzd7kJaD+wdOs0kSRqDPk8frQS+lGT/ev5LVf1Rkq8Atya5FngCuKKNvxO4BJgBXgSu6bE2SdI8eguFqnocOG+e/meAi+bpL+C6vuqRJB2a32iWJHUMBUlSp+9bUiUA1m7+w0mXIGkEHilIkjqGgiSpYyhIkjpeU5C0ILxutDgc8kihPczuuv2/eyBJWrxGOX30d4GfBL6S5HNJfi7ta8qSpMXlkKePqmoG+OdJ/iXwPuBTwEtJ/gD4eFU923ONi46H2Yuf+1jHq5EuNCd5M/C7wL8FvgB8AHgBuKe/0iRJ43bII4UkDwDPAzcDm6vqB23W/Une0WNtkqQxG+Xuow+0h9u9QlW9f4HrkSRN0Cinj34pyfL9E0lOT/Lb/ZUkSZqUUULh4qp6fv9EVT3H4HcPJEmLzCihcEKSk/dPJHkNcPJBxkuSjlOjXFP4DLC93YIKg19E29pfSZKkSRnlewofTfIwL/9a2r+uqj/utyxJ0iSM9OyjqroLuKvnWiRJEzbKs4/en+SxJHuTvJDku0leGEdxkqTxGuVC878BLq2q06rqdVV1alW9btQVJDkhyVeT3NGmz0lyf5KZJJ9PclLrP7lNz7T5a49oiyRJR2yUUHiqqh49inVcDwx//qPADVX1RuA54NrWfy3wXOu/oY2TJI3RKKEw3f6Cv6qdSnp/kpG+yZzkbODngU+26QDvBm5rQ7YCl7X2Bl6+q+k24CKfxipJ4zXKhebXAS8CPzvUV8AXR/jsvwP+KXBqmz4TeL6q9rXpncDq1l4N7ACoqn1J9rbxTw8vMMkmYBPA61//+hFKkCSNapRbUq85kgUneR+wp6oeSPKuI1nGAerZAmwBmJqaqoVariRptLuPfirJ9iSPtOk3J/kXIyz7HcClSb4DfI7BaaOPA8uT7A+js4Fdrb0LWNPWsQw4DXjmMLZFknSURrmm8B+BXwX+H0BVPQxceagPVdWvVtXZVbW2jb+nqj4I3Atc3oZtBG5v7W1tmjb/nqrySECSxmiUUPixqvrynL59844czT8DfiXJDINrBje3/puBM1v/rwCbj2IdkqQjMMqF5qeT/FUGF5dJcjmw+3BWUlV/Cvxpaz8OXDjPmL9k8ItukqQJGSUUrmNwYfdNSXYB/xv4hV6rkiRNxCh3Hz0O/EySU4BXVdV3+y9LkjQJo/xG86/PmQagqn6rp5okSRMyyumj7w+1Xw28jx99bIUkaZEY5fTR7w5PJ/kdwN9TkKRFaJRbUuf6MQZfOpMkLTKjXFP4Ou12VOAEYAXg9QRJWoRGuabwvqH2PgaP0j6aL69Jko5Ro4TC3FtQXzf8ROuqenZBK5IkTcwoofAggwfVPQcEWA78nzavgL/SS2WSpLEb5ULz3cDfrqqzqupMBqeT/qSqzqkqA0GSFpFRQmF9Vd25f6Kq7gLe3l9JkqRJGeX00V+030/4z236g8Bf9FeSJGlSRjlSuIrBbahfYvATnCtanyRpkRnlG83PAtcnOaWqvn+o8ZKk49coP8f59iTfpD3vKMl5Sf5D75VJksZulNNHNwA/R/u95Kr6GvDOPouSJE3GSM8+qqodc7pe6qEWSdKEjXL30Y4kbwcqyYnA9fjobElalEY5Uvj7DH6SczWwCzi/TR9Uklcn+XKSryX5RpLfbP3nJLk/yUySzyc5qfWf3KZn2vy1R7pRkqQjc9BQSHIC8PGq+mBVrayqH6+qX6iqZ0ZY9g+Ad1fVeQyC5L1J1gMfBW6oqjcyeHTGtW38tcBzrf+GNk6SNEYHDYWqegl4w/6/5g9HDXyvTZ7YXgW8G7it9W8FLmvtDW2aNv+iDD95T5LUu1GuKTwO/M8k2xj6ac6q+tihPtiONB4A3gj8PvDnwPNDj97eyeC0FO19R1v2viR7gTOBp0fbFEnS0TrgkUKST7fmpcAdbeypQ69DqqqXqup8Br/UdiHwpqMpttW1Kcl0kunZ2dmjXZwkacjBjhTemuQnGTwm+/eOZiVV9XySe4G3AcuTLGtHC2czuHhNe18D7EyyDDiN9t2IOcvaAmwBmJqaqrnzJUlH7mDXFD4BbAd+Cpgeej3Q3g8qyYoky1v7NcB7GNzKei9weRu2Ebi9tbe1adr8e6rK/+lL0hgd8Eihqm4EbkxyU1X98hEsexWwtV1XeBVwa1Xd0R6Z8bkkvw18Fbi5jb8Z+HSSGeBZ4MojWKck6SiM8kC8IwkEquph4C3z9D/O4PrC3P6/BD5wJOuSJC2MkR5zIUlaGgwFSVLHUJAkdQwFSVLHUJAkdQwFSVLHUJAkdQwFSVLHUJAkdQwFSVLHUJAkdQwFSVLHUJAkdQwFSVLHUJAkdQwFSVLHUJAkdQwFSVLHUJAkdQwFSVKnt1BIsibJvUm+meQbSa5v/WckuTvJY+399NafJDcmmUnycJIL+qpNkjS/Po8U9gH/pKrOBdYD1yU5F9gMbK+qdcD2Ng1wMbCuvTYBN/VYmyRpHr2FQlXtrqoHW/u7wKPAamADsLUN2wpc1tobgFtq4D5geZJVfdUnSXqlsVxTSLIWeAtwP7Cyqna3WU8CK1t7NbBj6GM7W9/cZW1KMp1kenZ2tr+iJWkJ6j0UkrwW+ALw4ap6YXheVRVQh7O8qtpSVVNVNbVixYoFrFSS1GsoJDmRQSB8pqq+2Lqf2n9aqL3vaf27gDVDHz+79UmSxqTPu48C3Aw8WlUfG5q1DdjY2huB24f6r253Ia0H9g6dZpIkjcGyHpf9DuAXga8neaj1/RrwEeDWJNcCTwBXtHl3ApcAM8CLwDU91iZJmkdvoVBV/wPIAWZfNM/4Aq7rqx5J0qH5jWZJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1eguFJJ9KsifJI0N9ZyS5O8lj7f301p8kNyaZSfJwkgv6qkuSdGB9Hin8J+C9c/o2A9urah2wvU0DXAysa69NwE091iVJOoDeQqGq/gx4dk73BmBra28FLhvqv6UG7gOWJ1nVV22SpPmN+5rCyqra3dpPAitbezWwY2jcztb3Ckk2JZlOMj07O9tfpZK0BE3sQnNVFVBH8LktVTVVVVMrVqzooTJJWrrGHQpP7T8t1N73tP5dwJqhcWe3PknSGI07FLYBG1t7I3D7UP/V7S6k9cDeodNMkqQxWdbXgpN8FngXcFaSncBvAB8Bbk1yLfAEcEUbfidwCTADvAhc01ddkqQD6y0UquqqA8y6aJ6xBVzXVy2SpNH4jWZJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1DAVJUsdQkCR1jqlQSPLeJN9OMpNk86TrkaSl5pgJhSQnAL8PXAycC1yV5NzJViVJS8sxEwrAhcBMVT1eVT8EPgdsmHBNkrSkLJt0AUNWAzuGpncCf2PuoCSbgE1t8ntJvn2E6zsLePoIP3u8cpuXBrd5CchHj2qb33CgGcdSKIykqrYAW452OUmmq2pqAUo6brjNS4PbvDT0tc3H0umjXcCaoemzW58kaUyOpVD4CrAuyTlJTgKuBLZNuCZJWlKOmdNHVbUvyT8E/hg4AfhUVX2jx1Ue9Smo45DbvDS4zUtDL9ucqupjuZKk49CxdPpIkjRhhoIkqbOoQyHJp5LsSfLIAeYnyY3tsRoPJ7lg3DUutBG2+V1J9iZ5qL1+fdw1LrQka5Lcm+SbSb6R5Pp5xiyqfT3iNi+qfZ3k1Um+nORrbZt/c54xJyf5fNvP9ydZO4FSF8SI2/uhJLND+/iXjnrFVbVoX8A7gQuARw4w/xLgLiDAeuD+Sdc8hm1+F3DHpOtc4G1eBVzQ2qcC/ws4dzHv6xG3eVHt67bvXtvaJwL3A+vnjPkHwCda+0rg85Ouu+ft/RDw7xdyvYv6SKGq/gx49iBDNgC31MB9wPIkq8ZTXT9G2OZFp6p2V9WDrf1d4FEG35Aftqj29YjbvKi0ffe9Nnlie829U2YDsLW1bwMuSpIxlbigRtzeBbeoQ2EE8z1aY1H/w2re1g5J70ry1yZdzEJqpwvewuCvqmGLdl8fZJthke3rJCckeQjYA9xdVQfcz1W1D9gLnDnWIhfQCNsL8HfaKdHbkqyZZ/5hWeqhsBQ9CLyhqs4Dfg/4b5MtZ+EkeS3wBeDDVfXCpOsZh0Ns86Lb11X1UlWdz+CJBxcm+esTLqlXI2zvfwfWVtWbgbt5+SjpiC31UFhyj9aoqhf2H5JW1Z3AiUnOmnBZRy3JiQz+5/iZqvriPEMW3b4+1DYv1n0NUFXPA/cC750zq9vPSZYBpwHPjLW4Hhxoe6vqmar6QZv8JPDWo13XUg+FbcDV7c6U9cDeqto96aL6lOQn9p9jTXIhg/8Gjut/NG17bgYeraqPHWDYotrXo2zzYtvXSVYkWd7arwHeA3xrzrBtwMbWvhy4p9oV2ePNKNs757rYpQyuLR2VY+YxF31I8lkGd2CclWQn8BsMLtZQVZ8A7mRwV8oM8CJwzWQqXTgjbPPlwC8n2Qf8X+DK4/UfzZB3AL8IfL2dfwX4NeD1sGj39SjbvNj29SpgawY/yPUq4NaquiPJbwHTVbWNQVB+OskMgxsurpxcuUdtlO39x0kuBfYx2N4PHe1KfcyFJKmz1E8fSZKGGAqSpI6hIEnqGAqSpI6hIEnqLOpbUqVxSnImsL1N/gTwEjDbpi+sqh9OpDDpMHhLqtSDJP8K+F5V/c6ka5EOh6ePJEkdQ0GS1DEUJEkdQ0GS1DEUJEkdQ0GS1PGWVElSxyMFSVLHUJAkdQwFSVLHUJAkdQwFSVLHUJAkdQwFSVLn/wM5oOPmFcC43wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.hist(T_test)\n", "plt.xlabel(\"T\")\n", "plt.ylabel(\"frequency\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Set up training data - define magnetic phases" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "Tc = 2.27\n", "y_train, y_test = T_train > Tc, T_test > Tc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " ### Task\n", "\n", " - evaluate the test accuracy for a convolutional network,\n", " - plot the test accuracy vs. temperature.\n", " - compare to the results obtained using a fully-connected network (Exercise 7.1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", "Layer (type) Output Shape Param # \n", "=================================================================\n", "reshape (Reshape) (None, 32, 32, 1) 0 \n", "_________________________________________________________________\n", "conv2d (Conv2D) (None, 32, 32, 16) 160 \n", "_________________________________________________________________\n", "conv2d_1 (Conv2D) (None, 32, 32, 16) 2320 \n", "_________________________________________________________________\n", "max_pooling2d (MaxPooling2D) (None, 16, 16, 16) 0 \n", "_________________________________________________________________\n", "conv2d_2 (Conv2D) (None, 16, 16, 32) 4640 \n", "_________________________________________________________________\n", "conv2d_3 (Conv2D) (None, 16, 16, 32) 9248 \n", "_________________________________________________________________\n", "global_average_pooling2d (Gl (None, 32) 0 \n", "_________________________________________________________________\n", "dropout (Dropout) (None, 32) 0 \n", "_________________________________________________________________\n", "dense (Dense) (None, 1) 33 \n", "=================================================================\n", "Total params: 16,401\n", "Trainable params: 16,401\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "model = keras.models.Sequential()\n", "model.add(layers.InputLayer(input_shape=(32, 32)))\n", "model.add(layers.Reshape((32, 32,1)))\n", "model.add(layers.Convolution2D(16, (3, 3), padding='same', activation='relu'))\n", "model.add(layers.Convolution2D(16, (3, 3), padding='same', activation='relu'))\n", "model.add(layers.MaxPooling2D((2, 2)))\n", "model.add(layers.Convolution2D(32, (3, 3), padding='same', activation='relu'))\n", "model.add(layers.Convolution2D(32, (3, 3), padding='same', activation='relu'))\n", "model.add(layers.GlobalAveragePooling2D())\n", "model.add(layers.Dropout(0.25))\n", "model.add(layers.Dense(1, activation='sigmoid'))\n", "\n", "model.summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### prepare model for training" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "model.compile(\n", " loss='binary_crossentropy',\n", " optimizer=keras.optimizers.Adam(0.001),\n", " metrics=['accuracy'])\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/50\n", "282/282 - 10s - loss: 0.0859 - accuracy: 0.9613 - val_loss: 0.0413 - val_accuracy: 0.9820\n", "Epoch 2/50\n", "282/282 - 14s - loss: 0.0448 - accuracy: 0.9805 - val_loss: 0.0366 - val_accuracy: 0.9835\n", "Epoch 3/50\n", "282/282 - 14s - loss: 0.0443 - accuracy: 0.9821 - val_loss: 0.0365 - val_accuracy: 0.9835\n", "Epoch 4/50\n", "282/282 - 14s - loss: 0.0458 - accuracy: 0.9795 - val_loss: 0.0373 - val_accuracy: 0.9850\n", "Epoch 5/50\n", "282/282 - 13s - loss: 0.0442 - accuracy: 0.9808 - val_loss: 0.0367 - val_accuracy: 0.9850\n", "\n", "Epoch 00005: ReduceLROnPlateau reducing learning rate to 0.0006700000318232924.\n", "Epoch 6/50\n", "282/282 - 13s - loss: 0.0413 - accuracy: 0.9822 - val_loss: 0.0366 - val_accuracy: 0.9855\n", "Epoch 7/50\n", "282/282 - 14s - loss: 0.0404 - accuracy: 0.9816 - val_loss: 0.0365 - val_accuracy: 0.9855\n", "\n", "Epoch 00007: ReduceLROnPlateau reducing learning rate to 0.0004489000252215192.\n", "Epoch 8/50\n", "282/282 - 13s - loss: 0.0396 - accuracy: 0.9827 - val_loss: 0.0376 - val_accuracy: 0.9810\n", "Epoch 00008: early stopping\n" ] } ], "source": [ "results = model.fit(x_train, y_train,\n", " batch_size=64,\n", " epochs=50,\n", " verbose=2,\n", " validation_split=0.1,\n", " callbacks=[\n", " callbacks.EarlyStopping(patience=5, verbose=1),\n", " callbacks.ReduceLROnPlateau(factor=0.67, patience=2, verbose=1)]\n", " )" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(1, (12, 4))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(results.history['loss'])\n", "plt.plot(results.history['val_loss'])\n", "plt.ylabel('loss')\n", "plt.xlabel('epoch')\n", "plt.legend(['train', 'val'], loc='upper right')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(1, (12, 4))\n", "plt.subplot(1, 2, 1)\n", "plt.plot(results.history['accuracy'])\n", "plt.plot(results.history['val_accuracy'])\n", "plt.ylabel('accuracy')\n", "plt.xlabel('epoch')\n", "plt.legend(['train', 'val'], loc='upper right')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Evaluate training" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "\n", "preds = model.predict(x_test).round().squeeze()\n", "acc = (preds == y_test).astype(np.float)\n", "\n", "ax = sns.regplot(x=T_test, y=acc, x_estimator= np.mean, fit_reg=False)\n", "ax.set_ylabel(\"accuracy\")\n", "ax.set_xlabel(\"T\")\n", "plt.axvline(x=Tc, color='k', linestyle='--', label='Tc')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }