{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Prompt Templating with MLflow and Transformers\n", "\n", "Welcome to our in-depth tutorial on using prompt templates to conveniently customize the behavior of Transformers pipelines using MLflow. " ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Download this Notebook" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Learning Objectives\n", "\n", "In this tutorial, you will:\n", "\n", "- Set up a text generation pipeline using TinyLlama-1.1B as an example model\n", "- Set a prompt template that will be used to format user queries at inference time\n", "- Load the model for querying\n", "\n", "### What is a prompt template, and why use one?\n", "\n", "When dealing with large language models, the way a query is structured can significantly impact the model's performance. We often need to add some preamble, or format the query in a way that gives us the results that we want. It's not ideal to expect the end-user of our applications to know exactly what this format should be, so we typically have a pre-processing step to format the user input in a way that works best with the underlying model. In other words, we apply a prompt template to the user's input.\n", "\n", "MLflow provides a convenient way to set this on certain pipeline types using the `transformers` flavor. As of now, the only pipelines that we support are:\n", "\n", "- [feature-extraction](https://huggingface.co/transformers/main_classes/pipelines.html#transformers.FeatureExtractionPipeline)\n", "- [fill-mask](https://huggingface.co/transformers/main_classes/pipelines.html#transformers.FillMaskPipeline)\n", "- [summarization](https://huggingface.co/transformers/main_classes/pipelines.html#transformers.SummarizationPipeline)\n", "- [text2text-generation](https://huggingface.co/transformers/main_classes/pipelines.html#transformers.Text2TextGenerationPipeline)\n", "- [text-generation](https://huggingface.co/transformers/main_classes/pipelines.html#transformers.TextGenerationPipeline)\n", "\n", "\n", "If you need a runthrough of the basics of how to use the `transformers` flavor, check out the [Introductory Guide](https://mlflow.org/docs/latest/llms/transformers/guide/index.html)!\n", "\n", "Now, let's dive in and see how it's done!" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: TOKENIZERS_PARALLELISM=false\n" ] } ], "source": [ "# Disable tokenizers warnings when constructing pipelines\n", "%env TOKENIZERS_PARALLELISM=false\n", "\n", "import warnings\n", "\n", "# Disable a few less-than-useful UserWarnings from setuptools and pydantic\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pipeline setup and inference\n", "\n", "First, let's configure our Transformers pipeline. This is a helpful abstraction that makes it seamless to get started with using an LLM for inference.\n", "\n", "For this demonstration, let's say the user's input is the phrase \"Tell me the largest bird\". Let's experiment with a few different prompt templates, and see which one we like best." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Response to Template #0:\n", "Tell me the largest bird you've ever seen.\n", "I've seen a lot of birds\n", "\n", "Response to Template #1:\n", "Q: Tell me the largest bird\n", "A: The largest bird is a pigeon.\n", "\n", "A: The largest\n", "\n", "Response to Template #2:\n", "You are an assistant that is knowledgeable about birds. If asked about the largest bird, you will reply 'Duck'.\n", "User: Tell me the largest bird\n", "Assistant: Duck\n", "User: What is the largest bird?\n", "Assistant:\n", "\n" ] } ], "source": [ "from transformers import pipeline\n", "\n", "generator = pipeline(\"text-generation\", model=\"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\")\n", "\n", "user_input = \"Tell me the largest bird\"\n", "prompt_templates = [\n", " # no template\n", " \"{prompt}\",\n", " # question-answer style template\n", " \"Q: {prompt}\\nA:\",\n", " # dialogue style template with a system prompt\n", " (\n", " \"You are an assistant that is knowledgeable about birds. \"\n", " \"If asked about the largest bird, you will reply 'Duck'.\\n\"\n", " \"User: {prompt}\\n\"\n", " \"Assistant:\"\n", " ),\n", "]\n", "\n", "responses = generator(\n", " [template.format(prompt=user_input) for template in prompt_templates], max_new_tokens=15\n", ")\n", "for idx, response in enumerate(responses):\n", " print(f\"Response to Template #{idx}:\")\n", " print(response[0][\"generated_text\"] + \"\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving the model and template with MLflow\n", "\n", "Now that we've experimented with a few prompt templates, let's pick one, and save it together with our pipeline using MLflow. Before we do this, let's take a few minutes to learn about an important component of MLflow models—signatures!\n", "\n", "### Creating a model signature\n", "\n", "A model signature codifies a model's expected inputs, outputs, and inference params. MLflow enforces this signature at inference time, and will raise a helpful exception if the user input does not match up with the expected format.\n", "\n", "Creating a signature can be done simply by calling `mlflow.models.infer_signature()`, and providing a sample input and output value. We can use `mlflow.transformers.generate_signature_output()` to easily generate a sample output. If we want to pass any additional arguments to the pipeline at inference time (e.g. `max_new_tokens` above), we can do so via `params`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024/01/16 17:28:42 WARNING mlflow.transformers: params provided to the `predict` method will override the inference configuration saved with the model. If the params provided are not valid for the pipeline, MlflowException will be raised.\n" ] }, { "data": { "text/plain": [ "inputs: \n", " [string (required)]\n", "outputs: \n", " [string (required)]\n", "params: \n", " ['max_new_tokens': long (default: 15)]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import mlflow\n", "\n", "sample_input = \"Tell me the largest bird\"\n", "params = {\"max_new_tokens\": 15}\n", "signature = mlflow.models.infer_signature(\n", " sample_input,\n", " mlflow.transformers.generate_signature_output(generator, sample_input, params=params),\n", " params=params,\n", ")\n", "\n", "# visualize the signature\n", "signature" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Starting a new experiment\n", "We create a new [MLflow Experiment](https://mlflow.org/docs/latest/tracking.html#experiments) so that the run we're going to log our model to does not log to the default experiment and instead has its own contextually relevant entry.\n", "\n", "### Logging the model with the prompt template\n", "Logging the model using MLflow saves the model and its essential metadata so it can be efficiently tracked and versioned. We'll use `mlflow.transformers.log_model()`, which is tailored to make this process as seamless as possible. To save the prompt template, all we have to do is pass it in using the `prompt_template` keyword argument.\n", "\n", "Two important thing to take note of:\n", "\n", "1. A prompt template must be a string with exactly one named placeholder `{prompt}`. MLflow will raise an error if a prompt template is provided that does not conform to this format.\n", "\n", "2. `text-generation` pipelines with a prompt template will have the [return_full_text pipeline argument](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.inference._text_generation.TextGenerationParameters.return_full_text) set to `False` by default. This is to prevent the template from being shown to the users, which could potentially cause confusion as it was not part of their original input. To override this behaviour, either set `return_full_text` to `True` via `params`, or by including it in a `model_config` dict in `log_model()`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024/01/16 17:28:45 INFO mlflow.tracking.fluent: Experiment with name 'prompt-templating' does not exist. Creating a new experiment.\n", "2024/01/16 17:28:52 INFO mlflow.transformers: text-generation pipelines saved with prompt templates have the `return_full_text` pipeline kwarg set to False by default. To override this behavior, provide a `model_config` dict with `return_full_text` set to `True` when saving the model.\n", "2024/01/16 17:32:57 WARNING mlflow.utils.environment: Encountered an unexpected error while inferring pip requirements (model URI: /var/folders/qd/9rwd0_gd0qs65g4sdqlm51hr0000gp/T/tmpbs0poq1a/model, flavor: transformers), fall back to return ['transformers==4.34.1', 'torch==2.1.1', 'torchvision==0.16.1', 'accelerate==0.25.0']. Set logging level to DEBUG to see the full traceback.\n" ] } ], "source": [ "# If you are running this tutorial in local mode, leave the next line commented out.\n", "# Otherwise, uncomment the following line and set your tracking uri to your local or remote tracking server.\n", "\n", "mlflow.set_tracking_uri(\"http://127.0.0.1:5000\")\n", "\n", "# Set a name for the experiment that is indicative of what the runs being created within it are in regards to\n", "mlflow.set_experiment(\"prompt-templating\")\n", "\n", "prompt_template = \"Q: {prompt}\\nA:\"\n", "with mlflow.start_run():\n", " model_info = mlflow.transformers.log_model(\n", " transformers_model=generator,\n", " artifact_path=\"model\",\n", " task=\"text-generation\",\n", " signature=signature,\n", " input_example=\"Tell me the largest bird\",\n", " prompt_template=prompt_template,\n", " # Since MLflow 2.11.0, you can save the model in 'reference-only' mode to reduce storage usage by not saving\n", " # the base model weights but only the reference to the HuggingFace model hub. To enable this, uncomment the\n", " # following line:\n", " # save_pretrained=False,\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading the model for inference\n", "\n", "Next, we can load the model using `mlflow.pyfunc.load_model()`.\n", "\n", "The `pyfunc` module in MLflow serves as a generic wrapper for Python functions. It gives us a standard interface for loading and querying models as python functions, without having to worry about the specifics of the underlying models.\n", "\n", "Utilizing [mlflow.pyfunc.load_model](https://www.mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.load_model), our previously logged text generation model is loaded using its unique model URI. This URI is a reference to the stored model artifacts. MLflow efficiently handles the model's deserialization, along with any associated dependencies, preparing it for immediate use.\n", "\n", "Now, when we call the `predict()` method on our loaded model, the user's input should be formatted with our chosen prompt template prior to inference!" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "709dd14c0bd5433e95fcbb60755f7ed0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading artifacts: 0%| | 0/23 [00:00