{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial we will learn how to build, compose, and transform Iteration/Expression Trees (IETs).\n", "\n", "# Part II - Bottom Up\n", "\n", "`Dimensions` are the building blocks of both `Iterations` and `Expressions`.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'i': i, 'j': j, 'k': k, 't0': t0, 't1': t1}" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from devito import SpaceDimension, TimeDimension\n", "\n", "dims = {'i': SpaceDimension(name='i'),\n", " 'j': SpaceDimension(name='j'),\n", " 'k': SpaceDimension(name='k'),\n", " 't0': TimeDimension(name='t0'),\n", " 't1': TimeDimension(name='t1')}\n", "\n", "dims" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Elements such as `Scalars`, `Constants` and `Functions` are used to build SymPy equations." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'a': a, 'b': b, 'c': c[i], 'd': d[j, k], 'e': e[t0, t1, i], 'f': f[t, x, y]}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from devito import Grid, Constant, Function, TimeFunction\n", "from devito.types import Array, Scalar\n", "\n", "grid = Grid(shape=(10, 10))\n", "symbs = {'a': Scalar(name='a'),\n", " 'b': Constant(name='b'),\n", " 'c': Array(name='c', shape=(3,), dimensions=(dims['i'],)).indexify(),\n", " 'd': Array(name='d', \n", " shape=(3,3), \n", " dimensions=(dims['j'],dims['k'])).indexify(),\n", " 'e': Function(name='e', \n", " shape=(3,3,3), \n", " dimensions=(dims['t0'],dims['t1'],dims['i'])).indexify(),\n", " 'f': TimeFunction(name='f', grid=grid).indexify()}\n", "symbs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An IET `Expression` wraps a SymPy equation. Below, `DummyEq` is a subclass of `sympy.Eq` with some metadata attached. What, when and how metadata are attached is here irrelevant. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n", "\n" ] } ], "source": [ "from devito.ir.iet import Expression\n", "from devito.ir.equations import DummyEq\n", "from devito.tools import pprint\n", "\n", "def get_exprs(a, b, c, d, e, f):\n", " return [Expression(DummyEq(a, b + c + 5.)),\n", " Expression(DummyEq(d, e - f)),\n", " Expression(DummyEq(a, 4 * (b * a))),\n", " Expression(DummyEq(a, (6. / b) + (8. * a)))]\n", "\n", "exprs = get_exprs(symbs['a'],\n", " symbs['b'],\n", " symbs['c'],\n", " symbs['d'],\n", " symbs['e'],\n", " symbs['f'])\n", "\n", "pprint(exprs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An `Iteration` typically wraps one or more `Expression`s. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from devito.ir.iet import Iteration\n", "\n", "def get_iters(dims):\n", " return [lambda ex: Iteration(ex, dims['i'], (0, 3, 1)),\n", " lambda ex: Iteration(ex, dims['j'], (0, 5, 1)),\n", " lambda ex: Iteration(ex, dims['k'], (0, 7, 1)),\n", " lambda ex: Iteration(ex, dims['t0'], (0, 4, 1)),\n", " lambda ex: Iteration(ex, dims['t1'], (0, 4, 1))]\n", "\n", "iters = get_iters(dims)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we can see how blocks of `Iterations` over `Expressions` can be used to build loop nests. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n" ] } ], "source": [ "def get_block1(exprs, iters):\n", " # Perfect loop nest:\n", " # for i\n", " # for j\n", " # for k\n", " # expr0\n", " return iters[0](iters[1](iters[2](exprs[0])))\n", " \n", "def get_block2(exprs, iters):\n", " # Non-perfect simple loop nest:\n", " # for i\n", " # expr0\n", " # for j\n", " # for k\n", " # expr1\n", " return iters[0]([exprs[0], iters[1](iters[2](exprs[1]))])\n", "\n", "def get_block3(exprs, iters):\n", " # Non-perfect non-trivial loop nest:\n", " # for i\n", " # for s\n", " # expr0\n", " # for j\n", " # for k\n", " # expr1\n", " # expr2\n", " # for p\n", " # expr3\n", " return iters[0]([iters[3](exprs[0]),\n", " iters[1](iters[2]([exprs[1], exprs[2]])),\n", " iters[4](exprs[3])])\n", "\n", "block1 = get_block1(exprs, iters)\n", "block2 = get_block2(exprs, iters)\n", "block3 = get_block3(exprs, iters)\n", "\n", "pprint(block1), print('\\n')\n", "pprint(block2), print('\\n')\n", "pprint(block3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And, finally, we can build `Callable` _kernels_ that will be used to generate C code. Note that `Operator` is a subclass of `Callable`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "kernel no.1:\n", "void foo()\n", "{\n", " for (int i = 0; i <= 3; i += 1)\n", " {\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " a = b + c[i] + 5.0F;\n", " }\n", " }\n", " }\n", "}\n", "\n", "kernel no.2:\n", "void foo()\n", "{\n", " for (int i = 0; i <= 3; i += 1)\n", " {\n", " a = b + c[i] + 5.0F;\n", "\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " d[j][k] = e[t0][t1][i] - f[t][x][y];\n", " }\n", " }\n", " }\n", "}\n", "\n", "kernel no.3:\n", "void foo()\n", "{\n", " for (int i = 0; i <= 3; i += 1)\n", " {\n", " for (int t0 = 0; t0 <= 4; t0 += 1)\n", " {\n", " a = b + c[i] + 5.0F;\n", " }\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " d[j][k] = e[t0][t1][i] - f[t][x][y];\n", " a = 4*b*a;\n", " }\n", " }\n", " for (int t1 = 0; t1 <= 4; t1 += 1)\n", " {\n", " a = 8.0F*a + 6.0F/b;\n", " }\n", " }\n", "}\n", "\n" ] } ], "source": [ "from devito.ir.iet import Callable\n", "\n", "kernels = [Callable('foo', block1, 'void', ()),\n", " Callable('foo', block2, 'void', ()),\n", " Callable('foo', block3, 'void', ())]\n", "\n", "print('kernel no.1:\\n' + str(kernels[0].ccode) + '\\n')\n", "print('kernel no.2:\\n' + str(kernels[1].ccode) + '\\n')\n", "print('kernel no.3:\\n' + str(kernels[2].ccode) + '\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An IET is immutable. It can be \"transformed\" by replacing or dropping some of its inner nodes, but what this actually means is that a new IET is created. IETs are transformed by `Transformer` visitors. A `Transformer` takes in input a dictionary encoding replacement rules." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "void foo()\n", "{\n", " for (int i = 0; i <= 3; i += 1)\n", " {\n", " a = b + c[i] + 5.0F;\n", "\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " d[j][k] = e[t0][t1][i] - f[t][x][y];\n", " }\n", " }\n", " }\n", "}\n" ] } ], "source": [ "from devito.ir.iet import Transformer\n", "\n", "# Replaces a Function's body with another\n", "transformer = Transformer({block1: block2})\n", "kernel_alt = transformer.visit(kernels[0])\n", "print(kernel_alt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Specific `Expression`s within the loop nest can also be substituted." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "for (int i = 0; i <= 3; i += 1)\n", "{\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " d[j][k] = e[t0][t1][i] - f[t][x][y];\n", " }\n", " }\n", "}\n" ] } ], "source": [ "# Replaces an expression with another\n", "transformer = Transformer({exprs[0]: exprs[1]})\n", "newblock = transformer.visit(block1)\n", "newcode = str(newblock.ccode)\n", "print(newcode)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "for (int i = 0; i <= 3; i += 1)\n", "{\n", " a = b + c[i] + 5.0F;\n", "\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " // Replaced expression\n", " {\n", " }\n", " }\n", " }\n", "}\n" ] } ], "source": [ "from devito.ir.iet import Block\n", "import cgen as c\n", "\n", "# Creates a replacer for replacing an expression\n", "line1 = '// Replaced expression'\n", "replacer = Block(c.Line(line1))\n", "transformer = Transformer({exprs[1]: replacer})\n", "newblock = transformer.visit(block2)\n", "newcode = str(newblock.ccode)\n", "print(newcode)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "for (int i = 0; i <= 3; i += 1)\n", "{\n", " for (int j = 0; j <= 5; j += 1)\n", " {\n", " for (int k = 0; k <= 7; k += 1)\n", " {\n", " // This is the opening comment\n", " {\n", " a = b + c[i] + 5.0F;\n", " }\n", " // This is the closing comment\n", " }\n", " }\n", "}\n" ] } ], "source": [ "# Wraps an expression in comments\n", "line1 = '// This is the opening comment'\n", "line2 = '// This is the closing comment'\n", "wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2))\n", "transformer = Transformer({exprs[0]: wrapper(exprs[0])})\n", "newblock = transformer.visit(block1)\n", "newcode = str(newblock.ccode)\n", "print(newcode)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 2 }