{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import panel as pn\n", "\n", "from sklearn.datasets import load_iris\n", "from sklearn.metrics import accuracy_score\n", "from xgboost import XGBClassifier\n", "\n", "pn.extension(sizing_mode=\"stretch_width\", template=\"fast\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example was adapted from an example by [Bojan Tunguz](https://twitter.com/tunguz). It demonstrates how to build a simple ML demo demonstrating how hyper-parameters affect the accuracy of the XGBoostClassifier." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pn.state.template.param.update(title=\"XGBoost Example\")\n", "\n", "iris_df = load_iris(as_frame=True)\n", "\n", "n_trees = pn.widgets.IntSlider(start=2, end=30, name=\"Number of trees\")\n", "max_depth = pn.widgets.IntSlider(start=1, end=10, value=2, name=\"Maximum Depth\") \n", "booster = pn.widgets.Select(options=['gbtree', 'gblinear', 'dart'], name=\"Booster\")\n", "\n", "def pipeline(n_trees, max_depth, booster):\n", " model = XGBClassifier(max_depth=max_depth, n_estimators=n_trees, booster=booster)\n", " model.fit(iris_df.data, iris_df.target)\n", " accuracy = round(accuracy_score(iris_df.target, model.predict(iris_df.data)) * 100, 1)\n", " return pn.indicators.Number(\n", " name=\"Test score\",\n", " value=accuracy,\n", " format=\"{value}%\",\n", " colors=[(97.5, \"red\"), (99.0, \"orange\"), (100, \"green\")],\n", " align='center'\n", " )\n", "\n", "pn.Row(\n", " pn.Column(booster, n_trees, max_depth, width=320).servable(area='sidebar'),\n", " pn.Column(\n", " \"Simple example of training an XGBoost classification model on the small Iris dataset.\",\n", " iris_df.data.head(),\n", " \"Adjust the hyperparameters to re-run the XGBoost classifier. The training accuracy score will adjust accordingly:\",\n", " pn.bind(pipeline, n_trees, max_depth, booster)\n", " ).servable(),\n", ")" ] } ], "metadata": { "language_info": { "name": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }