{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Train a Pytorch Lightning Image Classifier\n", "\n", "\n", " \n", "\n", "\n", "\n", "This example introduces how to train a Pytorch Lightning Module using Ray Train {class}`TorchTrainer `. It demonstrates how to train a basic neural network on the MNIST dataset with distributed data parallelism.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install \"torchmetrics>=0.9\" \"pytorch_lightning>=1.6\" " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import random\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from filelock import FileLock\n", "from torch.utils.data import DataLoader, random_split, Subset\n", "from torchmetrics import Accuracy\n", "from torchvision.datasets import MNIST\n", "from torchvision import transforms\n", "\n", "import pytorch_lightning as pl\n", "from pytorch_lightning import trainer\n", "from pytorch_lightning.loggers.csv_logs import CSVLogger" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare a dataset and module\n", "\n", "The Pytorch Lightning Trainer takes either `torch.utils.data.DataLoader` or `pl.LightningDataModule` as data inputs. You can continue using them without any changes with Ray Train. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class MNISTDataModule(pl.LightningDataModule):\n", " def __init__(self, batch_size=100):\n", " super().__init__()\n", " self.data_dir = os.getcwd()\n", " self.batch_size = batch_size\n", " self.transform = transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", " )\n", "\n", " def setup(self, stage=None):\n", " with FileLock(f\"{self.data_dir}.lock\"):\n", " mnist = MNIST(\n", " self.data_dir, train=True, download=True, transform=self.transform\n", " )\n", "\n", " # Split data into train and val sets\n", " self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)\n", "\n", " def val_dataloader(self):\n", " return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)\n", "\n", " def test_dataloader(self):\n", " with FileLock(f\"{self.data_dir}.lock\"):\n", " self.mnist_test = MNIST(\n", " self.data_dir, train=False, download=True, transform=self.transform\n", " )\n", " return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Next, define a simple multi-layer perception as the subclass of `pl.LightningModule`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class MNISTClassifier(pl.LightningModule):\n", " def __init__(self, lr=1e-3, feature_dim=128):\n", " torch.manual_seed(421)\n", " super(MNISTClassifier, self).__init__()\n", " self.save_hyperparameters()\n", "\n", " self.linear_relu_stack = nn.Sequential(\n", " nn.Linear(28 * 28, feature_dim),\n", " nn.ReLU(),\n", " nn.Linear(feature_dim, 10),\n", " nn.ReLU(),\n", " )\n", " self.lr = lr\n", " self.accuracy = Accuracy(task=\"multiclass\", num_classes=10, top_k=1)\n", " self.eval_loss = []\n", " self.eval_accuracy = []\n", " self.test_accuracy = []\n", " pl.seed_everything(888)\n", "\n", " def forward(self, x):\n", " x = x.view(-1, 28 * 28)\n", " x = self.linear_relu_stack(x)\n", " return x\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = torch.nn.functional.cross_entropy(y_hat, y)\n", " self.log(\"train_loss\", loss)\n", " return loss\n", "\n", " def validation_step(self, val_batch, batch_idx):\n", " loss, acc = self._shared_eval(val_batch)\n", " self.log(\"val_accuracy\", acc)\n", " self.eval_loss.append(loss)\n", " self.eval_accuracy.append(acc)\n", " return {\"val_loss\": loss, \"val_accuracy\": acc}\n", "\n", " def test_step(self, test_batch, batch_idx):\n", " loss, acc = self._shared_eval(test_batch)\n", " self.test_accuracy.append(acc)\n", " self.log(\"test_accuracy\", acc, sync_dist=True, on_epoch=True)\n", " return {\"test_loss\": loss, \"test_accuracy\": acc}\n", "\n", " def _shared_eval(self, batch):\n", " x, y = batch\n", " logits = self.forward(x)\n", " loss = F.nll_loss(logits, y)\n", " acc = self.accuracy(logits, y)\n", " return loss, acc\n", "\n", " def on_validation_epoch_end(self):\n", " avg_loss = torch.stack(self.eval_loss).mean()\n", " avg_acc = torch.stack(self.eval_accuracy).mean()\n", " self.log(\"val_loss\", avg_loss, sync_dist=True)\n", " self.log(\"val_accuracy\", avg_acc, sync_dist=True)\n", " self.eval_loss.clear()\n", " self.eval_accuracy.clear()\n", " \n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n", " return optimizer" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "You don't need to modify the definition of the PyTorch Lightning model or datamodule." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Define a training function\n", "\n", "This code defines a {ref}`training function ` for each worker. Comparing the training function with the original PyTorch Lightning code, notice three main differences:\n", "\n", "- Distributed strategy: Use {class}`RayDDPStrategy `.\n", "- Cluster environment: Use {class}`RayLightningEnvironment `.\n", "- Parallel devices: Always set to `devices=\"auto\"` to use all available devices configured by ``TorchTrainer``.\n", "\n", "See {ref}`Getting Started with PyTorch Lightning ` for more information.\n", "\n", "\n", "For checkpoint reporting, Ray Train provides a minimal {class}`RayTrainReportCallback ` class that reports metrics and checkpoints at the end of each train epoch. For more complex checkpoint logic, implement custom callbacks. See {ref}`Saving and Loading Checkpoint `." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "use_gpu = True # Set to False if you want to run without GPUs\n", "num_workers = 4" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "from ray.train import RunConfig, ScalingConfig, CheckpointConfig\n", "from ray.train.torch import TorchTrainer\n", "from ray.train.lightning import (\n", " RayDDPStrategy,\n", " RayLightningEnvironment,\n", " RayTrainReportCallback,\n", " prepare_trainer,\n", ")\n", "\n", "def train_func_per_worker():\n", " model = MNISTClassifier(lr=1e-3, feature_dim=128)\n", " datamodule = MNISTDataModule(batch_size=128)\n", "\n", " trainer = pl.Trainer(\n", " devices=\"auto\",\n", " strategy=RayDDPStrategy(),\n", " plugins=[RayLightningEnvironment()],\n", " callbacks=[RayTrainReportCallback()],\n", " max_epochs=10,\n", " accelerator=\"gpu\" if use_gpu else \"cpu\",\n", " log_every_n_steps=100,\n", " logger=CSVLogger(\"logs\"),\n", " )\n", " \n", " trainer = prepare_trainer(trainer)\n", " \n", " # Train model\n", " trainer.fit(model, datamodule=datamodule)\n", "\n", " # Evaluation on the test dataset\n", " trainer.test(model, datamodule=datamodule)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now put everything together:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)\n", "\n", "run_config = RunConfig(\n", " name=\"ptl-mnist-example\",\n", " storage_path=\"/tmp/ray_results\",\n", " checkpoint_config=CheckpointConfig(\n", " num_to_keep=3,\n", " checkpoint_score_attribute=\"val_accuracy\",\n", " checkpoint_score_order=\"max\",\n", " ),\n", ")\n", "\n", "trainer = TorchTrainer(\n", " train_func_per_worker,\n", " scaling_config=scaling_config,\n", " run_config=run_config,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now fit your trainer:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " Tune Status\n", " \n", "\n", "Current time:2023-08-07 23:41:11\n", "Running for: 00:00:39.80 \n", "Memory: 24.2/186.6 GiB \n", "\n", "\n", " \n", " \n", " \n", " System Info\n", " Using FIFO scheduling algorithm.Logical resource usage: 1.0/48 CPUs, 4.0/4 GPUs\n", " \n", " \n", " \n", " \n", " \n", " Trial Status\n", " \n", "\n", "Trial name status loc iter total time (s) train_loss val_accuracy val_loss\n", "\n", "\n", "TorchTrainer_78346_00000TERMINATED10.0.6.244:120026 10 29.0221 0.0315938 0.970002 -12.3466\n", "\n", "\n", " \n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(TorchTrainer pid=120026)\u001b[0m Starting distributed worker processes: ['120176 (10.0.6.244)', '120177 (10.0.6.244)', '120178 (10.0.6.244)', '120179 (10.0.6.244)']\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m Setting up process group for: env:// [rank=0, world_size=4]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m [rank: 0] Global seed set to 888\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m GPU available: True (cuda), used: True\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m TPU available: False, using: 0 TPU cores\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m IPU available: False, using: 0 IPUs\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m HPU available: False, using: 0 HPUs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=120178)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120178)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw/train-images-idx3-ubyte.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/9912422 [00:00, ?it/s]\n", "100%|██████████| 9912422/9912422 [00:00<00:00, 94562894.32it/s]\n", " 9%|▉ | 917504/9912422 [00:00<00:00, 9166590.91it/s]\n", "100%|██████████| 9912422/9912422 [00:00<00:00, 115619443.32it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=120179)\u001b[0m Extracting /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_3/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_3/MNIST/raw\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=120177)\u001b[0m Missing logger folder: logs/lightning_logs\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m | Name | Type | Params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m ---------------------------------------------------------\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m 0 | linear_relu_stack | Sequential | 101 K \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m 1 | accuracy | MulticlassAccuracy | 0 \n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m ---------------------------------------------------------\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m 101 K Trainable params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m 0 Non-trainable params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m 101 K Total params\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m 0.407 Total estimated model params size (MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Sanity Checking: 0it [00:00, ?it/s])\u001b[0m \n", "Sanity Checking DataLoader 0: 0%| | 0/2 [00:00, ?it/s]\n", "Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 2.69it/s]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m /mnt/cluster_storage/pypi/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_accuracy', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m warning_cache.warn(\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120179)\u001b[0m [rank: 3] Global seed set to 888\u001b[32m [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 0%| | 0/108 [00:00, ?it/s] \n", "Epoch 0: 12%|█▏ | 13/108 [00:00<00:02, 39.35it/s, v_num=0]\n", "Epoch 0: 25%|██▌ | 27/108 [00:00<00:01, 59.26it/s, v_num=0]\n", "Epoch 0: 26%|██▌ | 28/108 [00:00<00:01, 61.03it/s, v_num=0]\n", "Epoch 0: 27%|██▋ | 29/108 [00:00<00:01, 62.76it/s, v_num=0]\n", "Epoch 0: 42%|████▏ | 45/108 [00:00<00:00, 81.02it/s, v_num=0]\n", "Epoch 0: 53%|█████▎ | 57/108 [00:00<00:00, 86.01it/s, v_num=0]\n", "Epoch 0: 64%|██████▍ | 69/108 [00:00<00:00, 88.63it/s, v_num=0]\n", "Epoch 0: 81%|████████ | 87/108 [00:00<00:00, 98.04it/s, v_num=0]\n", "Epoch 0: 81%|████████▏ | 88/108 [00:00<00:00, 98.69it/s, v_num=0]\n", "Epoch 0: 82%|████████▏ | 89/108 [00:00<00:00, 99.34it/s, v_num=0]\n", "Epoch 0: 96%|█████████▋| 104/108 [00:00<00:00, 104.14it/s, v_num=0]\n", "Epoch 0: 97%|█████████▋| 105/108 [00:01<00:00, 104.71it/s, v_num=0]\n", "Epoch 0: 98%|█████████▊| 106/108 [00:01<00:00, 105.22it/s, v_num=0]\n", "Epoch 0: 100%|██████████| 108/108 [00:01<00:00, 105.79it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 171.69it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 200.99it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 221.66it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 215.50it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 194.14it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 205.63it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 215.27it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 216.26it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 198.67it/s]\u001b[A\n", "Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 205.79it/s]\u001b[A\n", "Epoch 0: 100%|██████████| 108/108 [00:01<00:00, 79.84it/s, v_num=0] \n", " \u001b[A\n", "Epoch 1: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 1: 11%|█ | 12/108 [00:00<00:02, 32.36it/s, v_num=0]\n", "Epoch 1: 23%|██▎ | 25/108 [00:00<00:01, 50.16it/s, v_num=0]\n", "Epoch 1: 37%|███▋ | 40/108 [00:00<00:01, 65.95it/s, v_num=0]\n", "Epoch 1: 38%|███▊ | 41/108 [00:00<00:00, 67.05it/s, v_num=0]\n", "Epoch 1: 50%|█████ | 54/108 [00:00<00:00, 75.52it/s, v_num=0]\n", "Epoch 1: 51%|█████ | 55/108 [00:00<00:00, 76.40it/s, v_num=0]\n", "Epoch 1: 62%|██████▏ | 67/108 [00:00<00:00, 81.72it/s, v_num=0]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120178)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "Epoch 1: 77%|███████▋ | 83/108 [00:00<00:00, 89.48it/s, v_num=0]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120178)\u001b[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw/t10k-labels-idx1-ubyte.gz\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "Epoch 1: 78%|███████▊ | 84/108 [00:00<00:00, 89.21it/s, v_num=0]\n", "Epoch 1: 91%|█████████ | 98/108 [00:01<00:00, 93.27it/s, v_num=0]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120178)\u001b[0m Extracting /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw\u001b[32m [repeated 15x across cluster]\u001b[0m\n", "Epoch 1: 92%|█████████▏| 99/108 [00:01<00:00, 93.94it/s, v_num=0]\n", "Epoch 1: 93%|█████████▎| 100/108 [00:01<00:00, 94.57it/s, v_num=0]\n", "Epoch 1: 100%|██████████| 108/108 [00:01<00:00, 98.06it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 320.27it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 291.99it/s]\u001b[A\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m \u001b[32m [repeated 19x across cluster]\u001b[0m\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 291.61it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 268.90it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 290.07it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 293.52it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 299.70it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 304.80it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 310.16it/s]\u001b[A\n", "Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 303.63it/s]\u001b[A\n", "Epoch 1: 100%|██████████| 108/108 [00:01<00:00, 76.12it/s, v_num=0]\n", " \u001b[A\n", "Epoch 2: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 2: 7%|▋ | 8/108 [00:00<00:03, 25.23it/s, v_num=0]\n", "Epoch 2: 16%|█▌ | 17/108 [00:00<00:02, 39.73it/s, v_num=0]\n", "Epoch 2: 17%|█▋ | 18/108 [00:00<00:02, 41.60it/s, v_num=0]\n", "Epoch 2: 18%|█▊ | 19/108 [00:00<00:02, 43.49it/s, v_num=0]\n", "Epoch 2: 18%|█▊ | 19/108 [00:00<00:02, 43.46it/s, v_num=0]\n", "Epoch 2: 19%|█▊ | 20/108 [00:00<00:01, 45.27it/s, v_num=0]\n", "Epoch 2: 27%|██▋ | 29/108 [00:00<00:01, 53.08it/s, v_num=0]\n", "Epoch 2: 42%|████▏ | 45/108 [00:00<00:00, 69.12it/s, v_num=0]\n", "Epoch 2: 43%|████▎ | 46/108 [00:00<00:00, 70.31it/s, v_num=0]\n", "Epoch 2: 44%|████▎ | 47/108 [00:00<00:00, 71.51it/s, v_num=0]\n", "Epoch 2: 44%|████▍ | 48/108 [00:00<00:00, 72.71it/s, v_num=0]\n", "Epoch 2: 59%|█████▉ | 64/108 [00:00<00:00, 83.97it/s, v_num=0]\n", "Epoch 2: 75%|███████▌ | 81/108 [00:00<00:00, 93.77it/s, v_num=0]\n", "Epoch 2: 90%|████████▉ | 97/108 [00:00<00:00, 99.35it/s, v_num=0]\n", "Epoch 2: 100%|██████████| 108/108 [00:01<00:00, 101.71it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 212.13it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 184.45it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 228.42it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 225.00it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 250.65it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 251.36it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 268.85it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 256.15it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 269.87it/s]\u001b[A\n", "Epoch 2: 100%|██████████| 108/108 [00:01<00:00, 77.52it/s, v_num=0] it/s]\u001b[A\n", " \u001b[A\n", "Epoch 3: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 3: 8%|▊ | 9/108 [00:00<00:03, 25.68it/s, v_num=0]\n", "Epoch 3: 9%|▉ | 10/108 [00:00<00:03, 28.26it/s, v_num=0]\n", "Epoch 3: 20%|██ | 22/108 [00:00<00:01, 48.10it/s, v_num=0]\n", "Epoch 3: 21%|██▏ | 23/108 [00:00<00:01, 49.73it/s, v_num=0]\n", "Epoch 3: 22%|██▏ | 24/108 [00:00<00:01, 51.34it/s, v_num=0]\n", "Epoch 3: 23%|██▎ | 25/108 [00:00<00:01, 52.98it/s, v_num=0]\n", "Epoch 3: 37%|███▋ | 40/108 [00:00<00:00, 69.67it/s, v_num=0]\n", "Epoch 3: 51%|█████ | 55/108 [00:00<00:00, 80.93it/s, v_num=0]\n", "Epoch 3: 64%|██████▍ | 69/108 [00:00<00:00, 87.15it/s, v_num=0]\n", "Epoch 3: 65%|██████▍ | 70/108 [00:00<00:00, 88.04it/s, v_num=0]\n", "Epoch 3: 66%|██████▌ | 71/108 [00:00<00:00, 88.92it/s, v_num=0]\n", "Epoch 3: 77%|███████▋ | 83/108 [00:00<00:00, 92.62it/s, v_num=0]\n", "Epoch 3: 86%|████████▌ | 93/108 [00:01<00:00, 92.33it/s, v_num=0]\n", "Epoch 3: 87%|████████▋ | 94/108 [00:01<00:00, 92.93it/s, v_num=0]\n", "Epoch 3: 88%|████████▊ | 95/108 [00:01<00:00, 93.61it/s, v_num=0]\n", "Epoch 3: 100%|██████████| 108/108 [00:01<00:00, 97.43it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 308.50it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 344.87it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 375.98it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 335.26it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 327.34it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 317.66it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 332.79it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 188.14it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 201.21it/s]\u001b[A\n", "Epoch 3: 100%|██████████| 108/108 [00:01<00:00, 75.94it/s, v_num=0]6it/s]\u001b[A\n", " \u001b[A\n", "Epoch 4: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 4: 10%|█ | 11/108 [00:00<00:03, 30.09it/s, v_num=0]\n", "Epoch 4: 20%|██ | 22/108 [00:00<00:01, 46.96it/s, v_num=0]\n", "Epoch 4: 21%|██▏ | 23/108 [00:00<00:01, 47.88it/s, v_num=0]\n", "Epoch 4: 35%|███▌ | 38/108 [00:00<00:01, 65.26it/s, v_num=0]\n", "Epoch 4: 36%|███▌ | 39/108 [00:00<00:01, 65.73it/s, v_num=0]\n", "Epoch 4: 53%|█████▎ | 57/108 [00:00<00:00, 81.51it/s, v_num=0]\n", "Epoch 4: 54%|█████▎ | 58/108 [00:00<00:00, 82.56it/s, v_num=0]\n", "Epoch 4: 68%|██████▊ | 73/108 [00:00<00:00, 89.69it/s, v_num=0]\n", "Epoch 4: 69%|██████▊ | 74/108 [00:00<00:00, 90.53it/s, v_num=0]\n", "Epoch 4: 83%|████████▎ | 90/108 [00:00<00:00, 98.32it/s, v_num=0]\n", "Epoch 4: 98%|█████████▊| 106/108 [00:01<00:00, 103.12it/s, v_num=0]\n", "Epoch 4: 100%|██████████| 108/108 [00:01<00:00, 103.78it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 268.49it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 298.62it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 282.88it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 256.50it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 276.28it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 268.05it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 276.18it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 290.08it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 261.92it/s]\u001b[A\n", "Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 274.00it/s]\u001b[A\n", "Epoch 4: 100%|██████████| 108/108 [00:01<00:00, 78.54it/s, v_num=0] \n", " \u001b[A\n", "Epoch 5: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 5: 5%|▍ | 5/108 [00:00<00:06, 15.52it/s, v_num=0]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m \u001b[32m [repeated 9x across cluster]\u001b[0m\n", "Epoch 5: 17%|█▋ | 18/108 [00:00<00:02, 42.53it/s, v_num=0]\n", "Epoch 5: 26%|██▌ | 28/108 [00:00<00:01, 52.36it/s, v_num=0]\n", "Epoch 5: 27%|██▋ | 29/108 [00:00<00:01, 53.91it/s, v_num=0]\n", "Epoch 5: 28%|██▊ | 30/108 [00:00<00:01, 55.45it/s, v_num=0]\n", "Epoch 5: 29%|██▊ | 31/108 [00:00<00:01, 56.96it/s, v_num=0]\n", "Epoch 5: 37%|███▋ | 40/108 [00:00<00:01, 61.48it/s, v_num=0]\n", "Epoch 5: 38%|███▊ | 41/108 [00:00<00:01, 62.61it/s, v_num=0]\n", "Epoch 5: 39%|███▉ | 42/108 [00:00<00:01, 63.79it/s, v_num=0]\n", "Epoch 5: 40%|███▉ | 43/108 [00:00<00:01, 64.96it/s, v_num=0]\n", "Epoch 5: 48%|████▊ | 52/108 [00:00<00:00, 67.96it/s, v_num=0]\n", "Epoch 5: 49%|████▉ | 53/108 [00:00<00:00, 68.88it/s, v_num=0]\n", "Epoch 5: 50%|█████ | 54/108 [00:00<00:00, 69.77it/s, v_num=0]\n", "Epoch 5: 62%|██████▏ | 67/108 [00:00<00:00, 76.43it/s, v_num=0]\n", "Epoch 5: 78%|███████▊ | 84/108 [00:00<00:00, 85.56it/s, v_num=0]\n", "Epoch 5: 79%|███████▊ | 85/108 [00:00<00:00, 86.17it/s, v_num=0]\n", "Epoch 5: 93%|█████████▎| 100/108 [00:01<00:00, 92.27it/s, v_num=0]\n", "Epoch 5: 100%|██████████| 108/108 [00:01<00:00, 94.81it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 255.91it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 206.50it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 214.91it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 236.35it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 240.05it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 234.90it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 240.78it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 252.00it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 255.18it/s]\u001b[A\n", "Epoch 5: 100%|██████████| 108/108 [00:01<00:00, 72.36it/s, v_num=0]1it/s]\u001b[A\n", " \u001b[A\n", "Epoch 6: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 6: 15%|█▍ | 16/108 [00:00<00:02, 44.27it/s, v_num=0]\n", "Epoch 6: 25%|██▌ | 27/108 [00:00<00:01, 57.21it/s, v_num=0]\n", "Epoch 6: 38%|███▊ | 41/108 [00:00<00:00, 70.86it/s, v_num=0]\n", "Epoch 6: 39%|███▉ | 42/108 [00:00<00:00, 71.82it/s, v_num=0]\n", "Epoch 6: 55%|█████▍ | 59/108 [00:00<00:00, 85.97it/s, v_num=0]\n", "Epoch 6: 68%|██████▊ | 73/108 [00:00<00:00, 91.53it/s, v_num=0]\n", "Epoch 6: 81%|████████ | 87/108 [00:00<00:00, 96.88it/s, v_num=0]\n", "Epoch 6: 92%|█████████▏| 99/108 [00:00<00:00, 99.33it/s, v_num=0]\n", "Epoch 6: 93%|█████████▎| 100/108 [00:01<00:00, 98.66it/s, v_num=0]\n", "Epoch 6: 94%|█████████▎| 101/108 [00:01<00:00, 99.34it/s, v_num=0]\n", "Epoch 6: 100%|██████████| 108/108 [00:01<00:00, 102.79it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 197.51it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 143.68it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 156.17it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 180.52it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 205.25it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 212.20it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 195.64it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 211.21it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 225.13it/s]\u001b[A\n", "Epoch 6: 100%|██████████| 108/108 [00:01<00:00, 76.04it/s, v_num=0] it/s]\u001b[A\n", " \u001b[A\n", "Epoch 7: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 7: 11%|█ | 12/108 [00:00<00:02, 33.31it/s, v_num=0]\n", "Epoch 7: 20%|██ | 22/108 [00:00<00:01, 46.90it/s, v_num=0]\n", "Epoch 7: 22%|██▏ | 24/108 [00:00<00:01, 50.49it/s, v_num=0]\n", "Epoch 7: 31%|███▏ | 34/108 [00:00<00:01, 58.20it/s, v_num=0]\n", "Epoch 7: 32%|███▏ | 35/108 [00:00<00:01, 59.59it/s, v_num=0]\n", "Epoch 7: 33%|███▎ | 36/108 [00:00<00:01, 60.97it/s, v_num=0]\n", "Epoch 7: 48%|████▊ | 52/108 [00:00<00:00, 74.69it/s, v_num=0]\n", "Epoch 7: 64%|██████▍ | 69/108 [00:00<00:00, 85.96it/s, v_num=0]\n", "Epoch 7: 80%|███████▉ | 86/108 [00:00<00:00, 94.41it/s, v_num=0]\n", "Epoch 7: 81%|████████▏ | 88/108 [00:00<00:00, 95.91it/s, v_num=0]\n", "Epoch 7: 97%|█████████▋| 105/108 [00:01<00:00, 102.61it/s, v_num=0]\n", "Epoch 7: 100%|██████████| 108/108 [00:01<00:00, 103.00it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 215.46it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 246.46it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 264.39it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 256.84it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 218.46it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 230.90it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 243.53it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 253.83it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 249.22it/s]\u001b[A\n", "Epoch 7: 100%|██████████| 108/108 [00:01<00:00, 78.36it/s, v_num=0] it/s]\u001b[A\n", " \u001b[A\n", "Epoch 8: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 8: 7%|▋ | 8/108 [00:00<00:03, 25.72it/s, v_num=0]\n", "Epoch 8: 19%|█▊ | 20/108 [00:00<00:01, 47.54it/s, v_num=0]\n", "Epoch 8: 31%|███ | 33/108 [00:00<00:01, 62.61it/s, v_num=0]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m \u001b[32m [repeated 7x across cluster]\u001b[0m\n", "Epoch 8: 44%|████▎ | 47/108 [00:00<00:00, 74.45it/s, v_num=0]\n", "Epoch 8: 55%|█████▍ | 59/108 [00:00<00:00, 81.84it/s, v_num=0]\n", "Epoch 8: 56%|█████▌ | 60/108 [00:00<00:00, 80.82it/s, v_num=0]\n", "Epoch 8: 57%|█████▋ | 62/108 [00:00<00:00, 82.74it/s, v_num=0]\n", "Epoch 8: 58%|█████▊ | 63/108 [00:00<00:00, 83.69it/s, v_num=0]\n", "Epoch 8: 70%|███████ | 76/108 [00:00<00:00, 88.60it/s, v_num=0]\n", "Epoch 8: 85%|████████▌ | 92/108 [00:00<00:00, 96.53it/s, v_num=0]\n", "Epoch 8: 86%|████████▌ | 93/108 [00:00<00:00, 96.21it/s, v_num=0]\n", "Epoch 8: 87%|████████▋ | 94/108 [00:00<00:00, 96.72it/s, v_num=0]\n", "Epoch 8: 88%|████████▊ | 95/108 [00:00<00:00, 97.32it/s, v_num=0]\n", "Epoch 8: 89%|████████▉ | 96/108 [00:00<00:00, 98.03it/s, v_num=0]\n", "Epoch 8: 100%|██████████| 108/108 [00:01<00:00, 102.15it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 228.96it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 220.63it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 220.41it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 208.74it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 221.74it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 243.64it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 253.60it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 254.93it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 207.23it/s]\u001b[A\n", "Epoch 8: 100%|██████████| 108/108 [00:01<00:00, 78.28it/s, v_num=0] it/s]\u001b[A\n", " \u001b[A\n", "Epoch 9: 0%| | 0/108 [00:00, ?it/s, v_num=0] \n", "Epoch 9: 11%|█ | 12/108 [00:00<00:02, 33.03it/s, v_num=0]\n", "Epoch 9: 21%|██▏ | 23/108 [00:00<00:01, 48.82it/s, v_num=0]\n", "Epoch 9: 31%|███ | 33/108 [00:00<00:01, 58.62it/s, v_num=0]\n", "Epoch 9: 31%|███▏ | 34/108 [00:00<00:01, 58.61it/s, v_num=0]\n", "Epoch 9: 32%|███▏ | 35/108 [00:00<00:01, 59.89it/s, v_num=0]\n", "Epoch 9: 33%|███▎ | 36/108 [00:00<00:01, 61.11it/s, v_num=0]\n", "Epoch 9: 46%|████▋ | 50/108 [00:00<00:00, 71.95it/s, v_num=0]\n", "Epoch 9: 61%|██████ | 66/108 [00:00<00:00, 82.62it/s, v_num=0]\n", "Epoch 9: 70%|███████ | 76/108 [00:00<00:00, 83.77it/s, v_num=0]\n", "Epoch 9: 71%|███████▏ | 77/108 [00:00<00:00, 84.54it/s, v_num=0]\n", "Epoch 9: 72%|███████▏ | 78/108 [00:00<00:00, 85.33it/s, v_num=0]\n", "Epoch 9: 88%|████████▊ | 95/108 [00:01<00:00, 93.18it/s, v_num=0]\n", "Epoch 9: 100%|██████████| 108/108 [00:01<00:00, 98.27it/s, v_num=0]\n", "Validation: 0it [00:00, ?it/s]\u001b[A76)\u001b[0m \n", "Validation: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 0%| | 0/10 [00:00, ?it/s]\u001b[A\n", "Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 305.42it/s]\u001b[A\n", "Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 337.39it/s]\u001b[A\n", "Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 368.65it/s]\u001b[A\n", "Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 361.22it/s]\u001b[A\n", "Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 250.96it/s]\u001b[A\n", "Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 271.98it/s]\u001b[A\n", "Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 289.64it/s]\u001b[A\n", "Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 304.16it/s]\u001b[A\n", "Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 184.87it/s]\u001b[A\n", "Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 196.99it/s]\u001b[A\n", "Epoch 9: 100%|██████████| 108/108 [00:01<00:00, 74.63it/s, v_num=0]\n", " \u001b[A\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m `Trainer.fit` stopped: `max_epochs=10` reached.\n", "100%|██████████| 4542/4542 [00:00<00:00, 48474627.91it/s]\u001b[32m [repeated 14x across cluster]\u001b[0m\n", "100%|██████████| 9912422/9912422 [00:00<00:00, 90032420.31it/s]\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m \u001b[32m [repeated 5x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120178)\u001b[0m Missing logger folder: logs/lightning_logs\u001b[32m [repeated 2x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120179)\u001b[0m LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\u001b[32m [repeated 3x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m [rank: 0] Global seed set to 888\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9: 100%|██████████| 108/108 [00:01<00:00, 66.61it/s, v_num=0]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m /mnt/cluster_storage/pypi/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:225: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.test()`, it is recommended to use `Trainer(devices=1, num_nodes=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs.\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m rank_zero_warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing DataLoader 0: 25%|██▌ | 5/20 [00:00<00:00, 146.57it/s]\n", "Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 163.98it/s]\n", "Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 125.34it/s]\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m ┃ Test metric ┃ DataLoader 0 ┃\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m │ test_accuracy │ 0.9740999937057495 │\n", "\u001b[2m\u001b[36m(RayTrainWorker pid=120176)\u001b[0m └───────────────────────────┴───────────────────────────┘\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-08-07 23:41:11,072\tINFO tune.py:1145 -- Total run time: 39.92 seconds (39.80 seconds for the tuning loop).\n" ] } ], "source": [ "result = trainer.fit()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Check training results and checkpoints" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Result(\n", " metrics={'train_loss': 0.03159375861287117, 'val_accuracy': 0.9700015783309937, 'val_loss': -12.346583366394043, 'epoch': 9, 'step': 1080},\n", " path='/tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31',\n", " checkpoint=LegacyTorchCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/checkpoint_000009)\n", ")" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation Accuracy: 0.9700015783309937\n", "Trial Directory: /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31\n", "['checkpoint_000007', 'checkpoint_000008', 'checkpoint_000009', 'events.out.tfevents.1691476838.ip-10-0-6-244', 'params.json', 'params.pkl', 'progress.csv', 'rank_0', 'rank_0.lock', 'rank_1', 'rank_1.lock', 'rank_2', 'rank_2.lock', 'rank_3', 'rank_3.lock', 'result.json']\n" ] } ], "source": [ "print(\"Validation Accuracy: \", result.metrics[\"val_accuracy\"])\n", "print(\"Trial Directory: \", result.path)\n", "print(sorted(os.listdir(result.path)))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Ray Train saved three checkpoints(`checkpoint_000007`, `checkpoint_000008`, `checkpoint_000009`) in the trial directory. The following code retrieves the latest checkpoint from the fit results and loads it back into the model.\n", "\n", "If you lost the in-memory result object, you can restore the model from the checkpoint file. The checkpoint path is: `/tmp/ray_results/ptl-mnist-example/TorchTrainer_eb925_00000_0_2023-08-07_23-15-06/checkpoint_000009/checkpoint.ckpt`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Global seed set to 888\n" ] }, { "data": { "text/plain": [ "MNISTClassifier(\n", " (linear_relu_stack): Sequential(\n", " (0): Linear(in_features=784, out_features=128, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=128, out_features=10, bias=True)\n", " (3): ReLU()\n", " )\n", " (accuracy): MulticlassAccuracy()\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint = result.checkpoint\n", "\n", "with checkpoint.as_directory() as ckpt_dir:\n", " best_model = MNISTClassifier.load_from_checkpoint(f\"{ckpt_dir}/checkpoint.ckpt\")\n", "\n", "best_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## See also\n", "\n", "* {ref}`Getting Started with PyTorch Lightning ` for a tutorial on using Ray Train and PyTorch Lightning \n", "\n", "* {doc}`Ray Train Examples <../../examples>` for more use cases" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" }, "orphan": true, "vscode": { "interpreter": { "hash": "a8c1140d108077f4faeb76b2438f85e4ed675f93d004359552883616a1acd54c" } } }, "nbformat": 4, "nbformat_minor": 4 }