{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This is the second part of a tutorial describing the Iteration/Expression Tree (IET), an intermediate representation based on Abstract Syntax Trees (AST) used in the Devito compiler. In this part we will learn how to build, compose, and transform IETs.\n", "\n", "# Part II - Bottom Up" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Dimensions` are the building blocks of both `Iterations` and `Expressions`.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'i': i, 'j': j, 'k': k, 't0': t0, 't1': t1}" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from devito import SpaceDimension, TimeDimension\n", "\n", "dims = {'i': SpaceDimension(name='i'),\n", " 'j': SpaceDimension(name='j'),\n", " 'k': SpaceDimension(name='k'),\n", " 't0': TimeDimension(name='t0'),\n", " 't1': TimeDimension(name='t1')}\n", "\n", "dims" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Elements such as `Scalars`, `Constants` and `Functions` are used to build SymPy equations." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'a': a, 'b': b, 'c': c[i], 'd': d[j, k], 'e': e[t0, t1, i], 'f': f[t, x, y]}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from devito.types import Array, Scalar\n", "from devito import Grid, Constant, Function, TimeFunction\n", "\n", "grid = Grid(shape=(10, 10))\n", "symbs = {'a': Scalar(name='a'),\n", " 'b': Constant(name='b'),\n", " 'c': Array(name='c', shape=(3,), dimensions=(dims['i'],)).indexify(),\n", " 'd': Array(name='d', \n", " shape=(3,3), \n", " dimensions=(dims['j'],dims['k'])).indexify(),\n", " 'e': Function(name='e', \n", " shape=(3,3,3), \n", " dimensions=(dims['t0'],dims['t1'],dims['i'])).indexify(),\n", " 'f': TimeFunction(name='f', grid=grid).indexify()}\n", "symbs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An IET `Expression` wraps a SymPy equation. Below, `DummyEq` is a subclass of `sympy.Eq` with some metadata attached. What, when and how metadata are attached is here irrelevant. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n" ] } ], "source": [ "from devito.ir.iet import Expression\n", "from devito.ir.equations import DummyEq\n", "from devito.tools import pprint\n", "\n", "def get_exprs(a, b, c, d, e, f):\n", " return [Expression(DummyEq(a, b + c + 5.)),\n", " Expression(DummyEq(d, e - f)),\n", " Expression(DummyEq(a, 4 * (b * a))),\n", " Expression(DummyEq(a, (6. / b) + (8. * a)))]\n", "\n", "exprs = get_exprs(symbs['a'],\n", " symbs['b'],\n", " symbs['c'],\n", " symbs['d'],\n", " symbs['e'],\n", " symbs['f'])\n", "\n", "pprint(exprs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An `Iteration` typically wraps one or more `Expression`s. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from devito.ir.iet import Iteration\n", "\n", "def get_iters(dims):\n", " return [lambda ex: Iteration(ex, dims['i'], (0, 3, 1)),\n", " lambda ex: Iteration(ex, dims['j'], (0, 5, 1)),\n", " lambda ex: Iteration(ex, dims['k'], (0, 7, 1)),\n", " lambda ex: Iteration(ex, dims['t0'], (0, 4, 1)),\n", " lambda ex: Iteration(ex, dims['t1'], (0, 4, 1))]\n", "\n", "iters = get_iters(dims)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we can see how blocks of `Iterations` over `Expressions` can be used to build loop nests. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n" ] } ], "source": [ "def get_block1(exprs, iters):\n", " # Perfect loop nest:\n", " # for i\n", " # for j\n", " # for k\n", " # expr0\n", " return iters[0](iters[1](iters[2](exprs[0])))\n", " \n", "def get_block2(exprs, iters):\n", " # Non-perfect simple loop nest:\n", " # for i\n", " # expr0\n", " # for j\n", " # for k\n", " # expr1\n", " return iters[0]([exprs[0], iters[1](iters[2](exprs[1]))])\n", "\n", "def get_block3(exprs, iters):\n", " # Non-perfect non-trivial loop nest:\n", " # for i\n", " # for s\n", " # expr0\n", " # for j\n", " # for k\n", " # expr1\n", " # expr2\n", " # for p\n", " # expr3\n", " return iters[0]([iters[3](exprs[0]),\n", " iters[1](iters[2]([exprs[1], exprs[2]])),\n", " iters[4](exprs[3])])\n", "\n", "block1 = get_block1(exprs, iters)\n", "block2 = get_block2(exprs, iters)\n", "block3 = get_block3(exprs, iters)\n", "\n", "pprint(block1), print('\\n')\n", "pprint(block2), print('\\n')\n", "pprint(block3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And, finally, we can build `Callable` _kernels_ that will be used to generate C code. Note that `Operator` is a subclass of `Callable`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "kernel no.1:\n", "void foo()\n", "{\n", " for (int i = 0; i <= 3; i += 1)\n", " {\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " a = b + c[i] + 5.0F;\n", " }\n", " }\n", " }\n", "}\n", "\n", "kernel no.2:\n", "void foo()\n", "{\n", " for (int i = 0; i <= 3; i += 1)\n", " {\n", " a = b + c[i] + 5.0F;\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " d[j][k] = e[t0][t1][i] - f[t][x][y];\n", " }\n", " }\n", " }\n", "}\n", "\n", "kernel no.3:\n", "void foo()\n", "{\n", " for (int i = 0; i <= 3; i += 1)\n", " {\n", " for (int t0 = 0; t0 <= 4; t0 += 1)\n", " {\n", " a = b + c[i] + 5.0F;\n", " }\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " d[j][k] = e[t0][t1][i] - f[t][x][y];\n", " a = 4*a*b;\n", " }\n", " }\n", " for (int t1 = 0; t1 <= 4; t1 += 1)\n", " {\n", " a = 8.0F*a + 6.0F/b;\n", " }\n", " }\n", "}\n", "\n" ] } ], "source": [ "from devito.ir.iet import Callable\n", "\n", "kernels = [Callable('foo', block1, 'void', ()),\n", " Callable('foo', block2, 'void', ()),\n", " Callable('foo', block3, 'void', ())]\n", "\n", "print('kernel no.1:\\n' + str(kernels[0].ccode) + '\\n')\n", "print('kernel no.2:\\n' + str(kernels[1].ccode) + '\\n')\n", "print('kernel no.3:\\n' + str(kernels[2].ccode) + '\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An IET is immutable. It can be \"transformed\" by replacing or dropping some of its inner nodes, but what this actually means is that a new IET is created. IETs are transformed by `Transformer` visitors. A `Transformer` takes in input a dictionary encoding replacement rules." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "void foo()\n", "{\n", " for (int i = 0; i <= 3; i += 1)\n", " {\n", " a = b + c[i] + 5.0F;\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " d[j][k] = e[t0][t1][i] - f[t][x][y];\n", " }\n", " }\n", " }\n", "}\n" ] } ], "source": [ "from devito.ir.iet import Transformer\n", "\n", "# Replaces a Function's body with another\n", "transformer = Transformer({block1: block2})\n", "kernel_alt = transformer.visit(kernels[0])\n", "print(kernel_alt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Specific `Expression`s within the loop nest can also be substituted." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "for (int i = 0; i <= 3; i += 1)\n", "{\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " d[j][k] = e[t0][t1][i] - f[t][x][y];\n", " }\n", " }\n", "}\n" ] } ], "source": [ "# Replaces an expression with another\n", "transformer = Transformer({exprs[0]: exprs[1]})\n", "newblock = transformer.visit(block1)\n", "newcode = str(newblock.ccode)\n", "print(newcode)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "for (int i = 0; i <= 3; i += 1)\n", "{\n", " a = b + c[i] + 5.0F;\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " // Replaced expression\n", " {\n", " }\n", " }\n", " }\n", "}\n" ] } ], "source": [ "from devito.ir.iet import Block\n", "import cgen as c\n", "\n", "# Creates a replacer for replacing an expression\n", "line1 = '// Replaced expression'\n", "replacer = Block(c.Line(line1))\n", "transformer = Transformer({exprs[1]: replacer})\n", "newblock = transformer.visit(block2)\n", "newcode = str(newblock.ccode)\n", "print(newcode)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "for (int i = 0; i <= 3; i += 1)\n", "{\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " // This is the opening comment\n", " {\n", " a = b + c[i] + 5.0F;\n", " }\n", " // This is the closing comment\n", " }\n", " }\n", "}\n" ] } ], "source": [ "# Wraps an expression in comments\n", "line1 = '// This is the opening comment'\n", "line2 = '// This is the closing comment'\n", "wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2))\n", "transformer = Transformer({exprs[0]: wrapper(exprs[0])})\n", "newblock = transformer.visit(block1)\n", "newcode = str(newblock.ccode)\n", "print(newcode)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [default]", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }