{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Training metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Metrics* for training fastai models are simply functions that take `input` and `target` tensors, and return some metric of interest for training. You can write your own metrics by defining a function of that type, and passing it to [`Learner`](/basic_train.html#Learner) in the [`metrics`](/metrics.html#metrics) parameter, or use one of the following pre-defined functions." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.basics import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Predefined metrics:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
accuracy
[source][test]accuracy
(**`input`**:`Tensor`, **`targs`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"accuracy_thresh
[source][test]accuracy_thresh
(**`y_pred`**:`Tensor`, **`y_true`**:`Tensor`, **`thresh`**:`float`=***`0.5`***, **`sigmoid`**:`bool`=***`True`***) → `Rank0Tensor`\n",
"\n",
"top_k_accuracy
[source][test]top_k_accuracy
(**`input`**:`Tensor`, **`targs`**:`Tensor`, **`k`**:`int`=***`5`***) → `Rank0Tensor`\n",
"\n",
"dice
[source][test]dice
(**`input`**:`Tensor`, **`targs`**:`Tensor`, **`iou`**:`bool`=***`False`***, **`eps`**:`float`=***`1e-08`***) → `Rank0Tensor`\n",
"\n",
"error_rate
[source][test]error_rate
(**`input`**:`Tensor`, **`targs`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"mean_squared_error
[source][test]mean_squared_error
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"mean_absolute_error
[source][test]mean_absolute_error
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"mean_squared_logarithmic_error
[source][test]mean_squared_logarithmic_error
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"exp_rmspe
[source][test]exp_rmspe
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"root_mean_squared_error
[source][test]root_mean_squared_error
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"fbeta
[source][test]fbeta
(**`y_pred`**:`Tensor`, **`y_true`**:`Tensor`, **`thresh`**:`float`=***`0.2`***, **`beta`**:`float`=***`2`***, **`eps`**:`float`=***`1e-09`***, **`sigmoid`**:`bool`=***`True`***) → `Rank0Tensor`\n",
"\n",
"explained_variance
[source][test]explained_variance
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"r2_score
[source][test]r2_score
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"class
RMSE
[source][test]RMSE
() :: [`RegMetrics`](/metrics.html#RegMetrics)\n",
"\n",
"No tests found for RMSE
. To contribute a test please refer to this guide and this discussion.
class
ExpRMSPE
[source][test]ExpRMSPE
() :: [`RegMetrics`](/metrics.html#RegMetrics)\n",
"\n",
"No tests found for ExpRMSPE
. To contribute a test please refer to this guide and this discussion.
class
Precision
[source][test]Precision
(**`average`**:`Optional`\\[`str`\\]=***`'binary'`***, **`pos_label`**:`int`=***`1`***, **`eps`**:`float`=***`1e-09`***) :: [`CMScores`](/metrics.html#CMScores)\n",
"\n",
"No tests found for Precision
. To contribute a test please refer to this guide and this discussion.
class
Recall
[source][test]Recall
(**`average`**:`Optional`\\[`str`\\]=***`'binary'`***, **`pos_label`**:`int`=***`1`***, **`eps`**:`float`=***`1e-09`***) :: [`CMScores`](/metrics.html#CMScores)\n",
"\n",
"No tests found for Recall
. To contribute a test please refer to this guide and this discussion.
class
FBeta
[source][test]FBeta
(**`average`**:`Optional`\\[`str`\\]=***`'binary'`***, **`pos_label`**:`int`=***`1`***, **`eps`**:`float`=***`1e-09`***, **`beta`**:`float`=***`2`***) :: [`CMScores`](/metrics.html#CMScores)\n",
"\n",
"No tests found for FBeta
. To contribute a test please refer to this guide and this discussion.
class
R2Score
[source][test]R2Score
() :: [`RegMetrics`](/metrics.html#RegMetrics)\n",
"\n",
"No tests found for R2Score
. To contribute a test please refer to this guide and this discussion.
class
ExplainedVariance
[source][test]ExplainedVariance
() :: [`RegMetrics`](/metrics.html#RegMetrics)\n",
"\n",
"No tests found for ExplainedVariance
. To contribute a test please refer to this guide and this discussion.
class
MatthewsCorreff
[source][test]MatthewsCorreff
() :: [`ConfusionMatrix`](/metrics.html#ConfusionMatrix)\n",
"\n",
"No tests found for MatthewsCorreff
. To contribute a test please refer to this guide and this discussion.
class
KappaScore
[source][test]KappaScore
(**`weights`**:`Optional`\\[`str`\\]=***`None`***) :: [`ConfusionMatrix`](/metrics.html#ConfusionMatrix)\n",
"\n",
"No tests found for KappaScore
. To contribute a test please refer to this guide and this discussion.
class
ConfusionMatrix
[source][test]ConfusionMatrix
() :: [`Callback`](/callback.html#Callback)\n",
"\n",
"No tests found for ConfusionMatrix
. To contribute a test please refer to this guide and this discussion.
class
MultiLabelFbeta
[source][test]MultiLabelFbeta
(**`beta`**=***`2`***, **`eps`**=***`1e-15`***, **`thresh`**=***`0.3`***, **`sigmoid`**=***`True`***, **`average`**=***`'micro'`***) :: [`Callback`](/callback.html#Callback)\n",
"\n",
"No tests found for MultiLabelFbeta
. To contribute a test please refer to this guide and this discussion.
auc_roc_score
[source][test]auc_roc_score
(**`input`**:`Tensor`, **`targ`**:`Tensor`)\n",
"\n",
"No tests found for auc_roc_score
. To contribute a test please refer to this guide and this discussion.
roc_curve
[source][test]roc_curve
(**`input`**:`Tensor`, **`targ`**:`Tensor`)\n",
"\n",
"No tests found for roc_curve
. To contribute a test please refer to this guide and this discussion.
class
AUROC
[source][test]AUROC
() :: [`Callback`](/callback.html#Callback)\n",
"\n",
"No tests found for AUROC
. To contribute a test please refer to this guide and this discussion.
on_epoch_begin
(for initialization)\n",
"- on_batch_begin
(if we need to have a look at the input/target and maybe modify them)\n",
"- on_batch_end
(to analyze the last results and update our computation)\n",
"- on_epoch_end
(to wrap up the final result that should be added to `last_metrics`)\n",
"\n",
"As an example, the following code is the exact implementation of the [`AverageMetric`](/callback.html#AverageMetric) callback that transforms a function like [`accuracy`](/metrics.html#accuracy) into a metric callback."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class AverageMetric(Callback):\n",
" \"Wrap a `func` in a callback for metrics computation.\"\n",
" def __init__(self, func):\n",
" # If it's a partial, use func.func\n",
" name = getattr(func,'func',func).__name__\n",
" self.func, self.name = func, name\n",
"\n",
" def on_epoch_begin(self, **kwargs):\n",
" \"Set the inner value to 0.\"\n",
" self.val, self.count = 0.,0\n",
"\n",
" def on_batch_end(self, last_output, last_target, **kwargs):\n",
" \"Update metric computation with `last_output` and `last_target`.\"\n",
" if not is_listy(last_target): last_target=[last_target]\n",
" self.count += last_target[0].size(0)\n",
" val = self.func(last_output, *last_target)\n",
" self.val += last_target[0].size(0) * val.detach().cpu()\n",
"\n",
" def on_epoch_end(self, last_metrics, **kwargs):\n",
" \"Set the final result in `last_metrics`.\"\n",
" return add_metrics(last_metrics, self.val/self.count)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here [`add_metrics`](/torch_core.html#add_metrics) is a convenience function that will return the proper dictionary for us:\n",
"```python\n",
"{'last_metrics': last_metrics + [self.val/self.count]}\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And here is another example that properly computes the precision for a given class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Precision(Callback):\n",
" \n",
" def on_epoch_begin(self, **kwargs):\n",
" self.correct, self.total = 0, 0\n",
" \n",
" def on_batch_end(self, last_output, last_target, **kwargs):\n",
" preds = last_output.argmax(1)\n",
" self.correct += ((preds==0) * (last_target==0)).float().sum()\n",
" self.total += (preds==0).float().sum()\n",
" \n",
" def on_epoch_end(self, last_metrics, **kwargs):\n",
" return add_metrics(last_metrics, self.correct/self.total)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following custom callback class example measures peak RAM usage during each epoch: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tracemalloc\n",
"class TraceMallocMetric(Callback):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.name = \"peak RAM\"\n",
"\n",
" def on_epoch_begin(self, **kwargs):\n",
" tracemalloc.start()\n",
" \n",
" def on_epoch_end(self, last_metrics, **kwargs):\n",
" current, peak = tracemalloc.get_traced_memory()\n",
" tracemalloc.stop()\n",
" return add_metrics(last_metrics, torch.tensor(peak))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To deploy it, you need to pass an instance of this custom metric in the [`metrics`](/metrics.html#metrics) argument:\n",
"\n",
"```python\n",
"learn = cnn_learner(data, model, metrics=[accuracy, TraceMallocMetric()])\n",
"learn.fit_one_cycle(3, max_lr=1e-2)\n",
"```\n",
"And then the output changes to:\n",
"```\n",
"Total time: 00:54\n",
"epoch\ttrain_loss\tvalid_loss\taccuracy\tpeak RAM\n",
" 1\t0.333352\t0.084342\t0.973800\t2395541.000000\n",
" 2\t0.096196\t0.038386\t0.988300\t2342145.000000\n",
" 3\t0.048722\t0.029234\t0.990200\t2342680.000000\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As mentioner earlier, using the [`metrics`](/metrics.html#metrics) argument with a custom metrics class is limited in the number of phases of the callback system it can access, it can only return one numerical value and as you can see its output is hardcoded to have 6 points of precision in the output, even if the number is an int.\n",
"\n",
"To overcome these limitations callback classes should be used instead.\n",
"\n",
"For example, the following class: \n",
"* uses phases not available for the metric classes \n",
"* it reports 3 columns, instead of just one\n",
"* its column report ints, instead of floats"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tracemalloc\n",
"class TraceMallocMultiColMetric(LearnerCallback):\n",
" _order=-20 # Needs to run before the recorder\n",
" def __init__(self, learn):\n",
" super().__init__(learn)\n",
" self.train_max = 0\n",
"\n",
" def on_train_begin(self, **kwargs):\n",
" self.learn.recorder.add_metric_names(['used', 'max_used', 'peak'])\n",
" \n",
" def on_batch_end(self, train, **kwargs):\n",
" # track max memory usage during the train phase\n",
" if train:\n",
" current, peak = tracemalloc.get_traced_memory()\n",
" self.train_max = max(self.train_max, current)\n",
" \n",
" def on_epoch_begin(self, **kwargs):\n",
" tracemalloc.start()\n",
"\n",
" def on_epoch_end(self, last_metrics, **kwargs):\n",
" current, peak = tracemalloc.get_traced_memory()\n",
" tracemalloc.stop()\n",
" return add_metrics(last_metrics, [current, self.train_max, peak])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note, that it subclasses [`LearnerCallback`](/basic_train.html#LearnerCallback) and not [`Callback`](/callback.html#Callback), since the former provides extra features not available in the latter.\n",
"\n",
"Also `_order=-20` is crucial - without it the custom columns will not be added - it tells the callback system to run this callback before the recorder system. \n",
"\n",
"To deploy it, you need to pass the name of the class (not an instance!) of the class in the `callback_fns` argument. This is because the `learn` object doesn't exist yet, and it's required to instantiate `TraceMallocMultiColMetric`. The system will do it for us automatically as soon as the learn object has been created.\n",
"\n",
"```python\n",
"learn = cnn_learner(data, model, metrics=[accuracy], callback_fns=TraceMallocMultiColMetric)\n",
"learn.fit_one_cycle(3, max_lr=1e-2)\n",
"```\n",
"And then the output changes to:\n",
"```\n",
"Total time: 00:53\n",
"epoch\ttrain_loss valid_loss accuracy\t used\tmax_used peak\n",
" 1\t0.321233\t0.068252\t0.978600\t156504\t2408404\t 2419891 \n",
" 2\t0.093551\t0.032776\t0.988500\t 79343\t2408404\t 2348085\n",
" 3\t0.047178\t0.025307\t0.992100\t 79568\t2408404\t 2342754\n",
"```\n",
"\n",
"Another way to do the same is by using `learn.callbacks.append`, and this time we need to instantiate `TraceMallocMultiColMetric` with `learn` object which we now have, as it is called after the latter was created:\n",
"\n",
"```python\n",
"learn = cnn_learner(data, model, metrics=[accuracy])\n",
"learn.callbacks.append(TraceMallocMultiColMetric(learn))\n",
"learn.fit_one_cycle(3, max_lr=1e-2)\n",
"```\n",
"\n",
"Configuring the custom metrics in the `learn` object sets them to run in all future [`fit`](/basic_train.html#fit)-family calls. However, if you'd like to configure it for just one call, you can configure it directly inside [`fit`](/basic_train.html#fit) or [`fit_one_cycle`](/train.html#fit_one_cycle):\n",
"\n",
"```python\n",
"learn = cnn_learner(data, model, metrics=[accuracy])\n",
"learn.fit_one_cycle(3, max_lr=1e-2, callbacks=TraceMallocMultiColMetric(learn))\n",
"```\n",
"\n",
"And to stress the differences: \n",
"* the `callback_fns` argument expects a classname or a list of those\n",
"* the [`callbacks`](/callbacks.html#callbacks) argument expects an instance of a class or a list of those\n",
"* `learn.callbacks.append` expects a single instance of a class\n",
"\n",
"For more examples, look inside fastai codebase and its test suite, search for classes that subclass either [`Callback`](/callback.html#Callback), [`LearnerCallback`](/basic_train.html#LearnerCallback) and subclasses of those two.\n",
"\n",
"Finally, while the above examples all add to the metrics, it's not a requirement. A callback can do anything it wants and it is not required to add its outcomes to the metrics printout. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Undocumented Methods - Methods moved below this line will intentionally be hidden"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"on_batch_end
[source][test]on_batch_end
(**`last_output`**:`Tensor`, **`last_target`**:`Tensor`, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_begin
[source][test]on_epoch_begin
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_begin
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
mean_absolute_error
[source][test]mean_absolute_error
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"mean_squared_logarithmic_error
[source][test]mean_squared_logarithmic_error
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"mean_squared_error
[source][test]mean_squared_error
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"root_mean_squared_error
[source][test]root_mean_squared_error
(**`pred`**:`Tensor`, **`targ`**:`Tensor`) → `Rank0Tensor`\n",
"\n",
"on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_train_end
[source][test]on_train_end
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_train_begin
[source][test]on_train_begin
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_train_begin
[source][test]on_train_begin
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
on_batch_end
[source][test]on_batch_end
(**`last_output`**:`Tensor`, **`last_target`**:`Tensor`, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_begin
[source][test]on_epoch_begin
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_begin
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.