{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grid Search in REP\n",
    "\n",
    "This notebook demonstrates tools to optimize classification model provided by __Reproducible experiment platform (REP)__ package:\n",
    "\n",
    "* __grid search for the best classifier hyperparameters__\n",
    "\n",
    "* __different optimization algorithms__ \n",
    "\n",
    "* __different scoring models__ (optimization of arbirtary figure of merit)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading data \n",
    "Dataset 'magic' from UCI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "File `magic04.data' already there; not retrieving.\r\n"
     ]
    }
   ],
   "source": [
    "!cd toy_datasets; wget -O magic04.data -nc --no-check-certificate https://archive.ics.uci.edu/ml/machine-learning-databases/magic/magic04.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy, pandas\n",
    "from rep.utils import train_test_split\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "columns = ['fLength', 'fWidth', 'fSize', 'fConc', 'fConc1', 'fAsym', 'fM3Long', 'fM3Trans', 'fAlpha', 'fDist', 'g']\n",
    "data = pandas.read_csv('toy_datasets/magic04.data', names=columns)\n",
    "labels = numpy.array(data['g'] == 'g', dtype=int)\n",
    "data = data.drop('g', axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple grid search example\n",
    "In this example we are optimizing\n",
    "* parameters of GradientBoostingClassifier\n",
    "* we maximize RocAuc (= area under the ROC curve)\n",
    "* using 4 threads (each time we train 4 classifiers)\n",
    "* we use 3-Folding to estimate quality.\n",
    "* we use only 30 trees to make examples run fast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy\n",
    "import pandas\n",
    "from rep import utils\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from rep.report.metrics import RocAuc\n",
    "from rep.metaml import GridOptimalSearchCV, FoldingScorer, RandomParameterOptimizer\n",
    "from rep.estimators import SklearnClassifier, TMVAClassifier, XGBoostRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# define grid parameters\n",
    "grid_param = {}\n",
    "grid_param['learning_rate'] = [0.2, 0.1, 0.05, 0.02, 0.01]\n",
    "grid_param['max_depth'] = [2, 3, 4, 5]\n",
    "\n",
    "# use random hyperparameter optimization algorithm \n",
    "generator = RandomParameterOptimizer(grid_param)\n",
    "\n",
    "# define folding scorer\n",
    "scorer = FoldingScorer(RocAuc(), folds=3, fold_checks=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing grid search in 4 threads\n",
      "4 evaluations done\n",
      "8 evaluations done\n",
      "10 evaluations done\n",
      "CPU times: user 38.5 s, sys: 609 ms, total: 39.1 s\n",
      "Wall time: 14.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time \n",
    "estimator = SklearnClassifier(GradientBoostingClassifier(n_estimators=30))\n",
    "grid_finder = GridOptimalSearchCV(estimator, generator, scorer, parallel_profile='threads-4')\n",
    "grid_finder.fit(data, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Looking at results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.917:  learning_rate=0.2, max_depth=3\n",
      "0.914:  learning_rate=0.1, max_depth=4\n",
      "0.903:  learning_rate=0.1, max_depth=3\n",
      "0.888:  learning_rate=0.01, max_depth=5\n",
      "0.885:  learning_rate=0.05, max_depth=3\n",
      "0.874:  learning_rate=0.01, max_depth=4\n",
      "0.870:  learning_rate=0.05, max_depth=2\n",
      "0.854:  learning_rate=0.01, max_depth=3\n",
      "0.850:  learning_rate=0.02, max_depth=2\n",
      "0.834:  learning_rate=0.01, max_depth=2\n"
     ]
    }
   ],
   "source": [
    "grid_finder.params_generator.print_results()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizing the parameters and threshold \n",
    "In many applications we need to optimize some binary metrics for classification (f1, BER, misclassification error), in which case we need each time after training classifier to find optimal threshold on predicted probabilities (default one is usually bad).\n",
    "\n",
    "In this example:\n",
    "* we are optimizing AMS (binary metric, that was used in Higgs competition at kaggle)\n",
    "* tuning parameters of TMVA's GBDT\n",
    "* using GaussianProcesses to make good guesses about next points to check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from rep.metaml import RegressionParameterOptimizer\n",
    "from sklearn.gaussian_process import GaussianProcess\n",
    "from rep.report.metrics import OptimalMetric, ams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Performing grid search in 3 threads\n",
      "3 evaluations done\n",
      "6 evaluations done\n",
      "9 evaluations done\n",
      "12 evaluations done\n",
      "CPU times: user 8.39 s, sys: 1.75 s, total: 10.1 s\n",
      "Wall time: 1min 17s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# OptimalMetrics is a wrapper which is able to check all possible thresholds\n",
    "# expected number of signal and background events are taken as some arbitrary numbers\n",
    "optimal_ams = OptimalMetric(ams, expected_s=100, expected_b=1000)\n",
    "\n",
    "# define grid parameters\n",
    "grid_param = {'Shrinkage': [0.4, 0.2, 0.1, 0.05, 0.02, 0.01], \n",
    "              'NTrees': [5, 10, 15, 20, 25], \n",
    "              # you can pass different sets of features to be compared\n",
    "              'features': [columns[:2], columns[:3], columns[:4]],\n",
    "             }\n",
    "\n",
    "# using GaussianProcesses \n",
    "generator = RegressionParameterOptimizer(grid_param, n_evaluations=10, regressor=GaussianProcess(), n_attempts=10)\n",
    "\n",
    "# define folding scorer\n",
    "scorer = FoldingScorer(optimal_ams, folds=2, fold_checks=2)\n",
    "\n",
    "grid_finder = GridOptimalSearchCV(TMVAClassifier(method='kBDT', BoostType='Grad',), generator, scorer, parallel_profile='threads-3')\n",
    "grid_finder.fit(data, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Looking at results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.348:  Shrinkage=0.4, NTrees=20, features=['fLength', 'fWidth', 'fSize', 'fConc']\n",
      "4.253:  Shrinkage=0.4, NTrees=25, features=['fLength', 'fWidth', 'fSize']\n",
      "4.222:  Shrinkage=0.4, NTrees=20, features=['fLength', 'fWidth', 'fSize']\n",
      "4.201:  Shrinkage=0.4, NTrees=10, features=['fLength', 'fWidth', 'fSize', 'fConc']\n",
      "4.188:  Shrinkage=0.4, NTrees=15, features=['fLength', 'fWidth', 'fSize']\n",
      "4.152:  Shrinkage=0.2, NTrees=20, features=['fLength', 'fWidth', 'fSize']\n",
      "4.130:  Shrinkage=0.2, NTrees=15, features=['fLength', 'fWidth', 'fSize']\n",
      "4.064:  Shrinkage=0.1, NTrees=15, features=['fLength', 'fWidth', 'fSize', 'fConc']\n",
      "4.060:  Shrinkage=0.1, NTrees=15, features=['fLength', 'fWidth', 'fSize']\n",
      "3.983:  Shrinkage=0.05, NTrees=10, features=['fLength', 'fWidth', 'fSize', 'fConc']\n",
      "3.845:  Shrinkage=0.01, NTrees=10, features=['fLength', 'fWidth', 'fSize', 'fConc']\n",
      "3.696:  Shrinkage=0.1, NTrees=15, features=['fLength', 'fWidth']\n"
     ]
    }
   ],
   "source": [
    "grid_finder.generator.print_results()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's see dynamics over time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1101333d0>]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAEACAYAAABMEua6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XeYlOW5x/HvTZMmdo2CCkeNQSxIDLaYjC1iN5rYNShH\nN/YoAlGjbjTBSizRKJZYMLZgiRxRLDgaExXFpegGsQQLEjAoTerKff54ZmVZZndnZmfmnXnn97ku\nLmd333nnngV/++xTzd0REZF4aBN1ASIikj8KdRGRGFGoi4jEiEJdRCRGFOoiIjGiUBcRiZGMQt3M\n2ppZjZmNaeaaH5hZnZkdmb/yREQkG5m21M8DaoG0k9rNrC1wDfAsYPkpTUREstViqJtZD+Ag4C6a\nDuxzgNHAF/krTUREspVJS/0GYAiwMt0Xzaw7cDhwW+pTWqIqIhKRZkPdzA4B5rh7DU230m8Efu1h\nvwFr5joRESkwa27vFzMbDpwE1AEdgW7AY+5+coNrPmJVkG8ILAZOc/enGt1LLXgRkRy4e+aNZXfP\n6A/wY2BMC9fcAxzZxNc8zi6//PKoSygovb/yFef35h7/95fKzoyzOtt56g5gZlVmVpXlc0VEpMDa\nZXqhu78MvJx6PLKJa07JU10iIpIDrSjNk0QiEXUJBaX3V77i/N4g/u8vW80OlOb1hcy8WK8lIhIX\nZpbVQKla6iIiMaJQFxGJEYW6iEiMKNRFRGJEoS4iEiMKdRGRGFGoi4jEiEJdRCRD770X/pQyhbqI\nSIYuuACefz7qKpqX8d4vIiKV7LXXYOpUePzxqCtpnlrqIiIZuOwy+M1vYK21oq6keQp1EZEWvPIK\nfPghnFIG+9Aq1EVEmuEOl14aWurt20ddTcsU6iIizRg/Hv7zHzjxxKgryYxCXUSkCfWt9Msvh3Zl\nMq1EoS4i0oRnn4X58+GYY6KuJHMKdRGRNNxDP3p1NbRtG3U1mcso1M2srZnVmNmYNF873Mwmp74+\n0cz2yX+ZIiLF9dRTsHw5HHVU1JVkJ6Pj7MzsAuD7wNruflijr3Vx969Tj3cAnnD3rdPcQ8fZiUhZ\nWLkSdt4ZrrgCDj882lryfpydmfUADgLuAta4cX2gp3QF/pvpi4uIlKLHHoMOHeCww1q+ttRkMp57\nAzAE6NbUBWZ2BHAVsCnwk/yUJiJSfN98E2a7jBgBlnH7uHQ021I3s0OAOe5eQ5pWej13f9LdewOH\nAqPyW6KISPE8/DCsuy4MGBB1JblpqaW+B3CYmR0EdAS6mdn97n5yuovd/e9m1s7MNnD3uY2/Xl1d\n/e3jRCJBIpHIuXARkXyrq4Pf/hZuuy26VnoymSSZTOb8/IwGSgHM7MfAhe5+aKPPbwV85O5uZv2A\nv7r7Vmmer4FSESlp99wD990HL71UOl0v2Q6UZrtGylMvUgXg7iOBo4CTzWwFsAg4Nst7iohEbsUK\nuPLKEOqlEui5yLil3uoXUktdRErYHXfA6NHw3HNRV7K6bFvqCnURqXjLlsE228Cjj8Juu0Vdzery\nPk9dRCTu7rwTdtih9AI9F2qpi0hFW7IEtt46bAvw/e9HXc2a1FIXEcnCbbdB//6lGei5UEtdRCrW\nokWhlf7cc7DjjlFXk55a6iIiGbrlFvjxj0s30HOhlrqIVKQFC0Ir/eWXoXfvqKtpmlrqIiIZuPFG\nOOCA0g70XKilLiKttnx5aPE+9VRYjTliBLRvH3VVTfvqqzAv/fXXQ2u9lKmlLiJFMXcujBoFRx8N\nG28cjn7bdFP48EM47riw7L5UjRgRDr8o9UDPhVrqIpKx6dNDa3zMGKipgX32CQdJHHwwbLJJuGbZ\nsnAEXKdO8OCDpddi/+9/YdttYeJE6Nkz6mpapm0CRCRv6urgtddWBfmCBXDooSHI99knBHc6y5bB\nkUdCly7wl7+UVrAPGwbz58Ptt0ddSWYU6iLSKgsXwrhxIcjHjoXNNw8hfuih0K8ftMmw03bp0hDs\nXbuGFnu7bPeELYDZs8PA6OTJ4X2VA4W6iGTtk09CS/ypp0LLfI89QpAfcghssUXu9126FH76U+jW\nLbTYow72888Px9XdfHO0dWRDoS4iLVq5MvQp1wf5zJlw0EEhyH/yE1h77fy9Vn2wr7MOPPBAdME+\nc2bYtOvdd8OAbrlQqItIWkuWwIsvhhD/v/8LIVvfP7777tC2beFee+lSOOIIWG+9MGMmimA/66ww\nBnD99cV/7dZQqIvIt+rqQog++WQ4oq1fv1X949tsU9xali4N0wg32ADuv7+4wf7xx+G9T5sGG21U\nvNfNB4W6iHzr6qvhr3+FCy6AAw+E9dePtp4lS0Kwb7RRODauWMF+2mnhNYcPL87r5ZNCXUQA+OCD\ncOjDm29Cr15RV7NKw2C///7CdvtAWAy1665hjn3UP9RyUbAVpWbW1sxqzGxMmq+dYGaTzWyKmf3D\nzGK055lI+XGHX/4SLrqotAIdQr/23/4Gc+bAL34RZqMU0hVXwNlnl2eg5yKbbQLOA2qBdM3tj4Af\nufuOwJXAHXmoTURyNGoUfPklnHde1JWkVx/s//kPDBxYuGCfNi3MtT///MLcvxRlFOpm1gM4CLgL\nWOPXAHd/zd3npz58A+iRtwpFJCtffAFDhoRzN6OeF96czp3DTJzPP4dTTilMsF9xRQj0ddbJ/71L\nVaYt9RuAIcDKDK4dBIzNuSIpObW1MGFCWI2nYZHSd8EFcNJJ5XE8W+fOYa78zJlw6qn5DfZ33glT\nOM85J3/3LAct/hw3s0OAOe5eY2aJFq7dGzgV2DPd16urq799nEgkSCSavZ2UiJ/+FDp0gFmzYPHi\nsMKwZ0/Ycss1/7vpppkvI5f8e+45ePXVEGjloj7YDzkEBg2Cu+/Oz+BpdTVceGF+F1IVQzKZJJlM\n5vz8Fme/mNlw4CSgDugIdAMec/eTG123I/A4MMDdP0hzH81+KUPz50P37jBvXvhVftGiMOd3xozV\n/1v/eN486NFjVcg3Dv7u3Uu7S6CcLV4M228Pt94api+Wm6+/DsHesyfcdVfrgn3SpPA9+PDD8EOj\nnBV0SqOZ/Ri40N0PbfT5LYDxwInu/noTz1Wol6GXXw4zKP75z8yuX7Ik7COSLvhnzAgzHjbbLH0r\nv2fPsMlShw6FejfxNnQofPZZ2DyrXNUHe69eIdhz/a3vsMNg331Ld6A4G9mGei5tJk+9UBWAu48E\nLgPWA24zM4AV7t4/h3tLiZk4MazEy1SnTmGv6m23Tf/15cvh009Xb92/8sqq4J85M3ThnH9+WNZd\nSlu2lrKaGrj3Xpg6NepKWqdLl7CFwcEHw//+b27BPmFC+H48+mhhaix1WnwkzTrhBNhvvzA7oRjq\n6kJ/8NChYVbEzTeHfbulad98ExbXnHVW8f6eCu3rr8MGY1tvHWbxZBPsAwaExU1nnFG4+opJx9lJ\nXr39dnFnUbRrB337hv28r7wyDJz9/OehS0fS++Mfw2DgwIFRV5I/XbrA00/D++/D6aeHXSUz8Y9/\nhLnpgwYVtr5SplCXJi1cGMJ0u+2K/9pmYdZNbW3YLrVfvzDneMmS4tdSyj7+GH73Oxg5MnzP4qRr\n17BwaPp0qKrKLNgvvTT8qeRxGYW6NGnSpDCbIsrZKp06hQONJ06EKVOgT5+w46B68sL34IwzwvjD\nd78bdTWFUR/s06aFbQ+aC/aXXgqNkJNPbvqaSqBQlyZNnFg6C1i23BJGjw79qxdfHPpNp02Luqpo\nPfpoGHQeMiTqSgqrPthra8MPsXTB7h5a6JdfrsF1hbo0qdj96ZnYd99wvuSBB8Jee4XFJQsWRF1V\n8X31VWih33FHZXQ1rL02PPNMOLXozDPXDPbnnoP//heOPz6a+kqJQl2alO10xmJp3x5+9aswS+bL\nL+F73wtbuGY6mBYHQ4aEQ5133z3qSoqnPtinTg0zfer/vt1DF111deG38S0HmtIoaX39ddjvet68\n0m8JTpgQtlZt1y7MBCm13y7yLZkMe7u8+2440LnSLFwYut922imsnn366bBAbvLkeG5RoSmNkheT\nJoVByVIPdID+/eH118NilUMOCVPgvvgi6qoKY+nSMBPkj3+szECHVS32SZNCi/2yy+C3v41noOdC\n3wZJqxT705vTpk3Y5e9f/wpznPv0CcFXVxd1Zfk1fHiYkXTEEVFXEq1u3eDZZ0OwQ5j+KoG6XySt\ngQNhjz1Cq7ccvfsunHtuaLHffDPEYUPQd98N72PSpLAxmoRNzOq7CuNK3S+SF6U0nTEXffrACy+E\nKW4DB8Ixx4Tpf+Vq5crwA/aKKxToDXXuHO9Az4VCXdaweHHYsnT77aOupHXM4Kijwvzm730Pdt4Z\nfv/70C9dbkaODP+tqoq2Dil9CnVZw5Qp0Ls3rLVW1JXkR+fOYSDtzTfDbyB9+oRj1MqlN3DmzDAY\neMcdGgyUlumfiKyhVOent1avXvD443D77TBsWNgFcPr0qKtq2bnnhpWUffpEXYmUA4W6rKHc+9Nb\nsv/+4beR/feHPfcMAf/111FXld6TT4ZFVhdfHHUlUi4U6rKGcpvOmIv27cMBzVOnhu6NHXeE8eOj\nrmp1CxaERVV33AEdO0ZdjZQLTWmU1SxdCuuvH5bfV1KQPP106OIYMACuuw7WWSfqikKgL10aTv+R\nyqUpjdIqU6aEbVwrKdAhHJ/2zjthIHL77cPp9lF67bXQ/3/dddHWIeVHoS6riXt/enO6dQuDqKNG\nha6Z44+PZruB5cvhtNPghhtgvfWK//pS3jIKdTNra2Y1ZrZG+8XMvmdmr5nZUjMbnP8SpZgqoT+9\nJYlE2Byqe/dw6tJDDxV3+uN114X9448+univKfGRUZ+6mV0AfB9Y290Pa/S1jYAtgSOAr9x9RBP3\nUJ96GejXD/70J9htt6grKQ0TJoTzLnv1gttuK/xqzunTw/YMEyeGYBfJe5+6mfUADgLuAta4sbt/\n4e5vASuyKVRKz7Jl4TShnXaKupLS0b//qnn7ffuGk5cK1TZxDytGf/MbBbrkLpPulxuAIUAFHUFQ\nmd55B7beOpwLKqt06BAOYBg/PoT6vvuGbRTy7d57YdEiOOec/N9bKkezRwqb2SHAHHevMbNEa1+s\nurr628eJRIJEHLbOi5FKHiTNxA47hFkpN94Iu+4Kl1wSVnvm47Sd2bPDIqhx43R6T6VLJpMkk8mc\nn99sn7qZDQdOAuqAjkA34DF3X+O8bjO7HFikPvXyVVUVpvOppdiyDz4Ih3IsXQp33936JfzHHw89\nesC11+anPomPvPapu/vF7r65u/cCjgXGpwv0+tfOok4pQWqpZ27rrUN3zCmnhNkyV14ZpiLm4pln\nwslNDX6RFclZtvPUHcDMqsysKvX4O2b2KXA+8Bsz+8TMuua5Timw5cvDFrUaJM1cmzbht5u33w6h\n/IMfwFtvZXePr7+GM88M8+M7dy5MnVJZtE2AAFBTAyeeGE7Xkey5w4MPhkVLAweGVncmA86DB8Oc\nOWHBk0g62iZAcqKul9YxgxNOCBuEzZgRfuN55ZXmnzNxIjzwAPzhD0UpUSqEQl0AhXq+bLwxPPJI\nGPA87rhw2v3ChWteV1cXtgK49lodxyb5pVAXQNsD5NsRR4R5/8uWhRlFzz67+tdvuinshnlyU9MO\nRHKkPnVhxQpYd90wV7qrhrjz7oUXQqv8Rz8KXS0LFoRB1ddfD7NoRJqjPnXJWm0tbLGFAr1Q9tsv\n9LWvt15YwPTzn8OFFyrQpTBiGep1daH1KZlR10vhde0aVqKOHh1a6YO1n6kUSOxC/YMPQkBVVUVd\nSfmI60HTpWiPPcJuj+3bR12JxFWsQv3JJ8P/NCecEB5/8knUFZUHzXwRiY9YDJTW1YXT1h95BB59\nNGy2NHgwrFwZTo+RptXVhfM4Z80KJ/+ISGnJdqC07EN91iw49thwpuZf/gIbbhg+/9ln4YT4Dz4I\nU8ckvXfegaOOgvfei7oSEUmnoma/vPwy7LIL7L03jB27KtAh7Hh3+OFw663R1VcO1J8uEi9lGeru\n4RzHY44J255WV6ffg3roULjlFli8uOgllg31p4vES9mF+vz5cOSRYWrYhAkwYEDT1/buDbvvDvfc\nU7z6yo2mM4rES1mF+uTJIYA22yxslrTFFi0/Z9gwuP76MCAoq/vmm/A93XnnqCsRkXwpm1C/996w\nMu+KK0I/+VprZfa83XeHzTcPs2Jkde+9B5tsErYIEJF4aPaM0lKwdGk4Xu3vf4dkMrdjw4YNC+dJ\nHndc2CJVAvWni8RPSbfUP/ooLCZasADefDP3cyAPOih0NYwbl9/6yp3600Xip2RDfcwY2G03+MUv\n4OGHYe21c7+XWZgJc801+asvDjSdUSR+Sm7xUV0dXHZZON7rkUdCSz0fVqwIu+L99a/Qv39+7lnO\nVq4MfekzZmhxlkgpK8jiIzNra2Y1Zjamia/fbGbvm9lkM8t5LsXs2fCTn4SpihMn5i/QIWygNHiw\nWuv1pk8Pi7UU6CLxkmn3y3lALbBGU9vMDgK2dvdtgNOB23Ip5NVXQ//unnuGvu+NN87lLs0bNCgM\nuGpJvPrTReKqxVA3sx7AQcBdQLpfAQ4D7gNw9zeAdc1sk0wLcA+bbh11FIwcCVdemX51aD506QJn\nnhnmrVc69aeLxFMmLfUbgCHAyia+3h34tMHHnwE9MnnxBQvg6KPDRlxvvAEHH5zJs1rn7LPhscfg\n888L/1qlTNMZReKp2XnqZnYIMMfda8ws0dyljT5OOyJaXV397eMttkhwzTUJ9t47DIp27JhZwa21\n4YZw4onhFJprry3Oa5aalSuhpkahLlKKkskkyWQy5+c3O/vFzIYDJwF1QEegG/CYu5/c4JrbgaS7\nP5z6eBrwY3ef3ehe385+GTUKLrgARoyI5jT1GTNCoH34YWWupnz/fdh///B9EJHSltfZL+5+sbtv\n7u69gGOB8Q0DPeUp4OTUi+8GzGsc6PWWLYMzzgj95uPHRxPoAD17woEHwu23R/P6UVN/ukh8Zbv4\nyAHMrMrMqgDcfSzwkZl9AIwEzmzqyT/8IcyZA2+9FU5Vj9LQoXDTTWEbgkqj/nSR+Crq4qMRI5zz\nzy+d/VcOPjgcpHH66VFXUlz77gtDhjS/bbGIlIaKO86uNV55JcxdnzatcNMoS417WHD03nuFWQsg\nIvlVUcfZtdZee4XZME88EXUlxfPRR9C1qwJdJK4qOtTNwra811wTWrCVQP3pIvFW0aEOcNhhsGhR\nmI1TCbQ9gEi8VXyot2kTBg0rZaMvTWcUibeKHiitt2wZbLUVPPVUvAPPHTbYAGpr4TvfiboaEcmE\nBkpzsNZa8KtfxX/bgI8/hk6dFOgicaZQTzn9dHjhhbB1QFyp60Uk/hTqKd26QVVV2I8mrjTzRST+\nFOoNnHtuOA91zpyoKykMhbpI/CnUG9hkEzjmGLj55qgryT93TWcUqQSa/dLIhx/CrrvCv/8Na68d\ndTX588kn4cDtWbNKZ+8dEWmZZr+00lZbhQ2v7rgj6kryq77rRYEuEm8K9TSGDQvnpi5fHnUl+aOu\nF5HKoFBPo18/2G67cHZqXGiQVKQyqE+9CS++GA6pfvfdsJVAOXMPC47eegs23zzqakQkG+pTz5N9\n9oEuXWDMmKgrab2ZM0Ow9+gRdSUiUmgK9SbEaVve+v50DZKKxJ9CvRlHHglffAGvvhp1Ja2j7QFE\nKkeLoW5mHc3sDTObZGa1ZnZVmmvWM7MnzGxy6to+hSm3uNq2hQsvhKuvjrqS1tEgqUjlyGig1Mw6\nu/tiM2sHvApc6O6vNvj6dcACd7/SzLYFbnX3/Rrdo6wGSustXQq9esFzz8EOO0RdTW423RRefx22\n3DLqSkQkWwUZKHX3xamHHYC2wJeNLukNvJS69j2gp5ltlGkRpaxjx7AnTLluyztrFqxYAVtsEXUl\nIlIMGYW6mbUxs0nAbOAld69tdMlk4MjUtf2BLYHYzLU44wwYOzbsR15u6vvTNUgqUhnaZXKRu68E\n+prZOsA4M0u4e7LBJVcDN5lZDTAVqAG+aXyf6urqbx8nEgkSiUTOhRfTuuvCoEHwhz/ATTdFXU12\n1J8uUl6SySTJZDLn52e9+MjMLgWWuPv1zVzzb2AHd1/U4HNl2ade7/PPYfvt4f33w5Fw5eLww+Gk\nk+BnP4u6EhHJRd771M1sQzNbN/W4E7A/oSXe8Jp1zKxD6vFpwMsNAz0ONtssTHG85ZaoK8mOWuoi\nlaXFlrqZ7QDcR/gB0AYY5e7XmVkVgLuPNLPdgXsBB94BBrn7/Eb3KeuWOsB778Fee4Vtebt0ibqa\nls2eDb17w9y56lMXKVfZttS190uWjjwSEokwI6bUjR0bxgFeeCHqSkQkV9r7pcCGDQvnmK5YEXUl\nLdN2uyKVR6GepV13DYuRHnkk6kpapu0BRCqPQj0Hv/51WIxU6r1JGiQVqTwK9RwccEDYY/2ZZ6Ku\npGlffAELFoTj+USkcijUc2AGQ4eGbXlL1dtvayWpSCVSqOfo6KPhk0/CRlmlSP3pIpVJoZ6jdu1g\n8ODSba2rP12kMinUW+HUU+Gf/4R//SvqStak6YwilUmh3gqdO8NZZ8F110Vdyermzg1/tt466kpE\npNgU6q101lnw5JPw2WdRV7LK22/DzjuHGToiUln0v30rbbBB2Jb3kkuirmQV9aeLVC6Feh5cfnk4\nnHrMmKgrCdSfLlK5FOp50LUr/PnP8MtfwpeND/qLgKYzilQu7dKYR+eeC199BaNGRVfDV1+F80jn\nzYO2baOrQ0TyQ7s0Ruiqq+C11+Bvf4uuhpoa6NtXgS5SqRTqedSlC9xzTzioeu7caGrQIKlIZVOo\n59lee4UtBKI6REP96SKVTaFeAMOHw4QJ8MQTxX9ttdRFKpsGSgvkH/+An/0Mpk6FDTcszmvOnw/d\nu4f/qk9dJB7yPlBqZh3N7A0zm2RmtWZ2VZprNjSzZ1PXvGNmA7OsO3b23BOOPx7OPrt4r1lTAzvu\nqEAXqWQthrq7LwX2dve+wI7A3mb2w0aXnQ3UpK5JACPMrF2+iy03v/tdCNrRo4vzeup6EZGM+tTd\nfXHqYQegLdB4ic0soFvqcTdgrrvX5aXCMtapE9x7L5xzTjiJqNAU6iKSUaibWRszmwTMBl5y99pG\nl9wJ9DGzz4HJwHn5LbN87b47nHhi2Pir0LQ9gIhk1EXi7iuBvma2DjDOzBLunmxwycXAJHdPmNlW\nwPNmtpO7L2x4n+rq6m8fJxIJEolEK8svD1dcEaYZPvpomO5YCAsXwqefQu/ehbm/iBRHMpkkmUzm\n/PysZ7+Y2aXAEne/vsHnxgK/d/d/pD5+ERjm7m81uKaiZr809sYbcPjhMGUKbLxx/u//yivh3NRS\nPV5PRHJTiNkvG5rZuqnHnYD9gZpGl00D9ktdswmwLfBRpkVUgl13hYED4cwzoRA/29T1IiKQWZ/6\npsD4VJ/6G8AYd3/RzKrMrCp1zXBgFzObDLwADHX3EtivsLRUV4ej7x55JP/31iCpiIAWHxXdhAlw\n6KEweTJ85zv5u+9228GDD4bNvEQkPrLtflGoR+Dii0OL/fHHwTL+q2raokWwySZhu9327Vt/PxEp\nHdp6twxcfjm8/z489FB+7jd5MvTpo0AXEYV6JNZaKyxKOv98mDWr9ffTzowiUk+hHpFddoHTToOq\nqtbPhtEgqYjUU6hH6NJLYcYMeOCB1t1HoS4i9TRQGrG334YBA2DSJNhss+yfv3hx2Np33jzo0CH/\n9YlItDRQWmb69QvH3+XaDTN5ctgaQIEuIqBQLwmXXBL2bbn//uyfq64XEWlIoV4COnQIs2GGDIGZ\nM7N7rrYHEJGGFOolom/fsD3vaadl1w2j6Ywi0pAGSkvIihXQvz+cey6cckrL1y9ZAhtsAF9+CR07\nFr4+ESk+DZSWsfbtQzfM0KGhj70lU6bAttsq0EVkFYV6idlpp9BSz6QbRv3pItKYQr0E/frX4UzT\nu+9u/jr1p4tIYwr1ElTfDXPRRfDJJ01fp+mMItKYBkpL2PDhkEzCuHFrbtG7bBmstx7MnQudOkVS\nnogUgQZKY2ToUPjqK7jzzjW/NnUqbLONAl1EVqdQL2Ht2oVumEsugY8/Xv1r6k8XkXQyOXi6o5m9\nYWaTzKzWzK5Kc82FZlaT+jPVzOrqD6uW1unTBwYPhkGDVp8No/50EUmnxVB396XA3u7eF9gR2NvM\nftjomuvdfWd33xm4CEi6+7yCVFyBLrwQFi6EkSNXfU7TGUUknXaZXOTui1MPOwBtgS+bufx4IE8H\ntQmEbph77oEf/QgOOAC6d4fa2jCnXUSkoYz61M2sjZlNAmYDL7l7bRPXdQYOAB7LX4kCsN12YeB0\n0KAwSPo//wOdO0ddlYiUmoxC3d1XprpfegA/MrNEE5ceCryqrpfCGDw47Pdy9tnqehGR9DLqfqnn\n7vPN7GlgFyCZ5pJjaabrpbq6+tvHiUSCRCKRzctXvLZtQzdM375w3HFRVyMihZBMJkkmkzk/v8XF\nR2a2IVDn7vPMrBMwDvitu7/Y6Lp1gI+AHu6+JM19tPgoT154IfSnb7RR1JWISKFlu/gok5b6psB9\nZtaG0F0zyt1fNLMqAHevn5NxBDAuXaBLfu23X9QViEip0jYBIiIlTNsEiIhUMIW6iEiMKNRFRGJE\noS4iEiMKdRGRGFGoi4jEiEJdRCRGFOoiIjGiUBcRiRGFuohIjCjURURiRKEuIhIjCnURkRhRqIuI\nxIhCXUQkRhTqIiIxolAXEYkRhbqISIwo1EVEYqTZUDezjmb2hplNMrNaM7uqiesSZlZjZu+YWbIg\nlYqISIuaDXV3Xwrs7e59gR2Bvc3shw2vMbN1gVuBQ919e+BnhSq2lCWTyahLKCi9v/IV5/cG8X9/\n2Wqx+8XdF6cedgDaAl82uuR44DF3/yx1/X/zWmGZiPs/LL2/8hXn9wbxf3/ZajHUzayNmU0CZgMv\nuXtto0u2AdY3s5fM7C0zO6kQhYqISMvatXSBu68E+prZOsA4M0u4e7LBJe2BfsC+QGfgNTN73d3f\nL0TBIiKbIspDAAAD3ElEQVTSNHP3zC82uxRY4u7XN/jcMKCTu1enPr4LeNbdRzd6buYvJCIi33J3\ny/TaZlvqZrYhUOfu88ysE7A/8NtGl/0NuMXM2gJrAbsCf2hNUSIikpuWul82Be4zszaE/vdR7v6i\nmVUBuPtId59mZs8CU4CVwJ1p+t1FRKQIsup+ERGR0laUFaVmNsDMppnZ+6k++Ngws81TM3/eTS2+\nOjfqmvLNzNqmFpeNibqWfDOzdc1stJn9K7XAbreoa8onM7so9W9zqpk9aGZrRV1Ta5jZn81stplN\nbfC59c3seTObbmbPpdbOlKUm3t91qX+fk83s8dSklSYVPNRTfe23AAOA7YDjzKx3oV+3iFYA57t7\nH2A34KyYvT+A84BaII6/1t0EjHX33oQFdv+KuJ68MbOewGlAP3ffgbDO5Ngoa8qDewhZ0tCvgefd\n/bvAi6mPy1W69/cc0MfddwKmAxc1d4NitNT7Ax+4+wx3XwE8DBxehNctCnf/j7tPSj1eRAiFzaKt\nKn/MrAdwEHAXEKvB7lSLZy93/zOAu9e5+/yIy8qnBYRGR2cza0eYcjwz2pJax93/DnzV6NOHAfel\nHt8HHFHUovIo3ftz9+dTU8sB3gB6NHePYoR6d+DTBh9/lvpc7KRaRjsTvvFxcQMwhDAIHje9gC/M\n7B4ze9vM7jSzzlEXlS/u/iUwAvgE+ByY5+4vRFtVQWzi7rNTj2cDm0RZTIGdCoxt7oJihHocf2Vf\ng5l1BUYD56Va7GXPzA4B5rh7DTFrpae0Iyyc+5O79wO+prx/dV+NmW0F/AroSfjtsauZnRBpUQXm\nYeZHLDPHzC4Blrv7g81dV4xQnwls3uDjzQmt9dgws/bAY8AD7v5k1PXk0R7AYWb2b+AhYB8zuz/i\nmvLpM+Azd38z9fFoQsjHxS7AP919rrvXAY8T/k7jZraZfQfAzDYF5kRcT96Z2UBCN2iLP5SLEepv\nAduYWU8z6wAcAzxVhNctCjMz4G6g1t1vjLqefHL3i919c3fvRRhgG+/uJ0ddV764+3+AT83su6lP\n7Qe8G2FJ+TYN2M3MOqX+ne5HGPCOm6eAX6Qe/wKIU8MKMxtA6AI9PLVzbrMKHuqpFsLZwDjCP6hH\n3D02MwyAPYETCdsS16T+NB69jos4/lp7DvAXM5tMmP0yPOJ68sbdJwP3ExpWU1KfviO6ilrPzB4C\n/glsa2afmtkpwNXA/mY2Hdgn9XFZSvP+TgX+CHQFnk/ly5+avYcWH4mIxIeOsxMRiRGFuohIjCjU\nRURiRKEuIhIjCnURkRhRqIuIxIhCXUQkRhTqIiIx8v8Vgvfb/wL3IwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x10e4d6b50>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot(grid_finder.generator.grid_scores_.values())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimizing complex models + using custom scorer\n",
    "__REP__ supports sklearn-way of combining classifiers and getting/setting their parameters.\n",
    "\n",
    "So you can tune complex models using the same approach. \n",
    "\n",
    "Let's optimize \n",
    "* BaggingRegressor over XGBoost regressor, we will select appropriate parameters for both\n",
    "* we will roll new scorer, which test everything on special part of dataset\n",
    "* we use the same data, which will be once split into train and test (this scenario of testing is sometimes needed)\n",
    "* optimizing MAE (mean absolute error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.ensemble import BaggingRegressor\n",
    "from rep.estimators import XGBoostRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from rep.utils import train_test_split\n",
    "# splitting into train and test\n",
    "train_data, test_data, train_labels, test_labels = train_test_split(data, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_absolute_error\n",
    "from sklearn.base import clone\n",
    "\n",
    "class MyMAEScorer(object):\n",
    "    def __init__(self, test_data, test_labels):\n",
    "        self.test_data = test_data\n",
    "        self.test_labels = test_labels\n",
    "        \n",
    "    def __call__(self, base_estimator, params, X, y, sample_weight=None):\n",
    "        cl = clone(base_estimator)\n",
    "        cl.set_params(**params)\n",
    "        cl.fit(X, y)\n",
    "        # Returning with minus, because we maximize metric\n",
    "        return - mean_absolute_error(self.test_labels, cl.predict(self.test_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 27.4 s, sys: 300 ms, total: 27.7 s\n",
      "Wall time: 28 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# define grid parameters\n",
    "grid_param = {\n",
    "    # parameters of sklearn Bagging\n",
    "    'n_estimators': [1, 3, 5, 7], \n",
    "    'max_samples': [0.2, 0.4, 0.6, 0.8],\n",
    "    # parameters of base (XGBoost)\n",
    "    'base_estimator__n_estimators': [10, 20, 40], \n",
    "    'base_estimator__eta': [0.1, 0.2, 0.4, 0.6, 0.8]\n",
    "}\n",
    "\n",
    "# using Gaussian Processes \n",
    "generator = RegressionParameterOptimizer(grid_param, n_evaluations=10, regressor=GaussianProcess(), n_attempts=10)\n",
    "\n",
    "estimator = BaggingRegressor(XGBoostRegressor(), n_estimators=10)\n",
    "\n",
    "scorer = MyMAEScorer(test_data, test_labels)\n",
    "\n",
    "grid_finder = GridOptimalSearchCV(estimator, generator, scorer, parallel_profile=None)\n",
    "grid_finder.fit(data, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.158:  n_estimators=3, max_samples=0.6, base_estimator__n_estimators=40, base_estimator__eta=0.8\n",
      "-0.161:  n_estimators=3, max_samples=0.6, base_estimator__n_estimators=40, base_estimator__eta=0.6\n",
      "-0.168:  n_estimators=3, max_samples=0.4, base_estimator__n_estimators=40, base_estimator__eta=0.6\n",
      "-0.169:  n_estimators=3, max_samples=0.4, base_estimator__n_estimators=40, base_estimator__eta=0.4\n",
      "-0.179:  n_estimators=1, max_samples=0.8, base_estimator__n_estimators=40, base_estimator__eta=0.2\n",
      "-0.182:  n_estimators=1, max_samples=0.4, base_estimator__n_estimators=40, base_estimator__eta=0.2\n",
      "-0.184:  n_estimators=1, max_samples=0.6, base_estimator__n_estimators=40, base_estimator__eta=0.2\n",
      "-0.184:  n_estimators=1, max_samples=0.6, base_estimator__n_estimators=20, base_estimator__eta=0.6\n",
      "-0.190:  n_estimators=1, max_samples=0.8, base_estimator__n_estimators=10, base_estimator__eta=0.8\n",
      "-0.321:  n_estimators=1, max_samples=0.2, base_estimator__n_estimators=10, base_estimator__eta=0.1\n"
     ]
    }
   ],
   "source": [
    "grid_finder.generator.print_results()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Summary\n",
    "Grid search in __REP__ extends sklearn grid search, uses optimization techniques to avoid complete search of estimator parameters. \n",
    "\n",
    "__REP__ has predefined scorers, metric functions, optimization techniques. Each component is replaceable and you can optimize complex models and pipelines (Folders/Bagging/Boosting and so on). \n",
    "\n",
    "## Structure together\n",
    "* _ParameterOptimizer_ is responsible for generating new set of parameters which will be checked\n",
    "  * RandomParameterOptimizer\n",
    "  * AnnealingParameterOptimizer\n",
    "  * SubgridParameterOptimizer\n",
    "  * RegressionParameterOptimizer (this one can use any regression model, like GaussianProcesses)\n",
    "  \n",
    "* _Scorer_ is responsible for training and evaluating metrics\n",
    "  * Folding scorer (uses metrics with __REP__ interface), uses averaging quality after kFolding\n",
    "  \n",
    "* _GridOptimalSearchCV_ makes all of this work together and sends tasks to cluster or separate threads.\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}