{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-Tuning Transformers with MLflow for Enhanced Model Management\n", "\n", "Welcome to our in-depth tutorial on fine-tuning Transformers models with enhanced management using MLflow." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Download the Fine Tuning Notebook" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### What You Will Learn in This Tutorial\n", "\n", "- Understand the process of fine-tuning a Transformers model.\n", "- Learn to effectively log and manage the training cycle using MLflow.\n", "- Master logging the trained model separately in MLflow.\n", "- Gain insights into using the trained model for practical inference tasks.\n", "\n", "Our approach will provide a holistic understanding of model fine-tuning and management, ensuring that you're well-equipped to handle similar tasks in your projects.\n", "\n", "#### Emphasizing Fine-Tuning\n", "Fine-tuning pre-trained models is a common practice in machine learning, especially in the field of NLP. It involves adjusting a pre-trained model to make it more suitable for a specific task. This process is essential as it allows the leveraging of pre-existing knowledge in the model, significantly improving performance on specific datasets or tasks.\n", "\n", "#### Role of MLflow in Model Lifecycle\n", "Integrating MLflow in this process is crucial for:\n", "\n", "- **Training Cycle Logging**: Keeping a detailed log of the training cycle, including parameters, metrics, and intermediate results.\n", "- **Model Logging and Management**: Separately logging the trained model, tracking its versions, and managing its lifecycle post-training.\n", "- **Inference and Deployment**: Using the logged model for inference, ensuring easy transition from training to deployment." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: TOKENIZERS_PARALLELISM=false\n" ] } ], "source": [ "# Disable tokenizers warnings when constructing pipelines\n", "%env TOKENIZERS_PARALLELISM=false\n", "\n", "import warnings\n", "\n", "# Disable a few less-than-useful UserWarnings from setuptools and pydantic\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparing the Dataset and Environment for Fine-Tuning\n", "\n", "#### Key Steps in this Section\n", "\n", "1. **Loading the Dataset**: Utilizing the `sms_spam` dataset for spam detection.\n", "2. **Splitting the Dataset**: Dividing the dataset into training and test sets with an 80/20 distribution.\n", "3. **Importing Necessary Libraries**: Including libraries like `evaluate`, `mlflow`, `numpy`, and essential components from the `transformers` library.\n", "\n", "Before diving into the fine-tuning process, setting up our environment and preparing the dataset is crucial. This step involves loading the dataset, splitting it into training and testing sets, and initializing essential components of the Transformers library. These preparatory steps lay the groundwork for an efficient fine-tuning process.\n", "\n", "This setup ensures that we have a solid foundation for fine-tuning our model, with all the necessary data and tools at our disposal. In the following Python code, we'll execute these steps to kickstart our model fine-tuning journey." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset sms_spam (/Users/benjamin.wilson/.cache/huggingface/datasets/sms_spam/plain_text/1.0.0/53f051d3b5f62d99d61792c91acefe4f1577ad3e4c216fb0ad39e30b9f20019c)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "361d4272e6144267a1566abed2cd674c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pick a name that you like and reflects the nature of the runs that you will be recording to the experiment.\n", "mlflow.set_experiment(\"Spam Classifier Training\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Starting the Training Process with MLflow\n", "\n", "In this step, we initiate the fine-tuning training run, utilizing the native auto-logging functionality to record the parameters used and loss metrics calculated during the training process.\n", " \n", "With our model, training arguments, and MLflow experiment set up, we are now ready to start the actual training process. This step involves initiating an MLflow run, which will encapsulate all the training activities and metrics.\n", "\n", "#### Initiating the MLflow Run\n", "\n", "- **Starting an MLflow Run**: We use `mlflow.start_run()` to begin a new MLflow run. This function creates a new run context, under which all the training operations and logging will occur.\n", "- **Training the Model**: Inside the MLflow run context, we call `trainer.train()` to start training our model. This function will run the training loop, processing the data in batches, updating model parameters, and evaluating the model.\n", "\n", "#### Monitoring the Training Progress\n", "During training, the `Trainer` object will output logs that provide valuable insights into the training progress:\n", "\n", "- **Loss**: Indicates the model's performance, with lower values signifying better performance.\n", "- **Learning Rate**: Shows the current learning rate used during training.\n", "- **Epoch Progress**: Displays the progress through the current epoch.\n", "\n", "These logs are crucial for monitoring the model's learning process and making any necessary adjustments. By tracking these metrics within an MLflow run, we can maintain a comprehensive record of the training process, enhancing reproducibility and analysis.\n", "\n", "In the next code block, we will start our MLflow run and begin training our model, closely observing the output to gauge the training progress." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7ed21a1cc1ee4dc2bb66c903b8fccdb5", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1674 [00:00