{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Deploy Quantized Models using Torch-TensorRT\n\nHere we demonstrate how to deploy a model quantized to INT8 or FP8 using the Dynamo frontend of Torch-TensorRT\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports and Model Definition\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import argparse\n\nimport modelopt.torch.quantization as mtq\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch_tensorrt as torchtrt\nimport torchvision.datasets as datasets\nimport torchvision.transforms as transforms\nfrom modelopt.torch.quantization.utils import export_torch_mode\n\n\nclass VGG(nn.Module):\n def __init__(self, layer_spec, num_classes=1000, init_weights=False):\n super(VGG, self).__init__()\n\n layers = []\n in_channels = 3\n for l in layer_spec:\n if l == \"pool\":\n layers.append(nn.MaxPool2d(kernel_size=2, stride=2))\n else:\n layers += [\n nn.Conv2d(in_channels, l, kernel_size=3, padding=1),\n nn.BatchNorm2d(l),\n nn.ReLU(),\n ]\n in_channels = l\n\n self.features = nn.Sequential(*layers)\n self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n self.classifier = nn.Sequential(\n nn.Linear(512 * 1 * 1, 4096),\n nn.ReLU(),\n nn.Dropout(),\n nn.Linear(4096, 4096),\n nn.ReLU(),\n nn.Dropout(),\n nn.Linear(4096, num_classes),\n )\n if init_weights:\n self._initialize_weights()\n\n def _initialize_weights(self):\n for m in self.modules():\n if isinstance(m, nn.Conv2d):\n nn.init.kaiming_normal_(m.weight, mode=\"fan_out\", nonlinearity=\"relu\")\n if m.bias is not None:\n nn.init.constant_(m.bias, 0)\n elif isinstance(m, nn.BatchNorm2d):\n nn.init.constant_(m.weight, 1)\n nn.init.constant_(m.bias, 0)\n elif isinstance(m, nn.Linear):\n nn.init.normal_(m.weight, 0, 0.01)\n nn.init.constant_(m.bias, 0)\n\n def forward(self, x):\n x = self.features(x)\n x = self.avgpool(x)\n x = torch.flatten(x, 1)\n x = self.classifier(x)\n return x\n\n\ndef vgg16(num_classes=1000, init_weights=False):\n vgg16_cfg = [\n 64,\n 64,\n \"pool\",\n 128,\n 128,\n \"pool\",\n 256,\n 256,\n 256,\n \"pool\",\n 512,\n 512,\n 512,\n \"pool\",\n 512,\n 512,\n 512,\n \"pool\",\n ]\n return VGG(vgg16_cfg, num_classes, init_weights)\n\n\nPARSER = argparse.ArgumentParser(\n description=\"Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16\"\n)\nPARSER.add_argument(\n \"--ckpt\", type=str, required=True, help=\"Path to the pre-trained checkpoint\"\n)\nPARSER.add_argument(\n \"--batch-size\",\n default=128,\n type=int,\n help=\"Batch size for tuning the model with PTQ and FP8\",\n)\nPARSER.add_argument(\n \"--quantize-type\",\n default=\"int8\",\n type=str,\n help=\"quantization type, currently supported int8 or fp8 for PTQ\",\n)\nargs = PARSER.parse_args()\n\nmodel = vgg16(num_classes=10, init_weights=False)\nmodel = model.cuda().eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the pre-trained model weights\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ckpt = torch.load(args.ckpt)\nweights = ckpt[\"model_state_dict\"]\n\nif torch.cuda.device_count() > 1:\n from collections import OrderedDict\n\n new_state_dict = OrderedDict()\n for k, v in weights.items():\n name = k[7:] # remove `module.`\n new_state_dict[name] = v\n weights = new_state_dict\n\nmodel.load_state_dict(weights)\n# Don't forget to set the model to evaluation mode!\nmodel.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load training dataset and define loss function for PTQ\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "training_dataset = datasets.CIFAR10(\n root=\"./data\",\n train=True,\n download=True,\n transform=transforms.Compose(\n [\n transforms.RandomCrop(32, padding=4),\n transforms.RandomHorizontalFlip(),\n transforms.ToTensor(),\n transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n ]\n ),\n)\ntraining_dataloader = torch.utils.data.DataLoader(\n training_dataset,\n batch_size=args.batch_size,\n shuffle=True,\n num_workers=2,\n drop_last=True,\n)\n\ndata = iter(training_dataloader)\nimages, _ = next(data)\n\ncrit = nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Calibration Loop for quantization\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def calibrate_loop(model):\n # calibrate over the training dataset\n total = 0\n correct = 0\n loss = 0.0\n for data, labels in training_dataloader:\n data, labels = data.cuda(), labels.cuda(non_blocking=True)\n out = model(data)\n loss += crit(out, labels)\n preds = torch.max(out, 1)[1]\n total += labels.size(0)\n correct += (preds == labels).sum().item()\n\n print(\"PTQ Loss: {:.5f} Acc: {:.2f}%\".format(loss / total, 100 * correct / total))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tune the pre-trained model with FP8 and PTQ\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if args.quantize_type == \"int8\":\n quant_cfg = mtq.INT8_DEFAULT_CFG\nelif args.quantize_type == \"fp8\":\n quant_cfg = mtq.FP8_DEFAULT_CFG\n# PTQ with in-place replacement to quantized modules\nmtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)\n# model has FP8 qdq nodes at this point" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Load the testing dataset\ntesting_dataset = datasets.CIFAR10(\n root=\"./data\",\n train=False,\n download=True,\n transform=transforms.Compose(\n [\n transforms.ToTensor(),\n transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n ]\n ),\n)\n\ntesting_dataloader = torch.utils.data.DataLoader(\n testing_dataset,\n batch_size=args.batch_size,\n shuffle=False,\n num_workers=2,\n drop_last=True,\n) # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`\n\nwith torch.no_grad():\n with export_torch_mode():\n # Compile the model with Torch-TensorRT Dynamo backend\n input_tensor = images.cuda()\n\n exp_program = torch.export.export(model, (input_tensor,), strict=False)\n if args.quantize_type == \"int8\":\n enabled_precisions = {torch.int8}\n elif args.quantize_type == \"fp8\":\n enabled_precisions = {torch.float8_e4m3fn}\n trt_model = torchtrt.dynamo.compile(\n exp_program,\n inputs=[input_tensor],\n enabled_precisions=enabled_precisions,\n min_block_size=1,\n )\n # You can also use torch compile path to compile the model with Torch-TensorRT:\n # trt_model = torch.compile(model, backend=\"tensorrt\")\n\n # Inference compiled Torch-TensorRT model over the testing dataset\n total = 0\n correct = 0\n loss = 0.0\n class_probs = []\n class_preds = []\n with torch.no_grad():\n for data, labels in testing_dataloader:\n data, labels = data.cuda(), labels.cuda(non_blocking=True)\n out = trt_model(data)\n loss += crit(out, labels)\n preds = torch.max(out, 1)[1]\n class_probs.append([F.softmax(i, dim=0) for i in out])\n class_preds.append(preds)\n total += labels.size(0)\n correct += (preds == labels).sum().item()\n\n test_probs = torch.cat([torch.stack(batch) for batch in class_probs])\n test_preds = torch.cat(class_preds)\n test_loss = loss / total\n test_acc = correct / total\n print(\n \"Test Loss: {:.5f} Test Acc: {:.2f}%\".format(test_loss, 100 * test_acc)\n )" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 0 }