{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# A New Tree Booster: PART\n", "\n", "__12 Feb 2018, marugari__\n", "\n", "PART (Peeking Additive Regression Trees) aims to\n", "* optimize non-differential metrics\n", "* avoid over-fitting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For training a PART booster, we need split training data into 3 part.\n", "1. training set: to search optimal splits\n", "2. peeking set: to determin whether a new tree is committed\n", "3. validation set: to get validation score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Repository (https://github.com/marugari/LightGBM/tree/part)](https://github.com/marugari/LightGBM/tree/part)\n", "\n", "[Main contribution (part.hpp)](https://github.com/marugari/LightGBM/blob/part/src/boosting/part.hpp)\n", "\n", "This is implemented as a LightGBM custom booster.\n", "The following is a fork of [the Kaggle Zillow Prize Kernel](https://www.kaggle.com/guolinke/simple-lightgbm-starter-lb-0-06487/code)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import lightgbm as lgb\n", "import gc" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "train = pd.read_csv('input/zillow/train_2016_v2.csv', engine='python')\n", "prop = pd.read_csv('input/zillow/properties_2016.csv', engine='python')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(90275, 55) (90275,)\n" ] } ], "source": [ "for c, dtype in zip(prop.columns, prop.dtypes):\t\n", " if dtype == np.float64:\n", " prop[c] = prop[c].astype(np.float32)\n", "df_train = train.merge(prop, how='left', on='parcelid')\n", "col = [\n", " 'parcelid',\n", " 'logerror',\n", " 'transactiondate',\n", " 'propertyzoningdesc',\n", " 'propertycountylandusecode'\n", "]\n", "x_train = df_train.drop(col, axis=1)\n", "y_train = df_train['logerror'].values\n", "print(x_train.shape, y_train.shape)\n", "train_columns = x_train.columns\n", "for c in x_train.dtypes[x_train.dtypes == object].index.values:\n", " x_train[c] = (x_train[c] == True)\n", "del df_train" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "split = 80000\n", "xt, xv = x_train[:split], x_train[split:]\n", "xt = xt.values.astype(np.float32, copy=False)\n", "xv = xv.values.astype(np.float32, copy=False)\n", "yt, yv = y_train[:split], y_train[split:]\n", "ds_train = lgb.Dataset(xt, label=yt, free_raw_data=False)\n", "ds_valid = lgb.Dataset(xv, label=yv, free_raw_data=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "prm = {\n", " 'learning_rate': 0.002,\n", " 'boosting_type': 'gbdt',\n", " 'objective': 'regression',\n", " 'metric': 'mae',\n", " 'sub_feature': 0.5,\n", " 'num_leaves': 60,\n", " 'min_data': 500,\n", " 'min_hessian': 1,\n", "}\n", "num_round = 500" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [], "source": [ "clf_gbdt = lgb.train(prm, ds_train, num_round)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "prm_part = prm\n", "prm_part['boosting_type'] = 'part'\n", "prm_part['learning_rate'] = 0.002\n", "prm_part['drop_rate'] = 0.0\n", "prm_part['skip_drop'] = 0.0\n", "np.random.seed(20180212)\n", "flg_part = np.random.choice([True, False], len(yt), replace=True, p=[0.7, 0.3])\n", "flg_peek = np.logical_not(flg_part)\n", "ds_part = lgb.Dataset(xt[flg_part], label=yt[flg_part], free_raw_data=False)\n", "ds_peek = lgb.Dataset(xt[flg_peek], label=yt[flg_peek], free_raw_data=False)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1]\tvalid_0's l1: 0.0683414\n", "[2]\tvalid_0's l1: 0.0683379\n", "[3]\tvalid_0's l1: 0.0683343\n", "[4]\tvalid_0's l1: 0.068331\n", "[5]\tvalid_0's l1: 0.0683291\n", "[6]\tvalid_0's l1: 0.0683264\n", "[7]\tvalid_0's l1: 0.0683249\n", "[8]\tvalid_0's l1: 0.0683225\n", "[9]\tvalid_0's l1: 0.06832\n", "[10]\tvalid_0's l1: 0.0683163\n", "[11]\tvalid_0's l1: 0.0683139\n", "[12]\tvalid_0's l1: 0.0683106\n", "[13]\tvalid_0's l1: 0.0683076\n", "[14]\tvalid_0's l1: 0.0683049\n", "[15]\tvalid_0's l1: 0.0683014\n", "[16]\tvalid_0's l1: 0.0682984\n", "[17]\tvalid_0's l1: 0.0682964\n", "[18]\tvalid_0's l1: 0.0682937\n", "[19]\tvalid_0's l1: 0.0682904\n", "[20]\tvalid_0's l1: 0.0682873\n", "[21]\tvalid_0's l1: 0.0682854\n", "[22]\tvalid_0's l1: 0.0682819\n", "[23]\tvalid_0's l1: 0.0682799\n", "[24]\tvalid_0's l1: 0.068277\n", "[25]\tvalid_0's l1: 0.0682755\n", "[26]\tvalid_0's l1: 0.0682727\n", "[27]\tvalid_0's l1: 0.0682709\n", "[28]\tvalid_0's l1: 0.0682689\n", "[29]\tvalid_0's l1: 0.068267\n", "[30]\tvalid_0's l1: 0.0682641\n", "[31]\tvalid_0's l1: 0.0682614\n", "[32]\tvalid_0's l1: 0.0682589\n", "[33]\tvalid_0's l1: 0.0682562\n", "[34]\tvalid_0's l1: 0.0682533\n", "[35]\tvalid_0's l1: 0.0682514\n", "[36]\tvalid_0's l1: 0.0682481\n", "[37]\tvalid_0's l1: 0.0682448\n", "[38]\tvalid_0's l1: 0.0682428\n", "[39]\tvalid_0's l1: 0.0682411\n", "[40]\tvalid_0's l1: 0.0682378\n", "[41]\tvalid_0's l1: 0.0682362\n", "[42]\tvalid_0's l1: 0.0682346\n", "[43]\tvalid_0's l1: 0.0682316\n", "[44]\tvalid_0's l1: 0.0682295\n", "[45]\tvalid_0's l1: 0.0682262\n", "[46]\tvalid_0's l1: 0.0682243\n", "[47]\tvalid_0's l1: 0.0682211\n", "[48]\tvalid_0's l1: 0.0682192\n", "[49]\tvalid_0's l1: 0.0682163\n", "[50]\tvalid_0's l1: 0.068214\n", "[51]\tvalid_0's l1: 0.0682118\n", "[52]\tvalid_0's l1: 0.0682092\n", "[53]\tvalid_0's l1: 0.068208\n", "[54]\tvalid_0's l1: 0.0682054\n", "[55]\tvalid_0's l1: 0.068203\n", "[56]\tvalid_0's l1: 0.0682009\n", "[57]\tvalid_0's l1: 0.0681992\n", "[58]\tvalid_0's l1: 0.0681973\n", "[59]\tvalid_0's l1: 0.0681945\n", "[60]\tvalid_0's l1: 0.0681932\n", "[61]\tvalid_0's l1: 0.0681908\n", "[62]\tvalid_0's l1: 0.0681888\n", "[63]\tvalid_0's l1: 0.0681869\n", "[64]\tvalid_0's l1: 0.0681849\n", "[65]\tvalid_0's l1: 0.0681839\n", "[66]\tvalid_0's l1: 0.0681821\n", "[67]\tvalid_0's l1: 0.06818\n", "[68]\tvalid_0's l1: 0.0681773\n", "[69]\tvalid_0's l1: 0.0681753\n", "[70]\tvalid_0's l1: 0.0681734\n", "[71]\tvalid_0's l1: 0.0681712\n", "[72]\tvalid_0's l1: 0.0681695\n", "[73]\tvalid_0's l1: 0.068168\n", "[74]\tvalid_0's l1: 0.0681664\n", "[75]\tvalid_0's l1: 0.0681642\n", "[76]\tvalid_0's l1: 0.0681616\n", "[77]\tvalid_0's l1: 0.068159\n", "[78]\tvalid_0's l1: 0.0681568\n", "[79]\tvalid_0's l1: 0.0681557\n", "[80]\tvalid_0's l1: 0.068154\n", "[81]\tvalid_0's l1: 0.0681519\n", "[82]\tvalid_0's l1: 0.0681503\n", "[83]\tvalid_0's l1: 0.0681483\n", "[84]\tvalid_0's l1: 0.0681466\n", "[85]\tvalid_0's l1: 0.068144\n", "[86]\tvalid_0's l1: 0.068142\n", "[87]\tvalid_0's l1: 0.0681402\n", "[88]\tvalid_0's l1: 0.0681379\n", "[89]\tvalid_0's l1: 0.0681364\n", "[90]\tvalid_0's l1: 0.0681347\n", "[91]\tvalid_0's l1: 0.0681329\n", "[92]\tvalid_0's l1: 0.0681311\n", "[93]\tvalid_0's l1: 0.0681291\n", "[94]\tvalid_0's l1: 0.0681269\n", "[95]\tvalid_0's l1: 0.0681248\n", "[96]\tvalid_0's l1: 0.0681233\n", "[97]\tvalid_0's l1: 0.0681207\n", "[98]\tvalid_0's l1: 0.0681182\n", "[99]\tvalid_0's l1: 0.0681167\n", "[100]\tvalid_0's l1: 0.068115\n", "[101]\tvalid_0's l1: 0.0681135\n", "[102]\tvalid_0's l1: 0.0681127\n", "[103]\tvalid_0's l1: 0.0681113\n", "[104]\tvalid_0's l1: 0.0681093\n", "[105]\tvalid_0's l1: 0.0681076\n", "[106]\tvalid_0's l1: 0.0681069\n", "[107]\tvalid_0's l1: 0.0681051\n", "[108]\tvalid_0's l1: 0.0681029\n", "[109]\tvalid_0's l1: 0.0681011\n", "[110]\tvalid_0's l1: 0.068099\n", "[111]\tvalid_0's l1: 0.0680975\n", "[112]\tvalid_0's l1: 0.0680951\n", "[113]\tvalid_0's l1: 0.068093\n", "[114]\tvalid_0's l1: 0.0680912\n", "[115]\tvalid_0's l1: 0.0680896\n", "[116]\tvalid_0's l1: 0.0680887\n", "[117]\tvalid_0's l1: 0.0680876\n", "[118]\tvalid_0's l1: 0.068086\n", "[119]\tvalid_0's l1: 0.068084\n", "[120]\tvalid_0's l1: 0.0680817\n", "[121]\tvalid_0's l1: 0.0680803\n", "[122]\tvalid_0's l1: 0.0680782\n", "[123]\tvalid_0's l1: 0.0680766\n", "[124]\tvalid_0's l1: 0.0680747\n", "[125]\tvalid_0's l1: 0.068073\n", "[126]\tvalid_0's l1: 0.0680719\n", "[127]\tvalid_0's l1: 0.0680702\n", "[128]\tvalid_0's l1: 0.0680692\n", "[129]\tvalid_0's l1: 0.0680678\n", "[130]\tvalid_0's l1: 0.0680666\n", "[131]\tvalid_0's l1: 0.0680654\n", "[132]\tvalid_0's l1: 0.0680643\n", "[133]\tvalid_0's l1: 0.0680626\n", "[134]\tvalid_0's l1: 0.0680606\n", "[135]\tvalid_0's l1: 0.0680589\n", "[136]\tvalid_0's l1: 0.0680576\n", "[137]\tvalid_0's l1: 0.0680556\n", "[138]\tvalid_0's l1: 0.0680547\n", "[139]\tvalid_0's l1: 0.0680536\n", "[140]\tvalid_0's l1: 0.0680521\n", "[141]\tvalid_0's l1: 0.0680502\n", "[142]\tvalid_0's l1: 0.068049\n", "[143]\tvalid_0's l1: 0.0680474\n", "[144]\tvalid_0's l1: 0.0680462\n", "[145]\tvalid_0's l1: 0.0680447\n", "[146]\tvalid_0's l1: 0.0680432\n", "[147]\tvalid_0's l1: 0.068042\n", "[148]\tvalid_0's l1: 0.0680408\n", "[149]\tvalid_0's l1: 0.0680398\n", "[150]\tvalid_0's l1: 0.0680388\n", "[151]\tvalid_0's l1: 0.068038\n", "[152]\tvalid_0's l1: 0.0680362\n", "[153]\tvalid_0's l1: 0.0680349\n", "[154]\tvalid_0's l1: 0.0680338\n", "[155]\tvalid_0's l1: 0.068032\n", "[156]\tvalid_0's l1: 0.0680314\n", "[157]\tvalid_0's l1: 0.0680303\n", "[158]\tvalid_0's l1: 0.0680284\n", "[159]\tvalid_0's l1: 0.0680273\n", "[160]\tvalid_0's l1: 0.0680254\n", "[161]\tvalid_0's l1: 0.068024\n", "[162]\tvalid_0's l1: 0.068023\n", "[163]\tvalid_0's l1: 0.0680221\n", "[164]\tvalid_0's l1: 0.0680214\n", "[165]\tvalid_0's l1: 0.0680201\n", "[166]\tvalid_0's l1: 0.068019\n", "[167]\tvalid_0's l1: 0.0680168\n", "[168]\tvalid_0's l1: 0.0680159\n", "[169]\tvalid_0's l1: 0.0680149\n", "[170]\tvalid_0's l1: 0.0680135\n", "[171]\tvalid_0's l1: 0.0680126\n", "[172]\tvalid_0's l1: 0.0680116\n", "[173]\tvalid_0's l1: 0.0680097\n", "[174]\tvalid_0's l1: 0.0680083\n", "[175]\tvalid_0's l1: 0.0680072\n", "[176]\tvalid_0's l1: 0.0680061\n", "[177]\tvalid_0's l1: 0.0680045\n", "[178]\tvalid_0's l1: 0.0680033\n", "[179]\tvalid_0's l1: 0.068002\n", "[180]\tvalid_0's l1: 0.0680012\n", "[181]\tvalid_0's l1: 0.0679996\n", "[182]\tvalid_0's l1: 0.0679988\n", "[183]\tvalid_0's l1: 0.0679978\n", "[184]\tvalid_0's l1: 0.0679963\n", "[185]\tvalid_0's l1: 0.0679949\n", "[186]\tvalid_0's l1: 0.0679938\n", "[187]\tvalid_0's l1: 0.0679921\n", "[188]\tvalid_0's l1: 0.0679913\n", "[189]\tvalid_0's l1: 0.0679897\n", "[190]\tvalid_0's l1: 0.0679884\n", "[191]\tvalid_0's l1: 0.0679876\n", "[192]\tvalid_0's l1: 0.0679867\n", "[193]\tvalid_0's l1: 0.067986\n", "[194]\tvalid_0's l1: 0.067985\n", "[195]\tvalid_0's l1: 0.0679837\n", "[196]\tvalid_0's l1: 0.0679831\n", "[197]\tvalid_0's l1: 0.067982\n", "[198]\tvalid_0's l1: 0.0679806\n", "[199]\tvalid_0's l1: 0.0679798\n", "[200]\tvalid_0's l1: 0.0679792\n", "[201]\tvalid_0's l1: 0.0679783\n", "[202]\tvalid_0's l1: 0.0679775\n", "[203]\tvalid_0's l1: 0.0679762\n", "[204]\tvalid_0's l1: 0.0679754\n", "[205]\tvalid_0's l1: 0.0679747\n", "[206]\tvalid_0's l1: 0.0679734\n", "[207]\tvalid_0's l1: 0.0679719\n", "[208]\tvalid_0's l1: 0.0679705\n", "[209]\tvalid_0's l1: 0.0679694\n", "[210]\tvalid_0's l1: 0.0679686\n", "[211]\tvalid_0's l1: 0.0679671\n", "[212]\tvalid_0's l1: 0.0679664\n", "[213]\tvalid_0's l1: 0.067965\n", "[214]\tvalid_0's l1: 0.0679636\n", "[215]\tvalid_0's l1: 0.0679631\n", "[216]\tvalid_0's l1: 0.0679615\n", "[217]\tvalid_0's l1: 0.0679604\n", "[218]\tvalid_0's l1: 0.0679593\n", "[219]\tvalid_0's l1: 0.0679587\n", "[220]\tvalid_0's l1: 0.0679575\n", "[221]\tvalid_0's l1: 0.0679571\n", "[222]\tvalid_0's l1: 0.0679562\n", "[223]\tvalid_0's l1: 0.0679546\n", "[224]\tvalid_0's l1: 0.0679541\n", "[225]\tvalid_0's l1: 0.0679529\n", "[226]\tvalid_0's l1: 0.0679517\n", "[227]\tvalid_0's l1: 0.0679506\n", "[228]\tvalid_0's l1: 0.0679495\n", "[229]\tvalid_0's l1: 0.0679487\n", "[230]\tvalid_0's l1: 0.0679476\n", "[231]\tvalid_0's l1: 0.0679466\n", "[232]\tvalid_0's l1: 0.0679455\n", "[233]\tvalid_0's l1: 0.0679443\n", "[234]\tvalid_0's l1: 0.0679429\n", "[235]\tvalid_0's l1: 0.0679426\n", "[236]\tvalid_0's l1: 0.0679417\n", "[237]\tvalid_0's l1: 0.067941\n", "[238]\tvalid_0's l1: 0.0679398\n", "[239]\tvalid_0's l1: 0.0679386\n", "[240]\tvalid_0's l1: 0.0679374\n", "[241]\tvalid_0's l1: 0.0679364\n", "[242]\tvalid_0's l1: 0.067936\n", "[243]\tvalid_0's l1: 0.0679352\n", "[244]\tvalid_0's l1: 0.0679339\n", "[245]\tvalid_0's l1: 0.0679328\n", "[246]\tvalid_0's l1: 0.0679323\n", "[247]\tvalid_0's l1: 0.0679315\n", "[248]\tvalid_0's l1: 0.0679302\n", "[249]\tvalid_0's l1: 0.0679295\n", "[250]\tvalid_0's l1: 0.0679293\n", "[251]\tvalid_0's l1: 0.0679285\n", "[252]\tvalid_0's l1: 0.0679276\n", "[253]\tvalid_0's l1: 0.0679266\n", "[254]\tvalid_0's l1: 0.0679256\n", "[255]\tvalid_0's l1: 0.0679245\n", "[256]\tvalid_0's l1: 0.0679231\n", "[257]\tvalid_0's l1: 0.0679217\n", "[258]\tvalid_0's l1: 0.0679206\n", "[259]\tvalid_0's l1: 0.06792\n", "[260]\tvalid_0's l1: 0.0679191\n", "[261]\tvalid_0's l1: 0.067918\n", "[262]\tvalid_0's l1: 0.0679175\n", "[263]\tvalid_0's l1: 0.0679169\n", "[264]\tvalid_0's l1: 0.0679156\n", "[265]\tvalid_0's l1: 0.0679149\n", "[266]\tvalid_0's l1: 0.067914\n", "[267]\tvalid_0's l1: 0.0679134\n", "[268]\tvalid_0's l1: 0.0679128\n", "[269]\tvalid_0's l1: 0.0679111\n", "[270]\tvalid_0's l1: 0.0679098\n", "[271]\tvalid_0's l1: 0.0679091\n", "[272]\tvalid_0's l1: 0.067908\n", "[273]\tvalid_0's l1: 0.0679075\n", "[274]\tvalid_0's l1: 0.067907\n", "[275]\tvalid_0's l1: 0.0679059\n", "[276]\tvalid_0's l1: 0.0679047\n", "[277]\tvalid_0's l1: 0.0679036\n", "[278]\tvalid_0's l1: 0.0679025\n", "[279]\tvalid_0's l1: 0.0679011\n", "[280]\tvalid_0's l1: 0.0679006\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[281]\tvalid_0's l1: 0.0679003\n", "[282]\tvalid_0's l1: 0.067899\n", "[283]\tvalid_0's l1: 0.0678986\n", "[284]\tvalid_0's l1: 0.0678976\n", "[285]\tvalid_0's l1: 0.0678962\n", "[286]\tvalid_0's l1: 0.0678956\n", "[287]\tvalid_0's l1: 0.0678946\n", "[288]\tvalid_0's l1: 0.0678936\n", "[289]\tvalid_0's l1: 0.0678932\n", "[290]\tvalid_0's l1: 0.067892\n", "[291]\tvalid_0's l1: 0.067891\n", "[292]\tvalid_0's l1: 0.0678904\n", "[293]\tvalid_0's l1: 0.0678898\n", "[294]\tvalid_0's l1: 0.0678894\n", "[295]\tvalid_0's l1: 0.0678881\n", "[296]\tvalid_0's l1: 0.0678875\n", "[297]\tvalid_0's l1: 0.0678871\n", "[298]\tvalid_0's l1: 0.0678866\n", "[299]\tvalid_0's l1: 0.0678863\n", "[300]\tvalid_0's l1: 0.0678855\n", "[301]\tvalid_0's l1: 0.0678842\n", "[302]\tvalid_0's l1: 0.0678832\n", "[303]\tvalid_0's l1: 0.0678825\n", "[304]\tvalid_0's l1: 0.0678819\n", "[305]\tvalid_0's l1: 0.0678813\n", "[306]\tvalid_0's l1: 0.0678804\n", "[307]\tvalid_0's l1: 0.0678798\n", "[308]\tvalid_0's l1: 0.0678791\n", "[309]\tvalid_0's l1: 0.0678784\n", "[310]\tvalid_0's l1: 0.0678776\n", "[311]\tvalid_0's l1: 0.0678769\n", "[312]\tvalid_0's l1: 0.0678758\n", "[313]\tvalid_0's l1: 0.0678749\n", "[314]\tvalid_0's l1: 0.0678739\n", "[315]\tvalid_0's l1: 0.0678729\n", "[316]\tvalid_0's l1: 0.0678719\n", "[317]\tvalid_0's l1: 0.0678716\n", "[318]\tvalid_0's l1: 0.067871\n", "[319]\tvalid_0's l1: 0.0678702\n", "[320]\tvalid_0's l1: 0.0678695\n", "[321]\tvalid_0's l1: 0.0678693\n", "[322]\tvalid_0's l1: 0.0678688\n", "[323]\tvalid_0's l1: 0.0678677\n", "[324]\tvalid_0's l1: 0.0678674\n", "[325]\tvalid_0's l1: 0.0678671\n", "[326]\tvalid_0's l1: 0.0678669\n", "[327]\tvalid_0's l1: 0.0678659\n", "[328]\tvalid_0's l1: 0.0678649\n", "[329]\tvalid_0's l1: 0.0678643\n", "[330]\tvalid_0's l1: 0.0678633\n", "[331]\tvalid_0's l1: 0.0678626\n", "[332]\tvalid_0's l1: 0.067862\n", "[333]\tvalid_0's l1: 0.0678616\n", "[334]\tvalid_0's l1: 0.0678616\n", "[335]\tvalid_0's l1: 0.0678609\n", "[336]\tvalid_0's l1: 0.0678601\n", "[337]\tvalid_0's l1: 0.0678599\n", "[338]\tvalid_0's l1: 0.0678591\n", "[339]\tvalid_0's l1: 0.0678585\n", "[340]\tvalid_0's l1: 0.0678585\n", "[341]\tvalid_0's l1: 0.0678579\n", "[342]\tvalid_0's l1: 0.0678569\n", "[343]\tvalid_0's l1: 0.0678566\n", "[344]\tvalid_0's l1: 0.0678558\n", "[345]\tvalid_0's l1: 0.0678551\n", "[346]\tvalid_0's l1: 0.0678541\n", "[347]\tvalid_0's l1: 0.0678533\n", "[348]\tvalid_0's l1: 0.0678529\n", "[349]\tvalid_0's l1: 0.0678523\n", "[350]\tvalid_0's l1: 0.0678516\n", "[351]\tvalid_0's l1: 0.0678511\n", "[352]\tvalid_0's l1: 0.0678501\n", "[353]\tvalid_0's l1: 0.0678497\n", "[354]\tvalid_0's l1: 0.0678491\n", "[355]\tvalid_0's l1: 0.0678482\n", "[356]\tvalid_0's l1: 0.067848\n", "[357]\tvalid_0's l1: 0.0678474\n", "[358]\tvalid_0's l1: 0.0678465\n", "[359]\tvalid_0's l1: 0.0678462\n", "[360]\tvalid_0's l1: 0.0678458\n", "[361]\tvalid_0's l1: 0.0678452\n", "[362]\tvalid_0's l1: 0.0678444\n", "[363]\tvalid_0's l1: 0.067844\n", "[364]\tvalid_0's l1: 0.0678436\n", "[365]\tvalid_0's l1: 0.0678428\n", "[366]\tvalid_0's l1: 0.0678425\n", "[367]\tvalid_0's l1: 0.0678422\n", "[368]\tvalid_0's l1: 0.0678411\n", "[369]\tvalid_0's l1: 0.0678402\n", "[370]\tvalid_0's l1: 0.0678394\n", "[371]\tvalid_0's l1: 0.0678392\n", "[372]\tvalid_0's l1: 0.0678387\n", "[373]\tvalid_0's l1: 0.0678385\n", "[374]\tvalid_0's l1: 0.0678378\n", "[375]\tvalid_0's l1: 0.0678377\n", "[376]\tvalid_0's l1: 0.0678369\n", "[377]\tvalid_0's l1: 0.0678363\n", "[378]\tvalid_0's l1: 0.0678357\n", "[379]\tvalid_0's l1: 0.0678353\n", "[380]\tvalid_0's l1: 0.0678346\n", "[381]\tvalid_0's l1: 0.0678345\n", "[382]\tvalid_0's l1: 0.0678338\n", "[383]\tvalid_0's l1: 0.0678334\n", "[384]\tvalid_0's l1: 0.0678329\n", "[385]\tvalid_0's l1: 0.0678327\n", "[386]\tvalid_0's l1: 0.0678322\n", "[387]\tvalid_0's l1: 0.0678315\n", "[388]\tvalid_0's l1: 0.0678308\n", "[389]\tvalid_0's l1: 0.0678302\n", "[390]\tvalid_0's l1: 0.0678297\n", "[391]\tvalid_0's l1: 0.0678289\n", "[392]\tvalid_0's l1: 0.0678286\n", "[393]\tvalid_0's l1: 0.0678281\n", "[394]\tvalid_0's l1: 0.0678279\n", "[395]\tvalid_0's l1: 0.0678275\n", "[396]\tvalid_0's l1: 0.0678265\n", "[397]\tvalid_0's l1: 0.067826\n", "[398]\tvalid_0's l1: 0.0678253\n", "[399]\tvalid_0's l1: 0.0678248\n", "[400]\tvalid_0's l1: 0.0678241\n", "[401]\tvalid_0's l1: 0.0678235\n", "[402]\tvalid_0's l1: 0.0678233\n", "[403]\tvalid_0's l1: 0.0678231\n", "[404]\tvalid_0's l1: 0.0678229\n", "[405]\tvalid_0's l1: 0.0678224\n", "[406]\tvalid_0's l1: 0.0678222\n", "[407]\tvalid_0's l1: 0.0678221\n", "[408]\tvalid_0's l1: 0.0678213\n", "[409]\tvalid_0's l1: 0.0678209\n", "[410]\tvalid_0's l1: 0.0678207\n", "[411]\tvalid_0's l1: 0.0678202\n", "[412]\tvalid_0's l1: 0.0678197\n", "[413]\tvalid_0's l1: 0.0678194\n", "[414]\tvalid_0's l1: 0.0678192\n", "[415]\tvalid_0's l1: 0.0678191\n", "[416]\tvalid_0's l1: 0.0678187\n", "[417]\tvalid_0's l1: 0.0678185\n", "[418]\tvalid_0's l1: 0.0678183\n", "[419]\tvalid_0's l1: 0.0678181\n", "[420]\tvalid_0's l1: 0.0678178\n", "[421]\tvalid_0's l1: 0.0678173\n", "[422]\tvalid_0's l1: 0.0678168\n", "[423]\tvalid_0's l1: 0.067816\n", "[424]\tvalid_0's l1: 0.0678155\n", "[425]\tvalid_0's l1: 0.0678149\n", "[426]\tvalid_0's l1: 0.0678144\n", "[427]\tvalid_0's l1: 0.0678138\n", "[428]\tvalid_0's l1: 0.0678132\n", "[429]\tvalid_0's l1: 0.0678125\n", "[430]\tvalid_0's l1: 0.0678119\n", "[431]\tvalid_0's l1: 0.0678115\n", "[432]\tvalid_0's l1: 0.0678112\n", "[433]\tvalid_0's l1: 0.0678111\n", "[434]\tvalid_0's l1: 0.0678109\n", "[435]\tvalid_0's l1: 0.0678107\n", "[436]\tvalid_0's l1: 0.0678104\n", "[437]\tvalid_0's l1: 0.0678094\n", "[438]\tvalid_0's l1: 0.0678092\n", "[439]\tvalid_0's l1: 0.067809\n", "[440]\tvalid_0's l1: 0.0678083\n", "[441]\tvalid_0's l1: 0.0678081\n", "[442]\tvalid_0's l1: 0.0678078\n", "[443]\tvalid_0's l1: 0.0678076\n", "[444]\tvalid_0's l1: 0.0678072\n", "[445]\tvalid_0's l1: 0.067807\n", "[446]\tvalid_0's l1: 0.0678068\n", "[447]\tvalid_0's l1: 0.0678065\n", "[448]\tvalid_0's l1: 0.0678063\n", "[449]\tvalid_0's l1: 0.0678058\n", "[450]\tvalid_0's l1: 0.0678057\n", "[451]\tvalid_0's l1: 0.0678052\n", "[452]\tvalid_0's l1: 0.067805\n", "[453]\tvalid_0's l1: 0.0678048\n", "[454]\tvalid_0's l1: 0.0678044\n", "[455]\tvalid_0's l1: 0.067804\n", "[456]\tvalid_0's l1: 0.067804\n", "[457]\tvalid_0's l1: 0.0678037\n", "[458]\tvalid_0's l1: 0.0678036\n", "[459]\tvalid_0's l1: 0.0678033\n", "[460]\tvalid_0's l1: 0.0678029\n", "[461]\tvalid_0's l1: 0.0678023\n", "[462]\tvalid_0's l1: 0.0678023\n", "[463]\tvalid_0's l1: 0.067802\n", "[464]\tvalid_0's l1: 0.0678019\n", "[465]\tvalid_0's l1: 0.0678017\n", "[466]\tvalid_0's l1: 0.0678017\n", "[467]\tvalid_0's l1: 0.0678012\n", "[468]\tvalid_0's l1: 0.0678012\n", "[469]\tvalid_0's l1: 0.0678011\n", "[470]\tvalid_0's l1: 0.0678009\n", "[471]\tvalid_0's l1: 0.0678004\n", "[472]\tvalid_0's l1: 0.0678\n", "[473]\tvalid_0's l1: 0.0677998\n", "[474]\tvalid_0's l1: 0.0677996\n", "[475]\tvalid_0's l1: 0.0677992\n", "[476]\tvalid_0's l1: 0.0677992\n", "[477]\tvalid_0's l1: 0.067799\n", "[478]\tvalid_0's l1: 0.0677988\n", "[479]\tvalid_0's l1: 0.0677984\n", "[480]\tvalid_0's l1: 0.0677984\n", "[481]\tvalid_0's l1: 0.0677982\n", "[482]\tvalid_0's l1: 0.067798\n", "[483]\tvalid_0's l1: 0.0677978\n", "[484]\tvalid_0's l1: 0.0677977\n", "[485]\tvalid_0's l1: 0.0677977\n", "[486]\tvalid_0's l1: 0.0677973\n", "[487]\tvalid_0's l1: 0.0677966\n", "[488]\tvalid_0's l1: 0.0677965\n", "[489]\tvalid_0's l1: 0.0677965\n", "[490]\tvalid_0's l1: 0.0677964\n", "[491]\tvalid_0's l1: 0.0677958\n", "[492]\tvalid_0's l1: 0.0677958\n", "[493]\tvalid_0's l1: 0.0677956\n", "[494]\tvalid_0's l1: 0.0677955\n", "[495]\tvalid_0's l1: 0.0677953\n", "[496]\tvalid_0's l1: 0.0677951\n", "[497]\tvalid_0's l1: 0.0677947\n", "[498]\tvalid_0's l1: 0.0677941\n", "[499]\tvalid_0's l1: 0.0677939\n", "[500]\tvalid_0's l1: 0.0677933\n" ] } ], "source": [ "clf_part = lgb.train(prm_part, ds_part, num_round, valid_sets=ds_peek)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_absolute_error\n", "def get_score(x, y, clf, ii):\n", " return mean_absolute_error(y, clf.predict(x, num_iteration=ii))\n", "lab = []\n", "val_gbdt = []\n", "val_part = []\n", "ii = int(0.7 * num_round)\n", "while ii <= num_round:\n", " lab.append(ii)\n", " val_gbdt.append(get_score(xv, yv, clf_gbdt, ii))\n", " val_part.append(get_score(xv, yv, clf_part, ii))\n", " ii += 5" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GBDT: 0.06612165068883384\n", "PART: 0.06612067704950389\n" ] } ], "source": [ "print(f'GBDT: {np.array(val_gbdt).min()}')\n", "print(f'PART: {np.array(val_part).min()}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }