{ "metadata": { "name": "" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "code", "collapsed": false, "input": [ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from sklearn.ensemble import ExtraTreesClassifier\n", "from sklearn.datasets import fetch_covtype\n", "from sklearn.cross_validation import train_test_split\n", "from sklearn.cross_validation import cross_val_score\n", "from sklearn.utils import shuffle\n", "from sklearn.base import BaseEstimator\n", "from sklearn.base import clone" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 99 }, { "cell_type": "code", "collapsed": false, "input": [ "covtype = fetch_covtype()\n", "X, y = covtype.data, covtype.target\n", "\n", "print(X.shape)\n", "print(np.unique(y))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "(581012, 54)\n", "[1 2 3 4 5 6 7]\n" ] } ], "prompt_number": 37 }, { "cell_type": "code", "collapsed": false, "input": [ "X_dev, X_test, y_dev, y_test = train_test_split(X, y, test_size=0.1)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 13 }, { "cell_type": "code", "collapsed": false, "input": [ "X_small, X_unlabeled, y_small, _ = train_test_split(X_dev, y_dev, train_size=10000)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 30 }, { "cell_type": "code", "collapsed": false, "input": [ "%%time\n", "\n", "etrees = ExtraTreesClassifier(n_estimators=80, n_jobs=4)\n", "scores = cross_val_score(etrees, X_small, y_small, cv=5)\n", "\n", "print(\"5-folds cv score: %0.3f+/-%0.3f\" % (np.mean(scores), np.std(scores)))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "5-folds cv score: 0.836+/-0.005\n", "CPU times: user 14.5 s, sys: 1.99 s, total: 16.5 s\n", "Wall time: 6.12 s\n" ] } ], "prompt_number": 36 }, { "cell_type": "code", "collapsed": false, "input": [ "np.random.uniform(size=(3, 4))" ], "language": "python", "metadata": {}, "outputs": [ { "metadata": {}, "output_type": "pyout", "prompt_number": 68, "text": [ "array([[ 0.29906499, 0.61963946, 0.07382687, 0.51198165],\n", " [ 0.75008411, 0.32665691, 0.38846908, 0.26959562],\n", " [ 0.56896242, 0.1422773 , 0.06123208, 0.77610519]])" ] } ], "prompt_number": 68 }, { "cell_type": "code", "collapsed": false, "input": [ "def shuffle_columns(X, copy=True, seed=0):\n", " rng = np.random.RandomState(seed)\n", " if copy:\n", " X = X.copy()\n", " for i in range(X.shape[1]):\n", " rng.shuffle(X[:, i])\n", " return X\n", "\n", "\n", "def corrupt(X, copy=True, rate=0.1, seed=0):\n", " rng = np.random.RandomState(seed)\n", " if copy:\n", " X = X.copy()\n", " X_shuffled = shuffle_columns(X, seed=0)\n", " mask = rng.uniform(size=X.shape) < rate\n", " X[mask] = X_shuffled[mask]\n", " return X\n", "\n", "\n", "def make_normality_problem(X, seed=0):\n", " data = np.vstack([X, shuffle_columns(X, seed=seed)])\n", " target = np.zeros(X.shape[0] * 2, dtype=np.int)\n", " target[:X.shape[0]] = 1\n", " return shuffle(data, target, random_state=seed)\n", "\n", "\n", "X_normal, y_normal = make_normality_problem(X_small)" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 74 }, { "cell_type": "code", "collapsed": false, "input": [ "%%time\n", "\n", "etrees_normality = ExtraTreesClassifier(n_estimators=80, n_jobs=4)\n", "scores = cross_val_score(etrees_normality, X_normal, y_normal, cv=5)\n", "\n", "print(\"5-folds cv score: %0.3f+/-%0.3f\" % (np.mean(scores), np.std(scores)))" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "5-folds cv score: 0.978+/-0.001\n", "CPU times: user 32 s, sys: 464 ms, total: 32.5 s\n", "Wall time: 10 s\n" ] } ], "prompt_number": 75 }, { "cell_type": "code", "collapsed": false, "input": [ "%%time\n", "\n", "_ = etrees_normality.fit(X_normal, y_normal)" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "CPU times: user 8.47 s, sys: 55.8 ms, total: 8.53 s\n", "Wall time: 2.38 s\n" ] } ], "prompt_number": 76 }, { "cell_type": "code", "collapsed": false, "input": [ "X_corrupted = corrupt(X_small, rate=0.2)\n", "\n", "predicted_normality = etrees_normality.predict_proba(X_corrupted)[:, 1]\n", "_ = plt.hist(predicted_normality, bins=30)\n", "\n", "X_new_unlabeled = X_corrupted[predicted_normality > 0.5]\n", "print(X_new_unlabeled.shape)" ], "language": "python", "metadata": {}, "outputs": [ { "metadata": {}, "output_type": "display_data", "png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEACAYAAABbMHZzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAEwpJREFUeJzt3X+MHOVhh/Fni+0WWh/GdWVsnyu7xm4xIlFCsEkimkuT\nWhfU2lYrYWhDncSqKrlN0qpNgymC4w8oSdU0RJGpGmQwUbnKJcgyKVgYEquoCTghhBgO1z8Ut75L\nfSSBYrdqwJa3f7zv5cbH3e3M7N3s3r7PRxp29t13Z17Gt999552ZHZAkSZIkSZIkSZIkSZIkSdIM\ntQMYBg6OKf848DLwIvCZTPk24AhwCFiXKb8qLuMIcM90NVaSNDWuBd7B+eH/fmAfMDs+/6X4uBr4\nbixfBhwFavG1A8CaOP8Y0DttLZYkTYllnB/+u4DfGKfeNuDTmed7gWuARYS9hBE3AH8/tU2UJBXx\nMyXesxL4deAZYD/wrli+GBjM1BsEloxTPhTLJUktMqvkey4h9OqvJuwJ/MpUNkqSNL3KhP8g8Eic\n/xZwDlhA6NEvzdTrjnWH4ny2fGi8Ba9YsaJ+7NixEk2SpKQdAy4r8oYywz67GR3zXwXMAX4E7CGM\n588BlhOGhw4AJ4FTwFrCAeCb4jLe4tixY9Trdad6ndtvv73lbWiXyW3htnBbTD4BK4oGeaOefz/w\nPuAXgRPAbYTTP3cQDgK/CfxBrDtAGAIaAM4CW4F6fG0r8ABwIeFsn71FGypJmjqNwv/GCcpvmqD8\nrjiN9RxwZd5GSZKmV5lhH1Wgp6en1U1oG26LUW6LUW6L5tQaV6lUPY5fSZJyqtVqUDDP7flLUoIM\nf0lKkOEvSW2mq2s+tVot91SGY/6S1GZCoBfJQsf8JUk5GP6SlCDDX5ISZPhLUoIMf0lKkOEvSQky\n/CUpQYa/JCXI8JekBBn+kpQgw1+SEmT4S1KCDH9JSlCj8N8BDBNu1j7WnwPngPmZsm3AEeAQsC5T\nflVcxhHgnrKNlSRNjUbhfz/QO075UuA3gf/IlK0GNsXHXmA7oz8xei+wBVgZp/GWKUmqSKPwfxp4\nbZzyzwF/OaZsA9APnAGOA0eBtcAiYC5wINZ7ENhYrrmSpKlQZsx/AzAIfG9M+eJYPmIQWDJO+VAs\nlyS1yKyC9S8CbiEM+Yxot7uBSZIaKBr+K4BlwAvxeTfwHGF4Z4hwLIDMa4OxvHtM+dBEK+jr6/vp\nfE9PDz09PQWbKEmdbn+cysvTa18GPApcOc5r3yecyfMq4UDvQ8AawrDOk8BlhBtRPgt8gjDu/y/A\nF4C94yzPe/hKSl473MO3H/gGsAo4AXx0zOvZ1g0Au+Lj48DWzOtbgfsIp3oeZfzglyRVpN3G6+35\nS0peO/T8JUkdyPCXpAQZ/pKUIMNfkhJk+EtSggx/SUqQ4S9JCTL8JSlBhr8kJcjwl6QEGf6SlCDD\nX5ISVPT3/Kfd448/nqve1VdfzYIFC6a5NZLUmdruVz0vvrjxvd1/8pPD3HbbFm655ZYKmiRJ1ari\nVz3bruf/+uuNe/612q2cO3eugtZIUmdyzF+SEmT4S1KCDH9JSpDhL0kJahT+O4Bh4GCm7G+Al4EX\ngEeAizOvbSPcpP0QsC5TflVcxhHgnuaaLElqVqPwvx8Ye+7lE8AVwNuBw4TAB1gNbIqPvcB2Rk89\nuhfYAqyMU+PzOSVJ06ZR+D8NvDambB8wcp7ls0B3nN8A9ANngOPAUWAtsAiYCxyI9R4ENjbTaElS\nc5od8/8Y8FicXwwMZl4bBJaMUz4UyyVJLdLMRV5/BbwJPDRFbYn6MvM9cZIkjdofp/LKhv9HgOuA\nD2TKhoClmefdhB7/EKNDQyPlQxMvuq9kkyQpFT2c3zG+o/ASygz79AKfIozx/yRTvge4AZgDLCcc\n2D0AnAROEcb/a8BNwO4S65UkTZFGPf9+4H3AAuAEcDvh7J45hAO/AN8EtgIDwK74eDaWjfwy0Vbg\nAeBCwjGCvVP1PyBJKq5R+N84TtmOSerfFaexngOuzNsoSdL08gpfSUqQ4S9JCTL8JSlBhr8kJcjw\nl6QEGf6SlCDDX5ISZPhLUoIMf0lKkOEvSQky/CUpQYa/JCXI8JekBBn+kpQgw1+SEmT4S1KCDH9J\nSpDhL0kJMvwlKUGNwn8HMAwczJTNJ9y8/TDwBDAv89o24AhwCFiXKb8qLuMIcE9zTZYkNatR+N8P\n9I4pu5kQ/quAp+JzgNXApvjYC2wHavG1e4EtwMo4jV2mJKlCjcL/aeC1MWXrgZ1xfiewMc5vAPqB\nM8Bx4CiwFlgEzAUOxHoPZt4jSWqBMmP+CwlDQcTHhXF+MTCYqTcILBmnfCiWS5JaZFaT76/HaQr1\nZeZ74iRJGrU/TuWVCf9h4FLgJGFI55VYPgQszdTrJvT4h+J8tnxo4sX3lWiSJKWkh/M7xncUXkKZ\nYZ89wOY4vxnYnSm/AZgDLCcc2D1A+JI4RRj/rwE3Zd4jSWqBRj3/fuB9wALgBHAbcDewi3D2znHg\n+lh3IJYPAGeBrYwOCW0FHgAuBB4D9k5R+yVJJTQK/xsnKP/gBOV3xWms54Ar8zZKkjS9vMJXkhJk\n+EtSggx/SUqQ4S9JCTL8JSlBhr8kJcjwl6QEGf6SlCDDX5ISZPhLUoIMf0lKkOEvSQky/CUpQYa/\nJCXI8JekBBn+kpQgw1+SEmT4S1KCDH9JSlAz4b8NeAk4CDwE/CwwH9gHHAaeAOaNqX8EOASsa2K9\nkqQmlQ3/ZcAfAu8k3Jj9AuAG4GZC+K8CnorPAVYDm+JjL7C9iXVLkppUNoBPAWeAi4BZ8fEHwHpg\nZ6yzE9gY5zcA/fE9x4GjwJqS65YkNals+L8K/C3wn4TQ/29Cj38hMBzrDMfnAIuBwcz7B4ElJdct\nSWrSrJLvWwH8KWH453Xgn4EPj6lTj9NEJnitLzPfEydJ0qj9cSqvbPi/C/gG8OP4/BHg3cBJ4NL4\nuAh4Jb4+BCzNvL87lo2jr2STJCkVPZzfMb6j8BLKDvscAq4BLgRqwAeBAeBRYHOssxnYHef3EA4I\nzwGWAyuBAyXXLUlqUtme/wvAg8C3gXPAd4B/AOYCu4AthAO718f6A7F8ADgLbGXyISFJ0jQqG/4A\nn41T1quEvYDx3BUnSVKLea69JCWo1uoGjFHPMxpUq93K7Nmf5803/zfXQufOvYRTp15ttm2SVIla\nrUaxkfHaT/+TVzPDPi0Vgj/fxjl9ut2+4ySptRz2kaQEGf6SlCDDX5ISZPhLUoIMf0lKkOEvSQky\n/CUpQYa/JCXI8JekBBn+kpQgw1+SEmT4S1KCDH9JSpDhL0kJMvwlKUGGvyQlqJnwnwc8DLxMuDH7\nWmA+sA84DDwR64zYBhwBDgHrmlivJKlJzYT/PcBjwOXA2wihfjMh/FcBT8XnAKuBTfGxF9je5Lol\nSU0oG8AXA9cCO+Lzs8DrwHpgZyzbCWyM8xuAfuAMcBw4CqwpuW5JUpPKhv9y4IfA/cB3gC8BPw8s\nBIZjneH4HGAxMJh5/yCwpOS6JUlNKnsD91nAO4E/Ab4FfJ7RIZ4RdSa/w/oEr/Vl5nviJEkatT9O\n5ZUN/8E4fSs+f5hwQPckcGl8XAS8El8fApZm3t8dy8bRV7JJkpSKHs7vGN9ReAllh31OAicIB3YB\nPgi8BDwKbI5lm4HdcX4PcAMwhzBktBI4UHLdkqQmle35A3wc+EdCoB8DPgpcAOwCthAO7F4f6w7E\n8gHCweGtTD4kJEmaRrVWN2CMep7vhFrtVur1O8n//VGjXve7RtLMUKvVKNY/rv30P3l5rr0kJcjw\nl6QEGf6SlCDDX5Iq0NU1n1qtlmuqQjNn+0iScjp9+jWKnKQy3ez5S1KCDH9JSpDhL0kJMvwlKUGG\nvyQlyPCXpAQZ/pKUIMN/jCIXYnR1zW91cyWpFC/yGqPIhRinT7fbj6JKUj72/CUpQYa/JCXI8Jek\nBBn+kpQgw1+SEtRs+F8APA88Gp/PB/YBh4EngHmZutuAI8AhYF2T65UkNaHZ8P8kMMDouZE3E8J/\nFfBUfA6wGtgUH3uB7VOwbklSSc0EcDdwHXAfo3ceWA/sjPM7gY1xfgPQD5wBjgNHgTVNrFuS1IRm\nwv/vgE8B5zJlC4HhOD8cnwMsBgYz9QaBJU2sW5Jart1uzVhE2St8fwt4hTDe3zNBnTqTXyo7wWt9\nmfmeSRYvSa3Vulsz7o9TeWXD/z2EIZ7rgJ8DuoAvE3r7lwIngUWELwiAIWBp5v3dsWwcfSWbJEmp\n6OH8jvEdhZdQdtjnFkKYLwduAL4G3ATsATbHOpuB3XF+T6w3J75nJXCg5LolSU2aqh92G9nvuRvY\nBWwhHNi9PpYPxPIB4Cywlfz7SpKkKdZuRyHqeb4TarVbqdfvpMhYW72er244MDP1y5XUeYrmxfTU\nHalfLM89174ps3If6ff3/yW1E3/PvylnKfLt7O//S2oX9vwlKUGJ9PxnteVFFpLUKomEf5HhGb8k\nJHU+h30kKUGGvyQlyPCXpAQZ/pKUIMNfkhJk+EuakYr8lr5X179VIqd6Suo0RX5L//Tp2bmv9Zk7\n9xJOnXq1iZbNDIa/pATkv9YnlZ9hcdinTblLK2k62fNvU8V2adPoqUiaOvb8JU2bInuw7sVWy56/\npGlT7Abn7sVWyZ6/JCWobPgvBb4OvAS8CHwils8H9gGHgSeAeZn3bAOOAIeAdSXXq3Hlv6OYu9WS\noHz4nwH+DLgCuAb4Y+By4GZC+K8CnorPAVYDm+JjL7C9iXXrLUZOY2s8hd1wSRPL35maycoG8Eng\nu3H+f4CXgSXAemBnLN8JbIzzG4B+wpfGceAosKbkuiVpGuXvTM1kU9H7Xga8A3gWWAgMx/Lh+Bxg\nMTCYec8g4csiMWn0KNQevFZEk2n2bJ9fAL4CfBI4Pea1Rl+NM/trsxTvKKbqeK2IJtNM+M8mBP+X\ngd2xbBi4lDAstAh4JZYPEQ4Sj+iOZePoy8z3xEmt0NU1P/cxglR+D0XF/i40XfbHqbyyX/c1wpj+\njwkHfkd8NpZ9hnCwd158XA08RBjnXwI8CVzGW7sl9Tw9lVrtVur1O8d5+2TNbXXddmlHjXo9X90w\n/JR3ubMJezb5+GVRTvHgbfXfRdHPSLG/o3b4PLW+7kj9Ynletuf/XuDDwPeA52PZNuBuYBewhXBg\n9/r42kAsHyD8y24lyWGfTlZkSKvYMIN7IKOKXTQ1E4dyHBqtSrttPXv+FdRtjx7e9LUj73KLKNbj\nnk04sS2fIl9Y0/dv0j5/F+3wGZlZdUfqV9Pzl5JSvMftTxqovXmhlSQlyPCXpAQ57KMOM8vb9ZWS\nf7upMxj+yen0D7m36yvHs2xSY/gnxw+5JMNfagOdvjemdmT4K2HtErrujal6hr8SZugqXYa/WqRd\net1Smgx/tYi9bqmVvMhLkhJk+EtSggx/SUqQ4S9JCTL8JSlBhr8kJcjwl6QEGf6SlKCqw78XOAQc\nAT5d8bolSVGV4X8B8EXCF8Bq4Ebg8grXL0mKqgz/NcBR4DhwBvgnYEOF65ckRVWG/xLgROb5YCyT\nJFWsyh92y/UrXl1dv92wzhtvHOKNN5pujyQlq8rwHwKWZp4vJfT+s46dOvXVFfkXWeTXHtuhbru0\nox3qtks72qFuu7SjHeq2SztmWl2OFalctVmEBi4D5gDfxQO+kpSEDwH/Tjjwu63FbZEkSZI03fJc\n7PWF+PoLwDsqalcrNNoWv0/YBt8D/g14W3VNq1TeCwCvJtwG7HeqaFSL5NkWPcDzwIvA/kpa1RqN\ntsUCYC9hGPlF4COVtax6O4Bh4OAkddo6Ny8gDPssA2Yz/tj/dcBjcX4t8ExVjatYnm3xbuDiON9L\nZ26LPNthpN7XgK8Cv1tV4yqWZ1vMA14CuuPzBVU1rmJ5tkUf8NdxfgHwYzr39rTXEgJ9ovAvlJut\n+G2fPBd7rQd2xvlnCX/sCytqX5XybItvAq/H+WcZ/cB3krwXAH4ceBj4YWUtq16ebfF7wFcYPVvu\nR1U1rmJ5tsV/AV1xvosQ/mcral/VngZem+T1QrnZivDPc7HXeHU6MfSKXvi2hdFv9k6S929iA3Bv\nfJ737u8zTZ5tsRKYD3wd+DZwUzVNq1yebfEl4ArgB4Shjk9W07S2VCg3W7F7lPdDO/Yk1078sBf5\nf3o/8DHgvdPUllbKsx0+D9wc69YofvL4TJFnW8wG3gl8ALiIsHf4DGGst5Pk2Ra3EIaDeoAVwD7g\n7cDp6WtWW8udm60I/zwXe42t0x3LOk2ebQHhIO+XCGP+k+32zVR5tsNVhN1+CGO7HyIMBeyZ9tZV\nK8+2OEEY6vm/OP0rIfA6LfzzbIv3AHfG+WPA94FfJewRpabtczPPxV7ZAxfX0JkHOSHftvhlwrjn\nNZW2rFpFLwC8n8492yfPtvg14EnCAdGLCAcAV1fXxMrk2RafA26P8wsJXw7zK2pfKywj3wHfts3N\n8S72+qM4jfhifP0Fwi5up2q0Le4jHMR6Pk4Hqm5gRfL8TYzo5PCHfNviLwhn/BwEPlFp66rVaFss\nAB4l5MRBwsHwTtVPOLbxJmHv72Okm5uSJEmSJEmSJEmSJEmSJEmSJEmSNLP9P5m/8m5+bQgcAAAA\nAElFTkSuQmCC\n", "text": [ "" ] } ], "prompt_number": 90 }, { "cell_type": "code", "collapsed": false, "input": [ "predicted_normality = etrees_normality.predict_proba(X_unlabeled)[:, 1]\n", "_ = plt.hist(predicted_normality, bins=30)" ], "language": "python", "metadata": {}, "outputs": [ { "metadata": {}, "output_type": "display_data", "png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAEACAYAAABPiSrXAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAFPVJREFUeJzt3VuMVdd9x/HvsTEEN1w8psJcBowc3Ia0qR0acJu0OREy\nHvcBcIsMaYtpgypLtI7VqlJNHswQt039kDhEkXmIibm0JaCgGtxQLgGjRFUxToRtbEK5yEjMYHCK\nzSVKL6CcPqz/MJvJMbPmds6Z8fcjbfY6/32ZdXac/WPvtTcDkiRJkiRJkiRJkiRJkiRJkvrgQ8DL\nwKvAEeDLUW8F2oBDMT1Y2GYFcBw4Cswt1GcCh2PZ6kJ9BLA56geAqYVlS4FjMT3SD99HkjTAbo35\nMNJJ/dPASuCvqqw7gxQwtwB3AieAUiw7CMyK9g6gJdrLgWejvQj4drSbgJPA2Jg62pKkOrkpY52f\nxXw4cDPwXnwuVVl3PrAJuAKcIoXGbGACMIoUHAAbgAXRngesj/ZWYE60HwB2Axdi2kNn0EiS6iAn\nNG4iXT2cA14C3oz6Y8BrwFo6rwAmkm5bdWgDJlWpt0edmJ+O9lXgInD7DfYlSaqTnND4OXAPMBn4\nXaAMrAGmRf1t4CsD1D9JUgMZ1oN1LwLfBX4T2F+oPwe8GO12oLmwbDLpCqE92l3rHdtMAc5Ef8YA\n56NeLmzTDOzr2qm77rqrcvLkyR58DUkSaZz4Iz3dqLsrjXF03noaCdxPelrqjsI6D5GeigLYDiwm\njX9MA6aTxjHOApdI4xslYAmwrbDN0mgvBPZGezfp6auxwG3xs3d17eDJkyepVCpOlQorV66sex8a\nZfJYeCw8FjeegLu6Of9X1d2VxgTSIPVNMW2Mk/oG0q2pCvAW8GisfwTYEvOrpCejKrFsObCOFD47\ngJ1RXxv7PU66wlgc9XeBp4BX4vMq0oC4JKlOuguNw8AnqtRv9M7E38fU1Y+AX69S/1/g4ffZ1/Mx\nSZIaQM5AuAaJcrlc7y40DI9FJ49FJ49F31V712KwqcT9OUlSplKpBL3IAK80JEnZDA1JUjZDQ5KU\nzdCQJGUzNCRJ2QwNSVI2Q0OSlM3QkCRlMzQkSdkMDUlSNkNDkpTN0JAkZTM0JEnZDA1JUjZDQ5KU\nzdCQpCFi9OgmSqVS1tRb/hImSRoiUhjkng/9JUySpAFmaEiSshkakqRshoYkKVt3ofEh4GXgVeAI\n8OWoNwF7gGPAbmBsYZsVwHHgKDC3UJ8JHI5lqwv1EcDmqB8AphaWLY2fcQx4JPM7SZIGSHeh8T/A\nZ4F7gI9H+9PAE6TQuBvYG58BZgCLYt4CPEvn6PwaYBkwPaaWqC8DzkftGeDpqDcBTwKzYlrJ9eEk\nSaqxnNtTP4v5cOBm4D1gHrA+6uuBBdGeD2wCrgCngBPAbGACMAo4GOttKGxT3NdWYE60HyBdxVyI\naQ+dQSNJqoOc0LiJdHvqHPAS8CYwPj4T8/HRngi0FbZtAyZVqbdHnZifjvZV4CJw+w32JUmqk2EZ\n6/ycdHtqDLCLdIuqqEL+2yQDorW19Vq7XC5TLpfr1hdJakz7Y+qbnNDocBH4LmlA+xxwB3CWdOvp\nnVinHWgubDOZdIXQHu2u9Y5tpgBnoj9jSGMc7UC5sE0zsK9ax4qhIUmqpsz1p9RVvdpLd7enxtE5\n+DwSuB84BGwnPdlEzF+I9nZgMWn8YxppcPsgKVwukcY3SsASYFthm459LSQNrEMaz5gbP/+2+Nm7\nevj9JEn9qLsrjQmkQeqbYtpIOqkfAraQnnw6BTwc6x+J+hHS+MRyOm9dLQfWkcJnB7Az6mtjv8dJ\nVxiLo/4u8BTwSnxeRRoQlyTVif9goSQNEf6DhZKkhmJoSJKyGRqSpGyGhiQpm6EhScpmaEiSshka\nkqRshoYkKZuhIUnKZmhIkrIZGpKkbIaGJCmboSFJymZoSJKyGRqSpGyGhiQpm6EhScpmaEiSshka\nkqRshoYkKZuhIUnKZmhIkrIZGpKkbN2FRjPwEvAm8Abwhai3Am3AoZgeLGyzAjgOHAXmFuozgcOx\nbHWhPgLYHPUDwNTCsqXAsZgeyftKkqSBUupm+R0xvQp8GPgRsAB4GLgMfLXL+jOAfwY+CUwCvgdM\nByrAQeAvYr4D+DqwE1gO/FrMFwEPAYuBJuAVUtgQP3smcKHLz6xUKpXMrytJQ1epVCKdbrPWvvZH\nT3R3pXGWFBgAPwV+TAqD9/th84FNwBXgFHACmA1MAEaRAgNgAyl8AOYB66O9FZgT7QeA3aSQuADs\nAVq6/0qSpIHSkzGNO4F7SbeQAB4DXgPWAmOjNpF026pDGylkutbb6QyfScDpaF8FLgK332BfkqQ6\nGZa53oeB7wCPk6441gBfimVPAV8BlvV77zK1trZea5fLZcrlcr26IkkNan9MfZMTGreQbhv9I/BC\n1N4pLH8OeDHa7aTB8w6TSVcI7dHuWu/YZgpwJvozBjgf9XJhm2ZgX7UOFkNDklRNmetPqat6tZfu\nbk+VSLefjgBfK9QnFNoPkZ6KAthOGsQeDkwjDYIfJI2NXCKNb5SAJcC2wjZLo70Q2Bvt3aSnr8YC\ntwH3A7uyv5kkqd91d6XxKeCPgddJj9YCfBH4HHAPaZj+LeDRWHYE2BLzq6QnojqG8pcD64CRpKen\ndkZ9LbCR9MjteVLoALxLuvX1SnxexS8+OSVJqqEeP27VgHzkVpJojEduJUm6xtCQJGUzNCRJ2QwN\nSVI2Q0OSlM3QkCRlMzQkSdkMDUlSNkNDkpTN0JAkZTM0JEnZDA1JUjZDQ5KUzdCQJGUzNCRJ2QwN\nSVI2Q0OSlM3QkCRlMzQkSdkMDUlSNkNDkpTN0JAkZTM0JEnZuguNZuAl4E3gDeALUW8C9gDHgN3A\n2MI2K4DjwFFgbqE+Ezgcy1YX6iOAzVE/AEwtLFsaP+MY8Ejmd5IkDZDuQuMK8JfAx4D7gD8HPgo8\nQQqNu4G98RlgBrAo5i3As0Aplq0BlgHTY2qJ+jLgfNSeAZ6OehPwJDArppVcH06SpBrrLjTOAq9G\n+6fAj4FJwDxgfdTXAwuiPR/YRAqbU8AJYDYwARgFHIz1NhS2Ke5rKzAn2g+QrmIuxLSHzqCRJNVB\nT8Y07gTuBV4GxgPnon4uPgNMBNoK27SRQqZrvT3qxPx0tK8CF4Hbb7AvSVKdDMtc78Okq4DHgctd\nllViqpvW1tZr7XK5TLlcrltfJKkx7Y+pb3JC4xZSYGwEXojaOeAO0u2rCcA7UW8nDZ53mEy6QmiP\ndtd6xzZTgDPRnzGkMY52oFzYphnYV62DxdCQJFVT5vpT6qpe7aW721MlYC1wBPhaob6d9GQTMX+h\nUF8MDAemkQa3D5LC5RJpfKMELAG2VdnXQtLAOqTxjLmkwe/bgPuBXT35cpKk/lXqZvmnge8Dr9N5\nC2oFKQi2kK4QTgEPkwarAb4IfJ40PvE4nSf6mcA6YCSwg87Hd0eQrmLuJV1hLI59Avxp7A/gb+kc\nMC+qVCp1vTsmSQ2hVCqRP1pQuvZHj35GTzdoQIaGJFGb0PCNcElSNkNDkpTN0JAkZTM0JEnZDA1J\nUjZDQ5KUzdCQJGUzNCRJ2QwNSVI2Q0OSlM3QkCRlMzQkSdkMDUlSNkNDkpTN0JAkZTM0JEnZDA1J\nUjZDQ5KUzdCQJGUzNCRJ2QwNSVI2Q0OSlM3QkCRlywmNbwHngMOFWivQBhyK6cHCshXAceAoMLdQ\nnxn7OA6sLtRHAJujfgCYWli2FDgW0yMZfZUkDaCc0HgeaOlSqwBfBe6N6d+iPgNYFPMW4FmgFMvW\nAMuA6TF17HMZcD5qzwBPR70JeBKYFdNKYGz2N5Mk9buc0PgB8F6VeqlKbT6wCbgCnAJOALOBCcAo\n4GCstwFYEO15wPpobwXmRPsBYDdwIaY9/GJ4SZJqqC9jGo8BrwFr6bwCmEi6bdWhDZhUpd4edWJ+\nOtpXgYvA7TfYlySpTob1crs1wJei/RTwFdJtprpobW291i6Xy5TL5Xp1RZIa1P6Y+qa3ofFOof0c\n8GK024HmwrLJpCuE9mh3rXdsMwU4E/0ZQxrjaAfKhW2agX3VOlMMDUlSNWWuP6Wu6tVeent7akKh\n/RCdT1ZtBxYDw4FppMHtg8BZ4BJpfKMELAG2FbZZGu2FwN5o7yY9fTUWuA24H9jVy/5KkvpBzpXG\nJuAzwDjS2MNKUlzdQ3qK6i3g0Vj3CLAl5leB5bEO0V4HjAR2ADujvhbYSHrk9jwpdADeJd36eiU+\nryINiEuS6qTaE1CDTaVSqXS/liQNcaVSic6/p3e79rU/esI3wiVJ2QwNSVI2Q0OSlM3QkCRlMzQk\nSdkMDUlSNkNDkpTN0JAkZTM0JEnZDA1JUjZDQ5KUzdCQJGUzNCRJ2QwNSVI2Q0OSlM3QkCRlMzQk\nSdkMDUlSNkNDkpTN0JAkZTM0JEnZDA1JUjZDQ5KULSc0vgWcAw4Xak3AHuAYsBsYW1i2AjgOHAXm\nFuozYx/HgdWF+ghgc9QPAFMLy5bGzzgGPJLRV0nSAMoJjeeBli61J0ihcTewNz4DzAAWxbwFeBYo\nxbI1wDJgekwd+1wGnI/aM8DTUW8CngRmxbSS68NJklRjOaHxA+C9LrV5wPporwcWRHs+sAm4ApwC\nTgCzgQnAKOBgrLehsE1xX1uBOdF+gHQVcyGmPfxieEmSaqi3YxrjSbesiPn4aE8E2grrtQGTqtTb\no07MT0f7KnARuP0G+5Ik1cmwfthHJaa6aW1tvdYul8uUy+W69UWSGtP+mPqmt6FxDrgDOEu69fRO\n1NuB5sJ6k0lXCO3R7lrv2GYKcCb6M4Y0xtEOlAvbNAP7qnWmGBqSpGrKXH9KXdWrvfT29tR20pNN\nxPyFQn0xMByYRhrcPkgKl0uk8Y0SsATYVmVfC0kD65DGM+aSBr9vA+4HdvWyv5KkfpBzpbEJ+Aww\njjT28CTwD8AW0pNPp4CHY90jUT9CGp9YTuetq+XAOmAksAPYGfW1wEbSI7fnSaED8C7wFPBKfF5F\nGhCXJNVJqftVGl6lUqnrkIokNYRSqUT+EHPp2h894RvhkqRshoYkKZuhIUnKZmhIkrIZGpKkbIaG\nJCmboSFJymZoSJKyGRqSpGyGhiQpm6EhScpmaEiSshkakqRshoYkNbjRo5solUrdTrXgP40uSQ0u\n/588959GlyQ1EENDkpTN0JAkZTM0JEnZDA1JUjZDQ5KUzdCQJGUzNCRJ2foaGqeA14FDwMGoNQF7\ngGPAbmBsYf0VwHHgKDC3UJ8JHI5lqwv1EcDmqB8Apvaxv5KkPuhraFSAMnAvMCtqT5BC425gb3wG\nmAEsinkL8CydbyOuAZYB02Nqifoy4HzUngGe7mN/JUl90B+3p7q+hj4PWB/t9cCCaM8HNgFXSFco\nJ4DZwARgFJ1XKhsK2xT3tRWY0w/9laSG0Ej/plSu/rjS+B7wQ+DPojYeOBftc/EZYCLQVti2DZhU\npd4edWJ+OtpXgYuk21+SNOhdvvwe6TTa3dQ4hvVx+08BbwO/TLoldbTL8pp849bW1mvtcrlMuVwe\n6B8pSYPM/pj6pq+h8XbMfwL8C2lc4xxwB3CWdOvpnVinHWgubDuZdIXRHu2u9Y5tpgBnoq9jgHe7\ndqIYGpKkasoxdVjVq7305fbUraSxCIBfIj0NdRjYDiyN+lLghWhvBxYDw4FppMHtg6RwuUQa3ygB\nS4BthW069rWQNLAuSaqTvlxpjCddXXTs559Ij9j+ENhCevLpFPBwrHMk6kdI4xPL6bx1tRxYB4wE\ndgA7o74W2Eh65PY8KXQkSXXSWMPyveMvYZI0KPX/L1fylzBJkhqIoSFJymZoSJKyGRqSpGyGhiQp\nm6EhScpmaEiSshkakqRshoYkKZuhIUnKZmhIkrIZGpKkbIaGJCmboSFJymZoSJKy9fXXvUrSB8Lo\n0U1cvvxexpq3AFcGujt14y9hkqQM/f8Lk3qyrr+ESZI0CBkakqRshoYkKZuhIUnKZmhIkrINhtBo\nAY4Cx4G/qXNfJA0So0c3USqVMqbhWespafTQuBn4Bik4ZgCfAz5a1x41sP3799e7Cw3DY9Hpg3os\n0jsVlS7TS1VqV6rUqk2Cxg+NWcAJ4BTpf9lvA/Pr2aFG9kE9OVTjseg0GI5Ff18VvP+Vwf5afq0h\nqdHfCJ8EnC58bgNm16kvkujJm9HQs7ejB+TlNfWzRr/SqPs14YEDB7L/ZrNkyZJ6d1eD3ED8jXvV\nqqf6dZ/Vb/u83+Stn6Gm0aP4PqCVNKYBsAL4OfB0YZ0TwF217ZYkDXongY/UuxP9bRjpi90JDAde\nxYFwSdINPAj8J+mKYkWd+yJJkiRpqMl5ye/rsfw14N4a9aseujsWf0Q6Bq8D/w58vHZdq7nclz8/\nCVwFfr8WnaqDnONQBg4BbzC0nz3t7liMA3aSbne/AfxJzXpWe98CzgGHb7DOkDxv3ky6PXUn6Rm+\namMbvwfsiPZs4ECtOldjOcfit4Ax0W7hg30sOtbbB/wr8Ae16lwN5RyHscCbwOT4PK5WnauxnGPR\nCnw52uOA8zT+6we99TukIHi/0OjxebPRH7ntkPOS3zxgfbRfJv2fZHyN+ldLOcfiP4CL0X6ZzhPF\nUJP78udjwHeAn9SsZ7WVcxz+ENhKetcJ4L9q1bkayzkWbwOjoz2aFBpXa9S/WvsBcKOXanp83hws\noVHtJb9JGesMxZNlzrEoWkbn3ySGmtz/LuYDa+LzUHwpIOc4TAeaSP+Oxg+BofpSUc6x+CbwMeAM\n6ZbM47XpWkPq8XlzsFyS9fYV0KF4gujJd/os8HngUwPUl3rLORZfA56IdUs0/rtJvZFzHG4BPgHM\nAW4lXY0eIN3LHkpyjsUXSbetyqR3vPYAvwFcHrhuNbQenTcHS2i0A82Fz810Xma/3zqTozbU5BwL\nSIPf3ySNaeT+mw+DTc6xmEm6RQHp/vWDpNsW2we8d7WTcxxOk25J/XdM3yedKIdaaOQci98G/i7a\nJ4G3gF8hXYF90AzZ82bOS37FAZ37GLqDvznHYgrpvu59Ne1Z7fX05c/nGZpPT+Uch18FvkcaKL6V\nNDA6o3ZdrJmcY/FVYGW0x5NCpalG/auHO8kbCB9y581qL/k9GlOHb8Ty10iX4kNVd8fiOdLg3qGY\nDta6gzWU899Fh6EaGpB3HP6a9ATVYeALNe1dbXV3LMYBL5LOE4dJDwkMVZtIYzf/R7ra/Dwf3POm\nJEmSJEmSJEmSJEmSJEmSJEmSJEmD0/8DKGxa/p8Kz/YAAAAASUVORK5CYII=\n", "text": [ "" ] } ], "prompt_number": 138 }, { "cell_type": "code", "collapsed": false, "input": [ "class SelfTrainingClassifier(BaseEstimator):\n", " \n", " def __init__(self, base_estimator=None, n_iter=10, clamp_true_target=False):\n", " self.base_estimator = base_estimator\n", " self.n_iter = n_iter\n", " self.clamp_true_target = clamp_true_target\n", " \n", " def fit(self, X, y, X_unlabeled=None, X_val=None, y_val=None):\n", " if self.base_estimator is None:\n", " model = ExtraTreesClassifier(n_estimators=100)\n", " else:\n", " model = clone(self.base_estimator)\n", " \n", " X_train, y_train = X, y\n", " \n", " for i in range(self.n_iter):\n", " model.fit(X_train, y_train)\n", " \n", " if X_val is not None and y_val is not None:\n", " print(model.score(X_val, y_val))\n", "\n", " if self.clamp_true_target:\n", " y_predicted = y\n", " else:\n", " y_predicted = model.predict(X)\n", " \n", " X_train = np.vstack([X, X_unlabeled])\n", " y_train = np.concatenate([y, model.predict(X_unlabeled)])\n", "\n", " self.estimator_ = model\n", " \n", " def predict(self, X):\n", " return self.estimator_.predict(X)\n", " \n", " def score(self, X, y):\n", " return self.estimator_.score(X, y)\n", " " ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 133 }, { "cell_type": "code", "collapsed": false, "input": [ "ssc = SelfTrainingClassifier(etrees).fit(X_small, y_small, X_new_unlabeled, X_val=X_test, y_val=y_test)" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "0.840573474235\n", "0.834170940759" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "0.835203607449" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "0.83325875185" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "0.832897318509" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "0.835117551892" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "0.833620185192" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "0.833275962962" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "0.833585762969" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n", "0.833293174073" ] }, { "output_type": "stream", "stream": "stdout", "text": [ "\n" ] } ], "prompt_number": 137 }, { "cell_type": "code", "collapsed": false, "input": [], "language": "python", "metadata": {}, "outputs": [] } ], "metadata": {} } ] }