\n", "def add(a: Int, b: Int): Int = a + b\n", "val sum = exList.reduce(add)\n", "\n", "val anon\\_sum = exList.reduce(\\_ + \\_)\n", "\n", "def avg(a: Int, b: Double): Double = (a + b)/2.0\n", "val ma2 = exList.scanRight(0.0)(avg)\n", "
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Module 3.4: Functional Programming\n",
"**Prev: [Higher-Order Functions](3.3_higher-order_functions.ipynb)**
\n",
"**Next: [Object Oriented Programming](3.5_object_oriented_programming.ipynb)**\n",
"\n",
"## Motivation\n",
"You saw functions in many previous modules, but now it's time to make our own and use them effectively.\n",
"\n",
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"val path = System.getProperty(\"user.dir\") + \"/source/load-ivy.sc\"\n",
"interp.load.module(ammonite.ops.Path(java.nio.file.FileSystems.getDefault().getPath(path)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This module uses the Chisel `FixedPoint` type, which currently resides in the experimental package."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import chisel3._\n",
"import chisel3.util._\n",
"import chisel3.iotesters.{ChiselFlatSpec, Driver, PeekPokeTester}\n",
"import chisel3.experimental._\n",
"import chisel3.internal.firrtl.KnownBinaryPoint"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"# Functional Programming in Scala\n",
"Scala functions were introduced in Module 1, and you saw then used a lot in the previous module. Here's a refresher on functions. Functions take any number of inputs and produce one output. Inputs are often called arguments to a function. To produce no output, return the `Unit` type. \n",
"\n",
"**Example: Custom Functions**
\n",
"Below are some examples of functions in Scala."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"// No inputs or outputs (two versions).\n",
"def hello1(): Unit = print(\"Hello!\")\n",
"def hello2 = print(\"Hello again!\")\n",
"\n",
"// Math operation: one input and one output.\n",
"def times2(x: Int): Int = 2 * x\n",
"\n",
"// Inputs can have default values, and explicitly specifying the return type is optional.\n",
"// Note that we recommend specifying the return types to avoid surprises/bugs.\n",
"def timesN(x: Int, n: Int = 2) = n * x\n",
"\n",
"// Call the functions listed above.\n",
"hello1()\n",
"hello2\n",
"times2(4)\n",
"timesN(4) // no need to specify n to use the default value\n",
"timesN(4, 3) // argument order is the same as the order where the function was defined\n",
"timesN(n=7, x=2) // arguments may be reordered and assigned to explicitly"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Functions as Objects\n",
"Functions in Scala are first-class objects. That means we can assign a function to a `val` and pass it to classes, objects, or other functions as an argument.\n",
"\n",
"**Example: Function Objects**
\n",
"Below are the same functions implemented as functions and as objects."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"// These are normal functions.\n",
"def plus1funct(x: Int): Int = x + 1\n",
"def times2funct(x: Int): Int = x * 2\n",
"\n",
"// These are functions as vals.\n",
"// The first one explicitly specifies the return type.\n",
"val plus1val: Int => Int = x => x + 1\n",
"val times2val = (x: Int) => x * 2\n",
"\n",
"// Calling both looks the same.\n",
"plus1funct(4)\n",
"plus1val(4)\n",
"plus1funct(x=4)\n",
"//plus1val(x=4) // this doesn't work"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Why would you want to create a `val` instead of a `def`? With a `val`, you can now pass the function around to other functions, as shown below. You can even create your own functions that accept other functions as arguments. Formally, functions that take or produce functions are called *higher-order functions*. You saw them used in the last module, but now you'll make your own!\n",
"\n",
"**Example: Higher-Order Functions**
\n",
"Here we show `map` again, and we also create a new function, `opN`, that accepts a function, `op`, as an argument."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"// create our function\n",
"val plus1 = (x: Int) => x + 1\n",
"val times2 = (x: Int) => x * 2\n",
"\n",
"// pass it to map, a list function\n",
"val myList = List(1, 2, 5, 9)\n",
"val myListPlus = myList.map(plus1)\n",
"val myListTimes = myList.map(times2)\n",
"\n",
"// create a custom function, which performs an operation on X N times using recursion\n",
"def opN(x: Int, n: Int, op: Int => Int): Int = {\n",
" if (n <= 0) { x }\n",
" else { opN(op(x), n-1, op) }\n",
"}\n",
"\n",
"opN(7, 3, plus1)\n",
"opN(7, 3, times2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Example: Functions vs. Objects**
\n",
"A possibly confusing situation arises when using functions without arguments. Functions are evaluated every time they are called, while `val`s are evaluated at instantiation. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import scala.util.Random\n",
"\n",
"// both x and y call the nextInt function, but x is evaluated immediately and y is a function\n",
"val x = Random.nextInt\n",
"def y = Random.nextInt\n",
"\n",
"// x was previously evaluated, so it is a constant\n",
"println(s\"x = $x\")\n",
"println(s\"x = $x\")\n",
"\n",
"// y is a function and gets reevaluated at each call, thus these produce different results\n",
"println(s\"y = $y\")\n",
"println(s\"y = $y\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Anonymous Functions\n",
"As the name implies, anonymous functions are nameless. There's no need to create a `val` for a function if we'll only use it once. \n",
"\n",
"**Example: Anonymous Functions**
\n",
"The following example demonstrates this. They are often scoped (put in curly braces instead of parentheses). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"val myList = List(5, 6, 7, 8)\n",
"\n",
"// add one to every item in the list using an anonymous function\n",
"// arguments get passed to the underscore variable\n",
"// these all do the same thing\n",
"myList.map( (x:Int) => x + 1 )\n",
"myList.map(_ + 1)\n",
"\n",
"// a common situation is to use case statements within an anonymous function\n",
"val myAnyList = List(1, 2, \"3\", 4L, myList)\n",
"myAnyList.map {\n",
" case (_:Int|_:Long) => \"Number\"\n",
" case _:String => \"String\"\n",
" case _ => \"error\"\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise: Sequence Manipulation**
\n",
"A common set of higher-order functions you'll use are `scanLeft`/`scanRight`, `reduceLeft`/`reduceRight`, and `foldLeft`/`foldRight`. It's important to understand how each one works and when to use them. The default directions for `scan`, `reduce`, and `fold` are left, though this is not guaranteed for all cases. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"val exList = List(1, 5, 7, 100)\n",
"\n",
"// write a custom function to add two numbers, then use reduce to find the sum of all values in exList\n",
"def add(a: Int, b: Int): Int = ???\n",
"val sum = ???\n",
"\n",
"// find the sum of exList using an anonymous function (hint: you've seen this before!)\n",
"val anon_sum = ???\n",
"\n",
"// find the moving average of exList from right to left using scan; make the result (ma2) a list of doubles\n",
"def avg(a: Int, b: Double): Double = ???\n",
"val ma2 = ???"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert(add(88, 88) == 176)\n",
"assert(sum == 113)\n",
"\n",
"assert(anon_sum == 113)\n",
"\n",
"assert(avg(100, 100.0) == 100.0)\n",
"assert(ma2 == List(8.875, 16.75, 28.5, 50.0, 0.0))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n", "def add(a: Int, b: Int): Int = a + b\n", "val sum = exList.reduce(add)\n", "\n", "val anon\\_sum = exList.reduce(\\_ + \\_)\n", "\n", "def avg(a: Int, b: Double): Double = (a + b)/2.0\n", "val ma2 = exList.scanRight(0.0)(avg)\n", "
\n", " val mac = io.in.zip(io.weights).map{ case(a:FixedPoint, b:FixedPoint) => a*b}.reduce(_+_)\n", " io.out := act(mac)\n", "
\n", "val Step: FixedPoint => FixedPoint = x => Mux(x <= 0.F(8.BP), 0.F(8.BP), 1.F(8.BP))\n", "val ReLU: FixedPoint => FixedPoint = x => Mux(x <= 0.F(8.BP), 0.F(8.BP), x)\n", "
\n", "val weights = Seq(1.0, 1.0)\n", "