{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Exercise 10.1\n", "## Signal Classification using Dynamic Graph Convolutional Neural Networks\n", "After a long journey through the universe before reaching the earth, the cosmic particles interact with the galactic magnetic field $B$.\n", "As these particles carry a charge $q$ they are deflected in the field by the Lorentz force $F = q \\cdot v × B$.\n", "Sources of cosmic particles are located all over the sky, thus arrival distributions of the cosmic particles are isotropic in general. However, particles originating from the same source generate on top of the isotropic\n", "arrival directions, street-like patterns from galactic magnetic field deflections.\n", "\n", "In this tasks we want to classify whether a simulated set of $500$ arriving cosmic particles contains street-like patterns (signal), or originates from an isotropic background.\n", "\n", "Training graph networks can be computationally demanding, thus, we recommend to use a GPU for this task." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "keras 2.4.0\n" ] } ], "source": [ "from tensorflow import keras\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "\n", "layers = keras.layers\n", "\n", "print(\"keras\", keras.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Download EdgeConv Layer" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import gdown\n", "import os\n", "\n", "url = \"https://raw.githubusercontent.com/DeepLearningForPhysicsResearchBook/deep-learning-physics/main/edgeconv.py\"\n", "output = 'edgeconv.py'\n", "\n", "if os.path.exists(output) == False:\n", " gdown.download(url, output, quiet=False)\n", "\n", "from edgeconv import EdgeConv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Download Data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "url = \"https://drive.google.com/u/0/uc?export=download&confirm=HgGH&id=1XKN-Ik7BDyMWdQ230zWS2bNxXL3_9jZq\"\n", "output = 'cr_sphere.npz'\n", "\n", "if os.path.exists(output) == False:\n", " gdown.download(url, output, quiet=True)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "f = np.load(output)\n", "n_train = 10000\n", "x_train, x_test = f['data'][:-n_train], f['data'][-n_train:]\n", "labels = keras.utils.to_categorical(f['label'], num_classes=2)\n", "y_train, y_test = labels[:-n_train], labels[-n_train:]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train.shape (40000, 500, 4)\n", "y_train.shape (40000, 2)\n" ] } ], "source": [ "print(\"x_train.shape\", x_train.shape)\n", "print(\"y_train.shape\", y_train.shape)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# define coordinates for very first EdgeConv\n", "train_points, test_points = x_train[..., :3], x_test[..., :3]\n", "\n", "# Use normalized Energy as features for convolutional layers\n", "train_features, test_features = x_train[..., -1, np.newaxis], x_test[..., -1, np.newaxis]\n", "train_features = np.concatenate([train_features, train_points], axis=-1)\n", "\n", "train_input_data = [train_points, train_features]\n", "test_input_data = [test_points, test_features]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot an example sky map" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": 7OdKmdeUiJ54MTVR4HN9OWs3pI9dxdrPlzUldaNCqkkHOrdPpEUKgUhU9gRIZlcCg1xaSmPQwJKZsGRd+/um1IiW4dDo94YFR2DpaFcjf8/5/TzBtyLxs9hXX52DvXLA1mhUU8sLoOX9z5NKdLLZO9f2ZNqx4tEG+evIm77X5HG3aw+z/Cb++Sas+jbh0/DrvtZ2WZbxvFW8+WzGWD7t8Q+jdrCVT35o5gJdfb5ntGjqdjiFDhsRs27ZtdVhY2IjiKFaVrf8nIIQo6+7uvnvbtm0e1aoZtn6lSqXC3DI9OaZF70b5HpNqTBYs2s3KVccyjw8fCeDjyd2KlPAwNA7ONny+8LV8OXdRfgi4dDkwi0gFuHEzjKioBBwdC39B9esX7rFnzQm2Lz9ETHgcJmYa3vi8Ny8NaZ6v13Utlb0Os4W1OdY59F1/Xu7GRbMv8BZ2Zma0LVUBc7Vy61B4cTrW988mVDvWLz7NK7Yu3ZdFpAJs+HkXrfo04sbZO9nG37pwDydPB1KSsjehio9OyGYDUKvV/P7773ajR4/u8/fff0shxMjiJlaVb5scEEJ4ubm57dm4caPBRWpJIiYmkVWrj2Wx7d13mYDrDalQ3t1IXhV/IqMSWLJkP4FB0VSq5MnA/o0xNUI4SV5xsM8urExM1Fg9o+JBYWDf+pNMf2sp+uSHtRDTUrTMHb+CynXLvlBt3WdRsXYZXh7Wig0ZFShUahVvzx6ExkD91Y8E3+G1HatI0qWHFtR09mB5h1ex0OSuZmtqciobFu0i8HoIFeuUoe2ApoWq5FdJQkrJ8W1nuXMlkHLVfKjVKufySQVFl0aVSdPq+PfgeVRCRf82tWhWreyzJxYRclqQeWArlUMMuVc5N9QaNY0712Tz0v2ZdrVGRb12T9YiQgjmzZtnExcX13vbtm1RwEcv7n3hofDfvQoYIYSLq6vr/tWrV3vXrl3b2O4UaRISU8npuS4uLrngnTECxigUn5qq5f0PlnEnowXrfydvcedOBJ983KNA/XgeKvl70qZVZXbuftixZeTrLTDLhyL2BzacYt0ve9Dp9HQc0IR2fRs+e9JTWPTZv+hzKPgtpSTg3N1cCdXYiDiklNjlcbteCMGYWQNp1achIXfCqVDLl1IVnt4cIi9MO7E7U6QCnA4PYnXAOQb5P/v7Ua/X83HPbzm160Km7ep/N3hrzlCD+aeQe+Z+8DvrF+7MPH51XBeGTjVeIpkQgp7Nq9OzeXWj+ZCfdBjSnM1L9qJNffj38/KI9KoG1Zv60fPtjvz9wxYAbBys+GB+eqODN77si1AJjm49i42DFUMn96BctacnxwohWLJkiV2PHj1GOjs7x4SHh0976oQihBKj+ghCCDsXF5ejf/zxR4X27dsrj/wviF4vGTnqV27cCMu0OTlas+S3EVhaFv5VsuflyM6L/PTZv4SHxFK9QVk+nNEXpwJqP3viv5t8NOGvbPY1f7+D3WOB+IURvV5y+GgAQUHRVPL3pEplL4Nf48SuC0zpPzeLbdzcobTuWf+5z9ml7HtoU7Sg1WZ77Zu171O1YYUnzk1L1fLNa/PYveIgAE171GfiH29jav5cVfAMToOVcwlJylrqbGz1xrxfq9kz5148co33Wn2Wzb7q3jxsnZRmFAXJ/YBgXq85PotNCMHygO9wKObtsQsSKSWh99NLUbp6OXDlxA3+/Wk7qclptO3fhMYvZ33AC7weQmRIDGWqeGNlgHCdtLQ0OnbsGH3q1KkpkZGRP77wCQsByopqBhnF/Pf99NNP5RSRahhUKsFXX/Tm2zlbuRaQnkz1ztj22URqcHgsPyzby93gKPzKuPHWq82xsy78oion7t8KY9rbv2fGJZ0+FMD095cz4883C+T6T1rBLSoRwSqVoEmjJ4s6Q7DzsXAUgJ2rjr6QUK3bshJHtp1P7yb2SNvEbiNaUaVB+afOXTVrfaZIBTiw5hjLvlzD0M/6Prc/hqSJpy//XD+fxdbQPXelz5ITU3K05xSDp5C/RIdlbyYjpSQmIk4RqgYiOTGVaW//zol9VwCo28KP//0wiAm/Pvn737OcG57P6GSYF0xMTNi4caN9y5YtP7Ozs4uJiYn5/dmzCjeKUCW9Laqrq+vOGTNm+PXs2VP5TAyIi4stX017cgON5NQ0xkxbyf2Mfs9Xb4dxNyiKn6b0LZIJV+eP38oWPH/u2A3SUrUFUnasWlVvypZ1ybKK3aZ1ZWyLwGpqQWFqlv3fwfQFwwvemzWAOR8u5+S+y9g6WNK2Zz2adalNmVysCF95rHA/kK2YvzH5pH5bEtNS2XE3AGtTUz6s1ZzGHs9uQgFQuUF5XEs5EZoRigJQpXFFnHMouq+Qv5St5oODq22WFqXuvi54KfkCBmPF/F2ZIhXgxN4rrJy/m8HvdShQP8zNzdmxY4dDkyZN5lhbW8fFx8f/W6AOGJgSL8qEEGpXV9dNkydPrjFkyJDiux9dSLkQEJwpUh9w+sp9QiLicC+CpXUcXLJnp9vYW6IxKZi6tKamGr79pj9/LjtEYHA0lfw86d3r+VcKiyNdXm/Jnn9OkJqSvqqn1qjoPqLVC53T1sGKj38Z/lxzPXJYTfEoa7gVlhfF1tSM+a16oJcSwZNX7XPC3Mqc6VsmsmD8svRkqtpleGN6/yL5EFrUsbA254s1HzJ79C/cuRJI2Wo+fDB/uFHqdhdG0rQ6boVH4WBpgbOt1XOd4/rF+7myFQTW1tbs2bPHsVGjRj+bmprGpqam7jKKIwagxP+Gurq6/jR8+PAGb7/9draCh7FJyRy5eRczjYZGZUthqinxH5fBMXuCgDM1UMZyQVOnmR8NWlfi6K5LAAiV4M3JXQv0xmxra8GoN4tuG8L8pny1UszZPI7NfxxEp9XTrl9D/GuXMZo//Se9wqmd5zLL1fhU8mKwERNcnoTqOX+HPcu68emq9wzsjcLzUL5GaeYezB4zXNIJCA5n1C//EhQdhxDwRpsGvNWhcZ7PU6qsKyf2Xsli8y7naig384yDgwN79+51qlev3gohRFMp5VWjOfMClOhkKgcHhzebNm361bp16+wfFxK3IqIYvHgVoXHptcsqe7iydEgvrM2VRVdDotPrGfvVav67eDfT1qVFVf43smC3SgyJTqfn6K6LhAfHULVuGcpWUlpZGoMjYdc5E3UHL0sHOnpWQ6MquG5rNy8FsvPv4wiVoH2fBpQq//QV0tSUNM4fuIzUS6o18y/QRKrAkGhWrfuP+IQUmjUoT/NGFQvs2goKhYE+3/3JxXuhWWy/vtmL+uVK5ek8cTGJTH79F66eTb+fVaxeimm/DcfayKFXZ8+epV27drdDQ0NrSSmjjOrMc1BihapGo2nm5+f37/Hjxx0tLbNn2r29Yj3bL2eNE3u7ZSPGtHyxMjYK2UlOTWP1ttPcDY7Cv4wbXVtVQ63UWVR4AX4N2Md3l7dnHrdw82NO3f6oRP7/Xl0+dYvxvX8gNaOlqbmlKbPWvEvZfKhg8KKERcQx7N0lRMcmZdo+HN2erh1qZBur10uW/nOEDTvOoVIJ+rxch16di18Jv8S0VOafPk5AVASVnFwYUaMe5spuWrGm5vjv0Or1WWzju7ZgULO8/37rtDqunb+PEFC+qnehacLy77//at94442ToaGhTaSU2cuTFGJK5F+fEMLX09Nz9bZt23IUqQCBMdkzJINi4vLbtRKJuakJA1+u9+yBCgq5IFmXxryrWcOx9oZc4VTkHeo4+eb79f9ZuDtTpEJ6JvC/P+/h/W8H5Pu188rW3ReziFSAv/49nqNQ/XvzKX5e/rA6wZxfdmFva0nbpv55uqZeSnbfuEFgbBy1PT2p4ma8rdHH0UvJ65vXcCQwfUVs042rnAoJ4pdOPYiNjOeXKSu5cf4upSp6MPzzPjh5KElhRZlbl+4zY/gChJ85OGVd9Szr6vRc51Rr1PjXzF1VjIKke/fumrNnz1aZO3fuQuB1Y/uTF0qcUBVC2Li6uu5Ys2aNq5fXk1c4apXy5EJQ1q2AWj7KFm5eSIhN4sCWs6QkpdKgTRXcvLO3eiypSClZPX8Xa3/Zi16vp+OrjRj4QSelY48BSNKlkqbPXnw/OjWxYK6fkL0kU1Jiag4jjY9Ol/1zStPqcxgJB45nr0Sw/1hAnoSqXkre2bCRTVfTQ+UE8GX7dvQpJB0AL0eEZYrUB+y6c4Ob0ZF812seV/67AUDA6dtcP3uHeQc/NVgHMIWCRafV8Unf7wi+FYZDkAXhnXzRW6ZX/xjUrBaNKxY+sfmiTJkyxero4WM9nRyczkRERXxnbH9yS4m6KwohVC4uLuu//fZbn/r1n54J/V6bJrT1LweAWgiGNqrNKzUrF4SbBcqf+07R/tOfafXxAuas349On/NNKq/ERMYztuu3zPloBT9N/YdRHWZw5Uz23sYllR2rj/Prl+uICIkhKiyO5d9vY+2v+4ztVrHA3sSS6vZZY8tsNObUcsxdSaUXpUXXh9uFUqtDpqbi4W1PYQyzatXEH7PHsr47t6ma41ibHFrZ2ljnLWb/2L17mSIVQAJf7N6TbdvVWOhkzn4E3wjLFKkPuHM5kIDTtwvCrWeSnJTK0b1XOLLnMsmF9KGosBFyO5zgW+ll/EzDk3BffgWXfwOY2bIZ47u2LJaVKfb9cwzdYUdbjdZiponGpLWx/cktJepR0MXFZfaQIUNqDxgw4JlFE63MTPmxX1fiU1IxUakwK4ZPzdtOX2X6mj2Zx7/uOoG1hRnD2754OaN1i/cTeCs88zgpIYWlMzcx7feCKXxf2Dn+SEvJR209hrcseGeKGUIIvq37Kp+dXcvZqLt4WtozserLOJo9X8mZvNKmZz1SklL5Y/p6ou6n1w9dOWsjceFxvDNncIH4kFt8vB357ou+LF11hPiEZJo1rEDvLnVzHDuoZ0OOnr5FUnJ6WS9bG3P6PWFsREwCZ67ex87aglp+3qhU6Tf90PiEbGMT0tJISkvDxsz4iaqVnVyp5uLGubC