{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "## Открытый курс по машинному обучению\n", "
Автор материала: Екатерина Ширяева slackname: Katya.Shiryaeva" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Есть такая либа H20" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Согласно исследованиям [Gartner в феврале 2018](https://www.gartner.com/doc/reprints?id=1-4RQ3VEZ&ct=180223&st=sb), H2O занимает уверенное место в лидерах рынка среди DataScience и Machine Learning платформ.\n", "Gartner считают H2O.ai технологическим лидером, эта платформа мспользуется более чем 100000 data scientistами и удоволетверенность клиентами самая высокая (поддержка, обучение и продажи). \n", "В этом обзоре я хочу показать отличия от реализаций алгоритмов в обычном sklearn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"Drawing\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Подготовка данных" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Для примера можно взять датасет из [1ой лекции: отток клиентов телекома](https://habrahabr.ru/company/ods/blog/322626/) " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('data/telecom_churn.csv')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 3333 entries, 0 to 3332\n", "Data columns (total 20 columns):\n", "State 3333 non-null object\n", "Account length 3333 non-null int64\n", "Area code 3333 non-null int64\n", "International plan 3333 non-null object\n", "Voice mail plan 3333 non-null object\n", "Number vmail messages 3333 non-null int64\n", "Total day minutes 3333 non-null float64\n", "Total day calls 3333 non-null int64\n", "Total day charge 3333 non-null float64\n", "Total eve minutes 3333 non-null float64\n", "Total eve calls 3333 non-null int64\n", "Total eve charge 3333 non-null float64\n", "Total night minutes 3333 non-null float64\n", "Total night calls 3333 non-null int64\n", "Total night charge 3333 non-null float64\n", "Total intl minutes 3333 non-null float64\n", "Total intl calls 3333 non-null int64\n", "Total intl charge 3333 non-null float64\n", "Customer service calls 3333 non-null int64\n", "Churn 3333 non-null bool\n", "dtypes: bool(1), float64(8), int64(8), object(3)\n", "memory usage: 498.1+ KB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Описание всех признаков можно посмотреть в [1ой лекции](https://habrahabr.ru/company/ods/blog/322626/) \n", "В качестве предварительной обработки заменю все значения churn на 0/1, \n", "для International plan и Voice mail plan сделаю замену yes / no на 0 / 1 \n", "а категориальную переменную State пока удалю из обработки (данных недостаточно, чтобы сделать OHE, и укрупнение категорий не стоит в целях этого тьюториала)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Account lengthArea codeInternational planVoice mail planNumber vmail messagesTotal day minutesTotal day callsTotal day chargeTotal eve minutesTotal eve callsTotal eve chargeTotal night minutesTotal night callsTotal night chargeTotal intl minutesTotal intl callsTotal intl chargeCustomer service callsChurn
01284150125265.111045.07197.49916.78244.79111.0110.032.7010
11074150126161.612327.47195.510316.62254.410311.4513.733.7010
2137415000243.411441.38121.211010.30162.61047.3212.253.2900
384408100299.47150.9061.9885.26196.9898.866.671.7820
475415100166.711328.34148.312212.61186.91218.4110.132.7330
\n", "
" ], "text/plain": [ " Account length Area code International plan Voice mail plan \\\n", "0 128 415 0 1 \n", "1 107 415 0 1 \n", "2 137 415 0 0 \n", "3 84 408 1 0 \n", "4 75 415 1 0 \n", "\n", " Number vmail messages Total day minutes Total day calls \\\n", "0 25 265.1 110 \n", "1 26 161.6 123 \n", "2 0 243.4 114 \n", "3 0 299.4 71 \n", "4 0 166.7 113 \n", "\n", " Total day charge Total eve minutes Total eve calls Total eve charge \\\n", "0 45.07 197.4 99 16.78 \n", "1 27.47 195.5 103 16.62 \n", "2 41.38 121.2 110 10.30 \n", "3 50.90 61.9 88 5.26 \n", "4 28.34 148.3 122 12.61 \n", "\n", " Total night minutes Total night calls Total night charge \\\n", "0 244.7 91 11.01 \n", "1 254.4 103 11.45 \n", "2 162.6 104 7.32 \n", "3 196.9 89 8.86 \n", "4 186.9 121 8.41 \n", "\n", " Total intl minutes Total intl calls Total intl charge \\\n", "0 10.0 3 2.70 \n", "1 13.7 3 3.70 \n", "2 12.2 5 3.29 \n", "3 6.6 7 1.78 \n", "4 10.1 3 2.73 \n", "\n", " Customer service calls Churn \n", "0 1 0 \n", "1 1 0 \n", "2 0 0 \n", "3 2 0 \n", "4 3 0 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d = {'No' : 0, 'Yes' : 1}\n", "df['Churn'] = df['Churn'].apply(lambda x : int(x))\n", "df['International plan'] = df['International plan'].map(d)\n", "df['Voice mail plan'] = df['Voice mail plan'].map(d)\n", "df.drop('State', axis=1, inplace=True)\n", "\n", "df.head() " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Разделим на тест и обучающую выборки" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(df.iloc[:,:-1], df.iloc[:,-1], test_size=0.3, random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Для работы h2o необходимо установить эту библиотеку и запустить" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install h2o" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checking whether there is an H2O instance running at http://localhost:54321. connected.\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
H2O cluster uptime:6 days 0 hours 34 mins
H2O cluster timezone:Europe/Moscow
H2O data parsing timezone:UTC
H2O cluster version:3.18.0.4
H2O cluster version age:1 month and 12 days
H2O cluster name:H2O_from_python_katya_lpda3c
H2O cluster total nodes:1
H2O cluster free memory:6.316 Gb
H2O cluster total cores:8
H2O cluster allowed cores:8
H2O cluster status:locked, healthy
H2O connection url:http://localhost:54321
H2O connection proxy:None
H2O internal security:False
H2O API Extensions:XGBoost, Algos, AutoML, Core V3, Core V4
Python version:3.6.4 final
" ], "text/plain": [ "-------------------------- ----------------------------------------\n", "H2O cluster uptime: 6 days 0 hours 34 mins\n", "H2O cluster timezone: Europe/Moscow\n", "H2O data parsing timezone: UTC\n", "H2O cluster version: 3.18.0.4\n", "H2O cluster version age: 1 month and 12 days\n", "H2O cluster name: H2O_from_python_katya_lpda3c\n", "H2O cluster total nodes: 1\n", "H2O cluster free memory: 6.316 Gb\n", "H2O cluster total cores: 8\n", "H2O cluster allowed cores: 8\n", "H2O cluster status: locked, healthy\n", "H2O connection url: http://localhost:54321\n", "H2O connection proxy:\n", "H2O internal security: False\n", "H2O API Extensions: XGBoost, Algos, AutoML, Core V3, Core V4\n", "Python version: 3.6.4 final\n", "-------------------------- ----------------------------------------" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import h2o\n", "import os\n", "h2o.init(nthreads=-1, max_mem_size=8)\n", "# nthreads - количество ядер процессора для вычислений\n", "# max_mem_size - максимальный размер оперативной памяти " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Все данные надо перевести в специальную структуру h2o : H2OFrame \n", "h2o [поддерживает](http://docs.h2o.ai/h2o/latest-stable/h2o-docs/getting-data-into-h2o.html) большое количество источников, однако у меня не получилось перекодировать csr-матрицу из задания про Элис :) очень долго висел :)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parse progress: |█████████████████████████████████████████████████████████| 100%\n", "Parse progress: |█████████████████████████████████████████████████████████| 100%\n" ] } ], "source": [ "training = h2o.H2OFrame(pd.concat([X_train, y_train], axis=1))\n", "validation = h2o.H2OFrame(pd.concat([X_test, y_test], axis=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Посмотрим на структуру (это обязательно надо делать, т.к. не всегда корректно происходит переход форматов)
\n", "Здесь будут описаны основные параметры каждой переменной (тип, максимум, минимум, среднее, стандартное отклонение, количество нулевых и пропущенных значений и первые 10 наблюдений)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Rows:2333\n", "Cols:19\n", "\n", "\n" ] }, { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Account length Area code International plan Voice mail plan Number vmail messages Total day minutes Total day calls Total day charge Total eve minutes Total eve calls Total eve charge Total night minutes Total night calls Total night charge Total intl minutes Total intl calls Total intl charge Customer service calls Churn
type int int int int int real int real real int real real int real real int real int int
mins 1.0 408.0 0.0 0.0 0.0 2.6 30.0 0.44 0.0 0.0 0.0 23.2 33.0 1.04 0.0 0.0 0.0 0.0 0.0
mean 100.37848264037734436.719245606515130.09515645092156022 0.27261037291041588.032576082297462 180.0195027861121 100.6228032576083530.60383197599656 200.95752250321465 100.0540077153877817.081633090441525200.67038148306924 99.94813544792119 9.030210030004287 10.242777539648538 4.444492070295745 2.7660522931847398 1.568795542220319 0.14573510501500214
maxs 232.0 510.0 1.0 1.0 51.0 346.8 165.0 58.96 363.7 170.0 30.91 395.0 175.0 17.77 20.0 18.0 5.4 9.0 1.0
sigma 39.81513240422018642.11342758508376 0.2934938203721862 0.445397563089707813.722524774971957 54.503148533784056 19.89235683817998 9.265512312565042 50.771196810709434 20.0818564484648834.315580894314363 50.935130537595086 19.586623410722094 2.292113726067487 2.791145550244814 2.45159500388630970.7536382712677757 1.3337241215350106 0.35291609524195916
zeros 0 0 2111 1697 1697 0 0 0 1 1 1 0 0 0 13 13 13 493 1993
missing0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 80.0 510.0 0.0 0.0 0.0 202.4 118.0 34.41 260.2 67.0 22.12 177.4 112.0 7.98 9.2 5.0 2.48 3.0 0.0
1 63.0 510.0 0.0 0.0 0.0 132.9 122.0 22.59 67.0 62.0 5.7 160.4 121.0 7.22 9.9 2.0 2.67 3.0 0.0
2 116.0 510.0 0.0 1.0 12.0 221.0 108.0 37.57 151.0 118.0 12.84 179.0 80.0 8.06 9.0 6.0 2.43 2.0 0.0
3 71.0 415.0 0.0 0.0 0.0 278.9 110.0 47.41 190.2 67.0 16.17 255.2 84.0 11.48 11.7 7.0 3.16 0.0 1.0
4 120.0 510.0 0.0 1.0 43.0 177.9 117.0 30.24 175.1 70.0 14.88 161.3 117.0 7.26 11.5 4.0 3.11 1.0 0.0
5 132.0 510.0 0.0 0.0 0.0 181.1 121.0 30.79 314.4 109.0 26.72 246.7 81.0 11.1 4.2 9.0 1.13 2.0 0.0
6 105.0 415.0 0.0 0.0 0.0 156.5 102.0 26.61 140.2 134.0 11.92 227.4 111.0 10.23 12.2 2.0 3.29 2.0 0.0
7 117.0 510.0 1.0 0.0 0.0 198.4 121.0 33.73 249.5 104.0 21.21 162.8 115.0 7.33 10.5 5.0 2.84 1.0 0.0
8 64.0 510.0 0.0 0.0 0.0 216.9 78.0 36.87 211.0 115.0 17.94 179.8 116.0 8.09 11.4 5.0 3.08 3.0 0.0
9 143.0 510.0 0.0 1.0 33.0 141.4 130.0 24.04 186.4 114.0 15.84 210.0 111.0 9.45 7.7 6.0 2.08 1.0 0.0
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "training.describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Зависимая переменная для бинарной классификации должна быть не количественной переменной, а категориальной, преобразуем с помощью метода asfactor()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "training['Churn'] = training['Churn'].asfactor()\n", "validation['Churn'] = validation['Churn'].asfactor()\n", "\n", "training['International plan'] = training['International plan'].asfactor()\n", "training['Voice mail plan'] = training['Voice mail plan'].asfactor()\n", "\n", "validation['International plan'] = validation['International plan'].asfactor()\n", "validation['Voice mail plan'] = validation['Voice mail plan'].asfactor()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Random Forest" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### sklearn RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import roc_auc_score" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC для sklearn RandomForestClassifier: 0.9404\n" ] } ], "source": [ "forest = RandomForestClassifier(n_estimators=800, random_state=152, n_jobs=-1)\n", "forest.fit(X_train, y_train)\n", "print('AUC для sklearn RandomForestClassifier: {:.4f}'.format(roc_auc_score(y_test, forest.predict_proba(X_test)[:, 1])))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### H2ORandomForestEstimator" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from h2o.estimators import H2ORandomForestEstimator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "необходимо задать список зависимых переменных и предикторов (это будут X и y в коде)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", "Model Details\n", "=============\n", "H2ORandomForestEstimator : Distributed Random Forest\n", "Model Key: tutorial1\n", "\n", "\n", "ModelMetricsBinomial: drf\n", "** Reported on train data. **\n", "\n", "MSE: 0.04684101277272021\n", "RMSE: 0.21642784657414169\n", "LogLoss: 0.20074033630931348\n", "Mean Per-Class Error: 0.10923157521914939\n", "AUC: 0.8994753696762197\n", "Gini: 0.7989507393524393\n", "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.4256055363321799: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
01ErrorRate
01972.021.00.0105 (21.0/1993.0)
177.0263.00.2265 (77.0/340.0)
Total2049.0284.00.042 (98.0/2333.0)
" ], "text/plain": [ " 0 1 Error Rate\n", "----- ---- --- ------- -------------\n", "0 1972 21 0.0105 (21.0/1993.0)\n", "1 77 263 0.2265 (77.0/340.0)\n", "Total 2049 284 0.042 (98.0/2333.0)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Maximum Metrics: Maximum metrics at their respective thresholds\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
metricthresholdvalueidx
max f10.42560550.8429487146.0
max f20.23979010.8105802204.0
max f0point50.47465980.8958924137.0
max accuracy0.42560550.9579940146.0
max precision0.97905751.00.0
max recall0.00345021.0397.0
max specificity0.97905751.00.0
max absolute_mcc0.42560550.8233476146.0
max min_per_class_accuracy0.16641490.8470588235.0
max mean_per_class_accuracy0.23979010.8907684204.0
" ], "text/plain": [ "metric threshold value idx\n", "--------------------------- ----------- -------- -----\n", "max f1 0.425606 0.842949 146\n", "max f2 0.23979 0.81058 204\n", "max f0point5 0.47466 0.895892 137\n", "max accuracy 0.425606 0.957994 146\n", "max precision 0.979058 1 0\n", "max recall 0.00345018 1 397\n", "max specificity 0.979058 1 0\n", "max absolute_mcc 0.425606 0.823348 146\n", "max min_per_class_accuracy 0.166415 0.847059 235\n", "max mean_per_class_accuracy 0.23979 0.890768 204" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Gains/Lift Table: Avg response rate: 14,57 %\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
groupcumulative_data_fractionlower_thresholdliftcumulative_liftresponse_ratecumulative_response_ratecapture_ratecumulative_capture_rategaincumulative_gain
10.01028720.92933756.86176476.86176471.01.00.07058820.0705882586.1764706586.1764706
20.02014570.88407876.86176476.86176471.01.00.06764710.1382353586.1764706586.1764706
30.03000430.84296126.86176476.86176471.01.00.06764710.2058824586.1764706586.1764706
40.04029150.81141066.86176476.86176471.01.00.07058820.2764706586.1764706586.1764706
50.05015000.78141836.86176476.86176471.01.00.06764710.3441176586.1764706586.1764706
60.10030000.56656466.39258426.62717450.93162390.96581200.32058820.6647059539.2584213562.7174460
70.15002140.28659512.95765725.41099160.43103450.78857140.14705880.8117647195.7657201441.0991597
80.20017150.18053240.64512324.21697320.09401710.61456100.03235290.8441176-35.4876823321.6973170
90.30004290.10579730.14724822.86233610.02145920.41714290.01470590.8588235-85.2751830186.2336134
100.39991430.07627910.08834892.16958260.01287550.31618440.00882350.8676471-91.1651098116.9582624
110.50021430.05595670.11729511.75807000.01709400.25621250.01176470.8794118-88.270487775.8069963
120.60008570.04246440.23559711.50468700.03433480.21928570.02352940.9029412-76.440292950.4686975
130.69995710.03157230.23559711.32361050.03433480.19289650.02352940.9264706-76.440292932.3610461
140.79982850.02295900.20614741.18407730.03004290.17256160.02058820.9470588-79.385256218.4077297
150.89970000.01364150.26504671.08206010.03862660.15769410.02647060.9735294-73.49532958.2060085
161.00.00.26391401.00.03846150.14573510.02647061.0-73.60859730.0
" ], "text/plain": [ " group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n", "-- ------- -------------------------- ----------------- --------- ----------------- --------------- -------------------------- -------------- ------------------------- -------- -----------------\n", " 1 0.0102872 0.929337 6.86176 6.86176 1 1 0.0705882 0.0705882 586.176 586.176\n", " 2 0.0201457 0.884079 6.86176 6.86176 1 1 0.0676471 0.138235 586.176 586.176\n", " 3 0.0300043 0.842961 6.86176 6.86176 1 1 0.0676471 0.205882 586.176 586.176\n", " 4 0.0402915 0.811411 6.86176 6.86176 1 1 0.0705882 0.276471 586.176 586.176\n", " 5 0.05015 0.781418 6.86176 6.86176 1 1 0.0676471 0.344118 586.176 586.176\n", " 6 0.1003 0.566565 6.39258 6.62717 0.931624 0.965812 0.320588 0.664706 539.258 562.717\n", " 7 0.150021 0.286595 2.95766 5.41099 0.431034 0.788571 0.147059 0.811765 195.766 441.099\n", " 8 0.200171 0.180532 0.645123 4.21697 0.0940171 0.614561 0.0323529 0.844118 -35.4877 321.697\n", " 9 0.300043 0.105797 0.147248 2.86234 0.0214592 0.417143 0.0147059 0.858824 -85.2752 186.234\n", " 10 0.399914 0.0762791 0.0883489 2.16958 0.0128755 0.316184 0.00882353 0.867647 -91.1651 116.958\n", " 11 0.500214 0.0559567 0.117295 1.75807 0.017094 0.256213 0.0117647 0.879412 -88.2705 75.807\n", " 12 0.600086 0.0424644 0.235597 1.50469 0.0343348 0.219286 0.0235294 0.902941 -76.4403 50.4687\n", " 13 0.699957 0.0315723 0.235597 1.32361 0.0343348 0.192897 0.0235294 0.926471 -76.4403 32.361\n", " 14 0.799829 0.022959 0.206147 1.18408 0.0300429 0.172562 0.0205882 0.947059 -79.3853 18.4077\n", " 15 0.8997 0.0136415 0.265047 1.08206 0.0386266 0.157694 0.0264706 0.973529 -73.4953 8.20601\n", " 16 1 0 0.263914 1 0.0384615 0.145735 0.0264706 1 -73.6086 0" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "ModelMetricsBinomial: drf\n", "** Reported on validation data. **\n", "\n", "MSE: 0.04181281142387741\n", "RMSE: 0.20448181196350304\n", "LogLoss: 0.17373883641178992\n", "Mean Per-Class Error: 0.07810625780287395\n", "AUC: 0.9416814224282135\n", "Gini: 0.883362844856427\n", "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.29251602564007045: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
01ErrorRate
0838.019.00.0222 (19.0/857.0)
121.0122.00.1469 (21.0/143.0)
Total859.0141.00.04 (40.0/1000.0)
" ], "text/plain": [ " 0 1 Error Rate\n", "----- --- --- ------- -------------\n", "0 838 19 0.0222 (19.0/857.0)\n", "1 21 122 0.1469 (21.0/143.0)\n", "Total 859 141 0.04 (40.0/1000.0)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Maximum Metrics: Maximum metrics at their respective thresholds\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
metricthresholdvalueidx
max f10.29251600.8591549120.0
max f20.26359380.8644537129.0
max f0point50.583750.895953881.0
max accuracy0.38041670.961108.0
max precision0.98251.00.0
max recall0.01380661.0381.0
max specificity0.98251.00.0
max absolute_mcc0.29251600.8358744120.0
max min_per_class_accuracy0.18250.9090909164.0
max mean_per_class_accuracy0.26359380.9218937129.0
" ], "text/plain": [ "metric threshold value idx\n", "--------------------------- ----------- -------- -----\n", "max f1 0.292516 0.859155 120\n", "max f2 0.263594 0.864454 129\n", "max f0point5 0.58375 0.895954 81\n", "max accuracy 0.380417 0.961 108\n", "max precision 0.9825 1 0\n", "max recall 0.0138066 1 381\n", "max specificity 0.9825 1 0\n", "max absolute_mcc 0.292516 0.835874 120\n", "max min_per_class_accuracy 0.1825 0.909091 164\n", "max mean_per_class_accuracy 0.263594 0.921894 129" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Gains/Lift Table: Avg response rate: 14,30 %\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
groupcumulative_data_fractionlower_thresholdliftcumulative_liftresponse_ratecumulative_response_ratecapture_ratecumulative_capture_rategaincumulative_gain
10.0110.90756.99300706.99300701.01.00.07692310.0769231599.3006993599.3006993
20.020.87755006.99300706.99300701.01.00.06293710.1398601599.3006993599.3006993
30.030.84003756.99300706.99300701.01.00.06993010.2097902599.3006993599.3006993
40.040.79275006.99300706.99300701.01.00.06993010.2797203599.3006993599.3006993
50.050.76306256.99300706.99300701.01.00.06993010.3496503599.3006993599.3006993
60.10.51275006.43356646.71328670.920.960.32167830.6713287543.3566434571.3286713
70.150.26348443.91608395.78088580.560.82666670.19580420.8671329291.6083916478.0885781
80.20.17550250.83916084.54545450.120.650.04195800.9090909-16.0839161354.5454545
90.30.11004170.13986013.07692310.020.440.01398600.9230769-86.0139860207.6923077
100.40.08254850.02.30769230.00.330.00.9230769-100.0130.7692308
110.50.05865970.01.84615380.00.2640.00.9230769-100.084.6153846
120.60.04269860.13986011.56177160.020.22333330.01398600.9370629-86.013986056.1771562
130.70.03251820.34965031.38861140.050.19857140.03496500.9720280-65.034965038.8611389
140.80.02382020.06993011.22377620.010.1750.00699300.9790210-93.006993022.3776224
150.90.01622970.13986011.10334110.020.15777780.01398600.9930070-86.013986010.3341103
161.00.00125000.06993011.00.010.1430.00699301.0-93.00699300.0
" ], "text/plain": [ " group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n", "-- ------- -------------------------- ----------------- --------- ----------------- --------------- -------------------------- -------------- ------------------------- -------- -----------------\n", " 1 0.011 0.9075 6.99301 6.99301 1 1 0.0769231 0.0769231 599.301 599.301\n", " 2 0.02 0.87755 6.99301 6.99301 1 1 0.0629371 0.13986 599.301 599.301\n", " 3 0.03 0.840037 6.99301 6.99301 1 1 0.0699301 0.20979 599.301 599.301\n", " 4 0.04 0.79275 6.99301 6.99301 1 1 0.0699301 0.27972 599.301 599.301\n", " 5 0.05 0.763062 6.99301 6.99301 1 1 0.0699301 0.34965 599.301 599.301\n", " 6 0.1 0.51275 6.43357 6.71329 0.92 0.96 0.321678 0.671329 543.357 571.329\n", " 7 0.15 0.263484 3.91608 5.78089 0.56 0.826667 0.195804 0.867133 291.608 478.089\n", " 8 0.2 0.175502 0.839161 4.54545 0.12 0.65 0.041958 0.909091 -16.0839 354.545\n", " 9 0.3 0.110042 0.13986 3.07692 0.02 0.44 0.013986 0.923077 -86.014 207.692\n", " 10 0.4 0.0825485 0 2.30769 0 0.33 0 0.923077 -100 130.769\n", " 11 0.5 0.0586597 0 1.84615 0 0.264 0 0.923077 -100 84.6154\n", " 12 0.6 0.0426986 0.13986 1.56177 0.02 0.223333 0.013986 0.937063 -86.014 56.1772\n", " 13 0.7 0.0325182 0.34965 1.38861 0.05 0.198571 0.034965 0.972028 -65.035 38.8611\n", " 14 0.8 0.0238202 0.0699301 1.22378 0.01 0.175 0.00699301 0.979021 -93.007 22.3776\n", " 15 0.9 0.0162297 0.13986 1.10334 0.02 0.157778 0.013986 0.993007 -86.014 10.3341\n", " 16 1 0.00125 0.0699301 1 0.01 0.143 0.00699301 1 -93.007 0" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Scoring History: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
timestampdurationnumber_of_treestraining_rmsetraining_loglosstraining_auctraining_lifttraining_classification_errorvalidation_rmsevalidation_loglossvalidation_aucvalidation_liftvalidation_classification_error
2018-04-20 23:38:58 0.005 sec0.0nannannannannannannannannannan
2018-04-20 23:38:58 0.020 sec1.00.32995503.75938170.83087204.03360490.10882710.32703053.66500920.79469364.32462270.107
2018-04-20 23:38:58 0.031 sec2.00.31995253.31164780.81277974.33554290.10180510.27338111.57441270.85680256.08087560.075
2018-04-20 23:38:58 0.041 sec3.00.32093833.01950300.80903404.56372090.12848840.25406671.01725080.87872406.74763830.069
2018-04-20 23:38:58 0.052 sec4.00.31508232.72572840.80892394.79028860.10398340.24663770.88901100.88232256.99300700.062
------------------------------------------
2018-04-20 23:39:01 3.860 sec117.00.22079160.34069370.90619826.86176470.04629230.20587210.17691920.93880516.99300700.041
2018-04-20 23:39:01 3.921 sec118.00.22041950.34020150.90673096.86176470.04586370.20590820.17696750.93856036.99300700.04
2018-04-20 23:39:02 3.988 sec119.00.22025260.34022130.90652656.86176470.04629230.20585090.17710150.93842566.99300700.04
2018-04-20 23:39:06 7.993 sec780.00.21664580.20078820.89972626.86176470.04243460.20454160.17396910.94125716.99300700.039
2018-04-20 23:39:06 8.543 sec800.00.21642780.20074030.89947546.86176470.04200600.20448180.17373880.94168146.99300700.04
" ], "text/plain": [ " timestamp duration number_of_trees training_rmse training_logloss training_auc training_lift training_classification_error validation_rmse validation_logloss validation_auc validation_lift validation_classification_error\n", "--- ------------------- ---------- ----------------- ------------------- ------------------- ------------------ ----------------- ------------------------------- ------------------- -------------------- ------------------ ------------------ ---------------------------------\n", " 2018-04-20 23:38:58 0.005 sec 0.0 nan nan nan nan nan nan nan nan nan nan\n", " 2018-04-20 23:38:58 0.020 sec 1.0 0.3299549533419618 3.759381658451759 0.8308720112517581 4.033604928457869 0.10882708585247884 0.32703054843001095 3.6650091908267997 0.7946936377508139 4.324622745675378 0.107\n", " 2018-04-20 23:38:58 0.031 sec 2.0 0.31995248719448316 3.3116478265664515 0.8127796965211596 4.335542873865965 0.10180505415162455 0.273381134867804 1.5744126828608989 0.856802474072019 6.0808756460930375 0.075\n", " 2018-04-20 23:38:58 0.041 sec 3.0 0.3209383181496875 3.0195030090134423 0.8090340035860656 4.563720865704773 0.12848837209302325 0.2540666971937434 1.0172508387501398 0.8787239598208093 6.747638326585696 0.069\n", " 2018-04-20 23:38:58 0.052 sec 4.0 0.31508233994601303 2.725728410467888 0.8089238889980269 4.790288568257491 0.10398344542162442 0.24663767870350212 0.8890110203449179 0.8823224616690194 6.993006993006993 0.062\n", "--- --- --- --- --- --- --- --- --- --- --- --- --- ---\n", " 2018-04-20 23:39:01 3.860 sec 117.0 0.22079158962413375 0.3406936623575952 0.9061981641628051 6.861764705882353 0.04629232747535362 0.20587207920559789 0.17691917336902896 0.9388050689100865 6.993006993006993 0.041\n", " 2018-04-20 23:39:01 3.921 sec 118.0 0.22041952436491502 0.3402014866020256 0.9067309111301319 6.861764705882353 0.04586369481354479 0.20590815841200824 0.17696748670339232 0.9385602728659904 6.993006993006993 0.04\n", " 2018-04-20 23:39:02 3.988 sec 119.0 0.2202526177493472 0.34022129663205214 0.9065265192880966 6.861764705882353 0.04629232747535362 0.20585093865188814 0.17710146258395315 0.9384256350417378 6.993006993006993 0.04\n", " 2018-04-20 23:39:06 7.993 sec 780.0 0.21664584404283624 0.20078819748633736 0.8997262477494761 6.861764705882353 0.04243463351907415 0.2045416002818827 0.17396907936194805 0.9412571092851139 6.993006993006993 0.039\n", " 2018-04-20 23:39:06 8.543 sec 800.0 0.21642784657414169 0.20074033630931348 0.8994753696762197 6.861764705882353 0.04200600085726532 0.20448181196350304 0.17373883641178992 0.9416814224282135 6.993006993006993 0.04" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "See the whole table with table.as_data_frame()\n", "Variable Importances: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
variablerelative_importancescaled_importancepercentage
Total day minutes20748.49023441.00.1363729
Total day charge19548.78320310.94217860.1284876
Customer service calls19157.33593750.92331230.1259148
International plan14547.70019530.70114500.0956172
Total eve charge9294.66210940.44796810.0610907
Total eve minutes9159.76367190.44146650.0602041
Total intl calls8102.65185550.39051770.0532560
Total intl minutes6614.37451170.31878820.0434741
Total intl charge6516.21044920.31405710.0428289
Total night charge5165.45800780.24895580.0339508
Total night minutes5084.63525390.24506050.0334196
Total night calls5060.01123050.24387370.0332578
Total day calls4819.54541020.23228410.0316773
Number vmail messages4619.85156250.22265960.0303647
Total eve calls4476.22119140.21573720.0294207
Account length4363.56005860.21030740.0286802
Voice mail plan3168.40502930.15270530.0208249
Area code1697.57812500.08181690.0111576
" ], "text/plain": [ "variable relative_importance scaled_importance percentage\n", "---------------------- --------------------- ------------------- ------------\n", "Total day minutes 20748.5 1 0.136373\n", "Total day charge 19548.8 0.942179 0.128488\n", "Customer service calls 19157.3 0.923312 0.125915\n", "International plan 14547.7 0.701145 0.0956172\n", "Total eve charge 9294.66 0.447968 0.0610907\n", "Total eve minutes 9159.76 0.441467 0.0602041\n", "Total intl calls 8102.65 0.390518 0.053256\n", "Total intl minutes 6614.37 0.318788 0.0434741\n", "Total intl charge 6516.21 0.314057 0.0428289\n", "Total night charge 5165.46 0.248956 0.0339508\n", "Total night minutes 5084.64 0.24506 0.0334196\n", "Total night calls 5060.01 0.243874 0.0332578\n", "Total day calls 4819.55 0.232284 0.0316773\n", "Number vmail messages 4619.85 0.22266 0.0303647\n", "Total eve calls 4476.22 0.215737 0.0294207\n", "Account length 4363.56 0.210307 0.0286802\n", "Voice mail plan 3168.41 0.152705 0.0208249\n", "Area code 1697.58 0.0818169 0.0111576" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = training.columns\n", "X.remove('Churn')\n", "y = 'Churn'\n", "\n", "rf1 = H2ORandomForestEstimator(model_id='tutorial1', ntrees=800, seed=152)\n", "rf1.train(X, y, training_frame=training, validation_frame=validation)\n", "rf1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Получился большой вывод информации о модели. \n", "Сразу следует обратить внимание на AUC при тех же заданных параметрах (ntrees=800) " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9416814224282135\n" ] } ], "source": [ "print(rf1.auc(valid=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Вся информация идет сначала об обучающей, потом о валидационной выборке\n", "* Отображаются все метрики (MSE, RMSE, LogLoss, AUC, Gini и т.д.)\n", "* Строится матрица ошибок (confusion matrix), которая приводится для порогового значения спрогнозированной вероятности события, оптимального с точки зрения F1-меры _(для справки : F1- мера = 2 x точность x полнота/(точность + полнота))_\n", "* _Maximum Metrics:_ Рассчитываются на ней различные метрики и соответствующие пороговые значения\n", "* _Gains/Lift Table:_ Таблица выигрышей создается путем разбиения данных на группы по квантильным пороговым значениям спрогнозированной верроятности положительного класса\n", "* _Variable Importances:_ Информация о важностях предикторов. Информация выводится также в отмасштабированном и процентном видах. На сайте с документацией указано, что важность рассчитывается, как относительное влияние каждой переменной: была ли переменная выбрана при построении дерева и на как изменилась среднеквадратичная ошибка (рассчитывается на всех деревьях)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# можно нарисовать графиком \n", "import matplotlib.pyplot as plt\n", "plt.rcdefaults()\n", "fig, ax = plt.subplots()\n", "variables = rf1._model_json['output']['variable_importances']['variable']\n", "y_pos = np.arange(len(variables))\n", "scaled_importance = rf1._model_json['output']['variable_importances']['scaled_importance']\n", "ax.barh(y_pos, scaled_importance, align='center', color='green', ecolor='black')\n", "ax.set_yticks(y_pos)\n", "ax.set_yticklabels(variables)\n", "ax.invert_yaxis()\n", "ax.set_xlabel('Scaled Importance')\n", "ax.set_title('Variable Importance')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "___Описание параметров вызова функции___ \n", "1. Параметры, определяющие задачу \n", " * __model_id:__ идентификатор модели\n", " * __training_frame:__ датасет для построения модели\n", " * __validation_frame:__ датасет для валидации\n", " * __nfolds:__ количество фолдов для кросс-валидации (по умолчанию 0)\n", " * __y:__ имена зависимой переменной\n", " * __x:__ список названий предикторов\n", " * __seed:__ random state

\n", "2. Параметры, задающие сложность дерева\n", " * __ntrees__: количество деревьев\n", " * __max_depth:__  максимальная глубина дерева\n", " * __min_rows:__ минимальное количество наблюдений в терминальном листе\n", "

\n", "3. Параметры, определяющие формирование подвыборок\n", " * __mtries:__ количество случайно отбираемых предикторов для разбиения узла. По умолчанию -1: для классификации корень квадратный из р, для регрессии р / 3, где р - количество предикторов\n", " * __sample_rate:__ какую часть строк отбирать (от 0 до 1). По умолчанию 0.6320000291 \n", " * __sample_rate_per_class:__ для построения модели из несбалансированного набора данных. Какую часть строк выбирать для каждого дерева (от 0 до 1) \n", " * __col_sample_rate_per_tree:__ какую часть столбцов выбирать для каждого дерева (от 0 до 1, по умоланчанию 1)\n", " * __col_sample_rate_change_per_level:__ задает изменение отбора столбцов для каждого уровня дерева (от 0 до 2, по умолчанию 1), например: (factor = col_sample_rate_change_per_level)\n", " * level 1: col_sample_rate\n", " * level 2: col_sample_rate * factor\n", " * level 3: col_sample_rate * factor^2\n", " * level 4: col_sample_rate * factor^3\n", "

\n", "4. Параметры, определяющие биннинг переменных\n", " * __nbins__: (Numerical/real/int only) для каждой переменной строит по крайней мере n интервалов и затем рабивает по наилучшей точке расщепления\n", " * __nbins_top_level:__  (Numerical/real/int only) определяет максимальное количество интервалов n на вершине дерева. При переходе на увроень ниже делит заданное число на 2 до тех пор, пока не дойдет до уровня nbins \n", " * __nbins_cats:__ (Categorical/enums only) каждая переменная может быть разбита максимум на n интерваловю. Более высокие значения могут привести к переобучению.\n", " * __categorical_encoding:__ схема кодировки для категориальных переменных\n", " * auto: используется схема enum.\n", " * enum: 1 столбец для каждого категориального признака (как есть)\n", " * one_hot_explicit: N+1 новых столбцов для каждого признака с N уровнями\n", " * binary: не более 32 столбцов для признака (используется хеширование)\n", " * eigen: выполняет one-hot-encoding и оставляет k главных компонент\n", " * label_encoder: все категории сортируются в лексикографическом порядке и каждому уровню присваивается целое число, начиная с 0 (например, level 0 -> 0, level 1 -> 1, и т.д.)\n", " * sort_by_response: все категории сортируются по среднему значению переменной и каждому уровню присваивается значение (например, для уровня с мин средним ответом -> 0, вторым наименьшим -> 1, и т.д.). \n", " * __histogram_type:__ тип гистограммы для поиска оптимальных расщепляющих значений\n", " * AUTO = UniformAdaptive\n", " * UniformAdaptive: задаются интервалы одинаковой ширины (max-min)/N\n", " * QuantilesGlobal: интервалы одинаквого размера (в каждом интервале одинаковое количество наблюдений)\n", " * Random: задает построение Extremely Randomized Trees (XRT). Случайным образом отбирается N-1 точка расщепления и затем выбирается наилучшее разбиение\n", " * RoundRobin: все типы гистограмм (по одному на каждое дерево) перебираются по кругу \n", "

\n", "5. Параметры остановки\n", " * __stopping_rounds:__ задает количество шагов в течение которого должно произойти заданное улучшение(stopping_tolerance:) для заданной метрики (stopping_metric) \n", " * __stopping_metric:__  метрика для ранней остановки (по умолчанию логлосс для классификации и дисперсия - для регрессии). Возможные значения: deviance, logloss, mse, rmse, mae, rmsle, auc, lift_top_group, misclassification, mean_per_class_error\n", " * __stopping_tolerance:__ относительное улучшение для остановки обучения (если меньше, то остановка)\n", "\n", " \n", " \n", " \n", "Полный список всех параметров с описанием можно найти на сайте с [официальной документацией](http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/drf.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Отдельно стоит отметить, что биннинг переменных в h2o позволяет добиться улучшения качества и скорости построения леса. \n", "\n", "__Количественные переменные__\n", "В классической реализации случайного леса для количественного предиктора с k категориями может быть рассмотрено k-1 вариантов разбиения. (Все значения сортируются и средние значения по каждой паре рассматриваются в качестве точек расщепления). \n", "В h2o для каждой переменной определяются типы и количество интервалов для разбиения, т.е. создается гистограмма (тип : histogram_type), и количество регулируется двумя параметрами: nbins и nbins_top_level. Перебор этих параметров позволит улучшить качество леса.\n", "\n", "__Качественные переменные__\n", "В классической реализации случайного леса для категориального предиктора с k категориями может быть рассмотрено ${2^{k-1}} - 1$ разбиения (всеми возможными способами). \n", "В h2o для каждой качественной переменной также определюется типы и количество интервалов для разбиения. Количество регулируется : nbins_cats и nbins_top_level (при переходе на уровень ниже значение nbins_top_level уменьшается в 2 раза до тех пор пока значение больше nbins_cats). \n", "Пример как происходит разбиение на бины: \n", "Если количество категорий меньше значения параметра nbins_cats, каждая категория\n", "получает свой бин. Допустим, у нас есть переменная Class. Если у нее\n", "есть уровни A, B, C, D, E, F, G и мы зададим nbins_cats=8, то будут\n", "сформировано 7 бинов: {A}, {B}, {C}, {D}, {E}, {F} и {G}. Каждая категория\n", "получает свой бин. Будет рассмотрено ${2^6}-1=63$ точки расщепления. Если\n", "мы зададим nbins_cats=10, то все равно будут получены те же самые\n", "бины, потому что у нас всего 7 категорий. Если количество категорий\n", "больше значения параметра nbins_cats, категории будут сгруппированы\n", "в бины в лексикографическом порядке. Например, если мы зададим\n", "nbins_cats=2, то будет сформировано 2 бина: {A, B, C, D} и {E, F, G}. У\n", "нас будет одна точка расщепления. A, B, C и D попадут в один и тот же\n", "узел и будут разбиты только на последующем, более нижнем уровне или\n", "вообще не будут разбиты. \n", " \n", "Этот параметр очень важен для настройки, при больших значениях параметра мы можем получить дополнительную случайность в построении " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "И в завершение этой части сделаем подбор параметров и попытаемся улучшить модель по разным параметрам" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from h2o.grid.grid_search import H2OGridSearch" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "drf Grid Build progress: |████████████████████████████████████████████████| 100%\n", " max_depth model_ids auc\n", "0 14 rf_grid_max_depth_model_1 0.9433909148028168\n", "1 18 rf_grid_max_depth_model_2 0.9421669345823371\n", "2 10 rf_grid_max_depth_model_0 0.9416895822963501\n", "3 24 rf_grid_max_depth_model_3 0.9414692658566638\n", "\n", "0.9433909148028168\n", "Hyperparameters: [max_depth]\n", "[14]\n" ] } ], "source": [ "rf_params = {'max_depth': [10, 14, 18, 24]}\n", "\n", "rf_grid = H2OGridSearch(model=H2ORandomForestEstimator(ntrees=800, seed=152),\n", " grid_id='rf_grid_max_depth',\n", " hyper_params=rf_params)\n", "rf_grid.train(X, y, training_frame=training, validation_frame=validation)\n", "\n", "# модели, отсортированные по AUC\n", "rf_gridperf = rf_grid.get_grid(sort_by='auc', decreasing=True)\n", "print(rf_gridperf)\n", "\n", "# выберем лучшую модель и выведем AUC на тесте \n", "best_rf = rf_gridperf.models[0]\n", "print(best_rf.auc(valid=True))\n", "print(rf_gridperf.get_hyperparams(0))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "drf Grid Build progress: |████████████████████████████████████████████████| 100%\n", " mtries model_ids auc\n", "0 7 rf_grid_mtries1_model_2 0.9422444533296342\n", "1 5 rf_grid_mtries1_model_1 0.9417263017029645\n", "2 3 rf_grid_mtries1_model_0 0.9406940783836933\n", "3 9 rf_grid_mtries1_model_3 0.9390417050860458\n", "\n", "0.9422444533296342\n", "Hyperparameters: [mtries]\n", "[7]\n" ] } ], "source": [ "rf_params1 = {'mtries': [3, 5, 7, 9]}\n", "\n", "rf_grid1 = H2OGridSearch(model=H2ORandomForestEstimator(ntrees=800, seed=152, max_depth=14),\n", " grid_id='rf_grid_mtries1',\n", " hyper_params=rf_params1)\n", "rf_grid1.train(X, y, training_frame=training, validation_frame=validation)\n", "\n", "# модели, отсортированные по AUC\n", "rf_gridperf1 = rf_grid1.get_grid(sort_by='auc', decreasing=True)\n", "print(rf_gridperf1)\n", "\n", "# выберем лучшую модель и выведем AUC на тесте \n", "best_rf1 = rf_gridperf1.models[0]\n", "print(best_rf1.auc(valid=True))\n", "print(rf_gridperf1.get_hyperparams(0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "теперь к настройке более специфических параметров" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "drf Grid Build progress: |████████████████████████████████████████████████| 100%\n", " histogram_type model_ids auc\n", "0 RoundRobin rf_grid_hist_type_model_3 0.9433868348687484\n", "1 UniformAdaptive rf_grid_hist_type_model_0 0.9422444533296342\n", "2 Random rf_grid_hist_type_model_1 0.9391192238333429\n", "3 QuantilesGlobal rf_grid_hist_type_model_2 0.9380339613711842\n", "\n", "AUC for the best model: 0.9433868348687484\n", "Hyperparameters: [histogram_type]\n", "['RoundRobin']\n" ] } ], "source": [ "rf_params2 = {'histogram_type': ['UniformAdaptive', 'Random', 'QuantilesGlobal', 'RoundRobin']}\n", "\n", "rf_grid2 = H2OGridSearch(model=H2ORandomForestEstimator(ntrees=800, seed=152, max_depth=14, mtries=7),\n", " grid_id='rf_grid_hist_type',\n", " hyper_params=rf_params2)\n", "rf_grid2.train(X, y, training_frame=training, validation_frame=validation)\n", "\n", "# модели, отсортированные по AUC\n", "rf_gridperf2 = rf_grid2.get_grid(sort_by='auc', decreasing=True)\n", "print(rf_gridperf2)\n", "\n", "# выберем лучшую модель и выведем AUC на тесте \n", "best_rf2 = rf_gridperf2.models[0]\n", "print('AUC for the best model: ', best_rf2.auc(valid=True))\n", "print(rf_gridperf2.get_hyperparams(0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "По типу гистограммы RoundRobin оказался самым оптимальным" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "drf Grid Build progress: |████████████████████████████████████████████████| 100%\n", " col_sample_rate_per_tree sample_rate model_ids \\\n", "0 0.9 0.6 rf_grid3_model_9 \n", "1 0.7 0.7 rf_grid3_model_12 \n", "2 0.9 0.7 rf_grid3_model_14 \n", "3 0.7 0.5 rf_grid3_model_2 \n", "4 0.8 0.6 rf_grid3_model_8 \n", "5 0.9 0.8 rf_grid3_model_19 \n", "6 0.9 0.9 rf_grid3_model_24 \n", "7 0.9 0.5 rf_grid3_model_4 \n", "8 0.8 0.5 rf_grid3_model_3 \n", "9 0.6 0.5 rf_grid3_model_1 \n", "10 0.7 0.6 rf_grid3_model_7 \n", "11 0.7 0.8 rf_grid3_model_17 \n", "12 0.8 0.7 rf_grid3_model_13 \n", "13 0.6 0.8 rf_grid3_model_16 \n", "14 0.5 0.7 rf_grid3_model_10 \n", "15 0.8 0.8 rf_grid3_model_18 \n", "16 0.5 0.5 rf_grid3_model_0 \n", "17 0.6 0.7 rf_grid3_model_11 \n", "18 0.8 0.9 rf_grid3_model_23 \n", "19 0.7 0.9 rf_grid3_model_22 \n", "20 0.6 0.6 rf_grid3_model_6 \n", "21 0.6 0.9 rf_grid3_model_21 \n", "22 0.5 0.6 rf_grid3_model_5 \n", "23 0.5 0.8 rf_grid3_model_15 \n", "24 0.5 0.9 rf_grid3_model_20 \n", "\n", " auc \n", "0 0.9416855023622818 \n", "1 0.9415875839446435 \n", "2 0.941518225065483 \n", "3 0.9405920800319867 \n", "4 0.9405472007572357 \n", "5 0.9403554438560272 \n", "6 0.9402412057021159 \n", "7 0.9402412057021159 \n", "8 0.9401188076800678 \n", "9 0.9398862514381767 \n", "10 0.9396904146029 \n", "11 0.9396659349984905 \n", "12 0.9393640198774388 \n", "13 0.9392212221850494 \n", "14 0.9385194735253078 \n", "15 0.9375321294807876 \n", "16 0.9374015715906031 \n", "17 0.9373322127114426 \n", "18 0.9373281327773744 \n", "19 0.9370751768651419 \n", "20 0.9361816713041917 \n", "21 0.9354268835015626 \n", "22 0.9351494479849205 \n", "23 0.935018890094736 \n", "24 0.9342763420943118 \n", "\n", "AUC for the best model: 0.9416855023622818\n", "Hyperparameters: [col_sample_rate_per_tree, sample_rate]\n", "[0.9, 0.6]\n" ] } ], "source": [ "rf_params3 = {'col_sample_rate_per_tree': [0.5, 0.6, 0.7, 0.8, 0.9], \n", " 'sample_rate': [0.5, 0.6, 0.7, 0.8, 0.9] }\n", "\n", "rf_grid3 = H2OGridSearch(model=H2ORandomForestEstimator(ntrees=800, seed=152, max_depth=14, mtries=7,\n", " histogram_type='RoundRobin'),\n", " grid_id='rf_grid3',\n", " hyper_params=rf_params3)\n", "rf_grid3.train(X, y, training_frame=training, validation_frame=validation)\n", "\n", "# модели, отсортированные по AUC\n", "rf_gridperf3 = rf_grid3.get_grid(sort_by='auc', decreasing=True)\n", "print(rf_gridperf3)\n", "\n", "# выберем лучшую модель и выведем AUC на тесте \n", "best_rf3 = rf_gridperf3.models[0]\n", "print('AUC for the best model: ', best_rf3.auc(valid=True))\n", "print(rf_gridperf3.get_hyperparams(0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Logistic Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### sklearn LogisticRegression" ] }, { "cell_type": "code", "execution_count": 213, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression" ] }, { "cell_type": "code", "execution_count": 214, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC для sklearn LogisticRegression: 0.8271\n" ] } ], "source": [ "logreg = LogisticRegression().fit(X_train, y_train)\n", "print('AUC для sklearn LogisticRegression: {:.4f}'.format(roc_auc_score(y_test, logreg.predict_proba(X_test)[:, 1])))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "проверим на отмасштабированных данных" ] }, { "cell_type": "code", "execution_count": 215, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC для sklearn LogisticRegression Scaled: 0.8283\n" ] } ], "source": [ "from sklearn.preprocessing import StandardScaler\n", "scaler = StandardScaler()\n", "X_train_scaled = scaler.fit_transform(X_train)\n", "X_test_scaled = scaler.transform(X_test)\n", "\n", "logreg = LogisticRegression().fit(X_train_scaled, y_train)\n", "print('AUC для sklearn LogisticRegression Scaled: {:.4f}'.format(\n", " roc_auc_score(y_test, logreg.predict_proba(X_test_scaled)[:, 1])))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "результаты хуже, чем в лесу \n", "займемся подбором параметров в H2O" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### H2OGeneralizedLinearEstimator" ] }, { "cell_type": "code", "execution_count": 216, "metadata": {}, "outputs": [], "source": [ "from h2o.estimators.glm import H2OGeneralizedLinearEstimator" ] }, { "cell_type": "code", "execution_count": 217, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "glm Model Build progress: |███████████████████████████████████████████████| 100%\n" ] } ], "source": [ "# создаем экземпляр класса H2OGeneralizedLinearEstimator\n", "glm_model = H2OGeneralizedLinearEstimator(family= \"binomial\", seed=1000000)\n", "# обучаем модель\n", "glm_model.train(X, y, training_frame= training, validation_frame=validation)" ] }, { "cell_type": "code", "execution_count": 218, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "ModelMetricsBinomialGLM: glm\n", "** Reported on train data. **\n", "\n", "MSE: 0.09918009133349871\n", "RMSE: 0.3149287083349162\n", "LogLoss: 0.3263867097273659\n", "Null degrees of freedom: 2332\n", "Residual degrees of freedom: 2312\n", "Null deviance: 1937.5065769088035\n", "Residual deviance: 1522.9203875878893\n", "AIC: 1564.9203875878893\n", "AUC: 0.8190335291166141\n", "Gini: 0.6380670582332282\n", "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.24799189912357045: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
01ErrorRate
01773.0220.00.1104 (220.0/1993.0)
1149.0191.00.4382 (149.0/340.0)
Total1922.0411.00.1582 (369.0/2333.0)
" ], "text/plain": [ " 0 1 Error Rate\n", "----- ---- --- ------- --------------\n", "0 1773 220 0.1104 (220.0/1993.0)\n", "1 149 191 0.4382 (149.0/340.0)\n", "Total 1922 411 0.1582 (369.0/2333.0)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Maximum Metrics: Maximum metrics at their respective thresholds\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
metricthresholdvalueidx
max f10.24799190.5086551178.0
max f20.12289170.6103074263.0
max f0point50.30546310.5099502150.0
max accuracy0.60801740.864980747.0
max precision0.98701101.00.0
max recall0.00698311.0396.0
max specificity0.98701101.00.0
max absolute_mcc0.24799190.4180577178.0
max min_per_class_accuracy0.14007320.75248.0
max mean_per_class_accuracy0.15570060.7547866237.0
" ], "text/plain": [ "metric threshold value idx\n", "--------------------------- ----------- -------- -----\n", "max f1 0.247992 0.508655 178\n", "max f2 0.122892 0.610307 263\n", "max f0point5 0.305463 0.50995 150\n", "max accuracy 0.608017 0.864981 47\n", "max precision 0.987011 1 0\n", "max recall 0.00698315 1 396\n", "max specificity 0.987011 1 0\n", "max absolute_mcc 0.247992 0.418058 178\n", "max min_per_class_accuracy 0.140073 0.75 248\n", "max mean_per_class_accuracy 0.155701 0.754787 237" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Gains/Lift Table: Avg response rate: 14,57 %\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
groupcumulative_data_fractionlower_thresholdliftcumulative_liftresponse_ratecumulative_response_ratecapture_ratecumulative_capture_rategaincumulative_gain
10.01028720.74563595.43223045.43223040.79166670.79166670.05588240.0558824443.2230392443.2230392
20.02014570.65842744.17672634.81783480.60869570.70212770.04117650.0970588317.6726343381.7834793
30.03000430.59845354.17672634.60718490.60869570.67142860.04117650.1382353317.6726343360.7184874
40.04029150.55078322.85906864.16085730.41666670.60638300.02941180.1676471185.9068627316.0857322
50.05015000.51574542.98337603.92938660.43478260.57264960.02941180.1970588198.3375959292.9386626
60.10030000.36925853.10832083.51885370.45299150.51282050.15588240.3529412210.8320764251.8853695
70.15002140.28606572.95765723.33285710.43103450.48571430.14705880.5195.7657201233.2857143
80.20017150.22455531.75942682.93865730.25641030.42826550.08823530.588235375.9426848193.8657262
90.30004290.15205221.44303212.44082770.21030040.35571430.14411760.732352944.3032063144.0827731
100.39991430.11335590.85403942.04455580.12446350.29796360.08529410.8176471-14.5960616104.4555829
110.50021430.08636710.61579941.75807000.08974360.25621250.06176470.8794118-38.420060375.8069963
120.60008570.06487640.32394601.51939080.04721030.22142860.03235290.9117647-67.605402751.9390756
130.69995710.04732500.29449631.34462020.04291850.19595840.02941180.9411765-70.550366134.4620151
140.79982850.03371930.20614741.20246360.03004290.17524120.02058820.9617647-79.385256220.2463590
150.89970000.02057910.26504671.09840540.03862660.16007620.02647060.9882353-73.49532959.8405403
161.00.00178950.11729511.00.01709400.14573510.01176471.0-88.27048770.0
" ], "text/plain": [ " group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n", "-- ------- -------------------------- ----------------- -------- ----------------- --------------- -------------------------- -------------- ------------------------- -------- -----------------\n", " 1 0.0102872 0.745636 5.43223 5.43223 0.791667 0.791667 0.0558824 0.0558824 443.223 443.223\n", " 2 0.0201457 0.658427 4.17673 4.81783 0.608696 0.702128 0.0411765 0.0970588 317.673 381.783\n", " 3 0.0300043 0.598454 4.17673 4.60718 0.608696 0.671429 0.0411765 0.138235 317.673 360.718\n", " 4 0.0402915 0.550783 2.85907 4.16086 0.416667 0.606383 0.0294118 0.167647 185.907 316.086\n", " 5 0.05015 0.515745 2.98338 3.92939 0.434783 0.57265 0.0294118 0.197059 198.338 292.939\n", " 6 0.1003 0.369259 3.10832 3.51885 0.452991 0.512821 0.155882 0.352941 210.832 251.885\n", " 7 0.150021 0.286066 2.95766 3.33286 0.431034 0.485714 0.147059 0.5 195.766 233.286\n", " 8 0.200171 0.224555 1.75943 2.93866 0.25641 0.428266 0.0882353 0.588235 75.9427 193.866\n", " 9 0.300043 0.152052 1.44303 2.44083 0.2103 0.355714 0.144118 0.732353 44.3032 144.083\n", " 10 0.399914 0.113356 0.854039 2.04456 0.124464 0.297964 0.0852941 0.817647 -14.5961 104.456\n", " 11 0.500214 0.0863671 0.615799 1.75807 0.0897436 0.256213 0.0617647 0.879412 -38.4201 75.807\n", " 12 0.600086 0.0648764 0.323946 1.51939 0.0472103 0.221429 0.0323529 0.911765 -67.6054 51.9391\n", " 13 0.699957 0.047325 0.294496 1.34462 0.0429185 0.195958 0.0294118 0.941176 -70.5504 34.462\n", " 14 0.799829 0.0337193 0.206147 1.20246 0.0300429 0.175241 0.0205882 0.961765 -79.3853 20.2464\n", " 15 0.8997 0.0205791 0.265047 1.09841 0.0386266 0.160076 0.0264706 0.988235 -73.4953 9.84054\n", " 16 1 0.00178946 0.117295 1 0.017094 0.145735 0.0117647 1 -88.2705 0" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "glm_model.model_performance().show()" ] }, { "cell_type": "code", "execution_count": 219, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'0.8190335291166141'" ] }, "execution_count": 219, "metadata": {}, "output_type": "execute_result" } ], "source": [ "str(glm_model.model_performance().auc())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "___Описание параметров вызова функции___ \n", "1. Параметры, определяющие задачу \n", " * __model_id:__ идентификатор \n", " * __training_frame:__ датасет для построения модели\n", " * __validation_frame:__ датасет для валидации\n", " * __nfolds:__ количество фолдов для кросс-валидации (по умолчанию 0)\n", " * __y:__ имена зависимой переменной\n", " * __x:__ список названий предикторов\n", " * __seed:__ random state \n", " * __family:__ тип модели (gaussian, binomial, multinomial, ordinal, quasibinomial, poisson, gamma, tweedie \n", " * __solver:__ \n", " * IRLSM: Iteratively Reweighted Least Squares Method - используется с небольшим количеством предикторов и для l1-регулярзации\n", " * L_BFGS: Limited-memory Broyden-Fletcher-Goldfarb-Shanno algorithm - используется для данных в большим числом колонок\n", " * COORDINATE_DESCENT, COORDINATE_DESCENT_NAIVE - экспериментальные\n", " * AUTO: Sets the solver based on given data and parameters (default)\n", " * GRADIENT_DESCENT_LH, GRADIENT_DESCENT_SQERR: используется только для family=Ordinal

\n", "2. Параметры, определяющие регуляризацию \n", "для справки формула ElasticNet, объединяющая $L_1$ и $L_2$ регуляризацию\n", "$$\\large \\begin{array}{rcl}\n", "L &=& -\\mathcal{L} + \\lambda R\\left(\\textbf W\\right) \\\\\n", "&=& -\\mathcal{L} + \\lambda \\left(\\alpha \\sum_{k=1}^K\\sum_{i=1}^M w_{ki}^2 + \\left(1 - \\alpha\\right) \\sum_{k=1}^K\\sum_{i=1}^M \\left|w_{ki}\\right| \\right)\n", "\\end{array}$$ где $\\alpha \\in \\left[0, 1\\right]$\n", "\n", " * __alpha:__  распределение между $L_1$ и $L_2$ регуляризацией. (1 - $L_1$, 0 - $L_2$)\n", " * __lambda:__ сила регуляризации\n", " * __lambda_search:__ True / False. Определяет стоит ли начинать поиск $\\lambda$, начиная с максимального значения\n", " * __lambda_min_ratio:__ минимальное значение $\\lambda$, используемое при поиске $\\lambda$\n", " * __nlambdas:__ количество шагов при поиске $\\lambda$ (по умолчанию 100)\n", "

\n", "3. Параметры, влияющие на предобработку предикторов\n", " * __standardize:__ использовать ли масштабирование\n", " * __missing_values_handling:__ как работать с пропцщенными значениями (пропускать или испутировать средним)\n", " * __remove_collinear_columns:__ удалять ли автоматически коллинеарные столбцы при построении модели \n", " * __interactions:__ список колонок, из которых буду составлены все возможные пары и использованы для построения модели \n", " * __interaction_pairs:__ список уже готовых пар для модели \n", "

\n", "\n", "Полный список всех параметров с описанием можно найти на сайте с [официальной документацией](http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/glm.html)\n", " \n", " \n", "Воспольуземся перебором параметров [H2OGridSearch] для поиска лучшей модели" ] }, { "cell_type": "code", "execution_count": 223, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "glm Grid Build progress: |████████████████████████████████████████████████| 100%\n", " alpha model_ids auc\n", "0 [0.0] gridresults_model_0 0.8284550921657106\n", "1 [1.0] gridresults_model_20 0.8283938931546866\n", "2 [0.9500000000000001] gridresults_model_19 0.8283490138799356\n", "3 [0.9] gridresults_model_18 0.8282755750667069\n", "4 [0.8500000000000001] gridresults_model_17 0.8282388556600926\n", "5 [0.8] gridresults_model_16 0.8282347757260242\n", "6 [0.7000000000000001] gridresults_model_14 0.8282102961216147\n", "7 [0.75] gridresults_model_15 0.828202136253478\n", "8 [0.65] gridresults_model_13 0.8281613369127955\n", "9 [0.6000000000000001] gridresults_model_12 0.8280185392204062\n", "10 [0.55] gridresults_model_11 0.8278186224510612\n", "11 [0.5] gridresults_model_10 0.827785982978515\n", "12 [0.45] gridresults_model_9 0.827753343505969\n", "13 [0.35000000000000003] gridresults_model_7 0.8275615866047604\n", "14 [0.4] gridresults_model_8 0.8275493468025557\n", "15 [0.30000000000000004] gridresults_model_6 0.8274432685167807\n", "16 [0.25] gridresults_model_5 0.8272963908903233\n", "17 [0.2] gridresults_model_4 0.8272392718133674\n", "18 [0.15000000000000002] gridresults_model_3 0.8270801543847051\n", "19 [0.05] gridresults_model_1 0.8270679145825003\n", "20 [0.1] gridresults_model_2 0.8270107955055446\n", "\n", "AUC for the best model: 0.8284550921657106\n", "['Ridge ( lambda = 3.014E-5 )']\n" ] } ], "source": [ "hyper_parameters = {'alpha': np.arange(0, 1.05, 0.05).tolist()}\n", "\n", "gridsearch = H2OGridSearch(H2OGeneralizedLinearEstimator(family='binomial', lambda_search=True, standardize=True),\n", " grid_id=\"gridresults\", hyper_params=hyper_parameters)\n", "gridsearch.train(X, y, training_frame= training, validation_frame=validation)\n", "\n", "gridperf = gridsearch.get_grid(sort_by=\"auc\", decreasing=True)\n", "best_model = gridperf.models[0]\n", "print(gridperf)\n", "\n", "print('AUC for the best model: ', best_model.auc(valid=True))\n", "print(best_model.summary()['regularization'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "H2O позволяет настроить силу регуляризации. \n", "Для получения модели с более выоским AUC - надо попробовать предварительно разные способы масштабирования \n", "И попробовать настроить еще один интересный параметр : interactions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradient Boosting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### sklearn GradientBoostingClassifier\n" ] }, { "cell_type": "code", "execution_count": 224, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import GradientBoostingClassifier" ] }, { "cell_type": "code", "execution_count": 225, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC для sklearn GradientBoostingClassifier: 0.9357\n" ] } ], "source": [ "grb = GradientBoostingClassifier().fit(X_train, y_train)\n", "print('AUC для sklearn GradientBoostingClassifier: {:.4f}'.format(\n", " roc_auc_score(y_test, grb.predict_proba(X_test)[:, 1])))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Gradient Boosting Machine" ] }, { "cell_type": "code", "execution_count": 226, "metadata": {}, "outputs": [], "source": [ "from h2o.estimators.gbm import H2OGradientBoostingEstimator" ] }, { "cell_type": "code", "execution_count": 232, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm Model Build progress: |███████████████████████████████████████████████| 100%\n", "\n", "ModelMetricsBinomial: gbm\n", "** Reported on train data. **\n", "\n", "MSE: 4.97643761962381e-09\n", "RMSE: 7.05438701775272e-05\n", "LogLoss: 2.387073099550759e-05\n", "Mean Per-Class Error: 0.0\n", "AUC: 1.0\n", "Gini: 1.0\n", "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.9991932827883621: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
01ErrorRate
01993.00.00.0 (0.0/1993.0)
10.0340.00.0 (0.0/340.0)
Total1993.0340.00.0 (0.0/2333.0)
" ], "text/plain": [ " 0 1 Error Rate\n", "----- ---- --- ------- ------------\n", "0 1993 0 0 (0.0/1993.0)\n", "1 0 340 0 (0.0/340.0)\n", "Total 1993 340 0 (0.0/2333.0)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Maximum Metrics: Maximum metrics at their respective thresholds\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
metricthresholdvalueidx
max f10.99919331.0158.0
max f20.99919331.0158.0
max f0point50.99919331.0158.0
max accuracy0.99919331.0158.0
max precision1.00000001.00.0
max recall0.99919331.0158.0
max specificity1.00000001.00.0
max absolute_mcc0.99919331.0158.0
max min_per_class_accuracy0.99919331.0158.0
max mean_per_class_accuracy0.99919331.0158.0
" ], "text/plain": [ "metric threshold value idx\n", "--------------------------- ----------- ------- -----\n", "max f1 0.999193 1 158\n", "max f2 0.999193 1 158\n", "max f0point5 0.999193 1 158\n", "max accuracy 0.999193 1 158\n", "max precision 1 1 0\n", "max recall 0.999193 1 158\n", "max specificity 1 1 0\n", "max absolute_mcc 0.999193 1 158\n", "max min_per_class_accuracy 0.999193 1 158\n", "max mean_per_class_accuracy 0.999193 1 158" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Gains/Lift Table: Avg response rate: 14,57 %\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
groupcumulative_data_fractionlower_thresholdliftcumulative_liftresponse_ratecumulative_response_ratecapture_ratecumulative_capture_rategaincumulative_gain
10.01028720.99999996.86176476.86176471.01.00.07058820.0705882586.1764706586.1764706
20.02014570.99999906.86176476.86176471.01.00.06764710.1382353586.1764706586.1764706
30.03000430.99999756.86176476.86176471.01.00.06764710.2058824586.1764706586.1764706
40.04029150.99999586.86176476.86176471.01.00.07058820.2764706586.1764706586.1764706
50.05015000.99999296.86176476.86176471.01.00.06764710.3441176586.1764706586.1764706
60.10030000.99996026.86176476.86176471.01.00.34411760.6882353586.1764706586.1764706
70.15002140.00013986.27023336.66571430.91379310.97142860.31176471.0527.0233266566.5714286
80.20017150.00005790.04.99571730.00.72805140.01.0-100.0399.5717345
90.30004290.00002410.03.33285710.00.48571430.01.0-100.0233.2857143
100.39991430.00001170.02.50053590.00.36441590.01.0-100.0150.0535906
110.50021430.00000580.01.99914310.00.29134530.01.0-100.099.9143102
120.60008570.00000300.01.66642860.00.24285710.01.0-100.066.6428571
130.69995710.00000130.01.42865890.00.20820580.01.0-100.042.8658910
140.79982850.00000060.01.25026800.00.18220790.01.0-100.025.0267953
150.89970000.00000020.01.11148170.00.16198190.01.0-100.011.1481658
161.00.00000000.01.00.00.14573510.01.0-100.00.0
" ], "text/plain": [ " group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n", "-- ------- -------------------------- ----------------- ------- ----------------- --------------- -------------------------- -------------- ------------------------- ------- -----------------\n", " 1 0.0102872 1 6.86176 6.86176 1 1 0.0705882 0.0705882 586.176 586.176\n", " 2 0.0201457 0.999999 6.86176 6.86176 1 1 0.0676471 0.138235 586.176 586.176\n", " 3 0.0300043 0.999997 6.86176 6.86176 1 1 0.0676471 0.205882 586.176 586.176\n", " 4 0.0402915 0.999996 6.86176 6.86176 1 1 0.0705882 0.276471 586.176 586.176\n", " 5 0.05015 0.999993 6.86176 6.86176 1 1 0.0676471 0.344118 586.176 586.176\n", " 6 0.1003 0.99996 6.86176 6.86176 1 1 0.344118 0.688235 586.176 586.176\n", " 7 0.150021 0.000139845 6.27023 6.66571 0.913793 0.971429 0.311765 1 527.023 566.571\n", " 8 0.200171 5.79219e-05 0 4.99572 0 0.728051 0 1 -100 399.572\n", " 9 0.300043 2.41348e-05 0 3.33286 0 0.485714 0 1 -100 233.286\n", " 10 0.399914 1.17004e-05 0 2.50054 0 0.364416 0 1 -100 150.054\n", " 11 0.500214 5.77469e-06 0 1.99914 0 0.291345 0 1 -100 99.9143\n", " 12 0.600086 2.98531e-06 0 1.66643 0 0.242857 0 1 -100 66.6429\n", " 13 0.699957 1.30511e-06 0 1.42866 0 0.208206 0 1 -100 42.8659\n", " 14 0.799829 6.00854e-07 0 1.25027 0 0.182208 0 1 -100 25.0268\n", " 15 0.8997 1.69372e-07 0 1.11148 0 0.161982 0 1 -100 11.1482\n", " 16 1 2.73715e-10 0 1 0 0.145735 0 1 -100 0" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "gbm = H2OGradientBoostingEstimator(ntrees = 800, learn_rate = 0.1,seed = 1234)\n", "gbm.train(X, y, training_frame= training, validation_frame=validation)\n", "gbm.model_performance().show()" ] }, { "cell_type": "code", "execution_count": 237, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "AUC на валидационной выборке: 0.9146600190940914\n" ] } ], "source": [ "print('AUC на валидационной выборке: ',gbm.auc(valid=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "___Описание параметров вызова функции___ \n", "только специфические, отличаюшиеся от RandomForest\n", "1. Параметры для GBM\n", " * __learn_rate:__ скорость обучения (от 0 до 1)\n", " * __learn_rate_annealing:__ уменьшает скорость обучения после построения каждого дерева\n", " \n", "Для настройки параметров бустинга есть следующие советы от экспертов (Марк Лэндри, Дмитрий Ларько и [github](https://github.com/h2oai/h2o-3/blob/master/h2o-docs/src/product/tutorials/gbm/gbmTuning.ipynb)): \n", "* заификисровать количество деревьев константой, подобрать leraning rate (например, learn_rate=0.02, learn_rate_annealing=0.995) \n", "* в самом конце можно снова вернуться к настройке количества деревьев\n", "* потом необходимо настроить глубину деревьев (max_depth) (чаще всего это 4-10)\n", "* попробовать изменить тип гистрограммы, nbins/nbins_cat\n", "* изменить настройки, определяющие подвыборки (sample_rate, col_sample_rate), чаще всего это 70-80%\n", "* для несбалансированных наборов необходимо настроить параметры, отвечающие за баланс классов (sample_rate_per_class)\n", "* определить критерии остановки " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "начинаем настройку" ] }, { "cell_type": "code", "execution_count": 239, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm Grid Build progress: |████████████████████████████████████████████████| 100%\n", " learn_rate model_ids auc\n", "0 0.001 gbm_grid_learn_rate_model_1 0.9319711793457418\n", "1 0.01 gbm_grid_learn_rate_model_2 0.9310858336529282\n", "2 1.0 gbm_grid_learn_rate_model_4 0.9299924113226331\n", "3 1.0E-4 gbm_grid_learn_rate_model_0 0.9189194702613606\n", "4 0.1 gbm_grid_learn_rate_model_3 0.9146600190940914\n", "\n", "0.9319711793457418\n", "Hyperparameters: [learn_rate]\n", "[0.001]\n" ] } ], "source": [ "gbm_params = {'learn_rate': [0.0001, 0.001, 0.01, 0.1, 1]}\n", "\n", "gbm_grid = H2OGridSearch(model=H2OGradientBoostingEstimator(ntrees = 800, seed = 1234),\n", " grid_id='gbm_grid_learn_rate',\n", " hyper_params=gbm_params)\n", "gbm_grid.train(X, y, training_frame=training, validation_frame=validation)\n", "\n", "# модели, отсортированные по AUC\n", "gbm_gridperf = gbm_grid.get_grid(sort_by='auc', decreasing=True)\n", "print(gbm_gridperf)\n", "\n", "# выберем лучшую модель и выведем AUC на тесте \n", "best_gbm = gbm_gridperf.models[0]\n", "print(best_gbm.auc(valid=True))\n", "print(gbm_gridperf.get_hyperparams(0))" ] }, { "cell_type": "code", "execution_count": 240, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm Grid Build progress: |████████████████████████████████████████████████| 100%\n", " max_depth model_ids auc\n", "0 6 gbm_grid_max_depth_model_1 0.9400943280756583\n", "1 8 gbm_grid_max_depth_model_2 0.9301229692128176\n", "2 10 gbm_grid_max_depth_model_3 0.9272466156946904\n", "3 4 gbm_grid_max_depth_model_0 0.9173772551835563\n", "\n", "0.9400943280756583\n", "Hyperparameters: [max_depth]\n", "[6]\n" ] } ], "source": [ "gbm_params1 = {'max_depth': [4, 6, 8, 10]}\n", "\n", "gbm_grid1 = H2OGridSearch(model=H2OGradientBoostingEstimator(ntrees = 800, seed = 1234, learn_rate=0.001),\n", " grid_id='gbm_grid_max_depth',\n", " hyper_params=gbm_params1)\n", "gbm_grid1.train(X, y, training_frame=training, validation_frame=validation)\n", "\n", "# модели, отсортированные по AUC\n", "gbm_gridperf1 = gbm_grid1.get_grid(sort_by='auc', decreasing=True)\n", "print(gbm_gridperf1)\n", "\n", "# выберем лучшую модель и выведем AUC на тесте \n", "best_gbm1 = gbm_gridperf1.models[0]\n", "print(best_gbm1.auc(valid=True))\n", "print(gbm_gridperf1.get_hyperparams(0))" ] }, { "cell_type": "code", "execution_count": 242, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gbm Grid Build progress: |████████████████████████████████████████████████| 100%\n", " histogram_type nbins model_ids auc\n", "0 RoundRobin 16 gbm_grid_hist_model_11 0.9450596078367374\n", "1 RoundRobin 8 gbm_grid_hist_model_7 0.9447291331772079\n", "2 RoundRobin 32 gbm_grid_hist_model_15 0.9445985752870233\n", "3 UniformAdaptive 32 gbm_grid_hist_model_12 0.943439874011636\n", "4 UniformAdaptive 8 gbm_grid_hist_model_4 0.9427993243629184\n", "5 UniformAdaptive 16 gbm_grid_hist_model_8 0.9425912477254368\n", "6 Random 16 gbm_grid_hist_model_9 0.9421424549779276\n", "7 Random 32 gbm_grid_hist_model_13 0.9394864178994867\n", "8 Random 8 gbm_grid_hist_model_5 0.9377973251952249\n", "9 QuantilesGlobal 32 gbm_grid_hist_model_14 0.9370343775244592\n", "10 QuantilesGlobal 16 gbm_grid_hist_model_10 0.9347373746440257\n", "11 QuantilesGlobal 8 gbm_grid_hist_model_6 0.9315958254114614\n", "12 RoundRobin 2 gbm_grid_hist_model_3 0.9300209708611109\n", "13 Random 2 gbm_grid_hist_model_1 0.9293069823991644\n", "14 UniformAdaptive 2 gbm_grid_hist_model_0 0.9261654331666\n", "15 QuantilesGlobal 2 gbm_grid_hist_model_2 0.8943541872363342\n", "\n", "0.9450596078367374\n", "Hyperparameters: [nbins, histogram_type]\n", "[16, 'RoundRobin']\n" ] } ], "source": [ "gbm_params2 = {'nbins': [2, 8, 16, 32],\n", " 'histogram_type': ['UniformAdaptive', 'Random', 'QuantilesGlobal', 'RoundRobin']}\n", "\n", "gbm_grid2 = H2OGridSearch(model=H2OGradientBoostingEstimator(ntrees = 800, seed = 1234, \n", " max_depth=6, learn_rate=0.001),\n", " grid_id='gbm_grid_hist',\n", " hyper_params=gbm_params2)\n", "gbm_grid2.train(X, y, training_frame=training, validation_frame=validation)\n", "\n", "# модели, отсортированные по AUC\n", "gbm_gridperf2 = gbm_grid2.get_grid(sort_by='auc', decreasing=True)\n", "print(gbm_gridperf2)\n", "\n", "# выберем лучшую модель и выведем AUC на тесте \n", "best_gbm2 = gbm_gridperf2.models[0]\n", "print(best_gbm2.auc(valid=True))\n", "print(gbm_gridperf2.get_hyperparams(0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "В итоге AUC = 0.9450 на валидационной выборке получился немного выше для GradientBoostingMachine, чем для RandomForestClassifier AUC = 0.941686" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### XGBoost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Еще одним из алгоритмов, доступных в H20, из семейства Gradient Boosting Machine является XGBoost. Эта библиотека - очень популярный инструмент для решения Kaggle-задач. Он позволяет строить деервья параллельно, что позволяет решать проблемы скорости обучения. \n", "\n", "__Особенности модели__\n", "\n", "Реализация модели поддерживает особенности реализации scikit-learn и R с новыми дополнениями, такими как регуляризация. Поддерживаются три основные формы повышения градиента:\n", "\n", "* Алгоритм Gradient Boosting также называется градиентной машиной повышения, включая скорость обучения.\n", "* Stochastic Gradient Boosting с суб-выборкой в строке, столбце и столбце на каждый уровень разделения.\n", "* Регулярное усиление градиента с регуляцией L1 и L2.\n", "\n", "__Системные функции__\n", "\n", "Библиотека предоставляет систему для использования в различных вычислительных средах, не в последнюю очередь:\n", "\n", "* Параллелизация построения дерева с использованием всех ваших ядер процессора во время обучения.\n", "* Распределенные вычисления для обучения очень крупных моделей с использованием кластера машин.\n", "* Внекорпоративные вычисления для очень больших наборов данных, которые не вписываются в память.\n", "* Кэш Оптимизация структуры данных и алгоритма для наилучшего использования аппаратного обеспечения.\n", "\n", "__Особенности алгоритма__\n", "\n", "Реализация алгоритма была разработана для эффективности вычислительных ресурсов времени и памяти. Цель проекта заключалась в том, чтобы наилучшим образом использовать имеющиеся ресурсы для обучения модели. Некоторые ключевые функции реализации алгоритма включают:\n", "\n", "* Редкая реализация Aware с автоматической обработкой отсутствующих значений данных.\n", "* Блочная структура для поддержки распараллеливания конструкции дерева.\n", "* Продолжение обучения, чтобы вы могли еще больше повысить уже установленную модель для новых данных. \n", "\n", "В реальности в H2O используется нативный алгоритм [XGBoost](http://xgboost.readthedocs.io/en/latest/get_started/index.html), поэтому не буду пытатьтся подобрать параметры, покажу только, как запускается" ] }, { "cell_type": "code", "execution_count": 246, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "from h2o.estimators import H2OXGBoostEstimator\n", "# библиотека доступна не для всех платформ, сначала надо проверить ее доступность\n", "is_xgboost_available = H2OXGBoostEstimator.available()\n", "print(is_xgboost_available)" ] }, { "cell_type": "code", "execution_count": 249, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "xgboost Model Build progress: |███████████████████████████████████████████| 100%\n", "\n", "ModelMetricsBinomial: xgboost\n", "** Reported on train data. **\n", "\n", "MSE: 0.04233021547018061\n", "RMSE: 0.2057430812206831\n", "LogLoss: 0.1893648656423194\n", "Mean Per-Class Error: 0.0834117942209498\n", "AUC: 0.9795984475074526\n", "Gini: 0.9591968950149052\n", "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.38141971826553345: \n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
01ErrorRate
01960.033.00.0166 (33.0/1993.0)
165.0275.00.1912 (65.0/340.0)
Total2025.0308.00.042 (98.0/2333.0)
" ], "text/plain": [ " 0 1 Error Rate\n", "----- ---- --- ------- -------------\n", "0 1960 33 0.0166 (33.0/1993.0)\n", "1 65 275 0.1912 (65.0/340.0)\n", "Total 2025 308 0.042 (98.0/2333.0)" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Maximum Metrics: Maximum metrics at their respective thresholds\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
metricthresholdvalueidx
max f10.38141970.8487654163.0
max f20.14305330.8405723274.0
max f0point50.60403320.9194529118.0
max accuracy0.46445680.9597085143.0
max precision0.91704141.00.0
max recall0.10960871.0316.0
max specificity0.91704141.00.0
max absolute_mcc0.46263520.8304986144.0
max min_per_class_accuracy0.17151320.8966382251.0
max mean_per_class_accuracy0.14305330.9165882274.0
" ], "text/plain": [ "metric threshold value idx\n", "--------------------------- ----------- -------- -----\n", "max f1 0.38142 0.848765 163\n", "max f2 0.143053 0.840572 274\n", "max f0point5 0.604033 0.919453 118\n", "max accuracy 0.464457 0.959709 143\n", "max precision 0.917041 1 0\n", "max recall 0.109609 1 316\n", "max specificity 0.917041 1 0\n", "max absolute_mcc 0.462635 0.830499 144\n", "max min_per_class_accuracy 0.171513 0.896638 251\n", "max mean_per_class_accuracy 0.143053 0.916588 274" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Gains/Lift Table: Avg response rate: 14,57 %\n", "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
groupcumulative_data_fractionlower_thresholdliftcumulative_liftresponse_ratecumulative_response_ratecapture_ratecumulative_capture_rategaincumulative_gain
10.01028720.89525826.86176476.86176471.01.00.07058820.0705882586.1764706586.1764706
20.02014570.86314296.86176476.86176471.01.00.06764710.1382353586.1764706586.1764706
30.03000430.85109756.86176476.86176471.01.00.06764710.2058824586.1764706586.1764706
40.04029150.83736786.86176476.86176471.01.00.07058820.2764706586.1764706586.1764706
50.05015000.80467346.86176476.86176471.01.00.06764710.3441176586.1764706586.1764706
60.10030000.63702006.80311716.83244090.99145300.99572650.34117650.6852941580.3117144583.2440925
70.15002140.28454663.07596355.58743700.44827590.81428570.15294120.8382353207.5963489458.7436975
80.20017150.19063520.58647564.33451950.08547010.63169160.02941180.8676471-41.3524384333.4519461
90.30004290.12351011.17798543.28384450.17167380.47857140.11764710.985294117.7985357228.3844538
100.39991430.10649680.14724822.50053590.02145920.36441590.01470591.0-85.2751830150.0535906
110.50021430.09755460.01.99914310.00.29134530.01.0-100.099.9143102
120.60008570.09181060.01.66642860.00.24285710.01.0-100.066.6428571
130.69995710.08645800.01.42865890.00.20820580.01.0-100.042.8658910
140.79982850.08197980.01.25026800.00.18220790.01.0-100.025.0267953
150.89970000.07729510.01.11148170.00.16198190.01.0-100.011.1481658
161.00.06842450.01.00.00.14573510.01.0-100.00.0
" ], "text/plain": [ " group cumulative_data_fraction lower_threshold lift cumulative_lift response_rate cumulative_response_rate capture_rate cumulative_capture_rate gain cumulative_gain\n", "-- ------- -------------------------- ----------------- -------- ----------------- --------------- -------------------------- -------------- ------------------------- -------- -----------------\n", " 1 0.0102872 0.895258 6.86176 6.86176 1 1 0.0705882 0.0705882 586.176 586.176\n", " 2 0.0201457 0.863143 6.86176 6.86176 1 1 0.0676471 0.138235 586.176 586.176\n", " 3 0.0300043 0.851098 6.86176 6.86176 1 1 0.0676471 0.205882 586.176 586.176\n", " 4 0.0402915 0.837368 6.86176 6.86176 1 1 0.0705882 0.276471 586.176 586.176\n", " 5 0.05015 0.804673 6.86176 6.86176 1 1 0.0676471 0.344118 586.176 586.176\n", " 6 0.1003 0.63702 6.80312 6.83244 0.991453 0.995726 0.341176 0.685294 580.312 583.244\n", " 7 0.150021 0.284547 3.07596 5.58744 0.448276 0.814286 0.152941 0.838235 207.596 458.744\n", " 8 0.200171 0.190635 0.586476 4.33452 0.0854701 0.631692 0.0294118 0.867647 -41.3524 333.452\n", " 9 0.300043 0.12351 1.17799 3.28384 0.171674 0.478571 0.117647 0.985294 17.7985 228.384\n", " 10 0.399914 0.106497 0.147248 2.50054 0.0214592 0.364416 0.0147059 1 -85.2752 150.054\n", " 11 0.500214 0.0975546 0 1.99914 0 0.291345 0 1 -100 99.9143\n", " 12 0.600086 0.0918106 0 1.66643 0 0.242857 0 1 -100 66.6429\n", " 13 0.699957 0.086458 0 1.42866 0 0.208206 0 1 -100 42.8659\n", " 14 0.799829 0.0819798 0 1.25027 0 0.182208 0 1 -100 25.0268\n", " 15 0.8997 0.0772951 0 1.11148 0 0.161982 0 1 -100 11.1482\n", " 16 1 0.0684245 0 1 0 0.145735 0 1 -100 0" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "AUC на валидационной выборке GradientBoostingMachine: 0.9146600190940914\n", "AUC на валидационной выборке XGBoost: 0.9335786733686384\n" ] } ], "source": [ "param = {\n", " \"ntrees\" : 100\n", " , \"max_depth\" : 10\n", " , \"learn_rate\" : 0.02\n", " , \"sample_rate\" : 0.7\n", " , \"col_sample_rate_per_tree\" : 0.9\n", " , \"min_rows\" : 5\n", " , \"seed\": 4241\n", " , \"score_tree_interval\": 100\n", "}\n", "\n", "model = H2OXGBoostEstimator(**param)\n", "model.train(X, y, training_frame=training, validation_frame=validation)\n", "model.model_performance().show()\n", "\n", "print('AUC на валидационной выборке GradientBoostingMachine: ',gbm.auc(valid=True))\n", "print('AUC на валидационной выборке XGBoost: ',model.auc(valid=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### LightGBM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[LightGBM](https://github.com/Microsoft/LightGBM) строит глубокие асимметричные деревья, повторно разбивая один лист вместо разбиения всех листьев одного уровня до достижения максимальной глубины. \n", "XGBoost использует предварительную сортировку и гистограммирование для расчета наилучшего разбиения, т.е.\n", "- для каждого узла необходимо пронумеровать все признаки; \n", "- для каждого признака необходимо провести сортировку всех значений (здесь можно разбить на бины и провести сортировку бинов); \n", "- ищем налиучшее разбиение для признака; \n", "- выбираем наилучшее разбиение среди всех признаков.\n", "\n", "но LightGBM использует другой способ: градиент - это угол наклона функции потерь, таким образом, если градиент для каких-то точек больше, то эти точки важнее для поиска оптимального разбиения. Алгоритм находит все такие точки с максимальным градиентом и делает рандомное расщепление на точках с маленький градиентом. \n", "Предположим, есть 500K строчек данных, где у 10k строчек градиент больше, таким образом алогритм выберет 10k строчку большим градиентом + x% от отсавшихся 490k строчек, выбранных случайно. Предположим, x = 10%, общее количество выбранных строк = 59k из 500K\n", "Важное предположение здесь: ошибка на тренировочном наборе с меньшим градиентом меньше и эти данные уже хорошо обучены в модели. Таким образом мы уменьшаем с одной стороны количество данных для обучения, но при этом сохраняем качество для уже обученных деревьев.\n", "\n", "----------------\n", "H2O не интегрирован с LightGBM, но предоставляет метод для эмуляции LightGBM алгоритма, используя определенный набор параметров:\n", "\n", "\n", "tree_method=\"hist\"\n", "grow_policy=\"lossguide\"" ] }, { "cell_type": "code", "execution_count": 150, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "H2O session _sid_addf closed.\n" ] } ], "source": [ "h2o.cluster().shutdown()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "__Список литературы__\n", "\n", "* https://www.h2o.ai/ \n", "* http://statistica.ru/local-portals/actuaries/obobshchennye-lineynye-modeli-glm/ \n", "* https://en.wikipedia.org/wiki/Generalized_linear_model \n", "* https://alexanderdyakonov.files.wordpress.com/2017/06/book_boosting_pdf.pdf \n", "* https://github.com/h2oai/h2o-3/blob/master/h2o-docs/src/product/tutorials/gbm/gbmTuning.ipynb \n", "* https://ru.bmstu.wiki/XGBoost \n", "* Артем Груздев. Прогнозное моделирование в IBM SPSS Statistics, R и Python; Лекции\n", "* https://towardsdatascience.com/catboost-vs-light-gbm-vs-xgboost-5f93620723db" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }