{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.data.all import *\n", "from local.optimizer import *\n", "from local.learner import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp metrics\n", "# default_cls_lvl 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Metrics\n", "\n", "> Definition of the metrics that can be used in training models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Core metric" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is where the function that converts scikit-learn metrics to fastai metrics is defined. You should skip this section unless you want to know all about the internals of fastai." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "import sklearn.metrics as skm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export torch_core\n", "def flatten_check(inp, targ):\n", " \"Check that `out` and `targ` have the same number of elements and flatten them.\"\n", " inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)\n", " test_eq(len(inp), len(targ))\n", " return inp,targ" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(5,4),torch.randn(20)\n", "x1,x2 = flatten_check(x1,x2)\n", "test_eq(x1.shape, [20])\n", "test_eq(x2.shape, [20])\n", "x1,x2 = torch.randn(5,4),torch.randn(21)\n", "test_fail(lambda: flatten_check(x1,x2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AccumMetric(Metric):\n", " \"Stores predictions and targets on CPU in accumulate to perform final calculations with `func`.\"\n", " def __init__(self, func, dim_argmax=None, sigmoid=False, thresh=None, to_np=False, invert_arg=False,\n", " flatten=True, **kwargs):\n", " store_attr(self,'func,dim_argmax,sigmoid,thresh,flatten')\n", " self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs\n", "\n", " def reset(self): self.targs,self.preds = [],[]\n", "\n", " def accumulate(self, learn):\n", " pred = learn.pred.argmax(dim=self.dim_argmax) if self.dim_argmax else learn.pred\n", " if self.sigmoid: pred = torch.sigmoid(pred)\n", " if self.thresh: pred = (pred >= self.thresh)\n", " targ = learn.y\n", " if self.flatten: pred,targ = flatten_check(pred,targ)\n", " self.preds.append(to_detach(pred))\n", " self.targs.append(to_detach(targ))\n", "\n", " @property\n", " def value(self):\n", " if len(self.preds) == 0: return\n", " preds,targs = torch.cat(self.preds),torch.cat(self.targs)\n", " if self.to_np: preds,targs = preds.numpy(),targs.numpy()\n", " return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`func` is only applied to the accumulated predictions/targets when the `value` attribute is asked for (so at the end of a validation/trianing phase, in use with `Learner` and its `Recorder`).The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).\n", "\n", "For classification problems with single label, predictions need to be transformed with a sofmax then an argmax before being compared to the targets. Since a softmax doesn't change the order of the numbers, we can just apply the argmax. Pass along `dim_argmax` to have this done by `AccumMetric` (usually -1 will work pretty well).\n", "\n", "For classification problems with multiple labels, or if your targets are onehot-encoded, predictions may need to pass through a sigmoid (if it wasn't included in your model) then be compared to a given threshold (to decide between 0 and 1), this is done by `AccumMetric` if you pass `sigmoid=True` and/or a value for `thresh`.\n", "\n", "If you want to use a metric function sklearn.metrics, you will need to convert predictions and labels to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from us, so you will need to pass `invert_arg=True` to make `AccumMetric` do the inversion for you." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#For testing: a fake learner and a metric that isn't an average\n", "class TstLearner():\n", " def __init__(self): self.pred,self.y = None,None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _l2_mean(x,y): return torch.sqrt((x.float()-y.float()).pow(2).mean())\n", "\n", "#Go through a fake cycle with various batch sizes and computes the value of met\n", "def compute_val(met, x1, x2):\n", " met.reset()\n", " vals = [0,6,15,20]\n", " learn = TstLearner()\n", " for i in range(3): \n", " learn.pred,learn.y = x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]]\n", " met.accumulate(learn)\n", " return met.value" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "tst = AccumMetric(_l2_mean)\n", "test_close(compute_val(tst, x1, x2), _l2_mean(x1, x2))\n", "test_eq(torch.cat(tst.preds), x1.view(-1))\n", "test_eq(torch.cat(tst.targs), x2.view(-1))\n", "\n", "#test argmax\n", "x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))\n", "tst = AccumMetric(_l2_mean, dim_argmax=-1)\n", "test_close(compute_val(tst, x1, x2), _l2_mean(x1.argmax(dim=-1), x2))\n", "\n", "#test thresh\n", "x1,x2 = torch.randn(20,5),torch.randint(0, 2, (20,5)).bool()\n", "tst = AccumMetric(_l2_mean, thresh=0.5)\n", "test_close(compute_val(tst, x1, x2), _l2_mean((x1 >= 0.5), x2))\n", "\n", "#test sigmoid\n", "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "tst = AccumMetric(_l2_mean, sigmoid=True)\n", "test_close(compute_val(tst, x1, x2), _l2_mean(torch.sigmoid(x1), x2))\n", "\n", "#test to_np\n", "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "tst = AccumMetric(lambda x,y: isinstance(x, np.ndarray) and isinstance(y, np.ndarray), to_np=True)\n", "assert compute_val(tst, x1, x2)\n", "\n", "#test invert_arg\n", "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "tst = AccumMetric(lambda x,y: torch.sqrt(x.pow(2).mean()))\n", "test_close(compute_val(tst, x1, x2), torch.sqrt(x1.pow(2).mean()))\n", "tst = AccumMetric(lambda x,y: torch.sqrt(x.pow(2).mean()), invert_arg=True)\n", "test_close(compute_val(tst, x1, x2), torch.sqrt(x2.pow(2).mean()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def skm_to_fastai(func, is_class=True, thresh=None, axis=-1, sigmoid=None, **kwargs):\n", " \"Convert `func` from sklearn.metrics to a fastai metric\"\n", " dim_argmax = axis if is_class and thresh is None else None\n", " sigmoid = sigmoid if sigmoid is not None else (is_class and thresh is not None)\n", " return AccumMetric(func, dim_argmax=dim_argmax, sigmoid=sigmoid, thresh=thresh,\n", " to_np=True, invert_arg=True, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the quickest way to use a sckit-learn metric in a fastai training loop. `is_class` indicates if you are in a classification problem or not. In this case:\n", "- leaving `thresh` to `None` indicates it's a single-label classification problem and predictions will pass through an argmax over `axis` before being compared to the targets\n", "- setting a value for `thresh` indicates it's a multi-label classification problem and predictions will pass through a sigmoid (can be deactivated with `sigmoid=False`) and be compared to `thresh` before being compared to the targets\n", "\n", "If `is_class=False`, it indicates you are in a regression problem, and predictions are compared to the targets without being modified. In all cases, `kwargs` are extra keyword arguments passed to `func`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_single = skm_to_fastai(skm.precision_score)\n", "x1,x2 = torch.randn(20,2),torch.randint(0, 2, (20,))\n", "test_close(compute_val(tst_single, x1, x2), skm.precision_score(x2, x1.argmax(dim=-1)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2)\n", "x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))\n", "test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, torch.sigmoid(x1) >= 0.2))\n", "\n", "tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2, sigmoid=False)\n", "x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))\n", "test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, x1 >= 0.2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_reg = skm_to_fastai(skm.r2_score, is_class=False)\n", "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "test_close(compute_val(tst_reg, x1, x2), skm.r2_score(x2.view(-1), x1.view(-1)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def optim_metric(f, argname, bounds, tol=0.01, do_neg=True, get_x=False):\n", " \"Replace metric `f` with a version that optimizes argument `argname`\"\n", " def _f(preds, targs):\n", " def minfunc(x):\n", " kwargs = {argname:x}\n", " res = f(preds, targs, **kwargs)\n", " return -res if do_neg else res\n", " optres = scipy.optimize.minimize_scalar(minfunc, bounds=bounds, method='bounded',\n", " options={'xatol':0.01})\n", " fun = -optres.fun if do_neg else optres.fun\n", " return (fun,optres.x) if get_x else fun\n", " _f.__name__ = f'opt_{f.__name__}'\n", " return _f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Single-label classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: All functions defined in this section are intended for single-label classification and targets that aren't one-hot encoded. For multi-label problems or one-hot encoded targets, use the `_multi` version of them." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def accuracy(inp, targ, axis=-1):\n", " \"Compute accuracy with `targ` when `pred` is bs * n_classes\"\n", " pred,targ = flatten_check(inp.argmax(dim=axis), targ)\n", " return (pred == targ).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#For testing\n", "def change_targ(targ, n, c):\n", " idx = torch.randperm(len(targ))[:n]\n", " res = targ.clone()\n", " for i in idx: res[i] = (res[i]+random.randint(1,c-1))%c\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(4,5)\n", "y = x.argmax(dim=1)\n", "test_eq(accuracy(x,y), 1)\n", "y1 = change_targ(y, 2, 5)\n", "test_eq(accuracy(x,y1), 0.5)\n", "test_eq(accuracy(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.75)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def error_rate(inp, targ, axis=-1):\n", " \"1 - `accuracy`\"\n", " return 1 - accuracy(inp, targ, axis=axis)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(4,5)\n", "y = x.argmax(dim=1)\n", "test_eq(error_rate(x,y), 0)\n", "y1 = change_targ(y, 2, 5)\n", "test_eq(error_rate(x,y1), 0.5)\n", "test_eq(error_rate(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def top_k_accuracy(inp, targ, k=5, axis=-1):\n", " \"Computes the Top-k accuracy (`targ` is in the top `k` predictions of `inp`)\"\n", " inp = inp.topk(k=k, dim=axis)[1]\n", " targ = targ.unsqueeze(dim=axis).expand_as(inp)\n", " return (inp == targ).sum(dim=-1).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(6,5)\n", "y = torch.arange(0,6)\n", "test_eq(top_k_accuracy(x[:5],y[:5]), 1)\n", "test_eq(top_k_accuracy(x, y), 5/6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def APScore(axis=-1, average='macro', pos_label=1, sample_weight=None):\n", " \"Average Precision for single-label classification problems\"\n", " return skm_to_fastai(skm.average_precision_score, axis=axis,\n", " average=average, pos_label=pos_label, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def BalancedAccuracy(axis=-1, sample_weight=None, adjusted=False):\n", " \"Balanced Accuracy for single-label binary classification problems\"\n", " return skm_to_fastai(skm.balanced_accuracy_score, axis=axis,\n", " sample_weight=sample_weight, adjusted=adjusted)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html#sklearn.metrics.balanced_accuracy_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def BrierScore(axis=-1, sample_weight=None, pos_label=None):\n", " \"Brier score for single-label classification problems\"\n", " return skm_to_fastai(skm.brier_score_loss, axis=axis,\n", " sample_weight=sample_weight, pos_label=pos_label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def CohenKappa(axis=-1, labels=None, weights=None, sample_weight=None):\n", " \"Cohen kappa for single-label classification problems\"\n", " return skm_to_fastai(skm.cohen_kappa_score, axis=axis,\n", " sample_weight=sample_weight, pos_label=pos_label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html#sklearn.metrics.cohen_kappa_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def F1Score(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"F1 score for single-label classification problems\"\n", " return skm_to_fastai(skm.f1_score, axis=axis,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def FBeta(beta, axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"FBeta score with `beta` for single-label classification problems\"\n", " return skm_to_fastai(skm.fbeta_score, axis=axis,\n", " beta=beta, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html#sklearn.metrics.fbeta_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def HammingLoss(axis=-1, labels=None, sample_weight=None):\n", " \"Cohen kappa for single-label classification problems\"\n", " return skm_to_fastai(skm.hamming_loss, axis=axis,\n", " labels=labels, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html#sklearn.metrics.hamming_loss) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Jaccard(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Jaccard score for single-label classification problems\"\n", " return skm_to_fastai(skm.jaccard_similarity_score, axis=axis,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html#sklearn.metrics.jaccard_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def MatthewsCorrCoef(axis=-1, sample_weight=None):\n", " \"Matthews correlation coefficient for single-label binary classification problems\"\n", " return skm_to_fastai(skm.matthews_corrcoef, axis=axis, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html#sklearn.metrics.matthews_corrcoef) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Precision(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Precision for single-label classification problems\"\n", " return skm_to_fastai(skm.precision_score, axis=axis,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Recall(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Recall for single-label classification problems\"\n", " return skm_to_fastai(skm.recall_score, axis=axis,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None):\n", " \"Area Under the Receiver Operating Characteristic Curve for single-label binary classification problems\"\n", " return skm_to_fastai(skm.recall_score, axis=axis,\n", " laverage=average, sample_weight=sample_weight, max_fpr=max_fpr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Perplexity(AvgLoss):\n", " \"Perplexity (exponential of cross-entropy loss) for Language Models\"\n", " @property\n", " def value(self): return torch.exp(self.total/self.count) if self.count != 0 else None\n", " @property\n", " def name(self): return \"perplexity\"\n", "\n", "perplexity = Perplexity()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))\n", "tst = perplexity\n", "tst.reset()\n", "vals = [0,6,15,20]\n", "learn = TstLearner()\n", "for i in range(3): \n", " learn.y,learn.yb = x2[vals[i]:vals[i+1]],(x2[vals[i]:vals[i+1]],)\n", " learn.loss = F.cross_entropy(x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]])\n", " tst.accumulate(learn)\n", "test_close(tst.value, torch.exp(F.cross_entropy(x1,x2)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi-label classification" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):\n", " \"Compute accuracy when `inp` and `targ` are the same size.\"\n", " inp,targ = flatten_check(inp,targ)\n", " if sigmoid: inp = inp.sigmoid()\n", " return ((inp>thresh)==targ.bool()).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#For testing\n", "def change_1h_targ(targ, n):\n", " idx = torch.randperm(targ.numel())[:n]\n", " res = targ.clone().view(-1)\n", " for i in idx: res[i] = 1-res[i]\n", " return res.view(targ.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(4,5)\n", "y = (torch.sigmoid(x) >= 0.5).byte()\n", "test_eq(accuracy_multi(x,y), 1)\n", "test_eq(accuracy_multi(x,1-y), 0)\n", "y1 = change_1h_targ(y, 5)\n", "test_eq(accuracy_multi(x,y1), 0.75)\n", "\n", "#Different thresh\n", "y = (torch.sigmoid(x) >= 0.2).byte()\n", "test_eq(accuracy_multi(x,y, thresh=0.2), 1)\n", "test_eq(accuracy_multi(x,1-y, thresh=0.2), 0)\n", "y1 = change_1h_targ(y, 5)\n", "test_eq(accuracy_multi(x,y1, thresh=0.2), 0.75)\n", "\n", "#No sigmoid\n", "y = (x >= 0.5).byte()\n", "test_eq(accuracy_multi(x,y, sigmoid=False), 1)\n", "test_eq(accuracy_multi(x,1-y, sigmoid=False), 0)\n", "y1 = change_1h_targ(y, 5)\n", "test_eq(accuracy_multi(x,y1, sigmoid=False), 0.75)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def APScoreMulti(thresh=0.5, sigmoid=True, average='macro', pos_label=1, sample_weight=None):\n", " \"Average Precision for multi-label classification problems\"\n", " return skm_to_fastai(skm.average_precision_score, thresh=thresh, sigmoid=sigmoid,\n", " average=average, pos_label=pos_label, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def BrierScoreMulti(thresh=0.5, sigmoid=True, sample_weight=None, pos_label=None):\n", " \"Brier score for multi-label classification problems\"\n", " return skm_to_fastai(skm.brier_score_loss, thresh=thresh, sigmoid=sigmoid,\n", " sample_weight=sample_weight, pos_label=pos_label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def F1ScoreMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"F1 score for multi-label classification problems\"\n", " return skm_to_fastai(skm.f1_score, thresh=thresh, sigmoid=sigmoid,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def FBetaMulti(beta, thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"FBeta score with `beta` for multi-label classification problems\"\n", " return skm_to_fastai(skm.fbeta_score, thresh=thresh, sigmoid=sigmoid,\n", " beta=beta, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html#sklearn.metrics.fbeta_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def HammingLossMulti(thresh=0.5, sigmoid=True, labels=None, sample_weight=None):\n", " \"Cohen kappa for multi-label classification problems\"\n", " return skm_to_fastai(skm.hamming_loss, thresh=thresh, sigmoid=sigmoid,\n", " labels=labels, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html#sklearn.metrics.hamming_loss) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def JaccardMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Jaccard score for multi-label classification problems\"\n", " return skm_to_fastai(skm.jaccard_similarity_score, thresh=thresh, sigmoid=sigmoid,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html#sklearn.metrics.jaccard_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def MatthewsCorrCoefMulti(thresh=0.5, sigmoid=True, sample_weight=None):\n", " \"Matthews correlation coefficient for multi-label classification problems\"\n", " return skm_to_fastai(skm.matthews_corrcoef, thresh=thresh, sigmoid=sigmoid, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html#sklearn.metrics.matthews_corrcoef) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def PrecisionMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Precision for multi-label classification problems\"\n", " return skm_to_fastai(skm.precision_score, thresh=thresh, sigmoid=sigmoid,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RecallMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Recall for multi-label classification problems\"\n", " return skm_to_fastai(skm.recall_score, thresh=thresh, sigmoid=sigmoid,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RocAucMulti(thresh=0.5, sigmoid=True, average='macro', sample_weight=None, max_fpr=None):\n", " \"Area Under the Receiver Operating Characteristic Curve for multi-label binary classification problems\"\n", " return skm_to_fastai(skm.recall_score, thresh=thresh, sigmoid=sigmoid,\n", " laverage=average, sample_weight=sample_weight, max_fpr=max_fpr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regression" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def mse(inp,targ):\n", " \"Mean squared error between `inp` and `targ`.\"\n", " return F.mse_loss(*flatten_check(inp,targ))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(4,5),torch.randn(4,5)\n", "test_close(mse(x1,x2), (x1-x2).pow(2).mean())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _rmse(inp, targ): return torch.sqrt(F.mse_loss(inp, targ))\n", "rmse = AccumMetric(_rmse)\n", "rmse.__doc__ = \"Root mean squared error\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
rmse
[source]exp_rmspe
[source]