In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab

In [None]:
#default_exp callback.preds

In [None]:
#export
from fastai.basics import *

In [None]:
#hide
from nbdev.showdoc import *
from fastai.test_utils import *

# Predictions callbacks

> Various callbacks to customize get_preds behaviors

## MCDropoutCallback

> Turns on dropout during inference, allowing you to call Learner.get_preds multiple times to approximate your model uncertainty using [Monte Carlo Dropout](https://arxiv.org/pdf/1506.02142.pdf).

In [None]:
#export
class MCDropoutCallback(Callback):
 def before_validate(self):
 for m in [m for m in flatten_model(self.model) if 'dropout' in m.__class__.__name__.lower()]:
 m.train()
 
 def after_validate(self):
 for m in [m for m in flatten_model(self.model) if 'dropout' in m.__class__.__name__.lower()]:
 m.eval()

In [None]:
learn = synth_learner()

# Call get_preds 10 times, then stack the predictions, yielding a tensor with shape [# of samples, batch_size, ...]
dist_preds = []
for i in range(10):
 preds, targs = learn.get_preds(cbs=[MCDropoutCallback()])
 dist_preds += [preds]

torch.stack(dist_preds).shape

torch.Size([10, 32, 1])

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.l