(lightning_mnist_example)=

# Train a Pytorch Lightning Image Classifier

This example introduces how to train a Pytorch Lightning Module using Ray Train {class}`TorchTrainer <ray.train.torch.TorchTrainer>`. We will demonstrate how to train a basic neural network on the MNIST dataset with distributed data parallelism.


In [1]:
!pip install "torchmetrics>=0.9" "pytorch_lightning>=1.6" 

In [2]:
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.loggers.csv_logs import CSVLogger

## Prepare Dataset and Module

The Pytorch Lightning Trainer takes either `torch.utils.data.DataLoader` or `pl.LightningDataModule` as data inputs. You can keep using them without any changes for the Ray AIR LightningTrainer. 

In [4]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=100):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )

            # split data into train and val sets
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        with FileLock(f"{self.data_dir}.lock"):
            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)

Next, define a simple multi-layer perception as the subclass of `pl.LightningModule`.

In [5]:
class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr=1e-3, feature_dim=128):
        torch.manual_seed(421)
        super(MNISTClassifier, self).__init__()
        self.save_hyperparameters()

        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 10),
            nn.ReLU(),
        )
        self.lr = lr
        self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
        self.eval_loss = []
        self.eval_accuracy = []
        self.test_accuracy = []
        pl.seed_everything(888)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.linear_relu_stack(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, acc = self._shared_eval(val_batch)
        self.log("val_accuracy", acc)
        self.eval_loss.append(loss)
        self.eval_accuracy.append(acc)
        return {"val_loss": loss, "val_accuracy": acc}

    def test_step(self, test_batch, batch_idx):
        loss, acc = self._shared_eval(test_batch)
        self.test_accuracy.append(acc)
        self.log("test_accuracy", acc, sync_dist=True, on_epoch=True)
        return {"test_loss": loss, "test_accuracy": acc}

    def _shared_eval(self, batch):
        x, y = batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc

    def on_validation_epoch_end(self):
        avg_loss = torch.stack(self.eval_loss).mean()
        avg_acc = torch.stack(self.eval_accuracy).mean()
        self.log("val_loss", avg_loss, sync_dist=True)
        self.log("val_accuracy", avg_acc, sync_dist=True)
        self.eval_loss.clear()
        self.eval_accuracy.clear()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

You don't need to make any change to the definition of PyTorch Lightning model and datamodule.

## Define the Training Loop

Here we define a training loop for each worker. Compare with the original PyTorch Lightning code, there are 3 main differences:

- Distributed strategy: Use {class}`RayDDPStrategy <ray.train.lightning.RayDDPStrategy>`.
- Cluster environment: Use {class}`RayLightningEnvironment <ray.train.lightning.RayLightningEnvironment>`.
- Parallel devices: Always sets to `devices="auto"` to use all available devices configured by ``TorchTrainer``.

Please refer to {ref}`Getting Started with PyTorch Lightning <train-pytorch-lightning>`.


For checkpoint reportining, Ray Train provides a minimal {class}`RayTrainReportCallback <ray.train.lightning.RayTrainReportCallback>` that reports metrics and checkpoint on each train epoch end. For more complex checkpoint logic, please implement custom callbacks as described in {ref}`Saving and Loading Checkpoint <train-checkpointing>` user guide.

In [6]:
use_gpu = True # Set it to False if you want to run without GPUs
num_workers = 4

In [7]:
import pytorch_lightning as pl
from ray.train import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)

def train_func_per_worker():
    model = MNISTClassifier(lr=1e-3, feature_dim=128)
    datamodule = MNISTDataModule(batch_size=128)

    trainer = pl.Trainer(
        devices="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[RayTrainReportCallback()],
        max_epochs=10,
        accelerator="gpu" if use_gpu else "cpu",
        log_every_n_steps=100,
        logger=CSVLogger("logs"),
    )
    
    trainer = prepare_trainer(trainer)
    
    # Train model
    trainer.fit(model, datamodule=datamodule)

    # Evaluation on the test dataset
    trainer.test(model, datamodule=datamodule)

Now put everything together:

In [8]:
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

run_config = RunConfig(
    name="ptl-mnist-example",
    storage_path="/tmp/ray_results",
    checkpoint_config=CheckpointConfig(
        num_to_keep=3,
        checkpoint_score_attribute="val_accuracy",
        checkpoint_score_order="max",
    ),
)

