{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
" \n",
"## Открытый курс по машинному обучению\n",
"Автор материала: Екатерина Ширяева slackname: Katya.Shiryaeva"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Есть такая либа H20"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Согласно исследованиям [Gartner в феврале 2018](https://www.gartner.com/doc/reprints?id=1-4RQ3VEZ&ct=180223&st=sb), H2O занимает уверенное место в лидерах рынка среди DataScience и Machine Learning платформ.\n",
"Gartner считают H2O.ai технологическим лидером, эта платформа мспользуется более чем 100000 data scientistами и удоволетверенность клиентами самая высокая (поддержка, обучение и продажи). \n",
"В этом обзоре я хочу показать отличия от реализаций алгоритмов в обычном sklearn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Подготовка данных"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Для примера можно взять датасет из [1ой лекции: отток клиентов телекома](https://habrahabr.ru/company/ods/blog/322626/) "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('data/telecom_churn.csv')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"RangeIndex: 3333 entries, 0 to 3332\n",
"Data columns (total 20 columns):\n",
"State 3333 non-null object\n",
"Account length 3333 non-null int64\n",
"Area code 3333 non-null int64\n",
"International plan 3333 non-null object\n",
"Voice mail plan 3333 non-null object\n",
"Number vmail messages 3333 non-null int64\n",
"Total day minutes 3333 non-null float64\n",
"Total day calls 3333 non-null int64\n",
"Total day charge 3333 non-null float64\n",
"Total eve minutes 3333 non-null float64\n",
"Total eve calls 3333 non-null int64\n",
"Total eve charge 3333 non-null float64\n",
"Total night minutes 3333 non-null float64\n",
"Total night calls 3333 non-null int64\n",
"Total night charge 3333 non-null float64\n",
"Total intl minutes 3333 non-null float64\n",
"Total intl calls 3333 non-null int64\n",
"Total intl charge 3333 non-null float64\n",
"Customer service calls 3333 non-null int64\n",
"Churn 3333 non-null bool\n",
"dtypes: bool(1), float64(8), int64(8), object(3)\n",
"memory usage: 498.1+ KB\n"
]
}
],
"source": [
"df.info()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Описание всех признаков можно посмотреть в [1ой лекции](https://habrahabr.ru/company/ods/blog/322626/) \n",
"В качестве предварительной обработки заменю все значения churn на 0/1, \n",
"для International plan и Voice mail plan сделаю замену yes / no на 0 / 1 \n",
"а категориальную переменную State пока удалю из обработки (данных недостаточно, чтобы сделать OHE, и укрупнение категорий не стоит в целях этого тьюториала)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Account length \n",
" Area code \n",
" International plan \n",
" Voice mail plan \n",
" Number vmail messages \n",
" Total day minutes \n",
" Total day calls \n",
" Total day charge \n",
" Total eve minutes \n",
" Total eve calls \n",
" Total eve charge \n",
" Total night minutes \n",
" Total night calls \n",
" Total night charge \n",
" Total intl minutes \n",
" Total intl calls \n",
" Total intl charge \n",
" Customer service calls \n",
" Churn \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 128 \n",
" 415 \n",
" 0 \n",
" 1 \n",
" 25 \n",
" 265.1 \n",
" 110 \n",
" 45.07 \n",
" 197.4 \n",
" 99 \n",
" 16.78 \n",
" 244.7 \n",
" 91 \n",
" 11.01 \n",
" 10.0 \n",
" 3 \n",
" 2.70 \n",
" 1 \n",
" 0 \n",
" \n",
" \n",
" 1 \n",
" 107 \n",
" 415 \n",
" 0 \n",
" 1 \n",
" 26 \n",
" 161.6 \n",
" 123 \n",
" 27.47 \n",
" 195.5 \n",
" 103 \n",
" 16.62 \n",
" 254.4 \n",
" 103 \n",
" 11.45 \n",
" 13.7 \n",
" 3 \n",
" 3.70 \n",
" 1 \n",
" 0 \n",
" \n",
" \n",
" 2 \n",
" 137 \n",
" 415 \n",
" 0 \n",
" 0 \n",
" 0 \n",
" 243.4 \n",
" 114 \n",
" 41.38 \n",
" 121.2 \n",
" 110 \n",
" 10.30 \n",
" 162.6 \n",
" 104 \n",
" 7.32 \n",
" 12.2 \n",
" 5 \n",
" 3.29 \n",
" 0 \n",
" 0 \n",
" \n",
" \n",
" 3 \n",
" 84 \n",
" 408 \n",
" 1 \n",
" 0 \n",
" 0 \n",
" 299.4 \n",
" 71 \n",
" 50.90 \n",
" 61.9 \n",
" 88 \n",
" 5.26 \n",
" 196.9 \n",
" 89 \n",
" 8.86 \n",
" 6.6 \n",
" 7 \n",
" 1.78 \n",
" 2 \n",
" 0 \n",
" \n",
" \n",
" 4 \n",
" 75 \n",
" 415 \n",
" 1 \n",
" 0 \n",
" 0 \n",
" 166.7 \n",
" 113 \n",
" 28.34 \n",
" 148.3 \n",
" 122 \n",
" 12.61 \n",
" 186.9 \n",
" 121 \n",
" 8.41 \n",
" 10.1 \n",
" 3 \n",
" 2.73 \n",
" 3 \n",
" 0 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Account length Area code International plan Voice mail plan \\\n",
"0 128 415 0 1 \n",
"1 107 415 0 1 \n",
"2 137 415 0 0 \n",
"3 84 408 1 0 \n",
"4 75 415 1 0 \n",
"\n",
" Number vmail messages Total day minutes Total day calls \\\n",
"0 25 265.1 110 \n",
"1 26 161.6 123 \n",
"2 0 243.4 114 \n",
"3 0 299.4 71 \n",
"4 0 166.7 113 \n",
"\n",
" Total day charge Total eve minutes Total eve calls Total eve charge \\\n",
"0 45.07 197.4 99 16.78 \n",
"1 27.47 195.5 103 16.62 \n",
"2 41.38 121.2 110 10.30 \n",
"3 50.90 61.9 88 5.26 \n",
"4 28.34 148.3 122 12.61 \n",
"\n",
" Total night minutes Total night calls Total night charge \\\n",
"0 244.7 91 11.01 \n",
"1 254.4 103 11.45 \n",
"2 162.6 104 7.32 \n",
"3 196.9 89 8.86 \n",
"4 186.9 121 8.41 \n",
"\n",
" Total intl minutes Total intl calls Total intl charge \\\n",
"0 10.0 3 2.70 \n",
"1 13.7 3 3.70 \n",
"2 12.2 5 3.29 \n",
"3 6.6 7 1.78 \n",
"4 10.1 3 2.73 \n",
"\n",
" Customer service calls Churn \n",
"0 1 0 \n",
"1 1 0 \n",
"2 0 0 \n",
"3 2 0 \n",
"4 3 0 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d = {'No' : 0, 'Yes' : 1}\n",
"df['Churn'] = df['Churn'].apply(lambda x : int(x))\n",
"df['International plan'] = df['International plan'].map(d)\n",
"df['Voice mail plan'] = df['Voice mail plan'].map(d)\n",
"df.drop('State', axis=1, inplace=True)\n",
"\n",
"df.head() "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Разделим на тест и обучающую выборки"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, y_train, y_test = train_test_split(df.iloc[:,:-1], df.iloc[:,-1], test_size=0.3, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Для работы h2o необходимо установить эту библиотеку и запустить"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install h2o"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Checking whether there is an H2O instance running at http://localhost:54321. connected.\n"
]
},
{
"data": {
"text/html": [
"H2O cluster uptime: \n",
"6 days 0 hours 34 mins \n",
"H2O cluster timezone: \n",
"Europe/Moscow \n",
"H2O data parsing timezone: \n",
"UTC \n",
"H2O cluster version: \n",
"3.18.0.4 \n",
"H2O cluster version age: \n",
"1 month and 12 days \n",
"H2O cluster name: \n",
"H2O_from_python_katya_lpda3c \n",
"H2O cluster total nodes: \n",
"1 \n",
"H2O cluster free memory: \n",
"6.316 Gb \n",
"H2O cluster total cores: \n",
"8 \n",
"H2O cluster allowed cores: \n",
"8 \n",
"H2O cluster status: \n",
"locked, healthy \n",
"H2O connection url: \n",
"http://localhost:54321 \n",
"H2O connection proxy: \n",
"None \n",
"H2O internal security: \n",
"False \n",
"H2O API Extensions: \n",
"XGBoost, Algos, AutoML, Core V3, Core V4 \n",
"Python version: \n",
"3.6.4 final
"
],
"text/plain": [
"-------------------------- ----------------------------------------\n",
"H2O cluster uptime: 6 days 0 hours 34 mins\n",
"H2O cluster timezone: Europe/Moscow\n",
"H2O data parsing timezone: UTC\n",
"H2O cluster version: 3.18.0.4\n",
"H2O cluster version age: 1 month and 12 days\n",
"H2O cluster name: H2O_from_python_katya_lpda3c\n",
"H2O cluster total nodes: 1\n",
"H2O cluster free memory: 6.316 Gb\n",
"H2O cluster total cores: 8\n",
"H2O cluster allowed cores: 8\n",
"H2O cluster status: locked, healthy\n",
"H2O connection url: http://localhost:54321\n",
"H2O connection proxy:\n",
"H2O internal security: False\n",
"H2O API Extensions: XGBoost, Algos, AutoML, Core V3, Core V4\n",
"Python version: 3.6.4 final\n",
"-------------------------- ----------------------------------------"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import h2o\n",
"import os\n",
"h2o.init(nthreads=-1, max_mem_size=8)\n",
"# nthreads - количество ядер процессора для вычислений\n",
"# max_mem_size - максимальный размер оперативной памяти "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Все данные надо перевести в специальную структуру h2o : H2OFrame \n",
"h2o [поддерживает](http://docs.h2o.ai/h2o/latest-stable/h2o-docs/getting-data-into-h2o.html) большое количество источников, однако у меня не получилось перекодировать csr-матрицу из задания про Элис :) очень долго висел :)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parse progress: |█████████████████████████████████████████████████████████| 100%\n",
"Parse progress: |█████████████████████████████████████████████████████████| 100%\n"
]
}
],
"source": [
"training = h2o.H2OFrame(pd.concat([X_train, y_train], axis=1))\n",
"validation = h2o.H2OFrame(pd.concat([X_test, y_test], axis=1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Посмотрим на структуру (это обязательно надо делать, т.к. не всегда корректно происходит переход форматов) \n",
"Здесь будут описаны основные параметры каждой переменной (тип, максимум, минимум, среднее, стандартное отклонение, количество нулевых и пропущенных значений и первые 10 наблюдений)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rows:2333\n",
"Cols:19\n",
"\n",
"\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
" Account length Area code International plan Voice mail plan Number vmail messages Total day minutes Total day calls Total day charge Total eve minutes Total eve calls Total eve charge Total night minutes Total night calls Total night charge Total intl minutes Total intl calls Total intl charge Customer service calls Churn \n",
" \n",
"\n",
"type int int int int int real int real real int real real int real real int real int int \n",
"mins 1.0 408.0 0.0 0.0 0.0 2.6 30.0 0.44 0.0 0.0 0.0 23.2 33.0 1.04 0.0 0.0 0.0 0.0 0.0 \n",
"mean 100.37848264037734 436.71924560651513 0.09515645092156022 0.2726103729104158 8.032576082297462 180.0195027861121 100.62280325760835 30.60383197599656 200.95752250321465 100.05400771538778 17.081633090441525 200.67038148306924 99.94813544792119 9.030210030004287 10.242777539648538 4.444492070295745 2.7660522931847398 1.568795542220319 0.14573510501500214 \n",
"maxs 232.0 510.0 1.0 1.0 51.0 346.8 165.0 58.96 363.7 170.0 30.91 395.0 175.0 17.77 20.0 18.0 5.4 9.0 1.0 \n",
"sigma 39.815132404220186 42.11342758508376 0.2934938203721862 0.4453975630897078 13.722524774971957 54.503148533784056 19.89235683817998 9.265512312565042 50.771196810709434 20.081856448464883 4.315580894314363 50.935130537595086 19.586623410722094 2.292113726067487 2.791145550244814 2.4515950038863097 0.7536382712677757 1.3337241215350106 0.35291609524195916 \n",
"zeros 0 0 2111 1697 1697 0 0 0 1 1 1 0 0 0 13 13 13 493 1993 \n",
"missing 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
"0 80.0 510.0 0.0 0.0 0.0 202.4 118.0 34.41 260.2 67.0 22.12 177.4 112.0 7.98 9.2 5.0 2.48 3.0 0.0 \n",
"1 63.0 510.0 0.0 0.0 0.0 132.9 122.0 22.59 67.0 62.0 5.7 160.4 121.0 7.22 9.9 2.0 2.67 3.0 0.0 \n",
"2 116.0 510.0 0.0 1.0 12.0 221.0 108.0 37.57 151.0 118.0 12.84 179.0 80.0 8.06 9.0 6.0 2.43 2.0 0.0 \n",
"3 71.0 415.0 0.0 0.0 0.0 278.9 110.0 47.41 190.2 67.0 16.17 255.2 84.0 11.48 11.7 7.0 3.16 0.0 1.0 \n",
"4 120.0 510.0 0.0 1.0 43.0 177.9 117.0 30.24 175.1 70.0 14.88 161.3 117.0 7.26 11.5 4.0 3.11 1.0 0.0 \n",
"5 132.0 510.0 0.0 0.0 0.0 181.1 121.0 30.79 314.4 109.0 26.72 246.7 81.0 11.1 4.2 9.0 1.13 2.0 0.0 \n",
"6 105.0 415.0 0.0 0.0 0.0 156.5 102.0 26.61 140.2 134.0 11.92 227.4 111.0 10.23 12.2 2.0 3.29 2.0 0.0 \n",
"7 117.0 510.0 1.0 0.0 0.0 198.4 121.0 33.73 249.5 104.0 21.21 162.8 115.0 7.33 10.5 5.0 2.84 1.0 0.0 \n",
"8 64.0 510.0 0.0 0.0 0.0 216.9 78.0 36.87 211.0 115.0 17.94 179.8 116.0 8.09 11.4 5.0 3.08 3.0 0.0 \n",
"9 143.0 510.0 0.0 1.0 33.0 141.4 130.0 24.04 186.4 114.0 15.84 210.0 111.0 9.45 7.7 6.0 2.08 1.0 0.0 \n",
" \n",
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"training.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Зависимая переменная для бинарной классификации должна быть не количественной переменной, а категориальной, преобразуем с помощью метода asfactor()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"training['Churn'] = training['Churn'].asfactor()\n",
"validation['Churn'] = validation['Churn'].asfactor()\n",
"\n",
"training['International plan'] = training['International plan'].asfactor()\n",
"training['Voice mail plan'] = training['Voice mail plan'].asfactor()\n",
"\n",
"validation['International plan'] = validation['International plan'].asfactor()\n",
"validation['Voice mail plan'] = validation['Voice mail plan'].asfactor()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Random Forest"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### sklearn RandomForestClassifier"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.metrics import roc_auc_score"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC для sklearn RandomForestClassifier: 0.9404\n"
]
}
],
"source": [
"forest = RandomForestClassifier(n_estimators=800, random_state=152, n_jobs=-1)\n",
"forest.fit(X_train, y_train)\n",
"print('AUC для sklearn RandomForestClassifier: {:.4f}'.format(roc_auc_score(y_test, forest.predict_proba(X_test)[:, 1])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### H2ORandomForestEstimator"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from h2o.estimators import H2ORandomForestEstimator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"необходимо задать список зависимых переменных и предикторов (это будут X и y в коде)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"drf Model Build progress: |███████████████████████████████████████████████| 100%\n",
"Model Details\n",
"=============\n",
"H2ORandomForestEstimator : Distributed Random Forest\n",
"Model Key: tutorial1\n",
"\n",
"\n",
"ModelMetricsBinomial: drf\n",
"** Reported on train data. **\n",
"\n",
"MSE: 0.04684101277272021\n",
"RMSE: 0.21642784657414169\n",
"LogLoss: 0.20074033630931348\n",
"Mean Per-Class Error: 0.10923157521914939\n",
"AUC: 0.8994753696762197\n",
"Gini: 0.7989507393524393\n",
"Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.4256055363321799: \n"
]
},
{
"data": {
"text/html": [
" \n",
"0 \n",
"1 \n",
"Error \n",
"Rate \n",
"0 \n",
"1972.0 \n",
"21.0 \n",
"0.0105 \n",
" (21.0/1993.0) \n",
"1 \n",
"77.0 \n",
"263.0 \n",
"0.2265 \n",
" (77.0/340.0) \n",
"Total \n",
"2049.0 \n",
"284.0 \n",
"0.042 \n",
" (98.0/2333.0)
"
],
"text/plain": [
" 0 1 Error Rate\n",
"----- ---- --- ------- -------------\n",
"0 1972 21 0.0105 (21.0/1993.0)\n",
"1 77 263 0.2265 (77.0/340.0)\n",
"Total 2049 284 0.042 (98.0/2333.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum Metrics: Maximum metrics at their respective thresholds\n",
"\n"
]
},
{
"data": {
"text/html": [
"metric \n",
"threshold \n",
"value \n",
"idx \n",
"max f1 \n",
"0.4256055 \n",
"0.8429487 \n",
"146.0 \n",
"max f2 \n",
"0.2397901 \n",
"0.8105802 \n",
"204.0 \n",
"max f0point5 \n",
"0.4746598 \n",
"0.8958924 \n",
"137.0 \n",
"max accuracy \n",
"0.4256055 \n",
"0.9579940 \n",
"146.0 \n",
"max precision \n",
"0.9790575 \n",
"1.0 \n",
"0.0 \n",
"max recall \n",
"0.0034502 \n",
"1.0 \n",
"397.0 \n",
"max specificity \n",
"0.9790575 \n",
"1.0 \n",
"0.0 \n",
"max absolute_mcc \n",
"0.4256055 \n",
"0.8233476 \n",
"146.0 \n",
"max min_per_class_accuracy \n",
"0.1664149 \n",
"0.8470588 \n",
"235.0 \n",
"max mean_per_class_accuracy \n",
"0.2397901 \n",
"0.8907684 \n",
"204.0
"
],
"text/plain": [
"metric threshold value idx\n",
"--------------------------- ----------- -------- -----\n",
"max f1 0.425606 0.842949 146\n",
"max f2 0.23979 0.81058 204\n",
"max f0point5 0.47466 0.895892 137\n",
"max accuracy 0.425606 0.957994 146\n",
"max precision 0.979058 1 0\n",
"max recall 0.00345018 1 397\n",
"max specificity 0.979058 1 0\n",
"max absolute_mcc 0.425606 0.823348 146\n",
"max min_per_class_accuracy 0.166415 0.847059 235\n",
"max mean_per_class_accuracy 0.23979 0.890768 204"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gains/Lift Table: Avg response rate: 14,57 %\n",
"\n"
]
},
{
"data": {
"text/html": [
" \n",
"group \n",
"cumulative_data_fraction \n",
"lower_threshold \n",
"lift \n",
"cumulative_lift \n",
"response_rate \n",
"cumulative_response_rate \n",
"capture_rate \n",
"cumulative_capture_rate \n",
"gain \n",
"cumulative_gain \n",
" \n",
"1 \n",
"0.0102872 \n",
"0.9293375 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0705882 \n",
"0.0705882 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"2 \n",
"0.0201457 \n",
"0.8840787 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.1382353 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"3 \n",
"0.0300043 \n",
"0.8429612 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.2058824 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"4 \n",
"0.0402915 \n",
"0.8114106 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0705882 \n",
"0.2764706 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"5 \n",
"0.0501500 \n",
"0.7814183 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.3441176 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"6 \n",
"0.1003000 \n",
"0.5665646 \n",
"6.3925842 \n",
"6.6271745 \n",
"0.9316239 \n",
"0.9658120 \n",
"0.3205882 \n",
"0.6647059 \n",
"539.2584213 \n",
"562.7174460 \n",
" \n",
"7 \n",
"0.1500214 \n",
"0.2865951 \n",
"2.9576572 \n",
"5.4109916 \n",
"0.4310345 \n",
"0.7885714 \n",
"0.1470588 \n",
"0.8117647 \n",
"195.7657201 \n",
"441.0991597 \n",
" \n",
"8 \n",
"0.2001715 \n",
"0.1805324 \n",
"0.6451232 \n",
"4.2169732 \n",
"0.0940171 \n",
"0.6145610 \n",
"0.0323529 \n",
"0.8441176 \n",
"-35.4876823 \n",
"321.6973170 \n",
" \n",
"9 \n",
"0.3000429 \n",
"0.1057973 \n",
"0.1472482 \n",
"2.8623361 \n",
"0.0214592 \n",
"0.4171429 \n",
"0.0147059 \n",
"0.8588235 \n",
"-85.2751830 \n",
"186.2336134 \n",
" \n",
"10 \n",
"0.3999143 \n",
"0.0762791 \n",
"0.0883489 \n",
"2.1695826 \n",
"0.0128755 \n",
"0.3161844 \n",
"0.0088235 \n",
"0.8676471 \n",
"-91.1651098 \n",
"116.9582624 \n",
" \n",
"11 \n",
"0.5002143 \n",
"0.0559567 \n",
"0.1172951 \n",
"1.7580700 \n",
"0.0170940 \n",
"0.2562125 \n",
"0.0117647 \n",
"0.8794118 \n",
"-88.2704877 \n",
"75.8069963 \n",
" \n",
"12 \n",
"0.6000857 \n",
"0.0424644 \n",
"0.2355971 \n",
"1.5046870 \n",
"0.0343348 \n",
"0.2192857 \n",
"0.0235294 \n",
"0.9029412 \n",
"-76.4402929 \n",
"50.4686975 \n",
" \n",
"13 \n",
"0.6999571 \n",
"0.0315723 \n",
"0.2355971 \n",
"1.3236105 \n",
"0.0343348 \n",
"0.1928965 \n",
"0.0235294 \n",
"0.9264706 \n",
"-76.4402929 \n",
"32.3610461 \n",
" \n",
"14 \n",
"0.7998285 \n",
"0.0229590 \n",
"0.2061474 \n",
"1.1840773 \n",
"0.0300429 \n",
"0.1725616 \n",
"0.0205882 \n",
"0.9470588 \n",
"-79.3852562 \n",
"18.4077297 \n",
" \n",
"15 \n",
"0.8997000 \n",
"0.0136415 \n",
"0.2650467 \n",
"1.0820601 \n",
"0.0386266 \n",
"0.1576941 \n",
"0.0264706 \n",
"0.9735294 \n",
"-73.4953295 \n",
"8.2060085 \n",
" \n",
"16 \n",
"1.0 \n",
"0.0 \n",
"0.2639140 \n",
"1.0 \n",
"0.0384615 \n",
"0.1457351 \n",
"0.0264706 \n",
"1.0 \n",
"-73.6085973 \n",
"0.0
"
],
"text/plain": [
" group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n",
"-- ------- -------------------------- ----------------- --------- ----------------- --------------- -------------------------- -------------- ------------------------- -------- -----------------\n",
" 1 0.0102872 0.929337 6.86176 6.86176 1 1 0.0705882 0.0705882 586.176 586.176\n",
" 2 0.0201457 0.884079 6.86176 6.86176 1 1 0.0676471 0.138235 586.176 586.176\n",
" 3 0.0300043 0.842961 6.86176 6.86176 1 1 0.0676471 0.205882 586.176 586.176\n",
" 4 0.0402915 0.811411 6.86176 6.86176 1 1 0.0705882 0.276471 586.176 586.176\n",
" 5 0.05015 0.781418 6.86176 6.86176 1 1 0.0676471 0.344118 586.176 586.176\n",
" 6 0.1003 0.566565 6.39258 6.62717 0.931624 0.965812 0.320588 0.664706 539.258 562.717\n",
" 7 0.150021 0.286595 2.95766 5.41099 0.431034 0.788571 0.147059 0.811765 195.766 441.099\n",
" 8 0.200171 0.180532 0.645123 4.21697 0.0940171 0.614561 0.0323529 0.844118 -35.4877 321.697\n",
" 9 0.300043 0.105797 0.147248 2.86234 0.0214592 0.417143 0.0147059 0.858824 -85.2752 186.234\n",
" 10 0.399914 0.0762791 0.0883489 2.16958 0.0128755 0.316184 0.00882353 0.867647 -91.1651 116.958\n",
" 11 0.500214 0.0559567 0.117295 1.75807 0.017094 0.256213 0.0117647 0.879412 -88.2705 75.807\n",
" 12 0.600086 0.0424644 0.235597 1.50469 0.0343348 0.219286 0.0235294 0.902941 -76.4403 50.4687\n",
" 13 0.699957 0.0315723 0.235597 1.32361 0.0343348 0.192897 0.0235294 0.926471 -76.4403 32.361\n",
" 14 0.799829 0.022959 0.206147 1.18408 0.0300429 0.172562 0.0205882 0.947059 -79.3853 18.4077\n",
" 15 0.8997 0.0136415 0.265047 1.08206 0.0386266 0.157694 0.0264706 0.973529 -73.4953 8.20601\n",
" 16 1 0 0.263914 1 0.0384615 0.145735 0.0264706 1 -73.6086 0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"ModelMetricsBinomial: drf\n",
"** Reported on validation data. **\n",
"\n",
"MSE: 0.04181281142387741\n",
"RMSE: 0.20448181196350304\n",
"LogLoss: 0.17373883641178992\n",
"Mean Per-Class Error: 0.07810625780287395\n",
"AUC: 0.9416814224282135\n",
"Gini: 0.883362844856427\n",
"Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.29251602564007045: \n"
]
},
{
"data": {
"text/html": [
" \n",
"0 \n",
"1 \n",
"Error \n",
"Rate \n",
"0 \n",
"838.0 \n",
"19.0 \n",
"0.0222 \n",
" (19.0/857.0) \n",
"1 \n",
"21.0 \n",
"122.0 \n",
"0.1469 \n",
" (21.0/143.0) \n",
"Total \n",
"859.0 \n",
"141.0 \n",
"0.04 \n",
" (40.0/1000.0)
"
],
"text/plain": [
" 0 1 Error Rate\n",
"----- --- --- ------- -------------\n",
"0 838 19 0.0222 (19.0/857.0)\n",
"1 21 122 0.1469 (21.0/143.0)\n",
"Total 859 141 0.04 (40.0/1000.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum Metrics: Maximum metrics at their respective thresholds\n",
"\n"
]
},
{
"data": {
"text/html": [
"metric \n",
"threshold \n",
"value \n",
"idx \n",
"max f1 \n",
"0.2925160 \n",
"0.8591549 \n",
"120.0 \n",
"max f2 \n",
"0.2635938 \n",
"0.8644537 \n",
"129.0 \n",
"max f0point5 \n",
"0.58375 \n",
"0.8959538 \n",
"81.0 \n",
"max accuracy \n",
"0.3804167 \n",
"0.961 \n",
"108.0 \n",
"max precision \n",
"0.9825 \n",
"1.0 \n",
"0.0 \n",
"max recall \n",
"0.0138066 \n",
"1.0 \n",
"381.0 \n",
"max specificity \n",
"0.9825 \n",
"1.0 \n",
"0.0 \n",
"max absolute_mcc \n",
"0.2925160 \n",
"0.8358744 \n",
"120.0 \n",
"max min_per_class_accuracy \n",
"0.1825 \n",
"0.9090909 \n",
"164.0 \n",
"max mean_per_class_accuracy \n",
"0.2635938 \n",
"0.9218937 \n",
"129.0
"
],
"text/plain": [
"metric threshold value idx\n",
"--------------------------- ----------- -------- -----\n",
"max f1 0.292516 0.859155 120\n",
"max f2 0.263594 0.864454 129\n",
"max f0point5 0.58375 0.895954 81\n",
"max accuracy 0.380417 0.961 108\n",
"max precision 0.9825 1 0\n",
"max recall 0.0138066 1 381\n",
"max specificity 0.9825 1 0\n",
"max absolute_mcc 0.292516 0.835874 120\n",
"max min_per_class_accuracy 0.1825 0.909091 164\n",
"max mean_per_class_accuracy 0.263594 0.921894 129"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gains/Lift Table: Avg response rate: 14,30 %\n",
"\n"
]
},
{
"data": {
"text/html": [
" \n",
"group \n",
"cumulative_data_fraction \n",
"lower_threshold \n",
"lift \n",
"cumulative_lift \n",
"response_rate \n",
"cumulative_response_rate \n",
"capture_rate \n",
"cumulative_capture_rate \n",
"gain \n",
"cumulative_gain \n",
" \n",
"1 \n",
"0.011 \n",
"0.9075 \n",
"6.9930070 \n",
"6.9930070 \n",
"1.0 \n",
"1.0 \n",
"0.0769231 \n",
"0.0769231 \n",
"599.3006993 \n",
"599.3006993 \n",
" \n",
"2 \n",
"0.02 \n",
"0.8775500 \n",
"6.9930070 \n",
"6.9930070 \n",
"1.0 \n",
"1.0 \n",
"0.0629371 \n",
"0.1398601 \n",
"599.3006993 \n",
"599.3006993 \n",
" \n",
"3 \n",
"0.03 \n",
"0.8400375 \n",
"6.9930070 \n",
"6.9930070 \n",
"1.0 \n",
"1.0 \n",
"0.0699301 \n",
"0.2097902 \n",
"599.3006993 \n",
"599.3006993 \n",
" \n",
"4 \n",
"0.04 \n",
"0.7927500 \n",
"6.9930070 \n",
"6.9930070 \n",
"1.0 \n",
"1.0 \n",
"0.0699301 \n",
"0.2797203 \n",
"599.3006993 \n",
"599.3006993 \n",
" \n",
"5 \n",
"0.05 \n",
"0.7630625 \n",
"6.9930070 \n",
"6.9930070 \n",
"1.0 \n",
"1.0 \n",
"0.0699301 \n",
"0.3496503 \n",
"599.3006993 \n",
"599.3006993 \n",
" \n",
"6 \n",
"0.1 \n",
"0.5127500 \n",
"6.4335664 \n",
"6.7132867 \n",
"0.92 \n",
"0.96 \n",
"0.3216783 \n",
"0.6713287 \n",
"543.3566434 \n",
"571.3286713 \n",
" \n",
"7 \n",
"0.15 \n",
"0.2634844 \n",
"3.9160839 \n",
"5.7808858 \n",
"0.56 \n",
"0.8266667 \n",
"0.1958042 \n",
"0.8671329 \n",
"291.6083916 \n",
"478.0885781 \n",
" \n",
"8 \n",
"0.2 \n",
"0.1755025 \n",
"0.8391608 \n",
"4.5454545 \n",
"0.12 \n",
"0.65 \n",
"0.0419580 \n",
"0.9090909 \n",
"-16.0839161 \n",
"354.5454545 \n",
" \n",
"9 \n",
"0.3 \n",
"0.1100417 \n",
"0.1398601 \n",
"3.0769231 \n",
"0.02 \n",
"0.44 \n",
"0.0139860 \n",
"0.9230769 \n",
"-86.0139860 \n",
"207.6923077 \n",
" \n",
"10 \n",
"0.4 \n",
"0.0825485 \n",
"0.0 \n",
"2.3076923 \n",
"0.0 \n",
"0.33 \n",
"0.0 \n",
"0.9230769 \n",
"-100.0 \n",
"130.7692308 \n",
" \n",
"11 \n",
"0.5 \n",
"0.0586597 \n",
"0.0 \n",
"1.8461538 \n",
"0.0 \n",
"0.264 \n",
"0.0 \n",
"0.9230769 \n",
"-100.0 \n",
"84.6153846 \n",
" \n",
"12 \n",
"0.6 \n",
"0.0426986 \n",
"0.1398601 \n",
"1.5617716 \n",
"0.02 \n",
"0.2233333 \n",
"0.0139860 \n",
"0.9370629 \n",
"-86.0139860 \n",
"56.1771562 \n",
" \n",
"13 \n",
"0.7 \n",
"0.0325182 \n",
"0.3496503 \n",
"1.3886114 \n",
"0.05 \n",
"0.1985714 \n",
"0.0349650 \n",
"0.9720280 \n",
"-65.0349650 \n",
"38.8611389 \n",
" \n",
"14 \n",
"0.8 \n",
"0.0238202 \n",
"0.0699301 \n",
"1.2237762 \n",
"0.01 \n",
"0.175 \n",
"0.0069930 \n",
"0.9790210 \n",
"-93.0069930 \n",
"22.3776224 \n",
" \n",
"15 \n",
"0.9 \n",
"0.0162297 \n",
"0.1398601 \n",
"1.1033411 \n",
"0.02 \n",
"0.1577778 \n",
"0.0139860 \n",
"0.9930070 \n",
"-86.0139860 \n",
"10.3341103 \n",
" \n",
"16 \n",
"1.0 \n",
"0.0012500 \n",
"0.0699301 \n",
"1.0 \n",
"0.01 \n",
"0.143 \n",
"0.0069930 \n",
"1.0 \n",
"-93.0069930 \n",
"0.0
"
],
"text/plain": [
" group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n",
"-- ------- -------------------------- ----------------- --------- ----------------- --------------- -------------------------- -------------- ------------------------- -------- -----------------\n",
" 1 0.011 0.9075 6.99301 6.99301 1 1 0.0769231 0.0769231 599.301 599.301\n",
" 2 0.02 0.87755 6.99301 6.99301 1 1 0.0629371 0.13986 599.301 599.301\n",
" 3 0.03 0.840037 6.99301 6.99301 1 1 0.0699301 0.20979 599.301 599.301\n",
" 4 0.04 0.79275 6.99301 6.99301 1 1 0.0699301 0.27972 599.301 599.301\n",
" 5 0.05 0.763062 6.99301 6.99301 1 1 0.0699301 0.34965 599.301 599.301\n",
" 6 0.1 0.51275 6.43357 6.71329 0.92 0.96 0.321678 0.671329 543.357 571.329\n",
" 7 0.15 0.263484 3.91608 5.78089 0.56 0.826667 0.195804 0.867133 291.608 478.089\n",
" 8 0.2 0.175502 0.839161 4.54545 0.12 0.65 0.041958 0.909091 -16.0839 354.545\n",
" 9 0.3 0.110042 0.13986 3.07692 0.02 0.44 0.013986 0.923077 -86.014 207.692\n",
" 10 0.4 0.0825485 0 2.30769 0 0.33 0 0.923077 -100 130.769\n",
" 11 0.5 0.0586597 0 1.84615 0 0.264 0 0.923077 -100 84.6154\n",
" 12 0.6 0.0426986 0.13986 1.56177 0.02 0.223333 0.013986 0.937063 -86.014 56.1772\n",
" 13 0.7 0.0325182 0.34965 1.38861 0.05 0.198571 0.034965 0.972028 -65.035 38.8611\n",
" 14 0.8 0.0238202 0.0699301 1.22378 0.01 0.175 0.00699301 0.979021 -93.007 22.3776\n",
" 15 0.9 0.0162297 0.13986 1.10334 0.02 0.157778 0.013986 0.993007 -86.014 10.3341\n",
" 16 1 0.00125 0.0699301 1 0.01 0.143 0.00699301 1 -93.007 0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Scoring History: \n"
]
},
{
"data": {
"text/html": [
" \n",
"timestamp \n",
"duration \n",
"number_of_trees \n",
"training_rmse \n",
"training_logloss \n",
"training_auc \n",
"training_lift \n",
"training_classification_error \n",
"validation_rmse \n",
"validation_logloss \n",
"validation_auc \n",
"validation_lift \n",
"validation_classification_error \n",
" \n",
"2018-04-20 23:38:58 \n",
" 0.005 sec \n",
"0.0 \n",
"nan \n",
"nan \n",
"nan \n",
"nan \n",
"nan \n",
"nan \n",
"nan \n",
"nan \n",
"nan \n",
"nan \n",
" \n",
"2018-04-20 23:38:58 \n",
" 0.020 sec \n",
"1.0 \n",
"0.3299550 \n",
"3.7593817 \n",
"0.8308720 \n",
"4.0336049 \n",
"0.1088271 \n",
"0.3270305 \n",
"3.6650092 \n",
"0.7946936 \n",
"4.3246227 \n",
"0.107 \n",
" \n",
"2018-04-20 23:38:58 \n",
" 0.031 sec \n",
"2.0 \n",
"0.3199525 \n",
"3.3116478 \n",
"0.8127797 \n",
"4.3355429 \n",
"0.1018051 \n",
"0.2733811 \n",
"1.5744127 \n",
"0.8568025 \n",
"6.0808756 \n",
"0.075 \n",
" \n",
"2018-04-20 23:38:58 \n",
" 0.041 sec \n",
"3.0 \n",
"0.3209383 \n",
"3.0195030 \n",
"0.8090340 \n",
"4.5637209 \n",
"0.1284884 \n",
"0.2540667 \n",
"1.0172508 \n",
"0.8787240 \n",
"6.7476383 \n",
"0.069 \n",
" \n",
"2018-04-20 23:38:58 \n",
" 0.052 sec \n",
"4.0 \n",
"0.3150823 \n",
"2.7257284 \n",
"0.8089239 \n",
"4.7902886 \n",
"0.1039834 \n",
"0.2466377 \n",
"0.8890110 \n",
"0.8823225 \n",
"6.9930070 \n",
"0.062 \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
"--- \n",
" \n",
"2018-04-20 23:39:01 \n",
" 3.860 sec \n",
"117.0 \n",
"0.2207916 \n",
"0.3406937 \n",
"0.9061982 \n",
"6.8617647 \n",
"0.0462923 \n",
"0.2058721 \n",
"0.1769192 \n",
"0.9388051 \n",
"6.9930070 \n",
"0.041 \n",
" \n",
"2018-04-20 23:39:01 \n",
" 3.921 sec \n",
"118.0 \n",
"0.2204195 \n",
"0.3402015 \n",
"0.9067309 \n",
"6.8617647 \n",
"0.0458637 \n",
"0.2059082 \n",
"0.1769675 \n",
"0.9385603 \n",
"6.9930070 \n",
"0.04 \n",
" \n",
"2018-04-20 23:39:02 \n",
" 3.988 sec \n",
"119.0 \n",
"0.2202526 \n",
"0.3402213 \n",
"0.9065265 \n",
"6.8617647 \n",
"0.0462923 \n",
"0.2058509 \n",
"0.1771015 \n",
"0.9384256 \n",
"6.9930070 \n",
"0.04 \n",
" \n",
"2018-04-20 23:39:06 \n",
" 7.993 sec \n",
"780.0 \n",
"0.2166458 \n",
"0.2007882 \n",
"0.8997262 \n",
"6.8617647 \n",
"0.0424346 \n",
"0.2045416 \n",
"0.1739691 \n",
"0.9412571 \n",
"6.9930070 \n",
"0.039 \n",
" \n",
"2018-04-20 23:39:06 \n",
" 8.543 sec \n",
"800.0 \n",
"0.2164278 \n",
"0.2007403 \n",
"0.8994754 \n",
"6.8617647 \n",
"0.0420060 \n",
"0.2044818 \n",
"0.1737388 \n",
"0.9416814 \n",
"6.9930070 \n",
"0.04
"
],
"text/plain": [
" timestamp duration number_of_trees training_rmse training_logloss training_auc training_lift training_classification_error validation_rmse validation_logloss validation_auc validation_lift validation_classification_error\n",
"--- ------------------- ---------- ----------------- ------------------- ------------------- ------------------ ----------------- ------------------------------- ------------------- -------------------- ------------------ ------------------ ---------------------------------\n",
" 2018-04-20 23:38:58 0.005 sec 0.0 nan nan nan nan nan nan nan nan nan nan\n",
" 2018-04-20 23:38:58 0.020 sec 1.0 0.3299549533419618 3.759381658451759 0.8308720112517581 4.033604928457869 0.10882708585247884 0.32703054843001095 3.6650091908267997 0.7946936377508139 4.324622745675378 0.107\n",
" 2018-04-20 23:38:58 0.031 sec 2.0 0.31995248719448316 3.3116478265664515 0.8127796965211596 4.335542873865965 0.10180505415162455 0.273381134867804 1.5744126828608989 0.856802474072019 6.0808756460930375 0.075\n",
" 2018-04-20 23:38:58 0.041 sec 3.0 0.3209383181496875 3.0195030090134423 0.8090340035860656 4.563720865704773 0.12848837209302325 0.2540666971937434 1.0172508387501398 0.8787239598208093 6.747638326585696 0.069\n",
" 2018-04-20 23:38:58 0.052 sec 4.0 0.31508233994601303 2.725728410467888 0.8089238889980269 4.790288568257491 0.10398344542162442 0.24663767870350212 0.8890110203449179 0.8823224616690194 6.993006993006993 0.062\n",
"--- --- --- --- --- --- --- --- --- --- --- --- --- ---\n",
" 2018-04-20 23:39:01 3.860 sec 117.0 0.22079158962413375 0.3406936623575952 0.9061981641628051 6.861764705882353 0.04629232747535362 0.20587207920559789 0.17691917336902896 0.9388050689100865 6.993006993006993 0.041\n",
" 2018-04-20 23:39:01 3.921 sec 118.0 0.22041952436491502 0.3402014866020256 0.9067309111301319 6.861764705882353 0.04586369481354479 0.20590815841200824 0.17696748670339232 0.9385602728659904 6.993006993006993 0.04\n",
" 2018-04-20 23:39:02 3.988 sec 119.0 0.2202526177493472 0.34022129663205214 0.9065265192880966 6.861764705882353 0.04629232747535362 0.20585093865188814 0.17710146258395315 0.9384256350417378 6.993006993006993 0.04\n",
" 2018-04-20 23:39:06 7.993 sec 780.0 0.21664584404283624 0.20078819748633736 0.8997262477494761 6.861764705882353 0.04243463351907415 0.2045416002818827 0.17396907936194805 0.9412571092851139 6.993006993006993 0.039\n",
" 2018-04-20 23:39:06 8.543 sec 800.0 0.21642784657414169 0.20074033630931348 0.8994753696762197 6.861764705882353 0.04200600085726532 0.20448181196350304 0.17373883641178992 0.9416814224282135 6.993006993006993 0.04"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"See the whole table with table.as_data_frame()\n",
"Variable Importances: \n"
]
},
{
"data": {
"text/html": [
"variable \n",
"relative_importance \n",
"scaled_importance \n",
"percentage \n",
"Total day minutes \n",
"20748.4902344 \n",
"1.0 \n",
"0.1363729 \n",
"Total day charge \n",
"19548.7832031 \n",
"0.9421786 \n",
"0.1284876 \n",
"Customer service calls \n",
"19157.3359375 \n",
"0.9233123 \n",
"0.1259148 \n",
"International plan \n",
"14547.7001953 \n",
"0.7011450 \n",
"0.0956172 \n",
"Total eve charge \n",
"9294.6621094 \n",
"0.4479681 \n",
"0.0610907 \n",
"Total eve minutes \n",
"9159.7636719 \n",
"0.4414665 \n",
"0.0602041 \n",
"Total intl calls \n",
"8102.6518555 \n",
"0.3905177 \n",
"0.0532560 \n",
"Total intl minutes \n",
"6614.3745117 \n",
"0.3187882 \n",
"0.0434741 \n",
"Total intl charge \n",
"6516.2104492 \n",
"0.3140571 \n",
"0.0428289 \n",
"Total night charge \n",
"5165.4580078 \n",
"0.2489558 \n",
"0.0339508 \n",
"Total night minutes \n",
"5084.6352539 \n",
"0.2450605 \n",
"0.0334196 \n",
"Total night calls \n",
"5060.0112305 \n",
"0.2438737 \n",
"0.0332578 \n",
"Total day calls \n",
"4819.5454102 \n",
"0.2322841 \n",
"0.0316773 \n",
"Number vmail messages \n",
"4619.8515625 \n",
"0.2226596 \n",
"0.0303647 \n",
"Total eve calls \n",
"4476.2211914 \n",
"0.2157372 \n",
"0.0294207 \n",
"Account length \n",
"4363.5600586 \n",
"0.2103074 \n",
"0.0286802 \n",
"Voice mail plan \n",
"3168.4050293 \n",
"0.1527053 \n",
"0.0208249 \n",
"Area code \n",
"1697.5781250 \n",
"0.0818169 \n",
"0.0111576
"
],
"text/plain": [
"variable relative_importance scaled_importance percentage\n",
"---------------------- --------------------- ------------------- ------------\n",
"Total day minutes 20748.5 1 0.136373\n",
"Total day charge 19548.8 0.942179 0.128488\n",
"Customer service calls 19157.3 0.923312 0.125915\n",
"International plan 14547.7 0.701145 0.0956172\n",
"Total eve charge 9294.66 0.447968 0.0610907\n",
"Total eve minutes 9159.76 0.441467 0.0602041\n",
"Total intl calls 8102.65 0.390518 0.053256\n",
"Total intl minutes 6614.37 0.318788 0.0434741\n",
"Total intl charge 6516.21 0.314057 0.0428289\n",
"Total night charge 5165.46 0.248956 0.0339508\n",
"Total night minutes 5084.64 0.24506 0.0334196\n",
"Total night calls 5060.01 0.243874 0.0332578\n",
"Total day calls 4819.55 0.232284 0.0316773\n",
"Number vmail messages 4619.85 0.22266 0.0303647\n",
"Total eve calls 4476.22 0.215737 0.0294207\n",
"Account length 4363.56 0.210307 0.0286802\n",
"Voice mail plan 3168.41 0.152705 0.0208249\n",
"Area code 1697.58 0.0818169 0.0111576"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": []
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = training.columns\n",
"X.remove('Churn')\n",
"y = 'Churn'\n",
"\n",
"rf1 = H2ORandomForestEstimator(model_id='tutorial1', ntrees=800, seed=152)\n",
"rf1.train(X, y, training_frame=training, validation_frame=validation)\n",
"rf1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Получился большой вывод информации о модели. \n",
"Сразу следует обратить внимание на AUC при тех же заданных параметрах (ntrees=800) "
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9416814224282135\n"
]
}
],
"source": [
"print(rf1.auc(valid=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Вся информация идет сначала об обучающей, потом о валидационной выборке\n",
"* Отображаются все метрики (MSE, RMSE, LogLoss, AUC, Gini и т.д.)\n",
"* Строится матрица ошибок (confusion matrix), которая приводится для порогового значения спрогнозированной вероятности события, оптимального с точки зрения F1-меры _(для справки : F1- мера = 2 x точность x полнота/(точность + полнота))_\n",
"* _Maximum Metrics:_ Рассчитываются на ней различные метрики и соответствующие пороговые значения\n",
"* _Gains/Lift Table:_ Таблица выигрышей создается путем разбиения данных на группы по квантильным пороговым значениям спрогнозированной верроятности положительного класса\n",
"* _Variable Importances:_ Информация о важностях предикторов. Информация выводится также в отмасштабированном и процентном видах. На сайте с документацией указано, что важность рассчитывается, как относительное влияние каждой переменной: была ли переменная выбрана при построении дерева и на как изменилась среднеквадратичная ошибка (рассчитывается на всех деревьях)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# можно нарисовать графиком \n",
"import matplotlib.pyplot as plt\n",
"plt.rcdefaults()\n",
"fig, ax = plt.subplots()\n",
"variables = rf1._model_json['output']['variable_importances']['variable']\n",
"y_pos = np.arange(len(variables))\n",
"scaled_importance = rf1._model_json['output']['variable_importances']['scaled_importance']\n",
"ax.barh(y_pos, scaled_importance, align='center', color='green', ecolor='black')\n",
"ax.set_yticks(y_pos)\n",
"ax.set_yticklabels(variables)\n",
"ax.invert_yaxis()\n",
"ax.set_xlabel('Scaled Importance')\n",
"ax.set_title('Variable Importance')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"___Описание параметров вызова функции___ \n",
"1. Параметры, определяющие задачу \n",
" * __model_id:__ идентификатор модели\n",
" * __training_frame:__ датасет для построения модели\n",
" * __validation_frame:__ датасет для валидации\n",
" * __nfolds:__ количество фолдов для кросс-валидации (по умолчанию 0)\n",
" * __y:__ имена зависимой переменной\n",
" * __x:__ список названий предикторов\n",
" * __seed:__ random state \n",
"2. Параметры, задающие сложность дерева\n",
" * __ntrees__: количество деревьев\n",
" * __max_depth:__ максимальная глубина дерева\n",
" * __min_rows:__ минимальное количество наблюдений в терминальном листе\n",
" \n",
"3. Параметры, определяющие формирование подвыборок\n",
" * __mtries:__ количество случайно отбираемых предикторов для разбиения узла. По умолчанию -1: для классификации корень квадратный из р, для регрессии р / 3, где р - количество предикторов\n",
" * __sample_rate:__ какую часть строк отбирать (от 0 до 1). По умолчанию 0.6320000291 \n",
" * __sample_rate_per_class:__ для построения модели из несбалансированного набора данных. Какую часть строк выбирать для каждого дерева (от 0 до 1) \n",
" * __col_sample_rate_per_tree:__ какую часть столбцов выбирать для каждого дерева (от 0 до 1, по умоланчанию 1)\n",
" * __col_sample_rate_change_per_level:__ задает изменение отбора столбцов для каждого уровня дерева (от 0 до 2, по умолчанию 1), например: (factor = col_sample_rate_change_per_level)\n",
" * level 1: col_sample_rate\n",
" * level 2: col_sample_rate * factor\n",
" * level 3: col_sample_rate * factor^2\n",
" * level 4: col_sample_rate * factor^3\n",
" \n",
"4. Параметры, определяющие биннинг переменных\n",
" * __nbins__: (Numerical/real/int only) для каждой переменной строит по крайней мере n интервалов и затем рабивает по наилучшей точке расщепления\n",
" * __nbins_top_level:__ (Numerical/real/int only) определяет максимальное количество интервалов n на вершине дерева. При переходе на увроень ниже делит заданное число на 2 до тех пор, пока не дойдет до уровня nbins \n",
" * __nbins_cats:__ (Categorical/enums only) каждая переменная может быть разбита максимум на n интерваловю. Более высокие значения могут привести к переобучению.\n",
" * __categorical_encoding:__ схема кодировки для категориальных переменных\n",
" * auto: используется схема enum.\n",
" * enum: 1 столбец для каждого категориального признака (как есть)\n",
" * one_hot_explicit: N+1 новых столбцов для каждого признака с N уровнями\n",
" * binary: не более 32 столбцов для признака (используется хеширование)\n",
" * eigen: выполняет one-hot-encoding и оставляет k главных компонент\n",
" * label_encoder: все категории сортируются в лексикографическом порядке и каждому уровню присваивается целое число, начиная с 0 (например, level 0 -> 0, level 1 -> 1, и т.д.)\n",
" * sort_by_response: все категории сортируются по среднему значению переменной и каждому уровню присваивается значение (например, для уровня с мин средним ответом -> 0, вторым наименьшим -> 1, и т.д.). \n",
" * __histogram_type:__ тип гистограммы для поиска оптимальных расщепляющих значений\n",
" * AUTO = UniformAdaptive\n",
" * UniformAdaptive: задаются интервалы одинаковой ширины (max-min)/N\n",
" * QuantilesGlobal: интервалы одинаквого размера (в каждом интервале одинаковое количество наблюдений)\n",
" * Random: задает построение Extremely Randomized Trees (XRT). Случайным образом отбирается N-1 точка расщепления и затем выбирается наилучшее разбиение\n",
" * RoundRobin: все типы гистограмм (по одному на каждое дерево) перебираются по кругу \n",
" \n",
"5. Параметры остановки\n",
" * __stopping_rounds:__ задает количество шагов в течение которого должно произойти заданное улучшение(stopping_tolerance:) для заданной метрики (stopping_metric) \n",
" * __stopping_metric:__ метрика для ранней остановки (по умолчанию логлосс для классификации и дисперсия - для регрессии). Возможные значения: deviance, logloss, mse, rmse, mae, rmsle, auc, lift_top_group, misclassification, mean_per_class_error\n",
" * __stopping_tolerance:__ относительное улучшение для остановки обучения (если меньше, то остановка)\n",
"\n",
" \n",
" \n",
" \n",
"Полный список всех параметров с описанием можно найти на сайте с [официальной документацией](http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/drf.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Отдельно стоит отметить, что биннинг переменных в h2o позволяет добиться улучшения качества и скорости построения леса. \n",
"\n",
"__Количественные переменные__\n",
"В классической реализации случайного леса для количественного предиктора с k категориями может быть рассмотрено k-1 вариантов разбиения. (Все значения сортируются и средние значения по каждой паре рассматриваются в качестве точек расщепления). \n",
"В h2o для каждой переменной определяются типы и количество интервалов для разбиения, т.е. создается гистограмма (тип : histogram_type), и количество регулируется двумя параметрами: nbins и nbins_top_level. Перебор этих параметров позволит улучшить качество леса.\n",
"\n",
"__Качественные переменные__\n",
"В классической реализации случайного леса для категориального предиктора с k категориями может быть рассмотрено ${2^{k-1}} - 1$ разбиения (всеми возможными способами). \n",
"В h2o для каждой качественной переменной также определюется типы и количество интервалов для разбиения. Количество регулируется : nbins_cats и nbins_top_level (при переходе на уровень ниже значение nbins_top_level уменьшается в 2 раза до тех пор пока значение больше nbins_cats). \n",
"Пример как происходит разбиение на бины: \n",
"Если количество категорий меньше значения параметра nbins_cats, каждая категория\n",
"получает свой бин. Допустим, у нас есть переменная Class. Если у нее\n",
"есть уровни A, B, C, D, E, F, G и мы зададим nbins_cats=8, то будут\n",
"сформировано 7 бинов: {A}, {B}, {C}, {D}, {E}, {F} и {G}. Каждая категория\n",
"получает свой бин. Будет рассмотрено ${2^6}-1=63$ точки расщепления. Если\n",
"мы зададим nbins_cats=10, то все равно будут получены те же самые\n",
"бины, потому что у нас всего 7 категорий. Если количество категорий\n",
"больше значения параметра nbins_cats, категории будут сгруппированы\n",
"в бины в лексикографическом порядке. Например, если мы зададим\n",
"nbins_cats=2, то будет сформировано 2 бина: {A, B, C, D} и {E, F, G}. У\n",
"нас будет одна точка расщепления. A, B, C и D попадут в один и тот же\n",
"узел и будут разбиты только на последующем, более нижнем уровне или\n",
"вообще не будут разбиты. \n",
" \n",
"Этот параметр очень важен для настройки, при больших значениях параметра мы можем получить дополнительную случайность в построении "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"И в завершение этой части сделаем подбор параметров и попытаемся улучшить модель по разным параметрам"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from h2o.grid.grid_search import H2OGridSearch"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"drf Grid Build progress: |████████████████████████████████████████████████| 100%\n",
" max_depth model_ids auc\n",
"0 14 rf_grid_max_depth_model_1 0.9433909148028168\n",
"1 18 rf_grid_max_depth_model_2 0.9421669345823371\n",
"2 10 rf_grid_max_depth_model_0 0.9416895822963501\n",
"3 24 rf_grid_max_depth_model_3 0.9414692658566638\n",
"\n",
"0.9433909148028168\n",
"Hyperparameters: [max_depth]\n",
"[14]\n"
]
}
],
"source": [
"rf_params = {'max_depth': [10, 14, 18, 24]}\n",
"\n",
"rf_grid = H2OGridSearch(model=H2ORandomForestEstimator(ntrees=800, seed=152),\n",
" grid_id='rf_grid_max_depth',\n",
" hyper_params=rf_params)\n",
"rf_grid.train(X, y, training_frame=training, validation_frame=validation)\n",
"\n",
"# модели, отсортированные по AUC\n",
"rf_gridperf = rf_grid.get_grid(sort_by='auc', decreasing=True)\n",
"print(rf_gridperf)\n",
"\n",
"# выберем лучшую модель и выведем AUC на тесте \n",
"best_rf = rf_gridperf.models[0]\n",
"print(best_rf.auc(valid=True))\n",
"print(rf_gridperf.get_hyperparams(0))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"drf Grid Build progress: |████████████████████████████████████████████████| 100%\n",
" mtries model_ids auc\n",
"0 7 rf_grid_mtries1_model_2 0.9422444533296342\n",
"1 5 rf_grid_mtries1_model_1 0.9417263017029645\n",
"2 3 rf_grid_mtries1_model_0 0.9406940783836933\n",
"3 9 rf_grid_mtries1_model_3 0.9390417050860458\n",
"\n",
"0.9422444533296342\n",
"Hyperparameters: [mtries]\n",
"[7]\n"
]
}
],
"source": [
"rf_params1 = {'mtries': [3, 5, 7, 9]}\n",
"\n",
"rf_grid1 = H2OGridSearch(model=H2ORandomForestEstimator(ntrees=800, seed=152, max_depth=14),\n",
" grid_id='rf_grid_mtries1',\n",
" hyper_params=rf_params1)\n",
"rf_grid1.train(X, y, training_frame=training, validation_frame=validation)\n",
"\n",
"# модели, отсортированные по AUC\n",
"rf_gridperf1 = rf_grid1.get_grid(sort_by='auc', decreasing=True)\n",
"print(rf_gridperf1)\n",
"\n",
"# выберем лучшую модель и выведем AUC на тесте \n",
"best_rf1 = rf_gridperf1.models[0]\n",
"print(best_rf1.auc(valid=True))\n",
"print(rf_gridperf1.get_hyperparams(0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"теперь к настройке более специфических параметров"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"drf Grid Build progress: |████████████████████████████████████████████████| 100%\n",
" histogram_type model_ids auc\n",
"0 RoundRobin rf_grid_hist_type_model_3 0.9433868348687484\n",
"1 UniformAdaptive rf_grid_hist_type_model_0 0.9422444533296342\n",
"2 Random rf_grid_hist_type_model_1 0.9391192238333429\n",
"3 QuantilesGlobal rf_grid_hist_type_model_2 0.9380339613711842\n",
"\n",
"AUC for the best model: 0.9433868348687484\n",
"Hyperparameters: [histogram_type]\n",
"['RoundRobin']\n"
]
}
],
"source": [
"rf_params2 = {'histogram_type': ['UniformAdaptive', 'Random', 'QuantilesGlobal', 'RoundRobin']}\n",
"\n",
"rf_grid2 = H2OGridSearch(model=H2ORandomForestEstimator(ntrees=800, seed=152, max_depth=14, mtries=7),\n",
" grid_id='rf_grid_hist_type',\n",
" hyper_params=rf_params2)\n",
"rf_grid2.train(X, y, training_frame=training, validation_frame=validation)\n",
"\n",
"# модели, отсортированные по AUC\n",
"rf_gridperf2 = rf_grid2.get_grid(sort_by='auc', decreasing=True)\n",
"print(rf_gridperf2)\n",
"\n",
"# выберем лучшую модель и выведем AUC на тесте \n",
"best_rf2 = rf_gridperf2.models[0]\n",
"print('AUC for the best model: ', best_rf2.auc(valid=True))\n",
"print(rf_gridperf2.get_hyperparams(0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"По типу гистограммы RoundRobin оказался самым оптимальным"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"drf Grid Build progress: |████████████████████████████████████████████████| 100%\n",
" col_sample_rate_per_tree sample_rate model_ids \\\n",
"0 0.9 0.6 rf_grid3_model_9 \n",
"1 0.7 0.7 rf_grid3_model_12 \n",
"2 0.9 0.7 rf_grid3_model_14 \n",
"3 0.7 0.5 rf_grid3_model_2 \n",
"4 0.8 0.6 rf_grid3_model_8 \n",
"5 0.9 0.8 rf_grid3_model_19 \n",
"6 0.9 0.9 rf_grid3_model_24 \n",
"7 0.9 0.5 rf_grid3_model_4 \n",
"8 0.8 0.5 rf_grid3_model_3 \n",
"9 0.6 0.5 rf_grid3_model_1 \n",
"10 0.7 0.6 rf_grid3_model_7 \n",
"11 0.7 0.8 rf_grid3_model_17 \n",
"12 0.8 0.7 rf_grid3_model_13 \n",
"13 0.6 0.8 rf_grid3_model_16 \n",
"14 0.5 0.7 rf_grid3_model_10 \n",
"15 0.8 0.8 rf_grid3_model_18 \n",
"16 0.5 0.5 rf_grid3_model_0 \n",
"17 0.6 0.7 rf_grid3_model_11 \n",
"18 0.8 0.9 rf_grid3_model_23 \n",
"19 0.7 0.9 rf_grid3_model_22 \n",
"20 0.6 0.6 rf_grid3_model_6 \n",
"21 0.6 0.9 rf_grid3_model_21 \n",
"22 0.5 0.6 rf_grid3_model_5 \n",
"23 0.5 0.8 rf_grid3_model_15 \n",
"24 0.5 0.9 rf_grid3_model_20 \n",
"\n",
" auc \n",
"0 0.9416855023622818 \n",
"1 0.9415875839446435 \n",
"2 0.941518225065483 \n",
"3 0.9405920800319867 \n",
"4 0.9405472007572357 \n",
"5 0.9403554438560272 \n",
"6 0.9402412057021159 \n",
"7 0.9402412057021159 \n",
"8 0.9401188076800678 \n",
"9 0.9398862514381767 \n",
"10 0.9396904146029 \n",
"11 0.9396659349984905 \n",
"12 0.9393640198774388 \n",
"13 0.9392212221850494 \n",
"14 0.9385194735253078 \n",
"15 0.9375321294807876 \n",
"16 0.9374015715906031 \n",
"17 0.9373322127114426 \n",
"18 0.9373281327773744 \n",
"19 0.9370751768651419 \n",
"20 0.9361816713041917 \n",
"21 0.9354268835015626 \n",
"22 0.9351494479849205 \n",
"23 0.935018890094736 \n",
"24 0.9342763420943118 \n",
"\n",
"AUC for the best model: 0.9416855023622818\n",
"Hyperparameters: [col_sample_rate_per_tree, sample_rate]\n",
"[0.9, 0.6]\n"
]
}
],
"source": [
"rf_params3 = {'col_sample_rate_per_tree': [0.5, 0.6, 0.7, 0.8, 0.9], \n",
" 'sample_rate': [0.5, 0.6, 0.7, 0.8, 0.9] }\n",
"\n",
"rf_grid3 = H2OGridSearch(model=H2ORandomForestEstimator(ntrees=800, seed=152, max_depth=14, mtries=7,\n",
" histogram_type='RoundRobin'),\n",
" grid_id='rf_grid3',\n",
" hyper_params=rf_params3)\n",
"rf_grid3.train(X, y, training_frame=training, validation_frame=validation)\n",
"\n",
"# модели, отсортированные по AUC\n",
"rf_gridperf3 = rf_grid3.get_grid(sort_by='auc', decreasing=True)\n",
"print(rf_gridperf3)\n",
"\n",
"# выберем лучшую модель и выведем AUC на тесте \n",
"best_rf3 = rf_gridperf3.models[0]\n",
"print('AUC for the best model: ', best_rf3.auc(valid=True))\n",
"print(rf_gridperf3.get_hyperparams(0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logistic Regression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### sklearn LogisticRegression"
]
},
{
"cell_type": "code",
"execution_count": 213,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression"
]
},
{
"cell_type": "code",
"execution_count": 214,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC для sklearn LogisticRegression: 0.8271\n"
]
}
],
"source": [
"logreg = LogisticRegression().fit(X_train, y_train)\n",
"print('AUC для sklearn LogisticRegression: {:.4f}'.format(roc_auc_score(y_test, logreg.predict_proba(X_test)[:, 1])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"проверим на отмасштабированных данных"
]
},
{
"cell_type": "code",
"execution_count": 215,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC для sklearn LogisticRegression Scaled: 0.8283\n"
]
}
],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"scaler = StandardScaler()\n",
"X_train_scaled = scaler.fit_transform(X_train)\n",
"X_test_scaled = scaler.transform(X_test)\n",
"\n",
"logreg = LogisticRegression().fit(X_train_scaled, y_train)\n",
"print('AUC для sklearn LogisticRegression Scaled: {:.4f}'.format(\n",
" roc_auc_score(y_test, logreg.predict_proba(X_test_scaled)[:, 1])))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"результаты хуже, чем в лесу \n",
"займемся подбором параметров в H2O"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### H2OGeneralizedLinearEstimator"
]
},
{
"cell_type": "code",
"execution_count": 216,
"metadata": {},
"outputs": [],
"source": [
"from h2o.estimators.glm import H2OGeneralizedLinearEstimator"
]
},
{
"cell_type": "code",
"execution_count": 217,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"glm Model Build progress: |███████████████████████████████████████████████| 100%\n"
]
}
],
"source": [
"# создаем экземпляр класса H2OGeneralizedLinearEstimator\n",
"glm_model = H2OGeneralizedLinearEstimator(family= \"binomial\", seed=1000000)\n",
"# обучаем модель\n",
"glm_model.train(X, y, training_frame= training, validation_frame=validation)"
]
},
{
"cell_type": "code",
"execution_count": 218,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"ModelMetricsBinomialGLM: glm\n",
"** Reported on train data. **\n",
"\n",
"MSE: 0.09918009133349871\n",
"RMSE: 0.3149287083349162\n",
"LogLoss: 0.3263867097273659\n",
"Null degrees of freedom: 2332\n",
"Residual degrees of freedom: 2312\n",
"Null deviance: 1937.5065769088035\n",
"Residual deviance: 1522.9203875878893\n",
"AIC: 1564.9203875878893\n",
"AUC: 0.8190335291166141\n",
"Gini: 0.6380670582332282\n",
"Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.24799189912357045: \n"
]
},
{
"data": {
"text/html": [
" \n",
"0 \n",
"1 \n",
"Error \n",
"Rate \n",
"0 \n",
"1773.0 \n",
"220.0 \n",
"0.1104 \n",
" (220.0/1993.0) \n",
"1 \n",
"149.0 \n",
"191.0 \n",
"0.4382 \n",
" (149.0/340.0) \n",
"Total \n",
"1922.0 \n",
"411.0 \n",
"0.1582 \n",
" (369.0/2333.0)
"
],
"text/plain": [
" 0 1 Error Rate\n",
"----- ---- --- ------- --------------\n",
"0 1773 220 0.1104 (220.0/1993.0)\n",
"1 149 191 0.4382 (149.0/340.0)\n",
"Total 1922 411 0.1582 (369.0/2333.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum Metrics: Maximum metrics at their respective thresholds\n",
"\n"
]
},
{
"data": {
"text/html": [
"metric \n",
"threshold \n",
"value \n",
"idx \n",
"max f1 \n",
"0.2479919 \n",
"0.5086551 \n",
"178.0 \n",
"max f2 \n",
"0.1228917 \n",
"0.6103074 \n",
"263.0 \n",
"max f0point5 \n",
"0.3054631 \n",
"0.5099502 \n",
"150.0 \n",
"max accuracy \n",
"0.6080174 \n",
"0.8649807 \n",
"47.0 \n",
"max precision \n",
"0.9870110 \n",
"1.0 \n",
"0.0 \n",
"max recall \n",
"0.0069831 \n",
"1.0 \n",
"396.0 \n",
"max specificity \n",
"0.9870110 \n",
"1.0 \n",
"0.0 \n",
"max absolute_mcc \n",
"0.2479919 \n",
"0.4180577 \n",
"178.0 \n",
"max min_per_class_accuracy \n",
"0.1400732 \n",
"0.75 \n",
"248.0 \n",
"max mean_per_class_accuracy \n",
"0.1557006 \n",
"0.7547866 \n",
"237.0
"
],
"text/plain": [
"metric threshold value idx\n",
"--------------------------- ----------- -------- -----\n",
"max f1 0.247992 0.508655 178\n",
"max f2 0.122892 0.610307 263\n",
"max f0point5 0.305463 0.50995 150\n",
"max accuracy 0.608017 0.864981 47\n",
"max precision 0.987011 1 0\n",
"max recall 0.00698315 1 396\n",
"max specificity 0.987011 1 0\n",
"max absolute_mcc 0.247992 0.418058 178\n",
"max min_per_class_accuracy 0.140073 0.75 248\n",
"max mean_per_class_accuracy 0.155701 0.754787 237"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gains/Lift Table: Avg response rate: 14,57 %\n",
"\n"
]
},
{
"data": {
"text/html": [
" \n",
"group \n",
"cumulative_data_fraction \n",
"lower_threshold \n",
"lift \n",
"cumulative_lift \n",
"response_rate \n",
"cumulative_response_rate \n",
"capture_rate \n",
"cumulative_capture_rate \n",
"gain \n",
"cumulative_gain \n",
" \n",
"1 \n",
"0.0102872 \n",
"0.7456359 \n",
"5.4322304 \n",
"5.4322304 \n",
"0.7916667 \n",
"0.7916667 \n",
"0.0558824 \n",
"0.0558824 \n",
"443.2230392 \n",
"443.2230392 \n",
" \n",
"2 \n",
"0.0201457 \n",
"0.6584274 \n",
"4.1767263 \n",
"4.8178348 \n",
"0.6086957 \n",
"0.7021277 \n",
"0.0411765 \n",
"0.0970588 \n",
"317.6726343 \n",
"381.7834793 \n",
" \n",
"3 \n",
"0.0300043 \n",
"0.5984535 \n",
"4.1767263 \n",
"4.6071849 \n",
"0.6086957 \n",
"0.6714286 \n",
"0.0411765 \n",
"0.1382353 \n",
"317.6726343 \n",
"360.7184874 \n",
" \n",
"4 \n",
"0.0402915 \n",
"0.5507832 \n",
"2.8590686 \n",
"4.1608573 \n",
"0.4166667 \n",
"0.6063830 \n",
"0.0294118 \n",
"0.1676471 \n",
"185.9068627 \n",
"316.0857322 \n",
" \n",
"5 \n",
"0.0501500 \n",
"0.5157454 \n",
"2.9833760 \n",
"3.9293866 \n",
"0.4347826 \n",
"0.5726496 \n",
"0.0294118 \n",
"0.1970588 \n",
"198.3375959 \n",
"292.9386626 \n",
" \n",
"6 \n",
"0.1003000 \n",
"0.3692585 \n",
"3.1083208 \n",
"3.5188537 \n",
"0.4529915 \n",
"0.5128205 \n",
"0.1558824 \n",
"0.3529412 \n",
"210.8320764 \n",
"251.8853695 \n",
" \n",
"7 \n",
"0.1500214 \n",
"0.2860657 \n",
"2.9576572 \n",
"3.3328571 \n",
"0.4310345 \n",
"0.4857143 \n",
"0.1470588 \n",
"0.5 \n",
"195.7657201 \n",
"233.2857143 \n",
" \n",
"8 \n",
"0.2001715 \n",
"0.2245553 \n",
"1.7594268 \n",
"2.9386573 \n",
"0.2564103 \n",
"0.4282655 \n",
"0.0882353 \n",
"0.5882353 \n",
"75.9426848 \n",
"193.8657262 \n",
" \n",
"9 \n",
"0.3000429 \n",
"0.1520522 \n",
"1.4430321 \n",
"2.4408277 \n",
"0.2103004 \n",
"0.3557143 \n",
"0.1441176 \n",
"0.7323529 \n",
"44.3032063 \n",
"144.0827731 \n",
" \n",
"10 \n",
"0.3999143 \n",
"0.1133559 \n",
"0.8540394 \n",
"2.0445558 \n",
"0.1244635 \n",
"0.2979636 \n",
"0.0852941 \n",
"0.8176471 \n",
"-14.5960616 \n",
"104.4555829 \n",
" \n",
"11 \n",
"0.5002143 \n",
"0.0863671 \n",
"0.6157994 \n",
"1.7580700 \n",
"0.0897436 \n",
"0.2562125 \n",
"0.0617647 \n",
"0.8794118 \n",
"-38.4200603 \n",
"75.8069963 \n",
" \n",
"12 \n",
"0.6000857 \n",
"0.0648764 \n",
"0.3239460 \n",
"1.5193908 \n",
"0.0472103 \n",
"0.2214286 \n",
"0.0323529 \n",
"0.9117647 \n",
"-67.6054027 \n",
"51.9390756 \n",
" \n",
"13 \n",
"0.6999571 \n",
"0.0473250 \n",
"0.2944963 \n",
"1.3446202 \n",
"0.0429185 \n",
"0.1959584 \n",
"0.0294118 \n",
"0.9411765 \n",
"-70.5503661 \n",
"34.4620151 \n",
" \n",
"14 \n",
"0.7998285 \n",
"0.0337193 \n",
"0.2061474 \n",
"1.2024636 \n",
"0.0300429 \n",
"0.1752412 \n",
"0.0205882 \n",
"0.9617647 \n",
"-79.3852562 \n",
"20.2463590 \n",
" \n",
"15 \n",
"0.8997000 \n",
"0.0205791 \n",
"0.2650467 \n",
"1.0984054 \n",
"0.0386266 \n",
"0.1600762 \n",
"0.0264706 \n",
"0.9882353 \n",
"-73.4953295 \n",
"9.8405403 \n",
" \n",
"16 \n",
"1.0 \n",
"0.0017895 \n",
"0.1172951 \n",
"1.0 \n",
"0.0170940 \n",
"0.1457351 \n",
"0.0117647 \n",
"1.0 \n",
"-88.2704877 \n",
"0.0
"
],
"text/plain": [
" group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n",
"-- ------- -------------------------- ----------------- -------- ----------------- --------------- -------------------------- -------------- ------------------------- -------- -----------------\n",
" 1 0.0102872 0.745636 5.43223 5.43223 0.791667 0.791667 0.0558824 0.0558824 443.223 443.223\n",
" 2 0.0201457 0.658427 4.17673 4.81783 0.608696 0.702128 0.0411765 0.0970588 317.673 381.783\n",
" 3 0.0300043 0.598454 4.17673 4.60718 0.608696 0.671429 0.0411765 0.138235 317.673 360.718\n",
" 4 0.0402915 0.550783 2.85907 4.16086 0.416667 0.606383 0.0294118 0.167647 185.907 316.086\n",
" 5 0.05015 0.515745 2.98338 3.92939 0.434783 0.57265 0.0294118 0.197059 198.338 292.939\n",
" 6 0.1003 0.369259 3.10832 3.51885 0.452991 0.512821 0.155882 0.352941 210.832 251.885\n",
" 7 0.150021 0.286066 2.95766 3.33286 0.431034 0.485714 0.147059 0.5 195.766 233.286\n",
" 8 0.200171 0.224555 1.75943 2.93866 0.25641 0.428266 0.0882353 0.588235 75.9427 193.866\n",
" 9 0.300043 0.152052 1.44303 2.44083 0.2103 0.355714 0.144118 0.732353 44.3032 144.083\n",
" 10 0.399914 0.113356 0.854039 2.04456 0.124464 0.297964 0.0852941 0.817647 -14.5961 104.456\n",
" 11 0.500214 0.0863671 0.615799 1.75807 0.0897436 0.256213 0.0617647 0.879412 -38.4201 75.807\n",
" 12 0.600086 0.0648764 0.323946 1.51939 0.0472103 0.221429 0.0323529 0.911765 -67.6054 51.9391\n",
" 13 0.699957 0.047325 0.294496 1.34462 0.0429185 0.195958 0.0294118 0.941176 -70.5504 34.462\n",
" 14 0.799829 0.0337193 0.206147 1.20246 0.0300429 0.175241 0.0205882 0.961765 -79.3853 20.2464\n",
" 15 0.8997 0.0205791 0.265047 1.09841 0.0386266 0.160076 0.0264706 0.988235 -73.4953 9.84054\n",
" 16 1 0.00178946 0.117295 1 0.017094 0.145735 0.0117647 1 -88.2705 0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"glm_model.model_performance().show()"
]
},
{
"cell_type": "code",
"execution_count": 219,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.8190335291166141'"
]
},
"execution_count": 219,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"str(glm_model.model_performance().auc())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"___Описание параметров вызова функции___ \n",
"1. Параметры, определяющие задачу \n",
" * __model_id:__ идентификатор \n",
" * __training_frame:__ датасет для построения модели\n",
" * __validation_frame:__ датасет для валидации\n",
" * __nfolds:__ количество фолдов для кросс-валидации (по умолчанию 0)\n",
" * __y:__ имена зависимой переменной\n",
" * __x:__ список названий предикторов\n",
" * __seed:__ random state \n",
" * __family:__ тип модели (gaussian, binomial, multinomial, ordinal, quasibinomial, poisson, gamma, tweedie \n",
" * __solver:__ \n",
" * IRLSM: Iteratively Reweighted Least Squares Method - используется с небольшим количеством предикторов и для l1-регулярзации\n",
" * L_BFGS: Limited-memory Broyden-Fletcher-Goldfarb-Shanno algorithm - используется для данных в большим числом колонок\n",
" * COORDINATE_DESCENT, COORDINATE_DESCENT_NAIVE - экспериментальные\n",
" * AUTO: Sets the solver based on given data and parameters (default)\n",
" * GRADIENT_DESCENT_LH, GRADIENT_DESCENT_SQERR: используется только для family=Ordinal \n",
"2. Параметры, определяющие регуляризацию \n",
"для справки формула ElasticNet, объединяющая $L_1$ и $L_2$ регуляризацию\n",
"$$\\large \\begin{array}{rcl}\n",
"L &=& -\\mathcal{L} + \\lambda R\\left(\\textbf W\\right) \\\\\n",
"&=& -\\mathcal{L} + \\lambda \\left(\\alpha \\sum_{k=1}^K\\sum_{i=1}^M w_{ki}^2 + \\left(1 - \\alpha\\right) \\sum_{k=1}^K\\sum_{i=1}^M \\left|w_{ki}\\right| \\right)\n",
"\\end{array}$$ где $\\alpha \\in \\left[0, 1\\right]$\n",
"\n",
" * __alpha:__ распределение между $L_1$ и $L_2$ регуляризацией. (1 - $L_1$, 0 - $L_2$)\n",
" * __lambda:__ сила регуляризации\n",
" * __lambda_search:__ True / False. Определяет стоит ли начинать поиск $\\lambda$, начиная с максимального значения\n",
" * __lambda_min_ratio:__ минимальное значение $\\lambda$, используемое при поиске $\\lambda$\n",
" * __nlambdas:__ количество шагов при поиске $\\lambda$ (по умолчанию 100)\n",
" \n",
"3. Параметры, влияющие на предобработку предикторов\n",
" * __standardize:__ использовать ли масштабирование\n",
" * __missing_values_handling:__ как работать с пропцщенными значениями (пропускать или испутировать средним)\n",
" * __remove_collinear_columns:__ удалять ли автоматически коллинеарные столбцы при построении модели \n",
" * __interactions:__ список колонок, из которых буду составлены все возможные пары и использованы для построения модели \n",
" * __interaction_pairs:__ список уже готовых пар для модели \n",
" \n",
"\n",
"Полный список всех параметров с описанием можно найти на сайте с [официальной документацией](http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/glm.html)\n",
" \n",
" \n",
"Воспольуземся перебором параметров [H2OGridSearch] для поиска лучшей модели"
]
},
{
"cell_type": "code",
"execution_count": 223,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"glm Grid Build progress: |████████████████████████████████████████████████| 100%\n",
" alpha model_ids auc\n",
"0 [0.0] gridresults_model_0 0.8284550921657106\n",
"1 [1.0] gridresults_model_20 0.8283938931546866\n",
"2 [0.9500000000000001] gridresults_model_19 0.8283490138799356\n",
"3 [0.9] gridresults_model_18 0.8282755750667069\n",
"4 [0.8500000000000001] gridresults_model_17 0.8282388556600926\n",
"5 [0.8] gridresults_model_16 0.8282347757260242\n",
"6 [0.7000000000000001] gridresults_model_14 0.8282102961216147\n",
"7 [0.75] gridresults_model_15 0.828202136253478\n",
"8 [0.65] gridresults_model_13 0.8281613369127955\n",
"9 [0.6000000000000001] gridresults_model_12 0.8280185392204062\n",
"10 [0.55] gridresults_model_11 0.8278186224510612\n",
"11 [0.5] gridresults_model_10 0.827785982978515\n",
"12 [0.45] gridresults_model_9 0.827753343505969\n",
"13 [0.35000000000000003] gridresults_model_7 0.8275615866047604\n",
"14 [0.4] gridresults_model_8 0.8275493468025557\n",
"15 [0.30000000000000004] gridresults_model_6 0.8274432685167807\n",
"16 [0.25] gridresults_model_5 0.8272963908903233\n",
"17 [0.2] gridresults_model_4 0.8272392718133674\n",
"18 [0.15000000000000002] gridresults_model_3 0.8270801543847051\n",
"19 [0.05] gridresults_model_1 0.8270679145825003\n",
"20 [0.1] gridresults_model_2 0.8270107955055446\n",
"\n",
"AUC for the best model: 0.8284550921657106\n",
"['Ridge ( lambda = 3.014E-5 )']\n"
]
}
],
"source": [
"hyper_parameters = {'alpha': np.arange(0, 1.05, 0.05).tolist()}\n",
"\n",
"gridsearch = H2OGridSearch(H2OGeneralizedLinearEstimator(family='binomial', lambda_search=True, standardize=True),\n",
" grid_id=\"gridresults\", hyper_params=hyper_parameters)\n",
"gridsearch.train(X, y, training_frame= training, validation_frame=validation)\n",
"\n",
"gridperf = gridsearch.get_grid(sort_by=\"auc\", decreasing=True)\n",
"best_model = gridperf.models[0]\n",
"print(gridperf)\n",
"\n",
"print('AUC for the best model: ', best_model.auc(valid=True))\n",
"print(best_model.summary()['regularization'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"H2O позволяет настроить силу регуляризации. \n",
"Для получения модели с более выоским AUC - надо попробовать предварительно разные способы масштабирования \n",
"И попробовать настроить еще один интересный параметр : interactions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gradient Boosting"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### sklearn GradientBoostingClassifier\n"
]
},
{
"cell_type": "code",
"execution_count": 224,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import GradientBoostingClassifier"
]
},
{
"cell_type": "code",
"execution_count": 225,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC для sklearn GradientBoostingClassifier: 0.9357\n"
]
}
],
"source": [
"grb = GradientBoostingClassifier().fit(X_train, y_train)\n",
"print('AUC для sklearn GradientBoostingClassifier: {:.4f}'.format(\n",
" roc_auc_score(y_test, grb.predict_proba(X_test)[:, 1])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Gradient Boosting Machine"
]
},
{
"cell_type": "code",
"execution_count": 226,
"metadata": {},
"outputs": [],
"source": [
"from h2o.estimators.gbm import H2OGradientBoostingEstimator"
]
},
{
"cell_type": "code",
"execution_count": 232,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gbm Model Build progress: |███████████████████████████████████████████████| 100%\n",
"\n",
"ModelMetricsBinomial: gbm\n",
"** Reported on train data. **\n",
"\n",
"MSE: 4.97643761962381e-09\n",
"RMSE: 7.05438701775272e-05\n",
"LogLoss: 2.387073099550759e-05\n",
"Mean Per-Class Error: 0.0\n",
"AUC: 1.0\n",
"Gini: 1.0\n",
"Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.9991932827883621: \n"
]
},
{
"data": {
"text/html": [
" \n",
"0 \n",
"1 \n",
"Error \n",
"Rate \n",
"0 \n",
"1993.0 \n",
"0.0 \n",
"0.0 \n",
" (0.0/1993.0) \n",
"1 \n",
"0.0 \n",
"340.0 \n",
"0.0 \n",
" (0.0/340.0) \n",
"Total \n",
"1993.0 \n",
"340.0 \n",
"0.0 \n",
" (0.0/2333.0)
"
],
"text/plain": [
" 0 1 Error Rate\n",
"----- ---- --- ------- ------------\n",
"0 1993 0 0 (0.0/1993.0)\n",
"1 0 340 0 (0.0/340.0)\n",
"Total 1993 340 0 (0.0/2333.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum Metrics: Maximum metrics at their respective thresholds\n",
"\n"
]
},
{
"data": {
"text/html": [
"metric \n",
"threshold \n",
"value \n",
"idx \n",
"max f1 \n",
"0.9991933 \n",
"1.0 \n",
"158.0 \n",
"max f2 \n",
"0.9991933 \n",
"1.0 \n",
"158.0 \n",
"max f0point5 \n",
"0.9991933 \n",
"1.0 \n",
"158.0 \n",
"max accuracy \n",
"0.9991933 \n",
"1.0 \n",
"158.0 \n",
"max precision \n",
"1.0000000 \n",
"1.0 \n",
"0.0 \n",
"max recall \n",
"0.9991933 \n",
"1.0 \n",
"158.0 \n",
"max specificity \n",
"1.0000000 \n",
"1.0 \n",
"0.0 \n",
"max absolute_mcc \n",
"0.9991933 \n",
"1.0 \n",
"158.0 \n",
"max min_per_class_accuracy \n",
"0.9991933 \n",
"1.0 \n",
"158.0 \n",
"max mean_per_class_accuracy \n",
"0.9991933 \n",
"1.0 \n",
"158.0
"
],
"text/plain": [
"metric threshold value idx\n",
"--------------------------- ----------- ------- -----\n",
"max f1 0.999193 1 158\n",
"max f2 0.999193 1 158\n",
"max f0point5 0.999193 1 158\n",
"max accuracy 0.999193 1 158\n",
"max precision 1 1 0\n",
"max recall 0.999193 1 158\n",
"max specificity 1 1 0\n",
"max absolute_mcc 0.999193 1 158\n",
"max min_per_class_accuracy 0.999193 1 158\n",
"max mean_per_class_accuracy 0.999193 1 158"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gains/Lift Table: Avg response rate: 14,57 %\n",
"\n"
]
},
{
"data": {
"text/html": [
" \n",
"group \n",
"cumulative_data_fraction \n",
"lower_threshold \n",
"lift \n",
"cumulative_lift \n",
"response_rate \n",
"cumulative_response_rate \n",
"capture_rate \n",
"cumulative_capture_rate \n",
"gain \n",
"cumulative_gain \n",
" \n",
"1 \n",
"0.0102872 \n",
"0.9999999 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0705882 \n",
"0.0705882 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"2 \n",
"0.0201457 \n",
"0.9999990 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.1382353 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"3 \n",
"0.0300043 \n",
"0.9999975 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.2058824 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"4 \n",
"0.0402915 \n",
"0.9999958 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0705882 \n",
"0.2764706 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"5 \n",
"0.0501500 \n",
"0.9999929 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.3441176 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"6 \n",
"0.1003000 \n",
"0.9999602 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.3441176 \n",
"0.6882353 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"7 \n",
"0.1500214 \n",
"0.0001398 \n",
"6.2702333 \n",
"6.6657143 \n",
"0.9137931 \n",
"0.9714286 \n",
"0.3117647 \n",
"1.0 \n",
"527.0233266 \n",
"566.5714286 \n",
" \n",
"8 \n",
"0.2001715 \n",
"0.0000579 \n",
"0.0 \n",
"4.9957173 \n",
"0.0 \n",
"0.7280514 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"399.5717345 \n",
" \n",
"9 \n",
"0.3000429 \n",
"0.0000241 \n",
"0.0 \n",
"3.3328571 \n",
"0.0 \n",
"0.4857143 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"233.2857143 \n",
" \n",
"10 \n",
"0.3999143 \n",
"0.0000117 \n",
"0.0 \n",
"2.5005359 \n",
"0.0 \n",
"0.3644159 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"150.0535906 \n",
" \n",
"11 \n",
"0.5002143 \n",
"0.0000058 \n",
"0.0 \n",
"1.9991431 \n",
"0.0 \n",
"0.2913453 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"99.9143102 \n",
" \n",
"12 \n",
"0.6000857 \n",
"0.0000030 \n",
"0.0 \n",
"1.6664286 \n",
"0.0 \n",
"0.2428571 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"66.6428571 \n",
" \n",
"13 \n",
"0.6999571 \n",
"0.0000013 \n",
"0.0 \n",
"1.4286589 \n",
"0.0 \n",
"0.2082058 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"42.8658910 \n",
" \n",
"14 \n",
"0.7998285 \n",
"0.0000006 \n",
"0.0 \n",
"1.2502680 \n",
"0.0 \n",
"0.1822079 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"25.0267953 \n",
" \n",
"15 \n",
"0.8997000 \n",
"0.0000002 \n",
"0.0 \n",
"1.1114817 \n",
"0.0 \n",
"0.1619819 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"11.1481658 \n",
" \n",
"16 \n",
"1.0 \n",
"0.0000000 \n",
"0.0 \n",
"1.0 \n",
"0.0 \n",
"0.1457351 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"0.0
"
],
"text/plain": [
" group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n",
"-- ------- -------------------------- ----------------- ------- ----------------- --------------- -------------------------- -------------- ------------------------- ------- -----------------\n",
" 1 0.0102872 1 6.86176 6.86176 1 1 0.0705882 0.0705882 586.176 586.176\n",
" 2 0.0201457 0.999999 6.86176 6.86176 1 1 0.0676471 0.138235 586.176 586.176\n",
" 3 0.0300043 0.999997 6.86176 6.86176 1 1 0.0676471 0.205882 586.176 586.176\n",
" 4 0.0402915 0.999996 6.86176 6.86176 1 1 0.0705882 0.276471 586.176 586.176\n",
" 5 0.05015 0.999993 6.86176 6.86176 1 1 0.0676471 0.344118 586.176 586.176\n",
" 6 0.1003 0.99996 6.86176 6.86176 1 1 0.344118 0.688235 586.176 586.176\n",
" 7 0.150021 0.000139845 6.27023 6.66571 0.913793 0.971429 0.311765 1 527.023 566.571\n",
" 8 0.200171 5.79219e-05 0 4.99572 0 0.728051 0 1 -100 399.572\n",
" 9 0.300043 2.41348e-05 0 3.33286 0 0.485714 0 1 -100 233.286\n",
" 10 0.399914 1.17004e-05 0 2.50054 0 0.364416 0 1 -100 150.054\n",
" 11 0.500214 5.77469e-06 0 1.99914 0 0.291345 0 1 -100 99.9143\n",
" 12 0.600086 2.98531e-06 0 1.66643 0 0.242857 0 1 -100 66.6429\n",
" 13 0.699957 1.30511e-06 0 1.42866 0 0.208206 0 1 -100 42.8659\n",
" 14 0.799829 6.00854e-07 0 1.25027 0 0.182208 0 1 -100 25.0268\n",
" 15 0.8997 1.69372e-07 0 1.11148 0 0.161982 0 1 -100 11.1482\n",
" 16 1 2.73715e-10 0 1 0 0.145735 0 1 -100 0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"gbm = H2OGradientBoostingEstimator(ntrees = 800, learn_rate = 0.1,seed = 1234)\n",
"gbm.train(X, y, training_frame= training, validation_frame=validation)\n",
"gbm.model_performance().show()"
]
},
{
"cell_type": "code",
"execution_count": 237,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"AUC на валидационной выборке: 0.9146600190940914\n"
]
}
],
"source": [
"print('AUC на валидационной выборке: ',gbm.auc(valid=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"___Описание параметров вызова функции___ \n",
"только специфические, отличаюшиеся от RandomForest\n",
"1. Параметры для GBM\n",
" * __learn_rate:__ скорость обучения (от 0 до 1)\n",
" * __learn_rate_annealing:__ уменьшает скорость обучения после построения каждого дерева\n",
" \n",
"Для настройки параметров бустинга есть следующие советы от экспертов (Марк Лэндри, Дмитрий Ларько и [github](https://github.com/h2oai/h2o-3/blob/master/h2o-docs/src/product/tutorials/gbm/gbmTuning.ipynb)): \n",
"* заификисровать количество деревьев константой, подобрать leraning rate (например, learn_rate=0.02, learn_rate_annealing=0.995) \n",
"* в самом конце можно снова вернуться к настройке количества деревьев\n",
"* потом необходимо настроить глубину деревьев (max_depth) (чаще всего это 4-10)\n",
"* попробовать изменить тип гистрограммы, nbins/nbins_cat\n",
"* изменить настройки, определяющие подвыборки (sample_rate, col_sample_rate), чаще всего это 70-80%\n",
"* для несбалансированных наборов необходимо настроить параметры, отвечающие за баланс классов (sample_rate_per_class)\n",
"* определить критерии остановки "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"начинаем настройку"
]
},
{
"cell_type": "code",
"execution_count": 239,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gbm Grid Build progress: |████████████████████████████████████████████████| 100%\n",
" learn_rate model_ids auc\n",
"0 0.001 gbm_grid_learn_rate_model_1 0.9319711793457418\n",
"1 0.01 gbm_grid_learn_rate_model_2 0.9310858336529282\n",
"2 1.0 gbm_grid_learn_rate_model_4 0.9299924113226331\n",
"3 1.0E-4 gbm_grid_learn_rate_model_0 0.9189194702613606\n",
"4 0.1 gbm_grid_learn_rate_model_3 0.9146600190940914\n",
"\n",
"0.9319711793457418\n",
"Hyperparameters: [learn_rate]\n",
"[0.001]\n"
]
}
],
"source": [
"gbm_params = {'learn_rate': [0.0001, 0.001, 0.01, 0.1, 1]}\n",
"\n",
"gbm_grid = H2OGridSearch(model=H2OGradientBoostingEstimator(ntrees = 800, seed = 1234),\n",
" grid_id='gbm_grid_learn_rate',\n",
" hyper_params=gbm_params)\n",
"gbm_grid.train(X, y, training_frame=training, validation_frame=validation)\n",
"\n",
"# модели, отсортированные по AUC\n",
"gbm_gridperf = gbm_grid.get_grid(sort_by='auc', decreasing=True)\n",
"print(gbm_gridperf)\n",
"\n",
"# выберем лучшую модель и выведем AUC на тесте \n",
"best_gbm = gbm_gridperf.models[0]\n",
"print(best_gbm.auc(valid=True))\n",
"print(gbm_gridperf.get_hyperparams(0))"
]
},
{
"cell_type": "code",
"execution_count": 240,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gbm Grid Build progress: |████████████████████████████████████████████████| 100%\n",
" max_depth model_ids auc\n",
"0 6 gbm_grid_max_depth_model_1 0.9400943280756583\n",
"1 8 gbm_grid_max_depth_model_2 0.9301229692128176\n",
"2 10 gbm_grid_max_depth_model_3 0.9272466156946904\n",
"3 4 gbm_grid_max_depth_model_0 0.9173772551835563\n",
"\n",
"0.9400943280756583\n",
"Hyperparameters: [max_depth]\n",
"[6]\n"
]
}
],
"source": [
"gbm_params1 = {'max_depth': [4, 6, 8, 10]}\n",
"\n",
"gbm_grid1 = H2OGridSearch(model=H2OGradientBoostingEstimator(ntrees = 800, seed = 1234, learn_rate=0.001),\n",
" grid_id='gbm_grid_max_depth',\n",
" hyper_params=gbm_params1)\n",
"gbm_grid1.train(X, y, training_frame=training, validation_frame=validation)\n",
"\n",
"# модели, отсортированные по AUC\n",
"gbm_gridperf1 = gbm_grid1.get_grid(sort_by='auc', decreasing=True)\n",
"print(gbm_gridperf1)\n",
"\n",
"# выберем лучшую модель и выведем AUC на тесте \n",
"best_gbm1 = gbm_gridperf1.models[0]\n",
"print(best_gbm1.auc(valid=True))\n",
"print(gbm_gridperf1.get_hyperparams(0))"
]
},
{
"cell_type": "code",
"execution_count": 242,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gbm Grid Build progress: |████████████████████████████████████████████████| 100%\n",
" histogram_type nbins model_ids auc\n",
"0 RoundRobin 16 gbm_grid_hist_model_11 0.9450596078367374\n",
"1 RoundRobin 8 gbm_grid_hist_model_7 0.9447291331772079\n",
"2 RoundRobin 32 gbm_grid_hist_model_15 0.9445985752870233\n",
"3 UniformAdaptive 32 gbm_grid_hist_model_12 0.943439874011636\n",
"4 UniformAdaptive 8 gbm_grid_hist_model_4 0.9427993243629184\n",
"5 UniformAdaptive 16 gbm_grid_hist_model_8 0.9425912477254368\n",
"6 Random 16 gbm_grid_hist_model_9 0.9421424549779276\n",
"7 Random 32 gbm_grid_hist_model_13 0.9394864178994867\n",
"8 Random 8 gbm_grid_hist_model_5 0.9377973251952249\n",
"9 QuantilesGlobal 32 gbm_grid_hist_model_14 0.9370343775244592\n",
"10 QuantilesGlobal 16 gbm_grid_hist_model_10 0.9347373746440257\n",
"11 QuantilesGlobal 8 gbm_grid_hist_model_6 0.9315958254114614\n",
"12 RoundRobin 2 gbm_grid_hist_model_3 0.9300209708611109\n",
"13 Random 2 gbm_grid_hist_model_1 0.9293069823991644\n",
"14 UniformAdaptive 2 gbm_grid_hist_model_0 0.9261654331666\n",
"15 QuantilesGlobal 2 gbm_grid_hist_model_2 0.8943541872363342\n",
"\n",
"0.9450596078367374\n",
"Hyperparameters: [nbins, histogram_type]\n",
"[16, 'RoundRobin']\n"
]
}
],
"source": [
"gbm_params2 = {'nbins': [2, 8, 16, 32],\n",
" 'histogram_type': ['UniformAdaptive', 'Random', 'QuantilesGlobal', 'RoundRobin']}\n",
"\n",
"gbm_grid2 = H2OGridSearch(model=H2OGradientBoostingEstimator(ntrees = 800, seed = 1234, \n",
" max_depth=6, learn_rate=0.001),\n",
" grid_id='gbm_grid_hist',\n",
" hyper_params=gbm_params2)\n",
"gbm_grid2.train(X, y, training_frame=training, validation_frame=validation)\n",
"\n",
"# модели, отсортированные по AUC\n",
"gbm_gridperf2 = gbm_grid2.get_grid(sort_by='auc', decreasing=True)\n",
"print(gbm_gridperf2)\n",
"\n",
"# выберем лучшую модель и выведем AUC на тесте \n",
"best_gbm2 = gbm_gridperf2.models[0]\n",
"print(best_gbm2.auc(valid=True))\n",
"print(gbm_gridperf2.get_hyperparams(0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В итоге AUC = 0.9450 на валидационной выборке получился немного выше для GradientBoostingMachine, чем для RandomForestClassifier AUC = 0.941686"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### XGBoost"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Еще одним из алгоритмов, доступных в H20, из семейства Gradient Boosting Machine является XGBoost. Эта библиотека - очень популярный инструмент для решения Kaggle-задач. Он позволяет строить деервья параллельно, что позволяет решать проблемы скорости обучения. \n",
"\n",
"__Особенности модели__\n",
"\n",
"Реализация модели поддерживает особенности реализации scikit-learn и R с новыми дополнениями, такими как регуляризация. Поддерживаются три основные формы повышения градиента:\n",
"\n",
"* Алгоритм Gradient Boosting также называется градиентной машиной повышения, включая скорость обучения.\n",
"* Stochastic Gradient Boosting с суб-выборкой в строке, столбце и столбце на каждый уровень разделения.\n",
"* Регулярное усиление градиента с регуляцией L1 и L2.\n",
"\n",
"__Системные функции__\n",
"\n",
"Библиотека предоставляет систему для использования в различных вычислительных средах, не в последнюю очередь:\n",
"\n",
"* Параллелизация построения дерева с использованием всех ваших ядер процессора во время обучения.\n",
"* Распределенные вычисления для обучения очень крупных моделей с использованием кластера машин.\n",
"* Внекорпоративные вычисления для очень больших наборов данных, которые не вписываются в память.\n",
"* Кэш Оптимизация структуры данных и алгоритма для наилучшего использования аппаратного обеспечения.\n",
"\n",
"__Особенности алгоритма__\n",
"\n",
"Реализация алгоритма была разработана для эффективности вычислительных ресурсов времени и памяти. Цель проекта заключалась в том, чтобы наилучшим образом использовать имеющиеся ресурсы для обучения модели. Некоторые ключевые функции реализации алгоритма включают:\n",
"\n",
"* Редкая реализация Aware с автоматической обработкой отсутствующих значений данных.\n",
"* Блочная структура для поддержки распараллеливания конструкции дерева.\n",
"* Продолжение обучения, чтобы вы могли еще больше повысить уже установленную модель для новых данных. \n",
"\n",
"В реальности в H2O используется нативный алгоритм [XGBoost](http://xgboost.readthedocs.io/en/latest/get_started/index.html), поэтому не буду пытатьтся подобрать параметры, покажу только, как запускается"
]
},
{
"cell_type": "code",
"execution_count": 246,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"from h2o.estimators import H2OXGBoostEstimator\n",
"# библиотека доступна не для всех платформ, сначала надо проверить ее доступность\n",
"is_xgboost_available = H2OXGBoostEstimator.available()\n",
"print(is_xgboost_available)"
]
},
{
"cell_type": "code",
"execution_count": 249,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"xgboost Model Build progress: |███████████████████████████████████████████| 100%\n",
"\n",
"ModelMetricsBinomial: xgboost\n",
"** Reported on train data. **\n",
"\n",
"MSE: 0.04233021547018061\n",
"RMSE: 0.2057430812206831\n",
"LogLoss: 0.1893648656423194\n",
"Mean Per-Class Error: 0.0834117942209498\n",
"AUC: 0.9795984475074526\n",
"Gini: 0.9591968950149052\n",
"Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.38141971826553345: \n"
]
},
{
"data": {
"text/html": [
" \n",
"0 \n",
"1 \n",
"Error \n",
"Rate \n",
"0 \n",
"1960.0 \n",
"33.0 \n",
"0.0166 \n",
" (33.0/1993.0) \n",
"1 \n",
"65.0 \n",
"275.0 \n",
"0.1912 \n",
" (65.0/340.0) \n",
"Total \n",
"2025.0 \n",
"308.0 \n",
"0.042 \n",
" (98.0/2333.0)
"
],
"text/plain": [
" 0 1 Error Rate\n",
"----- ---- --- ------- -------------\n",
"0 1960 33 0.0166 (33.0/1993.0)\n",
"1 65 275 0.1912 (65.0/340.0)\n",
"Total 2025 308 0.042 (98.0/2333.0)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum Metrics: Maximum metrics at their respective thresholds\n",
"\n"
]
},
{
"data": {
"text/html": [
"metric \n",
"threshold \n",
"value \n",
"idx \n",
"max f1 \n",
"0.3814197 \n",
"0.8487654 \n",
"163.0 \n",
"max f2 \n",
"0.1430533 \n",
"0.8405723 \n",
"274.0 \n",
"max f0point5 \n",
"0.6040332 \n",
"0.9194529 \n",
"118.0 \n",
"max accuracy \n",
"0.4644568 \n",
"0.9597085 \n",
"143.0 \n",
"max precision \n",
"0.9170414 \n",
"1.0 \n",
"0.0 \n",
"max recall \n",
"0.1096087 \n",
"1.0 \n",
"316.0 \n",
"max specificity \n",
"0.9170414 \n",
"1.0 \n",
"0.0 \n",
"max absolute_mcc \n",
"0.4626352 \n",
"0.8304986 \n",
"144.0 \n",
"max min_per_class_accuracy \n",
"0.1715132 \n",
"0.8966382 \n",
"251.0 \n",
"max mean_per_class_accuracy \n",
"0.1430533 \n",
"0.9165882 \n",
"274.0
"
],
"text/plain": [
"metric threshold value idx\n",
"--------------------------- ----------- -------- -----\n",
"max f1 0.38142 0.848765 163\n",
"max f2 0.143053 0.840572 274\n",
"max f0point5 0.604033 0.919453 118\n",
"max accuracy 0.464457 0.959709 143\n",
"max precision 0.917041 1 0\n",
"max recall 0.109609 1 316\n",
"max specificity 0.917041 1 0\n",
"max absolute_mcc 0.462635 0.830499 144\n",
"max min_per_class_accuracy 0.171513 0.896638 251\n",
"max mean_per_class_accuracy 0.143053 0.916588 274"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gains/Lift Table: Avg response rate: 14,57 %\n",
"\n"
]
},
{
"data": {
"text/html": [
" \n",
"group \n",
"cumulative_data_fraction \n",
"lower_threshold \n",
"lift \n",
"cumulative_lift \n",
"response_rate \n",
"cumulative_response_rate \n",
"capture_rate \n",
"cumulative_capture_rate \n",
"gain \n",
"cumulative_gain \n",
" \n",
"1 \n",
"0.0102872 \n",
"0.8952582 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0705882 \n",
"0.0705882 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"2 \n",
"0.0201457 \n",
"0.8631429 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.1382353 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"3 \n",
"0.0300043 \n",
"0.8510975 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.2058824 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"4 \n",
"0.0402915 \n",
"0.8373678 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0705882 \n",
"0.2764706 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"5 \n",
"0.0501500 \n",
"0.8046734 \n",
"6.8617647 \n",
"6.8617647 \n",
"1.0 \n",
"1.0 \n",
"0.0676471 \n",
"0.3441176 \n",
"586.1764706 \n",
"586.1764706 \n",
" \n",
"6 \n",
"0.1003000 \n",
"0.6370200 \n",
"6.8031171 \n",
"6.8324409 \n",
"0.9914530 \n",
"0.9957265 \n",
"0.3411765 \n",
"0.6852941 \n",
"580.3117144 \n",
"583.2440925 \n",
" \n",
"7 \n",
"0.1500214 \n",
"0.2845466 \n",
"3.0759635 \n",
"5.5874370 \n",
"0.4482759 \n",
"0.8142857 \n",
"0.1529412 \n",
"0.8382353 \n",
"207.5963489 \n",
"458.7436975 \n",
" \n",
"8 \n",
"0.2001715 \n",
"0.1906352 \n",
"0.5864756 \n",
"4.3345195 \n",
"0.0854701 \n",
"0.6316916 \n",
"0.0294118 \n",
"0.8676471 \n",
"-41.3524384 \n",
"333.4519461 \n",
" \n",
"9 \n",
"0.3000429 \n",
"0.1235101 \n",
"1.1779854 \n",
"3.2838445 \n",
"0.1716738 \n",
"0.4785714 \n",
"0.1176471 \n",
"0.9852941 \n",
"17.7985357 \n",
"228.3844538 \n",
" \n",
"10 \n",
"0.3999143 \n",
"0.1064968 \n",
"0.1472482 \n",
"2.5005359 \n",
"0.0214592 \n",
"0.3644159 \n",
"0.0147059 \n",
"1.0 \n",
"-85.2751830 \n",
"150.0535906 \n",
" \n",
"11 \n",
"0.5002143 \n",
"0.0975546 \n",
"0.0 \n",
"1.9991431 \n",
"0.0 \n",
"0.2913453 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"99.9143102 \n",
" \n",
"12 \n",
"0.6000857 \n",
"0.0918106 \n",
"0.0 \n",
"1.6664286 \n",
"0.0 \n",
"0.2428571 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"66.6428571 \n",
" \n",
"13 \n",
"0.6999571 \n",
"0.0864580 \n",
"0.0 \n",
"1.4286589 \n",
"0.0 \n",
"0.2082058 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"42.8658910 \n",
" \n",
"14 \n",
"0.7998285 \n",
"0.0819798 \n",
"0.0 \n",
"1.2502680 \n",
"0.0 \n",
"0.1822079 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"25.0267953 \n",
" \n",
"15 \n",
"0.8997000 \n",
"0.0772951 \n",
"0.0 \n",
"1.1114817 \n",
"0.0 \n",
"0.1619819 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"11.1481658 \n",
" \n",
"16 \n",
"1.0 \n",
"0.0684245 \n",
"0.0 \n",
"1.0 \n",
"0.0 \n",
"0.1457351 \n",
"0.0 \n",
"1.0 \n",
"-100.0 \n",
"0.0
"
],
"text/plain": [
" group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n",
"-- ------- -------------------------- ----------------- -------- ----------------- --------------- -------------------------- -------------- ------------------------- -------- -----------------\n",
" 1 0.0102872 0.895258 6.86176 6.86176 1 1 0.0705882 0.0705882 586.176 586.176\n",
" 2 0.0201457 0.863143 6.86176 6.86176 1 1 0.0676471 0.138235 586.176 586.176\n",
" 3 0.0300043 0.851098 6.86176 6.86176 1 1 0.0676471 0.205882 586.176 586.176\n",
" 4 0.0402915 0.837368 6.86176 6.86176 1 1 0.0705882 0.276471 586.176 586.176\n",
" 5 0.05015 0.804673 6.86176 6.86176 1 1 0.0676471 0.344118 586.176 586.176\n",
" 6 0.1003 0.63702 6.80312 6.83244 0.991453 0.995726 0.341176 0.685294 580.312 583.244\n",
" 7 0.150021 0.284547 3.07596 5.58744 0.448276 0.814286 0.152941 0.838235 207.596 458.744\n",
" 8 0.200171 0.190635 0.586476 4.33452 0.0854701 0.631692 0.0294118 0.867647 -41.3524 333.452\n",
" 9 0.300043 0.12351 1.17799 3.28384 0.171674 0.478571 0.117647 0.985294 17.7985 228.384\n",
" 10 0.399914 0.106497 0.147248 2.50054 0.0214592 0.364416 0.0147059 1 -85.2752 150.054\n",
" 11 0.500214 0.0975546 0 1.99914 0 0.291345 0 1 -100 99.9143\n",
" 12 0.600086 0.0918106 0 1.66643 0 0.242857 0 1 -100 66.6429\n",
" 13 0.699957 0.086458 0 1.42866 0 0.208206 0 1 -100 42.8659\n",
" 14 0.799829 0.0819798 0 1.25027 0 0.182208 0 1 -100 25.0268\n",
" 15 0.8997 0.0772951 0 1.11148 0 0.161982 0 1 -100 11.1482\n",
" 16 1 0.0684245 0 1 0 0.145735 0 1 -100 0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"AUC на валидационной выборке GradientBoostingMachine: 0.9146600190940914\n",
"AUC на валидационной выборке XGBoost: 0.9335786733686384\n"
]
}
],
"source": [
"param = {\n",
" \"ntrees\" : 100\n",
" , \"max_depth\" : 10\n",
" , \"learn_rate\" : 0.02\n",
" , \"sample_rate\" : 0.7\n",
" , \"col_sample_rate_per_tree\" : 0.9\n",
" , \"min_rows\" : 5\n",
" , \"seed\": 4241\n",
" , \"score_tree_interval\": 100\n",
"}\n",
"\n",
"model = H2OXGBoostEstimator(**param)\n",
"model.train(X, y, training_frame=training, validation_frame=validation)\n",
"model.model_performance().show()\n",
"\n",
"print('AUC на валидационной выборке GradientBoostingMachine: ',gbm.auc(valid=True))\n",
"print('AUC на валидационной выборке XGBoost: ',model.auc(valid=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### LightGBM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[LightGBM](https://github.com/Microsoft/LightGBM) строит глубокие асимметричные деревья, повторно разбивая один лист вместо разбиения всех листьев одного уровня до достижения максимальной глубины. \n",
"XGBoost использует предварительную сортировку и гистограммирование для расчета наилучшего разбиения, т.е.\n",
"- для каждого узла необходимо пронумеровать все признаки; \n",
"- для каждого признака необходимо провести сортировку всех значений (здесь можно разбить на бины и провести сортировку бинов); \n",
"- ищем налиучшее разбиение для признака; \n",
"- выбираем наилучшее разбиение среди всех признаков.\n",
"\n",
"но LightGBM использует другой способ: градиент - это угол наклона функции потерь, таким образом, если градиент для каких-то точек больше, то эти точки важнее для поиска оптимального разбиения. Алгоритм находит все такие точки с максимальным градиентом и делает рандомное расщепление на точках с маленький градиентом. \n",
"Предположим, есть 500K строчек данных, где у 10k строчек градиент больше, таким образом алогритм выберет 10k строчку большим градиентом + x% от отсавшихся 490k строчек, выбранных случайно. Предположим, x = 10%, общее количество выбранных строк = 59k из 500K\n",
"Важное предположение здесь: ошибка на тренировочном наборе с меньшим градиентом меньше и эти данные уже хорошо обучены в модели. Таким образом мы уменьшаем с одной стороны количество данных для обучения, но при этом сохраняем качество для уже обученных деревьев.\n",
"\n",
"----------------\n",
"H2O не интегрирован с LightGBM, но предоставляет метод для эмуляции LightGBM алгоритма, используя определенный набор параметров:\n",
"\n",
"\n",
"tree_method=\"hist\"\n",
"grow_policy=\"lossguide\""
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"H2O session _sid_addf closed.\n"
]
}
],
"source": [
"h2o.cluster().shutdown()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Список литературы__\n",
"\n",
"* https://www.h2o.ai/ \n",
"* http://statistica.ru/local-portals/actuaries/obobshchennye-lineynye-modeli-glm/ \n",
"* https://en.wikipedia.org/wiki/Generalized_linear_model \n",
"* https://alexanderdyakonov.files.wordpress.com/2017/06/book_boosting_pdf.pdf \n",
"* https://github.com/h2oai/h2o-3/blob/master/h2o-docs/src/product/tutorials/gbm/gbmTuning.ipynb \n",
"* https://ru.bmstu.wiki/XGBoost \n",
"* Артем Груздев. Прогнозное моделирование в IBM SPSS Statistics, R и Python; Лекции\n",
"* https://towardsdatascience.com/catboost-vs-light-gbm-vs-xgboost-5f93620723db"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}