{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(hpu_bert_training)=\n",
"# BERT Model Training with Intel Gaudi\n",
"\n",
"\n",
"
\n",
"\n",
"
\n",
"\n",
"In this notebook, we will train a BERT model for sequence classification using the Yelp review full dataset. We will use the `transformers` and `datasets` libraries from Hugging Face, along with `ray.train` for distributed training.\n",
"\n",
"[Intel Gaudi AI Processors (HPUs)](https://habana.ai) are AI hardware accelerators designed by Intel Habana Labs. For more information, see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/index.html) and [Gaudi Developer Docs](https://developer.habana.ai/).\n",
"\n",
"## Configuration\n",
"\n",
"A node with Gaudi/Gaudi2 installed is required to run this example. Both Gaudi and Gaudi2 have 8 HPUs. We will use 2 workers to train the model, each using 1 HPU.\n",
"\n",
"We recommend using a prebuilt container to run these examples. To run a container, you need Docker. See [Install Docker Engine](https://docs.docker.com/engine/install/) for installation instructions.\n",
"\n",
"Next, follow [Run Using Containers](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html?highlight=installer#run-using-containers) to install the Gaudi drivers and container runtime.\n",
"\n",
"Next, start the Gaudi container:\n",
"```bash\n",
"docker pull vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest\n",
"docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest\n",
"```\n",
"\n",
"Inside the container, install the following dependecies to run this notebook.\n",
"```bash\n",
"pip install ray[train] notebook transformers datasets evaluate\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"# Import necessary libraries\n",
"\n",
"import os\n",
"from typing import Dict\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"\n",
"import numpy as np\n",
"import evaluate\n",
"from datasets import load_dataset\n",
"import transformers\n",
"from transformers import (\n",
" Trainer,\n",
" TrainingArguments,\n",
" AutoTokenizer,\n",
" AutoModelForSequenceClassification,\n",
")\n",
"\n",
"import ray.train\n",
"from ray.train import ScalingConfig\n",
"from ray.train.torch import TorchTrainer\n",
"from ray.train.torch import TorchConfig\n",
"from ray.runtime_env import RuntimeEnv\n",
"\n",
"import habana_frameworks.torch.core as htcore"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Metrics Setup\n",
"\n",
"We will use accuracy as our evaluation metric. The `compute_metrics` function will calculate the accuracy of our model's predictions."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Metrics\n",
"metric = evaluate.load(\"accuracy\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Function\n",
"\n",
"This function will be executed by each worker during training. It handles data loading, tokenization, model initialization, and the training loop. Compared to a training function for GPU, no changes are needed to port to HPU. Internally, Ray Train does these things:\n",
"\n",
"* Detect HPU and set the device.\n",
"\n",
"* Initializes the habana PyTorch backend.\n",
"\n",
"* Initializes the habana distributed backend."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def train_func_per_worker(config: Dict):\n",
" \n",
" # Datasets\n",
" dataset = load_dataset(\"yelp_review_full\")\n",
" tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
" \n",
" def tokenize_function(examples):\n",
" return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
"\n",
" lr = config[\"lr\"]\n",
" epochs = config[\"epochs\"]\n",
" batch_size = config[\"batch_size_per_worker\"]\n",
"\n",
" train_dataset = dataset[\"train\"].select(range(1000)).map(tokenize_function, batched=True)\n",
" eval_dataset = dataset[\"test\"].select(range(1000)).map(tokenize_function, batched=True)\n",
"\n",
" # Prepare dataloader for each worker\n",
" dataloaders = {}\n",
" dataloaders[\"train\"] = torch.utils.data.DataLoader(\n",
" train_dataset, \n",
" shuffle=True, \n",
" collate_fn=transformers.default_data_collator, \n",
" batch_size=batch_size\n",
" )\n",
" dataloaders[\"test\"] = torch.utils.data.DataLoader(\n",
" eval_dataset, \n",
" shuffle=True, \n",
" collate_fn=transformers.default_data_collator, \n",
" batch_size=batch_size\n",
" )\n",
"\n",
" # Obtain HPU device automatically\n",
" device = ray.train.torch.get_device()\n",
"\n",
" # Prepare model and optimizer\n",
" model = AutoModelForSequenceClassification.from_pretrained(\n",
" \"bert-base-cased\", num_labels=5\n",
" )\n",
" model = model.to(device)\n",
" \n",
" optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)\n",
"\n",
" # Start training loops\n",
" for epoch in range(epochs):\n",
" # Each epoch has a training and validation phase\n",
" for phase in [\"train\", \"test\"]:\n",
" if phase == \"train\":\n",
" model.train() # Set model to training mode\n",
" else:\n",
" model.eval() # Set model to evaluate mode\n",
"\n",
" # breakpoint()\n",
" for batch in dataloaders[phase]:\n",
" batch = {k: v.to(device) for k, v in batch.items()}\n",
"\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward\n",
" with torch.set_grad_enabled(phase == \"train\"):\n",
" # Get model outputs and calculate loss\n",
" \n",
" outputs = model(**batch)\n",
" loss = outputs.loss\n",
"\n",
" # backward + optimize only if in training phase\n",
" if phase == \"train\":\n",
" loss.backward()\n",
" optimizer.step()\n",
" print(f\"train epoch:[{epoch}]\\tloss:{loss:.6f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Main Training Function\n",
"\n",
"The `train_bert` function sets up the distributed training environment using Ray and starts the training process. To enable training using HPU, we only need to make the following changes:\n",
"* Require an HPU for each worker in ScalingConfig\n",
"* Set backend to \"hccl\" in TorchConfig"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def train_bert(num_workers=2):\n",
" global_batch_size = 8\n",
"\n",
" train_config = {\n",
" \"lr\": 1e-3,\n",
" \"epochs\": 10,\n",
" \"batch_size_per_worker\": global_batch_size // num_workers,\n",
" }\n",
"\n",
" # Configure computation resources\n",
" # In ScalingConfig, require an HPU for each worker\n",
" scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={\"CPU\": 1, \"HPU\": 1})\n",
" # Set backend to hccl in TorchConfig\n",
" torch_config = TorchConfig(backend = \"hccl\")\n",
" \n",
" # start your ray cluster\n",
" ray.init()\n",
" \n",
" # Initialize a Ray TorchTrainer\n",
" trainer = TorchTrainer(\n",
" train_loop_per_worker=train_func_per_worker,\n",
" train_loop_config=train_config,\n",
" torch_config=torch_config,\n",
" scaling_config=scaling_config,\n",
" )\n",
"\n",
" result = trainer.fit()\n",
" print(f\"Training result: {result}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Training\n",
"\n",
"Finally, we call the `train_bert` function to start the training process. You can adjust the number of workers to use.\n",
"\n",
"Note: the following warning is fine, and is resolved in SynapseAI version 1.14.0+:\n",
"```text\n",
"/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_bert(num_workers=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Possible outputs\n",
"\n",
"``` text\n",
"Downloading builder script: 100%|██████████| 4.20k/4.20k [00:00<00:00, 27.0MB/s]\n",
"2025-03-03 03:37:08,776 INFO worker.py:1841 -- Started a local Ray instance.\n",
"/usr/local/lib/python3.10/dist-packages/ray/tune/impl/tuner_internal.py:125: RayDeprecationWarning: The `RunConfig` class should be imported from `ray.tune` when passing it to the Tuner. Please update your imports. See this issue for more context and migration options: https://github.com/ray-project/ray/issues/49454. Disable these warnings by setting the environment variable: RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS=0\n",
" _log_deprecation_warning(\n",
"(RayTrainWorker pid=75123) Setting up process group for: env:// [rank=0, world_size=2]\n",
"(TorchTrainer pid=74734) Started distributed worker processes: \n",
"(TorchTrainer pid=74734) - (node_id=eef984cd0cd96cce50bad1b1dab12e19c809047f10be3c829524a3d1, ip=100.83.111.228, pid=75123) world_rank=0, local_rank=0, node_rank=0\n",
"(TorchTrainer pid=74734) - (node_id=eef984cd0cd96cce50bad1b1dab12e19c809047f10be3c829524a3d1, ip=100.83.111.228, pid=75122) world_rank=1, local_rank=1, node_rank=0\n",
"Generating train split: 0%| | 0/650000 [00:00, ? examples/s]\n",
"Generating train split: 7%|▋ | 45000/650000 [00:00<00:01, 435976.18 examples/s]\n",
"Generating train split: 15%|█▍ | 95000/650000 [00:00<00:01, 469481.51 examples/s]\n",
"Generating train split: 23%|██▎ | 150000/650000 [00:00<00:01, 477676.99 examples/s]\n",
"Generating train split: 31%|███ | 203000/650000 [00:00<00:00, 493746.70 examples/s]\n",
"Generating train split: 43%|████▎ | 279000/650000 [00:00<00:00, 499340.09 examples/s]\n",
"Generating train split: 55%|█████▍ | 355000/650000 [00:00<00:00, 498613.65 examples/s]\n",
"Generating train split: 66%|██████▋ | 431000/650000 [00:00<00:00, 497799.19 examples/s]\n",
"Generating train split: 78%|███████▊ | 506000/650000 [00:01<00:00, 495696.93 examples/s]\n",
"Generating train split: 86%|████████▌ | 556000/650000 [00:01<00:00, 494508.05 examples/s]\n",
"Generating train split: 94%|█████████▎| 609000/650000 [00:01<00:00, 490725.53 examples/s]\n",
"Generating train split: 100%|██████████| 650000/650000 [00:01<00:00, 494916.42 examples/s]\n",
"Generating test split: 0%| | 0/50000 [00:00, ? examples/s]\n",
"Generating test split: 100%|██████████| 50000/50000 [00:00<00:00, 509619.87 examples/s]\n",
"Map: 0%| | 0/1000 [00:00, ? examples/s]\n",
"Map: 100%|██████████| 1000/1000 [00:00<00:00, 3998.33 examples/s]\n",
"Map: 100%|██████████| 1000/1000 [00:00<00:00, 4051.80 examples/s]\n",
"Map: 100%|██████████| 1000/1000 [00:00<00:00, 3869.20 examples/s]\n",
"(RayTrainWorker pid=75123) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"(RayTrainWorker pid=75123) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"Map: 0%| | 0/1000 [00:00, ? examples/s] [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)\n",
"Map: 100%|██████████| 1000/1000 [00:00<00:00, 3782.66 examples/s] [repeated 2x across cluster]\n",
"(RayTrainWorker pid=75123) ============================= HABANA PT BRIDGE CONFIGURATION =========================== \n",
"(RayTrainWorker pid=75123) PT_HPU_LAZY_MODE = 1\n",
"(RayTrainWorker pid=75123) PT_HPU_RECIPE_CACHE_CONFIG = ,false,1024\n",
"(RayTrainWorker pid=75123) PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807\n",
"(RayTrainWorker pid=75123) PT_HPU_LAZY_ACC_PAR_MODE = 1\n",
"(RayTrainWorker pid=75123) PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0\n",
"(RayTrainWorker pid=75123) PT_HPU_EAGER_PIPELINE_ENABLE = 1\n",
"(RayTrainWorker pid=75123) PT_HPU_EAGER_COLLECTIVE_PIPELINE_ENABLE = 1\n",
"(RayTrainWorker pid=75123) PT_HPU_ENABLE_LAZY_COLLECTIVES = 0\n",
"(RayTrainWorker pid=75123) ---------------------------: System Configuration :---------------------------\n",
"(RayTrainWorker pid=75123) Num CPU Cores : 160\n",
"(RayTrainWorker pid=75123) CPU RAM : 1056374420 KB\n",
"(RayTrainWorker pid=75123) ------------------------------------------------------------------------------\n",
"2025-03-03 03:41:04,658 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/root/ray_results/TorchTrainer_2025-03-03_03-37-11' in 0.0020s.\n",
"\n",
"View detailed results here: /root/ray_results/TorchTrainer_2025-03-03_03-37-11\n",
"To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2025-03-03_03-37-06_983992_65223/artifacts/2025-03-03_03-37-11/TorchTrainer_2025-03-03_03-37-11/driver_artifacts`\n",
"\n",
"Training started with configuration:\n",
"╭─────────────────────────────────────────────────╮\n",
"│ Training config │\n",
"├─────────────────────────────────────────────────┤\n",
"│ train_loop_config/batch_size_per_worker 4 │\n",
"│ train_loop_config/epochs 10 │\n",
"│ train_loop_config/lr 0.001 │\n",
"╰─────────────────────────────────────────────────╯\n",
"(RayTrainWorker pid=75123) train epoch:[0] loss:1.979938\n",
"(RayTrainWorker pid=75123) train epoch:[0] loss:1.756611 [repeated 36x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[0] loss:1.643875 [repeated 180x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[0] loss:1.416416 [repeated 177x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[1] loss:1.272513 [repeated 107x across cluster]\n",
"(RayTrainWorker pid=75123) \n",
"(RayTrainWorker pid=75123) train epoch:[1] loss:2.086884 [repeated 155x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[1] loss:1.426217 [repeated 178x across cluster]\n",
"(RayTrainWorker pid=75122) train epoch:[1] loss:0.991381 [repeated 160x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[2] loss:1.294097 [repeated 28x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[2] loss:1.386306 [repeated 169x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[2] loss:1.190416 [repeated 181x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[3] loss:1.171733 [repeated 130x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[3] loss:1.287821 [repeated 152x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[3] loss:1.055692 [repeated 179x across cluster]\n",
"(RayTrainWorker pid=75122) train epoch:[3] loss:1.677789 [repeated 162x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[4] loss:0.942071 [repeated 19x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[4] loss:1.592500 [repeated 167x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[4] loss:0.936934 [repeated 180x across cluster]\n",
"(RayTrainWorker pid=75123) \n",
"(RayTrainWorker pid=75123) train epoch:[5] loss:2.465384 [repeated 141x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[5] loss:1.659170 [repeated 156x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[5] loss:1.850438 [repeated 180x across cluster]\n",
"(RayTrainWorker pid=75122) train epoch:[5] loss:1.101623 [repeated 160x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[6] loss:2.125591 [repeated 18x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[6] loss:1.612838 [repeated 170x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[6] loss:1.759160 [repeated 177x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[7] loss:1.338552 [repeated 139x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[7] loss:1.467959 [repeated 157x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[7] loss:1.682137 [repeated 181x across cluster]\n",
"(RayTrainWorker pid=75123) \n",
"(RayTrainWorker pid=75123) train epoch:[8] loss:1.395805 [repeated 162x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[8] loss:1.527835 [repeated 153x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[8] loss:1.672311 [repeated 177x across cluster]\n",
"(RayTrainWorker pid=75123) \n",
"(RayTrainWorker pid=75122) train epoch:[8] loss:1.093186 [repeated 166x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[9] loss:1.457587 [repeated 13x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[9] loss:1.727377 [repeated 171x across cluster]\n",
"(RayTrainWorker pid=75123) train epoch:[9] loss:1.694001 [repeated 182x across cluster]\n",
"\n",
"Training completed after 0 iterations at 2025-03-03 03:41:04. Total running time: 3min 53s\n",
"\n",
"Training result: Result(\n",
" metrics={},\n",
" path='/root/ray_results/TorchTrainer_2025-03-03_03-37-11/TorchTrainer_ca6cf_00000_0_2025-03-03_03-37-11',\n",
" filesystem='local',\n",
" checkpoint=None\n",
")\n",
"(RayTrainWorker pid=75122) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"(RayTrainWorker pid=75122) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"(RayTrainWorker pid=75122) train epoch:[9] loss:0.417845 [repeated 136x across cluster]\n",
"```"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.12 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orphan": true,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}