{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# About \n",
"\n",
"This notebook demonstrates neural networks (NN) classifiers, which are provided by __Reproducible experiment platform (REP)__ package.
REP contains wrappers for following NN libraries:\n",
"* __theanets__\n",
"* __neurolab__ \n",
"* __pybrain__ \n",
"\n",
"\n",
"### In this notebook we show: \n",
"* train classifier\n",
"* get predictions \n",
"* measure quality\n",
"* pretraining and partial fitting\n",
"* combine classifiers using meta-algorithms\n",
"\n",
"Most of this is done in the same way as for other classifiers (see notebook [01-howto-Classifiers.ipynb](https://github.com/yandex/rep/blob/master/howto/01-howto-Classifiers.ipynb)). \n",
"\n",
"Parameters selected here are specially taken to make training very fast, those are very non-optimal."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### download particle identification data set from UCI"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File `MiniBooNE_PID.txt' already there; not retrieving.\r\n"
]
}
],
"source": [
"!cd toy_datasets; wget -O MiniBooNE_PID.txt -nc --no-check-certificate https://archive.ics.uci.edu/ml/machine-learning-databases/00199/MiniBooNE_PID.txt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy, pandas\n",
"from rep.utils import train_test_split\n",
"from sklearn.metrics import roc_auc_score\n",
"\n",
"data = pandas.read_csv('toy_datasets/MiniBooNE_PID.txt', sep='\\s*', skiprows=[0], header=None, engine='python')\n",
"labels = pandas.read_csv('toy_datasets/MiniBooNE_PID.txt', sep=' ', nrows=1, header=None)\n",
"labels = [1] * labels[1].values[0] + [0] * labels[2].values[0]\n",
"data.columns = ['feature_{}'.format(key) for key in data.columns]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"130064"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### First rows of data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/html": [
"
\n", " | feature_0 | \n", "feature_1 | \n", "feature_2 | \n", "feature_3 | \n", "feature_4 | \n", "feature_5 | \n", "feature_6 | \n", "feature_7 | \n", "feature_8 | \n", "feature_9 | \n", "... | \n", "feature_40 | \n", "feature_41 | \n", "feature_42 | \n", "feature_43 | \n", "feature_44 | \n", "feature_45 | \n", "feature_46 | \n", "feature_47 | \n", "feature_48 | \n", "feature_49 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "2.59413 | \n", "0.468803 | \n", "20.6916 | \n", "0.322648 | \n", "0.009682 | \n", "0.374393 | \n", "0.803479 | \n", "0.896592 | \n", "3.59665 | \n", "0.249282 | \n", "... | \n", "101.174 | \n", "-31.3730 | \n", "0.442259 | \n", "5.86453 | \n", "0.000000 | \n", "0.090519 | \n", "0.176909 | \n", "0.457585 | \n", "0.071769 | \n", "0.245996 | \n", "
1 | \n", "3.86388 | \n", "0.645781 | \n", "18.1375 | \n", "0.233529 | \n", "0.030733 | \n", "0.361239 | \n", "1.069740 | \n", "0.878714 | \n", "3.59243 | \n", "0.200793 | \n", "... | \n", "186.516 | \n", "45.9597 | \n", "-0.478507 | \n", "6.11126 | \n", "0.001182 | \n", "0.091800 | \n", "-0.465572 | \n", "0.935523 | \n", "0.333613 | \n", "0.230621 | \n", "
2 | \n", "3.38584 | \n", "1.197140 | \n", "36.0807 | \n", "0.200866 | \n", "0.017341 | \n", "0.260841 | \n", "1.108950 | \n", "0.884405 | \n", "3.43159 | \n", "0.177167 | \n", "... | \n", "129.931 | \n", "-11.5608 | \n", "-0.297008 | \n", "8.27204 | \n", "0.003854 | \n", "0.141721 | \n", "-0.210559 | \n", "1.013450 | \n", "0.255512 | \n", "0.180901 | \n", "
3 | \n", "4.28524 | \n", "0.510155 | \n", "674.2010 | \n", "0.281923 | \n", "0.009174 | \n", "0.000000 | \n", "0.998822 | \n", "0.823390 | \n", "3.16382 | \n", "0.171678 | \n", "... | \n", "163.978 | \n", "-18.4586 | \n", "0.453886 | \n", "2.48112 | \n", "0.000000 | \n", "0.180938 | \n", "0.407968 | \n", "4.341270 | \n", "0.473081 | \n", "0.258990 | \n", "
4 | \n", "5.93662 | \n", "0.832993 | \n", "59.8796 | \n", "0.232853 | \n", "0.025066 | \n", "0.233556 | \n", "1.370040 | \n", "0.787424 | \n", "3.66546 | \n", "0.174862 | \n", "... | \n", "229.555 | \n", "42.9600 | \n", "-0.975752 | \n", "2.66109 | \n", "0.000000 | \n", "0.170836 | \n", "-0.814403 | \n", "4.679490 | \n", "1.924990 | \n", "0.253893 | \n", "
5 rows × 50 columns
\n", "