{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#all_slow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai2.vision.all import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial - Migrating from pure PyTorch\n", "\n", "> Incrementally adding fastai goodness to your PyTorch models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We're going to use the MNIST training code from the official PyTorch examples, slightly reformatted for space, updated from AdaDelta to AdamW, and converted from a script to a module. There's a lot of code, so we've put it into migrating_pytorch.py!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from migrating_pytorch import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can entirely replace the custom training loop with fastai's. That means you can get rid of `train()`, `test()`, and the epoch loop in the original code, and replace it all with just this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataLoaders(train_loader, test_loader)\n", "learn = Learner(data, Net(), loss_func=F.nll_loss, opt_func=Adam, metrics=accuracy, cbs=CudaCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also added `CudaCallback` to have the model and data moved to the GPU for us. Alternatively, you can use the fastai `DataLoader`, which provides a superset of the functionality of PyTorch's (with the same API), and can handle moving data to the GPU for us (see `migrating_ignite.ipynb` for an example of this approach). \n", "\n", "fastai supports many schedulers. We recommend fitting with 1cycle training:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 0 | \n", "0.129090 | \n", "0.052289 | \n", "0.982600 | \n", "00:17 | \n", "