{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Module 4.3: Common Pass Idioms\n",
"\n",
"**Prev: [FIRRTL AST Traversal](4.2_firrtl_ast_traversal.ipynb)**
\n",
"**Next: [A FIRRTL Transform Example](4.4_firrtl_add_ops_per_module.ipynb)**\n",
"\n",
"### Adding statements\n",
"Suppose we want to write a pass that splits nested DoPrim expressions, thus transforming this:\n",
"```\n",
"circuit Top:\n",
" module Top :\n",
" input x: UInt<3>\n",
" input y: UInt<3>\n",
" input z: UInt<3>\n",
" output o: UInt<3>\n",
" o <= add(x, add(y, z))\n",
"```\n",
"into this:\n",
"```\n",
"circuit Top:\n",
" module Top :\n",
" input x: UInt<3>\n",
" input y: UInt<3>\n",
" input z: UInt<3>\n",
" output o: UInt<3>\n",
" node GEN_1 = add(y, z)\n",
" o <= add(x, GEN_1)\n",
"```\n",
"\n",
"We first need to traverse the AST to every Statement and Expression. Then, when we see a DoPrim, we need to add a new DefNode to the module's body and insert a reference to that DefNode in place of the DoPrim. The code below implements this (and preserves the Info token). Note that `Namespace` is a utility function located in [Namespace.scala](https://github.com/ucb-bar/firrtl/blob/master/src/main/scala/firrtl/Namespace.scala).\n",
"\n",
"```scala\n",
"object Splitter extends Pass {\n",
" def name = \"Splitter!\"\n",
" /** Run splitM on every module **/\n",
" def run(c: Circuit): Circuit = c.copy(modules = c.modules map(splitM(_)))\n",
"\n",
" /** Run splitS on the body of every module **/\n",
" def splitM(m: DefModule): DefModule = m map splitS(Namespace(m))\n",
"\n",
" /** Run splitE on all children Expressions.\n",
" * If stmts contain extra statements, return a Block containing them and \n",
" * the new statement; otherwise, return the new statement. */\n",
" def splitS(namespace: Namespace)(s: Statement): Statement = {\n",
" val block = mutable.ArrayBuffer[Statement]()\n",
" s match {\n",
" case s: HasInfo => \n",
" val newStmt = s map splitE(block, namespace, s.info)\n",
" block.length match {\n",
" case 0 => newStmt\n",
" case _ => Block(block.toSeq :+ newStmt)\n",
" }\n",
" case s => s map splitS(namespace)\n",
" }\n",
"\n",
" /** Run splitE on all children expressions.\n",
" * If e is a DoPrim, add a new DefNode to block and return reference to\n",
" * the DefNode; otherwise return e.*/\n",
" def splitE(block: mutable.ArrayBuffer[Statement], namespace: Namespace, \n",
" info: Info)(e: Expression): Expression = e map splitE(block, namespace, info) match {\n",
" case e: DoPrim =>\n",
" val newName = namespace.newTemp\n",
" block += DefNode(info, newName, e)\n",
" Ref(newName, e.tpe)\n",
" case _ => e\n",
" }\n",
"}\n",
"```\n",
"### Deleting statements\n",
"Suppose we want to write a pass that inlined all DefNodes whose value is a literal, thus transforming this:\n",
"```\n",
"circuit Top:\n",
" module Top :\n",
" input x: UInt<3>\n",
" output o: UInt<4>\n",
" node y = UInt(1)\n",
" o <= add(x, y)\n",
"```\n",
"into this:\n",
"```\n",
"circuit Top:\n",
" module Top :\n",
" input x: UInt<3>\n",
" output y: UInt<4>\n",
" o <= add(x, UInt(1))\n",
"```\n",
"\n",
"We first need to traverse the AST to every Statement and Expression. Then, when we see a DefNode pointing to a Literal, we need to store it into a hashmap and return an EmptyStmt (thus deleting that DefNode). Then, whenever we see a reference to the deleted DefNode, we must insert the corresponding Literal.\n",
"\n",
"```scala\n",
"object Inliner extends Pass {\n",
" def name = \"Inliner!\"\n",
" /** Run inlineM on every module **/\n",
" def run(c: Circuit): Circuit = c.copy(modules = c.modules map(inlineM(_)))\n",
"\n",
" /** Run inlineS on the body of every module **/\n",
" def inlineM(m: DefModule): DefModule = m map inlineS(mutable.HashMap[String, Expression]())\n",
"\n",
" /** Run inlineE on all children Expressions, and then run inlineS on children statements.\n",
" * If statement is a DefNode containing a literal, update values and\n",
" * return EmptyStmt; otherwise return statement. */\n",
" def inlineS(values: mutable.HashMap[String, Expression])(s: Statement): Statement =\n",
" s map inlineE(values) map inlineS(values) match {\n",
" case d: DefNode => d.value match {\n",
" case l: Literal =>\n",
" values(d.name) = l\n",
" EmptyStmt\n",
" case _ => d\n",
" }\n",
" case o => o \n",
" }\n",
"\n",
" /** If e is a reference whose name is contained in values, \n",
" * return values(e.name); otherwise run inlineE on all \n",
" * children expressions.*/\n",
" def inlineE(values: mutable.HashMap[String, Expression])(e: Expression): Expression = e match {\n",
" case e: Ref if values.contains(e.name) => values(e.name)\n",
" case _ => e map inlineE(values)\n",
" }\n",
"}\n",
"```\n",
"\n",
"### Add a Primop\n",
"Would this be useful? Let [@azidar](https://github.com/azidar) know by submitting an issue to [the firrtl repo](https://github.com/freechipsproject/firrtl)!\n",
"\n",
"### Swap a statement\n",
"Would this be useful? Let [@azidar](https://github.com/azidar) know by submitting an issue to [the firrtl repo](https://github.com/freechipsproject/firrtl)!\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Scala",
"language": "scala",
"name": "scala"
},
"language_info": {
"codemirror_mode": "text/x-scala",
"file_extension": ".scala",
"mimetype": "text/x-scala",
"name": "scala211",
"nbconvert_exporter": "script",
"pygments_lexer": "scala",
"version": "2.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}