{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Module 4.4: A FIRRTL Transform Example\n",
"\n",
"**Prev: [Common Pass Idioms](4.3_firrtl_common_idioms.ipynb)**
\n",
"\n",
"This AnalyzeCircuit Transform walks a `firrtl.ir.Circuit`, and records the number of add ops it finds, per module.\n",
"\n",
"## Setup\n",
"\n",
"Please run the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"val path = System.getProperty(\"user.dir\") + \"/source/load-ivy.sc\"\n",
"interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"// Compiler Infrastructure\n",
"\n",
"// Firrtl IR classes\n",
"\n",
"// Map functions\n",
"\n",
"// Scala's mutable collections\n",
"import scala.collection.mutable\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Counting Adders Per Module\n",
"\n",
"As described, earlier, a Firrtl circuit is represented using a tree representation:\n",
" - A Firrtl `Circuit` contains a sequence of `DefModule`s.\n",
" - A `DefModule` contains a sequence of `Port`s, and maybe a `Statement`.\n",
" - A `Statement` can contain other `Statement`s, or `Expression`s.\n",
" - A `Expression` can contain other `Expression`s.\n",
"\n",
"To visit all Firrtl IR nodes in a circuit, we write functions that recursively walk down this tree. To record statistics, we will pass along a `Ledger` class and use it when we come across an add op:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Ledger {\n",
" import firrtl.Utils\n",
" private var moduleName: Option[String] = None\n",
" private val modules = mutable.Set[String]()\n",
" private val moduleAddMap = mutable.Map[String, Int]()\n",
" def foundAdd(): Unit = moduleName match {\n",
" case None => sys.error(\"Module name not defined in Ledger!\")\n",
" case Some(name) => moduleAddMap(name) = moduleAddMap.getOrElse(name, 0) + 1\n",
" }\n",
" def getModuleName: String = moduleName match {\n",
" case None => Utils.error(\"Module name not defined in Ledger!\")\n",
" case Some(name) => name\n",
" }\n",
" def setModuleName(myName: String): Unit = {\n",
" modules += myName\n",
" moduleName = Some(myName)\n",
" }\n",
" def serialize: String = {\n",
" modules map { myName =>\n",
" s\"$myName => ${moduleAddMap.getOrElse(myName, 0)} add ops!\"\n",
" } mkString \"\\n\"\n",
" }\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's define a FIRRTL Transform that walks the circuit and updates our `Ledger` whenever it comes across an adder (`DoPrim` with op argument `Add`). Don't worry about `inputForm` or `outputForm` for now.\n",
"\n",
"Take some time to understand how `walkModule`, `walkStatement`, and `walkExpression` enable traversing all `DefModule`, `Statement`, and `Expression` nodes in the FIRRTL AST.\n",
"\n",
"Questions to answer:\n",
" - **Why doesn't walkModule call walkExpression?**\n",
" - **Why does walkExpression do a post-order traversal?**\n",
" - **Can you modify walkExpression to do a pre-order traversal of Expressions?**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class AnalyzeCircuit extends firrtl.Transform {\n",
" import firrtl._\n",
" import firrtl.ir._\n",
" import firrtl.Mappers._\n",
" import firrtl.Parser._\n",
" import firrtl.annotations._\n",
" import firrtl.PrimOps._\n",
" \n",
" // Requires the [[Circuit]] form to be \"low\"\n",
" def inputForm = LowForm\n",
" // Indicates the output [[Circuit]] form to be \"low\"\n",
" def outputForm = LowForm\n",
"\n",
" // Called by [[Compiler]] to run your pass. [[CircuitState]] contains\n",
" // the circuit and its form, as well as other related data.\n",
" def execute(state: CircuitState): CircuitState = {\n",
" val ledger = new Ledger()\n",
" val circuit = state.circuit\n",
"\n",
" // Execute the function walkModule(ledger) on every [[DefModule]] in\n",
" // circuit, returning a new [[Circuit]] with new [[Seq]] of [[DefModule]].\n",
" // - \"higher order functions\" - using a function as an object\n",
" // - \"function currying\" - partial argument notation\n",
" // - \"infix notation\" - fancy function calling syntax\n",
" // - \"map\" - classic functional programming concept\n",
" // - discard the returned new [[Circuit]] because circuit is unmodified\n",
" circuit map walkModule(ledger)\n",
"\n",
" // Print our ledger\n",
" println(ledger.serialize)\n",
"\n",
" // Return an unchanged [[CircuitState]]\n",
" state\n",
" }\n",
"\n",
" // Deeply visits every [[Statement]] in m.\n",
" def walkModule(ledger: Ledger)(m: DefModule): DefModule = {\n",
" // Set ledger to current module name\n",
" ledger.setModuleName(m.name)\n",
"\n",
" // Execute the function walkStatement(ledger) on every [[Statement]] in m.\n",
" // - return the new [[DefModule]] (in this case, its identical to m)\n",
" // - if m does not contain [[Statement]], map returns m.\n",
" m map walkStatement(ledger)\n",
" }\n",
"\n",
" // Deeply visits every [[Statement]] and [[Expression]] in s.\n",
" def walkStatement(ledger: Ledger)(s: Statement): Statement = {\n",
"\n",
" // Execute the function walkExpression(ledger) on every [[Expression]] in s.\n",
" // - discard the new [[Statement]] (in this case, its identical to s)\n",
" // - if s does not contain [[Expression]], map returns s.\n",
" s map walkExpression(ledger)\n",
"\n",
" // Execute the function walkStatement(ledger) on every [[Statement]] in s.\n",
" // - return the new [[Statement]] (in this case, its identical to s)\n",
" // - if s does not contain [[Statement]], map returns s.\n",
" s map walkStatement(ledger)\n",
" }\n",
"\n",
" // Deeply visits every [[Expression]] in e.\n",
" // - \"post-order traversal\" - handle e's children [[Expression]] before e\n",
" def walkExpression(ledger: Ledger)(e: Expression): Expression = {\n",
"\n",
" // Execute the function walkExpression(ledger) on every [[Expression]] in e.\n",
" // - return the new [[Expression]] (in this case, its identical to e)\n",
" // - if s does not contain [[Expression]], map returns e.\n",
" val visited = e map walkExpression(ledger)\n",
"\n",
" visited match {\n",
" // If e is an adder, increment our ledger and return e.\n",
" case DoPrim(Add, _, _, _) =>\n",
" ledger.foundAdd\n",
" e\n",
" // If e is not an adder, return e.\n",
" case notadd => notadd\n",
" }\n",
" }\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running our Transform\n",
"\n",
"Now that we've defined it, let's run it on a Chisel design! First, let's define a Chisel module."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"// Chisel stuff\n",
"import chisel3._\n",
"import chisel3.util._\n",
"\n",
"class AddMe(nInputs: Int, width: Int) extends Module {\n",
" val io = IO(new Bundle {\n",
" val in = Input(Vec(nInputs, UInt(width.W)))\n",
" val out = Output(UInt(width.W))\n",
" })\n",
" io.out := io.in.reduce(_ +& _)\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's elaborate it into FIRRTL AST syntax."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"val firrtlSerialization = chisel3.Driver.emit(() => new AddMe(8, 4))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, let's compile our FIRRTL into Verilog, but include our custom transform into the compilation. Note that it prints out the number of add ops it found!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"val verilog = compileFIRRTL(firrtlSerialization, new firrtl.VerilogCompiler(), Seq(new AnalyzeCircuit()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `compileFIRRTL` function is defined only in this tutorial - in a future section, we will describe how the process of inserting customTransforms.\n",
"\n",
"That's it for this section!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Scala",
"language": "scala",
"name": "scala"
},
"language_info": {
"codemirror_mode": "text/x-scala",
"file_extension": ".scala",
"mimetype": "text/x-scala",
"name": "scala211",
"nbconvert_exporter": "script",
"pygments_lexer": "scala",
"version": "2.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}