{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Training metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Metrics* for training fastai models are simply functions that take `input` and `target` tensors, and return some metric of interest for training. You can write your own metrics by defining a function of that type, and passing it to [`Learner`](/basic_train.html#Learner) in the [code]metrics[/code] parameter, or use one of the following pre-defined functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predefined metrics:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
accuracy
[source]accuracy
(`input`:`Tensor`, `targs`:`Tensor`) → `Rank0Tensor`\n",
"\n",
"Compute accuracy with `targs` when `input` is bs * n_classes. "
],
"text/plain": [
"accuracy_thresh
[source]accuracy_thresh
(`y_pred`:`Tensor`, `y_true`:`Tensor`, `thresh`:`float`=`0.5`, `sigmoid`:`bool`=`True`) → `Rank0Tensor`"
],
"text/plain": [
"dice
[source]dice
(`input`:`Tensor`, `targs`:`Tensor`, `iou`:`bool`=`False`) → `Rank0Tensor`\n",
"\n",
"Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems. "
],
"text/plain": [
"error_rate
[source]error_rate
(`input`:`Tensor`, `targs`:`Tensor`) → `Rank0Tensor`\n",
"\n",
"1 - [`accuracy`](/metrics.html#accuracy) "
],
"text/plain": [
"fbeta
[source]fbeta
(`y_pred`:`Tensor`, `y_true`:`Tensor`, `thresh`:`float`=`0.2`, `beta`:`float`=`2`, `eps`:`float`=`1e-09`, `sigmoid`:`bool`=`True`) → `Rank0Tensor`\n",
"\n",
"Computes the f_beta between preds and targets "
],
"text/plain": [
"exp_rmspe
[source]exp_rmspe
(`pred`:`Tensor`, `targ`:`Tensor`) → `Rank0Tensor`\n",
"\n",
"Exp RMSE between `pred` and `targ`. "
],
"text/plain": [
"on_epoch_begin
(for initialization)\n",
"- on_batch_begin
(if we need to have a look at the input/target and maybe modify them)\n",
"- on_batch_end
(to analyze the last results and update our computation)\n",
"- on_epoch_end
(to wrap up the final result that should be stored in `.metric`)\n",
"\n",
"As an example, is here the exact implementation of the [`AverageMetric`](/callback.html#AverageMetric) callback that transforms a function like [`accuracy`](/metrics.html#accuracy) into a metric callback."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AverageMetric(Callback):\n",
" def __init__(self, func):\n",
" self.func, self.name = func, func.__name__\n",
"\n",
" def on_epoch_begin(self, **kwargs):\n",
" self.val, self.count = 0.,0\n",
"\n",
" def on_batch_end(self, last_output, last_target, train, **kwargs):\n",
" self.count += last_target.size(0)\n",
" self.val += last_target.size(0) * self.func(last_output, last_target).detach().item()\n",
"\n",
" def on_epoch_end(self, **kwargs):\n",
" self.metric = self.val/self.count"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And here is another example that properly computes the precision for a given class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Precision(Callback):\n",
" \n",
" def on_epoch_begin(self, **kwargs):\n",
" self.correct, self.total = 0, 0\n",
" \n",
" def on_batch_end(self, last_output, last_target, **kwargs):\n",
" preds = last_output.argmax(1)\n",
" self.correct += ((preds==0) * (last_target==0)).float().sum()\n",
" self.total += (preds==0).float().sum()\n",
" \n",
" def on_epoch_end(self, **kwargs):\n",
" self.metric = self.correct/self.total"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Undocumented Methods - Methods moved below this line will intentionally be hidden"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## New Methods - Please document or move to the undocumented section"
]
}
],
"metadata": {
"jekyll": {
"keywords": "fastai",
"summary": "Useful metrics for training",
"title": "metrics"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}