{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.data.all import *\n", "from fastai2.optimizer import *\n", "from fastai2.learner import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp metrics\n", "# default_cls_lvl 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Metrics\n", "\n", "> Definition of the metrics that can be used in training models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Core metric" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is where the function that converts scikit-learn metrics to fastai metrics is defined. You should skip this section unless you want to know all about the internals of fastai." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "import sklearn.metrics as skm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "import scipy.stats as scs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export torch_core\n", "def flatten_check(inp, targ):\n", " \"Check that `out` and `targ` have the same number of elements and flatten them.\"\n", " inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)\n", " test_eq(len(inp), len(targ))\n", " return inp,targ" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(5,4),torch.randn(20)\n", "x1,x2 = flatten_check(x1,x2)\n", "test_eq(x1.shape, [20])\n", "test_eq(x2.shape, [20])\n", "x1,x2 = torch.randn(5,4),torch.randn(21)\n", "test_fail(lambda: flatten_check(x1,x2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "mk_class('ActivationType', **{o:o.lower() for o in ['No', 'Sigmoid', 'Softmax', 'BinarySoftmax']},\n", " doc=\"All possible activation classes for `AccumMetric\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AccumMetric(Metric):\n", " \"Stores predictions and targets on CPU in accumulate to perform final calculations with `func`.\"\n", " def __init__(self, func, dim_argmax=None, activation=ActivationType.No, thresh=None, to_np=False,\n", " invert_arg=False, flatten=True, **kwargs):\n", " store_attr(self,'func,dim_argmax,activation,thresh,flatten')\n", " self.to_np,self.invert_args,self.kwargs = to_np,invert_arg,kwargs\n", "\n", " def reset(self):\n", " \"Clear all targs and preds\"\n", " self.targs,self.preds = [],[]\n", "\n", " def accumulate(self, learn):\n", " \"Store targs and preds from `learn`, using activation function and argmax as appropriate\"\n", " pred = learn.pred\n", " if self.activation in [ActivationType.Softmax, ActivationType.BinarySoftmax]:\n", " pred = F.softmax(pred, dim=self.dim_argmax)\n", " if self.activation == ActivationType.BinarySoftmax: pred = pred[:, -1]\n", " elif self.activation == ActivationType.Sigmoid: pred = torch.sigmoid(pred)\n", " elif self.dim_argmax: pred = pred.argmax(dim=self.dim_argmax)\n", " if self.thresh: pred = (pred >= self.thresh)\n", " self.accum_values(pred,learn.y)\n", "\n", " def accum_values(self, preds, targs):\n", " \"Store targs and preds\"\n", " preds,targs = to_detach(preds),to_detach(targs)\n", " if self.flatten: preds,targs = flatten_check(preds,targs)\n", " self.preds.append(preds)\n", " self.targs.append(targs)\n", "\n", " def __call__(self, preds, targs):\n", " \"Calculate metric on one batch of data\"\n", " self.reset()\n", " self.accum_values(preds,targs)\n", " return self.value\n", "\n", " @property\n", " def value(self):\n", " \"Value of the metric using accumulated preds and targs\"\n", " if len(self.preds) == 0: return\n", " preds,targs = torch.cat(self.preds),torch.cat(self.targs)\n", " if self.to_np: preds,targs = preds.numpy(),targs.numpy()\n", " return self.func(targs, preds, **self.kwargs) if self.invert_args else self.func(preds, targs, **self.kwargs)\n", "\n", " @property\n", " def name(self): return self.func.func.__name__ if hasattr(self.func, 'func') else self.func.__name__" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`func` is only applied to the accumulated predictions/targets when the `value` attribute is asked for (so at the end of a validation/trianing phase, in use with `Learner` and its `Recorder`).The signature of `func` should be `inp,targ` (where `inp` are the predictions of the model and `targ` the corresponding labels).\n", "\n", "For classification problems with single label, predictions need to be transformed with a sofmax then an argmax before being compared to the targets. Since a softmax doesn't change the order of the numbers, we can just apply the argmax. Pass along `dim_argmax` to have this done by `AccumMetric` (usually -1 will work pretty well). If you need to pass to your metrics the probabilities and not the predictions, use `softmax=True`.\n", "\n", "For classification problems with multiple labels, or if your targets are onehot-encoded, predictions may need to pass through a sigmoid (if it wasn't included in your model) then be compared to a given threshold (to decide between 0 and 1), this is done by `AccumMetric` if you pass `sigmoid=True` and/or a value for `thresh`.\n", "\n", "If you want to use a metric function sklearn.metrics, you will need to convert predictions and labels to numpy arrays with `to_np=True`. Also, scikit-learn metrics adopt the convention `y_true`, `y_preds` which is the opposite from us, so you will need to pass `invert_arg=True` to make `AccumMetric` do the inversion for you." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#For testing: a fake learner and a metric that isn't an average\n", "class TstLearner():\n", " def __init__(self): self.pred,self.y = None,None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _l2_mean(x,y): return torch.sqrt((x.float()-y.float()).pow(2).mean())\n", "\n", "#Go through a fake cycle with various batch sizes and computes the value of met\n", "def compute_val(met, x1, x2):\n", " met.reset()\n", " vals = [0,6,15,20]\n", " learn = TstLearner()\n", " for i in range(3): \n", " learn.pred,learn.y = x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]]\n", " met.accumulate(learn)\n", " return met.value" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "tst = AccumMetric(_l2_mean)\n", "test_close(compute_val(tst, x1, x2), _l2_mean(x1, x2))\n", "test_eq(torch.cat(tst.preds), x1.view(-1))\n", "test_eq(torch.cat(tst.targs), x2.view(-1))\n", "\n", "#test argmax\n", "x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))\n", "tst = AccumMetric(_l2_mean, dim_argmax=-1)\n", "test_close(compute_val(tst, x1, x2), _l2_mean(x1.argmax(dim=-1), x2))\n", "\n", "#test thresh\n", "x1,x2 = torch.randn(20,5),torch.randint(0, 2, (20,5)).bool()\n", "tst = AccumMetric(_l2_mean, thresh=0.5)\n", "test_close(compute_val(tst, x1, x2), _l2_mean((x1 >= 0.5), x2))\n", "\n", "#test sigmoid\n", "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "tst = AccumMetric(_l2_mean, activation=ActivationType.Sigmoid)\n", "test_close(compute_val(tst, x1, x2), _l2_mean(torch.sigmoid(x1), x2))\n", "\n", "#test to_np\n", "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "tst = AccumMetric(lambda x,y: isinstance(x, np.ndarray) and isinstance(y, np.ndarray), to_np=True)\n", "assert compute_val(tst, x1, x2)\n", "\n", "#test invert_arg\n", "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "tst = AccumMetric(lambda x,y: torch.sqrt(x.pow(2).mean()))\n", "test_close(compute_val(tst, x1, x2), torch.sqrt(x1.pow(2).mean()))\n", "tst = AccumMetric(lambda x,y: torch.sqrt(x.pow(2).mean()), invert_arg=True)\n", "test_close(compute_val(tst, x1, x2), torch.sqrt(x2.pow(2).mean()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "def _l2_mean(x,y): return torch.sqrt((x.argmax(dim=-1).float()-y.float()).pow(2).mean())\n", "x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))\n", "tst = AccumMetric(_l2_mean, dim_argmax=-1, flatten=False, activation=ActivationType.Softmax)\n", "test_close(compute_val(tst, x1, x2), _l2_mean(F.softmax(x1, dim=-1), x2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def skm_to_fastai(func, is_class=True, thresh=None, axis=-1, activation=None, **kwargs):\n", " \"Convert `func` from sklearn.metrics to a fastai metric\"\n", " dim_argmax = axis if is_class and thresh is None else None\n", " if activation is None:\n", " activation = ActivationType.Sigmoid if (is_class and thresh is not None) else ActivationType.No\n", " return AccumMetric(func, dim_argmax=dim_argmax, activation=activation, thresh=thresh,\n", " to_np=True, invert_arg=True, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the quickest way to use a sckit-learn metric in a fastai training loop. `is_class` indicates if you are in a classification problem or not. In this case:\n", "- leaving `thresh` to `None` indicates it's a single-label classification problem and predictions will pass through an argmax over `axis` before being compared to the targets\n", "- setting a value for `thresh` indicates it's a multi-label classification problem and predictions will pass through a sigmoid (can be deactivated with `sigmoid=False`) and be compared to `thresh` before being compared to the targets\n", "\n", "If `is_class=False`, it indicates you are in a regression problem, and predictions are compared to the targets without being modified. In all cases, `kwargs` are extra keyword arguments passed to `func`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_single = skm_to_fastai(skm.precision_score)\n", "x1,x2 = torch.randn(20,2),torch.randint(0, 2, (20,))\n", "test_close(compute_val(tst_single, x1, x2), skm.precision_score(x2, x1.argmax(dim=-1)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2)\n", "x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))\n", "test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, torch.sigmoid(x1) >= 0.2))\n", "\n", "tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2, activation=ActivationType.No)\n", "x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))\n", "test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, x1 >= 0.2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_reg = skm_to_fastai(skm.r2_score, is_class=False)\n", "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "test_close(compute_val(tst_reg, x1, x2), skm.r2_score(x2.view(-1), x1.view(-1)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_close(tst_reg(x1, x2), skm.r2_score(x2.view(-1), x1.view(-1)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def optim_metric(f, argname, bounds, tol=0.01, do_neg=True, get_x=False):\n", " \"Replace metric `f` with a version that optimizes argument `argname`\"\n", " def _f(preds, targs):\n", " def minfunc(x):\n", " kwargs = {argname:x}\n", " res = f(preds, targs, **kwargs)\n", " return -res if do_neg else res\n", " optres = scipy.optimize.minimize_scalar(minfunc, bounds=bounds, method='bounded',\n", " options={'xatol':0.01})\n", " fun = -optres.fun if do_neg else optres.fun\n", " return (fun,optres.x) if get_x else fun\n", " _f.__name__ = f'opt_{f.__name__}'\n", " return _f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Single-label classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: All functions defined in this section are intended for single-label classification and targets that are not one-hot encoded. For multi-label problems or one-hot encoded targets, use the version suffixed with multi." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: Many metrics in fastai are thin wrappers around sklearn functionality. However, sklearn metrics can handle python list strings, amongst other things, whereas fastai metrics work with PyTorch, and thus require tensors. The arguments that are passed to metrics are after all transformations, such as categories being converted to indices, have occured. This means that when you pass a label of a metric, for instance, that you must pass indices, not strings. This can be converted with `vocab.map_obj`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def accuracy(inp, targ, axis=-1):\n", " \"Compute accuracy with `targ` when `pred` is bs * n_classes\"\n", " pred,targ = flatten_check(inp.argmax(dim=axis), targ)\n", " return (pred == targ).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#For testing\n", "def change_targ(targ, n, c):\n", " idx = torch.randperm(len(targ))[:n]\n", " res = targ.clone()\n", " for i in idx: res[i] = (res[i]+random.randint(1,c-1))%c\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(4,5)\n", "y = x.argmax(dim=1)\n", "test_eq(accuracy(x,y), 1)\n", "y1 = change_targ(y, 2, 5)\n", "test_eq(accuracy(x,y1), 0.5)\n", "test_eq(accuracy(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.75)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def error_rate(inp, targ, axis=-1):\n", " \"1 - `accuracy`\"\n", " return 1 - accuracy(inp, targ, axis=axis)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(4,5)\n", "y = x.argmax(dim=1)\n", "test_eq(error_rate(x,y), 0)\n", "y1 = change_targ(y, 2, 5)\n", "test_eq(error_rate(x,y1), 0.5)\n", "test_eq(error_rate(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def top_k_accuracy(inp, targ, k=5, axis=-1):\n", " \"Computes the Top-k accuracy (`targ` is in the top `k` predictions of `inp`)\"\n", " inp = inp.topk(k=k, dim=axis)[1]\n", " targ = targ.unsqueeze(dim=axis).expand_as(inp)\n", " return (inp == targ).sum(dim=-1).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(6,5)\n", "y = torch.arange(0,6)\n", "test_eq(top_k_accuracy(x[:5],y[:5]), 1)\n", "test_eq(top_k_accuracy(x, y), 5/6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def APScoreBinary(axis=-1, average='macro', pos_label=1, sample_weight=None):\n", " \"Average Precision for single-label binary classification problems\"\n", " return skm_to_fastai(skm.average_precision_score, axis=axis, activation=ActivationType.BinarySoftmax,\n", " average=average, pos_label=pos_label, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def BalancedAccuracy(axis=-1, sample_weight=None, adjusted=False):\n", " \"Balanced Accuracy for single-label binary classification problems\"\n", " return skm_to_fastai(skm.balanced_accuracy_score, axis=axis,\n", " sample_weight=sample_weight, adjusted=adjusted)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html#sklearn.metrics.balanced_accuracy_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def BrierScore(axis=-1, sample_weight=None, pos_label=None):\n", " \"Brier score for single-label classification problems\"\n", " return skm_to_fastai(skm.brier_score_loss, axis=axis,\n", " sample_weight=sample_weight, pos_label=pos_label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def CohenKappa(axis=-1, labels=None, weights=None, sample_weight=None):\n", " \"Cohen kappa for single-label classification problems\"\n", " return skm_to_fastai(skm.cohen_kappa_score, axis=axis, labels=labels, weights=weights,\n", " sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html#sklearn.metrics.cohen_kappa_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def F1Score(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"F1 score for single-label classification problems\"\n", " return skm_to_fastai(skm.f1_score, axis=axis,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def FBeta(beta, axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"FBeta score with `beta` for single-label classification problems\"\n", " return skm_to_fastai(skm.fbeta_score, axis=axis,\n", " beta=beta, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html#sklearn.metrics.fbeta_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def HammingLoss(axis=-1, sample_weight=None):\n", " \"Hamming loss for single-label classification problems\"\n", " return skm_to_fastai(skm.hamming_loss, axis=axis,\n", " sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html#sklearn.metrics.hamming_loss) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Jaccard(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Jaccard score for single-label classification problems\"\n", " return skm_to_fastai(skm.jaccard_score, axis=axis,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html#sklearn.metrics.jaccard_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Precision(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Precision for single-label classification problems\"\n", " return skm_to_fastai(skm.precision_score, axis=axis,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Recall(axis=-1, labels=None, pos_label=1, average='binary', sample_weight=None):\n", " \"Recall for single-label classification problems\"\n", " return skm_to_fastai(skm.recall_score, axis=axis,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RocAuc(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='ovr'):\n", " \"Area Under the Receiver Operating Characteristic Curve for single-label multiclass classification problems\"\n", " assert multi_class in ['ovr', 'ovo']\n", " return skm_to_fastai(skm.roc_auc_score, axis=axis, activation=ActivationType.Softmax, flatten=False,\n", " average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RocAucBinary(axis=-1, average='macro', sample_weight=None, max_fpr=None, multi_class='raise'):\n", " \"Area Under the Receiver Operating Characteristic Curve for single-label binary classification problems\"\n", " return skm_to_fastai(skm.roc_auc_score, axis=axis, activation=ActivationType.BinarySoftmax,\n", " average=average, sample_weight=sample_weight, max_fpr=max_fpr, multi_class=multi_class)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def MatthewsCorrCoef(sample_weight=None, **kwargs):\n", " \"Matthews correlation coefficient for single-label classification problems\"\n", " return skm_to_fastai(skm.matthews_corrcoef, sample_weight=sample_weight, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html#sklearn.metrics.matthews_corrcoef) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Perplexity(AvgLoss):\n", " \"Perplexity (exponential of cross-entropy loss) for Language Models\"\n", " @property\n", " def value(self): return torch.exp(self.total/self.count) if self.count != 0 else None\n", " @property\n", " def name(self): return \"perplexity\"\n", "\n", "perplexity = Perplexity()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))\n", "tst = perplexity\n", "tst.reset()\n", "vals = [0,6,15,20]\n", "learn = TstLearner()\n", "for i in range(3): \n", " learn.y,learn.yb = x2[vals[i]:vals[i+1]],(x2[vals[i]:vals[i+1]],)\n", " learn.loss = F.cross_entropy(x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]])\n", " tst.accumulate(learn)\n", "test_close(tst.value, torch.exp(F.cross_entropy(x1,x2)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi-label classification" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):\n", " \"Compute accuracy when `inp` and `targ` are the same size.\"\n", " inp,targ = flatten_check(inp,targ)\n", " if sigmoid: inp = inp.sigmoid()\n", " return ((inp>thresh)==targ.bool()).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#For testing\n", "def change_1h_targ(targ, n):\n", " idx = torch.randperm(targ.numel())[:n]\n", " res = targ.clone().view(-1)\n", " for i in idx: res[i] = 1-res[i]\n", " return res.view(targ.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(4,5)\n", "y = (torch.sigmoid(x) >= 0.5).byte()\n", "test_eq(accuracy_multi(x,y), 1)\n", "test_eq(accuracy_multi(x,1-y), 0)\n", "y1 = change_1h_targ(y, 5)\n", "test_eq(accuracy_multi(x,y1), 0.75)\n", "\n", "#Different thresh\n", "y = (torch.sigmoid(x) >= 0.2).byte()\n", "test_eq(accuracy_multi(x,y, thresh=0.2), 1)\n", "test_eq(accuracy_multi(x,1-y, thresh=0.2), 0)\n", "y1 = change_1h_targ(y, 5)\n", "test_eq(accuracy_multi(x,y1, thresh=0.2), 0.75)\n", "\n", "#No sigmoid\n", "y = (x >= 0.5).byte()\n", "test_eq(accuracy_multi(x,y, sigmoid=False), 1)\n", "test_eq(accuracy_multi(x,1-y, sigmoid=False), 0)\n", "y1 = change_1h_targ(y, 5)\n", "test_eq(accuracy_multi(x,y1, sigmoid=False), 0.75)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def APScoreMulti(sigmoid=True, average='macro', pos_label=1, sample_weight=None):\n", " \"Average Precision for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.average_precision_score, activation=activation, flatten=False,\n", " average=average, pos_label=pos_label, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def BrierScoreMulti(thresh=0.5, sigmoid=True, sample_weight=None, pos_label=None):\n", " \"Brier score for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.brier_score_loss, thresh=thresh, activation=activation, flatten=False,\n", " sample_weight=sample_weight, pos_label=pos_label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.brier_score_loss.html#sklearn.metrics.brier_score_loss) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def F1ScoreMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n", " \"F1 score for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.f1_score, thresh=thresh, activation=activation, flatten=False,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#sklearn.metrics.f1_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def FBetaMulti(beta, thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n", " \"FBeta score with `beta` for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.fbeta_score, thresh=thresh, activation=activation, flatten=False,\n", " beta=beta, labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html#sklearn.metrics.fbeta_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def HammingLossMulti(thresh=0.5, sigmoid=True, labels=None, sample_weight=None):\n", " \"Hamming loss for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.hamming_loss, thresh=thresh, activation=activation, flatten=False,\n", " sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.hamming_loss.html#sklearn.metrics.hamming_loss) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def JaccardMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n", " \"Jaccard score for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.jaccard_score, thresh=thresh, activation=activation, flatten=False,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.jaccard_score.html#sklearn.metrics.jaccard_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def MatthewsCorrCoefMulti(thresh=0.5, sigmoid=True, sample_weight=None):\n", " \"Matthews correlation coefficient for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.matthews_corrcoef, thresh=thresh, activation=activation, flatten=False, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html#sklearn.metrics.matthews_corrcoef) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def PrecisionMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n", " \"Precision for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.precision_score, thresh=thresh, activation=activation, flatten=False,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html#sklearn.metrics.precision_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RecallMulti(thresh=0.5, sigmoid=True, labels=None, pos_label=1, average='macro', sample_weight=None):\n", " \"Recall for multi-label classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.recall_score, thresh=thresh, activation=activation, flatten=False,\n", " labels=labels, pos_label=pos_label, average=average, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html#sklearn.metrics.recall_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RocAucMulti(sigmoid=True, average='macro', sample_weight=None, max_fpr=None):\n", " \"Area Under the Receiver Operating Characteristic Curve for multi-label binary classification problems\"\n", " activation = ActivationType.Sigmoid if sigmoid else ActivationType.No\n", " return skm_to_fastai(skm.roc_auc_score, activation=activation, flatten=False,\n", " average=average, sample_weight=sample_weight, max_fpr=max_fpr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "roc_auc_metric = RocAucMulti(sigmoid=False)\n", "x,y = torch.tensor([np.arange(start=0, stop=0.2, step=0.04)]*20), torch.tensor([0, 0, 1, 1]).repeat(5)\n", "assert compute_val(roc_auc_metric, x, y) == 0.5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score) for more details." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regression" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def mse(inp,targ):\n", " \"Mean squared error between `inp` and `targ`.\"\n", " return F.mse_loss(*flatten_check(inp,targ))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(4,5),torch.randn(4,5)\n", "test_close(mse(x1,x2), (x1-x2).pow(2).mean())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _rmse(inp, targ): return torch.sqrt(F.mse_loss(inp, targ))\n", "rmse = AccumMetric(_rmse)\n", "rmse.__doc__ = \"Root mean squared error\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_doc(rmse, name=\"rmse\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "test_eq(compute_val(rmse, x1, x2), torch.sqrt(F.mse_loss(x1,x2)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def mae(inp,targ):\n", " \"Mean absolute error between `inp` and `targ`.\"\n", " inp,targ = flatten_check(inp,targ)\n", " return torch.abs(inp - targ).mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(4,5),torch.randn(4,5)\n", "test_eq(mae(x1,x2), torch.abs(x1-x2).mean())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def msle(inp, targ):\n", " \"Mean squared logarithmic error between `inp` and `targ`.\"\n", " inp,targ = flatten_check(inp,targ)\n", " return F.mse_loss(torch.log(1 + inp), torch.log(1 + targ))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(4,5),torch.randn(4,5)\n", "x1,x2 = torch.relu(x1),torch.relu(x2)\n", "test_close(msle(x1,x2), (torch.log(x1+1)-torch.log(x2+1)).pow(2).mean())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _exp_rmspe(inp,targ):\n", " inp,targ = torch.exp(inp),torch.exp(targ)\n", " return torch.sqrt(((targ - inp)/targ).pow(2).mean())\n", "exp_rmspe = AccumMetric(_exp_rmspe)\n", "exp_rmspe.__doc__ = \"Root mean square percentage error of the exponential of predictions and targets\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_doc(exp_rmspe, name=\"exp_rmspe\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1,x2 = torch.randn(20,5),torch.randn(20,5)\n", "test_eq(compute_val(exp_rmspe, x1, x2), torch.sqrt((((torch.exp(x2) - torch.exp(x1))/torch.exp(x2))**2).mean()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def ExplainedVariance(sample_weight=None):\n", " \"Explained variance betzeen predictions and targets\"\n", " return skm_to_fastai(skm.explained_variance_score, is_class=False, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.explained_variance_score.html#sklearn.metrics.explained_variance_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def R2Score(sample_weight=None):\n", " \"R2 score betzeen predictions and targets\"\n", " return skm_to_fastai(skm.r2_score, is_class=False, sample_weight=sample_weight)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scikit-learn documentation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html#sklearn.metrics.r2_score) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates(AccumMetric)\n", "def PearsonCorrCoef(dim_argmax=None, **kwargs):\n", " \"Pearson correlation coefficient for regression problem\"\n", " def pearsonr(x,y): return scs.pearsonr(x,y)[0]\n", " return AccumMetric(pearsonr, invert_arg=False, dim_argmax=dim_argmax, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scipy documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html?highlight=pearson#scipy.stats.pearsonr) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randint(-999, 999,(20,))\n", "y = torch.randint(-999, 999,(20,))\n", "test_eq(compute_val(PearsonCorrCoef(), x, y), scs.pearsonr(x.view(-1), y.view(-1))[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates(AccumMetric)\n", "def SpearmanCorrCoef(dim_argmax=None, axis=0, nan_policy='propagate', **kwargs):\n", " \"Spearman correlation coefficient for regression problem\"\n", " def spearmanr(a,b=None,**kwargs): return scs.spearmanr(a,b,**kwargs)[0]\n", " return AccumMetric(partial(spearmanr, axis=axis, nan_policy=nan_policy),\n", " invert_arg=False, dim_argmax=dim_argmax, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [scipy documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html?highlight=spearman#scipy.stats.spearmanr) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randint(-999, 999,(20,))\n", "y = torch.randint(-999, 999,(20,))\n", "test_eq(compute_val(SpearmanCorrCoef(), x, y), scs.spearmanr(x.view(-1), y.view(-1))[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Segmentation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def foreground_acc(inp, targ, bkg_idx=0, axis=1):\n", " \"Computes non-background accuracy for multiclass segmentation\"\n", " targ = targ.squeeze(1)\n", " mask = targ != bkg_idx\n", " return (inp.argmax(dim=axis)[mask]==targ[mask]).float().mean()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(4,5,3,3)\n", "y = x.argmax(dim=1)[:,None]\n", "test_eq(foreground_acc(x,y), 1)\n", "y[0] = 0 #the 0s are ignored so we get the same value\n", "test_eq(foreground_acc(x,y), 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Dice(Metric):\n", " \"Dice coefficient metric for binary target in segmentation\"\n", " def __init__(self, axis=1): self.axis = axis\n", " def reset(self): self.inter,self.union = 0,0\n", " def accumulate(self, learn):\n", " pred,targ = flatten_check(learn.pred.argmax(dim=self.axis), learn.y)\n", " self.inter += (pred*targ).float().sum().item()\n", " self.union += (pred+targ).float().sum().item()\n", "\n", " @property\n", " def value(self): return 2. * self.inter/self.union if self.union > 0 else None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1 = torch.randn(20,2,3,3)\n", "x2 = torch.randint(0, 2, (20, 3, 3))\n", "pred = x1.argmax(1)\n", "inter = (pred*x2).float().sum().item()\n", "union = (pred+x2).float().sum().item()\n", "test_eq(compute_val(Dice(), x1, x2), 2*inter/union)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class JaccardCoeff(Dice):\n", " \"Implemetation of the jaccard coefficient that is lighter in RAM\"\n", " @property\n", " def value(self): return self.inter/(self.union-self.inter) if self.union > 0 else None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x1 = torch.randn(20,2,3,3)\n", "x2 = torch.randint(0, 2, (20, 3, 3))\n", "pred = x1.argmax(1)\n", "inter = (pred*x2).float().sum().item()\n", "union = (pred+x2).float().sum().item()\n", "test_eq(compute_val(JaccardCoeff(), x1, x2), inter/(union-inter))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## NLP" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class CorpusBLEUMetric(Metric):\n", " def __init__(self, vocab_sz=5000, axis=-1):\n", " \"BLEU Metric calculated over the validation corpus\"\n", " self.metric_name = 'CorpusBLEU'\n", " self.axis, self.vocab_sz = axis, vocab_sz\n", " self.pred_len,self.targ_len,self.samp_idx,self.corrects,self.counts, = 0,0,0,[0]*4,[0]*4\n", "\n", " def reset(self):\n", " self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4\n", "\n", " class NGram():\n", " def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n\n", " def __eq__(self, other):\n", " if len(self.ngram) != len(other.ngram): return False\n", " return np.all(np.array(self.ngram) == np.array(other.ngram))\n", " def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))\n", "\n", " def get_grams(self, x, n, max_n=5000):\n", " return x if n==1 else [self.NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]\n", "\n", " def get_correct_ngrams(self, pred, targ, n, max_n=5000):\n", " pred_grams,targ_grams = self.get_grams(pred, n, max_n=max_n),self.get_grams(targ, n, max_n=max_n)\n", " pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)\n", " return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)\n", "\n", " def accumulate(self, learn):\n", " if learn.training: return None\n", " else:\n", " last_output = learn.pred.argmax(dim=self.axis)\n", " last_target = learn.y\n", " for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):\n", " self.pred_len += len(pred)\n", " self.targ_len += len(targ)\n", " smooth_mteval = 1\n", " for i in range(4):\n", " c,t = self.get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)\n", " if c == 0:\n", " smooth_mteval *= 2\n", " c = 1 / smooth_mteval # exp smoothing, method 3 from http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf\n", " self.corrects[i] += c\n", " self.counts[i] += t\n", "\n", " @property\n", " def value(self):\n", " if self.counts == 0: return None\n", " elif max(self.corrects) == 0: return 0.0\n", " else:\n", " precs = [c/t for c,t in zip(self.corrects,self.counts)]\n", " len_penalty = math.exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1\n", " return len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_vcb_emb(pred, targ):\n", " # create vocab \"embedding\" for predictions\n", " vcb_sz = max(torch.unique(torch.cat([pred, targ])))+1\n", " pred_emb=torch.zeros(pred.size()[0], pred.size()[1] ,vcb_sz)\n", " for i,v in enumerate(pred):\n", " pred_emb[i].scatter_(1, v.view(len(v),1),1)\n", " return pred_emb\n", "\n", "def compute_bleu_val(met, x1, x2):\n", " met.reset()\n", " learn = TstLearner()\n", " learn.training=False \n", " for i in range(len(x1)): \n", " learn.pred,learn.y = x1, x2\n", " met.accumulate(learn)\n", " return met.value\n", "\n", "targ = torch.tensor([[1,2,3,4,5,6,1,7,8]]) \n", "pred = torch.tensor([[1,9,3,4,5,6,1,10,8]])\n", "pred_emb = create_vcb_emb(pred, targ)\n", "test_close(compute_bleu_val(CorpusBLEUMetric(), pred_emb, targ), 0.48549)\n", "\n", "targ = torch.tensor([[1,2,3,4,5,6,1,7,8],[1,2,3,4,5,6,1,7,8]]) \n", "pred = torch.tensor([[1,9,3,4,5,6,1,10,8],[1,9,3,4,5,6,1,10,8]])\n", "pred_emb = create_vcb_emb(pred, targ)\n", "test_close(compute_bleu_val(CorpusBLEUMetric(), pred_emb, targ), 0.48549)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The BLEU metric was introduced in [this article](https://www.aclweb.org/anthology/P02-1040) to come up with a way to evaluate the performance of translation models. It's based on the precision of n-grams in your prediction compared to your target. See the [fastai NLP course BLEU notebook](https://github.com/fastai/course-nlp/blob/master/bleu_metric.ipynb) for a more detailed description of BLEU.\n", "\n", "The smoothing used in the precision calculation is the same as in [SacreBLEU](https://github.com/mjpost/sacrebleu/blob/32c54cdd0dfd6a9fadd5805f2ea189ac0df63907/sacrebleu/sacrebleu.py#L540-L542), which in turn is \"method 3\" from the [Chen & Cherry, 2014](http://acl2014.org/acl2014/W14-33/pdf/W14-3346.pdf) paper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LossMetrics -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class LossMetric(AvgMetric):\n", " \"Create a metric from `loss_func.attr` named `nm`\"\n", " def __init__(self, attr, nm=None): store_attr(self, 'attr,nm')\n", " def accumulate(self, learn):\n", " bs = find_bs(learn.yb)\n", " self.total += to_detach(getattr(learn.loss_func, self.attr, 0))*bs\n", " self.count += bs\n", "\n", " @property\n", " def name(self): return self.attr if self.nm is None else self.nm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def LossMetrics(attrs, nms=None):\n", " \"List of `LossMetric` for each of `attrs` and `nms`\"\n", " if isinstance(attrs, str): attrs = attrs.split(',')\n", " nms = attrs if nms is None else nms.split(',') if isinstance(nms, str) else nms\n", " return [LossMetric(a, n) for a,n in zip(attrs,nms)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from fastai2.test_utils import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CombineL1L2(Module):\n", " def forward(self, out, targ):\n", " self.l1 = F.l1_loss(out, targ)\n", " self.l2 = F.mse_loss(out, targ)\n", " return self.l1+self.l2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner(metrics=LossMetrics('l1,l2'))\n", "learn.loss_func = CombineL1L2()\n", "learn.fit(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }