{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Probability calibration\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "分類モデルの確率補正\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp proba_calib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 問題提起\n", "\n", "スパムメールやクレジット詐欺を見分けるタスクなどを学習した\n", "分類モデルが出力する予測値は通常 \n", "(0, 1) の範囲内に収まり、\n", "予測確率とも呼ばれるので、\n", "うっかり正例であるクラス確率だと\n", "思い込みかねません。\n", "\n", "実運用では、閾値を設けて、予測値がその閾値を超えるかどうかで判断を下したりします。\n", "予測値がクラス確率であるかどうかによって、閾値の意味も大きく変わってきます。\n", "\n", "スパムメール分類モデルの場合、真のクラス確率を学習したのなら、 \n", "90% という予測値が出力されたような100通のメールのうち、\n", "90通が本当にスパムメールだろうと期待されます。\n", "本当のスパムメールの数が90を下回ったらモデルの自信過剰、\n", "上回ったら自信不足と言えます。\n", "\n", "モデルの予測値が真の確率とどのくらい乖離しているのか図り、\n", "予測値を真の確率に近づける補正方法を実験してみましょう。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# exporti\n", "import matplotlib.pyplot as plt\n", "import japanize_matplotlib\n", "from sklearn.calibration import CalibratedClassifierCV, calibration_curve\n", "from sklearn.metrics import brier_score_loss, roc_auc_score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ライブラリの用意\n", "\n", "sklearn を使って信頼性曲線を書いたり確率補正します。\n", "補正前のベースモデルとしては LightGBM を使います。\n", "図形は Matplotlib と Plotly で作ります。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import plotly.express as px\n", "from lightgbm import LGBMClassifier\n", "from sklearn import datasets\n", "from sklearn.datasets import make_circles\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## サンプルデータ\n", "\n", "簡単な実験データとして、\n", "[sklearn](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html#sklearn.datasets.make_circles)を使って、\n", "円を指定し、その内側を正例として、外側を負例とします。\n", "ノイズを投入したり、正例の割合を減らしてデータを不均衡化しています。\n", "\n", "実験データを学習用・補正用・試験用に分けました。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ノイズ率(ラベルが反転) 0.08449600000000002\n", "正例率 0.20687921246278282\n", "学習・補正・テスト用データ比率 [0.74365384 0.00634497 0.25000119]\n" ] } ], "source": [ "n_samples = 1_000_000\n", "pos_rate = 0.1\n", "radius = 0.3\n", "\n", "X, y = make_circles(n_samples=n_samples, factor=radius, noise=0.2)\n", "df = pd.DataFrame(X, columns=[\"x1\", \"x2\"])\n", "df[\"y\"] = pd.Series(y).astype(\"category\")\n", "df[\"r\"] = df.x1 ** 2 + df.x2 ** 2\n", "print(\"ノイズ率(ラベルが反転)\", 1 - sum((df.r <= radius) == (df.y == 1)) / n_samples)\n", "\n", "# ランダムに正例をドロップし、インバランスクラス化\n", "df = pd.concat(\n", " [\n", " df[(df.y == 0) | (df.r >= radius)],\n", " df[(df.y == 1) & (df.r < radius)].sample(int(pos_rate * sum(df.y))),\n", " ]\n", ")\n", "X = df[[\"x1\", \"x2\"]]\n", "y = df.y\n", "\n", "print(\"正例率\", sum(y) / len(y))\n", "\n", "# 学習・補正・テストデータ分割\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.25, random_state=42\n", ")\n", "X_train_, X_calib, y_train_, y_calib = train_test_split(\n", " X_train,\n", " y_train,\n", " test_size=4000 / len(X_train),\n", " random_state=2020,\n", ")\n", "print(\"学習・補正・テスト用データ比率\", np.array([len(X_train_), len(X_calib), len(X_test)]) / len(X))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# exports\n", "def plot_calibration_curve(named_classifiers, X_test, y_test):\n", " fig = plt.figure(figsize=(10, 10))\n", " ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)\n", " ax2 = plt.subplot2grid((3, 1), (2, 0))\n", "\n", " ax1.plot([0, 1], [0, 1], \"k:\", label=\"完全な補正\")\n", " for name, clf in named_classifiers.items():\n", " prob_pos = clf.predict_proba(X_test)[:, 1]\n", " auc = roc_auc_score(y_test, prob_pos)\n", " brier = brier_score_loss(y_test, prob_pos)\n", " print(\"%s:\" % name)\n", " print(\"\\tAUC : %1.3f\" % auc)\n", " print(\"\\tBrier: %1.3f\" % (brier))\n", " print()\n", "\n", " fraction_of_positives, mean_predicted_value = calibration_curve(\n", " y_test,\n", " prob_pos,\n", " n_bins=10,\n", " )\n", "\n", " ax1.plot(\n", " mean_predicted_value,\n", " fraction_of_positives,\n", " \"s-\",\n", " label=\"%s (%1.3f)\" % (name, brier),\n", " )\n", "\n", " ax2.hist(prob_pos, range=(0, 1), bins=10, label=name, histtype=\"step\", lw=2)\n", "\n", " ax1.set_ylabel(\"正例の比率\")\n", " ax1.set_ylim([-0.05, 1.05])\n", " ax1.legend(loc=\"lower right\")\n", " ax1.set_title(\"信頼性曲線\")\n", "\n", " ax2.set_xlabel(\"予測値の平均\")\n", " ax2.set_ylabel(\"サンプル数\")\n", " ax2.legend(loc=\"upper center\", ncol=2)\n", "\n", " plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "