{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-tuning Llama-2 Model with Intel Gaudi\n", "\n", "In this Jupyter notebook, we will:\n", "- fine-tuning a [Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) model by using Intel Gaudi accelerators with DDP method\n", "- fine-tuning a [Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) model by using Intel Gaudi accelerators with DeepSpeed method\n", "\n", "We will use PyTorch for model training and Ray for distributed training. We will use dataset [tatsu-lab/alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca).\n", "\n", "[Intel Gaudi AI Processors (HPUs)](https://habana.ai) are AI hardware accelerators designed by Habana Labs. For more information, see [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/index.html) and [Gaudi Developer Docs](https://developer.habana.ai/).\n", "\n", "Basic features for this fine-tuning example are:\n", "- Running on HPUs, support three execution mode: [\"lazy\", \"eager\", \"eager.compile\"](https://docs.habana.ai/en/latest/PyTorch/Reference/PyTorch_Gaudi_Theory_of_Operations.html).\n", "- LoRA training.\n", "- DDP or DeepSpeed based method.\n", "- [`GaudiTrainer`](https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/trainer.py) based training.\n", "- Llama-2-7b/Llama-2-70b model.\n", "- Ray based resource scheduling and management." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare environment\n", "This example run on single node with 4 HPUs.\n", "\n", "We recommend using a prebuilt container to run these examples. To run a container, you need Docker. See [Install Docker Engine](https://docs.docker.com/engine/install/) for installation instructions.\n", "\n", "Next, follow [Run Using Containers](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html?highlight=installer#run-using-containers) to install the Habana drivers and container runtime.\n", "\n", "### Get docker image\n", "``` bash\n", "docker pull vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest\n", "```\n", "### Run docker image\n", "``` bash\n", "docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.15.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest\n", "# maybe should mapping your workspace volumns\n", "```\n", "### Install dependency\n", "``` bash\n", "# \"optimum-habana>1.11.1\" if exection mode \"eager\" or \"eager.compile\" \n", "# \"ray>=2.20.0\"\n", "pip install ray[train] notebook transformers datasets evaluate peft accelerate scikit-learn optimum-habana\n", "\n", "# install deepspeed\n", "pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0\n", "\n", "# this notebook verfied with packages' version:\n", "# transformers==4.38.2\n", "# datasets==2.19.1\n", "# evaluate==0.4.2\n", "# peft==0.4.0\n", "# accelerate==0.27.2\n", "# scikit-learn==1.4.2\n", "# optimum-habana==1.11.1\n", "\n", "# deepspeed==0.12.4+hpu.synapse.v1.15.0\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import necessary libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import copy\n", "from typing import Dict\n", "\n", "import torch\n", "\n", "import datasets\n", "import transformers\n", "from transformers import DataCollatorForLanguageModeling\n", "\n", "from tqdm import tqdm\n", "\n", "import peft\n", "\n", "from optimum.habana import GaudiTrainer, GaudiConfig, GaudiTrainingArguments\n", "from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare Dataset Function\n", "\n", "Preprocessing the raw dataset's each line with specified format." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def preprocess_dataset(raw_datasets):\n", "\n", " PROMPT_DICT = {\n", " \"prompt_with_input\": (\n", " \"Below is an instruction that describes a task, paired with an input that provides further context. \"\n", " \"Write a response that appropriately completes the request.\\n\\n\"\n", " \"### Instruction:\\n{instruction}\\n\\n### Input:\\n{input}\\n\\n### Response:\"\n", " ),\n", " \"prompt_without_input\": (\n", " \"Below is an instruction that describes a task. \"\n", " \"Write a response that appropriately completes the request.\\n\\n\"\n", " \"### Instruction:\\n{instruction}\\n\\n### Response:\"\n", " ),\n", " }\n", "\n", " def create_prompts(examples):\n", " prompts = {}\n", " prompts[\"source\"] = []\n", " prompts[\"target\"] = []\n", " for example in examples:\n", " prompt_template = (\n", " PROMPT_DICT[\"prompt_with_input\"] if example[\"input\"] != \"\" else PROMPT_DICT[\"prompt_without_input\"]\n", " )\n", " source = prompt_template.format_map(example)\n", " prompts[\"source\"].append(source)\n", " prompts[\"target\"].append(example[\"output\"])\n", " return prompts\n", "\n", " # Preprocessing the datasets.\n", " for key in raw_datasets:\n", " prompts = create_prompts(raw_datasets[key])\n", " columns_to_be_removed = list(raw_datasets[key].features.keys())\n", " raw_datasets[key] = raw_datasets[key].add_column(\"prompt_sources\", prompts[\"source\"])\n", " raw_datasets[key] = raw_datasets[key].add_column(\"prompt_targets\", prompts[\"target\"])\n", " raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset to Tokenizer Function\n", "\n", "Tokenize each line in dataset by model tokenizer.\n", "\n", "In example codes, we concatenate the dataset's line content to accelerate training speed.\n", "\n", "All datasets are processed as \"train\" datasets, no evaluation datasets are sampled from raw_datasets." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def preprocess_dataset_to_tokenizer(raw_datasets, tokenizer):\n", " max_seq_length = 512\n", " tokenizer.pad_token_id = 0\n", " tokenizer.eos_token_id = 1\n", " tokenizer.bos_token_id = 2\n", "\n", " def tokenize(prompt, add_eos_token=True):\n", " results = tokenizer(\n", " prompt,\n", " truncation=True,\n", " max_length=max_seq_length,\n", " padding=False,\n", " return_tensors=None,\n", " )\n", " for i in range(len(results[\"input_ids\"])):\n", " if (\n", " results[\"input_ids\"][i][-1] != tokenizer.eos_token_id\n", " and len(results[\"input_ids\"][i]) < max_seq_length\n", " and add_eos_token\n", " ):\n", " results[\"input_ids\"][i].append(tokenizer.eos_token_id)\n", " results[\"attention_mask\"][i].append(1)\n", "\n", " results[\"labels\"] = copy.deepcopy(results[\"input_ids\"])\n", " results[\"input_id_len\"] = [len(result) for result in results[\"input_ids\"]]\n", " return results\n", "\n", " def preprocess_function(examples):\n", " keys = list(examples.data.keys())\n", " if len(keys) != 2:\n", " raise ValueError(\"Unsupported dataset format\")\n", "\n", " st = [s + t for s, t in zip(examples[keys[0]], examples[keys[1]])]\n", "\n", " examples_tokenized = tokenize(st)\n", " input_ids = examples_tokenized[\"input_ids\"]\n", " labels = examples_tokenized[\"labels\"]\n", " return {\n", " \"input_ids\": input_ids,\n", " \"labels\": labels,\n", " \"attention_mask\": examples_tokenized[\"attention_mask\"],\n", " }\n", "\n", " tokenized_datasets = raw_datasets.map(\n", " preprocess_function,\n", " batched=True,\n", " load_from_cache_file=True,\n", " )\n", "\n", " def concatenate_data(dataset, max_seq_length):\n", " concatenated_dataset = {}\n", " for column in dataset.features:\n", " concatenated_data = [item for sample in dataset[column] for item in sample]\n", " reshaped_data = [\n", " concatenated_data[i * max_seq_length : (i + 1) * max_seq_length]\n", " for i in range(len(concatenated_data) // max_seq_length)\n", " ]\n", " concatenated_dataset[column] = reshaped_data\n", " return datasets.Dataset.from_dict(concatenated_dataset)\n", "\n", " tokenized_datasets_ = tokenized_datasets[\"train\"].remove_columns([\"prompt_sources\", \"prompt_targets\"])\n", " tokenized_datasets[\"train\"] = concatenate_data(tokenized_datasets_, max_seq_length)\n", "\n", " return tokenized_datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare training arguments\n", "\n", "here some arguments are hard coded, you can pass arguments from `config`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def prepare_training_args(config: Dict):\n", " # prepare execution mode config\n", " execution_mode = config[\"execution_mode\"]\n", " use_lazy_mode = True if execution_mode == \"lazy\" else False\n", " torch_compile_backend = \"hpu_backend\" if execution_mode == \"eager.compile\" else None\n", "\n", " deepspeed = config[\"deepspeed\"] if \"deepspeed\" in config else None\n", "\n", " return GaudiTrainingArguments(deepspeed=deepspeed,\n", " output_dir=config[\"output\"],\n", " do_train=True,\n", " do_eval=False,\n", " per_device_train_batch_size=config[\"batch_size_per_worker\"],\n", " bf16=True,\n", " learning_rate=config[\"lr\"],\n", " save_strategy=\"no\",\n", " torch_compile_backend=torch_compile_backend,\n", " evaluation_strategy=\"no\",\n", " lr_scheduler_type=\"cosine\",\n", " num_train_epochs=config[\"epochs\"],\n", " use_lazy_mode=use_lazy_mode,\n", " use_habana=True,\n", " pipelining_fwd_bwd=True,\n", " save_only_model=True,\n", " gradient_checkpointing=True,\n", " warmup_ratio=0.03,\n", " throughput_warmup_steps=3,\n", " logging_steps=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare model\n", "\n", "1. download model from huggingface or read model from local directory.\n", "2. convert model to lora model.\n", "3. move model to HPU device.\n", "\n", "If you doesn't want to fine-tune with LoRA, just remove LoRA conversion step." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def prepare_model(config: Dict, device):\n", " # prepare from pretrained model\n", " deepspeed = config[\"deepspeed\"] if \"deepspeed\" in config else None\n", " if deepspeed is not None:\n", " auto_config = transformers.AutoConfig.from_pretrained(config[\"model\"], use_cache=False, revision=\"main\", use_auth_token=None, trust_remote_code=None)\n", " model = transformers.AutoModelForCausalLM.from_pretrained(config[\"model\"], config=auto_config, **config[\"model_config\"])\n", " model.generation_config.attn_softmax_bf16 = True\n", " model.generation_config.use_flash_attention = True\n", " else:\n", " model = transformers.AutoModelForCausalLM.from_pretrained(config[\"model\"], **config[\"model_config\"])\n", " model.enable_input_require_grads()\n", "\n", " # convert to peft model for lora training\n", " peft_config = peft.LoraConfig(**config[\"lora_config\"])\n", " model = peft.get_peft_model(model, peft_config)\n", "\n", " model.to(dtype=config[\"model_config\"][\"torch_dtype\"], device=device)\n", "\n", " return model\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training Function\n", "\n", "This function will be executed by each worker during training, with following steps:\n", "\n", "- preparing training args, an instance of `GaudiTrainingArguments`.\n", "- loading datasets and preprocess datasets, just load the first 4096 item as training datasets.\n", "- loading pretrained model as tokenizer, and process datasets to tokenizer.\n", "- loading pretrained model.\n", "- preparing data collator and gaidu_config.\n", "- preparing instance of `GaudiTrainer`.\n", "- calling `train()` to train model.\n", "- saving model results.\n", "\n", "Compared to a training function for GPU, no changes are needed to port to HPU. Internally, Ray Train does these things:\n", "\n", "- Detect HPU and set the device.\n", "- Initialize the habana PyTorch backend.\n", "- Initialize the habana distributed backend." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def train_func_per_worker(config: Dict):\n", " # adapt transformers to gaudi\n", " adapt_transformers_to_gaudi()\n", "\n", " # prepare training arguments\n", " training_args = prepare_training_args(config)\n", "\n", " # prepare datasets\n", " # here we use dataset \"tatsu-lab/alpaca\" from huggingface\n", " raw_datasets = datasets.DatasetDict({\"train\": datasets.load_dataset(\"tatsu-lab/alpaca\", split='train[0:4096]')})\n", " preprocess_dataset(raw_datasets)\n", "\n", " # prepare tokenizer\n", " tokenizer = transformers.AutoTokenizer.from_pretrained(config[\"model\"])\n", " tokenized_datasets = preprocess_dataset_to_tokenizer(raw_datasets, tokenizer)\n", "\n", " # prepare model\n", " model = prepare_model(config, training_args.device)\n", "\n", " # prepare data collator\n", " data_collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", mlm=False)\n", "\n", " # prepare gaudi config\n", " gaudi_config = GaudiConfig()\n", " gaudi_config.use_fused_adam = True\n", " gaudi_config.use_fused_clip_norm = True\n", "\n", " # instance GaudiTrainer\n", " trainer = GaudiTrainer(\n", " model=model,\n", " gaudi_config=gaudi_config,\n", " args=training_args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=None,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=None,\n", " preprocess_logits_for_metrics=None,\n", " )\n", "\n", " train_result = trainer.train()\n", " print(f\"train_result = {train_result}\")\n", " trainer.save_model()\n", "\n", " return train_result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Main Training Function\n", "The `train_llama` function sets up the distributed training environment using Ray and starts the training process. To enable training using HPU, we only need to make the following changes:\n", "- Set the exectuion mode for training, supported execution mode are:\n", "\n", " - \"lazy\": Deferred execution of graphs, comprising of ops delivered from script op by op similar to Eager mode. It gives the Eager mode experience with performance on Gaudi. Unlike Eager Mode with torch.compile, graph is analyzed in each iteration leading to a higher CPU usage.\n", " - \"eager\": Op-by-op execution as defined in standard PyTorch Eager mode scripts.\n", " - \"eager.compile\": Eager mode extended with `torch.compile` - Similar to Eager mode but extended with wrapping complete or part of model (such as a function) into a graph. Parts that are not wrapped are executed eagerly.\n", "\n", " More detail theory can be found [here](https://docs.habana.ai/en/latest/PyTorch/Reference/PyTorch_Gaudi_Theory_of_Operations.html), and detail performance results can be found [here](https://developer.habana.ai/get-started/habana-models-performance/)\n", "- Set training method, supported method are:\n", " - \"ddp\"\n", " - \"deepspeed\"\n", "- Require an HPU for each worker in ScalingConfig\n", "- Set backend to `hccl` in TorchConfig" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def train_llama(num_workers, execution_mode, training_method):\n", " import ray\n", " from ray.train import ScalingConfig\n", " from ray.train.torch import TorchTrainer, TorchConfig\n", "\n", " # deepspeed config, can also place it to config file\n", " deepspeed_config = {\n", " \"steps_per_print\": 64,\n", " \"train_batch_size\": \"auto\",\n", " \"train_micro_batch_size_per_gpu\": \"auto\",\n", " \"gradient_accumulation_steps\": \"auto\",\n", " \"bf16\": {\n", " \"enabled\": True\n", " },\n", " \"gradient_clipping\": 1.0,\n", " \"zero_optimization\": {\n", " \"stage\": 3,\n", " \"overlap_comm\": False,\n", " \"contiguous_gradients\": False,\n", " \"stage3_gather_16bit_weights_on_model_save\": True\n", " }\n", " }\n", "\n", " # Preparing train configurations\n", " train_config = {\n", " \"execution_mode\": execution_mode,\n", " \"model\": \"/root/models/models--meta-llama--Llama-2-70b-chat-hf/snapshots/e9149a12809580e8602995856f8098ce973d1080/\",\n", " \"model_config\": {\"torch_dtype\": torch.bfloat16, \"trust_remote_code\": False, \"use_auth_token\": None},\n", " \"lora_config\": {\"task_type\": \"CAUSAL_LM\", \"r\": 8, \"lora_alpha\": 32, \"lora_dropout\": 0.1, \"target_modules\": [\"q_proj\", \"v_proj\"]},\n", " \"lr\": 1e-4,\n", " \"epochs\": 2,\n", " \"batch_size_per_worker\": 8,\n", " \"output\": \"/tmp/ray/\",\n", " \"deepspeed\": deepspeed_config if training_method == \"deepspeed\" else None,\n", " }\n", "\n", " # Configure computation resources\n", " # In ScalingConfig, require an HPU for each worker\n", " scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={\"CPU\": 1, \"HPU\": 1})\n", " # Set backend to hccl in TorchConfig\n", " torch_config = TorchConfig(backend = \"hccl\")\n", "\n", " # start your ray cluster\n", " ray.init()\n", "\n", " # Initialize a Ray TorchTrainer\n", " trainer = TorchTrainer(\n", " train_loop_per_worker=train_func_per_worker,\n", " train_loop_config=train_config,\n", " torch_config=torch_config,\n", " scaling_config=scaling_config,\n", " )\n", "\n", " result = trainer.fit()\n", " print(f\"Training result: {result}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Start Training\n", "\n", "Finally, we call the `train_llama` function to start the training process. You can adjust the number of workers to use, and the execution mode for HPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# set some environment variables\n", "os.environ[\"RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES\"] = \"0\"\n", "# if using RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES env var\n", "# you must set HABANA_VISIBLE_DEVICES, such as\n", "# os.environ[\"HABANA_VISIBLE_DEVICES\"] = \"0,1,2,3\"\n", "\n", "# execution_mode are [\"lazy\", \"eager\", \"eager.compile\"]\n", "execution_mode = \"lazy\"\n", "os.environ[\"PT_HPU_LAZY_MODE\"] = \"1\" if execution_mode == \"lazy\" else \"0\"\n", "\n", "# training_method are [\"ddp\", \"deepspeed\"]\n", "training_method = \"deepspeed\"\n", "if training_method == \"deepspeed\":\n", " os.environ[\"PT_HPU_MAX_COMPOUND_OP_SIZE\"] = \"10\"\n", " os.environ[\"DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED\"] = \"1\"\n", "\n", "# here use 4 HPUs\n", "train_llama(num_workers=4, execution_mode=execution_mode, training_method=training_method)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Final output\n", "\n", "### For DDP on HPUs\n", "- Llama-2-70b-chat-hf\n", "- 4 HPU\n", "- LoRA\n", "\n", "``` bash\n", "(RayTrainWorker pid=123181) {'loss': 1.8051, 'grad_norm': 0.6015625, 'learning_rate': 9.938441702975689e-05, 'epoch': 0.16, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.6754, 'grad_norm': 0.408203125, 'learning_rate': 9.567727288213005e-05, 'epoch': 0.32, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.568, 'grad_norm': 0.4453125, 'learning_rate': 8.885729807284856e-05, 'epoch': 0.48, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.4934, 'grad_norm': 0.4609375, 'learning_rate': 7.938926261462366e-05, 'epoch': 0.65, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.3965, 'grad_norm': 0.3515625, 'learning_rate': 6.7918397477265e-05, 'epoch': 0.81, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.3461, 'grad_norm': 0.34765625, 'learning_rate': 5.522642316338268e-05, 'epoch': 0.97, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.2924, 'grad_norm': 0.32421875, 'learning_rate': 4.2178276747988446e-05, 'epoch': 1.13, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.2643, 'grad_norm': 0.33203125, 'learning_rate': 2.9663167846209998e-05, 'epoch': 1.29, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.263, 'grad_norm': 0.318359375, 'learning_rate': 1.8533980447508137e-05, 'epoch': 1.45, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.2502, 'grad_norm': 0.275390625, 'learning_rate': 9.549150281252633e-06, 'epoch': 1.61, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.2161, 'grad_norm': 0.2734375, 'learning_rate': 3.3209786751399187e-06, 'epoch': 1.77, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=123181) {'loss': 1.2517, 'grad_norm': 0.294921875, 'learning_rate': 2.7390523158633554e-07, 'epoch': 1.94, 'memory_allocated (GB)': 13.64, 'max_memory_allocated (GB)': 48.92, 'total_memory_available (GB)': 94.62}\n", "```\n", "\n", "### For DeepSpeed on HPUs\n", "- Llama-2-70b-chat-hf\n", "- 4 HPU\n", "- LoRA\n", "\n", "``` bash\n", "(RayTrainWorker pid=110856) {'loss': 1.6627, 'grad_norm': 0.35921376943588257, 'learning_rate': 9.938441702975689e-05, 'epoch': 0.16, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.6085, 'grad_norm': 0.35271379351615906, 'learning_rate': 9.567727288213005e-05, 'epoch': 0.32, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.5051, 'grad_norm': 0.4277978837490082, 'learning_rate': 8.885729807284856e-05, 'epoch': 0.48, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.4157, 'grad_norm': 0.5138524770736694, 'learning_rate': 7.938926261462366e-05, 'epoch': 0.65, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.3233, 'grad_norm': 0.3451262414455414, 'learning_rate': 6.7918397477265e-05, 'epoch': 0.81, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.2728, 'grad_norm': 0.38564223051071167, 'learning_rate': 5.522642316338268e-05, 'epoch': 0.97, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.1989, 'grad_norm': 0.36078131198883057, 'learning_rate': 4.2178276747988446e-05, 'epoch': 1.13, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.1552, 'grad_norm': 0.47946077585220337, 'learning_rate': 2.9663167846209998e-05, 'epoch': 1.29, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.1413, 'grad_norm': 0.3357600271701813, 'learning_rate': 1.8533980447508137e-05, 'epoch': 1.45, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.129, 'grad_norm': 0.2777070701122284, 'learning_rate': 9.549150281252633e-06, 'epoch': 1.61, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.0876, 'grad_norm': 0.25669950246810913, 'learning_rate': 3.3209786751399187e-06, 'epoch': 1.77, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "(RayTrainWorker pid=110856) {'loss': 1.1238, 'grad_norm': 0.2423330545425415, 'learning_rate': 2.7390523158633554e-07, 'epoch': 1.94, 'memory_allocated (GB)': 32.88, 'max_memory_allocated (GB)': 43.56, 'total_memory_available (GB)': 94.62}\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "orphan": true, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }