{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Metrics* for training fastai models are simply functions that take `input` and `target` tensors, and return some metric of interest for training. You can write your own metrics by defining a function of that type, and passing it to [`Learner`](/basic_train.html#Learner) in the [code]metrics[/code] parameter, or use one of the following pre-defined functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predefined metrics:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

accuracy[source]

\n", "\n", "> accuracy(`input`:`Tensor`, `targs`:`Tensor`) → `Rank0Tensor`\n", "\n", "Compute accuracy with `targs` when `input` is bs * n_classes. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

accuracy_thresh[source]

\n", "\n", "> accuracy_thresh(`y_pred`:`Tensor`, `y_true`:`Tensor`, `thresh`:`float`=`0.5`, `sigmoid`:`bool`=`True`) → `Rank0Tensor`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(accuracy_thresh, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compute accuracy when `y_pred` and `y_true` for multi-label models, based on comparing predictions to `thresh`, `sigmoid` will be applied to `y_pred` if the corresponding flag is True." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

dice[source]

\n", "\n", "> dice(`input`:`Tensor`, `targs`:`Tensor`, `iou`:`bool`=`False`) → `Rank0Tensor`\n", "\n", "Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(dice)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

error_rate[source]

\n", "\n", "> error_rate(`input`:`Tensor`, `targs`:`Tensor`) → `Rank0Tensor`\n", "\n", "1 - [`accuracy`](/metrics.html#accuracy) " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(error_rate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fbeta[source]

\n", "\n", "> fbeta(`y_pred`:`Tensor`, `y_true`:`Tensor`, `thresh`:`float`=`0.2`, `beta`:`float`=`2`, `eps`:`float`=`1e-09`, `sigmoid`:`bool`=`True`) → `Rank0Tensor`\n", "\n", "Computes the f_beta between preds and targets " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(fbeta)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the [F1 score wikipedia page](https://en.wikipedia.org/wiki/F1_score) for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

exp_rmspe[source]

\n", "\n", "> exp_rmspe(`pred`:`Tensor`, `targ`:`Tensor`) → `Rank0Tensor`\n", "\n", "Exp RMSE between `pred` and `targ`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(exp_rmspe)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating your own metric" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creating a new metric can be as simple as creating a new function. If you metric is an average over the total number of elements in your dataset, just write the function that will compute it on a batch (taking `pred` and `targ` as arguments). It will then be automatically averaged over the batches (taking their different sizes into acount).\n", "\n", "Sometimes metrics aren't simple averages however. If we take the example of precision for instance, we have to divide the number of true positives by the number of predictions we made for that class. This isn't an average over the number of elements we have in the dataset, we only consider those where we made a positive prediction for a specific thing. Computing the precision for each batch, then averaging them will yield to a result that may be close to the real value, but won't be it exactly (and it really depends on how you deal with special case of 0 positive predicitions).\n", "\n", "This why in fastai, every metric is implemented as a callback. If you pass a regular function, the library transforms it to a proper callback called `AverageCallback`. The callback metrics are only called during the validation phase, and only for the following events: \n", "- on_epoch_begin (for initialization)\n", "- on_batch_begin (if we need to have a look at the input/target and maybe modify them)\n", "- on_batch_end (to analyze the last results and update our computation)\n", "- on_epoch_end(to wrap up the final result that should be stored in `.metric`)\n", "\n", "As an example, is here the exact implementation of the [`AverageMetric`](/callback.html#AverageMetric) callback that transforms a function like [`accuracy`](/metrics.html#accuracy) into a metric callback." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AverageMetric(Callback):\n", " def __init__(self, func):\n", " self.func, self.name = func, func.__name__\n", "\n", " def on_epoch_begin(self, **kwargs):\n", " self.val, self.count = 0.,0\n", "\n", " def on_batch_end(self, last_output, last_target, train, **kwargs):\n", " self.count += last_target.size(0)\n", " self.val += last_target.size(0) * self.func(last_output, last_target).detach().item()\n", "\n", " def on_epoch_end(self, **kwargs):\n", " self.metric = self.val/self.count" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here is another example that properly computes the precision for a given class." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Precision(Callback):\n", " \n", " def on_epoch_begin(self, **kwargs):\n", " self.correct, self.total = 0, 0\n", " \n", " def on_batch_end(self, last_output, last_target, **kwargs):\n", " preds = last_output.argmax(1)\n", " self.correct += ((preds==0) * (last_target==0)).float().sum()\n", " self.total += (preds==0).float().sum()\n", " \n", " def on_epoch_end(self, **kwargs):\n", " self.metric = self.correct/self.total" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Useful metrics for training", "title": "metrics" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }