{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "optical-forge", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from interpret.glassbox import ExplainableBoostingClassifier\n", "from interpret import show\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.metrics import precision_recall_fscore_support" ] }, { "cell_type": "code", "execution_count": null, "id": "solid-playing", "metadata": {}, "outputs": [], "source": [ "# imports for onnx conversion and inference\n", "import onnx\n", "import ebm2onnx\n", "import onnxruntime as rt\n", "import numpy as np\n", "import tempfile" ] }, { "cell_type": "markdown", "id": "offensive-soundtrack", "metadata": {}, "source": [ "# Binary classification" ] }, { "cell_type": "markdown", "id": "powerful-desktop", "metadata": {}, "source": [ "## Load dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "contemporary-staff", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('titanic_train.csv')\n", "df = df.dropna()\n", "df.head()" ] }, { "cell_type": "markdown", "id": "charged-connectivity", "metadata": {}, "source": [ "## Train model" ] }, { "cell_type": "code", "execution_count": null, "id": "corporate-bleeding", "metadata": {}, "outputs": [], "source": [ "feature_columns = ['Age', 'Fare', 'Pclass', 'Embarked']\n", "label_column = \"Survived\"\n", "\n", "y = df[[label_column]]\n", "le = LabelEncoder()\n", "y_enc = le.fit_transform(y)\n", "x = df[feature_columns]\n", "x_train, x_test, y_train, y_test = train_test_split(x, y_enc)\n", "ebm = ExplainableBoostingClassifier(\n", " interactions=2,\n", " feature_types=['continuous', 'continuous', 'continuous', 'nominal']\n", ")\n", "ebm.fit(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": null, "id": "canadian-telephone", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# A lookup at the generated model\n", "ebm_global = ebm.explain_global()\n", "show(ebm_global)" ] }, { "cell_type": "markdown", "id": "beautiful-wireless", "metadata": {}, "source": [ "## Convert model" ] }, { "cell_type": "code", "execution_count": null, "id": "vocational-globe", "metadata": {}, "outputs": [], "source": [ "onnx_model = ebm2onnx.to_onnx(\n", " model=ebm,\n", " dtype=ebm2onnx.get_dtype_from_pandas(x_train),\n", " name=\"ebm\",\n", ")" ] }, { "cell_type": "markdown", "id": "italic-authorization", "metadata": {}, "source": [ "## Predict with EBM implementation" ] }, { "cell_type": "code", "execution_count": null, "id": "appropriate-powder", "metadata": {}, "outputs": [], "source": [ "ebm_pred = ebm.predict(x_test)\n", "pd.DataFrame(precision_recall_fscore_support(y_test, ebm_pred, average=None), index=['Precision', 'Recall', 'FScore', 'Support'])" ] }, { "cell_type": "markdown", "id": "centered-enough", "metadata": {}, "source": [ "## Predict with ONNX Runtime" ] }, { "cell_type": "code", "execution_count": null, "id": "described-beatles", "metadata": {}, "outputs": [], "source": [ "_, filename = tempfile.mkstemp()\n", "onnx.save_model(onnx_model, filename)\n", "\n", "sess = rt.InferenceSession(filename)\n", "onnx_pred = sess.run(None, {\n", " 'Age': x_test['Age'].values,\n", " 'Fare': x_test['Fare'].values,\n", " 'Pclass': x_test['Pclass'].values,\n", " 'Embarked': x_test['Embarked'].values,\n", "})\n", "\n", "print(\"metrics of output {}:\".format(sess.get_outputs()[0].name))\n", "pd.DataFrame(precision_recall_fscore_support(y_test, onnx_pred[0], average=None), index=['Precision', 'Recall', 'FScore', 'Support'])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.10" } }, "nbformat": 4, "nbformat_minor": 5 }