{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 4: HG Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import keras\n",
    "from keras.utils import to_categorical\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "import tensorflow as tf\n",
    "\n",
    "tf.random.set_seed(seed)\n",
    "\n",
    "os.environ['PATH'] = os.environ['XILINX_VITIS'] + '/bin:' + os.environ['PATH']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fetch the jet tagging dataset from Open ML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you haven't finished part 1 already, uncomment the following lines to download, process, and save the dataset\n",
    "\n",
    "# le = LabelEncoder()\n",
    "# y = le.fit_transform(y)\n",
    "# y = to_categorical(y, 5)\n",
    "# X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "# # print(y[:5])\n",
    "# scaler = StandardScaler()\n",
    "# X_train_val = scaler.fit_transform(X_train_val)\n",
    "# X_test = scaler.transform(X_test)\n",
    "# np.save('X_train_val.npy', X_train_val)\n",
    "# np.save('X_test.npy', X_test)\n",
    "# np.save('y_train_val.npy', y_train_val)\n",
    "# np.save('y_test.npy', y_test)\n",
    "# np.save('classes.npy', le.classes_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_val = np.load('X_train_val.npy')\n",
    "X_test = np.load('X_test.npy')\n",
    "y_train_val = np.load('y_train_val.npy')\n",
    "y_test = np.load('y_test.npy')\n",
    "classes = np.load('classes.npy', allow_pickle=True)\n",
    "\n",
    "# Convert everything to tf.Tensor to avoid casting\n",
    "with tf.device('/cpu:0'):  # type: ignore\n",
    "    _X_train_val = tf.convert_to_tensor(X_train_val, dtype=tf.float32)\n",
    "    # We don't make explicit y categorical tensor:\n",
    "    # Use SparseCategoricalCrossentropy loss instead.\n",
    "    _y_train_val = tf.convert_to_tensor(np.argmax(y_train_val, axis=1), dtype=tf.int32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct a model\n",
    "This time we're going to use HGQ layers.\n",
    "\n",
    "HGQ is \"High Granularity Quantization\" for heterogeneous quantization at arbitrary granularity, up to per-weight and per-activation level.\n",
    "\n",
    "https://github.com/calad0i/HGQ\n",
    "\n",
    "Depending on the specific task, HGQ can achieve more than 10x resource savings comparing to QKeras. (For example, on this dataset and requiring an accuracy of around 0.72~0.74)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.optimizers import Adam\n",
    "from keras.losses import SparseCategoricalCrossentropy\n",
    "from HGQ.layers import HQuantize, HDense, HActivation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For any layer that needs to be quantized (i.e., layers that perform the actual computation), add a `H` in front of the layer name. For example, `HDense`, `HConv2D`, `HActivation`, etc.\n",
    "\n",
    "HGQ requires the input number to be quantized. To achieve it, you can simply add a `HQuantizer` layer at the beginning of the model. You may refer to https://calad0i.github.io/HGQ/ for full documentation.\n",
    "\n",
    "As all quantization bitwidths are learnt, you don't need to specify them. Instead, for each `H-` layer, you need to specify the `beta` parameter that controls the trade-off between accuracy and resource savings. The higher the `beta`, the more aggressive the quantization will be."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = 3e-6\n",
    "# The bigger the beta, the smaller the models is, at the cost of accuracy.\n",
    "\n",
    "model = Sequential(\n",
    "    [\n",
    "        HQuantize(beta=beta),\n",
    "        HDense(64, activation='relu', beta=beta),\n",
    "        HDense(32, activation='relu', beta=beta),\n",
    "        HDense(32, activation='relu', beta=beta),\n",
    "        HDense(5, beta=beta),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train sparse\n",
    "\n",
    "No need to do anything. Unstructured sparsity comes for free with HGQ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is a empty code cell, you don't need to put anything here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model\n",
    "We'll use the same settings as the model for part 1: Adam optimizer with categorical crossentropy loss.\n",
    "\n",
    "However, we can skip the softmax layer in the model by adding `from_logits=True` to the loss function. `Softmax` is expensive in hardware, so we want to avoid it if possible.\n",
    "\n",
    "For any HGQ model, it's essential to use `ResetMinMax` callback to reset the quantization ranges after each epoch. This is because the ranges are calculated based on the data seen so far, and we want to make sure they are recalculated after each epoch.\n",
    "\n",
    "It is recommended to use the `FreeBOPs` callback to monitor the number of (effective) bits operations in the model. This is a good proxy for ressource usage in FPGA (BOPs ~ 55*DSPs+LUTs) for **post place&route resource**. Notice that CSynth tends to overestimate at least by a factor of 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from HGQ import ResetMinMax, FreeBOPs\n",
    "from keras.callbacks import LearningRateScheduler\n",
    "from keras.experimental import CosineDecay\n",
    "from nn_utils import PBarCallback\n",
    "\n",
    "_sched = CosineDecay(2e-2, 200)\n",
    "sched = LearningRateScheduler(_sched)\n",
    "pbar = PBarCallback(metric='loss: {loss:.3f}/{val_loss:.3f} - acc: {accuracy:.3f}/{val_accuracy:.3f}')\n",
    "\n",
    "callbacks = [ResetMinMax(), FreeBOPs(), pbar, sched]\n",
    "\n",
    "# ResetMinMax: necessary callback for all HGQ models\n",
    "# FreeBOPs: recommended callback\n",
    "# pbar: progress bar callback, useful when the number of epochs is high\n",
    "# sched: learning rate scheduler. Cosine decay in this case."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notice\n",
    "\n",
    "- Due to the stochasticness of surrogate gradient on the individual bitwidth, it is recommended to train the model with a large batchsize over more epochs.\n",
    "\n",
    "- HGQ is jit-compiled for many parts. The first epoch will take longer to compile.\n",
    "\n",
    "- We train for 200 epochs here, which takes ~1min on a 3070-maxq GPU, similar to the time taken part 4.\n",
    "\n",
    "- Parameters used in this tutorial are not optimized for the best performance. Please refer to [HGQ-demos](https://github.com/calad0i/HGQ-demos) for more advanced examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = True\n",
    "if train:\n",
    "    opt = Adam(learning_rate=0)\n",
    "    loss = SparseCategoricalCrossentropy(from_logits=True)\n",
    "    model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])\n",
    "\n",
    "    model.fit(\n",
    "        _X_train_val,\n",
    "        _y_train_val,\n",
    "        batch_size=16384,\n",
    "        epochs=200,\n",
    "        validation_split=0.25,\n",
    "        shuffle=True,\n",
    "        callbacks=callbacks,\n",
    "        verbose=0,  # type: ignore\n",
    "    )\n",
    "    model.save('model_3.1/model.h5')\n",
    "else:\n",
    "    from keras.models import load_model\n",
    "\n",
    "    # No need to use custom_objects as the custom layers are already registered\n",
    "    model: keras.Model = load_model('model_3.1/model.h5')  # type: ignore"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare for conversion\n",
    "\n",
    "HGQ model cannot be converted to hls4ml model directly, and we need to convert it to a proxy model first. The proxy model also serves as a bit-accurate emulator of the hls4ml model that takes numerical overflow into account.\n",
    "\n",
    "To convert to a proxy model, we need to set appropriate ranges of the model internal variables. This is done by using the `trace_minmax` function. You can add a scaler factor `cover_range` to the ranges to make sure the model more robust to numerical overflow. `trace_minmax` also resturns the exact (effective) BOPs of the model (the number provided during training is approximated).\n",
    "\n",
    "If you keep all parameters the same and everything goes correctly, total BOPs of the model should be around 6500. This means, after running place&route (or vsynth), the model should take around 6500 LUTs, which means DSPs*55+LUTs used should be around 6500."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from HGQ import trace_minmax, to_proxy_model\n",
    "\n",
    "trace_minmax(model, X_train_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check that the model is indeed sparse without explicit pruning or `l1` regularization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer in model.layers:\n",
    "    if layer._has_kernel:\n",
    "        k = layer.fused_qkernel.numpy()\n",
    "        print(f'{layer.name}: {np.mean(k==0):.2%} sparsity')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, convert the model to a proxy model using the `to_proxy_model` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "proxy = to_proxy_model(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import hls4ml\n",
    "import plotting\n",
    "\n",
    "hls_model = hls4ml.converters.convert_from_keras_model(\n",
    "    proxy, output_dir='model_3.1/hls4ml_prj', part='xcu250-figd2104-2L-e', backend='Vitis'\n",
    ")\n",
    "hls_model.compile()\n",
    "\n",
    "X_test = np.ascontiguousarray(X_test)\n",
    "y_keras = model.predict(X_test, batch_size=16384, verbose=0)\n",
    "y_proxy = proxy.predict(X_test, batch_size=16384, verbose=0)\n",
    "y_hls = hls_model.predict(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check bit-accuracy\n",
    "If you are unlucky, `y_keras` and `y_hls` will not fully match due to numerical overflow (for a few entries). However, `y_keras` and `y_proxy` should match perfectly. (Sometime mismatch could also happen - only due to machine precision limit.\n",
    "\n",
    "For newer nvidia GPUs, TF32 is enabled by default (fp32 with reduced mantissa bits), which could cause this issue). This will make this issue more prevalent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(y_keras == y_hls), np.mean(y_proxy == y_hls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The plotting script assumes 0-1 range for the predictions.\n",
    "y_keras_softmax = tf.nn.softmax(y_keras).numpy()\n",
    "y_hls_softmax = tf.nn.softmax(y_hls).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from sklearn.metrics import accuracy_score\n",
    "from keras.models import load_model\n",
    "\n",
    "model_ref = load_model('model_1/KERAS_check_best_model.h5')\n",
    "y_ref = model_ref.predict(X_test, batch_size=1024, verbose=0)\n",
    "\n",
    "print(\"Accuracy baseline:  {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_ref, axis=1))))\n",
    "print(\"Accuracy pruned, quantized: {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_keras, axis=1))))\n",
    "print(\"Accuracy hls4ml: {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_hls, axis=1))))\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(9, 9))\n",
    "_ = plotting.makeRoc(y_test, y_ref, classes)\n",
    "plt.gca().set_prop_cycle(None)  # reset the colors\n",
    "_ = plotting.makeRoc(y_test, y_keras_softmax, classes, linestyle='--')\n",
    "plt.gca().set_prop_cycle(None)  # reset the colors\n",
    "_ = plotting.makeRoc(y_test, y_hls_softmax, classes, linestyle=':')\n",
    "\n",
    "from matplotlib.lines import Line2D\n",
    "\n",
    "lines = [Line2D([0], [0], ls='-'), Line2D([0], [0], ls='--'), Line2D([0], [0], ls=':')]\n",
    "from matplotlib.legend import Legend\n",
    "\n",
    "leg = Legend(ax, lines, labels=['baseline', 'pruned, quantized', 'hls4ml'], loc='lower right', frameon=False)\n",
    "ax.add_artist(leg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Synthesize\n",
    "Now let's synthesize this quantized, pruned model.\n",
    "\n",
    "**The synthesis will take a while**\n",
    "\n",
    "While the C-Synthesis is running, we can monitor the progress looking at the log file by opening a terminal from the notebook home, and executing:\n",
    "\n",
    "`tail -f model_3.1/hls4ml_prj/vitis_hls.log`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hls_model.build(csim=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check the reports\n",
    "Print out the reports generated by Vitis HLS. Pay attention to the Utilization Estimates' section in particular this time.\n",
    "\n",
    "## Notice\n",
    "We strip away the softmax layer compare to part 4, which takes 3~5 cycles to compute. The overall latency could be comparable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hls4ml.report.read_vivado_report('model_3.1/hls4ml_prj')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Print the report for the model trained in part 4. You should notice that the resource usage is significantly lower than the model trained in part 4.\n",
    "\n",
    "**Note you need to have trained and synthesized the model from part 4**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hls4ml.report.read_vivado_report('model_3/hls4ml_prj')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NB\n",
    "Note as well that the Vitis HLS `csynth` resource estimates tend to _overestimate_ on chip resource usage. Running the subsequent stages of FPGA compilation reveals the more realistic resource usage, You can run the next step, 'logic synthesis' with `hls_model.build(synth=True, vsynth=True)`, but we skipped it in this tutorial in the interest of time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}