{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Metadata Routing\n\n.. currentmodule:: sklearn\n\nThis document shows how you can use the `metadata routing mechanism\n` in scikit-learn to route metadata to the estimators,\nscorers, and CV splitters consuming them.\n\nTo better understand the following document, we need to introduce two concepts:\nrouters and consumers. A router is an object which forwards some given data and\nmetadata to other objects. In most cases, a router is a :term:`meta-estimator`,\ni.e. an estimator which takes another estimator as a parameter. A function such\nas :func:`sklearn.model_selection.cross_validate` which takes an estimator as a\nparameter and forwards data and metadata, is also a router.\n\nA consumer, on the other hand, is an object which accepts and uses some given\nmetadata. For instance, an estimator taking into account ``sample_weight`` in\nits :term:`fit` method is a consumer of ``sample_weight``.\n\nIt is possible for an object to be both a router and a consumer. For instance,\na meta-estimator may take into account ``sample_weight`` in certain\ncalculations, but it may also route it to the underlying estimator.\n\nFirst a few imports and some random data for the rest of the script.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import warnings\nfrom pprint import pprint\n\nimport numpy as np\n\nfrom sklearn import set_config\nfrom sklearn.base import (\n BaseEstimator,\n ClassifierMixin,\n MetaEstimatorMixin,\n RegressorMixin,\n TransformerMixin,\n clone,\n)\nfrom sklearn.linear_model import LinearRegression\nfrom sklearn.utils import metadata_routing\nfrom sklearn.utils.metadata_routing import (\n MetadataRouter,\n MethodMapping,\n get_routing_for_object,\n process_routing,\n)\nfrom sklearn.utils.validation import check_is_fitted\n\nn_samples, n_features = 100, 4\nrng = np.random.RandomState(42)\nX = rng.rand(n_samples, n_features)\ny = rng.randint(0, 2, size=n_samples)\nmy_groups = rng.randint(0, 10, size=n_samples)\nmy_weights = rng.rand(n_samples)\nmy_other_weights = rng.rand(n_samples)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Metadata routing is only available if explicitly enabled:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "set_config(enable_metadata_routing=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This utility function is a dummy to check if a metadata is passed:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def check_metadata(obj, **kwargs):\n for key, value in kwargs.items():\n if value is not None:\n print(\n f\"Received {key} of length = {len(value)} in {obj.__class__.__name__}.\"\n )\n else:\n print(f\"{key} is None in {obj.__class__.__name__}.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A utility function to nicely print the routing information of an object:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def print_routing(obj):\n pprint(obj.get_metadata_routing()._serialize())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Consuming Estimator\nHere we demonstrate how an estimator can expose the required API to support\nmetadata routing as a consumer. Imagine a simple classifier accepting\n``sample_weight`` as a metadata on its ``fit`` and ``groups`` in its\n``predict`` method:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class ExampleClassifier(ClassifierMixin, BaseEstimator):\n def fit(self, X, y, sample_weight=None):\n check_metadata(self, sample_weight=sample_weight)\n # all classifiers need to expose a classes_ attribute once they're fit.\n self.classes_ = np.array([0, 1])\n return self\n\n def predict(self, X, groups=None):\n check_metadata(self, groups=groups)\n # return a constant value of 1, not a very smart classifier!\n return np.ones(len(X))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above estimator now has all it needs to consume metadata. This is\naccomplished by some magic done in :class:`~base.BaseEstimator`. There are\nnow three methods exposed by the above class: ``set_fit_request``,\n``set_predict_request``, and ``get_metadata_routing``. There is also a\n``set_score_request`` for ``sample_weight`` which is present since\n:class:`~base.ClassifierMixin` implements a ``score`` method accepting\n``sample_weight``. The same applies to regressors which inherit from\n:class:`~base.RegressorMixin`.\n\nBy default, no metadata is requested, which we can see as:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print_routing(ExampleClassifier())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above output means that ``sample_weight`` and ``groups`` are not\nrequested by `ExampleClassifier`, and if a router is given those metadata, it\nshould raise an error, since the user has not explicitly set whether they are\nrequired or not. The same is true for ``sample_weight`` in the ``score``\nmethod, which is inherited from :class:`~base.ClassifierMixin`. In order to\nexplicitly set request values for those metadata, we can use these methods:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "est = (\n ExampleClassifier()\n .set_fit_request(sample_weight=False)\n .set_predict_request(groups=True)\n .set_score_request(sample_weight=False)\n)\nprint_routing(est)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ ".. note ::\n Please note that as long as the above estimator is not used in a\n meta-estimator, the user does not need to set any requests for the\n metadata and the set values are ignored, since a consumer does not\n validate or route given metadata. A simple usage of the above estimator\n would work as expected.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "est = ExampleClassifier()\nest.fit(X, y, sample_weight=my_weights)\nest.predict(X[:3, :], groups=my_groups)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Routing Meta-Estimator\nNow, we show how to design a meta-estimator to be a router. As a simplified\nexample, here is a meta-estimator, which doesn't do much other than routing\nthe metadata.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):\n def __init__(self, estimator):\n self.estimator = estimator\n\n def get_metadata_routing(self):\n # This method defines the routing for this meta-estimator.\n # In order to do so, a `MetadataRouter` instance is created, and the\n # routing is added to it. More explanations follow below.\n router = MetadataRouter(owner=self.__class__.__name__).add(\n estimator=self.estimator,\n method_mapping=MethodMapping()\n .add(caller=\"fit\", callee=\"fit\")\n .add(caller=\"predict\", callee=\"predict\")\n .add(caller=\"score\", callee=\"score\"),\n )\n return router\n\n def fit(self, X, y, **fit_params):\n # `get_routing_for_object` returns a copy of the `MetadataRouter`\n # constructed by the above `get_metadata_routing` method, that is\n # internally called.\n request_router = get_routing_for_object(self)\n # Meta-estimators are responsible for validating the given metadata.\n # `method` refers to the parent's method, i.e. `fit` in this example.\n request_router.validate_metadata(params=fit_params, method=\"fit\")\n # `MetadataRouter.route_params` maps the given metadata to the metadata\n # required by the underlying estimator based on the routing information\n # defined by the MetadataRouter. The output of type `Bunch` has a key\n # for each consuming object and those hold keys for their consuming\n # methods, which then contain key for the metadata which should be\n # routed to them.\n routed_params = request_router.route_params(params=fit_params, caller=\"fit\")\n\n # A sub-estimator is fitted and its classes are attributed to the\n # meta-estimator.\n self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)\n self.classes_ = self.estimator_.classes_\n return self\n\n def predict(self, X, **predict_params):\n check_is_fitted(self)\n # As in `fit`, we get a copy of the object's MetadataRouter,\n request_router = get_routing_for_object(self)\n # then we validate the given metadata,\n request_router.validate_metadata(params=predict_params, method=\"predict\")\n # and then prepare the input to the underlying `predict` method.\n routed_params = request_router.route_params(\n params=predict_params, caller=\"predict\"\n )\n return self.estimator_.predict(X, **routed_params.estimator.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's break down different parts of the above code.\n\nFirst, the :meth:`~utils.metadata_routing.get_routing_for_object` takes our\nmeta-estimator (``self``) and returns a\n:class:`~utils.metadata_routing.MetadataRouter` or, a\n:class:`~utils.metadata_routing.MetadataRequest` if the object is a consumer,\nbased on the output of the estimator's ``get_metadata_routing`` method.\n\nThen in each method, we use the ``route_params`` method to construct a\ndictionary of the form ``{\"object_name\": {\"method_name\": {\"metadata\":\nvalue}}}`` to pass to the underlying estimator's method. The ``object_name``\n(``estimator`` in the above ``routed_params.estimator.fit`` example) is the\nsame as the one added in the ``get_metadata_routing``. ``validate_metadata``\nmakes sure all given metadata are requested to avoid silent bugs.\n\nNext, we illustrate the different behaviors and notably the type of errors\nraised.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est = MetaClassifier(\n estimator=ExampleClassifier().set_fit_request(sample_weight=True)\n)\nmeta_est.fit(X, y, sample_weight=my_weights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the above example is calling our utility function\n`check_metadata()` via the `ExampleClassifier`. It checks that\n``sample_weight`` is correctly passed to it. If it is not, like in the\nfollowing example, it would print that ``sample_weight`` is ``None``:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we pass an unknown metadata, an error is raised:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "try:\n meta_est.fit(X, y, test=my_weights)\nexcept TypeError as e:\n print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And if we pass a metadata which is not explicitly requested:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "try:\n meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)\nexcept ValueError as e:\n print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Also, if we explicitly set it as not requested, but it is provided:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est = MetaClassifier(\n estimator=ExampleClassifier()\n .set_fit_request(sample_weight=True)\n .set_predict_request(groups=False)\n)\ntry:\n meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)\nexcept TypeError as e:\n print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another concept to introduce is **aliased metadata**. This is when an\nestimator requests a metadata with a different variable name than the default\nvariable name. For instance, in a setting where there are two estimators in a\npipeline, one could request ``sample_weight1`` and the other\n``sample_weight2``. Note that this doesn't change what the estimator expects,\nit only tells the meta-estimator how to map the provided metadata to what is\nrequired. Here's an example, where we pass ``aliased_sample_weight`` to the\nmeta-estimator, but the meta-estimator understands that\n``aliased_sample_weight`` is an alias for ``sample_weight``, and passes it as\n``sample_weight`` to the underlying estimator:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est = MetaClassifier(\n estimator=ExampleClassifier().set_fit_request(sample_weight=\"aliased_sample_weight\")\n)\nmeta_est.fit(X, y, aliased_sample_weight=my_weights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Passing ``sample_weight`` here will fail since it is requested with an\nalias and ``sample_weight`` with that name is not requested:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "try:\n meta_est.fit(X, y, sample_weight=my_weights)\nexcept TypeError as e:\n print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This leads us to the ``get_metadata_routing``. The way routing works in\nscikit-learn is that consumers request what they need, and routers pass that\nalong. Additionally, a router exposes what it requires itself so that it can\nbe used inside another router, e.g. a pipeline inside a grid search object.\nThe output of the ``get_metadata_routing`` which is a dictionary\nrepresentation of a :class:`~utils.metadata_routing.MetadataRouter`, includes\nthe complete tree of requested metadata by all nested objects and their\ncorresponding method routings, i.e. which method of a sub-estimator is used\nin which method of a meta-estimator:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print_routing(meta_est)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, the only metadata requested for method ``fit`` is\n``\"sample_weight\"`` with ``\"aliased_sample_weight\"`` as the alias. The\n``~utils.metadata_routing.MetadataRouter`` class enables us to easily create\nthe routing object which would create the output we need for our\n``get_metadata_routing``.\n\nIn order to understand how aliases work in meta-estimators, imagine our\nmeta-estimator inside another one:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_meta_est = MetaClassifier(estimator=meta_est).fit(\n X, y, aliased_sample_weight=my_weights\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the above example, this is how the ``fit`` method of `meta_meta_est`\nwill call their sub-estimator's ``fit`` methods::\n\n # user feeds `my_weights` as `aliased_sample_weight` into `meta_meta_est`:\n meta_meta_est.fit(X, y, aliased_sample_weight=my_weights):\n ...\n\n # the first sub-estimator (`meta_est`) expects `aliased_sample_weight`\n self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight):\n ...\n\n # the second sub-estimator (`est`) expects `sample_weight`\n self.estimator_.fit(X, y, sample_weight=aliased_sample_weight):\n ...\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Consuming and routing Meta-Estimator\nFor a slightly more complex example, consider a meta-estimator that routes\nmetadata to an underlying estimator as before, but it also uses some metadata\nin its own methods. This meta-estimator is a consumer and a router at the\nsame time. Implementing one is very similar to what we had before, but with a\nfew tweaks.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):\n def __init__(self, estimator):\n self.estimator = estimator\n\n def get_metadata_routing(self):\n router = (\n MetadataRouter(owner=self.__class__.__name__)\n # defining metadata routing request values for usage in the meta-estimator\n .add_self_request(self)\n # defining metadata routing request values for usage in the sub-estimator\n .add(\n estimator=self.estimator,\n method_mapping=MethodMapping()\n .add(caller=\"fit\", callee=\"fit\")\n .add(caller=\"predict\", callee=\"predict\")\n .add(caller=\"score\", callee=\"score\"),\n )\n )\n return router\n\n # Since `sample_weight` is used and consumed here, it should be defined as\n # an explicit argument in the method's signature. All other metadata which\n # are only routed, will be passed as `**fit_params`:\n def fit(self, X, y, sample_weight, **fit_params):\n if self.estimator is None:\n raise ValueError(\"estimator cannot be None!\")\n\n check_metadata(self, sample_weight=sample_weight)\n\n # We add `sample_weight` to the `fit_params` dictionary.\n if sample_weight is not None:\n fit_params[\"sample_weight\"] = sample_weight\n\n request_router = get_routing_for_object(self)\n request_router.validate_metadata(params=fit_params, method=\"fit\")\n routed_params = request_router.route_params(params=fit_params, caller=\"fit\")\n self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)\n self.classes_ = self.estimator_.classes_\n return self\n\n def predict(self, X, **predict_params):\n check_is_fitted(self)\n # As in `fit`, we get a copy of the object's MetadataRouter,\n request_router = get_routing_for_object(self)\n # we validate the given metadata,\n request_router.validate_metadata(params=predict_params, method=\"predict\")\n # and then prepare the input to the underlying ``predict`` method.\n routed_params = request_router.route_params(\n params=predict_params, caller=\"predict\"\n )\n return self.estimator_.predict(X, **routed_params.estimator.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The key parts where the above meta-estimator differs from our previous\nmeta-estimator is accepting ``sample_weight`` explicitly in ``fit`` and\nincluding it in ``fit_params``. Since ``sample_weight`` is an explicit\nargument, we can be sure that ``set_fit_request(sample_weight=...)`` is\npresent for this method. The meta-estimator is both a consumer, as well as a\nrouter of ``sample_weight``.\n\nIn ``get_metadata_routing``, we add ``self`` to the routing using\n``add_self_request`` to indicate this estimator is consuming\n``sample_weight`` as well as being a router; which also adds a\n``$self_request`` key to the routing info as illustrated below. Now let's\nlook at some examples:\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- No metadata requested\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est = RouterConsumerClassifier(estimator=ExampleClassifier())\nprint_routing(meta_est)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- ``sample_weight`` requested by sub-estimator\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est = RouterConsumerClassifier(\n estimator=ExampleClassifier().set_fit_request(sample_weight=True)\n)\nprint_routing(meta_est)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- ``sample_weight`` requested by meta-estimator\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(\n sample_weight=True\n)\nprint_routing(meta_est)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note the difference in the requested metadata representations above.\n\n- We can also alias the metadata to pass different values to the fit methods\n of the meta- and the sub-estimator:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est = RouterConsumerClassifier(\n estimator=ExampleClassifier().set_fit_request(sample_weight=\"clf_sample_weight\"),\n).set_fit_request(sample_weight=\"meta_clf_sample_weight\")\nprint_routing(meta_est)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, ``fit`` of the meta-estimator only needs the alias for the\nsub-estimator and addresses their own sample weight as `sample_weight`, since\nit doesn't validate and route its own required metadata:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Alias only on the sub-estimator:\n\nThis is useful when we don't want the meta-estimator to use the metadata, but\nthe sub-estimator should.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "meta_est = RouterConsumerClassifier(\n estimator=ExampleClassifier().set_fit_request(sample_weight=\"aliased_sample_weight\")\n)\nprint_routing(meta_est)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The meta-estimator cannot use `aliased_sample_weight`, because it expects\nit passed as `sample_weight`. This would apply even if\n`set_fit_request(sample_weight=True)` was set on it.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple Pipeline\nA slightly more complicated use-case is a meta-estimator resembling a\n:class:`~pipeline.Pipeline`. Here is a meta-estimator, which accepts a\ntransformer and a classifier. When calling its `fit` method, it applies the\ntransformer's `fit` and `transform` before running the classifier on the\ntransformed data. Upon `predict`, it applies the transformer's `transform`\nbefore predicting with the classifier's `predict` method on the transformed\nnew data.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class SimplePipeline(ClassifierMixin, BaseEstimator):\n def __init__(self, transformer, classifier):\n self.transformer = transformer\n self.classifier = classifier\n\n def get_metadata_routing(self):\n router = (\n MetadataRouter(owner=self.__class__.__name__)\n # We add the routing for the transformer.\n .add(\n transformer=self.transformer,\n method_mapping=MethodMapping()\n # The metadata is routed such that it retraces how\n # `SimplePipeline` internally calls the transformer's `fit` and\n # `transform` methods in its own methods (`fit` and `predict`).\n .add(caller=\"fit\", callee=\"fit\")\n .add(caller=\"fit\", callee=\"transform\")\n .add(caller=\"predict\", callee=\"transform\"),\n )\n # We add the routing for the classifier.\n .add(\n classifier=self.classifier,\n method_mapping=MethodMapping()\n .add(caller=\"fit\", callee=\"fit\")\n .add(caller=\"predict\", callee=\"predict\"),\n )\n )\n return router\n\n def fit(self, X, y, **fit_params):\n routed_params = process_routing(self, \"fit\", **fit_params)\n\n self.transformer_ = clone(self.transformer).fit(\n X, y, **routed_params.transformer.fit\n )\n X_transformed = self.transformer_.transform(\n X, **routed_params.transformer.transform\n )\n\n self.classifier_ = clone(self.classifier).fit(\n X_transformed, y, **routed_params.classifier.fit\n )\n return self\n\n def predict(self, X, **predict_params):\n routed_params = process_routing(self, \"predict\", **predict_params)\n\n X_transformed = self.transformer_.transform(\n X, **routed_params.transformer.transform\n )\n return self.classifier_.predict(\n X_transformed, **routed_params.classifier.predict\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note the usage of :class:`~utils.metadata_routing.MethodMapping` to\ndeclare which methods of the child estimator (callee) are used in which\nmethods of the meta estimator (caller). As you can see, `SimplePipeline` uses\nthe transformer's ``transform`` and ``fit`` methods in ``fit``, and its\n``transform`` method in ``predict``, and that's what you see implemented in\nthe routing structure of the pipeline class.\n\nAnother difference in the above example with the previous ones is the usage\nof :func:`~utils.metadata_routing.process_routing`, which processes the input\nparameters, does the required validation, and returns the `routed_params`\nwhich we had created in previous examples. This reduces the boilerplate code\na developer needs to write in each meta-estimator's method. Developers are\nstrongly recommended to use this function unless there is a good reason\nagainst it.\n\nIn order to test the above pipeline, let's add an example transformer.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class ExampleTransformer(TransformerMixin, BaseEstimator):\n def fit(self, X, y, sample_weight=None):\n check_metadata(self, sample_weight=sample_weight)\n return self\n\n def transform(self, X, groups=None):\n check_metadata(self, groups=groups)\n return X\n\n def fit_transform(self, X, y, sample_weight=None, groups=None):\n return self.fit(X, y, sample_weight).transform(X, groups)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that in the above example, we have implemented ``fit_transform`` which\ncalls ``fit`` and ``transform`` with the appropriate metadata. This is only\nrequired if ``transform`` accepts metadata, since the default ``fit_transform``\nimplementation in :class:`~base.TransformerMixin` doesn't pass metadata to\n``transform``.\n\nNow we can test our pipeline, and see if metadata is correctly passed around.\nThis example uses our `SimplePipeline`, our `ExampleTransformer`, and our\n`RouterConsumerClassifier` which uses our `ExampleClassifier`.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pipe = SimplePipeline(\n transformer=ExampleTransformer()\n # we set transformer's fit to receive sample_weight\n .set_fit_request(sample_weight=True)\n # we set transformer's transform to receive groups\n .set_transform_request(groups=True),\n classifier=RouterConsumerClassifier(\n estimator=ExampleClassifier()\n # we want this sub-estimator to receive sample_weight in fit\n .set_fit_request(sample_weight=True)\n # but not groups in predict\n .set_predict_request(groups=False),\n )\n # and we want the meta-estimator to receive sample_weight as well\n .set_fit_request(sample_weight=True),\n)\npipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(\n X[:3], groups=my_groups\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deprecation / Default Value Change\nIn this section we show how one should handle the case where a router becomes\nalso a consumer, especially when it consumes the same metadata as its\nsub-estimator, or a consumer starts consuming a metadata which it wasn't in\nan older release. In this case, a warning should be raised for a while, to\nlet users know the behavior is changed from previous versions.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):\n def __init__(self, estimator):\n self.estimator = estimator\n\n def fit(self, X, y, **fit_params):\n routed_params = process_routing(self, \"fit\", **fit_params)\n self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)\n\n def get_metadata_routing(self):\n router = MetadataRouter(owner=self.__class__.__name__).add(\n estimator=self.estimator,\n method_mapping=MethodMapping().add(caller=\"fit\", callee=\"fit\"),\n )\n return router" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As explained above, this is a valid usage if `my_weights` aren't supposed\nto be passed as `sample_weight` to `MetaRegressor`:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True))\nreg.fit(X, y, sample_weight=my_weights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now imagine we further develop ``MetaRegressor`` and it now also *consumes*\n``sample_weight``:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):\n # show warning to remind user to explicitly set the value with\n # `.set_{method}_request(sample_weight={boolean})`\n __metadata_request__fit = {\"sample_weight\": metadata_routing.WARN}\n\n def __init__(self, estimator):\n self.estimator = estimator\n\n def fit(self, X, y, sample_weight=None, **fit_params):\n routed_params = process_routing(\n self, \"fit\", sample_weight=sample_weight, **fit_params\n )\n check_metadata(self, sample_weight=sample_weight)\n self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)\n\n def get_metadata_routing(self):\n router = (\n MetadataRouter(owner=self.__class__.__name__)\n .add_self_request(self)\n .add(\n estimator=self.estimator,\n method_mapping=MethodMapping().add(caller=\"fit\", callee=\"fit\"),\n )\n )\n return router" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The above implementation is almost the same as ``MetaRegressor``, and\nbecause of the default request value defined in ``__metadata_request__fit``\nthere is a warning raised when fitted.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with warnings.catch_warnings(record=True) as record:\n WeightedMetaRegressor(\n estimator=LinearRegression().set_fit_request(sample_weight=False)\n ).fit(X, y, sample_weight=my_weights)\nfor w in record:\n print(w.message)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When an estimator consumes a metadata which it didn't consume before, the\nfollowing pattern can be used to warn the users about it.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class ExampleRegressor(RegressorMixin, BaseEstimator):\n __metadata_request__fit = {\"sample_weight\": metadata_routing.WARN}\n\n def fit(self, X, y, sample_weight=None):\n check_metadata(self, sample_weight=sample_weight)\n return self\n\n def predict(self, X):\n return np.zeros(shape=(len(X)))\n\n\nwith warnings.catch_warnings(record=True) as record:\n MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)\nfor w in record:\n print(w.message)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At the end we disable the configuration flag for metadata routing:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "set_config(enable_metadata_routing=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Third Party Development and scikit-learn Dependency\n\nAs seen above, information is communicated between classes using\n:class:`~utils.metadata_routing.MetadataRequest` and\n:class:`~utils.metadata_routing.MetadataRouter`. It is strongly not advised,\nbut possible to vendor the tools related to metadata-routing if you strictly\nwant to have a scikit-learn compatible estimator, without depending on the\nscikit-learn package. If all of the following conditions are met, you do NOT\nneed to modify your code at all:\n\n- your estimator inherits from :class:`~base.BaseEstimator`\n- the parameters consumed by your estimator's methods, e.g. ``fit``, are\n explicitly defined in the method's signature, as opposed to being\n ``*args`` or ``*kwargs``.\n- your estimator does not route any metadata to the underlying objects, i.e.\n it's not a *router*.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 0 }