trainer = TorchTrainer(
    train_func_per_worker,
    scaling_config=scaling_config,
    run_config=run_config,
)

Now fit your trainer:

In [9]:
result = trainer.fit()

0,1
Current time:,2023-08-07 23:41:11
Running for:,00:00:39.80
Memory:,24.2/186.6 GiB

Trial name,status,loc,iter,total time (s),train_loss,val_accuracy,val_loss
TorchTrainer_78346_00000,TERMINATED,10.0.6.244:120026,10,29.0221,0.0315938,0.970002,-12.3466


[2m[36m(TorchTrainer pid=120026)[0m Starting distributed worker processes: ['120176 (10.0.6.244)', '120177 (10.0.6.244)', '120178 (10.0.6.244)', '120179 (10.0.6.244)']
[2m[36m(RayTrainWorker pid=120176)[0m Setting up process group for: env:// [rank=0, world_size=4]
[2m[36m(RayTrainWorker pid=120176)[0m [rank: 0] Global seed set to 888
[2m[36m(RayTrainWorker pid=120176)[0m GPU available: True (cuda), used: True
[2m[36m(RayTrainWorker pid=120176)[0m TPU available: False, using: 0 TPU cores
[2m[36m(RayTrainWorker pid=120176)[0m IPU available: False, using: 0 IPUs
[2m[36m(RayTrainWorker pid=120176)[0m HPU available: False, using: 0 HPUs


[2m[36m(RayTrainWorker pid=120178)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=120178)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 94562894.32it/s]
  9%|▉         | 917504/9912422 [00:00<00:00, 9166590.91it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 115619443.32it/s]


[2m[36m(RayTrainWorker pid=120179)[0m Extracting /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_3/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_3/MNIST/raw
[2m[36m(RayTrainWorker pid=120176)[0m 


[2m[36m(RayTrainWorker pid=120177)[0m Missing logger folder: logs/lightning_logs
[2m[36m(RayTrainWorker pid=120176)[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
[2m[36m(RayTrainWorker pid=120176)[0m   | Name              | Type               | Params
[2m[36m(RayTrainWorker pid=120176)[0m ---------------------------------------------------------
[2m[36m(RayTrainWorker pid=120176)[0m 0 | linear_relu_stack | Sequential         | 101 K 
[2m[36m(RayTrainWorker pid=120176)[0m 1 | accuracy          | MulticlassAccuracy | 0     
[2m[36m(RayTrainWorker pid=120176)[0m ---------------------------------------------------------
[2m[36m(RayTrainWorker pid=120176)[0m 101 K     Trainable params
[2m[36m(RayTrainWorker pid=120176)[0m 0         Non-trainable params
[2m[36m(RayTrainWorker pid=120176)[0m 101 K     Total params
[2m[36m(RayTrainWorker pid=120176)[0m 0.407     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s])[0m 
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  2.69it/s]


[2m[36m(RayTrainWorker pid=120179)[0m [rank: 3] Global seed set to 888[32m [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m


Epoch 0:   0%|          | 0/108 [00:00<?, ?it/s] 
Epoch 0:  12%|█▏        | 13/108 [00:00<00:02, 39.35it/s, v_num=0]
Epoch 0:  25%|██▌       | 27/108 [00:00<00:01, 59.26it/s, v_num=0]
Epoch 0:  26%|██▌       | 28/108 [00:00<00:01, 61.03it/s, v_num=0]
Epoch 0:  27%|██▋       | 29/108 [00:00<00:01, 62.76it/s, v_num=0]
Epoch 0:  42%|████▏     | 45/108 [00:00<00:00, 81.02it/s, v_num=0]
Epoch 0:  53%|█████▎    | 57/108 [00:00<00:00, 86.01it/s, v_num=0]
Epoch 0:  64%|██████▍   | 69/108 [00:00<00:00, 88.63it/s, v_num=0]
Epoch 0:  81%|████████  | 87/108 [00:00<00:00, 98.04it/s, v_num=0]
Epoch 0:  81%|████████▏ | 88/108 [00:00<00:00, 98.69it/s, v_num=0]
Epoch 0:  82%|████████▏ | 89/108 [00:00<00:00, 99.34it/s, v_num=0]
Epoch 0:  96%|█████████▋| 104/108 [00:00<00:00, 104.14it/s, v_num=0]
Epoch 0:  97%|█████████▋| 105/108 [00:01<00:00, 104.71it/s, v_num=0]
Epoch 0:  98%|█████████▊| 106/108 [00:01<00:00, 105.22it/s, v_num=0]
Epoch 0: 100%|██████████| 108/108 [00:01<00:00, 105.79it/s, v_num=0]
Vali

[2m[36m(RayTrainWorker pid=120176)[0m `Trainer.fit` stopped: `max_epochs=10` reached.
100%|██████████| 4542/4542 [00:00<00:00, 48474627.91it/s][32m [repeated 14x across cluster][0m
100%|██████████| 9912422/9912422 [00:00<00:00, 90032420.31it/s][32m [repeated 2x across cluster][0m
[2m[36m(RayTrainWorker pid=120176)[0m [32m [repeated 5x across cluster][0m
[2m[36m(RayTrainWorker pid=120178)[0m Missing logger folder: logs/lightning_logs[32m [repeated 2x across cluster][0m
[2m[36m(RayTrainWorker pid=120179)[0m LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3][32m [repeated 3x across cluster][0m
[2m[36m(RayTrainWorker pid=120176)[0m [rank: 0] Global seed set to 888


Epoch 9: 100%|██████████| 108/108 [00:01<00:00, 66.61it/s, v_num=0]


[2m[36m(RayTrainWorker pid=120176)[0m   rank_zero_warn(


Testing DataLoader 0:  25%|██▌       | 5/20 [00:00<00:00, 146.57it/s]
Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 163.98it/s]
Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 125.34it/s]
[2m[36m(RayTrainWorker pid=120176)[0m ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
[2m[36m(RayTrainWorker pid=120176)[0m ┃        Test metric        ┃       DataLoader 0        ┃
[2m[36m(RayTrainWorker pid=120176)[0m ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
[2m[36m(RayTrainWorker pid=120176)[0m │       test_accuracy       │    0.9740999937057495     │
[2m[36m(RayTrainWorker pid=120176)[0m └───────────────────────────┴───────────────────────────┘


2023-08-07 23:41:11,072	INFO tune.py:1145 -- Total run time: 39.92 seconds (39.80 seconds for the tuning loop).


## Check the Training Results and Checkpoints

In [10]:
result

Result(
  metrics={'train_loss': 0.03159375861287117, 'val_accuracy': 0.9700015783309937, 'val_loss': -12.346583366394043, 'epoch': 9, 'step': 1080},
  path='/tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31',
  checkpoint=LegacyTorchCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/checkpoint_000009)
)

In [11]:
print("Validation Accuracy: ", result.metrics["val_accuracy"])
print("Trial Directory: ", result.path)
print(sorted(os.listdir(result.path)))

Validation Accuracy:  0.9700015783309937
Trial Directory:  /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31
['checkpoint_000007', 'checkpoint_000008', 'checkpoint_000009', 'events.out.tfevents.1691476838.ip-10-0-6-244', 'params.json', 'params.pkl', 'progress.csv', 'rank_0', 'rank_0.lock', 'rank_1', 'rank_1.lock', 'rank_2', 'rank_2.lock', 'rank_3', 'rank_3.lock', 'result.json']


As we can see, three checkpoints(`checkpoint_000007`, `checkpoint_000008`, `checkpoint_000009`) have been saved in the trial directory. To retrieve the latest checkpoint from the fit results and load it back into the model, follow these steps.

If you lost the in-memory result object, you can also restore the model from the checkpoint file. Here the checkpoint path is: `/tmp/ray_results/ptl-mnist-example/TorchTrainer_eb925_00000_0_2023-08-07_23-15-06/checkpoint_000009/checkpoint.ckpt`.

In [12]:
checkpoint = result.checkpoint

with checkpoint.as_directory() as ckpt_dir:
    best_model = MNISTClassifier.load_from_checkpoint(f"{ckpt_dir}/checkpoint.ckpt")

best_model

Global seed set to 888


MNISTClassifier(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=10, bias=True)
    (3): ReLU()
  )
  (accuracy): MulticlassAccuracy()
)