{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# How to scale LLM workloads to 20B+ with multi-node clusters on Amazon SageMaker using Hugging Face and PyTorch FSDP\n", "\n", "In this tutorial, we will fine-tune the new [GPT-NeoXT-Chat-Base-20B](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B) on the [ELI5](https://huggingface.co/datasets/eli5) dataset to improve the explanation and question-answering skills of the agent. The [ELI5](https://huggingface.co/datasets/eli5) dataset is an English-language dataset of questions and answers gathered from three subreddits where users ask factual questions requiring paragraph-length or longer answers. \n", "\n", "[GPT-NeoXT-Chat-Base](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B) is a 20B open-source LLM, which makes it hard to fine-tune on a single GPU or even a single Node with multiple GPUs. We are going to use Amazon SageMaker managed training platform as our infrastructure backbone to help us create a multi-node cluster to easily run our distributed training. As instances, we will use 2x p4d.24xlarge instances, which come with 8x NIVIDA A100 40GB GPUs. \n", "\n", "*Note: You might have to increase and request a quota for those instances.*\n", "\n", "As distributed training framework, we will use Pytorch FSDP + Hugging Face Transformers Trainer, which will make it super easy to distribute our model and data in a fully sharded way across all our nodes and GPUs.\n", "\n", "\n", "## What is PyTorch Fully Sharded Data Parallel (FSDP)?\n", "\n", "PyTorch FSDP (Fully Sharded Data Parallel) is an extension of data parallelism that enables efficient large-scale training of LLMs. With FSDP, each GPU stores only a subset of the model and associated optimizer states and gradients and can optionally offload the sharded model parameters to CPUs. This helps maximize the overlap between network communication and model computation, reducing the memory footprint on GPUs.\n", "\n", "FSDP optimizations include:\n", "\n", "- Transformer Wrapping Policy\n", "- Mixed Precision (bf16)\n", "- Activation Checkpointing (Gradient Checkpointing)\n", "- Full Sharding Strategy\n", "\n", "PyTorch FSDP is natively integrated into the [Hugging Face Trainer](https://huggingface.co/docs/transformers/main_classes/trainer#pytorch-fully-sharded-data-parallel), making it easy to adapt and use. You can learn more about PyTorch FSDP in [Efficient Large-Scale Training with Pytorch FSDP and AWS](https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/) or [Introducing PyTorch Fully Sharded Data Parallel (FSDP) API](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/) blog post." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install \"transformers==4.28.1\" \"datasets[s3]==2.9.0\" \"sagemaker>=2.150.0\" --upgrade --quiet" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "sess = sagemaker.Session()\n", "# sagemaker session bucket -> used for uploading data, models and logs\n", "# sagemaker will automatically create this bucket if it not exists\n", "sagemaker_session_bucket=None\n", "if sagemaker_session_bucket is None and sess is not None:\n", " # set to default bucket if a bucket name is not given\n", " sagemaker_session_bucket = sess.default_bucket()\n", "\n", "try:\n", " role = sagemaker.get_execution_role()\n", "except ValueError:\n", " iam = boto3.client('iam')\n", " role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n", "\n", "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n", "\n", "print(f\"sagemaker role arn: {role}\")\n", "print(f\"sagemaker bucket: {sess.default_bucket()}\")\n", "print(f\"sagemaker session region: {sess.boto_region_name}\")\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Load and prepare the dataset\n", "\n", "As the base dataset, we will use the [ELI5](https://huggingface.co/datasets/eli5) dataset, but before fine-tuning the model, we need to preprocess the data. We will create a \"chat\" version of the dataset by adding `` and ``tokens and add an end-of-sequence `<|endoftext|>` token to help the model learn to distinguish consecutive examples. Additionally, we create chunks of `2048` tokens ([model max length](https://huggingface.co/EleutherAI/gpt-neox-20b)) to avoid unnecessary padding and computing. \n", "\n", "The first step is to load our dataset from Hugging Face. The dataset contains `272634` samples for `eli5`. We will downsample the dataset to `25 000` to make it more realistic for real-world use cases." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer \n", "\n", "# Load Tokenizer \n", "model_id = \"togethercomputer/GPT-NeoXT-Chat-Base-20B\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "\n", "# Load dataset from huggingface.co\n", "dataset_id = \"eli5\"\n", "dataset = load_dataset(dataset_id, split=\"train_eli5\")\n", "\n", "# downsample dataset to 10k\n", "dataset = dataset.shuffle(42).select(range(25_000))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "An [ELI5](https://huggingface.co/datasets/eli5) sample can include multiple answers to a “question”. We will select the answer with the highest user score for our explanation. \n", "\n", "*Note: This dataset is a good example of using reinforcement learning for training transformers learning to generate answers with higher scores. Let me know if you are interested in an example of that.*\n", "\n", "The next step is to convert our dataset into a chat version. Here we will follow the instructions on the [Model card](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B#strengths-of-the-model) and add the EOS token." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from random import randint\n", "\n", "# dataset template for chat conversation\n", "template=f''': Explain like I am five: {{question}}\n", ": {{answer}}{{eos_token}}'''\n", "\n", "eos_token = tokenizer.eos_token \n", "\n", "def template_dataset(sample):\n", "\tsample[\"text\"] = template.format(\n", "\t\t\t\t\t\t\t\t\t\t\t\t\t\tquestion=sample[\"title\"], \n", "\t\t\t\t\t\t\t\t\t\t\t\t\t\tanswer=sample[\"answers\"][\"text\"][0],\n", "\t\t\t\t\t\t\t\t\t\t\t\t\t\teos_token=eos_token\n", "\t\t\t\t\t\t\t\t\t\t\t\t\t)\n", "\treturn sample\n", "\n", "# apply prompt template per sample\n", "dataset = dataset.map(template_dataset, remove_columns=list(dataset.features))\n", "\n", "# print random sample\n", "print(dataset[randint(0, 10_000)])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The last step of the data preparation is to tokenize and chunk our dataset. We convert our inputs (text) to token IDs by tokenizing, which the model can understand. Additionally, we concatenate our dataset samples into chunks of `2048` to avoid unnecessary padding." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from itertools import chain\n", "from functools import partial\n", "\n", "# empty list to save remainder from batches to use in next batch\n", "remainder = {\"input_ids\": [], \"attention_mask\": []}\n", "\n", "def chunk(sample, chunk_length=2048):\n", " # define global remainder variable to save remainder from batches to use in next batch\n", " global remainder\n", " # Concatenate all texts and add remainder from previous batch\n", " concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}\n", " concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}\n", " # get total number of tokens for batch\n", " batch_total_length = len(concatenated_examples[list(sample.keys())[0]])\n", "\n", " # get max number of chunks for batch\n", " if batch_total_length >= chunk_length:\n", " batch_chunk_length = (batch_total_length // chunk_length) * chunk_length\n", "\n", " # Split by chunks of max_len.\n", " result = {\n", " k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " # add remainder to global variable for next batch\n", " remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}\n", " # prepare labels\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result\n", "\n", "\n", "# tokenize and chunk dataset\n", "lm_dataset = dataset.map(\n", " lambda sample: tokenizer(sample[\"text\"]), batched=True, remove_columns=list(dataset.features)\n", ").map(\n", " partial(chunk, chunk_length=2048),\n", " batched=True,\n", ")\n", "\n", "# Print total number of samples\n", "print(f\"Total number of samples: {len(lm_dataset)}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "After we processed the datasets we are going to use the new [FileSystem integration](https://huggingface.co/docs/datasets/filesystems) to upload our dataset to S3. We are using the `sess.default_bucket()`, adjust this if you want to store the dataset in a different S3 bucket. We will use the S3 path later in our training script." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save train_dataset to s3\n", "training_input_path = f's3://{sess.default_bucket()}/processed/eli-5/train'\n", "lm_dataset.save_to_disk(training_input_path)\n", "\n", "print(\"uploaded data to:\")\n", "print(f\"training dataset to: {training_input_path}\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Fine-tune the GPT model using FSDP on Amazon SageMaker\n", "\n", "As mentioned in the beginning, we will use Amazon SageMaker and PyTorch FSDP to train our model. Amazon SageMaker makes it easy to create a multi-node cluster to train our model in a distributed manner. Lately, the `sagemaker` python SDK got support to run training jobs using `torchrun`, to distribute the script across multiple nodes and GPUs. \n", "\n", "To use `torchrun` to execute our scripts, we only have to define the `distribution` parameter in our Estimator and set it to `\"torch_distributed\": {\"enabled\": True}`. This tells sagemaker to launch our training job with.\n", "\n", "```python\n", "torchrun --nnodes 2 --nproc_per_node 8 --master_addr algo-1 --master_port 7777 --node_rank 1 run_clm.py --bf16 True --dataset_path /opt/ml/input/data/training --epochs 3 --fsdp \"full_shard auto_wrap\" --fsdp_transformer_layer_cls_to_wrap GPTNeoXLayer --gradient_checkpointing True --model_id togethercomputer/GPT-NeoXT-Chat-Base-20B --optimizer adamw_apex_fused --per_device_train_batch_size 2\n", "```\n", "\n", "To use FSDP with the Hugging Face Trainer, we need to provide our `fsdp` strategy as well as the `transformer layer policy`. \n", "\n", "In our example, we will use `full shard auto_wrap` and `GPTNeoXLayer`as transformer layer policy. If you run this example and change the model id make sure to also adjust the transformer layer policy. \n", "\n", "We prepared a [run_clm.py](https://www.notion.so/schmidphilipp/scripts/run_clm.py), which implements causal language modeling and accepts our fsdp and other hyperparameters.\n", "\n", "To create a sagemaker training job, we create an `HuggingFace` Estimator and provide all our information. SagMaker takes care of starting and managing all the required ec2 instances for us, provides the correct huggingface container, uploads the provided scripts and downloads the data from our S3 bucket into the container at `/opt/ml/input/data`. Then, it starts the training job by running." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "from sagemaker.huggingface import HuggingFace\n", "\n", "# define Training Job Name \n", "job_name = f'huggingface-fsdp-{time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.localtime())}'\n", "\n", "\n", "# hyperparameters, which are passed into the training job\n", "hyperparameters={\n", " 'model_id': 'togethercomputer/GPT-NeoXT-Chat-Base-20B', # model id from huggingface.co/models\n", " 'dataset_path': '/opt/ml/input/data/training', # path where sagemaker will save training dataset\n", " 'gradient_checkpointing': True, # enable gradient checkpointing\n", " 'bf16': True, # enable mixed precision training\n", " 'optimizer': \"adamw_apex_fused\", # optimizer\n", " 'per_device_train_batch_size': 2, # batch size per device during training\n", " 'epochs': 3, # number of epochs to train\n", " 'fsdp': '\"full_shard auto_wrap\"', # fully sharded data parallelism\n", " 'fsdp_transformer_layer_cls_to_wrap': \"GPTNeoXLayer\", # transformer layer to wrap\n", "}\n", "\n", "# estimator\n", "huggingface_estimator = HuggingFace(\n", " entry_point='run_clm.py',\n", " source_dir='./scripts',\n", " instance_type=\"ml.p4d.24xlarge\",\n", " instance_count=2,\n", " volume_size=200,\n", " role=role,\n", " job_name=job_name,\n", " transformers_version='4.26.0',\n", " pytorch_version='1.13.1',\n", " py_version=\"py39\",\n", " hyperparameters = hyperparameters,\n", " distribution={\"torch_distributed\": {\"enabled\": True}} # enable torchrun\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We can now start our training job, with the `.fit()` method passing our S3 path to the training script." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define a data input dictonary with our uploaded s3 uris\n", "data = {'training': training_input_path}\n", "\n", "# starting the train job with our uploaded datasets as input\n", "huggingface_estimator.fit(data, wait=True)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The training took `9407` seconds, which is about 2.6 hours. The `ml.p4d.24xlarge` instance we used costs `$37.688` per hour. So the total cost for training `GPT-NeoXT-Chat-Base-20B` is (2.6h * $37.688) * 2 instances which results in $197. We could reduce the cost by using a spot instance or using Parameter Efficient Fine Tuning.\n", "\n", "_Note: Upload the model can take a while. To improve this you can save the artifacts to Hugging Face Hub, since SageMaker first creates an archives, which is pretty slow._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "pytorch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146" } } }, "nbformat": 4, "nbformat_minor": 2 }