{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# AutoGluon and probability calibration\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 確率補正に必要なsklearn APIの取り付け\n", "\n", "本連載では分類モデルの予測値を信頼性曲線にプロットしたり、\n", "クラス確率に近づける確率補正について取り上げています。\n", "前回、sklearn APIが元々備わっていたLightGBMモデルを対象にしましたが、\n", "今回は、[AutoGluon Tabular](https://pypi.org/project/autogluon.tabular/) という AutoML を対象に、\n", "確率補正に必要なAPIメソッドを取り付ける\n", "具体例をご紹介します。\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ライブラリの用意\n", "\n", "sklearn を使って信頼性曲線を書いたり確率補正します。\n", "図形は Matplotlib と Plotly で作ります。\n", "補正対象のモデルとしては \n", "[AutoGluon Tabular](https://pypi.org/project/autogluon.tabular/)\n", "を使います。\n", "AutoGluon内部では \n", "[lightgbm](https://lightgbm.readthedocs.io/en/latest/),\n", "[catboost](https://catboost.ai/),\n", "[xgboost](https://xgboost.readthedocs.io/en/latest/)\n", "などが使われます。\n", "\n", "Condaやpipでそれぞれインストールできます。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## サンプルデータ\n", "\n", "[Adult Census Income (国勢調査の成人収入)](https://archive.ics.uci.edu/ml/datasets/census+income)\n", "を使います。5万ドル以上の年収があるかどうかを分類するデータセットです。\n", "このデータは元々正例の割合が少なく、偏っていますが、\n", "更にノイズを投入しました。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original dataframe shape (32561, 12)\n", "Noisy dataframe shape (32561, 1212)\n", "Classes [False True]\n", "学習・補正・テスト用データ比率 [0.74997697 0.12284635 0.12717668]\n" ] } ], "source": [ "import numpy as np\n", "import shap\n", "import sklearn\n", "from sklearn.model_selection import train_test_split\n", "\n", "## Census income\n", "X, y = shap.datasets.adult()\n", "X = X.values\n", "\n", "print(\"Original dataframe shape\", X.shape)\n", "n_samples, n_features = X.shape\n", "\n", "# Add noise\n", "random_state = np.random.RandomState(0)\n", "X = X + 4 * random_state.randn(n_samples, n_features)\n", "X = np.c_[X, random_state.randn(n_samples, 100 * n_features)]\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.25, random_state=random_state\n", ")\n", "print(\"Noisy dataframe shape\", X.shape)\n", "print(\"Classes\", np.unique(y))\n", "\n", "X_test, X_calib, y_test, y_calib = train_test_split(\n", " X_test,\n", " y_test,\n", " test_size=4000 / len(X_test),\n", " random_state=random_state,\n", ")\n", "print(\"学習・補正・テスト用データ比率\", np.array([len(X_train), len(X_calib), len(X_test)]) / len(X))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
0123456789...120312041205120612071208120912101211y
055.1506790.19071611.6398230.9840504.207856-0.4386184.075127-0.1772134.4196742261.400319...-0.421482-0.4923542.1889311.8905321.197646-0.915642-0.7453671.309365-0.558290False
144.3246254.56808814.6914064.085347-0.9861665.562246-0.689750-8.1952241.908657-2.469216...-0.2481470.1887880.7590421.1835620.6418650.0956571.481117-1.294230-1.477307True
249.1617683.50422710.3446012.89802718.8647693.7340595.9443070.4452383.8144094.375051...-0.108891-1.597474-0.5574840.2747190.834568-0.1558430.2924810.3630870.687930True
338.37562310.63417914.36769211.40100111.490676-1.866726-0.6753718.0235154.178944-7.578039...1.123319-0.363197-0.915543-0.7218411.290744-0.400617-0.254988-1.682269-1.519599False
451.9417754.5047228.2335710.4160770.989725-2.964122-0.7661502.8369631.554746-7.509606...0.7005310.4650551.754451-0.510019-0.1184990.867805-0.419116-0.2964830.815914False
\n

5 rows × 1213 columns

\n
", "text/plain": " 0 1 2 3 4 5 6 \\\n0 55.150679 0.190716 11.639823 0.984050 4.207856 -0.438618 4.075127 \n1 44.324625 4.568088 14.691406 4.085347 -0.986166 5.562246 -0.689750 \n2 49.161768 3.504227 10.344601 2.898027 18.864769 3.734059 5.944307 \n3 38.375623 10.634179 14.367692 11.401001 11.490676 -1.866726 -0.675371 \n4 51.941775 4.504722 8.233571 0.416077 0.989725 -2.964122 -0.766150 \n\n 7 8 9 ... 1203 1204 1205 \\\n0 -0.177213 4.419674 2261.400319 ... -0.421482 -0.492354 2.188931 \n1 -8.195224 1.908657 -2.469216 ... -0.248147 0.188788 0.759042 \n2 0.445238 3.814409 4.375051 ... -0.108891 -1.597474 -0.557484 \n3 8.023515 4.178944 -7.578039 ... 1.123319 -0.363197 -0.915543 \n4 2.836963 1.554746 -7.509606 ... 0.700531 0.465055 1.754451 \n\n 1206 1207 1208 1209 1210 1211 y \n0 1.890532 1.197646 -0.915642 -0.745367 1.309365 -0.558290 False \n1 1.183562 0.641865 0.095657 1.481117 -1.294230 -1.477307 True \n2 0.274719 0.834568 -0.155843 0.292481 0.363087 0.687930 True \n3 -0.721841 1.290744 -0.400617 -0.254988 -1.682269 -1.519599 False \n4 -0.510019 -0.118499 0.867805 -0.419116 -0.296483 0.815914 False \n\n[5 rows x 1213 columns]" }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from autogluon.tabular import TabularDataset\n", "\n", "tr_data = TabularDataset(X_train)\n", "tr_data[\"y\"] = y_train\n", "tr_data.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
0123456789...120312041205120612071208120912101211y
062.2553835.83114312.638373-4.37595911.9504152.638438-2.181352-1.075052-1.040339-4.046030...0.627686-1.3812020.2350470.7814380.3523660.5837541.2249020.3500850.407340False
142.3044181.5077659.2252367.56878715.2969584.9982323.0061090.3830815.0575613.583399...-1.6875680.528663-1.1786461.333547-0.5573220.4890010.3723400.810345-0.431861False
225.3627050.0990634.37070717.3767909.8493070.6363757.183748-0.9599210.1366964.054935...-0.8926340.1201630.8729140.6510860.625578-1.2439110.6353221.4449860.077606False
310.00713012.3417504.399064-1.05264210.981154-1.7039683.493526-0.566090-1.1109277.241747...-1.264475-2.2352010.9152481.0412310.5449131.468259-1.7617623.2053640.742915False
425.6227874.8294717.5097053.9958925.7983642.0851360.069773-0.438765-4.6604861.304545...0.4252001.0081400.6189211.234493-0.2781990.8511980.133332-0.4671021.612342False
\n

5 rows × 1213 columns

\n
", "text/plain": " 0 1 2 3 4 5 6 \\\n0 62.255383 5.831143 12.638373 -4.375959 11.950415 2.638438 -2.181352 \n1 42.304418 1.507765 9.225236 7.568787 15.296958 4.998232 3.006109 \n2 25.362705 0.099063 4.370707 17.376790 9.849307 0.636375 7.183748 \n3 10.007130 12.341750 4.399064 -1.052642 10.981154 -1.703968 3.493526 \n4 25.622787 4.829471 7.509705 3.995892 5.798364 2.085136 0.069773 \n\n 7 8 9 ... 1203 1204 1205 1206 \\\n0 -1.075052 -1.040339 -4.046030 ... 0.627686 -1.381202 0.235047 0.781438 \n1 0.383081 5.057561 3.583399 ... -1.687568 0.528663 -1.178646 1.333547 \n2 -0.959921 0.136696 4.054935 ... -0.892634 0.120163 0.872914 0.651086 \n3 -0.566090 -1.110927 7.241747 ... -1.264475 -2.235201 0.915248 1.041231 \n4 -0.438765 -4.660486 1.304545 ... 0.425200 1.008140 0.618921 1.234493 \n\n 1207 1208 1209 1210 1211 y \n0 0.352366 0.583754 1.224902 0.350085 0.407340 False \n1 -0.557322 0.489001 0.372340 0.810345 -0.431861 False \n2 0.625578 -1.243911 0.635322 1.444986 0.077606 False \n3 0.544913 1.468259 -1.761762 3.205364 0.742915 False \n4 -0.278199 0.851198 0.133332 -0.467102 1.612342 False \n\n[5 rows x 1213 columns]" }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te_data = TabularDataset(X_test)\n", "te_data[\"y\"] = y_test\n", "te_data.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": "count 24420\nunique 2\ntop False\nfreq 18577\nName: y, dtype: object" }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tr_data.y.describe()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": "count 4141\nunique 2\ntop False\nfreq 3098\nName: y, dtype: object" }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te_data.y.describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## モデル学習" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "AutoGluonモデルを学習データにフィットさせます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Warning: path already exists! This predictor may overwrite an existing predictor! path=\"trained-model\"\n", "Beginning AutoGluon training ... Time limit = 30s\n", "AutoGluon will save models to \"trained-model/\"\n", "AutoGluon Version: 0.1.0\n", "Train Data Rows: 24420\n", "Train Data Columns: 1212\n", "Preprocessing data ...\n", "AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).\n", "\t2 unique label values: [False, True]\n", "\tIf 'binary' is not the correct problem_type, please manually specify the problem_type argument in fit() (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n", "Selected class <--> label mapping: class 1 = True, class 0 = False\n", "Using Feature Generators to preprocess the data ...\n", "Fitting AutoMLPipelineFeatureGenerator...\n", "\tAvailable Memory: 6459.76 MB\n", "\tTrain Data (Original) Memory Usage: 236.78 MB (3.7% of available memory)\n", "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n", "\tStage 1 Generators:\n", "\t\tFitting AsTypeFeatureGenerator...\n", "\tStage 2 Generators:\n", "\t\tFitting FillNaFeatureGenerator...\n", "\tStage 3 Generators:\n", "\t\tFitting IdentityFeatureGenerator...\n", "\tStage 4 Generators:\n", "\t\tFitting DropUniqueFeatureGenerator...\n", "\tTypes of features in original data (raw dtype, special dtypes):\n", "\t\t('float', []) : 1212 | ['0', '1', '2', '3', '4', ...]\n", "\tTypes of features in processed data (raw dtype, special dtypes):\n", "\t\t('float', []) : 1212 | ['0', '1', '2', '3', '4', ...]\n", "\t4.5s = Fit runtime\n", "\t1212 features in original data used to generate 1212 features in processed data.\n", "\tTrain Data (Processed) Memory Usage: 236.78 MB (3.4% of available memory)\n", "Data preprocessing and feature engineering runtime = 5.39s ...\n", "AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n", "\tTo change this, specify the eval_metric argument of fit()\n", "Automatically generating train/validation split with holdout_frac=0.1, Train Rows: 21978, Val Rows: 2442\n", "Fitting model: LightGBM ... Training model for up to 24.61s of the 24.6s of remaining time.\n", "\t0.8084\t = Validation accuracy score\n", "\t3.28s\t = Training runtime\n", "\t0.05s\t = Validation runtime\n", "Fitting model: CatBoost ... Training model for up to 21.28s of the 21.27s of remaining time.\n", "\t0.8157\t = Validation accuracy score\n", "\t1.74s\t = Training runtime\n", "\t0.02s\t = Validation runtime\n", "Fitting model: XGBoost ... Training model for up to 19.51s of the 19.5s of remaining time.\n", "ntree_limit is deprecated, use `iteration_range` or model slicing instead.\n", "\t0.8215\t = Validation accuracy score\n", "\t16.77s\t = Training runtime\n", "\t0.24s\t = Validation runtime\n", "Fitting model: NeuralNetMXNet ... Training model for up to 2.47s of the 2.46s of remaining time.\n", "\tTime limit exceeded... Skipping NeuralNetMXNet.\n", "ntree_limit is deprecated, use `iteration_range` or model slicing instead.\n", "Fitting model: WeightedEnsemble_L2 ... Training model for up to 24.61s of the -5.42s of remaining time.\n", "\t0.8219\t = Validation accuracy score\n", "\t0.35s\t = Training runtime\n", "\t0.0s\t = Validation runtime\n", "AutoGluon training complete, total runtime = 36.23s ...\n", "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"trained-model/\")\n" ] } ], "source": [ "from autogluon.tabular import TabularPredictor\n", "\n", "save_path = \"trained-model\"\n", "predictor = TabularPredictor(label=\"y\", path=save_path).fit(\n", " tr_data, hyperparameters=\"toy\", time_limit=30\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sklearnラッパー\n", "\n", "sklearnを使って確率補正を行うために、必要なインタフェースを用意する。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "試験データでメトリック AUC, Brier を図り、信頼性曲線を書いてみます。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastcore.basics import store_attr\n", "\n", "\n", "class AutoGluonWrapper:\n", " \"\"\"\n", " sklearnを使って信頼性曲線を描いたり、確率補正を行うために、\n", " 必要なインタフェースを用意する。\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " trained_model_path, # AutoGluon学習済みモデルの保存パス\n", " classes_, # sklearn APIに求められる属性\n", " ):\n", " store_attr()\n", "\n", " def load_model(self):\n", " \"\"\" AutoGluon学習済みモデルをロード \"\"\"\n", " self.ag_model = TabularPredictor.load(self.trained_model_path)\n", "\n", " def fit(self):\n", " \"\"\" sklearn API に求められるメソッド \"\"\"\n", " return True\n", "\n", " def predict_proba(self, X):\n", " \"\"\" sklearn API に求められるメソッド \"\"\"\n", " X = TabularDataset(X)\n", " proba = self.ag_model.predict_proba(X)\n", " return proba.values" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ag_ = AutoGluonWrapper(save_path, classes_=np.unique(y))\n", "ag_.load_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 信頼性曲線" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "試験データに対してメトリック AUC, Brier を図り、信頼性曲線を書いてみます。\n", "前回定義したメソッド `plot_calibration_curve` を使います。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from kowaza.proba_calib import plot_calibration_curve" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ntree_limit is deprecated, use `iteration_range` or model slicing instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "AutoGluon:\n", "\tAUC : 0.824\n", "\tBrier: 0.153\n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": "
" }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_calibration_curve(dict(AutoGluon=ag_), X_test, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 確率補正の実施\n", "\n", "補正するために[sklearn.calibration. CalibratedClassifierCV](https://scikit-learn.org/stable/modules/generated/sklearn.calibration.CalibratedClassifierCV.html#sklearn.calibration.CalibratedClassifierCV)を使います。 `cv = \"prefit\"` と指定することによって、ベースモデルが学習済みであり、補正モデルに渡すデータは全量補正用であることを伝えます。\n", "\n", "今回は確率補正方法として `sigmoid` と `isotonic` をそれぞれ使ってみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ntree_limit is deprecated, use `iteration_range` or model slicing instead.\n" ] }, { "data": { "text/plain": "CalibratedClassifierCV(base_estimator=<__main__.AutoGluonWrapper object at 0x7fbb9497d7c0>,\n cv='prefit')" }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.calibration import CalibratedClassifierCV, calibration_curve\n", "\n", "sigmoid = CalibratedClassifierCV(ag_, cv=\"prefit\", method=\"sigmoid\")\n", "sigmoid.fit(X_calib, y_calib)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ntree_limit is deprecated, use `iteration_range` or model slicing instead.\n" ] }, { "data": { "text/plain": "CalibratedClassifierCV(base_estimator=<__main__.AutoGluonWrapper object at 0x7fbb9497d7c0>,\n cv='prefit', method='isotonic')" }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "isotonic = CalibratedClassifierCV(ag_, cv=\"prefit\", method=\"isotonic\")\n", "isotonic.fit(X_calib, y_calib)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "補正前と補正後のモデルの信頼性曲線を同じ図に書いて比較してみましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ntree_limit is deprecated, use `iteration_range` or model slicing instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "AutoGluon:\n", "\tAUC : 0.824\n", "\tBrier: 0.153\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "ntree_limit is deprecated, use `iteration_range` or model slicing instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Sigmoid:\n", "\tAUC : 0.824\n", "\tBrier: 0.131\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "ntree_limit is deprecated, use `iteration_range` or model slicing instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Isotonic:\n", "\tAUC : 0.824\n", "\tBrier: 0.131\n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": "
" }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plot_calibration_curve(\n", " dict(\n", " AutoGluon=ag_,\n", " Sigmoid=sigmoid,\n", " Isotonic=isotonic,\n", " ),\n", " X_test,\n", " y_test,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "補正の結果、信頼性曲線が改善され、Brierスコアも改善されました。\n", "最も、対象データセットにノイズを投入せずにAutoGluonに学習させると、確率補正を行わなくてもきれいな信頼性曲線が得られます。\n", "しかし、実運用では、データに必ずノイズが含まれるし、データの分布も時間とともに少しずつ変化していくものなので、信頼性曲線をプロットして必要に応じて確率補正を行う必要があります。\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## まとめ\n", "\n", "第1回目に、分類モデルの出力値が必ずしもクラス確率とは限らないので、信頼性曲線を確認したり、確率補正を行ってみました。\n", "今回は、sklearnが規定しているインタフェースを持たない学習済みモデルの補正実験を行いました。\n", "次回は、確率補正に関する背景的な理論について書ければと思います。\n" ] } ], "metadata": { "kernelspec": { "display_name": "kowaza", "language": "python", "name": "kowaza" } }, "nbformat": 4, "nbformat_minor": 4 }