{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Release Highlights for scikit-learn 1.8\n\n.. currentmodule:: sklearn\n\nWe are pleased to announce the release of scikit-learn 1.8! Many bug fixes\nand improvements were added, as well as some key new features. Below we\ndetail the highlights of this release. **For an exhaustive list of\nall the changes**, please refer to the `release notes `.\n\nTo install the latest version (with pip)::\n\n pip install --upgrade scikit-learn\n\nor with conda::\n\n conda install -c conda-forge scikit-learn\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Array API support (enables GPU computations)\nThe progressive adoption of the Python array API standard in\nscikit-learn means that PyTorch and CuPy input arrays\nare used directly. This means that in scikit-learn estimators\nand functions non-CPU devices, such as GPUs, can be used\nto perform the computation. As a result performance is improved\nand integration with these libraries is easier.\n\nIn scikit-learn 1.8, several estimators and functions have been updated to\nsupport array API compatible inputs, for example PyTorch tensors and CuPy\narrays.\n\nArray API support was added to the following estimators:\n:class:`preprocessing.StandardScaler`,\n:class:`preprocessing.PolynomialFeatures`, :class:`linear_model.RidgeCV`,\n:class:`linear_model.RidgeClassifierCV`, :class:`mixture.GaussianMixture` and\n:class:`calibration.CalibratedClassifierCV`.\n\nArray API support was also added to several metrics in :mod:`sklearn.metrics`\nmodule, see `array_api_supported` for more details.\n\nPlease refer to the `array API support` page for instructions\nto use scikit-learn with array API compatible libraries such as PyTorch or CuPy.\nNote: Array API support is experimental and must be explicitly enabled both\nin SciPy and scikit-learn.\n\nHere is an excerpt of using a feature engineering preprocessor on the CPU,\nfollowed by :class:`calibration.CalibratedClassifierCV`\nand :class:`linear_model.RidgeCV` together on a GPU with the help of PyTorch:\n\n```python\nridge_pipeline_gpu = make_pipeline(\n # Ensure that all features (including categorical features) are preprocessed\n # on the CPU and mapped to a numerical representation.\n feature_preprocessor,\n # Move the results to the GPU and perform computations there\n FunctionTransformer(\n lambda x: torch.tensor(x.to_numpy().astype(np.float32), device=\"cuda\"))\n ,\n CalibratedClassifierCV(\n RidgeClassifierCV(alphas=alphas), method=\"temperature\"\n ),\n)\nwith sklearn.config_context(array_api_dispatch=True):\n cv_results = cross_validate(ridge_pipeline_gpu, features, target)\n```\nSee the [full notebook on Google Colab](https://colab.research.google.com/drive/1ztH8gUPv31hSjEeR_8pw20qShTwViGRx?usp=sharing)\nfor more details. On this particular example, using the Colab GPU vs using a\nsingle CPU core leads to a 10x speedup which is quite typical for such workloads.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Free-threaded CPython 3.14 support\n\nscikit-learn has support for free-threaded CPython, in particular\nfree-threaded wheels are available for all of our supported platforms on Python\n3.14.\n\nWe would be very interested by user feedback. Here are a few things you can\ntry:\n\n- install free-threaded CPython 3.14, run your favourite\n scikit-learn script and check that nothing breaks unexpectedly.\n Note that CPython 3.14 (rather than 3.13) is strongly advised because a\n number of free-threaded bugs have been fixed since CPython 3.13.\n- if you use some estimators with a `n_jobs` parameter, try changing the\n default backend to threading with `joblib.parallel_config` as in the\n snippet below. This could potentially speed-up your code because the\n default joblib backend is process-based and incurs more overhead than\n threads.\n\n```python\ngrid_search = GridSearchCV(clf, param_grid=param_grid, n_jobs=4)\nwith joblib.parallel_config(backend=\"threading\"):\n grid_search.fit(X, y)\n```\n- don't hesitate to report any issue or unexpected performance behaviour by\n opening a [GitHub issue](https://github.com/scikit-learn/scikit-learn/issues/new/choose)!\n\nFree-threaded (also known as nogil) CPython is a version of CPython that aims\nto enable efficient multi-threaded use cases by removing the Global\nInterpreter Lock (GIL).\n\nFor more details about free-threaded CPython see [py-free-threading doc](https://py-free-threading.github.io), in particular [how to install a\nfree-threaded CPython](https://py-free-threading.github.io/installing-cpython/)\nand [Ecosystem compatibility tracking](https://py-free-threading.github.io/tracking/).\n\nIn scikit-learn, one hope with free-threaded Python is to more efficiently\nleverage multi-core CPUs by using thread workers instead of subprocess\nworkers for parallel computation when passing `n_jobs>1` in functions or\nestimators. Efficiency gains are expected by removing the need for\ninter-process communication. Be aware that switching the default joblib\nbackend and testing that everything works well with free-threaded Python is an\nongoing long-term effort.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Temperature scaling in `CalibratedClassifierCV`\nProbability calibration of classifiers with temperature scaling is available in\n:class:`calibration.CalibratedClassifierCV` by setting `method=\"temperature\"`.\nThis method is particularly well suited for multiclass problems because it provides\n(better) calibrated probabilities with a single free parameter. This is in\ncontrast to all the other available calibrations methods\nwhich use a \"One-vs-Rest\" scheme that adds more parameters for each class.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.calibration import CalibratedClassifierCV\nfrom sklearn.datasets import make_classification\nfrom sklearn.naive_bayes import GaussianNB\n\nX, y = make_classification(n_classes=3, n_informative=8, random_state=42)\nclf = GaussianNB().fit(X, y)\nsig = CalibratedClassifierCV(clf, method=\"sigmoid\", ensemble=False).fit(X, y)\nts = CalibratedClassifierCV(clf, method=\"temperature\", ensemble=False).fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following example shows that temperature scaling can produce better calibrated\nprobabilities than sigmoid calibration in multi-class classification problem\nwith 3 classes.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n\nfrom sklearn.calibration import CalibrationDisplay\n\nfig, axes = plt.subplots(\n figsize=(8, 4.5),\n ncols=3,\n sharey=True,\n)\nfor i, c in enumerate(ts.classes_):\n CalibrationDisplay.from_predictions(\n y == c, clf.predict_proba(X)[:, i], name=\"Uncalibrated\", ax=axes[i], marker=\"s\"\n )\n CalibrationDisplay.from_predictions(\n y == c,\n ts.predict_proba(X)[:, i],\n name=\"Temperature scaling\",\n ax=axes[i],\n marker=\"o\",\n )\n CalibrationDisplay.from_predictions(\n y == c, sig.predict_proba(X)[:, i], name=\"Sigmoid\", ax=axes[i], marker=\"v\"\n )\n axes[i].set_title(f\"Class {c}\")\n axes[i].set_xlabel(None)\n axes[i].set_ylabel(None)\n axes[i].get_legend().remove()\nfig.suptitle(\"Reliability Diagrams per Class\")\nfig.supxlabel(\"Mean Predicted Probability\")\nfig.supylabel(\"Fraction of Class\")\nfig.legend(*axes[0].get_legend_handles_labels(), loc=(0.72, 0.5))\nplt.subplots_adjust(right=0.7)\n_ = fig.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Efficiency improvements in linear models\nThe fit time has been massively reduced for squared error based estimators\nwith L1 penalty: `ElasticNet`, `Lasso`, `MultiTaskElasticNet`,\n`MultiTaskLasso` and their CV variants. The fit time improvement is mainly\nachieved by **gap safe screening rules**. They enable the coordinate descent\nsolver to set feature coefficients to zero early on and not look at them\nagain. The stronger the L1 penalty the earlier features can be excluded from\nfurther updates.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from time import time\n\nfrom sklearn.datasets import make_regression\nfrom sklearn.linear_model import ElasticNetCV\n\nX, y = make_regression(n_features=10_000, random_state=0)\nmodel = ElasticNetCV()\ntic = time()\nmodel.fit(X, y)\ntoc = time()\nprint(f\"Fitting ElasticNetCV took {toc - tic:.3} seconds.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## HTML representation of estimators\nHyperparameters in the dropdown table of the HTML representation now include\nlinks to the online documentation. Docstring descriptions are also shown as\ntooltips on hover.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\nfrom sklearn.pipeline import make_pipeline\nfrom sklearn.preprocessing import StandardScaler\n\nclf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0, C=10))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Expand the estimator diagram below by clicking on \"LogisticRegression\" and then on\n\"Parameters\".\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "clf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DecisionTreeRegressor with `criterion=\"absolute_error\"`\n:class:`tree.DecisionTreeRegressor` with `criterion=\"absolute_error\"`\nnow runs much faster. It has now `O(n * log(n))` complexity compared to\n`O(n**2)` previously, which allows to scale to millions of data points.\n\nAs an illustration, on a dataset with 100_000 samples and 1 feature, doing a\nsingle split takes of the order of 100 ms, compared to ~20 seconds before.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import time\n\nfrom sklearn.datasets import make_regression\nfrom sklearn.tree import DecisionTreeRegressor\n\nX, y = make_regression(n_samples=100_000, n_features=1)\ntree = DecisionTreeRegressor(criterion=\"absolute_error\", max_depth=1)\n\ntic = time.time()\ntree.fit(X, y)\nelapsed = time.time() - tic\nprint(f\"Fit took {elapsed:.2f} seconds\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ClassicalMDS\nClassical MDS, also known as \"Principal Coordinates Analysis\" (PCoA)\nor \"Torgerson's scaling\" is now available within the `sklearn.manifold`\nmodule. Classical MDS is close to PCA and instead of approximating\ndistances, it approximates pairwise scalar products, which has an exact\nanalytic solution in terms of eigendecomposition.\n\nLet's illustrate this new addition by using it on an S-curve dataset to\nget a low-dimensional representation of the data.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\nfrom matplotlib import ticker\n\nfrom sklearn import datasets, manifold\n\nn_samples = 1500\nS_points, S_color = datasets.make_s_curve(n_samples, random_state=0)\nmd_classical = manifold.ClassicalMDS(n_components=2)\nS_scaling = md_classical.fit_transform(S_points)\n\nfig = plt.figure(figsize=(8, 4))\nax1 = fig.add_subplot(1, 2, 1, projection=\"3d\")\nx, y, z = S_points.T\nax1.scatter(x, y, z, c=S_color, s=50, alpha=0.8)\nax1.set_title(\"Original S-curve samples\", size=16)\nax1.view_init(azim=-60, elev=9)\nfor axis in (ax1.xaxis, ax1.yaxis, ax1.zaxis):\n axis.set_major_locator(ticker.MultipleLocator(1))\n\nax2 = fig.add_subplot(1, 2, 2)\nx2, y2 = S_scaling.T\nax2.scatter(x2, y2, c=S_color, s=50, alpha=0.8)\nax2.set_title(\"Classical MDS\", size=16)\nfor axis in (ax2.xaxis, ax2.yaxis):\n axis.set_major_formatter(ticker.NullFormatter())\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 0 }