{ "cells": [ { "cell_type": "markdown", "id": "e720418d", "metadata": {}, "source": [ "# The FIL Backend for Triton: FAQs and Advanced Features\n", "\n", "## Introduction\n", "\n", "This example notebook focuses on the technical details of deploying tree-based models with the FIL Backend for Triton. It is organized as a series of FAQs followed by example code providing a practical illustration of the corresponding FAQ section.\n", "\n", "The goal of this notebook is to offer information that goes beyond the basics and provide answers to practical questions that may arise when attempting a real-world deployment with the FIL backend. If you are a complete newcomer to the FIL backend and are looking for a short introduction to the basics of what the FIL backend is and how to use it, you are encouraged to check out [this introductory notebook](https://github.com/triton-inference-server/fil_backend/blob/main/notebooks/categorical-fraud-detection/Fraud_Detection_Example.ipynb).\n", "\n", "While we do provide training code for example models, training models is *not* the subject of this notebook, and we will provide little detail on training. Instead, you are encouraged to use your own model(s) and data with this notebook to get a realistic picture of how your model will perform with Triton." ] }, { "cell_type": "markdown", "id": "e9ad97cc", "metadata": {}, "source": [ "\n", "# Table of Contents\n", "* [Introduction](#Introduction)\n", "* [Table of Contents](#Table-of-Contents)\n", "* [Hardware Pre-requisites](#Hardware-Pre-Requisites)\n", "* [Software Pre-requisites](#Software-Pre-Requisites)\n", "* [FAQ 1: What can I deploy with the FIL Backend?](#FAQ-1:-What-can-I-deploy-with-the-FIL-backend?)\n", " - [FAQ 1.1 Can I deploy non-tree Scikit-Learn models like LinearRegression?](#FAQ-1.1-Can-I-deploy-non-tree-Scikit-Learn-models-like-LinearRegression?)\n", " - [FAQ 1.2 Can I deploy Scikit-Learn/cuML Pipelines with the FIL backend?](#FAQ-1.2-Can-I-deploy-Scikit-Learn/cuML-Pipelines-with-the-FIL-backend?)\n", " - [FAQ 1.3 Can I deploy Scikit-Learn/cuML models serialized with Pickle?](#FAQ-1.3-Can-I-deploy-Scikit-Learn/cuML-models-serialized-with-Pickle?)\n", " - [FAQ 1.4 Can I deploy Scikit-Learn/cuML models serialized with Joblib?](#FAQ-1.4-Can-I-deploy-Scikit-Learn/cuML-models-serialized-with-Joblib?)\n", "* [Example 1: Model Serialization](#Example-1:-Model-Serialization)\n", " - [Example 1.1: Serializing an XGBoost model](#Example-1.1:-Serializing-an-XGBoost-model)\n", " - [Example 1.2 Serializing a LightGBM model](#Example-1.2-Serializing-a-LightGBM-model)\n", " - [Example 1.3 Serializing an in-memory Scikit-Learn model](#Example-1.3-Serializing-an-in-memory-Scikit-Learn-model)\n", " - [Example 1.4 Serializing an in-memory cuML model](#Example-1.4-Serializing-an-in-memory-cuML-model)\n", " - [Example 1.5 Converting a pickled Scikit-Learn model](#Example-1.5-Converting-a-pickled-Scikit-Learn-model)\n", " - [Example 1.6 Converting a pickled cuML model](#Example-1.5-Converting-a-pickled-Scikit-Learn-model)\n", "* [FAQ 2: How do I execute models on CPU only? On GPU?](#FAQ-2:-How-do-I-execute-models-on-CPU-only?-On-GPU?)\n", " - [FAQ 2.1: How do I fall back to CPU only if GPUs are not available?](#FAQ-2:-How-do-I-execute-models-on-CPU-only?-On-GPU?)\n", "* [Example 2: Generating a configuration file](#Example-2:-Generating-a-configuration-file)\n", "* [FAQ 3: How can I quickly test configuration options?](#FAQ-3:-How-can-I-quickly-test-configuration-options?)\n", "* [Example 3: Launching the Triton server with polling mode](#Example-3:-Launching-the-Triton-server-with-polling-mode)\n", "* [FAQ 4: My models are exhausting Triton's memory. What can I do?](#FAQ-4:-My-models-are-exhausting-Triton's-memory.-What-can-I-do?)\n", " - [FAQ 4.1 How can I decrease the memory consumed by a model?](#FAQ-4.1-How-can-I-decrease-the-memory-consumed-by-a-model?)\n", " - [FAQ 4.2 How do I increase Triton's device memory pool?](#FAQ-4.2-How-do-I-increase-Triton's-device-memory-pool?)\n", "* [Example 4: Configuring Triton for large models](#Example-4:-Configuring-Triton-for-large-models)\n", " - [Example 4.1: Changing `storage_type` to reduce memory consumption](#Example-4.1:-Changing-storage_type-to-reduce-memory-consumption)\n", " - [Example 4.2: Increasing Triton's device memory pool](#$\\color{#76b900}{\\text{Example-4.2:-Increasing-Triton's-device-memory-pool}}$)\n", "* [FAQ 5: How do I submit an inference request to Triton?](#FAQ-5:-How-do-I-submit-an-inference-request-to-Triton?)\n", " - [FAQ 5.1: How do I submit inference requests through Triton's C API?](#FAQ-5.1:-How-do-I-submit-inference-requests-through-Triton's-C-API?)\n", " - [FAQ 5.2: How do I submit inference requests with categorical variables?](#FAQ-5.2:-How-do-I-submit-inference-requests-with-categorical-variables?)\n", "* [Example 5: Submitting a request with the Triton Python client](#Example-5:-Submitting-a-request-with-the-Triton-Python-client)\n", "* [FAQ 6: How do I return probability scores rather than classes from a classifier?](#FAQ-6:-How-do-I-return-probability-scores-rather-than-classes-from-a-classifier?)\n", "* [Example 6: Using the `predict_proba` option](#Example-6:-Using-the-predict_proba-option)\n", "* [FAQ 7: Does serving my model with Triton change its accuracy?](#FAQ-7:-Does-serving-my-model-with-Triton-change-its-accuracy?)\n", "* [Example 7: Comparing results from Triton and native execution](#Example-7:-Comparing-results-from-Triton-and-native-execution)\n", "* [FAQ 8: How do we measure performance of the FIL backend?](#FAQ-8:-How-do-we-measure-performance-of-the-FIL-backend?)\n", "* [Example 8: Using perf_analyzer to measure throughput and latency](#Example-8:-Using-perf_analyzer-to-measure-throughput-and-latency)\n", "* [FAQ 9: How can we improve performance of models deployed with the FIL backend?](#FAQ-9:-How-can-we-improve-performance-of-models-deployed-with-the-FIL-backend?)\n", " - [FAQ 9.1: Does specifying preferred batch sizes help FIL's performance?](#FAQ-9.1:-Does-specifying-preferred-batch-sizes-help-FIL's-performance?)\n", "* [Example 9: Optimizing model performance](#Example-9:-Optimizing-model-performance)\n", " - [Example 9.1: Minimizing latency](#Example-9.1:-Minimizing-latency)\n", " - [Example 9.2: Maximizing Throughput](#Example-9.2:-Maximizing-Throughput)\n", " - [Example 9.3: Balancing latency and throughput](#Example-9.3:-Balancing-latency-and-throughput)\n", "* [FAQ 10: How fast is the FIL backend relative to alternatives?](#FAQ-10:-How-fast-is-the-FIL-backend-relative-to-alternatives?)\n", " - [FAQ 10.1 How fast is the FIL backend on CPU vs on GPU?](#FAQ-10.1-How-fast-is-the-FIL-backend-on-CPU-vs-on-GPU?)\n", " - [FAQ 10.2 How fast is the FIL backend relative to the ONNX backend?](#FAQ-10.2-How-fast-is-the-FIL-backend-relative-to-the-ONNX-backend?)\n", "* [Example 10: Comparing the FIL and ONNX backends](#$\\color{#76b900}{\\text{Example-10:-Comparing-the-FIL-and-ONNX-backends}}$)\n", "* [FAQ 11: How do I submit many inference requests in parallel?](#FAQ-11:-How-do-I-submit-many-inference-requests-in-parallel?)\n", "* [Example 11: Submitting requests in parallel with the Python client](#Example-11:-Submitting-requests-in-parallel-with-the-Python-client)\n", "* [FAQ 12: How do I retrieve Shapley values for model explainability?](#$\\color{#76b900}{\\text{FAQ-12:-How-do-I-retrieve-Shapley-values-for-model-explainability?}}$)\n", "* [Example 12: Retrieving Shapley Values](#$\\color{#76b900}{\\text{Example-12:-Retrieving-Shapley-Values}}$)\n", "* [FAQ 13: How do I serve a learning-to-rank model?](#FAQ-13:-How-do-I-serve-a-learning-to-rank-model?)\n", "* [Cleanup](#Cleanup)\n", "* [Conclusion](#Conclusion)" ] }, { "cell_type": "markdown", "id": "e7224487", "metadata": {}, "source": [ "# Hardware Pre-Requisites\n", "Most of this notebook is designed to run either on CPU or GPU. Sections that will only run on GPU will be marked in $\\color{#76b900}{\\text{green}}$. To guarantee that all cells will execute correctly if a GPU is not available, change `USE_GPU` in the following cell to `False`. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "58491c57", "metadata": {}, "outputs": [], "source": [ "USE_GPU = True" ] }, { "cell_type": "markdown", "id": "72b761ac", "metadata": {}, "source": [ "## Software Pre-Requisites\n", "\n", "Depending on which model framework you choose to use, you may need a different subset of dependencies. In order to install *all* dependencies with conda, you can use the following environment file:\n", "\n", "```yaml\n", "---\n", "name: triton_faq_nb\n", "channels:\n", " - conda-forge\n", " - nvidia\n", " - rapidsai\n", "dependencies:\n", " - cudatoolkit=11.4\n", " - cuml=22.04\n", " - joblib\n", " - jupyter\n", " - lightgbm\n", " - numpy\n", " - pandas\n", " - pip\n", " - python=3.8\n", " - scikit-learn\n", " - skl2onnx\n", " - treelite=2.3.0\n", " - pip:\n", " - tritonclient[all]\n", " - xgboost>=1.5,<1.6\n", " - protobuf==3.20.1\n", "```\n", "If you do not wish to install all dependencies, remove the frameworks you do not intend to use from this list. If you do not have access to an NVIDIA GPU, you should remove `cuml` and `cudatoolkit` from this list.\n", "\n", "In addition to the above dependencies, the Triton client requires that `libb64` be installed on the system, and Docker must be available to launch the Triton server. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "86510373", "metadata": {}, "source": [ "# FAQ 1: What can I deploy with the FIL backend?\n", "The first thing you will need to begin using the FIL backend is a serialized model file. The FIL backend supports **tree-based** models serialized to formats from a variety of frameworks, including the following:\n", "\n", "## XGBoost JSON and binary models\n", "XGBoost uses two serialization formats, both of which are natively supported by the FIL backend. All XGBoost models except for multi-output regression models are supported.\n", "
\n", "VERSION NOTE: Categorical variable support was added to XGBoost 1.5 as an experimental feature. The FIL backend has supported categorical variables since version 21.11.\n", "
\n", "
\n", "VERSION NOTE: The XGBoost JSON format changed in XGBoost 1.6. The first version of the FIL backend to support these JSON changes will be 22.07.\n", "
\n", "\n", "## LightGBM text models\n", "LightGBM's text serialization format is natively supported for all LightGBM model types except for multi-output regression models.\n", "\n", "
\n", "VERSION NOTE: Models trained on categorical variables have been supported since version 21.11 of the backend\n", "
\n", "\n", "## Scikit-Learn/cuML tree models and other Treelite-supported models\n", "\n", "The FIL backend supports the following model types from Scikit-Learn/cuML:\n", "- GradientBoostingClassifier\n", "- GradientBoostingRegressor\n", "- IsolationForest\n", "- RandomForestClassifier\n", "- RandomForestRegressor\n", "- ExtraTreesClassifier\n", "- ExtraTreesRegressor\n", "\n", "Since Scikit-Learn and cuML do not have native serialization formats for these models (instead relying on e.g. Pickle), we use Treelite's checkpoint format to support these models. This also means that *any* framework that can export to Treelite's checkpoint format will be supported by the FIL backend. As part of this notebook, we will provide an example of how to save a Scikit-Learn or cuML model to a Treelite checkpoint. [^](#Table-of-Contents)\n", "\n", "
\n", "\n", "VERSION NOTE: Treelite's checkpoint format provides no forward/backward compatibility guarantees. It is therefore strongly recommended that you save Scikit-Learn and cuML models to Pickle so that they can be reconverted as needed. The table below shows the version of Treelite which must be used with each version of the FIL backend.\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
FIL Backend VersionTreelite
21.081.3.0
21.09-21.102.0.0
21.11-22.022.1.0
22.03-22.062.3.0
22.07+2.4.0
\n", " \n", "
\n", "\n", "\n", "\n", "### FAQ 1.1 Can I deploy non-tree Scikit-Learn models like LinearRegression?\n", "No. The FIL backend only supports tree models and will continue to support only tree models in the future. Support for other model types may eventually be added to Triton via another backend. [^](#Table-of-Contents)\n", "\n", "### FAQ 1.2 Can I deploy Scikit-Learn/cuML Pipelines with the FIL backend?\n", "No. If you wish to create pipelines of different models in Triton, check out Triton's [Python backend](https://github.com/triton-inference-server/python_backend#python-backend), which allows users to connect models supported by other backends with arbitrary Python logic. [^](#Table-of-Contents)\n", "\n", "### FAQ 1.3 Can I deploy Scikit-Learn/cuML models serialized with Pickle?\n", "Pickle-serialized models can be converted to Treelite's checkpoint format using a script provided with the FIL Backend. This script is [documented here](https://github.com/triton-inference-server/fil_backend/blob/main/SKLearn_and_cuML.md#converting-to-treelite-checkpoints), and an example of its use will be included with this notebook. **Pickle models MUST be converted to Treelite checkpoints. They CANNOT be used directly by the FIL backend.** [^](#Table-of-Contents)\n", "\n", "### FAQ 1.4 Can I deploy Scikit-Learn/cuML models serialized with Joblib?\n", "JobLib-serialized models can be loaded in Python and serialized to Treelite checkpoints. At the moment, the conversion scripts for Pickle-serialized models do **not** work with Joblib, but support for Joblib will be added with a later version. **Joblib models MUST be converted to Treelite checkpoints. They CANNOT be used directly by the FIL backend.** [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "038cd3fa", "metadata": {}, "source": [ "# Example 1: Model Serialization\n", "\n", "In the following example code snippets, we will demonstrate how model serialization works for each of the supported model types. In the cell below, indicate the type of model you would like to use.\n", "\n", "If you are bringing your own model, please also provide the path to the serialized model. Otherwise, a model will be trained on random data in your selected format.\n", "\n", "In addition to information on where and how the model is stored, we'll use the following cell to gather a bit of metadata on the model which we'll need later on including the number of features the model expects and the number of classes it outputs. If you are using a regression model, use `1` for the number of classes." ] }, { "cell_type": "code", "execution_count": null, "id": "977f2e89", "metadata": {}, "outputs": [], "source": [ "# Allowed values for MODEL_FORMAT are xgboost_json, xgboost_bin, lightgbm, skl_pkl, cuml_pkl, skl_joblib,\n", "# and treelite\n", "MODEL_FORMAT = 'xgboost_json'\n", "\n", "# If a path is provided to a model in the specified format, that model will be used for the following examples.\n", "# Otherwise, if MODEL_PATH is left as None, a model will be trained and stored to a default location.\n", "MODEL_PATH = None\n", "\n", "# Set this value to the number of features (columns) in your dataset\n", "NUM_FEATURES = 32\n", "\n", "# Set this value to the number of possible output classes or 1 for regression models\n", "NUM_CLASSES = 2\n", "\n", "# Set this value to False if you are using your own regression model\n", "IS_CLASSIFIER = True" ] }, { "cell_type": "markdown", "id": "c62c5b43", "metadata": {}, "source": [ "## Model Training/Loading\n", "\n", "In this section, if a model path has been provided, we will load the model so that we can compare its output to what we get from Triton later in the notebook. If a model path has **not** been provided, a model of the indicated type will be trained and serialized to a default location. We will not provide detail or commentary on training, since this is not the focus of this notebook. Consult documentation or examples for your chosen framework if you would like to learn more about the training process." ] }, { "cell_type": "code", "execution_count": null, "id": "49760159", "metadata": {}, "outputs": [], "source": [ "RANDOM_SEED=0" ] }, { "cell_type": "code", "execution_count": null, "id": "2e483d05", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn.datasets import make_classification\n", "# Create random dataset. Even if we do not use this dataset for training, we will use it for testing later.\n", "# If you would like to use a real dataset, load it here into X and y Pandas dataframes\n", "X, y = make_classification(\n", " n_samples=5000,\n", " n_features=NUM_FEATURES,\n", " n_informative=max(NUM_FEATURES // 3, 1),\n", " n_classes=NUM_CLASSES,\n", " random_state=RANDOM_SEED\n", ")\n", "X = pd.DataFrame(X)\n", "y = pd.DataFrame(y)" ] }, { "cell_type": "code", "execution_count": null, "id": "03943d23", "metadata": {}, "outputs": [], "source": [ "# Set model parameters for any models we need to train\n", "NUM_TREES = 500\n", "MAX_DEPTH = 10" ] }, { "cell_type": "code", "execution_count": null, "id": "a8a2a0d6", "metadata": {}, "outputs": [], "source": [ "model = None\n", "# XGBoost\n", "def train_xgboost(X, y, n_trees=NUM_TREES, max_depth=MAX_DEPTH):\n", " import xgboost as xgb\n", " \n", " if USE_GPU:\n", " tree_method = 'gpu_hist'\n", " predictor = 'gpu_predictor'\n", " else:\n", " tree_method = 'hist'\n", " predictor = 'cpu_predictor'\n", " \n", " model = xgb.XGBClassifier(\n", " eval_metric='error',\n", " objective='binary:logistic',\n", " tree_method=tree_method,\n", " max_depth=max_depth,\n", " n_estimators=n_trees,\n", " use_label_encoder=False,\n", " predictor=predictor\n", " )\n", " \n", " return model.fit(X, y)\n", "\n", "def train_lightgbm(X, y, n_trees=NUM_TREES, max_depth=MAX_DEPTH):\n", " import lightgbm as lgb\n", " \n", " lgb_data = lgb.Dataset(X, y)\n", " \n", " if classes <= 2:\n", " classes = 1\n", " objective = 'binary'\n", " metric = 'binary_logloss'\n", " else:\n", " objective = 'multiclass'\n", " metric = 'multi_logloss'\n", " training_params = {\n", " 'metric': metric,\n", " 'objective': objective,\n", " 'num_class': NUM_CLASSES,\n", " 'max_depth': max_depth,\n", " 'verbose': -1\n", " }\n", " return lgb.train(training_params, lgb_data, n_trees)\n", "\n", "def train_cuml(X, y, n_trees=NUM_TREES, max_depth=MAX_DEPTH):\n", " from cuml.ensemble import RandomForestClassifier\n", " model = RandomForestClassifier(\n", " max_depth=max_depth, n_estimators=n_trees, random_state=RANDOM_SEED\n", " )\n", " return model.fit(X, y)\n", "\n", "def train_skl(X, y, n_trees=NUM_TREES, max_depth=MAX_DEPTH):\n", " from sklearn.ensemble import RandomForestClassifier\n", " model = RandomForestClassifier(\n", " max_depth=max_depth, n_estimators=n_trees, random_state=RANDOM_SEED\n", " )\n", " return model.fit(X, y)\n", " \n", " \n", "if MODEL_FORMAT in ('xgboost_json', 'xgboost_bin'):\n", " if MODEL_PATH is not None:\n", " # Load model just as a reference for later\n", " import xgboost as xgb\n", " model = xgb.Booster()\n", " model.load_model(MODEL_PATH)\n", " print('Congratulations! Your model is already in a natively-supported format')\n", " else:\n", " model = train_xgboost(X, y)\n", "elif MODEL_FORMAT == 'lightgbm':\n", " if MODEL_PATH is not None:\n", " # Load model just as a reference for later\n", " import lightgbm as lgb\n", " model = lgb.Booster(model_file=MODEL_PATH)\n", " print('Congratulations! Your model is already in a natively-supported format')\n", " else:\n", " model = train_lightgbm(X, y)\n", "elif MODEL_FORMAT in ('cuml_pkl', 'skl_pkl', 'cuml_joblib', 'skl_joblib'):\n", " if MODEL_PATH is not None:\n", " if MODEL_FORMAT in ('cuml_pkl', 'skl_pkl'):\n", " # Load model just as a reference for later\n", " import pickle\n", " model = pickle.load(MODEL_PATH)\n", " print(\n", " \"While pickle files are not natively supported, we will use a script to\"\n", " \" convert your model to a Treelite checkpoint later in this notebook.\"\n", " )\n", " else:\n", " print(\"Loading model from joblib file in order to convert it to Treelite checkpoint...\")\n", " import joblib\n", " model = joblib.load(MODEL_PATH)\n", " elif MODEL_FORMAT.startswith('cuml'):\n", " model = train_cuml(X, y)\n", " elif MODEL_FORMAT.startswith('skl'):\n", " model = train_skl(X, y)" ] }, { "cell_type": "markdown", "id": "11b5b9e4", "metadata": {}, "source": [ "## The Model Repository\n", "Triton expects models to be stored in a specific directory structure. We will go ahead and create this directory structure now and serialize our models directly into the final directory or copy the serialized model there if the trained model was provided.\n", "\n", "Each model requires a configuration file stored in `$MODEL_REPO/$MODEL_NAME/config.pbtxt`, and a model file stored in `$MODEL_REPO/$MODEL_NAME/$MODEL_VERSION/$MODEL_FILENAME`. Note that Triton supports storing multiple versions of a model directories with different `$MODEL_VERSION` numbers starting from `1`. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "39eb4a7a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import shutil\n", "\n", "MODEL_NAME = 'example_model'\n", "MODEL_VERSION = 1\n", "MODEL_REPO = os.path.abspath('data/model_repository')\n", "MODEL_DIR = os.path.join(MODEL_REPO, MODEL_NAME)\n", "VERSIONED_DIR = os.path.join(MODEL_DIR, str(MODEL_VERSION))\n", "\n", "os.makedirs(VERSIONED_DIR, exist_ok=True)\n", "\n", "# We will use the following variables to record information from the serialization\n", "# process that we will require later\n", "model_path = None\n", "model_format = None" ] }, { "cell_type": "markdown", "id": "e9b396b3", "metadata": {}, "source": [ "## Example 1.1: Serializing an XGBoost model\n", " [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "54bb67ad", "metadata": {}, "outputs": [], "source": [ "if MODEL_FORMAT == 'xgboost_json':\n", " # This is the default filename expected for XGBoost JSON models. It is recommended\n", " # that you stick with the default to avoid additional configuration.\n", " model_basename = 'xgboost.json'\n", " model_path = os.path.join(VERSIONED_DIR, model_basename)\n", " \n", " model_format = 'xgboost_json'\n", "elif MODEL_FORMAT == 'xgboost_bin':\n", " # This is the default filename expected for XGBoost binary models. It is recommended\n", " # that you stick with the default to avoid additional configuration.\n", " model_basename = 'xgboost.model'\n", " model_path = os.path.join(VERSIONED_DIR, model_basename)\n", " \n", " # This is the format name Triton uses to indicate XGBoost binary models\n", " model_format = 'xgboost'\n", "\n", "if MODEL_FORMAT.startswith('xgboost'):\n", " if MODEL_PATH is not None: # Just need to copy existing file...\n", " shutil.copy(MODEL_PATH, model_path)\n", " else:\n", " model.save_model(model_path) # XGB derives format from extension" ] }, { "cell_type": "markdown", "id": "736d87bb", "metadata": {}, "source": [ "## Example 1.2 Serializing a LightGBM model\n", " [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "32c82643", "metadata": {}, "outputs": [], "source": [ "if MODEL_FORMAT == 'lightgbm':\n", " # This is the default filename expected for LightGBM text models. It is recommended\n", " # that you stick with the default to avoid additional configuration.\n", " model_basename = 'model.txt'\n", " model_path = os.path.join(VERSIONED_DIR, model_basename)\n", " \n", " model_format = MODEL_FORMAT\n", " \n", " if MODEL_PATH is not None: # Just need to copy existing file...\n", " shutil.copy(MODEL_PATH, model_path)\n", " else:\n", " model.save_model(model_path) " ] }, { "cell_type": "markdown", "id": "acb7eb0d", "metadata": {}, "source": [ "## Example 1.3 Serializing an in-memory Scikit-Learn model\n", "The following will show how to serialize a SKL model from Python directly to a Treelite checkpoint format. This could be a model that you have just trained or a model that you have e.g. loaded from Joblib. Again it is strongly recommended that you **save trained models in Pickle/Joblib as well as Treelite** since Treelite provides no compatibility guarantees between versions. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "110924ce", "metadata": {}, "outputs": [], "source": [ "if model is not None and MODEL_FORMAT.startswith('skl'):\n", " import pickle\n", " archival_path = os.path.join(VERSIONED_DIR, 'model.pkl')\n", " pickle.dump(model, archival_path) # Create archival pickled version\n", " \n", " # This is the default filename expected for Treelite checkpoint models. It is recommended\n", " # that you stick with the default to avoid additional configuration.\n", " model_basename = 'checkpoint.tl'\n", " model_path = os.path.join(VERSIONED_DIR, model_basename)\n", " \n", " model_format = 'treelite_checkpoint'\n", " \n", " import treelite\n", " tl_model = treelite.sklearn.import_model(model)\n", " tl_model.serialize(model_path)" ] }, { "cell_type": "markdown", "id": "6e70cc57", "metadata": {}, "source": [ "## Example 1.4 Serializing an in-memory cuML model\n", "The following will show how to serialize a cuML model from Python directly to a Treelite checkpoint format. This could be a model that you have just trained or a model that you have e.g. loaded from Joblib. Again it is strongly recommended that you **save trained models in Pickle/Joblib as well as Treelite** since Treelite provides no compatibility guarantees between versions. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "16c0fb1b", "metadata": {}, "outputs": [], "source": [ "if model is not None and MODEL_FORMAT.startswith('cuml'):\n", " import pickle\n", " archival_path = os.path.join(VERSIONED_DIR, 'model.pkl')\n", " pickle.dump(model, archival_path) # Create archival pickled version\n", " \n", " # This is the default filename expected for Treelite checkpoint models. It is recommended\n", " # that you stick with the default to avoid additional configuration.\n", " model_basename = 'checkpoint.tl'\n", " model_path = os.path.join(VERSIONED_DIR, model_basename)\n", " \n", " model_format = 'treelite_checkpoint'\n", " \n", " model.convert_to_treelite_model().to_treelite_checkpoint(model_path)" ] }, { "cell_type": "markdown", "id": "7b4a73ee", "metadata": {}, "source": [ "## Example 1.5 Converting a pickled Scikit-Learn model\n", "For convenience, the FIL backend provides a script which can be used to convert a pickle file containing a Scikit-Learn model directly to a Treelite checkpoint file. If you do not have access to that script or prefer to work directly from Python, you can always load the pickled model into memory and then serialize it as in [Example 1.3](#example_1.3). [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "7bbf2e69", "metadata": {}, "outputs": [], "source": [ "if MODEL_PATH is not None and MODEL_FORMAT == 'skl_pkl':\n", " archival_path = os.path.join(VERSIONED_DIR, 'model.pkl')\n", " shutil.copy(MODEL_PATH, archival_path)\n", " \n", " !../../scripts/convert_sklearn {archival_path}" ] }, { "cell_type": "markdown", "id": "adadf2ce", "metadata": {}, "source": [ "## Example 1.6 Converting a pickled cuML model\n", "For convenience, the FIL backend provides a script which can be used to convert a pickle file containing a cuML model directly to a Treelite checkpoint file. If you do not have access to that script or prefer to work directly from Python, you can always load the pickled model into memory and then serialize it as in [Example 1.4](#example_1.4). [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "b22bfb7e", "metadata": {}, "outputs": [], "source": [ "if MODEL_PATH is not None and MODEL_FORMAT == 'cuml_pkl':\n", " archival_path = os.path.join(VERSIONED_DIR, 'model.pkl')\n", " shutil.copy(MODEL_PATH, archival_path)\n", " \n", " !python ../../scripts/convert_cuml.py {archival_path}" ] }, { "cell_type": "markdown", "id": "801f5015", "metadata": {}, "source": [ "# FAQ 2: How do I execute models on CPU only? On GPU?\n", "\n", "In addition to a serialized model file, you must provide a `config.pbtxt` configuration file for each model you wish to serve with the FIL backend for Triton. Within that file, it is possible to specify whether a model will run on CPU or GPU and how many instances of the model you wish to serve. For example, adding the following entry to the configuration file will create one instance of the model for each available GPU and run those instances each on their own dedicated GPU:\n", "\n", "```pbtxt\n", " instance_group [{ kind: KIND_GPU }]\n", "```\n", "\n", "If you wish to instead run exactly three instances on CPU, the following entry can be used:\n", "```pbtxt\n", " instance_group [\n", " {\n", " count: 3\n", " kind: KIND_CPU\n", " }\n", " ]\n", "```\n", "\n", "In the following example, we will create a configuration file that can be used to serve your model on either CPU or GPU depending on the value of the `USE_GPU` flag set earlier in this notebook. [^](#Table-of-Contents)\n", "\n", "
\n", "VERSION NOTE: CPU execution was introduced with version 21.07 of the FIL backend.\n", "
\n", "\n", "## FAQ 2.1: How do I fall back to CPU only if GPUs are not available?\n", "In addition to `KIND_GPU` and `KIND_CPU`, Triton offers a `KIND_AUTO`. To make use of this option, you must specify which GPUs you wish to make use of. If *any* of those GPUs are unavailable, Triton will fall back to CPU execution. An example of such a configuration is shown below: [^](#Table-of-Contents)\n", "```pbtxt\n", " instance_group [\n", " {\n", " gpus: [ 0, 1 ]\n", " kind: KIND_AUTO\n", " }\n", " ]\n", "```" ] }, { "cell_type": "markdown", "id": "44c067fd", "metadata": {}, "source": [ "# Example 2: Generating a configuration file\n", "\n", "Based on the information provided about your model in previous cells, we can now construct a `config.pbtxt` that can be used to run that model on Triton. We will generate the configuration text and save it to the appropriate location.\n", "\n", "For full information on configuration options, check out the FIL backend [documentation](https://github.com/triton-inference-server/fil_backend#configuration). For a detailed example of configuration file construction, you can also check out the [introductory notebook](https://nbviewer.org/github/triton-inference-server/fil_backend/blob/main/notebooks/categorical-fraud-detection/Fraud_Detection_Example.ipynb#The-Configuration-File). [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "a6c1be72", "metadata": {}, "outputs": [], "source": [ "# Maximum size in bytes for input and output arrays\n", "MAX_MEMORY_BYTES = 60_000_000\n", "bytes_per_sample = (NUM_FEATURES + NUM_CLASSES) * 4\n", "max_batch_size = MAX_MEMORY_BYTES // bytes_per_sample\n", "\n", "# Select deployment hardware (GPU or CPU)\n", "if USE_GPU:\n", " instance_kind = 'KIND_GPU'\n", "else:\n", " instance_kind = 'KIND_CPU'\n", " \n", "if IS_CLASSIFIER:\n", " classifier_string = 'false'\n", "else:\n", " classifier_string = 'true'\n", "\n", "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"{model_format}\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"false\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"{classifier_string}\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"AUTO\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", "config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", "with open(config_path, 'w') as file_:\n", " file_.write(config_text)" ] }, { "cell_type": "markdown", "id": "e051c189", "metadata": {}, "source": [ "# FAQ 3: How can I quickly test configuration options?\n", "Sometimes it is useful to be able to quickly iterate on the options available in the `config.pbtxt` file for your model. While it is not recommended for production deployments, Triton offers a \"polling\" mode which will automatically reload models when their configurations change. To use this option, launch the server with the `--model-control-mode=poll` flag. After changing the configuration, wait a few seconds for the model to reload, and then Triton will be ready to handle requests with the new configuration. [^](#Table-of-Contents)\n", "\n", "# Example 3: Launching the Triton server with polling mode\n", "In the following cell, we will launch the server with the model repository we set up previously in this notebook. We will use pulling mode in order to allow us to tweak the configuration file and observe the impact of our changes. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "193612d0", "metadata": {}, "outputs": [], "source": [ "TRITON_IMAGE = 'nvcr.io/nvidia/tritonserver:22.05-py3'\n", "!docker run --gpus all -d -p 8000:8000 -p 8001:8001 -p 8002:8002 -v {MODEL_REPO}:/models --name tritonserver {TRITON_IMAGE} tritonserver --model-repository=/models --model-control-mode=poll" ] }, { "cell_type": "code", "execution_count": null, "id": "ccbcce1b", "metadata": {}, "outputs": [], "source": [ "import time\n", "time.sleep(10) # Wait for server to come up\n", "!docker logs tritonserver" ] }, { "cell_type": "markdown", "id": "e543928f", "metadata": {}, "source": [ "In later sections, we'll take advantage of polling mode to make tweaks to our configuration and observe their impact." ] }, { "cell_type": "markdown", "id": "8c4f1883", "metadata": {}, "source": [ "# FAQ 4: My models are exhausting Triton's memory. What can I do?\n", "Tree-based models tend to have fairly modest memory needs, but when using several models together or when trying to process very large batches, you'll sometimes run into memory constraints.\n", "\n", "For models deployed on GPU, Triton allocates a device memory pool when it is launched, and the FIL backend only allocates device memory from this pool, so one option is to simply increase the size of the pool until you reach hardware limits on available memory.\n", "\n", "For models deployed on CPU or for deployments which have exceeded the hardware limits on available device memory, you may wish to instead reduce the memory consumption of models by tweaking configuration options. Let's look at this case first to showcase the polling mode introduced in the previous example. [^](#Table-of-Contents)\n", "## FAQ 4.1 How can I decrease the memory consumed by a model?\n", "There are two primary ways to reduce memory consumption by a FIL model:\n", "- Reducing max batch size\n", "- Changing the `storage_type` option to use a low-memory model representation\n", "We will take a look at both of these options in the next example.\n", "\n", "Note that reducing `max_batch_size` in `config.pbtxt` will only impact memory consumption if your model is receiving large enough batches to run up against this limit. Keep in mind that with Triton's dynamic batching feature, there is a distinction between *client* batch size and *server* batch size. With dynamic batching enabled, several small batches received from one or more clients can be combined into a larger server batch. The `max_batch_size` configuration option sets the maximum *server* batch that your model will receive.\n", "\n", "For most deployments, the more consistent way to reduce model memory usage is to change the `storage_type` option. FIL gives models an in-memory representation of type `DENSE`, `SPARSE`, or `SPARSE8`, in order of progressively less memory usage. Note that using a `SPARSE` or `SPARSE8` representation can reduce a model's runtime performance and that some models will fail to load as `SPARSE8`. In general, `SPARSE8` should be avoided unless absolutely necessary for memory management.\n", "\n", "By setting `storage_type` to `AUTO`, you allow FIL to select the memory representation that is likely to offer the best mix of runtime performance and memory footprint, but if memory consumption is a problem, explicitly setting `storage_type` to `SPARSE` will ensure that a sparse representation is used. Alternatively, if you wish to maximize runtime performance regardless of memory considerations, you may try explicitly using a `DENSE` representation. In Example 4.1, we will explicitly set the `storage_type` to `SPARSE` and then back to `AUTO`.\n", "\n", "Note that you will not be able to observe a reduction in system GPU memory consumption (e.g. by using `nvidia-smi`) when adjusting these options, since all device memory allocations for FIL models come from Triton's pre-allocated memory pool. [^](#Table-of-Contents)\n", "\n", "## FAQ 4.2 How do I increase Triton's device memory pool?\n", "By default, Triton allocates a device memory pool of 67,108,864 bytes on each GPU. You can adjust this value on server startup with the `--cuda-memory-pool-byte-size` flag. Note that this flag takes arguments in the form `--cuda-memory-pool-byte-size=$GPU_ID:$BYTE_SIZE` and must be repeated for each GPU. In Example 4.2, we will take down the server and restart it with this option. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "14619b97", "metadata": {}, "source": [ "# Example 4: Configuring Triton for large models\n", "\n", "## Example 4.1: Changing `storage_type` to reduce memory consumption\n", "In the following example, we generate a new configuration with `storage_type` set to `SPARSE` and write it out to `config.pbtxt`. Since we have turned on Triton's polling mode, this configuration change will automatically be picked up. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "1f77a8e9", "metadata": {}, "outputs": [], "source": [ "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"{model_format}\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"false\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"{classifier_string}\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"SPARSE\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", "config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", "with open(config_path, 'w') as file_:\n", " file_.write(config_text)" ] }, { "cell_type": "code", "execution_count": null, "id": "b1488aa6", "metadata": {}, "outputs": [], "source": [ "time.sleep(10) # Wait for configuration change to be processed\n", "!docker logs tritonserver\n", "# In the logs, we should see that example_model has been loaded and unloaded successfully\n", "# with the new configuration" ] }, { "cell_type": "markdown", "id": "8f9cdc14", "metadata": {}, "source": [ "## $\\color{#76b900}{\\text{Example 4.2: Increasing Triton's device memory pool}}$\n", "Now, let's take the server down entirely and bring it back up with a larger device memory pool. In the following example, specify the list of GPUs you wish to use and your desired memory pool size. Keep in mind the hardware limits of your system when specifying the pool size. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "256002d3", "metadata": {}, "outputs": [], "source": [ "!docker rm -f tritonserver\n", "time.sleep(10) # Wait for server to come down" ] }, { "cell_type": "code", "execution_count": null, "id": "43de9d5b", "metadata": {}, "outputs": [], "source": [ "POOL_SIZE_BYTES = 100_663_296 # Set based on available device memory\n", "GPU_LIST = [0, 1] # Set based on available device IDs" ] }, { "cell_type": "code", "execution_count": null, "id": "ac2392cf", "metadata": {}, "outputs": [], "source": [ "pool_flags = ' '.join([f'--cuda-memory-pool-byte-size={device}:{POOL_SIZE_BYTES}' for device in GPU_LIST])\n", "print(pool_flags)" ] }, { "cell_type": "code", "execution_count": null, "id": "20996b57", "metadata": {}, "outputs": [], "source": [ "!docker run --gpus all -d -p 8000:8000 -p 8001:8001 -p 8002:8002 -v {MODEL_REPO}:/models --name tritonserver {TRITON_IMAGE} tritonserver --model-repository=/models --model-control-mode=poll {pool_flags}" ] }, { "cell_type": "code", "execution_count": null, "id": "883c63b6", "metadata": {}, "outputs": [], "source": [ "time.sleep(10) # Wait for server to come up\n", "!docker logs tritonserver" ] }, { "cell_type": "markdown", "id": "554b1df1", "metadata": {}, "source": [ "# FAQ 5: How do I submit an inference request to Triton?\n", "Triton supports both GRPC and HTTP requests, and you can use any GRPC/HTTP client to submit requests to Triton. In practice, however, it is often difficult to construct correct inference requests using a generic client, so Triton provides Python, Java, and C++ [clients](https://github.com/triton-inference-server/client#client-library-apis) to make it easier to interact with the Triton server from each of those languages. For other languages, including [Go](https://github.com/triton-inference-server/client/tree/main/src/grpc_generated/go), [Scala](https://github.com/triton-inference-server/client/tree/main/src/grpc_generated/java), and [JavaScript](https://github.com/triton-inference-server/client/tree/main/src/grpc_generated/javascript), Triton provides a protoc compiler that can generate GRPC APIs for inclusion in your application. In this notebook, we will take a look solely at how to make use of the Python client for submitting inference requests. [^](#Table-of-Contents)\n", "\n", "## FAQ 5.1: How do I submit inference requests through Triton's C API?\n", "In addition to its use as an HTTP/GRPC server, Triton can be used as a library within another application using its C API. Example code for this use case with the FIL backend is under development and will be linked to once complete.\n", " [^](#Table-of-Contents)\n", "## FAQ 5.2: How do I submit inference requests with categorical variables?\n", "Triton accepts input as an array of floating-point values, rather than e.g. a dataframe. In order to accept categorical values, we must convert them to floating-point representations of the categorical codes used in training.\n", "\n", "Just as with most frameworks that accept categorical variables, a certain amount of care must be taken to ensure that the same categorical codes are used during inference as during training. For instance, if you trained your model using a dataframe with ten possible categorical values in a particular column, but the data you submit at inference time has only two of the ten possible values, Pandas/cuDF may assign different categorical codes from those used during training, leading to incorrect results. Therefore, it is important to keep some record of the categories used during training to correctly construct an input array. In this example, we will demonstrate how to convert categorical values to their floating point representation using the categorical codes provided by Pandas/cuDF. [^](#Table-of-Contents)\n", "\n", "## FAQ 5.3: Should I use any name besides `input__0` and `output__0` for inputs/outputs to the FIL backend?\n", "**No.** Unless you are using the FIL backend's Shapley value output feature (described later), the sole input array should *always* be named `input__0`, and the sole output array should *always* be named `output__0`. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "2186f89b", "metadata": {}, "source": [ "# Example 5: Submitting a request with the Triton Python client\n", "In the following example, we first use the Triton Python client to ensure that the server is live and the model is ready to accept requests. We then submit a single request and view the result as a numpy array. Next, we submit a larger batch of requests and examine the results. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "f75ecbdd", "metadata": {}, "outputs": [], "source": [ "# Create GRPC client instance\n", "import time\n", "import tritonclient.grpc as triton_grpc\n", "from tritonclient import utils as triton_utils\n", "\n", "# If you are running your Triton server on a remote host or a non-standard port, adjust the\n", "# following settings\n", "HOST = 'localhost'\n", "PORT = 8001\n", "\n", "# TIMEOUT sets the time in seconds that we will wait for the server to be ready before giving\n", "# up\n", "TIMEOUT = 60\n", "\n", "client = triton_grpc.InferenceServerClient(url=f'{HOST}:{PORT}')" ] }, { "cell_type": "code", "execution_count": null, "id": "83004e16", "metadata": {}, "outputs": [], "source": [ "# Check to see if server is live and model is loaded\n", "def is_triton_ready():\n", " server_start = time.time()\n", " while True:\n", " try:\n", " if client.is_server_ready() and client.is_model_ready(MODEL_NAME):\n", " return True\n", " except triton_utils.InferenceServerException:\n", " pass\n", " if time.time() - server_start > TIMEOUT:\n", " print('Server was not ready before given timeout. Check the logs below for possible issues.')\n", " !docker logs tritonserver\n", " return False\n", " time.sleep(1)" ] }, { "cell_type": "code", "execution_count": null, "id": "37abe292", "metadata": {}, "outputs": [], "source": [ "# Convert a dataframe to a numpy array including conversion of categorical variables\n", "def convert_to_numpy(df):\n", " df = df.copy()\n", " cat_cols = df.select_dtypes('category').columns\n", " for col in cat_cols:\n", " df[col] = df[col].cat.codes\n", " for col in df.columns:\n", " df[col] = pd.to_numeric(df[col], downcast='float')\n", " return df.values" ] }, { "cell_type": "code", "execution_count": null, "id": "16a73c41", "metadata": {}, "outputs": [], "source": [ "single_sample = convert_to_numpy(X[0:1])\n", "print(single_sample)" ] }, { "cell_type": "code", "execution_count": null, "id": "58aa08b6", "metadata": {}, "outputs": [], "source": [ "def triton_predict(model_name, arr):\n", " triton_input = triton_grpc.InferInput('input__0', arr.shape, 'FP32')\n", " triton_input.set_data_from_numpy(arr)\n", " triton_output = triton_grpc.InferRequestedOutput('output__0')\n", " response = client.infer(model_name, model_version='1', inputs=[triton_input], outputs=[triton_output])\n", " return response.as_numpy('output__0')" ] }, { "cell_type": "code", "execution_count": null, "id": "77d85efb", "metadata": {}, "outputs": [], "source": [ "# Perform inference on a single sample (row)\n", "if is_triton_ready():\n", " triton_result = triton_predict(MODEL_NAME, single_sample)\n", " print(triton_result)" ] }, { "cell_type": "code", "execution_count": null, "id": "3a7cf54e", "metadata": {}, "outputs": [], "source": [ "# Perform inference on a batch\n", "if is_triton_ready():\n", " batch = convert_to_numpy(X[0:min(len(X), max_batch_size, 100)])\n", " triton_result = triton_predict(MODEL_NAME, batch)\n", " print(triton_result)" ] }, { "cell_type": "markdown", "id": "47088c39", "metadata": {}, "source": [ "# FAQ 6: How do I return probability scores rather than classes from a classifier?\n", "To return confidence scores from 0 to 1 rather than just a final output class, we set the `predict_proba` option to `true` in `config.pbtxt`. Note that for multi-class classifiers, we also need to change the output dimensions, since a confidence score will be generated for each possible class in our output. For single-class classifiers, the output dimension is still 1, since we return the confidence score only for the positive class. In the following example, we adjust these options and perform inference on a single sample once again to demonstrate the change in output. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "95121614", "metadata": {}, "source": [ "# Example 6: Using the `predict_proba` option\n", "
\n", " Note: This example is not relevant for regression models.\n", "
\n", "\n", "In the following example, we will write out a new configuration file with `predict_proba` set to true and the correct output dimensions. We will wait for the change to be picked up by the server's polling feature and then submit an inference request to the model as we did before. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "ddedb01f", "metadata": {}, "outputs": [], "source": [ "if NUM_CLASSES <= 2:\n", " output_dim = 1\n", "else:\n", " output_dim = NUM_CLASSES\n", "\n", "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {output_dim} ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"{model_format}\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"true\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"{classifier_string}\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"AUTO\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", "if IS_CLASSIFIER:\n", " config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", " with open(config_path, 'w') as file_:\n", " file_.write(config_text)" ] }, { "cell_type": "code", "execution_count": null, "id": "52a29a41", "metadata": {}, "outputs": [], "source": [ "time.sleep(10) # Wait for server polling to reload the changed config\n", "if is_triton_ready():\n", " batch = convert_to_numpy(X[0:min(len(X), max_batch_size, 100)])\n", " triton_result = triton_predict(MODEL_NAME, batch)\n", " print(triton_result)" ] }, { "cell_type": "markdown", "id": "9b882395", "metadata": {}, "source": [ "# FAQ 7: Does serving my model with Triton change its accuracy?\n", "In general, no. Models served with Triton's FIL backend are loaded and executed according to the exact same rules as their native frameworks. However, inference with FIL is subject to floating point error, and the highly-parallel nature of FIL means that there may be slight differences in results between inference runs. If a confidence score is close to the decision threshold, this may even impact the output class reported for the model, so it is recommended to set `predict_proba` on if you are comparing model output on Triton vs. local execution.\n", "\n", "Additionally, LightGBM allows training and execution of models with double precision, while FIL currently uses single precision for all models. An update is under development to allow double precision execution, but for now, this may lead to slight differences in output between double precision models in their native framework and the same model loaded in the FIL backend.\n", "\n", "Regardless, these differences tend to be very marginal, and even with these caveats you can expect approximately the same accuracy for your model when served with Triton. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "2c5c4c32", "metadata": {}, "source": [ "# Example 7: Comparing results from Triton and native execution\n", "In the following example, we obtain results from both Triton and from our local copy of the model. We compare these results and demonstrate that they are the same to within the differences we might expect from floating-point error. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "fa64b258", "metadata": {}, "outputs": [], "source": [ "# Obtain results from native framework\n", "batch = X[0:min(len(X), max_batch_size, 100)]\n", "native_result = model.predict_proba(batch)\n", "print(\"Output from native framework:\")\n", "if IS_CLASSIFIER and NUM_CLASSES <= 2 and len(native_result.shape) == 2 and native_result.shape[1] == 2:\n", " native_result = native_result[:, 1]\n", "print(native_result)" ] }, { "cell_type": "code", "execution_count": null, "id": "fdda0dd5", "metadata": {}, "outputs": [], "source": [ "# Obtain results from Triton\n", "if is_triton_ready():\n", " numpy_batch = convert_to_numpy(batch)\n", " triton_result = triton_predict(MODEL_NAME, numpy_batch)\n", " print(\"Output from Triton:\")\n", " print(triton_result)" ] }, { "cell_type": "code", "execution_count": null, "id": "9da39908", "metadata": {}, "outputs": [], "source": [ "# Compare results and compute maximum absolute difference for any output entry\n", "import numpy as np\n", "max_diff = np.max(np.abs(triton_result - native_result))\n", "print(f'The maximum absolute difference between the Triton and native output is {max_diff}')" ] }, { "cell_type": "markdown", "id": "899f08d5", "metadata": {}, "source": [ "# FAQ 8: How do we measure performance of the FIL backend?\n", "For practical purposes, there are two numbers we generally care about when considering the performance of an inference server like Triton: **throughput** and **latency**. Throughput gives a measure of how many samples (rows) can be processed in a given interval, while latency gives a measure of how long a client will have to wait after sending a request before it receives a response.\n", "\n", "Depending on your specific application, throughput or latency may have a higher or lower priority. Imagine that you are developing an application to analyze credit card transactions for potential fraud and stop suspected fraudulent transactions before they can go through. In such cases, latency is extremely important because a fraud/non-fraud determination must be made before a response can be sent to the point-of-sale. If the latency is too high, you may not be able to make a determination before the transaction has been processed.\n", "\n", "On the other hand, imagine that you are developing an application to *retroactively* process all credit card transactions cleared in a day and look for fraudulent transactions that might already have occurred. In this case, throughput is the most important performance metric, since all we care about is the rate at which we can process the entire batch of daily transactions.\n", "\n", "As another way of thinking about this tradeoff, latency generally determines the *responsiveness* of your application, while throughput determines the *compute cost*. With higher throughput, you will be able to process more total samples in a shorter period of time, which means that you can use fewer instances to handle required traffic. With lower latency, you will be able to quickly return responses, but you may need more instances to handle the volume of traffic you expect.\n", "\n", "The FIL backend and Triton itself offer a range of configuration options to allow you to achieve exactly the right balance of throughput and latency for your application. We will go through a few of them in the following sections. For each option we consider, we will make use of Triton's `perf_analyzer` [tool](https://github.com/triton-inference-server/server/blob/main/docs/perf_analyzer.md#performance-analyzer) which allows us to measure latency and throughput with a variety of configuration options. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "d7ce4221", "metadata": {}, "source": [ "# Example 8: Using `perf_analyzer` to measure throughput and latency\n", "\n", "Triton's `perf_analyzer` offers a huge number of options that allow you to simulate real deployment conditions and measure both throughput and latency. In this example, we will focus on just a few of these options that will help us to assess overall performance of models deployed with the FIL backend." ] }, { "cell_type": "code", "execution_count": null, "id": "3c37e83e", "metadata": {}, "outputs": [], "source": [ "# The simplest invocation of perf_analyzer, submitting one sample at a time\n", "!perf_analyzer -m {MODEL_NAME}" ] }, { "cell_type": "markdown", "id": "9047766c", "metadata": {}, "source": [ "In the above output, we see the throughput in inferences/second (equivalently samples/second or rows/second) and the latency in microseconds. Several different latency scores are reported, including the average latency and p99 latency (the latency of the request in the 99th percentile of recorded latencies). Often, service level agreements are established around a particular latency percentile, so you can see the value which is most relevant to your application.\n", "\n", "While this gives us a starting point for understanding our model's performance, other factors can play into both throughput and latency. Let's see what happens when we switch from the default HTTP to GRPC" ] }, { "cell_type": "code", "execution_count": null, "id": "37513ea8", "metadata": {}, "outputs": [], "source": [ "!perf_analyzer -m {MODEL_NAME} -i GRPC" ] }, { "cell_type": "markdown", "id": "1de98aa2", "metadata": {}, "source": [ "Depending on your model, this may or may not have had much effect on the measured performance, but generally, Triton gets slightly better performance over GRPC than HTTP. Let's now look at what happens when we increase our client batch size from 1 to 16:" ] }, { "cell_type": "code", "execution_count": null, "id": "86f40963", "metadata": {}, "outputs": [], "source": [ "!perf_analyzer -m {MODEL_NAME} -i GRPC -b 16" ] }, { "cell_type": "markdown", "id": "b954e516", "metadata": {}, "source": [ "Depending on your model, throughput likely went up, but latency may have gone up as well. If you have the option of doing client-side batching in your application, exploring possible batch sizes with `perf_analyzer` can help you find the right balance of latency and throughput. In general, you should not be afraid to consider very large batch sizes when running FIL on the GPU. FIL offers throughput benefits for batches consisting even of millions of rows, but latency may suffer due to the overhead of transferring such large arrays over the network.\n", "\n", "While client-side batching is not always an option, we can take advantage of Triton's dynamic batching feature to combine many small client-side batches into a larger server batch. Let's use `perf_analyzer`'s `--concurrency-range` option to see what happens when many small concurrent requests are submitted to the server at once. As the name of the flag implies, it is generally used to explore a *range* of concurrencies, but here we will use it just to explore a single higher concurrency value" ] }, { "cell_type": "code", "execution_count": null, "id": "061501d8", "metadata": {}, "outputs": [], "source": [ "!perf_analyzer -m {MODEL_NAME} -i GRPC --concurrency-range 16:16" ] }, { "cell_type": "markdown", "id": "a79765de", "metadata": {}, "source": [ "There is some additional work required to combine these small batches and distribute responses, so the measured performance is likely somewhat worse than when we were doing a similar amount of batching client-side. Nevertheless, throughput likely increased relative to our original run, demonstrating the potential value of Triton's dynamic batching feature. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "dfed583c", "metadata": {}, "source": [ "# FAQ 9: How can we improve performance of models deployed with the FIL backend?\n", "\n", "A number of factors impact the throughput and latency we can expect for models deployed on the FIL backend including:\n", "- The complexity of the model\n", "- Whether GPUs or CPUs are used for inference\n", "- Client batch size\n", "- Whether or not Triton's dynamic batching feature is used\n", "- Whether or not Triton's shared memory mode is used\n", "- Whether HTTP or GRPC is used\n", "- On CPU:\n", " * The `use_experimental_optimizations` flag in `config.pbtxt`\n", "- On GPU:\n", " * The `storage_type` option in `config.pbtxt`\n", " * The `transfer_threshold` option in `config.pbtxt`\n", " * The `algo` option in `config.pbtxt` (advanced feature)\n", " * The `blocks_per_sm` option in `config.pbtxt` (advanced feature)\n", " * The `threads_per_tree` option in `config.pbtxt` (advanced feature)\n", "\n", "Making the correct choice for each of these depends on your model, hardware configuration and on whether you are prioritizing latency, throughput, or some combination of both. In the following examples, we will first look at a case where we are prioritizing latency above all else. Then we will look at ways to maximize throughput, and finally we will try to find a good balance of both. [^](#Table-of-Contents)\n", "\n", "## FAQ 9.1: Does specifying preferred batch sizes help FIL's performance?\n", "Triton allows specification of \"preferred\" batch sizes in the configuration of its dynamic batching feature, but this does **not** offer any benefit to performance of models deployed with the FIL backend. While there are theoretically extremely marginal performance benefits at specific batch sizes, the actual model evaluation in FIL is fast enough that waiting for a specific batch size instead of proceeding immediately with whatever input data is available is counterproductive. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "83f3f53b", "metadata": {}, "source": [ "# Example 9: Optimizing model performance\n", "\n", "## Example 9.1: Minimizing latency\n", "
\n", " WARNING: Using the configuration we consider here is NOT recommended for any sort of production deployment, unless your application truly requires minimal latency down to the microsecond and you have either extremely light traffic or the ability to scale out onto a very large number of instances.\n", "
\n", "\n", "In this example, we will focus on minimizing latency at all costs. In reality, there is a point of diminishing returns at which substantially higher throughput can be achieved by adding just a few microseconds to latency, but we will take the extreme case here to illustrate all the ways you can reduce latency for your model.\n", "\n", "Note that by taking these steps, you will likely have to invest in more compute instances to handle total incoming traffic. Taken to extremes, a minimal latency configuration could have a dedicated instance for each client, so look to Example 9.3 for a more balanced approach.\n", "\n", "### Step 1: Minimize batch size\n", "If all we care about is achieving the lowest possible latency, we will want to ensure that data is processed in as small of batches as possible. To do this, we will turn off Triton's dynamic batching feature by removing the `dynamic_batching` entry from `config.pbtxt` and submit client batches of size 1. Note that doing this will substantially impact our ability to handle high server traffic unless we scale out to more compute instances." ] }, { "cell_type": "code", "execution_count": null, "id": "f322d78d", "metadata": {}, "outputs": [], "source": [ "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"{model_format}\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"false\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"{classifier_string}\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"AUTO\" }}\n", " }}\n", "]\n", "\"\"\"\n", "\n", "config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", "with open(config_path, 'w') as file_:\n", " file_.write(config_text)\n", "time.sleep(10)" ] }, { "cell_type": "markdown", "id": "70588003", "metadata": {}, "source": [ "### Step 2: Use shared memory mode (if possible)\n", "In the rare case that your input data is coming from the same machine that is making requests, you can substantially improve performance (both throughput and latency) by using Triton's shared memory mode. In this mode, rather than transferring the entirety of an input array over the network, the data are loaded into memory by the client, and then Triton is provided with a pointer to that data in memory.\n", "\n", "Usually, input requests do not originate from the server, so this option is not always possible. Nevertheless, let's take a look at this option with `perf_analyzer`, in part to get a sense of how much overhead is due to network transfer of the input array.\n", "\n", "
\n", "VERSION NOTE: Prior to version 21.07, a bug resulted in incorrect results when CUDA shared memory mode was used with a model deployed on CPU. It is generally not recommended that you use CUDA shared memory mode with a CPU-deployed model prior to version 21.11.\n", "
" ] }, { "cell_type": "code", "execution_count": null, "id": "5ad2e415", "metadata": { "scrolled": true }, "outputs": [], "source": [ "if is_triton_ready():\n", " if USE_GPU:\n", " !perf_analyzer -m {MODEL_NAME} -i GRPC --shared-mem=cuda\n", " else:\n", " !perf_analyzer -m {MODEL_NAME} -i GRPC --shared-mem=host" ] }, { "cell_type": "markdown", "id": "4a5c5246", "metadata": {}, "source": [ "For our remaining experiments in this example, we will **not** make use of shared memory." ] }, { "cell_type": "markdown", "id": "abbb87f1", "metadata": {}, "source": [ "### Step 3: Activate experimental CPU optimizations\n", "For simple enough models we can sometimes achieve better minimum latency on CPU than on GPU because we can avoid the overhead of transferring data from host to device and back again. Whether this offers benefit for any particular application depends significantly on the model and on your available hardware (both CPU and GPU). Regardless, if minimizing latency is essential, you will want to set the `use_experimental_optimizations` flag to `true` in order to obtain the best possible CPU performance.\n", "\n", "This flag uses an alternate CPU inference method with substantially improved performance. Despite the name of the flag, this method is quite stable and will become the default CPU execution path in future versions of the FIL backend. Let's try activating this mode now, switching to CPU execution, and looking at the impact on latency.\n", "\n", "
\n", "VERSION NOTE: This option was added in version 22.04. Version 22.03 did not offer the use_experimental_optimizations flag, but it did include some CPU optimizations relative to earlier versions. Prior to version 22.03, CPU execution was not at all optimized. If you are using a version of Triton before 22.03, GPU execution should be used for optimal performance (both throughput and latency).\n", "
" ] }, { "cell_type": "code", "execution_count": null, "id": "9705a382", "metadata": {}, "outputs": [], "source": [ "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }}\n", "]\n", "instance_group [{{ kind: KIND_CPU }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"{model_format}\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"false\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"{classifier_string}\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"AUTO\" }}\n", " }},\n", " {{\n", " key: \"use_experimental_optimizations\"\n", " value: {{ string_value: \"true\" }}\n", " }}\n", "]\n", "\"\"\"\n", "\n", "config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", "with open(config_path, 'w') as file_:\n", " file_.write(config_text)\n", "time.sleep(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "5c235042", "metadata": {}, "outputs": [], "source": [ "if is_triton_ready():\n", " !perf_analyzer -m {MODEL_NAME} -i GRPC" ] }, { "cell_type": "markdown", "id": "758a18ed", "metadata": {}, "source": [ "The impact of this change will depend on the complexity of your model and on your available CPU/GPU hardware, but you are likely to see a modest improvement in latency and a slight decline in throughput. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "c5c3b41f", "metadata": {}, "source": [ "## Example 9.2: Maximizing Throughput\n", "We now turn to the opposite extreme and attempt to maximize throughput without any consideration for latency. While the absolute-minimum latency configuration rarely makes sense, there are many applications for which latency can be arbitrarily high. If you are batch-processing a large volume of data without real-time client interaction, throughput is likely to be the only reblevant metric.\n", "\n", "### $\\color{#76b900}{\\text{Step 1: Use GPU execution}}$\n", "Maximum throughput for any model deployed with the FIL backend will be substantially higher on GPU than on CPU. In fact, this increased throughput more than offsets the increased per-hour cost of most GPU cloud instances, resulting in cost savings of around 50% for many typical model deployments relative to CPU execution.\n", "\n", "### Step 2: Use dynamic batching\n", "By combining many small requests from clients into one large server-side batch, we can take maximal advantage of the highly-optimized parallel execution of FIL.\n", "\n", "### Step 3: Increase client-side batch size (if possible)\n", "If your application allows it, increasing the size of client batches can further increase performance by avoiding the overhead of combining input data and scattering output data.\n", "\n", "### $\\color{#76b900}{\\text{Step 4: Try storage_type DENSE (if possible)}}$\n", "If you have enough memory available, switching to a `storage_type` of `DENSE` *may* improve performance. This is model-dependent, however, and the impact may be small. We will not attempt this here, since you may be working with a model that would not fit in memory with a dense representation.\n", "\n", "### $\\color{#76b900}{\\text{Step 5 (Advanced): Experiment with algo options}}$\n", "We can use this configuration option to explicitly specify how FIL will lay out its trees and progress through them during inference. The following options are available:\n", "- `ALGO_AUTO`\n", "- `NAIVE`\n", "- `TREE_REORG`\n", "- `BATCH_TREE_REORG`\n", "It is difficult to say *a priori* which option will be most suitable for your model and deployment configuration, so `ALGO_AUTO` is generally a safe choice. Nevertheless, we will demonstrate the `TREE_REORG` option, since it provided about a 10% improvement in throughput for the model used while testing this notebook.\n", "\n", "
\n", " NOTE: Only ALGO_AUTO and NAIVE can be used with storage_type SPARSE or SPARSE8\n", "
\n", "\n", "### $\\color{#76b900}{\\text{Step 6 (Advanced): Experiment with the blocks_per_sm option}}$\n", "This option lets us explicitly set the number of CUDA blocks per streaming multiprocessor that will be used for the GPU inference kernel. For very large models, this can improve the cache hit rate, resulting in a modest improvement in performance. To experiment with this option, set it to any value between 2 and 7.\n", "\n", "As with `algo` selection, it is difficult to say what the impact of tweaking this option will be. In the model used while testing this notebook, no value offered better performance than the default of 0, which allows FIL to select the blocks per SM via a different method.\n", "\n", "### $\\color{#76b900}{\\text{Step 7 (Advanced): Experiment with the threads_per_tree option}}$\n", "This option lets us increase the number of CUDA threads used for inference on a single tree above 1, but it results in increased shared memory usage. Because of this, we will not experiment with this value in this notebook.\n" ] }, { "cell_type": "markdown", "id": "910977b1", "metadata": {}, "source": [ "Combining all of these options, let's write out a new configuration file and observe the impact on throughput with `perf_analyzer`. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "dd76a83e", "metadata": {}, "outputs": [], "source": [ "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"{model_format}\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"false\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"{classifier_string}\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"AUTO\" }}\n", " }},\n", " {{\n", " key: \"algo\"\n", " value: {{ string_value: \"TREE_REORG\" }}\n", " }},\n", " {{\n", " key: \"blocks_per_sm\"\n", " value: {{ string_value: \"0\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", "config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", "with open(config_path, 'w') as file_:\n", " file_.write(config_text)\n", "time.sleep(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "e49095d5", "metadata": {}, "outputs": [], "source": [ "if is_triton_ready():\n", " !perf_analyzer -m {MODEL_NAME} -i GRPC -b {max_batch_size}" ] }, { "cell_type": "markdown", "id": "8433b206", "metadata": {}, "source": [ "### Step 8: Increase concurrency with a slightly smaller client batch size\n", "Because of the overhead of transferring large input arrays over the network, it can sometimes be beneficial to submit arrays of a size less than the maximum batch size but to do so at higher concurrency. This allows the FIL backend to begin working on a batch while network transfer continues, effectively \"overlapping\" data transfer and processing. Let's take a look at what happens when we submit batches of half the size we were before but with higher concurrency." ] }, { "cell_type": "code", "execution_count": null, "id": "df68b5db", "metadata": {}, "outputs": [], "source": [ "# First submit batches of maximum size at concurrency 16\n", "!perf_analyzer -m {MODEL_NAME} -i GRPC -b {max_batch_size} --concurrency-range 16:16" ] }, { "cell_type": "code", "execution_count": null, "id": "61de18cd", "metadata": {}, "outputs": [], "source": [ "# Now submit batches of half the maximum size at the same concurrency\n", "half_max_batch_size = max_batch_size // 2\n", "!perf_analyzer -m {MODEL_NAME} -i GRPC -b {half_max_batch_size} --concurrency-range 16:16" ] }, { "cell_type": "markdown", "id": "03a2a1c9", "metadata": {}, "source": [ "Depending on your model, this may or may not have improved throughput, but in general there is some optimal client batch size below the maximum server batch size that allows for increased throughput while *also* reducing latency. This technique is so valuable we will revisit it in more detail in FAQ 11, where we will demonstrate how to do this with the Python client. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "40f7f278", "metadata": {}, "source": [ "## Example 9.3: Balancing latency and throughput\n", "Let's now imagine a somewhat more typical scenario, where rather than working toward absolute minimal latency or absolute maximum throughput, we have some latency budget and are looking to maximize throughput within that latency budget.\n", "\n", "Let's imagine that we have a p99 latency budget of 2 ms. We will start by writing out a configuration that makes use of features we explored in examples 9.1 and 9.2, then we will use `perf_analyzer` to see how our model will perform under various deployment scenarios." ] }, { "cell_type": "code", "execution_count": null, "id": "b853b08a", "metadata": {}, "outputs": [], "source": [ "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"{model_format}\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"false\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"{classifier_string}\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"AUTO\" }}\n", " }},\n", " {{\n", " key: \"algo\"\n", " value: {{ string_value: \"TREE_REORG\" }}\n", " }},\n", " {{\n", " key: \"blocks_per_sm\"\n", " value: {{ string_value: \"0\" }}\n", " }},\n", " {{\n", " key: \"use_experimental_optimizations\"\n", " value: {{ string_value: \"true\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", "config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", "with open(config_path, 'w') as file_:\n", " file_.write(config_text)\n", "time.sleep(10)" ] }, { "cell_type": "markdown", "id": "9098ea38", "metadata": {}, "source": [ "Let's assume for the moment that we must hold our client batch size fixed to just 1. Using `perf_analyzer`, we can explore a range of concurrencies to see how well our server will handle increasing amounts of traffic for the model. Using the `--binary-search` flag, we can perform a binary search over a range of concurrencies to find the highest value that does not exceed our 2 ms p99 latency budget (set with the `-l` flag)." ] }, { "cell_type": "code", "execution_count": null, "id": "673693dc", "metadata": {}, "outputs": [], "source": [ "if is_triton_ready():\n", " # Find the highest concurrency between 1 and 16 that does not exceed 2ms p99 latency\n", " !perf_analyzer -m {MODEL_NAME} -i GRPC --binary-search --percentile 99 --concurrency-range 1:16 -l 2" ] }, { "cell_type": "markdown", "id": "dfda9f90", "metadata": {}, "source": [ "For the model and hardware used to test this notebook, a concurrency of 8 was the highest we could go for the given p99 latency budget. This gives us a sense of how many concurrent requests our model can handle without threatening our latency target.\n", "\n", "Now, let's assume that we *can* increase client batch size. Can we set the client batch size at the given concurrency and still stay under our 2 ms p99 latency target? There is no batch size search option for `perf_analyzer`, but we can simply experiment with a range of batch sizes in the following invocation." ] }, { "cell_type": "code", "execution_count": null, "id": "612ac2d8", "metadata": {}, "outputs": [], "source": [ "if is_triton_ready():\n", " !perf_analyzer -m {MODEL_NAME} -i GRPC --concurrency-range 8:8 -b 3" ] }, { "cell_type": "markdown", "id": "d4f8d3c0", "metadata": {}, "source": [ "For the model and hardware used to test this notebook, we could increase the client batch size to 3 without meaningfully changing the p99 latency, thereby tripling our throughput. By exploring these kinds of tradeoffs with `perf_analyzer`, you can find a balance of throughput and latency which meets your latency budget but which also maximizes throughput (and therefore minimizes compute costs). [^](#Table-of-Contents)" ] }, { "attachments": { "Throughput%20vs%20Batch%20Size%20for%20XGBoost_FIL%20Only.png": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAFzCAYAAADi5Xe0AABhiUlEQVR4XuydC5xdVXX/+UBBDShVsYBEfKUFI2hBEdRi1VZsbbVqK1g+raBFfGBtRZF3RAIhEoEIAYUkyDPIM4gk0ATygLzfCUNIIIQk5J1MXjOTx7zWP2vnvzZn1ux77t0r986c2fP75vP75N7zWPfs39x99rp773POAQQAAAAAAKrKAXoBAAAAAADYP5BgAQAAAABUGSRYAAAAAABVBgkWAAAAAECVQYIFAAAAAFBlkGABAAAAAFQZJFgAAAAAAFUGCRYAAAAAQJVBggUAAAAAUGWQYAEAAAAAVBkkWAAAAAAAVQYJFgAAAABAlUGCBQAAAABQZZBgAQAAAABUGSRYAAAAAABVBgkWAAAAAECVQYIFAAAAAFBlkGABAAAAAFQZJFgAAAAAAFUGCRYAAAAAQJVBggUAAAAAUGWQYAEAAAAAVBkkWAAAAAAAVabqCdZFF10Upew+PZnuLEN7ezvNmzdPL+529N+aNWDAAPrd735HL7/8st48iv0tc7X+Xo2NjfTkk0/SkCFD6LLLLqMrr7yShg4dSuPHj6ddu3Z12LZan2lh27ZtdPfddzv/r7jiCnr00Uf1JlXnnnvuceX99a9/TS0tLR3WtbW10Y033ujW33vvvR3WCeztc889R7/97W/p2muvpUsuuYR+8Ytf0K233koTJ06k3bt3610c+jtXze9dLSj3Xa6vr3dlYA8EXb6QsoSW5S23wMf5xz/+0dWFyy+/3NWFm266iZ5++mn3/dsfqnmcAHQVSLCqRHeW4YYbbui2z85D/62z+vnPf06vvvqq3qVi9rfM1fh7bd26lQYOHNipbKLrrruOduzY4bevxmda4eQie2yjR4/Wm1QdTpCuuuoq93ljx47tsG7ChAluOa/n7TR1dXUuKdKeZvXLX/6SXn/9db1rp+2y2t/vXS0o912ePn26W//QQw/5ZbpcIWUJLctbHsu0adNcAqyPQcQ/PubOnat3q5hqHScAXUnVE6wQ5SpHufU9ge4sQ3d+dh76uLjXYvv27a6h4OXc6FvRsWPZ3/2Z+++/38X4zW9+Q6+99hrt2bOHdu7cSa+88opvNB988EG9W7dw6aWXuuPhHpzW1tZOPUq1YsGCBe5zL774Ylq1apVbtnnzZtfg8nJer3nppZf832fEiBH04osvUkNDgztu/n/RokV08803u/Wc4JbrKazm964W6OPVSE/g/Pnz/bJy+2hKbV9qeQwzZ870cf7whz/QihUrXF3gv8vy5cvpvvvu8+vzeuryqMZxAtDVIMGqEt1Zhu787DxKHRf3WPByHkawUip2pezv/oz0sHDCoFm/fr1bx70sRUDKy8NRXY0eKrz99tvde16u4UaZPeP13MtVCk62pFeOhwuzlPrbVuN7VwtKHS/Dfy8eFuWeN04uhbx9QpTavtTySuFj4mSZjy8veZoyZYr7HK4z2XJUyv4eJwDdQaESLJ5T8cADD7gT4NVXX+2GFfjXp96Oh2Z4ngtX7Lvuusuvb25upqeeeooGDRrkuqt5iIbH//Wv9VLHE1rOx8TzCviXMs8pePjhh108vW1sGbiBeOSRR1xMjs1DNuV+iYeWy+usSrFhwwa3/vrrr9erHIMHD3brN27c6N7z0AT3zvCcHS7PLbfcQrNnz1Z7labU8TQ1Nbnl11xzjV7leie44WRf+G/I807Yf/5FLOjy6s9gH3kfjs89NzxnRw8lyX7l/l55SIKVHQbMQx+rLoOWwN8VntPFfx/2hL/f/D3X3+sQOqaOHVtnStW9PLhB5SSB95deJH4famhnzJjh1g8fPlyv6sTq1atdLyL3cGXRZRTyvneV+hC7bbk6pP8u+rj5e8vLeL5altC2eZTavtTySuHvJe/PdagcI0eOdNvyvDpBPp/POdxbyd8rHjb+05/+5L73ejv54cJ1IfRjgX/s8Hr+mwDQ3RQqweIKJq9F48aN67TdHXfc4f7nX02PPfaYW8eVkSfD6v1Z/Is5lORo9HLehxtnHS/b5a33rbQMPMlYb8fDHrHHqWOEts/CJ3veZu3atR2W8/ANL+cGgJFfnCFlT5B56OPhEyI3qjxsxsv/7//+L7P1GyfrkOTvzOh12c/gRk4mT2fFCRsPE+kY5f5eeUg5+Pu1Zs0avboT+lj152ox7Jk0TFpcD0KNTBa9Tza2pc6E6l4l8PCW7Mf/Z4e7skh8TrStZMvIlPvexfgQs20ldUgvZ2WRuWqccGQJbZtHqe1LLa+U2267ze1fycUDnAjztryPIJ8fmssYOm8y4j8PJWuef/55t27MmDF6FQBdTqESLL76hBt+7q3gyiXL9HY8v2XTpk2ZCPsmWfI6/mW8ePFi14vBFVp+OfMvSaHU8ejlc+bM8TE5FsdcuHCha6z1trFl4F+0POlTHyf/gtfbavRy/T4POenzL+4sfALn5eIT9+bwe+7251/sXB7xmH+5V4IcV0jce6d7iqRHiD9H5nA8++yzbhn7k6VUmbnh4uX8C5Z//XOvxZ133umWhZK0cn+vPPj4hg0b5mOxL9zDyckBe6YpdcwC7yMJH/9CZ+Q7yIkxJ8G8DffccCLMy3n+SyWEPttSZ0J1r1K4V4pj5PVOcY8lb5NNhmORYw0p9L2L8SFm20rrkBxbCBlOXbJkSYflulxamtjllSIXMYQuVNDw35S3zfYgyufz94q/1xyHe595mdSB7HaMzOvjuqKR75jusQagOyhUgsWTgwXpzucuY70dn9g08ksqm6AwcgUO90QJpY5HL5fGLnvSZORXUnbb2DLoeSNynHxCFfRnlFqu3+fBJzAe1sgOE/IvfBlO40naDA9ncMxKG/AQclwh8TAgT4YthwzHcs9HllJlll+32R4QHn7gZdlhA9m/3N+rHNxYcwPK3xWZSM7iZFEPp5Y6ZoYbaukt5R44GT6TBlaGbQUZCuHyVkLosy11JlT3KiE7TMj/h4YHGfEwOzwkyDGEVOl2oe9djA8x21Zah0JlYPi7z35wfdUJuy6XliZ2eaXIlYM6aQ3Bf1PelsskyOcvW7bML5NEjGPr7Rj+LO7x4nNCNtlnjzh2pT8AAag1hUqwsve14UZf7yfvQ7+WpPdD32+F54zwcu51EnTcUstlsq2OuWXLlk7bxpaB7xmTRY6TfxEKet9Sy/X7cvz+979328swIV+2zu95PovAPTESl5MHvhqIfxWG5pmUQh8X+8F/O26MOInhk2Hoknn+HG7AuMeJe5N0HCa0jJFGvNy8KNm/3N8rBj7B87AFH3foKrlSsTnZ4PsF8TpuoCXJZaQ8pcTrKyH02ZY6E6p7lcD34OL9pYeB34eQoaJQD5Yue1ah7YRy37sYH2K2rbQO6eMVeNiNl4eueiy1TylKbV9qeaVID5aePxoiNAdOPj9bDzmB0sel30tvM/d2CZz88zK+Lx0ARaBQCZZGL5f3oXknfBk4r9MnL+kBKfVrSAg1ruViho5No5fL+1IxLcep35fjhRdecNvLMCEnBPx+6dKlfhtOFp544gk3lCnxWXxC1T0Apcg7Lh5q5XU8xCZwo5odctPKElrGyN8s1AOSpdT+pZbHwr/IOQ5PCBdCsblh5t5EXs6NsO6pkPKUEq+vhNBnl/t+h76LobpXDpl/xd8dLp98bugWDdJDVGlPWahcoWVC6HsX40PMtpXWoVLHy/OIePkzzzyjV5XcpxSlti+1vFL4+837h34oafj8wtuGekY1erl+z+cK/ltwQisXwPCFQryN3A4EgO6mRyZYIeSXJTdYWaS3KXu5vMTJdmvzLzAdv1RM+bVaybHp5fJe/wKW4+Rf8EKlx6nfl4OTD+754IadY/MJn39VhhpPPnlxo8T3t5EevWzSkEfecfEvVl6XHY6Ty/l5KI/nhPH8I7nykZU9vlKxZX5cqSEoodT+pZZrSvVkCNy46vLp2DzsJ3OO+MKJUFIo5SnXI1cO/dlMqe93Xp2Jhf8O0svB848YSRx4uf47yfA7z5urhNBxhZYJoe9djA8x2wrl6lCp45XkZeXKlXpVyX1KUWr7UssrhS8Y4P1HjRqlV3VC6vfkyZP9slKfr5fr94zE46F4Pjfw0CCGB0GRSCbBkvkrem6ETCrNzlWReQPZk7t0x2fjy3129BwKuZS8kmPTy+W9np8jk8+zl75Xepz6fSVIrxX/wub/9Z22Q/Bx8LbZORR55B0X/4LndTwRWJA5K9nhIW5cJE6216BUbLkBJd8JXJBHjYTmYGlKLdfI1W76YgGBJ+zy+lKTmXm9JB488TqU3DLyvZ41a5ZeFUWoXDF1JrR/JUgjyJ8lZeS/owz96nth8TCS+MIXOOQhc+v0cYWWCaHvXYwPMduGCNWh0PGyDzzHiBO60Pym0D55lNq+1PJK4bolw+H6nJZFzpn8gyF7Piv1+Xq5fs9ILzF7Lr2k+mpLALqTZBIsSVC4ZyZ0dc/UqVP9tnK/J/71xb9oeS4S3wRRx+dGjd/zr06eW8PbcmxpACo5Nr1c3nNMvjKIj5OH7OSXMb8WKj1OGbbgRrtSsokLS0+ils/hhoM/m3+F8y9PXsZX/FSCPk6GGwuehxK607n4yldXciPMc1Y4QZE42cSrVJnlsnZuwPmeOTz3hpNWXpZ9PEzo2PKWa2S+BzeCjz/+OK1bt84ds3zvOJnj9dkELBtbeqbK3emdv7e8HX+P+bvBfwee4M7zTHg538KhEkLliqkzof3LIVd7cTKhv188jFNqqJDrhdzOgX/k8HruKeQePv57cl1k3+QHSLbXlwkda973LsaHmG0rrUOh77J4V2quWqiMeZTavtTyGLI/OLmXjs8t3IPL4te8TNbrx+WU+ny9XL8X2GP+rsh3IdTbB0B3kUyCxSffUvN3+BdO9lcg3yRQbyONcDY+7xO6D5Yoe2Wb3rfUcnkfiqtPppUepwwzsbK/zMshc3/k3ldZ+ESoP1ukG8RS6P20OKnMTvbPTgoOKXtDyVJl5kZMypWVHo6S5ZpSy0PI8Egp8d+4VK+b3jYkhr+Dpe67xP5xYlcJ2ZhCTJ0J7Z9H9jmEpXqiskOFevI899SG7o2UFdc/vsEl9/Zk0dtp6e9djA8x21Zah0LfZe7V5PcyrKqR7Sul1Pb6uLQqhXuvslfRanEvlx4JYEp9jl6u3wvyA4TFPpbqCQagO0gmwWK4ceWTNlc0/kXDPUB8tYmekMonSe4B4BM7T0DlpEUeo6Hj8y9UHk7jHie5E7NMEud9hdC+oeXynhsF/lw+8fBxcs+LHgqo9Dj5Fz0nFVxmLnul8ORZjqNvQyFwtzuXlxskPk5+nR16K4ccZ1ZyjPyrVt9Pib3mngUuL38mN2TcEE2cONHtm71/Ul6ZOZHi+NyrwCd9Hs7TiYj2sNzyUvAwBV99ycfAn8U9U+wTn/j13zMbW/sSksA9AdwTxr1iXF5OPPgzuYeuUnRModI6U2r/UsjQIN9yIjS3jMkOFd577716tfs+cKPMk/+5J5P95e8Fz03iY9a9YoL2kZX3vWMq9SF220rqUOi7LD2g5cpYKaW21z5pxcDz0niqAV8Ry/WAxX9//qHIc9RClPocvVy/F3huoqzj6Q4AFIkuSbBSQya0cvd0LKVOFAAAAOLIPiGgkqcpANCVIMHKQSZdczc935uIf4nzPCiZ/J69B0ulIMECAID9g0cAeJhX5rlVcpEOAF0NEqwc5Eq7kLgLX8/9qAQkWAAAsH9kz8U8JFlqGBqA7gQJVg48h4bnJ/G8CZnPw/Mt+FLg7N22Y0CCBQAA+wefh3lOG19FW+pedAB0N0iwAAAAAACqDBIsAAAAAIAqgwQLAAAAAKDKIMECAAAAAKgySLAAAAAAAKoMEiwAAAAAgCqDBAsAAAAAoMogwQIAAAAAqDJIsAAAAAAAqgwSLAAAAACAKlP1BOv222/Xi2rKnj17aMqUKXTvvffS8OHD6YEHHqC5c+dSe3u73tREV5cHAAAAAD2fHp9gjR8/nl544QXavXu3S6p27drlnh84c+ZMvamJri4PAAAAAHo+XZ5gcRI0e/Zs1+PED+rkByc3NDR02GbOnDl099130/33308vvvhibkyO0dLS0mFZU1OT21/gJ61PmDDBbfvwww/Txo0b/Tp+UOjTTz9N9913n+sBe/DBB2nFihV+vf5sOTY+fk7i+IHQAAAAAABZujzBWrhwIS1atMglRSx+P2nSJL9+yZIlNHnyZNcTxcnPI488khvzoYceogULFrihwlJwQrd06VKXaL300ksuyRIeffRRl8TxsXCMefPmueRJyH72yy+/7Lfl4+Nj5fcAAAAAAFm6PMHihIiTE2Hnzp2up0oYPXo0bd++3b9fs2ZNbsxNmza5eVfc+/TUU09RXV0drV+/vsM2/Jl5CVgWTsKyn5d9zcem42STNQAAAAAApssTrBEjRnQYVuPXvEzgYTxOcgROaMrF5GHHDRs2uN6nMWPGuGSLe8kEjp836Z0TMu7l4vlcnKyVSrB+//vfu/dZZY8dAAAAAIDp8gSLkxSdYN1xxx0d1mcTrObm5rIxNevWrXO9VsKdd95ZMsHiIUNOqvh/TtJ4PlipBIvjZI8NAAAAACBElydYPOdJDxFme4Eef/xx2rFjh3/PyVJeTJ5wrpMeniPFiZrAyRYnaiHuueceNyle2LJlS8kEi+eD6Qn5AAAAAACaLk+wFi9e3GmSe3Ye0yuvvELTpk1zCREnWjzvKS8mX8nHw3v19fUuXmNjo5v0/vzzz/tt+L5Yr776qkvEli1b1uHzuPeKJ6/zOp7PpSfVZ1/zhPb58+e7Y+MkkYckn3zySb8eAAAAAICpSYJVSgwP1c2YMcP1MN111130zDPPdOixYjhh4nWjRo1yQ3c8p6oUHI8TKJnozj1SfB+sbK8WJ17jxo1zPWXcg7Z582a/ToYTee4XJ1ec4JVKsPizOKHjXjMeLuTbO2R7vwAAAAAAmKonWNVGX2UIAAAAAFB0Cpdg8T2ouBeJe6D4CkIeLuQeKQAAAACAnkLhEiy+79Vjjz3mhux4KG7q1KmdJrEDAAAAABSZwiVYAAAAAAA9HSRYAAAAAABVBgkWAAAAAECVQYIFAAAAAFBlkGABAAAAAFQZJFgAAAAAAFUGCRYAAAAAQJVBggUAAAAAUGWQYAEAAAAAVBkkWAAAAAAAVQYJFgAAAABAlUGCBQAAAABQZZBgAQAAAABUGSRYoPCs2b6H/vux5XTuqFfKirfj7UH1aGlp0YuiWdOwlf77/0bRuX+6s6x4O94edGb77jX0x5f+mx564dyy4u14ewBA94AECxSea8a/Tgf8ZGrF4u0rZcGCBXTmmWfSX/3VX9FFF11Er7/ecd93v/vdXsceeyx95StfoSVLlnRYHyK0/Pzzz6eJEyd2WDZ16lT6z//8T/9+7Nix7nN4//e97300adKkzNYdj+e9730vfelLX6JFixZ12Kba/O3f/q1eFM01U56kA649r2Lx9pXSlX9DJhtPlF0Xel0tnn31Gvr5/x1QsXj7Son57vV0HwHoCpBggcJz1dOrOiVReeLtK4Ebh7/+67+mMWPGUFNTEy1fvpz+9V//lUaNGuW3yZ7cuSfnnnvuoX/5l38Jrs8SWr5582b62te+1qFH6N/+7d/c5zLcwH3wgx90/zOPP/64SxqmTZvmt8/G3blzJ918881VSYDyCJUllquee6JTEpUn3r4SuvpvyJRaztQ6MRi/7KpOSVSeePtKiP3u9XQfAegKkGCBwlOrBOt73/se3X///R2WcQP905/+1L/XJ3dOav7yL//Sv9frhVLL//SnP9Ftt93mXj/wwAM0cOBAv457BZ566in/nvnjH/9It956q3+v427bto2OO+44/37r1q30rW99yy0755xz3PpK1knvxfvf/376p3/6J5oxY4Zbzp8n2h9qlWB1x9+w1HKm1olBrRIsy3evJ/sIQFeABAsUnlolWB/96EdpxYoVenEHsid37iH57W9/63qhQuuzlFrO/PCHP6RLL73U9bw0NDT45Xn7CNlt9uzZ4xrAbCN42WWX0YgRI6ixsZHuuOMO9zmVrOMG9u6776bW1lZ68skn6VOf+pRfV8lxlaNWCVZ3/A1LLWdqnRjUKsGq5FhT8hGArgAJFig8tUqwOKlobm7WizvAJ3cRz0v58pe/vF/zThjen9cPHjy4w3K9T/azQ8tEnCwJp556qu+Z4h6rj33sYxWt++Y3v0nnnXcePf/8865nIos+Lgu1SrC642+o/dfrQq+rRVclWKHypeQjAF0BEixQeGqVYJ100km0fv36Dsva29td8iGUO7n369ev01V2/J6Xl4J7jq6//nr6xCc+UVEPVqnGhhOLZ555xpVD4Iavra3NvebeKE5AKlnH88O+/vWvu4nzfOwLFy7060odVwy1SrC642+YF6/U36padFWCJcSUpyf5CEBXgAQLFJ5aJVg8VKfn7/Ck3s9+9rP+fbmTO0/yffXVVzsse+WVV1yyEoLnYD300EPu9RNPPFF2DpbMjRL08XBvE09GFrhXiocAmS1btrghtErWCbt373ae8PCloD/TQq0SrO74G+bFq3ViUKsEy/Ld0/QkHwHoCpBggcJTqwTrhRdecEnG+PHjadeuXe5y/9NPP53+8Ic/+G3KndynTJlCZ599tptYzXOiXn75ZXfbBZkknoV7ic466yzfi8R89atf9VcR8pVwfCWXNHT8f/bKLiZ7PDIH6z/+4z/8sp///Of04IMPujkyPHT4P//zPxWt44SEEz7ucdANK7/euHGjf2+hVglWV/8Nmbx4tU4MapVgxX73QvQkHwHoCpBggcJTqwSLmTBhAv3DP/wDfeADH6DTTjuN7rzzzg7rKzm5P/bYY25fHl775Cc/6XqpQnz3u9/tdG8hfR8sbujkXkT8f7aBY3i5iK8G5AZt1ao3yltfX+8SLh6W4XX8vpJ1c+fOpb//+793n6k/9/vf/75rbPeHWiVYTFf+DZm8eDoxCGl/qFWCxVTy3StHT/ERgK4ACRYoPLVMsEDXUMsEqzdRywQLAFBdkGCBwjNkwppOSVSeeHtQLIbMeLpTEpUn3h50ZvJrQzolUXni7QEA3QMSLFB4Gve00j2zN9JdszaUFW/H24Ni0di8m+55YRrdtWhqWfF2vD3ozJ7WRpq75h6as+ausuLteHsAQPeABAsAAAAAoMogwQIAAABANPPW3ku/er4fXfHMoXT77M9R/c6Ot+lg1uyYR5eMO4iWbHrjNiDLtzxHN049ga589q00Ys4XacfutZk90gEJFgAAAACiWLtjAV014Z20ctsM2tPaQA/XfYdGzv3HDtu07/03bMZpLpGSBKutvYUGTjqa5q+9n3a1bKNH6s6jUYvO7rBfKiDBAgAAAEAUr26ZRFNX3uLfc6/Ur57veDuXma/fQaMW/jv9ZvrHfIK1evscuu659/ptNjYuoV9MePu+7VcPpxunnkjt7W3U0raLBk1+D72w/lG/bU8DCRYAAIDkyBu+yhuimrbq1k5XY3KvC8jn2VevoYfr/su/b2revDeRep/zNptgLd74J9erJXDvF3vc3LbT9XjdPOMUd5HGxOWD9/5tzvDb9USQYAEAAEiKvOGrckNUT79yOY1bNsC/B+XheVbXTjqGtu1e7Zexr8+tuMG9ziZYi9Y/TLfN/LTfjnuqOMHa2bLv+aGvb5tFgyYfS1dP/Ava1LTUb9cTQYIFAAAgKfKGr/KGqJiH6r7thqo0qQ1fVYutu1Y5P1/ePN4vW7ltuvNKev4692Cd6reVHqyWtjduzXLrzE/S7+d92b/vqSDBAgAAkDTZ4au8ISqGhwyHTDmOLn+mD90wtf/e5GyyW57a8FU14B5ATqR4ODbLfQvP7DTMypq4/Dpas2O+6+0SNja+RFdNeId//2r9RBr83Pvpmknvdj2RPRkkWAAAAJJFD1+VG6Kq2zDaNfLNrU00ZeVvaOCko9w2TErDV/tLa3sz3T778zTptV/pVZ3I9mC1tbe6Hq9Zq0e6BI2Hbx984Ry3jmNyUrtw3YM04/Xf0a0zP+US254KEiwAAABJEhq+qmSIKsu1k/u6JE1IZfhqf3lp05hOPVQs9leTTbCYFVun0U3TPuIvMmjYs94t52RNehc5Ebth6oddItZTQYIFAAAgOUoNX+UNUfGcofnrRvl1TDbBSmn4CtQeJFgAAACSIm/4Km+Iihky5Xia8frtboiQJ7bzFYc8RJja8BWoPUiwAAAAJEW54atSQ1TM2h0L3TAVT3K/adpH92471S1PbfgK1B4kWAAAAAAAVQYJFgAAAABAlUGCBQAAAABQZZBgAQAAAABUGSRYAAAAAABVBgkWAAAAAECVQYIFAAAAAFBlkGABAAAAAFQZJFgAAAAAAFUGCRYAAAAAQJVBggUAAAAAUGWQYAEAAAAAVBkkWAAAAADYb5ZsGkvXTu7b6SHblYj34/1TAgkWAAAAAPabQZPf0ylxihHvnxJIsApMQ0MDrV27FoIgCIIKL50wWaRjdrW43a0WSLAAAAAkD4avao/2zaKUSKs0AAAAQAAMX9Ue7ZlFKZFWaQAAoAewoXExPfvqNfSb6SfTK/XPdFg3Z81dNPi599OVz76Vfj/vy9SwZ4NfN23VrZ0apLb2lszeoBTaN4tAPtovi1IirdIAAEAP4IapH6ZnXr16byL1AXp583i/fH3Di3TVhHfQ6u1zaU9rAz1c9x16pO48v/7pVy6nccsG+PegcnRDbhHIR/tlUUqkVRoAAOhB3DTtox0SrC07V+x9P86/f2nTGLplxif8+4fqvk0zVw/37wVeduPUE6m9vY1a2na54awX1j+qN+vV6IbcIpCP9suilEirNAAA0IPQCVaWpuZ6um/BN1yvlTBizhdpyJTj6PJn+tANU/vT8i2T3fL2vf9unnGKG16cuHzw3u3O8PuAfeiG3CKQj/bLopRIqzQAANCDKJVgcSLFjc2vnu9HO5u3+OV1G0bTq/UTqbm1iaas/A0NnHSU67FiXt82iwZNPpaunvgXtKlpqd8H7EM35BaBfLRfFqVEWqX5/9TX19MxxxyjF1NdXR0dd9xxdMghh9DJJ59Ms2bN0ps48razrstywAFJ2g4AiKRUgsXsatlGf3zpx3T3/K/qVR6+fcCaHfP8+1tnftJNjAed0Q25RSAf7ZdFKZFWafaydOlSOumkk4JJzDnnnEODBw+mbdu20cUXX0xnnnmm3sSRt511XZbQsQEAeh86wdrY+JLroRLWN9S5HimGrxacv26UX8dkEyzej68+vGbSu2ntjgUdtgNo/LsC7ZdFKZFWafZy5JFH0siRI4NJDPcubdq0yb1esWIF9e/fX22xj7ztrOuyZI9t0KBBdOGFF2bWAgB6CzrBem3r83sTqiP3JkgLaU9rIz259Gd059wv+fVDphxPM16/3Q0R8sT2gZOOdkOEre3Nbk7WwnUP7l3/O7p15qfcvCzwBrohtwjko/2yKCXSKs1e1q9f7/4PJViHHXYYtbTsu2fM7t27qU+fPmqLfeRtZ12XRY5tyJAhdNZZZ1F7O06EAPRGdILFTFl5s5tLdcUzh9LIuf9AW3et8us48Ro24zQ3yZ33XbF1qls+6bVfueVMW3uruw3ErNUj/X4AjX9XoP2yKCXSKk2GUIJ10EEH+dec1Bx44IGZtW+Qt511XRY+tqFDh9IJJ5zgErFS8HOR5syZA0FQYlqzZo2v53v27KH58+fTwoULafv27ZkzALkfbK+99prbZ/Xq1dTW1ubX8Tlmx44d/j2v47jZH2zLly/v9Nm9Vboht0jHhDpK+2WRjtnV4na3WnTOQhIhlGAdccQR1Nzc7F5zYsM9TiHytrOuy8LHdvbZZ1Pfvn1p1ao3fp0CAACoDbohtwjko/2yKCXSKk2GUILFPUb8K5DhxOb4449XW+wjbzvruixybMOGDXMT4wEAANQW3ZBbBPLRflmUEmmVJkMowbrgggto4MCB1NjYSAMGDKBzzz1Xb+LI2866LoscW2trK5144om0YAGu+AEAgFqiG3KLQD7aL4tSIq3SZAglWLNnz6Z+/frRwQcf7K744/tWCdnt87azrsuS/axx48bRGWeE77rc0NDgxoMhCIKg/ZNuyC3SMaGO0n5ZpGN2tbjdrRads5Beyvnnn68XAQAASATdkFsE8tF+WZQSaZVmP7jkkkv0IgAAAImgG3KLQD7aL4tSIq3SAAAAAAF0Q24RyEf7ZVFKpFUaAAAAIIBuyC0C+Wi/LEqJtEoDAAAABNANuUUgH+2XRSmRVmkSA1cRQhAEVUe6IbdIx4Q6SvtlkY7Z1cJVhAAAAEAEuiG3COSj/bIoJdIqDQAAABBAN+QWgXy0XxalRFqlAQAAAALohtwikI/2y6KUSKs0AAAAQADdkFsE8tF+WZQSaZUGAAB6IEs2jaVrJ/ft1NhUIt6P9wf5aN8sAvlovyxKibRKAwAAPZBBk9/TqaGJEe8P8tGeWQTy0X5ZlBJplSYxcJsGCOod0o2MRTom1FHaL4t0TKijtF8W6ZhdLdymAQAAEkI3MhaBfLRfFoF8tF8WpURapQEAgB6IbmQsAvlovywC+Wi/LEqJtEoDAAA9EN3IWATy0X5ZBPLRflmUEmmVBgAAeiC6kbEI5KP9sgjko/2yKCXSKg0AAPRAdCNjEchH+2URyEf7ZVFKpFUaAADogehGxiKQj/bLIpCP9suilEirNAAA0APRjYxFIB/tl0UgH+2XRSmRVmkAAKAHohsZi0A+2i+LQD7aL4tSIq3SAABAD0Q3MhaBfLRfFoF8tF8WpURapQEAgB6IbmQsAvlovywC+Wi/LEqJtEoDAKgac9bcRYOfez9d+exb6ffzvkwNezZ0WF+3YTTdPP3jNODZwzssB/HoRsYikI/2yyKQj/bLopRIqzQAgKqwvuFFumrCO2j19rm0p7WBHq77Dj1Sd55f/+KGx90DhhdvfMKtB/uHbmQsAvlovywC+Wi/LEqJtEqTGHjYM9QdamlpoS07V9DLm8f57+JLm8bQLTM+4d/fNO0jtKz+Wf9emLl6ON049URqb2+jlrZdLgl7Yf2jbt26des6fRa0T7qRsUjHhDpK+2WRjgl1lPbLIh2zq4WHPQMAuoym5nq6b8E36OlXLnfvG5s30aXjD6H7F55FA559295k66O0Zsc8t65977+bZ5zihhcnLh9MI+ackQ0FSqAbGYtAPtovi0A+2i+LUiKt0gAAqsqIOV90J71fPd+PdjZvccs4meJl01YNc8ODE5YPoiFTjnfJFfP6tlk0aPKxdPXEv6BNTUuz4UAJdCNjEchH+2URyEf7ZVFKpFUaAEDV2dWyjf740o/p7vlfde9Xb5/jEiihrb2VLhv/Ztq2e7VfduvMT7qJ8aAydCNjEchH+2URyEf7ZVFKpFUaAEBV2Nj4Er1aP9G/X99Q53qkmMY9G11C1dK2272XBKthz3r3nvfjqw+vmfRuWrtjgY8BSqMbGYtAPtovi0A+2i+LUiKt0gAAqsJrW5/fm1AduTdBWkh7WhvpyaU/ozvnfsmvv33252nM0otoV8t2mrj8OjcPi2ltb6YbpvanhesepBmv/45unfkpP3QISqMbGYtAPtovi0A+2i+LUiKt0gAAqsaUlTe7ocArnjmURs79B9q6a5Vfx8OBw+d8Ye+6w1wStaFxsVs+6bVf0bAZp7nX3LN1w9QP06zVI/1+IIxuZCwC+Wi/LAL5aL8sSom0SgMAAD0Q3chYBPLRflkE8tF+WZQSaZWmAurr6+mYY47RiztQV1dHxx13HB1yyCF08skn06xZs/Z7XZYDDuh1tgMActCNjEUgH+2XRSAf7ZdFKZFWacqwdOlSOumkk8omOOeccw4NHjyYtm3bRhdffDGdeeaZ+70uS7nPBwD0LnQjYxHIR/tlEchH+2VRSqRVmjIceeSRNHLkyLIJDvdCbdq0yb1esWIF9e/ff7/XZcl+/qBBg+jCCy/MrAUA9DZ0I2MRyEf7ZRHIR/tlUUqkVZoyrF+/7zLycgnWYYcd5h4XwuzevZv69Omz3+uyyOcPGTKEzjrrLGpvx1VWAPRmdCNjEchH+2URyEf7ZVFKpFWaCimXYB100EH+NSc/Bx544H6vy8KfP3ToUDrhhBNcIgYA6N3oRsYikI/2yyKQj/bLopRIqzQVUi7BOuKII6i5udm95gSIe6b2d10W/vyzzz6b+vbtS6tWvXHpu4YfPDlnzhwIghKXbmQs0jF7mubNm+fPn/e+MJ363XYZHXr9BfS5+35Nr27d2OF82NTU5LZ7+KU59L5hF9Pbb/gxff+pe2lPawu99tprbr4tX3TE27744ovuPd+fTXsWK33MUEdpvyzSMbta/D2rFvmZRqKUS7C4Z2n16n2P/eAE6Pjjj9/vdVnk84cNG+YmxgMAeje6kbEoFRZsWEXvvPF/acaa5dSwZzd958m76B//MFRvRmsbttFbh/yInn3tJdrQuIP+5p7B9KvpT+nNOnD/wm928i1GIB/tl0UpkVZpKqRcgnXBBRfQwIEDqbGxkQYMGEDnnnvufq/LIp/f2tpKJ554Ii1YgMeJANCb0Y2MRakwaeVSumX2s/79c6tepg/edmlmi32MqptJXxh1o38/fvliOmnE1e718PnP0YnDf0Ft7e20q6WZ3nPLz+nRJXNp4vLBnXyLEchH+2VRSqRVmgoJJVjZZbNnz6Z+/frRwQcf7K4M5K7m/V2XJftZ48aNozPOOCOzFgDQ29CNjEWpcs2UJ+m/nrxLL3bDiGeMusm/f7l+g+v5YvjxTKfceQ3dtWgqDZ72lNtuT2sDXffc+zr5FiOQj/bLopRIqzT7wfnnn68XdTsNDQ1uPBiCoGJo3bp1roc6e+Uvz7fU2+zatStTk/exY8eOTvFEupGxSMfsqWL/ZC7WvPUr6Zibf0ard2x1t8Dh9Rs2bHDrVmzbTH2u/yFNXLGEtuxqoq8/cptPsJhZa1+jY2/5Of3F0J/Q0vr1NHrxDzt5Fit9rFBHab8s0jG7WtzuVgskWP+fSy65RC8CAIAgo5fOo4/feQ0d/uv/1qs8PLR1wh2/cPOEvvjATW7OUCkuHXdwp4YmVqmxans9vXfYxW7orxQjFzzvEihOpH7w1H30oduv7LD+k3ddR19+6Bb3+hfP/nknz2IF8tF+WZQSaZUGAABqzONL57s5PU+8vMBNwg7R0tZGR//mp3R/3QzatnsnnTfmbjr78eF6M8+ulm3021mnd2psYpQS7BnPoeJhwEoZNHWM81ngnq3333oJvfs3P3MT52euHt7Js1iBfLRfFqVETUszadIkNyeJ4flJRx11lJufNHr0aLUlAAD0DD4y/Cp35Voec9atcL0vwpLN69ytBJhSE7CfW3FDp8YmRqnQ3NZKn7//1yWvCJy99rUO7zc1NdCdC6fQETf9L9VtXOOWcYz+tw+gBxfPpt/Nm0Sfuvs6atv7b/icv+/kW4xAPtovi1KipqX52Mc+Rg8//LC7Z8mhhx5KI0aMoEcffdQlWQAA0NPgxvyQwd+ns0bfTm/79X/TR4f/0s0T0vzplYV02l2D/Hvu6Trg2vNo596EKjQBu629lW6Z8YlOjU2MUmHMskXOKy32tH5nI73rpp/Q+sbtbtvpq19198rihGza6mU+Bidn4n9rext9+I4Bbjhx1uqRnXyLEchH+2VRStS0NG9605vcrQjGjx9PhxxyiJv4yY+SQYJVGZjkDnW3pi29j66ecHSnk2Al4v14fx2zJ0qeMcrJFDf2w+ZMcEkTD0sd/7srqLWt1W23ZcsWtx3fAPPTdw/2k7W5p4r327pr3w0y9QTs/e29Yulj7mniye3yqLEQ3JYIfG4Ub/mCA77QoL6+nrZu3eovQGhra6Nt27a5/5k9rY00cu4/dPItRvqYe6LkJq2LN611V2iePPJqeua1xc4/2YYv5GB++sxDnRJd1tx1K913PRuXLz64ZNyfdfIsVvp4u1o9ZpI7J1K33XYbff3rX6fPfOYzbtkDDzxAJ510ktoSAFBEBk1+T6cTYIx4/5TgoT9OjATuHXnzr37grnLLwr0tp/7+Wv9eerB2t76RQGQnYF/57Fs7eRcrkM9l49/UybNYpQT36l39/J/oA7demnshQRZOyvieZDy8HYLnEt428286+RajlKhpaWQ48Atf+AJt3LjRLevfvz/NnDlTbQkAKCL65GdRSmxs2uESKkmUJMGSISth/vpV7vYCwkub19E7bvwf/15PwB778iWdfIsVyEf7ZVGK8DB3pQnW+WPvoRtmjHOvMZewPGmVBgBQVfTJz6LU4Pk+Fz37MG3fvZOumzbWNVCCTMDmxIsnufO8H74ijh/3cs4Td7p1oQnYzW276IapH+7kXYxAPtovi1Kk0gRr884GdyGBDHNjLmF50ioNAKCq6JOfRanBw4H8iJbDhlzgkiMeNmH0BGyedM1XHMp9sGR5qQnYM17/XSfvYgTy0X5ZlCKVJljXTh3jerCy1GIuYUqkVRoAeil5k1E1fG+hfrdd5q6++tx9v6ZXt+4bvg+hT34Wgcp49MXvdfIuRiAf7ZdFKVJJgsW9rjzkvWjDar0KcwlzSKs0AABHqcmoPN+HHycyY81yN/Gah67+8Q9DO2yT5b6FZ9LF4w7sdBKMESjPTdM+0sm3WIF8tF8WpUglCdZ9L8ygz943RC/GXMIy1KQ0ffv2pW984xt044030vTp02nPnj16E1ABuE0DVE58aXUImYzKl7bzdnz5OjNp5VK6Zfazfjt+nAsnYoxMWt3T3Nxh0uoTS/6300kwRvqYoc7SnlmkY0Idpf2ySMfsyZLbXEiCxbdd4Ftc8O2Upq5Y6l7LLS/4sVB8LuDbXmzevNktKz2XsH8n32Kkj7OrVfjbNLz++uv00EMP0U9+8hM67bTT6G1vext98pOfpAsvvNDdeJTXAwBqg56MmgffB+e/nrzLvQ5NWmXGLRvQ6SQYI1Ae7ZlFIB/tl0Uponuw9FzC51e94n5s8XzBLKXmEk5f9dtOvsUoJbqkNHwTuGnTptENN9xA//Zv/0bHHHOM3gQAUCVCk1FD8E0zeV5F9h5OetLq9t1r6BcT3t7pJBgjUB7tmUUgH+2XRaA8jy3+fiffYpQSaZUGgF5O3mTULKu217vbCITmXmQnrd417yudToCxAuXRnlkE8tF+WQTywVzCjqRVGgB6OaUmo2bh+zLxXCu+mlCjJ61eO7lvpxNgrEB5tGcWgXy0XxaBfLRfFqVEWqUBoJcjk1E1cgNM7uHiG2Xy/AlNaNLqks1PdToBxgqUR3tmEchH+2URyEf7ZVFKpFUaAHoxpSajZietjlm2qNO9slj87LxSk1YfWPQfnU6CMQLl0Z5ZBPLRflkE8tF+WZQSNS3NpEmTqF+/fu717Nmz6aijjnLPJhw9erTaEgBQVGatHtnpJBgjUB7tmUUgH+2XRSAf7ZdFKVHT0nzsYx9zt2VoamqiQw89lEaMGOEfAA0AKD6rt8+lodNO6nQSjBEoj/bMohQZu+wF6nvzRZ16XCsR78f7C9ovi0A+2i+LUqKmpXnTm97kbnQ4fvx4OuSQQ9wNzFpaWpBgAdBD0Cc/i0B5tGcWpQgPeevEKUa8v6D9sgjko/2yKCVqWhpOpG677Tb6+te/Tp/5zGfcsgceeIBOOukktSUAoIjok59FoDzaM4tSRCdMFgnaL4tAPtovi1KipqWR4cAvfOELtHHjvgfK9u/fn2bOnKm2BAAUEX3yswiUR3tmUYroZMkiQftlEchH+2VRSqRVGlBYHn5pDr1v2MX09ht+TN9/6l7a09qiN+kA32X8oEHn01OvvjGHAnQ9+uRnESiP9syiFNHJkkWC9ssikI/2y6KUSKs0idHTH/a8bt0697DQtQ3b6K1DfkTPvvYSbWjcQX9zz2B3S4Dt27d32ocfq8TPxOPbBfA+nGDxw8JZjY2NtGHDBnfRBL/n/9evX98pBlQ96ZOfRTpmKrpv1mQ6+qYLOzXolYj34/0llvbMIn18KUj7ZpHE0n5ZpI8P6ijtl0U6Zler8A97FrgBveiii+gDH/iAm+TO8EOfFy/u/HgOkC6j6mbSF0bd6N/z41lOGnF1ZouO3DH/Ofr3x++gj40c6Huwhu9dxncfb+Onvbc0u8mroRtqguqiT34WpQomYNce7ZlFgvbLIpCP9suilKhpab761a/SqFGjqK2tjQ44YN9H8aT3j3/842pLkDL8SJYzRt3k379cv4HeeeP/ZrZ4g807G9xQIvd6ZRMs7tU65c5r6K5FU2nwtKc6xAO1Q5/8LEoV3ZBbJGjPLEoR7ZdFgvbLIpCP9suilKhpad72tre5WzMwkmDx+8MOOyy7GUicFds2U5/rf+iec7dlVxN9/ZHbSiZY5425m26YMc69ziZYzKy1r9Gxe3/1/8XQn9DS+vV+Oagd+uRnUarohtwiQXtmUYpovywStF8WgXy0XxalRE1Lwz1YixYtcve+4gSrubmZ5s+fT//yL/+iNwWJw49c4cSIE6QfPHUffej2K/UmNH31q24YsKVt36NedILFfPKu6+jLD93SYRmoHfrkZ1Gq6IbcIkF7ZlGKaL8sErRfFqUIbuZaO2pamm3bttGPf/xjeu973+tuOnrMMcfQ9773Paqvr9ebgl7EoKljXE+V5szHftepArOumzbWrecesPffegm9+zc/owUbVqm9QS3QJz+LUkV/Ty0StGcWpYj2yyJB+2VRimAuYe1IqzSg0GxqaqA7F06hI276X6rbuMYvn732tTc2ypDtwWpua6X+tw+gBxfPpt/Nm0Sfuvs6Ny8L1BZ98rMoVXRDY5GgPbMoRbRfFgnaL4tSRPtlkaD9sigl0ioNKCw8/Hfo9RfQ5+//NU1bvcwvr9/ZSO+66Se0vnF7Zut9ZBMsvq0D37qBaW1vow/fMcANO4Laok9+FqWKbmQsErRnFqWI9ssiQftlUYpovywStF8WpURNSsPzrcqp6PAwJg9paurq6ui4445zt504+eSTadasWRWty9ITyg8Ao09+FqWKbmQsErRnFqWI9ssiQftlUYpovywStF8WpURapakSS5cudc9LDCVC55xzDg0ePNjNL7v44ovpzDPPrGhdllBcAKoFJq12Ddo7iwTtmUUpov2ySNB+WZQi2i+LBO2XRSmRVmmqxJFHHkkjR44MJkLcQ7Vp0yb3esWKFe7ZipWsy5KNO2jQILrwwgszawHYPzBptWvQvlkkaM8sShHtl0WC9suiFNF+WSRovyxKiZqWhh9r8u1vf5uOOuoo99Dn97znPe6qwh07duhNCwU/foUJJVh8Dy++7QTDj3Xp06dPReuySNwhQ4bQWWed5R4nA0C10Cc/iwR98rMoVbRnFgnaM4tSRPtlkaD9sihFtF8WCdovi1KipqU5/fTTaejQobRlyxaXeHDvzvXXX+/uj9UTCCVYBx10kH/NidGBBx5Y0bosHJd9OeGEE1wiBkA10Sc/iwR98rMoVbRnFgnaM4tSRPtlkaD9sihFtF8WCdovi1KipqU5/PDD3XykLPx8wre+9a0dlhWVUIJ1xBFHuBumMpwcZe9Kn7cuC8c9++yzqW/fvrRqVen7OfGDJ+fMmQNBUdInP4sklj75WaSPLxVpzyySWNozi/TxpSDtl0USS/tlkT6+FKT9skhiab8s0sfX1eJ2t1p0ziCqyIMPPkjf+ta36JVXXqE9e/bQ6tWr6YILLnC9Nz2BUILFvU5cDoaTo+OPP76idVkk7rBhw9zEeACqiT75WSTok59FqaI9s0jQnllUFMYu3kJ9fzmbDvjJ1Gjxfry/oP2ySNB+WVQU4HHPoKalyd6WoZSKTOj4OEEcOHAgNTY20oABA+jcc8+taF0Widva2konnngiLViwQG0BylG3YTTdPP3jNODZw/Uqz7RVt3aqvG3t++bIpYw++VkkaP8sShXtmUWC9syiovCeX87p1KjHiPcXtF8WCdovi4oCPO4ZpFWaKhNKsGbPnk39+vVzk/b5qkG+91Ul67Jk444bN47OOOOMzFpQjhc3PE6DJr+HFm98gva0NujVnqdfuZzGLRugFyePPvlZJOiTn0Wpoj2zSNCeWVQUdGNukY8V8CxWgvbLoqKg/bLIxwp4FitB+2VRSqRVmsRoaGhw48HQPu3cudP5ctO0j9Cy+me9TzzfLbudXKX6UN23aebq4X47gZfdOPVEamtrpZa2XS5Ze2H9o9TU1NTpM3ui9MnPIomlT34W6eNLRdoziySW9swifXzdJd2QW+RjBTyLlcTSflmky9pd0n5Z5GMFPIuVxNJ+WaTL2tXidrda1DTBeuyxx1xvjh4WDPUMAVAJjc2b6NLxh9D9C8+iAc++bW+y9VFas2Oe3swxYs4XaciU4+jyZ/rQDVP70/Itk91yfobhzTNOoTlr7qKJywfv3S6tHkR98rNI0Cc/i1JFe2aRoD2zqCjohtwiHyvgWawE7ZdFRUH7ZZGPFfAsVoL2y6KUqGlpOLkaPXq0XgyAGU6muBJOWzXMDQ9OWD5obxJ1fPDBzzxP69X6idTc2kRTVv6GBk46yvVYMa9vm0WDJh9LV0/8C9rUtFTt2bPRJz+LBH3ysyhVtGcWCdozi4qCbsgt8rECnsVK0H5ZVBS0Xxb5WAHPYiVovyxKiZqW5uijj65qdxsAq7fPcYmR0NbeSpeNfzNt2706s1WYayf37dDbdevMT9Lv5305s0Ua6JOfRYI++VmUKtoziwTtmUVFQTfkFvlYAc9iJWi/LCoK2i+LfKyAZ7EStF8WpURNS/P000/TVVdd5R6cDEA1aNyz0SVULW37btAqCVbDnn133xf4asH560Z1WJZNsLhna/Bz76drJr2b1u5I6ypOffKzSNAnP4tSRXtmkaA9s6go6IbcIh8r4FmsBO2XRUVB+2WRjxXwLFaC9suilKhpaf74xz/6hyZrgfJgkntHtbW1OV9un/15GrP0ItrVsp0mLr/OzcNieP3KrTO9fzx0OOP1290QIU9sHzjpaDdE2Nre7OZkLVz34N71v6NbZ36KWtta3SOS9Gf2ROmTn0USS5/8LNLHl4q0ZxZJLO2ZRfr4ukvSgH9o8Dy6fMxKmvt6I/3dbXV++T8NX0xbdrbQu66c5ZddMXYlLVrbRH/202nuvY8V8CxWEkv7ZZEua3cJHtdO1Rx1q2mmw3c2f/jhh/3z+QAYu+wF6nvzRZ0qaCW6efa+Kwd5OHD4nC/QFc8c5pKjDY2L3fKm5nr65cR3+d6stTsW0rAZp7lJ7pyErdi671fbpNd+5ZYz3AN2w9QP06zVI937FNC+WSTok59FqaI9s0jQnllUFKRBr1vXRFc+tZJe3byL/v63bzT+rCfq6mnEjPXuNd+TqXFPK51+ywt+vY8V8CxWgvbLoqIAj3sGNS3NO9/5Tvf8QQCE99zy806Vs1K95Vc/oFlrX9MhgUL7ZpGgT34WpYr2zCJBe2ZRUcg28qwFaxo7Nf7vGziHmva00alDF9GD8zfRvXM2dljvYwU8i5Wg/bKoKMDjnkFNSzNixAj69re/TRs2bHAPPwZAV8xYHXzd93RIoNCeWSTok59FqaI9s0jQnllUFCpp/FkDnlpFK7fspm27WumoAW8MZaHxL4/2Eh4Xk5qWRs+7whwsoCumRSAf7ZdFgj75WZQq2jOLBO2ZRUVBN/KlGv93XjGTWtva6Y7p+4ax0PhXjvYLHheTtEoDCo+umBaBfLRfFgn65GdRqmjPLBK0ZxYVBd2Ql2r8r5+wmpZs2EkNu1vp2Ks7PlvPxwp4FitB+2VRUdBewuNiklZpQOHRFdMikI/2yyJBn/wsShXtmUWC9syioqAb+VDj33/wfNrV0kYfHbKARs7YQKMX1aPxjwAe9wxqWpo777zTXUmohwcxRFgZKd6mQVdMi3RMqKO0XxZJLH3ys0gfXyrSnlkksbRnFunj6y5V0vhPWradbn5u37Z8K4GtO1voS3cs9ut9rIBnsZJY2i+LdFm7S/C4duoxt2k48sgjXZLV3NxM73//+2nRokX0ne98h2666Sa9Kegl6IppkbBk01h381BdQSsR78f7p4j2yyJB+2ZRqmjPLBK0ZxYVhXKN/7fuf5k2NDTT4ZfO8Mt+9Ohyd6uBN1803b33sQKexUrQfllUFOBxz6CmpenTpw/t3r3vjttf+9rXaOzYsbR161Y69tg3HnUCehe6YlokDJr8nk6VM0a8f4povywStGcWpYr2zCJBe2ZRUdCNv0U+VsCzWAnaL4uKgvbLIh8r4FmsBO2XRSlR09J84hOfoEceecS9HjhwoHtszuzZs+m4445TW4Legq6YFgm6YlqUItoviwTtl0Wpoj2zSNCeWVQUdENukY8V8CxWgvbLoqKg/bLIxwp4FitB+2VRStS0NNOmTXOPymE2btxIn/70p+nwww+n0aNHqy1Bb0FXTIsEXTEtShHtl0WC9suiIjF28Rbq+8vZnRqbSsT78f6C9swiQXtmUVHQvlnkYwU8i5Wg/bKoKGi/LPKxAp7FStB+WZQSaZUGFB5dMS0SdMW0KEW0XxYJ2i+LigQ/MkQ3NDHi/QXtmUWC9syioqA9s8jHCngWK0H7ZVFR0H5Z5GMFPIuVoP2yKCXSKg0oPLpiWiToimlRimi/LBK0XxYVCd3IWORjBXyLlaA9s6goaL8s8rECnsVK0H5ZVBS0Xxb5WAHPYiVovyxKiZqUhi9zPP/8892zCJl58+bRiSeeSG9+85vpi1/8orsUEvROdMW0SNAV06IU0X5ZJGi/LCoSupGxyMcK+BYrQXtmUVHQflnkYwU8i5Wg/bKoKGi/LPKxAp7FStB+WZQSNSnNeeedRz/60Y/8FYTHH388XXfddbR9+3a65JJL6Ctf+YraA/QWdMW0SNAV06IU0X5ZJGi/LCoSupGxyMcK+BYrQXtmUVHQflnkYwU8i5Wg/bKoKGi/LPKxAp7FStB+WZQSNSnN29/+djepnXnxxRfpoIMOom3btrn39fX19Ja3vCW7OehF6IppkaArpkUpov2ySNB+WVQkdCNjkY8V8C1WgvbMoqKg/bLIxwp4FitB+2VRUdB+WeRjBTyLlaD9siglalIavv+V3A31jjvuoFNOOcWv4+FBvrs76J3oimmRoCumRSmi/bJI0H5ZVCR0I2ORjxXwLVaC9syioqD9ssjHCngWK0H7ZVFR0H5Z5GMFPIuVoP2yKCVqUhqeZ3XjjTe6JOsHP/gBfe9733PLd+zYQZdeeql/D3ofumJaJOiKaVGKaL8sErRfFhUJaVw+NHgeXT5mJc19vZH+7rY37oD9T8MX05adLe7RIrLsirEradHaJvqzn05Dw1QBuiG3yMcKeBYrQftlUVHQflnkYwU8i5Wg/bIoJWpSmmXLltGpp57qJrV/9atfpc2bN7vlPHTIk98bGxvVHqC3oCumRYKumBaliPbLIkH7ZVGRkMalbl0TXfnUSvfoEP0Mtyfq6mnEjPXuNd+WoXFPK51+ywtomCpEN+QW+VgBz2IlaL8sKgraL4t8rIBnsRK0XxalRFqlSQw87DksiaUrpkX6+FKQ9ssiiaX9skgfX3dKNzL6GW6s9w2cQ0172ujUoYvowfmb6N45Gzus97ECvsVKYmnPLNJl7S5pjy3ysQKexUpiab8s0mXtLmm/LPKxAp7FSmJpvyzSZe1q9ZiHPQOg0RXTIkFXTItSRPtlkaD9sqhI6EYmlGCxBjy1ilZu2U3bdrXSUQPeGC5k+VgB32IlaM8sKgraS4t8rIBnsRK0XxYVBe2XRT5WwLNYCdovi1IirdKAwqMrpkWCrpgWpYj2yyJB+2VRkdCNTKkE651XzKTWtna6Y/q+oUI0TJWj/bLIxwp4FitB+2VRUdB+WeRjBTyLlaD9sigl0ioNKDy6Ylok6IppUYpovywStF8WFQndyJRKsK6fsJqWbNhJDbtb6dirOz5ex8cK+BYrQXtmUVHQXlrkYwU8i5Wg/bKoKGi/LPKxAp7FStB+WZQSaZUGFB5dMS0SdMW0qCjgIcRdg/YulGD1HzyfdrW00UeHLKCRMzbQ6EX1Hdb7WAHfYiVozywqCtpji3ysgGexErRfFhUF7ZdFPlbAs1gJ2i+LUiKt0oDCoyumRYKumBYVBTyEuGvQvoUSrEnLttPNz+2bSMy3a9i6s4W+dMdiv97HCvgWK0F7ZlFR0B5b5GMFPIuVoP2yqChovyzysQKexUrQflmUEmmVBhQeXTEtEnTFtKgo6JOfRT5WwLNYCdovi4qE9kwnWN+6/2Xa0NBMh186wy/70aPL3e0c3nzRdPhcAdpji3ysgGexErRfFhUF7ZdFPlbAs1gJ2i+LUiKt0oDCoyumRYKumBYVBX3ys8jHCngWK0H7ZVGR0J5Z5GMFfIuVoD2zqChovyzysQKexUrQfllUFLRfFvlYAc9iJWi/LEqJtErTRfDzFI855hi9mOrq6ui4446jQw45hE4++WSaNWuW3sRxwAG913ZdMS0SdMW0qCjok59FPlbAs1gJ2i+LioT2zCIfK+BbrATtmUVFQftlkY8V8CxWgvbLoqKg/bLIxwp4FitB+2VRSqRVmi5g6dKldNJJJwWTpHPOOYcGDx7sHmx98cUX05lnnqk3cYT27S3oimmRoCumRUVBn/ws8rECnsVK0H5ZVCS0Zxb5WAHfYiVozywqCtovi3ysgGexErRfFhUF7ZdFPlbAs1gJ2i+LUiKt0nQBRx55JI0cOTKYJHHv1aZNm9zrFStWUP/+/dUW+8juO2jQILrwwgsza9NGV0yLBF0xLSoK+uRnkY8V8CxWgvbLoiKhPbPIxwr4FitBe2ZRUdB+WeRjBTyLlaD9sqgoaL8s8rECnsVK0H5ZlBJplaYLWL9+vfs/lGAddthh1NLS4l7v3r2b+vTpo7bYh+w7ZMgQOuuss6i9vV1tkS66Ylok6IppUVGQkx4eQlxbdCNjkY8V8C1WgvbMoqKg/bLIxwp4FitB+2VRUdB+WeRjBTyLlaD9sigl0ipNFxJKsA466CD/mpOmAw88MLP2DXjfoUOH0gknnOASsd6ErpgWCbpiWlQU5KSHhxDXFt3IWORjBXyLlaA9s6goaL8s8rECnsVK0H5ZVBS0Xxb5WAHPYiVovyxKibRK04WEEqwjjjiCmpub3WtOnLhHKwTve/bZZ1Pfvn1p1apVerWHHzw5Z86cpKQrpkUSS1dMi/TxdZf0yU/fPoBV7iHEPlbAs1hJLO2XRbqs3Snts0U+VsC3WEks7ZlFuqzdJe2XRT5WwLNYSSztl0W6rN0l7ZdFPlbAs1hJLO2XRbqsXS1ud6tF5ywBVEQoweIeqdWrV7vXnDgdf/zxaot9yL7Dhg1zE+N7E7piWiToimlRUdAnv1CCxcJDiPcP7adFPlbAt1gJ2jOLioL2yyIfK+BZrATtl0VFQftlkY8V8CxWgvbLopRIqzRdSCjBuuCCC2jgwIHU2NhIAwYMoHPPPVdv4pB9W1tb6cQTT6QFCxaoLdJFV0yLBF0xLSoK+uRXKsHCQ4j3D+2ZRT5WwLdYCdozi4qC9ssiHyvgWawE7ZdFRUH7ZZGPFfAsVoL2y6KUSKs0XUgowZo9ezb169ePDj74YHdFId8XK0R233HjxtEZZ5yRWZs2umJaJOiKaVFR0Ce/UgkWHkK8f2g/LfKxAr7FStCeWVQUtF8W+VgBz2IlaL8sKgraL4t8rIBnsRK0XxalRFqlSYyGhgY3HpySdMW0SGLpimmRPr7ukj75hRKscg8h9rECnsVKYmm/LNJl7U5pny3ysQK+xUpiac8s0mXtLmm/LPKxAp7FSmJpvyzSZe0uab8s8rECnsVKYmm/LNJl7Wpxu1stkGCBLkVXTIsEXTEtKgr65BdKsPAQ4v1H+2yRjxXwLVaC9syioqD9ssjHCngWK0H7ZVFR0H5Z5GMFPIuVoP2yKCXSKg0oPLpiWiToimlRUdAnP51g4SHE1UH7bJGPFfAtVoL2zKKioP2yyMcKeBYrQftlUVHQflnkYwU8i5Wg/bIoJdIqDSg8umJaJOiKaVFR0Cc/i3ysgGexErRfFhUJ7ZlFPlbAt1gJ2jOLioL2yyIfK+BZrATtl0VFQftlkY8V8CxWgvbLopRIqzSg8OiKaZGgK6ZFRUGf/CzysQKexUrQfllUJLRnFvlYAd9iJWjPLCoK2i+LfKyAZ7EStF8WFQXtl0U+VsCzWAnaL4tSIq3SJEZRJrnfN2UZHT3gjaGpGPF+vL/E0hXTIomlK6ZFuqzdJe2bRT5WwLNYSSztl0W6rN0p7ZlFPlbAt1hJLO2ZRbqs3SXtl0U+VsCzWEks7ZdFuqzdJe2XRT5WwLNYSSztl0W6rF0tTHIHXQo/lkVXzhjx/oKumBYJumJaVBS0Zxb5WAHPYiVovywqEtozi3ysgG+xErRnFhUF7ZdFPlbAs1gJ2i+LioL2yyIfK+BZrATtl0UpkVZpQE3QFdMiHytQOWMl6IppUVHQflnkYwU8i5Wg/bKoSGjPLPKxAr7FStCeWVQUtF8W+VgBz2IlaL8sKgraL4t8rIBnsRK0XxalRFqlATVBV0yLfKxA5YyVoCumRUVB+2WRjxXwLFaC9suiIqE9s8jHCvgWK0F7ZlFR0H5Z5GMFPIuVoP2yqChovyzysQKexUrQflmUEmmVBtQEXTEt8rEClTNWgq6YFhUF7ZdFPlbAs1gJ2i+LioT2zCIfK+BbrATtmUVFQftlkY8V8CxWgvbLoqKg/bLIxwp4FitB+2VRSqRVGlATdMW0yMcKVM5YCbpiWlQUtF8W+VgBz2IlaL8sKhLaM4t8rIBvsRK0ZxYVBe2XRT5WwLNYCdovi4qC9ssiHyvgWawE7ZdFKZFWaUBN0BXTIh8rUDljJeiKaVFR0H5Z5GMFPIuVoP2yqEhozyzysQK+xUrQnllUFLRfFvlYAc9iJWi/LCoK2i+LfKyAZ7EStF8WpURapQE1QVdMi3ysQOWMlaArpkVFQftlkY8V8CxWgvbLoiKhPbPIxwr4FitBe2ZRUdB+WeRjBTyLlaD9sqgoaL8s8rECnsVK0H5ZlBJplSYxinIfLKmQHxo8jy4fs5Lmvt5If3fbG49x+afhi2nLzhb3fDxZdsXYlbRobRP92U+nufc+VqByxkpi6YppkS5rd0mf/CzysQKexUpiab8s0mXtTmnPLPKxAr7FSmJpzyzSZe0uab8s8rECnsVKYmm/LNJl7S5pvyzysQKexUpiab8s0mXtauE+WKBLkQpZt66JrnxqpXv+nX4Q8RN19TRixnr3mu971binlU6/5QW/3scKVM5YCbpiWlQU9MnPIh8r4FmsBO2XRUVCe2aRjxXwLVaC9syioqD9ssjHCngWK0H7ZVFR0H5Z5GMFPIuVoP2yKCXSKg2oCbpi6gcRs943cA417WmjU4cuogfnb6J752xEZY5Ae2yRjxXwLFaC9suiIqE9s8jHCvgWK0F7ZlFR0H5Z5GMFPIuVoP2yqChovyzysQKexUrQflmUEmmVBtQEXTFDCRZrwFOraOWW3bRtVysdNeCN4UJU5vJoLy3ysQKexUrQfllUJLRnFvlYAd9iJWjPLCoK2i+LfKyAZ7EStF8WFQXtl0U+VsCzWAnaL4tSIq3SgJqgK2apBOudV8yk1rZ2umP6vqFCVObK0X5Z5GMFPIuVoP2yqEhozyzysQK+xUrQnllUFLRfFvlYAc9iJWi/LCoK2i+LfKyAZ7EStF8WpURapQE1QVfMUgnW9RNW05INO6lhdysde3XH5xf6WIHKGStBV0yLioL20iIfK+BZrATtl0VFQntmkY8V8C1WgvbMoqKg/bLIxwp4FitB+2VRUdB+WeRjBTyLlaD9sigl0ioNqAm6YoYSrP6D59Ouljb66JAFNHLGBhq9qB6VOQLtsUU+VsCzWAnaL4uKhPbMIh8r4FusBO2ZRUVB+2WRjxXwLFaC9suioqD9ssjHCngWK0H7ZVFKpFUaUBN0xQwlWJOWbaebn9t3+TDfrmHrzhb60h2LUZkrRHtskY8V8CxWgvbLoiKhPbPIxwr4FitBe2ZRUdB+WeRjBTyLlaD9sqgoaL8s8rECnsVK0H5ZlBJplQbUBF0xdYL1rftfpg0NzXT4pTP8sh89utzdzuHNF01HZa4A7bFFPlbAs1gJ2i+LioT2zCIfK+BbrATtmUVFQftlkY8V8CxWgvbLoqKg/bLIxwp4FitB+2VRSqRVGlATdMW0yMcKVM5YCbpiWlQUtF8W+VgBz2IlaL8sKhLaM4t8rIBvsRK0ZxYVBe2XRT5WwLNYCdovi4qC9ssiHyvgWawE7ZdFKZFWaUBN0BXTIh8rUDljJeiKaVFR0H5Z5GMFPIuVoP2yqEhozyzysQK+xUrQnllUFLRfFvlYAc9iJWi/LCoK2i+LfKyAZ7EStF8WpURapQE1QVdMi3ysQOWMlaArpkVFQftlkY8V8CxWgvbLoiKhPbPIxwr4FitBe2ZRUdB+WeRjBTyLlaD9sqgoaL8s8rECnsVK0H5ZlBJplSYxivYswv2RjxWonLGSWLpiWqTL2l3SflnkYwU8i5XE0n5ZpMvandKeWeRjBXyLlcTSnlmky9pd0n5Z5GMFPIuVxNJ+WaTL2l3SflnkYwU8i5XE0n5ZpMva1cKzCEGXoiumRT5WoHLGStAV06KioP2yyMcKeBYrQftlUZHQnlnkYwV8i5WgPbOoKGi/LPKxAp7FStB+WVQUtF8W+VgBz2IlaL8sSom0SgNqgq6YFvlYgcoZK0FXTIuKgvbLIh8r4FmsBO2XRUVCe2aRjxXwLVaC9syioqD9ssjHCngWK0H7ZVFR0H5Z5GMFPIuVoP2yKCXSKg2oCbpiWuRjBSpnrARdMS0qCtovi3ysgGexErRfFhUJ7ZlFPlbAt1gJ2jOLioL2yyIfK+BZrATtl0VFQftlkY8V8CxWgvbLopRIqzSgJuiKaZGPFaicsRJ0xbSoKGi/LPKxAp7FStB+WVQktGcW+VgB32IlaM8sKgraL4t8rIBnsRK0XxYVBe2XRT5WwLNYCdovi1IirdKAmqArpkU+VqByxkrQFdOioqD9ssjHCngWK0H7ZVGR0J5Z5GMFfIuVoD2zqChovyzysQKexUrQfllUFLRfFvlYAc9iJWi/LEqJtEoDaoKumBb5WIHKGStBV0yLioL2yyIfK+BZrATtl0VFQntmkY8V8C1WgvbMoqKg/bLIxwp4FitB+2VRUdB+WeRjBTyLlaD9sigl0ioNqAm6YlrkYwUqZ6wEXTEtKgraL4t8rIBnsRK0XxYVCe2ZRT5WwLdYCdozi4qC9ssiHyvgWawE7ZdFRUH7ZZGPFfAsVoL2y6KUSKs0BaC+vp6OOeYYvbgDBxzQs2zXFdMiHytQOWMl6IppUVHQflnkYwU8i5Wg/bKoSGjPLPKxAr7FStCeWVQUtF8W+VgBz2IlaL8sKgraL4t8rIBnsRK0XxalRFql6WaWLl1KJ510UtkEqtz6oqErpkU+VqByxkrQFdOioqD9ssjHCngWK0H7ZVGR0J5Z5GMFfIuVoD2zqChovyzysQKexUrQfllUFLRfFvlYAc9iJWi/LEqJtErTzRx55JE0cuTIsglUdv2gQYPowgsvzKwtHrpiWuRjBSpnrARdMS0qCtovi3ysgGexErRfFhUJ7ZlFPlbAt1gJ2jOLioL2yyIfK+BZrATtl0VFQftlkY8V8CxWgvbLopRIqzTdzPr1693/lSZYQ4YMobPOOova29vVFsVCV0yLfKxA5YyVoCumRUVB+2WRjxXwLFaC9suiIqE9s8jHCvgWK0F7ZlFR0H5Z5GMFPIuVoP2yqChovyzysQKexUrQflmUEmmVpiBUkmANHTqUTjjhBNq9e7deXTh0xbTIxwpUzlgJumJaVBS0Xxb5WAHPYiVovywqEtozi3ysgG+xErRnFhUF7ZdFPlbAs1gJ2i+LioL2yyIfK+BZrATtl0UpkVZpCkIlCdbZZ59Nffv2pVWrVunVHn7w5Jw5cyAIgiAI6gJxu1st8jMBYKKSBIsZNmwYnXPOOR1XAgAAAKDHk58JABOVJlitra104okn0oIFC9QWAAAAAOjJ5GcCwESlCRYzbtw4OuOMMzJrAQAAANDTyc8EAAAAAABANEiwAAAAAACqDBIsAAAAAIAqgwQLAAAAAKDKIMEChaCSh2SDeEr5ylew/uAHP6BDDz2U+vXrR5MmTdKbgBxK+Tpt2jT667/+a3rLW95Cp512Gi1ZssSv49ef+tSnnOef+cxnaNmyZZk9AVPK17q6OjruuOPokEMOoZNPPplmzZrl1+V5LkycOLHsxUepwPdx4rJmlV33+c9/ng477DD693//d9q5c6dfl+dxXkwh5HFezN5AZ5cA6GIqfUg2iCPP11tuuYUuvfRSampqoocfftglWaAy8nz9q7/6K3riiSdcw8XPGf2bv/kbv+7Tn/40Pfnkk7Rr1y56/PHH6bOf/WxmT5DnK98vcPDgwbRt2za6+OKL6cwzz/Tr8jwX/vZv/zYYN0XYi6985St6seMb3/gGDR8+3NX7Bx54gC677DK/Ls/jvJhCyOO8mL2B3vGNA4Wm0odkgzjyfOXGnu9aDOLJ8zULN2LcWyW8+c1vdsuYxsZG1+MC3iDPV+4F2bRpk3u9YsUK6t+/v9piH9pzhntWOJkNxU2RX/ziF/TLX/5SL3a84x3v8N9BTvQ/9KEP+XV5HufFZEp5nBezN9A7vnGg0FT6kGwQR56vb3/72+naa6+lww8/nD7+8Y/T8uXL9SagBHm+ZuHHYLG3wumnn05jxoxxzx8dO3ZssKelN5PnKw9ptbS0uNfsX58+fdQW+9CeM9yzwkPgobgp8s///M9uqJTrNv+QWrlypV/Hy/bs2eNeb968uUMymudxXkymlMd5MXsDveMbB3oEunKC6hDy9cADD6Sf/exn1NDQ4IYI+QQK4gj5moUTWB4SFGbOnOnmovB+Bx98cK+bj1IpIV8POugg/7q9vd19f0Noz6VnhQnFTZGjjz6ann76aZfYTJ8+nb785S/7dVzPH3roIVq3bh197Wtf6+Bjnsd5MfM8zovZG+gd3zjQI9CVE1SHkK88XMXzIhj+Rcu/TEEcIV+FRYsW0YUXXthh2SmnnOJ6rmQOFnqwwoR8PeKII6i5udm95p4Q7hnRhDyXnhUmFDd12tra6G1ve5t/zz3V/D3k+n7TTTe584BQiceMjpnncaUxU6X3feNAYdGVE1SHkK/HH3+8G05huMH/8z//c7UFKEfIV4Z9/a//+i/XoGTJDpdgDlZpQr6ecMIJtHr1avea/eXvb5ZSnnMsrd4Ef9/e9a536cWOl156id797nf79+U8FnRM7W/W40pjpkrv+raBQtPbTn5dRchXvqKHr7jiBomvJvrmN7+pNwFlCPnKtwz4x3/8Rz+ROMvnPvc5ev75593VbtyTxbdsAJ0J+XrBBRfQwIEDXWI6YMAAOvfcc/26PM+zhOKmCF9VycN4nAhNmTKFvvvd7/p1f/mXf0mjR4+mrVu30nXXXdeh3ud5nBczi/Y4L2ZvoHd840CPQFdOUB1CvnIj/61vfct19fOk1TVr1uhNQBlCvn7gAx8o+Wte7oPFE335/9D9mkDY19mzZ7tbifDcNb4yje+vJOR5nqXU8tTgROgjH/kIvelNb3KJ55YtW/y6CRMm0Ac/+EHXm8q3TOD7jgl5HufFzKI9zovZG+gd3zgAAAAAgC4ECRYAAAAAQJVBggUAAAAAUGWQYAEAAAAAVBkkWAAAAAAAVQYJFgAAAABAlUGCBQAAAABQZZBgAQAAAABUGSRYAAAAAABVBgkWAAAAAECVQYIFAAAAAFBlkGABAAAAAFQZJFgAAAAAAFUGCRYAAAAAQJVBggUAAAAAUGWQYAEAAAAAVBkkWAAAAAAAVQYJFgAAAABAlUGCBQAAAABQZZBgAQAAAABUGSRYAAAAAABVBgkWAAAAAECVQYIFAABlaG5u1osAACAXJFgAgF7BAQcc4HXIIYfQKaecQnPnztWbBfnQhz6kF3WC41bCsmXL6Itf/CL16dOH3vGOd9B3v/tdamho8OsrjQMAKDaoyQCAXkE2cWlqaqJBgwZVlDgxlSQ9lWzDnHrqqTRkyBCXVG3YsIF++MMf0ve//329GQCgh1PZGQEAAHo4OgH6f+3cPyi9URzH8RIh/5J/YRAGAwY2pewkYrAjMsgggyzKZGFSFspiIotktcggSSlZbJK/sSD01ffUOZ17PPrVrzPc3/29X3XrOed5zr3X4OnT93yf+/j4KMXFxW58dHQkHR0dUlBQINXV1bKxsWHm/cqXury8lObmZqmoqJDd3V23Xs9vbm5KWVmZ5OTkyM7Ojjvny8vLS6lYvb29mUqWZT/H/1x9ZWVlmfnn52cZGBiQkpIS6e/vl5eXF7cWQPogYAH4L/gBS0PN4uKieVlazdra2pLX11dZW1uT8vJyd85fq+FGK1D7+/vS2NiYcs3k5KSpjm1vb5uQlURD0cTEhKyvr8vFxUV4+kcQVAsLCzI/P2+Op6am5OzszHxPDXQzMzPB1QDSwc//ZADIQGFFSF/Ly8vhZYY2tftBxz/W4HV7e+vGll6jVTF/nOT+/l7Gxsakra3NVKU02J2cnLjz4TqtrPX09MjX15cZ19XVycfHhznWuYaGBv9yAGki+Q4AABnGDy7v7+9me6+mpsbNafDRSpFWqOrr638NWNnZ2fL5+enGVhiMwnESDWTaC9ba2urm/HW6/dfV1SV3d3duLjc3NyUkasM+gPTz5zsAAGSAMPDoVl5RUZEbd3Z2mm24vb09swX3W8CqrKw0YSwUvn84tkpLS1N6sLSnSvuyLH/dyMiIHB4eurGqra3lZyOAf0DyHQAAMowfXGwPVnd3t5vTpvHj42N5enqS4eHhlOu1n+rm5sYcDw4OysrKihwcHPzowfKFY0u3B+fm5uTq6spUsGZnZ83PNlh2nfZXTU9Pu3lrfHxcTk9Pzd+wurpqGvMBpJ/kOwAAZBh/W02fHtRQoyHH0gb3pqYmUyFaWlpKCUhDQ0Pmd6uUPkXY0tIiVVVVptplhYEqHFtavRodHTVPDubn50tvb69cX1+783ad9lb539nOPzw8SF9fnxQWFkp7e7ucn5+7tQDSR/IdAAAAAH+NgAUAABAZAQsAACAyAhYAAEBkBCwAAIDICFgAAACREbAAAAAiI2ABAABERsACAACIjIAFAAAQGQELAAAgMgIWAABAZAQsAACAyAhYAAAAkRGwAAAAIiNgAQAAREbAAgAAiIyABQAAEBkBCwAAIDICFgAAQGQELAAAgMgIWAAAAJF9Awo0Jk59m6PWAAAAAElFTkSuQmCC" }, "Throughput%20vs%20p99%20Latency%20With%20Triton.png": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAFzCAYAAADi5Xe0AABF5klEQVR4Xu3dCZwU1b3ocS9ENMjVl4QEFRKNIYEQSASz6M3Vm5d3Jau7gVzfTUBvoi8hyyckCm6IEhFBBQ3iBu6CCIKKIg6IgCwDwzLsOwz7OqwDDDPD/J//Q6qpOd3Vp3umaqam5/flcz5016murvrPmdP/OXWq+jQBAABAqE6zFwAAAKBmSLAAAABCRoIFAAAQMhIsAACAkJFgAQAAhIwECwAAIGQkWAAAACEjwQIAAAgZCRYAAEDISLAAAABCRoIFAAAQMhIsAACAkJFgAQAAhIwECwAAIGQkWAAAACEjwQIAAAgZCRYAAEDISLAAAABCRoIFAAAQMhIsAACAkJFgAQAAhIwECwAAIGShJli33357VsX/mvqsLo+hsrJSFi5caC+uV0pLS+Wdd96Rfv36yd133y1PPPGELFmyxF7NyGbdIFH8vHLh5+D3wAMPmBgdOHDArpJ9+/ZV+T1Otc7+/ftNnf6clB3zoHjZ62XDv0+ZFBd7vaB9BoBUSLBCUJfH8Oijj9bZe4ehvLzcJEl229Dy4YcfVnvddKL4edX3n4NtxIgR5nhWrFhhV0l+fr6pu+OOO8z/BQUF9iqyfPlyU6fbUXbMg+Jlr5cNu024iou9XtA+A0AqoSZYqdidlM1VXx/U5THU5XuH4aOPPjL7P3jwYNmxY4dUVFRIcXGxPP3009KrVy/ZuXNntdZNJ4qYRbHNuvT++++b45k8ebJdJa+88oqpe/31183/r732mr2K5OXlmbpJkybZVUZQvIKWV1eY2wtzWwByHwlWCOryGOryvcOgyZLu/6pVq6osLyoqMsvHjx+fWJbNuulEEbMotlmX9LSrHs9LL71UZbmeJrvvvvukd+/e5tSgJrZ9+/Y1y/1eeOEF8/qlS5dWWe4JilfQ8uoKc3thbgtA7otNgqVza0aNGmXm1ej8j4kTJ8qJEyeS1tO5HUOGDJG77rpLXnzxxUR9WVmZ+au7f//+pvN/6KGHzF/PelrJL2h/Ui3XfXr77bfNPJJ7771XxowZY7Znr5vtMejIy9ixY802dduaGBw7diyxnn9dm3+599hfguzatcvUDxw40K4yBgwYYOp3795tns+ZM0cef/xxueeee8zx/OMf/0h5KiiVbI5Tf4667tGjR6ss11jqcv1Ze7JZNx1XrPw00dARMj0GbVcPPvigaRPHjx9PrGP/DOxtaxx0JEhjrNvQNqptNaht6s9g+PDh5ng1eZkwYYLZhp/GUffj73//u9x5553y5JNPypYtW0ydjuTpdvT97MRH7d2719Tr70gQHRnUdXRf/TZv3myWP/XUU+a5vq8+997bo3HS5bod5Y+LHSt/vLznrt+lTNnbt3n1qfqVTPdZZdv/ZPIzBlC/xSbB0s7G7sD0NIO93rPPPmv+1/kf48aNM3XaMWmHb79eyzPPPJMyybHZy/U13oeHv7z66qtJ63rPMz2GN998M2k9nVuU7X7a20i1vp8mTLrO9u3bqyz3PjQ1iVIzZ85M2q5XZsyYUeW1qXjrZnKc+mGky/UDyk+TKF2uoyWebNZNx9sXF02K7P33itf2lF3n37YmON58JrtoW/YnQN5yTUbtdf3tSD+0H3vssaR1NAk8ePCgWcf7fVi5cmXidZ6PP/7Y1L333nt2VRW6PV3Pn9BOmTLFLNP/lRcj/xy4I0eOmGV9+vRJLPP20f/YX+z1XL9LmbK3b/PqU/Ur/tfa++LfZnX6H9fPGED9F5sEa9CgQeaDX0cGvPkbusxeTyea7tmzx7cFkdmzZ5s6/YDVSbn6171OstXnulxHYzxB+2Mvnz9/fmKbui3d5uLFixMfOv51sz0GHRVasGBB0n7q5GF7XZu93H6ejpc42fNi9K9nXe7FybuCTK+Y0mRGj8eLsT2ikUo2x6kjarrMHgHRU0u6XE9BebJZN51MY6YJgq6nx64x0OPQREKX2clc0Da9dqTJrSayGs+tW7eaZFaXz507N7Gutw1t47pOSUmJGaXSZToa5dEkV5fpKInGQhOa559/3izzkoPCwkLzXBMV23PPPZcyjjYvaVi3bl3SMj0WtWnTJvPcG9FSa9asMct05M9jx8d+bi93/S5lKuh9PF59qn7Ffq393FOd/sf1MwZQ/8UmwVq7dm1imfcXsA6f2+uluqpp2LBhps7/wa20Y9PlOhLlCdofe7n3F7S/c1TeX//+dbM9Bp2s7eftp/6167HfI2i5/Twd7cx1FMh/mlBHULzTTN5IhZ6W0W36P/yzkc1xvvvuu2aZjiDo6SQdndEPdN0nXa4jCtVZN51sYmbzThHb7xW0TT1WXe6devV4p+n8iYm3DX9CoyNSukx/bh4vyfHfnkK3r8u80346aqKjJLqf/sRBEzz9WWeSKOvtMHSb06dPN8812dH90MTTG3nT99Hnutw7bepdjKCJu8eOj/3cXu76XcpU0Pt4vPpU/Yr9Wvu5pzr9j+tnDKD+i02CpXMuPNp526/znmuSYPNGGuz78Xj34tFRJ4+93aDl999/f8pt+u8B5Mn2GLx5KR5vP3Uuhsd+bdBy+7mLN/nYO024fv1689x/JZjONfO2q4nmhg0bzGiHPZ8kSDbHqR+e3nwdf/ESXP+HTjbrpuO9LlN67PrhqaNDOoqS6vWplilvFCOo+EfCvGX+dqQJjL1tb5uHDh1KLEvFG/nRERKPJhK6TJNVFx191HV1LpTyXvvyyy9XWU+f63IvSdG2pM/19R77GOzn9nLX71KmXK/z6lP1K/Zr7eee6vQ/rp8xgPovNgmWzV7uPU81aVdPDWmdnQB4ow3+D157uypVB+7aZqp9s9nLvedB26zOftrPXbzTad5pQk0a9Pnq1asT6+goh45e6Ck+b/taNDHSK/ZcsjlOpUmrTizWEQpNbPXD35tk7Z/Hk+26QTKNmY4sDB06tEoM/MUv1TLltaOg4j+tGbQNe7m3TdekaN1/XVc/4L3RJb3QQF/rneJLR2+Foes+8sgj5rl3KsseqfHui+Ulct6pXP9tM+xjsJ9Xd7mL63Vefap+xX6t/dzj6itcv9fplgOov+pdgpWK9xek/sXo54026Qexx9uOf+Kpzpmwtx+0Te+v0kz2zV7uPbf/0vX207vrtcp0P+3nLvqhrCMg+iGo29akSU+xpfqA0Q9lnXem9zvyRvQyuVIvm+MMoqNmuu7DDz9sVyXJZl2Vacy8kRk97aanu3Q+lXc1ppZUE9Rt3pw912iTCtqGvdzb5uHDh31rpeYdg14BqvurpwYzOT2otH1oIqsJhCbd3uid/hz9vJ+r1mub0dOS+rpUk7uDnld3uYvrdenq7Tr7uSeor0jX/9iClgOov3IiwfKu9rP/svYmn/rnuXhXovk/nLxJuf7t6wRdfW7PQ/L+Ws9k3+zl3nP7dgfe5HP/bScy3U/7eSa8UStvjo1eAu+i+6Hr6vwdl2yOUz/sdZk9R8mbx6MTsj3ZrJtOpjHz5qJ5V+Ypb1K3Fv+IRdA2vbY5b948uypJ0Dbs5d7d7JctW5ZY5o3i2bde0Lk+ulx/BxYtWmQe++dGuXjv5U2aD0pidbnWe+/hXZHqsY/Bfl7d5S6u16Wrt+vs555s+p+gbQQtB1B/5USC5X1w68hMqqt4Zs2alVjXu9/TBx98YOZB6FwkPQVib18/EPW5/vWpl7rrurptHfGx17WfBy33nus29WaZup96ys77C9h/U8ZM99M7PaFXJGXKnyRosRMW7330A0LfW0cldKKzLtOrn1yyOU69T5Yu04nrOuKl76Vx9iau+3922aybjh3DIN7PWucSaTKlI2VekqfFn3gF/Rx0n3S5tkU9bt1nneDuTdj3vkpGBe2XvXzq1KnmuY4Y6Wk4nT+kSasuS3WzVf156qiSl7Trzz9T3q02vJu8+m9P4eedegxazz6GoHjZ67mWu7hel67ergva52z6H3ubruUA6q+cSLD0tFfQXBn969F/qsL7ChB/8T6c/NsPug+WV/xXkdmvDVruPU+1XfuO2Znup3/St95eIVPePBl7pEF5k5tTFR3JcPHWzeQ4NUnxEi+7aPLhn2eUzbrp2K+1i8c/2T9V0Q9RT9DPQdtR0D2S9Fh0npPHfv+g5ZqkeT8/f9GEMNVpQy/J06L7mep0cBD/iK0W/6iZn/fdg16xR3PsYwiKl72ea7mL63Xp6u26oH3Opv+xt+laDqD+yokES+mHjt44UTtB/UtdR4D0Kip74ql2hjp6oB9GOolbkxYdAUi1ff1LVP8S15EY727m3iRxfa0n1WtTLfee69Vw3kRt3U8dkfB3wirT/dTRNf2w1WPWY8+Ud8NI+zYUHj3Vo8erSYDupz4O+nC1ZXOcSkfn9PSerqd/9eutDTSRS5UIZLNuEG//gopHf/6jR482PwONg36I6nulOiWZ7ueg85f0ogI9faf1OgdNr7SzvzvRfv90yzWR0rlxGgM9baujev5kzU/nf3nb0NPC2fBuQqtFR3DsO/F7vFs4eOva99iyjyEoXvZ6ruUurtelq7frgvZZZdr/2Nt0LQdQf0WeYOUab+Kqd2VVNhpKJ9pQjrO+8OZF6ajrtm3b7GoAQARIsAJ4E5x1HpLegFNHlHT0xJv87r+3UKYaSuLRUI4z7nQEUSe/e3PqMrmYAQAQDhKsAN6VdqmKnurRD69sNZTEo6EcZ9z526xOPs90jhoAoOZIsALoXCGdn6Rzj7x5LjqvQi9x93/5bTYaSuLRUI4z7rS96lw1vQDAvicZACBaJFgAAAAhI8ECAAAIGQkWAABAyEiwAAAAQkaCBQAAEDISLAAAgJCRYAEAAISMBAsAACBkJFgAAAAhI8ECAAAIGQkWAABAyEJPsJ555hl7UeRq+z2PHz8uM2fOlFdeeUWee+45GTVqlCxYsEAqKyvtVaulto8HAACEiwSrGiZPnixLly6V0tJSk1QdO3bMfDH03Llz7VWrpbaPBwAAhKvWEyxNSAoKCszoz4gRI2TChAly+PDhKuvMnz9fXnrpJXnttddk+fLlzm266sN+T91GeXl5lWVHjhwxr/dUVFTI1KlTzbpjxoyR3bt3J+oOHDggkyZNkldffdWMgI0ePVqKiooS9fZ7e/um+69J3IkTJ6rUAwCAeKn1BGvx4sWyZMkSk6Bo0efTpk1L1K9atUqmT59uRoU0ERk7dqxzm676sN/zjTfekMLCQnOqMIgmdKtXrzaJ1sqVK02S5XnzzTdNEqf7ottYuHChSZ48/vdes2ZNYl3dP91XfQ4AAOKr1hMsTU40UfAcPXrUjBp5xo8fLwcPHkw837Ztm3Obrvqw33PPnj1m3pWOPr3//vuybNky2blzZ5V19D3TJWB+moT538//WPfN3o4/WQMAAPFT6wnW8OHDq5zi0se6zKOn1DTh8Ghy4dqmqz6K99TTjrt27TKjT++9955JtnSUzKPbTzfpXRMyHeXS+VyarAUlWC+88IJ57i/+fQcAAPFT6wmWJgx2svPss89WqfcnO2VlZYlt2omGJ8r3zNSOHTvMqJXn+eefD0yw9JShJlX6vyZpOh8s6Hh0O/59AwAA8VfrCZbOP7JP1/lHZN566y05dOhQ4rkmLq5tuurDfk+dcG4nPTpHShM1jyZbmqil8vLLL5tJ8Z59+/YFJlg6H8yekA8AAOKt1hOsFStWJE04988pWrt2rcyePdskJ5r06Bwk1zZd9WG/p17Jp6f3iouLzfZKSkrMpPePP/44sY7eF2v9+vUmEVu3bl2V99PRK528rnU6n8ueVO9/rBPaFy1aZPZNk0Q9Jfnuu+8m6gEAQPxEkmAFFaWnzfLz881oz4svvihTpkypMnqkNHnRupEjR5rTaDq/KR37faJ+T92eJlDeRHcdkdL7YPlHtTTxysvLMyNlOoK2d+/eRJ13OlHnfmlypQleUIKl76UJnY6a6elCvb2Df/QLAADET+gJVtjsK/5qQ128JwAAyB2xS7D0flA6oqOjQXo1n56609GhKNXFewIAgNwVuwRL70E1btw4c/pMT4vNmjUraUJ52OriPQEAQO6KXYIFAABQ35FgAQAAhIwECwAAIGQkWAAAACEjwQIAAAgZCRYAAEDISLAAAABCRoIFAAAQMhIsAACAkJFgAQAAhIwECwAAIGQkWAAAACEjwQIAAAgZCRYAAEDISLAAZKywsFC6dOkiX/va1+T222+XLVu2VKk///zzE+VLX/qSXH311bJq1aoq9akELY+LiRMnmuPR/bzwwgtl2rRpVepz9bgBVB8JFoCMaMJw8cUXy3vvvSdHjhyRDRs2yA033CAjR45MrONPGMrLy+Xll1+Wa665JmW9X9DyONDk6itf+Yr5X7311lsmwZw9e3ZinVw8bgA1Q4IFICO33XabvPbaa1WWaZL117/+NfHcThiOHj0qX/3qVxPP7XpP0PI40BGp999/v8qyt99+W5588snEc3v/c+G4AdQMCRaQg7qPXJtVycS3vvUtKSoqshdX4U8YdJTrqaeekuuuuy5lvV/Q8my9sbR7ViUTmexbXR83gPghwQJy0Gl/mZVVyYSO5JSVldmLq9CEwSs6V+mqq66q1blId3xwWlYlE/a++Y8x1bK6OG4A8ZNZDwOgXrETKFfJRMeOHWXnzp1VllVWVsr+/fsTz10JQ+vWrc0cJT99rsvDYCdQrpKJoGOyE6x0oj5uAPGTWQ8DoF6xEyhXycTvf//7pDlYOtH7Bz/4QeK5K9HQid/r16+vsmzt2rVy/fXXV1lWXXYC5SqZSDUHy7uq0FPXxw0gfjLrYQDUKy/O25VVycTSpUvNPKzJkyfLsWPHzC0bLr/8cnn99dcT67gSjZkzZ8pNN91kJscfP35c1qxZI7/61a8kPz/fXrVa5m97MauSCb1qUq8i9JIs/d9/VaGq6+MGED8kWAAyNnXqVPnxj38sF110kVx66aXy/PPPV6l3JRpq3Lhx5rUXXHCBXHbZZTJhwgR7ldjRJMu7D5b+70+uVK4eN4DqI8ECAAAIGQkWAABAyEiwAAAAQkaCBQAAEDISLAAAgJCRYAEAAISMBAsAACBkJFgAAAAhI8ECAAAIGQkWAABAyEiwAAAAQkaCBQAAEDISLAAAgJCRYAEAAISMBAsAACBkJFgAAAAhI8ECAAAIGQkWAABAyEiwAAAAQkaCBQAAEDISLAAAgJCRYAEAAISMBAsIsH37dnsRLMQIcUcbdSNG0SDBAgLMnz/fXgQLMULc0UbdiFE0SLCAAHQ6bsQIcUcbdSNG0SDBAgLQ6bgRI8QdbdSNGEWDBAsIQKfjRowQd7RRN2IUDRIsIACdjhsxQtzRRt2IUTRIsIAAdDpuxAhxRxt1I0bRIMECAtDpuBEjxB1t1I0YRYMEK8YOHz5s7k9CoVAoFAol+qKfu2EhwQIC8FedGzFC3NFG3YhRNEiwgAB0Om7ECHFHG3UjRtEgwQIC0Om4ESPEHW3UjRhFgwQLCECn40aMEHe0UTdiFA0SLCAAnY4bMULc0UbdiFE0SLCAAHQ6bsQIcUcbdSNG0SDBAgLQ6bgRI8QdbdSNGEUjJxOs4uJiadmypb1Yli1bJm3atJEmTZpIp06dZN68efYqRrr1qlvnd9ppORn2nEOn40aMEHe0UTdiFI2c+6RfvXq1dOzYMWUS061bNxkwYIAcOHBAevXqJV26dLFXMdKtV906v1T7hvih03EjRog72qgbMYpGzn3St2jRQkaMGJEyidHRpT179pjHRUVF0q5dO2uNk9KtV906P/++9e/fX3r27OmrRVzQ6bgRI8QdbdSNGEUjOQup53bu3Gn+T5VgNWvWTMrLy83j0tJSadq0qbXGSenWq26dn7dvgwYNkq5du0plZaW1BuKATseNGCHuaKNuxCgayVlIjkiVYDVu3DjxWJOaRo0a+WpPSbdedev8dN+GDBki7du3N4lYEP1eJG34FAqFQqFQoi/6uRuW5CwkR6RKsJo3by5lZWXmsSY2OuKUSrr1qlvnp/t20003SatWrWTz5s12NWJCf9mQHjFC3NFG3YhRNJKzkByRKsHSEaOtW7eax5rYtG3b1lrjpHTrVbfOz9u3oUOHmonxiCc6HTdihLijjboRo2gkZyE5IlWC1aNHD+nXr5+UlJRInz59pHv37vYqRrr1qlvn5+1bRUWFdOjQQQoLC601EAd0Om7ECHFHG3UjRtFIzkJyRKoEq6CgQFq3bi2nn366ueJP71vl8a+fbr3q1vn53ysvL086d+7sqz3l8OHD5nwwhUKhUCiU6It+7oYlOQtpoG699VZ7ERo4/qpzI0aIO9qoGzGKBgnWP/Xu3dtehAaOTseNGCHuaKNuxCgaJFhAADodN2KEuKONuhGjaJBgAQHodNyIEeKONupGjKJBggUEoNNxI0aIO9qoGzGKBglWjHEVIYVCoVAotVe4ihCoBfxV50aMEHe0UTdiFA0SLCAAnY4bMULc0UbdiFE0SLCAAHQ6bsQIcUcbdSNG0SDBAgLQ6bgRI8QdbdSNGEWDBAsIQKfjRowQd7RRN2IUDRIsIACdjhsxQtzRRt2IUTRIsGKM2zRQKBQKhVJ7hds0ALWAv+rciBHijjbqRoyiQYIFBKDTcSNGiDvaqBsxigYJFhCATseNGCHuaKNuxCgaJFhAADodN2KEuKONuhGjaJBgAQHodNyIEeKONupGjKJBggUEoNNxI0aIO9qoGzGKBgkWEIBOx40YIe5oo27EKBokWEAAOh03YoS4o426EaNokGABAeh03IgR4o426kaMokGCBQSg03EjRog72qgbMYoGCRYQgE7HjRgh7mijbsQoGiRYQAA6HTdihLijjboRo2iQYMUYX/ZMoVAoFErtFb7sGagF/FXnRowQd7RRN2IUDRIsIACdjhsxQtzRRt2IUTRIsIAAdDpuxAhxRxt1q60YlRzfLfuObkxZNh3Il/XFH0VSVux+R+ZvezGjEiYSLCBAbXU69RkxQtzRRt0yjZGXIG05MM8kLot3vmGSkmkbB8rkdX1l/IrfyxtLu8vzC34qTxf8QB6d1U4emnGh3DPlLLnjg9PqRQlTuFsDckimnU5DRowQd7TRU4ISpFc//lNAgvSNUBOk+6c2N9tLVYbmX2reM4ry4sKrzXFlUsJEggUEoGN2I0aIu1xqo0EJUtII0sKfmcQiqgTpiTnfNtt/pfBG837vrv6ref9Zm/5h9mfF7glm/3YcXmr293hFeFfm1SckWECAXOqYo0KMEHdxaqPxS5BuMO83YsZ/kyBFoMElWMXFxdKyZUt7cRXLli2TNm3aSJMmTaRTp04yb968Gtf5nXZagwt7vRSnjjmuiBHiLsw26k6QetRKgvT4nEuqJEg1HUEKM0Y4pUF90q9evVo6duzoTHC6desmAwYMkAMHDkivXr2kS5cuNa7zc70/4oFOx40YIe78bTQoQZq+cVCtJEh9p34ukgSppvg9jkaD+qRv0aKFjBgxwpng6CjUnj17zOOioiJp165djev8/O/fv39/6dmzp68WcUGn40aMUJvsBGnJzjHOBOn+KefnfIJUU/weRyN9ppFjdu7caf53JVjNmjWT8vJy87i0tFSaNm1a4zo/7/0HDRokXbt2lcrKSmsNxAGdjhsxQjaqkyA1hBGkusbvcTTSZxo5ypVgNW7cOPFYk59GjRrVuM5P33/IkCHSvn17k4ghnuh03IhRw5JIkA4W1JsEafrcd6S0/JB9KPDh9zga6TONHOVKsJo3by5lZWXmsSZAOjJV0zo/ff+bbrpJWrVqJZs3b7arE/SLJ7XhUygUShhlVsEUk3Dk5b8s7895Rt6a/bCMndVXXvv4z/LSjFvl2eld5KlpV8ngj/5dBk29RP4+5SJziu2uvE8nJTvVKffknWO299CHXzfbf+KjH5r306vY9P1fn3mH2Z93Zg82+/fh3NfN/uYXzEg6FgoliqKfu2FJn2nkKFeCpSNLW7duNY81AWrbtm2N6/y89x86dKiZGI940l82pEeMaldJ2Z56MoL0zj9HkJaY/a3LESTaqBsxikb6TCNHuRKsHj16SL9+/aSkpET69Okj3bt3r3Gdn/f+FRUV0qFDByksLLTWQBzQ6bgRo+xUP0FqlpTsVKe4EqSZm56IXYJUU7RRN2IUjfSZRo5KlWD5lxUUFEjr1q3l9NNPN1cG6v2talrn53+vvLw86dy5s68WcUGn49bQYpRtgvTYrPYkSHWsobXR6iBG0UjONBqoW2+91V5U5w4fPmzOB1MolHDLtm1bZdnG6fLxqudlyvIhMmHxvTJu0V/ltfk3y0vzuspTc/5Tnpj1bzJwelv5+0dflLsnh3OKrc+Uz5jtPTLjm2b7z+X/zLzf6wtuM+8/cenfzf7MWv2SzFszVpZs+FBWFc2Voi1rko6BQqGEX/RzNywkWP/Uu3dvexEaOP6qc4t7jHTkRkebFm5/RSau6S0vL7rOnHbrnfeppOQnk2KPIL1ceL0ZQZqwqicjSDEV9zYaB8QoGiRYQAA6Hbe4xOhQ6Q5Zs3eyzN48VN5e+UcZPr+z9Jt2XlKC5C8DZlxkTuWRIOW2uLTROCNG0Yg0wZo2bZqZk6R0ftK5555r5ieNHz/eWhOIHzodt9qM0YnKCtlVskKW7Rov0zY+LKOX/lqeyP+O9PnwnKTkySt3T/60DJ79LXl1cRfJW3efLN4xWrYfKpSyE0ftzSNH1WYbra+IUTQiTbAuueQSGTNmjBw5ckTOOussGT58uLz55psmyQLijk7HLYoYVee0np66e3Luv8mby2+TGUWPysrd78reI+vsTaMBiqKN5hpiFI1IE6wzzjjD3Ipg8uTJ0qRJEzl27Jj5KhkSrMwwyZ2Sy2XtpkLJXzNa3l/6oIycf4s8Oes/5P6pLZKSJ3958KMLzAR0nRT+wbKHZf66t2XDluVJ26ZQKJTqlHozyV0TqWHDhsn1118vV1xxhVk2atQo6dixo7UmED/8VefmilFNT+vpvChO66EmXG0UxCgqkSZY3unAK6+8Unbv3m2WtWvXTubOnWutCcQPnY6bF6PUp/XaZXFa7zFzWq/46HrrHYCa4ffYjRhFI9IEC6jP6HSSeVfr6Vei6NV6j0291Hm13sMff8VcradX6uVveVqK9s+SI2V77U0DkeD32I0YRYMECwjQUDud4NN6ZyclT15JOq238w1zWq/8xDF780Ctaqi/x9kgRtEgwQIC5HqnU/W0Xq+MTuvdP7W5DJv7/cRpvQmzh3BaD7GW67/HYSBG0SDBAgLkSqdjn9bL5CacmZ7Wy5UYIXfRRt2IUTQiSbBatWolv/jFL+Sxxx6TOXPmyPHjx+1VkAFu00DJtGzdtiXx3XoTltwjL877hTz68cVy75R/TUqevHJX3pkycHo7GT73avM9eNNWPiNLNkyRLds2Jm2fQqFQGkKJ/W0atmzZIm+88Yb85S9/kUsvvVTOPvtsueyyy6Rnz57mxqNaD8RdHP+qO1Z+sMan9cK8Wi+OMQL8aKNuxCgakSRYttLSUpk9e7Y8+uijcuONN0rLli3tVYDYqctO5+RpvbxITuuFqS5jBGSCNupGjKJRKwkWUB9F3en4r9b7aMOAjK/WGzL74thcrRd1jICaoo26EaNokGABAcLqdLzTegu2vRyL03phCitGQFRoo27EKBokWECAbDsd+7Tec/OvzPK03jO1clovTNnGCKhttFE3YhQNEiwggKvT0dNzry/9Vb07rRcmV4yAukYbdSNG0Yg0wZo2bZq0bt3aPC4oKJBzzz3XfDfh+PHjrTWB+EnX6Rwo3Sp3T2laL0/rhSldjIA4oI26EaNoRJpgXXLJJea2DEeOHJGzzjpLhg8fnvgCaCDu0nU6Y5f9xiRVenVffTutF6Z0MQLigDbqRoyiEWmCdcYZZ0hFRYVMnjxZmjRpIseOHZPy8nISLNQLQZ3O7pKVZoL6nXmny54ja+zqBiUoRkBc0EbdiFE0Ik2wNJEaNmyYXH/99XLFFVeYZaNGjZKOHTtaawLxE9Tp6FWAOno1bsXv7KoGJyhGQFzQRt2IUTQiTbC804FXXnml7N692yxr166dzJ0711oTiJ9Unc6mA3Ok1wf/IvdMOctcNdjQpYoRECe0UTdiFI1IEyygPkvV6Tw17wozevXB2nvsqgYpVYyAOKGNuhGjaJBgxRhf9hyvMmvVyya5uu/Dz8nGLauT6ikUCoVSv0vsv+zZc/DgQbn99tvloosuMpPclX7p84oVK6w1gfjx/1VXWXlCHpvVwSRYH28a4lurYeMvX8QdbdQtihht2X9cNhaXSsnxCruqwYg0wbr22mtl5MiRcuLECTnttJNvpZPev/3tb1trAvHj73Tmb3vJJFcPzbhQKk4c963VsEXRMQNhoo1Wdaz8hEl8tKzbe0w+WntAnnlvnvn/5YLd8uK8XaYM+HCr9J202ZQ/jtsg3UeuNeWGF1bJD55cZsqlQ5bIhf3mm/KlB+bLaX+ZlbaccfucxPpa/u3xJYlt/fy5FYn30NLn/ZPvrWXI9O2J/Xpl/m6zr15ZvvNI4nh2HCqzD7dORZpgnX322ebWDMpLsPR5s2bN/KsBseR1zOUnSj9JrC4wCdbC7a9YazVsfHgh7upTG9128OSoj5bF244kkoiJK/YlEozh+TsTiYeWm0edSkp+9MzyRMLSYeCiRCLz2bvnJiU7UZdW9xeY927aa05SXW2Vln1P7oMWjYcXGy3dfMncX97amIhnmCJNsHQEa8mSJebeV5pglZWVyaJFi+Saa66xVwVix+uY9a7smlwNnv0tc6oQp9SnDy/kvn1HyxMJilfemVYgC7aUVBn18MrYxXsTiYu/DJq6rUoS45Xfj11fZZTFKz95dkWVD2+vtH1oYZURG6/oSI6dDNRW8Y8iXfT3BWY/Lxkwx/z/36+uSRxTrwlFieN+fMapEaQxhXsT8Zu98VAizkX7Su0fRxL/6JmWWZ+83tvWO8uKq/wM/HH/k28ETffRH+OvDzgV4xZ95iUdb7YlTJEmWAcOHJA//elPcsEFF5ibjrZs2VJuu+02KS4utlcFYkeTh2PlB6Tv1M+ZBGvVnon2Kg0eCVZusT8AvbJi59Gk5ETLe76RFX8Z+vGOpOREy1/f3piUnGi53nfayV86PlKYlJxo+V931f6ITG2U8+87NeLyzUGnRlx+/MzyRKxueX1tlZiOyD8V9/dX7k/8bJZsP3XqbG9Juf2jriJXf4+3Hjg1Iqjx8Lddf3t9dNqphDpMkSZYQH2mnc6kNXeZ5Orpgh/Y1ZDc7ZhrQv+StxMULXZyomWq1dF7RT807eREyz0TNyUlJ1r+r/VXvVd0joudnGg575MPcvvDPRfKZz5JvOxjPf/e2dLp0cKk2GjR+UR2LLX87Z1Tp4z85cmZO5J+Vlr0FJ79s9WyctfRpHagpbQ8XiPh/B5HgwQLCDBz3iRzQ1G9sejmA9wcN5UoOuZdh8uSPpC05BedOp3gL6MW7kn6wNPSf8qWpA9ILb8dvS7pA1XLfz6V/AGs5St/X5D0oa2lUc/kD/j6Xs60JiF7RU/D2HHR8tNnq05M9kqPN9cnxV3LIx9tS/o5aXlz8anTTv6ycGtJUjvQsv9o+hEZvyjaaK4hRtGIJMHS+VauEnd6GlNPadqWLVsmbdq0Mbed6NSpk8ybNy+jOr/6cPwQeXra9Wb06pXCG+wq/FO2HfP6vcfkg1X7ZdisHeZ00bXPrzSTT+tyImzY5YIHkhMULXZyouV/f1Ls5ESLfRrIK/3ytiQlJ1peta6s8orOcbGTEy3bDzacK2GzbaMNETGKBp/0Kaxevdp8X2KqRKhbt24yYMAAM7+sV69e0qVLl4zq/FJtF/Gy58hq6f3Bp8yXOu8uWWVX45/sjvl4RaWZrzNh+T5zabVe3q2jHG0eWiif+uvspGQkVfn8vfOSkhMt3x28OClB0fLLl1cnJSha7nx3U1KCouXZOTuTEhQtk1cnJyha9FJ2O0HRcqKyyqEjpuw2imTEKBp80qfQokULGTFiRMpESEeo9uzZYx4XFRWZ71bMpM7Pv93+/ftLz549fbWIg1cKbzSjV28uv82uavAOl1ZI4bYScwXWn16eK7e+sU7+z7Bl8sX73ffB0Uu3NSnS03R6nx29IklPAx081nBvRohokTy4EaNoJGcQIdq1a5fcfPPNcu6555ovff7iF79orio8dOiQvWqs7Ny50/yfKsHSe3jpbSdUaWmpNG3aNKM6P2+7gwYNkq5du0plJX8Kx4nOt9J5V3fmnSkHS7fZ1Q2CXnU0d9NhGblgjzkt9avX1pgJ067LoHVe0lf7LzBXPek8nMembZO3lxXLsh1HzBVqQG0jeXAjRtFIziBCdPnll8uQIUNk3759JvHQ0Z2BAwea+2PVB6kSrMaNGycea2LUqFGjjOr8dLsal/bt25tEDPHyTMH/NqNXL864xa7KKToP5+MNB82NC/V0WpeXVssljy52XgKvE6HbD1wk14xYKf/3mXxzZdWklftl7Z5jwt8KiBuSBzdiFI3kDCJE55xzjpmP5KffT/iv//qvVZbFVaoEq3nz5uaGqUqTI/9d6dPV+el2b7rpJmnVqpVs3hx83w394klt+JTaKxNmP2GSq3vyzpE5BdOS6utbeWtagQydME96j5or//1svvzgsTnS+v7ZnyRJyYmTvzS7Y5a07Tdb/nPwHOk+PF/ueX2uPDuxQCbOKEh6DwqFQsmVop+7YUnOIEI0evRo+fWvfy1r166V48ePy9atW6VHjx5m9KY+SJVg6aiTHofS5Kht27YZ1fl52x06dKiZGI940Lu0693aNcGaUfSo+WWLO3tS+R/e3GDuKv21/gulsWNSuU4m1+8S0zsj3//BZnltwR5zK4Q9JZl/n1d9iBEaNtqoGzGKRnIGESL/bRmCSpyl2j9NEPv16yclJSXSp08f6d69e0Z1ft52KyoqpEOHDlJYWGitgbqg3zN48gudLzDfPxiXTudQaYUs2npyUvnDH27NelK53gogqknlcYkREIQ26kaMopGcQSAhVYJVUFAgrVu3NpP29apBvfdVJnV+/u3m5eVJ586dfbWoCxUnjn+SWF1oEqz52140y2qz0/FPKn/ggy1mVOmyx5fIF+7NfFK5jl4Nnr7dfKeXfsN8bdwtujZjBFQHbdSNGEUjOYNAbBw+fNicD6ZEX95b+oBJrgZO/7ps27Y1qT6MsmjtFhk/b4M8+sFq+ePry+XqZxbLNwfMl7N7p7/J5pm3z5a2/QvkR08uktteWyb931slI2etl9krNn+yr8nvQ6FQKJTqFf3cDUukCda4cePMaI59WjDVyBBQV0rLD8n9H33eJFgrd7+bWF6dv+r0ppT2ncr1irtP35E+iTrnznzzfWm/eHGV9H63yFzZN33dQdkW8ztuVydGQG2ijboRo2hEmulocjV+/Hh7MRArH6y91yRXT827osryVJ2OnnbT0296Gk5Px2UzqVxP9+lpv5pMKo+bVDEC4oQ26kaMohFpgnXeeeeFOtwGhO1Q6Q65Z0ozc2PRov2z/7ns5KTygePnJSaV/3DYMjNh3E6a7KLr6LrepHKdmK7b0m3mIjpmxB1t1I0YRSPSBGvSpEnSt29f88XJQJzoqJGOHg2cfosZvfrTO1dmNKlcR6l0tEpHrepiUnnc0DEj7mijbsQoGpEmWG+//XbiS5PtAjcmudes6KTycXODJ5V/5u7R8rf3PyV/m9RYPnfva4kkKnhS+aak96BQKBRK7pQwz7pFmunonc3HjBmT+H4+IGz69Sz6NS36dS0939povr4lk0nl+nUw3d74kRm96vX+L1NOKuevOjdihLijjboRo2hEmmB97nOfM98/CNTE7sNlNZ5UrveW8iaV6z2nthwsMPOu7p7SVA6Unrz7vo1Ox40YIe5oo27EKBqRJljDhw+Xm2++WXbt2mW+/BjI1sQV+9LOi9K7meukcp2IrhPSM51U/mzB/zGjVxPX9LKrEuh03IgR4o426kaMohFpgmXPu2IOFjKl37H3l7c2yr/0PJlIfWfw4tAmla/e+4FJrvpO/awcLd9vVyfQ6bgRI8QdbdSNGEWDTAexs2rXUen4SKFJrE7/22wZOHWrhDUAWvnJv8fndDIJ1rSNA+3qKuh03IgR4o426kaMokGChVgZkb9LzuqVb5Kr1g8ukILN4V3RoRbtGGmSq/7TvyjlJ47Z1VXQ6bgRI8QdbdSNGEUj0gTr+eefN1cS2qcHOUWYmYZ0m4ZVG7ea2yh4c6tuHL5Y1hSF+52AW7Ztkgc/usAkWFOWD06qp1AoFErDLvXmNg0tWrQwSVZZWZl8+ctfliVLlsgtt9wigwcPtldFAzZr4yG54IH5JrE6+858eXX+bnuVUMza9A+TXD066xtyojL9JHjFX3VuxAhxRxt1I0bRiDTBatq0qZSWlprH1113nUycOFH2798vX/rSl6w10RBVnKg0t0/41D9vt/DdwYtl/d70p+2q63jFYXngoy+YBGv57rft6pTodNyIEeKONupGjKIRaYL13e9+V8aOHWse9+vXz3xtTkFBgbRp08ZaEw3Nlv3H5T+GLjWJVaOes6T3u0VSfiKkmewpTF7X1yRXw+Z+364KRKfjRowQd7RRN2IUjUgTrNmzZ5uvylG7d++W73//+3LOOefI+PHjrTXRkIxbUiyfvXuuSa7Ov69Apqw5YK8SqsPHd8m9H/6rSbA27p9pVwei03EjRog72qgbMYpGpAkW4He07IT8vzHrExPZf/7cCvOly1F7a+UfTHL14sKr7aq06HTciBHijjbqRoyiQYKFWrFk+xH5xsOLTGJ15u1z5IkZ2+1VIlF8dL3cObmJ9M5rLDsPL7Or06LTcSNGiDvaqBsxikYkCZZe5njrrbea7yJUCxculA4dOsiZZ54pP/rRj8ylkGg4/vHxdpNUaXLVbsAiWbztiL1KZEYu/i8zevXGspvtKic6HTdihLijjboRo2hEkmD95je/kT/84Q+JKwjbtm0rDz30kBw8eFB69+4tV1+d3aka1E/6pcpXD1+ZOCWo3xd45Hj1vt6mOrYdWii98hrJXZPPlP3HNtvVTnQ6bsQIcUcbdSNG0YgkwfrMZz5jJrWr5cuXS+PGjeXAgZMTmYuLi+XTn/60f3XkoKlrD0jLvgUmsfrMXXPlzcV77VUiN3x+ZzN69d7q2+2qjNDpuBEjxB1t1I0YRSOSBEvvf+XdDfXZZ5+V73znO4k6PT2od3dHbtJbLdz57iZz6wVNri7/x1LZvP/kSGZtWls8xSRX9039jBwpK7arM0Kn40aMEHe0UTdiFI1IEiydZ/XYY4+ZJOt3v/ud3HbbbWb5oUOH5M4770w8R27ZUHxMvjdkiUms9Oah93+w2dxMtLbpFzo/MefbJsH6aMMAuzpjdDpuxAhxRxt1I0bRiCTBWrdunXzve98zk9qvvfZa2bv35OkhPXWok99LSkqsV6C+G7lgj/maG02u9GtvZm44ZK9SaxbvGG2SqwentZSyE0ft6ozR6bgRI8QdbdSNGEUjkgQL4agPX/a8tmirdBlxctTK3Nvq6ULzxc32erVVTn6h84Umwcpb/khSPYVCoVAoQaXefNkzctv8LYflq/0XmMTqrF758tycnfYqtW725idNcvXIzK9n9IXO6fBXnRsxQtzRRt2IUTRIsJC1ykqRQVO3yel/O/klzRc/Uigrd1X/VFxYjleUSL9p55oEa+mucXZ11uh03IgR4o426kaMokGChazsOFQmnZ9ebhKrf+k5S/48foOUltfeva3SmbL+AZNcPTn3MruqWuh03IgR4o426kaMokGChYxNXLFPvnDvPJNc6f/vffI8LkrK9kifD882CdaGfdPt6mqh03EjRog72qgbMYoGCRacjldUyl/e2mhGrDS5uvKp5bL94HF7tTr1zso/m+TqhYU/t6uqjU7HjRgh7mijbsQoGiRYSGvVrqPS8ZFCk1jpnKuBU7eaOVhxsu/oRvOFzvq1ODsOL7Grq41Ox40YIe5oo27EKBokWAg0In+XuTpQk6vWDy6Qgs3hXb4aplFL/tuMXo1e2s2uqhE6HTdihLijjboRo2iQYFWDfp9iy5Yt7cWybNkyadOmjTRp0kQ6deok8+bNs1cxTjst3mE/cKxCury0OnFvq1+9tkYOldbslgdR2X6o0PeFzpvs6hqh03EjRog72qgbMYpGvD/pY2j16tXSsWPHlElSt27dZMCAAeaLrXv16iVdunSxVzFSvTYuZm08JBf2m28SK70z+6vzT35pd1yNWPBjM3o1YVVPu6rG6HTciBHijjbqRoyiEd9P+phq0aKFjBgxImWSpKNXe/bsMY+LioqkXbt21hon+V/bv39/6dkz/OQgW/qdgQ98sMV8h6AmV98dvFjW7z1mrxYr64qnmuSqz4fnyJGyvXZ1jdHpuBEjxB1t1I0YRSM5S0BaO3eevFt5qgSrWbNmUl5ebh6XlpZK06ZNrTVO8l47aNAg6dq1q1TW8azxLfuPy38MXWoSq0Y9Z0mvCUVSXgdf0pwN/ULnf+R/1yRYH2540K4OBZ2OGzFC3NFG3YhRNJKzBGQkVYLVuHHjxGNNmho1auSrPUVfO2TIEGnfvr1JxOrS+CXF8tm755rk6vz7CmTKmgP2KrG0ZOcYk1z9fdr5UlZxxK4OBZ2OGzFC3NFG3YhRNJKzBGQkVYLVvHlzKSsrM481cdIRrVT0tTfddJO0atVKNm/ebFcn6BdPasOPosyaO19uGHbyCkEtlz8yR6bMKkhaL46lYP5ceWDyl05eOTjzrqR6CoVCoVCqU/RzNyzJWQIykirB0hGprVu3mseaOLVt29Za4yTvtUOHDjUT42vbku1H5BsPLzKJ1Zm3z5EnZoTXoGpD/panTXI1aGYbOVF58pRsFPSXDekRI8QdbdSNGEUjOUtARlIlWD169JB+/fpJSUmJ9OnTR7p3726vYnivraiokA4dOkhhYaG1RnSGfrzDJFWaXLUbsEgWb4vm9FpU9HRgv2nnmQRryc6xdnWo6HTciBHijjbqRoyikZwlICOpEqyCggJp3bq1nH766eaKQr0vVir+1+bl5Unnzp19tdHYW1IuVw9fmTgleOsb6+TI8Xh8SXM2dEK7JldD879nJrpHiU7HjRgh7mijbsQoGslZAmLj8OHD5nxwTcuY/PVy7r0n51ud03uODJ+2Nmmd+lA2bFku9045+YXO89aMTaqnUCgUCqUmRT93w0KClcP0Vgt3vrvJ3HrBTGT/x1LZvL9ur1qsCb2ZqCZXIxb8xK6KBH/VuREjxB1t1I0YRYMEK0dtKD4m3xuyxCRWevPQvpM2m5uJ1lf7jhbJXZPPMF+Lo1+PUxvodNyIEeKONupGjKJBgpWDRi7YY77mRpOrCx6YLzM3HLJXqXdGL/21Gb3SL3auLXQ6bsQIcUcbdSNG0SDByiGHSyuk28i1iYnsN764SvYfje42BrVlx+El//xC5zNk39GNdnVk6HTciBHijjbqRoyiQYIVY9lMct+0ZZu07V9gEqumd8yRRyatTlqnvpan5vznydGr+b9JqqNQKBQKJazCJHek9P/GrJeLHymUlbuO2lX11oZ90//5hc5nS0nZyS/Sri38VedGjBB3tFE3YhQNEqwcUritRErL69+9rdIZmn+pSbCmrO9nV0WOTseNGCHuaKNuxCgaJFiIraW7xpnkqt+0c+V4RYldHTk6HTdihLijjboRo2iQYCGWTlRWyKCZbU2CNWfzMLu6VtDpuBEjxB1t1I0YRYMEC7E0d+tzJrka+PFXpaKyzK6uFXQ6bsQIcUcbdSNG0SDBQuyUnTgqD05raRKsxTtG29W1hk7HjRgh7mijbsQoGiRYiJ2PNjxkkqsn5nw78i90TodOx40YIe5oo27EKBokWDGWzX2wcqVs2LJC7p188gud5659I6meQqFQKJSoCvfBQs6qOHFcxi7/rQyf39muqnX8VedGjBB3tFE3YhQNEizEzpaDBbKrZIW9uNbR6bgRI8QdbdSNGEWDBAsIQKfjRowQd7RRN2IUDRIsIACdjhsxQtzRRt2IUTRIsIAAdDpuxAhxRxt1I0bRIMECAtDpuBEjxB1t1I0YRYMECwhAp+NGjBB3tFE3YhQNEiwgAJ2OGzFC3NFG3YhRNEiwgAB0Om7ECHFHG3UjRtEgwQIC0Om4ESPEHW3UjRhFgwQLCECn40aMEHe0UTdiFA0SrBhriN9FSKFQKBRKXRW+ixCoBfxV50aMEHe0UTdiFA0SLCAAnY4bMULc0UbdiFE0SLCAAHQ6bsQIcUcbdSNG0SDBAgLQ6bgRI8QdbdSNGEWDBAsIQKfjRowQd7RRN2IUDRIsIACdjhsxQtzRRt2IUTRIsIAAdDpuxAhxRxt1I0bRIMEKWXFxsbRs2dJeXMVppxH2+oBOx40YIe5oo27EKBp80odo9erV0rFjR2cC5apHPNDpuBEjxB1t1I0YRYNP+hC1aNFCRowY4Uyg/PX9+/eXnj17+moRF3Q6bsQIcUcbdSNG0UifCSArO3fuNP9nmmANGjRIunbtKpWVldYaiAM6HTdihLijjboRo2ikzwRQLZkkWEOGDJH27dtLaWmpXY2YoNNxI0aIO9qoGzGKRvpMANWSSYJ10003SatWrWTz5s12dYJ+8aQ2fAqFQqFQKNEX/dwNS/pMANWSSYKlhg4dKt26dataCQAA6r30mQCqJdMEq6KiQjp06CCFhYXWGgAAoD5LnwmgWjJNsFReXp507tzZVwsAAOq79JkAAAAAskaCBQAAEDISLAAAgJCRYAEAAISMBAsN2rJly6RNmzbSpEkT6dSpk8ybN89exdB7o+jFCf7S0GXyxeZAXeF3NljQ726m/SEyQ4tDg6b3IRswYIAcOHBAevXqJV26dLFXMd555x25+uqr7cUNVqZfbA7UFX5nU0v3u5tpf4jMJEcYaED0r7U9e/aYx0VFRdKuXTtrjZPuu+8+uf/+++3FDVamX2wO1BV+Z1NL97ubaX+IzCRHGGhAmjVrJuXl5eaxfi9k06ZNrTVO+vnPfy6XXnqpnHPOOfL9739fNm3aZK/SoGT6xeZAXeF3NrV0v7uZ9ofITHKEgQakcePGiceVlZXSqFEjX+0p5513nkyaNMl0PnPmzJGrrrrKXqVBStVJA3HA72x6qX53M+0PkZnkCAMNSPPmzaWsrMw81r/Y9C84lxMnTsjZZ59tL26QUnXSQNzwO5ss1e9udfpDBEuOMNCAtG/fXrZu3Woeb968Wdq2bWutkUz/Iv785z9vL26QUnXSQNzwO5ss1e9udfpDBEuOMNCA9OjRQ/r16yclJSXSp08f6d69u72K8bWvfc2cZtCOeubMmfLb3/7WXqVBStVJA3HA72x6qX53M+0PkZnkCAMNSEFBgbRu3VpOP/10cwWN3gfG4++AtKP+5je/KWeccYb85Cc/kX379iXqGrJUnTQQB/zOppfqdzddf4jsJUcYAAAANUKCBQAAEDISLAAAgJCRYAEAAISMBAsAACBkJFgAAAAhI8ECAAAIGQkWAABAyEiwAAAAQkaCBQAAEDISLAAAgJCRYAEAAISMBAsAACBkJFgAAAAhI8ECAAAIGQkWAABAyEiwAAAAQkaCBQAAEDISLAAAgJCRYAEAAISMBAsAACBkJFgAAAAhI8ECkDPKysrsRQBQJ0iwAIRq+/bt8sMf/lCaNWsm//Vf/yVHjx7NqM7vtNOq1zV9/etftxfVqnHjxsmNN95oL3byH28mj+Mqk3284YYbTJyAXOf+bQCALPziF7+Q5557To4cOSKjRo2Su+66K6M6v0w+qFOp7uvCcOzYMWndurVJImuiviVVfpnsr8bnq1/9qokXkMvcvw0AYNEP0tGjR8s555wjl19+uezatStR99nPftYkUEo/RP2jSunq/NJ9UOfn58tll10mZ511lpx33nny0ksvmeX6Gq+ogwcPyvXXX2/28dprr5VDhw4ltqHrPPbYY/LlL39ZTj/9dBk/fnyibs2aNfKNb3xDPv/5z8u7775rlrVp00bWr19vHq9duzblfg8fPlxuvvnmxHN9jxdeeEGaN29uyttvvy2TJk0y+2y/Z1BSFfRYR4B0G2eeeaZ873vfkxkzZiTq0h13qmNTxcXF8vOf/1zOPvtsueqqq2Tfvn2JunSx2rNnjxmR1Nfp8We6jxonXR/IZcG9GAAE0A/SX//613LgwAEZPHiw3HLLLYk6/WA/fvy4ebx3716TCGVS55cuwdLkZuzYsSZBGzFihElePP7X/fnPf5YlS5aY9UaOHCm33357lfX+8Ic/SElJibz55psmEfBocjJo0CCTDH3lK18xy+644w55/PHHzeMnnnhCevbsmVjf07lzZ5kyZUriub5H9+7dTYIzZswY6dixo/Tt21cOHz6c9J5BiVTQY33tsGHDpKKiwmxbR8486Y471bGpHj16mOPTfdNk6ve//32iLl2s/ud//keGDh1qkjp9nOk+Tp48WX70ox8lngO5KLgXA4AA+kGqIzlKRz/OPffcRJ2OhLzxxhuyY8cOue6666RRo0YZ1fmlS7D8dFJ7UBJywQUXSHl5uXl84sQJueiiixJ1up5/1M3/Ok3Ydu/enXiuZs2aJVdeeaV5/NOf/tQkCDYdFdKk0eN/D90Pfa6x8tdX97HuiyZLmtB5I4KedMed6tjUhRdemBi10n1s1apVoi5drHQ0bv/+/ebxhg0bMt5HHfnSeAG5LLNeDAB89IPUG4nSD3P/qIZ+0H7nO98xo1U6uqWniDKp80uXYGkS069fP/PhraetgpKQM844wzz3SpMmTVKuZz//1Kc+ZUZd/DRR+eIXv2iSky984QtSWlpapV5pDLzERuk2Kysrqzz3C9rvTB7rflxxxRXmmHQUcP78+Ym6dMed6tiUvkaPUWl90Oia/dy/PW0Pme6jJsb+9wByUXAvBgAB9IN069at5rEmPJp8pLJy5Uo5//zz7cVGujr7Q93v3//9382ptokTJ5pTYUFJSMuWLQNv22Bv3/9cEyj/SJTnN7/5jfzxj3+Uq6++2q4yWrRoITt37kw8T/ce9vNsH3v0NKBeNOAfQUx33EHHpiNWenpQab0ei8d+XzvG3vY2b96ctK5KtY8aJx39AnJZ8m8DADjoB6nO29H5RY888ohJPDx6hZhOhNZTRw899JD88pe/zKjOL9UHtUdHv3Q0RLehc7/86+qoiJfk3HbbbVJYWGhGm55++mkzMd5jb9//XG8j8OSTT8r06dOrzFOaMGGC2f5TTz2VWOb3k5/8pMrE8XTvYT/P9nG7du3MRQY6YuZNJvekO+6gY7v11lvNhHydZ6VzsLp165aoS7ffOldr4MCBZg7W7373u4z3UWP5s5/9LPEcyEXBvRgABNAP0kcffdTcz+qaa64xH7CeqVOnmg9vrevSpUuVeUfp6vx0+6mK0gnuelWfjp5oMuD/UO/atas0bdrUPNZt677pe3Xq1EmWL1+eWC9d0qBX2rVv396M4ugomUcn9Ot6RUVFiWV+OkpjX0Xol+55to/nzJkj3/rWt8wpOk1c/PeVSnfcQcemc6J0bpmeytPJ5/rck26/NSY//vGPzeiUJmiZ7qPGSdcHchkJFoCs2R+6DUFeXp5cfPHF9uIEPRWmk8W9U6dIbcuWLWbivTeHD8hVDa+XBFBjDTHB0nt4vfPOO/biKvQKSb06EsH03lw6CgnkuobXSwIAAESMBAsAACBkJFgAAAAhI8ECAAAIGQkWAABAyEiwAAAAQkaCBQAAEDISLAAAgJCRYAEAAISMBAsAACBkJFgAAAAhI8ECAAAIGQkWAABAyEiwAAAAQkaCBQAAEDISLAAAgJCRYAEAAISMBAsAACBkJFgAAAAhI8ECAAAI2f8HXuNIYxUnAlsAAAAASUVORK5CYII=" } }, "cell_type": "markdown", "id": "3d629a02", "metadata": {}, "source": [ "# FAQ 10: How fast is the FIL backend relative to alternatives?\n", "There are relatively few inference servers which offer support for tree-based models, so it is somewhat difficult to find a useful point of comparison. Sometimes, users ask for a comparison of the FIL backend's performance to e.g. native XGBoost performance, but because XGBoost does not offer an inference *serving* solution, it is very difficult to create an apples-to-apples comparison.\n", "\n", "What we can do is compare the underlying FIL execution library directly to the underlying XGBoost library without invoking Triton as a server. In the following chart, we see this comparison for a fairly typical XGBoost model of moderate complexity. For this benchmark, GPU execution took place on a single A100 provided via a GCP a2 instance. CPU execution took place on a GCP n2-standard-16 instance.\n", "![Throughput%20vs%20Batch%20Size%20for%20XGBoost_FIL%20Only.png](attachment:Throughput%20vs%20Batch%20Size%20for%20XGBoost_FIL%20Only.png)\n", "\n", "As we can see, FIL significantly outperforms CPU XGBoost at small batch sizes and more dramatically outperforms it on GPU for higher batch sizes. [^](#Table-of-Contents)\n", "\n", "## FAQ 10.1 How fast is the FIL backend on CPU vs on GPU?\n", "\n", "Naturally, the answer to this question depends on the CPU and GPU hardware you have available. To give a sense of the typical performance differential one might expect, however, we present throughput-latency curves on CPU and GPU for a typical XGBoost model below. For this benchmark, GPU execution took place on a single A100 provided via a GCP a2 instance. CPU execution took place on a GCP n2-standard-16 instance.\n", "![Throughput%20vs%20p99%20Latency%20With%20Triton.png](attachment:Throughput%20vs%20p99%20Latency%20With%20Triton.png)\n", "\n", "In the very low latency domain (small batch size, light server traffic), we see that CPU execution can outperform GPU, but above around half a millisecond latency (slightly larger batches or increased traffic), GPU clearly dominates. [^](#Table-of-Contents)\n", "\n", "## FAQ 10.2 How fast is the FIL backend relative to the ONNX backend?\n", "\n", "While analysis of the underlying library is a useful starting place for understanding the performance of the FIL backend, it does not provide an entirely satisfactory point of comparison, since it does not provide a direct comparison of latency/throughput achievable on FIL vs another solution.\n", "\n", "Triton does offer one other backend that supports *some* tree-based models, however. Using ONNX, we can get a more direct sense of how the FIL backend performs in an end-to-end comparison. In example 10, we will compare performance of a Scikit-Learn model when deployed with the FIL backend to the same model deployed with the ONNX backend and demonstrate that the FIL backend generally significantly outperforms ONNX on GPU. [^](#Table-of-Contents)\n", "\n", "\n", "
\n", " NOTE: We have restricted the following example to GPU execution because the relative CPU performance of ONNX vs FIL varies on a wide enough range of factors to make it difficult to draw generally-applicable conclusions. This topic will be covered separately elsewhere. CPU optimizations continue to be a focus for FIL backend development, and we expect its performance to continue to improve in areas where it currently underperforms the ONNX backend. For more details, see this Github issue.\n", "
" ] }, { "cell_type": "markdown", "id": "ea4fa3e5", "metadata": {}, "source": [ "## $\\color{#76b900}{\\text{Example 10: Comparing the FIL and ONNX backends}}$\n", "\n", "For this example, we will abandon our current example model and instead train up a Scikit-Learn RandomForestClassifier. This model can be converted to ONNX using the skl2onnx package. While *some* XGBoost models can also be converted to ONNX in a similar way, there is often difficulty with matching versions of each of the relevant packages, so we will stick to Scikit-Learn for the moment. We will then deploy the model using both the FIL and ONNX backends and compare their performance using `perf_analyzer`.\n", "\n", "\n", "
\n", "VERSION NOTE: Remember that when using Scikit-learn models, the version of Treelite used for serialization must match that used in the Triton server version you are using. See the compatibility table above for info on which Treelite version is required.\n", "
" ] }, { "cell_type": "code", "execution_count": null, "id": "64e2331f", "metadata": {}, "outputs": [], "source": [ "# Train model\n", "if USE_GPU:\n", " skl_model = train_skl(X, y, n_trees=5)" ] }, { "cell_type": "code", "execution_count": null, "id": "81a58e7e", "metadata": {}, "outputs": [], "source": [ "# Serialize model to Triton model repo and write out config file for FIL\n", "if USE_GPU:\n", " import treelite\n", "\n", " fil_name = 'fil_model'\n", " fil_dir = os.path.join(MODEL_REPO, fil_name)\n", " fil_versioned_dir = os.path.join(fil_dir, '1')\n", "\n", " os.makedirs(fil_versioned_dir, exist_ok=True)\n", "\n", " fil_path = os.path.join(fil_versioned_dir, 'checkpoint.tl')\n", "\n", " tl_model = treelite.sklearn.import_model(skl_model)\n", " tl_model.serialize(fil_path)\n", "\n", " config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"treelite_checkpoint\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"false\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"true\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"SPARSE\" }}\n", " }},\n", " {{\n", " key: \"algo\"\n", " value: {{ string_value: \"ALGO_AUTO\" }}\n", " }},\n", " {{\n", " key: \"blocks_per_sm\"\n", " value: {{ string_value: \"0\" }}\n", " }},\n", " {{\n", " key: \"use_experimental_optimizations\"\n", " value: {{ string_value: \"true\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", " config_path = os.path.join(fil_dir, 'config.pbtxt')\n", " with open(config_path, 'w') as file_:\n", " file_.write(config_text)" ] }, { "cell_type": "code", "execution_count": null, "id": "1222b126", "metadata": {}, "outputs": [], "source": [ "# Serialize model to Triton model repo and write out config file for ONNX\n", "\n", "if USE_GPU:\n", " from skl2onnx import convert_sklearn\n", " from skl2onnx.common.data_types import FloatTensorType\n", " onnx_name = 'onnx_model'\n", " onnx_dir = os.path.join(MODEL_REPO, onnx_name)\n", " onnx_versioned_dir = os.path.join(onnx_dir, '1')\n", "\n", " os.makedirs(onnx_versioned_dir, exist_ok=True)\n", "\n", " onnx_path = os.path.join(onnx_versioned_dir, 'model.onnx')\n", "\n", " onnx_model = convert_sklearn(\n", " skl_model,\n", " initial_types=[('input', FloatTensorType([None, NUM_FEATURES]))],\n", " target_opset=12,\n", " options={'zipmap': False}\n", " )\n", " outputs = {output.name: output for output in onnx_model.graph.output}\n", " onnx_model.graph.output.remove(outputs['label'])\n", " with open(onnx_path, 'wb') as file_:\n", " file_.write(onnx_model.SerializeToString())\n", "\n", " config_text = f\"\"\"platform: \"onnxruntime_onnx\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"probabilities\"\n", " data_type: TYPE_FP32\n", " dims: [ 2 ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"intra_op_thread_count\"\n", " value: {{ string_value: \"8\" }}\n", " }},\n", " {{\n", " key: \"cudnn_conv_algo_search\"\n", " value: {{ string_value: \"1\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", " config_path = os.path.join(onnx_dir, 'config.pbtxt')\n", " with open(config_path, 'w') as file_:\n", " file_.write(config_text)\n", "\n", " time.sleep(10)" ] }, { "cell_type": "markdown", "id": "a97a4945", "metadata": {}, "source": [ "Now, let's use `perf_analyzer` to see how each of these models perform under a bit of traffic." ] }, { "cell_type": "code", "execution_count": null, "id": "3f4df7be", "metadata": {}, "outputs": [], "source": [ "if USE_GPU and is_triton_ready():\n", " !perf_analyzer -m {fil_name} -i GRPC --concurrency-range 16:16\n", " !perf_analyzer -m {onnx_name} -i GRPC --concurrency-range 16:16" ] }, { "cell_type": "markdown", "id": "049d74e1", "metadata": {}, "source": [ "Your exact results will depend on available hardware, but on the GPU used to test this notebook, we see that FIL offers about 5x the throughput with about 5x better average latency and about 20x better p99 latency. As server load and client batch size grow, FIL's relative advantage over ONNX in both throughput and latency continue to increase. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "17b66ec7", "metadata": {}, "outputs": [], "source": [ "if USE_GPU and is_triton_ready():\n", " !perf_analyzer -m {fil_name} -i GRPC --concurrency-range 16:16 -b 128\n", " !perf_analyzer -m {onnx_name} -i GRPC --concurrency-range 16:16 -b 128" ] }, { "cell_type": "markdown", "id": "5697140c", "metadata": {}, "source": [ "# FAQ 11: How do I submit many inference requests in parallel?\n", "Because of the overhead of transferring an input array over the network, it is possible to achieve a significant performance improvement by submitting multiple requests in parallel. In this way, network transfer overlaps with inference, and the model can begin processing some samples while others are still being transferred.\n", "\n", "In the following example, we will make use of the `async_infer` method of the Triton Python client to break a large batch up into chunks and submit them concurrently. Correctly selecting a chunk size that results in improved performance depends on a number of factors. It is often best to proceed empirically, testing a range of chunks to find an optimum for your application. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "b9bf4960", "metadata": {}, "source": [ "# Example 11: Submitting requests in parallel with the Python client\n", "In this example, we will try processing a number of samples equal to the selected `max_batch_size` in two different ways. In the first, we will submit the entire dataset at once. In the second, we will break the dataset into 5 chunks and submit them asynchronously. In each case, we will look at how long it took to process the entire dataset.\n", "\n", "The results of this experiment will vary depending on hardware and your model details, so you may try experimenting with the number of chunks to find an optimum for your deployment." ] }, { "cell_type": "code", "execution_count": null, "id": "ba86ed5f", "metadata": {}, "outputs": [], "source": [ "NUM_CHUNKS = 5" ] }, { "cell_type": "code", "execution_count": null, "id": "b2163c3a", "metadata": {}, "outputs": [], "source": [ "import concurrent.futures\n", "\n", "def triton_predict_async(model_name, arr):\n", " triton_input = triton_grpc.InferInput('input__0', arr.shape, 'FP32')\n", " triton_input.set_data_from_numpy(arr)\n", " triton_output = triton_grpc.InferRequestedOutput('output__0')\n", " \n", " future_result = concurrent.futures.Future()\n", " \n", " def callback(result, error):\n", " if error is None:\n", " future_result.set_result(result.as_numpy('output__0'))\n", " else:\n", " future_result.set_exception(error)\n", " \n", " client.async_infer(\n", " model_name,\n", " model_version='1',\n", " inputs=[triton_input],\n", " outputs=[triton_output],\n", " callback=callback\n", " )\n", " return future_result" ] }, { "cell_type": "code", "execution_count": null, "id": "9040051c", "metadata": {}, "outputs": [], "source": [ "large_input = np.random.rand(max_batch_size, NUM_FEATURES).astype('float32')\n", "chunks = np.array_split(large_input, NUM_CHUNKS)" ] }, { "cell_type": "code", "execution_count": null, "id": "1f783448", "metadata": {}, "outputs": [], "source": [ "%%timeit\n", "triton_predict(MODEL_NAME, large_input)" ] }, { "cell_type": "code", "execution_count": null, "id": "15414d3d", "metadata": {}, "outputs": [], "source": [ "%%timeit\n", "concurrent.futures.wait([triton_predict_async(MODEL_NAME, chunk_) for chunk_ in chunks])" ] }, { "cell_type": "markdown", "id": "90990473", "metadata": {}, "source": [ "Again, your results may vary depending on hardware and model, but for the configuration used while testing this notebook, splitting the data into 5 chunks in order to overlap transfer and data processing resulted in an approximately 2x speedup. [^](#Table-of-Contents)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b7b318b9", "metadata": {}, "source": [ "# $\\color{#76b900}{\\text{FAQ 12: How do I retrieve Shapley values for model explainability?}}$\n", "\n", "The FIL backend allows models to return [Shapley values](https://en.wikipedia.org/wiki/Shapley_value) for each feature in order to explain which features were most important in the model's decision. By configuring the model with an extra output, we will get Shapley values along with the main prediction output. This extra output must have the name `treeshap_output`, and its size should be equal to the number of features plus one, where the final extra column stores a \"bias\" term.\n", "\n", "Because of the increased latency from transferring such a large output array, always computing and returning Shapley values may have an undesirable impact on latency. When model explainability is required however, this additional output can provide powerful insight on how your model comes to its decisions for any given input. [^](#Table-of-Contents)\n", "\n", "\n", "
\n", "VERSION NOTE: GPU Shapley value support was added with version 22.03 of the FIL backend.\n", "
\n", "
\n", " NOTE: Experimental support for CPU Shapley values has been added in version 23.04.\n", "
" ] }, { "cell_type": "markdown", "id": "cddbd34a", "metadata": {}, "source": [ "# $\\color{#76b900}{\\text{Example 12: Retrieving Shapley Values}}$\n", "In order to have the FIL backend compute and return Shapley values, we simply add an additional output to the `config.pbtxt` file for that model with the correct number of dimensions. For instance, if our model has 500 input features, we would add the following output entry:\n", "\n", "```pbtxt\n", " {\n", " name: \"treeshap_output\"\n", " data_type: TYPE_FP32\n", " dims: [ 501 ]\n", " }\n", "```\n", "\n", "Let's try this now with our example model." ] }, { "cell_type": "code", "execution_count": null, "id": "bf0d177a", "metadata": {}, "outputs": [], "source": [ "shap_dim = NUM_FEATURES + 1\n", "\n", "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [ \n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ] \n", " }} \n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }},\n", " {{\n", " name: \"treeshap_output\"\n", " data_type: TYPE_FP32\n", " dims: [ {shap_dim} ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"{model_format}\" }}\n", " }},\n", " {{\n", " key: \"predict_proba\"\n", " value: {{ string_value: \"false\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"{classifier_string}\" }}\n", " }},\n", " {{\n", " key: \"threshold\"\n", " value: {{ string_value: \"0.5\" }}\n", " }},\n", " {{\n", " key: \"storage_type\"\n", " value: {{ string_value: \"AUTO\" }}\n", " }},\n", " {{\n", " key: \"algo\"\n", " value: {{ string_value: \"TREE_REORG\" }}\n", " }},\n", " {{\n", " key: \"blocks_per_sm\"\n", " value: {{ string_value: \"0\" }}\n", " }},\n", " {{\n", " key: \"use_experimental_optimizations\"\n", " value: {{ string_value: \"true\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", "config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", "with open(config_path, 'w') as file_:\n", " file_.write(config_text)\n", "time.sleep(10)" ] }, { "cell_type": "markdown", "id": "64edee4c", "metadata": {}, "source": [ "To make use of this new output, we will need to slightly adjust how we use the Python client to submit requests. The following function will allow us to retrieve the Shapley values in addition to the usual model output. [^](#Table-of-Contents)" ] }, { "cell_type": "code", "execution_count": null, "id": "38c612a9", "metadata": {}, "outputs": [], "source": [ "def triton_predict_shap(model_name, arr):\n", " triton_input = triton_grpc.InferInput('input__0', arr.shape, 'FP32')\n", " triton_input.set_data_from_numpy(arr)\n", " triton_output = triton_grpc.InferRequestedOutput('output__0')\n", " triton_shap_output = triton_grpc.InferRequestedOutput('treeshap_output')\n", " response = client.infer(model_name, model_version='1', inputs=[triton_input], outputs=[triton_output, triton_shap_output])\n", " return (response.as_numpy('output__0'), response.as_numpy('treeshap_output'))" ] }, { "cell_type": "code", "execution_count": null, "id": "c8378d72", "metadata": {}, "outputs": [], "source": [ "if is_triton_ready():\n", " batch = convert_to_numpy(X[0:2])\n", "\n", " output, shap_out = triton_predict_shap(MODEL_NAME, batch)\n", " print(\"The model output for these samples was: \")\n", " print(output)\n", " print(\"\\nThe Shapley values for these samples were: \")\n", " print(shap_out)\n", " print(\"\\n The most significant feature for each sample was: \")\n", " print(np.argmax(shap_out, axis=1))" ] }, { "cell_type": "markdown", "id": "790bafdd-06ba-446e-a996-699829b6fbe9", "metadata": {}, "source": [ "# FAQ 13: How do I serve a learning-to-rank model?" ] }, { "cell_type": "markdown", "id": "1f5901fe-baf9-4156-9aec-de63413264ae", "metadata": {}, "source": [ "Learning-to-rank models are treated as regression models in the FIL backend. In the configuration file, make sure to set `output_class=\"false\"`.\n", "\n", "If the learning-to-rank model was trained with dense data, no extra preparation is needed. As in Example 13.2, you can obtain predictions just like other regression models.\n", "\n", "Special care is needed when the learning-to-rank model was trained with sparse data. Since the FIL backend does not yet support a sparse input, we need to use an equivalent dense input instead. Example 13.3 shows how to convert a sparse input into an equivalent dense input. [^](#Table-of-Contents)" ] }, { "cell_type": "markdown", "id": "05c9c398-c960-4318-99c1-4f7101b99f63", "metadata": {}, "source": [ "# Example 13.1: Configuring an XGBoost ranking model in the FIL backend" ] }, { "cell_type": "code", "execution_count": null, "id": "4a41b760-8c9a-4d79-94c2-ae6996edbc26", "metadata": {}, "outputs": [], "source": [ "config_text = f\"\"\"backend: \"fil\",\n", "max_batch_size: {max_batch_size}\n", "input [\n", " {{ \n", " name: \"input__0\"\n", " data_type: TYPE_FP32\n", " dims: [ {NUM_FEATURES} ]\n", " }}\n", "]\n", "output [\n", " {{\n", " name: \"output__0\"\n", " data_type: TYPE_FP32\n", " dims: [ 1 ]\n", " }}\n", "]\n", "instance_group [{{ kind: {instance_kind} }}]\n", "parameters [\n", " {{\n", " key: \"model_type\"\n", " value: {{ string_value: \"xgboost_json\" }}\n", " }},\n", " {{\n", " key: \"output_class\"\n", " value: {{ string_value: \"false\" }}\n", " }}\n", "]\n", "\n", "dynamic_batching {{}}\"\"\"\n", "\n", "config_path = os.path.join(MODEL_DIR, 'config.pbtxt')\n", "with open(config_path, 'w') as file_:\n", " file_.write(config_text)\n", "time.sleep(10)" ] }, { "cell_type": "markdown", "id": "30772dbb-845a-4964-8c08-618f5998a971", "metadata": {}, "source": [ "# Example 13.2: Running inferece with an XGBoost ranking model and a dense input" ] }, { "cell_type": "code", "execution_count": null, "id": "bc8e1401-851f-407d-90ee-70cdf4d71d9a", "metadata": {}, "outputs": [], "source": [ "triton_result = triton_predict(MODEL_NAME, X_test) # X_test is a NumPy array" ] }, { "cell_type": "markdown", "id": "c0993c1a-a510-4958-9e2c-e2f9d3f4856e", "metadata": {}, "source": [ "# Example 13.3: Running inferece with an XGBoost ranking model and a sparse input" ] }, { "cell_type": "markdown", "id": "771525a1-1747-4b83-9c47-2456c50f6384", "metadata": {}, "source": [ "In the sparse matrix, zeros are synonymous with the missing values, whereas in the dense array, zeros are distinct from the missing values. So to achieve a consistent behavior, do the following:\n", "1. Convert the sparse matrix to a dense NumPy array, by calling `toarray()` method.\n", "2. Replace zeros with `np.nan` (missing value)." ] }, { "cell_type": "code", "execution_count": null, "id": "55d8fa91-ada4-4939-a5f0-8dd4debf607a", "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_svmlight_file\n", "\n", "X_test, y_test = load_svmlight_file(\"data.libsvm\")\n", "# X_test is scipy.sparse.csr_matrix type\n", "# Since Triton client only accepts NumPy arrays as inputs, convert it to NumPy array.\n", "X_test = X_test.toarray()\n", "X_test[X_test == 0] = np.nan # Replace zero with np.nan\n", "\n", "# Now X_test can be passed to triton_predict().\n", "triton_result = triton_predict(MODEL_NAME, X_test)" ] }, { "cell_type": "markdown", "id": "d56810d0", "metadata": {}, "source": [ "# Cleanup\n", "We will finish up by taking down the server container." ] }, { "cell_type": "code", "execution_count": null, "id": "5270f694", "metadata": {}, "outputs": [], "source": [ "!docker rm -f tritonserver" ] }, { "cell_type": "markdown", "id": "30a6c3be", "metadata": {}, "source": [ "# Conclusion\n", "\n", "By combining pieces from each of the code examples provided in this notebook, you should be able to handle most common tasks involved with deploying and making use of a model with the FIL backend for Triton. If you have additional questions, please consider [submitting an issue](https://github.com/triton-inference-server/fil_backend/issues/new) on Github, and we may be able to add an entry to this notebook to cover your specific use case.\n", "\n", "For more information on anything covered in this notebook, we recommend checking out the [FIL backend documentation](https://github.com/triton-inference-server/fil_backend#triton-inference-server-fil-backend) the introductory [FIL example notebook](https://github.com/triton-inference-server/fil_backend/tree/main/notebooks/categorical-fraud-detection#fraud-detection-with-categorical-xgboost), and [Triton's documentation](https://github.com/triton-inference-server/server/blob/main/README.md#documentation). [^](#Table-of-Contents)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.15" } }, "nbformat": 4, "nbformat_minor": 5 }