{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Before You Start\n", "\n", "The current set of notebooks are under constant development.\n", "\n", "## Update Tutorial Repository\n", "\n", "If you have previously cloned the tutorial repository, you may need to get the latest versions of the notebooks.\n", "\n", "First check the status of your repository:\n", "```\n", "cd hls4ml-tutorial\n", "make clean\n", "git status \n", "```\n", "\n", "You may have some _modified_ notebooks. For example:\n", "\n", "```\n", "# On branch csee-e6868-spring2022\n", "# Changes not staged for commit:\n", "# (use \"git add ...\" to update what will be committed)\n", "# (use \"git checkout -- ...\" to discard changes in working directory)\n", "#\n", "#\tmodified: part1_getting_started.ipynb\n", "#\tmodified: part2_advanced_config.ipynb\n", "#\n", "no changes added to commit (use \"git add\" and/or \"git commit -a\")\n", "```\n", "\n", "You can make a copy of those modified notebooks if you had significat changes, otherwise the easiest thing to do is to discard those changes.\n", "\n", "**ATTENTION** You will loose your local changes!\n", "\n", "```\n", "git checkout *.ipynb\n", "```\n", "\n", "At this point, you can update you copy of the repository:\n", "```\n", "git pull\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Part 1: Getting started\n", "\n", "## Model Training\n", "\n", "\n", "\n", "### Setup\n", "\n", "Import packages from [TensorFlow](https://www.tensorflow.org), [scikit-learn](https://scikit-learn.org), and [NumPy](https://numpy.org)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.utils import to_categorical\n", "from sklearn.datasets import fetch_openml\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder, StandardScaler\n", "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use a [magic function](https://ipython.readthedocs.io/en/stable/interactive/tutorial.html#magics-explained) to include matplotlib graphs in the notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Force a deterministic behaviour with a constant seed. In TensorFlow, `tf.random.set_seed` sets a global-random seed; you can also specify operation-level seeds. More [details](https://www.tensorflow.org/api_docs/python/tf/random/set_seed)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "seed = 0\n", "np.random.seed(seed)\n", "import tensorflow as tf\n", "tf.random.set_seed(seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Specify where the to find the executable of Xilinx Vivado HLS. The path on the Columbia servers is `/opt/xilinx/Vivado/2019.1/bin`, but you can change it if you have a local installation of this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['PATH'] = '/opt/xilinx/Vivado/2019.1/bin:' + os.environ['PATH']\n", "\n", "def is_tool(name):\n", " from distutils.spawn import find_executable\n", " return find_executable(name) is not None\n", "\n", "print('-----------------------------------')\n", "if not is_tool('vivado_hls'):\n", " print('Xilinx Vivado HLS is NOT in the PATH')\n", "else:\n", " print('Xilinx Vivado HLS is in the PATH')\n", "print('-----------------------------------')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fetch the jet tagging dataset from OpenML\n", "\n", "The [jet tagging dataset](https://www.openml.org/d/42468) is publicly available on [OpenML](https://www.openml.org)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = fetch_openml('hls4ml_lhc_jets_hlf')\n", "X, y = data['data'], data['target']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's print some information about the dataset (e.g. feature names and the dataset shape)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "print('-----------------------------------')\n", "print('Feature names')\n", "print(data['feature_names'])\n", "print('-----------------------------------')\n", "print('Shape of the data and label (target) arrays')\n", "print(X.shape, y.shape)\n", "print('-----------------------------------')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's print some data and labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "print('-----------------------------------')\n", "print('\\nFirst five samples in the data set')\n", "display(pd.DataFrame(data=X[:5]))\n", "\n", "print('\\nFirst five labels (targets) in the data set')\n", "display(pd.DataFrame(data=y[:5]))\n", "print('-----------------------------------')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can visualize the data with [boxplots](https://en.wikipedia.org/wiki/Box_plot) and notice that the distribution of some of the features is larger than others. You can also plot the outliers with `showfliers=True`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.boxplot(X, showfliers=False)\n", "_ = plt.xticks(np.arange(1, X.shape[1] + 1), data['feature_names'], rotation=30, ha=\"right\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preprocessing\n", "\n", "As you saw above, the `y` target is an array of strings, e.g. \\['g', 'w',...\\] etc.\n", "We need to make this a _One Hot encoding_ for the training phase." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('-----------------------------------')\n", "print(y[:5]) # Target labels\n", "print('-----------------------------------')\n", "le = LabelEncoder()\n", "y = le.fit_transform(y) # Encode target labels with values\n", "print(y[:5])\n", "print('-----------------------------------')\n", "y = to_categorical(y, 5) # Convert those values to one-hot encoding\n", "print(y[:5])\n", "print('-----------------------------------')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split the dataset into training (80% of the samples) and test (20%) sets." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", "print('-----------------------------------')\n", "print('*** Shape of the splitted arrays ***')\n", "print(X_train_val.shape, X_test.shape, y_train_val.shape, y_test.shape)\n", "print('-----------------------------------')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we have done before, let's plot the boxplots for the training-validation set only." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.boxplot(X_train_val, showfliers=False)\n", "_ = plt.xticks(np.arange(1, X_train_val.shape[1] + 1), data['feature_names'], rotation=30, ha=\"right\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Preprocess the data `X` with the [StandardScaler](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html)\n", "```\n", "Zi = (Xi - u) / s\n", "```\n", "where `u` is the mean of the training samples and `s` is the standard deviation of the training samples. The resulting will have a mean value (closer to) 0 and standard deviation of (closer to) 1." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "scaler = StandardScaler()\n", "X_train_val = scaler.fit_transform(X_train_val)\n", "X_test = scaler.transform(X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally let's plot the boxplots for the training-validation set after the standard scaling." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.boxplot(X_train_val, showfliers=False)\n", "_ = plt.xticks(np.arange(1, X_train_val.shape[1] + 1), data['feature_names'], rotation=30, ha=\"right\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Save NumPy arrays to files for this notebooks and the next ones (so you do not have to run it again)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.save('X_train_val.npy', X_train_val)\n", "np.save('X_test.npy', X_test)\n", "np.save('y_train_val.npy', y_train_val)\n", "np.save('y_test.npy', y_test)\n", "np.save('classes.npy', le.classes_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Construct the model\n", "\n", "Import additional [Keras](https://keras.io) packages. Keras is a deep learning API written in Python, running on top of the machine learning platform TensorFlow. It was developed with a focus on enabling fast experimentation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense, Activation, BatchNormalization\n", "from tensorflow.keras.optimizers import Adam\n", "from tensorflow.keras.regularizers import l1\n", "from callbacks import all_callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create a simple _multi-layer perceptron (MLP)_ model. An MLP consists of at least three dense layers of nodes alternating with activation functions.\n", "\n", "- We use the [Sequential API](https://www.tensorflow.org/guide/keras/sequential_model) that is essentially a stack of layers, where each layer has exactly one input tensor and one output tensor. See also the [Functional API](https://www.tensorflow.org/guide/keras/functional).\n", "- We use _3_ hidden layers with _64_, then _32_, then _32_ neurons. See a plot of the model in the next few cell.\n", "- Each layer will use ReLU activation.\n", "- Add an output layer with _5_ neurons (one for each class), then finish with Softmax activation.\n", "- [Initializers](https://keras.io/api/layers/initializers) define the way to set the initial random weights of Keras layers. In this case, we choose [LecunUniform](https://www.tensorflow.org/api_docs/python/tf/keras/initializers/LecunUniform).\n", "- [Regularizers](https://keras.io/api/layers/regularizers) help to get models that generalize to new, unseen data (see the [overfitting problem](https://en.wikipedia.org/wiki/Overfitting)); the regularizes allow you to apply penalties on layer parameters or layer activity during optimization. These penalties are summed into the loss function that the network optimizes. In this case, we choose [L1 regularization](https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/L1) that producer _sparse models_, i.e. model where unnecessary features are set to zero, thus do not contribute to the model predictive power." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Sequential()\n", "model.add(Dense(64, input_shape=(16,), name='fc1', kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))\n", "model.add(Activation(activation='relu', name='relu1'))\n", "model.add(Dense(32, name='fc2', kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))\n", "model.add(Activation(activation='relu', name='relu2'))\n", "model.add(Dense(32, name='fc3', kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))\n", "model.add(Activation(activation='relu', name='relu3'))\n", "model.add(Dense(5, name='output', kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))\n", "model.add(Activation(activation='softmax', name='softmax'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot model. See this post on [How do you visualize neural network architectures?\n", "](https://datascience.stackexchange.com/questions/12851/how-do-you-visualize-neural-network-architectures)\n", "\n", "The question mark `?` or `None` is for the batch size that is unknown to the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model\n", "If this is the first time you run the notebook `train = True`; if you've restarted the notebook kernel after training once, set `train = False` to load the trained model from file." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once the model is created, you can config the model with losses and metrics with `model.compile()`, train the model with `model.fit()`.\n", "- We use [Adam optimizer](https://keras.io/api/optimizers/adam) with [categorical crossentropy](https://www.tensorflow.org/api_docs/python/tf/keras/losses/CategoricalCrossentropy) loss.\n", "- We use [callbacks](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback) that are utilities executed during training at some given stages of the training procedure. Callbacks can help you prevent overfitting, visualize training progress, debug your code, save checkpoints, generate logs etc. The callbacks will decay the learning rate and save the model into a directory `model_1`.\n", "- The model isn't very complex, so this should just take a few minutes even on the CPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if train:\n", " adam = Adam(lr=0.0001)\n", " model.compile(optimizer=adam, loss=['categorical_crossentropy'], metrics=['accuracy'])\n", " callbacks = all_callbacks(stop_patience = 1000,\n", " lr_factor = 0.5,\n", " lr_patience = 10,\n", " lr_epsilon = 0.000001,\n", " lr_cooldown = 2,\n", " lr_minimum = 0.0000001,\n", " outputDir = 'model_1')\n", " model.fit(X_train_val, y_train_val, batch_size=1024,\n", " epochs=30, validation_split=0.25, shuffle=True,\n", " callbacks = callbacks.callbacks)\n", "else:\n", " from tensorflow.keras.models import load_model\n", " model = load_model('model_1/KERAS_check_best_model.h5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Check performance\n", "Check the accuracy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y_keras = model.predict(X_test)\n", "\n", "from sklearn.metrics import accuracy_score\n", "print('-----------------------------------')\n", "print(\"Keras Accuracy: {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_keras, axis=1))))\n", "print('-----------------------------------')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Accuracy may be _not_ the best or only metric that you should consider when you are dealing with a classification problem, expecially with a skewed dataset.\n", "\n", "A [confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix) is the a tool that you can use to get a better understanding of how a classifier perform.\n", "\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import plotting # Import local package plotting.py\n", "from sklearn.metrics import confusion_matrix\n", "cm = confusion_matrix(y_true=np.argmax(y_test, axis=1), y_pred=np.argmax(y_keras, axis=1))\n", "plt.figure(figsize=(9,9))\n", "_ = plotting.plot_confusion_matrix(cm, le.classes_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another tool that you can use is the [ROC curve](https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc).\n", "\n", "A ROC curve (typically) features true positive rate (TPR) on the vertical axis, and false positive rate (FPR) on the horizzontal axis. The top left corner of the plot is the _ideal_ point - a FPR rate of zero, and a TPR of one. This also mean that a larger area under the curve (AUC) is usually better" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(9,9))\n", "_ = plotting.plotMultiClassRoc(y_test, y_keras, le.classes_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert the model to FPGA firmware with hls4ml\n", "\n", "\n", "\n", "Now we will go through the steps to convert the model we trained to a low-latency optimized FPGA firmware with hls4ml.\n", "- First, we will evaluate its classification performance to make sure we haven't lost accuracy using the fixed-point data types. \n", "- Then we will synthesize the model with Vivado HLS and check the metrics of latency and FPGA resource usage.\n", "\n", "hls4ml comes with a [Python API](https://fastmachinelearning.org/hls4ml) so all of the next steps will be run through the notebook and that includes HLS." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's import hls4ml package!\n", "import hls4ml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create an hls4ml configuration & model\n", "\n", "hls4ml is controlled through an _hls4ml configuration dictionary_. In this example, we'll use the most simple variation (`granularity='model'`), later exercises will look at more advanced configuration." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Generate a hls4ml configuration dictionary from the Keras model\n", "config = hls4ml.utils.config_from_keras_model(model, granularity='model')\n", "\n", "print('-----------------------------------')\n", "# Show the generated configuration dictionary for hls4ml\n", "plotting.print_dict(config)\n", "print('-----------------------------------')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's create an HLS model from the Keras model and hls4ml configuration dictionary. Please just notice that, in these notebooks, we plan to target (`fpga_part`) three different boards equipped with Xilinx SoC chips: a [ZCU106](https://www.xilinx.com/products/boards-and-kits/zcu106.html), an [Ultra96](http://zedboard.org/product/ultra96-v2-development-board), [Pyqn-Z1](https://reference.digilentinc.com/reference/programmable-logic/pynq-z1), and an even _smaller_ [MiniZed](http://zedboard.org/product/minized)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hls_model = hls4ml.converters.convert_from_keras_model(model,\n", " hls_config=config,\n", " output_dir='model_1/hls4ml_prj',\n", " #part='xczu7ev-ffvc1156-2-e') # ZCU106\n", " part='xczu3eg-sbva484-1-e') # Ultra96\n", " #part='xc7z020clg400-1') # Pynq-Z1\n", " #part='xc7z007sclg225-1') # MiniZed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualise the HLS model that we created. The model architecture is shown annotated with the layer shapes and [data types](https://github.com/Xilinx/HLS_arbitrary_Precision_Types). Please note that we are converting the trained model from a floating-point implementation to a fixed-point implementation. _Post-training quantization_ is a conversion technique that can reduce resource requriments and latency of the final hardware accelerator, with little degradation in model accuracy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hls4ml.utils.plot_model(hls_model, show_shapes=True, show_precision=True, to_file=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compile & Predict\n", "\n", "Because of the quantization, now we need to check that the HLS-model performance is still good. We first compile the `hls_model`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "hls_model.compile()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we use `hls_model.predict` to execute the FPGA firmware with bit-accurate emulation **on the CPU**." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "y_hls = hls_model.predict(np.ascontiguousarray(X_test))\n", "# this an alternative to np.ascontiguousarray()\n", "#X_test = X_test.copy(order='C')\n", "#y_hls = hls_model.predict(X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare Performance\n", "That was easy! Now let's see how the performance compares to Keras:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('-----------------------------------')\n", "print(\"Keras Accuracy: {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_keras, axis=1))))\n", "print(\"hls4ml Accuracy: {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_hls, axis=1))))\n", "print('-----------------------------------')\n", "\n", "# Enable logarithmic scale on TPR and FPR axes \n", "logscale_tpr = False # Y axis\n", "logscale_fpr = False # X axis\n", "\n", "fig, ax = plt.subplots(figsize=(9, 9))\n", "_ = plotting.plotMultiClassRoc(y_test, y_keras, le.classes_, logscale_tpr=logscale_tpr, logscale_fpr=logscale_fpr)\n", "plt.gca().set_prop_cycle(None) # reset the colors\n", "_ = plotting.plotMultiClassRoc(y_test, y_hls, le.classes_, logscale_tpr=logscale_tpr, logscale_fpr=logscale_fpr, linestyle='--')\n", "\n", "from matplotlib.lines import Line2D\n", "lines = [Line2D([0], [0], ls='-'),\n", " Line2D([0], [0], ls='--')]\n", "from matplotlib.legend import Legend\n", "leg = Legend(ax, lines, labels=['keras', 'hls4ml'],\n", " loc='center right', frameon=False)\n", "_ = ax.add_artist(leg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The AUC results for the Keras and hls4ml implementation are really close - up to the second decimal point. You can notice the difference on ROC curves if you apply logaritmic scale on the FPR axis (`logscale_fpr=True`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Synthesize\n", "Now we'll actually use Vivado HLS to synthesize the model (_C-Synthesis_). We can run the build using a method of our `hls_model` object.\n", "\n", "After running this step, we can integrate the generated IP into a workflow to compile for a specific FPGA board.\n", "In this case, we'll just review the reports that Vivado HLS generates, checking the latency and resource usage." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "hls_results = hls_model.build(csim=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**This takes approx. 15 minutes on Columbia servers.**\n", "\n", "While the C-Synthesis is running, we can monitor the progress looking at the log file by opening a terminal from the notebook home, and executing:\n", "\n", "`tail -f model_1/hls4ml_prj/vivado_hls.log`\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Check the results\n", "\n", "You can print the HLS results from the synthesis at the previous step." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('-----------------------------------')\n", "#print(hls_results) # Print hashmap\n", "print(\"Estimated Clock Period: {} ns\".format(hls_results['EstimatedClockPeriod']))\n", "print(\"Best/Worst Latency: {} / {}\".format(hls_results['BestLatency'], hls_results['WorstLatency']))\n", "print(\"Interval Min/Max: {} / {}\".format(hls_results['IntervalMin'], hls_results['IntervalMax']))\n", "print(\"BRAM_18K: {} (Aval. {})\".format(hls_results['BRAM_18K'], hls_results['AvailableBRAM_18K']))\n", "print(\"DSP48E: {} (Aval. {})\".format(hls_results['DSP48E'], hls_results['AvailableDSP48E']))\n", "print(\"FF: {} (Aval. {})\".format(hls_results['FF'], hls_results['AvailableFF']))\n", "print(\"LUT: {} (Aval. {})\".format(hls_results['LUT'], hls_results['AvailableLUT']))\n", "print(\"URAM: {} (Aval. {})\".format(hls_results['URAM'], hls_results['AvailableURAM']))\n", "print('-----------------------------------')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(hls_results)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also view the entire reports generated by Vivado HLS. Pay attention to the _Latency_ and the _Utilization Estimates_ sections." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "hls4ml.report.read_vivado_report('model_1/hls4ml_prj/')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Under the Hood\n", "\n", "The `hls_model` and in particular all of the the hls4ml-generated files are in the `model_1/hls4ml_prj` directory.\n", "\n", "In this tutorial we use the [Python API](https://fastmachinelearning.org/hls4ml/autodoc/hls4ml.html) to hls4ml, but the tool comes also with a [command line interface](https://fastmachinelearning.org/hls4ml/command.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusions\n", "\n", "With the current hls4ml configurations, the resource usage that HLS _estimates_ for the current design is greater than the available resources on each of the boards ([ZCU106](https://www.xilinx.com/products/boards-and-kits/zcu106.html), [Ultra96](http://zedboard.org/product/ultra96-v2-development-board), [Pynq-Z1](https://reference.digilentinc.com/reference/programmable-logic/pynq-z1), and [MiniZed](http://zedboard.org/product/minized)).\n", "\n", "In the next notebooks, we will learn how to reduce the hardware-resource usage without affecting the model accuracy.\n", "\n", "\n", "\n", "Here we summarize the expected latency and resource costs for each of these boards from the previous synthesis runs.\n", "\n", "```\n", "+-----------------------------------------------------------+\n", "| ZCU106 |\n", "+-----------------+---------+-------+--------+--------+-----+\n", "| Name | BRAM_18K| DSP48E| FF | LUT | URAM|\n", "+-----------------+---------+-------+--------+--------+-----+ +-----+-----+-----+-----+----------+\n", "|Total | 4| 3911| 26921| 88404| 0| | Latency | Interval | Pipeline |\n", "+-----------------+---------+-------+--------+--------+-----+ | min | max | min | max | Type |\n", "|Available | 624| 1728| 460800| 230400| 96| +-----+-----+-----+-----+----------+\n", "+-----------------+---------+-------+--------+--------+-----+ | 9| 9| 1| 1| function |\n", "|Utilization (%) | ~0 | 226| 5| 38| 0| +-----+-----+-----+-----+----------+\n", "+-----------------+---------+-------+--------+--------+-----+\n", "\n", "+-----------------------------------------------------------+\n", "| Ultra96 |\n", "+-----------------+---------+-------+--------+-------+------+\n", "| Name | BRAM_18K| DSP48E| FF | LUT | URAM|\n", "+-----------------+---------+-------+--------+-------+------+ +-----+-----+-----+-----+----------+\n", "|Total | 4| 3911| 49742| 88564| 0| | Latency | Interval | Pipeline |\n", "+-----------------+---------+-------+--------+-------+------+ | min | max | min | max | Type |\n", "|Available | 432| 360| 141120| 70560| 0| +-----+-----+-----+-----+----------+\n", "+-----------------+---------+-------+--------+-------+------+ | 14| 14| 1| 1| function |\n", "|Utilization (%) | ~0 | 1086| 35| 125| 0| +-----+-----+-----+-----+----------+\n", "+-----------------+---------+-------+--------+-------+------+\n", "\n", "+----------------------------------------------------------+\n", "| Pynq-Z1 |\n", "+-----------------+---------+-------+--------+-------+-----+ +-----+-----+-----+-----+----------+\n", "| Name | BRAM_18K| DSP48E| FF | LUT | URAM| | Latency | Interval | Pipeline |\n", "+-----------------+---------+-------+--------+-------+-----+ | min | max | min | max | Type |\n", "|Total | 4| 3911| 270258| 90772| 0| +-----+-----+-----+-----+----------+\n", "+-----------------+---------+-------+--------+-------+-----+ | 52| 52| 1| 1| function |\n", "|Available | 280| 220| 106400| 53200| 0| +-----+-----+-----+-----+----------+\n", "+-----------------+---------+-------+--------+-------+-----+\n", "|Utilization (%) | 1| 1777| 254| 170| 0|\n", "+-----------------+---------+-------+--------+-------+-----+\n", "\n", "+----------------------------------------------------------+\n", "| MiniZed |\n", "+-----------------+---------+-------+--------+-------+-----+ +-----+-----+-----+-----+----------+\n", "|Total | 4| 3911| 270258| 90772| 0| | Latency | Interval | Pipeline |\n", "+-----------------+---------+-------+--------+-------+-----+ | min | max | min | max | Type |\n", "|Available | 100| 66| 28800| 14400| 0| +-----+-----+-----+-----+----------+\n", "+-----------------+---------+-------+--------+-------+-----+ | 52| 52| 1| 1| function |\n", "|Utilization (%) | 4| 5925| 938| 630| 0| +-----+-----+-----+-----+----------+\n", "+-----------------+---------+-------+--------+-------+-----+\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercise\n", "Since `ReuseFactor = 1` we expect each multiplication used in the inference of our neural network to use 1 DSP. Is this what we see? (Note that the Softmax layer should use 5 DSPs, or 1 per class)\n", "Calculate how many multiplications are performed for the inference of this network...\n", "(We'll discuss the outcome)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 2 }