{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Optimize PyTorch Models using IntelĀ® Extension for PyTorch (IPEX) Quantization\n",
    "This code sample will quantize a ResNet50 model while using Intel's Extension for PyTorch (IPEX). The model will run inference with FP32 and INT8 precision, including static INT8 quantization and dynamic INT8 quantization. During Static Quantization, the model calibrated with the CIFAR10 dataset. The inference time will be compared, showcasing the speedup of INT8 Quantization.\n",
    "\n",
    "## Environment Setup\n",
    "Ensure the PyTorch kernel is activated before running this notebook.\n",
    "\n",
    "## Imports, Dataset, Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from time import time\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import intel_extension_for_pytorch as ipex\n",
    "from intel_extension_for_pytorch.quantization import prepare, convert\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters and constants\n",
    "LR = 0.001\n",
    "DOWNLOAD = True\n",
    "DATA = 'datasets/cifar10/'\n",
    "WARMUP = 3\n",
    "ITERS = 100\n",
    "transform = torchvision.transforms.Compose([\n",
    "torchvision.transforms.Resize((224, 224)),\n",
    "torchvision.transforms.ToTensor(),\n",
    "torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "])\n",
    "test_dataset = torchvision.datasets.CIFAR10(\n",
    "        root=DATA,\n",
    "        train=False,\n",
    "        transform=transform,\n",
    "        download=DOWNLOAD,\n",
    ")\n",
    "calibration_data_loader = torch.utils.data.DataLoader(\n",
    "        dataset=test_dataset,\n",
    "        batch_size=128\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get model from torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.rand(1, 3, 224, 224)\n",
    "model_fp32 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)\n",
    "model_fp32.eval()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference with FP32 model\n",
    "\n",
    "The function below will test the inference time with input model and return the average inference time for 1 iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference(model, WARMUP, ITERS, data):\n",
    "    print(\"Warmup before benchmark ...\")\n",
    "    for i in range(WARMUP):\n",
    "        out = model(data)\n",
    "\n",
    "    print(\"Inference ...\")\n",
    "    inference_time = 0\n",
    "    for i in range(ITERS):\n",
    "        start_time = time()\n",
    "        out = model(data)\n",
    "        end_time = time()\n",
    "        inference_time = inference_time + (end_time - start_time)\n",
    "\n",
    "    inference_time = inference_time / ITERS\n",
    "    print(\"Inference Time Avg: \", inference_time)\n",
    "    return inference_time"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Static Quantization \n",
    "The function below staticQuantize will calibrate the fp32 model with calibration dataloader and return the quantized static int8 model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def staticQuantize(model_fp32, data, calibration_data_loader):\n",
    "    # Acquire inference times for static quantization INT8 model \n",
    "    qconfig_static = ipex.quantization.default_static_qconfig\n",
    "    # # Alternatively, define your own qconfig:\n",
    "    # from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig\n",
    "    # qconfig = QConfig(activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),\n",
    "    #        weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_channel_symmetric))\n",
    "    prepared_model_static = prepare(model_fp32, qconfig_static, example_inputs=data, inplace=False)\n",
    "    print(\"Calibration with Static Quantization ...\")\n",
    "    for batch_idx, (data, target) in enumerate(calibration_data_loader):\n",
    "        prepared_model_static(data)\n",
    "        if batch_idx % 10 == 0:\n",
    "            print(\"Batch %d/%d complete, continue ...\" %(batch_idx+1, len(calibration_data_loader)))\n",
    "    print(\"Calibration Done\")\n",
    "\n",
    "    converted_model_static = convert(prepared_model_static)\n",
    "    with torch.no_grad():\n",
    "        traced_model_static = torch.jit.trace(converted_model_static, data)\n",
    "        traced_model_static = torch.jit.freeze(traced_model_static)\n",
    "\n",
    "    # save the quantized static model \n",
    "    traced_model_static.save(\"quantized_model_static.pt\")\n",
    "    return traced_model_static\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dynamic Quantization \n",
    "The function below dynamicQuantize will quantize the fp32 model with dynamic quantization and return the quantized dynamic int8 model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dynamicQuantize(model_fp32, data):\n",
    "    # Acquire inference times for dynamic quantization INT8 model\n",
    "    qconfig_dynamic = ipex.quantization.default_dynamic_qconfig\n",
    "    print(\"Quantize Model with Dynamic Quantization ...\")\n",
    "\n",
    "    prepared_model_dynamic = prepare(model_fp32, qconfig_dynamic, example_inputs=data, inplace=False)\n",
    "\n",
    "    converted_model_dynamic = convert(prepared_model_dynamic)\n",
    "    with torch.no_grad():\n",
    "        traced_model_dynamic = torch.jit.trace(converted_model_dynamic, data)\n",
    "        traced_model_dynamic = torch.jit.freeze(traced_model_dynamic)\n",
    "\n",
    "    # save the quantized dynamic model \n",
    "    traced_model_dynamic.save(\"quantized_model_dynamic.pt\")\n",
    "    return traced_model_dynamic\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantize the FP32 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists('quantized_model_static.pt'):\n",
    "    # Static Quantizaton & Save Model to quantized_model_static.pt\n",
    "    print('quantize the model with static quantization')\n",
    "    staticQuantize(model_fp32, data, calibration_data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists('quantized_model_dynamic.pt'):\n",
    "    # Dynamic Quantization & Save Model to quantized_model_dynamic.pt\n",
    "    print('quantize the model with dynamic quantization')\n",
    "    dynamicQuantize(model_fp32, data)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference With FP32 Model, Static INT8 Model and Dynamic INT8 Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Inference with FP32\")\n",
    "fp32_inference_time = inference(model_fp32, WARMUP, ITERS, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Inference with Static INT8\")\n",
    "traced_model_static = torch.jit.load('quantized_model_static.pt')\n",
    "traced_model_static.eval()\n",
    "traced_model_static = torch.jit.freeze(traced_model_static)\n",
    "int8_inference_time_static = inference(traced_model_static, WARMUP, ITERS, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Inference with Dynamic INT8\")\n",
    "traced_model_dynamic = torch.jit.load('quantized_model_dynamic.pt')\n",
    "traced_model_dynamic.eval()\n",
    "traced_model_dynamic = torch.jit.freeze(traced_model_dynamic)\n",
    "int8_inference_time_dynamic = inference(traced_model_dynamic, WARMUP, ITERS, data)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary of Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference time results\n",
    "print(\"Summary\")\n",
    "print(\"FP32 inference time: %.3f\" %fp32_inference_time)\n",
    "print(\"INT8 static quantization inference time: %.3f\" %int8_inference_time_static)\n",
    "print(\"INT8 dynamic quantization inference time: %.3f\" %int8_inference_time_dynamic)\n",
    "\n",
    "# Create bar chart with training time results\n",
    "plt.figure(figsize=(4,3))\n",
    "plt.title(\"ResNet Inference Time\")\n",
    "plt.xlabel(\"Test Case\")\n",
    "plt.ylabel(\"Inference Time (seconds)\")\n",
    "plt.bar([\"FP32\", \"INT8 static\", \"INT8 dynamic\"], [fp32_inference_time, int8_inference_time_static, int8_inference_time_dynamic])\n",
    "\n",
    "# Calculate speedup when using quantization\n",
    "speedup_from_fp32_static = fp32_inference_time / int8_inference_time_static\n",
    "print(\"Staic INT8 %.2fX faster than FP32\" %speedup_from_fp32_static)\n",
    "speedup_from_fp32_dynamic = fp32_inference_time / int8_inference_time_dynamic\n",
    "print(\"Dynamic INT8 %.2fX faster than FP32\" %speedup_from_fp32_dynamic)\n",
    "\n",
    "\n",
    "# Create bar chart with speedup results\n",
    "plt.figure(figsize=(4,3))\n",
    "plt.title(\"Quantization Speedup\")\n",
    "plt.xlabel(\"Test Case\")\n",
    "plt.ylabel(\"Speedup\")\n",
    "plt.bar([\"FP32\",\"Static INT8\", \"Dynamic INT8\"], [1, speedup_from_fp32_static, speedup_from_fp32_dynamic])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('[CODE_SAMPLE_COMPLETED_SUCCESFULLY]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  },
  "vscode": {
   "interpreter": {
    "hash": "4678fb2792a22465205165c52aab2f7cff7494375a364749bf16e0ac11f2a502"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}