{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# BentoML Example: PyTorch GPU Serving\n", "\n", "BentoML makes moving trained ML models to production easy:\n", "\n", " Package models trained with any ML framework and reproduce them for model serving in production\n", " Deploy anywhere for online API serving or offline batch serving\n", " High-Performance API model server with adaptive micro-batching support\n", " Central hub for managing models and deployment process via Web UI and APIs\n", " Modular and flexible design making it adaptable to your infrastrcuture\n", "\n", "BentoML is a framework for serving, managing, and deploying machine learning models. It is aiming to bridge the gap between Data Science and DevOps, and enable teams to deliver prediction services in a fast, repeatable, and scalable way. Before reading this example project, be sure to check out the Getting started guide to learn about the basic concepts in BentoML.\n", "\n", "This notebook demonstrates how to serve your PyTorch model with BentoML, building a Docker Images that has GPU supports. Please refers to [GPU Serving guides](https://docs.bentoml.org/en/latest/guides/gpu_serving.html) for more information.\n", "\n", "This is an extension of [PyTorch's text_sentiment_ngrams_tutorial](https://github.com/pytorch/tutorials/blob/master/beginner_source/text_sentiment_ngrams_tutorial.py)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "!pip install -q bentoml torch==1.8.1+cu111 torchtext==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "We are building a simple news classification model with PyTorch, using the dataset **AG_NEWS** provided by `torchtext` library" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import os\n", "import time\n", "from collections import Counter\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import random_split\n", "\n", "from torchtext.datasets import AG_NEWS\n", "from torchtext.data.utils import get_tokenizer\n", "from torchtext.vocab import Vocab\n", "\n", "from bentoml import BentoService, api, artifacts, env\n", "from bentoml.adapters import JsonInput, JsonOutput\n", "from bentoml.frameworks.pytorch import PytorchModelArtifact\n", "from bentoml.service.artifacts.pickle import PickleArtifact" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CUDA: True -> cuda:0\n" ] } ], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"CUDA: {torch.cuda.is_available()} -> {device}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preprocessing Data Pipelines" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup tokenizer and vocab" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Please refers to [`torchtext API`](https://pytorch.org/text/stable/index.html)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def get_tokenizer_vocab(dataset=AG_NEWS, tokenizer_fn='basic_english', root_data_dir='dataset'):\n", " print('Getting tokenizer and vocab...')\n", " tokenizer = get_tokenizer(tokenizer_fn)\n", " train_ = dataset(root=root_data_dir, split='train')\n", " counter = Counter()\n", " for (label, line) in train_:\n", " counter.update(tokenizer(line))\n", " vocab = Vocab(counter, min_freq=1)\n", " return tokenizer, vocab" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Getting tokenizer and vocab...\n" ] } ], "source": [ "tokenizer, vocab = get_tokenizer_vocab()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setup pipeline...\n" ] } ], "source": [ "def get_pipeline(tokenizer, vocab):\n", " print('Setup pipeline...')\n", " text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]\n", " label_pipeline = lambda x: int(x) - 1\n", " return text_pipeline, label_pipeline\n", "\n", "text_pipeline, label_pipeline = get_pipeline(tokenizer, vocab)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def get_train_valid_split(train_iter):\n", " train_dataset = list(train_iter)\n", " num_train = int(len(train_dataset) * 0.95)\n", " split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])\n", " return split_train_, split_valid_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generating data batch and iterator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are going to use [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def collate_batch(batch):\n", " label_list, text_list, offsets = [], [], [0]\n", " for (_label, _text) in batch:\n", " label_list.append(label_pipeline(_label))\n", " processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)\n", " text_list.append(processed_text)\n", " offsets.append(processed_text.size(0))\n", " label_list = torch.tensor(label_list, dtype=torch.int64)\n", " offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)\n", " text_list = torch.cat(text_list)\n", " return label_list.to(device), text_list.to(device), offsets.to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining our Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model is composed of the `nn.EmbeddingBag` layer plus a linear layer for the classification purpose. \n", "\n", "`nn.EmbeddingBag` with the default mode of “mean” computes the mean value of a “bag” of embeddings. Although the text entries here have different lengths, `nn.EmbeddingBag` module requires no padding here since the text lengths are saved in offsets.\n", "\n", "source: [Text classification with the torchtext library](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![sentiment_model.png](./text_sentiment_ngrams_model.png)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class TextClassificationModel(nn.Module):\n", "\n", " def __init__(self, vocab_size, embed_dim, num_class):\n", " super().__init__()\n", " self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)\n", " self.fc = nn.Linear(embed_dim, num_class)\n", " self.init_weights()\n", "\n", " def init_weights(self):\n", " init_range = 0.5\n", " self.embedding.weight.data.uniform_(-init_range, init_range)\n", " self.fc.weight.data.uniform_(-init_range, init_range)\n", " self.fc.bias.data.zero_()\n", "\n", " def forward(self, text, offsets=None):\n", " embedded = self.embedding(text, offsets=offsets)\n", " return self.fc(embedded)\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def summary(model):\n", " count_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", " print(f'\\nThe model has {count_params:,} trainable parameters')\n", " print(f\"Model summary:\\n{model}\\nDetails:\")\n", " for n, p in model.named_parameters():\n", " print(f'name: {n}, shape: {p.shape}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparing Hyperparameters" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Hyperparameters\n", "EPOCHS = 10 # epoch\n", "LR = 5 # learning rate\n", "BATCH_SIZE = 64 # batch size for training\n", "EMBEDDING_SIZE = 64 # embedding size\n", "\n", "train_iter = AG_NEWS(root='dataset', split='train')\n", "num_class = len(set([label for (label, text) in train_iter]))\n", "vocab_size = len(vocab)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "The model has 6,132,228 trainable parameters\n", "Model summary:\n", "TextClassificationModel(\n", " (embedding): EmbeddingBag(95812, 64, mode=mean)\n", " (fc): Linear(in_features=64, out_features=4, bias=True)\n", ")\n", "Details:\n", "name: embedding.weight, shape: torch.Size([95812, 64])\n", "name: fc.weight, shape: torch.Size([4, 64])\n", "name: fc.bias, shape: torch.Size([4])\n" ] } ], "source": [ "model = TextClassificationModel(vocab_size, EMBEDDING_SIZE, num_class).to(device)\n", "summary(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define our train and evaluate loop" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def train(model, data_loader, optimizer, criterion, epoch):\n", " model.train()\n", " total_acc, total_count = 0, 0\n", " log_interval = 500\n", "\n", " for idx, (label, text, offsets) in enumerate(data_loader):\n", " optimizer.zero_grad()\n", " predicted = model(text, offsets=offsets)\n", " loss = criterion(predicted, label)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)\n", " optimizer.step()\n", " total_acc += (predicted.argmax(1) == label).sum().item()\n", " total_count += label.size(0)\n", " if idx % log_interval == 0 and idx > 0:\n", " print(f'| epoch {epoch:3d} | {idx:5d}/{len(data_loader):5d} batches | accuracy {(total_acc / total_count):5.3f}')\n", " total_acc, total_count = 0, 0\n", "\n", "\n", "def evaluate(model, data_loader, criterion):\n", " model.eval()\n", " total_acc, total_count = 0, 0\n", "\n", " with torch.no_grad():\n", " for idx, (label, text, offsets) in enumerate(data_loader):\n", " predited_label = model(text, offsets)\n", " loss = criterion(predited_label, label)\n", " total_acc += (predited_label.argmax(1) == label).sum().item()\n", " total_count += label.size(0)\n", " return total_acc / total_count" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training and save our model locally" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "| epoch 1 | 500/ 1782 batches | accuracy 0.689\n", "| epoch 1 | 1000/ 1782 batches | accuracy 0.856\n", "| epoch 1 | 1500/ 1782 batches | accuracy 0.876\n", "-----------------------------------------------------------\n", "| end of epoch 1 | time: 8.45s | valid accuracy 0.889\n", "-----------------------------------------------------------\n", "| epoch 2 | 500/ 1782 batches | accuracy 0.896\n", "| epoch 2 | 1000/ 1782 batches | accuracy 0.901\n", "| epoch 2 | 1500/ 1782 batches | accuracy 0.904\n", "-----------------------------------------------------------\n", "| end of epoch 2 | time: 7.70s | valid accuracy 0.898\n", "-----------------------------------------------------------\n", "| epoch 3 | 500/ 1782 batches | accuracy 0.919\n", "| epoch 3 | 1000/ 1782 batches | accuracy 0.913\n", "| epoch 3 | 1500/ 1782 batches | accuracy 0.915\n", "-----------------------------------------------------------\n", "| end of epoch 3 | time: 7.74s | valid accuracy 0.904\n", "-----------------------------------------------------------\n", "| epoch 4 | 500/ 1782 batches | accuracy 0.924\n", "| epoch 4 | 1000/ 1782 batches | accuracy 0.921\n", "| epoch 4 | 1500/ 1782 batches | accuracy 0.922\n", "-----------------------------------------------------------\n", "| end of epoch 4 | time: 7.77s | valid accuracy 0.905\n", "-----------------------------------------------------------\n", "| epoch 5 | 500/ 1782 batches | accuracy 0.932\n", "| epoch 5 | 1000/ 1782 batches | accuracy 0.927\n", "| epoch 5 | 1500/ 1782 batches | accuracy 0.929\n", "-----------------------------------------------------------\n", "| end of epoch 5 | time: 7.16s | valid accuracy 0.894\n", "-----------------------------------------------------------\n", "| epoch 6 | 500/ 1782 batches | accuracy 0.940\n", "| epoch 6 | 1000/ 1782 batches | accuracy 0.942\n", "| epoch 6 | 1500/ 1782 batches | accuracy 0.943\n", "-----------------------------------------------------------\n", "| end of epoch 6 | time: 7.77s | valid accuracy 0.911\n", "-----------------------------------------------------------\n", "| epoch 7 | 500/ 1782 batches | accuracy 0.944\n", "| epoch 7 | 1000/ 1782 batches | accuracy 0.945\n", "| epoch 7 | 1500/ 1782 batches | accuracy 0.944\n", "-----------------------------------------------------------\n", "| end of epoch 7 | time: 8.14s | valid accuracy 0.914\n", "-----------------------------------------------------------\n", "| epoch 8 | 500/ 1782 batches | accuracy 0.944\n", "| epoch 8 | 1000/ 1782 batches | accuracy 0.945\n", "| epoch 8 | 1500/ 1782 batches | accuracy 0.944\n", "-----------------------------------------------------------\n", "| end of epoch 8 | time: 8.92s | valid accuracy 0.913\n", "-----------------------------------------------------------\n", "| epoch 9 | 500/ 1782 batches | accuracy 0.947\n", "| epoch 9 | 1000/ 1782 batches | accuracy 0.946\n", "| epoch 9 | 1500/ 1782 batches | accuracy 0.946\n", "-----------------------------------------------------------\n", "| end of epoch 9 | time: 8.10s | valid accuracy 0.913\n", "-----------------------------------------------------------\n", "| epoch 10 | 500/ 1782 batches | accuracy 0.945\n", "| epoch 10 | 1000/ 1782 batches | accuracy 0.948\n", "| epoch 10 | 1500/ 1782 batches | accuracy 0.947\n", "-----------------------------------------------------------\n", "| end of epoch 10 | time: 9.61s | valid accuracy 0.913\n", "-----------------------------------------------------------\n", "Checking the results of test dataset.\n", "test accuracy 0.911\n" ] } ], "source": [ "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.SGD(model.parameters(), lr=LR)\n", "scheduler = optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)\n", "total_accu = None\n", "\n", "train_iter, test_iter = AG_NEWS(root='dataset')\n", "test_dataset = list(test_iter)\n", "split_train_, split_valid_ = get_train_valid_split(train_iter)\n", "\n", "train_data_loader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n", "valid_data_loader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n", "test_data_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)\n", "\n", "for epoch in range(1, EPOCHS + 1):\n", " epoch_start_time = time.time()\n", " train(model, train_data_loader, optimizer, criterion, epoch)\n", " accu_val = evaluate(model, valid_data_loader, criterion)\n", " if total_accu is not None and total_accu > accu_val:\n", " scheduler.step()\n", " else:\n", " total_accu = accu_val\n", " torch.save(model.state_dict(), 'model/pytorch_model.pt')\n", " print('-' * 59)\n", " print(f'| end of epoch {epoch:1d} | time: {time.time() - epoch_start_time:5.2f}s | valid accuracy {accu_val:8.3f}')\n", " print('-' * 59)\n", "\n", "print('Checking the results of test dataset.')\n", "accu_test = evaluate(model, test_data_loader, criterion)\n", "print('test accuracy {:8.3f}'.format(accu_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining our BentoService\n", "\n", "Please refers to our [GPU Serving guide](https://docs.bentoml.org/en/latest/guides/gpu_serving.html) to setup your environment correctly.\n", "\n", "We will be using Docker images provided by *BentoML* : `bentoml/model-server:0.12.1-py38-gpu` to prepare our CUDA-enabled images." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting bento_svc.py\n" ] } ], "source": [ "%%writefile bento_svc.py\n", "\n", "from bentoml import BentoService, api, artifacts, env\n", "from bentoml.adapters import JsonInput, JsonOutput\n", "from bentoml.frameworks.pytorch import PytorchModelArtifact\n", "from bentoml.service.artifacts.pickle import PickleArtifact\n", "from train import get_pipeline\n", "import torch\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "@env(conda_dependencies=['pytorch','torchtext', 'cudatoolkit=11.1'], conda_channels=['pytorch', 'nvidia'], docker_base_image=\"bentoml/model-server:0.12.1-py38-gpu\")\n", "@artifacts([PytorchModelArtifact(\"model\"), PickleArtifact(\"tokenizer\"), PickleArtifact(\"vocab\")])\n", "class PytorchService(BentoService):\n", " def __init__(self):\n", " super().__init__()\n", " self.news_label = {1: 'World',\n", " 2: 'Sports',\n", " 3: 'Business',\n", " 4: 'Sci/Tec'}\n", "\n", " def classify_categories(self, sentence):\n", " text_pipeline, _ = get_pipeline(self.artifacts.tokenizer, self.artifacts.vocab)\n", " with torch.no_grad():\n", " text = torch.tensor(text_pipeline(sentence)).to(device)\n", " offsets = torch.tensor([0]).to(device)\n", " output = self.artifacts.model(text, offsets=offsets)\n", " return output.argmax(1).item() + 1\n", "\n", " @api(input=JsonInput(), output=JsonOutput())\n", " def predict(self, parsed_json):\n", " label = self.classify_categories(parsed_json.get(\"text\"))\n", " return {'categories': self.news_label[label]}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pack our BentoService" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Getting tokenizer and vocab...\n" ] }, { "data": { "text/plain": [ "TextClassificationModel(\n", " (embedding): EmbeddingBag(95812, 64, mode=mean)\n", " (fc): Linear(in_features=64, out_features=4, bias=True)\n", ")" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer, vocab = get_tokenizer_vocab()\n", "train_iter = AG_NEWS(root='dataset', split='train')\n", "num_class = len(set([label for (label, text) in train_iter]))\n", "vocab_size = len(vocab)\n", "model = TextClassificationModel(vocab_size, EMBEDDING_SIZE, num_class).to(device)\n", "\n", "\n", "model.load_state_dict(torch.load(\"model/pytorch_model.pt\"))\n", "model.eval()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-06-04 09:59:40,285] WARNING - Using BentoML installed in `editable` model, the local BentoML repository including all code changes will be packaged together with saved bundle created, under the './bundled_pip_dependencies' directory of the saved bundle.\n", "[2021-06-04 09:59:40,334] INFO - Using user specified docker base image: `bentoml/model-server:0.12.1-py38-gpu`, usermust make sure that the base image either has Python 3.8 or conda installed.\n", "[2021-06-04 09:59:40,335] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n", "[2021-06-04 09:59:43,136] INFO - Detected non-PyPI-released BentoML installed, copying local BentoML modulefiles to target saved bundle path..\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/aarnphm/.pyenv/versions/3.8.8/lib/python3.8/site-packages/setuptools/distutils_patch.py:25: UserWarning: Distutils was imported before Setuptools. This usage is discouraged and may exhibit undesirable behaviors or errors. Please use Setuptools' objects directly or at least import Setuptools first.\n", " warnings.warn(\n", "warning: no previously-included files matching '*~' found anywhere in distribution\n", "warning: no previously-included files matching '*.pyo' found anywhere in distribution\n", "warning: no previously-included files matching '.git' found anywhere in distribution\n", "warning: no previously-included files matching '.ipynb_checkpoints' found anywhere in distribution\n", "warning: no previously-included files matching '__pycache__' found anywhere in distribution\n", "no previously-included directories found matching 'e2e_tests'\n", "no previously-included directories found matching 'tests'\n", "no previously-included directories found matching 'benchmark'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "UPDATING BentoML-0.12.1+52.g55c7bfb/bentoml/_version.py\n", "set BentoML-0.12.1+52.g55c7bfb/bentoml/_version.py to '0.12.1+52.g55c7bfb'\n", "[2021-06-04 09:59:48,513] WARNING - Saved BentoService bundle version mismatch: loading BentoService bundle create with BentoML version 0.12.1, but loading from BentoML version 0.12.1+52.g55c7bfb\n", "[2021-06-04 09:59:48,580] INFO - BentoService bundle 'PytorchService:20210604095940_7515CC' saved to: /home/aarnphm/bentoml/repository/PytorchService/20210604095940_7515CC\n" ] } ], "source": [ "# 1) import the custom BentoService defined above\n", "from bento_svc import PytorchService\n", "\n", "\n", "bento_svc = PytorchService()\n", "\n", "bento_svc.pack(\"model\", model)\n", "bento_svc.pack(\"tokenizer\", tokenizer)\n", "bento_svc.pack(\"vocab\", vocab)\n", "saved_path = bento_svc.save()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## REST API Model Serving\n", "\n", "To start a REST API model server with the BentoService save above, use the `serve` command:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-06-04 09:59:49,906] INFO - Getting latest version PytorchService:20210604095940_7515CC\n", "[2021-06-04 09:59:49,912] INFO - Starting BentoML API proxy in development mode..\n", "[2021-06-04 09:59:49,913] INFO - Starting BentoML API server in development mode..\n", "[2021-06-04 09:59:49,944] WARNING - Using BentoML installed in `editable` model, the local BentoML repository including all code changes will be packaged together with saved bundle created, under the './bundled_pip_dependencies' directory of the saved bundle.\n", "[2021-06-04 09:59:49,944] WARNING - Using BentoML installed in `editable` model, the local BentoML repository including all code changes will be packaged together with saved bundle created, under the './bundled_pip_dependencies' directory of the saved bundle.\n", "[2021-06-04 09:59:49,965] WARNING - Saved BentoService bundle version mismatch: loading BentoService bundle create with BentoML version 0.12.1, but loading from BentoML version 0.12.1+52.g55c7bfb\n", "[2021-06-04 09:59:49,965] WARNING - Saved BentoService bundle version mismatch: loading BentoService bundle create with BentoML version 0.12.1, but loading from BentoML version 0.12.1+52.g55c7bfb\n", "[2021-06-04 09:59:49,967] INFO - Your system nofile limit is 4096, which means each instance of microbatch service is able to hold this number of connections at same time. You can increase the number of file descriptors for the server process, or launch more microbatch instances to accept more concurrent connection.\n", "======== Running on http://0.0.0.0:5000 ========\n", "(Press CTRL+C to quit)\n", "[2021-06-04 09:59:50,658] INFO - Using user specified docker base image: `bentoml/model-server:0.12.1-py38-gpu`, usermust make sure that the base image either has Python 3.8 or conda installed.\n", "[2021-06-04 09:59:52,782] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n", " * Serving Flask app 'PytorchService' (lazy loading)\n", " * Environment: production\n", "\u001b[31m WARNING: This is a development server. Do not use it in a production deployment.\u001b[0m\n", "\u001b[2m Use a production WSGI server instead.\u001b[0m\n", " * Debug mode: off\n", "INFO:werkzeug: * Running on http://127.0.0.1:34823/ (Press CTRL+C to quit)\n", "INFO:werkzeug:127.0.0.1 - - [04/Jun/2021 09:59:53] \"GET / HTTP/1.1\" 200 -\n", "INFO:werkzeug:127.0.0.1 - - [04/Jun/2021 09:59:53] \"\u001b[36mGET /static_content/main.css HTTP/1.1\u001b[0m\" 304 -\n", "INFO:werkzeug:127.0.0.1 - - [04/Jun/2021 09:59:53] \"\u001b[36mGET /static_content/readme.css HTTP/1.1\u001b[0m\" 304 -\n", "INFO:werkzeug:127.0.0.1 - - [04/Jun/2021 09:59:53] \"\u001b[36mGET /static_content/swagger-ui-bundle.js HTTP/1.1\u001b[0m\" 304 -\n", "INFO:werkzeug:127.0.0.1 - - [04/Jun/2021 09:59:53] \"\u001b[36mGET /static_content/swagger-ui.css HTTP/1.1\u001b[0m\" 304 -\n", "INFO:werkzeug:127.0.0.1 - - [04/Jun/2021 09:59:53] \"\u001b[36mGET /static_content/marked.min.js HTTP/1.1\u001b[0m\" 304 -\n", "INFO:werkzeug:127.0.0.1 - - [04/Jun/2021 09:59:54] \"GET /docs.json HTTP/1.1\" 200 -\n", "Setup pipeline...\n", "[2021-06-04 10:00:16,167] INFO - {'service_name': 'PytorchService', 'service_version': '20210604095940_7515CC', 'api': 'predict', 'task': {'data': '{\"text\":\"WASHINGTON — President Biden offered a series of concessions to try to secure a $1 trillion infrastructure deal with Senate Republicans in an Oval Office meeting this week, narrowing both his spending and tax proposals as negotiations barreled into the final days of what could be an improbable agreement or a blame game that escalates quickly.\"}', 'task_id': 'a8019712-49a5-4163-96a8-d0cd12286f97', 'http_headers': (('Host', 'localhost:5000'), ('User-Agent', 'Mozilla/5.0 (X11; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0'), ('Accept', '*/*'), ('Accept-Language', 'en-US,en;q=0.5'), ('Accept-Encoding', 'gzip, deflate'), ('Referer', 'http://localhost:5000/'), ('Content-Type', 'application/json'), ('Origin', 'http://localhost:5000'), ('Content-Length', '357'), ('Dnt', '1'), ('Connection', 'keep-alive'), ('Cookie', 'username-localhost-8888=\"2|1:0|10:1622774823|23:username-localhost-8888|44:NzFiOWVlZjA2YmI3NGI0NmJlMmExNDc3YTU0MDE1MGM=|cb3ea22ef6f3140a1224188b484f48ed62e14b0d50082b1a236251c3a980df86\"; _xsrf=2|33d87053|2d1249d56e8ad5d63c884ea9a243cc2a|1622740800'))}, 'result': {'data': '{\"categories\": \"Business\"}', 'http_status': 200, 'http_headers': (('Content-Type', 'application/json'),)}, 'request_id': 'a8019712-49a5-4163-96a8-d0cd12286f97'}\n", "INFO:werkzeug:127.0.0.1 - - [04/Jun/2021 10:00:16] \"POST /predict HTTP/1.1\" 200 -\n", "^C\n" ] } ], "source": [ "!bentoml serve PytorchService:latest" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check if `BentoService` is running on GPU" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fri Jun 4 10:01:07 2021 \r\n", "+-----------------------------------------------------------------------------+\r\n", "| NVIDIA-SMI 465.31 Driver Version: 465.31 CUDA Version: 11.3 |\r\n", "|-------------------------------+----------------------+----------------------+\r\n", "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n", "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\r\n", "| | | MIG M. |\r\n", "|===============================+======================+======================|\r\n", "| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 Off | N/A |\r\n", "| N/A 71C P8 7W / N/A | 849MiB / 6078MiB | 0% Default |\r\n", "| | | N/A |\r\n", "+-------------------------------+----------------------+----------------------+\r\n", " \r\n", "+-----------------------------------------------------------------------------+\r\n", "| Processes: |\r\n", "| GPU GI CI PID Type Process name GPU Memory |\r\n", "| ID ID Usage |\r\n", "|=============================================================================|\r\n", "| 0 N/A N/A 1191 G /usr/lib/Xorg 4MiB |\r\n", "| 0 N/A N/A 405631 C ...sions/3.8.8/bin/python3.8 841MiB |\r\n", "+-----------------------------------------------------------------------------+\r\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you are running this notebook from Google Colab, start the dev server with `--run-with-ngrok` option to gain access to the API endpoint via a public endpoint managed by [ngrok](https://ngrok.com/):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!bentoml serve PyTorchFashionClassifier:latest --run-with-ngrok" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Containerize our model server with Docker\n", "\n", "One common way of distributing this model API server for production deployment, is via Docker containers. And BentoML provides a convenient way to do that.\n", "\n", "Note that docker is not available in Google Colab. You will need to download and run this notebook locally to try out this containerization with docker feature.\n", "\n", "If you already have docker configured, simply run the follow command to product a docker container serving the PytorchService with GPU prediction service created above:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-06-04 10:08:54,054] INFO - Getting latest version PytorchService:20210604095940_7515CC\n", "\u001b[39mFound Bento: /home/aarnphm/bentoml/repository/PytorchService/20210604095940_7515CC\u001b[0m\n", "[2021-06-04 10:08:54,079] WARNING - Using BentoML installed in `editable` model, the local BentoML repository including all code changes will be packaged together with saved bundle created, under the './bundled_pip_dependencies' directory of the saved bundle.\n", "[2021-06-04 10:08:54,094] WARNING - Saved BentoService bundle version mismatch: loading BentoService bundle create with BentoML version 0.12.1, but loading from BentoML version 0.12.1+52.g55c7bfb\n", "Containerizing PytorchService:20210604095940_7515CC with local YataiService and docker daemon from local environment|^C\n", "\b \r" ] } ], "source": [ "!bentoml containerize PytorchService:latest -t pytorch-service-gpu:latest" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021-06-04 03:48:28,799] INFO - Starting BentoML proxy in production mode..\n", "[2021-06-04 03:48:28,801] INFO - Starting BentoML API server in production mode..\n", "[2021-06-04 03:48:28,833] INFO - Running micro batch service on :5000\n", "[2021-06-04 03:48:28 +0000] [8] [INFO] Starting gunicorn 20.1.0\n", "[2021-06-04 03:48:28 +0000] [8] [INFO] Listening at: http://0.0.0.0:52545 (8)\n", "[2021-06-04 03:48:28 +0000] [8] [INFO] Using worker: sync\n", "[2021-06-04 03:48:28 +0000] [9] [INFO] Booting worker with pid: 9\n", "[2021-06-04 03:48:28,884] WARNING - Using BentoML not from official PyPI release. In order to find the same version of BentoML when deploying your BentoService, you must set the 'core/bentoml_deploy_version' config to a http/git location of your BentoML fork, e.g.: 'bentoml_deploy_version = git+https://github.com/{username}/bentoml.git@{branch}'\n", "[2021-06-04 03:48:28,914] WARNING - Saved BentoService bundle version mismatch: loading BentoService bundle create with BentoML version 0.12.1, but loading from BentoML version 0.12.1+52.g55c7bfb\n", "[2021-06-04 03:48:28 +0000] [1] [INFO] Starting gunicorn 20.1.0\n", "[2021-06-04 03:48:29 +0000] [1] [INFO] Listening at: http://0.0.0.0:5000 (1)\n", "[2021-06-04 03:48:29 +0000] [1] [INFO] Using worker: aiohttp.worker.GunicornWebWorker\n", "[2021-06-04 03:48:29 +0000] [10] [INFO] Booting worker with pid: 10\n", "[2021-06-04 03:48:29,012] WARNING - Using BentoML not from official PyPI release. In order to find the same version of BentoML when deploying your BentoService, you must set the 'core/bentoml_deploy_version' config to a http/git location of your BentoML fork, e.g.: 'bentoml_deploy_version = git+https://github.com/{username}/bentoml.git@{branch}'\n", "[2021-06-04 03:48:29,044] WARNING - Saved BentoService bundle version mismatch: loading BentoService bundle create with BentoML version 0.12.1, but loading from BentoML version 0.12.1+52.g55c7bfb\n", "[2021-06-04 03:48:29,047] INFO - Your system nofile limit is 1048576, which means each instance of microbatch service is able to hold this number of connections at same time. You can increase the number of file descriptors for the server process, or launch more microbatch instances to accept more concurrent connection.\n", "[2021-06-04 03:48:29,801] INFO - Using user specified docker base image: `bentoml/model-server:0.12.1-py38-gpu`, usermust make sure that the base image either has Python 3.8 or conda installed.\n", "[2021-06-04 03:48:32,065] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n", "Setup pipeline...\n", "[2021-06-04 03:48:54,455] INFO - {'service_name': 'PytorchService', 'service_version': '20210604095940_7515CC', 'api': 'predict', 'task': {'data': '{\"text\":\"WASHINGTON — President Biden offered a series of concessions to try to secure a $1 trillion infrastructure deal with Senate Republicans in an Oval Office meeting this week, narrowing both his spending and tax proposals as negotiations barreled into the final days of what could be an improbable agreement or a blame game that escalates quickly.\"}', 'task_id': 'e723d038-0bb3-41be-912d-5d605192918e', 'http_headers': (('Host', 'localhost:5000'), ('User-Agent', 'Mozilla/5.0 (X11; Linux x86_64; rv:88.0) Gecko/20100101 Firefox/88.0'), ('Accept', '*/*'), ('Accept-Language', 'en-US,en;q=0.5'), ('Accept-Encoding', 'gzip, deflate'), ('Referer', 'http://localhost:5000/'), ('Content-Type', 'application/json'), ('Origin', 'http://localhost:5000'), ('Content-Length', '357'), ('Dnt', '1'), ('Connection', 'keep-alive'), ('Cookie', 'username-localhost-8888=\"2|1:0|10:1622774823|23:username-localhost-8888|44:NzFiOWVlZjA2YmI3NGI0NmJlMmExNDc3YTU0MDE1MGM=|cb3ea22ef6f3140a1224188b484f48ed62e14b0d50082b1a236251c3a980df86\"; _xsrf=2|33d87053|2d1249d56e8ad5d63c884ea9a243cc2a|1622740800'))}, 'result': {'data': '{\"categories\": \"Business\"}', 'http_status': 200, 'http_headers': (('Content-Type', 'application/json'),)}, 'request_id': 'e723d038-0bb3-41be-912d-5d605192918e'}\n", "^C\n", "[2021-06-04 03:56:05 +0000] [1] [INFO] Handling signal: int\n", "[2021-06-04 03:56:05 +0000] [10] [INFO] Worker exiting (pid: 10)\n", "[2021-06-04 03:56:05 +0000] [8] [INFO] Handling signal: term\n", "[2021-06-04 03:56:05 +0000] [9] [INFO] Worker exiting (pid: 9)\n" ] } ], "source": [ "!docker run --gpus all --device /dev/nvidia0 --device /dev/nvidiactl --device /dev/nvidia-modeset --device /dev/nvidia-uvm --device /dev/nvidia-uvm-tools -p 5000:5000 pytorch-service-gpu" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deployment Options\n", "\n", "If you are at a small team with limited engineering or DevOps resources, try out automated deployment with BentoML CLI, currently supporting AWS Lambda, AWS SageMaker, and Azure Functions:\n", "- [AWS Lambda Deployment Guide](https://docs.bentoml.org/en/latest/deployment/aws_lambda.html)\n", "- [AWS SageMaker Deployment Guide](https://docs.bentoml.org/en/latest/deployment/aws_sagemaker.html)\n", "- [Azure Functions Deployment Guide](https://docs.bentoml.org/en/latest/deployment/azure_functions.html)\n", "\n", "If the cloud platform you are working with is not on the list above, try out these step-by-step guide on manually deploying BentoML packaged model to cloud platforms:\n", "- [AWS ECS Deployment](https://docs.bentoml.org/en/latest/deployment/aws_ecs.html)\n", "- [Google Cloud Run Deployment](https://docs.bentoml.org/en/latest/deployment/google_cloud_run.html)\n", "- [Azure container instance Deployment](https://docs.bentoml.org/en/latest/deployment/azure_container_instance.html)\n", "- [Heroku Deployment](https://docs.bentoml.org/en/latest/deployment/heroku.html)\n", "\n", "Lastly, if you have a DevOps or ML Engineering team who's operating a Kubernetes or OpenShift cluster, use the following guides as references for implementating your deployment strategy:\n", "- [Kubernetes Deployment](https://docs.bentoml.org/en/latest/deployment/kubernetes.html)\n", "- [Knative Deployment](https://docs.bentoml.org/en/latest/deployment/knative.html)\n", "- [Kubeflow Deployment](https://docs.bentoml.org/en/latest/deployment/kubeflow.html)\n", "- [KFServing Deployment](https://docs.bentoml.org/en/latest/deployment/kfserving.html)\n", "- [Clipper.ai Deployment Guide](https://docs.bentoml.org/en/latest/deployment/clipper.html)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 1 }