{ "cells": [ { "cell_type": "markdown", "id": "8ea1e6eb", "metadata": {}, "source": [ "# Launching Multi-Node Training from a Jupyter Environment\n", "> Using the `notebook_launcher` to use Accelerate from inside a Jupyter Notebook" ] }, { "cell_type": "markdown", "id": "915d7904", "metadata": {}, "source": [ "## General Overview\n", "\n", "This notebook covers how to run the `cv_example.py` script as a Jupyter Notebook and train it on a distributed system. It will also cover the few specific requirements needed for ensuring your environment is configured properly, your data has been prepared properly, and finally how to launch training." ] }, { "cell_type": "markdown", "id": "00dafb31", "metadata": {}, "source": [ "## Configuring the Environment\n", "\n", "Before any training can be performed, an accelerate config file must exist in the system. Usually this can be done by running the following in a terminal:\n", "\n", "```bash\n", "accelerate config\n", "```\n", "\n", "However, if general defaults are fine and you are *not* running on a TPU, accelerate has a utility to quickly write your GPU configuration into a config file via `write_basic_config`.\n", "\n", "The following cell will restart Jupyter after writing the configuration, as CUDA code was called to perform this. CUDA can't be initialized more than once (once for the single-GPU's notebooks use by default, and then what would be again when `notebook_launcher` is called). It's fine to debug in the notebook and have calls to CUDA, but remember that in order to finally train a full cleanup and restart will need to be performed, such as what is shown below:" ] }, { "cell_type": "code", "execution_count": 1, "id": "ae835e68", "metadata": {}, "outputs": [], "source": [ "#import os\n", "#from accelerate.utils import write_basic_config\n", "#write_basic_config() # Write a config file\n", "#os._exit(00) # Restart the notebook" ] }, { "cell_type": "markdown", "id": "e067d11a", "metadata": {}, "source": [ "## Preparing the Dataset and Model\n", "\n", "Next you should prepare your dataset. As mentioned at earlier, great care should be taken when preparing the `DataLoaders` and model to make sure that **nothing** is put on *any* GPU. \n", "\n", "If you do, it is recommended to put that specific code into a function and call that from within the notebook launcher interface, which will be shown later. \n", "\n", "Make sure the dataset is downloaded based on the directions [here](https://github.com/huggingface/accelerate/tree/main/examples#simple-vision-example)" ] }, { "cell_type": "code", "execution_count": 2, "id": "4805cb1a", "metadata": {}, "outputs": [], "source": [ "import os, re, torch, PIL\n", "import numpy as np\n", "\n", "from torch.optim.lr_scheduler import OneCycleLR\n", "from torch.utils.data import DataLoader, Dataset\n", "from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor\n", "\n", "from accelerate import Accelerator\n", "from accelerate.utils import set_seed\n", "from timm import create_model" ] }, { "cell_type": "markdown", "id": "9938f8e4", "metadata": {}, "source": [ "First we'll create a function to extract the class name based on a file:" ] }, { "cell_type": "code", "execution_count": 3, "id": "bcd4907f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "beagle_32.jpg\n" ] } ], "source": [ "import os\n", "data_dir = \"../../images\"\n", "fnames = os.listdir(data_dir)\n", "fname = fnames[0]\n", "print(fname)" ] }, { "cell_type": "markdown", "id": "39caa398", "metadata": {}, "source": [ "In the case here, the label is `beagle`:" ] }, { "cell_type": "code", "execution_count": 4, "id": "43e28d35", "metadata": {}, "outputs": [], "source": [ "import re\n", "def extract_label(fname):\n", " stem = fname.split(os.path.sep)[-1]\n", " return re.search(r\"^(.*)_\\d+\\.jpg$\", stem).groups()[0]" ] }, { "cell_type": "code", "execution_count": 5, "id": "fab40fe9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'beagle'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "extract_label(fname)" ] }, { "cell_type": "markdown", "id": "0bf6a733", "metadata": {}, "source": [ "Next we'll create a `Dataset` class:" ] }, { "cell_type": "code", "execution_count": 6, "id": "72242e1b", "metadata": {}, "outputs": [], "source": [ "class PetsDataset(Dataset):\n", " def __init__(self, file_names, image_transform=None, label_to_id=None):\n", " self.file_names = file_names\n", " self.image_transform = image_transform\n", " self.label_to_id = label_to_id\n", "\n", " def __len__(self):\n", " return len(self.file_names)\n", "\n", " def __getitem__(self, idx):\n", " fname = self.file_names[idx]\n", " raw_image = PIL.Image.open(fname)\n", " image = raw_image.convert(\"RGB\")\n", " if self.image_transform is not None:\n", " image = self.image_transform(image)\n", " label = extract_label(fname)\n", " if self.label_to_id is not None:\n", " label = self.label_to_id[label]\n", " return {\"image\": image, \"label\": label}" ] }, { "cell_type": "markdown", "id": "222fc93a", "metadata": {}, "source": [ "And build our dataset" ] }, { "cell_type": "code", "execution_count": 7, "id": "8a0f2499", "metadata": {}, "outputs": [], "source": [ "# Grab all the image filenames\n", "fnames = [\n", " os.path.join(\"../../images\", fname)\n", " for fname in fnames\n", " if fname.endswith(\".jpg\")\n", "]\n", "\n", "# Build the labels\n", "all_labels = [\n", " extract_label(fname)\n", " for fname in fnames\n", "]\n", "id_to_label = list(set(all_labels))\n", "id_to_label.sort()\n", "label_to_id = {lbl: i for i, lbl in enumerate(id_to_label)}" ] }, { "cell_type": "markdown", "id": "343e66fd", "metadata": {}, "source": [ "> Note: This will be stored inside of a function as we'll be setting our seed during training." ] }, { "cell_type": "code", "execution_count": 8, "id": "25430a5e", "metadata": {}, "outputs": [], "source": [ "def get_dataloaders(batch_size:int=64):\n", " \"Builds a set of dataloaders with a batch_size\"\n", " random_perm = np.random.permutation(len(fnames))\n", " cut = int(0.8 * len(fnames))\n", " train_split = random_perm[:cut]\n", " eval_split = random_perm[:cut]\n", " \n", " # For training we use a simple RandomResizedCrop\n", " train_tfm = Compose([\n", " RandomResizedCrop((224, 224), scale=(0.5, 1.0)),\n", " ToTensor()\n", " ])\n", " train_dataset = PetsDataset(\n", " [fnames[i] for i in train_split],\n", " image_transform=train_tfm,\n", " label_to_id=label_to_id\n", " )\n", " \n", " # For evaluation we use a deterministic Resize\n", " eval_tfm = Compose([\n", " Resize((224, 224)),\n", " ToTensor()\n", " ])\n", " eval_dataset = PetsDataset(\n", " [fnames[i] for i in eval_split],\n", " image_transform=eval_tfm,\n", " label_to_id=label_to_id\n", " )\n", " \n", " # Instantiate dataloaders\n", " train_dataloader = DataLoader(\n", " train_dataset, \n", " shuffle=True, \n", " batch_size=batch_size,\n", " num_workers=4\n", " )\n", " eval_dataloader = DataLoader(\n", " eval_dataset,\n", " shuffle=False,\n", " batch_size=batch_size*2,\n", " num_workers=4\n", " )\n", " return train_dataloader, eval_dataloader" ] }, { "cell_type": "markdown", "id": "91084504", "metadata": {}, "source": [ "## Writing the Training Function\n", "\n", "Now we can build our training loop. `notebook_launcher` works by passing in a function to call that will be ran across the distributed system.\n", "\n", "Here is a basic training loop for our animal classification problem:" ] }, { "cell_type": "code", "execution_count": 9, "id": "ebab7267", "metadata": {}, "outputs": [], "source": [ "from torch.optim.lr_scheduler import CosineAnnealingLR" ] }, { "cell_type": "code", "execution_count": 10, "id": "4366f90e", "metadata": {}, "outputs": [], "source": [ "def training_loop(mixed_precision=\"fp16\", seed:int=42, batch_size:int=64):\n", " set_seed(seed)\n", " # Initialize accelerator\n", " accelerator = Accelerator(mixed_precision=mixed_precision)\n", " # Build dataloaders\n", " train_dataloader, eval_dataloader = get_dataloaders(batch_size)\n", " \n", " # instantiate the model (we build the model here so that the seed also controls new weight initaliziations)\n", " model = create_model(\"resnet50d\", pretrained=True, num_classes=len(label_to_id))\n", " \n", " # Freeze the base model\n", " for param in model.parameters(): \n", " param.requires_grad=False\n", " for param in model.get_classifier().parameters():\n", " param.requires_grad=True\n", " \n", " # We normalize the batches of images to be a bit faster\n", " mean = torch.tensor(model.default_cfg[\"mean\"])[None, :, None, None]\n", " std = torch.tensor(model.default_cfg[\"std\"])[None, :, None, None]\n", " \n", " # To make this constant available on the active device, we set it to the accelerator device\n", " mean = mean.to(accelerator.device)\n", " std = std.to(accelerator.device)\n", " \n", " # Intantiate the optimizer\n", " optimizer = torch.optim.Adam(params=model.parameters(), lr = 3e-2/25)\n", " \n", " # Instantiate the learning rate scheduler\n", " lr_scheduler = OneCycleLR(\n", " optimizer=optimizer, \n", " max_lr=3e-2, \n", " epochs=5, \n", " steps_per_epoch=len(train_dataloader)\n", " )\n", " \n", " # Prepare everything\n", " # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n", " # prepare method.\n", " model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(\n", " model, optimizer, train_dataloader, eval_dataloader, lr_scheduler\n", " )\n", " \n", " # Now we train the model\n", " for epoch in range(5):\n", " model.train()\n", " for step, batch in enumerate(train_dataloader):\n", " # We could avoid this line since we set the accelerator with `device_placement=True`.\n", " batch = {k: v.to(accelerator.device) for k, v in batch.items()}\n", " inputs = (batch[\"image\"] - mean) / std\n", " outputs = model(inputs)\n", " loss = torch.nn.functional.cross_entropy(outputs, batch[\"label\"])\n", " accelerator.backward(loss)\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " model.eval()\n", " accurate = 0\n", " num_elems = 0\n", " for _, batch in enumerate(eval_dataloader):\n", " # We could avoid this line since we set the accelerator with `device_placement=True`.\n", " batch = {k: v.to(accelerator.device) for k, v in batch.items()}\n", " inputs = (batch[\"image\"] - mean) / std\n", " with torch.no_grad():\n", " outputs = model(inputs)\n", " predictions = outputs.argmax(dim=-1)\n", " accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch[\"label\"])\n", " num_elems += accurate_preds.shape[0]\n", " accurate += accurate_preds.long().sum()\n", "\n", " eval_metric = accurate.item() / num_elems\n", " # Use accelerator.print to print only on the main process.\n", " accelerator.print(f\"epoch {epoch}: {100 * eval_metric:.2f}\")" ] }, { "cell_type": "markdown", "id": "117a7f5f", "metadata": {}, "source": [ "All that's left is to use the `notebook_launcher`.\n", "\n", "We pass in the function, the arguments (as a tuple), and the number of processes to train on. (See the [documentation](https://huggingface.co/docs/accelerate/launcher) for more information)" ] }, { "cell_type": "code", "execution_count": 11, "id": "88a096cf", "metadata": {}, "outputs": [], "source": [ "from accelerate import notebook_launcher" ] }, { "cell_type": "code", "execution_count": 13, "id": "439fb3e6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Launching training on 2 GPUs.\n", "epoch 0: 88.12\n", "epoch 1: 91.73\n", "epoch 2: 92.58\n", "epoch 3: 93.90\n", "epoch 4: 94.71\n" ] } ], "source": [ "args = (\"fp16\", 42, 64)\n", "notebook_launcher(training_loop, args, num_processes=2)" ] }, { "cell_type": "markdown", "id": "62fab197", "metadata": {}, "source": [ "And that's it!" ] }, { "cell_type": "markdown", "id": "b83c964a", "metadata": {}, "source": [ "## Conclusion\n", "\n", "This notebook showed how to perform distributed training from inside of a Jupyter Notebook. Some key notes to remember:\n", "\n", "- Make sure to save any code that use CUDA (or CUDA imports) for the function passed to `notebook_launcher`\n", "- Set the `num_processes` to be the number of devices used for training (such as number of GPUs, CPUs, TPUs, etc)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }