{ "cells": [ { "cell_type": "markdown", "id": "57be61cb", "metadata": {}, "source": [ "# Jax Parallelization Tricks\n", "\n", "\n", "While differentiability is a great reason to use JAX, arguably\n", "its APIs for parallelization are an equally good reason\n", "\n", "In standard..\n" ] }, { "cell_type": "markdown", "id": "5b03fa84", "metadata": {}, "source": [ "## An Example\n", "\n", "Let's say you have a function computing something\n", "interesting for some inputs `x`.\n", "\n", "For example we can use the text book definition\n", "of a dense layer in a neural network\n", "\n", "$$y = W x + b$$\n", "\n", "or written with indices: \n", "\n", "$$y_{i} = W_{ij} x_{j} + b_{i}$$\n", "\n" ] }, { "cell_type": "code", "execution_count": 90, "id": "b964aa1a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1.1 1.2] -> [0.90000004 1.53 ]\n" ] } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "def dense_layer(x):\n", " W = jnp.array([[0.2,0.4],[0.9,0.2]])\n", " b = jnp.array([0.2,0.3])\n", " return W.dot(x) + b\n", "\n", "inputs = jnp.array([1.1,1.2])\n", "out = dense_layer(inputs)\n", "print(f'{inputs} -> {out}')" ] }, { "cell_type": "markdown", "id": "0a360820", "metadata": {}, "source": [ "But in practice we always want to evaluate on a mini-batch of $x$ with a batch dimension\n", "\n", "$$x = x_{bi}$$\n", "\n", "The linear layer would thus be in index-notation\n", "\n", "$$y_{bi} = W_{ij} x_{bj} + b_{i}$$\n", "\n", "but this is not compatible with our code `W.dot(x)` above as the batch index \n", "is in the way" ] }, { "cell_type": "code", "execution_count": 28, "id": "6a837bf0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Incompatible shapes for dot: got (2, 2) and (3, 2).\n" ] } ], "source": [ "inputs = jnp.array([\n", " [1.1,1.2],\n", " [2.1,2.2],\n", " [3.1,3.2]\n", "\n", "])\n", "\n", "try:\n", " dense_layer(inputs)\n", "except TypeError as exc:\n", " print(exc)" ] }, { "cell_type": "markdown", "id": "4ace1791", "metadata": {}, "source": [ "## Workarounds\n", "\n", "### Einsums\n", "\n", "We can work around it by adjusting our dense layer code and using `einsum` tricks\n", "to incorporate the new batch dimension:" ] }, { "cell_type": "code", "execution_count": 44, "id": "86ab5297", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.90000004, 1.53 ],\n", " [1.5000001 , 2.6299999 ],\n", " [2.1000001 , 3.7299998 ]], dtype=float32)" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def einsum_layer(x):\n", " W = jnp.array([[0.2,0.4],[0.9,0.2]])\n", " b = jnp.array([0.2,0.3])\n", " return jnp.einsum('ij,bj->bi',W,x) + b\n", "\n", "einsum_layer(inputs)" ] }, { "cell_type": "markdown", "id": "f7e68267", "metadata": {}, "source": [ "## Transposition\n", "\n", "Another workaround you'll see frequently in textbooks is to interchange\n", "the order of `W` and `b` such that the batch dimension is \"up front\" and\n", "won't be in the way\n", "\n", "$$ y = xW^T + b $$\n", "\n", "which with indices looks like\n", "\n", "$$ y = x_{bj}W^T_{ji} + b_{i} = W_{ij}x_{bj} + b_{i}$$\n", "\n", "\n", "This works, but also at the cost of changing code to accomodate \"batching\"\n" ] }, { "cell_type": "code", "execution_count": 45, "id": "0f974f37", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.90000004, 1.53 ],\n", " [1.5000001 , 2.6299999 ],\n", " [2.1000001 , 3.7299998 ]], dtype=float32)" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def transpose_layer(x):\n", " W = jnp.array([[0.2,0.4],[0.9,0.2]])\n", " b = jnp.array([0.2,0.3])\n", " return x.dot(W.T) + b\n", "\n", "\n", "transpose_layer(inputs)" ] }, { "cell_type": "markdown", "id": "7cfa1b16", "metadata": {}, "source": [ "## The JAX Way:\n", "\n", "In JAX, batching could not be easier!\n", "\n", "There is no change in the code, but you just apply a \n", "batching function to it to receive a \"batched\" version of it" ] }, { "cell_type": "code", "execution_count": 48, "id": "f9371639", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[0.90000004, 1.53 ],\n", " [1.5000001 , 2.6299999 ],\n", " [2.1000001 , 3.7299998 ]], dtype=float32)" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batched_dense = jax.vmap(dense_layer)\n", "batched_dense(inputs)" ] }, { "cell_type": "markdown", "id": "320419ef", "metadata": {}, "source": [ "# Handling multiples inputs\n", "\n", "Sometimes you have more than one inputs to a function \n", "\n", "$$ f = f(x,y) $$\n", "\n", "and for such functions some more information must be provided\n", "as to how batching should take place" ] }, { "cell_type": "markdown", "id": "69a40d36", "metadata": {}, "source": [ "## Zipping\n", "\n", "When zipping a function we want both arguments to be iterated \"in lock-step\"\n", "\n", "In standard Python it would look something like this:\n", "\n", "`out = [f(x[i],x[i]) for zip(x,y)]`\n", "\n", "In JAX, we give that batch-dimension of each argument (here: 0) via the `in_axes` argument" ] }, { "cell_type": "code", "execution_count": 71, "id": "03fdf36a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([5, 7, 9], dtype=int32)" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def func(x,y):\n", " return x+y\n", "\n", "x_batched = jnp.array([1,2,3])\n", "y_batched = jnp.array([4,5,6])\n", "jax.vmap(func, in_axes = (0,0))(x_batched,y_batched)" ] }, { "cell_type": "markdown", "id": "a3635cc4", "metadata": {}, "source": [ "## Non-leading Batch dimensions\n", "\n", "It can happen that your batch dimension for the various arguments differ\n", "\n", "for example for\n", "\n", "```\n", "x = [[1,2,3]] #shape (1,3)\n", "y = [[4],[5],[6] #shape (3,1)\n", "```\n", "\n", "In this case our Python code would look like\n", "\n", "`out = [f(x[0][i],x[i][0]) for zip(x,y)]`\n", "\n", "That is for `x` the batch dimension is the second (i.e. idx = 1, starting from 0) dimension\n", "of the array.\n", "\n", "For `y` it's the first dimension of the array (i.e. idx=0)\n", "\n", "We can communicate this to JAX by specifying where the batch dimensions are\n", "for the inputs and also where we want the batch dimension to be for the output" ] }, { "cell_type": "code", "execution_count": 72, "id": "18c04558", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[5 7 9]]\n", "(1, 3)\n" ] } ], "source": [ "x_batched = jnp.array([[1,2,3]])\n", "y_batched = jnp.array([[4],[5],[6]])\n", "out1 = jax.vmap(func, in_axes = (1,0), out_axes=1)(x_batched,y_batched)\n", "print(out1)\n", "print(out1.shape) #batch dimension of size 3 at second position (idx 1)" ] }, { "cell_type": "code", "execution_count": 73, "id": "889ab85c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[5]\n", " [7]\n", " [9]]\n", "(3, 1)\n" ] } ], "source": [ "out2 = jax.vmap(func, in_axes = (1,0), out_axes=0)(x_batched,y_batched)\n", "print(out2)\n", "print(out2.shape) #batch dimension of size 3 at first position (idx 0)" ] }, { "cell_type": "markdown", "id": "4a1ca4da", "metadata": {}, "source": [ "## Un-batched dimensions\n", "\n", "You can also have un-batched dimensions that just \"go along for the ride\"\n", "\n", "e.g.\n", "\n", "```\n", "x = 3.0\n", "out = [f(x,y) for y in range(..)]\n", "```\n", "\n", "In JAX by passing `None`" ] }, { "cell_type": "code", "execution_count": 77, "id": "c4a7a0c8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([ 9., 12., 15.], dtype=float32, weak_type=True)" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def func(x,y):\n", " return x*y\n", "\n", "jax.vmap(func, in_axes = (None,0))(3.0,jnp.array([3,4,5]))" ] }, { "cell_type": "markdown", "id": "2d241761", "metadata": {}, "source": [ "## Composition and Cartesian Products\n", "\n", "For a function $f(x,y)$ you may want to evalute it on a \"grid\"\n", "\n", "\n", "```\n", "x = 3.0\n", "out = [f(x[i],y[j]) for i,j in cartesian_product(...)]\n", "```\n", "\n", "by applying `vmap` multiple times" ] }, { "cell_type": "code", "execution_count": 99, "id": "d645f5a2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD8CAYAAAB0IB+mAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAA2LklEQVR4nO2df6xlV3Xfv2vem5n3xjM2Brs2tqdAGquqFVJoLdOISqWBpMZFuElDBagR5IemkbAKElFrgkQl/qKtlLQpKGREUEiLMFUSF7dMIdAQ0aiFenDND+O4nVo09tgG/6D2jOe9mXnvrf7x7nnvvPP22Xvtvdf+ce7dH+lp7rv33HPOu/Nd37P22uvsS8yMRqPRaMw/B0qfQKPRaDTy0Ay/0Wg0FoRm+I1Go7EgNMNvNBqNBaEZfqPRaCwIzfAbjUZjQYg2fCI6TkRfIaLvEtFDRPRewzZERL9JRGeI6FtE9Ndij9toNBoNP5YV9rEB4P3M/AARHQPwDSL6EjN/t7fNmwHcPPt5HYDfmv3baDQajUxEZ/jM/CQzPzB7fA7AwwBuHGx2J4Df422+BuAlRPTy2GM3Go1GQ45Ghr8DEb0SwGsBfH3w0o0AHuv9/vjsuScN+zgB4AQAHDmCv/4jf2kZBynsunSZt+yvj1zvNnhpz++XeHn0tcu8u4/NraXedjR7rvd6b9udU9ui3Z11j3s3P9PW3n+3t+u9zvtf3/uYjc9ja+8d1rQ1+Kw2Rz674XZDxt5nYsnx/3rA8LrhPdzf7gANXus/pn3P73m9/1bT693jPdvxvsd9uS7NPvSlA7ufyzLtvmfpwObO44PU32b3eQA4RBt7fh++DgAHYf/sQ+MICIslWxwNX3fF0fbz29sY4wjYjZ9+TNniwxBH/ddFcQTsiaV9cbR94vufSxhHL2w88wwzX2vaVM3wiegogD8A8D5mfiF0P8x8EsBJAHj1jx/kz526BgBw3dIh7319f/OS9fWnNlf3P7dx1b7nnrh89e4+L+99/elLR3ceP3tx+/FzF3f3+8L6ys7jc2u7jy+tb3/0m+u7/wW0ti30pfVdwS6tbT9eXt895tKF3cfd88tru6I7eGH38fJaz0Re3GsSyxcu7zw+8OJFDKFzF/Y9h/OG5wxsnTtnfP7AsWOi9wMAjh7Z9xQf2/vc1hWH9/y+ceTg7uMr9hrOxupuYFw+QrPndj/rjd3/HmweMTy3uv25bq7sfr68uvuZLq1sm/KhlV1zPra6/R905cruf+BLD6/tPH7Z4fM7j689tPsYAK47+PzO4xsO/nDPa9cvP48h1y+t7XtuSK44AvbHUj+OgL2xZIojwBxLpjgCdmOpiyNgN5a6OAJ2Y8YUR4B/LPXjCNgfS8Y4AkSxNBZHwHgsfeGZ3/6/Y+9RMXwiOohts/80M/+hYZOzAI73fr9p9pyYTnQSwboECoyLdMhQpH3GRFqCjVXaI1TjNlcs7TP9jq0rDhtNfx9Hj4iE6mXsY8cZMDR7H0xmn5MX1lf2mL6Jpy8d3Wf6HU9cvnqP6T+1cdU+039qc9Vp+v3YcMWSJI7G8DH7eYePHRk3fQfRcTQg2vCJiAD8DoCHmfnXRza7D8BdRHQPtidrn2fmfeUcCTEidGHK7lPQz0qkbKzszUJsXD5CO5nJxuqBPZmJDzFCjcJg9hL62X3NPHdxdSfLf/bi0T1Zvg+hpt+hEUvSxMlGP3Eaoz9SNrEZEFObR/Zm+aFsHDm4J8vXTp400cjwXw/g5wF8m4genD33awD+IgAw88cBnAJwB4AzAC4A+AWF4wYTmt3PW1ZSpVBHzN6U3Q/LOVpshg8kVPn+5av2lHWk+Jh+CmISp7FyztyS2fSjDZ+Z/xR7p7FM2zCA98QeSwNpvbGxzWiW3xmzplgDM3sTw/r9PDAs6wDmLB9Ib/o+mf08Jk628qgJ62g5o+kv1J22GsPPDskwtE9/oqkkIUZorZ1rmbRlP5Ls3lbO6dfvcxBSskuBpt6l+51q4rQhCE+bjob6M40+s8SRg4UxfF+R+mQluYehmiUHiVCdxIo1k9iHjHXolGSYSEiyYZvJapt+bBwN8U2cctDXhfYEf1LTP3rEuY+qDX+sT96HpzZXi2UkrommXGhkuM4OmRCxCgSqXbvP3aEjGdn5dHi5DNSElulr7EeznFN6JOUaLQcnT76x5PGeqg0fcBu2670hTCkr6WemWpnJmFBFpi8RnnA7aRvmcJRSa/2+nwDEjARN+nQlLjFx1L3f+rpC4iQZKddSGjUh7RIT6Vpq+p4XhzqKjQJ8J6Ek4paKdB4mmfoMJ5yG3TrRmER4/oJK6SZVZ05NDPvxQ7t1THRxIY2lmDjSSpxqGSmHMNb5Jmp5TlDqnIzhA7r1yFCRDtG64ap/d2ANRAnVhKd4Y26y6pN7wjYXPh07JkzGn2qSt2MeEqeY+1pqYD6jwUHM8NOWldTcN+wyPp+blrTM2Hf/pux+KuUcF1qJg6+2u1JPcNlUKXFKTbcsRgi28ui+5TsMcRRcIk3AQhr+GCaRzkNWYkJijLbySSqxpgyCEksqhCLp1hkz1VytkT7HccVR6aVJhvTnw1KS2/QXzvBzBYOk7hhyO7gPGgaX0/Rt+5Nk9z7kCmgbKUeEqXVu278kuy85Uh5rwZW25oaMlkskTyYWyvB9RWpbGROQZyVjq/ulwGZkmmWdDj52REWwvmZvwrY6pvV9GecEYyYgfbJ8IJ3px8ZRTWjc0yItI7pMP3UcAQti+E9tXBWdkcwjGkLtCBWrltC16BuA1oVA2ko4TCA02n81Td8VR6HYEqdaOnR8R8uho8/UcVS14W/wUrTAQt/vm5WUnLDVzFDHhKpt+lKBjh3XN6Biy1u+k345bwpyJSwaRi15vzRxquk+ltxI40gaS74J0yTaMjuxSVvO+u9xMVWRSpd27S+VHItkNU2T+Lo2Tt/sxcfsQ8s5mmyuL+98CYqU/lLJEsZ68k1tmkNC4qj/PhtjcaRZzslZGnUxbM80LaY2dn+LdFXaFCPfSRh+x1CwsVlLjEinMAw1ESPUUEKEO5UbrGhtac+3XsUyXB/f9KUosTdi9eNm33r6ATHlUxJ1JU41tTZLvlQoFPFS5MpMyvA7NOqIU6zbS78EZepCtZm9JLt3UUOHTgo6Tbsy/T4pu3liEycpqbvdpPgmTyVMv+oafipsZm8S6ZSykhBMhmmrkafMvlPsu7b+e98RoEl/NjPNmcykPFaJkbJtPmyooynexT29M45EQ6A+WUkNiz2FCjW36bv2KZ2o9QnEGpZFTpEw5DB97cRJg9JLlNSUPJlYOMO3UUqkNmJuCfchZEmCrSsOqwk21OynspSCz4VfmlC4SiYpTd/X7E0M/86pj5RD0YwjFyqGT0SfJKIfENF3Rl5/AxE9T0QPzn4+pHFcH564fHUSkQ6xDUNTdhakylQlWXWsWDXFPsVhtoSxxCO36bviaIzYxCnlSNl285VrvsekN98svyPW+CXv14qO3wVwu2Ob/8rMr5n9fFiy00scb5ChAgVkIq01K9ESKpBGrN32kvfEZPf7y1njn0stX14O6K4to2X6kv3UfFetBjHzQdKSZGgsSVAxfGb+KoDnNPY1JNSwfd437yIF4icufcU6JkJfMcesl1Mzw5GgJHEIzfKB3XhIGUtj52E6b1c5p+bW5iHayVOHK1ZCRgQ5+5l+goi+CeAJAL/KzA/5vFnSchYi5hiRlkZ685UPptayndcC+vNjhqjWLyY3DZsrK+dcWl/GIc8bsULx6c13xZJmHKUg+XpUwvbn4P17xpFmyTOX4T8A4BXMfJ6I7gDwHwDcbNqQiE4AOAEAV9+w/w/VrEmmFqmt7lhD73DNX+bga/YmXKOaGjp0bAxvwgLMN2KFohVLtjiqNXHyMfXhfS3Su9e1kycNsqREzPwCM5+fPT4F4CARXTOy7UlmvpWZbz169aFk5+QrUhO1DUOHBjasV0vLOjZDzVFiCTmGJLv3ueEqxcXANfEonQ+ylXYWoTxZIyGjyxLlyiyGT0TXExHNHt82O+6zOY5tIiQoashKUlCbUF37nkobpgRJghCiu5ymnyNxquFellBces1t+lptmZ8B8N8B/GUiepyIfomIfoWIfmW2yc8B+M6shv+bAN7OzHkazAe4gqF0330tSISqLdZQszddtErdXZurVOfSaQ7TDzH7WhMn3w4tk77GkqeaTF9Fncz8DsfrHwXwUY1jxRBq9iaR+rZjak00xU4omeqPY7V8Ww1yZ5uZWGPrkalF7yrn1NKSaVo901TLB9z1/NhF1mz71SK2rTnVBXYYZynXpwLyxVFdbQ2JyFHbrHUYGrtQmLSEElR3n40SJO/1ye5Fx67jvyf5vI+m7qVxNLXsPhTNLH9nu4CRs08cVW34G7wUZdY+702Z3adkmJlKjMxHqIC/WG3i8xGn7/E7alosbTiyC00ExsxSUoLs4iDG/GPjSIrrAlh6HfwYfHQsiZGQi8NkPr1OcJIhqq+wW91+G1ubpqS8s+89CmUaa8dQZX33Wvh+KYpPq2Y/NmyxFHJxsMVRaOJUy0gZkJd1NONIu9Q5GcPv0C7N+IrUROl2zFyEmH7s8Xwxj17qyfhjGKvlA2H9+ZqxpBFHU2SsJ187edJiPlMkISEiTZGVpF7S1WR4Y2WPWrJmZ2dDxHnWUr/vMCUMNZUNXWiNkFMnTr4rz6bUSan24jqiuwA5yzi5645TF2qo2YfW7mvp0JFgy5afvnQ0e3nSdbyYxCk1wzgJ1UFo8lTC9BfS8DVFGpKVlFhWQTPL37hiKZlYtfcbUs5JPQKQjgDHTNFVIpmXOama6vcd2uXB3Ka/cIY/L8HQJ0WGKlqqQNH4pfvSzu4lxH4JjWSEp13OSK1zyWgideJUEzEl0pQJ1JCFMXzpcDd2CJojK9HMQG1GKf4qxEjBits+Pev2pmwsRfY+nIPRHMGFZvlAmhJPbBzFUKIl06QX3yw/Rxx173W9v2rDv8wHogXrI3pfkdaclWgI1fuY0iy9J04Ns6+p7z4nUr1qmL5WHGnW7mtYcbaPRvIE+Bm/70Wirk9shE5oPm1nviLPJdIQNlcZS2vpTM223GvIEsr6dfh0eUmpCdtzays4trp3jYwX1ldw5cr+dTNsffm2Vs0+w3iQxpJmHI1hSpxCRsoa3W4pvmOiwzeWUpR5JmH4Hanqkin7hGu7M3Ds5hFt08/FeO10cbJ+qen3SRFLrjgqnTiFYlq/KiSOtt9XNpaqLunkIFSkWlmJFqZMVbNevbF6oEiPvuYxS/bfSy/8Y2XC2s1SM46mjvOLdwre67LQhl8iI6mt7thHUgvPJVbJBWbq2b1mglDyblbtY5s+l1Qj5ZgkIEZnpUx/YQ0/RqRTz0psQq3B9CX7TzVRW9tduH1cCcizF49mN37J8eYhu/fVRQ1xZGIhDT9GpGPkzEoAuQBTGViKEo90n/ZuiJGs3/A5pJywjRnJ2YxQostcpp8ijmrARxcayVNO418ow5dmQDaR1pyV5BTq7n7iBVtqfiA3pgRgrKwTq7OU2X7KfUvLXDWXRvuIv0daKQZc+5jGpxaJjzinmJGkwNVtMKQTmrQDIVTcWtn9lPFZPjmkg8e2Lx+0EqdaOt3Gvm1O89uw+nEhiSXfOKrjkxxhc2tpR2S+ok2RgYyJtMY1P/qECtXX9Lf3mS5Tn9cbrEw9+cB4Xz7gb/p9csTS1BMnzX78kDgC0sSSiuET0ScBvAXAD5j5xwyvE4B/DeAOABcAvJuZH/A5Ro66pKZIpVlJ6qWRXaQw/RS4W938LgZTWiFzDN8vSulIHUuuOCqZOJluYoz9nmhgOnGkdQn5XQC3W15/M4CbZz8nAPyW0nHVyCHSHHXHMSOLKWuUzqxjzN7n7zZO7EYunDbEtzzhKn3UlkmHxtEYtZRzOmx6ciUdl49Q8VhSMXxm/iqA5yyb3Ang93ibrwF4CRG9XOPYsTx3cVVdpLnQqk1LsuNSYi0dIDH4XOBticNUTD/mPGovi2pSUtO5WiNuBPBY7/fHZ8/tg4hOENFpIjp98f/5D1d9SBUotWUlHTHZSUcusUovMCHZfYpyTunSXEnTlyRNgF7ilKtDJ2S0XFscDamuF46ZTzLzrcx86+GXpBOxNEBsIs2ZlfiUFkINrQax+owkpnJH7ZCxhCAmywfkxqtJyjjySZxKX3D7+MRRbuPPZfhnARzv/X7T7Lns+ARFLWaviVYZqBOrlmC1xT9vrZiAPEPOZfwacVQLvnpxbe+TjOQ0/ly1h/sA3EVE9wB4HYDnmfnJTMcG4D/kDRXpWFYylRtFQnqKO7H6diGEijwms08xqa3JWItmCJ3mQzp5XPvUovbEydae6eru8Y2lfjykiiWttszPAHgDgGuI6HEA/wzAQQBg5o8DOIXtlswz2G7L/AXJfjeYglvPgHBxusy+JpGOiS6nUDtsojt4gVWyGInZ12LewPaFfmllY9/zl9aXccjwvItOm2P9+SY0jD9VLI0x9cSpI0UsxaDy6THzOxyvM4D3hO4/Z11yCsPPHGjePQjoCHhqZh+KJMu33ZQ1himO+hcB7TibUuJkIyZ5AvRjKYbqJm1LIjF7m0hr686xTd5KjHFjlaqZGNU4jyndbCUxQ43kpKv35zb7kow1QKRMFmqJo2b4M2LN3sbYMLR0Z4F4xc3CYpUePzRgR9s3lW+6GqKRINRmrC+sr8xd4gTEJ09A+TgCmuED0AmakiK1CU4rqy2V7WuZfcnsPrTuLE0wajF96XlMpZTjg4/plzT+hTZ8aTYC1CHSFBmndztaBrF2QZE6s9cmZMTmShSmYPqacWT7PGqfsPVaxqOQ8S+k4fsIFIgTaQ24stsQ008h1lT71RiOl8TH9H21HUsNowvN0mjsaLnGBKpP3U6VAF+Bxmb2ubKS2BX/Qt7fF2tMF0LUd4OW9xsRY+2ZQHiLpo2Q9s2Q/fsw9cRJim8sDfUfGkuSOKr6E97cih+AxGQgErOfikgl63vHXDTGxDYUr2ZGIzH7KXXm2Ai9IUvT+FPHko1ayjnSdfK1Yslm/iGxVMenaGEoMolwNYaZNdTsS6CxNvie/aUo/Qj/a7RLWSmRZPkxd+GaYmIslrTLNDUmTqZ18TtcMZDD9Hf2oRw/1Rv+kBw1Q6nZu0Rqy0pStGTWJNRU5DLpFBPktrIOIDd9ACrLL9QSS1MZJYdQWywt5KStDS2zT0WsEUlLHBsrdWXAgN/5lCrl5Lq3YgojUK1zrKWc08dHXzXFUTP8HppBVEqk2nXtGozf9xxSdFPkwieRqNX0z62tZEucUl1g5zGOgAmWdFLgGzjzPAQdoxNr7uGpb5DUPknrKusAfl07miWeWFLEUY3ZfR/fLzsvFUcdC53h+2QiHVMQacrulRxZSpcNpTL7GjItF75JRclsPySOasBVHtVqDhjbdwkd1n35TESoOLUy+9Jr6HT4ZicdQ6FqZCux4tfM7FOvoSPJ8gH//vy+rnNl/CljqXTi5INGLOXI+qfziUYSm4FIzT6HSG0tZR3S7oBQoQ6P1SE5pnZmU9sEGq0tgVc3VfYVelNWKvPXyOSnVBL16bKJjaWQRMpXz1V/8pt8IFi4mkPMKQk0hM4wY40fyD9Mrc3spUizfCD+TlxTLNQeS5LEKddIOafpD4+rzaScrESd0MfsaxuC+vYAa4o1B7VP0GqivfxC7liqNWmSjJa991lxHC30pK2LFCLVykqkdeZ563IBts/R9zzFE3CJ6/d7juWpr1pN04XvedeWOAFhcVRjLKkYPhHdTkSPENEZIrrb8Pq7iehpInpw9vPLGsdNxaX15bkQaSi1ihUI74jIjfTCPu+mnyqOSjQ+hOiotjiKVg8RLQH4GICfAvA4gPuJ6D5m/u5g088y812xx0tNSEDVbvaht3dr1vZjqS1wNPGp5wO7GtVeYVOTKV2YUpR19uy/ojjSyPBvA3CGmR9l5ksA7gFwp8J+sxKS1fuinZX4lB9istySGX/ssb3u0M1YztEgh2Z9iTmn2hMnQCeOSiYvGoZ/I4DHer8/PntuyN8nom8R0e8T0fGxnRHRCSI6TUSnN59/UeH0xunEGRM0UxBph0ave2rB9o8Re5zSXTk+F/gYHZU2/txxVPo+Fg1dlTL/XCr5jwA+w8wXiegfAfgUgJ80bcjMJwGcBICVH71RPeXSCowpGX0frdX7TEINGbKmErz3JFsF2b1vaWfIUNupSz7zFEu+ZR3NJRI0YkkaRxqf9FkA/Yz9ptlzOzDzs71fPwHgX0h2zFtx7Wipsp4QgabKSkLqj6mWbK2hzl46qx/iexNWrOn30bwA1BJLpbP7IVOLJY3/xfsB3ExEr8K20b8dwDv7GxDRy5n5ydmvbwXwsM8BaqpT1pCNaFB6EacUhJp9Ddl9n05jWsbfUVMcAfXFUujkbW1r3tuI/sSZeYOI7gLwRQBLAD7JzA8R0YcBnGbm+wD8YyJ6K4ANAM8BeHfscUsQKtDUWUlMl8GUxDpGbVn9kNClFjSz/ZqoNY5imEoCpXKJZeZTAE4NnvtQ7/EHAHxA41gliMlEcok01vSB+sVqInoiurLsfkiqbL8UtWX1Q2JbNGuPpbo//cLULk5tahdrH5VOiYxmH7ug2tSz/dhYypnda/Tl1xpLi+VoQrSMPvcQVOsGktxLtkqpvXTjQsP0gWll+1NNmrRjqZY4mub/RiI0xVmq3qh912Bp809l8rWXcmz0dVqj+Wub/DzEUorvkAhhoQ0/VfZRenIp1a3iJvPVFG6uDL6k2WuulQ/UY/4tlvwodQFYGMPPNbQsLdCO1OuDdEytzFJDZq9t+h0mjWtfBHKWaBYplnLFUd2Gv0WTqgHWItCOXKY/BWow+j6pTH/IlOKnT4ulNExTDRVSm0A7OqObB7GGUJvR9+k0k8P4p0KtcQTMh+k3w1egZpF2LKLx12z2fZrxbzOlOAKmGUvN8AOZgjhNTF2wLqZi8iZylXlqY+qxNKU4aobvyVTFaWKKgh1jykbfZ1Gy/RZHZWiGL2SeBDpkaJZTEO68GPwYfb3Nk/kvUhwB9cVSM3wD8yxKCWNmWkK8827sEoZ6nNIFoMVSXReBhTb8RRejL81862BMtyUvBC2W5JSMo7oNf4uakBoNIS1WGi40vtO20Wg0GhOgGX6j0WgsCM3wG41GY0Foht9oNBoLgorhE9HtRPQIEZ0horsNrx8mos/OXv86Eb1S47iNRqPRkBNt+ES0BOBjAN4M4BYA7yCiWwab/RKAHzLzjwL4DQD/PPa4jUYpltbrupmmYab9P+1HI8O/DcAZZn6UmS8BuAfAnYNt7gTwqdnj3wfwRiJq/xuNydLMpF6W1qn9/4yg0Yd/I4DHer8/DuB1Y9sw8wYRPQ/gZQCeGe6MiE4AOAEAS1dfrXB602Zzpd3sVBN9I+ket/+j8pgMvv2/7Ke6G6+Y+SSAkwBw+PhxXsQrdRNqvXT/NybjbzRqR6OkcxbA8d7vN82eM25DRMsArgLwrMKx545m9o1GIxUahn8/gJuJ6FVEdAjA2wHcN9jmPgDvmj3+OQB/zMzN2RqNRiMj0SWdWU3+LgBfBLAE4JPM/BARfRjAaWa+D8DvAPi3RHQGwHPYvig0Go1GIyMqNXxmPgXg1OC5D/UerwN4m8axGo1GoxFGu9O20Wg0FoRm+I1Go7EgVNeW2aiXEl+w0dZ4b8wjpb6spmrDp600XwfWvrnJTk1foTd2Lu1CME7Or9BrsTROTXHUUbXhp8IWEIso4BqF6WJ4zot4AajhC7LHzqHFUZ0spOHbGAp4XoU7BXH60P975tX8azB4KS2O6qQZvoO+cOdBtFMTaAjd3zgPxj8lk7fR4qgOmuF70Il2ioKdqkBjmLLxz4vRm2hxVI5m+AFMSbBTF6gGUzL+eTb6IS2O8tMMP4KaBTsvAtWEVzerNf1FMvohNccRMF+x1AxfgaU1qkqsuQW6tLKhsp/N9fRyrDHbX2Sz71NbHAF5Y0krjmxUb/jL63Hv31jROQ8XtYg1pUBTC9K0/xwXgVLkNPoWR/6kiqUcxj5G1dFEW/H7MAk9lXhLi1VboCWFaToHTfMvXd5Jafax5u6zzxSxVDqOAN1YqiGOOqo2/FQMxasp2lJinVeB9tE2/9Kmr0kKkw85rlYslTR9rViqMY4W0vCH9EWrIdgaMpQQahToGN25xhp/CdPXyu5LmbyN7pymGkcaZl9zHDXDH6Al2JxijRVpzQJ1oWH8OU1fw+xrNPohU4yjWKYQR83wR9DMVGplCgKVopXxpyTW7Kdg9EM04iiX6YcmTlOKo7YevoPl9fBAy9GFsQgi9SH076q51zpGg7VQ+9+wKHEUlQ4R0UsBfBbAKwF8D8A/YOYfGrbbBPDt2a9/zsxvjTluCZbXw7KU2oakKQV6yHPflxJl4zVm+6EX/5pNMoR5iSNgmrEUGxF3A/gvzPwRIrp79vs/NWy3xsyviTwWli7Itts8EnskM6FiTYVvVqIpUF9B+uxD60KwtLLhZfo1de2kNvqSsVRbuTQku9eKJY048tlPbGTdCeANs8efAvAnMBt+GFtyYfYZe4+GeGszfSkaAtUSp89xUo0CcuKb3WuafUj8uN6vdRHwjaVasvzYWMoVRyZia/jXMfOTs8dPAbhuZLsVIjpNRF8jor9n2yERnZhte3pz7cXI09vL0oXdnxh8AzJFLd8nK4kR6KGVjZ2fEsQe2/dvr7mWL0VL5zn2P7WSlUYslcSZPhHRlwFcb3jpg/1fmJmJaOzy+wpmPktEPwLgj4no28z8f0wbMvNJACcBYPW648ku551YQ7OVqWT6oQItLcwh3fmEZPy+pR1NcmX3qcxdetyYrN8nlrSz/ByJU02x5IwCZn7T2GtE9H0iejkzP0lELwfwg5F9nJ39+ygR/QmA1wIwGn5uYgRbUqgpqUmgQ0KNv6TpS5ma2Y+dQ+pYKkGI2dcYR7ElnfsAvGv2+F0APjfcgIiuJqLDs8fXAHg9gO9GHled0CFqiSGpNCvxFWkNQ04pIedacwtdiI5Slm1iqPGcYpgXswfiDf8jAH6KiP43gDfNfgcR3UpEn5ht81cAnCaibwL4CoCPMHN1ht8xL2INMfspkuK81dZSEZZzvOeEKjX6PiHnKP0ctObEUszX1J40RY1xmflZAG80PH8awC/PHv83AK+OOU5ufMs80uForWWdFAI9tjoevefWdMfuh1Y2xCWeKZR2bNRu9EOWLviVeGor7fgkTtpxZIuhDt9YmqTybZmA6sqXHmLNJVRJViIVqYZAJaKUvifmQjBV0/fJ7lOYfY7lw31Nf4rUFktj1KF6C77DXe3lWudVrDECDRGm735DzN/H9F3kuAkrp9n7HCvF2vc+o+ZasvzUiVOqOLJRteGPNnl6oHEBkJp+DUJNOTGZU6DdsXyNX2r6qbN8zXsvQs1eu6FAYxlxrQQqtjyqVb8PMfsSRt9RteGnIPS27lqEqoGvSEsKNMT4NTP9VIgnKCvtHEu9PELp5EmSOE0pjjoWdrXMkNX7JMFX8s5BbZEeW12vQqSAf7BI/s6a2zSBsC6X3PpLFUe1M9U4WljD7/AVbEmxxg5DfUVaG76BU2t7nERvPjqrYenhFHFU6m/STARqi6OFN/wOTXGVEOo8i3RIrvMrta6Or9nXRM7kKcf3TZiQJhI1xlEz/B4p66qlkYg01dDzypX1nR8tpOfq+rtTlHVyGVFtZt9R63l12C7kLj3UYvahsVT3zFYBxDdROSZxbfvJPXErNfsYpAIc2+6F9bAZumOr6+o3cqXEZYYlSh3La4yNVd2LlHRSNyaOaiUmlnyMfLitJIYmY/jLa/4GGSriqYksNksNFahmxt7tK9T4bUyhawdIb/a2GLK9FnMxkMTSVO51SZU4acWRZD9VRwFthRl9R/+9vqKdF6FqT1xqmrxt/z7GL8nybaZf0523KYiJoeH7Q8w/NoHKlYDlTpxSx5KJ+VX5gE60PoKtKdNPNYHoI9KcAvU1/imUdmzZeYrsPtboXfvULAPVnjy5Eqda42jIwk3aLq+xVyDE1FxzTF7ZshItkWpPuPrgc2zX31Nrm6YEr6URPDUeimYchVKqU6fPVMweWEDD76hBrCXxMfsa0DL9McYunN5fFB9gQK7sXqq/XEYfekyNCesSaCQKJZOmPgtr+IBegNQo1FiR1iLQPrWdjw+hSYOP2ZdEy/S13zdk7AIeWr+XJBg16XahDb9DIlZtoZYcirpEWpNAh0jOzfb31VbWUfki8MJm3zGPyZNNL1Mze2CBJm1dSHqRrb31BSadQrKSlGb/0sNr3u957uKq93uuXFl3TuZOYRLXhSTJqMXsOyTNETU1Q6QkNJZSxtGkDP/gBbm4Lx8JaB9LcANKCXJmsSHiHHu/j/lLTH+MsRbNmtozU5p96jiSMHXT106cNOPIRlRJh4jeRkQPEdEWEd1q2e52InqEiM4Q0d3i/W9ti7P78SH0fa4gmvIErqZIX3p4LVqksft0nW8ta5mMlvUiyhe+Zl9bHNkY+1ySdfqMJEihiVPpOLIRW8P/DoCfBfDVsQ2IaAnAxwC8GcAtAN5BRLdEHtcLX8GGijWVUHMs4iUVaQ6Bapr+FHHpxUefIWatta95TZ5siYSv2ecmyvCZ+WFmfsSx2W0AzjDzo8x8CcA9AO6MOW4oWmKtWahjWYmGSHMK1OfCYjv/sb+7tslbKVKz1zT6mH3XNscQi1YclTB7IE+Xzo0AHuv9/vjsOSNEdIKIThPR6csXX1Q/mZSBkBOtVR4lIi0p0JLHTknIaNDH7HOgYfoa5a4UHW+pEoLSWnYaPhF9mYi+Y/hJkqUz80lmvpWZbz14+IoUhwAgE2tqoeYippZdWqAdrvPQKu2YLqTSklrpuz5LJDO5j1fr6Nqlv1oSF2dLAjO/KfIYZwEc7/1+0+y54hy8wM4uhCl17vhmJRKRTomxzp2xFs0cq2h6LYkQkd2XHLV2x7bFki2OcnXsmC7cPiPl0MSppjjK0YN2P4CbiehV2Db6twN4Z4bjipCIdQwfoZZqQxsTaQqzf9nh8+Jtn7141Hv/Lz28FtS3Xxve31Vbsdn3cSVQvslTqQXVtBOnEFLFUpThE9HPAPg3AK4F8HkiepCZ/w4R3QDgE8x8BzNvENFdAL4IYAnAJ5n5odBjLq9tObfZWNWdmkgl1NxfhCLF1+x9xDn2HqloXabvm+UvKpI4AvxjSTJqNp5P5X35uRKn2FhyxVGU4TPzvQDuNTz/BIA7er+fAnAq9DhScQ639xFraHZSi1B9shKbSH0EGiJO174kxj8vmX6fsXKOZnbvG0f992jF0pRKpLFIYylFHI1R9Vo6tMVBIu1YXtva+ZFQy9DYhbTumOrGo5cdPq8q0uG+JdiCyWeIPdX2TMCnS0YeA1r70IqlGpsgOjQSp1RxNEbVhq+JhunPW0/xEIlIcwhUekHxHS5LL4Apvth8Z98eBmbTm4/Za6Jz8TCfu3hl0ISdOqYEwDdxksZRbrMHFsjwAblYvW8jN+ih1sxkLCupxex9jzd23jXcgZvKwGQtxfHG7Nq/i9pHzDEX9hh9lTD6joUy/I64MlF9ItbISlyUEmnJ4CjJmM6kZp+DGNOvMY40cCVOpfW8kIYPuMWaIjspfdNIaHZfWqSu4/tk+bUsqNahrYlcZp/qeKVjZIiPXmqPI2BiyyNrs7y2Ze0+8GkxS9Wtk3rhtBQivfaQ/T1PX/LvwX/Z4fNBvfsSctyABcjLfKHZfW6z7x83JI6kHTu5+vGlE/gh5ZwUcQT4x9IkDX/5xXET3LhiyW9fDrGa31OXUIeYspLUIpWIc2xbqWhtpj+PrZo+BLVcFo6jecKWOPmafWgsSeJoMoZvE6dpOx/B2sQaeiNJKjQ7SDRE6iNO1z4kgvXN9E03Yk3xJix795jc7H3jCJDHUkgcmZKnkNFyrpsYfROnEnFko+pLMm1tC08q0j7d+8QC9765q45JpxR95BKRXnvovIpIh/uUMHZ+Na1ZIsFUr/bVlVS3oXHk+17b+dTetTNEWr+P0V2KOLJRteFrESN2wKPnubIJpw5TVjImUqnZpyJ3AIwxHEm55lJSrpQ53ukiN3sNpHGUYy6h1ljr44qlEjpfCMPvcIm11KSXJqm7UHKJ1HUcnyxfMgzXHimFGJL2qDE20bHtNxTTxUvyd2vf1xJSGtVMnEolNQtl+EC46YcKdd97KshMpiLSUNOXkLs9M9SwQrP7FEbvs//ou3GV/3tcI7ThBT+lPkqOYBfO8IH0wdCn9B23WneclhJpyHGnUMuPMbTSZi89zjyMmG2EJE6ly5WT6dLRZvnFzdHug7FuA9EXpnh2GNSyRHIKkV538Hnj89+/fJXXfq49dH60g0fatTO2bHItmEaLIZOcOZOZ7ni+LZymOKp9FU1p4qQZR2Px08c3liZp+MsXLo++tnHkoHw/Aaa/f7tyQg0Zhvpmvz4ilQh0uJ1UsDbTNzGvffm2rDmom20QSz7x0z9ubBzNOyniqL+tNI4mZfg2ox9uIxVuSIYyFWKzEqlIfQQ69l6JYMdMP/Qu3Jr78b0X8PMwe1schV4AfOOotvtbfDElTlOIo0lcepcvXBaZfeh7xoLFlE2V7CVOsWxvjEivO/h8lEhT7WsM1wUw1dr4w3mc0Pr9WHYv7pGvKI7M2+2NreHnlGs+bDhSTr3yqnYc2ajb8LfYW6BDQkQeSw2dOCY0JzNTmbNrv2MXItOFawqTt0D6m/g0YkC6j1jTL4HvhV4zu0+d5AyJMnwiehsRPUREW0R0q2W77xHRt4noQSI6HXPMUFxijRGqK2BLd+qMESLSHJl4qOnPC+YW4LDsXjvZ0dxf7Gi51sSqozazB+Iz/O8A+FkAXxVs+7eZ+TXMPHphSI2WWGu8Rdw1DJVmuy6zz0XIsbSXnw0toeUyotxm39+vdR5ggll+DCG6K2H2QKThM/PDzPyI1snkoAl1OtiCQprlDy90NXwTVh9JOSdoJczMZcx9xw9dt8dRx9fA50I+9cRpSK4aPgP4IyL6BhGdyHTMUUJM37nPShZTk2LKSlKI9IaDP9zz44uv6ftmW7nvuHUZmHjdJotOc5m9xnFqGy2n1kOs2cfEEiBoyySiLwO43vDSB5n5c8Lj/E1mPktEfwHAl4joz5jZWAaaXRBOAMDKofEWowMvXsTWFYeFh9/P8oXLQT3HHa62MtsNWBpflpKqo8SEVKQSEfa3eeLy1eLj+95gMi+Ysnttsz/w4kUACIonWxyZWjXnrS/fJ3HyMXufWJLGESAwfGZ+k3hv4/s4O/v3B0R0L4DbMFL3Z+aTAE4CwJVHb9x3+e/EOXzcRyrcMbFOXaiuYWgKkYZkHD7mP2b6vjdkucj17Vcp8DF7U+yYnpPEUmzy5EOpLxWKpZY4Su5gRHQFER3rHgP4aWxP9oo58OLFnR+f7RvpiRleDvejxfCCVrKOb+vQGpYBRV9QPjbPJDB73zjq3iNh7PiSEunw785ZHvUZKUvq96EdZLniKLYt82eI6HEAPwHg80T0xdnzNxDRqdlm1wH4UyL6JoD/AeDzzPwF0f63tqKMWyJuTaFqI/0+29i6Y2h2r2nSkv2VnOwqgWazQOo48mEqTRCuxEA6X5Q7jmxEjV2Z+V4A9xqefwLAHbPHjwL4qzHHicVV75cOSV1lHdu6OqWGopJyjgmbSFMK9IaDP7QOS02lndiyTs1LLPQJze61zDo0jmKXL9GY88qBKXGqyeyB2u+0VSRE9LlXHiyBr0hzCFTjGK6yTgl8WgzFyxFkMvv+/mz7FC/DMPj7aujWsY2UQxMnG7nNHlggwwfs4k/RyqbZQyztHU5Zn84pUNuxTBekqd99m8LwUs5j+cZSjuQp5ddMuqgxcTKxUIYPxAu1xsxEiiQrSVEnv375+Z0fH1IGRekbsLy/rNxgmLYkJUfTQspjlL6vZV4SpyHT7D+LJLaHvzTSzgJXGcMnK5aIVGLo3TZPbcj66sdq+pLe/NBlk6eOrxHTuQvgY7oTTLlaNUvU92MSJ6nZ+yRH0lgCJmj4dM69EplEvGOmHyvU3F+IonlnYIxIfbP3/vYuwbomcjtck7epvhRFu5QwHEX6ZPc+Zt+PpbG4csWST/I0nLwdNkFMdY18aeKUwuz720uMfzIlHTp3QWT23baauMo6U0RTpL4CNb0/ZB+p2jT7I6j+3ImkTTZk3sb7y04izL6LI81Yyn3PS+7VZ30n/EN1GRoHPu+v3vB9xOn7vjGh+k7g1lLHt9UdtVeSBOIFatrfGNLsKHTyNsUaKjUtix0aR917Q9CcvK19KWQXLv1qxpGNug1/Mz6TDjX9qaORldhEmkqgvqbvyqZsF7qcE7dSw5KUc0zYdKwx4nVdMELjyPodvQkmbvsjtv5ITnrBH+ppmGCExFEuswdqN3wlQsQ6zE40yjq1ZHySLLiE2efaf2lijMyUNac2e+n+JHE0BXIlAiV0vhCG36Et/jH6AT2W2eUcorrKOT41x9JmLMnybRe0Gm7A6pOyHJhK77GZvm3UkuLzkC5RYiJWLyVbME0slOHbmNfSTghjIs1p9prHSjF/kYN9o0qP7D51cuOz/ylm+RIk5ZwxSiVOk2vLjMWn53jYomlbE2SqLWVJbrRacmdFT2262yOvX37e2GombdOcGlPr/hqLJd/7XKay9Lhv4hCbOGnFUZ9pGv55S3Zx1G3mtQl1aY2wueo/lO1PNEnrjq76fahIJeI0be8S7JjpDxneiKW9Tn6NRGf3kXHkQ84180sgTZxENyd6xJI0jjrqv6z2OX/BLlLpNpAHxRSHo/26o0Y5Q9vsh+91vd90/JjaaG11fBOSco4Jka4zxJFzWXJh91HpJRbGCE2cXITGkiSOgKkYvlB8+94TgJZQU6H51YYa5ZwYs0+xH1/6I6Mc32/bN7CYCUqTTp1mHxpHjvdozxeMfS45Gh1SdujYEiepYTuP4dhH3Ya/tRVs3ACcYs3VtVMLIVmJS6Sa2PYnyfJtF7CSE7dV3DQUE0ca74d9lFLj/IV0pCxJnHLGkY26DV+LSNOXCrWfmdQ6FPWhhEg19zulJZOtNyAN9Oed3SuYtWs/puOnHC1r3tPiO7LTLOfkHtkuhuEDXqLXbtGsIsMbMMxKahLp2P61a/k2YktnMYakWjbUMnvB/hZtxCxlLHEqUcZcHMMHRsXqK9TcdXzpl58AehO2pUUaepx5/95br+xe2+wD9js831qaILTmwkITp1JzVrFfYv4viejPiOhbRHQvEb1kZLvbiegRIjpDRHfHHDMXtQrVhGSiSaO84SPS65YOjf5EnUPEDSv9C+AUOnWACN2lMnvH/mOy/Brr+LHE6NUWQ6GxFJvhfwnAjzHzjwP4XwA+MNyAiJYAfAzAmwHcAuAdRHRLzEG3zp0b/XGSQKg+1LCejisrCRWpVIhSsUovMLasqtY6/t75nnCjM+rWw+yDYijgOGOMjZZ95sNSlEzHRso2PWlm91Iz9zX+KMNn5j9i5m5s9DUANxk2uw3AGWZ+lJkvAbgHwJ0hx5MIMsb0h1i/DrEn1HnMTIa4RBqSbUjEajqu64LkW9Yp/XWHfWzlQs25JVssBZl/j+FFaEqjZU1iEqdU79G80/YXAXzW8PyNAB7r/f44gNeN7YSITgA4Mfv1/Bee+e1HvM9EEhfPiPd2jdfW8037LHZZjM/CFUvbn8BifBYyavgsXjH2gtPwiejLAK43vPRBZv7cbJsPAtgA8OnQM+xg5pMATsbuRwsiOs3Mt5Y+jxpon8Uu7bPYpX0Wu9T+WTgNn5nfZHudiN4N4C0A3sjMpmLbWQDHe7/fNHuu0Wg0GhmJ7dK5HcA/AfBWZh4rjN8P4GYiehURHQLwdgD3xRy30Wg0Gv7Edul8FMAxAF8iogeJ6OMAQEQ3ENEpAJhN6t4F4IsAHgbw75n5ocjj5qSa8lIFtM9il/ZZ7NI+i12q/izIXIVpNBqNxryxWHfaNhqNxgLTDL/RaDQWhGb4HhDR+4mIieia0udSCulyGvPKFJcJSQERHSeirxDRd4noISJ6b+lzKg0RLRHR/ySi/1T6XMZohi+EiI4D+GkAf176XArjXE5jXkmxTMiE2QDwfma+BcDfAPCeBf4sOt6L7caUammGL+c3sN2CutCz3MLlNOYVtWVCpg4zP8nMD8wen8O20d1Y9qzKQUQ3Afi7AD5R+lxsNMMXQER3AjjLzN8sfS6V8YsA/nPpk8iIaZmQhTW5DiJ6JYDXAvh64VMpyb/CdkJY9cJammvpTBrbEhIAfg3b5ZyFIPdyGo3pQkRHAfwBgPcx8wulz6cERPQWAD9g5m8Q0RsKn46VZvgzxpaQIKJXA3gVgG8SEbBdwniAiG5j5qcynmI2FJbTmFfaMiE9iOggts3+08z8h6XPpyCvB/BWIroDwAqAK4no3zHzPyx8XvtoN155QkTfA3ArM5deEa8Is+U0fh3A32Lmp0ufT06IaBnbE9VvxLbR3w/gnRO7c1wF2s5+PgXgOWZ+X+HTqYZZhv+rzPyWwqdipNXwG74Yl9NYBOZgmRBNXg/g5wH85EwHD84y3EbFtAy/0Wg0FoSW4TcajcaC0Ay/0Wg0FoRm+I1Go7EgNMNvNBqNBaEZfqPRaCwIzfAbjUZjQWiG32g0GgvC/wfQdGt6KUf2+wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x_range = jnp.linspace(-5,5)\n", "y_range = jnp.linspace(-2,2)\n", "\n", "def func(x,y):\n", " return jnp.sin(x)*y\n", "\n", "out = jax.vmap(jax.vmap(func, in_axes = (None,0)), in_axes = (0,None))(x_range,y_range)\n", "\n", "grid = jnp.meshgrid(x_range,y_range, indexing = 'ij')\n", "plt.contourf(grid[0],grid[1],out, levels = 31)" ] }, { "cell_type": "code", "execution_count": null, "id": "67f77afd", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.11" } }, "nbformat": 4, "nbformat_minor": 5 }