{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Notebook_3_Feedforward_neural_networks.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "kZ2jDsmh2paw", "colab_type": "text" }, "source": [ "# Neural Networks for Data Science Applications\n", "\n", "## Lab session 3: Feedforward networks with tf.data and tf.layers\n", "\n", "**Contents of the lab session:**\n", "\n", "* Building a feedforward neural networks on a realistic binary classification task.\n", "* Using tf.data to iterate on the dataset.\n", "* Using tf.layers to build a model by stacking several components.\n", "* The high-level Keras compile/fit interface." ] }, { "cell_type": "code", "metadata": { "id": "fsD4s0gW1mSz", "colab_type": "code", "outputId": "4795f478-94a7-47ea-a4d0-709f319879df", "colab": { "base_uri": "https://localhost:8080/", "height": 104 } }, "source": [ "# We are making use of the GPU here, so remember to enable it on Colab by:\n", "# Runtime >> Change runtime type >> Hardware accelerator (before starting the VM).\n", "!pip install -q tensorflow-gpu==2.0.0" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "\u001b[K |████████████████████████████████| 380.8MB 74kB/s \n", "\u001b[K |████████████████████████████████| 3.8MB 32.3MB/s \n", "\u001b[K |████████████████████████████████| 450kB 52.0MB/s \n", "\u001b[31mERROR: tensorflow 1.15.0rc3 has requirement tensorboard<1.16.0,>=1.15.0, but you'll have tensorboard 2.0.0 which is incompatible.\u001b[0m\n", "\u001b[31mERROR: tensorflow 1.15.0rc3 has requirement tensorflow-estimator==1.15.1, but you'll have tensorflow-estimator 2.0.0 which is incompatible.\u001b[0m\n", "\u001b[?25h" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "6HtSkoCt3gL1", "colab_type": "text" }, "source": [ "## Load the dataset with tf.data" ] }, { "cell_type": "markdown", "metadata": { "id": "46mB7pC63mR6", "colab_type": "text" }, "source": [ "We'll do classification of super-symmetric particles from simulated particle collision experiments.\n", "\n", "+ The dataset is here: https://archive.ics.uci.edu/ml/datasets/SUSY\n", "+ Before starting, it is *highly advisable* to read the reference paper: https://www.nature.com/articles/ncomms5308." ] }, { "cell_type": "code", "metadata": { "id": "7AboYpmr3bwN", "colab_type": "code", "outputId": "9fe8bd82-c7a2-4701-aba2-2379c4766c67", "colab": { "base_uri": "https://localhost:8080/", "height": 208 } }, "source": [ "# Download the dataset\n", "!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00279/SUSY.csv.gz" ], "execution_count": 2, "outputs": [ { "output_type": "stream", "text": [ "--2019-10-18 17:17:36-- https://archive.ics.uci.edu/ml/machine-learning-databases/00279/SUSY.csv.gz\n", "Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252\n", "Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 922377711 (880M) [application/x-httpd-php]\n", "Saving to: ‘SUSY.csv.gz’\n", "\n", "SUSY.csv.gz 100%[===================>] 879.65M 22.2MB/s in 41s \n", "\n", "2019-10-18 17:18:17 (21.7 MB/s) - ‘SUSY.csv.gz’ saved [922377711/922377711]\n", "\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "bLhmxdap4Cvv", "colab_type": "code", "colab": {} }, "source": [ "# Unzip the .gz file\n", "!gzip -d SUSY.csv.gz" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "o62k3djk5v8D", "colab_type": "code", "outputId": "e33477f6-f8e1-4362-f3f4-1cdc59631dde", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "!ls" ], "execution_count": 4, "outputs": [ { "output_type": "stream", "text": [ "sample_data SUSY.csv\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "CJyjFJCk5lD6", "colab_type": "code", "colab": {} }, "source": [ "# Load with pandas\n", "import pandas as pd\n", "susy = pd.read_csv('SUSY.csv', header=None)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "cRk12pka6Gja", "colab_type": "code", "outputId": "bca15a68-cb3b-4fa2-c230-920142f98ac7", "colab": { "base_uri": "https://localhost:8080/", "height": 222 } }, "source": [ "# Inspect the first five rows:\n", "# Column 0 is the desired target.\n", "# Columns 1-18 are the input features for the task (both low-level and high-level, see the reference paper).\n", "susy.head()" ], "execution_count": 6, "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "
\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "10 | \n", "11 | \n", "12 | \n", "13 | \n", "14 | \n", "15 | \n", "16 | \n", "17 | \n", "18 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "0.0 | \n", "0.972861 | \n", "0.653855 | \n", "1.176225 | \n", "1.157156 | \n", "-1.739873 | \n", "-0.874309 | \n", "0.567765 | \n", "-0.175000 | \n", "0.810061 | \n", "-0.252552 | \n", "1.921887 | \n", "0.889637 | \n", "0.410772 | \n", "1.145621 | \n", "1.932632 | \n", "0.994464 | \n", "1.367815 | \n", "0.040714 | \n", "
1 | \n", "1.0 | \n", "1.667973 | \n", "0.064191 | \n", "-1.225171 | \n", "0.506102 | \n", "-0.338939 | \n", "1.672543 | \n", "3.475464 | \n", "-1.219136 | \n", "0.012955 | \n", "3.775174 | \n", "1.045977 | \n", "0.568051 | \n", "0.481928 | \n", "0.000000 | \n", "0.448410 | \n", "0.205356 | \n", "1.321893 | \n", "0.377584 | \n", "
2 | \n", "1.0 | \n", "0.444840 | \n", "-0.134298 | \n", "-0.709972 | \n", "0.451719 | \n", "-1.613871 | \n", "-0.768661 | \n", "1.219918 | \n", "0.504026 | \n", "1.831248 | \n", "-0.431385 | \n", "0.526283 | \n", "0.941514 | \n", "1.587535 | \n", "2.024308 | \n", "0.603498 | \n", "1.562374 | \n", "1.135454 | \n", "0.180910 | \n", "
3 | \n", "1.0 | \n", "0.381256 | \n", "-0.976145 | \n", "0.693152 | \n", "0.448959 | \n", "0.891753 | \n", "-0.677328 | \n", "2.033060 | \n", "1.533041 | \n", "3.046260 | \n", "-1.005285 | \n", "0.569386 | \n", "1.015211 | \n", "1.582217 | \n", "1.551914 | \n", "0.761215 | \n", "1.715464 | \n", "1.492257 | \n", "0.090719 | \n", "
4 | \n", "1.0 | \n", "1.309996 | \n", "-0.690089 | \n", "-0.676259 | \n", "1.589283 | \n", "-0.693326 | \n", "0.622907 | \n", "1.087562 | \n", "-0.381742 | \n", "0.589204 | \n", "1.365479 | \n", "1.179295 | \n", "0.968218 | \n", "0.728563 | \n", "0.000000 | \n", "1.083158 | \n", "0.043429 | \n", "1.154854 | \n", "0.094859 | \n", "