{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Variational Auto-Encoder\n", "\n", "Auto-encoder is a type of neural network that can learn a compact representation (encoding) of a dataset. \n", "\n", "Auto-encoders can be used for several purpose, including:\n", "- Dimensionality reduction, for the purpose of plotting, or to make further computations easier.\n", "- Data generation: create new data points resembling instances of the training set\n", "\n", "VAE uses a variational bayesian approach, i.e. it makes assumption about the distribution of the encoding and uses probability theory and statistical tools to train the encoder and decoders.\n", "\n", "We'll use VAE to cluster and classify instances of the [Epileptic Seizure Recognition Data Set](https://archive.ics.uci.edu/ml/datasets/Epileptic+Seizure+Recognition).\n", "\n", "Data points have 178 dimensions, so we'll use VAE to come up with a more compact representation, perhaps in as few as 2 or 3 dimensions. We will then apply clustering and classification (seizure vs no seizure) techniques before and after the dimensionality reduction, and see what works best.\n", "\n", "Resources:\n", "- [VAE original paper](https://arxiv.org/pdf/1312.6114.pdf)\n", "- [Good explanations of some of the maths in the paper](https://blog.evjang.com/2016/08/variational-bayes.html)\n", "- [Implementing VAE with Tensorflow](https://danijar.com/building-variational-auto-encoders-in-tensorflow/)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/Cellar/python/3.6.5/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n", " return f(*args, **kwds)\n" ] } ], "source": [ "import collections\n", "\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import sklearn.cluster\n", "import sklearn.manifold\n", "import sklearn.model_selection\n", "import sklearn.preprocessing\n", "import tensorflow as tf\n", "\n", "from classifier import BinaryClassifier\n", "from vae_model import VariationalAutoEncoder\n", "\n", "sns.set(font_scale=1.5, palette='colorblind')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "dataset = pd.read_csv(\"data/epileptic_seizure_dataset.csv\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | id | \n", "X1 | \n", "X2 | \n", "X3 | \n", "X4 | \n", "X5 | \n", "X6 | \n", "X7 | \n", "X8 | \n", "X9 | \n", "... | \n", "X170 | \n", "X171 | \n", "X172 | \n", "X173 | \n", "X174 | \n", "X175 | \n", "X176 | \n", "X177 | \n", "X178 | \n", "y | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "X21.V1.791 | \n", "135.0 | \n", "190.0 | \n", "229.0 | \n", "223.0 | \n", "192.0 | \n", "125.0 | \n", "55.0 | \n", "-9.0 | \n", "-33.0 | \n", "... | \n", "-17.0 | \n", "-15.0 | \n", "-31.0 | \n", "-77.0 | \n", "-103.0 | \n", "-127.0 | \n", "-116.0 | \n", "-83.0 | \n", "-51.0 | \n", "4 | \n", "
1 | \n", "X15.V1.924 | \n", "386.0 | \n", "382.0 | \n", "356.0 | \n", "331.0 | \n", "320.0 | \n", "315.0 | \n", "307.0 | \n", "272.0 | \n", "244.0 | \n", "... | \n", "164.0 | \n", "150.0 | \n", "146.0 | \n", "152.0 | \n", "157.0 | \n", "156.0 | \n", "154.0 | \n", "143.0 | \n", "129.0 | \n", "1 | \n", "
2 | \n", "X8.V1.1 | \n", "-32.0 | \n", "-39.0 | \n", "-47.0 | \n", "-37.0 | \n", "-32.0 | \n", "-36.0 | \n", "-57.0 | \n", "-73.0 | \n", "-85.0 | \n", "... | \n", "57.0 | \n", "64.0 | \n", "48.0 | \n", "19.0 | \n", "-12.0 | \n", "-30.0 | \n", "-35.0 | \n", "-35.0 | \n", "-36.0 | \n", "5 | \n", "
3 | \n", "X16.V1.60 | \n", "-105.0 | \n", "-101.0 | \n", "-96.0 | \n", "-92.0 | \n", "-89.0 | \n", "-95.0 | \n", "-102.0 | \n", "-100.0 | \n", "-87.0 | \n", "... | \n", "-82.0 | \n", "-81.0 | \n", "-80.0 | \n", "-77.0 | \n", "-85.0 | \n", "-77.0 | \n", "-72.0 | \n", "-69.0 | \n", "-65.0 | \n", "5 | \n", "
4 | \n", "X20.V1.54 | \n", "-9.0 | \n", "-65.0 | \n", "-98.0 | \n", "-102.0 | \n", "-78.0 | \n", "-48.0 | \n", "-16.0 | \n", "0.0 | \n", "-21.0 | \n", "... | \n", "4.0 | \n", "2.0 | \n", "-12.0 | \n", "-32.0 | \n", "-41.0 | \n", "-65.0 | \n", "-83.0 | \n", "-89.0 | \n", "-73.0 | \n", "5 | \n", "
5 rows × 180 columns
\n", "