{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\"Chisel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Module 4.4: A FIRRTL Transform Example\n", "\n", "**Prev: [Common Pass Idioms](4.3_firrtl_common_idioms.ipynb)**
\n", "\n", "This AnalyzeCircuit Transform walks a `firrtl.ir.Circuit`, and records the number of add ops it finds, per module.\n", "\n", "## Setup\n", "\n", "Please run the following:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "val path = System.getProperty(\"user.dir\") + \"/source/load-ivy.sc\"\n", "interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "// Compiler Infrastructure\n", "\n", "// Firrtl IR classes\n", "\n", "// Map functions\n", "\n", "// Scala's mutable collections\n", "import scala.collection.mutable\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Counting Adders Per Module\n", "\n", "As described, earlier, a Firrtl circuit is represented using a tree representation:\n", " - A Firrtl `Circuit` contains a sequence of `DefModule`s.\n", " - A `DefModule` contains a sequence of `Port`s, and maybe a `Statement`.\n", " - A `Statement` can contain other `Statement`s, or `Expression`s.\n", " - A `Expression` can contain other `Expression`s.\n", "\n", "To visit all Firrtl IR nodes in a circuit, we write functions that recursively walk down this tree. To record statistics, we will pass along a `Ledger` class and use it when we come across an add op:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Ledger {\n", " import firrtl.Utils\n", " private var moduleName: Option[String] = None\n", " private val modules = mutable.Set[String]()\n", " private val moduleAddMap = mutable.Map[String, Int]()\n", " def foundAdd(): Unit = moduleName match {\n", " case None => sys.error(\"Module name not defined in Ledger!\")\n", " case Some(name) => moduleAddMap(name) = moduleAddMap.getOrElse(name, 0) + 1\n", " }\n", " def getModuleName: String = moduleName match {\n", " case None => Utils.error(\"Module name not defined in Ledger!\")\n", " case Some(name) => name\n", " }\n", " def setModuleName(myName: String): Unit = {\n", " modules += myName\n", " moduleName = Some(myName)\n", " }\n", " def serialize: String = {\n", " modules map { myName =>\n", " s\"$myName => ${moduleAddMap.getOrElse(myName, 0)} add ops!\"\n", " } mkString \"\\n\"\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's define a FIRRTL Transform that walks the circuit and updates our `Ledger` whenever it comes across an adder (`DoPrim` with op argument `Add`). Don't worry about `inputForm` or `outputForm` for now.\n", "\n", "Take some time to understand how `walkModule`, `walkStatement`, and `walkExpression` enable traversing all `DefModule`, `Statement`, and `Expression` nodes in the FIRRTL AST.\n", "\n", "Questions to answer:\n", " - **Why doesn't walkModule call walkExpression?**\n", " - **Why does walkExpression do a post-order traversal?**\n", " - **Can you modify walkExpression to do a pre-order traversal of Expressions?**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class AnalyzeCircuit extends firrtl.Transform {\n", " import firrtl._\n", " import firrtl.ir._\n", " import firrtl.Mappers._\n", " import firrtl.Parser._\n", " import firrtl.annotations._\n", " import firrtl.PrimOps._\n", " \n", " // Requires the [[Circuit]] form to be \"low\"\n", " def inputForm = LowForm\n", " // Indicates the output [[Circuit]] form to be \"low\"\n", " def outputForm = LowForm\n", "\n", " // Called by [[Compiler]] to run your pass. [[CircuitState]] contains\n", " // the circuit and its form, as well as other related data.\n", " def execute(state: CircuitState): CircuitState = {\n", " val ledger = new Ledger()\n", " val circuit = state.circuit\n", "\n", " // Execute the function walkModule(ledger) on every [[DefModule]] in\n", " // circuit, returning a new [[Circuit]] with new [[Seq]] of [[DefModule]].\n", " // - \"higher order functions\" - using a function as an object\n", " // - \"function currying\" - partial argument notation\n", " // - \"infix notation\" - fancy function calling syntax\n", " // - \"map\" - classic functional programming concept\n", " // - discard the returned new [[Circuit]] because circuit is unmodified\n", " circuit map walkModule(ledger)\n", "\n", " // Print our ledger\n", " println(ledger.serialize)\n", "\n", " // Return an unchanged [[CircuitState]]\n", " state\n", " }\n", "\n", " // Deeply visits every [[Statement]] in m.\n", " def walkModule(ledger: Ledger)(m: DefModule): DefModule = {\n", " // Set ledger to current module name\n", " ledger.setModuleName(m.name)\n", "\n", " // Execute the function walkStatement(ledger) on every [[Statement]] in m.\n", " // - return the new [[DefModule]] (in this case, its identical to m)\n", " // - if m does not contain [[Statement]], map returns m.\n", " m map walkStatement(ledger)\n", " }\n", "\n", " // Deeply visits every [[Statement]] and [[Expression]] in s.\n", " def walkStatement(ledger: Ledger)(s: Statement): Statement = {\n", "\n", " // Execute the function walkExpression(ledger) on every [[Expression]] in s.\n", " // - discard the new [[Statement]] (in this case, its identical to s)\n", " // - if s does not contain [[Expression]], map returns s.\n", " s map walkExpression(ledger)\n", "\n", " // Execute the function walkStatement(ledger) on every [[Statement]] in s.\n", " // - return the new [[Statement]] (in this case, its identical to s)\n", " // - if s does not contain [[Statement]], map returns s.\n", " s map walkStatement(ledger)\n", " }\n", "\n", " // Deeply visits every [[Expression]] in e.\n", " // - \"post-order traversal\" - handle e's children [[Expression]] before e\n", " def walkExpression(ledger: Ledger)(e: Expression): Expression = {\n", "\n", " // Execute the function walkExpression(ledger) on every [[Expression]] in e.\n", " // - return the new [[Expression]] (in this case, its identical to e)\n", " // - if s does not contain [[Expression]], map returns e.\n", " val visited = e map walkExpression(ledger)\n", "\n", " visited match {\n", " // If e is an adder, increment our ledger and return e.\n", " case DoPrim(Add, _, _, _) =>\n", " ledger.foundAdd\n", " e\n", " // If e is not an adder, return e.\n", " case notadd => notadd\n", " }\n", " }\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Running our Transform\n", "\n", "Now that we've defined it, let's run it on a Chisel design! First, let's define a Chisel module." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "// Chisel stuff\n", "import chisel3._\n", "import chisel3.util._\n", "\n", "class AddMe(nInputs: Int, width: Int) extends Module {\n", " val io = IO(new Bundle {\n", " val in = Input(Vec(nInputs, UInt(width.W)))\n", " val out = Output(UInt(width.W))\n", " })\n", " io.out := io.in.reduce(_ +& _)\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's elaborate it into FIRRTL AST syntax." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "val firrtlSerialization = chisel3.Driver.emit(() => new AddMe(8, 4))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let's compile our FIRRTL into Verilog, but include our custom transform into the compilation. Note that it prints out the number of add ops it found!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "val verilog = compileFIRRTL(firrtlSerialization, new firrtl.VerilogCompiler(), Seq(new AnalyzeCircuit()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `compileFIRRTL` function is defined only in this tutorial - in a future section, we will describe how the process of inserting customTransforms.\n", "\n", "That's it for this section!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Scala", "language": "scala", "name": "scala" }, "language_info": { "codemirror_mode": "text/x-scala", "file_extension": ".scala", "mimetype": "text/x-scala", "name": "scala211", "nbconvert_exporter": "script", "pygments_lexer": "scala", "version": "2.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }