{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# AutoGluon and probability calibration\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 確率補正に必要なsklearn APIの取り付け\n", "\n", "本連載では分類モデルの予測値を信頼性曲線にプロットしたり、\n", "クラス確率に近づける確率補正について取り上げています。\n", "前回、sklearn APIが元々備わっていたLightGBMモデルを対象にしましたが、\n", "今回は、[AutoGluon Tabular](https://pypi.org/project/autogluon.tabular/) という AutoML を対象に、\n", "確率補正に必要なAPIメソッドを取り付ける\n", "具体例をご紹介します。\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ライブラリの用意\n", "\n", "sklearn を使って信頼性曲線を書いたり確率補正します。\n", "図形は Matplotlib と Plotly で作ります。\n", "補正対象のモデルとしては \n", "[AutoGluon Tabular](https://pypi.org/project/autogluon.tabular/)\n", "を使います。\n", "AutoGluon内部では \n", "[lightgbm](https://lightgbm.readthedocs.io/en/latest/),\n", "[catboost](https://catboost.ai/),\n", "[xgboost](https://xgboost.readthedocs.io/en/latest/)\n", "などが使われます。\n", "\n", "Condaやpipでそれぞれインストールできます。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## サンプルデータ\n", "\n", "[Adult Census Income (国勢調査の成人収入)](https://archive.ics.uci.edu/ml/datasets/census+income)\n", "を使います。5万ドル以上の年収があるかどうかを分類するデータセットです。\n", "このデータは元々正例の割合が少なく、偏っていますが、\n", "更にノイズを投入しました。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original dataframe shape (32561, 12)\n", "Noisy dataframe shape (32561, 1212)\n", "Classes [False True]\n", "学習・補正・テスト用データ比率 [0.74997697 0.12284635 0.12717668]\n" ] } ], "source": [ "import numpy as np\n", "import shap\n", "import sklearn\n", "from sklearn.model_selection import train_test_split\n", "\n", "## Census income\n", "X, y = shap.datasets.adult()\n", "X = X.values\n", "\n", "print(\"Original dataframe shape\", X.shape)\n", "n_samples, n_features = X.shape\n", "\n", "# Add noise\n", "random_state = np.random.RandomState(0)\n", "X = X + 4 * random_state.randn(n_samples, n_features)\n", "X = np.c_[X, random_state.randn(n_samples, 100 * n_features)]\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.25, random_state=random_state\n", ")\n", "print(\"Noisy dataframe shape\", X.shape)\n", "print(\"Classes\", np.unique(y))\n", "\n", "X_test, X_calib, y_test, y_calib = train_test_split(\n", " X_test,\n", " y_test,\n", " test_size=4000 / len(X_test),\n", " random_state=random_state,\n", ")\n", "print(\"学習・補正・テスト用データ比率\", np.array([len(X_train), len(X_calib), len(X_test)]) / len(X))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": "
\n | 0 | \n1 | \n2 | \n3 | \n4 | \n5 | \n6 | \n7 | \n8 | \n9 | \n... | \n1203 | \n1204 | \n1205 | \n1206 | \n1207 | \n1208 | \n1209 | \n1210 | \n1211 | \ny | \n
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n55.150679 | \n0.190716 | \n11.639823 | \n0.984050 | \n4.207856 | \n-0.438618 | \n4.075127 | \n-0.177213 | \n4.419674 | \n2261.400319 | \n... | \n-0.421482 | \n-0.492354 | \n2.188931 | \n1.890532 | \n1.197646 | \n-0.915642 | \n-0.745367 | \n1.309365 | \n-0.558290 | \nFalse | \n
1 | \n44.324625 | \n4.568088 | \n14.691406 | \n4.085347 | \n-0.986166 | \n5.562246 | \n-0.689750 | \n-8.195224 | \n1.908657 | \n-2.469216 | \n... | \n-0.248147 | \n0.188788 | \n0.759042 | \n1.183562 | \n0.641865 | \n0.095657 | \n1.481117 | \n-1.294230 | \n-1.477307 | \nTrue | \n
2 | \n49.161768 | \n3.504227 | \n10.344601 | \n2.898027 | \n18.864769 | \n3.734059 | \n5.944307 | \n0.445238 | \n3.814409 | \n4.375051 | \n... | \n-0.108891 | \n-1.597474 | \n-0.557484 | \n0.274719 | \n0.834568 | \n-0.155843 | \n0.292481 | \n0.363087 | \n0.687930 | \nTrue | \n
3 | \n38.375623 | \n10.634179 | \n14.367692 | \n11.401001 | \n11.490676 | \n-1.866726 | \n-0.675371 | \n8.023515 | \n4.178944 | \n-7.578039 | \n... | \n1.123319 | \n-0.363197 | \n-0.915543 | \n-0.721841 | \n1.290744 | \n-0.400617 | \n-0.254988 | \n-1.682269 | \n-1.519599 | \nFalse | \n
4 | \n51.941775 | \n4.504722 | \n8.233571 | \n0.416077 | \n0.989725 | \n-2.964122 | \n-0.766150 | \n2.836963 | \n1.554746 | \n-7.509606 | \n... | \n0.700531 | \n0.465055 | \n1.754451 | \n-0.510019 | \n-0.118499 | \n0.867805 | \n-0.419116 | \n-0.296483 | \n0.815914 | \nFalse | \n
5 rows × 1213 columns
\n\n | 0 | \n1 | \n2 | \n3 | \n4 | \n5 | \n6 | \n7 | \n8 | \n9 | \n... | \n1203 | \n1204 | \n1205 | \n1206 | \n1207 | \n1208 | \n1209 | \n1210 | \n1211 | \ny | \n
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n62.255383 | \n5.831143 | \n12.638373 | \n-4.375959 | \n11.950415 | \n2.638438 | \n-2.181352 | \n-1.075052 | \n-1.040339 | \n-4.046030 | \n... | \n0.627686 | \n-1.381202 | \n0.235047 | \n0.781438 | \n0.352366 | \n0.583754 | \n1.224902 | \n0.350085 | \n0.407340 | \nFalse | \n
1 | \n42.304418 | \n1.507765 | \n9.225236 | \n7.568787 | \n15.296958 | \n4.998232 | \n3.006109 | \n0.383081 | \n5.057561 | \n3.583399 | \n... | \n-1.687568 | \n0.528663 | \n-1.178646 | \n1.333547 | \n-0.557322 | \n0.489001 | \n0.372340 | \n0.810345 | \n-0.431861 | \nFalse | \n
2 | \n25.362705 | \n0.099063 | \n4.370707 | \n17.376790 | \n9.849307 | \n0.636375 | \n7.183748 | \n-0.959921 | \n0.136696 | \n4.054935 | \n... | \n-0.892634 | \n0.120163 | \n0.872914 | \n0.651086 | \n0.625578 | \n-1.243911 | \n0.635322 | \n1.444986 | \n0.077606 | \nFalse | \n
3 | \n10.007130 | \n12.341750 | \n4.399064 | \n-1.052642 | \n10.981154 | \n-1.703968 | \n3.493526 | \n-0.566090 | \n-1.110927 | \n7.241747 | \n... | \n-1.264475 | \n-2.235201 | \n0.915248 | \n1.041231 | \n0.544913 | \n1.468259 | \n-1.761762 | \n3.205364 | \n0.742915 | \nFalse | \n
4 | \n25.622787 | \n4.829471 | \n7.509705 | \n3.995892 | \n5.798364 | \n2.085136 | \n0.069773 | \n-0.438765 | \n-4.660486 | \n1.304545 | \n... | \n0.425200 | \n1.008140 | \n0.618921 | \n1.234493 | \n-0.278199 | \n0.851198 | \n0.133332 | \n-0.467102 | \n1.612342 | \nFalse | \n
5 rows × 1213 columns
\n