{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Common Gotchas in JAX", "version": "0.3.2", "provenance": [], "collapsed_sections": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "metadata": { "id": "uFfowrpjheff", "colab_type": "text" }, "cell_type": "markdown", "source": [ "##### Copyright 2019 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", "https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "metadata": { "id": "hjM_sV_AepYf", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# 🔪 JAX - The Sharp Bits 🔪" ] }, { "metadata": { "id": "4k5PVzEo2uJO", "colab_type": "text" }, "cell_type": "markdown", "source": [ "*levskaya@ mattjj@*\n", "\n", "When walking about the countryside of [Italy](https://iaml.it/blog/jax-intro), the people will not hesitate to tell you that __JAX__ has _\"una anima di pura programmazione funzionale\"_.\n", "\n", "__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. As such it needs to control the _unwanted proliferation_ of __side-effects__ in its programs so that analysis and transformation of its computations remain tractable!\n", "\n", "This requires us to write code in a _functional_ style with _explicit_ descriptions of how the state of a program changes, which results in __several important differences__ to how you might be used to programming in Numpy, Tensorflow or Pytorch.\n", "\n", "Herein we try to cover the most frequent points of trouble that users encounter when starting out in __JAX__." ] }, { "metadata": { "id": "n2r1_5_KeiF1", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Installs and Imports" ] }, { "metadata": { "id": "ndFipU9xeTj3", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!pip install --upgrade -q git+https://github.com/google/jax.git\n", "!pip install --upgrade -q jaxlib" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "GoK_PCxPeYcy", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "import numpy as onp\n", "from jax import grad, jit\n", "from jax import lax\n", "from jax import random\n", "import jax\n", "import jax.numpy as np\n", "import matplotlib as mpl\n", "from matplotlib import pyplot as plt\n", "from matplotlib import rcParams\n", "rcParams['image.interpolation'] = 'nearest'\n", "rcParams['image.cmap'] = 'viridis'\n", "rcParams['axes.grid'] = False" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "oBdKtkVW8Lha", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# 🔪 In-Place Updates\n", "---" ] }, { "metadata": { "id": "JffAqnEW4JEb", "colab_type": "text" }, "cell_type": "markdown", "source": [ "In Numpy you're used to doing this:" ] }, { "metadata": { "id": "om4xV7_84N9j", "colab_type": "code", "outputId": "25ed90e1-74f9-420c-ba06-21e5d6a3b58e", "colab": { "base_uri": "https://localhost:8080/", "height": 153 } }, "cell_type": "code", "source": [ "numpy_array = onp.zeros((3,3), dtype=np.float32)\n", "print(\"original array:\")\n", "print(numpy_array)\n", "\n", "# In place, mutating update\n", "numpy_array[1, :] = 1.0\n", "print(\"updated array:\")\n", "print(numpy_array)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "original array:\n", "[[0. 0. 0.]\n", " [0. 0. 0.]\n", " [0. 0. 0.]]\n", "updated array:\n", "[[0. 0. 0.]\n", " [1. 1. 1.]\n", " [0. 0. 0.]]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "go3L4x3w4-9p", "colab_type": "text" }, "cell_type": "markdown", "source": [ "If we try to update a JAX device array in-place, however, we get an __error__! (☉_☉)" ] }, { "metadata": { "id": "2AxeCufq4wAp", "colab_type": "code", "outputId": "7013374b-041f-4270-db19-cfb4ab992f52", "colab": { "base_uri": "https://localhost:8080/", "height": 198 } }, "cell_type": "code", "source": [ "jax_array = np.zeros((3,3), dtype=np.float32)\n", "\n", "# In place update of JAX's array will yield an error!\n", "jax_array[1, :] = 1.0" ], "execution_count": 0, "outputs": [ { "output_type": "error", "ename": "TypeError", "evalue": "ignored", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# In place update of JAX's array will yield an error!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mjax_array\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mTypeError\u001b[0m: '_FilledConstant' object does not support item assignment" ] } ] }, { "metadata": { "id": "7mo76sS25Wco", "colab_type": "text" }, "cell_type": "markdown", "source": [ "__What gives?!__ \n", "\n", "Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program. \n", "\n", "Instead, JAX offers the _functional_ update functions: __index_update__, __index_add__ and the __index__ helper.\n", "\n", "__NB__: _Fancy Indexing_ is __not__ yet supported, but will likely be added to JAX soon.\n", "\n", "️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation." ] }, { "metadata": { "id": "m5lg1RYq5D9p", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "from jax.ops import index, index_add, index_update" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "X2Xjjvd-l8NL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## index_update" ] }, { "metadata": { "id": "eM6MyndXL2NY", "colab_type": "text" }, "cell_type": "markdown", "source": [ "If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_." ] }, { "metadata": { "id": "ygUJT49b7BBk", "colab_type": "code", "outputId": "c1dc7528-4a4a-4ee6-c9a2-c7e39f95ccb1", "colab": { "base_uri": "https://localhost:8080/", "height": 221 } }, "cell_type": "code", "source": [ "jax_array = np.zeros((3, 3))\n", "print(\"original array:\")\n", "print(jax_array)\n", "\n", "new_jax_array = index_update(jax_array, index[1, :], 1.)\n", "\n", "print(\"old array unchanged:\")\n", "print(jax_array)\n", "\n", "print(\"new array:\")\n", "print(new_jax_array)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "original array:\n", "[[0. 0. 0.]\n", " [0. 0. 0.]\n", " [0. 0. 0.]]\n", "old array unchanged:\n", "[[0. 0. 0.]\n", " [0. 0. 0.]\n", " [0. 0. 0.]]\n", "new array:\n", "[[0. 0. 0.]\n", " [1. 1. 1.]\n", " [0. 0. 0.]]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "7to-sF8EmC_y", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## index_add" ] }, { "metadata": { "id": "iI5cLY1xMBLs", "colab_type": "text" }, "cell_type": "markdown", "source": [ "If the __input values__ of __index_update__ aren't reused, __jit__-compiled code will perform these operations _in-place_." ] }, { "metadata": { "id": "tsw2svao8FUp", "colab_type": "code", "outputId": "2492b20d-0b8e-4f61-816d-00b8a08ce29f", "colab": { "base_uri": "https://localhost:8080/", "height": 221 } }, "cell_type": "code", "source": [ "print(\"original array:\")\n", "jax_array = np.ones((5, 6))\n", "print(jax_array)\n", "\n", "new_jax_array = index_add(jax_array, index[::2, 3:], 7.)\n", "print(\"new array post-addition:\")\n", "print(new_jax_array)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "original array:\n", "[[1. 1. 1. 1. 1. 1.]\n", " [1. 1. 1. 1. 1. 1.]\n", " [1. 1. 1. 1. 1. 1.]\n", " [1. 1. 1. 1. 1. 1.]\n", " [1. 1. 1. 1. 1. 1.]]\n", "new array post-addition:\n", "[[1. 1. 1. 8. 8. 8.]\n", " [1. 1. 1. 1. 1. 1.]\n", " [1. 1. 1. 8. 8. 8.]\n", " [1. 1. 1. 1. 1. 1.]\n", " [1. 1. 1. 8. 8. 8.]]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "MUycRNh6e50W", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# 🔪 Random Numbers\n", "---" ] }, { "metadata": { "id": "O8vvaVt3MRG2", "colab_type": "text" }, "cell_type": "markdown", "source": [ "> _If all scientific papers whose results are in doubt because of bad \n", "> `rand()`s were to disappear from library shelves, there would be a \n", "> gap on each shelf about as big as your fist._ - Numerical Recipes" ] }, { "metadata": { "id": "Qikt9pPW9L5K", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## RNGs and State\n", "You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:" ] }, { "metadata": { "id": "rr9FeP41fynt", "colab_type": "code", "outputId": "180b7c87-7050-4123-dc42-2356da6f14a2", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "print(onp.random.random())\n", "print(onp.random.random())\n", "print(onp.random.random())" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "0.7117779558041075\n", "0.014396253746679077\n", "0.7717174868106601\n" ], "name": "stdout" } ] }, { "metadata": { "id": "ORMVVGZJgSVi", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Underneath the hood, numpy uses the [Mersenne Twister](https://en.wikipedia.org/wiki/Mersenne_Twister) PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937-1}$ and at any point can be described by __624 32bit unsigned ints__ and a __position__ indicating how much of this \"entropy\" has been used up." ] }, { "metadata": { "id": "7Pyp2ajzfPO2", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "onp.random.seed(0)\n", "rng_state = onp.random.get_state()\n", "#print(rng_state)\n", "# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,\n", "# 2481403966, 4042607538, 337614300, ... 614 more numbers..., \n", "# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "aJIxHVXCiM6m", "colab_type": "text" }, "cell_type": "markdown", "source": [ "This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, \"consuming\" 2 of the uint32s in the Mersenne twister state vector:" ] }, { "metadata": { "id": "GAHaDCYafpAF", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "_ = onp.random.uniform()\n", "rng_state = onp.random.get_state()\n", "#print(rng_state) \n", "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", "# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)\n", "\n", "# Let's exhaust the entropy in this PRNG statevector\n", "for i in range(311):\n", " _ = onp.random.uniform()\n", "rng_state = onp.random.get_state()\n", "#print(rng_state) \n", "# --> ('MT19937', array([2443250962, 1093594115, 1878467924,\n", "# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)\n", "\n", "# Next call iterates the RNG state for a new batch of fake \"entropy\".\n", "_ = onp.random.uniform()\n", "rng_state = onp.random.get_state()\n", "# print(rng_state) \n", "# --> ('MT19937', array([1499117434, 2949980591, 2242547484, \n", "# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "N_mWnleNogps", "colab_type": "text" }, "cell_type": "markdown", "source": [ "The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's _very easy_ to screw up when the details of entropy production and consumption are hidden from the end user.\n", "\n", "The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexchange.com/a/53475) of problems, it has a large 2.5Kb state size, which leads to problematic [initialization issues](https://dl.acm.org/citation.cfm?id=1276928). It [fails](http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf) modern BigCrush tests, and is generally slow. " ] }, { "metadata": { "id": "Uvq7nV-j4vKK", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## JAX PRNG" ] }, { "metadata": { "id": "COjzGBpO4tzL", "colab_type": "text" }, "cell_type": "markdown", "source": [ "\n", "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Three-fry counter-based PRNG](https://github.com/google/jax/blob/master/design_notes/prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", "\n", "The random state is described by two unsigned-int32s that we call a __key__:" ] }, { "metadata": { "id": "yPHE7KTWgAWs", "colab_type": "code", "outputId": "6c2db189-d971-4d60-eb6b-c7ee3a4704b7", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "from jax import random\n", "key = random.PRNGKey(0)\n", "key" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([0, 0], dtype=uint32)" ] }, "metadata": { "tags": [] }, "execution_count": 196 } ] }, { "metadata": { "id": "XjYyWYNfq0hW", "colab_type": "text" }, "cell_type": "markdown", "source": [ "JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! \n", "\n", "Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:" ] }, { "metadata": { "id": "7zUdQMynoE5e", "colab_type": "code", "outputId": "9e1e1f08-19c9-4d22-c78f-4d3e113e185d", "colab": { "base_uri": "https://localhost:8080/", "height": 85 } }, "cell_type": "code", "source": [ "print(random.normal(key, shape=(1,)))\n", "print(key)\n", "# No no no!\n", "print(random.normal(key, shape=(1,)))\n", "print(key)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "[-0.20584233]\n", "[0 0]\n", "[-0.20584233]\n", "[0 0]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "hQN9van8rJgd", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a new pseudorandom number:" ] }, { "metadata": { "id": "ASj0_rSzqgGh", "colab_type": "code", "outputId": "ea3fae99-6642-4016-b0c0-938214384fe7", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "print(\"old key\", key)\n", "key, subkey = random.split(key)\n", "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", "print(\" \\---SPLIT --> new key \", key)\n", "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "old key [0 0]\n", " \\---SPLIT --> new key [4146024105 967050713]\n", " \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "tqtFVE4MthO3", "colab_type": "text" }, "cell_type": "markdown", "source": [ "We propagate the __key__ and make new __subkeys__ whenever we need a new random number:" ] }, { "metadata": { "id": "jbC34XLor2Ek", "colab_type": "code", "outputId": "436713d1-06a3-408e-fbaa-1fedeea73c73", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "print(\"old key\", key)\n", "key, subkey = random.split(key)\n", "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", "print(\" \\---SPLIT --> new key \", key)\n", "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "old key [4146024105 967050713]\n", " \\---SPLIT --> new key [2384771982 3928867769]\n", " \\--> new subkey [1278412471 2182328957] --> normal [-0.5866507]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "0KLYUluz3lN3", "colab_type": "text" }, "cell_type": "markdown", "source": [ "We can generate more than one __subkey__ at a time:" ] }, { "metadata": { "id": "lEi08PJ4tfkX", "colab_type": "code", "outputId": "7599b43d-930e-4c20-d549-b7694281a59a", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "key, *subkeys = random.split(key, 4)\n", "for subkey in subkeys:\n", " print(random.normal(subkey, shape=(1,)))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "[-0.37533447]\n", "[0.9864503]\n", "[0.1455319]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "rg4CpMZ8c3ri", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# 🔪 Control Flow\n", "---" ] }, { "metadata": { "id": "izLTvT24dAq0", "colab_type": "text" }, "cell_type": "markdown", "source": [ "✔ __python control_flow + autodiff__ ✔\n", "\n", "If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)." ] }, { "metadata": { "id": "aAx0T3F8lLtu", "colab_type": "code", "outputId": "1f75bb41-2d50-451e-c05d-cb946b580d8d", "colab": { "base_uri": "https://localhost:8080/", "height": 51 } }, "cell_type": "code", "source": [ "def f(x):\n", " if x < 3:\n", " return 3. * x ** 2\n", " else:\n", " return -4 * x\n", "\n", "print(grad(f)(2.)) # ok!\n", "print(grad(f)(4.)) # ok!" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "12.0\n", "-4.0\n" ], "name": "stdout" } ] }, { "metadata": { "id": "hIfPT7WMmZ2H", "colab_type": "text" }, "cell_type": "markdown", "source": [ "__python control flow + JIT__\n", "\n", "Using control flow with `jit` is more complicated, and by default it has more constraints.\n", "\n", "This works:" ] }, { "metadata": { "id": "OZ_BJX0CplNC", "colab_type": "code", "outputId": "d75b0e66-273d-461a-814d-a95c40d41ef4", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "@jit\n", "def f(x):\n", " for i in range(3):\n", " x = 2 * x\n", " return x\n", "\n", "print(f(3))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "24\n" ], "name": "stdout" } ] }, { "metadata": { "id": "22RzeJ4QqAuX", "colab_type": "text" }, "cell_type": "markdown", "source": [ "So does this:" ] }, { "metadata": { "id": "pinVnmRWp6w6", "colab_type": "code", "outputId": "f7829934-8cdd-4bba-b540-d9df38c71e95", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "@jit\n", "def g(x):\n", " y = 0.\n", " for i in range(x.shape[0]):\n", " y = y + x[i]\n", " return y\n", "\n", "print(g(np.array([1., 2., 3.])))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "6.0\n" ], "name": "stdout" } ] }, { "metadata": { "id": "TStltU2dqf8A", "colab_type": "text" }, "cell_type": "markdown", "source": [ "But this doesn't, at least by default:" ] }, { "metadata": { "id": "9z38AIKclRNM", "colab_type": "code", "outputId": "f911fb55-f489-4300-f9b1-9142d252f3f9", "colab": { "base_uri": "https://localhost:8080/", "height": 54 } }, "cell_type": "code", "source": [ "@jit\n", "def f(x):\n", " if x < 3:\n", " return 3. * x ** 2\n", " else:\n", " return -4 * x\n", "\n", "# This will fail!\n", "try:\n", " f(2)\n", "except Exception as e:\n", " print(\"ERROR:\", e)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "ERROR: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.\n" ], "name": "stdout" } ] }, { "metadata": { "id": "pIbr4TVPqtDN", "colab_type": "text" }, "cell_type": "markdown", "source": [ "__What gives!?__\n", "\n", "When we `jit`-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.\n", "\n", "For example, if we evaluate an `@jit` function on the array `np.array([1., 2., 3.], np.float32)`, we might want to compile code that we can reuse to evaluate the function on `np.array([4., 5., 6.], np.float32)` to save on compile time.\n", "\n", "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/master/jax/abstract_arrays.py), and different transformations use different abstraction levels.\n", "\n", "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), np.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n", "\n", "But there's a tradeoff here: if we trace a Python function on a `ShapedArray((), np.float32)` that isn't committed to a specific concrete value, when we hit a line like `if x < 3`, the expression `x < 3` evaluates to an abstract `ShapedArray((), np.bool_)` that represents the set `{True, False}`. When Python attempts to coerce that to a concrete `True` or `False`, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.\n", "\n", "The good news is that you can control this tradeoff yourself. By having `jit` trace on more refined abstract values, you can relax the traceability constraints. For example, using the `static_argnums` argument to `jit`, we can specify to trace on concrete values of some arguments. Here's that example function again:" ] }, { "metadata": { "id": "-Tzp0H7Bt1Sn", "colab_type": "code", "outputId": "1435a6a3-2b1c-4acd-be81-c1361021f3c4", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "def f(x):\n", " if x < 3:\n", " return 3. * x ** 2\n", " else:\n", " return -4 * x\n", "\n", "f = jit(f, static_argnums=(0,))\n", "\n", "print(f(2.))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "12.0\n" ], "name": "stdout" } ] }, { "metadata": { "id": "MHm1hIQAvBVs", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Here's another example, this time involving a loop:" ] }, { "metadata": { "id": "iwY86_JKvD6b", "colab_type": "code", "outputId": "469a4aeb-2dbd-4f03-9aef-9fd646a717d7", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "def f(x, n):\n", " y = 0.\n", " for i in range(n):\n", " y = y + x[i]\n", " return y\n", "\n", "f = jit(f, static_argnums=(1,))\n", "\n", "f(np.array([2., 3., 4.]), 2)" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array(5., dtype=float32)" ] }, "metadata": { "tags": [] }, "execution_count": 206 } ] }, { "metadata": { "id": "nSPTOX8DvOeO", "colab_type": "text" }, "cell_type": "markdown", "source": [ "In effect, the loop gets statically unrolled. JAX can also trace at _higher_ levels of abstraction, like `Unshaped`, but that's not currently the default for any transformation" ] }, { "metadata": { "id": "wWdg8LTYwCW3", "colab_type": "text" }, "cell_type": "markdown", "source": [ "️⚠️ **functions with argument-__value__ dependent shapes**\n", "\n", "These control-flow issues also come up in a more subtle way: numerical functions we want to __jit__ can't specialize the shapes of internal arrays on argument _values_ (specializing on argument __shapes__ is ok). As a trivial example, let's make a function whose output happens to depend on the input variable `length`." ] }, { "metadata": { "id": "Tqe9uLmUI_Gv", "colab_type": "code", "outputId": "dbb43bac-8141-40a3-c760-95656181b598", "colab": { "base_uri": "https://localhost:8080/", "height": 85 } }, "cell_type": "code", "source": [ "def example_fun(length, val):\n", " return np.ones((length,)) * val\n", "# un-jit'd works fine\n", "print(example_fun(5, 4))\n", "\n", "bad_example_jit = jit(example_fun)\n", "# this will fail:\n", "try:\n", " print(bad_example_jit(10, 4))\n", "except Exception as e:\n", " print(\"error!\", e)\n", "# static_argnums tells JAX to recompile on changes at these argument positions:\n", "good_example_jit = jit(example_fun, static_argnums=(0,))\n", "# first compile\n", "print(good_example_jit(10, 4))\n", "# recompiles\n", "print(good_example_jit(5, 4))" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "[4. 4. 4. 4. 4.]\n", "error! `full` requires shapes to be concrete. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.\n", "[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]\n", "[4. 4. 4. 4. 4.]\n" ], "name": "stdout" } ] }, { "metadata": { "id": "MStx_r2oKxpp", "colab_type": "text" }, "cell_type": "markdown", "source": [ "`static_argnums` can be handy if `length` in our example rarely changes, but it would be disastrous if it changed a lot! \n", "\n", "Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside __jit__'d functions: " ] }, { "metadata": { "id": "m2ABpRd8K094", "colab_type": "code", "outputId": "06fe7d4e-2c59-4499-c04e-94166916be74", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "cell_type": "code", "source": [ "@jit\n", "def f(x):\n", " print(x)\n", " y = 2 * x\n", " print(y)\n", " return y\n", "f(2)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Traced\n", "Traced\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "array(4, dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 12 } ] }, { "metadata": { "id": "uCDcWG4MnVn-", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Structured control flow primitives\n", "\n", "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:\n", " - `lax.cond` _will be differentiable soon_\n", " - `lax.while_loop` __non-differentiable__*\n", " - `lax.fori_loop` __non-differentiable__*\n", " - `lax.scan` _will be differentiable soon_\n", "\n", "*_these can in principle be made to be __forward__-differentiable, but this isn't on the current roadmap._" ] }, { "metadata": { "id": "Sd9xrLMXeK3A", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## cond\n", "python equivalent:\n", "\n", "```\n", "def cond(pred, true_operand, true_fun, false_operand, false_fun):\n", " if pred:\n", " return true_fun(true_operand)\n", " else:\n", " return false_fun(false_operand)\n", "```" ] }, { "metadata": { "id": "SGxz9JOWeiyH", "colab_type": "code", "outputId": "b91c6e01-c3a7-41a0-b4d2-f815f273c8a7", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "from jax import lax\n", "\n", "operand = np.array([0.])\n", "lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)\n", "# --> array([1.], dtype=float32)\n", "lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)\n", "# --> array([-1.], dtype=float32)" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array([-1.], dtype=float32)" ] }, "metadata": { "tags": [] }, "execution_count": 207 } ] }, { "metadata": { "id": "xkOFAw24eOMg", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## while_loop\n", "\n", "python equivalent:\n", "```\n", "def while_loop(cond_fun, body_fun, init_val):\n", " val = init_val\n", " while cond_fun(val):\n", " val = body_fun(val)\n", " return val\n", "```" ] }, { "metadata": { "id": "jM-D39a-c436", "colab_type": "code", "outputId": "496ba1d8-e1d9-4432-d44b-c1104e1e966d", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "init_val = 0\n", "cond_fun = lambda x: x<10\n", "body_fun = lambda x: x+1\n", "lax.while_loop(cond_fun, body_fun, init_val)\n", "# --> array(10, dtype=int32)" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array(10, dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 208 } ] }, { "metadata": { "id": "apo3n3HAeQY_", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## fori_loop\n", "python equivalent:\n", "```\n", "def fori_loop(start, stop, body_fun, init_val):\n", " val = init_val\n", " for i in range(start, stop):\n", " val = body_fun(i, val)\n", " return val\n", "```" ] }, { "metadata": { "id": "dt3tUpOmeR8u", "colab_type": "code", "outputId": "3155b3ce-589c-437c-a456-de81b3db0a64", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "init_val = 0\n", "start = 0\n", "stop = 10\n", "body_fun = lambda i,x: x+i\n", "lax.fori_loop(start, stop, body_fun, init_val)\n", "# --> array(45, dtype=int32)" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "array(45, dtype=int32)" ] }, "metadata": { "tags": [] }, "execution_count": 209 } ] }, { "metadata": { "id": "SipXS5qiqk8e", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Summary\n", "\n", "$$\n", "\\begin{array} {r|rr} \n", "\\hline \\\n", "\\textrm{construct} \n", "& \\textrm{jit} \n", "& \\textrm{grad} \\\\\n", "\\hline \\\n", "\\textrm{if} & ❌ & ✔ \\\\\n", "\\textrm{for} & ✔* & ✔\\\\\n", "\\textrm{while} & ✔* & ✔\\\\\n", "\\textrm{lax.cond} & ✔ & \\textrm{soon!}\\\\\n", "\\textrm{lax.while_loop} & ✔ & ❌\\\\\n", "\\textrm{lax.fori_loop} & ✔ & ❌\\\\\n", "\\textrm{lax.scan} & \\textrm{soon!} & \\textrm{soon!}\\\\\n", "\\hline\n", "\\end{array}\n", "$$\n", "
* = argument-__value__-independent loop condition - unrolls the loop
" ] }, { "metadata": { "id": "bxuUjFVG-v1h", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# 🔪 Convolutions\n", "---" ] }, { "metadata": { "id": "0pcn2LeS-03b", "colab_type": "text" }, "cell_type": "markdown", "source": [ "JAX and XLA offer the very general N-dimensional __conv_general_dilated__ function, but it's not very obvious how to use it. We'll give some examples of the common use-cases. There are also the convenience functions `lax.conv` and `lax.conv_general_padding` for the most common kinds of convolutions.\n", "\n", "A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285) is highly recommended reading!\n", "\n", "Let's define a simple diagonal edge kernel:" ] }, { "metadata": { "id": "Yud1Y3ss-x1K", "colab_type": "code", "outputId": "1674482b-501a-43eb-91c6-0bef42a73d6d", "colab": { "base_uri": "https://localhost:8080/", "height": 286 } }, "cell_type": "code", "source": [ "# 2D kernel - HWIO layout\n", "kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32)\n", "kernel += onp.array([[1, 1, 0],\n", " [1, 0,-1],\n", " [0,-1,-1]])[:, :, onp.newaxis, onp.newaxis]\n", "\n", "print(\"Edge Conv kernel:\")\n", "plt.imshow(kernel[:, :, 0, 0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Edge Conv kernel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ8AAAD8CAYAAABpXiE9AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADe5JREFUeJzt3X+snmV9x/H3Zy1gJkwqJdKUKj+j\nc24GPEGUxTRDEySGLpEl8IeC0XQ4yZRoMtQEE5Nl6h8uYxpJA0RYDDaCgeNSQ2DAcFmKVFIohSCF\nuLS1EyyuyHSysu/+ODfm8XB+9Xru8zzP0fcrefJc931f576+vdp8ev9sU1VI0pH6vXEXIGllMjwk\nNTE8JDUxPCQ1MTwkNTE8JDUZKjySvDbJXUme7L7XzNPvpSQ7u8/0MGNKmgwZ5jmPJF8CnquqLyS5\nGlhTVX8zR78XqurYIeqUNGGGDY8ngI1VdSDJOuC+qnrjHP0MD+m3zLDh8V9VdXzXDvCzl5dn9TsM\n7AQOA1+oqtvn2d9mYDPAq38/b3vTGUc31ybt+tmJ4y5h4r24d99Pq6ppolYv1iHJ3cBJc2z67OBC\nVVWS+ZLoDVW1P8lpwD1JdlXVU7M7VdUWYAvA1FtfVd+/c8OivwBpPqdvvWLcJUy8H33iU//R+rOL\nhkdVvXu+bUl+kmTdwGnLM/PsY3/3/XSS+4CzgFeEh6SVY9hbtdPAZV37MuCO2R2SrElyTNdeC5wH\nPDbkuJLGbNjw+ALwniRPAu/ulkkyleT6rs8fAjuSPAzcy8w1D8NDWuEWPW1ZSFUdBM6fY/0O4CNd\n+9+BPx5mHEmTxydMJTUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwk\nNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1\nMTwkNeklPJJckOSJJHuSXD3H9mOSbO22P5DklD7GlTQ+Q4dHklXAV4H3Am8GLk3y5lndPgz8rKrO\nAP4e+OKw40oarz6OPM4B9lTV01X1IvBNYNOsPpuAm7r2rcD5SdLD2JLGpI/wWA/sHVje162bs09V\nHQYOASf0MLakMZmoC6ZJNifZkWTHswdfGnc5khbQR3jsBzYMLJ/crZuzT5LVwGuAg7N3VFVbqmqq\nqqZOPGFVD6VJWi59hMeDwJlJTk1yNHAJMD2rzzRwWde+GLinqqqHsSWNyephd1BVh5NcCdwJrAJu\nrKrdST4P7KiqaeAG4J+S7AGeYyZgJK1gQ4cHQFVtA7bNWnfNQPt/gL/oYyxJk2GiLphKWjkMD0lN\nDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0M\nD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU0MD0lNDA9JTQwPSU16CY8kFyR5Isme\nJFfPsf3yJM8m2dl9PtLHuJLGZ/WwO0iyCvgq8B5gH/BgkumqemxW161VdeWw40maDH0ceZwD7Kmq\np6vqReCbwKYe9itpgg195AGsB/YOLO8D3j5Hv/cneRfwQ+Cqqto7u0OSzcBmgNev76O0316nb71i\n3CVMvDOu2j7uEibej4b42VFdMP0OcEpV/QlwF3DTXJ2qaktVTVXV1IknrBpRaZJa9BEe+4ENA8sn\nd+t+raoOVtWvusXrgbf1MK6kMeojPB4EzkxyapKjgUuA6cEOSdYNLF4EPN7DuJLGaOgLC1V1OMmV\nwJ3AKuDGqtqd5PPAjqqaBv46yUXAYeA54PJhx5U0Xr1clayqbcC2WeuuGWh/Gvh0H2NJmgw+YSqp\nieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ\n4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIamJ4SGpieEhqYnhIalJL+GR5MYkzyR5\ndJ7tSXJtkj1JHklydh/jShqfvo48vg5csMD29wJndp/NwNd6GlfSmPQSHlV1P/DcAl02ATfXjO3A\n8UnW9TG2pPEY1TWP9cDegeV93brfkGRzkh1Jdjx78KURlSapxURdMK2qLVU1VVVTJ56watzlSFrA\nqMJjP7BhYPnkbp2kFWpU4TENfLC763IucKiqDoxobEnLYHUfO0lyC7ARWJtkH/A54CiAqroO2AZc\nCOwBfgF8qI9xJY1PL+FRVZcusr2Aj/UxlqTJMFEXTCWtHIaHpCaGh6QmhoekJoaHpCaGh6Qmhoek\nJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmhoekJoaHpCaGh6Qm\nhoekJoaHpCaGh6QmhoekJoaHpCaGh6QmvYRHkhuTPJPk0Xm2b0xyKMnO7nNNH+NKGp9e/qNr4OvA\nV4CbF+jzvap6X0/jSRqzXo48qup+4Lk+9iVpZejryGMp3pHkYeDHwKeqavfsDkk2A5sBVq1Zw+lb\nrxhheSvLGVdtH3cJ+h03qgumDwFvqKq3Av8I3D5Xp6raUlVTVTW16thXj6g0SS1GEh5V9XxVvdC1\ntwFHJVk7irElLY+RhEeSk5Kka5/TjXtwFGNLWh69XPNIcguwEVibZB/wOeAogKq6DrgY+GiSw8Av\ngUuqqvoYW9J49BIeVXXpItu/wsytXEm/JXzCVFITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NS\nE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1ITw0NSE8NDUhPDQ1IT\nw0NSE8NDUhPDQ1ITw0NSE8NDUpOhwyPJhiT3Jnksye4kH5+jT5Jcm2RPkkeSnD3suJLGq4//6Pow\n8MmqeijJccAPktxVVY8N9HkvcGb3eTvwte5b0go19JFHVR2oqoe69s+Bx4H1s7ptAm6uGduB45Os\nG3ZsSePT6zWPJKcAZwEPzNq0Htg7sLyPVwaMpBWkt/BIcixwG/CJqnq+cR+bk+xIsuOlF/67r9Ik\nLYNewiPJUcwExzeq6ttzdNkPbBhYPrlb9xuqaktVTVXV1KpjX91HaZKWSR93WwLcADxeVV+ep9s0\n8MHursu5wKGqOjDs2JLGp4+7LecBHwB2JdnZrfsM8HqAqroO2AZcCOwBfgF8qIdxJY3R0OFRVf8G\nZJE+BXxs2LEkTQ6fMJXUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTw\nkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ1MTwkNTE8JDUxPCQ\n1MTwkNTE8JDUZOjwSLIhyb1JHkuyO8nH5+izMcmhJDu7zzXDjitpvFb3sI/DwCer6qEkxwE/SHJX\nVT02q9/3qup9PYwnaQIMfeRRVQeq6qGu/XPgcWD9sPuVNNlSVf3tLDkFuB94S1U9P7B+I3AbsA/4\nMfCpqto9x89vBjZ3i28BHu2tuH6sBX467iIGWM/CJq0emLya3lhVx7X8YG/hkeRY4F+Bv62qb8/a\n9gfA/1XVC0kuBP6hqs5cZH87qmqql+J6Mmk1Wc/CJq0emLyahqmnl7stSY5i5sjiG7ODA6Cqnq+q\nF7r2NuCoJGv7GFvSePRxtyXADcDjVfXlefqc1PUjyTnduAeHHVvS+PRxt+U84APAriQ7u3WfAV4P\nUFXXARcDH01yGPglcEktfr60pYfa+jZpNVnPwiatHpi8mprr6fWCqaTfHT5hKqmJ4SGpycSER5LX\nJrkryZPd95p5+r008Jj79DLUcUGSJ5LsSXL1HNuPSbK12/5A92zLslpCTZcneXZgXj6yjLXcmOSZ\nJHM+g5MZ13a1PpLk7OWq5QhqGtnrEUt8XWOkc7Rsr5BU1UR8gC8BV3ftq4EvztPvhWWsYRXwFHAa\ncDTwMPDmWX3+Criua18CbF3meVlKTZcDXxnR79O7gLOBR+fZfiHwXSDAucADE1DTRuCfRzQ/64Cz\nu/ZxwA/n+P0a6RwtsaYjnqOJOfIANgE3de2bgD8fQw3nAHuq6umqehH4ZlfXoME6bwXOf/k29Bhr\nGpmquh94boEum4Cba8Z24Pgk68Zc08jU0l7XGOkcLbGmIzZJ4fG6qjrQtf8TeN08/V6VZEeS7Un6\nDpj1wN6B5X28cpJ/3aeqDgOHgBN6ruNIawJ4f3cIfGuSDctYz2KWWu+ovSPJw0m+m+SPRjFgd0p7\nFvDArE1jm6MFaoIjnKM+nvNYsiR3AyfNsemzgwtVVUnmu4f8hqran+Q04J4ku6rqqb5rXWG+A9xS\nVb9K8pfMHBn92ZhrmiQPMfPn5uXXI24HFnw9Yljd6xq3AZ+ogfe8xmmRmo54jkZ65FFV766qt8zx\nuQP4ycuHbt33M/PsY3/3/TRwHzMp2pf9wODf2id36+bsk2Q18BqW92nZRWuqqoNV9atu8XrgbctY\nz2KWMocjVSN+PWKx1zUYwxwtxyskk3TaMg1c1rUvA+6Y3SHJmiTHdO21zDzdOvvfDRnGg8CZSU5N\ncjQzF0Rn39EZrPNi4J7qrjgtk0VrmnW+fBEz57TjMg18sLujcC5waOB0dCxG+XpEN86Cr2sw4jla\nSk1NczSKK9BLvCJ8AvAvwJPA3cBru/VTwPVd+53ALmbuOOwCPrwMdVzIzNXop4DPdus+D1zUtV8F\nfAvYA3wfOG0Ec7NYTX8H7O7m5V7gTctYyy3AAeB/mTlX/zBwBXBFtz3AV7tadwFTI5ifxWq6cmB+\ntgPvXMZa/hQo4BFgZ/e5cJxztMSajniOfDxdUpNJOm2RtIIYHpKaGB6SmhgekpoYHpKaGB6Smhge\nkpr8P9IpB0Tn+nMHAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "dITPaPdh_cMI", "colab_type": "text" }, "cell_type": "markdown", "source": [ "And we'll make a simple synthetic image:" ] }, { "metadata": { "id": "cpbGsIGa_Qyx", "colab_type": "code", "outputId": "44f0c042-3c74-4f39-9ed2-cd651cbc13fc", "colab": { "base_uri": "https://localhost:8080/", "height": 286 } }, "cell_type": "code", "source": [ "# NHWC layout\n", "img = onp.zeros((1, 200, 198, 3), dtype=np.float32)\n", "for k in range(3):\n", " x = 30 + 60*k\n", " y = 20 + 60*k\n", " img[0, x:x+10, y:y+10, k] = 1.0\n", "\n", "print(\"Original Image:\")\n", "plt.imshow(img[0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Original Image:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQMAAAD8CAYAAABzYsGzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADO1JREFUeJzt3V2MXOV9x/Hvr6ZwkSIBhVoInNog\nJxJE1ZYgEqkBkbZJAFU19ILaiho3QTVIWGqlShWkUoPam6oNRYqaEBnVwkgNL2pFsCIScK0q3IQG\nO7F4CwRDjPDW2AUqSJsoic2/F/NsM3F2s7M7c3Zmh+9HOppznjkz53k865/Oy8z5p6qQpF8adwck\nTQbDQBJgGEhqDANJgGEgqTEMJAEdhkGSq5I8n+Rgklu62o6k0UgX3zNIsgb4LvAR4DDwBLClqp4d\n+cYkjURXewaXAQer6qWq+jFwH7Cpo21JGoFTOnrf84BX+pYPAx9YaOUkfg1S6s5rVXXOYit1FQaL\nSrIN2Dau7UvvIC8PslJXYTALrOtbPr+1/b+q2gHsAPcMpEnQ1TmDJ4CNSTYkORXYDOzuaFuSRqCT\nPYOqOp5kO/AIsAbYWVXPdLEtSaPRyaXFJXfCwwSpS/ur6tLFVvIbiJIAw0BSYxhIAgwDSY1hIAkw\nDCQ1hoEkwDCQ1BgGkgDDQFJjGEgCDANJjWEgCTAMJDWGgSTAMJDUGAaSAMNAUrPsMEiyLsm/J3k2\nyTNJ/rS135ZkNsmBNl0zuu5K6sowN0Q9Dvx5VX0ryenA/iR72nN3VNVnh++epJWy7DCoqiPAkTb/\n/STfoVdJSdIqNJJzBknWA78J/Edr2p7kySQ7k5w5im1I6tbQYZDkV4B/Bf6sqt4C7gQuBGbo7Tnc\nvsDrtiXZl2TfsH2QNLyh6iYk+WXgK8AjVfUP8zy/HvhKVb1vkfexboLUnW7rJiQJ8E/Ad/qDIMm5\nfatdBzy93G1IWjnDXE34LeCPgKeSHGhtnwa2JJkBCjgE3DhUDyWtCMurSdNvoMOErkqyT4SlJEw6\n64W0Ovh1ZEmAYSCpMQwkAYaBpMYwkAQYBpIaw0ASYBhIagwDSYBhIKmZ6q8j+xVjaXDuGUgCDANJ\njWEgCTAMJDWGgSTAMJDUGAaSgBF8zyDJIeD7wAngeFVdmuQs4H5gPb2bol5fVf897LYkdWdUewYf\nrqqZvpsu3gLsraqNwN62LGmCdXWYsAnY1eZ3Add2tB1JIzKKMCjg0ST7k2xrbWtbYVaAV4G1J7/I\n8mrSZBnFbxM+VFWzSX4N2JPkuf4nq6rmq4tQVTuAHWDdBGkSDL1nUFWz7fEY8CBwGXB0rsxaezw2\n7HYkdWuoMEjyriSnz80DH6VXW3E3sLWtthV4aJjtSOresIcJa4EHezVYOQX4UlV9LckTwANJbgBe\nBq4fcjuSOmatRWn6dVuSXdJ0MQwkAYaBpMYwkAQYBpIaw0ASYBhIagwDSYBhIKkxDCQBhoGkxjCQ\nBBgGkhrDQBJgGEhqDANJgGEgqTEMJAFD3AMxyXvplVCbcwHwV8AZwJ8A/9XaP11VDy+7h5JWxEju\ngZhkDTALfAD4JPA/VfXZJbzeeyBK3VnReyD+DvBiVb08oveTtMJGFQabgXv7lrcneTLJziRnzvcC\ny6tJk2Xow4QkpwL/CVxcVUeTrAVeo1eD8W+Ac6vqU4u8h4cJUndW7DDhauBbVXUUoKqOVtWJqnob\nuIteuTVJE24UYbCFvkOEuRqLzXX0yq1JmnBDlVdr9RU/AtzY1/x3SWboHSYcOuk5SRPK8mrS9LO8\nmqTBGQaSAMNAUmMYSAIMA0mNYSAJMAwkNYaBJMAwkNQYBpKAIX+boFVgKV/0Tme90CrgnoEkwDCQ\n1BgGkgDDQFJjGEgCDANJjWEgCRgwDFr9g2NJnu5rOyvJniQvtMczW3uSfC7JwVY74ZKuOi9pdAbd\nM7gbuOqktluAvVW1EdjblqF36/SNbdoG3Dl8NyV1baAwqKrHgDdOat4E7Grzu4Br+9rvqZ7HgTNO\nun26pAk0zDmDtVV1pM2/Cqxt8+cBr/Std7i1aRyyhEnvaCP5bUJV1VJvd55kG73DCEkTYJg9g6Nz\nu//t8VhrnwXW9a13fmv7GVW1o6ouHeR+7pK6N0wY7Aa2tvmtwEN97Z9oVxU+CLzZdzghaVJV1aIT\nvVqKR4Cf0DsHcAPwq/SuIrwA/BtwVls3wOeBF4GngEsHeP9ycnLqbNo3yP9zy6tJ08/yapIGZxhI\nAgwDSY1hIAkwDCQ1hoEkwDCQ1BgGkgDDQFJjGEgCDANJjWEgCTAMJDWGgSTAMJDUGAaSAMNAUmMY\nSAIGCIMFSqv9fZLnWvm0B5Oc0drXJ/lhkgNt+mKXnZc0OoPsGdzNz5dW2wO8r6p+A/gucGvfcy9W\n1UybbhpNNyV1bdEwmK+0WlU9WlXH2+Lj9GojSFrFRnHO4FPAV/uWNyT5dpKvJ7l8BO8vaQUMVV4t\nyV8Cx4F/bk1HgHdX1etJ3g98OcnFVfXWPK+1vJo0QZa9Z5Dkj4HfAz5ec5VQqn5UVa+3+f30Cqm8\nZ77XW15NmizLCoMkVwF/Afx+Vf2gr/2cJGva/AXARuClUXRUUrcWPUxIci9wJXB2ksPAZ+hdPTgN\n2JME4PF25eAK4K+T/AR4G7ipqt6Y940lTRTLq0nTz/JqkgZnGEgCDANJjWEgCTAMJDWGgSTAMJDU\nGAaSAMNAUmMYSAIMA0mNYSAJMAwkNYaBJMAwkNQYBpIAw0BSYxhIApZfXu22JLN9ZdSu6Xvu1iQH\nkzyf5GNddVzSaC23vBrAHX1l1B4GSHIRsBm4uL3mC3N3S5Y02ZZVXu0X2ATc1+onfA84CFw2RP8k\nrZBhzhlsb1WYdyY5s7WdB7zSt87h1iaNWS1hemdabhjcCVwIzNArqXb7Ut8gybYk+5LsW2YfJI3Q\nssKgqo5W1Ymqehu4i58eCswC6/pWPb+1zfcelleTJshyy6ud27d4HTB3pWE3sDnJaUk20Cuv9s3h\nuihpJSy3vNqVSWboHWAdAm4EqKpnkjwAPEuvOvPNVXWim65LGiXLq+kdYil/YumsF2NieTVJgzMM\nJAGGgaTGMJAEGAaSmkUvLUrTYequEIycewaSAMNAUmMYSAIMA0mNYSAJMAwkNYaBJMAwkNQYBpIA\nw0BSYxhIAgwDSY1hIAlYfq3F+/vqLB5KcqC1r0/yw77nvthl5yWNziA/Yb4b+EfgnrmGqvrDufkk\ntwNv9q3/YlXNjKqDklbGomFQVY8lWT/fc0kCXA/89mi7JWmlDXvO4HLgaFW90Ne2Icm3k3w9yeUL\nvdDyatJkGfZOR1uAe/uWjwDvrqrXk7wf+HKSi6vqrZNfWFU7gB1g3QRpEix7zyDJKcAfAPfPtbVS\n7K+3+f3Ai8B7hu2kpO4Nc5jwu8BzVXV4riHJOUnWtPkL6NVafGm4LkpaCYNcWrwX+Abw3iSHk9zQ\nntrMzx4iAFwBPNkuNf4LcFNVvTHKDkvqhrUWpelnrUVJgzMMJAGGgaTGMJAEGAaSGsNAEmAYSGoM\nA0mAYSCpMQwkAYaBpMYwkAQYBpIaw0ASYBhIagwDSYBhIKkxDCQBhoGkxjCQBBgGkpphKyqNymvA\n/7bHaXM20zkumN6xTdu4fn2QlSbiVukASfYNcjvn1WZaxwXTO7ZpHddiPEyQBBgGkppJCoMd4+5A\nR6Z1XDC9Y5vWcf1CE3POQNJ4TdKegaQxGnsYJLkqyfNJDia5Zdz9GVaSQ0meSnIgyb7WdlaSPUle\naI9njrufi0myM8mxJE/3tc07jvR8rn2GTya5ZHw9X9wCY7styWz73A4kuabvuVvb2J5P8rHx9Lp7\nYw2DJGuAzwNXAxcBW5JcNM4+jciHq2qm7/LULcDeqtoI7G3Lk+5u4KqT2hYax9XAxjZtA+5coT4u\n1938/NgA7mif20xVPQzQ/h43Axe313yh/d1OnXHvGVwGHKyql6rqx8B9wKYx96kLm4BdbX4XcO0Y\n+zKQqnoMeOOk5oXGsQm4p3oeB85Icu7K9HTpFhjbQjYB91XVj6rqe8BBen+3U2fcYXAe8Erf8uHW\ntpoV8GiS/Um2tba1VXWkzb8KrB1P14a20Dim5XPc3g5zdvYdyk3L2BY17jCYRh+qqkvo7TrfnOSK\n/ierd/lm1V/CmZZx9LkTuBCYAY4At4+3Oytv3GEwC6zrWz6/ta1aVTXbHo8BD9LbpTw6t9vcHo+N\nr4dDWWgcq/5zrKqjVXWiqt4G7uKnhwKrfmyDGncYPAFsTLIhyan0TtTsHnOfli3Ju5KcPjcPfBR4\nmt6YtrbVtgIPjaeHQ1toHLuBT7SrCh8E3uw7nFgVTjrHcR29zw16Y9uc5LQkG+idJP3mSvdvJYz1\nV4tVdTzJduARYA2ws6qeGWefhrQWeDAJ9P5tv1RVX0vyBPBAkhuAl4Hrx9jHgSS5F7gSODvJYeAz\nwN8y/zgeBq6hd3LtB8AnV7zDS7DA2K5MMkPv0OcQcCNAVT2T5AHgWeA4cHNVnRhHv7vmNxAlAeM/\nTJA0IQwDSYBhIKkxDCQBhoGkxjCQBBgGkhrDQBIA/weZCejC+N5rZwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "_m90y74OWorG", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## lax.conv and lax.conv_with_general_padding" ] }, { "metadata": { "id": "Pv9_QPDnWssM", "colab_type": "text" }, "cell_type": "markdown", "source": [ "These are the simple convenience functions for convolutions\n", "\n", "️⚠️ The convenience `lax.conv`, `lax.conv_with_general_padding` helper function assume __NCHW__ images and __IOHW__ kernels." ] }, { "metadata": { "id": "kppxbxpZW0nb", "colab_type": "code", "outputId": "2c872f2b-b71a-4821-d870-0b3a4f1eeee9", "colab": { "base_uri": "https://localhost:8080/", "height": 629 } }, "cell_type": "code", "source": [ "out = lax.conv(np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n", " np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor\n", " (1, 1), # window strides\n", " 'SAME') # padding mode\n", "print(\"out shape: \", out.shape)\n", "print(\"First output channel:\")\n", "plt.figure(figsize=(10,10))\n", "plt.imshow(onp.array(out)[0,0,:,:]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "out shape: (1, 3, 200, 198)\n", "First output channel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGXdJREFUeJzt3X+spndZ5/HPtTPaxKmbtjvdpkKh\nhZQGNd2CYyVRCLtIbYmh4h/Qxigq2UICxGZNFDRZiImJq1a7ullMCQ2Q1AIuVhvTKl3WlWxilSk2\nY6EMTLEN0x3aDgWxxbB2uPaPc81wZjjTDnPOc55hzuuVnJz7+T4/7u/cuc/0Pff3eU6ruwMAQPKv\nlj0BAIBThTACABjCCABgCCMAgCGMAACGMAIAGMIIAGAsLIyq6sqq2ltV+6rqbYvaDwDARqlF/ILH\nqtqW5DNJXplkf5KPJ7m2uz+14TsDANgg2xf0upcn2dfdn0uSqvpAkquTrBlG287c0dvPOWdBUwEA\ntrKnHn88h554sk7ksYsKo2cl+fyq2/uT/NBxJ3HOOfmeX7x+QVMBALay/3vDjSf82KW9+bqqrquq\n3VW1+9ATTy5rGgAARywqjB5OcsGq28+esSO6+6bu3tXdu7aduWNB0wAAOHGLCqOPJ7m4qi6qqu9M\nck2S2xe0LwCADbGQ9xh191NV9ZYkf5FkW5Kbu/uTi9gXAMBGWdSbr9PddyS5Y1GvDwCw0fzmawCA\nIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYw\nAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABjCCABgCCMAgCGMAACGMAIAGCcdRlV1QVX9ZVV9qqo+WVW/MOPvrKqHq+re+XrVxk0XAGBxtq/j\nuU8l+cXu/kRVfXeSe6rqrrnvd7v7t9c/PQCAzXPSYdTdB5IcmO1/qqr7kzxroyYGALDZNuQ9RlV1\nYZIXJfmbGXpLVe2pqpur6uyN2AcAwKKtO4yq6swkH05yfXd/Jcm7kjw/yWVZuaJ0w3Ged11V7a6q\n3YeeeHK90wAAWLd1hVFVfUdWouiW7v7jJOnuR7r7UHd/Pcm7k1y+1nO7+6bu3tXdu7aduWM90wAA\n2BDr+VRaJXlPkvu7+3dWjZ+/6mGvSXLfyU8PAGDzrOdTaT+c5KeT/H1V3Ttjv5Lk2qq6LEkneTDJ\nG9c1QwCATbKeT6X9nyS1xl13nPx0AACWx2++BgAYwggAYAgjAIAhjAAAxno+lcZxnHX/ynvSd+5Z\n/C+uPHjpN34H1Jdf2AvfHwCczlwxAgAYwggAYFhKW4AjS2h37/nG4EsuXci+Vi+fnXvJwSTJY3t3\nLmRfAHC6c8UIAGC4YrRIq64S7Xvddy1kF9dfcec3jd2496qF7AsATneuGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAIzty57A6ejgpTuSJF9+YR8Zu/6KOxeyr7ee/dCR7d//0nMXsg8A2Cpc\nMQIAGMIIAGBYSluAw0to515ycOH7Wr18dstDP7jw/QHA6cwVIwCAIYwAAIaltAV6bO/OI9s37r1q\niTMBAE6EK0YAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABj+3pfoKoeTPJPSQ4leaq7d1XV\nOUk+mOTCJA8meW13f2m9+wIAWKSNumL077v7su7eNbffluSj3X1xko/ObQCAU9qiltKuTvK+2X5f\nkp9Y0H4AADbMRoRRJ/lIVd1TVdfN2HndfWC2v5DkvGOfVFXXVdXuqtp96IknN2AaAADrs+73GCX5\nke5+uKr+bZK7qurTq+/s7q6qPvZJ3X1TkpuS5IznXPBN9wMAbLZ1XzHq7ofn+6NJbktyeZJHqur8\nJJnvj653PwAAi7auMKqqHVX13Ye3k1yR5L4ktyd5/Tzs9Un+dD37AQDYDOtdSjsvyW1Vdfi1/rC7\n/7yqPp7kQ1X1hiQPJXntOvcDALBw6wqj7v5ckn+3xvgXk7xiPa8NALDZ/OZrAIAhjAAAhjACABjC\nCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACG\nMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMII\nAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCA\nIYwAAIYwAgAY20/2iVV1SZIPrhp6XpL/nOSsJP8xyWMz/ivdfcdJzxAAYJOcdBh1994klyVJVW1L\n8nCS25L8XJLf7e7f3pAZAgBsko1aSntFkge6+6ENej0AgE23UWF0TZJbV91+S1Xtqaqbq+rstZ5Q\nVddV1e6q2n3oiSc3aBoAACdv3WFUVd+Z5NVJ/miG3pXk+VlZZjuQ5Ia1ntfdN3X3ru7ete3MHeud\nBgDAum3EFaOrknyiux9Jku5+pLsPdffXk7w7yeUbsA8AgIXbiDC6NquW0arq/FX3vSbJfRuwDwCA\nhTvpT6UlSVXtSPLKJG9cNfybVXVZkk7y4DH3AQCcstYVRt39ZJJ/c8zYT69rRgAAS+I3XwMADGEE\nADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMDYvuwJAItz1v11ZHvnnicXuq+Dl+44sv3lF/ZC9wWwKK4YAQAMYQQAMCylwWns\nqOWzu/esfH/JpQvZ1+rls3MvOXhk+7G9OxeyP4BFcMUIAGAIIwCAYSkNtopZQtv3uu9ayMtff8Wd\na47fuPeqhewPYBFcMQIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAxvZlTwBYnIOX7jiy/eUXdpLk\n+ivuXMi+3nr2Q0e2f/9Lz13IPgAW7YSuGFXVzVX1aFXdt2rsnKq6q6o+O9/PnvGqqt+rqn1Vtaeq\nXryoyQMAbKQTXUp7b5Irjxl7W5KPdvfFST46t5PkqiQXz9d1Sd61/mkCACzeCS2ldffHqurCY4av\nTvLy2X5fkv+d5Jdn/P3d3Unurqqzqur87j6wERMGTtzh5bMkOfeSgwvd1+rls1se+sGF7gtgUdbz\n5uvzVsXOF5KcN9vPSvL5VY/bP2NHqarrqmp3Ve0+9MST65gGAMDG2JBPpc3VoX7GBx79nJu6e1d3\n79p25o5nfgIAwIKt51NpjxxeIquq85M8OuMPJ7lg1eOePWPAEj22d2eS5Ma9Vy15JgCnrvVcMbo9\nyetn+/VJ/nTV+M/Mp9NekuQfvb8IAPh2cEJXjKrq1qy80XpnVe1P8o4kv5HkQ1X1hiQPJXntPPyO\nJK9Ksi/JV5P83AbPGQBgIU70U2nXHueuV6zx2E7y5vVMCgBgGfwvQQAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIA\nGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYDxjGFXVzVX1aFXdt2rst6rq01W1p6puq6qzZvzCqvrnqrp3vv5gkZMHANhIJ3LF6L1Jrjxm\n7K4k39/dlyb5TJK3r7rvge6+bL7etDHTBABYvGcMo+7+WJLHjxn7SHc/NTfvTvLsBcwNAGBTbcR7\njH4+yZ2rbl9UVX9XVX9VVS893pOq6rqq2l1Vuw898eQGTAMAYH22r+fJVfWrSZ5KcssMHUjynO7+\nYlX9QJI/qarv6+6vHPvc7r4pyU1JcsZzLuj1zAMAYCOc9BWjqvrZJD+e5Ke6u5Oku7/W3V+c7XuS\nPJDkBRswTwCAhTupMKqqK5P8UpJXd/dXV42fW1XbZvt5SS5O8rmNmCgAwKI941JaVd2a5OVJdlbV\n/iTvyMqn0M5IcldVJcnd8wm0lyX5tar6lyRfT/Km7n58zRcGADjFPGMYdfe1awy/5ziP/XCSD693\nUgAAy+A3XwMADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABD\nGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEE\nADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwti97AgCc/s66v45s79zz5ML3d/DSHUmSL7+wF74vTi/P\neMWoqm6uqker6r5VY++sqoer6t75etWq+95eVfuqam9V/diiJg4AsNFOZCntvUmuXGP8d7v7svm6\nI0mq6nuTXJPk++Y5/72qtm3UZAEAFukZl9K6+2NVdeEJvt7VST7Q3V9L8g9VtS/J5Un++qRnCMC3\nvaOWz+7e843tl1y6kP0dXkI795KDR8Ye27tzIfvi9LKeN1+/par2zFLb2TP2rCSfX/WY/TP2Tarq\nuqraXVW7Dz2x+PVmAIBncrJh9K4kz09yWZIDSW74Vl+gu2/q7l3dvWvbmTtOchoAABvnpD6V1t2P\nHN6uqncn+bO5+XCSC1Y99NkzBgArVi2f7Xvddy1kF9dfcec3jd2496qF7IvTy0ldMaqq81fdfE2S\nw59Yuz3JNVV1RlVdlOTiJH+7vikCAGyOZ7xiVFW3Jnl5kp1VtT/JO5K8vKouS9JJHkzyxiTp7k9W\n1YeSfCrJU0ne3N2HFjN1AICNdSKfSrt2jeH3PM3jfz3Jr69nUgAAy+B/CQIAMIQRAMAQRgAAQxgB\nAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAw\nti97AgCc/g5euuPI9pdf2Ee2r7/izoXs761nP5Qk+f0vPXchr8/pyxUjAIAhjAAAhqU0ABZu9fLZ\nuZccXPj+Di+h3fLQDy58X5xeXDECABjCCABgWEoDYFM9tnfnke0b9161xJnAN3PFCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIA\nGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAxjOGUVXdXFWPVtV9q8Y+WFX3zteDVXXvjF9YVf+86r4/\nWOTkAQA20vYTeMx7k/y3JO8/PNDdrzu8XVU3JPnHVY9/oLsv26gJAgBslmcMo+7+WFVduNZ9VVVJ\nXpvkP2zstAAANt9632P00iSPdPdnV41dVFV/V1V/VVUvPd4Tq+q6qtpdVbsPPfHkOqcBALB+J7KU\n9nSuTXLrqtsHkjynu79YVT+Q5E+q6vu6+yvHPrG7b0pyU5Kc8ZwLep3zAABYt5O+YlRV25P8ZJIP\nHh7r7q919xdn+54kDyR5wXonCQCwGdazlPajST7d3fsPD1TVuVW1bbafl+TiJJ9b3xQBADbHiXxc\n/9Ykf53kkqraX1VvmLuuydHLaEnysiR75uP7/yPJm7r78Y2cMADAopzIp9KuPc74z64x9uEkH17/\ntAAANp/ffA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMKq7lz2HVNVjSR5KsjPJwSVP51Ti\neBzN8Tia4/ENjsXRHI+jOR5H24rH47ndfe6JPPCUCKPDqmp3d+9a9jxOFY7H0RyPozke3+BYHM3x\nOJrjcTTH4+lZSgMAGMIIAGCcamF007IncIpxPI7meBzN8fgGx+JojsfRHI+jOR5P45R6jxEAwDKd\naleMAACW5pQIo6q6sqr2VtW+qnrbsuez2arqgqr6y6r6VFV9sqp+YcbfWVUPV9W98/WqZc91s1TV\ng1X19/Pn3j1j51TVXVX12fl+9rLnuRmq6pJV58C9VfWVqrp+K50fVXVzVT1aVfetGlvzfKgVvzd/\nn+ypqhcvb+aLcZzj8VtV9en5M99WVWfN+IVV9c+rzpM/WN7MF+M4x+O4Px9V9fY5P/ZW1Y8tZ9aL\nc5zj8cFVx+LBqrp3xk/78+NbtfSltKraluQzSV6ZZH+Sjye5trs/tdSJbaKqOj/J+d39iar67iT3\nJPmJJK9N8kR3//ZSJ7gEVfVgkl3dfXDV2G8meby7f2MC+uzu/uVlzXEZ5ufl4SQ/lOTnskXOj6p6\nWZInkry/u79/xtY8H+Y/gG9N8qqsHKf/2t0/tKy5L8JxjscVSf5Xdz9VVf8lSeZ4XJjkzw4/7nR0\nnOPxzqzx81FV35vk1iSXJ/meJP8zyQu6+9CmTnqB1joex9x/Q5J/7O5f2wrnx7fqVLhidHmSfd39\nue7+f0k+kOTqJc9pU3X3ge7+xGz/U5L7kzxrubM6JV2d5H2z/b6sxONW84okD3T3Q8ueyGbq7o8l\nefyY4eOdD1dn5T8I3d13Jzlr/vFx2ljreHT3R7r7qbl5d5Jnb/rEluQ458fxXJ3kA939te7+hyT7\nsvLfodPG0x2Pqqqs/KP71k2d1LeRUyGMnpXk86tu788WjoKp9xcl+ZsZestcGr95qywdjU7ykaq6\np6qum7HzuvvAbH8hyXnLmdpSXZOj/0LbqudHcvzzwd8pyc8nuXPV7Yuq6u+q6q+q6qXLmtQSrPXz\nsdXPj5cmeaS7P7tqbKueH2s6FcKIUVVnJvlwkuu7+ytJ3pXk+UkuS3IgyQ1LnN5m+5HufnGSq5K8\neS4NH9Era8Bb6iOVVfWdSV6d5I9maCufH0fZiufD8VTVryZ5KsktM3QgyXO6+0VJ/lOSP6yqf72s\n+W0iPx9ruzZH/+Nqq54fx3UqhNHDSS5YdfvZM7alVNV3ZCWKbunuP06S7n6kuw9199eTvDun2eXe\np9PdD8/3R5PclpU/+yOHl0Tm+6PLm+FSXJXkE939SLK1z49xvPNhy/6dUlU/m+THk/zUxGJmyeiL\ns31PkgeSvGBpk9wkT/PzsZXPj+1JfjLJBw+PbdXz4+mcCmH08SQXV9VF8y/ia5LcvuQ5bapZ831P\nkvu7+3dWja9+X8Rrktx37HNPR1W1Y96EnqrakeSKrPzZb0/y+nnY65P86XJmuDRH/Utvq54fqxzv\nfLg9yc/Mp9NekpU3mR5Y6wVOJ1V1ZZJfSvLq7v7qqvFz5037qarnJbk4yeeWM8vN8zQ/H7cnuaaq\nzqiqi7JyPP52s+e3JD+a5NPdvf/wwFY9P57O9mVPYD5B8ZYkf5FkW5Kbu/uTS57WZvvhJD+d5O8P\nf4Qyya8kubaqLsvKEsGDSd64nOltuvOS3LbSi9me5A+7+8+r6uNJPlRVb0jyUFbeQLglTCC+Mkef\nA7+5Vc6Pqro1ycuT7Kyq/UnekeQ3svb5cEdWPpG2L8lXs/LpvdPKcY7H25OckeSu+dm5u7vflORl\nSX6tqv4lydeTvKm7T/SNyt8WjnM8Xr7Wz0d3f7KqPpTkU1lZcnzz6fSJtGTt49Hd78k3v0cx2QLn\nx7dq6R/XBwA4VZwKS2kAAKcEYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAADj/wNFwYmlDN8g\nuwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "aonr1tWvYCW9", "colab_type": "code", "outputId": "63727dd7-1758-4aa0-f93f-557758a160a8", "colab": { "base_uri": "https://localhost:8080/", "height": 629 } }, "cell_type": "code", "source": [ "out = lax.conv_with_general_padding(\n", " np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor\n", " np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor\n", " (1, 1), # window strides\n", " ((2,2),(2,2)), # general padding 2x2\n", " (1,1), # lhs/image dilation\n", " (1,1)) # rhs/kernel dilation\n", "print(\"out shape: \", out.shape)\n", "print(\"First output channel:\")\n", "plt.figure(figsize=(10,10))\n", "plt.imshow(onp.array(out)[0,0,:,:]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "out shape: (1, 3, 202, 200)\n", "First output channel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGr1JREFUeJzt3X+s5XV95/HXe2dakg7dAAtLKKKg\nQaJtWLRTJGk17lpZMI3U/qGQxtLWLJqoKdluWrXJapo0cdvauu1mbTASMaGoXUolDbSybrdmk9I6\nWIL8cOpgIUIQGH9UwcYWfO8f9z0zZ3DGGeaec+905vFIbu73fM6P72e+fO/wnO/3fM+t7g4AAMm/\n2uwJAAAcLYQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABjZWFUVZdU1c6q2lVV71jVegAAlqVW8QGP\nVbUlyd8leXWSh5J8JskV3X3v0lcGALAkW1f0uhcm2dXdX0ySqvpoksuSHDCMtpy4rbeecsqKpgIA\nHO/+6UsP7e7u0w71uFWF0ZlJvrRw+6EkLzvoJE45JT/0y1evaCoAwPHugav/y4OH87hNe/N1VV1V\nVTuqasfTTzy5WdMAANhrVWH0cJKzFm4/Z8b26u5runt7d2/fcuK2FU0DAODwrSqMPpPk3Ko6p6q+\nP8nlSW5e0boAAJZiJe8x6u6nquptSf48yZYk13b3PatYFwDAsqzqzdfp7luS3LKq1wcAWDaffA0A\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgB\nAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAw\nhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBG\nAABDGAEADGEEADCEEQDAOOIwqqqzquovqureqrqnqn5pxt9TVQ9X1Z3z9ZrlTRcAYHW2ruO5TyX5\n5e7+bFX9YJI7quq2ue93u/u31z89AICNc8Rh1N2PJHlklr9ZVfclOXNZEwMA2GhLeY9RVZ2d5CVJ\n/nqG3lZVd1XVtVV18jLWAQCwausOo6o6McmNSa7u7m8k+UCSFyS5IGtHlN53kOddVVU7qmrH0088\nud5pAACs27rCqKq+L2tRdH13/3GSdPej3f10d38nyQeTXHig53b3Nd29vbu3bzlx23qmAQCwFOu5\nKq2SfCjJfd39OwvjZyw87HVJ7j7y6QEAbJz1XJX240nemORzVXXnjL0ryRVVdUGSTvJAkjeva4YA\nABtkPVel/b8kdYC7bjny6QAAbB6ffA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABjPR/wyCGc\ndN++j3k69a7V/j643efv+7UqX39Rr3RdAHCscsQIAGA4YrRC+x0luv2ute8Xnb+SdS0eJTrtvN1J\nksd3nrqSdQHAscoRIwCAIYwAAIZTaRtlTqHtesMPrOTlr7741u8ae//OS1eyLgA4VjliBAAwhBEA\nwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABD\nGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMLZu9gSOZbvP37Z3+esv6iTJ1RffupJ1vf3kB/cu//7X\nnreSdQDAsc4RIwCA4YjRCu05SpQkp523e6XrWjxKdP2DP7bSdQHAscoRIwCAIYwAAIZTaRvk8Z2n\nJknev/PSTZ4JAHAwjhgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA2LreF6iqB5J8M8nTSZ7q7u1VdUqS\njyU5O8kDSV7f3V9b77oAAFZpWUeM/n13X9Dd2+f2O5J8qrvPTfKpuQ0AcFRb1am0y5JcN8vXJfnp\nFa0HAGBplhFGneSTVXVHVV01Y6d39yOz/OUkpz/zSVV1VVXtqKodTz/x5BKmAQCwPut+j1GSn+ju\nh6vq3ya5rao+v3hnd3dV9TOf1N3XJLkmSU547lnfdT8AwEZb9xGj7n54vj+W5KYkFyZ5tKrOSJL5\n/th61wMAsGrrCqOq2lZVP7hnOcnFSe5OcnOSK+dhVyb5xHrWAwCwEdZ7Ku30JDdV1Z7X+sPu/rOq\n+kySj1fVm5I8mOT161wPAMDKrSuMuvuLSf7dAca/kuRV63ltAICN5pOvAQCGMAIAGMIIAGAIIwCA\nIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYw\nAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABhbj/SJVXVeko8tDD0/yX9NclKS/5Tk8Rl/V3ffcsQzBADYIEccRt29M8kFSVJVW5I8nOSmJL+Q\n5He7+7eXMkMAgA2yrFNpr0pyf3c/uKTXAwDYcMsKo8uT3LBw+21VdVdVXVtVJx/oCVV1VVXtqKod\nTz/x5JKmAQBw5NYdRlX1/Ulem+SPZugDSV6QtdNsjyR534Ge193XdPf27t6+5cRt650GAMC6LeOI\n0aVJPtvdjyZJdz/a3U9393eSfDDJhUtYBwDAyi0jjK7Iwmm0qjpj4b7XJbl7CesAAFi5I74qLUmq\naluSVyd588Lwb1bVBUk6yQPPuA8A4Ki1rjDq7ieT/JtnjL1xXTMCANgkPvkaAGAIIwCAIYwAAIYw\nAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgbN3sCQCrddJ9lSQ59a4nV76u3edvS5J8/UW98nUBrIIjRgAAQxgBAAyn0uAY\nt/cU2u137Ru86PyVrGvPKbTTztu9d+zxnaeuZF0Aq+CIEQDAEEYAAMOpNDheLJw+2/WGH1jJKq6+\n+NbvGnv/zktXsi6AVXDECABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAcVhhVFXXVtVjVXX3wtgpVXVbVX1hvp88\n41VVv1dVu6rqrqp66aomDwCwTId7xOjDSS55xtg7knyqu89N8qm5nSSXJjl3vq5K8oH1TxMAYPW2\nHs6DuvvTVXX2M4YvS/LKWb4uyf9N8qsz/pHu7iS3V9VJVXVGdz+yjAkDz87u87clSb7+ot47dvXF\nt65kXW8/+cEkye9/7XkreX2AVVvPe4xOX4idLyc5fZbPTPKlhcc9NGMAAEe1pbz5eo4O9SEfuKCq\nrqqqHVW14+knnlzGNAAA1uWwTqUdxKN7TpFV1RlJHpvxh5OctfC458zYfrr7miTXJMkJzz3rWUUV\ncPj2nEI77bzdK1/XnlNo1z/4YytfF8AqrOeI0c1JrpzlK5N8YmH85+bqtIuS/IP3FwEA/xIc1hGj\nqroha2+0PrWqHkry7iTvTfLxqnpTkgeTvH4efkuS1yTZleRbSX5hyXMGAFiJw70q7YqD3PWqAzy2\nk7x1PZMClu/xnafuXX7/zks3cSYARy+ffA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEA\nDGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCE\nEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAOGQYVdW1VfVYVd29MPZb\nVfX5qrqrqm6qqpNm/Oyq+sequnO+/mCVkwcAWKbDOWL04SSXPGPstiQ/0t3nJ/m7JO9cuO/+7r5g\nvt6ynGkCAKzeIcOouz+d5KvPGPtkdz81N29P8pwVzA0AYEMt4z1Gv5jk1oXb51TV31bVX1bVy5fw\n+gAAG2Lrep5cVb+W5Kkk18/QI0me291fqaofTfInVfXD3f2NAzz3qiRXJcmWk09ezzQAAJbiiI8Y\nVdXPJ/mpJD/b3Z0k3f3t7v7KLN+R5P4kLzzQ87v7mu7e3t3bt5y47UinAQCwNEcURlV1SZJfSfLa\n7v7WwvhpVbVllp+f5NwkX1zGRAEAVu2Qp9Kq6oYkr0xyalU9lOTdWbsK7YQkt1VVktw+V6C9Ismv\nV9U/J/lOkrd091cP+MIAAEeZQ4ZRd19xgOEPHeSxNya5cb2TAgDYDD75GgBgCCMAgCGMAACGMAIA\nGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgj\nAIAhjAAAxiHDqKqurarHquruhbH3VNXDVXXnfL1m4b53VtWuqtpZVf9xVRMHAFi2wzli9OEklxxg\n/He7+4L5uiVJqurFSS5P8sPznP9ZVVuWNVkAgFXaeqgHdPenq+rsw3y9y5J8tLu/neTvq2pXkguT\n/NURzxCAf9FOuq/2Lp9615MrXdfu87ftXf76i3ql6+LYtJ73GL2tqu6aU20nz9iZSb608JiHZuy7\nVNVVVbWjqnY8/cRqf1AAAA7HkYbRB5K8IMkFSR5J8r5n+wLdfU13b+/u7VtO3HboJwAArNghT6Ud\nSHc/ume5qj6Y5E/n5sNJzlp46HNmDIDj1H6nz26/a9/yRecvfV2Lp89OO2/33uXHd5669HVxbDqi\nI0ZVdcbCzdcl2XPF2s1JLq+qE6rqnCTnJvmb9U0RAGBjHPKIUVXdkOSVSU6tqoeSvDvJK6vqgiSd\n5IEkb06S7r6nqj6e5N4kTyV5a3c/vZqpA/AvzsJRol1v+IGlv/zVF996wPH377x06evi2HQ4V6Vd\ncYDhD32Px/9Gkt9Yz6QAADaDT74GABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAMbWzZ4AAMe23edv\n27v89Rf13uWrL7516et6+8kP7l3+/a89b+mvz7HPESMAgCGMAACGU2kArNTi6bPTztu90nUtnj67\n/sEfW+m6ODY5YgQAMBwxAmDDPL7z1L3L79956SbOBA7MESMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYBwy\njKrq2qp6rKruXhj7WFXdOV8PVNWdM352Vf3jwn1/sMrJAwAs09bDeMyHk/yPJB/ZM9Ddb9izXFXv\nS/IPC4+/v7svWNYEAQA2yiHDqLs/XVVnH+i+qqokr0/yH5Y7LQCAjbfe9xi9PMmj3f2FhbFzqupv\nq+ovq+rlB3tiVV1VVTuqasfTTzy5zmkAAKzf4ZxK+16uSHLDwu1Hkjy3u79SVT+a5E+q6oe7+xvP\nfGJ3X5PkmiQ54bln9TrnAQCwbkd8xKiqtib5mSQf2zPW3d/u7q/M8h1J7k/ywvVOEgBgI6znVNpP\nJvl8dz+0Z6CqTquqLbP8/CTnJvni+qYIALAxDudy/RuS/FWS86rqoap609x1efY/jZYkr0hy11y+\n/7+SvKW7v7rMCQMArMrhXJV2xUHGf/4AYzcmuXH90wIA2Hg++RoAYAgjAIAhjAAAhjACABjCCABg\nCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIA\nGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAccgw\nqqqzquovqureqrqnqn5pxk+pqtuq6gvz/eQZr6r6varaVVV3VdVLV/2HAABYhsM5YvRUkl/u7hcn\nuSjJW6vqxUnekeRT3X1ukk/N7SS5NMm583VVkg8sfdYAACtwyDDq7ke6+7Oz/M0k9yU5M8llSa6b\nh12X5Kdn+bIkH+k1tyc5qarOWPrMAQCW7Fm9x6iqzk7ykiR/neT07n5k7vpyktNn+cwkX1p42kMz\nBgBwVDvsMKqqE5PcmOTq7v7G4n3d3Un62ay4qq6qqh1VtePpJ558Nk8FAFiJwwqjqvq+rEXR9d39\nxzP86J5TZPP9sRl/OMlZC09/zoztp7uv6e7t3b19y4nbjnT+AABLczhXpVWSDyW5r7t/Z+Gum5Nc\nOctXJvnEwvjPzdVpFyX5h4VTbgAAR62th/GYH0/yxiSfq6o7Z+xdSd6b5ONV9aYkDyZ5/dx3S5LX\nJNmV5FtJfmGpMwYAWJFae3vQJk+i6vGsxdWpSXZv8nSOFrbFPrbF/myPfWyLfWyLfWyL/dkea57X\n3acd6kFHRRjtUVU7unv7Zs/jaGBb7GNb7M/22Me22Me22Me22J/t8ez4lSAAAEMYAQCMoy2Mrtns\nCRxFbIt9bIv92R772Bb72Bb72Bb7sz2ehaPqPUYAAJvpaDtiBACwaY6KMKqqS6pqZ1Xtqqp3bPZ8\nNlJVnVVVf1FV91bVPVX1SzP+nqp6uKrunK/XbPZcN0pVPVBVn5s/944ZO6WqbquqL8z3kzd7nqtW\nVect/Pe/s6q+UVVXHy/7RlVdW1WPVdXdC2MH3A/mA2V/b/4OuauqXrp5M1+Ng2yP36qqz8+f+aaq\nOmnGz66qf1zYR/5g82a+fAfZFgf9uaiqd86+sbOq/uPmzHo1DrItPrawHR7Y8xmEx/p+sSybfiqt\nqrYk+bskr87aL5z9TJIruvveTZ3YBplfp3JGd3+2qn4wyR1JfjprH5j5RHf/9qZOcBNU1QNJtnf3\n7oWx30zy1e5+78Tzyd39q5s1x402PycPJ3lZ1j409ZjfN6rqFUmeSPKR7v6RGTvgfjD/E3x71j5c\n9mVJ/nt3v2yz5r4KB9keFyf5P939VFX9tySZ7XF2kj/d87hjzUG2xXtygJ+LqnpxkhuSXJjkh5L8\n7yQv7O6nN3TSK3KgbfGM+9+Xtd9A8evH+n6xLEfDEaMLk+zq7i929z8l+WiSyzZ5Thumux/p7s/O\n8jeT3JfkzM2d1VHpsiTXzfJ1WYvH48mrktzf3Q9u9kQ2Snd/OslXnzF8sP3gsqz9j6G7+/YkJ80/\nOo4ZB9oe3f3J7n5qbt6etd9Necw7yL5xMJcl+Wh3f7u7/z5rv5XhwpVNboN9r20xv9Lr9VkLQw7T\n0RBGZyb50sLth3KchsHU/EuS/PUMvW0OkV97PJw6WtBJPllVd1TVVTN2+sLv3PtyktM3Z2qb5vLs\n/5fb8bpvHGw/8PdI8otJbl24fU5V/W1V/WVVvXyzJrXBDvRzcTzvGy9P8mh3f2Fh7HjcL56VoyGM\nSFJVJya5McnV3f2NJB9I8oIkFyR5JMn7NnF6G+0nuvulSS5N8tY5VLxXr53/PW4up6yq70/y2iR/\nNEPH876x1/G2H3wvVfVrSZ5Kcv0MPZLkud39kiT/OckfVtW/3qz5bRA/F9/tiuz/D6rjcb941o6G\nMHo4yVkLt58zY8eNqvq+rEXR9d39x0nS3Y9299Pd/Z0kH8wxdOj3ULr74fn+WJKbsvZnf3TPqZH5\n/tjmzXDDXZrks939aHJ87xs5+H5w3P49UlU/n+SnkvzsxGLmtNFXZvmOJPcneeGmTXIDfI+fi+Ny\n36iqrUl+JsnH9owdj/vFkTgawugzSc6tqnPmX8aXJ7l5k+e0YeYc8IeS3Nfdv7Mwvvj+iNclufuZ\nzz0WVdW2eRN6qmpbkouz9me/OcmV87Ark3xic2a4Kfb7V9/xum+Mg+0HNyf5ubk67aKsvdn0kQO9\nwLGkqi5J8itJXtvd31oYP23esJ+qen6Sc5N8cXNmuTG+x8/FzUkur6oTquqcrG2Lv9no+W2Cn0zy\n+e5+aM/A8bhfHImtmz2BuZribUn+PMmWJNd29z2bPK2N9ONJ3pjkc3suqUzyriRXVNUFWTtV8ECS\nN2/O9Dbc6UluWuvFbE3yh939Z1X1mSQfr6o3JXkwa28oPOZNHL46+//3/83jYd+oqhuSvDLJqVX1\nUJJ3J3lvDrwf3JK1K9J2JflW1q7cO6YcZHu8M8kJSW6bn5nbu/stSV6R5Ner6p+TfCfJW7r7cN+s\nfNQ7yLZ45YF+Lrr7nqr6eJJ7s3a68a3HyhVpyYG3RXd/KN/9vsTkGN8vlmXTL9cHADhaHA2n0gAA\njgrCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAMb/B4Qd82ed6HqvAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "lyOwGRez_ycJ", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## Dimension Numbers define dimensional layout for conv_general_dilated\n", "\n", "The important argument is the 3-tuple of axis layout arguments:\n", "(Input Layout, Kernel Layout, Output Layout)\n", " - __N__ - batch dimension\n", " - __H__ - spatial height\n", " - __W__ - spatial height\n", " - __C__ - channel dimension\n", " - __I__ - kernel _input_ channel dimension\n", " - __O__ - kernel _output_ channel dimension\n", "\n", "⚠️ To demonstrate the flexibility of dimension numbers we choose a __NHWC__ image and __HWIO__ kernel convention for `lax.conv_general_dilated` below." ] }, { "metadata": { "id": "oXKebfCb_i2B", "colab_type": "code", "outputId": "0b80fca6-0eb7-4baf-d824-458c3739d052", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n", " kernel.shape, # only ndim matters, not shape \n", " ('NHWC', 'HWIO', 'NHWC')) # the important bit\n", "print(dn)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n" ], "name": "stdout" } ] }, { "metadata": { "id": "elZys_HzFVG6", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## SAME padding, no stride, no dilation" ] }, { "metadata": { "id": "rgb2T15aFVG6", "colab_type": "code", "outputId": "93fed3a7-69d2-4046-de2f-487ff34b5ee2", "colab": { "base_uri": "https://localhost:8080/", "height": 629 } }, "cell_type": "code", "source": [ "out = lax.conv_general_dilated(img, # lhs = image tensor\n", " kernel, # rhs = conv kernel tensor\n", " (1,1), # window strides\n", " 'SAME', # padding mode\n", " (1,1), # lhs/image dilation\n", " (1,1), # rhs/kernel dilation\n", " dn) # dimension_numbers = lhs, rhs, out dimension permutation\n", "print(\"out shape: \", out.shape)\n", "print(\"First output channel:\")\n", "plt.figure(figsize=(10,10))\n", "plt.imshow(onp.array(out)[0,:,:,0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "out shape: (1, 200, 198, 3)\n", "First output channel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGXdJREFUeJzt3X+spndZ5/HPtTPaxKmbtjvdpkKh\nhZQGNd2CYyVRCLtIbYmh4h/Qxigq2UICxGZNFDRZiImJq1a7ullMCQ2Q1AIuVhvTKl3WlWxilSk2\nY6EMTLEN0x3aDgWxxbB2uPaPc81wZjjTDnPOc55hzuuVnJz7+T4/7u/cuc/0Pff3eU6ruwMAQPKv\nlj0BAIBThTACABjCCABgCCMAgCGMAACGMAIAGMIIAGAsLIyq6sqq2ltV+6rqbYvaDwDARqlF/ILH\nqtqW5DNJXplkf5KPJ7m2uz+14TsDANgg2xf0upcn2dfdn0uSqvpAkquTrBlG287c0dvPOWdBUwEA\ntrKnHn88h554sk7ksYsKo2cl+fyq2/uT/NBxJ3HOOfmeX7x+QVMBALay/3vDjSf82KW9+bqqrquq\n3VW1+9ATTy5rGgAARywqjB5OcsGq28+esSO6+6bu3tXdu7aduWNB0wAAOHGLCqOPJ7m4qi6qqu9M\nck2S2xe0LwCADbGQ9xh191NV9ZYkf5FkW5Kbu/uTi9gXAMBGWdSbr9PddyS5Y1GvDwCw0fzmawCA\nIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYw\nAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABjCCABgCCMAgCGMAACGMAIAGCcdRlV1QVX9ZVV9qqo+WVW/MOPvrKqHq+re+XrVxk0XAGBxtq/j\nuU8l+cXu/kRVfXeSe6rqrrnvd7v7t9c/PQCAzXPSYdTdB5IcmO1/qqr7kzxroyYGALDZNuQ9RlV1\nYZIXJfmbGXpLVe2pqpur6uyN2AcAwKKtO4yq6swkH05yfXd/Jcm7kjw/yWVZuaJ0w3Ged11V7a6q\n3YeeeHK90wAAWLd1hVFVfUdWouiW7v7jJOnuR7r7UHd/Pcm7k1y+1nO7+6bu3tXdu7aduWM90wAA\n2BDr+VRaJXlPkvu7+3dWjZ+/6mGvSXLfyU8PAGDzrOdTaT+c5KeT/H1V3Ttjv5Lk2qq6LEkneTDJ\nG9c1QwCATbKeT6X9nyS1xl13nPx0AACWx2++BgAYwggAYAgjAIAhjAAAxno+lcZxnHX/ynvSd+5Z\n/C+uPHjpN34H1Jdf2AvfHwCczlwxAgAYwggAYFhKW4AjS2h37/nG4EsuXci+Vi+fnXvJwSTJY3t3\nLmRfAHC6c8UIAGC4YrRIq64S7Xvddy1kF9dfcec3jd2496qF7AsATneuGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAIzty57A6ejgpTuSJF9+YR8Zu/6KOxeyr7ee/dCR7d//0nMXsg8A2Cpc\nMQIAGMIIAGBYSluAw0to515ycOH7Wr18dstDP7jw/QHA6cwVIwCAIYwAAIaltAV6bO/OI9s37r1q\niTMBAE6EK0YAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABj+3pfoKoeTPJPSQ4leaq7d1XV\nOUk+mOTCJA8meW13f2m9+wIAWKSNumL077v7su7eNbffluSj3X1xko/ObQCAU9qiltKuTvK+2X5f\nkp9Y0H4AADbMRoRRJ/lIVd1TVdfN2HndfWC2v5DkvGOfVFXXVdXuqtp96IknN2AaAADrs+73GCX5\nke5+uKr+bZK7qurTq+/s7q6qPvZJ3X1TkpuS5IznXPBN9wMAbLZ1XzHq7ofn+6NJbktyeZJHqur8\nJJnvj653PwAAi7auMKqqHVX13Ye3k1yR5L4ktyd5/Tzs9Un+dD37AQDYDOtdSjsvyW1Vdfi1/rC7\n/7yqPp7kQ1X1hiQPJXntOvcDALBw6wqj7v5ckn+3xvgXk7xiPa8NALDZ/OZrAIAhjAAAhjACABjC\nCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACG\nMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMII\nAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCA\nIYwAAIYwAgAY20/2iVV1SZIPrhp6XpL/nOSsJP8xyWMz/ivdfcdJzxAAYJOcdBh1994klyVJVW1L\n8nCS25L8XJLf7e7f3pAZAgBsko1aSntFkge6+6ENej0AgE23UWF0TZJbV91+S1Xtqaqbq+rstZ5Q\nVddV1e6q2n3oiSc3aBoAACdv3WFUVd+Z5NVJ/miG3pXk+VlZZjuQ5Ia1ntfdN3X3ru7ete3MHeud\nBgDAum3EFaOrknyiux9Jku5+pLsPdffXk7w7yeUbsA8AgIXbiDC6NquW0arq/FX3vSbJfRuwDwCA\nhTvpT6UlSVXtSPLKJG9cNfybVXVZkk7y4DH3AQCcstYVRt39ZJJ/c8zYT69rRgAAS+I3XwMADGEE\nADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMDYvuwJAItz1v11ZHvnnicXuq+Dl+44sv3lF/ZC9wWwKK4YAQAMYQQAMCylwWns\nqOWzu/esfH/JpQvZ1+rls3MvOXhk+7G9OxeyP4BFcMUIAGAIIwCAYSkNtopZQtv3uu9ayMtff8Wd\na47fuPeqhewPYBFcMQIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAxvZlTwBYnIOX7jiy/eUXdpLk\n+ivuXMi+3nr2Q0e2f/9Lz13IPgAW7YSuGFXVzVX1aFXdt2rsnKq6q6o+O9/PnvGqqt+rqn1Vtaeq\nXryoyQMAbKQTXUp7b5Irjxl7W5KPdvfFST46t5PkqiQXz9d1Sd61/mkCACzeCS2ldffHqurCY4av\nTvLy2X5fkv+d5Jdn/P3d3Unurqqzqur87j6wERMGTtzh5bMkOfeSgwvd1+rls1se+sGF7gtgUdbz\n5uvzVsXOF5KcN9vPSvL5VY/bP2NHqarrqmp3Ve0+9MST65gGAMDG2JBPpc3VoX7GBx79nJu6e1d3\n79p25o5nfgIAwIKt51NpjxxeIquq85M8OuMPJ7lg1eOePWPAEj22d2eS5Ma9Vy15JgCnrvVcMbo9\nyetn+/VJ/nTV+M/Mp9NekuQfvb8IAPh2cEJXjKrq1qy80XpnVe1P8o4kv5HkQ1X1hiQPJXntPPyO\nJK9Ksi/JV5P83AbPGQBgIU70U2nXHueuV6zx2E7y5vVMCgBgGfwvQQAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIA\nGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYDxjGFXVzVX1aFXdt2rst6rq01W1p6puq6qzZvzCqvrnqrp3vv5gkZMHANhIJ3LF6L1Jrjxm\n7K4k39/dlyb5TJK3r7rvge6+bL7etDHTBABYvGcMo+7+WJLHjxn7SHc/NTfvTvLsBcwNAGBTbcR7\njH4+yZ2rbl9UVX9XVX9VVS893pOq6rqq2l1Vuw898eQGTAMAYH22r+fJVfWrSZ5KcssMHUjynO7+\nYlX9QJI/qarv6+6vHPvc7r4pyU1JcsZzLuj1zAMAYCOc9BWjqvrZJD+e5Ke6u5Oku7/W3V+c7XuS\nPJDkBRswTwCAhTupMKqqK5P8UpJXd/dXV42fW1XbZvt5SS5O8rmNmCgAwKI941JaVd2a5OVJdlbV\n/iTvyMqn0M5IcldVJcnd8wm0lyX5tar6lyRfT/Km7n58zRcGADjFPGMYdfe1awy/5ziP/XCSD693\nUgAAy+A3XwMADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABD\nGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEE\nADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwti97AgCc/s66v45s79zz5ML3d/DSHUmSL7+wF74vTi/P\neMWoqm6uqker6r5VY++sqoer6t75etWq+95eVfuqam9V/diiJg4AsNFOZCntvUmuXGP8d7v7svm6\nI0mq6nuTXJPk++Y5/72qtm3UZAEAFukZl9K6+2NVdeEJvt7VST7Q3V9L8g9VtS/J5Un++qRnCMC3\nvaOWz+7e843tl1y6kP0dXkI795KDR8Ye27tzIfvi9LKeN1+/par2zFLb2TP2rCSfX/WY/TP2Tarq\nuqraXVW7Dz2x+PVmAIBncrJh9K4kz09yWZIDSW74Vl+gu2/q7l3dvWvbmTtOchoAABvnpD6V1t2P\nHN6uqncn+bO5+XCSC1Y99NkzBgArVi2f7Xvddy1kF9dfcec3jd2496qF7IvTy0ldMaqq81fdfE2S\nw59Yuz3JNVV1RlVdlOTiJH+7vikCAGyOZ7xiVFW3Jnl5kp1VtT/JO5K8vKouS9JJHkzyxiTp7k9W\n1YeSfCrJU0ne3N2HFjN1AICNdSKfSrt2jeH3PM3jfz3Jr69nUgAAy+B/CQIAMIQRAMAQRgAAQxgB\nAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAw\nti97AgCc/g5euuPI9pdf2Ee2r7/izoXs761nP5Qk+f0vPXchr8/pyxUjAIAhjAAAhqU0ABZu9fLZ\nuZccXPj+Di+h3fLQDy58X5xeXDECABjCCABgWEoDYFM9tnfnke0b9161xJnAN3PFCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIA\nGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAxjOGUVXdXFWPVtV9q8Y+WFX3zteDVXXvjF9YVf+86r4/\nWOTkAQA20vYTeMx7k/y3JO8/PNDdrzu8XVU3JPnHVY9/oLsv26gJAgBslmcMo+7+WFVduNZ9VVVJ\nXpvkP2zstAAANt9632P00iSPdPdnV41dVFV/V1V/VVUvPd4Tq+q6qtpdVbsPPfHkOqcBALB+J7KU\n9nSuTXLrqtsHkjynu79YVT+Q5E+q6vu6+yvHPrG7b0pyU5Kc8ZwLep3zAABYt5O+YlRV25P8ZJIP\nHh7r7q919xdn+54kDyR5wXonCQCwGdazlPajST7d3fsPD1TVuVW1bbafl+TiJJ9b3xQBADbHiXxc\n/9Ykf53kkqraX1VvmLuuydHLaEnysiR75uP7/yPJm7r78Y2cMADAopzIp9KuPc74z64x9uEkH17/\ntAAANp/ffA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMKq7lz2HVNVjSR5KsjPJwSVP51Ti\neBzN8Tia4/ENjsXRHI+jOR5H24rH47ndfe6JPPCUCKPDqmp3d+9a9jxOFY7H0RyPozke3+BYHM3x\nOJrjcTTH4+lZSgMAGMIIAGCcamF007IncIpxPI7meBzN8fgGx+JojsfRHI+jOR5P45R6jxEAwDKd\naleMAACW5pQIo6q6sqr2VtW+qnrbsuez2arqgqr6y6r6VFV9sqp+YcbfWVUPV9W98/WqZc91s1TV\ng1X19/Pn3j1j51TVXVX12fl+9rLnuRmq6pJV58C9VfWVqrp+K50fVXVzVT1aVfetGlvzfKgVvzd/\nn+ypqhcvb+aLcZzj8VtV9en5M99WVWfN+IVV9c+rzpM/WN7MF+M4x+O4Px9V9fY5P/ZW1Y8tZ9aL\nc5zj8cFVx+LBqrp3xk/78+NbtfSltKraluQzSV6ZZH+Sjye5trs/tdSJbaKqOj/J+d39iar67iT3\nJPmJJK9N8kR3//ZSJ7gEVfVgkl3dfXDV2G8meby7f2MC+uzu/uVlzXEZ5ufl4SQ/lOTnskXOj6p6\nWZInkry/u79/xtY8H+Y/gG9N8qqsHKf/2t0/tKy5L8JxjscVSf5Xdz9VVf8lSeZ4XJjkzw4/7nR0\nnOPxzqzx81FV35vk1iSXJ/meJP8zyQu6+9CmTnqB1joex9x/Q5J/7O5f2wrnx7fqVLhidHmSfd39\nue7+f0k+kOTqJc9pU3X3ge7+xGz/U5L7kzxrubM6JV2d5H2z/b6sxONW84okD3T3Q8ueyGbq7o8l\nefyY4eOdD1dn5T8I3d13Jzlr/vFx2ljreHT3R7r7qbl5d5Jnb/rEluQ458fxXJ3kA939te7+hyT7\nsvLfodPG0x2Pqqqs/KP71k2d1LeRUyGMnpXk86tu788WjoKp9xcl+ZsZestcGr95qywdjU7ykaq6\np6qum7HzuvvAbH8hyXnLmdpSXZOj/0LbqudHcvzzwd8pyc8nuXPV7Yuq6u+q6q+q6qXLmtQSrPXz\nsdXPj5cmeaS7P7tqbKueH2s6FcKIUVVnJvlwkuu7+ytJ3pXk+UkuS3IgyQ1LnN5m+5HufnGSq5K8\neS4NH9Era8Bb6iOVVfWdSV6d5I9maCufH0fZiufD8VTVryZ5KsktM3QgyXO6+0VJ/lOSP6yqf72s\n+W0iPx9ruzZH/+Nqq54fx3UqhNHDSS5YdfvZM7alVNV3ZCWKbunuP06S7n6kuw9199eTvDun2eXe\np9PdD8/3R5PclpU/+yOHl0Tm+6PLm+FSXJXkE939SLK1z49xvPNhy/6dUlU/m+THk/zUxGJmyeiL\ns31PkgeSvGBpk9wkT/PzsZXPj+1JfjLJBw+PbdXz4+mcCmH08SQXV9VF8y/ia5LcvuQ5bapZ831P\nkvu7+3dWja9+X8Rrktx37HNPR1W1Y96EnqrakeSKrPzZb0/y+nnY65P86XJmuDRH/Utvq54fqxzv\nfLg9yc/Mp9NekpU3mR5Y6wVOJ1V1ZZJfSvLq7v7qqvFz5037qarnJbk4yeeWM8vN8zQ/H7cnuaaq\nzqiqi7JyPP52s+e3JD+a5NPdvf/wwFY9P57O9mVPYD5B8ZYkf5FkW5Kbu/uTS57WZvvhJD+d5O8P\nf4Qyya8kubaqLsvKEsGDSd64nOltuvOS3LbSi9me5A+7+8+r6uNJPlRVb0jyUFbeQLglTCC+Mkef\nA7+5Vc6Pqro1ycuT7Kyq/UnekeQ3svb5cEdWPpG2L8lXs/LpvdPKcY7H25OckeSu+dm5u7vflORl\nSX6tqv4lydeTvKm7T/SNyt8WjnM8Xr7Wz0d3f7KqPpTkU1lZcnzz6fSJtGTt49Hd78k3v0cx2QLn\nx7dq6R/XBwA4VZwKS2kAAKcEYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAADj/wNFwYmlDN8g\nuwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "E4i3TI5JFVG9", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## VALID padding, no stride, no dilation" ] }, { "metadata": { "id": "1HQwudKVFVG-", "colab_type": "code", "outputId": "9edf1704-b920-4666-8a49-2db7dd622fd1", "colab": { "base_uri": "https://localhost:8080/", "height": 629 } }, "cell_type": "code", "source": [ "out = lax.conv_general_dilated(img, # lhs = image tensor\n", " kernel, # rhs = conv kernel tensor\n", " (1,1), # window strides\n", " 'VALID', # padding mode\n", " (1,1), # lhs/image dilation\n", " (1,1), # rhs/kernel dilation\n", " dn) # dimension_numbers = lhs, rhs, out dimension permutation\n", "print(\"out shape: \", out.shape, \"DIFFERENT from above!\")\n", "print(\"First output channel:\")\n", "plt.figure(figsize=(10,10))\n", "plt.imshow(onp.array(out)[0,:,:,0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "out shape: (1, 198, 196, 3) DIFFERENT from above!\n", "First output channel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGXFJREFUeJzt3X+spndZ5/HPtTNK4tRNpzvdpkKh\nhZQGNd2qYyVRCLtIbYmh4h/Qxigq2UICxGZNFDRZiImJq1a7ullMCQ2Q1AKK1WbTKl3WlWxilSk2\ntVBGptiG6da2Q0FsMawdrv1jrhme6ZxhxjnnOc845/VKTs79fJ8f93fu3Gf6nvvHaXV3AABI/tWq\nJwAAcLoQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEAjKWFUVVdWVV7q2pfVb19WesBANgotYxf8FhV\n25L8TZJXJdmf5BNJru3uT2/4ygAANsj2JX3u5Un2dffnkqSqPpjk6iRrhtG2s3b09nPOWdJUAICt\n7Jknn8zBp56uk3ntssLouUk+v/B4f5LvO+4kzjkn3/az1y9pKgDAVvZ/b7jxpF+7souvq+q6qtpT\nVXsOPvX0qqYBAHDEssLokSQXLDx+3owd0d03dffu7t697awdS5oGAMDJW1YYfSLJxVV1UVV9c5Jr\nkty+pHUBAGyIpVxj1N3PVNVbk/xJkm1Jbu7uTy1jXQAAG2VZF1+nu+9IcseyPh8AYKP5zdcAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgB\nAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAw\nhBEAwBBGAABDGAEADGEEADCEEQDAOOUwqqoLqupPq+rTVfWpqvqZGX9XVT1SVffO16s3broAAMuz\nfR3vfSbJz3b3J6vqW5PcU1V3zXO/2d2/vv7pAQBsnlMOo+5+NMmjs/wPVfVAkudu1MQAADbbhlxj\nVFUXJvmuJH8xQ2+tqvuq6uaq2nmc91xXVXuqas/Bp57eiGkAAKzLusOoqs5K8pEk13f3l5O8O8mL\nklyWQ0eUbljrfd19U3fv7u7d287asd5pAACs27rCqKq+KYei6Jbu/oMk6e7Huvtgd38tyXuSXL7+\naQIALN967kqrJO9N8kB3/8bC+PkLL3ttkvtPfXoAAJtnPXelfX+SH0/y11V174z9QpJrq+qyJJ3k\noSRvWtcMAQA2yXruSvs/SWqNp+449ekAAKyO33wNADDWcyqNBWc/8PWDZ7vuW/6vHzhw6dfv5PvS\nS3rp6wOArcARIwCAIYwAAIZTaRvkqNNnd9/39eWXXrqU9S2ePjv3kgNJkif27lrKugBgq3DECABg\nOGK0DAtHifa9/luWsorrr7jzmLEb9161lHUBwFbhiBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMDYvuoJnCkOXLrjyPKXXtJHlq+/4s6lrO9tOx8+svzbX3zBUtYBAFuNI0YAAEMYAQAMp9I2\nyOLps3MvObD09S2ePrvl4e9d+voAYCtwxAgAYAgjAIDhVNoSPLF315HlG/detcKZAAD/HI4YAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAGP7ej+gqh5K8g9JDiZ5prt3\nV9U5ST6U5MIkDyV5XXd/cb3rAgBYpo06YvTvu/uy7t49j9+e5GPdfXGSj81jAIDT2rJOpV2d5P2z\n/P4kP7Kk9QAAbJiNCKNO8tGquqeqrpux87r70Vn+uyTnbcB6AACWat3XGCX5ge5+pKr+bZK7quoz\ni092d1dVP/tNE1HXJcm2nTs3YBoAAOuz7iNG3f3IfH88yW1JLk/yWFWdnyTz/fE13ndTd+/u7t3b\nztqx3mkAAKzbusKoqnZU1bceXk5yRZL7k9ye5A3zsjck+aP1rAcAYDOs91TaeUluq6rDn/W73f3H\nVfWJJB+uqjcmeTjJ69a5HgCApVtXGHX355L8uzXGv5Dklev5bACAzeY3XwMADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABj+6m+saouSfKhhaEXJvnPSc5O8h+TPDHjv9Ddd5zyDAEANskph1F3701yWZJU\n1bYkjyS5LclPJfnN7v71DZkhAMAm2ahTaa9M8mB3P7xBnwcAsOk2KoyuSXLrwuO3VtV9VXVzVe1c\n6w1VdV1V7amqPQefenqDpgEAcOrWHUZV9c1JXpPk92bo3UlelEOn2R5NcsNa7+vum7p7d3fv3nbW\njvVOAwBg3TbiiNFVST7Z3Y8lSXc/1t0Hu/trSd6T5PINWAcAwNJtRBhdm4XTaFV1/sJzr01y/was\nAwBg6U75rrQkqaodSV6V5E0Lw79aVZcl6SQPPes5AIDT1rrCqLufTvJvnjX24+uaEQDAivjN1wAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEA\nwBBGAABj+6onACzH2Q/UkeVd9z299PUduHTHkeUvvaSXvj6AZXDECABgCCMAgOFUGpyhjjp9dvd9\nX19+6aVLWd/i6bNzLzmQJHli766lrAtgWRwxAgAYjhjBVrBwlGjf679lKau4/oo7jxm7ce9VS1kX\nwLI4YgQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADC2r3oCwHIcuHTHkeUvvaSPLF9/xZ1L\nWd/bdj58ZPm3v/iCpawDYNkcMQIAGMIIAGA4lQZnqMXTZ+decmDp61s8fXbLw9+79PUBLMNJHTGq\nqpur6vGqun9h7JyququqPjvfd854VdVvVdW+qrqvqr57WZMHANhIJ3sq7X1JrnzW2NuTfKy7L07y\nsXmcJFcluXi+rkvy7vVPEwBg+U7qVFp3f7yqLnzW8NVJXjHL70/yv5P8/Ix/oLs7yd1VdXZVnd/d\nj27EhIF/vif27jqyfOPeq1Y4E4DT23ouvj5vIXb+Lsl5s/zcJJ9feN3+GTtKVV1XVXuqas/Bp55e\nxzQAADbGhtyVNkeH+oQvPPo9N3X37u7eve2sHSd+AwDAkq0njB6rqvOTZL4/PuOPJLlg4XXPmzEA\ngNPaesLo9iRvmOU3JPmjhfGfmLvTXprk711fBAD8S3BSF19X1a05dKH1rqran+SdSX4lyYer6o1J\nHk7yunn5HUlenWRfkq8k+akNnjMAwFKc7F1p1x7nqVeu8dpO8pb1TAoAYBX8L0EAAIYwAgAYwggA\nYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABg\nCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGCcMIyq6uaqeryq7l8Y+7Wq+kxV3VdVt1XV2TN+YVX9Y1XdO1+/s8zJAwBspJM5\nYvS+JFc+a+yuJN/Z3Zcm+Zsk71h47sHuvmy+3rwx0wQAWL4ThlF3fzzJk88a+2h3PzMP707yvCXM\nDQBgU23ENUY/neTOhccXVdVfVdWfVdXLNuDzAQA2xfb1vLmqfjHJM0lumaFHkzy/u79QVd+T5A+r\n6ju6+8trvPe6JNclybadO9czDQCADXHKR4yq6ieT/HCSH+vuTpLu/mp3f2GW70nyYJIXr/X+7r6p\nu3d39+5tZ+041WkAAGyYUwqjqroyyc8leU13f2Vh/Nyq2jbLL0xycZLPbcREAQCW7YSn0qrq1iSv\nSLKrqvYneWcO3YX2nCR3VVWS3D13oL08yS9V1T8l+VqSN3f3k2t+MADAaeaEYdTd164x/N7jvPYj\nST6y3kkBAKyC33wNADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABD\nGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEE\nADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwtq96AgBsDWc/UEeWd9339NLXd+DSHUeWv/SSXvr6ODM4\nYgQAMBwxAmBTHHWU6O77vr780kuXsr7Fo0TnXnIgSfLE3l1LWRdnjhMeMaqqm6vq8aq6f2HsXVX1\nSFXdO1+vXnjuHVW1r6r2VtUPLWviAAAb7WROpb0vyZVrjP9md182X3ckSVV9e5JrknzHvOe/V9W2\njZosAMAynfBUWnd/vKouPMnPuzrJB7v7q0n+tqr2Jbk8yZ+f8gwBOPMsnD7b9/pvWcoqrr/izmPG\nbtx71VLWxZljPRdfv7Wq7ptTbTtn7LlJPr/wmv0zdoyquq6q9lTVnoNPLf/uBACAEznVMHp3khcl\nuSzJo0lu+Od+QHff1N27u3v3trN2nPgNAABLdkph1N2PdffB7v5akvfk0OmyJHkkyQULL33ejAEA\nnPZOKYyq6vyFh69NcviOtduTXFNVz6mqi5JcnOQv1zdFAIDNccKLr6vq1iSvSLKrqvYneWeSV1TV\nZUk6yUNJ3pQk3f2pqvpwkk8neSbJW7r74HKmDgCwsU7mrrRr1xh+7zd4/S8n+eX1TAoAYBX8L0EA\nAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYGxf9QQA2BoOXLrjyPKXXtJHlq+/4s6lrO9tOx8+svzbX3zBUtbBmccRIwCAIYwAAIZTaQBs\nisXTZ+decmDp61s8fXbLw9+79PVxZnDECABgCCMAgOFUGgCb7om9u44s37j3qhXOBI7miBEAwBBG\nAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEA\nDGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCE\nEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAADjhGFUVTdX1eNVdf/C2Ieq\n6t75eqiq7p3xC6vqHxee+51lTh4AYCNtP4nXvC/Jf0vygcMD3f36w8tVdUOSv194/YPdfdlGTRAA\nYLOcMIy6++NVdeFaz1VVJXldkv+wsdMCANh8673G6GVJHuvuzy6MXVRVf1VVf1ZVLzveG6vquqra\nU1V7Dj719DqnAQCwfidzKu0buTbJrQuPH03y/O7+QlV9T5I/rKrv6O4vP/uN3X1TkpuS5DnPv6DX\nOQ8AgHU75SNGVbU9yY8m+dDhse7+and/YZbvSfJgkhevd5IAAJthPafSfjDJZ7p7/+GBqjq3qrbN\n8guTXJzkc+ubIgDA5jiZ2/VvTfLnSS6pqv1V9cZ56pocfRotSV6e5L65ff/3k7y5u5/cyAkDACzL\nydyVdu1xxn9yjbGPJPnI+qcFALD5/OZrAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgj\nAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGBUd696DqmqJ5I8nGRXkgMrns7pxPY4lm1y\nLNvkaLbHsWyTY9kmRzvTt8cLuvvck3nhaRFGh1XVnu7evep5nC5sj2PZJseyTY5mexzLNjmWbXI0\n2+PrnEoDABjCCABgnG5hdNOqJ3CasT2OZZscyzY5mu1xLNvkWLbJ0WyPcVpdYwQAsEqn2xEjAICV\nOS3CqKqurKq9VbWvqt6+6vmsQlVdUFV/WlWfrqpPVdXPzPi7quqRqrp3vl696rlulqp6qKr+ev7c\ne2bsnKq6q6o+O993rnqem6WqLlnYD+6tqi9X1fVbbR+pqpur6vGqun9hbM39og75rfm75b6q+u7V\nzXw5jrM9fq2qPjN/5tuq6uwZv7Cq/nFhX/md1c18eY6zTY77c1JV75h9ZG9V/dBqZr1cx9kmH1rY\nHg9V1b0zviX2k+NZ+am0qtqW5G+SvCrJ/iSfSHJtd396pRPbZFV1fpLzu/uTVfWtSe5J8iNJXpfk\nqe7+9ZVOcAWq6qEku7v7wMLYryZ5srt/ZSJ6Z3f//KrmuCrzc/NIku9L8lPZQvtIVb08yVNJPtDd\n3zlja+4X8x+/tyV5dQ5tq//a3d+3qrkvw3G2xxVJ/ld3P1NV/yVJZntcmOR/HH7dmeo42+RdWePn\npKq+PcmtSS5P8m1J/meSF3f3wU2d9JKttU2e9fwNSf6+u39pq+wnx3M6HDG6PMm+7v5cd/+/JB9M\ncvWK57TpuvvR7v7kLP9DkgeSPHe1szotXZ3k/bP8/hyKx63olUke7O6HVz2RzdbdH0/y5LOGj7df\nXJ1D/yHo7r47ydnzj5Azxlrbo7s/2t3PzMO7kzxv0ye2QsfZR47n6iQf7O6vdvffJtmXQ/9dOqN8\no21SVZVD/wi/dVMndZo6HcLouUk+v/B4f7Z4EEytf1eSv5iht84h8Zu30qmjJJ3ko1V1T1VdN2Pn\ndfejs/x3Sc5bzdRW7poc/ZfYVt1HDjvefuHvl+Snk9y58PiiqvqrqvqzqnrZqia1Imv9nNhHkpcl\neay7P7swtmX3k9MhjFhQVWcl+UiS67v7y0neneRFSS5L8miSG1Y4vc32A9393UmuSvKWORR8RB86\nD7zlbqusqm9O8pokvzdDW3kfOcZW3S/WUlW/mOSZJLfM0KNJnt/d35XkPyX53ar616ua3ybzc3J8\n1+bof2ht5f3ktAijR5JcsPD4eTO25VTVN+VQFN3S3X+QJN39WHcf7O6vJXlPzsBDvMfT3Y/M98eT\n3JZDf/bHDp8Kme+Pr26GK3NVkk9292PJ1t5HFhxvv9iyf79U1U8m+eEkPzaxmDld9IVZvifJg0le\nvLJJbqJv8HOyZfeRJKmq7Ul+NMmHDo9t5f0kOT3C6BNJLq6qi+ZfwtckuX3Fc9p0c473vUke6O7f\nWBhfvB7itUnuf/Z7z0RVtWMuQk9V7UhyRQ792W9P8oZ52RuS/NFqZrhSR/3rbqvuI89yvP3i9iQ/\nMXenvTSHLi59dK0POJNU1ZVJfi7Ja7r7Kwvj586F+6mqFya5OMnnVjPLzfUNfk5uT3JNVT2nqi7K\noW3yl5s9vxX6wSSf6e79hwe28n6SJNtXPYG5a+KtSf4kybYkN3f3p1Y8rVX4/iQ/nuSvD98ymeQX\nklxbVZfl0KmBh5K8aTXT23TnJbntUC9me5Lf7e4/rqpPJPlwVb0xycM5dMHgljGR+KocvR/86lba\nR6rq1iSvSLKrqvYneWeSX8na+8UdOXRH2r4kX8mhO/jOKMfZHu9I8pwkd83P0N3d/eYkL0/yS1X1\nT0m+luTN3X2yFyn/i3GcbfKKtX5OuvtTVfXhJJ/OodOObznT7khL1t4m3f3eHHu9YrJF9pPjWfnt\n+gAAp4vT4VQaAMBpQRgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA+P+2dY6FyWLi+wAAAABJ\nRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "VYKZdqLIFVHB", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## SAME padding, 2,2 stride, no dilation" ] }, { "metadata": { "id": "mKq2-zmmFVHC", "colab_type": "code", "outputId": "73b80162-fdca-4f6a-ec06-4645d6fdc9f6", "colab": { "base_uri": "https://localhost:8080/", "height": 629 } }, "cell_type": "code", "source": [ "out = lax.conv_general_dilated(img, # lhs = image tensor\n", " kernel, # rhs = conv kernel tensor\n", " (2,2), # window strides\n", " 'SAME', # padding mode\n", " (1,1), # lhs/image dilation\n", " (1,1), # rhs/kernel dilation\n", " dn) # dimension_numbers = lhs, rhs, out dimension permutation\n", "print(\"out shape: \", out.shape, \" <-- half the size of above\")\n", "plt.figure(figsize=(10,10))\n", "print(\"First output channel:\")\n", "plt.imshow(onp.array(out)[0,:,:,0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "out shape: (1, 100, 99, 3) <-- half the size of above\n", "First output channel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj8AAAJCCAYAAAAvEKYoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFHtJREFUeJzt3V+sZWdZx/HfYw8tzKi0RZnUtsoY\niIaYKHZKEKwxU00QiJ0LgjRqiqnpjX9QNLZ6Y0ww0cSAXBhMQ9VekAKppG2M0RBaknrTdEpNKq1K\nU4TOpP8MFk1rgBMfL85WBuaUczpn7/OH5/O5mbPevU7fN1lZk2/Xfvee6u4AAEzxbXu9AACA3SR+\nAIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBglB3FT1W9uar+paoeraqblrUoAIBVqXP9hueq\nOi/Jvyb56SSnktyf5NrufviFfmft0OF+ycsvPqf5AAC+ma9+6YtZf/652uq8tR3M8fokj3b3Y0lS\nVR9Jck2SF4yfl7z84hx913t2MCUAwOY+91fv29Z5O3nb69Ikj59xfGoxBgCwb618w3NV3VBVJ6vq\n5Przz616OgCAb2on8XM6yeVnHF+2GPs63X1zdx/r7mNrhw7vYDoAgJ3bSfzcn+Q1VXW0qs5P8s4k\ndy1nWQAAq3HOG567e72qfjXJ3yc5L8lfdPdnlrYyAIAV2MmnvdLdf5vkb5e0FgCAlfMNzwDAKOIH\nABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCA\nUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF\n/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQP\nADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAA\no4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK\n+AEARhE/AMAo4gcAGEX8AACjiB8AYJS1vV7AQXPk/i+fNbZ29wMrmWv9+BWbjj915QUrmQ8AJvDk\nBwAYRfwAAKOIHwBgFPEDAIwifgCAUXza60Xa7JNdp29840rmOnHtvZuO33HbVSuZDwAm8OQHABhF\n/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQP\nADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMsrbXCzho1o9fcdbYiWvvXclc733lQ5uO35Gr\nVjIfAEzgyQ8AMIr4AQBGET8AwCjiBwAYZcv4qarLq+qeqnq4qj5TVe9ejF9cVZ+oqs8u/rxo9csF\nANiZ6u5vfkLVJUku6e5PV9V3JHkgyYkk70ryxe7+o6q6KclF3X3jN/tvveySy/vou96znJUDAJzh\nc3/1vvz3E4/XVudt+eSnu5/o7k8vfv6vJI8kuTTJNUluXZx2azaCCABgX3tRe36q6lVJXpfkviRH\nuvuJxUtPJjmy1JUBAKzAtuOnqr49yV8n+Y3u/s8zX+uN9842ff+sqm6oqpNVdXL9+ed2tFgAgJ3a\nVvxU1UuyET4f7u6PL4afWuwH+r99QU9v9rvdfXN3H+vuY2uHDi9jzQAA52w7n/aqJLckeaS733fG\nS3cluW7x83VJ7lz+8gAAlms7/7bXm5L8YpKHquofF2O/l+SPknysqq5P8vkk71jNEgEAlmfL+Onu\nf0jyQh8bu3q5ywEAWC3f8AwAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwi\nfgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIH\nABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCA\nUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF\n/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQP\nADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAA\no4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK\n+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gf\nAGAU8QMAjCJ+AIBRxA8AMIr4AQBG2Xb8VNV5VfVgVf3N4vhoVd1XVY9W1Uer6vzVLRMAYDlezJOf\ndyd55IzjP07y/u5+dZL/SHL9MhcGALAK24qfqrosyVuTfGhxXEmOJ7l9ccqtSU6sYoEAAMu03Sc/\nf5rkd5L8z+L4FUme7e71xfGpJJdu9otVdUNVnayqk+vPP7ejxQIA7NSW8VNVb0vydHc/cC4TdPfN\n3X2su4+tHTp8Lv8JAIClWdvGOW9K8rNV9ZYkL03ynUk+kOTCqlpbPP25LMnp1S0TAGA5tnzy092/\n292Xdferkrwzyd3d/fNJ7kny9sVp1yW5c2WrBABYkp18z8+NSd5TVY9mYw/QLctZEgDA6mznba//\n192fSvKpxc+PJXn98pcEALA6vuEZABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCji\nBwAYRfwAAKOIHwBgFPEDAIwifgCAUV7Uv+oO7B9H7v/ypuNrdz+wkvnWj19x1thTV16wkrkAVsmT\nHwBgFPEDAIwifgCAUcQPADCK+AEARvFpLzigXuhTXadvfONK5jtx7b1njd1x21UrmQtglTz5AQBG\nET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTx\nAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGCUtb1eAHBu1o9fsen4iWvvXcl8733l\nQ2eN3ZGrVjIXwCp58gMAjCJ+AIBRxA8AMIr4AQBGseEZDqinrrxg0/E7blvNJmSbm4FvFZ78AACj\niB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4\nAQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8A\nYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBG\nET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTx\nAwCMIn4AgFHEDwAwyrbip6ourKrbq+qfq+qRqvqxqrq4qj5RVZ9d/HnRqhcLALBT233y84Ekf9fd\nP5jkh5M8kuSmJJ/s7tck+eTiGABgX9syfqrq5Ul+IsktSdLdX+nuZ5Nck+TWxWm3JjmxqkUCACzL\ndp78HE3yTJK/rKoHq+pDVXU4yZHufmJxzpNJjmz2y1V1Q1WdrKqT688/t5xVAwCco+3Ez1qSH03y\nwe5+XZLn8g1vcXV3J+nNfrm7b+7uY919bO3Q4Z2uFwBgR7YTP6eSnOru+xbHt2cjhp6qqkuSZPHn\n06tZIgDA8mwZP939ZJLHq+oHFkNXJ3k4yV1JrluMXZfkzpWsEABgida2ed6vJflwVZ2f5LEkv5SN\ncPpYVV2f5PNJ3rGaJQIALM+24qe7/zHJsU1eunq5ywEAWC3f8AwAjCJ+AIBRxA8AMIr4AQBGET8A\nwCjiBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCM\nIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCji\nBwAYRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4A\ngFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAY\nRfwAAKOIHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHE\nDwAwivgBAEZZ2+sFAHDwHbn/y2eNrd39wErmWj9+xVljT115wUrm4luTJz8AwCjiBwAYRfwAAKOI\nHwBgFBueAdixzTY3n77xjSuZ68S19541dsdtV61kLr41efIDAIwifgCAUcQPADCK+AEARhE/AMAo\n4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+\nAIBRxA8AMIr4AQBGET8AwChre70AAA6+9eNXnDV24tp7VzLXe1/50Fljd+SqlczFtyZPfgCAUcQP\nADCK+AEARtlW/FTVb1bVZ6rqn6rqtqp6aVUdrar7qurRqvpoVZ2/6sUCAOzUlhueq+rSJL+e5LXd\n/d9V9bEk70zyliTv7+6PVNWfJ7k+yQdXuloA9qWnrrzgrLE7blvNJmSbm9mp7b7ttZbkZVW1luRQ\nkieSHE9y++L1W5OcWP7yAACWa8v46e7TSf4kyReyET1fSvJAkme7e31x2qkkl65qkQAAy7Jl/FTV\nRUmuSXI0yfckOZzkzdudoKpuqKqTVXVy/fnnznmhAADLsJ23vX4qyee6+5nu/mqSjyd5U5ILF2+D\nJcllSU5v9svdfXN3H+vuY2uHDi9l0QAA52o78fOFJG+oqkNVVUmuTvJwknuSvH1xznVJ7lzNEgEA\nlmc7e37uy8bG5k8neWjxOzcnuTHJe6rq0SSvSHLLCtcJALAU2/q3vbr795P8/jcMP5bk9UtfEQDA\nCvmGZwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAw\nivgBAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOI\nHwBgFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgB\nAEYRPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBg\nFPEDAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYR\nPwDAKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPED\nAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGEX8AACjiB8AYBTxAwCMIn4AgFHEDwAwivgBAEYRPwDA\nKOIHABhF/AAAo4gfAGAU8QMAjCJ+AIBRxA8AMIr4AQBGET8AwCjiBwAYRfwAAKOIHwBgFPEDAIwi\nfgCAUcQPADCK+AEARqnu3r3Jqp5J8vnF4Xcl+fddm5ydcr0OHtfs4HHNDhbXa//5vu7+7q1O2tX4\n+bqJq05297E9mZwXzfU6eFyzg8c1O1hcr4PL214AwCjiBwAYZS/j5+Y9nJsXz/U6eFyzg8c1O1hc\nrwNqz/b8AADsBW97AQCj7Hr8VNWbq+pfqurRqrppt+dna1V1eVXdU1UPV9Vnqurdi/GLq+oTVfXZ\nxZ8X7fVa+ZqqOq+qHqyqv1kcH62q+xb32ker6vy9XiNfU1UXVtXtVfXPVfVIVf2Ye2x/q6rfXPyd\n+E9VdVtVvdR9djDtavxU1XlJ/izJzyR5bZJrq+q1u7kGtmU9yW9192uTvCHJryyu001JPtndr0ny\nycUx+8e7kzxyxvEfJ3l/d786yX8kuX5PVsUL+UCSv+vuH0zyw9m4du6xfaqqLk3y60mOdfcPJTkv\nyTvjPjuQdvvJz+uTPNrdj3X3V5J8JMk1u7wGttDdT3T3pxc//1c2/lK+NBvX6tbFabcmObE3K+Qb\nVdVlSd6a5EOL40pyPMnti1Ncr32kql6e5CeS3JIk3f2V7n427rH9bi3Jy6pqLcmhJE/EfXYg7Xb8\nXJrk8TOOTy3G2Keq6lVJXpfkviRHuvuJxUtPJjmyR8vibH+a5HeS/M/i+BVJnu3u9cWxe21/OZrk\nmSR/uXir8kNVdTjusX2ru08n+ZMkX8hG9HwpyQNxnx1INjzzgqrq25P8dZLf6O7/PPO13viYoI8K\n7gNV9bYkT3f3A3u9FrZtLcmPJvlgd78uyXP5hre43GP7y2L/1TXZCNfvSXI4yZv3dFGcs92On9NJ\nLj/j+LLFGPtMVb0kG+Hz4e7++GL4qaq6ZPH6JUme3qv18XXelORnq+rfsvFW8vFs7Ce5cPF4PnGv\n7Tenkpzq7vsWx7dnI4bcY/vXTyX5XHc/091fTfLxbNx77rMDaLfj5/4kr1nsjj8/G5vF7trlNbCF\nxX6RW5I80t3vO+Olu5Jct/j5uiR37vbaOFt3/253X9bdr8rGPXV3d/98knuSvH1xmuu1j3T3k0ke\nr6ofWAxdneThuMf2sy8keUNVHVr8Hfl/18x9dgDt+pccVtVbsrE/4bwkf9Hdf7irC2BLVfXjSe5N\n8lC+tofk97Kx7+djSb43yeeTvKO7v7gni2RTVfWTSX67u99WVd+fjSdBFyd5MMkvdPeX93J9fE1V\n/Ug2Nqifn+SxJL+Ujf8hdY/tU1X1B0l+LhufiH0wyS9nY4+P++yA8Q3PAMAoNjwDAKOIHwBgFPED\nAIwifgCAUcQPADCK+AEARhE/AMAo4gcAGOV/AV4aIGDshI8OAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "gPxttaiaFVHE", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## VALID padding, no stride, rhs kernel dilation ~ Atrous convolution (excessive to illustrate)" ] }, { "metadata": { "id": "_pGr0x6qFVHF", "colab_type": "code", "outputId": "a0f489eb-bab7-42c4-c030-fb756f63a4a4", "colab": { "base_uri": "https://localhost:8080/", "height": 629 } }, "cell_type": "code", "source": [ "out = lax.conv_general_dilated(img, # lhs = image tensor\n", " kernel, # rhs = conv kernel tensor\n", " (1,1), # window strides\n", " 'VALID', # padding mode\n", " (1,1), # lhs/image dilation\n", " (12,12), # rhs/kernel dilation\n", " dn) # dimension_numbers = lhs, rhs, out dimension permutation\n", "print(\"out shape: \", out.shape)\n", "plt.figure(figsize=(10,10))\n", "print(\"First output channel:\")\n", "plt.imshow(onp.array(out)[0,:,:,0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "out shape: (1, 176, 174, 3)\n", "First output channel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkUAAAJCCAYAAADOe7N5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGwpJREFUeJzt3X+spXd92Pn3Zz2xU2xtjeOGEtu7\ndm2TikTNBk1ZV9lWCXQbk0Zx/oiCadq4LZLlLk0TSjeBRFp2/0BK2qpuot2CvIHitAhMKQ1WlP6g\nLilaaQ2ZkITfhAECHsvERAm0OFoTO9/94x52r+yZ2sy512fmzuslje45z3nOPZ+HZ+b6zXOe59xZ\nawUAcKH7r3Y9AADAuUAUAQAkigAAKlEEAFCJIgCAShQBAFSiCACgOsQompmbZ+YTM3NyZl59WK8D\nAHAQ5jA+vHFmLqp+q/ofq1PVr1YvW2t99MBfDADgABw7pO/7wurkWuvTVTPztuqW6rRRdNFll65j\nV1xxSKMAABeyrzxw6nfXWn/iqdY7rCi6qnpg3/1T1X9/xiGuuKJvetWPHdIoAMCF7Ld/7O999ums\nt7MTrWfm9pk5MTMnHv/yI7saAwCgOrwoerC6Zt/9qzfL/j9rrbvWWsfXWscvuuzSQxoDAODpOawo\n+tXqxpm5bmYurm6t7j2k1wIA2NqhnFO01npsZv529W+ri6o3rbU+chivBQBwEA7rROvWWr9c/fJh\nfX8AgIPkE60BABJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEA\nQHWIvxD2a3HJA490wyvvP5DvdfLOm067/KC+/9k400wAwLnDkSIAgEQRAEAligAAKlEEAFCJIgCA\nShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBA\nJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAqo7teoCqR6+5\ntJOvuulQX+PknYf7/QGA85sjRQAAiSIAgEoUAQBUoggAoBJFAACVKAIAqM6RS/IveeCRbnjl/Qfy\nvc6nS+8/9dI37Oy1r7/njp29NgCcixwpAgBIFAEAVKIIAKDaIopm5pqZec/MfHRmPjIzP7pZfsXM\nvHtmPrn5+uyDGxcA4HBsc6TosepVa63nVzdVr5iZ51evru5ba91Y3be5DwBwTjvrKFprPbTW+sDm\n9n+uPlZdVd1S3b1Z7e7q+7cdEgDgsB3IOUUzc2317dX7questR7aPPT56jkH8RoAAIdp6yiamcuq\nf1n92FrrP+1/bK21qnWG590+Mydm5sQf9ui2YwAAbGWrKJqZr2sviN6y1nrnZvHvzMxzN48/t3r4\ndM9da9211jq+1jr+dV2yzRgAAFvb5uqzqd5YfWyt9Y/2PXRvddvm9m3Vu85+PACAZ8Y2v+bjO6q/\nVn1oZn5js+wnq5+u3j4zL68+W/3gdiMCABy+s46itdb/Vc0ZHn7x2X5fAIBd8InWAACJIgCAShQB\nAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoA\nACpRBABQiSIAgKqO7XqAqkevubSTr7pp12M8466/545djwAAbDhSBACQKAIAqEQRAEAligAAKlEE\nAFCdI1efPRM+9dI37Oy1z3SV2bk4EwBcqBwpAgBIFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpR\nBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUo\nAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAA1QFE0cxcNDO/PjO/tLl/3cy8b2ZOzsw9\nM3Px9mMCAByugzhS9KPVx/bd/5nqzrXWDdXvVy8/gNcAADhUx7Z58sxcXf3l6nXV352ZqV5U/ZXN\nKndX/2v1+m1e5yBcf88dux7hSc7FmQDgQrXtkaJ/XP149Ueb+99QfXGt9djm/qnqqtM9cWZun5kT\nM3Pi8S8/suUYAADbOesompnvrR5ea/3a2Tx/rXXXWuv4Wuv4RZdderZjAAAciG3ePvuO6vtm5nuq\nr6/+6+pnq8tn5tjmaNHV1YPbjwkAcLjO+kjRWus1a62r11rXVrdW/2Gt9UPVe6of2Kx2W/WuracE\nADhkh/E5RT/R3knXJ9s7x+iNh/AaAAAHaqurz75qrfUr1a9sbn+6euFBfF8AgGeKT7QGAEgUAQBU\noggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAq\nUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACV\nKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBK\nFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqLaMopm5\nfGbeMTMfn5mPzcyfm5krZubdM/PJzddnH9SwAACHZdsjRT9b/Zu11p+uvq36WPXq6r611o3VfZv7\nAADntLOOopn549VfqN5Ytdb6ylrri9Ut1d2b1e6uvn/bIQEADts2R4quq75Q/dOZ+fWZ+fmZubR6\nzlrroc06n6+ec7onz8ztM3NiZk48/uVHthgDAGB720TRseoF1evXWt9ePdIT3ipba61qne7Ja627\n1lrH11rHL7rs0i3GAADY3jZRdKo6tdZ63+b+O9qLpN+ZmedWbb4+vN2IAACH76yjaK31+eqBmfnm\nzaIXVx+t7q1u2yy7rXrXVhMCADwDjm35/B+p3jIzF1efrv5Ge6H19pl5efXZ6ge3fA0AgEO3VRSt\ntX6jOn6ah168zfcFAHim+URrAIBEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQR\nAEAligAAqu1/ISxckG545f07e+2Td9502uXn4kwA5xNHigAAEkUAAJUoAgCoRBEAQCWKAAAqUQQA\nUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIA\nqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUNWxXQ8A56OT\nd9606xGe5FycCeB84kgRAECiCACgEkUAAJUoAgCoRBEAQOXqMzgrN7zy/p299vl0ldmnXvqGnb32\n9ffcsbPXBs5PjhQBACSKAAAqUQQAUG0ZRTPzypn5yMx8eGbeOjNfPzPXzcz7ZubkzNwzMxcf1LAA\nAIflrKNoZq6q/k51fK31rdVF1a3Vz1R3rrVuqH6/evlBDAoAcJi2ffvsWPXHZuZY9azqoepF1Ts2\nj99dff+WrwEAcOjOOorWWg9W/7D6XHsx9KXq16ovrrUe26x2qrpq2yEBAA7bNm+fPbu6pbqu+qbq\n0urmr+H5t8/MiZk58fiXHznbMQAADsQ2b5/9xeoza60vrLX+sHpn9R3V5Zu306qurh483ZPXWnet\ntY6vtY5fdNmlW4wBALC9baLoc9VNM/OsmZnqxdVHq/dUP7BZ57bqXduNCABw+LY5p+h97Z1Q/YHq\nQ5vvdVf1E9XfnZmT1TdUbzyAOQEADtVWv/tsrfXa6rVPWPzp6oXbfF8AgGeaT7QGAEgUAQBUoggA\noBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQA\nUIkiAIBKFAEAVHVs1wPA+ejknTfteoTzwvX33LHrEQCeNkeKAAASRQAAlSgCAKhEEQBAJYoAACpX\nn8GR8amXvmFnr32mq8zOxZkAzsSRIgCARBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgC\nAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQB\nAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUD2NKJqZN83MwzPz4X3LrpiZd8/MJzdfn71ZPjPz\nczNzcmY+ODMvOMzhAQAOytM5UvTm6uYnLHt1dd9a68bqvs39qpdUN27+3F69/mDGBAA4XMeeaoW1\n1ntn5tonLL6l+s7N7burX6l+YrP8F9Zaq7p/Zi6fmeeutR46qIGB07v+njt2PcKTnIszAZzJ2Z5T\n9Jx9ofP56jmb21dVD+xb79RmGQDAOW3rE603R4XW1/q8mbl9Zk7MzInHv/zItmMAAGzlbKPod2bm\nuVWbrw9vlj9YXbNvvas3y55krXXXWuv4Wuv4RZddepZjAAAcjLONonur2za3b6vetW/5D2+uQrup\n+pLziQCA88FTnmg9M29t76TqK2fmVPXa6qert8/My6vPVj+4Wf2Xq++pTlZ/UP2NQ5gZAODAPZ2r\nz152hodefJp1V/WKbYcCAHim+URrAIBEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIA\nqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEA\nVKIIAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAA\nKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAA\nlSgCAKhEEQBAJYoAACpRBABQiSIAgOppRNHMvGlmHp6ZD+9b9g9m5uMz88GZ+Vczc/m+x14zMydn\n5hMz892HNTgAwEF6OkeK3lzd/IRl766+da31Z6rfql5TNTPPr26tvmXznH8yMxcd2LQAAIfkKaNo\nrfXe6veesOzfrbUe29y9v7p6c/uW6m1rrUfXWp+pTlYvPMB5AQAOxUGcU/Q3q3+9uX1V9cC+x05t\nlj3JzNw+Mydm5sTjX37kAMYAADh7W0XRzPxU9Vj1lq/1uWutu9Zax9daxy+67NJtxgAA2Nqxs33i\nzPz16nurF6+11mbxg9U1+1a7erMMAOCcdlZHimbm5urHq+9ba/3BvofurW6dmUtm5rrqxur9248J\nAHC4nvJI0cy8tfrO6sqZOVW9tr2rzS6p3j0zVfevte5Ya31kZt5efbS9t9VesdZ6/LCGBwA4KE8Z\nRWutl51m8Rv/C+u/rnrdNkMBADzTfKI1AECiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBK\nFAEAVKIIAKASRQAA1dP43WcAXJhueOX9O3vtk3fedNrl5+JMHB2OFAEAJIoAACpRBABQiSIAgEoU\nAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWK\nAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoKpj\nux4AgHPTyTtv2vUIT3IuzsTR4UgRAECiCACgEkUAAJUoAgCoRBEAQCWKAAAql+QDcAY3vPL+nb32\n+XTp/ade+oadvfb199yxs9c+ihwpAgBIFAEAVKIIAKB6GlE0M2+amYdn5sOneexVM7Nm5srN/ZmZ\nn5uZkzPzwZl5wWEMDQBw0J7OkaI3Vzc/ceHMXFP9pepz+xa/pLpx8+f26vXbjwgAcPieMorWWu+t\nfu80D91Z/Xi19i27pfqFtef+6vKZee6BTAoAcIjO6pyimbmlenCt9ZtPeOiq6oF9909tlp3ue9w+\nMydm5sTjX37kbMYAADgwX/PnFM3Ms6qfbO+ts7O21rqruqvqkv/mmvUUqwMAHKqz+fDG66vrqt+c\nmaqrqw/MzAurB6tr9q179WYZAMA57Wt++2yt9aG11jeuta5da13b3ltkL1hrfb66t/rhzVVoN1Vf\nWms9dLAjAwAcvKdzSf5bq/+7+uaZOTUzL/8vrP7L1aerk9X/Wf1PBzIlAMAhe8q3z9ZaL3uKx6/d\nd3tVr9h+LACAZ5ZPtAYASBQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKII\nAKASRQAAlSgCAKhEEQBAJYoAACpRBABQiSIAgKqO7XoAAM5NJ++8adcjnBeuv+eOXY/AAXGkCAAg\nUQQAUIkiAIBKFAEAVKIIAKASRQAAlUvyATiPfOqlb9jZa5/p0vtzcSbOjiNFAACJIgCAShQBAFSi\nCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAJYoAACpR\nBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQBAFRPI4pm5k0z8/DM\nfPgJy39kZj4+Mx+Zmb+/b/lrZubkzHxiZr77MIYGADhox57GOm+u/vfqF766YGa+q7ql+ra11qMz\n842b5c+vbq2+pfqm6t/PzPPWWo8f9OAAAAfpKY8UrbXeW/3eExb/reqn11qPbtZ5eLP8lupta61H\n11qfqU5WLzzAeQEADsXTOVJ0Os+r/vzMvK76f6q/t9b61eqq6v59653aLAOArV1/zx27HuFJzsWZ\nODtnG0XHqiuqm6o/W719Zv7U1/INZub26vaqi5797LMcAwDgYJzt1WenqneuPe+v/qi6snqwumbf\neldvlj3JWuuutdbxtdbxiy679CzHAAA4GGcbRb9YfVfVzDyvurj63ere6taZuWRmrqturN5/EIMC\nABymp3z7bGbeWn1ndeXMnKpeW72petPmMv2vVLettVb1kZl5e/XR6rHqFa48AwDOB08ZRWutl53h\nob96hvVfV71um6EAAJ5pPtEaACBRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEE\nAFCJIgCAShQBAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAIBKFAEAVKIIAKASRQAAlSgC\nAKhEEQBAJYoAACpRBABQiSIAgEoUAQBUoggAoBJFAACVKAIAqEQRAEAligAAKlEEAFCJIgCAShQB\nAFSiCACgEkUAAJUoAgCoRBEAQCWKAAAqUQQAUIkiAICqZq216xmamS9Un62urH53x+PswoW43Rfi\nNpftvtBciNt9IW5z2e5z3X+71voTT7XSORFFXzUzJ9Zax3c9xzPtQtzuC3Gby3bveo5n2oW43Rfi\nNpft3vUcB8XbZwAAiSIAgOrci6K7dj3AjlyI230hbnPZ7gvNhbjdF+I2l+0+Es6pc4oAAHblXDtS\nBACwE+dEFM3MzTPziZk5OTOv3vU8h2VmrpmZ98zMR2fmIzPzo5vlV8zMu2fmk5uvz971rAdtZi6a\nmV+fmV/a3L9uZt632ef3zMzFu57xoM3M5TPzjpn5+Mx8bGb+3AWyr1+5+fv94Zl568x8/VHc3zPz\nppl5eGY+vG/Zaffv7Pm5zfZ/cGZesLvJt3OG7f4Hm7/nH5yZfzUzl+977DWb7f7EzHz3bqbe3um2\ne99jr5qZNTNXbu4f6f29Wf4jm33+kZn5+/uWn9f7e+dRNDMXVf9H9ZLq+dXLZub5u53q0DxWvWqt\n9fzqpuoVm219dXXfWuvG6r7N/aPmR6uP7bv/M9Wda60bqt+vXr6TqQ7Xz1b/Zq31p6tva2/7j/S+\nnpmrqr9THV9rfWt1UXVrR3N/v7m6+QnLzrR/X1LduPlze/X6Z2jGw/Dmnrzd766+da31Z6rfql5T\ntfn5dmv1LZvn/JPNz/zz0Zt78nY3M9dUf6n63L7FR3p/z8x3VbdU37bW+pbqH26Wn/f7e+dRVL2w\nOrnW+vRa6yvV29r7H/vIWWs9tNb6wOb2f27vP5JXtbe9d29Wu7v6/t1MeDhm5urqL1c/v7k/1Yuq\nd2xWOYrb/Merv1C9sWqt9ZW11hc74vt641j1x2bmWPWs6qGO4P5ea723+r0nLD7T/r2l+oW15/7q\n8pl57jMz6cE63Xavtf7dWuuxzd37q6s3t2+p3rbWenSt9ZnqZHs/8887Z9jfVXdWP17tP0H3SO/v\n6m9VP73WenSzzsOb5ef9/j4Xouiq6oF9909tlh1pM3Nt9e3V+6rnrLUe2jz0+eo5OxrrsPzj9n5o\n/NHm/jdUX9z3Q/Qo7vPrqi9U/3TztuHPz8ylHfF9vdZ6sL3/1/i59mLoS9WvdfT391edaf9eSD/n\n/mb1rze3j/R2z8wt1YNrrd98wkNHerur51V/fvOW+H+cmT+7WX7eb/e5EEUXnJm5rPqX1Y+ttf7T\n/sfW3uWAR+aSwJn53urhtdav7XqWZ9ix6gXV69da31490hPeKjtq+7pqcw7NLe1F4TdVl3aatxwu\nBEdx/z6Vmfmp9k4TeMuuZzlsM/Os6ier/2XXs+zAseqK9k4D+Z+rt2/eATjvnQtR9GB1zb77V2+W\nHUkz83XtBdFb1lrv3Cz+na8eWt18ffhMzz8PfUf1fTPz2+29Nfqi9s61uXzz9kodzX1+qjq11nrf\n5v472ouko7yvq/5i9Zm11hfWWn9YvbO9vwNHfX9/1Zn275H/OTczf7363uqH1v//WS9Hebuvby/+\nf3Pz8+3q6gMz8yc72ttdez/f3rl5e/D97b0LcGVHYLvPhSj61erGzdUpF7d3kta9O57pUGxK+o3V\nx9Za/2jfQ/dWt21u31a965me7bCstV6z1rp6rXVte/v2P6y1fqh6T/UDm9WO1DZXrbU+Xz0wM9+8\nWfTi6qMd4X298bnqppl51ubv+1e3+0jv733OtH/vrX54c1XSTdWX9r3Ndt6bmZvbe4v8+9Zaf7Dv\noXurW2fmkpm5rr0Tj9+/ixkP2lrrQ2utb1xrXbv5+XaqesHm3/6R3t/VL1bfVTUzz6subu+Xwp7/\n+3uttfM/1fe0d8XCp6qf2vU8h7id/0N7h9M/WP3G5s/3tHeOzX3VJ6t/X12x61kPafu/s/qlze0/\n1d4/lpPVv6gu2fV8h7C9/111YrO/f7F69oWwr6v/rfp49eHqn1WXHMX9Xb21vfOm/rC9/yC+/Ez7\nt5r2rrL9VPWh9q7O2/k2HOB2n2zvXJKv/lx7w771f2qz3Z+oXrLr+Q9yu5/w+G9XV14g+/vi6p9v\n/o1/oHrRUdnfPtEaAKBz4+0zAICdE0UAAIkiAIBKFAEAVKIIAKASRQAAlSgCAKhEEQBAVf8vMwxh\njUiRM0EAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "v-RhEeUfFVHI", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## VALID padding, no stride, lhs=input dilation ~ Transposed Convolution" ] }, { "metadata": { "id": "B9Ail8ppFVHJ", "colab_type": "code", "outputId": "03d00b5a-ec38-435a-81f7-79d4737239c0", "colab": { "base_uri": "https://localhost:8080/", "height": 629 } }, "cell_type": "code", "source": [ "out = lax.conv_general_dilated(img, # lhs = image tensor\n", " kernel, # rhs = conv kernel tensor\n", " (1,1), # window strides\n", " 'SAME', # padding mode\n", " (2,2), # lhs/image dilation\n", " (1,1), # rhs/kernel dilation\n", " dn) # dimension_numbers = lhs, rhs, out dimension permutation\n", "print(\"out shape: \", out.shape, \"<-- larger than original!\")\n", "plt.figure(figsize=(10,10))\n", "print(\"First output channel:\")\n", "plt.imshow(onp.array(out)[0,:,:,0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "out shape: (1, 399, 395, 3) <-- larger than original!\n", "First output channel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGYRJREFUeJzt3H+s5XV95/HXuzOIDZgC1SUU2JVF\nNsY26UimlKZN42pskX/GJq7iJpU1JiO7mKhxN6L/1CY1sZu1bMzuQmikYtMWiNZIGvqDVZKmf4CO\nFpEf2l78EZgdYetvMEvD9L1/3Dd6Z5xh7sy9557p3Mcjubnf8znfc8/nfPhe8pzzPd9b3R0AAJKf\nWPYEAABOFsIIAGAIIwCAIYwAAIYwAgAYwggAYCwsjKrqiqr6clWtVNV1i3oeAIDNUov4O0ZVtSPJ\n3yV5dZLHknw2yRu7+6FNfzIAgE2yqHeMLkuy0t1f6e5/THJrkj0Lei4AgE2xc0E/9/wkj665/ViS\nXzzazjvOPKN3nnPOgqYCAGxnz3zrWzn45FO1nn0XFUbHVFV7k+xNkh1nn52fedc7ljUVAOAU9n8+\n+N/Xve+iTqXtT3LhmtsXzNgPdfdN3b27u3fvOPOMBU0DAGD9FhVGn01ySVVdVFXPS3JVkjsW9FwA\nAJtiIafSuvuZqnpbkr9MsiPJzd394CKeCwBgsyzsM0bdfWeSOxf18wEANpu/fA0AMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEA\nwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABD\nGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEE\nADCEEQDAEEYAAEMYAQCMnRt5cFV9Lcn3kxxM8kx3766qc5LcluTFSb6W5PXd/e2NTRMAYPE24x2j\nf9vdu7p799y+LsmnuvuSJJ+a2wAAJ71FnErbk+SW2b4lyWsX8BwAAJtuo2HUSf6qqj5XVXtn7Nzu\nPjDb30hy7pEeWFV7q2pfVe07+ORTG5wGAMDGbegzRkl+pbv3V9W/SHJXVX1p7Z3d3VXVR3pgd9+U\n5KYkOf1fXnjEfQAAttKG3jHq7v3z/Ykkn0hyWZLHq+q8JJnvT2x0kgAAW+GEw6iqzqiqFzy7neTX\nkjyQ5I4kV89uVyf55EYnCQCwFTZyKu3cJJ+oqmd/zh93919U1WeT3F5Vb0ny9SSv3/g0AQAW74TD\nqLu/kuTnjzD+zSSv2sikAACWwV++BgAYwggAYAgjAIAhjAAAhjACABgb/cvXHMVL3nnPlj7fyvWX\nb+nzAcCpyDtGAADDO0YLtlXv5Dzyhhtz8W3XbMlzAcCpyjtGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABj57IncKpauf7yJMkjb7hxyTMBANZLGC3Yxbdds+wp\nAADr5FQaAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMI4ZRlV1c1U9\nUVUPrBk7p6ruqqq/n+9nz3hV1YeqaqWq7q+qSxc5eQCAzbSed4w+kuSKw8auS/Kp7r4kyafmdpK8\nJskl87U3yQ2bM00AgMU7Zhh1918n+dZhw3uS3DLbtyR57Zrxj/aqe5KcVVXnbdZkAQAW6UQ/Y3Ru\ndx+Y7W8kOXe2z0/y6Jr9HpuxH1NVe6tqX1XtO/jkUyc4DQCAzbPhD193dyfpE3jcTd29u7t37zjz\njI1OAwBgw040jB5/9hTZfH9ixvcnuXDNfhfMGADASe9Ew+iOJFfP9tVJPrlm/E1zddrlSb675pQb\nAMBJbeexdqiqP0nyiiQvrKrHkvxWkg8kub2q3pLk60leP7vfmeTKJCtJfpDkzQuYMwDAQhwzjLr7\njUe561VH2LeTXLvRSQEALIO/fA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCE\nEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAM\nYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQR\nAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAOOYYVRVN1fVE1X1wJqx91XV\n/qq6b76uXHPfe6pqpaq+XFW/vqiJAwBstvW8Y/SRJFccYfz67t41X3cmSVW9LMlVSX52HvO/qmrH\nZk0WAGCRjhlG3f3XSb61zp+3J8mt3f10d381yUqSyzYwPwCALbORzxi9rarun1NtZ8/Y+UkeXbPP\nYzP2Y6pqb1Xtq6p9B598agPTAADYHCcaRjckuTjJriQHknzweH9Ad9/U3bu7e/eOM884wWkAAGye\nEwqj7n68uw929z8l+f386HTZ/iQXrtn1ghkDADjpnVAYVdV5a27+RpJnr1i7I8lVVXV6VV2U5JIk\nn9nYFAEAtsbOY+1QVX+S5BVJXlhVjyX5rSSvqKpdSTrJ15K8NUm6+8Gquj3JQ0meSXJtdx9czNQB\nADbXMcOou994hOEPP8f+70/y/o1MCgBgGfzlawCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjC\nCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgLFz2RMA\nFu8l77xny55r5frLt+y5ADabMIJtYquC5ZE33Jgkufi2a7bk+QA2k1NpAABDGAEADGEEADCEEQDA\nEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMY\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwDhmGFXVhVV1d1U9VFUPVtXbZ/ycqrqrqv5+vp89\n41VVH6qqlaq6v6ouXfSLAADYDDvXsc8zSd7V3Z+vqhck+VxV3ZXkPyT5VHd/oKquS3JdkncneU2S\nS+brF5PcMN+BJVm5/vI88oYblz0NgJPeMd8x6u4D3f352f5+koeTnJ9kT5JbZrdbkrx2tvck+Wiv\nuifJWVV13qbPHABgk63nHaMfqqoXJ3l5knuTnNvdB+aubyQ5d7bPT/Lomoc9NmMH1oylqvYm2Zsk\nO84++zinDRyvi2+7ZtlTADjprfvD11V1ZpKPJ3lHd39v7X3d3Un6eJ64u2/q7t3dvXvHmWccz0MB\nABZiXWFUVadlNYr+qLv/dIYff/YU2Xx/Ysb3J7lwzcMvmDEAgJPaeq5KqyQfTvJwd//emrvuSHL1\nbF+d5JNrxt80V6ddnuS7a065AQCctNbzGaNfTvKbSb5YVffN2HuTfCDJ7VX1liRfT/L6ue/OJFcm\nWUnygyRv3tQZAwAsyDHDqLv/Jkkd5e5XHWH/TnLtBucFALDl/OVrAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACG\nMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMII\nAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCA\nIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYw\nAgAYxwyjqrqwqu6uqoeq6sGqevuMv6+q9lfVffN15ZrHvKeqVqrqy1X164t8AQAAm2XnOvZ5Jsm7\nuvvzVfWCJJ+rqrvmvuu7+7+t3bmqXpbkqiQ/m+Rnkvzvqvo33X1wMycOALDZjvmOUXcf6O7Pz/b3\nkzyc5PzneMieJLd299Pd/dUkK0ku24zJAgAs0nF9xqiqXpzk5UnunaG3VdX9VXVzVZ09Y+cneXTN\nwx7LEUKqqvZW1b6q2nfwyaeOe+IAAJtt3WFUVWcm+XiSd3T395LckOTiJLuSHEjyweN54u6+qbt3\nd/fuHWeecTwPBQBYiHWFUVWdltUo+qPu/tMk6e7Hu/tgd/9Tkt/Pj06X7U9y4ZqHXzBjAAAntfVc\nlVZJPpzk4e7+vTXj563Z7TeSPDDbdyS5qqpOr6qLklyS5DObN2UAgMVYz1Vpv5zkN5N8sarum7H3\nJnljVe1K0km+luStSdLdD1bV7UkeyuoVbde6Ig0A+OfgmGHU3X+TpI5w153P8Zj3J3n/BuYFALDl\n/OVrAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABg\nCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYxwyjqnp+VX2mqr5QVQ9W1W/P+EVVdW9VrVTVbVX1vBk/\nfW6vzP0vXuxLAADYHOt5x+jpJK/s7p9PsivJFVV1eZLfTXJ9d78kybeTvGX2f0uSb8/49bMfAMBJ\nb+exdujuTvLk3DxtvjrJK5P8+xm/Jcn7ktyQZM9sJ8nHkvyPqqr5OQBscy955z1b9lwr11++Zc/F\nqeGYYZQkVbUjyeeSvCTJ/0zySJLvdPczs8tjSc6f7fOTPJok3f1MVX03yU8n+YdNnDcA/4xtVbA8\n8oYbc/Ft12zJc3FqWNeHr7v7YHfvSnJBksuSvHSjT1xVe6tqX1XtO/jkUxv9cQAAG3ZcV6V193eS\n3J3kl5KcVVXPvuN0QZL9s70/yYVJMvf/VJJvHuFn3dTdu7t7944zzzjB6QMAbJ71XJX2oqo6a7Z/\nMsmrkzyc1UB63ex2dZJPzvYdcztz/6d9vggA+OdgPZ8xOi/JLfM5o59Icnt3/1lVPZTk1qr6nSR/\nm+TDs/+Hk/xhVa0k+VaSqxYwbwCATbeeq9LuT/LyI4x/JaufNzp8/P8l+XebMjsAgC3kL18DAAxh\nBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEA\nwBBGAABDGAEADGEEADCEEQDAEEYAAGPnsicAwPaycv3leeQNNy57GnBEwgiALXfxbdcsewpwRE6l\nAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQCMY4ZRVT2/qj5TVV+oqger\n6rdn/CNV9dWqum++ds14VdWHqmqlqu6vqksX/SIAADbDznXs83SSV3b3k1V1WpK/qao/n/v+S3d/\n7LD9X5Pkkvn6xSQ3zHcAgJPaMd8x6lVPzs3T5quf4yF7knx0HndPkrOq6ryNTxUAYLHW9RmjqtpR\nVfcleSLJXd1979z1/jlddn1VnT5j5yd5dM3DH5sxAICT2rrCqLsPdveuJBckuayqfi7Je5K8NMkv\nJDknybuP54mram9V7auqfQeffOo4pw0AsPmO66q07v5OkruTXNHdB+Z02dNJ/iDJZbPb/iQXrnnY\nBTN2+M+6qbt3d/fuHWeecWKzBwDYROu5Ku1FVXXWbP9kklcn+dKznxuqqkry2iQPzEPuSPKmuTrt\n8iTf7e4DC5k9AMAmWs9VaecluaWqdmQ1pG7v7j+rqk9X1YuSVJL7klwz+9+Z5MokK0l+kOTNmz9t\nAIDNd8ww6u77k7z8COOvPMr+neTajU8NAGBr+cvXAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgB\nAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAw\nhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBG\nAACjunvZc0hV/d8kTyX5h2XP5STywliPtazHoazHoazHoazHj7Mmh9pu6/GvuvtF69nxpAijJKmq\nfd29e9nzOFlYj0NZj0NZj0NZj0NZjx9nTQ5lPY7OqTQAgCGMAADGyRRGNy17AicZ63Eo63Eo63Eo\n63Eo6/HjrMmhrMdRnDSfMQIAWLaT6R0jAIClWnoYVdUVVfXlqlqpquuWPZ9lqaqvVdUXq+q+qto3\nY+dU1V1V9ffz/exlz3NRqurmqnqiqh5YM3bE11+rPjTHzP1VdenyZr4YR1mP91XV/jlG7quqK9fc\n955Zjy9X1a8vZ9aLU1UXVtXdVfVQVT1YVW+f8W15jDzHemzLY6Sqnl9Vn6mqL8x6/PaMX1RV987r\nvq2qnjfjp8/tlbn/xcuc/2Z7jvX4SFV9dc3xsWvGT+nfl+PW3Uv7SrIjySNJ/nWS5yX5QpKXLXNO\nS1yLryV54WFj/zXJdbN9XZLfXfY8F/j6fzXJpUkeONbrT3Jlkj9PUkkuT3Lvsue/RevxviT/+Qj7\nvmx+d05PctH8Tu1Y9mvY5PU4L8mls/2CJH83r3tbHiPPsR7b8hiZ/85nzvZpSe6d/+63J7lqxm9M\n8h9n+z8luXG2r0py27Jfwxatx0eSvO4I+5/Svy/H+7Xsd4wuS7LS3V/p7n9McmuSPUue08lkT5Jb\nZvuWJK9d4lwWqrv/Osm3Dhs+2uvfk+SjveqeJGdV1XlbM9OtcZT1OJo9SW7t7qe7+6tJVrL6u3XK\n6O4D3f352f5+koeTnJ9teow8x3oczSl9jMx/5yfn5mnz1UlemeRjM3748fHscfOxJK+qqtqi6S7c\nc6zH0ZzSvy/Ha9lhdH6SR9fcfizP/ct9Kuskf1VVn6uqvTN2bncfmO1vJDl3OVNbmqO9/u183Lxt\n3uq+ec2p1W21HnPa4+VZ/Vfwtj9GDluPZJseI1W1o6ruS/JEkruy+q7Yd7r7mdll7Wv+4XrM/d9N\n8tNbO+PFOnw9uvvZ4+P9c3xcX1Wnz9gpf3wcj2WHET/yK919aZLXJLm2qn517Z29+n7ntr2EcLu/\n/nFDkouT7EpyIMkHlzudrVdVZyb5eJJ3dPf31t63HY+RI6zHtj1Guvtgd+9KckFW3w176ZKntFSH\nr0dV/VyS92R1XX4hyTlJ3r3EKZ60lh1G+5NcuOb2BTO27XT3/vn+RJJPZPUX+/Fn386c708sb4ZL\ncbTXvy2Pm+5+fP5n909Jfj8/OhWyLdajqk7LagT8UXf/6Qxv22PkSOux3Y+RJOnu7yS5O8kvZfWU\n0M65a+1r/uF6zP0/leSbWzzVLbFmPa6YU7Dd3U8n+YNsw+NjPZYdRp9NcslcOfC8rH4I7o4lz2nL\nVdUZVfWCZ7eT/FqSB7K6FlfPblcn+eRyZrg0R3v9dyR501xJcXmS7645nXLKOuyc/29k9RhJVtfj\nqrnS5qIklyT5zFbPb5Hm8x8fTvJwd//emru25TFytPXYrsdIVb2oqs6a7Z9M8uqsfu7q7iSvm90O\nPz6ePW5el+TT847jKeEo6/GlNf+IqKx+3mrt8XHK/r4cr53H3mVxuvuZqnpbkr/M6hVqN3f3g8uc\n05Kcm+QT89m/nUn+uLv/oqo+m+T2qnpLkq8nef0S57hQVfUnSV6R5IVV9ViS30rygRz59d+Z1aso\nVpL8IMmbt3zCC3aU9XjFXF7bWb2K8a1J0t0PVtXtSR5K8kySa7v74DLmvUC/nOQ3k3xxPjeRJO/N\n9j1GjrYeb9ymx8h5SW6pqh1Z/Qf/7d39Z1X1UJJbq+p3kvxtVmMy8/0Pq2olqxc5XLWMSS/Q0dbj\n01X1oqxefXZfkmtm/1P99+W4+MvXAABj2afSAABOGsIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAxv8H2emSG61MbfQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "A-9OagtrVDyV", "colab_type": "text" }, "cell_type": "markdown", "source": [ "We can use the last to, for instance, implement _transposed convolutions_:" ] }, { "metadata": { "id": "5EYIj77-NdHE", "colab_type": "code", "outputId": "d2e82a42-9c8e-4973-f760-511a14805527", "colab": { "base_uri": "https://localhost:8080/", "height": 629 } }, "cell_type": "code", "source": [ "# The following is equivalent to tensorflow:\n", "# N,H,W,C = img.shape\n", "# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))\n", "\n", "# transposed conv = 180deg kernel roation plus LHS dilation\n", "# rotate kernel 180deg:\n", "kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1))\n", "# need a custom output padding:\n", "padding = ((2, 1), (2, 1))\n", "out = lax.conv_general_dilated(img, # lhs = image tensor\n", " kernel_rot, # rhs = conv kernel tensor\n", " (1,1), # window strides\n", " padding, # padding mode\n", " (2,2), # lhs/image dilation\n", " (1,1), # rhs/kernel dilation\n", " dn) # dimension_numbers = lhs, rhs, out dimension permutation\n", "print(\"out shape: \", out.shape, \"<-- transposed_conv\")\n", "plt.figure(figsize=(10,10))\n", "print(\"First output channel:\")\n", "plt.imshow(onp.array(out)[0,:,:,0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "out shape: (1, 400, 396, 3) <-- transposed_conv\n", "First output channel:\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkYAAAJCCAYAAAAlTAh6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGXVJREFUeJzt3W2s5nV95/HPtzOIDWPKUF1Cgaws\nsDG2SUcypdO0aVyNFXkyNrGKD5QYk5FdTKrpbop9UpvUpN2ssjHZ1WCkYtMWiNZIDL1hlcT4ABQt\nIje1Pd4FZkfY1ls0S8P0uw/OFz3Dzs2ZOec61zjn9UpOzv/6Xf/rXL/rx/+Y91z/63+s7g4AAMlP\nLXsCAACnC2EEADCEEQDAEEYAAEMYAQAMYQQAMBYWRlV1VVV9uapWquqGRT0PAMBmqUX8HaOq2pHk\nH5K8IsljST6X5PXd/fCmPxkAwCZZ1DtGVyZZ6e6vdve/JLk1yf4FPRcAwKbYuaCfe2GSR9fcfizJ\nLx9r5x27zumd5523oKkAANvZ09/6Vg4/+YNaz76LCqMTqqoDSQ4kyY7du/Nzv/O2ZU0FADiD/e93\n//d177uoU2kHk1y85vZFM/Yj3X1Td+/t7r07dp2zoGkAAKzfosLoc0kur6pLquo5Sa5JcseCngsA\nYFMs5FRadz9dVW9N8jdJdiS5ubsfWsRzAQBsloV9xqi770xy56J+PgDAZvOXrwEAhjACABjCCABg\nCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIA\nGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYOzcyIOr6utJvp/kcJKnu3tvVZ2X5LYkL0zy9SSv7e5vb2yaAACLtxnvGP2H7t7T\n3Xvn9g1JPtndlyf55NwGADjtLeJU2v4kt8z2LUlevYDnAADYdBsNo07yt1X1+ao6MGPnd/eh2f5m\nkvM3+BwAAFtiQ58xSvJr3X2wqv5Nkruq6u/X3tndXVV9tAdOSB1Ikh27d29wGgAAG7ehd4y6++B8\nfyLJx5JcmeTxqrogSeb7E8d47E3dvbe79+7Ydc5GpgEAsClOOYyq6pyqet4z20l+I8mDSe5Icu3s\ndm2Sj290kgAAW2Ejp9LOT/Kxqnrm5/x5d/91VX0uye1V9eYk30jy2o1PEwBg8U45jLr7q0l+8Sjj\n/5zk5RuZFADAMvjL1wAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDA2Oj/iSzH8ZXXvX9L\nn+/S267b0ucDgDONMNoCWxUsl739nqzcuG9LngsAzkROpQEADGEEADCEEQDAEEYAAEMYAQAMYQQA\nMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEAjJ3LnsCZ7NLbrkuSXPb2e5Y8EwBgPYTRFli5cd+ypwAA\nrINTaQAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQCME4ZRVd1cVU9U1YNrxs6rqruq\n6h/n++4Zr6p6b1WtVNUDVXXFIicPALCZ1vOO0YeSXPWssRuSfLK7L0/yybmdJK9Kcvl8HUjyvs2Z\nJgDA4p0wjLr700m+9azh/Ulume1bkrx6zfiHe9U9Sc6tqgs2a7IAAIt0qp8xOr+7D832N5OcP9sX\nJnl0zX6PzRgAwGlvwx++7u5O0if7uKo6UFX3VdV9h5/8wUanAQCwYacaRo8/c4psvj8x4weTXLxm\nv4tm7P/T3Td1997u3rtj1zmnOA0AgM1zqmF0R5JrZ/vaJB9fM/7GuTptX5LvrjnlBgBwWtt5oh2q\n6i+SvDTJ86vqsSS/n+SPktxeVW9O8o0kr53d70xydZKVJD9M8qYFzBkAYCFOGEbd/fpj3PXyo+zb\nSa7f6KQAAJbBX74GABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAh\njAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjAC\nABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABg\nCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGM\nAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIBxwjCqqpur6omqenDN2Dur6mBV3T9fV6+5\n7x1VtVJVX66qVy5q4gAAm2097xh9KMlVRxm/sbv3zNedSVJVL05yTZKfn8f8z6rasVmTBQBYpBOG\nUXd/Osm31vnz9ie5tbuf6u6vJVlJcuUG5gcAsGU28hmjt1bVA3OqbfeMXZjk0TX7PDZjAACnvVMN\no/cluTTJniSHkrz7ZH9AVR2oqvuq6r7DT/7gFKcBALB5TimMuvvx7j7c3f+a5AP58emyg0kuXrPr\nRTN2tJ9xU3fv7e69O3adcyrTAADYVKcURlV1wZqbv5nkmSvW7khyTVWdXVWXJLk8yWc3NkUAgK2x\n80Q7VNVfJHlpkudX1WNJfj/JS6tqT5JO8vUkb0mS7n6oqm5P8nCSp5Nc392HFzN1AIDNdcIw6u7X\nH2X4g8fZ/11J3rWRSQEALIO/fA0AMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCE\nEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAGPnsicA\nbI2vvO79W/Zcl9523ZY9F8BmEkawjWxVsFz29nuycuO+LXkugM3kVBoAwBBGAABDGAEADGEEADCE\nEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYA\nAEMYAQAMYQQAMIQRAMAQRgAA44RhVFUXV9XdVfVwVT1UVb894+dV1V1V9Y/zffeMV1W9t6pWquqB\nqrpi0S8CAGAzrOcdo6eT/E53vzjJviTXV9WLk9yQ5JPdfXmST87tJHlVksvn60CS9236rAEAFuCE\nYdTdh7r7C7P9/SSPJLkwyf4kt8xutyR59WzvT/LhXnVPknOr6oJNnzkAwCbbeTI7V9ULk7wkyb1J\nzu/uQ3PXN5OcP9sXJnl0zcMem7FDAZbm0tuuy2Vvv2fZ0wA4ra07jKpqV5KPJnlbd3+vqn50X3d3\nVfXJPHFVHcjqqbbs2L37ZB4KnKKVG/ctewoAp7V1XZVWVWdlNYr+rLv/coYff+YU2Xx/YsYPJrl4\nzcMvmrEjdPdN3b23u/fu2HXOqc4fAGDTrOeqtErywSSPdPd71tx1R5JrZ/vaJB9fM/7GuTptX5Lv\nrjnlBgBw2lrPqbRfTfKGJF+qqvtn7PeS/FGS26vqzUm+keS1c9+dSa5OspLkh0netKkzBgBYkBOG\nUXd/Jkkd4+6XH2X/TnL9BucFALDl/OVrAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgj\nAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYJwyjqrq4qu6uqoer\n6qGq+u0Zf2dVHayq++fr6jWPeUdVrVTVl6vqlYt8AQAAm2XnOvZ5OsnvdPcXqup5ST5fVXfNfTd2\n939bu3NVvTjJNUl+PsnPJflfVfXvu/vwZk4cAGCznfAdo+4+1N1fmO3vJ3kkyYXHecj+JLd291Pd\n/bUkK0mu3IzJAgAs0kl9xqiqXpjkJUnunaG3VtUDVXVzVe2esQuTPLrmYY/l+CEFAHBaWHcYVdWu\nJB9N8rbu/l6S9yW5NMmeJIeSvPtknriqDlTVfVV13+Enf3AyDwUAWIh1hVFVnZXVKPqz7v7LJOnu\nx7v7cHf/a5IP5Menyw4muXjNwy+asSN0903dvbe79+7Ydc5GXgMAwKZYz1VpleSDSR7p7vesGb9g\nzW6/meTB2b4jyTVVdXZVXZLk8iSf3bwpAwAsxnquSvvVJG9I8qWqun/Gfi/J66tqT5JO8vUkb0mS\n7n6oqm5P8nBWr2i73hVpAMBPghOGUXd/Jkkd5a47j/OYdyV51wbmBQCw5fzlawCAIYwAAIYwAgAY\nwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgj\nAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAA\nhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjC\nCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMA\ngCGMAACGMAIAGCcMo6p6blV9tqq+WFUPVdUfzPglVXVvVa1U1W1V9ZwZP3tur8z9L1zsSwAA2Bzr\necfoqSQv6+5fTLInyVVVtS/JHye5sbsvS/LtJG+e/d+c5NszfuPsBwBw2jthGPWqJ+fmWfPVSV6W\n5CMzfkuSV8/2/rmduf/lVVWbNmMAgAXZuZ6dqmpHks8nuSzJ/0jylSTf6e6nZ5fHklw42xcmeTRJ\nuvvpqvpukp9N8k+bOG8AfkJ95XXv37LnuvS267bsuTgzrCuMuvtwkj1VdW6SjyV50UafuKoOJDmQ\nJDt2797ojwPgJ8hWBctlb78nSbJy474teT5+8p3UVWnd/Z0kdyf5lSTnVtUzYXVRkoOzfTDJxUky\n9/9Mkn8+ys+6qbv3dvfeHbvOOcXpAwBsnvVclfaCeacoVfXTSV6R5JGsBtJrZrdrk3x8tu+Y25n7\nP9XdvZmTBgBYhPWcSrsgyS3zOaOfSnJ7d3+iqh5OcmtV/WGSv0vywdn/g0n+tKpWknwryTULmDcA\nwKY7YRh19wNJXnKU8a8mufIo4/83yW9tyuwAALaQv3wNADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQ\nRgAAQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgB\nAAxhBAAwdi57AgBsL5fedl0ue/s9y54GHJV3jAAAhneMANhyKzfuW/YU4Ki8YwQAMIQRAMAQRgAA\nQxgBAAxhBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxh\nBAAwhBEAwBBGAABDGAEADGEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAAQxgBAAxhBAAwhBEA\nwBBGAABDGAEADGEEADCEEQDAEEYAAOOEYVRVz62qz1bVF6vqoar6gxn/UFV9rarun689M15V9d6q\nWqmqB6rqikW/CACAzbBzHfs8leRl3f1kVZ2V5DNV9Vdz33/p7o88a/9XJbl8vn45yfvmOwDAae2E\n7xj1qifn5lnz1cd5yP4kH57H3ZPk3Kq6YONTBQBYrHV9xqiqdlTV/UmeSHJXd987d71rTpfdWFVn\nz9iFSR5d8/DHZgwA4LS2rjDq7sPdvSfJRUmurKpfSPKOJC9K8ktJzkvyuyfzxFV1oKruq6r7Dj/5\ng5OcNgDA5jupq9K6+ztJ7k5yVXcfmtNlTyX5kyRXzm4Hk1y85mEXzdizf9ZN3b23u/fu2HXOqc0e\nAGATreeqtBdU1bmz/dNJXpHk75/53FBVVZJXJ3lwHnJHkjfO1Wn7kny3uw8tZPYAAJtoPVelXZDk\nlqrakdWQur27P1FVn6qqFySpJPcnuW72vzPJ1UlWkvwwyZs2f9oAAJvvhGHU3Q8keclRxl92jP07\nyfUbnxoAwNbyl68BAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAI\nIwCAIYwAAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwA\nAIYwAgAYwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY\nwggAYAgjAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAYwggAYAgj\nAIAhjAAAhjACABjCCABgCCMAgCGMAACGMAIAGMIIAGAIIwCAIYwAAIYwAgAY1d3LnkOq6v8k+UGS\nf1r2XE4jz4/1WMt6HMl6HMl6HMl6HMl6HGk7rse/7e4XrGfH0yKMkqSq7uvuvcuex+nCehzJehzJ\nehzJehzJehzJehzJehyfU2kAAEMYAQCM0ymMblr2BE4z1uNI1uNI1uNI1uNI1uNI1uNI1uM4TpvP\nGAEALNvp9I4RAMBSLT2MquqqqvpyVa1U1Q3Lns8yVNXXq+pLVXV/Vd03Y+dV1V1V9Y/zffey57ko\nVXVzVT1RVQ+uGTvq669V753j5YGqumJ5M1+MY6zHO6vq4Bwj91fV1Wvue8esx5er6pXLmfXiVNXF\nVXV3VT1cVQ9V1W/P+LY8Ro6zHtvyGKmq51bVZ6vqi7MefzDjl1TVvfO6b6uq58z42XN7Ze5/4TLn\nv9mOsx4fqqqvrTk+9sz4Gf37ckq6e2lfSXYk+UqSf5fkOUm+mOTFy5zTktbh60me/6yx/5rkhtm+\nIckfL3ueC3z9v57kiiQPnuj1J7k6yV8lqST7kty77Plv0Xq8M8l/Psq+L57fm7OTXDK/TzuW/Ro2\neT0uSHLFbD8vyT/M696Wx8hx1mNbHiPz33nXbJ+V5N757357kmtm/P1J/uNs/6ck75/ta5LctuzX\nsEXr8aEkrznK/mf078upfC37HaMrk6x091e7+1+S3Jpk/5LndLrYn+SW2b4lyauXOJeF6u5PJ/nW\ns4aP9fr3J/lwr7onyblVdcHWzHRrHGM9jmV/klu7+6nu/lqSlaz+Xp0xuvtQd39htr+f5JEkF2ab\nHiPHWY9jOaOPkfnv/OTcPGu+OsnLknxkxp99fDxz3HwkycurqrZougt3nPU4ljP69+VULDuMLkzy\n6Jrbj+X4v+Bnqk7yt1X1+ao6MGPnd/eh2f5mkvOXM7WlOdbr387HzFvnre6b15xa3VbrMac9XpLV\nfwVv+2PkWeuRbNNjpKp2VNX9SZ5IcldW3xX7Tnc/Pbusfc0/Wo+5/7tJfnZrZ7xYz16P7n7m+HjX\nHB83VtXZM3bGHx8na9lhxKpf6+4rkrwqyfVV9etr7+zV9zu37eWD2/31j/cluTTJniSHkrx7udPZ\nelW1K8lHk7ytu7+39r7teIwcZT227THS3Ye7e0+Si7L6btiLljylpXr2elTVLyR5R1bX5ZeSnJfk\nd5c4xdPassPoYJKL19y+aMa2le4+ON+fSPKxrP5iP/7M25nz/YnlzXApjvX6t+Ux092Pz//Y/WuS\nD+THp0K2xXpU1VlZjYA/6+6/nOFte4wcbT22+zGSJN39nSR3J/mVrJ4S2jl3rX3NP1qPuf9nkvzz\nFk91S6xZj6vmFGx391NJ/iTb8PhYr2WH0eeSXD5XDzwnqx+Eu2PJc9pSVXVOVT3vme0kv5Hkwayu\nw7Wz27VJPr6cGS7NsV7/HUneOFdS7Evy3TWnU85Yzzrn/5tZPUaS1fW4Zq60uSTJ5Uk+u9XzW6T5\n/McHkzzS3e9Zc9e2PEaOtR7b9RipqhdU1bmz/dNJXpHVz13dneQ1s9uzj49njpvXJPnUvON4RjjG\nevz9mn9EVFY/b7X2+Dhjf19Oxc4T77I43f10Vb01yd9k9Qq1m7v7oWXOaQnOT/Kx+ezfziR/3t1/\nXVWfS3J7Vb05yTeSvHaJc1yoqvqLJC9N8vyqeizJ7yf5oxz99d+Z1asoVpL8MMmbtnzCC3aM9Xjp\nXF7bWb2K8S1J0t0PVdXtSR5O8nSS67v78DLmvUC/muQNSb40n5tIkt/L9j1GjrUer9+mx8gFSW6p\nqh1Z/cf+7d39iap6OMmtVfWHSf4uqzGZ+f6nVbWS1YscrlnGpBfoWOvxqap6QVavPrs/yXWz/5n+\n+3LS/OVrAICx7FNpAACnDWEEADCEEQDAEEYAAEMYAQAMYQQAMIQRAMAQRgAA4/8Bvw2GDv/f1K0A\nAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "v8HsE-NCmUxx", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## 1D Convolutions" ] }, { "metadata": { "id": "WeP0rw0tm7HK", "colab_type": "text" }, "cell_type": "markdown", "source": [ "You aren't limited to 2D convolutions, a simple 1D demo is below:" ] }, { "metadata": { "id": "jJ-jcAn3cig-", "colab_type": "code", "outputId": "614ed589-e097-4bfe-f596-3421e3492698", "colab": { "base_uri": "https://localhost:8080/", "height": 680 } }, "cell_type": "code", "source": [ "# 1D kernel - WIO layout\n", "kernel = onp.array([[[1, 0, -1], [-1, 0, 1]], \n", " [[1, 1, 1], [-1, -1, -1]]], \n", " dtype=np.float32).transpose([2,1,0])\n", "# 1D data - NWC layout\n", "data = onp.zeros((1, 200, 2), dtype=np.float32)\n", "for i in range(2):\n", " for k in range(2):\n", " x = 35*i + 30 + 60*k\n", " data[0, x:x+30, k] = 1.0\n", "\n", "print(\"in shapes:\", data.shape, kernel.shape)\n", "\n", "plt.figure(figsize=(10,5))\n", "plt.plot(data[0]);\n", "dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n", " ('NWC', 'WIO', 'NWC'))\n", "print(dn)\n", "\n", "out = lax.conv_general_dilated(data, # lhs = image tensor\n", " kernel, # rhs = conv kernel tensor\n", " (1,), # window strides\n", " 'SAME', # padding mode\n", " (1,), # lhs/image dilation\n", " (1,), # rhs/kernel dilation\n", " dn) # dimension_numbers = lhs, rhs, out dimension permutation\n", "print(\"out shape: \", out.shape)\n", "plt.figure(figsize=(10,5))\n", "plt.plot(out[0]);" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "in shapes: (1, 200, 2) (3, 2, 2)\n", "ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1))\n", "out shape: (1, 200, 2)\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlYAAAEyCAYAAAA4KJ7OAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XuwJVd13/Hf6jOMEA9J4BkI1owY\nYQscxXEMNQFSYMfYciIpseTEiUuqpPwIsSoVk9jlR0ouEkzh5A/sipMiUWyLMuVH2cjgVyZluXBs\nMKacCGsAAXogGAvZGllIgwQSDghJp1f+6NN9j65n5uy9u/c99571/VSpZubOGU2fvrdvr9m/vVab\nuwsAAADjNes+AAAAgE1BYQUAADARCisAAICJUFgBAABMhMIKAABgIhRWAAAAE6GwAgAAmAiFFQAA\nwEQorAAAACayb11/8YEDB/zIkSPr+usBAACSfehDH/qsux9c9bq1FVZHjhzR8ePH1/XXAwAAJDOz\nP0t5HVEgAADARCisAAAAJkJhBQAAMBEKKwAAgIlQWAEAAEyEwgoAAGAiFFYAAAATWVlYmdk7zOwh\nM7v9DL9vZvY2MzthZh8zs1dMf5gAAAC7X8qK1S9Iuvwsv3+FpEsW/10n6WfGHxYAAMDes3Lyurv/\nkZkdOctLrpb0S+7ukm4xswvM7EXu/sBEx4gdcvv9j+rj9z+67sPY9UzS677mBXrhec/M/rN/8fkv\n6f2fPDX9Qe0RX/H5j+mCxz6Z9NrDz3+WXnR+/jneCNZIL7tSevaB/D/7yKelT79/+mPaRBf9Heng\ny9Z9FNgwUzzS5kJJ9y39+uTiY3+lsDKz69Staumiiy6a4K/GlH7k3R/VJz7zhXUfxp7wva85oh//\ntr+R/ef+23s/pXf+yX2rX7ihPrD/B3W4iVtYZvnG+6RvfmP+n3vvT0i3/8b0x7OJXvI66bt+e91H\ngQ2zo88KdPcbJd0oSUePHvWd/Lux2uNPzvWtl75QP3H11677UHa1f/C2D+jxJ9uiP/v4k62+8vxn\n6jf/9WsmPqq94eDPSV+66Dv0hW/4D2d93X/6nbv02b/8sn7lX75qh45sl3nby6WnvlT2Z598XDrw\nMgqGVd79PdJTj6/7KLCBpiis7pd0eOnXhxYfwx7TuvScc/bpr0WNXxLtm5natuzfBa279u9rAp/j\nVuc+5wKde+HFZ33Vl899RKf+319K533lDh3XLtPsk9qy4l0+l/adE/fcpXrGudITX1z3UWADTTFu\n4Zik71p0B75a0qPsr9qb5q3LbN1HsfvNzNR6WWE1b11N5JPsrdTMVr5s1pgKa9fNYLPuXJVIPMfh\njTnHwFmsXLEys3dK+iZJB8zspKQfl/QMSXL3n5V0s6QrJZ2Q9EVJ31vrYFFX665Z5Jt+IjPTvLCw\nat3VNIHPcdt2G7NXMFPxquBGMOtWnkq086RzHJ415ecYOIuUrsBrV/y+S/r+yY4Ia9O6axb5pp9o\n1pgK6yq1rWIXr952KwUrdCtWgQurZuSKVcI5Dm/MOQbOgn/WYDBvu9UYnF1jXaRXYu7B41afK+UE\nNCNWBTeCNd3KUwlnxSqJNeX72ICz4OrDoFuxWvdR7H5NMyIKbIOvCrbzpP0/jVnse57NRkSB7LFK\nQhSISriNYsAeqzQzM/mIPVahC6vkKFBEgaOiQL61r0QUiEq4+jDougID3/QTNWYjosDgcWtiTDXm\nHG+EMTEVUWCaMXErcBZcfRiEj6kSNY1pXnjPa1vXLOopdk8eBdBE37w+JqZKjFvDGxO3AmdBYYVB\n66KwSjBrRBRYoj9nKVGgBZ9jRVdgfUSBqITCCoPwHWuJxnSshY5b+9WBpCiwvPNyI9AVWB9dgaiE\nqw+DLqYKetPPMGb/T+gGgb5QaBIKq6b8sUEbga7A+ogCUQmFFQahY6oMowaERo5b+9glOQoMXFjR\nFVhf0xAFogquPkjq9gy10TvWEo0aEBr5eYw5UeCIWWEbga7A+ugKRCVcfZC0ta84bEyVYcweK4+8\nKjhEgQwIXYmuwPqIAlEJhRUkaSgUot7zc3RRYPkjbZqoxWtOFMiAULoCa6MrEJVQWEHSVrTVUFmt\nNGpAaCsKq9QBoZELK7oC66MrEJVw9UHSUhRIYbVSN7yy7M965Ocx9oVVSlegdQ0CpSuDe56NXLEi\nClxtzDkGziLqt3hsQxSYrrHymGreBo4C25w5Vt05CjtxYdQeK7oCk5ixxwpVcPVB0lIUGPWmn2E2\n6lmBHjduHboC0/ZYSYGHhDYzugJra2Z0BaIKrj5I2opciAJXGxcFBu68HKLAtGcFSoE3sNuIGUtE\ngWmIAlEJhRUksWKVozEVTwXvosCJD2ivKIoCIxdWI8YtsGK12phzDJwFVx8kLe2xCnvXTzcbMbxy\n3kaOAvMmr0vRo8AxXYGsWK3Uj1uIWryjGgorSGJAaI5mxONWPPKzAouiwJoHtIvRFVhfX3xSWGFi\nFFaQtBwFrvlA9oBuKjgDQrMNUeDq999/HYZ9EDNdgfX154g4EBPj6oOkrb0sYWOqDGOiwNYDn+Os\nrsBFFBh1NYGuwPr6eWp0BmJiXH2QtPU9POxqSoYxz7FrI29ez4kC2bxOFFjbEAXSGYhpUVhB0tbK\nQNip4BlGDQjlIcx5XYFR73l0BdZHFIhKuPogaSkKZMVqpVlTPiC0jTx5vWRAaNQVK7oC6+tX9YgC\nMTEKK0ja2iQc9qafwax8QGjrgc9xXyQlrKbYsGIVtLAaGwWyYrXasGIVdVkUtXD1QdJyFBj0pp9h\n1ox7VmDYuLVfGUh4CPMs/B6rWVlE1abvYwuPPVaoJOq3eGzD5vV0sxFzrNrQzwrMGBAafY5VUzjH\nKuMch9ewYoU6KKwgaXmP1ZoPZA+wEQ9hbiPPsfL0zev9KQo7ed2asp37nj4rLDxj3ALqoLCCpK0b\nGFHgarNmxIDQNvDk9SEKzFmxClxYFUWB6ec4vCEKpLDCtCisIIkBoTlmzcjN61HPccGzAsMWVkSB\n9TXssUIdFFaQxLiFHGZlYwDa6I8NyooCgz+E2ZqyiCrjHIdHFIhKuPogaWuTcNiYKsOs8FmBffEa\n9hxndKwNUWDUxYTirkCiwGR0BaISCitI4iHMObooML+wmkePW4eYKmHcQr+YQBSYZ5gVRmG1ElEg\nKqGwgqSlmCrqTT9DPyDUM2/64UdalESBUQsrugLr688RUSAmRmEFSUtRIIXVSlsbq/P+XBv9eYw5\nXYFMXicKrI2uQFQS9Vs8tpkzxypZaUw1j94gwIDQdHQF1kcUiEoorCCJZwXmKO1YC3+OGRCajq7A\n+ugKRCVcfZC0HFMFveln6M9R7vaf8HFrf8IyosDcfWwbw0auWBEFrkZXICqhsIKk5a7AoDf9DH1d\nlLuxOnznZZu+mtI3UYTevD5mjxUrVqv154jCChNLuvrM7HIzu9vMTpjZ9af5/YvM7H1m9hEz+5iZ\nXTn9oaImBoSma0qjwPDjFjIKq+gDQvsVp9zOQPZYpWuIAlHHyu9wZjaTdIOkKyRdKulaM7t028v+\nvaR3ufvLJV0j6X9MfaCoK3xMlWErCiwrrMIOCM2IqUrj1o1RGlMN55gVq5WIAlFJytX3Skkn3P0e\nd39C0k2Srt72Gpd03uLn50v6i+kOETshfEyVoXQ1JXzcmhMFht+8vjgBuXEgUWC6IQpkxQrT2pfw\nmgsl3bf065OSXrXtNW+W9Htm9m8kPVvSZZMcHXZM+JgqQ+n+n2FAaNRzPESBq1esmugDQocoMPOm\nn3GOwys9x8AKU/2z5lpJv+DuhyRdKemXzf7qP5nM7DozO25mx0+dOjXRX40phI+pMmx1rOX9ufAD\nQouiwKCF1egokMJqJaJAVJLyLf5+SYeXfn1o8bFlr5f0Lkly9/8r6ZmSDmz/H7n7je5+1N2PHjx4\nsOyIUcU8+uNWMpTGVOEHhLbpzwrciltrHtAuVhpTEQWmIwpEJSlX362SLjGzi81sv7rN6ce2vebP\nJX2LJJnZX1dXWLEktYdsRYFrPpA9YIgCMwsrj15YZXQF9qt6RIG5USBdgclKOy+BFVZ+h3P3pyS9\nQdJ7JN2lrvvvDjN7i5ldtXjZD0v6PjP7qKR3SvoeD7uGvzeFnwqeoTQKDL8qmBFTNQwI7X7Mff90\nBaZjjhUqSdm8Lne/WdLN2z72pqWf3ynpNdMeGnbSnMnryZrC1ZR+hSvsHqusrsDgc6zoCqyPKBCV\ncPVB0tYcq7CrKRlGDwiNeo4zOtZmhXHrxqArsD66AlEJhRUkLUeBaz6QPWDsgNC4hVVGFMiA0O5H\nugLroSsQlVBYQdJyTBX0pp+hdMZS+HOc1RXY/Rh28zpdgfURBaISrj5IYkBojr6wym0mGuLWqOc4\no2OtbxBooxZWzcgVK6LA1ZrCBgFgBQorSCKmytHXRbk3/a1zPPUR7RHD/p/VJ8CG4jXoTa9fTSke\nt8C39pVKzzGwAlcfJG2NAmDy+mqlG6uHKDDqOW7n3c0s4f2H37xeuv+nLxIYt7AaUSAq4eqDJAaE\n5uijvOIVq6hLVt4mR1RbUWDNA9rFmsIZS0SB6UrjVmAFbqOQxIDQHE3h/p82/IDQeXJENaQ0Ufe/\nFEeBbF5PRhSISrj6IGlrZSBsTJVhVvgcu/APYW7nyWMAZuEHhParKYVdgYxbWK30HAMrRP0Wj236\ntnbqqtWawtWUrXMc9CS7p0eBTfQokK7A6ogCUQmFFSR1UWBjgW/6GZrCjrU2+ub1nCiwsPNyY9AV\nWN9wjimsMC2uPkjqbmBhB1dmGjrWsjevP/3Ph9POk7sjiALHdgWyYrUSXYGohMIKkroigdWqNE1h\nx1pfJIQ9zTldgYWdlxtjdFcg39pXIgpEJVx9kNTFVGEjqkzDgNDChzCHXbHKigIZECqJrsCa6ApE\nJVx9kNStvoS94WcqHV45FFZRC9iMrkCpO89xnxVIV2B1dAWiEgorSOqKhKj3+1ylc6y2osCgJzqj\nK1DqCtCoC1Z0Be4AokBUQmEFSWxez1E8IJQoMCuiMiMKJAqsiK5AVMLVB0mLwirqSkqm0hlLbfTn\nMXqb9cykWWNxN6+XdgX254socLXScwysQGEFSd0U8bARVaZ+wSn7IczRh7C2eStWjVn2dPuNYYVd\ngS0rVsn6C5E9VpgYVx8kSe4e91ErmUofwuxEgVl7fxpj3AJRYEVm3XmiKxAT4+qDpG71JezDgTPN\nijevdz+GPc/eZncFhi2siqPA/ouMKDCJzYgCMTkKK0jqYqqwN/xMTeFDmPsoMGOb0WYpigKjFlaF\nU8GJAvNYQxSIyXH1QVK35zVsRJVpSGkyb/oefY5VxuR1qYtcw65Y9StOxVEgK1ZJmhlRICZHYQVJ\nfRS47qPYG0oft9KvvoRdGcztCjSL2wlPFLgzbLbVSQlMhMIKkhZRIJVVkiEKLC2sop7n7Cgw/xxv\njOKuQJ4VmIUoEBVw9UHSoisw6kpKpqbwOXZ9jRA2cs3tCmws7oDQ0VEg39qTNHQFYnpcfZBEV2CO\n0gGhw+b1qKeZrsB0pStW3nZ/lms5DV2BqIDCCpK6DrewEVWm4gGh0fdYlXQFBq2rRnUFslqVjigQ\nFXAFQhIDQnMwILRQbldg5GcFjokC6QhMR1cgKuBWCknMscrBgNBCPCsw3ZiuQDoC09EViAoorCCp\n2y8U9oafqXRAaBt9jxUDQtMVR4EtUWAOokBUwBUISV3kEvaGn2kYEJr5L93WvXs8WdQCNjsKtOwG\ngY3RrzrlrqZknuPwmobN65gchRUkdRurw+79yVQ6bmHeBh9p4ZkrVk3ghzD356loj1Xgr7FcPIQZ\nFVBYQVJ3AyMKTDMrHBDaevDOy3aeN26BKLCsK5A9VulsRhSIyVFYQRKFVY6mcI5Vd44rHNBe4c6z\nAlM1IzavEwWma5hjhelRWEESUWCuklEARIF5MVW3xypoYTUqCuTbejKiQFTAFQhJxFS5Zo0VRIHB\nn8dIFJhuGLdQ0BVIFJiOyeuogMIKkoipcpWsprTRHxuU2xXY5MetG4MocGfQFYgKKKwgiZgqV2P5\nDwiee/C4NbcrsOAcb4whCswtrOgKzEIUiAoorCCJKDDXrLGCAaHBh7DmRoEFcevGoCtwZ9AViAoo\nrCCJAaG5GisYEBr9HOd2BUYeEGq2mApOFFgVXYGoIKmwMrPLzexuMzthZtef4TXfaWZ3mtkdZvar\n0x4magsfU2UqGQUQvvMyOwoM/BBmqSymoiswD1EgKti36gVmNpN0g6RvlXRS0q1mdszd71x6zSWS\nfkzSa9z9c2b2gloHjDqYY5WnpGONKHCe/RDmsF2BUllMRRSYh65AVJDyXe6Vkk64+z3u/oSkmyRd\nve013yfpBnf/nCS5+0PTHiZqC9+xlqlbscr7M924hTrHsycUPSswcGFVElNlxq3h0RWIClK+zV8o\n6b6lX59cfGzZSyW91Mz+2MxuMbPLT/c/MrPrzOy4mR0/depU2RGjitYVO6bKVBJTtR6887KkKzBy\nYWUNXYG1EQWigqn+/bxP0iWSvknStZLebmYXbH+Ru9/o7kfd/ejBgwcn+qsxhXnrfD/OMCu46c+j\nrwp63vDKWcGq4EYpiakyz3F4RIGoIKWwul/S4aVfH1p8bNlJScfc/Ul3/7SkT6ortLBHhF9NyWTG\n5PVsbZu1YmXhN69b2R4rNq+ns4ZxC5hcyhV4q6RLzOxiM9sv6RpJx7a95rfVrVbJzA6oiwbvmfA4\nUVlLV2CWWZM/vLJtFbt49XnW/p/Qc6ykbuWpqCuQFatkJecYWGFlYeXuT0l6g6T3SLpL0rvc/Q4z\ne4uZXbV42XskPWxmd0p6n6QfdfeHax00pjdvu1UYpCmJqeYePG71Nq8rMPweK6LA6mzWbfgHJrRy\n3IIkufvNkm7e9rE3Lf3cJf3Q4j/sQd2K1bqPYu8wU34UGH2OVWZMZWbZe7c3SklMlRm3hlcStwIr\ncAVCEnuscs0KnmMXPm7NjgIVe45VMyvsCuTbejKiQFTAFQhJfVdg4Jt+plnJ5HUPHrcWdQUGLqyI\nAuujKxAVUFhBEjFVLrOChzC3rlnUU+y+GBCaGQWGLqzoCqyOrkBUwBUISQwIzTVrCh7CHDkK7FcF\ncqLAgscGbRS6AusjCkQFFFaQRMdartIBoWGjwL6wynxWYOS6iihwB9AViAoorCBJcjavZ7GC1RT3\nwHOs+lUBBoSmoyuwPqJAVMAVCEk8biVX2eb1wA9h7m9euVFg5NWE4igw6hdZgYZnBWJ6XIGQu6t1\nxX7cSqZZwYyl0MXrEAXSFZisJKYiCsxDVyAqoLDC8L07bExVoGRAqEfevF4UBUYfEEpXYHVEgaiA\nKxBDgRD1nl9i1pi8JAqMWryWdAUWdF5ulKZw8zpdgelKzjGwAoUVhk3YRIHpmoLN6/NWFFYZqylN\n9D1WVrD/hz1WeazJn24PrMAViK0okMIqWdOY5rnbXyI/j7EvEDJ27zdm3VzRqMWVzcq6Atljla7k\nHAMrRP02jyVEgflmln/DZ/O6MqPA7lyFnbhAFFhf0xAFYnIUVtiKAqPe9AsURYHuceNWz9+83p+q\nsNPXS2Iqn4tJvxlK4lZgBQorDCsvRIHpmoYBoVmGKDB9NaUZVqwCF1YlXYFEgemIAlEBhRVYsSow\nW+z/ydFFgXWOZ9crfFagFLiwIgqsj65AVEBhhWEPS9iYqkDT5M+xakNHgWVdgVL0KJCuwKpssccq\navGOKrgCMawIRL3nl2jMsp9j10bevD4mCoy6oEBXYH396h6rVpgQhRWGFYGw+38KlD4rMOw5Llix\nmi1OFVFgBqLAPP34DworTIjCClsrVixZJSsZXhn6eYwlXYGLcxV2SChdgfX1X490BmJCFFYYvneH\njakKNAXPsWsjb14viQL7zeuR91jRFVjXEAVSWGE6FFYYVgTCTgUvUPIcu3nkhzAzIDQfUWB9DXus\nMD1upVjavB70pl+gZEBo6M3rYwaEho4C6QqsiigQFXAFYohawt70CzSNZa+ktB74HPfFEVFgOitc\nsSIKTEdXICqgsMJSFBj0pl+gsYIosOUhzDkbq5voA0Jz91i1+Z2X4RldgZgeVyDYvF5gVhIFhh4Q\n2hdW+Xuswg4IbWZ5XYEF5zi8higQ06OwAgNCCzQFc6xaj7zHqq/eS54VWOOA9oDcKHA4x3xbT0YU\niAq4ArE1IJTKKlnJ5PV5G3hAaFu+eT1uFGiZUWD+OQ5viAJZscJ0uALBgNACs4YBoVlKosDozwps\nZnkRFVFgvn4FlSgQE6KwAuMWCjSW1xW41XlZ6YB2u1FRYNDCqjgKpLBKRhSICiisMBQIYWOqAo3l\njQHoi4Ow57igY21r3EKNA9oDsrsCiQKz0RWICrgCMUQtYVdTCuRGgfPocWvBgNB+NEXYAaHZXYH5\n0+3DIwpEBRRW2Iqpot70CzRmcpc88aYffqRFSRQYfo4VXYHVsWKFCrgCsRUFUlgl27rpp72+jf48\nxqKuwOiT1+kKrI6uQFTAFYitmIq6KtkQUyXe9OfRGwQYEJqPrsD6iAJRAYUVeFZggdyOtfDneFQU\nWOOA9gC6AuujKxAVUFhhKaYKetMvkLv/J3zcWtQVuPijYfdY0RVYHVEgKuAKxFJXYNCbfoFZ5mpK\n+M5Lzy+sZtHnWDWFK1ZEgen61b2oX2OogsIKDAgt0J+q1P0/4afbF4xbsOiT1/tzlTpyoaB4Da+/\nkNljhQlxBYKYqsCwmpJZWMUdELq4cWXs/wm/YjXs/0m86Q/nmG/ryXLPMZAg6Qo0s8vN7G4zO2Fm\n15/ldd9hZm5mR6c7RNQWPqYqkHvTDx+3FsRUs+iT15vMGUtEgfkaNq9jeisLKzObSbpB0hWSLpV0\nrZldeprXPVfSD0j64NQHibrCx1QFhpgqsbDqXxb2HBdFgd2PYSevD1Fg4mpKwTkOL/ccAwlSrsBX\nSjrh7ve4+xOSbpJ09Wle9xOS3irp8QmPDzuAPVb5cldTwq8KjokCw+6xKo0CWbFKRhSIClIKqwsl\n3bf065OLjw3M7BWSDrv770x4bNgh80VxEHb/T4F+QGhyFBh9pEV/ngoGhEatq7JjKqLAfESBqGD0\nmrGZNZJ+WtIPJ7z2OjM7bmbHT506NfavxkS2osA1H8gektux5tFXBYeYKv39N0SB3Y/JUSBdgdly\nOy+BBClX4P2SDi/9+tDiY73nSvpaSX9oZvdKerWkY6fbwO7uN7r7UXc/evDgwfKjxqTCTwUvMMsc\nEDqP/hDmgpiKZwVmrqbQFZiPAaGoIOUKvFXSJWZ2sZntl3SNpGP9b7r7o+5+wN2PuPsRSbdIusrd\nj1c5YkwufExVIDem6le2wj6EuaQrMPq4BboC6yMKRAUrv827+1OS3iDpPZLukvQud7/DzN5iZlfV\nPkDU1xcHYVdTChQPCI16jgs61hoGhHY/0hVYD12BqGBfyovc/WZJN2/72JvO8NpvGn9Y2Elt9I61\nArmrKeELq5IoMPqKFV2B9dEViAr4pw2WYqqgN/0C+Xusgp/jkq7AzOcxbhy6AusjCkQFFFZgQGiB\n3K7AIW6Neo6LosDuR6JAosBq6ApEBVyBIKYq0K88paZUW+e41hHtcv2KQEbHWjOc46iFVe6KVV+9\ns2KVzDIbBIAEFFbYeggzhVWy3NWUfh9b2HPczrNXUti8nnnTb/NnhYXHuAVUQGGF4cbF9+N0/WpK\n6vDK/nUW9ST7PHvvz2x4HmONA9oDmtIokBWrZP3qHl2BmBCFFbZWU8LmVPn6m35qTNVv4Qh7jr3N\njqiGMU5EgWmvH+JWCqtkuecYSEBhBaLAAlsxVdrr22EIa60j2uWIAvPlxlQtm9ezEQWiAq5ALMVU\naz6QPaRfTUm96RMFtvlRYGbcunFyYyqiwHxEgaiAwgpqW1djgW/6BfKjwOCb173NfoZdY3mdlxuH\nrsD6hnMc9YsMNVBYQa173L0/hXI3rw9xa9TzXBQFdj8SBeZ2BfJtPVn/Dx2iQEyIKxCau7NalSl3\n/0/4zsuSrsAm+B6r4q5Avq0nIwpEBVyBUNt63IiqUOmA0LArVgVdgWYmM7oC6QqsiK5AVEBhBbUe\n+IZfKHtAqAffY1UQBUrdymDYzet0BdZHVyAq4AqE5q3HjagKDVFg5kOYw0auBV2BUleIpo602Dh0\nBdZHFIgKKKwgZ/N6tlnmc+w8+ub1gq5AqfsjRIFEgdXQFYgKKKyguTsPYM6UOyC0X7GKWleNigKj\nbl7PjgLbp/85rEYUiAq4AqF5KwqrTP0E9TbzWYFhz/OIKDBqXTWsPKWupjiFVbYmc6QFkIArEIso\ncN1Hsbf0e6VSCysP3xVYtmJlln6ON05fhDNuoS5r2GOFSXEFQvOWKDDXLHuOVfdj2PPczov2/sya\nyFFgv2KV2RXIHqs8NiMKxKQorMAeqwL9ylPqPX+IAqNecaVRYGNxV6yaws3rdAXmaWZEgZhU1G/z\nWOLMsco2pDSJlZVHn2PlbWEUGLiwMiav7wiiQEyMKxCLKHDdR7G3zDKfFbjVFRj0RLfzouW6Weiu\nwNwokHELRYwVK0yLwgpdFEhllWWWuXl9KKyinudRUWCF49kLirsCKayyNA2FFSZFYYWuKzDqSkqh\noSswOQrsfgwbuY7pCoxaWRVHgUG/xkoRBWJiFFagK7DAEAWmdgU6A0KLuwKj77HK6Qq0hsIqF12B\nmBiFFboBoWHv+GW2osC014ffY8WA0HwlXYHEgPnoCsTEKKzAgNACljl5nQGhpV2BRIFZUSAdgfmI\nAjExrkIwx6oAA0IzlXYFMiA0LwqkIzAfXYGYGIUV1HrgG36h3AGhbfQ9VoUxVRN5jlV2FOhEgSXo\nCsTEKKygljlW2YYBoYk3/dZdZlvdhOEUxlShC6shCkwtrIgCixAFYmJchdC89bh7fwrNMsctzNvg\nIy28Le4KjJoEbnUFZmxeD/vMpBGIAjExrkKoZY9Vtv58pY4CaD1452VbumKVvo9t45SOW0Aeaxi3\ngElxFYLCqkBfJKWuWHXnuOYR7XKle6x4CHNmVyB7rLI1M6JATIrCCkSBhXJiKqLAsphqFnqPVcEc\nK7oC8xEFYmIUViCmKtRYThQ7JbU6AAARdElEQVQY/HmMxVFg5HELuVFg2ayw8IyuQEyLqxDEVIUa\ns/QoMPpjgwpjqqZJb4rbOEMUmNMVyIpVtoauQEyLwgrEVIVmGft/5h48bh3VFRh1xaokCuRbejai\nQEyMqxBEgYW6mCrtteGHsI6JAsMWVouvF7oC66IrEBPjKgQDQgs1ljEgNPo5HjN5PeweK8sbXkkU\nWIauQEyMwgpqo8dUhXJiqvDnuLQrMPKAUCkvpqIrsAxRICZGYQXN3eM+amWEnI61eUsUyIDQAjkx\nFVFgGboCMbGkq9DMLjezu83shJldf5rf/yEzu9PMPmZmf2BmL57+UFFLy+b1IjnDK7txC5UPaDcr\n7QqMPMdKyoupCuPW8OgKxMRWfqs3s5mkGyRdIelSSdea2aXbXvYRSUfd/esk/bqkn5z6QFFP64od\nUxWamSV3wrcevHilK7CMzaTU909XYBmiQEws5Sp8paQT7n6Puz8h6SZJVy+/wN3f5+5fXPzyFkmH\npj1M1DRvXZHv+aVyBoTOo8+xYkBoGaLA+ugKxMRSrsILJd239OuTi4+dyesl/e7pfsPMrjOz42Z2\n/NSpU+lHiarCr6YUyo8CA59j9+JnBUZesFKTsf+HKLBMw4oVpjXpP2/M7J9LOirpp073++5+o7sf\ndfejBw8enPKvxgjhO9YK5U1eV+zi1UdsXo9cWWWPW2DFKlvOOQYS7Et4zf2SDi/9+tDiY09jZpdJ\neqOkv+vuX57m8LAT5q3oCiwwa0zz1Icwe/C4tZ0XP4Q5dhQ4y4sCGbeQjz1WmFjKd7pbJV1iZheb\n2X5J10g6tvwCM3u5pJ+TdJW7PzT9YaKmbsVq3Uex9+QOCA29Klg6IDR8FJgzx6osbg0vJ24FEqy8\nnbr7U5LeIOk9ku6S9C53v8PM3mJmVy1e9lOSniPp3WZ2m5kdO8P/DrsQe6zKZEWB0ePWMVFg6BWr\njKdQ+1yxl0ULEQViYilRoNz9Zkk3b/vYm5Z+ftnEx4Ud1HUF8g0516zJGBDqwePWwpiqi1sjF1aZ\nUeC+c+oezybKOcdAAgIgyJljVaQbXpn2WnfXLOopdpdU2BVoJo9cWNEVWB9dgZgYhRUWM5bWfRR7\nT9Ok77EKPceqv2kxxyofXYH1EQViYlyF0Dz6jKVCOR1r8zbwOe5vWoUPYY5dWNEVWB1dgZgYhRUW\nMVXQm/4IOQNC3QPPsRpWrEqjwImPZy+hK7A+ugIxMQorxI6pRsh5QPA88kOY+xUXBoTmy44CuY6z\nEQViYlG/1WPB3dW64sZUI2RHgVFvekMUWNgVGD4KTH3SN1FgEboCMTEKq+D6xYCwMdUI3eb1tNd6\n5DlWY6LA8ANC6Qqsjq5ATIzCKrg+Zol6zx8jZ0Do3AOvWI3qCiQKpCuwMqJATIyrMLg+ZiEKzDfL\n2Lw+b0VhVRIFZuxj20g5UaC3RIElbKbYy6KYGoVVcEMUSGGVzSz9Icwe+XmM/WpAQWFpi67AsENC\nrckYt9CyYlUi5xwDCbgKgyMKLDczpUeBkTevD12BZZvXpcDPC2xmmVEgK1bZGqJATIvCKrghCox6\n0x8hKwqMPIR1TBS4OGdR66qsmMrboiGs4TEgFBPjKgyuj1iIAvNZxriF0ANC2/I5Vv0pC7vPyixv\n8jpRYD6iQEyMqzA4VqzK5WysDv08xjFRoBEFEgVWlnOOgQQUVsH196uwMdUIXRSY9to2dBTYf5GN\niQKDFlZ0BdZnM0lOZyAmQ2EVXMvm9WKWsXm9jbx5fVQUuCisom6BoSuwvv6csc8KE+EqDK6PWMLu\n/xlh1ljy8Mp55Addj3hW4GxxysIOCSUKrK/f8E8ciIlQWAU3rFixZJUtZ49V6OcxTtIVGLSwoiuw\nvr4YZcUKE+EqDK6PWMLGVCOYWXJE1UbevD5JFBi1sKIrsLohCmTFCtPgKgxuPoxbWPOB7EGzJr1b\nbR76IcwTDAiNumJFFFhfv5JKFIiJcDsNbmvzetCb/gg5A0JDb14f+axAKfqAULoCqyIKxMQorIJr\nmWNVzHL3WEU9x31eOmZAaNTKKrUr0L0rDIgC89EViIlxFQY3Z/J6sW7zetpr523ghzB7eWEVfvN6\nk7hi1Z8fosB8DStWmFbUb/VYYPN6ucbS91jFHhBavnm9iT553Zq0IV4jznF4w7Ioe6wwDa7C4BgQ\nWq5pLH1AqAfeY9XfsAr2/zTRV6xslhYFDueYb+nZhj1WFFaYBldhcC1RYLHcOVZxB4T2USCb17M1\nTWIUWH6OwyMKxMQorILjIczlmpzJ65HnWI2KArsfY0eBKZvXiQKLGZPXMS2uwuCYvF6uSRwQOnRe\nRj3Hbfm4hf6cxS2scqNAVqyyEQViYhRWwfX3q7Ax1QizJm3vzxC3Rj3HE0SBUbdYpXcFEgUWG6LA\nqF9kmBqFVXBbUeCaD2QPaiwtCpxHXxUcYqr899/vxQ47eT25K7B8pEV4RIGYGFdhcOFjqhEas24u\n44qbfviRFmO6AsOPW6ArsDqeFYiJcRUGN0SBFFbZtoZXnv11bfTnMY6JAps+CgxaWNEVWB9dgZhY\n1G/1WJgzx6pYasfaPPrzGBkQWo6uwPqIAjExrsLgeFZgudThleHP8RRRYNQVK7oC66MrEBOjsAqO\nAaHltoZXriisosetI55jtxUFTnlAewhdgfURBWJiFFbBMSC0XGpMFb7zckxXIANCuxv+qsqSrsBy\nQxRIYYVpcBUG19+vKKzyNYmb1z38uIXxA0JDPytQSi+siALzDV2BFFaYBoVVcFuT19d8IHtQXyet\nehBz+M3r7fjN63ELq8RRAG35qmB4jFvAxLidBtdHLGGngo/Q7/9ZtbE6/DkeosDyyevzqIsJTWLH\n2ohzHF6/ykdXICZCYRUczwosl7qa0v922HM8Kgrsfoy7YpW4sZoosFzqOQYSJRVWZna5md1tZifM\n7PrT/P45ZvZri9//oJkdmfpAUUcbPaYaYSisVnw/Dr95fYooMPLmdSkjCuTfytmIAjGxlVehmc0k\n3SDpCkmXSrrWzC7d9rLXS/qcu3+1pP8i6a1THyjq6COWsDHVCP0k9ZVRYPSRFhNMXg87xyo1piIK\nLDecY1asMI19Ca95paQT7n6PJJnZTZKulnTn0muulvTmxc9/XdJ/NzPzNT6H4rN/8Wc6eccfr+uv\n3zOefOAxXdY8pGfd+4R06px1H86ecuGDn9Vlzb360w98Vqee+Ywzvu6RLz6hy5o/14s+8znpE8/f\nwSPcJR68vfuxoEOiX7H66H2f1zn74hUNhx/8or5G0sf/8N2a73v2GV/3rC/cq5dK+vDJx/Tw/MEd\nO75NcN7Dn9erJN3zkffqsXsfWvfhoNCBi79Oh776a9d9GJLSCqsLJd239OuTkl51pte4+1Nm9qik\nr5D02eUXmdl1kq6TpIsuuqjwkNPcd/sH9PL/8/1V/45N8PWSrt0v6X+t+0j2ntdKeu1+SR9a/dpv\n3i/plsV/Ee17ZvdfpvPO3Scz6e0f+LTe/oFPVziw3e3bm0f0X/dLf/ODP5L0+rf8/gO6zY9XPqrN\n8mL7jN5/jvSSO29Y96FghFse+Ld7qrCajLvfKOlGSTp69GjV1ayX/O3LdeKFv1Pzr9gY5537DL3g\nOaxW5XK5/vyRL+nJhJa1c/Y1OvS8c2UKGgc++6D0jHOz/9gLnvtM/dGPvk6PfunJCge1B/hr9KnP\nfbusXf3+233P0n+84Kt24KA2zycfO6rmiS+s+zAwwle/sO5iTY6Uwup+SYeXfn1o8bHTveakme2T\ndL6khyc5wkLnP++Azn/ea9d5CNhwJunFX7nuo9h8h5//rKd9Awrn0PaAAJO78G+t+wiwQVI2Pdwq\n6RIzu9jM9ku6RtKxba85Jum7Fz//J5Leu879VQAAAOuwcsVqsWfqDZLeI2km6R3ufoeZvUXScXc/\nJunnJf2ymZ2Q9Ii64gsAACCUpD1W7n6zpJu3fexNSz9/XNI/nfbQAAAA9hamyQEAAEyEwgoAAGAi\nFFYAAAATobACAACYCIUVAADARCisAAAAJkJhBQAAMBFb14B0Mzsl6c8q/zUHtO1B0MHw/nn/Ud9/\n5Pcu8f55/3Hff833/mJ3P7jqRWsrrHaCmR1396PrPo514f3z/qO+/8jvXeL98/7jvv/d8N6JAgEA\nACZCYQUAADCRTS+sblz3AawZ7z+2yO8/8nuXeP+8/7jW/t43eo8VAADATtr0FSsAAIAdQ2EFAAAw\nkY0trMzscjO728xOmNn16z6e2szssJm9z8zuNLM7zOwHFh9/s5ndb2a3Lf67ct3HWoOZ3WtmH1+8\nx+OLjz3fzP63mX1q8ePz1n2cNZjZy5Y+v7eZ2WNm9oOb/Lk3s3eY2UNmdvvSx077+bbO2xbfCz5m\nZq9Y35FP4wzv/6fM7BOL9/hbZnbB4uNHzOxLS18HP7u+Ix/vDO/9jF/rZvZji8/93Wb299dz1NM5\nw/v/taX3fq+Z3bb4+EZ97qWz3ut2z/Xv7hv3n6SZpD+V9BJJ+yV9VNKl6z6uyu/5RZJesfj5cyV9\nUtKlkt4s6UfWfXw78P7vlXRg28d+UtL1i59fL+mt6z7OHTgPM0mfkfTiTf7cS/pGSa+QdPuqz7ek\nKyX9riST9GpJH1z38Vd6/39P0r7Fz9+69P6PLL9ur/93hvd+2q/1xffAj0o6R9LFi/vCbN3vYer3\nv+33/7OkN23i537xns50r9s11/+mrli9UtIJd7/H3Z+QdJOkq9d8TFW5+wPu/uHFz78g6S5JF673\nqNbuakm/uPj5L0r69jUey075Fkl/6u61n2qwVu7+R5Ie2fbhM32+r5b0S965RdIFZvainTnSOk73\n/t3999z9qcUvb5F0aMcPbAec4XN/JldLusndv+zun5Z0Qt39Yc862/s3M5P0nZLeuaMHtYPOcq/b\nNdf/phZWF0q6b+nXJxWoyDCzI5JeLumDiw+9YbEE+o5NjcMkuaTfM7MPmdl1i4+90N0fWPz8M5Je\nuJ5D21HX6OnfVCN87ntn+nxH/H7wL9T9K713sZl9xMzeb2bfsK6Dqux0X+vRPvffIOlBd//U0sc2\n9nO/7V63a67/TS2swjKz50j6DUk/6O6PSfoZSV8l6eslPaBumXgTvdbdXyHpCknfb2bfuPyb3q0J\nb/RsETPbL+kqSe9efCjK5/6viPD5PhMze6OkpyT9yuJDD0i6yN1fLumHJP2qmZ23ruOrJOzX+jbX\n6un/sNrYz/1p7nWDdV//m1pY3S/p8NKvDy0+ttHM7BnqvtB+xd1/U5Lc/UF3n7t7K+nt2uPL4Gfi\n7vcvfnxI0m+pe58P9ku+ix8fWt8R7ogrJH3Y3R+U4nzul5zp8x3m+4GZfY+kfyjpny1uLlrEYA8v\nfv4hdfuMXrq2g6zgLF/rkT73+yT9Y0m/1n9sUz/3p7vXaRdd/5taWN0q6RIzu3jxr/hrJB1b8zFV\ntcjWf17SXe7+00sfX86S/5Gk27f/2b3OzJ5tZs/tf65uE+/t6j7n37142XdL+p/rOcId87R/rUb4\n3G9zps/3MUnftegOerWkR5cig41hZpdL+neSrnL3Ly59/KCZzRY/f4mkSyTds56jrOMsX+vHJF1j\nZueY2cXq3vuf7PTx7ZDLJH3C3U/2H9jEz/2Z7nXaTdf/Onf31/xPXSfAJ9VV6G9c9/HswPt9rbql\nz49Jum3x35WSflnSxxcfPybpRes+1grv/SXqOn8+KumO/vMt6Ssk/YGkT0n6fUnPX/exVjwHz5b0\nsKTzlz62sZ97dQXkA5KeVLdn4vVn+nyr6wa6YfG94OOSjq77+Cu9/xPq9pL01//PLl77HYvr4jZJ\nH5b0bes+/grv/Yxf65LeuPjc3y3pinUff433v/j4L0j6V9teu1Gf+8V7OtO9btdc/zzSBgAAYCKb\nGgUCAADsOAorAACAiVBYAQAATITCCgAAYCIUVgAAABOhsAIAAJgIhRUAAMBE/j9lFTjh2FpkrwAA\nAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlUAAAEyCAYAAADTHyXNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3XmQJGl53/HfU2f2VT0sO+x9wXJo\nwcDiAaEDJBAWC5LAQiEZrAOMHBsKIYfQYYwChwx2KGxZssJhCVteB4SwQAJhaQ0hQAIUyAosrlm0\n7AnscggWL8vsLtNV3dN1v/4j6+iuzqzqrMqqrK73+4mY6Jnqruq3srorf/Pk+z6vOecEAACA2eSy\nHgAAAMAqIFQBAACkgFAFAACQAkIVAABACghVAAAAKSBUAQAApIBQBQAAkAJCFQAAQAoIVQAAACko\nZPFNL774Ynfttddm8a0BAAASue222x52zp2e9HWZhKprr71WZ8+ezeJbAwAAJGJmf3+cr+PyHwAA\nQAoIVQAAACkgVAEAAKSAUAUAAJACQhUAAEAKCFUAAAApIFQBAACkYOZQZWaBmX3azD5nZneb2VvS\nGBgAAMBJkkbzz4akFzrnds2sKOnjZvYh59wnU3hsAACAE2HmUOWcc5J2e/8s9v64WR8XGWnUpHve\nL3VbWY9kuTzuBumq56TzWM5J9/xvqb6TzuMto3xJuuHlUmkj65EAwMKksk2NmeUl3Sbpeklvdc59\nKuJrbpZ0syRdffXVaXxbzMMd75E+8CtZj2L5bF4i/eoX03msb94hvfc16TzWMnNOuvEnsx4FACxM\nKqHKOdeR9EwzOyXpVjN7mnPurpGvuUXSLZJ05swZKlnL6sKj4cfX3ynlMtkacvl8/D9LZ98WhgSz\n2R/vwiPhx3/yTumKfzj74y2b1r70u88aPk8A8ESqZ03n3Hkz+5ikmyTdNenrsYTqO1JxQzpFNXFg\n+wqp25ZaF9K5nNW/7HfRE6TK5bM/3rJxTrL8al/eBIAIaaz+O92rUMnM1iT9I0mfn/VxkZH6eSmo\nZD2K5VLuHY+0QkL/cVb1OJuFz41QBcAzaVSqLpP0jt68qpykP3HO/XkKj4ss1HekYDvrUSyX/vGo\n76RTWRqEqhU+zsE2oQqAd9JY/XeHpBtTGAuWAaHqqIOhKg31HclyUmkzncdbRoQqAB6iozoOI1Qd\nFZwKP6YZqoLtdCa9LytCFQAPEapwGKHqqHlUqlb9GBOqAHiIUIXD6tXVP+EnlXqo8uAYB9tSo5r1\nKABgoQhVGHIuDA7lFV2VNq1gDqv/Vv0Yl6lUAfAPoQpDzT3JdVa/ipJUoSwVAi7/JRFsS81dqdPO\neiQAsDCEKgz5sNR/WmnOEarvDCe/r6r+zxCXAAF4hFCFIUJVvNRD1Yof48E8tPPZjgMAFohQhSFC\nVby0QlWnJbX2Vv8Ypz25HwBOAEIVhvqXalb90tQ00gpVjdrw8VYZoQqAhwhVGKJSFS+tFgH9y2Gr\nfowHoYo5VQD8QajC0Kpv9DuLckobBPtyjNNuQwEAJwChCkP9Ksqq91CaRv/yn3OzPY4v1UAu/wHw\nEKEKQ/WdsB9TMch6JMsn2JY6Taldn+1xfAlVpS1JRqgC4BVCFYZ8WOo/rbQqL76EqlwuvARIqALg\nEUIVhnzYk25aqYWq/gpLD44zmyoD8AyhCkNUquL120zMupqtviPJepfHVhybKgPwDKEKQ4SqeGle\n/gsq4eWxVRecolIFwCsevLPj2Oo7rPyLM2gRMOO2K/UdqexJcE2rDQUAnBCEKgxRqYqXaqXKk2PM\nnCoAniFUIeScXyf8pAhVyRGqAHiGUIVQuy51W/6c8JMqBFK+NHtIaHi0wrI/Ub3byXokALAQhCqE\nfOmfNC2zdCovvlWqJFYAAvAGoQohQtVkabQI8DFUsakyAE8QqhAiVE0262q2bqd3+c+TFZZsqgzA\nM4QqhAhVk816+a/hUTd1iU2VAXiHUIUQoWqyWUOVb8eYUAXAM4QqhPpNLX054U+DUJUMoQqAZwhV\nCPm00e+0Zg5Vnh1jQhUAzxCqEKrvhH2YCkHWI1lewXbYz6vdmO7+vlWqykxUB+AXQhVC/aX+ZlmP\nZHnN2iLAt1CVy4fBij5VADxBqEKIzZQnm/VyVv9+Ph1nNlUG4BFCFUI+NaWcFqEqOfb/A+CRmUOV\nmV1lZh8zs3vM7G4z+8U0BoYFI1RNNghV56e7f31HKm1J+UJ6Y1p2hCoAHkmjUtWW9CvOuRskPVfS\n68zshhQeF4vk00a/05q1UuXjMQ62pw+hAHDCzByqnHMPOuc+2/t7TdK9kq6Y9XGxYFSqJkvj8p9v\nx5hKFQCPpDqnysyulXSjpE+l+bhYAB9P+En1j8+0q9l8PMbBNhsqA/BGaqHKzDYl/amk1zvnjryL\nmtnNZnbWzM6eO3curW+LNLTqYf8l3074SRXXpVxhhkrVef+OcbAdhtBuN+uRAMDcpRKqzKyoMFC9\nyzn3Z1Ff45y7xTl3xjl35vTp02l8W6TFt41+p2U2W4uA+o4UeLTyTwqfr+tKzd2sRwIAc5fG6j+T\n9DZJ9zrnfmf2IWHhfGtKOYtZ5gj5evlPYl4VAC+kUan6Hkk/LemFZnZ7789LU3hcLIpve9LNYtpQ\n1e1KjZp/x5hQBcAjMzfMcc59XBJ7m5xk/SXvvp3wpzFtqGruhpfBfDvGhCoAHqGjOrj8l8S0ocrX\nY0yoAuARQhX8PeFPY9oWAb4e41nbUADACUKogp970k1r1kqVb8e4TKUKgD8IVQhPeJaXShtZj2T5\nBdtSa0/qtJLdz9tKVS9EEqoAeIBQheFSf2O9wUSDOUIJL2f5GqryRam4QagC4AVCFfzc6Hdag1CV\ncJPgQYPVU+mO5yRgU2UAniBUwc+mlNOadjXboFLl2ZwqiU2VAXiDUAVCVRKzhKriRng5zDeEKgCe\nIFSBUJXEtC0CfNxMuW/aNhQAcMIQquDnRr/TKk+5ms3nYxzMsAk1AJwghCr0TvgeTqCexiyX/7yu\nVBGqAKw+QpXvOi2pdcHfE35SpU3JclOEKo9XWPZDlXNZjwQA5opQ5bv+XBdfT/hJ5XLhJUAqVccX\nbEuuIzX3sh4JAMwVocp3/f5Bvp7wpzHN5SzfQ5XEJUAAK49Q5TtfO33PIulqNucIVRKbKgNYeYQq\n3xGqkktaqWruhZe/fD3GVKoAeIJQ5bv+ia7s6XL/aSQNVb4f4zKhCoAfCFW+azBRPbGkocr3Y0yl\nCoAnCFW+4/JfctNWqnw9xoQqAJ4gVPmuvhP2XSptZj2SkyPYlpo1qdM+3tcPQpWnDVb7neT7K00B\nYEURqnxX3wnn+uT4UTi2pKvZfK9UFcpSYY1KFYCVx5nUdz4v9Z8WoSo5NlUG4AFCle983uh3Wkk3\nVR40WPX4OLOpMgAPEKp8x2bKySWdeF3fkQpBeBnMV2yqDMADhCrf+bzR77QShyqOMaEKgA8IVb5j\nTlVy01SqfD/GhCoAHiBU+Y4TfnKEquQIVQA8QKjyWacd9lvy/YSfVLkiyQhVSfRDlXNZjwQA5oZQ\n5TPft0+ZVi4XBqvjtgggVIXPv9uS2vWsRwIAc0Oo8pnvG/3OIkmLgH6DVZ8lbUMBACcQocpnVKqm\nd9w5Qs6Fx9n3Y8z+fwA8QKjyGZ2+p3fcUNWuS50mx7jfC41QBWCFpRKqzOztZvYtM7srjcfDghCq\npnfcUMUxDlGpAuCBtCpVfyDpppQeC4vCCX96hKpkCFUAPFBI40Gcc39jZtem8VhYIE74iXW6Tq1O\nV4XSlvKNHdmkOwyOsb9bATXaHbnCpgLpWKHKOadGuytJyudMxTyzFCZptrvq9tpVlAs5mU38yQQw\nB6mEKpxQ9R1Jxsq0Y2p1unreb35M36zW9UuFh/UvClW5Tke5fD7+Tp4H17d//Cv6t39+j8pq6guB\ndPeXv66nPnv8fW7+w9v0kXsekhQGhI/80vfp6seuL2C0J9NnvvqoXnXLJ9XuhqHqNd99rd78sqdm\nPCrATwv7L6CZ3WxmZ83s7Llz5xb1bTFOvSqVt8K+S5jo2xea+ma1rh+84RJdevpxyslpt3Z+/J0G\nocrP4Hrvg1VVgoJ+8aZ/oKYrqHr+kYn3uef/VfW0Kyr66edeo0a7q688sreAkZ5c9z20q3bX6XUv\neIKuvmhd9z54zP5pAFK3sLOpc+4W59wZ59yZ06dPL+rbYhyaUiZS3W9Lkn7o6Zfp8VddLkna23l4\n/J08r1RV6y1dtr2mn//+61WzDeWaky//VestnbnmIv3Md10T/nu/Ne9hnmjVenh8XveC6/WkS7ZU\nrbczHhHgL0oUPiNUJVLrnbwqQVGFjXCO1H7t2+Pv5HmoqtXbqqyFswwu2IYKzfFVlG7XabfRViUo\nqLJWHDwG4tXqLRVyprViXpW1wuDnFMDipdVS4Y8lfULSk83sATP72TQeF3NGqEqkXwGorBVU3HiM\nJKlefXT8neo7Ur4kFYJ5D28pVestVYIwHO3nN1Vs1cZ+/W6zLeekylpxcL8qIWGs6n5blbWizEyV\noEhlD8hQWqv/XpXG42DB6jvSqauyHsWJ0T9ZbQVF2eZFkqTm3jEqVcG25OlqrOp+W9efDt9m6oUt\nldvjQ9XwGBcUFHMq5IyQMEG13tJWEB7jSlBQrdFWt+uUy/n5Mwdkict/PqNSlUj1wOW/9UoYqlp7\nx5io7vExrtZbg8t4rcKWgu74Sef9eWuVoFd5WStSqZqguj+sBlbWinIurPgBWDxClc/Y6DeRwQl/\nraCNXqjq7B8jVHl6jJ1zh074ndKWNrq7Y+8zCK69IFYJCoPjjmjVA/PWBpdMqe4BmSBU+arbZaPf\nhKoHJgRvboehyu1PWM3m8THea3bUdRqc8LulijbdpErVsBooiUrVMRyuVBV6txFEgSwQqnzVrEly\n3p7wp1HrXcoyMxWKJe25QNaYEKo8vvw3GpAUnFJgLTXq8cGqdmAxQP++VF3GO7gYgMn9QLYIVb7y\nfKn/NKr74VL/vl3bUL4xodGix6FqGJDCE72thcdhdyd+xeTBeWvhfQu0VJjgYNsK2lAA2SJU+YpQ\nldjBSdeSdCG3qUKLUBVnNCDl18PeXhd24ruq9y9bDVezcflvnFanqwvNztFKFdU9IBOEKl8RqhKr\n7g+XrktSPb+p0ri+S6261K57e4wPtkeQpOIxGqZW6y2tl/Iq9DZR3mKi+lj9ilT/GPc/EkSBbBCq\nfEWoSqxabw/nB0lqFLZU7oxZzda/NOjpMR5dyVfq9fZq7I4JVQcmXUth5WW/1VGz3Z3jSE+uwby1\n3jEehCqCKJAJQpWv6v0Tvp/L/acxesJvFze1Nq5FQN3zUDXoORWe6Ne2wi704xqmhpdYh9XA4Rwh\nKi9RRi+xFvI5bZTyVKqAjBCqfDWoVJ3KdhwnyMEJwZLUKVW0Ma5FgOfVwIMd6CVpvfJYSVJ7TMPU\n2kg1cNAigInXkYa90w4eM1ZMAlkhVPmqf8L3tDFlUs12V/utzqETfre8rS23J9eNuTRV74UHX0NV\nvaW1Yl6lQvg2M+ztFR+qRhcDMPF6vOEl1gPVPSb3A5khVPmqviOVNqV8Kts/rrzayPwgKWwRkDen\nC7sxvao8r1SNVvbW1rfUcnlpTG+v0bYVtAgYrzZy+U+iDQWQJUKVrzxe6j+N6khTSknKrYWXTner\nMS0CPA9VB5tSSpLlctq1DVk9vg1FbKWKykukyMt/VKqAzBCqfFU/7+3JfhqD+UHl4clr0HepGjPx\n2vdQtd8+1IJCkvZsQ/mYNhT9vQIP3me4mo2QEKVabyln0kYpP7iNNhRAdghVvvJ4T7ppjLYHkKTS\nRriarV6L6RDeqEq5glRcn/v4ltFo1UmS9nObKsY0TB3sFRgcnnTdfywcFYbQcOukPvZLBLJDqPJV\nfYdJ6gmM7kknSeXNMFTF9l3qH+MDJzyfjLagkKRGYUPldnQbiqh5axulvHJG36U41ZF5a9Jwv0Tn\nXEajAvxFqPIVc6oSObI5sKS1SriarR3Xd8nzYxx1wm8WKgo60Zf/hn2tDszDMqPyMkZUcK2sFdR1\nYeUPwGIRqnzl+Qk/qajLfxu9UNW5ENMiwONj3J8fNXrCb5e2tN6N7u0V1R5AGlZecNToYgCJNhRA\nlghVPnLO6xP+NKr77SMTgje3w2aW3bi+Sx4f4/1WR+2uOzKnqluqaDOmYWpUNVAKQxbNP6NV9yMu\n/zEPDcgMocpHzV3Jdb094U+jP+n64ITgUjnQBVeWNWJaBHgcqgZz0EYCkgu2tW4NtZqNI/eJqgb2\nH4NtaqLVxlSq6FUFLB6hykeeL/WfxuhS/75d21COUHXEcIuaw8fMetsi7e4cXTHZn1M1eh9aBMSr\n1tuDbYD6aEMBZIdQ5SM2U05sdE+6vv3chgoxLQJU97dtRVzVKb8eHo+9naMNU/vVqNFQRTPLaO1O\nV7sNLv8By4RQ5SMqVYlFTQiW+n2XIlazdVpSa8/bYzxcyXf4hF/s9fbaj+jtVa23FRRzKhfyh25n\ng+Bou43oS6yVQaWK6h6waIQqHxGqEouaECxJjcKmgk5E36VBNdDPYxxXqSpt9humHm1DEbVaUApD\nw16zo3YnZuNqT0VtUSNpcDmQIAosHqHKR4NQdSrbcZwgcZWqVrGitchQ1VsR6GuoilnJt7YVhqrW\nXlSl6mgHdmnYYoGJ14cNgutINbBUyGmtmOfyH5ABQpWPqFQlVt2PPuF3SlvaiGoR4Pkx7rdAGJ0f\ntVYJ21C0I3p7VffbRwKCxKbKcQbBNSaIcvkPWDxClY/6J3y2qTmWdqervWYnslLVLW9rw+3JdUcu\nTXkfqloqF3IKiofnRw17e+1E3ic6INAiIEo1pm1F/7ZagxAKLBqhykeNnXCT30Ip65GcCP0JwVEt\nFRRsq2Qd1fdHqlUNz+dU7R9d6i9JG5vb6jiTqx8NVbWI9gASLQLiVGNWS/Zvo1IFLB6hykdsppxI\n3IRgScqthaFpd7RFgOfVwLDqdPRkb7mcdm1duYhQFU5U5/LfcY2//EcbCiALhCofedyUchpxE4Il\nqbAeTva/EBeqPD3OcSv5JGnPNpVvHu7t5ZybOFGdysth1XpbZtJWOTqIUtkDFo9Q5SNCVSLjKgKx\nfZfqO5LlpNLm3Me3jKr1duTxkqQLuU0VRnp71VtdtTouen4QzSwjVfdb2iwXlMvZkc+xXyKQDUKV\njwhViQwrVUdP+OVe36XG7kjfpf4l1pyfv2K1mEt5ktTIb6rcPhyqhn2tjt5ns1SQGXOqRsW1+ZCG\nlSrn3IJHBfjNz3d83xGqEhnOqTp6wl+rXCRJau1FhCqPj3HcpTxJaha3jjRMjetrJUm5nGmrTOVl\nVNiQNiZUrRXV7jrttzoLHhXgt1RClZndZGZfMLP7zeyNaTwm5sjjPemmEdcdXJLWe32XOqMtAjw/\nxtWYvRKlXsPU7t6Rr5eij3H/di7/HVarx1cD+8eeNhTAYs0cqswsL+mtkl4i6QZJrzKzG2Z9XMyJ\nc95XUZLqTwjeLEVcmtoOK1Xd/ZFmlh4f43qro2a7G1nZk6RuuaJNNxqq4hcDhLcXmag+Yty8teHk\nfoIosEjR72DJPEfS/c65L0uSmb1b0ssl3ZPCY0/l07f+ri67479m9e2XmsnpStfS1/eLuiri83/3\ntW/rDf/rDrW7zMXoe2S3ETshOFjbUN0VdcNX/kBff8utg9svcd/SQ5c8P/IYn7/Q1Kvf/umVvZzV\n6f3sRPWckiRX3tam7evrb/mOwW2Pd07vLW2pkv9g5H0qawV9/P5zesFv/3Xq410W3/ek03rzy556\n5HbX7er23/4hXbz/1UO3v7Xr9Im110o6c+Q+/UrVq9/+aZVHGrD67LLtQO947XNUzB+tJ3zibb+q\nKx/4QAajwqwefd5b9IwX/kTWw5CUTqi6QtLXD/z7AUnfOfpFZnazpJsl6eqrr07h28Yrn7pUD21+\nx+Qv9NSnzl+v/dxz9VMRn/v0Vx7Vfd/a1Q89/TLl7WiI8NUzr4rfJ/H2J75OhYfuPHTbXbXrdK+9\nWL8c8fVffGhXn3tgR9/9hMfq4s1yyiNdDs++9iK98CmPi/zc5d/9Sn1m5yuy7nC+z2Pa39Kz9+9U\nJ39O0ukj9/ln33OdPnDHg/MabubueOC8PnTXg5GhqlG/oBsv/K2+lH+8vr1+3eD2p+19Qi8O7op8\nvBuvPqVXPecq7TWYU9X3tUcv6G+/9IjO1Rq6/NTakc9f8o2PqOQa+vrWMzMYHWax2ZuGsQzSCFXH\n4py7RdItknTmzJm5lkGe8YIfl17w4/P8FifaT/7rD+k1dlnk56r1lnIm/d6rbpQRqo7luT/1liO3\nvfKWT2h055q+/iWZf3XTU/SMMWFtVV3zlGfpmqe89/CN9/+V9M5XHOlf1ffip16qFz/10gWMLhu/\n8YF79M5Pfi3yc7vnH1Eg6eEn/1N950/8y+En/vvzFeTrkffZCor69694+hxGenJ98M4H9fPv+qyq\n9ZYu19FQtd7d09e2n61n/9J7MhgdVkUaE9W/IR26ynFl7zYsqXDSb/Slp1pvngaBajaVIH5idX9P\ntrj5MF4KeuGyHh2qVl0lKGq/1VGrczSJ7/V6oOU3RgJ4sD3cDgkTDTrzx8zNW3d76ni6AwLSk0ao\n+oykJ5rZdWZWkvRKSe9P4XExJ5WgEHvCH9cJG8dXWSvGrrwatGiImZTtpf6k/ojta3wwbtPofmPZ\n0vpjDn8i2Pb2eE2jP3m/FvHe1+10tKl9ubKfi0uQnplDlXOuLekXJP2lpHsl/Ylz7u5ZHxfzszVm\nC4twRREn+1mFG9rGB9fwawivA0GvQlA/P/7rVtS41XrNWtgDrbQ5UqkqE6qSGLeHZK36beXMyQIq\nVZhNKmdP59wHJUUv28HSqawVtTPmhL9V5mQ/q0pQVK3RVqfrlB9ZNVittxQUcyoV6L070L/s4mlI\n6P/ORZ3wmxfCULW2ddHhT1CpSmQriN9Dcm/nEW1Lyq37N8cR6eJd3UOVoKBabKWqRaUqBf3LObsR\nl3Oq+/GNMb1VDKRC4G1IGOxvGHHC7+yF1buN0RVOwbbU3JU6q9maI21bgzlVR9/79qvhhuiF0Uus\nQEKEKg+N607NCT8d/flSUcd53BYuXvO48jK4/Bc136fXWHbzVESokpisfkylQk5rxXzkMa7HXWIF\nEiJUeShcmRa3+o8TfhoGlYeo+Rv1NpPUo/gcqsZUUVSvqukKKgfrh28fTO73cx7aNCprhchqYLO3\nd2ewedGRzwFJEKo8VFkrqNnuqj6y2Wq709Ves0OlKgXjlm9TqYrhcYuAcSE816yqZhuy3Mjb9SBU\n+XnMphHX6qTVu8S6vkRNJHEyEao8tBWzCqY22NSWKsqstsZd/qNtRTSPK1UbpbxyFt1SodCsas82\njt7J8zYU04hrddK/xLqxTaUKsyFUeagSswpmuKktJ/xZba/FX86hbUWMcsXbgGBm4VzHiJ+XYqum\nen7z6J0Cv1dMTiOuR5/rHcONLSaqYzaEKg/FXWroh6wt5vvMbNgT53Bwdc6FbSsIrkd5XKmSer3N\nIqoo5XZNjchQRaUqqbgefVbf0a5bU6HI7yVmQ6jyUNyk2EGlivk+M9sMops57rc6ancd1cAo/VDl\n5ro16NKqxJzwg+6uWsWto3cgVCVWWYsOrvlmVbtRl1iBhAhVHtoebNdw+M2lxuW/1ORzpq1yIeIY\nM28tVrAtdZpSO3qT4FUXN4l6o7urdimi03dpS5IRqhLoB1c3EtwLzar2cxHVQCAhQpWH4rZrGOxJ\nxwk/FVH9wPqVCIJrBM8rL3HL/TfdnrpRoSqXC+dVeXq8plFZK6rdddofWflcate0H3WJFUiIUOWh\nrZjl/lz+S1fU/n8c4zE8bxEQValq1C8osJZcELPRr8dtKKYR1+qk3NlTsxBxiRVIiFDloaCYUzFv\nkVUUM2mzRKUqDVEnyUE1kMUARwW9btaeVl6ilvvv7jwqScqtxXT69nxyf1KVwdSHw7+X691dtUqE\nKsyOUOUhM4ucFFutt7VVLig3sgEwphN1OYdK1RietwioBEXtNtpqd7qD2y709qTLr8VUqsqEqiTi\npj5suF11oxYDAAkRqjwVzvcZOeGz1D9V0ZWq8N+0rYjg+bYr/Z+J3cbw93K/tyddMW5POipViWxF\n9OjrdjradBfUjbvECiRAqPJUJWa+DxWU9EQ1c+wHWSaqR/B+ovrR+T6NWnj5r7wR0+mbUJVIVI++\nvd0d5c3JCFVIAaHKU+H8jaMnfOb6pKcSFLTbaKvbHS7frtZbKhVyCor5DEe2pHwPVRFbGw02+q0Q\nqtIQ1aNvbye8xBo7bw1IgFDlqajuzdV9KlVp2gqK6jpprzk8ztX9NlWqOIVAype8Xc1WidjaqH0h\nDExr40JVoyZ1u9GfxyHDPTmHv5MXqmE1sLBBqMLsCFWeipqoXqtzwk9Tf6XRwTfw8BIr1cBIZl5X\nXqImUfc3+t2sPDb6TsG2JOdtEE0qKOZVLuQOvffVe5dYixvs+4fZEao8FdeYkhN+eqIuNVT3WwTX\ncTzeVDkqhKu+o7bLaX0jZmWa5ysmpzG6SKe5FwbXYJNQhdkRqjxVCQqqt7pqtMPOwp2uU61BpSpN\nUZdzqvU2l1jH8blSFfHzkmtUVbMNWS7mrdrzeWjTqASFQ/+hbPXmra1txVxiBRIgVHmq/wbebza4\n2/vIUv/0DC/nDP9XXNtvcYzH8ThUbZYKMjv885JvVrU3bqNfQlViWyNTHzoXwkrVxnbMJVYgAUKV\np/on/H6ooill+qK6N1eZtzaex6EqlzNtlg+3Oim2Jmz0S6hKbPTyn+sdu81tKlWYHaHKU8MmeOEb\n+CBUccJPzVbUnComqo/ncaiSjjaMLbd31SgQqtJUCQqqHfidtPqO9lxZxVI5w1FhVRCqPDXaBG+w\nJx0n/NSMLt+utzpqtrsE13GCbW83VJb6DWOHVZRyZ1fNQiX+Dv1Qxeq/YxtdpJObdIkVSIBQ5anR\n3dqpVKWvmM9pvZQ/Wg3kEmu8YFtq70vtRtYjycToJOqN7q7a4zb6LbP6L6mwncwwuBaaNe2Nu8QK\nJECo8tRw+Xa/UhV+3OaEn6rnPWZCAAASi0lEQVSDl3MG1UAmqscbXM7ys/IS7nQwPOFvuj11S2Mq\nVbm8VNoiVCVQWSuo2emq3gpXPpfaVTXyhCqkg1DlqdEeSuxJNx+VtcLRaiDBNZ7nc4QONuVtNRta\nt4bcpD3pPJ+HltRok9WJ89aABAhVnlov5ZXP2eB/xf0VaptUUVJVCYqqNcJjW6tTqZrI81C1deDy\n3+5O2Ol74ka/hKpEhot0wt/Hte6uWsUx1UAgAUKVp8zs0Bt4db+tzXJB+ZxlPLLVshUcqFTtM29t\nokGoOp/tODJSWSsONuHub/Sbn7TRL6EqkdFFOhtub/y8NSABQpXHDl5qqNZbVFDm4OBKIy7/HYPn\nlapKUJBzUq3R1n5vT7rCpO1Tgm1vQ+g0Dk59cN2uttyeXGlCNRA4JkKVxyprhcFcqnDfP072aTsU\nXPeZtzaR5y0CDm5VU98Nt08pbxynUuXn8ZrG9oE9Fi/s1VSwrrRGqEI6CFUeO1qp4mSftn5wdc6p\nWm+pmDcFRX7tYnneIuDgJOr+nnTBpD3pAn83oZ7GwUrVbjW8xJqbNG8NOKaZ3t3N7MfN7G4z65rZ\nmbQGhcUYXe5P48/0VYKiOl2nC81OWA0MijJj3lqs0oZkeW9DwqDVyX5b7b3wkt56ZVKo2g4re93u\nvIe3Eg7OqbrQWwxQmFQNBI5p1v8y3yXpFZL+JoWxYMEqa4Xh6r8Glap5OLhxda3e5hLrJGZeT7we\n7snZUnf/mBv9BtuS60rN3XkPbyWUCzmV8jnV6m01evPWShsT5q0BxzRTqHLO3euc+0Jag8FibQVF\nnas19PPvuk0P7TQGS42Rnv4xfcOf3qFPfvkRjvFxEKr01r/+ku7/2jfUcaaNzWO0VJC8PWZJ9Vc+\nf/DOB3XrJ+6RJJUmLQYAjmlhkzvM7GYzO2tmZ8+dO7eob4sxnvfEi/WE05u676FdXXfxhp7/pNNZ\nD2nlPOPKU3r6ldt68Py+tteKevFTL816SMvP41B16Xag773+Yl1otLXp9lTPbyqXz4+/E6EqsR95\nxuUq5XNSIzxmV1zG7yXSMfG/zWb2UUlRP3Fvcs6977jfyDl3i6RbJOnMmTPu2CPE3Hz/kx+n73/y\n47Iexkq76qJ1vf8XvjfrYZwsHq9mKxVyeuc//87wH3/2R9LXJsynkrxfMTmNN7/sqeFfPv1F6YPS\nZuXibAeElTExVDnnXrSIgQCApDAkPHxf1qPIXn1nGJjGoVI1vf4xC+iojnSwthvAcqFFQOi4ocrz\nNhQzqe9IhUAqlLMeCVbErC0VftTMHpD0XZI+YGZ/mc6wAHgrOEVAkBJUqk4Nvx7JHPcYA8c001Ik\n59ytkm5NaSwAEJ7kWntSpyXlPW5BcexQRaVqaoQqpIzLfwCWy2COkOcTr497ws8XpeIGoWoahCqk\njFAFYLkMQpXHmwR3O1KzdvwTPpsqT4dQhZQRqgAsF1oEDJ97olDl8fGaVqNKqEKqCFUAlgstAg4s\n9U8Sqjw+XtOiUoWUEaoALBdaBAyfe/mY/ZNoQ5Gcc+ExO+4xBo6BUAVguVCpolK1CO261GlSqUKq\nCFUAlguhilC1CEmPMXAMhCoAy6W0KVnO75AwbahybKt6bIQqzAGhCsByyeXCeS6EqmShynWk5t78\nxrRqBsf4VLbjwEohVAFYPr63CKhXJVmCieq0oUisnrBtBXAMhCoAy8f31Wz1Ham8FVbtjoMVk8n1\nm6UGrP5DeghVAJaP75sqJ+2fxOT+5JhThTkgVAFYPr6vZkscqk4N74fjIVRhDghVAJYPoYpK1bzV\nd6R8SSoEWY8EK4RQBWD5EKoIVfPWP8ZmWY8EK4RQBWD5BNtSsyZ1O1mPJBuNpKGKieqJsZky5oBQ\nBWD5+N4iIGmlqlCWCmuEqiTYTBlzQKgCsHx8bhHQ7YY9lJJu9Ot7G4qk2EwZc0CoArB8fJ4j1KxJ\ncsmrKL7PQ0uKShXmgFAFYPn4HKqmXepPqEqGUIU5IFQBWD6EKkLVvBGqMAeEKgDLh1BFqJqndkNq\n1wlVSB2hCsDyGYQqD1f/TbvRb7Dt72rJpNhMGXNCqAKwfMoVSeZn5WXWSpVz6Y9p1QyO8alsx4GV\nQ6gCsHxyOam8RahKolyROs3wshbGGxxjWiogXYQqAMvJ1zlC/eecuE+Vx/PQkqqfDz9y+Q8pI1QB\nWE4+h6rSppQvJLsfoer4pq0GAhMQqgAsJ59D1TQn+/78IB+PWVKEKswJoQrAcvI2VJ2fMlRRqTo2\nQhXmhFAFYDkF21LDw4DQqBKq5q1RlXIFqbie9UiwYghVAJZT2dMNgqfd6DfweBPqpPrH2CzrkWDF\nEKoALKdgO2zS2O1mPZLFmnpOFZWqY2OLGszJTKHKzH7LzD5vZneY2a1mRic1AOkItiU5qVnLeiSL\nNe0JvxBI+RKh6jgIVZiTWStVH5H0NOfc0yV9UdKvzT4kAJCflRfnpj/hm/k7uT8pQhXmZKZQ5Zz7\nsHOu3fvnJyVdOfuQAEB+hqrmruS605/wCVXHQ6jCnKQ5p+q1kj4U90kzu9nMzprZ2XPnzqX4bQGs\nJB83VZ51o182VT6e+pQrLIEJJrbsNbOPSro04lNvcs69r/c1b5LUlvSuuMdxzt0i6RZJOnPmDDt+\nAhjPx0rVrP2TqFQdD5UqzMnEUOWce9G4z5vZayT9sKQfcI7t0QGkxMcWAbNu9FuuSDsPpDeeVdRp\nSa09QhXmIuHmUoeZ2U2S3iDp+5xzF9IZEgDIz21XqFTN36yXWIExZp1T9XuStiR9xMxuN7PfT2FM\nADBsgOlTSBiEqim70xCqJqufDz8SqjAHM1WqnHPXpzUQADgkX5BKm36FhDQqVe261KpLxSC9ca0S\n9v3DHNFRHcDy8q3y0n+u02xTIw2DAisA4xGqMEeEKgDLy7dNlRs74Sa/hdJ09x/MQyNUxWowpwrz\nQ6gCsLx8rFTNcrL3sQ1FUlSqMEeEKgDLq1zxKyDUd6a/9CcdaENxPp3xrKJZL7ECYxCqACwvKlXJ\nUKmarL4jWS5cBAGkjFAFYHkRqpIhVE3WrwbmOP0hffxUAVhe/VDly2YNhKr5Y4sazBGhCsDyCrYl\n15Wau1mPZDFmPeEX16VcgVA1DqEKc0SoArC8BpUXD1oEOBc+z1lO+Ga9NhQeHK9pzXqMgTEIVQCW\nl0+bKrf2pW5r+s2U+3xbMZkUlSrMEaEKwPLyaY5QWv2TfJvcnxShCnNEqAKwvAhVyRGqxiNUYY4I\nVQCW12DbFQ9CAqFq/jptqVkjVGFuCFUAlpeXlapTsz0OoSoe+/5hzghVAJZXfysRHzZVTuuEH2z7\nsVpyGoQqzBmhCsDyKpTC3ks+VF76+/XNHKpOSa09qdOafUyrhs2UMWeEKgDLzZcWAWlt9DtoQ0G1\n6gg2U8acEaoALDdf5gjVd6R8WSoGsz3OYB7a+dnHtGqoVGHOCFUAlptPoSqNk71Pk/uTIlRhzghV\nAJYboSoZQlU8QhXmjFAFYLkRqpIhVMWr70gy5lRhbghVAJabLy0C0trot/8YbKp8VL0aBqocpz7M\nBz9ZAJZbv1LlXNYjmS8qVfPHFjWYM0IVgOUWVKRuS2rtZz2S+arvDNshzKK0KVmOUBUlrWMMxCBU\nAVhuvlRe0qqimPnT2yspKlWYM0IVgOXmQ6hq1aVOI70Tvi+T+5MiVGHOCFUAlpsPoSrtpf6EqmiE\nKswZoQrAcgtOhR9XOSQMQtWpdB6PUBWNUIU5I1QBWG4+tAjoP7dUK1UrfLym0e2Gx5lQhTkiVAFY\nbv1Gjau8l13/uaXVlJJK1VHNmiRH40/MFaEKwHJjTlVyhKqj2KIGC0CoArDcioGUL692SJhHqGrW\npE47ncdbBYQqLMBMocrM/p2Z3WFmt5vZh83s8rQGBgADq155mUeoklZ7HlpShCoswKyVqt9yzj3d\nOfdMSX8u6ddTGBMAHOZDqMoVpeJaOo/nwyXTpAhVWIDCLHd2zh38b9CGpBXfnAtAJoJt6dEvS5//\nYNYjmY+H7g6fo1k6j9cPDl/8C+nUNek85kn31f8bfiRUYY5mClWSZGa/IelnJO1IesGYr7tZ0s2S\ndPXVV8/6bQH4ZPsK6Z73Se9+VdYjmZ9Ln57eY1WuCD/+xRvTe8xVkC9JGxdnPQqsMHMTdn43s49K\nujTiU29yzr3vwNf9mqTAOfdvJn3TM2fOuLNnzyYdKwBfNfekh+/LehTzdepqaf2i9B7v4ful5m56\nj7cKNi6Wtq/MehQ4gczsNufcmUlfN7FS5Zx70TG/57skfVDSxFAFAImUNqTLn5n1KE6Wi6/PegSA\nd2Zd/ffEA/98uaTPzzYcAACAk2nWOVX/wcyeLKkr6e8l/dzsQwIAADh5Zl3992NpDQQAAOAko6M6\nAABACghVAAAAKSBUAQAApIBQBQAAkAJCFQAAQAoIVQAAACkgVAEAAKRg4t5/c/mmZucUNgudp4sl\nPTzn77HMeP7+Pn+fn7vE8+f5+/v8fX7u0nyf/zXOudOTviiTULUIZnb2OJsfriqev7/P3+fnLvH8\nef7+Pn+fn7u0HM+fy38AAAApIFQBAACkYJVD1S1ZDyBjPH9/+fzcJZ4/z99fPj93aQme/8rOqQIA\nAFikVa5UAQAALAyhCgAAIAUrGarM7CYz+4KZ3W9mb8x6PPNkZleZ2cfM7B4zu9vMfrF3+5vN7Btm\ndnvvz0uzHuu8mNlXzezO3vM827vtIjP7iJnd1/v4mKzHOQ9m9uQDr/HtZlY1s9ev8utvZm83s2+Z\n2V0Hbot8vS30X3rvBXeY2bOyG3k6Yp7/b5nZ53vP8VYzO9W7/Voz2z/wc/D72Y18djHPPfZn3cx+\nrffaf8HMXpzNqNMT8/zfc+C5f9XMbu/dvmqvfdy5brl+951zK/VHUl7SlyQ9XlJJ0uck3ZD1uOb4\nfC+T9Kze37ckfVHSDZLeLOlXsx7fgo7BVyVdPHLbf5T0xt7f3yjpN7Me5wKOQ17SNyVds8qvv6Tn\nS3qWpLsmvd6SXirpQ5JM0nMlfSrr8c/p+f+gpELv77954Plfe/DrTvqfmOce+bPeex/8nKSypOt6\n54V81s8h7ec/8vn/JOnXV/S1jzvXLdXv/ipWqp4j6X7n3Jedc01J75b08ozHNDfOuQedc5/t/b0m\n6V5JV2Q7qqXwcknv6P39HZL+cYZjWZQfkPQl59y8dyvIlHPubyQ9OnJz3Ov9ckn/04U+KemUmV22\nmJHOR9Tzd8592DnX7v3zk5KuXPjAFiDmtY/zcknvds41nHNfkXS/wvPDiTXu+ZuZSfoJSX+80EEt\nyJhz3VL97q9iqLpC0tcP/PsBeRIyzOxaSTdK+lTvpl/olT3fvqqXv3qcpA+b2W1mdnPvtkuccw/2\n/v5NSZdkM7SFeqUOv6H68vpL8a+3j+8Hr1X4P/S+68zs78zs/5jZ87Ia1JxF/az79to/T9JDzrn7\nDty2kq/9yLluqX73VzFUecnMNiX9qaTXO+eqkv6bpCdIeqakBxWWhVfV9zrnniXpJZJeZ2bPP/hJ\nF9aCV7p3iJmVJL1M0nt7N/n0+h/iw+sdx8zeJKkt6V29mx6UdLVz7kZJvyzpj8ysktX45sTbn/UR\nr9Lh/1St5Gsfca4bWIbf/VUMVd+QdNWBf1/Zu21lmVlR4Q/Zu5xzfyZJzrmHnHMd51xX0v/QCS97\nj+Oc+0bv47ck3arwuT7UL/X2Pn4ruxEuxEskfdY595Dk1+vfE/d6e/N+YGavkfTDkn6yd3JR79LX\nI72/36ZwXtGTMhvkHIz5WffptS9IeoWk9/RvW8XXPupcpyX73V/FUPUZSU80s+t6/3t/paT3Zzym\nueldR3+bpHudc79z4PaD145/VNJdo/ddBWa2YWZb/b8rnLB7l8LX/NW9L3u1pPdlM8KFOfS/VF9e\n/wPiXu/3S/qZ3kqg50raOXCpYGWY2U2S3iDpZc65CwduP21m+d7fHy/piZK+nM0o52PMz/r7Jb3S\nzMpmdp3C5/7pRY9vQV4k6fPOuQf6N6zaax93rtOy/e5nNZN/nn8Uzvr/osJk/qasxzPn5/q9Csud\nd0i6vffnpZL+UNKdvdvfL+myrMc6p+f/eIUrfD4n6e7+6y3psZL+StJ9kj4q6aKsxzrHY7Ah6RFJ\n2wduW9nXX2F4fFBSS+E8iZ+Ne70Vrvx5a++94E5JZ7Ie/5ye//0K54/03wN+v/e1P9b7vbhd0mcl\n/UjW45/Dc4/9WZf0pt5r/wVJL8l6/PN4/r3b/0DSz4187aq99nHnuqX63WebGgAAgBSs4uU/AACA\nhSNUAQAApIBQBQAAkAJCFQAAQAoIVQAAACkgVAEAAKSAUAUAAJCC/w+oKapj4mrseQAAAABJRU5E\nrkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "7XOgXqCTmaPa", "colab_type": "text" }, "cell_type": "markdown", "source": [ "## 3D Convolutions" ] }, { "metadata": { "id": "QNvSiq5-mcLd", "colab_type": "code", "outputId": "eecbad0f-f443-43c1-83d6-f8fba22c7383", "colab": { "base_uri": "https://localhost:8080/", "height": 530 } }, "cell_type": "code", "source": [ "# Random 3D kernel - HWDIO layout\n", "kernel = onp.array([\n", " [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n", " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n", " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n", " dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis]\n", "\n", "# 3D data - NHWDC layout\n", "data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32)\n", "x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j]\n", "data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None]\n", "\n", "print(\"in shapes:\", data.shape, kernel.shape)\n", "dn = lax.conv_dimension_numbers(data.shape, kernel.shape,\n", " ('NHWDC', 'HWDIO', 'NHWDC'))\n", "print(dn)\n", "\n", "out = lax.conv_general_dilated(data, # lhs = image tensor\n", " kernel, # rhs = conv kernel tensor\n", " (1,1,1), # window strides\n", " 'SAME', # padding mode\n", " (1,1,1), # lhs/image dilation\n", " (1,1,1), # rhs/kernel dilation\n", " dn) # dimension_numbers\n", "print(\"out shape: \", out.shape)\n", "\n", "# Make some simple 3d density plots:\n", "from mpl_toolkits.mplot3d import Axes3D\n", "def make_alpha(cmap):\n", " my_cmap = cmap(np.arange(cmap.N))\n", " my_cmap[:,-1] = np.linspace(0, 1, cmap.N)**3\n", " return mpl.colors.ListedColormap(my_cmap)\n", "my_cmap = make_alpha(plt.cm.viridis)\n", "fig = plt.figure()\n", "ax = fig.gca(projection='3d')\n", "ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)\n", "ax.axis('off')\n", "ax.set_title('input')\n", "fig = plt.figure()\n", "ax = fig.gca(projection='3d')\n", "ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)\n", "ax.axis('off')\n", "ax.set_title('3D conv output');" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1)\n", "ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3))\n", "out shape: (1, 30, 30, 30, 1)\n" ], "name": "stdout" }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvcnLbWle7/l52tXt5m1PFxEZkZFN\niVkXi4K6WJC3qEFiUYLinagoWorgKMdOFMmBIjh3Ig6cJE5qpCNBKP8AKfRSpqhpNhFx2rfZ7eqe\ntgbrPZGNV0zNNE5ExvrA4exz1l6btfZmfddvfX/NI3LOzMzMzMx8MMhXfQAzMzMzHydm0Z2ZmZn5\nAJlFd2ZmZuYDZBbdmZmZmQ+QWXRnZmZmPkD0v7J9Lm2YmZmZ+bcj/qUNc6Q7MzMz8wEyi+7MzMzM\nB8gsujOvlM997nP85V/+5as+jJmZDwzxr3SkzZ7uzEeeX/mVX+H111/nd37nd171ocx8fJg93ZmZ\nmZkPA7PozrxS3nrrLf7iL/6CL33pS/zsz/4sv/zLv8xyueRzn/scf/VXf/Ud7/u93/s9fvRHf5TT\n01N+9Vd/lWEYAPjjP/5jPv/5z3/H5woh+OpXv8of/uEf8uUvf5nf//3fZ7FY8FM/9VMf6PnNzHw3\ns+jOfGj40z/9U37+53+e7XbLT//0T/PFL37xO7Z/+ctf5s///M/5p3/6J/7hH/7he7ILfv3Xf51f\n/MVf5Dd+4zc4Ho/82Z/92X/U4c/MfE/MojvzoeHzn/88P/mTP4lSil/6pV/ib/7mb75j+xe/+EXe\neOMNzs7O+M3f/E3+5E/+5BUd6czMv59ZdGc+NDx48OD913VdMwwDIYT3/++NN954//Wbb77JkydP\nPtDjm5n5QTCL7sxHhnfffff91++88w6PHj0CoGkauq57f9uzZ8++Yz8h/sVE8szMB84sujMfGf7g\nD/6A9957j9vbW373d3+Xn/u5nwPgx37sx/jbv/1b/vqv/5phGPjSl770Hfvdv3+fr33ta6/giGdm\n/jmz6M58ZPiFX/gFfuInfoK3336bT33qU/zWb/0WAJ/97Gf57d/+bb7whS/wmc985p9VMvzar/0a\nX/nKVzg5OeFnfuZnXsWhz8y8z9wcMfOR4K233uKP/uiP+MIXvvCqD2Vm5nthbo6YmZmZ+TAwi+7M\nzMzMB8hsL8zMzMz84JnthZmZmZkPA7PozszMzHyAzKI7MzMz8wEyi+7MzMzMB8gsujMzMzMfILPo\nzszMzHyA/GtLsM/M/JsYQ6DzHiEEjTEYpQDIOc+DZ2ZmmOt0Z34A5JxxMeJioHUeqyehjSmxLEpu\nxpYhOpSQPKhXlMrgUyDmiJEaJdQrPoOZmR84/2KEMYvuzL+bnDMpZ/bjgI+JvXOQEudNjUuR3Thw\nCANLa1naipAiPkWWVvN0uCbnjBKKtxePKJShDVtSTtR6iZXlqz69mZnvh1l0Z36wjCFwGEc673Ap\ncVHXHMeR7dDjUmTMHoXgNvaclzX3ygUbt+NZf4sXI59evkapLGN0+DhQ6p5j2CIyKGn49PJ/QuSe\ng/saQihW9jNYtX7Vpz0z870yi+7M90/Omd57hhA4upFlUTDEwK4fKJVmE3q+dnvNNgycFjUPqgVe\neG7GAwM9hRRYrRlSy3mx4txWXLlvcN0/Y2kkn17+J5QwtOEFMvcU4golCjIJISSvN/8n3v+/OP93\nSLlmUf0faPXg7tgGoJh945kPC7Pozvz7iSkRUuLoRkJKxJzZ9D0nZQUi85XrZ1x3HVJArS1KwZAD\nVimu/Q3P+z1RdJwVax6WZ3T5Gc+H5xSqY2UuWWiJ4ECtVizVgTE+wYWnrOwp96v/jCDS+7+npqMQ\nEaUfkVMHInNS/1di93+T0y1CnmGa/wupXyfHa8gdyAuErF/1Vzjz8WMW3Zl/H2MIbIeBkCKbYeCy\nrpFC8OSwZzsO9NERQmZMjkMaOdU1o+j5x8MLutRSSMvD+hSlB3rfIUTHkLe4NGBlR63PeFiccgz/\nDRefslSg5QMurQYOlLKkFs/Iqcfm51j9Jovyx8nxihS+QiUSWn8Wqe6R0xZyxhT/C4z/DyBBlIjm\n10HdA/93QAD9SYQ8f8Xf7MwPObPoznzv5Jw5jCNDCOyGnpNqEtrr9ohPkUjmvcOOTd8RiTS64qwu\n+erxim/sr8kEam0R0iOkZF0W7MIzbv01mpZC1bze3EOJG4awQ4trFBkrByrZY9QlZ0rgw19TiANW\nGpR8wJkSxLSjlAqLQyAocUj9CG3+ZwhfhfANlKyQ9sdBNuS4JSMQcgnx8d2lUCIXXyQB3v0VIDDF\nf0aph6/2i5/5YWIW3Zl/nZASIUUOw0gkYYTiRddilabSmr+/veZZeyDkyFJb1mXBxre86Dr2Y88x\nDex8i1SCR/UKJT3f7J/Qpi1WKRaqYG0cSQQMkSyfQT5SyR6rGh6V98nxq4S0oRYtRhWspEOLiBYr\nKrGH9IxGJBAFSt6nkZKQbrEopLAoLFJKhLgg67eJ4SuI+AKhTlH2v4CsCOE9kijx2SOEJpMRSKrF\nF+nCY9rx/0PKkpPqf6fQrwGQckTOpW0z3zuz6M78y6Sc8XGyD3JO3LQdi6JgZS2PD3seHw/4GAHQ\nUjDg8TGSgW+0Gx7vb0FmVrZkZQ2HfORmOBBFT0yRwowoJbi0J3hesAvPMLKllIqFrjhTA54WIw6U\nsqOko1IeLRasVU1K72AZKIXDyIq1kqQMAosUR2QeqIQGoVHiDVI+EvIGLQSaBqnWQAKxYhRLXHxC\nTgeEOqcs/jeysHT+azhWBAqsfkTOIymPXCx+gWfDP3HwzzCi5I3mx1na+4QU8NmjhcZI80p/v5kP\nJbPozvxzcs7sxoE+BDZ9x8IUrMqCq7blMA5kMrf9yGEcGbPHIDlvKh53B/5xc8UYPUYqjFYYDS55\nrvo9t2GHEz1GwaltWBvBLl7T5wOFHDFScGI8WioqJUnpBVLsqMVAqWAhCmrd49OeUg5URBZqRAGG\nEiU1Ph+phKMASlVSUeAJxJTIwqPIGAFCWKR4kzY9xecjEkkhGoS6T8KRWbDLFSG1uDSi9eucVv8r\nPktux3+kz5eU5jXW9hE+9bjY8lrzX/hG+x4xR4QQfHrxGU7tKX1w+ByxUlOqWYg/5syiO/MtXIwM\n3nN0jpQTi2ISWhcjjTY8b1ve3W/xIWCV5mLRMMbAk8OOq74lkzn4EaGh0oqc4clww5NhT86BwmjW\n1qCUp0s9fTqA6DHSUyjBmS4RqsOlDVoMLJSjUoJTHUk5oERCscfKgVIEGhWppUGkHp89lYzUItGo\nQEKh0QxZQ/YY4bBSsJBLMpI+OhwehUSLRBIJJWqS+ARX/jEueaQwnOiKpD5LF4+EXHHMDZDockWj\n3+BR/T9yiIGn3Xt4TvhE82lWZoVPniH2PCzf5sVwfP9Ke7O55MTWtN4RcqZUilLPQvwxYhbdjzs5\nZ0JKuBjZDQNGSW77npQzF3XDTdfy9c0tLkSkFBRaY6ViMw6EGLkdep50O0JI1NZwWpUcfM+77YFb\ndyThkUqglWBhNX0euXUbhtxRqkhjDSdaEhkZ8h4lBkrlsSrQKEEtIeWWlEYK7TmTLY2GQkKIHiEC\nhQhY4VEis5YRJSxtCkQSSxFYqYgmE5CQBYdcIIgIApUULOQJ++Rpk6NPioUUIAwuK4yo6MUjrtwz\nxixRwnJpT4jqR9m6HT4X+LygUgWeMxbmjDfrT7D1jnePLyjkKT+yfoOlKSdffBy4X50xxogSgpgy\nl3VDYyyDD6ScMEpR6Hn8yQ8ps+h+nMk5sx0Geu/ZjwMCuFws2A8DL45HQkoMzuNSIjC19i6N4egc\n/7i5Ztf3FEZTGkNpNYN3PO+OXI8tbfQgEuvSUmvDs3HD9XAgKY+1UApBYyRKRg6pI3MEAkvjWRuN\nlZIxHkiMrHVHpQNaQCUyWnhCDkCikIF7+kilBCFLXEoIkVlIhyKRUZzKSJ8VuyhJQnAmBhY6krOm\nz+CzYMgaScahaQRU6oznPnCIGZ81Z1oS5Rk+G4zU7OM9bsYb+lRQSMsb9SOk/CTvds/JqaRQDZfl\nGisvKJThteqUdgzc9C1L0/DJ1TmVMYSUaJ3jvKiQQiKFIKTEwhqM0pNHnjNKSoyeE3Y/BMyi+3Fk\nDIGDG+mcI6TEaVWzGwb2rqdUhs4F3tttiDGTRGZdFZTa8ni/49l+R8qCQMIqRW0Mo3c87VuuugN7\nP1KWmpOygAQ3vuV2PNClkUikNpK11QgFt+7AEDsK46ltotSGQiQyPTEPaOEQMnNmHEsNPsMQA1o4\nLsyeUmcEmoJAIJCSIApBJQL37REjFIegGJLAishaDQRhyWhW0nEdDGMSjNlwrlqWWtDlBcfg6ZJF\nikwG+rxkpTSluuBrw8jRQZYFD2yNNq8zBMg50Yc1h9gRYkWtDT+yehPJCV/bX1PmgkVZ8ubinKVs\nSAIubYMPgS46zkzDSV1htSbGhIuBxli0kog7IdZSopUkpenyE0Ig5dxp9xFjFt2PC9+a+BXZDj2F\n1hzGkePouGwqxpD4+6sXjD5glMRozUlZsRsGDv1AHwM3XUeKGVMaFkrhU+Lp8cBVd2CMgaLQFFJh\njWIfHO8dNuxdT5aJqlCclAVjcGxcyygGlHZIFEsNTSHpQ8+QB8iRZTGyNAktNSkEknBYNaJlRAk4\n1Y5Ge3bBMkSohOey2FPITMwGRCAEEDITsqYSgXu2w2fDtbeMSbPQPWvlGGlIWWKF44WvCBlcspzq\ngVNruXJn3PoeFw2FjohcMHLOUlu0WPON3tGNCasa3lpecGIesnUBnxwmLYk5ITEsbcVn1vfIXvH0\nsKdWBYvS8KnVOYXUpByptSUDMSdqYyjvhDelRMoZLRVKyfd/05ei+/J6FULMLc8fbmbR/TiQc+a2\n73AxcnSOIQQeLJYM3vPubsvopmYFlwMlBp8iSIFImWe7A7dDjzEaIyW10kSZebrbcdOPOCI5RVZV\nwaKyPNkfeHzY06epprcpDaVU+OS5Dke65IkxYE1mXSlqLdk6xxCPSO0pTcQIyUIJlBoYQiTmiBSJ\n02rgREfGbHDeg8g0esDIiECx1iNKeG59hYuKWjkuyyOFgC4WhBwJWWJVZEiGksSl6dmGkqtQEpJi\nbToWOuLzGUNMiBS4jRURgUuWE5M4MwsejytuhiM5GxojaNQpmTO0FBgWPO1HCIKlWfEjZw+4b064\n6nr64FirhoU1lLKg1Ip7VQ1I9u30pFGWigerNRqB8xGrFVJIhASjJtEVQkzT3FJCSoGU37nuwCy8\nH1pm0f1hZgh+8mzvfNnLuqb3nhdti5ZAFjzZ7wkxooTEFJITVfH82PJsvyeKDBkqq1mVJYP3PG9b\ntseBox+xVnFal6SYuR06XvQthzgSc6KUmkVpidnxojtyNbREkSgKyaqAQlja1NOljpgDWSRqKzkp\nppvEfnQIHNZ4FoVHS4PJGSkGhpTIgJKJEzNyaj07Z+mCQJFYFQNWTZ5tJT0pJ7apwgdJpT2XRYdB\ncuNLfMokBKXy+GzRKJa659pV7EKJT4KFDpxZyZgv2Y6emD1DLEBoUq44KyrOTcU7R83N0FJgWZSG\nT9avIXLDGB2Wkt4HFrLkzDR86uyclanYdR0hJc6KmkVlKaRBApW1SBLeJ7RSaK1YViVCQEpThCsE\n70e63x3hzqL7oWUW3R82Us6MITCGwG4cqI1h8J6bl4NocuarN9fsR0clNVkJXlsu6cfA88OeLgRS\nDLiYOa0qhJIEH7htO150R0KK2EJTSkNtDUc38s52x2boyBpKZTipLTEE3j3uuOlbvIpoBauipNSa\nNrXc9I6YRpJKrEtojCWlzDEdyCmgdUTJTGMlCxPoXaQNGU2gKANnZgQMIUDMjgQIKdAC1taxMgNP\n+5ohSqxMLMuRSgRCLtB4+iRoY0FMklInLoqBmBXPx4qcEglBoz05FyAqVO659RVt0MQkaaziUd3g\nwgnP2o6UM5mSRloqfcLaliyM4cVesB2OLHTDeVPwn05eIzhBFwYWqibFyL2moVElF6uKQhi6cUSQ\nWJQltS0wRpFiRmuJlt/p6RaFJqU0JduUurMc5Cy6H15m0f1hIedMzJnbvsfHSO9HDqPn4XJJyomv\nbzbs+5FGa9oYOC8rQsyMcaoXbbuBTd9jpcIYQ2U0KgtetAeujj1ZJASSZWmnRonuyLubAwfXgRQ0\nheakbDi4gRdjy25sOY4OrRVNUWA1HMeBjWtxOaEsVFqwtIosMnvf0YcRQUBrwVklMFpw6B0xB5CB\nygQKCYXUGDlwjDB6iVSJhY1cFAOD07RBkUVAiYyWCSUlS+3RIvDELXBhshjW1lGrSOdLHAkXBTEb\nAoJKak6LQB8Ez4aSkBJKwtomSk4YUoHLHcehwMWMlZaFXfDG4gQ3Wp4cjwgElax4vTnlxDQYpdES\nYi9ox8BJVXHe1Hzm9JR+CFMkbC1KwGXdoLScRFcr+t6jtaC0Bms1OWdyzmitvsPPBVBqrnL4EDOL\n7g8DnXPsxoHubqbt/cUSFwNPDzvIAglcHztiSlilQcC6LNgfB54e9ricaIwGqXjQ1LiYePd2y9GN\nuBAwSnFvvSSmxL7reNF2HLwjkyh0QaOn4qxn7YEX7Z6eRKENdaFoTMGm3/O8P9ClqUW4tHBW1SBG\nbvuRzo8ImbFFZlUorNK0YSTEHpcjBYnCwrqQRByHHnKOGBuoTaKQAIqYPaMXd/W1kaZInBc9B1ew\n9QpIFDphRUIJjZaJSOC6a4hMrcwnlaPKsPElQ4CYIQlBIQyFqShVohszL3pLTgmjFRe14UScsPGS\n/dhDVCAE5/qE03rJZVURBrg59hRScV6v+dTpKStdTmvE5YxOmpwDlS1Z1SX3lwvG0eFCwBqD0ZpV\nVQDiztOFEDzWTo0VWuvZXvhoMIvuR5WY0jQ4PEb2w8CqLHAx8OzQsiosEsHXbzfs+p6ltbiceWO1\nIsTMe/st7TB1aB39yMNmjbWK4+jZHTt67xjjlASzymCEoPWem+7IpuvIYhLVk6pBKPjmzYbrvsXl\nTFbTsjyNtTze7rgaDvSuJ5JYNBXLskKIyIvDgT4FAlCoRF1JKi0YYqQfO2KMCJNY1JKFkowh0oaI\nzAGEpzJQGH0XqTo6p4hkSp1YFQ6DYkiSkCMuqbslgGBhM0vjuB4snZ+G2hgTWegEuSSKjIuBnSsn\nMZSa0xp0yty4gt6nuxuZYF00FHJBItKPicMo0UlSWcNrqxWnsuGqixxdRykshTF8enFBrQqaQhNH\nQeccq6Lkoq5542yJStM6cRIopURpiVGGurSU1uB9QIjJ01VKYY1+3899KbQfxgqG747GP8bMovtR\nI6ZETInboZ+E13l248hrqxWZzD++uGY7DCyLki44HjRLcowc3MgQPcElbvuBZWEprEIJhUrQu8CT\n454soJKKwhhOqpJ2dHzj5pbOj8i7ltWLVcPgPU+3e3au4+g9SknO6gatFNuh5Um7Zd8PZAWFNSwL\nQ0pTbe7RexJxelxvJm/40PfsxkBMI1JLFtX0iI9IHHxHCgFyQtvMqtCUKrN3HhcyQgW0yNQSlNHI\n5OlCYgyKnAVlmVkXDoJg5y0pR1IWU3QtNYUWGDNwcyzovUJqgZVwWmZCKmgjuOgZvEFmQaFL1qUh\nEdgeDV3wWFGwLDQPixVaVByDw7lIdIqlKTgtGz5xtsZmze1+IIZIYwuWheFTpxeoJFAKEBBc4rSo\nOWksq2UFWdwl1CRaC6yZutW0lhhtSGkaOvTSy/2w2QsvrZCXFRdSyskOSwlgagr5+NQbz6L7UWI/\njuzHqRqh956H6xXOB54e9oSY0RJuuwElBEpIEJlaWwbveLKb3rMoC2JO3G9qSILH2z2b/ojVhpQT\nnzg9RWTBVb9n242MMTDGyHndUGqFi4nb4ciL3YEhJIyVNGXJwhYc/cA3Nxv2YSRnMALOmgUpR949\nbGndkYBAW8mqLCkLydGPbPtAThGpIlWhqY2c5jiMHS5EIKAtnJRTZr+LIy5Esg9olSmtpLaKlBxH\nl/FZonPEFtDITBYalz1ulIxRIUiUleSkivRj5rY3CJFBCowEi0Frhc8D+97ggsIoSVMaVoXADYpN\nAB88AkkhLKflgtIYWtczHCVBRNaq4aRa8HqzJCdN7x1hTJRScVHXXNZLzhfT79DvR7KCWhvOThoe\n1A0pJ3JOpCSByKKoqCtLVVpSiqSUMFqjlEQp9c8i3pf8R0aXL8X0X9r28s/LkraX/w4x3x0rZKb6\n40xm9IGUJ6unMN+yTFLOyB+OKHkW3Q87PkZaP9XWtqNjXZb03vO8O7K0FpkFX9vccuwd66JgIPH6\ncgUx8+5hy6F31ErTRs9rzZLCanbDwO3hgJGKTT9y2pQ0xhJzxEVBO4zcDi1KCEpraQqDFoLnxyPX\n+yOd92ijOV/U1EXBpj3y3u7AZmhJZJqi4Lyp6Zzj6fHA0U3NE8YWLOuSxip2Q8em73AZlMgYKziv\nSjye3eDpR4dQGatgWWqU1vRhYBhGfApYBYVRrCvF6D3HkAghIYkUOlEZizaSYRw5OkFMCqsSVSmo\nNbio6GPEe4GPYJSgLgylSez6xGGQCC3RZAojKaUlZk0fHcMoiVHQmIJ1XbLQkl0facdEDhmrFOuy\n5n6xIkvFvusIAcjwRr3m/nLJed0Qxml5epkElbS8frqgNiWLwhBjZHQBI6EqKpaLkkVpGV1CqczL\nS7Au7OT53kW/OYMQ30qmvRRFIQTpLrL8t9gP364D371PSun96PXlZ758/7dHtt9uLfiQSDkRYqYw\nGikFKWVcCPiUMOquFTqmqT5ZSW7vyuq0klzUDfq7apI/Ysyi+2El3K0/dtt35Ay99+yGnkerFSnB\n310/59CNrMqCNgbeaJZ3sxR6Ou9JMbMfOtampi4NCYgxEXJisz/iEyxrg5SK87pmcIEn21v2Y6DU\nCqM1j5YrkoSvX9+yG3uCT2SVub9coAVsxoGrY8eu60kqU6qCppSkmLkajtz0PYN3CKNYVQWN1hyC\nZzN2DG5ASEFtNSdNQ8pTx1vnA5mM1YnGSAqrOA4eHz1DiEgNlVKsa03vwxQ9ukASmdJkSltQ6syx\n8xxDRpBQApoClNRkKejHyOAEGTBGsCwUQkA3CAbSNJM3cdfmXJBUoh0j7TglJQttWZcWjWQImT5m\nYpj2uSgbLuoFMme2g8eN0wCblSl4UC44uasaaceRGKeE3qdXp9N3ajSjH8kpU4vpZrdeVJSmoCgk\n4xDJMaOkoC4tTVNgjSaGeCe0GiFBSfV+sm26jsVdSdk/F6v/Xifbdwvnt7/+9qaM//4+0+e+3K6U\nxPtAiHF6GlMSraeZEilN0evgHT5lQkw01lAWBhfjtEpJDNTWUCiNi9N53m8WH2VveBbdDxs5ZzZD\nz9E5WueIKfNguWDwgXe326l0CcHtMNIoiRAS7rzJYZhqbROJRVGSROayqFBC82y/5arrqLUmicib\npxcYqXi637IZe1QUDDFxvlzSaMUYA5tu4BAG2sFRKUtRaiqpGVLg+X7PzdChhbrLrFu00rxze8PV\n2JNyIpG5aBpqW/Cs23F13E/dbiKzKGtOK0tC8KLb0vYeIRPGaJalpigU+0NL6xMuOoyWFIWhthIf\nBEMYiN6RSKhCsTIGpWDvA6ML5DANillYSVVUIDyH0dN7EEiMgNoCGAaRGH0mjKCkoLSapjY452kd\nDClj7rzSWpVTwo/MrvPEqKbaYFOyLDUg6AdwIaKyoFSGB82C02KBEInbXU9KmdOi5l694GJZY1HE\n6BlcJmeojObBasH9qiYhGUaHVNCYiqrULJoSQcZIDUxzGrTSU7KtsECeBF1KePkIn0FJibqb3RDv\nfGB9Z0t8Oyl9KzqdmjBeRqwA3/JkvxXNTrmGnACR0XdRdoyREBIxJ7QShJDJYrINumHkOHoG5yms\nZlVPT3A+RLSU7P1ITJk+BRpruKibSaBD4MFiMZ3bR5NZdD8sjCHQekfnPC5GVmXJ0TmujgdWZUUK\nkX/a3DA6T20KApnXViuCjzze7Ti4kcpYeu94a7WitJbbtuO6O2KzZOdG7i3WLEuLS5He9Ywhs297\n6kJjtMVYRSWnhof3NnsQGSUVZWE4KQyd83zzZkuXRsiaolTcqxtcSLyzvWbfO/rsUFJzf7EgkdgM\nA9ftgdaPKA1VWXBqKxLwrNszDh0+ZbTRnCxrCj2VXbX9wJgSVmTq0lCXBTF5toNnHBOCgCk1K6OQ\nSrEZPNGPZB9Agi30VNERHPsY8D6RI5QKalugjaGLI10fCQkUgqrUVFbjQqRLmRgyIYBUgpWpaWrB\nYfDTwBsEFoHVkqUt0VLRjZ42OlIU1MryqDqhNoreB/wQcTFNlkRhOasaLqqaFBO3uw6h4LyqeViv\nWC8qiIkUEj4mRJLYWnBeLzmtqkmU2h5rJYVSrBcNUk8iNP0lSBkSmdIoyqIAIIT0/rL16q7JQkqB\n0YoQIzFmhOB90fx2YZ0aMqao+WUU7UMgxqk70GgFWZByQiDwIRJinG6+KdNU5TQUaPS4mBhCoDKW\nMXgyUBvN7TBwfTzSpWle81unp2xcz9F5TsuChS2wSvFwuZwj3Zl/Hy9n2foUuek6lFRs+p7Bex6t\nVoze87dXLxhGx6Kw9HHya2OI7J1jN/SICJvuyEWzYlkVuBwZuhGjFLf7FiQsigKpFYuyhBh4fHtL\nFwO1KTFacrleY6Xkm9c3bMcBmTJBCF5brzHWcLXdcdsP0woI0XNSN1Rmyt4/P3Rsjh1dHCiLu0jM\nFByHkXd3Ww5Dj9SSUmvOFgtyily3B3Z+JOVEYQzLqqDUgsMQ2A57QspILSiV4qxZMqaBTdsz+ACR\nqV250Sgh2Y2Jfpy66KRISFlythQMAXbjSHKRlCJGw6pskEYxuI7dkMgiI5Kg0oa60sQEXXKMw9SR\nZrRkUZQYBQcXcTkRUwYUhYK1KjDasPcdY4yQBQtVsbIVVivwER8TbfBICQtd8vbyAis0226ahaGF\n4KQoWJYVldGc6RqL5HnfkXLisq546+ScWltcjOSYEDITE5TWsCg1p3WNEopNd0QrzaK2NEWBEPLu\nO55uJClNvm5hNGVhyHfLMQl694OvAAAgAElEQVSmCgL1/gSz6dxH5whxGrRTFtOoyZQm71zIPNkw\nOeNjQmuFVpKuHxl9IAZBUxm0UYzj9MQWU6IbPf5ufvPZokYowdWxpR1HDsFxv2kIOdOHQGUNicxN\n13JaNSgp+OTJKRdN84qv3O+LWXRfFTElrvsOFwKbYeoEu2gWHMeRr28201jBCNtxoNYagbibtwqH\n1nF7PJKFYFUVZAlrW2KF5Ml+z6ZrWWhDVII3T0+pleHJZsv10FJkRZ8ij05XLMqC1o9sDgMpRW76\nnotmRWMVWcDgI9uu47bvkEJRlYpVMT3aXvUtzw8HYkggJad1RWMM237gvd2eLropqWYt53XBYfDc\nuJ5u7PEpUlcVJ7agsobbvuXQ94zJIZRgWdScNVOX2NXtlmF0JAHaCk7qGqUE+7FnGMCljL5LHJ01\nk9hu20hIDnJCGM3KllgDh9GxGyOEAAqWsqSqNFFMqxy7FO+WojQsKoVCEHKijYlAQgK1VCxNiROB\nfvREMkgwUlFoxVIW6Kg55AGfHSorTvSKi6IhEPExY4KgFwPIzMJUfLI8pxLFNAIzOCptubQL1lVJ\nVrAUloW0PB4PQOa0qXh7fUmt7yoiUkBLBWSsNmgtOasrrLZcHVqEypTKcr5qyGkaQ6mVwsU42QJx\n6mxryoKUEoML5JRATMKecgYBWkqcj4SY8NFTWUNpLaPzdN6TUyJnsNrgwiTmlVXcHHsOw4hLgdOy\nYbkouD60OB/wObH3IxLJdhi4t1qwKgzf3B+m8ZaFYV1UvHGyRoupquXhcvVqL97vj1l0P2g65+i8\np3UjCEFjCzZ9z7brWJclh2Hka9tbfAwsbUUm86CexPjZsWXfH1kWJaOPvHl6QmUMz48tT293VEIz\nxMDDkxXLoqD3gYPvESmxaQcaq2hsQVGVCBJt57jqpkRdYTR1pVibBp8C37jZ0IURIRRVoXmwWJJS\n4p3dtMS6DwElBJfLE4yVvDgeuDm2dG4AIVhVNVYIInA7tNz2Iz45ysJyVlZUtmQ7tmzbAR8GkpCc\nLmtObEXMgRftkbY7khQUSrEqF1gD29Zx7HtinB77y4VkqQu60XEIGd8GshSgNEsrqCrYO8FhcIQh\noQrBqtQUQuFkpg+RLgVEBKEVtRRYI3Ap08W7hJHMFEZQGT0t1ukTSXmEjkgJZapZyIYuDYzZISMo\nm7BojNFUWMxY08sDQQ1ooTkVp1wWa7ow4mPCRonTI1JJrNG8ZS5Zq5Jb17JxHaWyfKI+YVWV9ClR\nCcWJLnnqDoCgKks+tVqzLit2fY+LESOnumprDZHESVlSaDP5/mGydO4ta7RUjC4gJYScp2g3AQJW\ndQnAtuvJaUqEvSw7TCmihKLtR8Y4RfWF0SzL6ffYdz0+TFKxrA2dn/apCsOT/YFt13PwnotFzRsn\np3x9e8t+HKnuqhYWZYGWklJpHqxW2Lvpag8Wy1d2/f4AmEX3gyDdFYJ3zrEdB7RSvGiPKOB+s2Q7\nDPy3Z88IOVIpTcpMQ2h8YD84Nl1LiJl91/HoZM1pVTJ4x7Z1KODQD1gjqa1BakNjDeTEs+uWNnSs\nbYUsNffKmkIZ3tttuR4OaDTaKB4tV1Sl5flhz9WxhSQYo+OsXrCqNGNKXLUdx66n9YnKSOrSUqmC\nQxh4cWjZdy1aW4yCy3pBFJF3tgc27kjyGWkNrzULrJbsnON6t8eliDCSk6LkrG6IKfC4PdB3PTEG\nqqrgZFljEmxHx747TgkiBbYwnBY1jsh11+M6CHkS4otGYG3BTec4DJEUEwhFU0oqBdLCdsy0Pkxt\nyVJQG4UpBK0PuJBJerIfrM4YbZBZEpInyqllWSKoRYm1Ep8TvheU9QFjAlKB8CuquGZIA04c0RmM\nnmYxCAUGjW1P6Mwt2RyR0rLM93loTulioPU9tbQENY3d1FLyyJ7yUJ3yeNxw61tqrXlUnfOgWdJF\nDylzokteuOPk22rDp07OOCsrNkPL4OM0P8MW1GXBGDyl0TRG83R7JOaMUoJ7iwVVYTl2AyFOVQjI\nu1q0DIu7ZN3m0NF7j5CCB+s1PgYObY9RmuPoGGO8E3HBw1XF0QUe77eEEOmJvHV2xqbv6Z2nKSx7\nN9I5j7WKWhv+h8tLboeezTBwXtdoKfnM2Tmrsny1F/T3xyy6/9G4GLnujoSYed62nFcVy7Lgtuv4\n5naLFpLhbmaClRKpJAkgZq72RzZ9T6GmvnshoNEFCni+3bNtR0pjMQW8cXZGLRWPbzZctQNWQBSC\nBycNl8sVx+OR521LTpn92HHeLFk3FSA4Okc3OnZ9jxZqSlwVFqs0N+2Rm7anD4HGaqqiZFkodqPj\n6e5AFyIxjjRVwcNmRRs87+w2HLqewPT4fFZXCA2bY8/t0E2zcKXk3rLhom4YXeBF13FwLUQoK81Z\n3WCl5PbYsxmPeDcikqReWM4WS0LIPDvu8EMiJJAFnBcFlanYuQO3fcbdlYWtSliWkiQlN72nD4Ek\nBFaALcFIQyLSZU+UU72YkYJaliAzfR5IKVOWwxTZGgFZ44YCKUaK+kChIikbVFyhdSLgGTrD5WKL\nNg6pJNGvEMM9XBrAbO5WflA0rJmW0gTTXeDtDcnuSRhWPOKRfsAxOI5hTyUqog5YaRESTs2CT5oL\nnow7no07rLG8WZzzcLFmzIG+9zxYrHg+HCYPW0heb1Y8Wq54dujYDkeM0JwWJYu6pB8DSMFKa666\nnhCnm8+j9YqmsNwc2jufF9TLSgTnpmSsVLy73TD6iA+Z10/W2ELzbHMk+EgfPV5Gcp7anx+tFow5\n8852i1JTSd8nTtesipLrrruzNyyV1lwuGrSQnFTlNC3vo8ssuv8R5Jw5OscYAjd9x9JajNY8OewZ\nfWRpLS+6I1+/3UzlQ1WJkop1Ybnpeh7vdgzB3y1W6PnkyRmVMTze7nlvs6PSkpwFb5yccFKWbLsj\nh2FEJNj1A7UpuFiUKGuARN8Hbg5HhE4sqprKFJTWQAg8OxzYDiNGQWELHpys0Rm+cXvLwftpSLZW\n3KtrKq256jqebreMdwnAy+WCRVnS+5HrQ8thcIwxYCvN2lQsrGHTdzzfH+nGDmMLzqqai2WDj5Gr\n/Z7OexyJWhku64a6KLkZ9lxtjzg3gBKcrhZclAt6H3je7TkOI9kDEh6uasqiZjseudp7hpBRCFYW\n1suGNnraYeQQp+SP1Im6kNS6YEgj+3y3IGSRkUlQWgk6M/YJYUYWzYDVkZQLsm+QMpBET0qJe4sD\nQkhKAzkLNoczrDxyf3WDkZE+VTh3ihKCRE/X1zxcXWOMJ6HwoWJs3yTmgCqeIxAkUVDFC6IUwIjs\nLsFuCXpPEJoqXvBm8Ra9d1z5DZUsyFqwVs1UsSAtn60e8HjY8bi/xSjDJxaXvL06pfOBXddzr1lw\n3XVIJYgZLsqKTyzXPG+P3BwHpIKzsuZiUdGOnpgTS1PwvD1MFkvIXC4bLpuaZ7s9u34k5EylNaum\n5PY4oKWikJKv77cMMeBC4M31CYvG8tUXt3Q+YI3EGsX5oqF1HiUky8IwhERpNVopzquS11ZrEKCE\n5LyuX/EV/n0xi+4Pkpc2wnYYaL1DC8k7+y2nVcVJWfHubsvfvniOQk11p0qxLqb6xJuuZTc62sHR\nhpG3l6csq5LD4Lg+HKZymxhp7mpFjVUstGJwnqfbFhc8F9WSstKs6wKTMleHnqvuSCU1UiserE44\nXZa82Gx4d7tDy8lzvVgsuFguOLie7Thw3A90MbIoDSd1jbGGYz/wYrdlM4wUtqDShtOmIOfM82PH\n1eFATFMjwsVqyaqqeXHY8XS3w4VMFpmTsuayrhECnh6PHIfJJ6ys4Y3zU4zQXPUHbg8HhhRICe4t\nlqz0VPL0pN+yPXbTxY3iwekJpTVctUeujz1DmIrtFxqW1YJM4pg6bvqpKUQbWGmLLRXHHKYEnY1k\nnbB6Gm4jo8Fnjyg6jPJUZSDEglplpPYcB8NpceDBeoORGZ8Mt+05MicW5Q4lIpfFgZANSgokicf7\nBzTqwGdO3wOR2fkFm/4hGUMWLe1Q8Wh5i5YBhyUnzf7wNhmBKv9/9t6rx7IlydL7zNVWR4RIdXWJ\n7umaxhB84P//DT1DEsWeLtVX5E0V6qgtXBkfdlTNvLBJTJMsVN1yIBGBQJ5MnB3Hl5svW7bWDxg1\nLLWn1deoOHK5oPMLTHMmmZEshk63/Mz9glThbfxEIw1iLS/DFWJAa+UX3Svu48h34z2I402z5R/3\nb4i18P58ZGMDY00MriNpZnCer7c7Pp4nPs1nVGHnO7662XCeM6dpYRM8D5eJKkrUSu8cX93s+f7x\nwMO4hp72jeWr/S0fzifOS2TwjkOawQjWWTY+8M3+iu9OB85LWoG7bfiH16+5G0/MKXPbDzTe8bP9\nFX0If96N/u9bfwPd/7fWGCOfptUd69Nl5Mv9nmAtPxwO/Hg+gKxXqpQLWWDrA7VWznHm29ORaYls\nfOCm68l55XaXrLw7HzhOkdZZdm3D17trai18+/DA4bKw9Q3GwJvtlpvtwDjOvD2dIGUuS+WLqw1v\nbq5YYuJpOhNTXmNh2sCmbRn6BqPCcRx5+/SEPM+832537LqeMU58//TEcZyw1uOc8PluB2L4w+MD\nD+cTSqV1gdfbPSE4PpxOHKaZ0zzRuEDrhZt2w5QjT9PMYZwAuOkG3uy2OAPHmHmaTlymiLXCPrS8\n2A4ULbx9OnFYRmqsNMHyarfD4zjmmbvTiVNWXBU2neWm27KQOS6r3WVEaa3Qtg0mWGKMLDUy+YTa\nStsYfG7wFiazkGtm318ITaaxK1BfpgFK5uXunt5FhhA5xp4WofEzT0vL580jX2/uECBj+f78hpgd\nr4d7GonchAtjaanqEcm8Pb/GkflP13+gIDzlDe/Gr4i5A3smF8sujBhRknSgcDj/Aq0Baf8VEWEp\nA039EqMNSz1T4hbjKouMgCMQ+Mb/DCee310+4o0B6/imucVay3EZ+dzvmVX5YXzCOstWAv949QYV\n+O7xid4FlpK5aluqUajCZ23PMSbeHS8Umxik4Ze315xK5ONhojeGpzzjvGMqCVOFN8OWd+cThzTT\nOofF8HcvrlmK8v58JFhH8JZXmy2Ncyw1swnt2tyrlZfDSi98dbWndf7Putf/netvoPvvWaVWzin+\nSfa1Cy0KfHc4rJ1WC98+PfHd0xPGClehpbcBscqPpxM/Xs5oLrR2NRb/D9e3WBF+e/+JHy9ngvWE\nAr/Y3+Kt4RBnni4zZOWUFgbf8vlmw6bxHGLkcJoY57gm7nY9TePW2JcqvHs6cLpc6NqGrg18tt3i\nrOH7xwfuziMOoesbXgw9+77lYVz48enINM9kVV5cbXi92/E4jrx/fGIslbks7Jstm84TnOU8T7w7\nnBnnBesMm67hs/2OeVr4/nDgskSstXTO82a3pzOGU8rcnw6M00LbGG77HS92O5aYeH8+8DSeSLXS\nGs/rzY5N6Dlo5N3THeO4TlVdd4HN0JPUcI4jx2VkAqw1dE1DZy1RKocys+gCNhN8WGkEm1miIJrY\nXB1oQibYyhS3hKq45oxS+HJ4ZPARbyrBFj6MO3Ky/OPNd+ztRLCZT3GHFk/nJi4l8Nqd+LJ7WgcV\nRPh2esVhHvhieKSVhc4tHFPHVAecjdwte1Ly/MPuO6IazrXj7fgV07JH3QlMwYtS1VJ0QCUzzZ9R\n4xWp+Q4xlbkMtPlLduy50xM1GpwPzDphJSAIX9nX7P2WXx/fryPXJvBNd8tVGLibj2ykW28j04XG\nCaLCr3YvcBh+9/iIOINmeDEMNN5wnBb2PhBz5u18wQoUFX51e8OUCt8ejjTWkKlcdR1D43k8TwzB\no8ga1dQEklbedFucNzxMM11wNNbx9dUVLzcblryavG+fhz3+QtffQPd/ZFVVqiofz2fSM7f54Xzi\ny90eYwz/9f4T//XhE71zaBWu2oYsSq3w9vzIMY48TAtZC7/cvmQbGj6OBz5OF2pSYq5sXIOxcBU6\nRIX7dOHD4Yio4VWzZd96Gucpqjw+XTinSGvW2fuvr695td3w/f0Df3g8MYihtYaXuw0v+oFzXPh0\nORFjYcyJrfN8drNDnOFwGXmaZi7LQjCO/b6n8Ss3fDjNvDtdyGnBh8CL7YZ923CYFv718Z4pJawx\nvN5uuO03XNLMh9OF8zky1Zlt8Hx2fcVgG+6nkeOyGtg0xnG9HbgOHYjl43jm4XQg5cjQ9rwZNgxd\ny9248HRZZXPihZ3t6NsOI4ZTGrkbLyylYFvPznf4YDkTucSFXBOpTXhnaVqQ3KA5UdoL+MSunVEc\nrYGuGbksAaOJX718R+8ijcncLXtK9Fx3TzR25qvmidYkeM5Yu1u2PC0D/8v+LVszgxQ+pg1zaWjt\nTEbYijynElcyyvvlinfTFZ/3B7ZmoUrhMfWcywYxmSkHHtOOb/p3ZAxTbXg/fcZlfkX2JwwzhUDR\ngKm7NYw+btH0hpN7h5pMroFNecMb84q36YmlJIJ0RNbMtqVWXtsrPg9X/O+nj+RacNbzVbPli37P\nu3FEa6URx32c6INniZkvhz075/nt8QnVTKnwsh242nT8+HDAWsEayyUn9k3HWBO3Xcve97wdj1RR\nvLG87Hq+vL7i9w93TFnZNYF9F/jm+oZjnMlV1yLBWD7bbv5W6f7U1uM0rmGPKbOUzBfbHblWfvf0\nwJjWqPLHaaaqMmvmRdOTa+HH6ZF/PT5SFXrTcNN2XNLC4Bse4pm76cwpLTTG88Ls+Gy74Zgu/HA6\ncImZTgOt8Qxdw5tu4DCPvB1PzDHjVPg87PnqassxRY7zxLIIacmrvKsPvOp7BMuH05G744WA4Xpo\nuRp6gnMsOfLj45FLXAguELzjs6uBNrT8690dHw8HBME2ls93ezZNx/vxzMfDE8tSEDHc9gM3245L\niTydJu6OB8RYhsax7zd4azkvC6fzhaQQVLndbLnpW1QsH89HjpeJ/Gxy8mq7JzjPaV44LBOP0wiq\n7Nqe226L2sLdNHKaFuYasXY1S6ExSFYueeFoJqoB34En0Blh9usIdGMX+u2Ed4UhRC5LT1ng5dUD\nXZh43VyI6pDiuG5OzNVCqfzP++9pTKKVzF3eco49L8MTg1t4ZRNWlFwFbxKXGvgw7/hPw5nBZDKZ\nh+I5pg5rJ4wUrLZcu0RUZdbKMff8YbzlRXtk72YmtTymjlO8prrVU+Ld9IKX3T0Fy1IDT/MNl+kL\nZnsBe2HJPaU2dHrDQiJnQ0iveTJPJCJaA1dyw8/9G36cHnksE71syFRu/cBYEh2BX/obfjs9caoL\nHsubZuCb4Zof5xNTzAzGc0oLm7ZhipGtDbzuev7leECt0oqlsZ5fvbjhu8OJMWWcU1of+NnNNec5\nMuVM8BYnhq7xBOvwzvFi6LgskYxy03W01vF3t7d/yb4L8DfQ/X+2YinMOTHGyClFdqFhzpnvDgdu\n+o5Dmvjdwz1vT08MTUsnnhdDx9185j5d+O7whKnr1NSokV9tX7Mw89vzW+6mE4EeSsPPt1csmlhy\n5mk5kmplyQXvLd+4V3gDj+nAoUxMkyFow8umB2vY+ZZY1nSHU1xojeMzt+PNZkO0lR8fT0xjxGEZ\nOkfXBr4c9hyWid/c3ZFyQah8vt3zxfU1T9PEu4dHxpTJWtk0nt1mS9sa5hj58eHMaZ4Jdk2pvd70\nKIZ3Dw8c5xm0MriGq35gM3g+Hc48nGfmtDAEz812w2f9jnNedciX52bMJgSumw1t40la+HS6cBov\n69Rbt+G6C4yinKaFKS0c0oS4hq6x7HyPauVTOXNMC4SKC5bGGcQXSoaYFLc74prIJiRSbvDF0rRn\n1GR2fuLVcMSbwt7NnGLHcWn4D7v3bN3EK39krp5SHddunQrMxfH37QFLJZjKuTQccsvenmmtspO1\n8RNVMZIpGD6mlq9CoRUlauJYDY+po8oaKb8Uz8ZmFgxTNUT1/Mv5NdtmZLAjkzYcc89xeUESRSTz\nfr4l+AlVIdaGlDrG8XMWk0j2TMottQRueEEUZc6Rod7yWC5kU9Dq2EvPPzSveT8e+ZQnet+iCT5r\nt6TnEeAvuy3v5iPnmDFi2DrPL3e3fFhWx7ldu3qDvGhbVAzjvLBrAktduX/jDDkXXg09YyosZLyz\n9Dbw9y9vGEvmaV653yEEvthuuep6Ui3cdj2Nc39mRPh3rb+B7r+1VJVYCj+eTojA47iOsH6133OM\nE//543vu5vNz0wte9T336Uwrnt+e3nFIF+6WM61tedNcs/GWH6Y7pjpxzjM5mZWCMJUX9opiTpzy\nI495oqSOUK64DT3FzRjgMZ2Zl4KtBvHwhfmMbeP4ED9yN5/R2BNyz8umZ+jXDK+76cK8ZFLJhOD5\nj8NrLMrHdOHxMhFToVPHTd8jfh1lHeeFd8cTc1zom5WX3TeBYyx8d3/PNE8E6+jbltdXV0Dl7fHA\naSrEMrGVjtdXG0JwfDqOPF5OTEtcm2DthqttTynKVBKfHk6oFtrQ8LrfsBs6HqeZ4+XCmBeM8Qxt\noGk8nQTGuPBxPjPmiHUN+25NGD5L4jwvLDkTQ4RgCI3iq8eqMNoL0US6LuGdYrDsupElG2KEz2/u\n2YSZ6zAx5UDKnlfNAWcjezvxIpwxUrm2E3PxPKaWn7f3DCZxbStRDak6tiYhUom14aVTBPBiiAqH\n4mjNTDBCYHUkiwqVlaL4lA0bm/EomcpYDfdpIErGSeFSA8YosTqWGqgY/nV8RTWVzi2MpeVSApf5\nFVEtxUwclz1ZFEHINUBxyPySKHDQM1UbpBpec4Mznrv5ws4MjDmTUYI4nBp+Hm5YaubH6cLGOZak\nfLHp8TQ8jGeu24GpJDLKrvGkVPis31JEeYqRbRNItfL1Zo9xwo+nA6JCaBxfbPfs2pYfjkesNbTW\nsms7rroGZE2KfjEMKMqrYfM30P1rXKVW7saRMSWe5old07JtAqdlXn0RrPBuPnKcZ6aS8MZw4wem\neuZ34yfulhM1W4Lx7JuGUxm59T0f03cc48TETGNaXL7liyHwUN4x5fl50KAlqKNpK1fyiiTvmcoT\nkcqSBnR6wefDlpEnpphYaiEtq7C8CZYXvCKXhbN54hhnytzTlx23fQduHSy6G0fGacGKwQfLN+1L\nWm/57nzPh+MZW2WNBN+0bJ5lax+OJ8YlAsp+0/HF5pq5ZD6enjhOibSsG+t2syX4hnGZ+XQ+MKWK\n09Vs+/VmoAgczpE5jcSk7PuWF0OPc55YMuMcOcWJWAu7dsvLvqH1DR/HM8fLmnRsgiM0HucdVF2V\nGXVmtBkXlMZ7OgKLmxlTpJqKGxaML2yahJZAnS395khoRnZhJtiMVsfL5kxWSEX4Znigt5FbdyKq\nZUmBN+GRYAudJDYmIQp7WykYLsVyZTPeKBsJVCBVQ2sAlKRC/3w1NggVuFQoumBFqaxa3wgUrYgo\nT9mTqDhTmNWR1HAXdyzqsLZwLi2LOmJ1jKVDpXI/X3EuAe8ycw3EYknzNVo7LnIh5sBUPA6DqT21\nKl3ZoDXwqZxxdY05+szuubEt7+YLXgxGPKVkBteSUuVNP4AqH8aJbbOqca6f897enc54sTgjtK3n\nzWbL/XmiaqF1HusdLzct07NHw6YJLKUQnCNYy1Xb8Wa74d35jDGGqzawbzt+cX3zl+wwBv8G6P5F\nHyX/o2vOiVJ1nYZBGYLnaR65H0eKFn5/vOOfnz7xkEa+6q4RDF9ud3x7uef385Fvzx+wYljqwqYJ\nfNlfM9UPXMbvOSwLqfZkGj5vXrLIiHUHDny/TkD5wrVvyfGKKzdS3TsK3zElx1x7Gunpuoh1M+K/\nJ+iZ6gWbe6b6OS/blmyeeMi/5VyVGj2Bhn6feIFjygee9IGpZIr29H7P19sbznXi98f3jKUQS8KJ\nYXs98DLsGaeJb+8/cFgmQnHcNhuuNg2pwh8eP/F4HqkOfBWutztuho7LMvHt/QNLVAxwPXTchoap\nFr4/HplLoS6VoXO82u/ZPjtgvT8/sqSIMbB1A5/vB6quTmHvTiemUslGGfqOrQ+oFg5p5hQT0WSM\nN2x9QPw6enpKC0kWSrcQQsWKweV2deiyC9JlmnbEm0rnKqU6luxIThnahZ2d8ZKZiqcYi5GVA7ZG\nMVLZmERlTR7uNWOlsrcVb/5beoMAIoWkishqTA5K0Wfzb614yhoVBDiEKiAqZF1jbLyJiEJWS64W\nK5XBjqQ6IIAzGVFYSmApDmsT3o805lnlEAPGVoodORWoUsmyekbYNNBoy309c6kHcmlwKmwlMGbl\nWCe0KI9pwhmHlMit7XhpB75fHvn2/Mhg1wilvQuc0sKcCw/jgorS945UIMXKHAuTJlrriFRcLmxc\nx+P8RCqFpRSGxvOzq2vOKXKIM/lUGELgVb9h27aUWiiquL9s0P2/XD850H0YRx7nVT/69njgy93V\nqlKQyj8/fOBcF4J4OhP4fONINWMl80+Pv+PjcuZczrywL2hdx5sQeDt/4IfpHWM9kavHO2XbRCQP\n3LS/Qc0HkiqPy56kA73Z4NwZ677nKvxArEJhNXN50D2ftU8E/wNjNRzihjlt8DowdJGtOeH9bwiy\nINXS5o55/Jq9dVzkjg/1gdk48hLopcdfKW1U3qUfOOqR1IPJDdu853U/8JQXfv3wHXOqGANDaBn6\nll4CPxyPPMWJpJXWWV60A1dtx6fpwj9/OpGjgihXvaMLDYLhh/OB87wOSPTB8PpqSxcscyl8ezxQ\nFqUYGJqWbevxxnGYR05TZCTiJdAPDb3xLGTu55GUIqNUaiP4xuHF4IplTIlRIuoV2xYcjmBWr9ip\nJoRE8IkhLBgjxNwSTca5yLZZCGENrXRUklqKGqbq6Gxk55ZVN6uOogZQvBRGXY3Rr42SV0MuYlVU\nYHWOeU5UePa6RStnTSTKM8mw/rw8m4SLFtIzEFdWAEchqxAJzOJwRqkqTKnB2kKQGbUtEUdVhxGl\nVGGpAgaKJKqNq1td6oYjC5sAACAASURBVKhSmczC03IkasVUT9WCV4etjlgnjrVw1IwTw2vbccqF\nQ5xpzIFzrgQDC+vtqbWOh2VmqQtVFKnQi+OsmUkzd9MZawyvNhtyWWmvf3n4iHn2x7ViOMeZ7w9P\ngNAEh7drRpozsvr//nfRP3+N6ycFuqkUHueJwQdEhN4H3h6fmIm8G0/cLSMHnfn5cI3DcCkX/o/T\nW5Z6IWehcy3eK9YueD9zLH8AGVE34/MthsDnrbLwgdb/jkpiKQ3eFD7vL8yx4xfDP9HYA0kt7+cb\nct3RSkvbPLJ1/8xtuGeuHl8dex/54XTFq/ZH9s175t7yIV5xWnaEuqFpFly9wzf37Mi06silYTJf\nM1jDST/x5D8xakNNgZ4WvzXImPjd5UeSW4gdhNbzghuMgU/Tme+WB2oGz8pfVydMy8ynpzNTKqgY\nhsGwtz1WlPeXkZgqtiqIZd8IXdMzlsj9MTI/j/FuOseN7/DecZwmFp3IdQ19bNse5wxGldM8c8oL\n0VVw4INlMC0LkXHJxDqhnVI8q75UDDY7lpghJFwouKasQY/FkYuiklfgMAZvI1Uh14ZJIq1JtCaR\nxFJrS2cy0zMgH2sDCF4SnRQclaKClRUY7rRQEXZm3UxWYCmVLIaFNWxzxdI/Vm3KqcIZWStk1tcl\nFUTXrxfWCcCslso6TaiiHEvHXBwVg6UypYAaEFOoNhOLJRcP6qhVSZJXk/YCSQui0JSGWISTJMZ0\nIFVoyppMogqTKaS6AuinZYICt01PVctTnHl7OTKXzK5pac1qavQ0TyQV3LOHwhgjd8cLxgvOGza+\npcjq1+udklFOS6T3npfDwLZp+XA68U7P9DHyetj8peej/ZvrJwW69b8P71NlGxp+ffeRmRlRw3XT\ns3Wej8uR++Ujj+lArJXMhc/7NwRjUfOR98tbao1cUoc1LZ1xuG6i0XteNN9TKXiTeYo3VAJv/JHW\nvWPb/YaswlhaDPB5e+RgDP9T/19W2RGO9/Oei/YEbejDkd3+f2XjL1xKQyXwdffE93nPi/5b9u0d\nl03g/XLN47yn0T3OjZTwAWlObKoS1LLrEvP4GbVUZn1k7jLJNEhuuaWjhMrjdOR8yWSXkd7Qi2Oj\nPbEoD+PIkhWrQuMMV02gOMPdZSIvdQUFDxtvCdKS8sJpvjAva1Lw4GAbGsQYDnUhnc4khWp0jcUJ\nHothzDMPc1yjYZzBBYs3hgJclshoFxZfQdYgy0YdWjKpwGImjFWcUYJRSrFINUwU+m7GS0EMxGxp\ncUwVgsmcS2BUT28iRhSkMFVHbxYywhMdosqVrUwqtGSeFKbaIZLpJNGIUgCHUhQ+aWWpQie6mu3I\naiEZCVxUWVAC+gzESgYu1fNQN6hCxhCkkKulICS1HLVDETKGXC0Zg5jMOXdEdYzV46WSiidXKLJG\n4OTqoBqkeFQqU40kNaRSMdkhojTGrqkjOZL1glahE8tgHKcUuY8LjW0Qa8CCFkFrpWZQo0QDU4oM\nvoVSucQMHbhsaa3laut5d77w43Kgbxuu2o6fX18xpsTH03nlmr1bByp8APn/Ntn4z71+UqDrrQVV\nvn16WDO7zmc+2w5c8AQxvDtfWJbCb04fUE4YE3jRtGSxHMpHWnlPyhNOKo2d6NoNosoX3Q8Yuac1\nicc0oLRYMrfhxN+Z97wOdyhrA+UpbVmwfO4vbN2Zrv2OqVpmbSjqeNOcaI3wq/Y9iHJSx10cOJeB\nTg1tOLG9+t9wJnPJaxX2ZXegpi1b/3t27Ylp6/k4XfG07AllQzYLxb5DuowvBqmOXZNIlx1jHFnq\nQm4qiKfVlisGjnXkbj4yJ4MYJfTrBjLJ8JAicSyUCs5bOoE+eGYtHMeRkgBraFsYrKVxjjEmxqTP\nEfJC5xydcxSBc5xZSgKUaIQmeLxVHI45Z8Ya16usl7WaNA5KZqaiUijtCqjeFkwOlJKpLhJ8JBtl\nzg5vVmN4L5UxOwafWdRiUJwWSrVcCAw28pgHPugVXjJbN7M3E4uuke6H4on0GCqDKSxYoFIKHOrA\nokJnZwaeEytUWVR5qB3nGrCSsWTUKFYr59pyqQ2n2tDJyh9XDDOGRS0f8hW1ChdtaSRR1ax0CMIx\nt2Q1aIFSHAWlSmIuDal65uixBqRacoVFFSkVresB2mOJtTJSmOpC0kKjgVoUsQZjDBHQnJlRmiq0\nteOsC8dYSX4Fx95aclHGNK/Zb2GNnp9SZkyJ745nSq282Q1gDFNeuB/H1XHPO5ac6ELDTdfRWMdS\nyr8Z+f6Xvn5SoGtEMCKIWBwweIczhnNZ+JfjHe8vJx7iyJtuwyyF23bLp/nEnCpP+pEXdiTr6i/b\nWovwgS/DW5Q1P6q1kWt60Mjft9/TmRlD4aH0FHVrE8ad+GVz5NquiQtJM1ZalmT42iUGGyE8cqmW\nBU8qlhs/YRC+3qxDFydW0f2lbOhEadyFv9v/hopwLi0pB151IzkP2PAjbZhZNpanecNp2eNLx8TE\n4j9Sg2Cr4IrDh8p8Vt6nE8UUtAMTKkEbQoZLLpSUKRmMX9NoGy+USXmYMzkrRaAJQmuhM3AuhUMq\nlLR+2JpWaIxDgUteGHOlyppdFhpHby0VZYqRrBmVSm0MIhWD4LDMOZNMplrFecVWQdRQI6hUii04\ngblYBqtUNeSaV6McHxmrI8V+5W1dZOdmZilUPGMJeFMY7AIoS3WcaDmXwKl2OKm0pvDCXsislMBa\nLa9USyOJpRqqGLpS+bZccVGPM4W9LAgrGTwV4VS2PNQNVipRLRlLKwun3HAuDU/a00pCWHnbi6y0\nw0NqqcVwyB1WKhYlq5LVMKe1EWcUtBiWKgiZpQi1ChJBjKGKstTEVASHYGrBG2EnLRebmDRyFyuq\nlcH15FIpKKNUUq2IE6ImYlZOsWFKC9Y7jIM5Fi4xEmtFzNrUa5xl33TMmnm4JN5fzrTe8/PtNYjh\n+NxIbeyau/fXCrjwEwPdUisV+Hq/B+Cqbfkv799xl87MqdBZzy/3t3TO8n458i+n3zOWxKQjr90N\nG6s403HMR84pchWOGCksNSB4ai1cuU/8PDxQVSgqBFPZPneefx4mvKyd2ZNasjqKKoOZ+bxTOlkN\nuicKRgwxC1/5SGsK0R05V0fCk4tnYyOv/CM3/kxV4agtl9xyyRuCZIJd+GI7s6jnUhpSatmHyBIX\ncjOxdYUeYY4Nl2VLTZazJko4I9Zi1GDVYqwQ58QlgYpBg2LDqj/VqIyLEstaORkvtM+JtBPKJUKp\nrBHsDkJYK9VJM0sGquINOCN4J1QKOSZigepWHtM3BiuswLkUollBFQRTDaYoYpWimazgmoox688R\nYVkECQk1oGpZUssQEs5nlPIswfJcdB05XdULGUdFbGXOnru6QUS5sheMVHIVHugJxXJXthjAivLK\nXggWUhUO1XOpe6yBlkSuwqO0bGXibbrmXFuiOm7smSAVwTDWwFPpuMsbHIVRGyKejZ1Z1HKMLee8\n+iq0NlGrkHEYClPya8Nt8agavCuoVGJZeW0tdi06MMSy3rpizagaGjxBLEsp3OmRXFfz9WpWLwlR\nQaxlKpn7NFEVBvU4URYykcykhR2GpIVTnvGzoAJX/cCroeXTeeQ3D3cMoWHftLzebak18zhN7Lue\nYAxWBDHyVw248BMDXSOCFcNxmUmlco4LzgjbpuXNdsdhWtMC/unxe851ZCmW1ni2rqUxFjF3JL7D\n2YxIJZc9qhOt8ayXtEorEUHJujY8ihb2ZuQzZymsjRIrSv9MN9z6BqGiqkwKFQG1NJL5WdDn/K5V\nEbClMOeGV+5CMJkbd+ZUApFALIHGZG78E6HLFBVG2zDnwpgGLBljMzebJ6biWKonTgPBFi51JnqL\nc6vPrC2QFkualFkL1QKNRaRiyxqlMsU1lBAFgkXqmlCbMlSFUtYmURuEZ/kqY9aVemDttAdh5Qmp\nzFEprKoA69dqzAI1QbVrkm8VKEZxajEUtJi1GmaN2hGEPHuCz1QHxlQWhJINjkDjEiLKJTY4U/AG\nlhw4aEtrElfN+HwoWj7FASMdqhZnMojgKTgZOZaWSw4ghhf+RGMzWS0f68CmOt7FLSJKFceNvdC4\nvJr05Ibvy46CYWNmkljuy4ZrznxIVxxyy1QCvY3sXcaUslIP2fOUBpyppLpWsoKiWjinniVbshpa\nVzGiTEVI0RKzIMWQkwE1qKmIVGoRqIJXT31ucGUtFMM6xKHQmkpvLPdx5pO9IAUa5+lw66Skd7Rq\nMWJYE+eUMReETAiBfdfyuEwc5guqhSknPtvuUFGiVDBwmDNeMirCbdf9yUs31/pnQoj/f9ZPCnRF\nhMF7vn16wFvL/TjSh4bgDJcceT8f+TSfV12lOP7j9huWmkl14ffz77gxM0qDp6XzQhbBSGFjH/80\nRXLOW7KdMQheIkolPJ/ca+QhVK0EqeyskHk23AYaWcX0WxtQKpXKXEHForXgpPDan1CEpAZB2JjE\nmBs2dsQaxZvEuTQkGubcYCUz+DNtsKzFn2IE5qmjsF7fm82CFEcplnluUaOrUYuzYEAUjAVNkPOz\n5EkE3JoqYFBUhSkDBWTNsMTKH2VVa8VLBV/AtWCeK+K5Vkpdv/dmjeERXXd+RolUUMGY9d80WVBT\nyRmqySvoR7DOYtwamJiirA2fqogKMXlwipOCd4VUhcdpvZq3LhNMYcx+pT185ZI6YnF4U3jdHmkk\ns+D5kLbcx55FPY1JZByOyo0cOZWWY1qr11t/YWNmZhUeSs9SDR/ijoowq2eQha5J1Ko85YEP80DS\nwM5NJLGcc0AVDrnhKQ2kYikYrsKMI3MsgUvqVxmZKCik7IG0ej1nQ8mGHD0uVMQqORYo8sd0dQTF\nUki6tvKyGowKvazc/EzlVBfUFLTaFSxLAs3rFF2ZmLFQYVsDsVZaa9fPc03kZwld1XW8vveeN8PA\n3TTxNE4UVRpr+MWLVxStXFLk/flM7x3X3V+0efn/7fpJgS7AUjI/u75BgBd9z7dPB07Twq8P71ly\nxhrlVXPFgnLIR+7iiVMeCdgVbM0bJl2oNaLmB6IqD3nAKni7nvhQ2doFRUkYTrVhbwqGVTyfdBX9\nPMs1Ya11EZTerGCjrEAUBIpmOrP+1QwsdX1NUYNKZeNmYrUkVolRazLn1OJMAgOdiUzZUzSwpAY1\nFe9mgg2UYkkFvApLshStJARpMsYoRixxdOS6glixdRWVWkEQSoGaKxS7PmBZwVEM1Cosdd14UldQ\nlbWZTgZKehZTVQjPeqs1whtKVdSszRRbDAZdHdxQtOgqGq2AAYuDtEq3jC/UAmkWjAmEsP5Oc1aK\nM9QEVCGrRQSCqVibmQt8HLdYKTSusvUzc/Gc0sqjHlPHmBtAed0e6ewqK/uQNjzkjqU4WpOYNWBR\nGpM5J8d92rFUy8Yu3IYRqXCqDWURHtNAUktSARVakzBUHvOO+2VgKoFtWDAUDqnnKbakDFNsKUWY\niqf1CW8K5wJLDJRi0Gr+9DmM8Vn/WwXU/OmrOmUq6zMsFYKss3PFrKbyVYVLXgc2dt4z5kxV5bHG\nddTcBaQqpVSiKkvNWAQxsCyFj+cT1lpu+55N03KcJ/5weKQLgauuYd92OCtrwwz5U7V9jsqr4W/0\nwl/dMrL+sg3CnBOHZeHvti85zQv7ruUPl3sOOfPd9BGDwQls3TX7UFjqA7mMpJpoqyOWBiNbZiIu\nR27DJxLwWFZlgX2GSkVpzWoVmUUZVeh0BRNQIgX7vFWA5y0DsF4JWQuatSIUZXq+bleVP2lAVYVU\nLCqyXp/FULCUYrFAzIZqCkUsGME8A1ue18632IR0iq2WmsxaUSehaFmHBIyAKEYFTRWtBgqs5TBr\nBZWhkp83t0Xqqke1FhKgFZgVrGCNYhVwKwjXvL4fi648rhGgrtykruCsVtaKbQZxAk4pJq3/vUKJ\nFi+rlpgCxRhMSKQK5eIRFZo2M/hIUSEmt1bvZY2jF/V4M2Ge5V/vxj2WgWCUm+bCVAOn1IIo59Rx\nTi2lwk07sfczogufloHH1LLUVYomCofUrtVxNXxY9uTnYYYXzZkW5SENvF92TMky14BBGXPASSGY\nzCU2lCrMyRNcpnWJKVvm7InVkJKjVkiLA1dwdm0qanbrM8ewPmygFsgr0CNrYKepyizKSdfbQidC\nYwNzyhzyaeXLjWepShHF2ZXnxcIawrRalQp19ZTeDNxdzjwsI4uuk2idei6pcNW2XLctH8cLb09H\nrtqOF13HZ5vdsz75r3cwAn6CoLsNDb/++JGKcokRFL7c7fHPp+7b4xNPaeI+XfgyfIkxht4Kv5/f\n8XGMWHdGBCwWkS1WjjT2jtU1FabcsNiGpA0iGSFzaycylUM1iK4cGGQqjvAsSgcl6rO6ArCqLOuU\n/p8AFwyV5yv0nwDaUHXlTZfqmCSs4ItDWTvWSzFUsRSzFohKJVZLKYYlG6qpJHVkDIpBVKjZUMuq\nbVZvkFXnj9Y/RnfLiqJ/XD6DFjSDFAdiEZ5pBOrzezaQBXGCe35P1UOJPE91gTfP6tU1moHi1mdT\ni2AtGF2rs+oMuVYExVSzVr/l+bWNRSStzmBZKdliEazPgKJ1bVimZJhGt0rifOWmv6wKgDnwsQo5\nrkdmkgYflvVY1Mq7ac/dtEGeQbNWGJPjTnqW7DikjqpCayubLhLMwrv5mlNpiMXT2khrI6fYczIt\ngczDvFIauRp27czGLYwlcDcNGFFislhTKUVIWKyrlGqJs6PkgBqh9YmShJwcSy5o9uvnR/54QvP8\neTJQ/5t+OFmoNWMw+AIF89x4XM1wZoRSwZJoXMfZRC46UaJgrSNkIWtlsG61Y5RKrhHrDHNMOGPZ\neM/Pb655fz7z9nTgtCz0PvBqv8OIMJbCJSWcMfTy1zsYAT9B0C1aue46VGDbNtydL1Qq//nuB45L\n5JhmBh/4prnGO8vjMvLt+ci5Ltg20eavSRSuvONc3/GUhRtTcQKpehbtKHpk546AMqnjXALRJpKu\nneakhp0pVCpjXaFztQOpZC0EMWtDDSi6Cu8B5Pn7tZ4QVOW5Tl7/jNpgta4NvGooGKZi14m0aphq\noCDUashVSBgilioGhzJXQ0rmuXpex0+rrmJ4eO6w5OdK84+jVFTIee3AIFDtWnlXBbes1WoSyB6x\nwh+7akkKVNC8aqetrDcPrawgIM+FWV25+FCe7wtZUVfWK3MGiSvFYmRNh1CFUhQjBs2KNRZR0FD5\nY7GeF8s5d1CUNmScKCIw50AuMC8NZm4wXnnVn6DCvHg+FLs2pQxkMXhTid7gKXyYttxNAyKGq2ak\ntzOHNHAfB7RUDrFl9WOwNF3CS2WulqfTDjAYUbZ25lBbLkugVuU0B5bkqdXgXKZzkcVazkvDuDhi\ncnhXUTUUNcS60ktaDCrPB6N7PrirrLcSZeV+TCHqqgIpSRDjEYVghSkruVZGsYhZpX/xmSZKVKoo\nmUxVg2TQ0FEkU01mLBFVx+OyKnP+T/berNmxJLnv/LlHnAXAXXKptYtNSpQoiSM9zff/DBobjY2J\nNLLZXc1m15p3xXLOiQj3efBAVs+YyfRESVNFvGTmvYkL4ALHw/3v/+Wz21t2mlit8c3rCy6J3TBy\nO43kNHA/73hZVi618Hg+MQ0Dnx8O//yF4H/i7RdXdKs5h2lkTKnnnJ34+uWZ27SnKPzr9+84tpW1\nJP7z899Sm7F6Zac7Phs/Ianyw/rEN+srYzLMhe+XX+MokxpD+oFnywyWyBKBhkcb+MJeuU0LjrMZ\nvFhir0rzoCmd3dmL4zROHsuj6tHxNTeySIcbYpHVXD/+24nOtxKF1gkYQByOtifhwbZwpVoomlYb\nqJbYWmarGh2ihUlLaxIdb3JoYK1BX5ogEpXL/IolRBEuHcilQY6CGt9zqDlm/9ogbQEReMZrRmvg\n3N4Vaq2/JnXALO6fhE0c1Lq5jEJ1xASnIQiegmFg5mgFL4KmBKmhucbyb1PE4n4ijg6gyeM9OScu\nmhCE/bQxSDBAjtuIuPF62ZGkYpr47HBizoWnZccP5ztaP5RyalTrBdATD8sOOwsmyiGvfLI78bjs\neFxvOG0jr8sOF8csMQ8FGaCZciwzT+cZI3MzXdjWzFIHnn3msg20linVMZScDclGXTK+ZawqZEPV\nMDNo+tOprF0c4f2A7F+f1IMdohuLG8WUUVJgQp1IUBwkgdcVcdj5zKZCytB8o7hxroaokCV8HM62\ncakX1pzZloC7pjzwb96+ZT/ueFrO/OH5KVKtDwc+2x+oblQz/n8dSfnfuf3iiu4uZ759faVYY6mV\n07ry1e0tSYVPdnuel5WH05lv1g/ccI9l59N5zw/bK+cKF34TibNquN+SuJAQipw4WWMnYZ7y9fol\nhjDg3OZHjl5DhCBKMdhcWQ1u1HCB6o1nE+5VqYDiXNwZOrXsZKHf3zxcqX4i1XiHEaJD3Tz1e0dB\nr57YmuICmUaVxKUOsfCSWFpV4sJsvUt1hLIJIlH3grAJ3RorLt5KPxH6n8mJameIXgHc6JJVNxjA\nrV/oJmCKtoppYIAm4KYo0jm2sUwkQUmRWBHc3DgI0OiaXQVJ/WeuHjiyNCQ7kgzX4Et7AauKmKA7\nY8zB6309jyiGFkhjwyUKsbmwnTPnZURcmcbCLq80Mq9lYKnK82UmeaOSud9duJ0Kj5fMt683qByo\nljjklWLBnT3XzHGbOJUJ8YmU4dP9kfM2cNx2FMscl7EXw5A/15yowGXNCJmtZoaxMI2N8zKxrAOt\n9WlErpht8Gv7xyMORbz/vvoC0kJsYgLohhKHsokiCqM2ksLFLpw78WFix5CFtcAqC6vBYCCp4STm\nlMFhGGAYG8NFeVo3Dqbc7Qb+4u07vju+8ruXJ97tIkV70hwznXt8Jv3nLQGGX2DRHVLCcZobSYUx\nZe7Gkd+eHnk8n/n65RkV5910YJ+H8DpdFx6XC56emPQ9jvHZfMPJX/hhveWz3e8YPEQAl3pPsTNZ\nKsjK5nBuIzUn/qnMNAI7vsknjsQoNriGHt6d0eCQrpBDUMaSKNWjzoQRS0hFi2tgkJZ6DWrAENdZ\n/567sTBhHnACHRPe+qLs2i/XFpCG0LX1KK1JYLkOiEVnmXqhNbjiHpprXLwVaDHWiwnqHsucTAdp\nvYO0jsiKTYA7VhIMOeCF6rTkkBz3WDOmFr8ztEMq6fozo0u1kgKK0DgQfABUQjm3gslASg1Nhg1R\nqNeSoAluDpphMuaxslXh5TQh3vCamHLFNMCfppnLZaBYAheyGjeHjVIqp5rZTgfOy9SLHJE0PDTq\nojwc9zynidoGpmELLrQrr2XiUjKnZeAimWbK7X4h03i47Hk6zdSimCs5N8S7ws4drxossAYkCcpc\n9WCS1A4BiUFqeMtIEyBYJiqV5IqRKKa4C5nKKJkNMC3BlxUl5RErnconQtXGKkLtn0XxjPtG1ROl\nZi4bkEIV9+XNPVkT5sZLvZBUO2vH2U8jv7q/4/vTKWA+c/bjyOc36Z/r8v9f4vaLK7pba9zNM/th\nCAMcHvnd0wMv9cIPpxNf3dxGPHkW/svD77h45WE9gcPnu3tu88xrWfnhsrBJQacLf7x8GQ5LOlPl\nld8v9/y7m6+jqbORp3rPuS0MWshiHC3z2mbu9MK3LWMI1ZVDWjlRqC1ypDaLrndz4ZBqx3ATJ0+c\nLQe3Uq6dqlA8Cm41ZbUEnUMLcf01T5SqseX2RPNE7fhtxihNadYFGgExw5/+3YM1EKNpjPu4YR4L\nMq48XQTNNTovj+KGEhisVzwbLtrhgXAJCyJvDX7tIKADUkFNkD/p7D0aXrT1YqOCarAwSIAEjU1N\n0AKM8dxbSoh6HCin4EODkscV6d3wsYywGEkVqYruDJ0cKc7LMqFbom6ZXQ5qlwxOIXOpmfMlkzRe\n7v3ugjQ41YkP5x2XbcRrxjSK5GEXpj2vlx2ndaC0zJArkzYubeRSRvDGto5Ao9XEMFdSamzbwLZ0\nJzEnFmriYGNQ92o/3KSBKa4axdJjysgSnFzT6wLXGDxYI5IM0QU0dfg3Pj+71KjN8BTQgzJFzHv/\naCCNqpWLCdULOyaShUy46MZpFXAjS+YwZP763WckTbyWjd8/P3M3Tsy7ifeHA+5Oc+fnXHZ/cUU3\nqVKb8WIrW628rhu7YeCr6S2fzkFZ+W458pvHHzkXY7XGV7u3rNZ4k4Xvy285tcrFF7Jnkt8x6Y7i\nhR/XlZzhk7Hwt6dfk4AsAykd+e12z3/cf9NpCAOP7YZ7ceYUKQ0XH3lsE+/SxksbghxgA4fUDQJb\ndIIXG8Jv1ROjWIflEpslTm2ieoyH13SCQkbNqGRaC3zROzk+aFGh2KpNaJaoFkVZDJI0Gr2SuHCF\nAsWNJEYTga13VW4gkMWwIfBGagJzpEbnKkPDXQnj1/hVJBxPJcbhTCzdPCHNEAveAyIwjB+5uakJ\nrRuieOpwpQjUUO4pUWg8E1S4qF2UltACeXBQxQfHUsKbIVt0Y1hi3FdkjsPkdZ3wSx97NSEZZHKk\nOed14rI6tSqDNjKN5rD5QG1wXkaQIVy7DhuTNF7XgcfjjDeh1kS6LupzOIOtValbjOkqMAyx7Col\nx/KzxNLt/3UItcC4UQ0oKMV7gevHqUQ03lS32BdgwcH21hh0ZfExzHRa8K/n1NgaII2lLDRLaIDg\nqBVSmzrrzKi+oqakOpI0sx+F3QTLYnw4nhmGgbfjxPt55rVs/PHyyoAyaixRSzM+PczscqZ6+1l7\n6cIvsOjOKXEphad1QcV5WC78+f0dr3WjufF///Adi8UY98X8lne7HSD816c/8rfPj6gmGsYn46eY\nKImN1X8TRSAZWxsp9Q7xiYWVzY2DZz7ZNf7L+aswZvGBOR/5x3Lg36W1X0TCU92TPLHXcOS/yMxD\n2zFL5WIZF+HSMvtUqMQSzBGObeolNPAw9z8pxHVkaWESfd2nWO/zIJZtsRi73j823c1DrvHxgvYY\nLyXH8rCVn3Bbs5XK2QAAIABJREFUpZFS8GlrSkhL0Dq9y400VVpS3PLHjlmtd8pqWEo/Yb3uBC4Q\n/wwtcEASNNBmNHVcO9Qh3bKzyU+QZnLSoHEQWFDStHNSPUERIY8NK0JbM6l5LPST4ZOwktCm2BZd\ndmuJ3a6SU1glntYJL45vSkotDLh3kaBxuQycq9KKkID9WCkqrHWgirAsU0hxLZPGxjRVlm3kso5Y\ny5QyhqjFEy0FK6PUhDVCWdIEn40sRl1HbL1ypT2KrYO3/v66IS4kgmpGkrBTdxjYSG5sIlQJtsMg\njayJak5iQwWaCEUSLpmhQ601F1ZZWVti1H18mpIxjRsvpQZ2X+LNezPsSCmRU2YcFIrycL7wbt7z\ndjfxxc0tD+cL359OXGrhzTSFG+DP+PaLK7rFjMM48GY3Y+5MKfPd6ch3y5FvT69IUg5Z+cvdPV+f\nP/Dj5cgfT6+ca2E/DNwMn5MlU1vjgz3wuJ7YDTeMWtnpnubwYdnxxeHvY4A15dQOPK0VJVOlsDns\n28TNuPB/Xj7DfKCachjPfFuVvxg3BsIs+7nuOEvlJoVTv4vy1HasPrC0DCJcLDOnhqFsrjRPvJS5\nd7SGSSzZKonaEkvJLDXuaxJjp0gwHNRT72ASVO2FEVLvrVqTj5IycSdriy7NwBhCktscqOgY3IJq\nqUMPsabLavjoVMm4Bz9MjGgnk+GqP3Vw3VLRa+C1Ruoijfh5uIRYQwzP/espHje1jj8jMHgv5A5N\n2I6JgQ6IqkRXLAk2o0VlBmCaCmk0NpStjtgloWqBkU6NPIUq67iMeHFaSSRtqECaHcvhY1FLprWg\nf+0mQ6TRXLiUgcslhA0iCVEjz2ClUWtmXSRodWKItoCYWrBQ4kSJ8cO7CfpHGKhdF1y579Ey1Mao\nThHhGlMhFkvXhDJIqN8aiZVMQ8jSmHVkMUNkw6mIZDRPpKaMYmRRVoyTb1QRct3RVBAK4845Lxvn\n84p44m6Y+Ku378lp4Fw3nteVeRwYU2I/DJ0B8/O+/eKKLgASOKK4oZI4bpXbPDIc7sk5Y+58d3zm\n2+OJ53Zi0swuJ/7y7ks+lA88bWe+vbyS1DnkzI4vyIxctsKRH9l8pZzfMadGYs9qwu/rxL+//Qcm\nQvX02m542BR3x7sxS6l7Psmv/M36KcVCaXQ7LBzbSNbGLJXiiee642KZQy7B4urd7kvbcakJRFk8\nkYXATj2WUpdtiNwuorNuJrQOJ1xVWXExR8eZtOKVwCItutdr4Ry14Wrxc4pC06CEWkOnKGStKFh0\nsYlG1oKN8bhmGemiDqxh2iArbikev9FZDDUMd/CgMCXFK6gHtHJlPknu0HOJoirSccwhHsMkloay\nAhqHYU1OHiWyzTpnONWEXhd5A5QUnb5fFPeEN5gng7HRNLFUwc7hg+GErwUDYLCumcuWsMsQ1DVA\nx4ZnYVuUes69U1dSriSc4srlkuLACc/LoOeNjmTBq8DWW04BUvs4JfiWEDNEuwik5o9KwVA2hn9H\naSGmqCSyNA6snH2mSCL3JIwksaQMKu8lcF6Nz0sSJ1UNyCJtFGbEFfGREbidhRnlqW18dzziwOfz\nG27GibVWntuCrwtZlKWEmvDXd3ccxomly431Z8xg+MUV3bGPsr97eiCJ8s3rM7++v6N4w5rz948f\nONaN786v7HTiV4e3HHLih/XEPzw+c5ETJztzP4yojHw+v+fF/shxXXhpK0kzh7zH2xseN6F4RfMz\nNxn+7vwls1TM92zAq+/4j7dfA+A28VL3WBtINFBjcaXWHftU+O3yaSS+mnI3LpESa84gFp6vdcdL\nnYI7jKLAZsLLNnMuw0ecN9haikgsK7YGrQW31LoHQKi8PPBfJzi4ANIYchDSzIRah+hSWyNRY2RX\naJaxoqg5YhYJvCkYF9WDGpYa0As3A5AULCMIYh5ctUxE7rgEXtkEaUFLMwXRoYstQKtiHgwL8BBi\nJEAcq0EppkuD0xyNn3l4E2iL1526hYQORhVDimLrgKwBPZCAydlSgk1wT5Tq0Jw0NDRDSwIl4Vt0\npFGFwYfAUq1kTsXxEpeexkmFSfCjKanLBvu2MFnHZjNWuu9FDhjBLDBc9fg9c/XFkBz0vRaA0mhG\n6X4ZxRKqMBIZcSKOqaPWoYk0kjbj4Gde5IBIQPPehEmDlaey0nShloHCSNag+U0KFzcufkZzIlti\nSjt2ktmNibsh833d+PH1zJQT/+b+Pe/2M69r5bvTkcNWuJ8n0s+44MIvsOhecc/PD7eYW5fPNh6W\nM98cX3m+LKw0/urtZzxsR/Z55A8vzzxsC2st7Kd77tNbboaJl3bm9y8v1OSkvHCXBqrP7OwtRf7I\nYs5GQerI5HC2tzwarLWym165H4W/Pf+KUZ2tzRSc763xn27/CRGnMvNcDzyvytDNa6oIT3VGEF6W\nA4sPXGrmblxwgdUSSZxzHVgsNusm2uFN6R1vYi0xVuJdCAGAdXMYMMtxsVtcuBkjSReeWeC52jyo\nR8mwBNUzvvYCRmVKDrNF7IxJLxLOIBXJTkshqDDv3ZtUkguoUYfeJaugFrgtLTwsXBRSwCNBjTJa\nMRgF6S5lVxmxdB6xmiMDkEOySnGkRYeaqqHZsSle8rYp0o1jrhRgkYaloJF5SbAo6oZkj+591Iiv\nWTNuhpeE5NanZYc2sC0aeKd6V/RFsaUpvvTX6IqmcE8zoqBaS8HQ6ApEEaKjhVhSXnnSLnhX4Ik5\nKi3kvOaIxiGo4rgbYy6Upt1wZ8IEdqzBBc7KmjJsjTQ4WQesGGIRRU/WELtJYrS+YE1niliYMW07\nksOc4M1OOK6F75eKMrAfB766uWeXRzZrIVzUmK6iw/0Xnu7P7hY5aWFgbu5kUf6Pb/6JxRuzZIbD\nDe/mHVur/PZl429OP+IY59r4j2++ZEuFrVT+/ukDxRsblR1vuc97ZjLfbA/8dj1zmDO74cxeRtZ2\nQ2lCnv9IMfBsnNpE3jKrZ87mFHMO05lDNv72/AVZ4LzNmBpbE/7t7Y8oRmtTdLXrzJAKomGo8Npm\n1pZY68DmmaUmxtyuTK0oyFuiWKZWpXkOuW+QJoNNYIEV+nU1LhaYrUiIwzqMoO4kazAYyY0qKXDH\nZohUJjUYIlixSkZQ1BvZaggWktAkBQ+3gEpwpp1Yqhkd5+0Lrmt0OSkKukp0eG4BrlxVcp2F3Iu4\nk6pBVnJfupkCJWhV2iJI0RXqkHBpSAPZomvWIsgENrTwnEWQYxjmiIQdZ8vgA3hLyDlFZK6DThKQ\nSI7tvG8SBbldn6vHYWapK/a0y3UlOvPkXbabsBrvTxgS2U8cXCOMkMTjMVoXhXgYGg9qAR11eoe5\ns9M1DiRGCsGQSWLMWrm0hOb4WebEYlYUbVt0+ZqwDFZbsHYYMXHScKG2hvlE8cyo8HaKA/FsCz+u\nUJrwWX6HIFQzUhJetwi9fB5WbqaJr+7umFJmqeVnHdUDv8CimyRUYX94fsbF+fF85n7e8cmQGCTx\nzfGZy7rxm5dHtmL82f4NuyFw3sfLwvflicdypkno+//D4SuOdqJV+M35iUZjyIK0Tyh1Rsi82gMP\nfuGNz9xOK8l2nOuBZ3PeHL5BVUgKpzbR2oyacnGjuLMft4AXzp+AC+dtQLKzVOXzoZC0Ym3gWEZ+\nuOxjlJcoxKslLpbYykBFWUuM6RHOGYePV6EWpdTwa/A+rmoyFAsCfh9j6UswlRXvcMDmii6CuDHo\nRppCvtws4U0YDJCGiiG7wCyj401gjUFDOlq1L7A8/BeSBf3LNSLqr1xhMYXWcG1RhMmdjyxRwCwY\nGSIgqoiEyISmSHFy9bClTMGZbTngAWlKrkIyhcFZp44Zm5JrTASyhZlL23d+siY4K1riPq0RHbx1\nWKEQVpItIXN/UtmjoJmCWWfChvrvIx9667xnta4gi296FZJ3DooHTi16TRU2Uo33rUhQv5yEemUn\nleYpjJFUPyoCHWGvhVE3ljRTfcAtkdXZ6UbpUmD3Ft2+JfCEtoJqC0hEUriO4cwmLOK0dAlqnwyo\n7blT5f3NzKQDD5cLv399ZlThr998zt00s7bCh9OJwzR1QdLPt+DCL7DoikioYvppOmqMmGbOt8sL\n3zy/8tJW3u/27HLii8MtT+vC83Lhn47PjCnzVm54u5tJmnm5LHy7HVlZURWEgX91+IzVXnheK99u\nzyQRhpyx+itey0AtxllfQDYWueF+XvE2cyx7Hovx5d33kVsFXNrEZb3BrFFFaeZMXply49vlnmoB\nF6TsbC0xZCOrU0zD8eo8xQhPeJ1G9LjQqtO601glTFyyG54citFaorWM1ChmSQ3tGGPrGn+x6BjH\noaIpFjBFwhRGLShIqoZlwyRhbej8UEh1Q0ejitBIneYk4cWQWsdcu8BDJFzE3KO7TXScKHXSv+M1\nujNS8Iv9SoOrFh2sRAKvpRxqNhrmCT1bNJmEzWbB8RRsgLwFEJ5M8NHZ1EICvQniKZRyJRR04oJo\np9mZwDm6YFA8x3NT6OIFPvKaXfvrQvpIwsciiwQ27V0ZqBbdumsXd7ggtaH9ZyAtzMyjtgOFjJFT\no1hGOgclWWOfTqy2Y2mJKhMCke+mYdnZOhUwDr1MMhjZ2HSC5BRt0DLJdlQqogtVK0sVhrIjq7NP\n8GYcOZXC95cj70bldpj5fDowpPjMiYLX8Jg+rgtvpvl/ZDn4n3KT/w4R+WfJUv7983MES4pwLoX/\n67tvOZeNl23hWFc+291wSBN/OD3z9dMDjvD96chXNzcMY+aQR3778Ejxwj9eoqgehoFP5pnHbeG4\nFS76QJPCoIroxHvds/gzp+K82oqqs8+FQxpZLVGqYemEpsLNdOF+LNQ68VxGltX56u3zR7VYZQiK\nkodvrLmQUmPMHjidKcsWOODaNPA8jXDDVoXLkiOfDfmY4ODNkQKV8NINAn3wYnNyvIVwoNUoAqk0\nJBsThZIz7sFWkBaYbVLHk2OeMJMw1m6Qc0NpVAVLYScpJawshaDkRr6cdtmqBH4p0AJrwOMfwf8P\n8BPVUGjl3MltfYmHCYOADF253ALzTd1DVi0kvnUQTD2Whg7ZovusIezCBzA1kgmsBI83VdpwhT78\no1m79/FcpujWP5oEXbXcueOw4kGNiyccImzh6r/ZK1JBiSBJFad5iCPEK3gieaERjBWldkaIM2kJ\nKAkikt4Su7SirqwuTMPCVic0O0MO68j76cRiIbAZc2OrmSmtoJnmyn5aOdcZFWHIA9ac27Fy9hlN\nMKQJMeGz8Z6zh/z4fphYm/D5+I7dMHM/7virN+85lUJy4cubN9zOE39+d9+DY5X3+59FcsR/s13/\nRRbdb15fOJXCWiuXrfDhckZEmIfM0xLJu18/PfFwOXHeCuM4MIgwp4GHcuK711del43NjV/d3GJq\n3KaRP74ceakLz3ZkTMqsife7A69l5XFdafmBlCsqivrEje5YeOK4VTaFpM4uRezJaUsstTGOC5Ia\n81C5yYW1DpzKwGUTPrk9d8pUxPe8Xoagg3kYcqsEJtqqhLtYja7WjMBNiYhwM8EKlJrBEkILRRoV\n7VCBmaINpMWCJmsIH4pnvDiQGFqJtAltVI2L0iyhbgwExFAlUXJAD2IxloanguESUIJblwhLCxy0\nd55IFG6nN4pK0NiGKUZSB6vBW020ThzwGKm1+zcYpBIOFZ4bNiqlXLt5QTc+Ls8MocwgLaaOXIRm\n4ftg5tiguBakZcwcr/E96Z41OhCiBugsPAm+cNBuwyfi46fS+mxD4Ltd8BBEtPgdqIcVEj2l4UpB\njsSexMz28f8LhjHg1O4iJuzSBdGRU02MQ8NMGdPCLM6rzQxDxSx8e8dcuGwT87DiGofjboBLSbFU\nHZSKczsLSxlRG9kPM5bg7QSlzjQXbocD+5z4D29/xbFUjtvKJ/MNU1b++s2X7POIiPL5zYE5D8Gh\nn3f/7DXgf8Dtv1l0f3HwAsB+GPnt0yNZlNdtQxXeTjOrNV7Xle+OZ0aFLJn/9NknqMC5VP7zN3+g\nNadW4/2847PDLYMk/u75e/6wPtFaJED8b7dfItnYCvzT4xPVGydWbnnLjQzMMvDN8sqpNnxypsm4\nkYjTKSeljq9camyeVxsYbOC8CQ9FWZszzYWUGo/Lnjk3tqKspqxr4rCPwhf8T+Fyji6lWe8Q4xIm\ne41i20K8YH2kTVK5Wt+Yh+zUW4q8M6/oaAzeWFuGbpKdLaAHmYM1UOoUG3MTJl9RDfP1TSNxlwpD\nq1EUs1OJtAurGdxQYktuaKjOgA7OgntXzcV4Sh5iPHfgamSugubUkQfHSoc03EluIEIdovu2rTtl\nVSGXio+ZpoE9I1Foc3XUNSKLUoSNkrr4Yh0wHDJh9j5F961XQ6Crdlq9V8fe8UJ/P/otXZkiArUv\n1QjRh1zxaq5f7x0+GfVytTgiHIdzZ6M4ysZAY5LGiZlNwjFtEONWLrzaHlTZtLfbVwcy7+kkatQs\nHecXkjUURYcN1Yy3jNke88qUVxBnK3DJO8wbd2nHu3HH6oVvLk8c9Ia7YcebcceUEjfTAK5srfK6\nbZy3wpv53T/TVf+/zu0XWXTX1vjzuzeICFut/O7pgW9PR348n3lZN95PM3fzxDIHf/CyFj4sJ+6G\ngf1h4v38JQ/LmVoLf/PwgQ3nTmfu9zMqilTnj+dXnrcLbqBZ+Q/7Lyi+shb4p9MxBFq5src3DEu4\nZD0uJwrG5Mpu11CErU68Xpxp2tiaoyl8GKwObGK8lEQzJ2UjZeeyjSRtWA2oYN2ElGKxlnrsT9vC\n59ctFlW4kKQxioUktAqlxWo/E51rTiXifpqwtgEvypCc5BU6DFdbCrxTHfVKzr0IIzSbwJxsjSGF\neiwkpgOpj+WpVdLQHdI6lksNFoNqC0P0jxETfZRni3ojA+H81b0nFLxFMU1mpARDTrHIE4cOacyN\n4KgC2xgQhYswWTy2dilxHcMbgSTkNkAL395NjNZdzTK9LNY4IMLtrNPDNKizYSwTxTdYGf3AMHqX\nCh8VDRFu1jvgxk/2RYpjZCqZAj0bL/jZgtCYKB9/1uIDgpK9EHFL0W2rRWp1QBLGvV94lQMGLDYC\niYE18HhXTCouCZEJl0wSyGwkyTStqBiqA3vfc/RKkZUjL5w3mGVHzhvvbu757OaGl/PCN8cTn82H\niGTvxuXtZ54EDL/Qoit0TiRh8dgsaEa/urnlMC68GWeelo3jtvLd6cismdtp4s20Jyfl8Xziu5cT\n523j/TwhOvGrmztOa+HhdOTr11eGrNykmc9vDohmzpeVb9Zzp9c485D57PCGS9t4Ol04lgKSyIMw\n2RvaUTitW2yDvWEysN8V8Eg4OF8a085p5miKvqasgX0ubaBaXxrmCHJMxJ+1hbrMPL6v3bxGWqOW\nMZqsfnmn1BilYaJsNX00xBYa09CQ5LgJrWkkGldQbQwaIomiilnGN6LQWiFpYLNGwppGwXVjSEbJ\nQuEns3ZphEJKAU9RSQcBN6T0ktIEtKvYOr7tBrKEUCK7ISkWpwuGiKKrkZqTc9yt5hS0MAvjnHEN\nI54mPZI8S6fVasAdHTJoCSRrPHwLExyMLvbokGzHD6REsRfhJ+evzkIIFMSBEK1cO+Cw3QzWCEDY\n3KQu0L2GoAZUZMBE7aU3sngrI8FhcDIbO1kpnll1gqbUpuyGJdKhGbnIGMbnqcuOaaiFab4no2nG\nK2AhR1YJv2JrkNseSSNDgpwXxjqyNaMU4XbM/PntWx63he8uL/1XJHxxcxfpxWa4OZr0Z89cgF9o\n0T2MA797emAzYy2FxSp/dv8mwhBfnH94ekJN+PF44d+/+4TDMKIi/Ndvv6fUxofljKjy6/tb3u8O\nfL+c+PrxkctaOG4bX93ecDOOzHng26dXLmXhD6cTt9OePDn348zzZeX8tPJ9O+HqjFkY88zgiXVb\neF0LxQUVJw+J1DLnl4m1GCYxIq+rsNs13IzShLppGHCbRJS5BDaoYtQlUVtgo2lwEhY0LnfKphQb\nPipGE61fd0axoeOSAjTGwVEqTTPWrhv0sNuZx6vLVv9eU1I1Um4olTIMwZwQRYuTpOe1S4g6LElY\nIbSgoIXXQuCxdk2yrb2wXUdxkfDD9Sv3NfwExIRhJBaN16ia4gxupBbRPcWv1pLBXZg3utk6lKH7\nNKgw0ilczYN8oFFw8QjJiN1bOLh9jMch/pCrcXsKWpto+Bm7daghUIJebK7wQV+ecbWgjIKcejHN\nWI+nS72oFlJ/3wyl0bPRcCY2shiFkcUGGgOiUYTNx3CF40qtS5hkBivd8H7Ck0YahQvaQDXk4QXB\nPFF9h9FIsgZ32ZxXmxCM23zDbT7gXjhxYcphPzlK4mac+OLmhpe18Lpt5MuJMWXe/czj1+EXukh7\nXVe+Pb6SRCmt8f3xSDXj4Xzm8XJBBA7j3ON2jGVrPLyeeS0rCnx2e8NmlYnE7z8887puHNvKYRqZ\nNXE3jjysC0/nC6/LiqDcHEbmYSC78M3LK8ftzGtr3IyZlJX9PPDheOK8rVyk4NlImhhSWFEubmwl\nyqImQ1JDB8OqBYvhqlbKzjh6RMsYlFUwEtJVS5qCXWAuWAmlWHEhK2RrpOTQjGIZtxQ+BxCKuE7s\nrzUFFpoEt9bHykZl+CgASM1Rq0hPYWiagtJloNVJyWh+NePR4LxWCx6qX4uYYF3+q1fRQ+tMgI+0\ni07RcrrIQGIJ1/1ecSe1MOHJ3QPYh5A9Y8bQQEswHEgRS2M5ICE3YajRzZt2s7VuteXdhwKcNoRQ\nwoRoY/oyj2QhX4aQ2nbcNGpt79S7sM57kYwv2kfKbhjTKxHGBP1FfcTdRzbCSslCrUe071Mq/fkV\nRAc2y6hWRBKDbOzyymu7QfrBlsTZpyPndkA14JjVhN3UWNtAFieps3liN0HzDKYchoFTdeapIj7i\nlvlkfselbdzNiX06sJTGZ/u37HTmL+7e8sl04LhtvN8deD/vGYfEp7sDjvDJYc+ch3/mCvA/5PYv\ni7Q/vRnOnDNZE2NKiAoPlzP7ceBUG1/c7BATjuvK3/zwI5lQFh3GgV/f3dPc+YcfL7yeH0kpM2Tl\nf//sK5IKL5cLf/f9B5II61b4/LDndr8jeeI3jw9sW+FSVoZh5t/eTQwI369HPjyeuWwLFeF+PjCk\nxMbC46Wx1Sgyw+Ddf1UozWI0q05E1EAeGomGr8LWlNISKYWRo46tR5MLbRFqy8FaSrDTyqiNZtqN\nb2IZk8SZc42Is2aUmntKgzNK2DJeuabVZ6wZ6kJm6+Y0kVwrXSqq1RCNTK9iCUsZ0DAya2HSDkJL\nXQIrPSuNoGNxldBq7/6u+yM8qt6Q4vnIT8sqbcZwNebpxVZq4LmDQaoOWYKfKxJYcVXSEtYHYUkb\nBkGM8VDao+GueLgMAZEkE1rzn3AF0V5HPxpvhtLODXLCr4qy/hI+XqYeIgzt08UV7TV3XBLX3OaE\nd+A6UX1glEqi4BjqjY25c9AD9b3JC0udaZ45WwJXRtmw7mjnaYxpQTWi691JHRJSDS8OKuFlkcCl\n0KQBI24HxhQeH2MKV7ZjKWQ1Dnnkq/kNH9aF75dXjIATPpdbltY4zBM5ZeyjS/DP+/aLLLpzyvzj\nunIphWbGw+XMX9y9YcoDN9OFPzw/02oYMH962HM37TgMA98cX/lwPvHHxyPFGnfTgc9ud1Rzni5n\nzlvhw8uR+3FmHjP//pNPeC4r67bx2x8eGMaEpsSnd/cx8hp8/fpEq4Zj3M+3jENi08aH05GyRVDj\nrEIaAv88N6OVEp1WFlJOaAbRghWLgEETpHswDEAatsgIQ8IgWx11Iw/dmtGUsoYYwsVRMca0xYjp\nipdEtYxXYRyc5AVyxHJbzVgJitborQsXEiaw2UiifbSADDsHpzAB1j0CanSr+tP7456CxSAeYZnW\nWQyJ2ORHLG3vvIkilXtn2YiO1wyqo7n7wXYeWLLoeocuslhH5SoHSwajBYxgQBt6XddwUGtr1y7Q\na2kihBYiYBqvB/mYbkG6QiHXO8QU4IMHLQ+Pzv06T/bcN5ECXdRMF4jgGnAQRiZM6asbEzmYClJJ\nbL28C5tNGIlJS3gF1TC7cQl4J9DzRmIDJgqZxaKDHWUFEiuJIplWBckJzXEMSKpUa2hSTCbQ8OZo\nNVJBTh4o8pv0llvZIdmpqTKPA2tp2CDcdnjhcVn4cDpj5sx54LP9zT/bdf+/yu0XWXRFhCmlHg1i\n3M8zzZwfTycezmdOy8aUBr66vwslUxp4PC28nha21fjyzQFz4f1u5uW88bpc+OH5xG4YeXtz4Iub\nPe7Cw/nMD69nmjlvbnbshpG73cjD+RxQxrJykwfaINzdjBxL4bic+XG5kCV2RrdDpubK4sJxtfA/\n9cyQMpYKUEA26qWbl3tiGB3NlTwY1EZdIiwzmsTKlFqolGrCilBrFIMsDdUWtn5NWdoQ2CzCoI00\ngUjE+vimmAT3dUoO0vgYDd/9dsfUwmhXvGONgQ8m9Y9YsyFBW0JR7SwGwgjFiMeQ/il1uxax3tF2\ndRZYJE1sfaHmEfYpk1CEMPKuFjADoV7YkuCpB3YapAqDEn2iEkowhGwSMG0LpOCKaFxVcdKiE/ar\nP3smjHb0yi3u/7dj12j7uCqM59r/lqN7v7ITPrI0rjQ/uSbaGW5DLLHIuGwItRdipYbujsgSXpmk\n0Hxk8zGWjE2Zh8KolZeWWHTCajBehg5VRL4yqAY84teFoHkP7kxgI9YS6Ay2wFRCeGIj0va4r6Sp\ncpGVtjg/sHAYR/7y7h1vxpmLNb6/XLgbRlJS7qcJEcV+5lE98Astus2MMWXe7DoJ2+HvPnxglzMv\ny8onNwfejTtw57988x3gbBYR33/91SdMmvnm5cjXD0/QjNdL4a+++JR9yiDGP3z/iJnw3esL9/PI\n/eGW+znzm8dHvnl85mnZSOb86v6OQ1aerPJyOfNyObLWxic6IFPGU+WlFJYVtgqZuADmIbPmLbbU\nm+Fn7dHLMuksAAAgAElEQVTXjg4tFlTeKJcII3Q00m5ki4u7NpYl+LvJhZwCXlAVSovC2Ew7g8FQ\nFVQq1QdaibQK8UAPU6qBvXrCWhR+MUNzF2GpYC1FMTHIyaFG6m6RgEWaBX3JO3WpqaDdC8atxwt1\nDq6rhMor9y63dAbAqp054FjuK8HmyBq9n7pEl521c4CFjDM2Op8WVrk2rdqXeTH9F4McupFIHyb4\nv0FAiMgg3H9iK0jHfKEXXD7iywGH5Pi6emAYV4ih/zdKUPxInUVBwCwmHmwCCce3gQruVBlwd7LU\nWIJS4/fXEpc292WkM+lGbWMsVLUbhhvBBRYLamENSEgt/EaybTF5iPUdZqb1cNEwl4hlJTYySiaJ\nshsSxy1zqhuTwk0+8HY6cCorr3Vhs0aWjI/GRSp/tr/lME4Uaz/PJdL/5/aLXKRtrfGH5ycurbLV\nxnenV27yyH4YWVvjw/lMLZUPr0dWqxzyxNvdnvO20bzx4enCcVuYJPFmtyMloalzXgoPxxNtM9KY\n+fR2ZmuOWeMfH15Y68Zlq7w97CApQxL+8eWV03JhrYWbNCIjJMn82I5BGVuiQ0qdhrQqeN1oXsMN\nrEOHOlR0KLhbFL8ShWEeII/Rg5pBqWH911CGFEWTnpVVWqYZJJSkG6k7c7WWMEu0ImhykjcshRx1\nbQFl+FV44CDqPZk9/STZdQt+qPLRON26QXgGsFBNeadUIRHBEw5aAS34FSO9mkgg0egLYX4usSyM\nVIuEVAs/gyFgDe3YvBTp0eERsVNSLN4ShOdF/YnqdY10cycgjo4je4pDAoU20TnDTtV4n2h8LLjy\nsaHt4MQVepD2JzguAT2IBezQO/io5+FvjMSyTLQgpJBcp0KzEVdj7AU3+A+Z5okpha9uwklDYfOZ\n4jDkOCBvhgsXn2mu5Oy0JiF00TDIGVNlY0CTozJg3hjEKZrIKgxpwGpmN4xBaFO4myfMBu7Tnv2w\nY5eVr24/5XhZSAJv5wOf7g/8m7v3HEthTpmbaeJ2HPlVlwP/DG7/skj709ugYXh92jayCObExrTB\nsq78cDwySWa/m7kV5bP9jvNSeF03no8LN6Oyk4Gv7u9RjJe18PWPjxzyBA2+fHvDzW7PcVv55vQQ\nZM/m3KWZf/3VgbMZ37888+3TKcY1V/7123es3nhuJ745P4WRijnv54QPcMZYmpNLoxYYZA+54Vrw\nYUHM8CWMaiSFEGKcnDRUHNi2BFuMjGMKn1QXOlQAzRMCjMlJWmiitGq0MtAi1IGcC6Se/uvG2vLV\nxpXc+bxCpBP7lWbk0UWB06QHYyaAsF9xVZyG5e7kJSGGMItkAmtRbP06zrv0Fpj4U3unKS0KrCW0\ngpjTVGByZAj5MKEN6DlrsbBLPUVE7U8y1johwnpDioU8oeskgoFwZU8g/Tn2D9fVSyH1pR8/FfDA\noFPokENOx8eTk2BWSD8MHImON4FjmHRvXofUUohYPOS+WRvujSpC8aljvIGzqmy4DqyWyRrm8uOw\nMTicauKSRmqcdLjVj9NG9aCdWR6pWzfQSR7d9pBIltGmiMyYVFLasJq7XWckVhz2A6VUns3ZnV7Z\nDyN/fvOW22FmaZVTqYyaSUmZckJV/9uV6md0+0UW3WrGmAb+4s3uo1Dh7z88oCY8ni/MOfPJvOdu\nnvj7Dz/yuw8Ll62x1pU/e3fLm2lmbcYfnh+hwg+nM1/c3HI3T0zzwO8fHji9PPPdw5HDnJl3I2/e\nTTwsGy/Lhe9fT1it3O32vJlHFq88l5Ufl1eO5wv7lBimhI7Ca1k5m7EtMa+KKdOcqAlIK60YaZHo\n9DwzTBW0oTmsCcuqbGtiHMNyUnNkbRlK3VIkLqizyxuSjEaQ5q1kUk8SlikoS07q0uCEm5CGwHhD\nJXeN9O6ztIBqwSwFlUmuvq6E/wFRwcwMH1Jf3kdqxXW+shyjvih9edbhi77NZ5CfsNWNAFTdg2Y2\n8tOSzYjIAxcsCW2mm53LT4GWxaPQjj/tvNQ6NU1iL0cOFoNc0yDaFXf1zpq4PiY/bdyu0MJ1USg/\nPc/ImB+I04CPvzf6IlH6/dzCGU8I/N2opEGC/gYUEtlD+ltdydJFKC0SpZ0hDooWHGHxUCvKMFC1\nYWVgSME0KJapaaRZvKBkBhLx9aIevyDJmAXP273bU6bMYMKkE7NOnH3jWFdw5a0emNLEViulL66z\nKkst6CD8+nDHfriamvu/JEf8HG8qguFc1i38FpaNKWXu5om7eeZSNrIk/vHpmfNSEHPe73eYz0xT\n5sPpwsuycFkbd/PEX757h46gLnz78sJlrYjBn316R1Jlmga+fz3zupx5vVx4P8+QlcNu4tvjC8/L\nhefLmUGUTw83HIaRD3bmaVtZS0WrMWq3ODwQ2Jdv2KWn4MqIDBueGzmXMHBZQjiacQ7jRkoFcvp/\n2HuzHkmSJFvvE1FVM/MlltyqpjfOcECA///XDIa4dzDdPVVdlVWZGRG+2KKL8EHUI+sSBPlC8AKd\nbUAhK2NLD49wMVGRc75DzgGtkbwKYRQkOrzGtHkEe3ZNbwrVT+uDx7RbC5RiiAZHQA4BtUyRQCu+\nNDLTzhlv6OAQbYeVeRdNdZVDFWfZltvyzHrag5dcV4bFbjbAHWZ0pxyxFzfBC3EHemN91husS9mc\nXqaba2SJXcGVcCUAhq1e4JvetLlejH/LTWidMWOhK7l8T9cLs58WJLl77rciBV5HBLe/98uafyGV\nPhfyv4uoF7BGR0TCzVccTJDikirFOo/Ctbpm2tkVPYeuszNK8+O+Nb9RDCGT1JhrYLURLBCkkmph\nbYFNHGrUcHKZqvYlqgGKBHXaWDFqqbQGNvbnIwdSHlkoDKG5xrwGRhvZx5HDEHkzTnyary7NjJF/\nPr7lcdox541fr1d2qfAwjn8vo4X/x+ubLLpBlSjKXy5PpBD4OF95P008DjtyLfz4fII2gzSmIfEv\nj4+oKX97eeHPv3zhLg6c88Y/v3vDfkw0Kv/+t0+MBF7WhYfdjg/7iRaUvzx94ddfLqylIQH+9999\nQEX56XrmPz5/olG5Lit/fHggiPBcN37IV0qutHnlbjeSJuVcNmaMVoyaK2qJGHvHt3PJDWTKGojV\nI17Gfca0otHYtoRsnlw7CqSxemFDKDXQMtQcicGQUGjROValBUoRAu5wE2mEndGaUVr0jXVwYpj0\nMMOG0Kovb8yzbhBzJ13Vm8NMutxWnAFh4tpPFaiupebmJLu1iyNfgxqz+J9mrzKs1xlpdc6uVHNo\neBcD3MQEmv1xqgiF6l+3B3JK64/Z8Iy23lFr85vDq4Li9soJX5dnX2WmveAGvt4sbm8X5XU9L97x\neiQPSHT5hoggpmhzmWCxSsPlb+CdeS2C1AhDJnCjkUFjoFkjeElmCBtmylYHqgnFAkl8ztvUb4j+\n/AjVknewOHO5iWCMWOv/XvBvS9WpDFIiGj3BOMvqvA8JZIGEchhG1mxQCi/bzC4Ffre/YzeMrLWh\nBlHD65LRQ4r//ovuN7lIMzP++uIc3IZxWlZ+eHlBKjwtC+uWeTvteTzs+TRfmLeNdc58nhc+HHfc\nTxMIfL5esWp8PJ0ZNXC/m7gbBj4vVxrCz08nSt3Yxcib/YGVxtI2fj5dOc8zgcb9uEej8bRlflmv\nPF2uTvTC2E2RpW6cpXBZMmJKyIUYtcNiMi1tSGm07pwKwZULsqtIcUF9K4JYJTRlmjIW/Bd8WSOJ\nRsnqnW1QTNUJW9mVpNWcO6yhkDW5BbV057+qW5JNOq8Xh6M0vJOrzgqgF1BTn3N6nHrwYkP/JQu/\n6ViNr0Dv2/uc+eIfU/v7b53k6As1f18fyCJfixtuIdbSy3eFpq3jEv0NobjdAG1YBUuCBjduAEjx\nTvpWGFrvQk2EIr9RLtw0urcFWg/KvBVhuY2C+4cF7Ts36eMNaySD2grRDEuR1rPJRIEtdHSmuQlY\nC4IQrPY5r7pKBV/sSWhoTRSM5FMLBs1YDGxV0eCuxDH5k5ObMkQfM6g6n7iUypjMecYod9PAvEIU\nIyWf/d6PiWYDewamkNik8nbaUSs8hD2P+x2DBP7l4Q1bcePH74737FPiXx7eMKbEUjLfH+/+P3yl\n/0+9/rFI+7+7VFzgD8LWKruQeHuYmMfIfRpZtpXLPPN83bgbEm8OIw93O8hwWmY+v1w5jgOP08T9\nfmKXBp7nK0/zTNkqxzES48jjNPFlXfl8nflyujANkcdh4O3jPad15fNy5eP1SlDjzW7gLk6sQ+Xz\n9cKpZMiFqTWIQjwGSjbWmGmlIVskZkgjFCmYrRCNUH3+J6Wh2kjRQxQ3UyiNnAOpL32GsdtTMbYq\nUFwBEaMRY8WCkGtAS2XtJCtUeqFw335ugvSKE5IXYGKf86oPOs3kNdiQZhR6ZxPpw95bS9r8Y7R9\nXf3LbxZoCuzxopr751b6+b+PGJL/XKle1cScdWti2CAg5vPKrtiqPaPNghGSqyS09ha2GtUMJunL\nPHt9XNYMHbpjLfB1kXYbTYtLv25dbwd8+bhXvnojBE+fcIu0eu6cuITKLJHMKLUrGaR3nUUwItUi\nIgWh9oWeaxgUnzv7aNnHDqtFmiW0RKz2JWd1A0mMt+ewz8b7zckQf1s3z+RNqM0YRn8+pQradmAF\nonfkoSqxJo7DwJtp4i5OHlB5vTKGwD/fvWUXIrnWPl6I3H8DqRHwjRZdEWGfEv/t0ydElC/zhV0c\n+NPhCKr8+6+f+K/zC1KMl1z43757xz4mTtvCf/v4iUOM/HK+8P3xjg/He9IA/8fHX1G58jKvpBj5\nw8MDh2nir09f+OvzM2tpXOaZ/+XNI7sx8lw2fnh6ZikbX9aFPxyO3pAp/Nf5xLbAy7xwSCM6jmCV\nU1tZW2NrRty6fVYEmyrFCiVWpAyE3NyqmRzvF6aKWaC0SiuBIQtJM2MQWvIjZ6uRuoEGIyLUUWg3\nNkL2YtlEGMT1rk2qoxyBgqsOnKLlLFpVnzcavG76pQcymnknifBqEeZ2rEzWZ57+8Vb0K+BGQaJB\n/E03DF61Ar5ij3yVka03k0HBohcqwI0OzXAekMBAB+y44OA2Q3bXVaCljnoQl9X6yOAm/XL1xG0m\n/dq638YL/Y8bsfA1SKKPdSv+kIPdblABxOVtSkBqpXY7coh9eBOUUPV1dCGWaSoEDT5/Rz28kwZS\nUetUN0YUp3rdpBDaH4wbIjwPLW/9ziCCSqSaIW3w5aI0CEZoShJf3jYatTXW0oixUUtjFxJ3aeJa\nK9eyMcnAISW+mw6ohm4t1i61c9xqDH/vtgi/vsnxAsDHy4WXdaGZsZTCl3kmmHBaN56uF/Yh8bCb\nyM1YWyavjefLFQHuxh37MXDOG63Ap8sFs8oogTfHPeeyEqLw49OZXDyF4X7v6QZVjJ/OZ67LRimV\nFIT9OLC1wi955tPLGYtQtsJxULLBooWXdXauS6lM4rPGpuZzXgrkRsL1uSIVhrlrXg3L4kf51ojB\nY8xDkA4tqZSs7KQhMfSvG7qLVjzl9zYhCH3Zo8LWfJtu0hdCoVGBWrqDKdwgNT5drKgX3X4GN/eR\nAPgYBM9we41tMO/SpAgWPFKHvkCiSKd+8aq8uikGzIAs/mLukjKSf01BsA2vns3HCzeFgPRi6rIx\nfWWkuxQ3eLEVXouEdA1ulS5Z6+/n9XHyqmQIvWG+1dQe6/Y/YCTk67ft7mEzkvmNrKn1XDRHWop6\n99qxC6DOG8vNW2dJLiETusTLiRwMomzWSKEnTwiIDrRmfdTRaDRiCp4shBJjJNdMij4bGSIMUcgb\n7OLkNyoxjkNCCNzFPaEbSB53E+vaeLc7so+J+3HiT8d75lJIGviwP3AcJv5wf+/SPZW/J8rYP8YL\n/9ertcZhGMCMMQT+dvKZblIhxcB3d0dowvl64qfrmccwUaTyu4d79nHgsq78crqQ+kl4GkY+HO5Z\ny8bpujG/FMYQMWv84f0jW6vObfj8zN00Mgm8e/uGDHxZX/jhfEJVUYH3+3tsXznnlZf5ii0bQV3Q\nn8aRWjeuihd0EXQLpORLENlvFKloC65sKLgpgYwO7twnwLq4hVdrYj+tvJ6k8WNnXgMpKTVm10BE\nwWpPgG2Kqva5bl/hNKF0bKTvvrQfuZuDu6ViKBJ8ASY3rKEIjN6N2U3ixe39FfaCqhO1rOez+SKr\nEnuuWVPxxVoPcETN9bm92TNAVh9BaOiTjIO8QnEk49ZlehCk+CjHivS5bPV4nqjoCObeY58HG18L\nrtBvUP5kdgInVYwUe1eNf472Me/N59ElvSS82V9wo4hER0KOwFq8OEoflYdu3KjFH09UqOonhSDS\nHWRuhlGpr/Ku1hLRA4u6tMz8ORbPoPA4qdpvfBAkMQ2wZaNkIRJo4iGl4O8fmfpsWojqZhlpkYcp\n8P3+gBnM28bH64UhBL4/3BNFWWrmaV2YYuBx/LuI6fl/vb7ZortLztStzbjmDWvG7+/umWIkvCg/\nnl+wAs/zhd8fH7gbJj7oHX/58pmXNvPpOnM3Drwf9uwPIz++vPDz6ZnneWVrhe+OR+72E1/OM397\nfmHdMudc+MP9HftxZKXwt5cLWyv88PTM++OOURMfjgd+XWeutfDxOrOTiO6UKSZO5Uq2yhmf/Y0S\nPKk2ueRrjZ5pJjWi1cGwTSsybNQqRIG8BnTL/sIMhRgqS4sQlLoYJo2hDgxD8aOsBoeOZ/O2SqQT\np7wAWDO3ASNoVFq4bd/tVZOrTgH0GfBtyUZ7xRpK8K5UqtKsgiiacmcduBStVodma3A0o6buVCju\nzLKmEBtxaFjvpNuKa39Ll1mNhoWG1JuzDe+YoyHSaMkLpBSHjJv5clIEGPXVy2Dm3yOtF+hu9rgB\na27ks6A+WlGRG/cc8fG5y/FCb9almw+sM9ojvvTTXstXkEQfH1QfTZgyIJSuvb3Nhl1W7GaPZp1H\n3ue8vgbtM2xxa3TqaSF02p6UfpLQSNT2iirWErrRxXyhK8oxJpYqbpBBmbeV4+iJJGMIHMPIaoU1\nF4bkrrX7cSKGwBTcDFFao5TKpRnvd8q3cH2zRVdxuUror7/SjLlWnueZp+tMKQ4b34+RJJG5rLxc\nV/JWGIaRPz7eewKAKD8+PfNSViiVD/s9psIwqAdYzivLlnk4jtzbjjREnpeFT5cLc84MMfKvbx5J\nMfBcNn65eD7aVjK/2x+9+wyNL9cTq1VKbhyDH/lDgnOtNIFcjFAiWhrRlKbldaFUWuoKB0W1YKKE\nwYMLt9aoVYnZFyfjqNTo8eTNPEOM4F1QUx92moS+X5JXm6wKmLbX5ZKIuO5V+uebfNX/SyMNuMTM\n8LRgoeMnDcO5vqUKWpRqjaCKpsIQm0vcmlDN2cLOYqhI6IVs6zNkUUwbYecKC5IHcILPYJsITOat\nZfYC7Q2et6gSvbs1BSlORWtNUPOEhxa8mIvQlRX29ZdLofQ5dTSPirxNYm5mi2B9ErFB6Dze7lom\nKOwaLAaWutxNYLRALTDcRrpdcRX9B/B1QeeokNcd5S2N4rbrw8RBPQSk9b1jf/ghQK4QojJI4Jor\n9fbvx8Cgzt9oOfi3GoWowt0wsGsjDSNIwqSSc2UOBRPh7TRySGOPZL/yfrfnfhz4cDyA4cvKb+D6\nZotuMePNbs8QAqVW/n35hZ9fnkkamevGPz88Mg0jp+uV//j0mWOcOC8rj4c9H/Z3WDD+49dPWDPW\nvJFC5I9v3yICP15e+PnLQhLFGvzLd+9QVZ6XK3/+/JmkkWUt/OHNI0MMvGwz//n0hWbGujbe7g6M\nR2HOmZ/mC7lt5GwkjdxNkdoaF1m45EbBO419HLCl+Dw4VJo1aELIEcluC7WYXWc6uvDdimE5IrHQ\nBMIe5tJTtrL6AVSNoJ786idpf5E29W19+23iQd/qi4KMwKt+tI8UojEN/rEiXgh8F+UhnMPgLiwx\nc7WECToURmk0kV6Iff4qZsQAqpVhbNTNTQK14Eu41ghjD5AMUDb/WXjVN2RsaEdCtta/ZvXHb8k7\nS9frmmeedcCNanNJb5/jym1EotDP8a+/Y57z2EcYv5mmGK+JPq6tDXiX2YzBnx0ohiTn+lTrfB/z\nrjdGL8a3Mjr9hvVQizeqKfE6o+7gs9cJyG0ELrj++NWxLA6mT6oUa67PrUIMsIuJrVZHhmpirRui\nShIjWmAwpYVADB5pVKuT4t5MO/5pd8dcK5d15ZMoSZWH/ZFmxjVnt+Nr8HHfN3B9s0V3CIGPZ+fi\nbrWy1MwfHx6Y0sDdPPDpOqPzwq/nC292I/sh8f3DgV8uV369nvj15QoKD2ngTx8eeF4XPm8XXi4L\nl7rxdtozhsjbw54v25XzdeN5XbkfBvbDxIf7HU/XlV+XMz+fThzDQIiBPxzveKob17Lyw+mFKQSk\nRt4cdsw1U6TxvF3dHlvgENSLSqlcY/V8MAS26B2WQUiVqoZYgBqRxbtb1QLJMEmOMyx+xC9FCDFT\nCegQyM2QZmSLRDFqxLf2pr0Ua+9ohRB5nWveLKcSG0OCFFvPZvMXdGue4TZNhVq1d1/eUk0pY51L\njnYzRk3UoqTUSLubkDdQWsO4HX+dg3uYKi0ruQm5CARDaIQktKkrGbZuSmgGFpCpoQ2aOGGLbojw\nO4OfGtpNudB6q9mLrUHvOo0biMGjgLzoNrygv+ZUtq7Coi/DzEcMiUBAuIbSzRzCASEarDRK9Ll0\nrD7xMLzQEr8ydoZe2Kt+Ze+k/l+D11l36g//BvQJ5qqvEH35GUwIoogmRk206otU1/p6mu81N1IL\nmCl5K4TRTRk7HZhkcLOJCGMMiCnHYSBp5GF0uFSuxiV70f3++PfP0oVvuOhGVbZWKa11boqgomy5\ncF0y5+vC/TDxuJ84TiOjBE7LwmmZGYPw5m7CrPEw7nieF57mmaVWjjGyG0fup8g5N57mK19OV8Yh\n8nY3cZgGtlr5eLrwMm9A5Z92Di+vCj9dT1y2jSU3Pkw71ALhqHyeTyxizMvKEEeCNeKYuLZMoXGp\nIBYJGEM2aiy0qmxjhSJYDcQt9SRaQ6aVVgIaK5IFW11zWxy1RdPB0283L65iAQ1GjeZQbBPMKrSA\njHgiROpSrs7KFRHCUEnRFzWvZ1+DYXD7cYqeWhDU/IhZvHDv4kZuHqJZe8TFEFd2UbpsSbFmbAVq\nC8SxkqKvqmpRTDzanWpMWskWGA+Zlj0rrlTBRD36KAo2FL+RbEBrXQKmkOzV+utXV01Ig+SGCfuK\nEQPr/AgB+i1J+lxFeEU5UIIb4QSPgB81sFG+cnxVOZBYybTakDDQ2BiCU+CSuuTNgsvPBO9mQ+xS\n4b7Ek9vesf9bA84P3uiqNlViaRyGSOmV11BCC+yGiGhgy8XHKBhTSEwpkqugNTB2GWFSAYlElGxG\niglFmWvmZV3ZReXddOCQRq5549N8ZT+MvNnt+HA4UFvrYbF//9c3W3SrGY/TjqHDzMH4y8sXRhI/\nX878/v6e+8HlNP/90y+MIfH5OnMYAv90/8AuDvz1+RN/fn7yYEgz/nR/38cFhf88PRGLclpX/vT+\nkUkHnuvMXz49EUR4WTYeDom78EDWxl+enshWeb6sHMfAu+MOU+XTdmW+Lly3Aiq8nY5EU05yZd4y\nW3Np1SFGmp+t2WJ1FUIEKcCGH//wDXYZDJo7iWRTNzJYRYaGFC8OLXuCriEQ1YMYgx9QW3VGgqpv\n8+st5bZYb/KENPrSK/SgTKtCjJUoPhaYYun0MXyBhjIOlbTLLrVtAVMhttIpjsIurlgTljK6+w1l\nTIXjsDqgRZTSjCKRUgMxVOJQCWqE2scjfeQzqpGtoZPPe2tWatVXKLqqg8klCK14N+/Ftp/1ocsO\n5DY09cnCbbrQAe+I0JoRbpAcXvG9KEJEXLOK0cQNvYL/vIL58iyLUWmMJGJXy8ytYH0RN6gjKRuB\nBbcMBxXHOUpkMcc9ikSCVVISz74LvhAzySjedAxBGUSp0ogSMBNiCCQN1FIYNbp6AV/YSTXipIxh\nINXKXhOLVbacuQsDjynx3X7PnCuXLaM6o9IZHa2SayG3V13HN3F9s0VXBUqtrDWz5Mo5bzzu9tyl\ngd0YmHPlaZn5cp4REwZV/uXtW+ayct4yPz6/sOaNyQKPD3fU2ri0ys+nK5dlZZ8GDsfEm+OOS8m8\nbGc+Pp/8l1oT7z/sOa+Zp7rww6cTEzBE5X9998bDMGvmb+cXkiilZN4PB0oXs/+0PoMJrTVGUULX\naZ4wLBq1RaiVWAOxNIo1Si8U1br2dbvFfXvUtg36mu5L88IkCiSPIqe6ksCTYxuyczqZH60dSqPB\nSMnQULzIGF2G5//O3ZjRUHt2Y6BWYYyV3ZBppl6I1ee1WhvSjHGo7MJKM2GrkZnILm205rbUXcyo\nwcs6dcNaIMTG/TizlohJ6G+vlCrEUJGDMcSKrJ2aVj3tIkqhqTrHQcE2XGLXvs6Brd0mseLt4g28\n0G28cpOL4dIrAfQ3Bfc2Y5UGoSMcnd8WiFLZGZgEFi1dYqbsNLEHZioqiY3shZCu1+Wr3XrXqQul\nvy9oYuhozqaJpYBoJKnn2WlTsjYkBEIzoimjRqpWZyRHn9eOcaDU6ioRFaIIh3HkykqwiJUKIVA7\nJWyXEhrURypNGFPADFKI7ELi/f7Iedu45szTPDOlSNJ/qBf+rq+kgaUUTnlFEOZt4w93dwyauObM\nr+cXDmnAgnGcRt5Ne7Za+OnkyoZDilQJfPd4RxPnMHy5zhzjyC5F3u12oPDpeuWXlzOqgX0K3E97\nJMDnZePTy4XWMvcpMITEOCif54Vry1y2zJ5IUOHx+MA5r2SrPF2vhBCRUjhMB3LOVGmc8+qLqU3Y\n3QaMtTEP5kBzxdUNFWITajJK63HuFpDNux/HFHjxNjN3aFWhtUBzJiMWvbuz2vm3CQjGsOuLMHwE\nUTIMo7EfM0kLMVhntxqTZjQ0DmP2/LQmrDVRqrCLlbu4cG0DSbtXSipRmnet2tiHQmvGbCNzDuyH\nzcWlGtwAACAASURBVKVTDUIUBqnklsgV1wkH4820kHN0i6053L3m4CoFze6m6iONUhw6gzX/fs1n\nupTmw9LSdV8d6+gbsa7C6G94XaLhZbibvHpB9jFCxZzjY6BErIPKtX+8AUm0qz/85hXdBkNyuwSV\nRqVgCH6o9882GrV67xwlUjUyaGOywCzeeARRBo2MRDbtwjaD0ALHaWRr1V3M1RgITCkQQsSKIqUR\nCQwhMgYPiFd1ud1OE9rjsJ7Kyk4j7w8HpjSw5ZVf5yt348hdGnmzd8RqNeNbKLvfbNEtrXEcBx52\nO1prpKD88PzCROCny5nDOPBut2MMd/z1+QsfL2eeLyvFjN/d3XO/G/jl/MLfLs8sa2G2zPe7O+6m\niazGD5cX6lr4dLrw/f2RwzAhWvnz0wkwfj3P7JPyMD2gUfnh9MK8VH6dL4waeT/uSFH5smauufBp\nvqKmPAwToUVsV7luG1fLUGBMEzVn736ssEVhQ1zHa42w+aa6VmOb8I18UMgTKubRNqnRsiFDg+wv\n9JrBktGCjxkMj4wBkB5qJmrEwaBrQ8WUYVeY9pW7aXNFgxlRGq0qQ6zcjyspVJIU1jog0nhIM0Er\no/pIYF8zlzqyVWGIjftx5VyTz5gFhmhomwnmkd1T9Cj6Ux1Zc2CMhTGCmce3p+jUtNYh6daE427B\nqrJZoNz0u5gbDcaKBsOKk9ismRfc161T31QVuLlL3GLhrr1bukXzEXTXcQjV3MTQx+ckEkmUc9sQ\nSSjGhLIncaGw4qYUNWEvzjm+WqZiFBoRYUckE1g7SLMCO5QWlbm4JDKExo5AEmU1n7NrVIbmUrTW\nK56oEGNksohiZPFxUtDIXpLn14kzNhVhUiVpYCSyw8cZ7oCLjFF5TANz8WW1JxV70vOSszc2r0qM\nb+P6ZouuiHSNaGNrha0WUgjsh4Hf6R3FjFyNz+dnrstKIPL+cGSzggbjhy/PPF0Xgnqw5Zt4AODL\neuXTPEOFMUb+9cM7Sm2cy8ZPpxNqPlv809t7cm1kq/z8+US1RrHC76d7QlQIwn+9vADGUlbexT1V\nGmOKfFo3aslkq0ySIMLEwFUBqc5GLcYgHsh402zWAKUr8qUJYcHtwDRKKC6RioIuCRXI1bAhu/FB\noBU/7hoNvWugzYtSlzJZhDC41XiIlUG7ttUa1oRxKDyMhSl4VwbGoM3P8SLcxYUUGhMb5zZgmvgQ\nz0T62AFjDJlTmTy1Qo3jsJLDxqX4WkrUOLCwE2VpgTE0xBqnOrHmQAjGTjv8pQVShHUxX7710cCY\nCiL4ws3iK50xiNFC7/bNl4yYvDrb6B07HTjun2bOpu1T0JuUN1SIQdh64cQaUaTTEdxSvVG6vEwY\nTF/T5reOVBtJ+PlDqPg2bTQfZ2R8UTyKMBKYojK3RubmGoNJIiKBmer4yiIkDUwxcGmFgi9XDzEy\nhcCMJzebwRQDh5i4lkrER0l+X1YGAkF8DDFKZAoJM1i3wrSL3I0jb8cdp7Jx2dbX/Ur8x3jh7/tK\nqlSMH55fSCHw0+nMH+8euB8nrnHj3z7+xKCJZcsM48jvd0dEjL9+ufD5kl/98b+/fyAG5Wle+On5\nxBQjtRof7g4c0sDTOvPXl5PrL83Yp5HHu5HneePzfGXOGWmNIQa+mx641IWXdeV02nDMd+B3h3uW\n0jiXhS/rQtkymPC437OuvmBbLGMSWLfK0M/2KY4sskEy1uZHxNiMUITcWTBVzD38NaDEVzDLRnbw\nN851JXt6QIsNGZ1MZaZOqkoujzocVlSNGCoqLtVKCtPQGLWwjxsqxojHATUTdiHzdlgQlLUF1CAF\n446NWeCgC1Eae914LiNFdnwYz4xSeqZaYBWlSWDOiSCwTytqxpd1R7NAJTFI5TBsvJTmm38q1zKx\nlOggn5QJUlmrAx2aQdnSK1QmxEqIRl2N2sL/yM4d5PV/6bCcm57MF2duY9be6d5gXtUaVSCZj3iS\nCIFCwY/qBWg09k29E6ax0TANpCZdkTCwUnxJquJzfBST+srKDbGPK9RhRUmVXH3EVgWieLEt4iML\n7YDzMQhmoYPVXSY2xkhpjahCM2MXhENIZDNUlGKOnHxIE2spXG0lbsKkkXeHPRJ9rPeybexT5Djt\nuB8SXWz3TXS832zRdYWn8MeHe0q3LX68nHi6zDytC4HAMSb+cHfPx8uZX5Yzp8vGtRbeDiMPhz1r\n3vhyncm58mk+82bacZz2vD8IH+eZz+eZT9crxxQZx8Qfdnf8eLrw08uJn08XxpR4CCNvH/d8PJ95\n2WZ+uV4BuIuBXdxxaivnbeNp22glMwV/uwGyNbLAmjMDQGscJTqKMAhb7R1bEXYNtHhXWgO++Inq\nIBdTQu0LM4xMl3dZQFqC6NrcFrMjHat6SGQVmhTS2EhD6QQrF8fvhkyMlff7Kyk44yqFypwDROMY\nV0YtHMJKFWWwlSm6kmIfMoe0MLfEue6xnjDxJs1IFUZx9uubuHEuI1l2vI9nhqFQTFkZyEWYUmUu\njnsZU2WnmdyUjeguPTGOaSZJ8gRgM7QIS/bvIYZGSJXSlNKJabX5Rt/dJDjKsPKahNwxQN729RFp\n+I1nwnBMhOHoAmmuGReDjQpVsVAYUXYIM0pWc7MLlT0RiCxkX/rRSE0YJJJ7UZaopKZ+mmvGLI2o\nyRdlJAZVlhBoHdI7kLgLka12C4w1hhA4xJHWNrdyiyLW2KfAUuR1TrxRCRr8+1ZjJ5FNvfseYiCm\ngTEk11IX2KwS1U8VS608AKo3BdG3cX27Rbf/kIcQUalEFXJrHFLiyEBrjX0c+HKdeb5c2YqTlIao\n3O8nzuvKdd14mmeOKfF+f2A3jFQaP5/OHtkjwveHPSDEIfCfzy8sy0Ypxu/2dyBGHCJ/PZ3IJXNZ\nNx7HHWIwBOXTsrDUyloKh+jWy2lIvCwbJo3VjJCNnQYmidRQKLmwRqXkikggLJWDAhawJNTaATEG\nafMupzVj6wsQAgw1OQ/coNAlW8EQ80h3AZdapYKE6hvqCmWLLrkKlRQLhyEzSKG0SDGYNHM/FN6N\nV5JWolWmUDjnAQmNu7AxxMJdWDi3kbuwsddCA5JU9rqx15VP9a6HMsJ9mrt5wgMmD7ow18qL7nhI\nM9+PLzSES5uoJq4LLu3V1LCPG7II1zqwtUCzwC5tVAnk6N29FsOym0ZEG6HLxkp280TLfZRQfQl5\ns0VLX6QZPdtNPD7dUFLFI88R5pZJolQzdj1VQ0SoZmQpfS0WiASESG4VxelmpbryQYMiVGJwpc3a\nCnF0s8VaCqMKc48EaqYEjDEmqhmbwIrnsO1iJIZA2bL36iZMKkwxYXmlVVfQTJo4DAFbG02MWhqH\nkBiiUFoj18wUI5MFjnHgWjOnsnmy9jiyGxLXvPGyrcyl8GG/+yaieuAbLrpBlRiUPz9/YdDIL5cz\nxzTw3f6ImfFvv/zKdTszb5lq8Md3b9gF5Zfzmf/49VeGEHmaF/54f8d+8HTTvz49k1Q5zytvDjvu\nx4nNKn95fsKuxrwV9tPIw93AZsbH8wvr9cJWfOj6x8d7cq6cS+bn60yzihXj3W7CKhQNPK9XrApz\nLjzuB0qDGAKzOb92A2Jr/eUZsQStVObkbiNrxliEZJ6GW3E84W3DftPp3qhj/kIQWHFgdTRK9BGA\nSx3wsUOJ7I4zadiIoRIEclY2cbV+UmOXNiYtHHVmI5JbYK8b74YTgzaiNCKVSbMf7wnswkrEuNcr\nz23HFCq/C09UUxSYJHOUlR/yGxxmKOzjSpTKqUzOT5eGysKlJB9njDNY46XsKT2vTKuRpEEojKmS\nQ6MsSqnBLcla2Q3Wvx8HxCPRb97Bl4gMwOqaZcrX+WQVe3WLheqdsCLQGlkh9UVVaj7TvWrmWsGi\n84qP+OhnkcpiBYuBWGAqhmpkkcza4b27EGhWKWLEpmwBJg0EjUhbnBcShURk0MCKoVRfppoTwgRh\niJGogjQlhUQQV9gMmmitoOrqlykmdhpZoiEouRhDSNwPA2stXOtGLIEUA4/jCKKcckbErcX7NHDs\n44Vv5fpmiy74j/ntbkdrxnf7A6dt5ePpxMu6seWNkcgf7x+Ya+G6LfyyrJzWhcMwcT8OvN8feM4b\n1+uZj6eZUZUxBf71/Ru+rAtftpUfn54cEkLkw9sd51x5yRs/v7zg7tLA7w8PXOrqmsVlYc2FSORx\nmliyu+avlrksC1NItCgcomP4NvUiHU3JLXMIyX3vWRxAkwKbGTE3rArBAjX4jM6Ku5nAQS90of5C\nl5+aS4oSjklcxNxaGsXDCMGXbIeKjpkUsz+r9SvE/G5cOaSFpMVP3M3RkEGMKSzcxZmojQe9cmkj\nrSlBKu/iymoDDSVSGKXyqFee2p5II2jjXhde2o6FxB+Gz1QC2ZRBHIJzbSNbjQRtJCm8G7IbNUyp\nBHYxs5XAopH9bgMzLnVia5F206PW6owJupwpu2vOl4qNkLxztB43pHgkuyfl/nah5k9qw5kTzbzI\nWYOhP/8VmKWSxZuCgcBmsJgRpNDMiAQojSrCYkqtjSAuEKsG15yZxsAk2ueu3sX6qs275URgbavL\njFXYSeQ4JNbqCBxVJTY4jCO5uXSwSSWIcpgSbWmEpkhw2s4QkhssREjR1RnBvBsO0fUcpVRkhK06\ncawJbGZ8lxJjcC31t3J900XXDB7GiWbGEAJ/u5yIfSaXYuDtfoc0eL4uPF2v3E+ugnx3PKACX+aF\nj6cTg0T2KXG3G4hR+HVe+Xy5QBUe0khMynHa8cvlzMuysqyVKQ5EUfZT4ul65doy162wl4iK8LhL\nnEpxuM3m0OcBZQyBVl04v2V/wVMrGuCgo4vvg7KkTG0GW2NsSi1G6KAU8OXZrfOKosydEFYwkika\n3d/fzFjEWDodKwFDEYiVGl2GhBgJoW4u58pSme6vHNLKXVowUUqNDDFjBse4cowrSSoDhbVFTISD\nbsTQOOhCJfDIM6c2kfti523Y0AJXG1CMQStv5Eo2Tx1WyUxszG3kE0c+pBdIwmqK4DPJz7pnbjfk\nCzwOV0qDYq7f9RyETNbEYfDF4LJG5jZSikcOqRgaPClIIsgirvy4sXvVG1+9qRqK09gcu+tz9Sie\nxpGAYg3Bv572VVsyzxEz6XB4GpHBUz1oNGuU0J2B1ZMnCobERlWjVFevVDzLLplwxXGTKu5GSwpZ\nPHpnLY0Y3LzQBEII5OqPYT8mpP9cl+JpGqMGppS45I3NHIq+j15A51aYa2VMgX0MDHHgWgqnbeU4\nTDyMI0GFJRee5hlV4Xd39/+/vvb/Z17fdNGdYuS/Xp4RgedlIWngnw4Hokb+88tnfp0vzGvlumx8\nfzxwHAYOY+aH5xdqNZ7nCx8OR45DIiblz08v1Fr5cr5wnAbeHI+MU+TPnz/xvC6c5oVmwodpRxwj\nX9bZRwylsq0b7w9HQhA2a/w8X8GU07rxdhhoDXa7O67bQpXAdV3Zp4jlyj6OVPMX9KW6RMqaMUnE\npQnGMiqleTyCFDhIcBup9VFCzy074jZQAc7WqOoSpEMVJPj8zhRsdOmUmqHB0KwMU0PTRgjNZfvV\n2EqimmJWuR9WHuLCY7p4kWvKqMYomZ1m9iETaExSmFtCBO7Dgkpjj7EivAsXxlZYuhl1rxkLwi/t\nri+AGlFnsgWu5mqGo1U2S3wpB+7iyqPM5KasNrjpIjTm7PPNJsLdkFExlpzINVAsoFIdfzgUrBPB\n6jI4WN3cJKHBeofMb1KM6VgvF6NKEU+qyL6r9PFC914IjAilVFa8uzQtjDaQJLJQmbs0LJgSqqcr\nZ4wmxaVcLSJAEWNtHreUxBiCMpvPjTcMlcCkA7ls7hwDhMAh+HIsu9UNzFwlIYoEY1Jlro0QAio+\nzjiGxKUUmlWqCQOR/ZBYt8pZCgcJjKrshwEwTmvmYTcxDQNTigwh/mOR9q1cog4WaRhjSC6sXzfO\n84nTuoEYD7uJh2nEzPi0rjxdr5ScOewm7sdHCo2XvPHrlwu5VnYx8M/v37DkQtXGf/v1V9a8oU35\n3eGeuTaKVn48PZNzxlrkLiY4jGQaT+vqM16EqQlvB9c4bjROl3P/ZTeOaYAKd9PENS9oGFi26i/g\nCqNGiroo/1x9bieqTEVJY8Ba5dxcyF8NL7bRKMXI0SjNeb2uj1XCEMg0rmTHYRUhZUXF550WG+PO\nO0PFSArrNhKnE8chM2gl9akrtacJYySpjKHwNlyZSdTmL+69br4BFyNgHLWxNN/mPwQvnCOwENjp\nyjuMxUaUxl4LMZ74z/ze2QTS2EmBcPHRhhqTQmjGpQ4EhXfjmWZwrq73VYxsqUPXYaCiY6Ouipmw\n5tiB3mCpOfSmCmQfG2hzPbbJV2wi4K41gxL65wIU1+wakFthNhxbaRkYvDi7CZyK9TFF10y3Ds8M\nzt5YW2Ekudll9IVZrcZiFalGisqofgNerSARxl4UN6us1Y0d05DYp0RulSqOd0wWGdWpc61CkUYg\nMA4DpQOCGkYUYR+jF2oximVEBnYhcs4Z1caSN6aYOA7jN6PPvV3fdNEttfFuvyeospXCv/3yC885\nI1VYW+b3xzuiRj5dLvz4/MIUE/NW+af7I4dx4Hle+Pnp8go1eRi8QJ/rxsfrlVx80z3pwOP9jjln\nXtrKl9OFQZTYhIeD6xnnlnm6LuwGbz93qogKRYRrdg1mE2PEiEa3kPoMzyxQc2G0SiQSdoFtKyw1\ns4qh1tiZMDKw7BplK6wijGKIBkaJSPTvYUkFK54QcIgJU2EWuOpGk+bg96YkAjkYRStVGkOo5HVw\nHCSNu/1MGjfuxpXBdWkkrVzzgCbjMVwYpJKkOobWhIGGqIPMFeMgcLZbmi3sxXoWpbu5jmqMrbDJ\nwDFs3LGRrLIysIrxIZ64tAEDRq2MsvJrvWMxL5hRjMdwJTcPPKumVApzTVSJHNKMiHDdEhuRVpRa\nQ29avaiGZA6Hb9I1zu7ztYhbhnvCBD4CpqqPbFQ9zkcbaPDZ6y1FPiqk4tjLtRZPXKYRGIitYapd\nw1toIThbwWDFiBG2VrGghNbYtKdEBKXVglp04I40NA20rTh5DofmhKCdVdFotbrTLKbXTKGKa3qP\nQ3TOj8Elb4godymiBDYrvOSNQxjYDSNNHOH4nDd2MfJu2pFplNb4PF+Jqvz+H+OFb+MaY+RvLy9k\nayzbhgG/v7tnlyI/vUSe1pllqTzNs6dITIn3+5EfT2d+ebnwqcvF9mHkzXd3/PB04ud15ucXd549\nxJGH48Sv68rH64XneSHnykMa2Y3OePh0vlBD47RsvPFMa1JQ5pZprfGyLNwNkbYoh+CphhXh0gqj\nKlYao4TufIosZuRSWUrz4x8DGlx72VSxqtAq94NzGFKIlNK4VpcI7XXAIgwSyNJYpNIkd5dRdDMF\nmTpW0ELA0KTEMhLd6kaIzg/Yh0KpkWseweDddOaYVu7DQlTP4IrSWG6aWdmcQiaOe4yi3IvL+yLe\nMr63xksTcne0TQq76tHwATiExtAWZonsNHMIG9YqMyNXRu7jzNYimwVPVpCCZuNcxy6TU45hpQVh\nlUDDO+NgjcUiMRZUhFJh2Qa2irN1rSMVAyANqz6dlVs+Dj5WaF2WFzbDrKM0wUcOXcnWKqzicUdF\nKtGMKEq1wiKhF2CXlUXEddVWCKaerqxCqZk1jNRWGTQSNGKhsAYj9ZPIKJFMYbVKCK4/3qXAUirN\nPLjSpLHTQFHDqjM1rILGiJQKQZkksFlG20CLQuoutHkrSIBBE4PClBIYXPpJETPGGEkaPP7oG7m+\n6aI7qLI1n4XdfPMYfL44H/e8ZnZD4kM4MKXIdct8vM6clo0YhO8OB8AhLP/902fO64a0xvfTjk38\nhfJflyvn9UreGvfTSJHCMEa+LFfO20bJjYOMPA4DJg5JuZSeT1Yah5ho5oVsLoUYEzlvXX+5sdPI\nVgrTbuQyLxQaLTf26rPG+ylxaUaRSs7GPkZSCAwNluCRNktwmEkMkRAitTaurbJaQwLcpYnaKiLG\nqsYmihWBEBlQTCuFwjDNTKEQxAhqLOueGq+E2EiaMYEpeF7OpYy+yEwL+7Bw1OLEsuaLntp1uIM4\ntaonqxNEeezWY3Ce7RstfLHI2mPMRzWOtvHZYl9awYGFTPw/2XuzJjmyJEvvU733mpm7xwIkMmuZ\nXqZlhkLh//8jHIpwEZLdLBlWVXblAiAWX8zuosoHtUAWH+aBMz0jTUHaCwQBwCPgEaamV/Wc7zBJ\n5y5vMUqwhdUzRQd33KheqDhJIOng1hbchUYKKleqrF4YpOhu96/Xkn/pZLU7hu4jFEAlrMHs0uDw\nODBUIu5nhOIsoD7x81dlj9mJ3zJZZNJ1nCLCYATs3GWPkY+HUnVjTkKmfMnlUVcandViZp9L6LH7\n0FArSOw2skdwZh+g6kwScBvbjOaxLJ1zjBz6qCEbQ1iSctBg8VZtiOX4t5KxErSJdXTuSmG20CRX\nBdk27peFhzmK79fC0oWvvOhWi2icouE2+sdPn/jT6wuY8VxXfnu64zQvXG4r//T0kULi0+XGh7uZ\n+2mhmwdPdwwut86sibv7EwJ8vpx53Ta8O9UHvzsd6cPoRfnnyxnvRu3Oh3nGtWDSedpCbbm1Rtnn\nzS5CM8WzUAf4aCR7888Xhg96cs63LTpCV07TjPWKTYnn2igp4935LmVUMykpl94QC0PInRZSWgDj\nusNVUDiZQlY24pjacYY4c1JEJ5pHpRnemeZBbQXvJVIKpopq4265sqSKIyQ1XrcTNb2SxCkpXGyz\nQCZRXTCc4p1Zo4OPKzSgebcbZARVIXl0eUMG72Xg3vejMxylcZXKzSeMCGA86conO1E9XneWUFNM\nOoJa5pWncaSOTPPEIdfdgRbSOFFnGwXc6XuLOs/OunbcAwH5FmTpKWbXyJ6P9vZDt08bYtvFl1nv\nIMbBE7HYnDJYj7yyDd8LbTB6XYSKUUi42xcCg6vHe+gj3jE3qgjDiEhKdWYVzAZVIsx0mJFEUU2o\nOSrGNoRpCqXGrMKcMuswusTJZU6ZJReGGdswVDpF4KhLLOFG5+KNYy4cilLNuI2GjsyhJL49njjX\nytoaH69Xsgp38/xf+3b/V3N91UU3i7CNzrkadTRab3w4Hjjkwv02c143/vz8zMfrK9kyy1z49x8W\nPm4rP1yv/PR8RgRKzvzDtw+8rJVzq3y8XjAb5JR4NxU2hyfbeK2Va60smjikiccZnmpFZfDx9cox\nF1SVh+PMbWtYVp63WxgZBpyk0FsnL4lLbcySWdugJGW0wd0y0Ybh2alpgh4qgoMI7+YZRbl4x2on\nFQDlPi84ShPn3G+IhOX2nc604bTkNLuh4rExT04SpaWGS2AlJUPpmTTmfetuZG8cp8raC5uVnam6\nomqc0sYh1S/Oq9e+8EE7SSwAOLt5YJZC9zdEYsxxVeTL70WUiczqFlwDEaadP/Mqyl2q3Hmlu9BJ\nzKmzWEjKDMjiHFNFunO1spd2Z5Iei7/kEb8OrGMK6ZoGdGHKgzZ2Vq8FcYu3+J43RVrguPavNq4d\nMQt1twoTIZAj7SqGvb1tgdRg21m4SNRo29UEDrSwwjAwFlKkRaSY/Xd3tuRYN4omisTI4mIxdio5\ncVClirD2iqY5HlbLQr/dsKF0IoR0KgkzRwk7ryDMnmge/ua3E2MCsiSkJBJwa5UpR/pvyRNzjv/8\nua4cyxxYy5TI6dfxwldzTSlxa43aB+ZOtcHjsuDmrHXw8/XKIU/clYWHZcZd+Lxe+OH8ijsclkJG\nWUriL9cbL9cbt61xmme6ZO7mwtO2cm4rl7VxyBOLFo5zojbjpXYufeVoYZUMOpVzHp3hg1EHeW+L\nTqKstjLNE+feMHPO7cpdPrCNyuPdkWqG5MTLWrnPIWf69uGBW+94SZxbDYG8wskzPTtDZO9OoGhg\n+4qG1Mi10XrnsFuQc1Ze6VRpe6pEIu+LooEhU+WYGuCk3GnbgZZvVJSkDfXEt+VG80zrAY2Zcgdt\nTBJIR/Nwb1XLmAY6MIl9kRQtMtHddmpXNIqLKGvs9YF4mMascy/EGOrO2QoH6ZxSpTusLHRXigR9\nq3pI5UrqZB28tIx7SMayDIbHn4lEVFCzxGj7kkz2hEfbxbp75luMnn+RLzj8AojfPzz2ht6IUElL\n8Xe0/DLjFQ09bdb4f0GEWLo6A6cS1u5g2IJJRA7Z/sqOojuTVwjObRfAlZITSRJdjdoaWZUszlyU\ntQ22LRgeRWOvcHVjjdhkjrkwpcJwYRPHce7SxFwSl97YetjhH+bEJBK2dgHzyrenEw/LElyGX8cL\nX8fV3XiYZsohYR6wwT++POPD+el25d184FASj3PhD0/PAbbZrpxS4X6Pk/7j6xN/vrxw3Rq9O98e\nFual8Pm68ufrC6MZa9t4LEc0Benp521luFHr4KHMmIVtdrONpsrztcYs08JSufaBLUrdhFor6gG4\nPqaZIOgpN29xc6/Gt9OMdWO+T3xsG7kkel15TFOoAJaZV9udbqMxoRQVSsqsrTJQNhquyiEXVnXM\nBzeMgUWgoc+MIaDGViqjNMZIIIOMYl1Z9+72Ibd9vDB42Y6s+kolkbXjBu/zigE3i0WhSBRiR5kl\n2Fx9n/M67AukwCXGXkqYiPhx87Dhnuhse6wMRFdbeEs0lnCb2WDzTBLjXboyHJ79yLBoLVWEOuRL\nVzzlwWWdcJStZ3BnKs5qEWsPsksQZC+otluEFXabbrhtf+ly96RO8N2SHa9C2jv2xL6cwwOv+Vc/\nvxWQncNAtgDjpIhNf4tNMhlczUkoQ3c5tkSs/TY6ziBrxm3EGEWhdpinhGigP+dSuPWQUA5VsjvH\npFz3UUJK8TBYVNlGjBJchWMqLFOhDePSAlI0l8K7w4HXuvGyrqhC0cz7w+G/yj3+r/H6qouuEKT/\ndQx679xG536aWVLiNBdurfPcGj+/vrBtA83C350euVqlMfhfP/7A1io2hN89PHCrFXP40/XMFL9L\nBAAAIABJREFU9bbRh3MqicPhxDYalz64rAEjz5p4LIU6Kqjyqa1xW3bhPiu1G8tx4vN1ZZ4Kz20j\nucIwljJxGw2OE7c2yBpC9LucmaZEnoTXFGhAkjKZcEgTss+IP203LIdd9ZSmSIhQ5WoNmZRRO4tm\nmhlIpvsa81x85x0UVrlh09uSJRJhZSys7YhjLMcLh+nCZgUxR93plugmTKnxqBeGCObKczvxQQaD\nCJ0UHxwkEIFtb4DeOrbug1lSnN7372LE2sSyrctbFhosHq/5xmS4042PftrNGjF7zbsIraGYKMmD\nkOUOp1Q5qvBcZyqZulO+zTyijoiHjge2LZI3jEh+FKL4jn3DhseiMPFLJO9QXAKdiLaorlnwEcUN\nCKhNi9j1t1qefnnFnaHu6B4XlBCaBJ93tYYgaI4OfozOWhLZAw5UUqK6ffm+JgIA1WSLB5plksRc\nt++W39Y7OWkkQIwNTRLBoe7MWhg0cooTW+2wFA/OiQTrJIp94zTNOEbW/KtO92u6iipb67xsKyLO\neVv59niHinDtnR+vF9SFrML9/ZE5Z65t4y+fLrQRlqKiicfTgVutvPaNn6/XyHpS4TfHhXU4F1t5\nGjd0KKKZRRUkcfHKVTo63shPQRe7jk6ZCi+jY0m49pVjmmjdOBwKN4+u4dN2415nvHfezxMbQIIn\nG1E0u/OQE6JRYJ57C4hJThQJgpa4sMmge7jV2oh5n5mCOJ3BkhLrcA44N5TVGl1idvslhN1mTC/k\nvOE6wA3vM2s78nl9xMV4t5x5LBeuPpE8juGNFBxdMU7aiBxN4TpmmgS4OwT2xixCkX2WuLu4DAvX\n1F+NBAXCGhxE2VhcubCROUgQxDZXhieWHY7eKHsnrSQJffVmkaFmCOLG8My0Q8xxGF5obS9Xvts9\n3hiOu2b3l0NziFq/KBpGAOBlx1QC+ywhuvBhHuxHU3L+Rau8v1JYI3JIgR0YZrgKwxpZJ5qHxnYM\nZ3KnaxTfRMx7s8qehCH4rNiIBZ33EewFLSQEc2NlICIBzkmJ2gfPdUN1Vz5oZh2Na9swdWbP5JLZ\nemdrDRP45nBEkOiEmzGs8rv7Ox7m+dfxwtd0DXcOpXCaw544pcyfXl65tcrHy5WSE0vO/PZ05M/n\nFz5fVn64vJJEeVyOvD/M/Hi58sPtlad1Y+2RgLrMM2N0fmgX6tY598pRJ3JJHA2eRwWJYj/t34JD\nEpo7V6tcvZNHh+5MqqEmEKcn50bfo1sijkXFyUvmIp3hznVzjiikwl2KVABRoRrMKWM9dMLVKj3D\npW0ccmbURnElC7GEoePqCBZxOaq00SPhV5TETK8x39xy5aqNCSelQfZwS611Ro6fuJ8uSIq+dGsz\nl3TkJ3sEjIdy4yHfuFrIwSBy1FZCCnXUKBJdnNWFg0v0pmZsdGaJ6J5A18qXDngW2zkTb5dTPWNv\n79leAG+eKDpItlLJ3NhtvTjVMt0CFaPqzGnwOqLIdnsrtAJ5D5fT/Xw+5EtemiTHd5faPuAFHElv\nHfB+qcDoUdB3zi0SwZgbukvI9AtIB0LtoPrLCLkPRzMMb6hm2CHpzYTk/YtbbLjRbATzIcFkRMR6\nj/d62KBoROooJaD3PqBMJBfy7mBbW/yMuDhFM1NObDa4+uAOZc6FuSRuvfK8btwvEw/TRMnKasbT\nbaMN41Am0lfU7X7VRVckOkDdn8CtRwjf3XLgUDJ1GLfW+cenT3xaV8wH385HmhiTKP/x+YnnulLN\neDws3NuMqvC5rrxuN9pwEoO7XEgqVKt87DXmjhZxJ8MiBubTuIUqAKMgdDce88xlVKTAs1WyKOs2\neFgO9DRYSubcB1nBm1NEOZKZj5m1dVqC2h3djNkLQ405J57aii7CVo3jlKjNecgzFwuDRLWwi269\nk0TptIDCiOGyQ1qImO+bRIpFEfB24rYewgpcVrRsbCiPuWMWx/en7R6WH3nIV1Cnu3LtMxeZ+dgL\nLsasjfu0sfogewCyzWFzY3PjEERJbm5cPFQVEV/pbMT89UuAg/OlW13YOMv8JdUYgbozIKLAvFls\nlYFy0EoTJcJ9wqnVR9iyu+vOWxhI2zvavn9huk+Oe9rj2eNhInuRVX4puI5Entqbu+Ktj9XobV3G\n7kZw3HIUuZHoxAPSiAewviEf3nZ3NLpk3AyKhLaYODEgQTbTpHjvUXyHkKbMVBLdnDYGljKTCKVk\nWnO2XklayAlOpcTfM6cPoySY0hKQ+/0haKNxmDJLyuRckP0Ud5/DEowISRNJ/0pS9xVcX3XRfSPY\n/+n8zKSJn25XHg8Ls2ZuPfN/fPxIG40+jMd54Thn3Jw/vHzmZb0xhrEO4/1yBB9s0vnL5QwuoYQo\nMyYFH41P/cbA2awzeQCjayDCMTVG91hSuWIa3eRL2rDmdIeTJ7oKpzJz6xtlSfxsKxPKdut8WGbW\nNpBT5mNfKVmxzTiViF8RdSpGl0FlUHpGR2y9RY0XN0Y2ejPmMrH1wSElbj4omuneOWphHZ2RhWYb\nI1kcqZMz2YzYOTirOsIqbMpa73m5PeIMDtPGIQeg/DflmW4ZR/mxvuPvyxN36YYLXKxwHguPUnnp\naT9aZ+60s/lG9om0d6SN2NYv6rtKy9gcrpZ2aZlQiW7Y+EX3O0zoKBOhId4800zBleZvSb1RQN6G\nBM0yObVQLzisLQXW0fbXfUtYTLFDc2VPXQjOAnn8FWc3vrZfevO/lsP99SAhJtZo5PnFLMdRCbKc\njYSl6GaHR7Kw7UVvSuASYBqVmPEPgdGNPKUvc+Fw0ik0w1KQyKaUSO5BMvPQZs8pk0XpHsqbSROH\nUuI13blZj0VtzpHuponbaEwamNJqg1trbKNzKIXf392xlELtga1MX4ls7Ksuuu5Bwf/d6US32Hr/\neDnz/fbKx9sVd0OS8u8+fOCn65mXbeWH64V1dA5p4uH+wLVufN4ufKob120lS2JJmVNWXtqKKHzu\nV5SEAo8lMq1cjJfRvnQrk0Yimia4EfpaG0JJsh//jJwSNRuthcSp9ISkxDzBkzd8Fl76hYNn6mbc\np8I6nGk2zq1zKjPX3nmcZ1o3ZFIuozKj3CwWVOaw9cHAOEtIrlyD4zraoOyFOJFQy2ScTQYrhusg\nlU5xQcQZ9cSYbiCGpEH1zGzCS7vnp/U9LsYxbxzyxpPN/Ju80Tw0mz+2E79NnUUiQubJ046HhB+7\n4UT44kEGg07ykK+5C3VXAiQBxzGPzLMnO+wdapQyQ4gsXQIl44lISANxZ7MJM6FbwMQjDSIYuMNC\n3RBsotCrxh9IoBx973qz7wwGDT0Y8qXcsmNr3lKCfS+2icA5pgx0xXZUpCgxyvBwwg0iNknEISkW\nuZGM3WosDmadLplMDmVJCtXHsI5Z3P4pZXyEAsMImE7RYEko4Rarw7hbMj5CzTyXFEtgmSLKXWP3\nYcOp7iwpUlZKyWxb49P1xnEufDieyCqkFIqhNgZpH2V8LddXXXQBEGFKEzoGRRvDYMqZ75YDqwX4\n5fvLM9+fz1zHxlEzkjInLXyuNz5vN576jUmV03HmKIWbVz7XlYts6AgplrijnjjnFbOYm+kOmS5D\nqBlUBxci4nwzOEksVTQ5WzJkDEYX7jR4CfOsbPQwMDRnGhoqgiRoMV5pdDFeN2PKE2urHHTi3BqS\nYR09ID5tcCiZrRmahJUW3X5rHFTZzEFtz9TaragSj5HVe1hNdeD9QGuGijBSRApV4G6+ReS5wMvt\nkW/LFRdHk3Gzmckb53Hgfzo/IhgldZZc+Xkk/q603fkGn2zi3gazBEDnxYXuiQXn074EW61ExI+z\nd7LRAVcLiZftN/cOn+RlzFSLbjARfAUk4DeFzkrwAsISYNQ+4+b0Edu0nITR4f9tNXv72XLUd5ku\nfBHdyhf24xs2R/deWijad5D7blqT+JMkYNLJ6rgH+SvGEhawekBy8HHzrsqxFAsqV2f4Xqw9Srwm\n3SOBbJ89K1mVlIQ2jDoMFWXSwlImuqy0MUBg0bKTygqD3aYtyp1m1jyYJEfKiMMply8PmJIzy5T5\ncDjytN6oo7OUwv00/WqO+FouEaGo8qeXZ1SEj9cLc1K+mQ+4O//46WcurfK6bWRV/ubwiKrww+2V\n/+vyiXVU1h5a25yEas6P6wtdQn62SEamWIKcbQNZubZQDCgaoz91mnaGBmRELW7mWZQxDSTteswe\nE8GDCI1BWuDJb1EKmnDUgpmTZ+VilZyFWzeOkkmSvmTAOTdWDG1xo4weyb3XYXga0WGpcOuNU15Y\nvVJSWDnnVLiOhiRo7qxS967NOXpB3TBPVB8oGS0raz+yvhxjjpkangY3S3x7eA0JmSf+cv2Wb+/O\nGJCS82JHZDg3yfyH2zckd9DBrJ2fx+Dvyh77jvAyElWckw6QQd3xkJsrzWIWu1pi1igYzZSsMbet\nVr7Md5snukenfO2ZantsjQ7WESkdW8+IjpgHY3tk0JsOF+gputYC2gZYwm1ftgn84kvbi6QQ6RLS\n0UCW47LPclUI8ZqSS9/ntXkfOIAmQ8QwMZJGx+5u6C4FdAPzEYYbNYyCqVEZDDfEM4jvcuHImIsF\nnqCqpKK4xWmwjcYiAS4fY2DJqMSSd97jmIY7Fx2cysSHw5FLbax9iyVyLnx398jYl3TdjMdl4XEO\nt9rXdn3VRReiC3m3zAxz5Hji6Xrl+5cXftqu1NpoGP/wzXuebxvntvH96xMvrSICH44nbDiXXvnU\nLzyvZ5BEEniYp73IDT73C7oDw4+a6LrTtQj+bO+OenRQGbC5k0Ro3oMRODSWZapYjq2zuuBdMXeO\nqnRtkOCzR2Lv2uBBZvowDkvm1TfKolyacZcKrUORoPznBLcxKGTMHbdwXN3YkARDogva3FhyYiUY\nD2bOkhObdToSEqfSkB0im9oBz8awRJNGcuWYLzy3B17aPWoOJUA3T3bkb5dPkbyrmT/dPvBw3OJ1\nxXged9z7lYN0/ufbCZVgyy6psQ3nqJ0gFAoXy1yscNSGijGYOFvidRxYrWAjCnHSgLv0EQWveQqX\nmYcmdTXFTBke0Jqxd9xJfV/AQmu7fswF0kCHBJncFNO3gUEU0vArhGuLffz7Nmwo2eh9HyFIpP9O\nSWLQoPEwyyl4CdkSLoFx3OoATZj1XYrWSZr3z6j0seM4JaRgkworysAYFuO1CLU0GmE3FxWOmqiE\nw63v4N8lF9akqIQSZvUGyUgpmApzLtTWeL2tHJaJudwxZWFOhUMJi/utVU7TxJLzV6VY+Ovr16IL\nnMq8D/KVHy8XEHhfZs4e1tdP1yt/vDzxud6YJCEZfluOrKPz2S58HlfMjTIXDpLpbmy+cpMtOp0c\ny4gsCcsd9c5II8T8DrPOMaEsSpcVsbDmq8/YcJbsNFnJGjHZQqF74lCgGbTUqeakLIyrMKmQRKmy\n0TNcR0URWleOzGzdmYvwYhs5J84DDjnF59IcLrWcuA5jcaGPjgrBdUjsmEXhMAq9O0UTK41iBa9C\nEaUzuDkUM/K0Mrki6qyXB/pxsFrGc0Oq825+5XN74KnehTkh9ln8pT3w7w8/0iwoX3+p75kJpmtm\n8NmP3LGySON/3x4jUgZhSZ3qiUUCPfm2KHvuM2/3uYuwWuLcJ9ZeAiIzouNVjUWburBZou/Q9Yxx\n6/Oui42hqUhIpqIVDbuC+x7r40Aau5wgxhMmupMeIx7Hdydk7zGjnXTgo6NJUBkMKagryGBI2KZH\nj8TgYUJKYDKYUhDCUkmREJyDwZBLzKs7oR03hNEHaSpoEmofYZwj1C+qYD0oc9mVaZpYNNGJBWza\nT4dCzIJ1H0ucUuY0z7zsQB53+HA68OF45OP1ytobp2nit3f3TF9hd/vX11dfdJec+cPTZ5IorzUi\n/L49nTikzB9ePvPz9cyn7UZ35/enB5LAa1/5/vrMZdy4WuWYQmeYpPDUXhjSWH2jmOA5EHydAbpS\nrcKbI8qjOKbpFlljgFkK3/1IJKkwAzTSyFSrKBNgTFmpciFlYXPIzFgTDrPQzMlz4zI8btgWSg0T\nobFS1VkdRBQzY7Ycm/rsXK2R5jCHHEowIkpR1h7z39UjsPPWw346PJIGJs1hs+2FKhYb/lyp7cgY\nhaKBHuyWuLlwOpwZljCBp/M3HI6wWYI8sA4fplfOduA/vP7bWCyJYCL8sT7yPxx/imQLdz7Wux1Q\nE/POZzvEggfhj9uCkahDmbLtyMhEkUF3pVnitUcGmzuYa3y8p6CJSSzUfO9JI4vN6V13OZoiw/Gx\nzyNHqGE8gfpuZxjyxRJsCEkcsd3fq4aNIImlFLpfleBhhA4sIZZIqX0RRgR410k5XGS9J7LAGLtp\nQ4LFZp1wtzloCVau7BjKPCmSdqFF0tDydue4JKYcOMxFMt1HFO6cUFMmTVRzmjuzOPdl5t1xoY7O\nS2t4StwfFr45HBkes30BHpeFh3lmzvmrWpj9p66vvugKcCwFd+NUCmbG87byf17P3Hrj3BrfHY+8\n94VtDP58feLTdqESaoA7mejduMnKU/9EVUe8c0wTbh0VZ7MLJY1wSHmhmjAhIOFnNxpjBFwlewjd\n57TiaeAmMV9MzixCyhtdHOEKVtiGI33CxcjZqHQowrlHIXaHpWigAKfBS4eTZrbNOKTEcKelwUYl\nu+IetuHsQh+OFKFjlF2/OVmmJaMQyMelpNhAW2KTWJ4N9xDdj4miid4WXmWEsH+68bLd00ahiLCZ\ncLPMQ898c3ylWaKR+OHyAV2Um2VSHjRPfDO9Upn4H1//dofUKEMUZfDfH3/8InV96QfOfeGQKyrO\nlZk6BnVkrjbRUbaRySm4v70LWZ1qCfPEtU10i8UcezF2I5ZnDlmdYW+qCN/NZtHJ7rtGviDH1OM4\n4xndE5E9RXcc8tuE+EC97xNfR7Pu01wj5w5iqCb6eAP9JFT7ri1OiFikESOIVzQVVD0SJlq8ZoxD\nBjnDaCH9tRSGlGkKCH7Dab3Fx0rBd1NM0cRqjaIZ0U5JhaWkvbA6S5mDWpYTU0o8zDPuxmvdAHi3\nLMz5qy81X66v/p0wnMd5RiQie15uK1trHFJm7Y3vDkfWsfH99Ykft1eSCKaN35d7ugzO25WzvGJe\n0Rx6WlFlWMPyhnn9JYF1l+0UN5JuiAawpVmk22YdqK4Rec2gtgknMQsMrSzFGdrxUWhDyCkE/SVX\nugyw6JTNEzrynl0GmwxchGsP+2+1zlQSG5WpKOdunHxiszAebFi4hvoWHY85Zege1QLs+tXZE7To\nvZsY0hTLsCShmSCubHbGdIQ4SiFvJ9Q3Xrc7Kh3VwTxf+dQe2M6FSZ2bKZc+c8grvzs9UYcCEz9c\n3tOmM5sVUupso/DN/Mqizv92+R0KUThV2YbyON1QjBsT2yj86fbInHo49Dyx9ZjTrn2imwbom5Bg\nuYcCYOthBW4WS7YxdiaCCftKa1cP7JIxMowR0IS9Amtykoygfu2vm76kSQxy6oxdsyzaGcMpJR7e\nNjJpFqwborAkp3qPZIoEssc6pxGgc9F9Abbv7pJG12275s1MwnF2DC0xGEUSpkaWPR0Ex1IEm2YN\njXNJKbgXJfPucORuynxcN259cCyJ3989cJonXreNbQymlPibh3ccSvlvfUv/q7+++qI7aeL720vk\ng/XBNgbvjwsP5cCfz898f37lp+3KuVcep2DPniTz83bmYq/c7EYRJWfn3o9c7IamiuuFTNwISsF2\ndldOK05s0ZUQXxZtkI2sbXdSKes2M2cLfi0X+p4V1rcFAYoozoW7A3EMHDMdZUZpbkzHxrZvxVtV\nsmRkCJJDd9kltu+rDTISyQ+aWKXtagVnKYnaQxjfMLIq1z44pkzd4dfNW/BQLVFywntECBWFTSqp\nTUjpHFSoFcbIbHpDpoZagMHHeke3jR/bIyMFNeZ4uPCp3dFeM4t0zpZ5HQvD4e/vPgbLNis/rw+k\nXfaUs3Hthcc5+L9/uH5AgGoBlhkeIZpJBpXEcOEv14c4/UvMbdsI5HgbKeafprukLFgF5oKNkJQB\nCLarFyR+tbdctNDgioSe1z1MLplQDBjClN/CK+N9UB9ftMXuhqoiruAtHtgp+A4moV6IlVyMdXZX\nMo5+cRwnwAuQNUbLJgyJB+qkSsmF1gejd9yc++NM2ZXD92ni6p2k0b2WrHx3OnFrlXNdKXrkw+HI\noWRyUorq3uE699PEcZq++tntf+r66otuCLMV1ChJKDlRe+d/ef0Ll1Z5bit3U+E4PYILP21nft6u\nrLaSkvOQDgw3kjSu9jOeKzA4pD3+h4nBjTlXhA0n0Vz2mPOGiDPpjToKjYR6EKselht4dEgNJVli\nbYXT3IFO0k4du7Vzm/dFHXi5cJBgLagXxogFnmPMx0aVgVpm68IsSmvRdW0Myi5ba+b4yJjGpn6M\ngU5Bv1pSdMQBnhlkVeqwiO6uHVWj7ZlmRTJzmmAjNJ5qVO2MemDWyqLOrSvbWkhaSdO6y7Fgvd6R\ny8Zzf0fLsex5OJ15HQf+4/lbFm1cR+F1LMzS+YeHjyEVS5nneuR1m8k6SMnZRuJYGirw83rCHTbT\n2PrjuCSyDnrPgPJpm6hdQSLaJnjBMWYIWReRtku8J8FetL/iKwySRgZacsNs18aKkyTGA3mXnplB\nyYLQcU2UNDBLId3CSTooyRgjvGpDJfLRGOSo3l8YC6qRxmE1zA1ZoA+Q3TShGY6SuMhgmFA8ZF8p\npd2+Kywpx/dfIrlkyomclCklvpkPfFKhW7jhlinxtw+PrL1xqY2tD94fjtxN03/r2/j/V9dXX3S7\nWQz5U6aPwVY7P98uGM7WB4/LAcT5cXvin6+fMXdWbnwzH0k6s43O5/GE+Y00NWYckbLfMCvIMyJB\nu/qSgeWZrFemXMnSWG1C1ChDUb2gCEU3zv1Id1ALqPT76RrzO5Rtd0mtI3E3R4ecGWF/7YrZBLIH\nGM5XVITWAZsYbkxMDOlMi1CHUySzVaOQ6QazOKttzDl+rz0hHkVDCcF95HQpE8a2L+2qDCYNq+iJ\nmbWuOJHvpS7Mlpk0wfnExQdDBiNXXrcjd3JjVuPSE7d1Qg7OvNxippqU19c7etn41B4Z2XFzHk9X\nzIQ/Xt5x0M6lz7zaRO/C390/BydBlVsv/Hw5kdSQLLQBOTlZnUtLYJl1pODimuz6WaGNcLPVpoyx\nL0Bj2rvPcxNID0+ZOrQ49kcqRjjLVCS0xuzOtT1Zl333lRjxf3yzMqNMqaHuiGdEApxjKswyuLmQ\nkiOmSIn5ubZ4PdXEvJ/ou+xjrSkxWmQCLstC37Z9TAYlK/dTYd0B6UK4IBHjfp745nhg0szztvKX\n9cLDvPDd8ZGcMmOEfXdKmYf7hbvp6wLX/OdeX33RzarcauNprAx3Xlvl/XHhb8s9T8uNP19e+WF9\n4afbC5Nmhgy+kxOXtnHxJ272iirk3DhwYtBxVkzOzLqBhFHARYLGr2dg0D064ZycNBpNjWV6RSXo\nr6/jwIGAySi3vfMZfK53b9J6cup8kzfcO53COjKFTFO4m0YwEByuI2GW6FbiGy4DKTeEGHOITAwz\nJtV9IQfVnYyydqNYENAWjZTiY47MrKPHKGMWIrE2B0A7uUQ2nAxUEi313WLsTKNQfcNpDKKjmnph\nkkJ7PnGxwVDDp41zPUB2Fhm89sRlPTBMOR6uuCVGEl6uBzLw6XzPjhvm4bTiDn+53nPQxjYyV5u5\nWuLdtO3FMDGG8vPLHNu3FLpjXMjJ8CGMEUV37B2svy3MRsx0BcF3RKWIILYf+Z3QIEtIutQjn+xN\nBxGQmhTkr10r3BDEoZtSdKA+diBTjDUGQhEnJ4tRgQrmjg2LOJ/9JK9JGX03XhhogmPJjH3pt9Yw\nu9wtMzbgkArHMjHqxlwyyRNzTnw43lF7cCZKzrzXA4epMKfE3TRzP898vN0wdx7mmWMpX5Wr7L/k\n+uqLbtIQ2Ye20sJzjvN/vz7xabty3jZUhL+5e2Ri4nM78+P2mVc/k3RjyaHxXDRmrElfSdxIGmGI\nKgvDG0tayVzJ2hmuNJ9wh8Lgm+mZm4X2VnZ//wOVrltYXtUZpnxuR+6myBZLXHdrqvNpvY8AAoSS\nKsccNtl1FLaRKVoigHIxOmFeWJuiI+/HzP14miP224llzBiROKA5bn6zwSQac15TbntXu5pwSsFP\nnSQ+VyFA2QcUZyJ1pXvnhoVTazZmD15usoXeVzqdIdE95jqjafD68chHHNWBTZWXuiB5sKROa8rr\neqCkzv3pyrBQM5xvM70psDA06GCnpaE4z+vCnBq1JzZPVE9MaojHbNYRXq8lZrqqKFGAZRciDNsL\ntscgJDrehAzDB8E8IKLKRSBZQ1JijHDbCfvYQALjaGRMOsoAUZbUqL3gWsANSaHnxRwzaF0hxQNY\nJUVApzq2z+uR4FIcigSO1+FA4aZxSiklZIJ3eUEK9GFsNphSCV5CyjxOM785nfh4vVL74JjDCPQ3\n9w+81sqtN5aS+e505HFefi22/x+vr77oDjPup4npkIPiBPzT54+c28q1hYXRk3EZnT9cv6f7YOXC\nfZq5Xybchaf2RPeNpBcWrbhMJFcWvSH6mUlWJulkscBIemYmoscn2RhEBMo6ZtyMos4pf+ZTv4sO\nZdd3vsvXoEdJdD7VMs/twKnU3cF/pRE5Z0/rEXbMTs4bS3bcN3xEzE6WgqmzpEjzLR7ys2yhjCgO\nQ40pOdsIufAYsm/dwZPGDNw0WAzdUZRqg5ygifK4Pxqm4WwjrKwRCzRBPSBmbB6wHLWJXDZKiqM/\nbWb4yvBQXgxPpDYBjafrt5F/pobMld6VUgdLbtQuvGwRnX5/XAOj6cJW0+4eEz5zxIYxTU5SY2uJ\nLB6drWss6djNDTuKa2yKjbcudRCwRgWPNFyPCkrCSLtJTSRaTfMgjZklSjJIUUz7CCYHSLAfUoDI\nu8CiAzNFXCgysBRwm5ScNizwmZ4w3RdrKlh3UgnkY8mF4jO1dS6tYq58OB046Uztzn3bM3BXAAAg\nAElEQVQp3IaRs7CUQlLl93cPDDNuvfG03niYJ+7nhayJQy5MOXNHNCrvloXpKwuU/Je6vvqim1Rp\nZpy3K82Mp/XGu+OBD3pk7Z0fzmd+bM/88+UTXTqu8O10R7dGtc+s/kROIfU6pCOHPf5m889MemHS\nG5M6zTMqzlEvHLkiMshuzNpplrlI4j6/ULRR6Fx85l42thFJsykNFtn4oT5+4QUkMd7lK54itbWj\n3Lrw2mfup5i3kTeqJ8ZIXNq8x3EJpWyR1+Ux/x1jkGXCZZALuwMrUUdYfluL5dlqzikHi2GxmAef\nRNm8M+fg4S4yMZqFhtYGqxpJM+pwr8J1gLjtxUM4pkSSgtSFPgYyBlty3BfSslKAOgy7lQhpyKGB\nHSS0ZhTn6fZA9wTJ0KnjZtxaZskNQ7isM6MnprmHnladMZztlkASV5fdtCChDOgeR/geDjPzvxop\n4GGZe3sAmYbpQTwCPc1RnBEOXTIdISOlIxYjBPXCJC2WZzLoZNqI8UMWUI9/46J0UTqxZMvJyIFU\nAAvzwSTGzWMsdMoTZ4uxkoowZeF+nqkVihfu84GbNK4+UIN3y4l3ZSEl5WHZ6XPAIU0cysTfP77H\ncT5db9xa47hnnP1qcvjPv34tuvvMbB0B7HbiTflUL3xcrzzVG5t1fnt3x0G/4eadl3rhxZ5xeSUn\nUCkcs+43ymeKvnJUJ0tlliWiaKiovHLUlUyn+kwlgSjf5DMLka+W3JhTRy2ScL/JZ3Rnqj73E/d5\no1oJ5YM62Qff11BWdBJZjfflAj6ozGxDoR+4mXMqA/MY+K0kvCVuLZN15wCkFooFUXqfwiFlGtSy\nHbA9EfrW2Ql7LsZtDO5KpCzMMlHbYJLE6p37tNBll5nVSB1mdLpALop44oBy6S3cXsMpOUPqJGa4\nTNysBkJSCTBMrmRxmht9KwGGKSNmtQ7UsKneXo68vkHFSxByexfmbAwXWs+0JlB2BoKG1Mt7OFSa\nZ0RiyYXKFxaYmQYcXPRL/ptYnDVwR5IxLEwU2QaWo3CHXjajKigdPCGueN6dXwpZjT7AJfLssNDi\nFrWgmL25H5MFe9d9d0NGx1qyMoVQhCGRp/dhueeslUUzIrEwez/N6CR8uxz5zeGen24Xfrpc+HA4\n8t99+I5TKbQxaGNwmArfnU58ezySVH/tbv8Lr6++6A53ppz5/TRhROLqP336yKd25WlbySlxN4UF\n8p+3j5zbhXWsnErhrjxy1Dsu45XGK51PnNIroCRVDuJkLhz0zKSVLEam4SwUjPdyo8hKwpi9crZD\nyMaA79IzSWJppm6B9EsrSZyiZ8BolvncT9znRh+JIw3bmaY/1QccwTxTUqfQQW6sFijDUjM1OQcx\nuneyJOregdUWcqKBMJXA/M05Y00o+0OKnMnJKFLA42MqQrOOSqLjPJaCoEwIW43Mrm6Vd9OCiLNo\n4bWuWJcd8i0cSkZVmYdS3XB10giRfmKAFMYtWK4JxXYIjABTimghb1Ok2+boCA2DVnfb7BRKC2Bk\nRWkh/9obVzelVSISKdkXLgHd8A6DRMYwB3CS2a7ZFfCxx5yzz/Md/ytmuY3ElFtIybJRe0LcEA8d\ncEotCu14iwwKPe6U4dYESYFL7EOY4As6csqHSJXwwdZCI31/nOm1sGjhmGfcNJa2OfFeFv7N8R2X\nViO1QYX7aea3ez7g3VT4zemOT7crTgDN7w/zr8qEf6Hrqy+6b8L1l7rRRufzbeUwT/zN8p6/v4fz\ntvF53PjT9UeudqXa4P1yYM7CnDpP9c8MVlRvvJOF+/SOJAc2e6LojUleOchGWBCUWRoqF7LEGGHa\n5UarH5j1SpYokA3lvZ55HkcGSpHKb8sz3+/jBYBZBw9ppUgnZ2e4cLWZV5855U43ZZJBl2AcfK4n\n3JXhmZIbxQRoXG3CuzAxBd9XheGNJBObw0zCGxSPmJ5lylTrPEwzfZPIy+qNU0k4iWPO0BUhlm91\nZw2oCN9Mp7DNmnFtGzOZljsfyhGxWArdWuNqQnaHlEhzWGBnz7Q2aKK4T0gSJpzmnXEL5cIicNM3\nptcINYgPGCnsuFMEgA4GabRYeg2JB44ZI6VdQcJucAhUIqY7J2GAxYItnhQRVa6EhA1JKH33Rygi\ng7TPfLNGnEQn4SMkfkMlSGltwjyju5275EG3kKh1zwiJopHJrKpklOaFUmCWTE+JRGdKkaD8jb7H\nFuE2KtvYEM+8ywcWnXg8Hvjt3R2fLlfW0VCE35/u+bfv3nPtjTo6a49k7ICO/1ps/yWvX4vuflR6\n2VZmzVTb028x1t742K48tyvv5ol/t/zD/9PeufRWll7n+fmu+3ZuJIt1aXWrJcudGLENZxAgg/w+\n/4Agv8SDIKMgkwAxPAgEBLEEJzEs9a2qurp4OZe993dbGXyHrVJbSuxEotqq/UxInkMS5OHhe9Ze\n31rvy5wTpxL4Or7ikA6gRgDWtmOwhlbtkfIzvAl4NdGYlk4ZwFHkns7MeJkARVAWwdGrGcOEotSx\nMhWZi0GUo7f7assnmVE8WzuyT10dtVKJwYx8Xi4p5+nRxhSKJBq/J1Mdsg6l41A8vc41caIEgtST\n82No0GiS1Gj0BChjIFssYHM1RS9KSDnjrCbnuiIaguAVpJLpWwvFsGs9Ya4uZqcws3ItWYTenHsY\npdoNTgWsquGFF86TsmLUkWlOtMaDifhzYGGWwrEETilgjcVoQ+tq9JEqBY0hUM20tShWqiBGCLOl\niKIxLVOpB2/kgDaZmoigUEFRrD3nqgk2J6wRcqriWBMlqsGN07UtYar0Ycp5caT6MZ4tLc+CTY2T\n1+efSemCMYUgFiVgVEGSxriaWZyxOJ3QpRBxKA1WNLOoc+R69XpwWuGdQtkGrzookTHMFDq2bYvX\na2KSuoqdLcoLK9+TNXy8vsQqw10YuZsm+sbzYbP9xnqxsZbGGqaY2DXVL2Gpbn/zvPeiK1LNQL63\n3lJEGLzjq9OR1+M9xxwYnCMox7N+IHAky8Sb6S2JzKVvuWr+hPt4QGTklH9Ob24pKqAwtKah14ZG\n3eHZYyg4MlY7tFL0ApqJRid6AsezGYtg2JhELgHOctrqiBTNoDOb5hZFFYD70jLkmZM053/yRGNm\nXqZtjZrBYTWsTKKYxFxcNfeOHaFAbxWx1NngUArWeOYZnLKUAl5ngkDTGHKsiw1KGVw2zDGCcTij\naaT2UcNYrQxzMgzO17501zGHhFLCsQQG3+JF09vhPHqlSWkGqakdK+/ZqYaAZkonJCkG22CVxbja\n4jiVxFwChzzhjceZBkxhzHPdxE2aRtX2is4wmAxGM8faMnCq5ZQLSue6tefkF0m7cwHtcOR6WKVj\nnWQ4++yqTPXNQJ1TelUdK1OCOqcAO1XNikQXsoaSDNYoPIGiauR81h6RGnuvVUFLqmYJWVGKpYir\nrSFdCAqsNnjVkSTTAllroqqiOiWwqmPnVkw6cigzVmme2DWXdgVeuOhaNDUbrTeelW344e4KqzVv\nxgNTCrTW8WKzWVZ4f4u896Jbd+P1uSdY98z3ceaKFT9wDV5rbuPEz49f8SbdUkS4anvGrHjRtwj3\nuHLHzfyG1mh27gJv/oB9usOrPVk+p9PHesCjLE6DVwYnh7qmSh0jK9R/SpGIo4AuXDByK/683aS5\n0DOzNNUiSoHTmViEjQtcMlGkcBLPvXS0OjNh8VKwaqYYxeuwqu2F4vAanC0IR6biSVlTQk8p0PuG\nlBJaWeYitLahzIreWFKE1kJUke3QkqOw8/XU22bFlDIr2+AU36QdTyGjFFgM66Zj5R2DdUwxMcbI\nGAMX7cAFgvEQ5gQoJAWmovEONq5BG8NxnLiTCSuKxvc0VEMYXQwnNeKLsC8zg/YErSg2EkqpEwap\n0MgaySOqtKx9JsmptgqC4GkZzyNgSQqiyjkRvQaSSqliqRH0OdfB6DrNYMy5naAzVurMbVEad843\nE33OYSsGMfY8iXa+OtGROWuKslhq2m5djEhIUeTS1BRepXC2pQ2RnAOBzOBartor3p5mBmtxSjEp\neNFfkbLislnz/dWOr04Hvjzuedpv+GR3ya4dyJIJOdM5x0Xb86QfFvvFR+C9F12o1nNvp5GYMkXg\nab9C5sLa1cHvxjpuwh2r5jlr39Mqy6fTa+7Dz5nKaxSGXTNgFWycQnGL6DfEfMugPYN5QqYllAOW\ne6zc0aq6u6awOF3/+bTMGH2+nBRdx9OkIOf6twAbfeSu9HXsi8LOHjhGRwZQioZMpwWxEyAkEe7z\nwKH4ajoj1WshlEjBcDs76hyWpzNQVAE9Uc7WJy55rGicr3HeylTj9K3rUEnjDYyhxsEbrXnSNegE\n27ZjTAGnFGPMNWXZKLw2pJIZQ0IEeutYeU/nOzon7KeCmOr0dtWtuegE4woxCWNOdM4SgkUbw7rz\nqKHn7XhgL0c68QymoReH8dXu8lSOBImMBVbWMCkouqNIIuaMVRZdPJpApmGlIomEkUQSjRPLJKo2\nMFRdIqmTYgadqziaVP82SqfqToZFziu+RTROEkYLVhWCbuqkCBqtC9pkxBgkexTxnKlmaF1gyoqi\nNY1bQQQIpBLAQOcHTFS0NLSq5bIxZBUoKJ74HR+219zNE0Pr8MYw+IYX/RpnHLu253vrNTfzSEgZ\nEeHpsFrsFx+J5VEGGmu57gdyKSil6qWVEV6d9nU8RuCPL1/w5fSKC78BQNQTPisv+cD9MYPtMErx\n8/FvmMsblLzGqAZrG3oz0Nk95DuyvkEk4nVHo5p6+CIBQx0jK0qRESwWr+rqrah60GGVZirQaujU\nTKH+g01i6NTMUTpAY3Vmo44cZlt7umisKvTW4jkRlCZkxT4NjMXSG1eFWGAsEaMthxk6bbHisEZx\nkprwW7CsjEEVR+8c+Rw+WKRWhLumQReLcoVTyngczhoudnXIf9M0HGPAGk0MkaerFUoEpxVjTBzG\nmkK8axpi03G18iiBuylwm0da7RhWPRelJ+mM5MwhZta6IZDoTUvvPaJ73k63HNLISnu0HZg4YYyl\n14WjTJTsIFkGo5iVkGhQciJJ3ZSLeIxkwNOrVCN7zitnjoaQBSOpujjq2t8tStVJBK1qllqpyxcZ\ng2AJJLwqxHPoZSmepBVSHM5ZnMlkyaiiSapDaUfraiRSzhplCojDoHjeXXM0mSAzMUdQwoXf0uuO\nVdPwbBhodb1qS67ww80VH+92nGIkFSHkTO8aPtp0i9g+MsujfcZq/UuntNftmpVtSVLw2mC1Zioj\nN+G+7qNry7/Y/SHH9JbB7gCY8wtK/jsG+28YTEMse+7m/47ke2CPpkFpwekLGhVp5UhR1T/Xqw5N\nR2SqDlIkVghHqYcooOm1YSxQCPUSVxVSsaxsYS0TuQQCjhMOrwujOBwOZQSdTnxVBpJY5mJxSuOc\nAxfZp5oMbGRAFcPOdUSpT46QEo00zFF40rQY0RgDhzizbltsslx1lpIUW9tyTLGmAafM4D1bX60o\nQ85MKeGdo1OaZ+sVusDgPcd5rn6voniyGbAosmTuYiTFjDOaD4YNIWUu24apZO7niTfhQG8c23bD\ntWyZ9IyKhfssXJo1OWnWvqe3jlA6btLXHENkpRxKOzo/YpSmFThJIpcecqI1rto46gkpU61W0SRp\nqotXsXidSedZ3KwUxrRkyRhJ5xfT6jmsRIMyCA6vCrmU6kCmG5Jp0SqeH2eFWI2lpQCNainakEic\n5gnRjo3tQDqQUo3EraMxmsG0pKL5ZPW8+iaXiVMKeG/5gVux8y29tXTOMTjP7TTRuvrCuQju46Pk\nm3zoX8n/8c73jSKFfTohIvS2o0jgJ3d/yVxGBNi4S1T6r1jdY3V3TqH4K9bytzj7FK/WxPw5Of8t\nO45A4ixtNOoJVkWUTBTmGh+jNEU6IsezucrMvhSOqGrMomr1+TJ1pFI9ZYsYjmXgvgwYLHOO7KXh\nPnnexh1JHEo1TLnOjd4Eg5SGUBydafGqAYG3UwbR2DzQKE+TPaKqSY+IwhWLLorn/ers7auY58im\n62ikJseCsG1aTjGilSalyFU/0FtDyYqj1MtpraG1Hq3qi1/jDHf7kViEkBPbbgANKScOYWIMqf5B\nxIAqbBpHQHh5OHAXDrTKc9H33Ewzk5ygKO7ShBV4edxz0QxYZZll5D69ZSpgikLRMOcJZyFm4SSC\nEuFUMoNyxAxBjSABMYYS6yJJoSCqQ5hJUn0URDuU9KBmnJ5QOMTPSGnq4oTTGNvVqRU1cQqerrFo\nGooqXHUwRk0ks7KeMWWera9Ym5b7+UjjQGfHuhm4bDYoI/xoe4UWy1fjgWf9hrX1fHJ1hdeGN+OJ\nbdPROseuaegX+8XfNr+2Mb6I7v8nscwc0x0KzcpdcIqf8vP9X1CtqAsX7ofo+B+x5jlKWUqZSOG/\nMJBQ+glGr4npf6DLDQ0RzlUtFLLagMzACSGSiuKEJtOS5UTBkGTiVeo5Sl1E0MqSiufz2BHOvg1F\nGo5lIOaOojxjitykhjEZ5rRFxJ9P8yNaPMegsHSUIjxxW6x2IMLbY6BTjo3q2fgGlQXnHFIUtiiU\ngk47rtsBhWJOkVIKm66htxalNTZr2sZynGaUqqLypBkYnOeUq0G2EgXK0DqIRdE7jUZzcxyZS6aU\n6tsacz00izFwmGobJpa6WrzpW7LOfHm35+20ZzAtvfHcziNBzUiG2zihC7ydTwyuxRnPmI6M5Z5Z\nQIlC0xDTRFGQizBTc9NGKXTa1lViRlSuaQtFDE5DVoKmq1tyakIhWGfQtCQRvB8pyaFMxNmWkGus\nklctsUSUDeRiuGg7dn7D2+nEddtgaRll5sVqyxgKz/ot3+s3fDUeaL1l63ou254n7YAgXPU9u7Zj\nHyKXbUvnHG6ZTHgMFtF9TOZ8w5y/xqiO3n7AYfoPHE7/Hs6b++vmX6Hn/4QyH1b3rvwG4k8xqkHp\nHUp15PhTkJnMkfOSKiKao+oRCWS5J0tdJ77nEsEQ81hXf8vMZ2HHKB6jWopqEen4fII5W6JolFqR\nSoPTLVkaDjHy9WRIybBSO5xp8MVwSgkjFsmKjqY6rvUXGFVjwO8OE7umY2MaVrb6B/TeM8VAb+qc\n7MZZNn2HKnDIgZKFla8iKCqjUHSN4zQHkhTiXHi27VHacJhnbqeper1qhVeKjKI11Sry7elUhTgJ\nu7Yjl8xdnIl55i5EnDhSCVhtGZq6UvzF3R334UCjWwbbchOPRAmkrPh6PtFZw810otUeZxzHfGJO\nR5KuQpyLr1cjuZrNnM57EkGgMXWiIamE5Iwxish5NtdBo0ztzarEMSu8qS2QMU50vm74nSSzdh0h\nQe89z8/bY5kTTjdszMD3V1ccJXDd9zxp1nx5f8ezYY03no82Gz7cbnlzOhJS5skwsG1ahqW6fUwW\n0f1dIiLE/DNKucHoa6x5Tjz8O3L88bnnZ/HuzyD8FZgX9WvyzyjpM1ADYi4pRSjpp5wkEeRwHloS\nkIG36impzMT8liCKqawY1fdrzlqeOTGwDxMvw0U9eNMXFN3h8Hx2HImlIYmm0xcY0Vy5LYcEc4rc\njhknlmf2kt7Vg61cCqZoFJq19nht+GDYIApyzowhsGs7etPiDcSSaZwn5ERnLFqEvmnpzv3Eu3lE\nFaHvG7yp1bkUwWvL/TiBUpQcebLZkEribgoc57naGlJ9ByKZwXlyEb467BlrpjkX5ymK+zQRU+Lr\nU12COeWZ3nl633CYZ77Y33AfR1a6p7GWN2FPlEwBbqcRlzV3JdAYi1OWU5kYy7FaaipHzpqiZ6TY\nGl8vZ0MhpVj7mi6Cmgil3mZtfRHtnaVzHacUESZQte/60eaa14c9zhbWvuOYEh8OWwRDpzw/urjg\nfp7Zx4nrfsOzvuf76wvGlNBa83xYkaSw8Q2rplnGwB6fRXS/a4hkSvprkBFlfoBSLXL/51Be1rAr\nvUP0BZI/Q+krAGL4CafyEvRzlLkk5z1z+p98nRtOZULhq+WhesatfMicR47pDbNYslxQ9MfEEkgl\ns889d2HiNvTEbLhun2NocUrz+f4OaLHiuHAbrFJ80G05xUKMkXHK9NbzUbOh8Z6QIvpsdjgYi0HT\neMuTricVYU6JmIRd17JyDdrCHDJNY5hjxp+TbFtva86awP00ogs03tF5SyhSK0dtuT+dACEl2K1a\nigi3p4lTCHX6xNZU45givfeEUnh9f8ecEqVoVk2txN/OI0rgzfGAwTJJwGvP4D37OPLl4YapZFw2\nYBz7sCfqjBLNbQjoUhhLROsWqxWnMjOlE41RaNuSi0KpiLO++lPYSEgFUZqn/Zo5RJTLiBSyEgbf\nkxM4rfneZseYIqILTteg0X9+cc0xRYyF728uuT2d2DU9g/OsfcOPLi/JRXhzOnI9DGyW6vZ3ySK6\n/xSQcoL010AB+0dQbsn7PwcZ6236Y47lLQqP0gOlFKb4Y97mjDUfoM0LxvSKY/yUz+MFUwlo1TOL\nMLgfccjXHOKBr+ZXZHpafcnafsghjohSHILnfj6RQgNK86P+Bd44ckncjhOmWAbbcOUHjFF80G44\nxURKNeZo0zZctj2tVZxCwltDTIWhcVgUjbcMTQNFmGIi58Kqa3GmpvJmMvY8x9s6R0yJznuM0aRS\nuNmf8M7ijMaos91jShhjOU2BUgpRCperNVkyr273hJjqSq3ShJzJRegbS8yFL97ecUwJjaK3nkOa\nuI0jRhSvjyccilEiDkPnHMdSeHl6yzRHrK1V9ZgnQqnbbocYzrE9dZFEK4uyma+nPcYYBtdQiqGz\n0DeOfUhgBI1CDHyyecJxDuzLzKr1pJjZNQOdtzgx/LPL67qaPp24Xm3YWc8fXDxBEPZh5qrrcdaw\naVq2Tfu7fjq/7yyi+08VKV8j8W9QyoH7U2L4CePh356jXwra/Rmvw0uceYpWjlIib6Yfc1Muae1z\nvH7BTficm/iKl/MVo8w0esMxKT7uPyHlgbfhwKfHVzgzcG0veOaf8HU40ZmWw1S3xoiKwbb84aae\nhs8lM84BI5q1b1k7U2Nd+p45BOZUUKWwapu6OGE0MZdzTpymaTQlKxrncLZWulOINf3Ca7y1NaWh\nVEewOWZa78kl4UxN2EilcLc/YZ2l8bYe3sWHUEjF/f2JrApFFCvfkHLkzfFIjEKSuhhRpJCk0FvH\ncQ58cX/HmBOaajA+lZm3c92A+yrM38wzWwyt1Zxy5KvxwJwL1mlSUhibyRq0KLLKCDX2/robiCJk\nMmOZsdpw3a9IAnPKPBtWFBGyJC66jvsQ+dH6AmsMxxy47geM0hilue4GtNJ8tN2x9p63pxO7rmPd\nNEvs+XeDRXR/n8j5c3L6O5RaY92fcDf9Z94c/wJ1nnpo/b/ms+mOtXuKUppQRn52+CmTPGXrr+ns\nFZ+fXnITDuzDijHPrPSWqQh/uvkYoxyvT/d8drhnsC0ftBuu/MA+1hj6acqEnHBF01nHB9sNBphS\npJQaSTO0DofFWE3rHaUIKScowtB3WF3XYFOpMfFFBG+rWBijMEYjIkxzRGuNswajNSkVSikoJaQi\nWKPJuWBtPZEPKXGaEiihOd82h1R9chXsjyPlbDC+ampI55e3e6ZUSCXjVTVmr5HocAgTL+9GjinU\n2TYjTLlwM52wSnMXTiSBeI5e98bQGsUX86EeYvU9SmmUhr6x3E6BLIXeWeac+GhzidOauzhSVEEp\nWLmOy6anIHx/vcEZy8vjgauuwxvLJ1dPuGw6Xh0PGK256DpWzrPrut/RM3LhV7CI7u87U/wZMb/B\nmh2t/QN+fvxLXp7+G0ppBOFJ8y/5dNpz6Wt/+JiO/K/95zh1yVVzQWcGPj1+TchCnGsSxM70FKX5\no801Vilen47cjTMb57jqVmydr8sGTUNMuW70IdVruGtQ1NBPbeqpf9sYFBqrDdZWUc2lehhbY7HW\nUIoQUz5H3tT4ZKUU+myeLVLvVwqM0eeU44JIOSc8lDpZQam/+zmMcZwiAN4bpAipZFIUSob744kp\n10OzwTmKZD6/PzDFyBwSRtfHMJxDKEPJfHHYcz9POKOYBLJk9jGek5812lhCjmyblkBGFcWRmZwL\nzzZrOuN5NR657Ns6EUHh6bDidpy47FqeDivu5wltDZvzcsOzfkUohYuu4cVqy2GeGbxn3TTLksN3\nj0V03zdEhH38klCOdGZHZ6/46f4nvJpeYc5WhB+0P+SraebCDwDczEdenu5Z2xVXvqc1ni+OexyG\nFDIFYW08zlherNcA3J0mYs54Y9h0Hm8MIRVab0mxGvdQwFlD1/qzaArWKqSAd/bs9FaFFfgmq05r\n9Y3Q5lzQWlMtdtU3X/Pwu+ZSRdfo2nqIKZNzrt8vyzmmSKrXuAhzTIxj9X9QuqbqpgIpBlISvrof\nz4kVCmNrDPvr44HDPDPlXCcmpM6KhVJIWfhyOnI/zTxZtShVUyS8s5xiIKh6+DjnzK7pWLWOMSVE\nK6YUWTnP9zZbDmHmqu+5aFs+3++5aGtG2Yv1ih9uL+to33jk+WrN4D0XbbckOXw3WUR3oW7U3YQb\nkiRWdkWjW3569wX3aURTBe6D5oop1fErgNu5eq9uXMfWN/VSeJxw1lBiRtC0usZ0D60DhCnUXDKt\nhM559LnatlZX/1ltQArGGIwxvySqUMX2XVGFKsQP9z8I8cPnfPvtw+eXIpRSMKYm4BaBmCIKXYM0\ndW1HCIIUxTjNjCGRc0YQ0tnI5hAmUoFX+wPHMGOMBlPNw19Oe26nmZircUwU4cK3HCWQqIdkJ8lc\n+46LoeXv7u/rwaI2xJJ5sd4ypRmN4oPVmkNKxJIYmobLpuGD9ZaQawrF8/WWXAqds1x2/eJ1+91m\nEd2FX02Wwl04URAG2+CU5bP9LaGUmpGmNZdNV6vV8ybTcQ6EWJ3FWlejZuaQsKamKyitMLpe/vvz\nZW8IdYpAa4V9ZyPqQUgfBPPbQvvtz/s23654Hz5+EOlyNjF6uK+U2oLQShNCJIIB754AAAcaSURB\nVJXCPEdykZo8HFU9zFJCiNUNbc6BmAqFQoilOrfNE1POfB1O3JxOtM4xS8EBjbccQ6zthW7gLs30\nzrJuGpJkPIZJCiElPr7YcYyJVeN4vl7xan/EW8Pae1rr+KMn16AUr497dm3PtmnYtO0yd/vdZxHd\nhX846RzDLUBnLRrF3VjjzKGmbXS2TgtoXZ9bc0xIAWv1N+Jcq1fOwqfRuornu22EB6H9TV0iPzyf\nH4T3QWxra6K+fWg75JzJuRBzImcBUfX9JOzHmZgLkUxOBaM0QYR5njnGyDFEYskEnTjNBU3h6zCz\nTzO7puV+Duwaz9Ptmk9v71GmGrSHELlerxCgiPDRbsthDhxi4KJr6Z3j4+0FSmvuxpGroacx1aym\nHsotYvtPhF/7h1q67wt/D6s1a9/80m3briWdK09nDEgV2tpLrT1b+46g/gLBWotSvxDCBx4Ox36T\nvPv93q2iHz7+ps/7Td9YY0SjkPMLQ/WD6NqCz4I2isMpkkpimkbmVKo1hoK194zaU8rEIY0oKayt\nY209kZpMfAyBdVcDIrVWvE6RofWkVKN8rNY4Z9johsu+I0ud9lj7BkToXT0oW3m/CO7vCYvoLvyD\nMFr/cg9RQePsN9Wv0Q+91F9cHGmt/p4IvyscjyUi77YfHipepdS5v6zIRdetsFzNcozRtMoxx4RR\nmq4VUtbMKeF1nceFkZKEKYzczBOtMxSr6bFcDD1pqnHyFJhzZugcWhueD2uuuo6QhJtpJOSE1YpP\nLq657Hu+Oh3ZTzNWay6Hnut+WFoJv2csorvw/4zWCv2tq6hf0uV3DrUe3v9tVLf/WB7aGQ8VsDUG\naywp1e06ZxXTWH/mkDMUwWnD0DacQsICQ+sJIeGK4Unb4qxmzplQMmPJiFE8cT1dozncJpz1eKvP\n88q13bLrWi67nrlU8W+dOydNG3ZtR2vt7/yxWvjNs4juwm+UXyUSD9Xlr7v/sXhob7z7AvDu5IPW\nGmMEVL3ERyDEyJQUxgrTGEm5xqbnkumahpUq7CcwVrNuG8aU6qREykSfQAzP1j3P1ysMik8P92Sg\nSObD7ZaPtztu55mb04kxBjrnuB6W6vb3mUV0Fx6F70rF9u4LwMP7IaR3DtoMRQreOUKIWGtxTaKo\nzLrtiC4RcmKKtSUxpsihBFw2nEpm17Rc9h3hNmOMZfCeKUamEEFpnvcDT9Zrxhi++Zk6a1hvt1wN\nA+47cCWw8NtlGfRbeO/49rSEtQZj6oGWSF1jHseZQp3AsEbTWI931fdg8J6ucYhSrBrP066nNw5d\nCnOJTKWw7jzP1wNP+xVZBAwUJazaho+2G36wu0CAQww01vF0tcIbswjue8BS6S6814j8ouIVOVfk\nCqwzKMB4U/0eyNydRpQ2cF5dvuhbdFLcjzPeCYfkEaD3liiJU4pMKfNsveLFas2cM7EUYq4eCx+t\nNzxdrZYlh/eMRXQX3muUgvywJiyCtQZBziY9hSLVYEeoFW7RZ+fgImQpvN4fyWScMRhnedEObNuO\n+znUmWSjQSkuhwGrFK+OB1LJrH3DRdctgvsesojuwnuPMaq6o6lqsKOVZs6RXKrgojWtMeiVJuTq\nRjYlTRahbxwxGdrWMknmLkfmU2bXN3xvvcVozc04cppnWud4Ogw8W62Xg7L3mEV0F957lFJnO8na\nw81n60hvdJ00UIIxmtOYSTkRcsZozbbzKGdIqd5+1IGohNZbQir0zrFp22p8Yy2bpmG9ROe89yyi\nu/De865/g9YKQdM4943xzhwyc8ooDU4Z0IqYM04b5pzYh5k5R4w1fLzZsOs6Xh/23IcZozW7ruPZ\n0rtdOLOI7sJ7zbvbag9TDVpDemeu2HmDymC0x1pLKcL9NHGMkZwL3mqsadiHmcbWMbGrfsAbzUXX\n0zm3CO7CNyyiu/De8+0RMvOOhy+czdK1JeRf+Pz23uOMBqNZN54swlfHA3fzROccnXNLdbvwK1lc\nxhYW/oFMITKnanxujcYYzcvDgcE5lFKcYkArzfXQ441derfvN4u148LCb4IHQ58HS8ub8cTdPAPQ\nGMP1sMIu1e3CIroLC789Ys4ILCu8C++yiO7CwsLCI/JrRXe5DlpYWFh4RBbRXVhYWHhEFtFdWFhY\neEQW0V1YWFh4RBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4\nRBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhE\nFtFdWFhYeEQW0V1YWFh4RBbRXVhYWHhEFtFdWFhYeEQW0V1YWFh4ROz/5f5fGyO8sLCwsPCPZ6l0\nFxYWFh6RRXQXFhYWHpFFdBcWFhYekUV0FxYWFh6RRXQXFhYWHpFFdBcWFhYekf8NqyMgE3YvbLEA\nAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvVnMZVtd7v0b3WxX8663qW5X7QYQ\nUOCwP/X4fR/iufiUCzWgwYgK2EW8IeHKLiqJQY0GNeiVRsVAJAE0JhpNNBoU7KLGaDjm6AE3u6ld\ne1f7NqudzWi/i1nUkdCogJtmz19SSb1rrDnWqnrXfNacz3j+/yFSSoyMjIyMPDPIz/UbGBkZGXk2\nMYruyMjIyDPIKLojIyMjzyCj6I6MjIw8g4yiOzIyMvIMov+d8THaMDIyMvKfR3yygfFKd2RkZOQZ\nZBTdkZGRkWeQUXRHRkZGnkFG0R3h9a9/PRcvXmQ2m/H85z+ft7/97ffGPvCBDyClZDKZMJlMuHz5\nMq95zWv4+7//+8/hO/6v5cEHH+R973vf5+18I1/YjKI7wo/+6I/yxBNPsF6v+f3f/33e/OY38w//\n8A/3xi9dusR2u2Wz2fC3f/u3vPCFL+RrvuZr+NM//dPP4bseGfnCZBTdEV70oheR5zkAQgiEEDz6\n6KMf9zwhBJcvX+Ynf/InecMb3sCP/MiPfNI5/+qv/oqXvexl7O3tceXKFd75zncCsFqt+K7v+i6O\njo544IEH+Omf/mlijAC8853v5OUvfzk/+IM/yGKx4KGHHuKP/uiPAPit3/otvvIrv/JjXuMXf/EX\nedWrXvUJX//69eu86lWvYn9/n+c973n8+q//+r2x7/me7+HNb37zvZ8/8IEPcPnyZQC+8zu/kyef\nfJJXvvKVTCYTfu7nfo4nnngCIQS/9mu/xqVLl7h48SK/8Au/8GnPN/LsZhTdEQDe+MY3UlUVL3zh\nC7l48SLf8A3f8Cmf/+pXv5p//Md/ZLfbfdzY1atX+fqv/3re9KY3cefOHT74wQ/y8MMPA/CmN72J\n1WrFY489xp//+Z/zm7/5m7zjHe+4d+zf/d3f8YIXvIDj42N++Id/mO/7vu8jpcQrX/lKPvzhD/PI\nI4/ce+673/1uXvva137C9/ft3/7tXL58mevXr/M7v/M7/NiP/Rh/9md/9u/+P7zrXe/i/vvv5w/+\n4A/Ybrf88A//8L2x97///TzyyCP8yZ/8CW9961v/Q5bBp5pv5NnJKLojAPzyL/8ym82Gv/zLv+TV\nr371vSvfT8alS5dIKbFcLj9u7N3vfjdf93Vfx3d8x3dgjOHg4ICHH36YEALvfe97+dmf/Vmm0ykP\nPvggP/ADP8C73vWue8c+8MADfP/3fz9KKb77u7+bGzducOvWLaqq4pu+6Zt4z3veA8AjjzzChz70\noU94pXvt2jX++q//mre+9a0URcHDDz/MG97wBn7zN3/zM/o/+omf+AnquuYlL3kJ3/u933vvvYyM\n/GcYRXfkHkopXv7yl/PUU0/xK7/yK5/yuU8//TRCCPb29j5u7Nq1azz3uc/9uMePj49xzvHAAw/c\ne+yBBx7g6aefvvfzhQsX7v29qioAttstAK997WvvCd273/1uvvmbv/nec/4t169fZ39/n+l0+klf\n59PhypUrHzPf9evXP6P5Rp6djKI78nF47z+hp/tv+d3f/V2+/Mu/nLquP27sypUrn/D4w8NDjDFc\nvXr13mNPPvkk991333/ofb3iFa+4Z1e85z3v+aTWwqVLlzg9PWWz2XzC16nrmqZp7o3dvHnzY44X\n4hMXE127du1j5rt06dJnNN/Is5NRdJ/l3L59m/e+971st1tCCPzxH/8x73nPe/jar/3aj3tuSomn\nn36at7zlLbz97W/nZ37mZz7hnK973et43/vex2//9m/jvefk5IQPfvCDKKV4zWtew4//+I+z2Wy4\nevUqb3vb23j961//H3qvxhi+9Vu/lR/6oR/i9PSUV7ziFZ/weVeuXOFlL3sZP/qjP0rXdfzTP/0T\nv/Ebv3HvdR5++GH+8A//kNPTU27evMkv/dIvfczx58+f57HHHvu4eX/qp36Kpmn453/+Z97xjnfw\nbd/2bZ/RfCPPUlJKn+rPyBc5t2/fTv/jf/yPNJ/P03Q6TS9+8YvTr/3ar90bf//735+EEKmu61RV\nVbp48WL6lm/5lvQ3f/M3n3Lev/iLv0hf9VVflabTabp8+XJ65zvfmVJK6fT0NL3uda9Lh4eH6fLl\ny+ktb3lLCiGklFJ6xzvekb76q7/6Y+YB0iOPPPIx8wLpjW9846d8/WvXrqVv/MZvTIvFIj3nOc9J\nv/Irv3JvrG3b9JrXvCZNp9P0kpe8JL3tbW9L9913373x3/u930tXrlxJ8/k8/fzP/3x6/PHHE5B+\n9Vd/NV28eDGdP38+vfWtb/205xt5VvBJdVWkT71dz9jwZuRZzxNPPMFDDz2Ecw6t/70eUSMjwNjw\nZmRkZOTzg1F0R0ZGRp5BRnthZGRk5LPPaC+MjIyMfD4wiu7IyMjIM8gouiMjIyPPIKPojoyMjDyD\njKI7MjIy8gwyiu7IyMjIM8hYXjPyWaV1lq1zAMzynFwNH7GU0tj4ZWSEMac78lkgpoQNns57VrYl\n1wYS+BQ4LCZ03mODRwnJLM8xShFTIqaEFAI5ivHIFx+f9EM9iu7Ip01KiZASp11DHwLLbkdCcL4e\n+thue0tKiVmeU2hNiBEfI3WWsbMWAIFgXhRoKXExklJCS4mSo/M18gXNKLojn11aZznrO7a2p4ue\ni/WMdddx2u4QQlDctRUIgkVdURlNjJFV2w/Nz6sCJSUhRkJMaKXovUcIgQT2ypKUEjYEAAqtRyEe\n+UJiFN2Rz5yYEjtnab3jrG1YFBWNdxx3OyYqw6fAv54c40NinufMioKjYoL3HoQg12r4QCWo84xM\nKVwIrLqOQhtmZY4QAus9keFTK4W4d8xeUdCHQH/XqqizDH1XiEfPeOTzjFF0Rz59Qoy4GFj2HT4G\nfEjcbrccljVGKv738S02XY9Wkggc5SVGqsEmQLHrO6TUVJlhXuakmOitR0pJpiUI0FKTaYWS4GNk\n1ffM85zCGIQQ9N4TY7on3iFGUoJZUbCzPT5GlJTM8sGqCDGSAHV3d+ORkWeYUXRHPj067zlpd7gQ\nudPsuDidIJFcWy/Z2h6BZOstpEh0CZUp5llO3wRcCtQ642hSkitNEoOHSxJsbIcSktwo9oqSTd8T\nXMQYhVISqSRaCpSQCAEuRna9ZVGVGK1IKdFaRwRyrTBK4WMkxkRlDJu7nrEUgkVZooS4Z1WMnvHI\nM8AouiP/cWJKrPuOxjlOuob9vERJya3dBp8iuVQ8uVqx6nuMkAghuDyZsrWOJ87uIIRinhVIJbiQ\nT9mrCgiJXe8QCiZFzl5WEBGkFIhx2EesC54kErXJyLXmpGkRgNYSpSVKChQSJQQg6IIjJpjmGUpJ\nUkpsOgsiMckGq8LdFVopBC5GICGFZFGWwPClAoNnrEchHvnsMYruyL+PCwEfI8u+JcYhRXB9t6FQ\nmonOeHR1yrX1ilxKEolLkzkpws3NBusdAjhrexZ5ztF0Qq4U9BIbeso8o1CaIjeURpMLQ+8dLkZk\ngllVUBYZgUjwAZE0SOijQyCos5xE4HTXQ0ooJSgyQ7ybdiCBkNB6j0RQGoOUghgTq64lN5o6G7aV\nt96TBMMxd62HRGK/rHAh0HqHFJLaGIxSw/joGY/85xhFd+ST89Gc7WnbQoKbzYZZljPPcq6ul9zc\nbZER2uSptEYlRevdcMvuAk+tzpBCMSszMimZZSWEiEQQIkx1zqTImNUFzkesc/QuMi0MRmlKZUgK\nZEogBd4mjJFUVYZRis4OwhtTRN31c6WUZEJhg2fV9QAoKZkWGTvnkAgkgiSGvLASklzrewt1W2up\nM3PPM+7ccMWbSGQfzRGTWBQlO2fpvUcKybwo/s/4mDMe+eSMojvy8aSUWPUdrfectQ21yZjlBTe2\nG9Z9i0JwZntWuwYpJY7IharGWs9HTk/og2eW5VgS900mSCc47hq2XY8WktJoHtjbZ6/M8RZiSBgt\n0Rlk0lBkhkJrGttjbUBKqIsSrQbP1SePkmootPAebRRVnpNI2D4QYyIzCqkkISVEGvza1jnWticl\nyJViVhactg0iJbTQeBFRchBLIwd/uA+B3gcqYyiyIe7WOEcioYWkMObugmJkluds+v5uwiKxV5Tk\nWuNCIKaEknK0KkZG0R35P7gQaJ1jay2JRJ1l3N5usSEyyQw3dxseOzsbfvsycb6ekQvFtdWSk3ZL\npXPWfcskKzgoKxrbse4cfWuxLjKpDAf1hEobpsrgU8DahJQwK3IuzOcEEtumxQiBySSlLsiMRGuN\nksMVtHPDLX2Ra5SSBB/wMaCUHBbYJCgl0VLhQ8T5QIyRIjcYPTwmkiQmTyCx6S0uRqZZRpFpbm22\nEIezw2QKdVeAtRAkBC4GbAhDFZ0erm43bU9IkUlRkGk53CX4wCTL2FiLFBATLIqCXGv64PExYeRw\npT3yrGEU3ZEhivXRXKyWkrO2xafIYVVzvNvyryd38CkBgjIz1NJw0u4IKdI5x63tDhcSB2VJnilU\nFHTWcdLsCD6RZwotDZfnUzI067ajD5Z9U3JQTqkqTUxAHG7dRYIkYV4UzKuSpnOElJAC6jKnzMwQ\nE0MgBFjvCGGIhlVFAQl6Z3Eh3LMOpBR3I2iD6AafQAypiMxovE/46ElREMXw5ZOSoM4MCcGd7ZYY\nA1JqJoVBKInzfrAQIvQpIoDKZORG0QfPsukQAg6qGqOHuFrnHIU2uBgRAkJMLMqSTCm21hJTpNCG\n0pjP7Ydi5L+KUXSfzQw2Qj94mX1PEnBQVpy1Lbe2WxyBtnf0cbhld8EzLwtaF3js9JSTdsthXqO0\nZJ7nZElwa7fltBsW3GISPGdvj0mW8dRqza7rmMicaZkxKyr285xZXmB9ousdNnn2qpKDukQaRdP3\nKCJGaYSQIASzsiIzmm3b03uHVopZUZDnit4O/qsQghgSKUWkEhRZhvORtu9JMZFnGvnReJhIQ9ly\nAOcCUgqMkeSZobUW20eEBG0k3keSGKwJGwK3t1tiCCit2a9qPIFd24MYFhSNVggpMEKSa0PrLauu\nRwk4N52S62GenbPkUqGVGtIUITLPc5SUrLqOBFTGMMmy4d9299wcPeMvSEbRfTbiQmBnLa1z+BCZ\nlwWbvmfZdeRac9ruuLZaY4NDIDmoC4zUPLlccn23xoihgqxQinlW0TvHzXZD23Z0PrFfFhwWBS0Q\nvCd2kTPbIqVgL6+4OKvYK0pubbYkJ5hkmoOyZjYtEEgKrbF+aJSTCcXepGReFWz7ntZ5MqkoC01C\nEkOizjMiic5ZrA0YpZlVOUoJ+j6ASIi7uV5iQqpBVLve0TpLSlBl2d0ijoRzgZQgkgg+kYhMygwp\nJeumxYdASomqyAhpiJ9lStE4y2nb4mIk15pL0ylr7zjb7PAECp0zLbKhcWpKZEqz9Y6dGyJ2Fycz\nqsywc5Zt32O0ZmpylBB03jMrClJK9+yfUhv2igIY7lZg8K7HNMXnNaPoPltIKeHuNpbZ9j1GKXZ9\nT+Mc86Jg3Xd86M4d1n1PaQyFNkzywUY46yxdtCybDusdi7Km1po2OE62Leu+QybJvMgJRA7LGgU8\nfnrGpm8pyZiWOQ/tLXAx0FhPip65KsgyTVUYDqsZvfMcrxukhr0s42g+oTI5LiUSHkEiJYVSgmmZ\nU+c5p7uWtuspTU5daiTDbbwSCk/A2oCLgVIb5nWFT4Gu9yBASXEv+oVIGKVoesuud0CiyjIKk2GD\npWk9yIhAIBJ4EmWWkVJk2fR0vUNoOKynBBGx3uFDonGOLjjaEJhkhguTKbfbHTc3WwKBRVFxYTrF\nh0AbA7mSbF1PvBtFu2+yxzTLOO62rPuO2uRcKKdIKem8ozYZISVa50AIMiXZLytg+HIdUhd6vCr+\n/GEU3WcDKSXWXY8Ngc45eu85nNS0znL1bKggCwlCGm61++golGHpO66dLTm1DTOdMzEZWil89Jw0\nHWddC0BIkQv1lEvVhEfXJ9zcbCmSQSE4X9eUWUZre6wPeJcIIrKX5Txn74BFkXGn7WldT6kU86wk\nN5qiyKhzw3bXsbJ2iJwVJRf351gfaPoerRWZlGhpECqRZ4Zcak53DV1nyXPDrCjQStL7SIqelASR\niHeRosiYFBmtdWzaHhg6neVZNkS/7pYQb9uOzsV740Wh2XSWs3WDUIJcm2GhTECmFJ31nHYtXXBk\nUnJpMccDp01D5zx9chQqo4+WwhgOq4onN0tuNltCCtxXTXlw74DGWU76HVoNfrQRmkTiYjXnoKh5\ncnfMum/Yy2runxyi5eAlZ1LiY8Lf9ZmVkBxUFQkGcQZKY8YkxeeGUXS/mLEhsOstvfO4GJkXOb33\nHO92IGDteo6bhk3fDx28ihIhI4+cnXJju0EpQSY1Ripqo/EhcnV3xrrtaENkYXLuq+d00XPc7LAh\n4FxAoyiMYZ4ZSq05aTpu7FbkQXKunnI4qSgyjfOJaIfbYiUki0nGcxaHtCFwZ7PFJc9hPuGgqjBG\nEQVoCdvW4mNEKMWiLDk/nXLWdOxsR6YUs7xA6kRKkkxJlFKcbHcQwRjNfl2Qkhi+bHyEJBAEQhIU\nRlPmGeu2Y9tZEkMUrC5yeueHeFjwtC4QUsCGyKKqKLThqeUZ62bwxvfKjDwvaL1FJNi4jrO2xxPR\nWvHQfI8+ea5uV2x9j5KCC+UUz1Apt1fkPL4546xriAQuVHNesn+RO82GG/0SrWGmCvb0hB7LUT7j\noKj5yPYpdrZnz0x5/t5lMmnY2g4hExKF/DebwhxU1dCs6G5pdJ1lY5Liv55RdL/YSCnhQrjXk6Aw\nms45lm3PLM9YOcujJ3c47RpqbShNRp1n3Gw33N5uWbuhSczWWRZFxTzLWNqWJ9anbFyPEYq5qdBC\nogSkFLndbFj1FqUUM5VzZTrHBc8Tm1N6G5kkTaYyZnXGUVnT+8BZ09LHwEwbrkz32S8ydiGytR06\nCDKlKLOMvbriqKpZ2Y6zpsUFz/3TPfYnFShoenu325gkEXEkztcTpnnBjfWaro8URrJXFuRZhvP+\nbj8Gz6brCSGRac3hpMYRWW0brHfkOkdJQAikSGTScLLbse2HW/9FXbCoajbWselaepuAODThSYFF\nXiI1fPj4mFXfIYTg4mTCpMw56Vq2tmcbLDYFIgGjNc+d7bMJLR9e3aINPbXJed7eOXwI9KmjVprr\n7YrOB4SIHBQzvnz/Ck/uznhyd51MS+4rDzmXLdiFljrLOchqnmxv4HykVDkv2nsepSo469ckApKM\n2nzUFx4SKCHGez0qJllGnWUA4wLeZ4dRdL9YSHetgY21+BDpnaN1nv26pA+eR09OWdp2qJRKEqUE\nO2+JMXHS77jerDhpGuosZ25yMqE5jVuOmy0r26GTIcnAROVcqGY83Z3x1HYJJDIyDnXNpChY2i07\n64gp0XaW0uQcZRMqozlxWza9wwjJvqw5X9UUWYYRkqVtaG0EAnu65LmLQ1QmubHeYb1jqhXzsqTK\nczKdMc8ybjVbmt4htOD+2R5Hk5rGedZ9S/JDBVkSw635flWTKcG1szW9sxQy4+J8hpSCJlg0AusD\nNvh7C2HnplN2veXOtsE5S6H1UIQhJd57hEgc73bsekcSgr2y5Pxkws3dlpvbDd47Mq2Y5DlbIqVU\nuBj5yPrOEA+TkQdn+xzUFY8uj1m5hk60FLJAKFBIrkzmbOKSJ7Y36GJgbgpeNH9gSJaE2ygEp84j\nQoaUgWlW89K9h3hkdZ0b/S1yDfeX93GlPM82dEiRmGcVd/pjYpBopXjx3guYqJrj/hQbHZkoOSzm\nAHTesShKbAw0drAmKmOY/5sFvASYcQHvP8ooul8M9M7RWEfvh56y+1WJC4Fbuy0uRNa2Z2u7u1cv\nw0kXCPzL6hY3dmuUkFQyQ2uBkgmXIk9uz1j2O3xKFMpwKZ9jsdzqTmjuBvvzWFAag5GCDMlp3LGx\nHT4KJrLiwFRoI9i6hk3XkYLGJEWZGx4o5khluNad0jtPLTMuZDMuTub0wpNcovEO7wPGaI6qmgdn\nC7bO8tTqDBfhMK84mgyecSKhJay6ntZFjFJcnk05mk44bVqONxtyqSlNhtESR2JiMrxP3NltsCGQ\nS8WVxQwpFcu2w/YeJKQIQSTKXLMoSo63W24shyY/dZ6zX9cEApuuo+08Z3bIOSOH3S8Oy4pHNqfc\n2m2wMTDPS85NKlahJ8TALvbcaNeE6Igq8eB0n3NFxT+vn2DjNhjpqNV88JpjZD+XRE652Z3RBcXU\nGF48fz429Vj/EWxK7PyEjAXISGFyXjR7iA+tnuTY3iZXksvFFZ43eYCtb9j5LXvZhJ1vIBqkSrxg\n9iXsZwuebm6ycy2lqrl/cgEhBl94nheDELthgTNTmoOqQjDYWomEkWrs2vbxjKL7hcpQ8eTpfaDz\nnspoXIicNi2lMaxty9XlKcddQ6YNM5VTFpqnd2tubtcsXQsIGmupy4y5zjnu1zy+u0Pje7RQTFVF\noTIcPV1s2IWWjY1IJKU27JsJQVpO7Cmt96iUoWPJVGfUuWRrG9be0seIipI9PeG8ntOIlmXY4Lyk\niAW1LKlLzaGesPYdJ32Di449XfGcyRGLouTMd3SdxYc0NJwxmr2i4mJZcafZcX27QQnJ+XLKuekU\nI2BjLQrJzloiw0LbhWrCXplzfbPjdLsjN5pFWVDoHJc8BkFjPadNgw2BQiuec3iAC5Fb6w29DxAg\nyxTKSFKEeZlzY73m6WaDc5GJ0VyZL2iC5+r6lDZ62hAxUpB0pNYZs6LkX9Y3OW7XJJWYZyXP3Vtw\nares7JYoWjbeIpEobblSzZjpgkd2j5JSQ6V6anWI0hN637IwHUadcdwl1jGnlpEX772QLvUY/z85\nC4YuHlCq+0kyoITgS+r7+Jf1dZZ2NVgT2SW+bO/5nDZrTvwJ06xGC0MuSvroeM70MueLc1zdXWfV\nNUzMhOfN7kNLSec9hdKEu/06QKCk4LCqEQxCHRm6tmUfTYw8OxlF9wuNjzZU2Xb9vYY0u97dsxEe\nOT7mtGvQSqOiJM8ly77Fp8Stds3Nds2pbaiVZi+foAXc7M84tit2vkdFhRSKItNMM83Srji2K2yK\naCGZiQmlyghyzSbuCAE6b9BCMdMlGZqWDW3c4aNExowiTsm1RElP6zu6CD5KCpVzqOfs6YLbYUhR\nkBR7YsJU15RGUUjFWdex8ZYgEufNlOfNjsiM4undirb3aKE4zAqqqqSUkllW8PR6zWnXooXkfD3j\n/vl8KE5oO4L3CBRSSnSmOVeU1EXBY8fHbLueItMcFPWQDXaOFBKbrmfdWUIKVGXGffWE1geeWC3p\nbQ9SsV9VCC3xISCAG82ak64jikShDc9d7HNqd3xoeYcuWKIMzLMKqSNSQCbhyf6YnevRxrLISx6o\njzh1t2jCGUZ0dEFjhCHXHRcKRaEKnmpuIGmZ65ZaHWDlEZ1fcaiXFGLHDTthGWtqGXjh7AouWkz4\nF264ChvnzIoX3u3KtuNSccSH1ht2vkNJyYXiiP82/1JutkuudzeZ6Jq53mPP7NFHy0Ex5/7qIo9u\nn2bZNdS65MsWV8hVRnfXgkkp3WvTmVLiqK6HVIjtCTFRak111zd+FjCK7hcSjbU01mPv7qJ7UA8t\nB2+uN3TBsfYO6zwbZ7HRcVhU9Dj+5+kNbrdrFIpCGyplcMLThYYb7ZKNbXAkMqk5V1QkmTixN2hT\njyQSYkFGQWESWvT0cUcbPT6JoVCBColCm1O66HBB0tgSIzImyqBFZB1bbHIQFYaCadpHqIgXW9oQ\n8D4jJUOtCi5kU2TUPOVP6J0nE4ZDveAon+BlwneebRyKJ4yWHJUznlcv6GPk8fUpMQQqlXNltiAz\nEiUU0sFxt2Pte4xUXKpnXKnnHPctdzYblJSopKhzjTKSiclRSfD4esmm6ym14b7pnFmecdq37BrL\nxll6b0ErCiM5KCtWtuOx5ZI+BLQUnL9rf9zpdvQxcKtf0/uA1JBnivvrGXfslid2t4nJIk3iMJuR\n6UgXW3JlWbktfQzUqmev1JzPjlj7p0isKEVLE0syOSEXO47yDkXOtbZD4TgwHftZyVm6QmNPuGBO\nKQVccwtuhwm18jyvPgAsfX+Np12NEjkH1UtxSWLSbfazKf+6FfReIJVmkU34v/dfwtXNKdfaGxQ6\n53x2jivVOba+J1OKhyYXuLo7pXHD1fqXLi4y1QWdt3R+KCbRUg39NHy4Vwq97IZdSEo9xP0+umj3\nRdRCcxTdz3d8jHRuENo+eGqT4ULkrG3IjGZlW55erjhuWnKtmJmCIlM80ZxyfbNi7VoQkhCHblyT\nLOOsX/NEc5sm9qgUmWcVRuRENvRpjaWjDRCTolSSqdRE6fGc4lKkEAHraxIls7xF0GGjpQkGnzSV\ngpKaNjhyvcQCMWi2/ZxMZtQ6EGJilyx9kCAUE1kw44iOLV3YYqMkhQIdS2Z5wdxU9H3gelwRbaJU\nJZeyOYusZht6mt7ShwRJUGrDuWrKuaxmaTtutRsIib285MpknywbGqOHLrC0Pb1PFEZyvppxqap4\ncrfmzmaHuttrYr+sCBJUAO8j13drbPCYzHDfdMbEGK6uzzjdNrQEFIqyHEqNjdKctlue3K7wIqI1\nXKpnVLnmie0JTehpUodAUmcKrSL7Zc6yX7J0G4y25Dqyly3IZE8X10xUy9YnXICJtuznkZnZo/E3\nELGhUg6fNEocotlwkG2QSfNYXyITHOieKyXcCA+ydLe5pE9RGI7DgqfdgozAg5MSTc9Jf8pNN8EI\nw/31S+lTAfEJJqrmybbCxRIpJBNT8P8evJAn1kuutbeRMuNSsc8L5/excxabPA9Uh9zptvQhIoXg\noek++0VN7/29HUMqbYbNSEOgMppZXrDuOmIELQXTIv9C94lH0f18ZdgNN7Lu+6GBi3dses/RpKJ1\nlkdOTjhpmyHkjiDPNGeuwQXP9e2aW3bLWd8wMxmVzsiV4Gp/h7N+iU0WLTSgKQ2UJrD1G9Z+hQcK\nEShkhpZTcnGbKLZo6Vm5Ch+5vq5gAAAgAElEQVQ1tZHUMuDpSKInJEEtHL2f0MeK/eyMJBwuCU5s\nTUIxUQkVS2zqkarFRYVMip3dJxM5hepoo2cXBH3QaKHZ1wVZ3GcZz7CpwydNFieUYkKtFbkwrHvH\niWuQKGY654HqkFLkHLsN675HBYURillRMjM5c11w0u240WwxQnBUTbl/MkdGwXHfEGOi6yxRQmYE\n+1XNQuY8vl5x1jUIMSxcHU0rOhdpnafvPXfslqSgyAwH5YRSKj6yOuXY7iBFjFYsigqLowmWJvQc\n2wZUoNCCg3JKLuCGu4PHIoSDpJkXCiM7SpNofEdrLZXuqI1jlk/R9NiwZq4bViGjdxm1chwWHROV\ns3INEkchPEZGtLpIClsmevidfrjZRybNTDmeUzluxwe52d3hvDkbfj9pn1thiK1dKiK5EDzVNZy4\nCUYqXrT3ZXRpTuMeJxMZp/0ELeYIDJlS/Pej5/DUZsPVzQlGaM6Vc77i4AqN92x9y6Vqn03fD03r\ngaOq5qisCSmx6TpKZSizoWm8D+HujtHl5/js/IwYRffzjY9Wj3Xe0zlHiImDuqT3gadXS9oQ6L2l\n95E+BWxwLIqKre/4p+V1bjUbDJoy10ykoadnHXbc6c/Y+Z4+RmqlOLi7G8My3MTHjlL3+JiRUsXE\nOAq1JdHSxkSfDAvdUwhF40sOshvDqri03OrndLFkbhwlHZaEjQKPYqFbGj+jjQUH6hgvwUbF7W5G\nJGeuIjEZOiw+BUgSjca5wyEfy5Y2StogsT4jVzlHJifYCSdxiY0OMEypWegFQg654U3r2HmHloZF\nVvLQZJ8YEjebLVtnKYViogsOyhKBQoXEsrWcuRajFEdlyf3zBba33GiGuFgSktwYslwxURk6SR5d\nnbKyFpNgWmYcTaZsveW0a2itZ+t6tBaUeUaZGQRwdXvCyvXILFKbnKO8Yht3LMOOEDxd8BgFZeaH\n7YaEZBNuI4VDy4BIGXtGIsWWSdbRuMTOKkrjWJgdiyLDukSkYa4b1i6nSyWlcCyyjj0dudFrXJSU\n0jLVFqPO43yHUTuMcDzS7ZNCRq0sD1Q9u/QAjzcrarnGCIGQe2zj/aydZ5F3ZLLkqcbR+AkCwVcc\nPIcojrjZPYWICRcnLNQhSiiEErx47wKnjeXq+gwlFIu85L+fu0zrPWvbc1BUqCSRAnKlKY1mVpTI\nuxuR7tfVF3JWeBTdzxdsCEOdvnOEGJlkOa0btjQvteHUdtxYn3Gn6ahzw1TlFLnhsc0x17dr1n2L\n1sM2NFmmyLTkpFtytT2mCR1GBOZ5QS5LfFrTc0oSDdZLLJqpTsxNxMaAZEkQcD5b08eclZtxqLdU\nukEry4kraWPOBbOhkokTW7PQS5xQLNSWW36PtZtwaIaTuE2SpSsJSM5nOzZ2Ths0+9kxPQYbNLf6\nOVAyU44+aHYxYONw4lVS4+05Ah7Hjt4rbMwIMadWGUdZwW6nuBO3d4sAMg7MhIWe4aKn955tH/Ep\nUKqMvaLkcjFh2zluN1t676lUzn5VcVBX7DpHcI5172m9I9eag6rk4mTGpm25tl3hQkLrwfvN86Gh\nurOBp5o1jbUordirC/ayktN+y812h2PIAE8yMxyjwAfPnX6Djw5lEpNMssgnNOGMJu2QItA7QaGg\nNo66ABkTbdihpSWTESUkE23QYsU027LuS9Y2p9KWhWk5VwY21uCSo1YdfdS4UGBkpNKWQ9PxRD9l\nZzWl9OxnDYU+Tx92BDokiet2TggVRjkuFj1CXOZ/bS2SFoOiNlNK9SDXu264MtdTbrUREUsigv/r\n8D5m+gKPbJ4eYoCUXKmOyJQhhcTF6ZTkBU9t1lRCkSvDV9x3meQTrbNMinyIJyrFoiq/kP3dUXQ/\nl6SU7kZsAqt26GW7uXuVezSp2FrLv96+w6rvKLMMKaAsMpZti/OOa82KY9uw7nqmZcYkUxAT17pj\nTtxgIxRKEZNiohW56Vj7NTu3xSJZZB21Ap8mFPIUrQa/cOMKNrFm33Qc6JZdHMQ8Sngov4NLOTf7\nfS6aJUoGjPRcd1N2IedytqQQgTtugpGRmBIXzJIb/R5nYcaFbE0k0kbDbVsTkVzMdmz6GdsoKeUW\nLzQhKU66PRIlE9mx84pdknRBoYRkrgzCH7ANHR0t3mlIGXksyU3OREg2XWLpe3yEmc44X0yYqwnL\nrqX1Fu+Grdind9tMTpVhvW251bckJFNtOJpOqbTiuGlxvRs6swmYaMOkypkZw0nTcXO3IQiJUYJF\nWZBJySpYGm8561usCxS5ZlrlGCk5bbecxh0pASowy3IyBU5ZfPLs+hZBosgjhRFMtcHHNU50aBHp\nvaIUkolpqXI3dB9zkAuLUZFpFtEpR6cz6rxnaSuWXUGlPHPTcanecdzXtFFSqQ4fJFKaYTOjZDnM\nG652e6xdjk6RvbznsNhn2e/oo8cLydpWKLEYSqUzz8xc4J9Wdmi5KTSLYsaD1UM8vt3Sh4ZJNmfT\nRyaqwgd4/uKAK9V5PnJ2jA2eSubcP93jXDXDu0SdSRamYtX3TPIcJQUXZ1NmZfG5PnU/E0bR/VwR\nUxpKYYNn3fUYJTmoarZ9z/XVGhsDffT0faAXnpgiM52xdpb/tbzJnW7YGDI3monK6HCc2iW3uyVt\nsHQxMTOK/VLRx8SZu4WjZ5btENHgUsFUB0p9BrKn9dCmjAeLMyoZOHEzDtQpUiUumBXHvuLEzTmX\nbVjIHdtYsYkFLsGXlnfwUfGkXXBJb4lCUMiep+yMVSi4kp2R4bnlZvRRo1TgPrPktp1z7KYcmiUO\nwS7m3O6nJDQXsx3rfsIySKQIIAQCya7fx6cMI1q2TtFEifeaTGbsZwbhJpzZjjZaiIpC5tRUlEIj\nE2w7zzZ5iLCXFZwvp+RRs7IdW9sjo0QryUE5RNZEhGXTs/EdMiqq3HBUT5HJc9x1NJ2lixGpYC8v\nMIUixciqs5y0OyKBLM/ZLyu0FNzq12xCj/OOGIcy26IS+OTYOUcbe4QMGA21MmQq4eQWgaV1ApkS\npUnUpacUEh8boggoGQhBkQtJqTumRU8KgeOuxAhHrmDvbrIBNhTasXEVS1tQkKhMx/3Vmpt2ytJl\n5MKCkEy1wIVsKOrIGm7aKZu+IAlJpSL3TRactB2b0NGTE7xhnp3HxSEieJQt+Nf1cNegRM65suL/\nOfoSPnS2vNuspyJExQUzhSg5mla8YH6Os6bBxcSlekqdZ5yb1GiliCT2q+pzffp+Joyi+0zTuuFq\nadv3CCmZZBmrrmPddkzyfAj6r9asXc9enlNog1GSRzen3NltWPcdWabx3lPnGUkljpstTzS3aZPD\nSM+8yNFkOHbs0hmJBu8TDsMi8ywyzy5EUtqSROLB6g4+as7cnAO1pc535Mlx6mvWseDLyltMZM91\nu0+pPDHB84sld0LBdTvjotlQiZ5tnHDbV7go+G/1LUKEx+2C87qhT4qZ2nGtn3ISplzJTyBFbro9\nNi7DKM+lfMOJnXJsJ8yzDTZompBxZifEZDjMLCubs/QaH4edJLQwCL9H5wVCdDRO0wdFipJaFSwy\ng2s1S9tjUxiKPnTJoZ6AT3TJ09uADQmD4KAomRb5EDFrG6wPaDH03z2sKkDQ9pbWWTbOoeTQ0nJW\nFPQ+cKvZ0PqeJBRGKmZ5jjCw8y3L1tGEHiMEea4piwxInNoNbQpAQqTE1BhM5nDKYZ3DRY+RicyE\noTBERILcIpWlaXNEipQaplVHKR2tT0RAkRAxUiiDlpYya9Ai8tRuTkbEmMhRsUOmoX2lUJ42FHQu\nQ6ExquN8uWVlJ9zqC4wIRCRHBfRhxia0VNqyshWtK4jRILXgcj1j7eCk2xBTAcnwYHUBS4b1lnP5\nlDuNY9M5SkoWZcX/d/l5XF2tONk1XKonGGV4/uKQaZajlOKgLsnU0GVtFN2Rf5eYEiFGWudY9R1G\nKo53O6SAw3rCqm35l9vHdMFRGQUoJrlh7Vr63vPUds2pb1h1HQdlSZVnxOh5bHPKid/SpZZCDH0G\nZsaQVMvKD5VNHsl+sWOiJT6VaLFGiA2LbEfrNOtYcTHfsm9a1k5ikyQiebh6iiAVT7ULjsyWTHhq\nBTdczTKUvLg8ZqZaHusPCEkhSLyoPOMsaJ60M+4zG6SINDHnWj/FJ3hpfZMQ4dFuQa06vJAs9I4b\n3YzjMOdCtiREwU07Ze0LlEicz3es+inHrqLUHX2Q2GjYuYIUKirjaKxm5SQ+CJSQFNqQxym7PtJH\nSwgCkiFDU6uMWuf4zrPsLJHAxBTMTMm+KbDWs4uOzg6VZ7nRHBYFCYnzQ066dx5jDLOyYD/L2VnP\n0jVYl+iCo9Sa2aRAqqGf7p3dDhs9QgiyzDAzBT2Os7jFuUifoJCRIpeYHFyALjbYFFAiopOkznKk\n3kHW4XtwLmEUaB2Y5gJJJIoWoyy7pkAAuYFp0TA1LWddgQsSJQVaOCaau7tbWLSyXN/tIUJCSsFh\n1VEKwdIlAsOXdQoaIyo8HfO8oQ0Ft9sCoiQJyYVKocQ+N9oNiID1GSkWGCZE4EJV4GPJ9e0KkxSa\njJceXSGj5KxrWBQVWVI4HzhfzZkYw0svXsL7xNb2TIqcTCou780ov7ALKUbR/a/GhsBxs8WHxJ3d\njoOqYprnLNuWp9ZriInWO1xM2OjIpcYoxdp1/O+T22ysxWhBZTJKqbBEbvQrjps1jXf0KTDPNPNa\n0YX/n703a5Iry64zv73PcAd3jwlAZlZlTSyRErtEUU3TU///lzaj6aHFboks1ZxZSAyBGHy40xl2\nP5woUbTupplEtWhVlccMFrDwCDiA8Lt83bXXXqvwzfSRxTbGOKMI0LP3heCP5Lqy1MpWA/9if8+g\nxlMa6eSC18yPhgeOpeN+u+J78Zkbf+YxjzyVHQXlf9t/A1L5+fKKW3cBHHcu8zaN3OeRnwyP7HXm\n59stz6kjaOWvdvcci/KL+Yo34YghTNXz9XpFwvGvx3cUE34+v8IQoHIbZh7yjo/rNTfxQqrCp+3A\naetAhNtuY1oHHpYOdZlclWoOywO1DjhJTMlYtgAGXgM75xmt43RJzGSkKIM6di4yug4nyrZtTGul\nUrnuew6hp1PPvKzMZSO3FEqu+r7lNtTCaduYlgQUdt1A3zu8a1U/D9u5RUeKo3eOXd+T3MYxr5yW\nlSyV6KALjuACSRaeUgYr1FwYQkffges2csnMaUM1ExRUlUEc3k1It7GukDePVyGGxNWYqKU1Jwe3\nMS1t0cCLsu8u3AwXPkwHlqQvYe6J22gs1WGy4V3hftojRTA8V+PKlTc+LspWhGSKF8fe3XKxGe9W\nqkUeFw8lUHF8vht4Fe/46fMzJVdUIx2Bz+It01a4HgM3/pqPpzPXoWeIHT+5+4wfX9/xtCxE53nd\n74jOc7Pr8SIMMXLou3++C/qffv4/QffbUM1/wmlpXytrLjzME4fYLENxddxPE2vOfLycefv8zGaZ\nu25H8I59HPjN8Zn7+cJxWwii9Bp4NXRkNd6djvz89MBWMi4aN0NPdMqcV359/kSqM2s1MEek59DD\nZUs8bTM1Vb6/m+hR5uxY8g4JZ/bumVPpmLaR73Qnbt3M5CJPJXAqN/wvu4/c5IW36y33eWAnCztW\nfrXctSHY4T1/2j2RzTiWykPa84N45D0D79KexxQZZEak8nfza4IU/nL/jmu38PPllscUqQjXYebd\ncmC2keuwcuVmPrLj47LDSeU6THRS+LTumZOykUGVOXkgMvrKSuG8rhStiDmcL4w6YIswpcJzvmCl\nNT/snGPX9Wg2Pp0vrJbxVdn3PQfp6MQz5ZX75UhJLc7wthvZj56UK/eXia2sSDWiixzGAWfGZSsc\ny4ltzogXruJAHBypFJ6mE0dZEBOiCHd+QL1n0gvPZWHbWk1Q7zv6EaC2pRhLaK5QBad7YixIt5Lq\nxml26BLxqgRXCeJx0iSGdVWWdaRze0JY2A2JnJWnreOSlal0eATJHpNKYeZp61i2ARPofeJuKByT\n47I6llx4WjtcCWQCFjeQM9PqWHJPMUenyhfDHQ/bzMO88LTcs25CtKFVM+07vrPb8XfrE18/nzh1\ncOMi/+bNFxzXxPvLiaiemo0vr3qKVaJXxhBo4Zl/uHzvW9D97zi/kxEel5nzthFUuZ+nlr+qSq6V\nv/vwHqeO6JRDN9DHVk74OF34sMw8TBNz2nhz2DF2rcfrl0+PPKaJc13ofaD3nusussjM++nM/XKi\nYOwH464PiAW2MvN+OtK7BadGrZFURnY+sbHxsCou7fizwwd2YeNhqzxsIxf1XLuZh23gqQxcumeu\ndeaTjvx63hF05C/HD1yHma/WGz5skU6EQTZ+Nd9S8HwZZ74fjhyL55tNqXbgB90DBz/w9XLNY+qJ\nbJQq/Dx9hqfyp+MHfjRmfjHdcb+NVFN2fuOchFMa6FzFayZVuKQBURhCpdPK02wcDaplxAlaFakB\nL8acVubNUSp4HJ0Xduopm/FwmkgvWm4nyiH0RPEsNfN0njAcXlo1UNTGwt8dJ7ZtI5giKlzvdsTg\nuZTEZV5IueC9cDvsCUHQSgtYt4W0Zpxz3A17JLR0sTktLLUiVRico/cQnHKsG2st2JYx8fTRE0bD\ncmauhXwBbwoFvOsIVIgTuW4cp4irHvVKFyuRSkU5ZSWtji0F1PaEsLE/LGyp8rgOnNbIZp5OK8qI\n1TNT3Tglx7L1gOK1cjcmTskzJSEVuCRHtA6pkUzmMR85b455c2DK4B0/urrl07LycbpwyQVJ8J3+\nluu4wzlBvFIX45Iyz5eNwxj48vqK07JxXJZWCOodn+/3/9yX+f9v51t54b/xXLaN99OZXAofpgs/\nvL6lc463xyPvTie8Oi5pJeUCCLsYsVqZc+Znj4/M24p64Tr2FAORyodt4tP5zMM8g4PrPnAzjpzS\nylfHTzyVCQ2FTttAaR+UhTOXsrDkjJnxxWGm9441O7KtdNo0OauVqQx8tz8x+Jnj1nPKHZ7CX1y9\nxavxzXJFLwXnCrd+4mEbeM57/u3+Awc38bfTGx7TwOAq/3b3DcmEX2933LgVEWMp8PV6TSLw7/Zf\n0enG31w+ZzOPmvFF/8xSO94uV9zFCbHCb9c7ntKIUPl8vCAGX01XVJRalSJCzcp5G/C+gMGneaAW\nBwjqFF+FafFsFWo11DzOPLF4wFqVzSIoSkQ49J6onnVJlFzJZowaiKHJEuqUOW1My4piXIWBXT9Q\nLJNrZS6FkjPqHdHg0A9UB8eamC8zJpVOPfthJFtLXTtpqy+SCqJwHTqIxlNNpFyxUnEiuJfNOBBW\nf6FYxlawqhxixA8z1TZSafX1ClCNIAGvhTBcMBUuZw8meA9dZ3RUlmJISOQU2JIjakBcZj/O1Op4\nnjyqkM3ROcNLR2YmhswlRdY14OhBCtdjYcsDx3V76alTdm6kk6EN3HpPWaBW5RAHvCl//upznATe\nX07cxB27LvIXt5/xut9xSYk348gYImbCVd+hKnx+2BN/v9stvtV0/ykn18pxXVhK4nFeuOl6DPjV\n8yOd8+xD5O3xxK8fHwnOMQTH4COdd7w7nvgwXZjSRieepRa+d31FovL2+ZlfPD82H69UvtjvMIwl\nZ94uzxzLwpZXggZ2I+xC4Hm58JQ2MjN9zAweqJ7eZ3AXihVSEajw/asj0VXOKbAVGHXj9XhBrfCc\nRr4bn4ku8bSNfEo7IoX/9eZrolR+tdygJm0IGM6cS+AxHfjz4Z6dn/m/Lp/zPu3ZSeIvrt6iwK+X\nG7wUVIStKg95YMmRP9u/J0jmP12+4JxHIonrbkWl8n46EH2zit3PO6YcKOYYu4ya8TAPLEURa9tv\nDli3SKqKUdjmAC1LDE9EFba5UjaHVnDeMagQCqy1ZeZahiBNY70KERNj2RI1Ny/A3nvUu7asITCl\nxJoXorR69LHruLAxrWurIqISnWeIDpV2p3OUTN5aDdDgHWPoWVxizpmtVko1nIEqDF5IrjJphlyx\nrRL6iDeQWKFmsluoUilLY/KjC2g3Yb6wZsGS4Z1gFaK2n5t2zQN8miKgOHV03cbgCqfkafAtpO3v\nw5GGoWV4PF8CIo5UHLtgDK7nWFagUKqnJE8nI4nKrq84Op7nRJSAE3gz7vn++IavTo+oV3Zu4CZ0\n/Hj/imyV62Hgu+M1p3XlTT+yjx1fXO25GVu4Ux8CYwz/fBf9P/18C7r/Ped3uQi/PR9JtVXjvD09\n86PrOwYX+MXDA3/76QO3w44lbbweBryP1Jz55nJhXhYeloVcjc/2O/YxcpxnPswTT/PEJW90ISJS\nuekHLiXxbjryzfmEUdHeeB0HKpUlJ+7ziSQZLBHFcRiM4ITztrDkRHSF3bAx+kKpAW8b3q8gULJg\npvzg6oFOE5+WkSVHBrfx+XgiSuJx3fFZd8FJ5jHtuF9HvBh/ef1bes38ar5ltYgD7sKJYsLDduDL\n/oFeEv95/oz7bU/nEj/e3dNp4ev5hmQCphTa0sOcO14NFzpN/PL8ilPqcBhdrHSaeZxHNnOoGKc1\nUrKjmqeo4aQyXwJbca1q3RxeoCSom4IKrAIS6HB4YKsVsqBFcbRp/xA8pTQLWa2GVqEXiKHJOhuV\ntVZqKnS+rVp7heJgxUjFyOtG8HC92+FMOJI4rTNmRtBWCBk84BypZi62ta44MXqn7GLkKBtbgVIy\nJuBra2bwalisLLpgG1CMzgWcL+hLkNBmieBrY/Mv210uTBAT6+YpuWXdWhb6WDBzmFuRYEznAOJx\n4gh+Zegyxy1QKy04KTt2Glip4DeCU06ToNZhCN4Jr/uBD9OGScZrQIvny90155TRUHnd7ZmXzM73\n7Hzk1bjnJ28+5+tPz6wl89l4xath4N989gVzThSDN7sBr45Xu/FbpvvHdMyMD9OFh2XikjaWkvjR\n1R3FjJ89fGRaE6MP3J8nRCB4JajHDNYt8fXzM5ctEYPnbhjYcmYfI9+cT9xPM/fTxOgdfR95PQ48\nLTNvj0c+rBPFFXqn9L5j1ytPy4X3c4tFxDKxh+suUGvlnBeqrCCF6DKK5zAkRCrrltmyp/cbV7uV\na78wpYCn4NQwAWpjLd/fP9K5jQ/TFc+pZ3Qbn49nDrrwaRs5uAbcx9zzlEcw+FdX7xld4pfnWy4v\nq6aHuBLJfNoGbuKMk8LX0x2PacRb5fXuws6tfJj3XEpERMimiBlz7lBXCVK5n3YsqWUYmIeolWl2\nTDngFLZVQRQtQi4tico2qFtrMFBz7SVfaNW9FVyW1ir8ch2nDJghyZr+GxSctt45MyhCFBicR7zD\nibBZZbOCVCWqcOMDycFixlpabkPOiV4d+77HRDjZwlITliveKb14TCsbDThXAVeBWulUCF45OaNS\nqNlanU9W1BmKULuNLBuWjbpBpz3OJ4jtc2s1nDdsVdSE6AISFly3MW2RsjqCc1QrRN8S5orbcK5w\nmQJqAScBiQt9V5iWBsSmDcT3sa1zV2aCV+ZN6a0jek+VzJe7A0/LSiqV0XccYsdfvfkel7Vw3Cbu\n+gO983z/cMNd16MIX15fs+VW1XnT90Tv+fL66tuUsT+Gs+TMJW1ctpWnbeGmG7ikjV88PfDl/gAV\nfvHwyNfHI5/vdswl8939AanG87ry60+PGCBOqMX4k1e3lGL84uM9Hy4TiJEo/OD6hlKNVDK/OR45\n5ZnjurGLgXGI7KPn3XTmcZp4zgvioetgxCO+cCmFS56xmnAC3U7YeyOlSrYNtAHrPi7UEtj3G1Yr\nWxaW7Oh84mbYuIkzl82TKyiC80a0zFo8392dCJJ4O93wtA2MfuOz3YUbvfCUB8wMRLiUniV7ahW+\ns3tmH1Z+c7rlKe+Ikum7xF5XntcBcU2L/LSMXF7Ybewzg0scp45j6sGELIIXoxTHYopgLJMnF48C\nFYdJpWYhzw4nIAXQpiFb0mZLy6DF4U1xvskRWPtaMWnPr4p1Qt4qTsGqIBn2vdIF92Lxa8M7L4p3\nEY+AN5zoS6VRIYgjqLKPgSLCsW5YyS9lmsYQmm5aRTjbSsGaX9Zg8JFsG3M1EGMTcMXQWnHO0XWO\ni26kUiAbqkqUSJUmR2RNaJexDGUVQuhREtqvWDGW5PAR6ia4Cl4cJWS029i2QNmUIJFkldAVqIHN\nEuogJYdWh3M9lZluMJZVKKl1wpVaOXSeIJFjWum9ItXzWX/gtht5XFde7Ts0OZx4Xg8DY+z4yas3\nRPU8Tgs3/ch1H/lif+AwdKRSueq73/fG4m9B9x87Zs1D+8vjI4LwcbqQSuZPb1+zlsL/8e4tD9PM\nbez5lBf+9OoWexlofPX0TC2V+/OFIXhe7w+MwfPxdKHUyrvzmZQLwbUUqn0I5Fr4+njim+cTpkYI\nLdzDFNa08tXpmVPeWGrmKniGMdI75f104rxlsm44B0NfGOip2nb/jQ0FvBr9aAxa2LKRV3CxoM54\n1c9syTHEtp66pMBWhN4VDruFO3/huHZMNbwMeIy9LizZczvMeCrfLAee15HOJ26HhVfhzNPaMZdA\nFWWtHi2QquMwLPQucX/Z87DtCFLQYFz5lcvmWSViFZYtsmZFEExAfSWtymXpaPZ+BQdkIRVHrYJu\nBvnv2ZCJw5JQM8TqEFUcBk6oibZivLW8VufsZXgEooImIxiIhyBK9ZArODNUlGhC7DtUIRlsJVFL\nJYowdhEpSg3GUgpZjJorXXCN3XrPVhKnkqkp4bzHi9D55scuubJqoeaKSEUQRg0sUtikQoWqFVcF\nK4Z6hzdY+o1SKqwF5yJRA4WVaoVERX1LdCuz0ElEQ6WGdteyror3AllAjKCRpebGmGugbkKnA6sk\nQqxYVZZsDBG24vHmuO661hwdarvTS8JnwxXZaluj7g58uFwYXWRwgT+5ueXH16/5xeMDUT3Xfcfn\nw54vr66wagxd5HZoW2i3u+H3ve7nW9D9fzupFt6ej5zTxofpwpth5KrreV4X/vbTR65DZEuVb6YT\nvfjmIRTD41iWxKfpzDhA85oAACAASURBVHlLjKFj7DwlV16PO96fj3x4vvC8zuy6njEEvnu44n6+\n8Oky883zM84JSYw348gQA6dl5eePH7mkjaSVsYvcjj2pFD7NJ445s9aEKeyjZ+yhlsJTml4M6QUf\nlX2fcOZbUpdVrBa8VDo1Ql8ZfWJZlXlzhFBx3njVn8gp4H0moaxrIFelk8J+t3IdZo5bxylFBEW1\nctvNrFug8wkEHueeU+oJrjJ0G3dx4rJFnlNHJVANBIMqjfFqZVoDz+uIYqDQh0zNwnkLZBSSstXm\ndd1MKGZQFFvaxSgGiFAqSIZSHN4El2kAbU3Hxhq4dCpNB1Ze1jPa9zna1hcqmBQEbV9UoQ8KUgkI\nqxrVpNX/VKUTxalSvGHq2fJGhbaMESOYMdXMlBOiipVC1NbQHJxns8KFhKzlxS8bCSrMdaMA6aX5\nWc1wHvqiLK6weBry05qQ62aYA6lKHVaqVdZLJbqBUT2LLFTJbAV8MDQraTU637d/ZpgxJ2wLOPGo\nc+Raib4t4qjLiHpyMvYukgE89CrMW7N4RZToAz/c3fB+Ob0k6PXsQ+DPr75oTSMGb/oRMbjtBw5+\n4Grs+OH1LZ8uF1SEfdd65b53c/0/HQ/+B59vQfe/Ppe0kWvl3XSi1NYW+/OnT+Ra+XJ/xXFe+ffv\nvybnwqvdnnNK/KvrW5a1ckwzP73/xFWInLbEZ7uB7x1uWFLib+8/sq0JMWUT4wc31xhGLpW3TyfO\n68zTvPDqMDB0Pbvo+ObpzDEtfDyfkKCoCjfDCNo60b4+PbNWEFfbLW/XbESPy8xWEuYy6oSdU0Ln\nSCWxpQUrgg8ZF5SrIWFZSZbI1SNWCVoYfEWD0buNefOsOSBihFh5M5xJqWmnc9XGLosQtdAPhb2f\nuMw9T1vz0jpnXMWZnB3VGdWUaY3MKeKk4ALs48q2Cc/rQMaDgfMFyUoyKKKUpCyLRwVMhOoKFCVt\nCtk1pDRFrLb/2ypQFd0EVwV1DagKDjFFsyKmRCqIUJ2QiuCtgW0UQ0Jj0tkEtYqvjd120VEwCGDV\n2jDOK9FFnNSmRZcCqjhROu8YEfChsUZp1eZSjT507LpAToWnPJO1IsWhVujEU/zLnp40oBZrUkw0\nj1fjIpVqzeYmAlKBanTqWFwlq0E2xCA6T840SyJQfBuelVmREulVSW4luUROoB561zFtLYS9AuiK\nBaGsDm+O4IWtwhhhy4BC9IoUeNMdKAqnbeJm6MlJWpAQHgT+5O71S2Je4U2/57qL/NV3vseSC8d1\n5fOxSQ6vhoFd35FL5fV+h/9W0/3DOO8uJz5MZwz4xfMDf3bzikPoeH858+/ff02nHifCnBI/OtyB\nwTEtfDpP5Fx5OM28uRp4NezxDj4cL/SuyQlz3hBxbQLrHa4aXz8deZomVisc4sjQK71r1rG3T0c+\nXs4t/7WL3I4jXef4zacH3k9HVgwvgu+FN/2eqaw8zguXlMEMFwtD54lRmNfMWpt261wD04OrFFWy\npab3ZSPGigtw6DM5NYfGZhG1TAyVXjMuVESMdfXMW4fTgouVV+PUQLE6VguUAoIQqPiY6X1mWj3H\ndWzh5K6wC03GKObYENISWKvHUzEvOJ+xJJznSBHFWRvyOYytOHJpTbxsDjAEowhQHayVWgSH4KuB\nM4oJpQquBLQIoW5YhKKK1YCY0mVQ3yp/qoJRqdJ8v95K03a7xopzdgQHndDM/UCqRhVpPwNzDCE2\n36xrTHjNBVOITggERm3gvZApZg04rT0+hEDN8JQvJDFUDMUzViVLoQTDijFpRTMolV4CLhvPJLK2\ni9SLIrn93omQ1Ni0oo0y04XIlitZIKiyyYoGKMmhBWIUVipVmwvYgFGU2QwvoC+gO3jaR3OMsWMt\nG/veU7NQq7CPHcE5vrfbE33P2/MTB9fTOc+/vHvFIfY8TjN3w0jU1vTxehgIzvF63LEfGui+2o1/\nsKD7e61U/7eetWQ+TGeuYtvp3sfIL54/0WvkYZ5YckbFsR86BMeUNp7OE89pY15WXu/3XO0DVz4y\nhsDjeeLd85FcjOt+4GYc+N71Lae59Zk9X2a8UwTlz17fEYPnNK385/cfWHPhUjbe7Ed2Y49WeDef\n+Opx5nE9gyivxrZ/XjDen56ZVqOINa2wE0YiazWO54VcNjAh9EZwni4Ip0t5qYTxLQQ7Gr1rQ/2n\nS7tQSlGGF0bchcK6ebZV2CwSKIQu02u7GJ+WjnUJLKWxLxeMq26lZjgvgQdGpLaLL2oBYCmOlByX\nJVJRxFmTJIBUPJelQ6uQi8NJW/4sTkhVKauD0kZnaG1yQVVsayu+WgTftgSoCrkIUoSYAU2oqyQH\niGKp4qsSpCK6Ug2KRjCHmNBbAq8g1nyoqyLa3BRSheyUhP0Xn69HGaIjOCVTKcCWG2h67xhQOu1Y\nNTOX8lJNXtmF0GqVDDarHNcNJ+AkMJjQOY+qUUQ4rxdqMrS2gPexKotk1lKotUJQehRyZoyODFyk\nsJpRDbricaXJWDVBNYHeyFRkbc83mVFcZVUjpebWWF9Y9n9xgPj2ZiIFBhehJtQJToRqkFOz6u1i\nx3fGA/fLiW/OFw4hc+UHvr+7JdfKWjKHFx635EyIjpuh52YceX85Uy4njnnjs/H3nuX+o+ePCnTr\nS9OoiFCqcR06/ub+HfuXIcI+RH64u2XNmd8+P/PzZSbgeVoW/uqLL7gNAw/LzH/88I6DO5JL5tAP\nvBkGhhB5Wmd+c3/Pw3lm2gpvDiNXYw9iPB0nfrvMPK4rToXD2POj3S2ocllnfvp4zzSvmHd8sb9i\n6COqjq9O91wuidXAORgH5XroOW8LD5dCyQkLRug8nW+OgKUsrM/5ZX/dtc0k36xKT5NitGry6IU4\nNL0ymfBw6rDiKAWGLqO+DeWm1ZFyZK2B4Co+VjoqJpXnNbIuni17gqtoqPQxk7KQFs9Smlm+AJ1U\nUGOrjpSUvHqqNY1YtAGX5RcpoShSBJTGSEWgCJIMTaCuIt7IRRrbLRCrgRqYgVSKKJYdkhzRCqoT\nppARqimuGP4FRc1DeRnYqYBapXP2cqdA26xwgg+CFxBXyLWyVcgiaK0M0hFDIDpPEWPKK8tScE7Z\ndR2hVhyOzeBisG4bmBFixz46/NasfkcHqpXOBXYWsCDwotFvgIkSVAhZGCxwsfYGKw4sCEMV1tSq\n3hVYpCkzloU+N2fIhLFQGhB7QBrQalNT2GiacjVwtS19BAdBHWuFYi2KMqrnNu6Yt4RJ5ZISpQiH\n2CGidN5zPXR8czrxcVooBW6Hgb94/TlzyXy6XHAi7EPkEOPLQsQf3A32Pzh/VPJCqZX/9OkD31xO\nVDG+Pj1zFweuQk+uld+cnxjE8+H5zOO28Pmw47obOeeVlDOhOh6mM6d1o/eB714d8CqIKF8/PDGt\nBSRx1x8A4/PDFcdl5tPpzFfHI70Iro98eXtARXj3eOaXT59IpaABBhd4fbVjqZm3x2cezud2wXu4\n63tiFzhNE885k3OllkqIHj8I0Qlz3tiWjKUEPcQeHA41a0O43BhjNdgN4L1StTDPgtQKCp03fGe4\nUskKOXusQilCFxvgimbWzZOzks0TtIGtp1JfrF5b8pSiqDSpw4VKKUpaha04nDboDwIixlaVkqXJ\nCPJCbF3TL+vaAFcrQAVrIC78DnDbsEmkIAJWlIogWQnkZoStSjXBvOJraDqwrW3gJkoRBYkESagJ\nOAUE0zZ5C9qeIwZHkbZ2zNZsc13wBKe0v5g1Q4UJXh0BuIrNmZKrsZWWNOeCEFAG55CkJF+Za6aW\ntho8dB3eKZorc8lc6oY1Is6gkZQ2srTgnM2guoqiWDH2LnDeEpMBAuJg54XLYsTYWPRChWDUInSA\neM+lbK0KytqOSe8dy1oZgsfEsZWVfedbq7IEOg2glS/2By5rYiqJXezZec+fXN1yyYnjkrjpI8GU\n71/f4NRjVvhif02umd4FbseBIQS+d31N5z1rzny+3/8+V/XAt/JCO07btFlU8Ah731KbqhnndeXT\n+UJNldthYNPA926usARb2vi7Dw9cxYhzyvfvrrnrR07Lyi8fHllSQUwotfCT737O6CO/fT7y17/8\nDeqFNSV+9PqGsetwBr9+fmLdEt+cjnTiubkeuO17Fit8er7w1fRMrZk+RobYMUZlqcbH05F1qZQq\n+F7Z7z3OC6dj4bmsVDFEDBc7uqiYCVueKellCBQdnXeo96xpY9mMWhRRqOY5+Io4SFvlmCKu1BZA\nEipd1xjksjXWW5t6QPRtS6waTCmSVqE6xYkh3hBtt+Tz7EnVYUlRMUwKqtZWWKuHpI1mSftVnUCS\ndltbBQxMMgAFBxnEjJCtgaoZVVx7pZf2x3hNLyvWzY3gUHQznJuoophXam5ML5g2xmuV6qVJErmB\nvw+Go7G+S8qIKU6hi83i5tWopVBVKFJwSej6jk4DaqWtC1uheIji2PnAXiN4aY/VRFoq0QshRLro\nkCKc8sZaN0Q9XfUE8c1rnBrYpiJoAM3GjY5sa2aNhS1nMhC9UCt0FXwNiFuZXEVKC9GJ4rmwkUWQ\nWvAKAaFKc0IECxRJYIbVjaiBQToqqQUyqeOYE8/bSinGq37Pdei5lIXjuoC1cJ/OBXrnuek7KsKH\n88b95UTvIz++e4UCx23jw/lMdI43v/+A+4+ePz6m+/iBq9C6l56Wif/9m69xJpy3jTknfnL9GUP0\n/Or4yDfPJ7qqfFwmvn91YPQ9uy7y8XzCCXx9f6Rq26J5vd+zpo21ZJ7nledlQWohaOTusOMQHO+n\nibcPj5xTYvCBu/2Id57OO37z9MiH5cK8LUR17IaOzw97Ps0z3zw/sq61Ta174WYYEKk8TRtzqqSU\nEbHWUrATUoZl2Sg5Yy/aqzqH89o2ndJGTYLVgnZC3xkSPOuam/2qCiaGOqHTimnFFKY1IKVJNNEX\n6AwrQilQarNxVXuRRl0biG3FYZtQpN22WqhYbbJBSdYAtwg4Q6qBa0yV0jQJob2RWNHWM2ZGw97W\nviC8AHt+kR/M8LVSgqGlUlxoAFyFmgzXbZgJ4gJWFTGPyxkfM8kUKw6T2JiotucrtSLeYeJwtba8\nWq9EVbZiFANoa8RD8BhK71y7tcfIpdCJo9OI9wq1YlZZSsXU6GLEVxhdxKQwlUrFWLZM7xVTGFXZ\n5sKRzJIXnHdEjUQTEgmpymyFuWR8y83hyvWcc+FiCxo8qRRCJ6TSvL9dDEw1UXxFcdRUuY4DF2uD\n2j4ElpzZO6WqUitcx55LWtvcwCmrGZ93A5eS2YeO69jzME9cdY2gfGe84svDFb89n16KNSN348CX\nVzeUmlFRbseRnDP7oSeo47P9jrvh97o1Ar5luu2oCE4cD8vEWgqnbaVzjrs48J3DgYdpJkTH83lm\nmzJaoe8CP+gOXHcj3uDD45Hfno6oCm+ud6y18KPbW87TxuMy8fbpxKHr6ELkx69vqTnz/nThp/dH\nPMZmlVf7HW9urgkI//HDe54vM3NZGLuRP//8C8SE87Ly0/f3bGUlpUoXlf0wEGPgabkwnQpzSXjn\n2e08vVM2jMu5staMZEHVE4PSRSVZZTlviFmLMuwNcT3wYntaM3V9uX+Njt5XCsKUhZz15Xa5LRI4\nFbIJdRW21BYSxCvRF9RXclXSIpi8eGTNEFGyVLRCztLcBy/gK668fJm23IT6u5+YYFjzLLx8Tq0x\nYaFiIlhxSK4vbQjtDSIXbfYwE1xu1jKTikTFxFHVI1kItVCpFF/IxWPOcK4tVKi0bbINh6qDLASp\nhADFKikZWQpVHEENVx3y4mjDwSWtLRfXB276HSUlvArUzIpSSzMTH3xHqEJV5bxmqmXWktnvevYx\nMhQhUTmX0lJ6NuPajy3yMgamaWVOmeIdzgoHF+lwXPLMua4sVqgBRCuxKqN2nEgsumFkisEVHVmF\n5CvlZYjm9WVYWduG2/qy/RhEEATnHYpwhXLb7ajLxJITQRyHruNHV7c8LDPHZUU5U0vlzdXuRW9v\niWvHecNry5p4tRv57v4AIuRS/58X7x/Q+aNiugBvT8/89fuv8c7xab4w+sAPd7fkUvkPH9/xPM/s\nXOS8rfzrN1+wU88pbfyHt28Zvec0b3Sd54f7G3bB8bePDzzPK6k0X+2/eH33MoByfDhemNLCx+OF\n6AOHceDVruMyb9zPMx+OZ2qtmIfrIbLvBtTgZ4+fOE1nqhljP3CzH+i94+Fy4eE4kaiYCb1XxjHg\nfeT+fGZbCkUqqsI4OqI6lppZ5kKW1G6VRXFDwIm1RtrJkFIRFaSrL8Z4o1rFSltmMAS8EF0lYRRa\niArWLh5x7WNpVJSSmwcXEUTb7b+ZvDgRtF148OJvaqyrPWZgv/v4IkgiWGnOhcZ6aQbUAlIbIFML\n+vKYSZsaqRXEmhxhQV8kB3kB1IploaoBShWPV48UUJeoCKUI6jw4j1dw1ajt/QNzilobuPUiRK8t\n5mGriDYPVxwCDqG31vJgTtlyoZbC4AO7vm/LIqU1RU8po6Gx9V4c3kGuhbUYWy2U0uyBToSYPXNZ\nmaTgTNmssA+BlCvewZSNJ5khKF1ulroSlCknnFdKKZg3QoxsW2Kkx3xlsUTEsVLpzNGHwFwyo/Mk\nFaxkrmLHkgs3fd8SynLi9bjjshXG6BlcQEX50fUVD8vaLGWx49D1/OTuDe/mC6dlYT/0dOr4yZvP\nyC/ZFa/HPWMI3I0jN/3wPxkZ/oefb326vzs/e/pEtcafpm3jp0/3aBUezzPHdeHNcOC2HzinhWlb\nIcPzslBK5abbcbPvmFNi1MA3z08c55Wlblx3IyEG3gwjb88n3h9PPM4LvShXY0/sHbvY8+HpyNfn\nI9uWiDGwC5Evrq95uJz46vjE0zTjRNCgfDbsCdHzvK68Oz033bDCvo/cHkbOy8Jxbkw4lUoIQqce\njVCozGslTRkJL3mtnUOCIy2JVCtWMiLgvKIhIFLYSNhaX275wUdBPKQsbcgj1gD0d8MmbRIG1VFr\nA0yBhq0CIFh9CZOFFzBuyx6UF5Pp716f9b8C4/KSBGNN6iAL8jJEoziw3EAYa0BrIJvgKI1tSkFw\nbW1W2+OuVLwWijlKUKiNrTmriBRqUQzXtrtwqEQk5aaLiLThnWkDYdc25UptzNDUoaq42sLJdz4i\nDuYtY6Wi4hn7gAK9C0RzbGosOWEpo6q8HndkK+Sl+W/PJROlrc916vCiLNvGQmHJFVTZ+RZsA8ol\nJ2Y2RBRxQHCkrdB7ZaYyS0YUFGWXA9rDY0q09yQjiOM2DNynqXmBnZBr4aYbOOWVIJ59DMxr4jr2\niGtNHbdjx1YqX3Qj3jvu15lDiERRvthfsY+B47qy6zr6EPEK1/2IU/ju7sBGZUmZV+MOFfjx7S37\n7ve6fh2+lRf+4VGBWsGLsKTmKL8ae3Dw5dWOkuG8VH7z8YmbvidbM2t/93DFed345tOJJSfGLuC8\n49999n3MhMdp5m9++54osKbCj1/d8mo34qrwf354z2/Tkad1pveeL1+9wkfHcV352ccPLxtmhet+\n4HY34oPycJn56uMTqyWCCMM4sNNGKz8dT5xToW4FUeUwBK73A5d55Thn8lYRCrFXBh8QhbVWltNG\nSeAQQuzxnVLESHMil9wYr3hQh6i05KqUIL+ArVNUW/pXA0Pfwm+Q9jJTeQl5EfgdE8W9gGvLQcBo\nVq6XlVzkxY7QcsX+/nMijdVWwBusDiQjnTXGmdvihJQ2lHtJkGzfhm9SiBn6EpuYgpDNYdayC6iG\n1gKSSXiKc6gKUgORRK0LxRlWAghNehDQWqnAIi93AVWJpeB7wasnp8yUF+yl2j14jzOhqw71wnlt\nf7YBu+DxsQMrLJeNWXPbYqMx1KuhR6RwWTdWE+ZS6Zyyj45ePZZhqitTXVmdsDnYV2XQyEpmDi/b\nbBhDVq5C5DEnNi1ti06V0SkrYLXZ4HpaXKYTIdUXDV8dnQY68cwvMg7iOMTAXRh5splT3ujMs/cd\nX+wOPM0Ta07tOiuZWAPkxJdXLdLx3eXMb87PHGLPZ+OO7xwOmFVS/YPjev/g/NEx3Yd54q/ffc3/\nzd57NdmVJVl6n29x1BWhAKSqLF2tyKYNbYxP5C/gG382zcaGxukedlV1iSxkJlSoK4/YyvmwL9D9\nQBqf2ENmzjFDAomIQETcuNePb/e1vpW1po9OOfC3V58hCn857Hg/HhnU8+Z45CfbK666nkaEb/bP\nmKw87kewwnXb8XK75hwnTufAKSzs5pmVb+ibhu3QEVL13X/34ZlAxnrLTb+mvYjqf/v+PU/TyBIS\nq65jO3T03vE8nXk4jpzjjBWL9ZavNhsKwpvTnuN5gZyRxrIZGlZ49rFqg5f4cUkirL0FK5xjZp4S\npSgOwXcGaw0pVS1mzrF2Y2qR1tWjrUnkUCgl17sUBtyloMpluRa5SKrqrPfi7KUW4AIYlDpygFps\nP4lBK9rmov26FNhSgTMVO1llXmTQWKCpBVmESqnJtXk2EdBCaaRqTYtAAEl1wfZxhli4iFYRfChI\nrrPOjEfVIArW1EWcZpBcyM6BEUQstlzkbSUTjUGKYi6jB3eZPTtTUHUUsbhLU75uHM42SE5QCsV7\npCitteRcuPENGHgKMxbDkmHrqwrDOyHEQjZ17u6NgyJsraVoZFbDsURmMq21tCokX1OGcy4kW2e0\nja/9vy2W3hn2OYKri9K8KK+aDQedSZLpnGeKmbVxYA1zStx1PceY6J3QOscYE1ddR0xKbyy3qxX7\nacY3lt55OuP4er3hwzwyhqWGf7Ydf313x5vjkUkLqwvo/ydXVxQUZwyfrzc4Y7nuOq66/9rp/mCu\nrIWXwwqAue14N544xZnn88KH04mQEpuh5avrLZ21qCbuT4Hn4xlvLF/cbQml8MVmy9PpyIfTyP3u\nwPWqo2s6fnK9gaI8HEb+9PSAtQ7r4fP2ii9fXLM7j/z5+ZHnYxXqX3U9X77a0BjHu/ORP71/YE6B\nqImXwxW36545Z74/7gkhM8VE74V+vWWwhrlkHseR8xwRoO8cfeNwBc6aCaeFkAsGofGGwXrEZM5R\nCaGgRbHicG19oSeFEGKdzxYB09XGU6DkRKGgqqjUBUgtpLW08hE082mGcLlni3wcz17MDnwqypS6\n5JGmXLpGRYqpn9tW5QLdRUp2+ZrKR1WDE7Sr1U0isIB4RQyoy+Tmwt3FIAtIVIRabIvYS1cuiFST\nRJmrVK04R/EOmx1OwWiuScS5UEQQTJWJUQMUU8mkAtFcEnepDIOkWjmxubBorsV5Tly3PQ2GKMrz\nPJFEKvjcQWuFlqosOSyFoIUUhbXUgllEOYVAoDDZRGMMrfV4EUwyaImMEtHGYEToouXatOzKzGQi\n0TlMFobsEGM528BZJwqCxzEYx0LBGQsiDM7TGcNi6vdtRBic48717DUQS2JMAWeFV01PKvVkdD+e\nUZTPN1ta58hZeZpn1ECHpXOe1lnWbcNuniutbHK01vL5DzgfDX6ERTeUzKbp6Jwj5cy784k/7J7p\njGPSyN+9/IzWOMY48x/fvKVRQ0yFVef59c0LjMA3jzv+8e1bSkyMMfO3P/miQkQE/vz+kRATT+cz\n10PPetVx1fY8n0d+9/Yt73fHKudZNdyuBgqG0zLz9vjI8XjGNw0vhy23mx7NmecQePe8qz56MdwN\nPXf9QMiZ7097SiyUAl3j2A4dK2N5zpHn80iOGXWGjXd03pNUOcVCWZRYBK8G01XNpgJLVlLIqBqM\nt3gjqK3b81IipUjFJ1bfKpSKIqwzXuoIoFA740txFXKdidbYsqpCyFV7Ky5he0U8lI8zW2qTK7bg\nV7WgFq3WW10MuEpLqxrg6rLCCKko1hRwIL1QStW66gzYgrhSRxKdkE2dQbhFkUUREsE4pDWokapy\nURCNFSWpQmlAG4PLpiooCojJxKVgxCKuflyRqueN6VKIfe2avQi9ayq4fJlR33Isc33ciuPaWPpi\nmET5ME31JEJm41qKoxbipDyVhUR1mLXiaC7yvn01IpNQvDgG9ahVRpM4EkkorfH0OI6mUsPQjDHg\nrCWmgqUu8rxKpadd5HKKYIHeebw6ZiJLqdr0ddOyNp5Tif+y/DT1+VFSYeM8WeCQZ2SsVLRfXN9i\nveFpmni9e6ZrWl72Gz5br8ilEEum/QGXph/ud/Z/c619y3fHPXNKF2j5wm9uXtA7x3XXcz+fcNlw\nfzqzdZ5NO3A9NBznUIMi7w+MYcGJcrVa8/l1g8Hw4XTm8ThyCjOd9fzi81d0TVUQ/Pn+scZ/x8j1\nMNC0Da82a97vd7w7HHmazjTGsR46fra9JVvlw/HAu9OIloSh8LP1Dbd9x/048va8ZxlrIGXrDdfb\nGhO0O534Nk0w19nsVd9hHVjTcI5n5glKrnPVqwvRLEnddhP1Mjr1qKtHdwVSqgkHaupcWKjTgaz5\nops14C8drZGP27NaJZuafIABjYAFMVU94JqMbaoUrJQq1coBfJtoVgGxQsqQQy0ARUBdod9GMEop\nhhKEPIN66NcZBsBVFoAXJS/1c+IU6aR2yUWw40X76+qCrniLuosTblEklLq09J7sBYxDEKQoxSkm\nKCkp2hmkraAck0GNIppYshCp7rcmywVfoMxhQY1h1ozJE75ULW3SXJN/c+ZsM1iDUBjE1YVbSdzH\nhWKqW2zbtIjL5CJEVSZJF/pYhbKbUmlpS6lchWCUDAxYenWcXCKSqqZaHRvTs5ipnmKMYq29MBgq\n2cwgGCt0YjDOMM0FSsEIrEzD2rfEEhlLwmN40Q+8aHs+zGfenUd6b7nrBl6uVzyOE98d9wxtAyI0\nrrncnz9OMuvN64d8/eiKbnNxoGVVrDG01tE6w7QEzvPCt7s9d22P85abYcWN7TnFwPvjkXnKrDpL\nh+evX75g0czj6cwfPjwyGMsxRX52d83d0DPGyG/fvieWxBIznfX8+icvsFjenI/847ffM5bAEpWf\nXt8wNC1FlG9PKUGQEgAAIABJREFUR87jxGmZ2bQtV6srusZzDpE/PT5zWiZMgbX3XK0HrDGcp8D9\nYU+MGSdKMzSsmw6ryu48s0vnqmNFWTWepneEUBjnQrw4y6ypETTWwaIQilap1kVHKfLRXVYdZnWg\nWmp3K1xms3UkgEmY9vKAm9rdihWsZNo+YWzBGCUGgxaLNeBcxPWZpimkUsE2FmVWR9ssrDYRjLBk\nS5hrRHixgumFfh3BQopS04LnqjHu1wmG6pBLseoqTKidd1EwK4MOVXnhplJt0k5q59561AlSDEJC\nRndBRirZg20tFIPkTJY6t9YMsxekFUQFFzLiHVJglGoj1qI0pk60nRNMiSxkZq1zTSmGG+NQI3V0\nFBZmydWsomC9g6xEhR01+NKIpRXHGkOxsNPELAkj0GbLVlr2BEbNJOrX0UtDECVrIpRIo1JZupcf\nWyp1ArT1LVdtS5oyoSg+Z3rruG57nsPMkiJnJ1gcd33PFDPzEtiLxYplM3gsFit1ZGSAoEpTCtd9\nz0+vrnh9eOb9+UDWzMq3fLXZ/BtXhX/b60dXdEPO3HYDK99QVPmDPPLPTw84Nbw9nfj51S03bYeI\n8s8PT7xPB45zxRN+fXPFpm04xsi3+yf2U+S4zLxc9az6jl+0LbvxzLfPe7593OOdwRrPrz+7pWAY\nY+Lbh3tOywJSuB56rm/XeCt8OJ14v9uzaMEby5dX19ytVhQKf3l+4nge0csY4Wa1Ytt2vD8f2J1n\nSqk4wtuh43a9ZloCu2kmLBNzVHpnaLYdViGHxPmYmLXWx9ZRk3R9RfYt4cJhNYJiaS9FM+RLSINU\nPYJ+HNkqfKKCS8K4evTHgGap2/sm4V2qvv9yeX+UxlV+xNVmwZqKk0zJkILDmULbBrbdRNMoc7Zo\ngV4iRTyNy9xcTRQMoThOo8eJUEzCrKDpCuKEnISUFBuFOSjdKiN9pYblqOQimJxIro48dF2Ljapi\nA+hSo4O1jWix0NRvPqni5ippywKpE2IGEYPPdSZSnLDERHGGbKpsT5LSYUhJ2ZnC6TKRqWsjxZqK\nUMwEdpqxpg7Db12HCBwJPGuilIqZ7IzBuwoYHzUTc7mcgCzlogpJWh/b6hKs46BWDdjELEK2QlJo\njeCMx9pIY4XFQS5UXbZUWHlFWFpaU+29onWybY1lJQ1JZmYKUwlYMbzqBp7nyH6ZOKfItmv49dUN\nSTOHJfDH3SObpqX3nhfDGqg3th9yYfohf2//l5c1lpQL+zKz5MxpXrhue3rXsGprNEwuiQ/7E6d5\nxmD44mpNzoVN1/DucOJ5GXk+z9x2LbergRerNbkoH/YHvj8ekaxsVw191/BqteLD4cS3+2eO55nW\nGbq25Vcvb1mWzPvzgTeHIzknQPn6+oaX2yt25xN/2j+xzIElRrbNwBfbNbnA0zLyx8d75jlhXY36\naZ0lq/JwPHE8TUhRvHPcriy971i0MM0LMQIJOg/tYDBYlpCZQl2iaQEstJc6GsslweCjCkFMjZS5\nROdI9dgiJtcqHivFTaVgrGLJ9E1EjNK4TM6GEIWhiazaUP9ODUu2qELrqi72bjNiTUapc44YOwqG\nTTdz2414V5hKQ0qG1iZwSijCq5uJpELAczx7tFiczTAo/Qqkq7lfkhWvhelkcX2h6wrJeFIy5EWx\nlE/LttIqYgXVgkSDjFRQe1vI2f6LakPApBpto/XhwfYX/kGuh4IksI8FvBDlkymPBkdJmYOJ7DUi\nWhdqjRVwljEmlMxoyiUt2NbnUnGElJm0jpQQgxfDtnjOrnA2kcgCImy1xRnDo86MLpO0xrlvra/o\nSa03G4+h9Y4ctZLNcsbjuLItYoWneWaXZ4wabvqe3nqel4nHZcQYw6t+RescuzBzfxzBwqapLs05\nRY7LQir54jZUlpJ41axZec+SE/8Piqr/318/uqLbO8ecE0/TiAD7ZeFvbu6qLAjD/37/hqZYihZu\n1yu+3twAhT88PPCf339AsjKmyF+/fEHrHVqU/+PtezRndsvCzdBxux1YdQ3vdkd+9/6Bp8MeNZav\n7q5ZdR2FwuuHPY/zkdO4sO0btqsbVkPDKcz887u3PI8zRpSubflis8E5y7gk3p+OxCVhBa7XK243\nPaYYHuYjp+OMpoJxnpuVZ3At5xTYLTN5jsRUIS2brcOIYymROWbCopf5GuDrGDTWMS/JXKJgBLyp\nM2EVapKCU0x7kZRd4Cvq60JsM8w4l/FGmRYHakgBGh/p25l1k8gI3mZ6IiH19C7yxbBHt4aswiF1\nlGzobeSqnbhqF5zJRByNyZSxpglvmxnfJcTCObbE6PGSue0yO2243UyAsmjDNDlSrlZm24C5LthO\niclgc8aRmE4eWqW9KZRsSMVWvi6gF/g6DRQP5AzFYkcQVVILJWVEKtWsLhOrmy1p1dEWD04UX4QB\nWAzscvokSW6lwtWlKJIN+1JjgJxRRByrYnDOcI6FRRbK5eezwVOsMObAwdSbqCvCxrScciCQoUCD\nYIxgS+2ES1Faa+ia6iaLoc6cjAiDb+iNI0mqa7okNNawtg3HEolZ8a7avIemZUoJtNTH11iKUyzC\n2rdcdx3fjpF3xwN91/DFesPLYeDteOT744FjDNx1lb/wQ75+dEU3lszGN9x2A0ULg2v47nggxsz9\nNDLg2LQtN8PA+9ORD+OB97sTY4oMjeO6bbG+CuDfHw58OJ5IWfHW8NdfvCKnhBjDP377nikthKzc\nbTYMrWPbrfjuuGN3GtmNI2vf8vndDS9XPanAd8/PPO6PiELbOV4OW66HnnfjkfvdnrJklGrk+Opq\ny1ISH/Zn5rwQ50jfNFxvrnENjCHwNJ85nQMO8N5zMzjUQE6F/bJUILcIrTW0rorgI8o5V56qoFWb\nqoAK0RdEL/bZC0HMeSUvlzloVmyTuFmdwQjz7BiGgGsDU/Z8vt5z5ScMMGXLIfWE2dIPkS/7PVtZ\nOMcG5wvXfiZnS7Hw8/6BsrJELLvYkbKl1cRNd+Yzs8egTNrS2ozNhYBl5WaGLnC3Npxyx3lpEQrX\n3UROwmaIOJuZi2eJljS2GFNwXtHriGlqQS9BaXKqicAC/qqmVRS1lKkC6g1VJ6wKNJCsYBTcrFgV\nkoXFajWTXDLfiqlys3OGWevx3VtoFXpg1srcnSWRtN7wOrEUEaIxTDmQRBGExjokVW7GkjLpMuop\nWDqg0ZqOHCg1NBMY1DFJIpjMYipQqMl1DMJlH1qfG57GCZoMqKFQ2LiGu3ZgKUeWkiBCZyyv2pY3\nOfMcFyYKvXF8OVyxm0cepxOHvLDyDS+vrohSGGPgYRYG19B5x9r5qgT5r4u0H94lUqn3Sp2znVJg\n6ztu6BicZ+UaHs8j7w9HplAYGkdjLD+/veEUArtp5I+PT/TisGL44nbNunUc58A3uz3jEmv4oG/4\nqy9uSTny7nTidx9eI6VqL39xe8eqa5lL5rcP91DgNJ64XW95sd5grPI4Ljw+PbI/jngxvFj1NG1D\nKcq3hwPn4wQi9A7WV2t6qYGUD08TQSpToW88m64aNg7zzGkKdXlWoPGG/iKcz1k5Ff20WOsMOFct\ntqFUKZQWyLYuz6rHwSAqGFcoqlWdIFoTKKKrHXM09KvAz4d7nC3cnza86M+89DM6G+66E192ezQJ\nxmTeLNeMocG6wlfdM9sycb9ckbzwqjthSuEZ4ZfdIyLKqJ5DbEnR4rVw201s/IwYYcwOZzIv5ExY\nLNYWbvzIpl2YSsMxtORi2PqF3AnOF7zLLLkhIxz3TeXuukKzCag1JOoNxmiuc+sEpS3QAjhYqJI4\nIDdKSoK2VBQllVljYt07zh+3VlrjgCRDdrAPSjBSlVdA46HLllIyExk1iSLKyjS4VJMlRhJLLpjW\nYbMyqK8LOgeqC1kja9MjDsaciFRVSmMsLZZFArNRWoHOWjamJeaZMQcMVat707WMpRBi4jEsOFPH\nCxSIpXBMNXhzZRust4QlcR4nQsn4xlZnWqmpIVOIxFxZFX3j+cV6w6bpmGIka10O/lCvH13Rba1D\nFf749Iizhm8PO355dUtvHCFnfvf4wFOZOI4Lzjl+db1m8I79tPD794/MaeF5XHi1Hmh9y03veb07\n8vaUebvb0XnP7arnxXrgGAIfzgfens5M00zvHNthzXbdcQ4Lfzk887A74JzDi+WXt6+wjSeEyPf3\nO6JW7ePVuuem7+l8w4exjhFSrkfYu1XPauiJc+RxOn0yPKy8Zb1Z41QYNfNwPDHnggU23tI6Ryy5\ndluhLriMh8bVxIGslfo1UyhGCZdARKdgVchI3eQXcEP+tPXWYKAYvM2YLpOwTNFx0yxM2bJ1MzEY\nupXhv795zcbPfHN+wUpmXviZpEeKtfyk2ZGDcNWNeJP5Pt5SkvBZc+Rv9QN/ml5wNi2fdzvWzUIs\njs+bIxs7sy895+w5ppZc4MaNfLV6RkSYSt3Q37mRGAzBKOt2oWsCU245LQ0hWToXWXeRJSu+zcRU\nF4vT2UOk5rv1iVxMNVlEgZpRiaJkr1Vb5wRJdaFojFTFgCgFg7lwfzF1Fh6dkNA64lHobBWIRIVj\nzpQ6YqYFrHWUoixSVRfF1n2Fu8B6soWKAaoYSdVq87XZYwSSKEGUAYOxgmRH61y9eRalkLDO0Bpl\nZW0l0WEoJWEMGFFiVNriyEY5p8SYI4jwWdeREZ5sYl8CrXe86gawhrfnI6+PexoxfLm9om8ansPI\nd8cD23bhph2q0uEHfP3oiq5S51+frzd1OaSV5nQ/R57GE/tpxBrHV9fbi/vL8eZwYL9MnFLgqmvZ\nDis2jWMuyh8fdzwcT0SFV6sV1llebnpePx94fzoxzQtr17AdVvzkasM5Rd4cdtyfTjUCxRm+vL5m\n03d82O959/xImmtw4bZb8fV2wzEGHuaJ8+OOXDKtFb6+u8Ua4RwyHw5HxnHEGMvgPaumQayHkvkw\nj5RcdeutNQze4cVzTjMpKrMqWoS+oXaqqc4VY4GFOs9FaudbLJgoBFWSyRQHRgp5Mp+64mKVKTpe\n3pxYsqvc2MUyRk/rA3YofBi3LFPD3wxvWJnAT9tHDnPPhOOXqye+bnb8dvycUIS1DXzV7bDO8qoJ\nlJC5cSPD6g3/MH9JiobrbuJ/HP7En88v+UZf8KI/8FUzErPDS+aLZs+jXTEXz7xYluJYNYGXw5GM\nYcoNoXi2bqJEiFi8qwvAMTWMiycmjzPK0AfOVJdh+djRLoYcpaZydAW1FwBEBBDUgsk1OQILWQSM\nIqkeo3OB2eonQ0lbBDVVMQKQMri2qiq8sUiGWBJJhWwU7x02QecqgCeYTNGIYhnEXzS31cWW9BL3\nbj2x1IVWTNUW7UrV884KxVhE6wK3NY5RZ+YSgMy26bhtO+7TxK5MEIW197zsVjyMIx/GM85aeu+4\nHTpOS+TDNNK3nmvf0bYVvh9yppWCt1WWGFPBdv9Vp/uDu4oqTgzbvvvk+f6Pb75HELIWbtYrPh/W\nxFz4bn/gw+mMMcKcIn/36hVqYI6Bf3p3T9HCFBPXq54X6zWNNfzx6Yl/eHfkPAVSgZ+8uKWzdZv9\nu/sPjCkyz4Hbviaf9q3nw/nMm8OewzzTGctV27DZVHfO96cDD/MECXon9KZn1XnmFDhMgWlJNAba\n1nPVt6yagcMyM4YzcQ6g4H3LddOhokw5cAgT6dJl9U4w3iJFCSiZ8ml+6ajLFC5ohHNRRqOVu+DB\nl/pYalSSMajNiFdck3k6dxgMxmSsVz6EFf9uu+c59GyGmZQsu9Bz15wwbeH1fMu78zUvNyNbt/B3\nw1u+K9ccS0tvlP/l9o/85/EF7/KKbbdwLTNndljrkWjYmsR/t/6ef5i/IASHG2b+vvuW+/mKPx/v\n8E3hVXsEVWxSPnN7DqarCgh1HKMhG8tNP9L4xFwaxuRpTcI6JQWDNUrTJIqFEFzl9qrgm1wXjPbC\ngXWVI2G02nnVKrmhGklitT4jkI0iH80kH+vMZZkWLv8rVbwAEbyncpEvWpIaIW/xGAKJY061A6bQ\nSkvUUvGQRTAozlsyqUrjtOJHzUWDrSI14sgordbAy0kNpyVQLp33te/ZLxMxC2PK4ISt71hMxoqQ\nSsFYQ+st8VLMTTJMKWGtYYyRddPxoq/Ls1Op4J3rtuMX21taZznHgF6yDH+o14+u6FoRrBi+PewB\neBjP3HQrrtoWI8Kb44HdNPP98ciUIp+v1wxtTUx4mkcep5mH0wkjisfym7sXTCVyjoF/fP9IyIqW\nwqvrbQ3ZU70sz87kBK31fPVioDdVHvPHh0fOISKqXLUNq3Zg23renU8c55mS6kig71q+uNpyGGfu\n54UQEykpq9Zy26wwpkbF78c9c8q01rBqO1a+rfM7Mg/zhCsV0GKdoXMW56pcKGhg0UpUNEINPhQh\nGIi5MF9oj94C1LhvSnVm0VWtpvG1ozMIIViSrRrTxhWu2sDr0zVzblDA+8L7fMXf+zc8R0/fJQaN\nPKceI8rGzezo+fZ8w1/1z3h55u+HBzYpsk8tEcP/fPMtf1qu+afplmEIdDbxq/LII1c0Cl1RfjPc\n07nA6+UWCnzV7LjOI9+dXjCK5cvhQEtiyY61m4FCVEchcc4NQYVVt5BECMUyhTqe6H2ixGqVxSq2\nqy65HKgdrs9oIzUFA6rIWQVswaRqFVahEn6oowku7r0glzd+dFOX+uZDqe5rbKVqSq4jg7lU+Zf1\nFwduNliFRZUoCTVCyVrHGc4xXxImkmZ676ulOpoazU7GGYcRQzE1y01F8VZojKkNfA6MRXEKN03L\n87Iwa2IXJ5waXg0rTiFwiIEP8UznLK+2a8aQeA4T/7yLNE74fH2FNcIUI9+ddly3HRvf/aALLvwI\ni66I4K35dDdtrMFdaFgfxjPvzyPHZeKm6xm842W/4n6eOMaZPz890TnHuvO8GFa01vLudOT17sCS\nMq2xeKf86u6WQ4i8Ox14dzrSFkexhp9fX9G2Dftl5A+7R1LI5ALb1vP5toJ03o9H3p4SMSlO4PNh\ng/GGKSde7/c1eLIoQ+PpVg1OYCmR3WEkW0OjwqbxeOfojOOUl7rRLooXAw7WtqFIAa3ch5BLhYgB\nrXM4JyyhMJMZS6HGiGnN8FIhAXMpNdZMFFuoc8FUM9fmybHeLCzF1ON0qikOIXmiGhYcTpSfrh/4\ny3zLPvZM2dPbxBM9iBLUkZzhC7tjpgZcelMYrfAf9j/lys/8T+z4Vbtn2Ca+m1se88B/u3nimE/8\n/vwCL5mNW/iCA7M0ODJ5Mbxsz3Qu8fvxM0K0rHzg6+aZd+OWpzRw153p28iSHIXqpHO2UBCWYknF\n0LeJpo2E5EgFyiWEEyNk95FBUCVZkkFTBlfnBcVyQV/WKCE1ETVSCWoKUOlmtiq8qiSstrYAl6QJ\nrT8HU+ORjEKr1SYcRTloQKmNQWMsZ0pNBM5al7wOYqo2bzUWb2uhkwwzkbFEnLXc2B5jDccwX3L4\nhNuhw4vltATu54mUMp13rJqW81J3HhmldZ62MZSszEuFJQ2+KhXEVNVFkUJOkErhaRl50f+wYTfw\nIyy6ALEoP7u6RkQY44p/ePeOKUamFClkfnF9S2sNT+PEf3r/jlQKD+PIT262NK6hd5Y/PT9xDoH3\nh2O9k697bvqBfVj45+cnnqaFkBJb7+manm3reDyPHMc9u/NMb6FrWl6sBuYceUwz7w9HNEFrDNd9\ny6ptyaXwMJ+YlyoYtwJ36562ccwx8RgCJdYXe2uFu9UKzZm5FO5PJyKKsYa1t4hWaHeQSisLGVox\nWCu0NdiMjDLHxKJU15WpcqXWORZRppzq8sZVjoIzgokeVSWR0QaiCHO2Fd6YazzP/XHFr+4eOKWG\nnC1ZIWXDLI5UDPvcsc+G3/TveRNXHFPPmBt6v7AonApsDbzTDT9dPzIXjwKNyViX+Q/jzxGF/2b9\nzJ2f+Xerd/zhdMO7sGHrR/59+xe+mW44xIGVW2hs5IWuiOqYgsOjfLU+kE41yXjtAtfNyGFa8zSu\nsA6GJlAawylI1dRqPY6bXOW66usSjI+JG3OthnrhP1QnhHwaJUiJn6hslXYG9R8ACOSPbGJnIFd4\nOqUWVf1XHa/Vy+RCq+MsS8FZW0ceuRZaEUhUeI1olbKhBes8iBA1oQZKKZcwSceyZCYbcFqlcRXY\nE5BSo+iNu4BxrFSkozhOEohkphK5cz1b8bzJZ8JcwCh3Q8+rYcXb8ci3+2c2fcdtN/D15hprDLGk\nf8NK8F/m+lEWXW+EY1yIqTCGQGMtnXN84bfs5xmD8s3zM0/zRFKlaxy/6m5preV+OvG7hyPnEAil\n8NPb68siyvHNbscxjpxDZONbNq3npus5xoU/nZ44TgtGhVVrWXcDg3U8zmf2SyCnasdddZbPtxsO\nc+BxOTLGQqlOVF4Maxon7JaJh3EmpfqCGxrLle8IZM4x8bQsmAKtM6yMwzuwUrvllDO5xnfRWKFx\nDq3ILOYSOEetoG4LW+vIRihaGFOq4JRPndUF8pKFTKFohqYaBKyF8dSitkqr3JDJVtiFnqTmYrM1\nvD7d8rNXzzQlc582RBGW5MjFUIBRPe/nDV/7Hc/F8D55jtnS24g1me+T4ddt4o/hms9Xe1BhXzxf\nuYXilX9aXnBKLX+3/Z6tX/h5/8zrKDyGARXDL4cHHpYVb+drroYJI4XbruExrhmjo2TPyi8UCqfQ\n0qjBu4iPjjg7QnT4JiGukuiKVkeCUEE+8pH/Y/gEbyeVakvzipp/labx6Ve8QNy5sCz+BSIUNWLE\nkC14MZ9omcolht0AlNr1imUSIdvqVktGGURQp5SoF/dhnVFEPiIwLzZm53ACKkpRJRmhxbC1jrMu\n7GMgFI+3wufdwGOYOS6BqRQGY7npeg5ROIaZqRQ2Tcvt0LOLM8c5MOcdTgwvViswhTEv7MPE4Bo2\nTcsP/fpRFt1V0/DtbocYw3EJOGu4aluWlDnME/fnM85YGmv51c0dU0ksKfC/vXtLSoWYa3Lpy36g\nCPz5+ZGn5zPlMnv79d0dKjDHzO8PD2jOnGPmumtZ+5bGWd6PJz7MkTlUn/lN52mcZyqB1+c9YyhI\ngZU1NJ2laRxzmHkImRgr07uzwuANXdewmybmXCphKoPzsHE92RZyKZxiIJXa+QwXxq8XS1BlzkpI\nC9lA56XmfxlLkcRcIudyqRumkqZELEkLs2rt6rqqtzRah8I5VBiOALRQSh27PIxrQq40r65NiIe3\nYUsuBoPSaOHNfMW/377m2sxYKRhRpuLJl6Ozlcxv589Zy4xZfcvvF8t9unAjrfIhe74S5UNcM3SB\nKxkZcXiTcSiPOvB6fMHGTVivbNpAzEd2oUfF0vnEDSO7pQEyXVPIBkZtmJMjBIc1ivGZlCr0Bqt1\nBpoF5jqINb7qmiUBuXb8COAu1VguCzXVC7pCP+rHLr8LaI0lQnLFaRqoYTv1j5nKh/ZiPsHdPoqt\nQkkkLTRNdQiWWMA1UECtUHwh5Roi2l6IcpYK6gkxMdsqIbtqerIoi2YOMWDUsGkbjAgxJc4psaRE\n4x3GCEvOl0gnaJqLsYJUOclUVYYKGGe57Qf2aWJcIk/TxNEGPht+2LAb+JEW3SVlvrqMF0JKvN7t\nuD+PvD8dOS6BoWm4alpu6Pj+eOBpnniczqxcg22FV6s1H8YThzDxh+MzOdfj3M2mx1nLOSy8mU5M\nywyAl4avbzZkzcwl8d15T8mVG3u99ngMzhjuw0hJEGPtJgdvGRpLSJn7eSZfBPadFVaNrZ8rBY7z\nzBK1YmsdDM0lyrtMHBdFymVHY2DtLMZb5pwZNZCz1gWfdziRilgsyjlGgtbo9dZCY6QCvlWruF4A\nW8cPLtSOLOklviw4NtulbsOBtBjO0rDpZ9QoUR1zdliU+3nDMbbM2dM3iSs38jrc4qka2LVZ2KeO\nhOHKTShKI4mgnqLVNXbtTvyv559jFH7ZPvCHZcW7OBDUXMA9jqAWVSFYz1fbZw6xw9lcT/sO3p+u\naqJDX1CBm37h/rRmjpaohs4lYjQVDem1JmZ4JSVLufAsjCtoI7WTLf9KlaCXGAkxdR6AoBdgDVLR\niWovcUifdkhayewfYe+XQlwu/02SyOoQUyhSTw9GKj0tCRhTlRRa6uJYTCGQ0UvaL7Q4uVD3EIqx\nFR15oZk5YwkaUQq5bvzq53aGtW9QqlrhkCZUhWvfolZ5XgoPUz0hvvA9nbW8G89EPWKM4eVqTd96\n7s8n/nx4Yu0b7oYVP91cAUIs+f/dF///B64fZdH9qAPU8hGarcSS+WJzxbZdWPuGN+ORh+nEt8c9\nVoTrrmPrGzDCN4enmjiRAtumBVU+X295Wk7cT0fuxyPeWIzzvOw7jBh285nnFCipLrAGCy82a+ac\nOMSZMVQYuRH4bNOiKOecuc+5RlJn6J1lZU0NndTMOMeKK1Rh5esLd7COXQiIQsmXbbeHtXiKgZAr\nbyFfntteoG8NrffEUtjHhVj+xT3VtA1ODGOcmUsV1CPVQGENQKVrlY+aXhTXKOOpwVysruLgEFqu\nNhM+V0ZrSpbn88D2amKwgYmGY25Zm5lD6ngOPQ9xTWMzKxf4NmzZmJZz6djamVkdb8KGX/WPfDdf\n0UgmYwilOs/WduJ9/Jp58fymf8/7tGFMnkNqGVzE2sJUPFs78xwHXqyPnGOHiuBdQQtMpSFHYSq+\n3nx8wsZCTHWjVbQWt5zkkp5xCdGsYo6q+7IKbZ1/a/r4NsVfAEGaucQFXey3pdSPEajF9l/GC1aU\nrKF2ikYwJVcT8kWpFqjvbmyNSh+LYlGiKSgVSqRFKP9KfVKfIErOqTbdRhhsQ6OGo1bOCFJz9m5c\nz/145mEe6ayjs47P+hVP88humWi9Y3AN667hHALnlEilsOk9XdNyDDNP00ibLYqhNXVDWOE3NdzJ\n/LCFC8CPtOiumobXu1119KR6DPvJ1RUCvD0k/rx/Yi6RwzTzq5s7jBM8hv90/5ZTCjydR4w1fN6v\n2bYtu2XYWg5IAAAgAElEQVTi9/v3nGJkyYnroaf1ng7Lu3BkiZFjSnTi6FvL0DTMOfIhHBhTukBR\nHGtfNY7nHDmXWBdZuNqh9jUUcSYz5oyUekR1BlYWKBAEHnMkUhujrZVLp+o450jOdfbny8dOuqtK\ngRJ4WiYytSlrBKy3tEYYRdmFBaiv/RZorEOtYYmJbDPJ167NGsUg1c6apG7dXbUGN03h3WGLlMuY\nwFdVxFdicUZxuSAK9+OKn/ePrN3CU16zzwMWZSwt78Oa92FNVsGZzF4bvlmu+T5eszYzGcM3yy3/\nw+Y1U3Y0kqu+OLcUFZwtxGL5w/mGjZtwFvax5ykOeFNTFKLWr/ccG5omkYohxMo8yJj6KxtSFnKy\nmCZjHaSFmlQRP3arpea95YtJ4jJCodROkkvig3KRHgBoRqxDTXV/oZd0jEtQo5gK18HVG2rtnhNi\nlCL1fUUbNGt9Dkg1ZmSTahrHpzgPJUshSR09OXuhluFYSgXbTK4mh2y7niVVR9usVQVz5RuKCqWU\nGkWUC94KUQu2QE/DkQlVZRFPk5Wta4g5cUoBEqxbxy9u7nh7PrCPM9+fDrTG8cV6+29WB/5LXT/K\noptL4W4YUK1/vh9HplC5CYdlrq4v7/mr2zuWXHg7HXk3HjmnQEb5m5d3lf6vwm+P9ywxMMVIbxs2\n3cDKWvZx4X0cmWJ94V93jtZZnHXs4shcMlkLzhicdawbxxgy+zySVeuG2FgGY3AWppQJWiVmRizO\nCNtGCEVZrNbineuWubNC6/i06T5XgjfFCL1C6y2gzCxEVealLr8aYGg9RYSlRJ5CoUgt4E7ANRaj\nylQS88dXvWjV/lohR0OkbvKbVbxweRVyTXlQqXKzmkphWHnlL8cbQnQoQuPrv/mcVlz5EauFgcj9\ntKasYO0iXVz4Lt7Rl0BUx5+XWx5Tzy6t8CZz40f+tLxgHzqsZLzPnEqDNxlLQQRWdmbMLVnPFAy9\ni7w+3+Ek0TeJ86wssWGMjsYpxpYKuNF6FDe+IBeZYS4O0YKmC/IymTq3bA0sF5dJqRwMY0FsLaKa\n6iFfpBZi8bXbQ9NlTlA7UG9jjYw3H2OGgKw4ozU11xck187YWl8TOiSDdXURJlU3bQqQ6ijCmlzj\n7XM1VmiqygmRRNFCLoJo/bhOLEEKOeea9tsYbtoVS1o4xsBzmDACd92KUBL7uPAhHIhJuVsN5FI4\npMDr057WOr5ab7HWsA+Bb54fWLUd63bgy9UalRow8EO/fpRFt6jiTE3kLaXwNMK788i2rV3C7bpj\nFxeO88Q/3H+o1smSuW5bbvsVx7jw7rTjOdTwPUT4ze1LFiJLDHwzPqIoSZV16+hdixDZhYmYMiEX\nrHHc+BYxQiiRvc4kUyU9DQ7beCwQ8sK55EvxszTG4TC4RjiVUI/7BQwO42BthER1LUU+qguEwUiN\n1TZCIDNlJaX6BOgboXUWEcOYA3OsHbEFWie0YkmSmVOuuEepnTIGrFg0lbo9px5X1UBZDGKFkiAX\nocSem6szIlKLVxaWyWHaTMqWbC0xG66bibfLFX88vGDJHn8pxH9Z7vjlcE8snmsz8rSsOEfPplnY\n5om/LC8wWflZeeLdsuGYWj7EDV4KWz/xOtwiBebsaWwhSySrYeUq6WrtZo6xwZeKZPQuM00r5kwN\npVTBUEhZakz9hYtQFEqsy7SPc17NUs/62XyKJwJF1NWkYVUwpTaq5FpsPuYjST1mc/l71Yt/oihG\nEzk7rK+dp3WVHlaMqYU/x9o924waU+s3NZJJHWSbUc31BpLquAILpVT4zUJNU9n4hgbDUioWNFEV\nCBvrOS6R56m6NFsMm7ZnjjMxJnBC7xyt9+QcmWLk/2Tv3XrsyJIrzc9s7+1+zolgkExmskqXUle3\nGt1oYP7/DxjM4wAzQKNHrWlJdVFV3nmLOBd339vM5sE8qId5l4CkHEiAySCDEX7imNs2W+tbRSuz\nVg61cBmDWzeKezr1KKw2+Op4ZK6N4XlPfunXF1l051r5/vGJ1XLm9GlZ+euHByhwPx/44fETP68X\nvr+c+fp4Igp8Nc388fyR359/5k/XtJPOtfFymlBRPqwXPviZ29homhCRX909sNiN4Vc+jStRFDHh\n1SGXQCGdqy37mE9RVQ61UkvQbWONQWgQQ2heMt5FjC1SrpbNUKVKSRE6xhIbWzg20vSBCi8ruDoj\nhCcbue3ynSRWsjBcfDB68gEE4SQwz5UALj5SwA6oBgdNShsRbO5su/ZTG+gOxhmb4ip55K0Capxv\nh+wPAwbK6o0Xdyu1dq4GEcLT9cgL3Rhe2FCuvXHfFj6OO/73n15yGUdaGWgJfrd8w/82fct5HHhd\nLzz2Az8u9/z27h1IsNwa5yhUMR7lwOKND/2EGhR1PvbgPBrvlztKDY5t0F25nzrLpszF6a6so6Aa\n+1JOGJtinv8VyYTg2Kl1sWu4xDNVomieMHDdx0ieioV9Nh7PqRsRVHGGF4o4UpWBZ+rEVlAZVLEs\n7MURKsOgiSe3V/O+h41UPexqCSnpbvMIiijuu/U4Umu8Rc8RFpn11sPpEvsiTjnWymUMCoUgbcin\nAttwHqaJg8LjgE03woL7OvFKDqxTZ7FB1WRYf3O4o98uvF/OHOeZV/OBv3l4ybfXJ76/POXst878\n5sWrf7O68K91fZFFV4BWUrpUI3gxGYLy7eMjH5YrH643NpzfPrwClCU2/nh5x0/LmY/9xtvjPS6D\n1/ORb5cPPC5nPm6dWjJz7VU7EGI89g/cYgMygqcU5b7MLHHDPNUBokpx5e652/GVNXqeSimUKLyY\nGyED90EXJyyPhLM2QKjibHHFPDAraBSq5qIiinPdI8ItoLqiCqcp74XhXMaudQulaY4mRJSnbVcp\nwG59VQpgbslp2A0CWjJmhpGKDDOFlqoJKYGEQaQ12A1KAVrQmvPT5Q4Nz6RhLYwOb09P3M8r26af\nC/FWL5gXXOBDP3Gk06Xyf7z/W57WA+BMLfhhvOAbP7N45aSd1YN364lv5gsiHST41E/7dDYYPnP1\nxrZUWkmJ2qdVuS6N7qkaqJqauc97dQ187KhLzSobzzdqk2yBW6A9u2T1rLGqgxDJ47/XHKAXkgch\nRlOe/yAjcjE3qWX3W6A0xSwDKN0diXTBRQANgpLxQ24J3dFApp2FO/IbMA/q5IQasQ6aZ1DQHjrB\ntgXVjcmFVgszyhbBeV2xqTBX+Opw4mntXMZGvw6KKK8ORy69cx2D79ZHqhQe5sbNnNu28s/h3LXK\nV6cHpCjdBt+en7ibG00rXx1OqCT/5F+S2n6Z1xdZdD2CQ6u8rPsPHMI//PwzPYytG2/v7wnNH+z/\n++fvufaVS1+Rqvznu7ccauHbywf+4el7BivncN7eHaklpTY/9w902zJsUAqlNO5L4WadJR4zs6oa\ncyiF1MW63ggG2xCKF9QLRxEsDNEN8+yC3DVBNFOhSMciWCX2QgzHz6ojYzCwCDYP1BstlHlWRIJb\nBMtImpmYMKtynAuGsYSzbilhKCIZpCg5PrjtqMFduMAsiS3cIhiRUT245wJeshi5F4RAj5a/p2S3\nZ0JfK71LzjubUyfn2+tLGoNulRBhHcqnw5FvThfWW6Xp4NNlZq2VYQUR54Pd0TbjL0+d//74Vzyt\nE4tNHOeRI6R+4r6sVIK7svBhPRHzjak58xicx5FlwFSNEcnMvfRK0xyPRCS/YHRNgHt5vs/ZPabr\nN+VmTkq4PIRCFknhmacQhOy+XoUmhotQClhkocwhU97jJvuooCZ8PKRSxCk4o0IURzzHNar5EK/V\nIQYSQnjdgRqGac7bzZ3QPYhSkijnJWf1FLjXilBYbOPT2LiF86IVDlJZxuC2dnrv1CIUrVgM2GE5\nUyloUW59cO/pVJSa6caPW6p9bt3YLCE4tyH8+uGBl9ORzVNB8Uu/vsiiqyJs5pzXK2bO+9uNN8cT\nrsGv71/w6Xrjz7eP/OPjOzwMEeE/v36THAPb+L/efccaA3BO08yrw5xwbXtktXMyTLXw1aFR1BCC\nVT6i4nTTNHpKYy5gthE6CHRPD4j9yG8gC2AsvQIlNbgEnUGRjU7gvRCuNAKdKioDo7NZOo7chLua\n8ekRzhIdDyFHcIUmynQQguAcnW3bExIUGkJLkxVb7BlqpGNtAoh9Mb8HNSIp8hcVYhMC39kF2VHr\nULQ4MQQbyubO4cWaDFpJLWisSqdyvk478TtozfhxvefTemAbLcHqVvjT+YH/9OoDYzvQ3NhG5WmZ\nmXSgGlxH47Y0XkwLf7694rpVLn3mOI3saPuBVyys1jho52yNYcpUjR5CoCwm9B3BqLqHNPYCkWMS\nwdGxZ6ahiKXjTCMSMlOcaAIjCCuZsEwgml215tMsLcX7UPc4b4ztWYdcExqvQQnwyILrqiBOrcHa\nQUnSVzxH3lsFG0R4qiKag5ZcqBmZbFw6o1ZiKNXSDGTkxyIS/PA8uz2VGZHCSufcF7YI3pQTx1L4\naRm8W69ECA/TxEkqq+ZSTavwUGfu5okfble+vT1xaJWvj3e8Oh55d7vw7eWRy9h4PR9p+svucuEL\nLbpVFYlg7Zm2GhEc55kP24X3txvfn888joVfH1+gqkiBn24Xfuyf+PF2pqgwS/AXx9dcuXLbzny3\nXCj7XO2uTryUmo6uuFHKykSwMvG6aW6JWUBWxhD6UBotJWAtiFiIMLo0IKN0jm0QYYgYswTXraGj\nIBWOAUM6SGeLILZGWC5K5lpANyxST5v5l4W7ll2OEtx2LbDbs3RIaFXoHpwjMA8EZeLZ8ZTlYfOd\nGiAwZWgNwGcTB6FoJUtvEWwtYCXhLiWwELZtoraBD6UvyjaEJkaZgu6ZnjC2wtDK+XrAdx7BNA+u\nceDvP37DNgqOYqFsW+W/vPkRLLvEHsqn64HXxxsiWTQ/LkcmHRjKh9uRy5bqBimZxFsiGLbPa00z\nOrw+N6qahLCeS65I72wOvIHwkvPQXa4WUWA44ZLyriK06Fg0IBeTWgMiUBykEFFS4yyBitGt0NQh\nPO9n2Vm+UTLXrY58qKrga86E8UFU39USmiOGknN4SmBq2IgMEt2/BSToo3OdCodItc2pTXxabjyt\nV1pt3Kny1Xzi09i4jI2lCxPCizJz7guXsbIxaCocWqWH87gs9MiYn6kpooXVnEvfOE0zkybXoaj+\n4glj8IUWXYuglcpfvJwzrrpWfvfhPR+3K++uVw5S0OOBV9PMP376wLvzmR+XM4POXx5eMtXkkn6/\nvGflExfv3M8ZR/Ki3rNwRuLMUw903wjPpdAQRFZCNk5ly+NkTBwbVL3iAYizmjJ6RSLTcFsx0MAM\n1qiEK4cyqDWXY6JODVh6RXuBGjmXZTDobKaY1czn0jzKShl08zxahqBauNNM5A0iFROAu9AkUYFV\noIfQI9GEhcRE1gi6gHjGpCtBaOaruWehGub76EEyol0ExBlbpZ/bnrKQ5fxpOXA6bOCF9aL0rTLF\nhTY5iylUYdsqBzXO6wkLTZZEC6R2/uHxDdGFLZQQZZjy8rgy14FugWI83iZeHZe9E3eu/YAOQwvc\ntsrouTiSSs5zA9SyaEnErv2Cz26xUGQ4Yk557nqr7LPUBCvk8jFPQWF5C5o42xAOk+WyTR0Tydls\ngVkHLnlCyHGBUGK3WDPQYgiSQZi7SUPVEfYxiSjhe0GPkiOKfbxTqVQKmxujOpskO/JIQSXYzFn6\nRt8Gh9MhqWBmSTIbtp/IkpNxqMriBSvBEMOH8NVh4uN6I1TYfLBF8DfHl3wYCz1SX3xslb99+BUv\n5onFBh7x7xDzX+L1/JKuW6fj3NaNWZWvDie+Pt5z2zY+2o3/8fEHPq03Pm0rXx0PiGQy6j9e33Ee\nNxa/cGzwst4zF8VZucV75nbhOgovpkIV51Aa6BNE52mb881T0nzQiiG6IOJMOti8gk+c5kHVzh5n\nxtYbEcph1+rWaiiOWKoARhTm4pw0UuKmjplio1BMoMJMYJHys3UApkBh1rQAp/ohUnjv6fG/r7l1\nHw4X9/TuI0z7ll4VNhPG7opqkR+X3fkathdulGi5kIqxi39NM0F4CDFlwTUJYhQ+/XyHliwQWoKP\ny4E7WcGVvgnr2tABh2NnDAitrCMoKOt6oI+CCdTmFHH+fHlBdWfr6cwbDqtVTtNg65Wqg2WptMnT\naVacW2/omg8lSM6EDUFajmOeDQzhkou5tlt0raFqqVaQDRHNsY+1jK4PQcQp6lSygIeUlNyJ0thA\n2j52ULT5voATMKhqqaIoucizLV+rpmlp1GYwKqVDbAmt0dlxyc5WTQgTvA562RgG0xCkOFMIFU1i\nmXuGrrbKqc1g8Ng7dbsyRsKBRAsfbys/bgvig4d6olThkZX3241WhIfpiBbhw3rjT5dPzKXx9u6e\nQ2uct5Xvzp943GbeHI/5YPqFX19k0S2qqMCP1ysqwvvrlfvDzCaDZet8d3nivd2400qd7/jti6/Y\novPn9Wf+7vwdmye67u3hgVZXJJw/3z4R2pnKxgHhTicaypALU/3IqW18XI88zJ1Jsus7lBvhzvv1\nBAgulYNGfk6Se3ASY7EKVThUYdJ0jq1eWbaGinDUweBZDwpuhWXk8mfSoB0GjuUbyYTeUxVR6vMb\nLAHm/bnYRmWqOUjYMMbI1U6hfgZf4XliWA3Ec2H0DMcKgeE510R2eqGAjKRvqUqaPJ6rcno1chm0\nFaKCWmp+XTxHGFb49PMJKc9uDeNxbZliIcLosG0NE+H+oSfoJxSzXOS5CLfR8FCipnvusgnXdWL0\nlLahyuhQ5+RRNI2Ekj9zgSPlXdYVPPaTeqCWpoOcyxZ29Ww2wKKAY5Jfq3gWTUeomg+kqHsnrMk3\naBKIDrxJdtldmUtPDEORz8459lNFkZx7K8qwfI3zwRG0YvSR9zZUwQ3ZYUYeKdPL1OdUOdzGwmVk\n1M+hTBxL48lWLrcFbZVTKZy0cSmdcx8UTenihHBxWMZGkYqEUlXo5pz7wpGJu1aZ6oQW6GbczRNz\nremcE09GxL8X3V/mFfuK9Ncv7nHgbmr8eL7w43Lh26dHRgRG8DcPr/nxduHjeuN3l3esceauHHl5\nKByKch0XNj5xsQulVRTnZT0x6xWTjY+bUdSYSYXELCkBUt140c40dX5cX/DmeKOQxeNYN0YXflrv\nwASthWMxJjFUnNVT3E+ATnlczY5YOffGticUHLQQpFrAAe+FYXn8n4sw12DD2Sw5DGFKk4KUikgu\n64YHYgKhFAptn/utuyEkHNQlYToi+yk6wyrLrmrQnSMQkQWmWp64y346j5FxP6LKbs5C3PGWHbeg\nxG2HybhggLZnr77y9OkueTIaSIMbQizCJHlcX7fsRu/vNooGw0E8U4qjSH7vsUPId4ziuIH3VB0g\nhciyj2pK29Rj7+TzNfPIuWkhdqVAp1QhPPCRC1L1oEmqG/LhWLD9zOUh3NXO2CpgjChYzZmyuiGx\nqz80X2/2fLYSMIljasi0d+Fj/8zuWAmkVWLbRz77TCgfakk5m71yxfHidLHM7RPF3HCHmxm3pXM4\nlD3qKhduT71zM6OJcWwTL+cjBtzCUOtUhAe942e7ssTARjBr45tT40PfWMaKqHA/T/yHh9ccWuXW\n+792Kfg3ub7Iovv8NJUd1S8CfQyqFn776jXbNtiK8Q9PP/P99Ykfb2fmmt3rN4cX/LA+8e525pN/\n4MW8UuTAy1I51aDHE3fTOzaD+5gRgQMTB91o0xPvt5kSwn1V1J0Tg0kHVZ0X9YJT+K6/5JvDlaqG\nhXJog+taebfcpfVUC4cymN1wnJvPgDOp0UrgFYqubCZcx5THd1XmiPTpk2MV6wWzirrmfK5A940t\nwExQT7txJYX6izuMBLRoJG5bNDIVIlJ07yGpzvB80PhuZ2VAFdkJVs7YuS4Wgu6BbOnuSxRkLM8R\n5DnfjGHIMWVsPgosWTwgwdtSciEqKizXxm1MWTRboBFc15azcRfGyGM6B0/HLVlsfcs0kYgE3jg1\nDSgCjKCM/FZih4hb5INC9kKqNnCUMGXsc14kk5Jn3RjWoBhDSuIfSQj76qlAkZLYx1DHeqVWRxVK\nCaI5W6/YKBx6p1IoUyY1J1Iy70VIUGpChXCh31KNUnTvnnuyMYYZosndGB6M0ZBZqE2ZSuWpL5gE\nt9HRWrlrBxbbuPQ1E4s9eHM80jVYtsGPXCgK921GxLkN5/24cV8r09SSj9xXvr+eadp4e3dHrcra\njR8uZ+5a49Xh+K9XBP4Nry+y6AIcWuXPHx5B4GldmFrjq1IpIvxx+8BP5wvndaUE/NdXXxMSbDzx\nu6c/86F3br5wXybuqjNz4oftkWVbmfSGYliceDEpjc5cLvxq/sDH7YC1522/cFc3Jul8vzyk06tW\nGoOHsjBpSgBOutCpPPKSrw8XqjqGUkpwXivvl3sEz9FE6agG5sLVJyJSCTAVwaog0rMQbxM+ctt9\nLKn1TMNDFhLraf0NQCVYtCeKMAQphbpv6jspqg9LFxuRUjLbl2w9MjCBEEokrcAFGCmbC0uJHAFm\nQRTZdcMgmyJTjjBcc9bJLTvo8HRYeQn0lOkLYTm+kAh859N6za42JIMkey+55S85Cll6pRYHEyQ8\nu3p1aJmq6+SIJGfQOboRDwYVEctRQQzEBK2Bld3w4AVxy9FDycJpUXM+LUlnM4e5dswKRQMvOdpA\n4MTgIhVV251uICOXmRSj1UztHaE7nleYas/OOhqDmvwEglpzqWWyz4fjeS6fggsr7PQzwTZjtcFF\nFkyFkxYOJTvQ61jZDFQrsyg3NbwImw0GTvPgcVtpdyUlZ5L3+NwHD1IoTZlr49QmUGXSwiQV074H\npnbell8+Sxe+4KIbAa9OMx6JTPy0bLxbz/z+8SPbGHy0hd88vMx8Jzf+/ulHzuM9PYJjUd6eXgNG\n0StX+yNOAzEmnajM3Gnhk3VWyQ604HhUHupg1o17vfL14cK3ywOvp3OOHzBetpWjdH5//SoDImej\nYrwqV6oYLsqsnWvMuMx8vXfEa6SR4LZVzuO4b9aFgw5QGGHcbMIil3eKYuF4NboLo9dcsEQ6kcId\nL3kcxxI9WMitu5V0sGUcT44XUni0H9/3xdtutvosfwr4l1kushfeYCjgQjGhOAzSQDFGVm8lC6r2\ngit4jfwk6vjtmTWrBJ6n+yOfFQXas1BSfJ8dBhEpbxNXbMtllgMmmvbdnooM89TEhgdeKwPfRy2+\n12DbLbXO6C07YlOKbJmyq0Z4LvRkZzcUDWqkI617hlKGOBqDyoR7YE0pxYmSXXjqsPdFWC1Egb6l\n5boWoyO5d+uCbjlGkACZPGfKwxFLN+BQw2rJ8EoX6AnuCe+Ypq47dpXGEGHrnWVsKT0M51BnWhOW\nDT4sCwrctYm71nIpPTaQkvbfUvl5u/FUgiaJfXxoM5/6yk+XCy9PRx7mxl/dPyQjIn75LF34gouu\nhXNoaaPtpfDz5cp1G7w9nHjqG1+3Ez9vZ/58/cDvnz7uOkjnN6e3LL5xs5Xv1vd8dXzConJXC3et\nIGGc6gXlZ67LSxQnfWeFO+18GBNLzNyXC9Pu5HkzLTQZPJQbL6aN39trvpnPn4+f923hUDq/u7xB\nCF7MQsV4Xa4UdbaozLJxsZmQiVfTQlVn8cIWjbXDtp6QcEpUigTSOm7GsJYLIXYB/pQhgRZ5DNfx\nLyXV6FnsXJEUqOYPUICJEbbTrDzhOtlrJTNWAByK75HfGnunLJ8h61WELciCWfLP6j6TRcBq7BSq\nPN6KFKR78inqLp+ogncoI7v00B0go3v7HYIMp4XjaujeSbpklyoopfc9Q2zgkhHHblA90hbMwE2w\nfXmW1JighH1WdEg4oXVXcARTyde6agJ+pGQ3mAPpRpOeSbxtVySYUEOZWXEUnQNbK+6SC8eAefJk\nGq/CGImcpAaHYlx7yRkvhREw14FIQSmE7QzeCBxjG4N5nnKJp8IkyuqdpW9oK4gI97Xx1Dc+rQuL\nN7wbb1++YPXOah0fCWM/lpmVDqGc+8aL6cDdsbGZsWyDn+VCEeHrFw+pEe9pTjrWyuvj6V/p3f9v\ne32xRfdQKt89PUHA0vNo9ub+xKk2vn985MflzPfXM49b56/uHnB1jvXAd8vP/LRdeOoLrQQqldfT\nKz5uC5d+o+qZSW9cbOblZMwyqLLx9fSRB1Uu/isgUjQPvNIL3/eXBAdelIUpOgq8nc6oOPcl/53f\nbV/zdnpEEEpx5jK4aeUPlzcQ0CZnYvC6XQmC1SeOku6vIRP3h44QdN+4esO6JlzFEyMpajkvHopR\nEy0ZqakNSz2wimTX67t8gWA8++SNzzIxITO4kPicpVZ3dhYSuUDK5jO74JGLsFH+BUUofS/UkhK0\nsXejKkJZFCFTDrxm4cAl5Wc9LbjChk+CoAxJaHfZDCnZpRbJAjtK8iTUjcksGcDkEsw0SW9uysRg\ncmfgmLYsmDv3AoFaAx/gdbc9S6I/m6Yeu2jQrdA9CzHAqQy6F5zBoOZ4RZ0pBt0rTQyLgtbAd0qk\njCSg9ep4SYtyoiwcxTENeilIz51BYFQtUPcHxJJqA89aj5e834xMGF7MuETO/O+lMmlycK9jsLlR\nS00zRSsUgW0EIzqTVrrDyzaTPspsNswGL33OcUMJas104hc1F4mbW+qEt+BXd+Vf5b3/b319sUVX\nBCbNzTRRsQhiDP7n409cto3365lXbeZ1m5lq5feX9/x4u/LRLgSWiRAIL9tgG9/tMi54oYWIhvod\nqwdD4HW7csfGB3/g6+nCJIODLPzF9Mgkg5/7C2L3+yvBN/XMH5bXWCj/6dSZuDHJ4GHaQIRD3XCp\nfFpf8vbwCDvntRRnscKfz695Rv5VgYe64sDFJxSleDqfDtWg5Sjg1lseLaOgI3HXA8ee8YKRcz8J\nh5ILpLC0sz5DuiELpZONpeyyJgkY+3ghZP+hM2Eio4merasa0J6B3QhD94XVXtdrT1hP9oyZfAt5\ndE5WGoMAACAASURBVKcLdeRAODSLCVaQasgODq+MRF1KY5ALu2IbTbMKKYKG0+tEkMqBGSPCMck4\n9kF2m/nGCWbveCgjagJx9nulCE02hle0ec6Ba1rCqxjmNWesJZkWgWNec2oikVrckrS4rRfmsuF4\n6m1rEAvYBq3AKI5OhlthLJXiRg9nnnKgPbowFqUaWAumfEFQB+uO7cyIDoTkktOG49VZvXPtRisd\nD+Nuqhyk8OiDH65XQoOHOjOXypXBp7HSxTmVxizC4+i8Xy/U2nhojYc6c7XOj5cbD4cDr+aZv7h7\nkTPm5wiMX/j1xRZdi+B+njKq2p3Lh4+8X6778Tj4yxcvWaLzYTnz/7z7kTU2Vi68qq+pc27W/7z8\nQNvOtLpSy8T9BI3GQYIX85+53L6ih2Be6NGo0Th7ZUU5tZUDxsVm/np+ZNKNo2z85XTGo6QUaJek\nNnG+aWf+6fY1FsJfHD8x68apDCbpCaYWZ2Hm3GfeHK4AbFFxEboV3l3usf33FMk3kQpXqwTZ7RVA\nimM1ZV90hZ4dIZbHfAMQRy1BPWz/oquUHUMpEalIePYP7Pe7khKz3PgHayQKUQj0edbrgivpcItc\nZKllhxsWjAIxGSGpgxVXimexyJwxf5YTUMMo3SnSs9vVZPkKOWMvArNk2OKqjc0FpFKtUyXTGyQ0\nZ7b7DFyBWVY0AkMZUhmekJvKIFyYy8jvWXPOHpHJvRrPQB1FKnRKohYFjjoISSfYoGKquxwvkY9F\nd/cgGV3vIsylp7YV2LzQN6Vqzo15RnF62qPrzlMwB1PPBw47FY5giNPUmbyCpEju6sbBhCrC1AqX\nxfl061xb8nBPh2OmTJil3E7yRaxoqmOicD/N3E+NmxtPvaNSEA3uDi3/jc24bhutFOoXkAQMX3DR\nrao89YW+rmwjnTevTkd+Ve95uq18uN34w/UD310+0Wq2Tn91fMN1nPm5n/lpfQIxXs0bEm+ZJXAf\n3PiIyEIIPLSVprnwuCsrD8dv+T/Pv2XEhHlhRBbon63x5DO/aY+5GfYDf3P4yCTGsay8bRcWmzho\nZ0TGvhyK8VU980/XNzjKi3mjhnHUTtEVU2WEcIkTt964P2wQyVnoVMzgvBzAS2o+ZV82qWOjZKyO\nk/AaA3/WBpvg9lwgY4er5xs/bJ+f2r5x189uXybfC2c2zLuLIiBSihWkfhcEs53NS6SSgZSkecsZ\nbYkCljwIwrGSBos0jjXEjGKehoeatlqiJMRFgkPkw2pQsJgYnjzcSTu4UmWlhLJQMSkEBYnBHT3H\nAAAOQ+WzuWDSlRapb92sQnF8QNNOUShqbNbYiuKjkEMcZ4jiXlij7stBYB8L1V20YWIM0YwSGoVW\nBx1DWhpacpmZizpqzstdgU0Zu3rj+Z0uHkQvSEAvjokhw7EtmZuLd2avDMvQUY0koi3d6LupRzwo\nUrk7zFwuG6ss9C0lei8P91zGmmyTXcHR9JhpKKQ+eyoTXx2P3Mag++C8F91f3/+7euEXfdVS2Nwx\ni8/Lj0Mp/HC98P524cfzOX8QTi94OR14P658XD/yQ3/PdQzuasVVeVVf0u3K2Ts3Dw6aG+dbf8Ni\nFYvOXX3kqAuPduCb9kSVYBbjoSzM8+AP62sGMGpheOVejG/7PQuF/ygfaKEsfuBvjp+YxKjSuasr\ny2icaqbmNoxD60gEf7x9lSaCkkXpVDaIjU0KooMnOzGsciyGl/TQX0fNtNwlGQ1quvv003TAnvWV\no4Y9GVjIRIPuyNgrqXyeNKRzzUhtr4BJPJu7cq+1L6t9l82yO9Gqgey+4vH8+XaTRUPwsU+OdaSS\ngfyaxIJCuh9UOlGyEocUJIIWG5H9HYO2A3UsSVoxOMZgQ1liStoWwpRxj1QZicFF6VFTf+vKMW4Y\nJdMfzPPXKGpwqIZEIOL0XV1iJjQGIypFUiNmNUEM7oK4MItTcagwevITNHI+XqtRa4ClHRqUKAkA\nsm1imNJtHy2p0SxYfWCeHXlKkgcRlWL6nCZEFGGzTonA3OgYcwiLbyxj7MhKuJ8ah9q4js63t0ck\nhLvSqFrpGDdbGeactKIinMfCdEuJ2IvTkbs2s4zBu8uV0zzz5njHr073KQ38EriOfMFF19y5myZa\n2XWFT8GfHj/yYb3yuKy8PMysxRAm/tfHn3gcC4/2gbnecXco3LeJH5ePvF/PvDh8QH3moQkSlUaj\n6SfOfuJp5P8vduCxv+DmM02MuQyOCj+NA7+ZP6IaHNl4XVfwM/9zfYu7Jz2LwoNsfLe94ErlN9N7\nKje6F/56/phqBHVcCj/YHce67RCYDK40NX64PdBd6FIRh4N2aht0GuvYRQld00Gm+2Z9pEIhtvLM\ndSTYjQzmgKJjV2eJ45R9EUaebfcRXejebe3OLJWM8YnIeW5kDaVugUo8eymy8FU+px/ocIzYJWPZ\nwqmUlFL5SDi67oWsVMKUqXRkX/oRkYWPCgTF4aT58R7KlSmP5JKoxMbgIIYTXOyY2WZFmGTbl4mJ\n5DRRem9IseRjyJKxRCJs0VBPqZ+6c2gjNbTudFpqkIk0T8QRFWdITShQBCJpttB0ghMK2240nsTw\ncLwoG4r7DiYiu+045LyefSyRMUGkocL3ReKwfOC3QaGhrskNtmBrqUVGhHkqfFpWbp6pEW7G7AcW\nMw5TpbUgerCFMzCGFw4i3M9H7qaZxY3bulF3Y5KK4JY8ir6Pbb6U64stuiKCe3C1je7GdXQOU+W3\nx9esW+eyDf7X7Sf+dH7H41hYwvir40tcOjfv/OPTz3TbeH28svQ3NC1MwCpPXB3ezhfuvHHUjqJU\nCf5y+pmfrr/hk90zESx+pPuRNQ6Ide6q0yh8jCN/O78DgoNsvKkrNl34u+Utz4qB4ZWXuvLt9sA1\nJt5MZ05lw115O59RIa3BHHjcGbNVCmKpvyWCSz/SXdmsIgZVx45OTAdRdrd7i/MMUNh2mdMAemQU\njWja0wY5f91VXc9LsPJ8at6h6a6B7Z9PA8Rit9rmiNiEz/E4DXDLpdmQLJ+5Ppd93hpkGFhu82W0\ndHbhFDFKcWgDWyrUirhS6UnvIpBIs4nLTOxz3oNcKcAqlcUrFqlRnXTkIswyRmnhyPDYH26DyQPT\nwL0ixeie82MEDroyrKZjjbRHa0Q69jS1uAkhD8KCzab9gWGJc6zO6DleSNay74u5iq/ZyZtnJ5xT\nF8eXHA1Rco6qW85w3aBH0starXgHH074YOwPuOEJ2W+aMJ2bdYYP1u5oq0wop6mx+kjQee/UKnzN\nkccd0L6G0aygAnNJPTAKxzrzzfGOi21c+sZ0u3Fo9Ytg6cIXXHSbKpsNrmu2ausYvJgPfNwWzmPj\nn58e6eI8tCO/Pr7ANDj3G39cfsensRAYWoLX0wMlOh+t89NYs/7oxg/LW9aYaTIQScXDVJyX9cY9\nGxVnFuNvpk/84/qKJz9xlCc2P+A+sfiMY9w1o1L5aCf+yzEDL2dWvp4ueMDfL28hgu4FU+VFWfhx\ne8HFJ2ZNqycRvJxuhCibFz6OI2vPhUmTnXdbHXMlomCDz5FA6I5iDElGAnvMeMhn7CFrurqqZRfj\nEjlb2Lus8nmtz15s8yruCbZxiCmLbTwzGxDUnRFkss3z36mk+B/2AMZdHuGpCZaSc8tSO0jBvWCr\n5lyVLeODGEg1xlZxLbgU5rEylxUpBTdlhDNsRqQTwKl0mndWUTaZMr8snAnDAqaaJ4VOoT8rszVo\nw5IdjOJNWclbIwEH3bjucqotys6CUGbZS7OQKz8RzGrqaLEsulbpFMbYgTU4tqtNxmAv4vk5wi3v\nhWTLa5LGEESeU5oy6ilGnoKK4ERCgvqKh2E64xE8tANRM27n3XLbRz5BnRphxvWzuC/YxqAzaEM5\nSOXV6cipTWy9835duJsmXk4Tr08nPAKL+CI63i+26A5Pc8TdNOERtKJ8+/jEn28feXddKBVm0nn2\nz0+feHe98EN/T9XGfVVetSPnsXLePnI3f5daziokqeCOKsYw49ELpzITKH9ev+LmRybpoIbKIGLi\nTb3xioTazGr8dnrifyxvOLtwpx/oPiMUFp/pOFPpaMCjHfnbw0/Z9WHcTxsSwe/tKyRgo9DEOZaN\nD+uJW8w7wzc34ndtMDBKUx63AzZyu19UsAhKM2x/c9Ll2QkLk4BFLpE6WU99n0ZoniK0k0ng7LV3\n9xDsDG2qJZe2E0TdXW2FZBwkwoy17uYKEabY58Wb4ZrgBt+ZsyIVIjPPtGTkjGvZRai2+yIGVQJn\n4NZycRVKqwsSgzIPMFhc8CggjUbnVG44ykZj88Lq6YorLkylM4VzE+XWk7PRXZlJTbRK0CW9yI7S\nxBGxtDFTuDIRxUET1bndKqc5C2Wp6SQ0E7wXCpmGweRsJJIzelBcMXF6VbyTD8SR9yqeb77t8j6X\nHNeQbAnzHNeMgLrPYCuCx4ZrsNqWYx/J8NNrrNyi40sqa1pRxnAe7u5BhF4zpidIq/WxNUSVu9K4\njcFqRqwbJZ/ldLPnqQ/y/3+L/mKvL7boPkNvPCITUC07pm8Od7xuJ7obP6yP/NOnd3x/e+LqG19N\nB9BCk8J3t0fOdqW1M9pfIKG8rBNb3Ph5O/C393/gtrwGn8ALEZWjdEY47+yeN3rj5jM/bEcGSpWB\ni1Mxljjwl+2MRWeWwUEHv53P/PfrW27k0uTmMyLCzWc2lPtyQyK4euOv54+58Ml3L499YvWGEglX\nkaCosfTC5o3NcmbbNBmrTsGL4ks6rdw02S05xiRbU0F2z36kaS3bUQftSScLAWt5nBXJjrfsEqlR\nd05tjob38MbALBgtcoApu7NNspsbWMblaA6Byx5x7tZz2QeMUSg1cvsmTqlBpRMhmCqxVRQDddSd\nqoMwWOwAIzv5ua00SbOARWULoZuiKhR3XumVzSsbE7dQLHJsUMKZilE9KXWrt7xvCNXTLjAXp1My\nsh2wUZKN7EHVYGieIjxSylgJrFqeRLaGbYW2d9dzs1yajYr3ZyZHJh2PrkSHz0P2nfuL5eLs+TU0\nyWmRSyaEjJJFeXikEoRAW2DRCYEn70woTStzrWy+8qHfQCqzwovTzNa3XSYoaZlGqVrofXA6TdzV\nxpvDiXPfOPcNu8HDfKD++3jhl31VVYY7Hy43WlXeXa48nA7cvLN14x8+fOJxrETArw8P1FaAwZ9u\nf+LH9cI5VgbGV/XEXGYY8P1yQxRet40fl6/Z7MhRAysbH8bMfzj8xA92h5NLLgl4VRb+cXvNe7/j\nr+uFR7/y82iMyOOpYRQxLnbiN/MjI4yjGPdl46+mT/zd9desIfx6isRGAptXlmgU9T16pvB6uhJa\nOPjKxY5sllv4Z+RiNpLJFnDP0ULsRS1JK/t8UNjnvDmfC7F9ZrBLEFKVhU05khB2V5qDex5Z0ZSa\nVWKXH4GVPF5G08+qB62RWV4ebMVpRUAzAZcIqNAZ+ec8CJV0swWEJiYwj80l2Q4lQd1lNjQGZso6\nJiIKbs5ce8qawnGCqx33qUhh0sEsK6s2Fp/wUlJap0bgvCgrEco1Gldpn8cDLYwaholiLqzUvNcO\nU9lY5IA7bN7wkt9vRQkvSA4jQIWOEi6ZWIyju/27j5IOu53N68Uwf9bPAft94Tkc03b9NELUnKmP\nSK5DiD4LzzAF9ysjKgsrW51A4KXM+fqF8Tg2quheVh20Zugl4CW42MIclWvpNC28Ph2Za6N7hlTe\ntcbdPPMwZTT1vk/9xV9fbNF9jgX5+v60/xp+Pt/44+2Rd7cz7sGI4L++esvPtysf+4U/XD7SBUpp\n/LqeiAjcFqz8gZvHnhAQTHJgHTOLFz50pxRF1Pnd8pbFZ06SEdi3CF7Wzr1uGAm2LgRv6pW/W97w\noZ/4j9OZR1t4NybMQTRz0hTjYi/4zfwutbtivKwLFeOfrt+wRUb6TJ5MXUNYeyHLcvJg5zoS0yiO\n+8ToJbs2S5KXaOLECorvEd5Z7PZjoQHe8o1iDjJgyg6HUZG92OH58Mg0y4xVb5ZMhiCBN8+1YFJJ\npURGixElO7V5h+08F9yhgetIa7KUTNwVI+puxBDNJdXIjhcxppaCYLeyf42KiGWCw7wxiWXkjM2Y\npFri0Doe2R3HgE0nPFLDPBXjICtbqWxjwoTPOmpJMCQCmAhDdH+3PRMpwKWg+9c8gD4qp50fKTuc\nvUda8MSTf0FN+tlY0xUXu/3PxTAtxLbb8SR2SUpamvPkkeOF53BKtYwODQepz3xkRyOXfesuL1PP\nlOeMdtpYn2fqlovPX7cHLJJUt1gmZUQETRqzCrXk/mQZgz6cQ6vgzmKD13KgqH5mXH8J1xdbdCGP\nrc8dr0hhcePFNHGSl6w+WBl8e/3EH84f+NgX5gKiM7+eXvOxX/nUL5z9wqsy42K8LhMjgg+r8R/v\n/4ltO7ExJyfAJ5bRuPjEj9aYtbMenX/a7tkoPGimAjudo1ZelhXIhrKI86t25n/c3vDoD/xm/sCj\nn/g4ZkYkJDtZAs7VZn49f8qFThSmaqyj8N31JQOlR2G3FVBxFtcsQLGbEdTRFqD7sXVLwAqxL9HY\n26pBVkrrOS9Mwk2eWl2QMCRGStlKasMkan5DGZyG1d33v//VImBjZBQxRpSWmtuQ3YhhKS+r5Bx7\n7MWiRqoENB8uVYURBdUgSlA0Rf3DCjFkTx/OcUA7dCSEHhPXnt+AaFDdme6TWXw2ZVjDFMSdYxl5\ndC7KOpQ1ao5kIjGOB1uQWrhFxUj2sIgh5vl6SMNd9mRjwI25JDfByWWmSyRm0yO/XmUHrgtu6Ztu\nZTAGuCh9lJ0OT6o5YrcCPstINE8VYfvHAY+W4xrJmW8AFFjFsQBRzQddyYy1YPC0F+W7NnGownU4\nn7YbPYzTfEiinUcGgzbokokYRKWbc5wnHuaZl/PMpXfOW+dqxten0y8+G+35+mKLrkpGg3//eKYW\n4eN14b5V7mpFAn7//gOP68qHdeWoE69eHDnVwo/9Pd+e3/PBrqze+WpuTPLAxMz75QKyEq3z0/ZA\nj8aDFq4efLe84r+9+D1P119x00yv7T6xeuXRKo924qGsfIoPfOozneBVXUAGykaVA6/rSvPsbirG\nN+3M/3t7yyUaX7dHJjWe7ECP1I5WyeXSFpXXhxvdChermFR6KJflxHDhZslfVQ0mjJs0HMEt9mEs\noE7sybhsqTgQ993mq1k1Bzus3JIP8P+x92ZNliTHleanaov7XSJyqSoUAZDdHIr0//85LTMcYTcJ\ngAQqKyszI+Iu7mamOg9qkeA8UGYeukkKk/6CQt2KiBuLq6upnvOdObaA6NSiGIyIcVFAK680P50f\njgTMJsTAsc12cXoJQ0ak+7zGuhgjhePPNODcuAYqsTiSLRJ5A7BLI4pSTZ3qDVLCbEb4dCBFR1jS\nhgCXvvIyNcwioT7RMq21CPeeQ5/sSpUNR2k4d13A7RX5ErZjEYZk9q+F1ElpsO2ZXKLzltfC26CP\nxKKdpLFojOeBBL4RMDE2AVNhbCAp+mt0amvn7NaH4EqAeewVTxSOQdWBkuOhKRtmhWaJIvEUnMIS\nEOcmGw6sutAcduvk+bmHdMQzOuYylFh2/rLfOefKvXdqyrxdDyyqXFtnL4NDKRxL5ljrNzFWeL2+\n2aL7ej2sFXfj7XHh+R6cz59uF7Z748k3/tub77i0nXtv/M+Xz3wZG7sY7+oRFQ37rP2JL/tHNkAY\n1FS5tbfcRuXmnaSNS4HfbT+w+8KjGHdXPvQjf7N+4I/7kWepdIKx2hyeRuZTf88P5caH/sSnXhkM\nHtMGdJJ0lMy7cmHxQFRmHbwpF/7+9p6bZ1btrNm4dWVYOOVKnsd6S5TcUE/TbJZoXdhHiWLZY+aM\nGJoE29O8A/3Pmqe5PBM8iOUjcsVcwIOpGOaJAZ4nFtI1tKAmSA9J26Bjkqb+dsHNEAkew9BZ4FVg\nl68ef1fHFsfrREd6ifmzWcDLFcYuqMaoJONw6AEUNxiSkSGRuTYCAp5Tpydl9MKQFE8Bj6990BGM\njn4IOVZYu6jJSDQgMQZhE5agWSQJf9oQ6JOr+xrr42bh+lX+LBkbMZ5KEh36QOgW3TkplnPLpJn1\nUZB9BL4xgYgFj2IobhpdOU6SAOG8VjUjuCEMmY+IMTvpEtFzMojJOKQ8GF3pA0oKXTva2duIh2Z3\nNAkP+URHGN7mSQpQ5yiJQ06RONw7zY3b3nhTl3BMunN8NSj96972/6bXN110X22NhpC08dPLFTfj\nIS3oAudUudrG714+8fP9FjcDwt8sf0Gj8bnt/PH2keMSVs03KSEcaL3ysPyB+9ix+SNWqXzYVj71\nJUaWuXMx5e/u39FJvNc7d+Bmyq+XZ/7n9jaiZyzTPGMYn/rC5/HID/mZD/3Cx3bCXDjqDkRe1Ubm\nbb6yeOXumZJCjfDTyyP3Ma2oRFbaJCRQBHaXGB9odGHDY+RiTfGmk+Q4YTJEB+w2nWIei51IJQ+t\nqcz5YBTpmJnbDHZ87ZI1xWZORGLZ4yk4tDU4FqYJ1VBWWAdPigzDK1MpIdBmQY72H88+C6tEttg0\nWRg9lk09xXwTQWyQ86AkGFOfbB6KDRfIeSe70VV52ZdIxPAojIiQpun3QsG8xhgFodAiZojEJsTM\n2QUZxqo7GwtO5u4hHUtz2bZZ4BeDEOeMpIyW0BEjFpWIlQ9lQ8z/5XUvNvQr84IR4ZOiY/7tWXxP\nHuhIFWPM5ZejM9uuxxJPDdM9XGx+CpB+YTKQ4dY3EKFKpdRE64PnfmcfxqFWTMakzGWqxpgk5Xiw\n3fadh2Xh7bqylsLejS+3G0kTvzqf/tXu+3/r65suujVlPrw8Yw7P9w0VeH88sZbCH7584cP9wu+f\nX2jm/Op4pqiwe+VPtwsf2wtP7ca5KCoHfsg/8PF2pUmjcSWnBdR4S+Jqyqf7O/7r+Xdc/cx9HBhW\nMC/84X7mo51pI9xEn+0Dn++POMJ7vUw3rfGrcuFv7x2Vys0XmmUE55d+4IutvMtXsjqf20ojk9Q4\nssWNSeKUG0mdl1ZJ2RjAfVvoI7O50kZCpourk8gpYtrdpkIh0ieRFmmzjHBzRf0RVJRBvJYsDAgk\nw7PikvBJI1OJI67lxCiTm+vAMFzGtB8LnmY32WO+7LHgxkZ04DpRkTO9Muytw2GWElHCmZbiQeKa\nsD1AL5oUtYbkmFPfe0IGDE0UcWrd5pw3sW2KS6RCqBmeBtUGe1F6K9y7MuZC7DU+XEzpGfpsL8Og\nZZgKTUqcNFKoOoaF+y6eXxPu3sJBUjSQj1qd5pHirLvHDF4dL+B7zF0DDQqaRzDgbXa401yiYoRr\nIhKMRWJcIxa/ryQbSKUniedq/Jbo1sgujFTivZYyTwhhQXaBzXdEC1k1xgwTrvzp3nizVvbRqFp4\nWxdqyjztneMiLElZSmHJkaDxrVzfdNEVQDXMAsdakAbPfeMfnr/wsu183q/89vTAZgfUhX+8PfN5\nb7z0G4sWfjxkqigiFz7cnnm2G1089Lr9DeYHPvWGpw3nzmM74yTe6OC5V/64veNvjv/I5Vq4yJkx\nlLtVftlP/GmcuI/MmnZ+GgujnwDhfYpCrGq81yvl3ikSC7TdboDyqR+4jIVDaizeubdCmyaHY20h\nT5IUgYcMvMU8t7swrDCGMFpssyWN6ESHRrG1SMnVeXwdhGliTJ1t8oETnavrPKLbiGOzWiyJ0oyo\nMTDVYAIIk7IlWAtmhPfOtP7DqGG8IJjD6KvhIhah2SVmngQfIUtEpNtOzIndIo1CPFxpRfC90OZ8\nmmIsNsk8prSRYdd4CAzQukMyGpWbKeOeonNMPmfhxjaiOO0UHCMrpG6kFJyF7gT7V0BHqCYGGZMg\nwjmhuIiJitAtooVanzZpE7QY3sJ6PPapKAm4MUqCFp2ym864IA/r8PRyuxlD47SyWMPwoMnPhVpA\n2yFlI6ULWGZ4mekegjK4fwXnbCjCqZzwEaCcPpFxm3WWkjnkipvSZdBwnu83flxP7K2xlsy5VpLG\ng/Jbub7pojvceFxXkgjdjL/98DPP9x3rRvfBXxzOkJTPz5/4/dMXEOG57/zV8VeIOm0Yf3/5GbKz\ns3PQhaTK6gJ84G5X7mKIOVkKH+6/4mWstDHIaedjvvHQ3oAk3mrjS1/4uD/w2/VnPl4O3FmRAZex\n8qkd+KmduFrmmHb+qT3QTYNZmnaaRbx7Tnfq1tm1cO1LOIVceRmVWy8hPHCNeBcDV6EW4zpiphht\nZ4xRSER6rCV8Zo0lmTftzCXTAeIWoJzX7YvMYtBCG5oxqE4fOY76w7+iHIQGsxt+RVlpjgWZuMxF\nWJlQHaJzg6+Q9JA5JSgztkcj+daHgBTcRowTViclZ3SQLezBMjrUGD8wUZgumW0Lf3PSTtGO5YR7\nwi0jFrZhXJE0yDK4u4YNmxQhl+qkKYGylOiSeN1JKlNHmxPNdcJnNHQhpoypeTa1afNNqMcJRJR4\nj57CdfaqTMgzxbjFz0uGxe/Pg3PrY3wtxJqcSqNLjkDMWdjDS9KoGhAkEJrHiSmXEZJBdfoA14Qm\nQVUxc7bRadY4poohiGQKiZIUk8FaDuzW2XrnXBfeHo9kCYTnl9uOKvzwn+OFb+MqKfH5+UKzwb13\nBs7bw8pfvnnkw+XCx/uF3z195qfbhbUslCy8kcrz/c7HHq9riiyrX5ff8NJ2eh/81L/wSELSzqMW\nNnP27ZF1+Yj7nT3FDWqe+buX3/JlLAxzltL5Uzuh0knAW9n40hduo/Kb9TM/txO7H5DhXPrK5/3A\nRztzu2eOZSfvRjOlWaGKoXmjJqOq8dzC8XRpC6SZgeaFrcUmfvS4gaLqpYjpaYEH807MRmc0DR18\nBDxHic6XKX3yqOow4ut6Dl2uzQ5VWsSei4CV6CrdFB/+Fd8oIqiHFMyHIDQs4Al4+FhjlguTvB7H\naxFHGrjOJSAGJea8oolhYQ5IYvhwNKXoKk1Qc7yvsfVLkeJrIgyP35Nvr2nFUQBdYiSwtYKVXOt2\n1wAAIABJREFUzNTtgTjVjQ1BeooEC4JXLO5zLC6BlSwWeW6TI4yEBvdV55x8wivEI8GjaTj1ptJC\nhSDktcIUUQNBiJMR83dDEYFFGpsL2WIxqm50iW4dnOydrvEQwJ1hiQI0G2Q6XXuoPPQYgckJ3EOm\nuPkdGYmyFpSK4HTrvPQN1RPb2Dlo5U1ZyUm49J0f1jOGU3PINr8NL1pc33TRTRIZTUxzRJK4cf7p\n8sLPlwuf7lfOpVLTW44587RvfLi88KFdGWY8LgeSODVXPl83nuzGjlE10caZwyjczGm+c9EbZakk\nHTy6cxuZp+sj358+kYdy10IfhUblv7/8JV/6IZKK68Y/tQdunsjinGTneVS6Kz+sX/jl5UBj4dIr\nVYyXVnm2hXsvrLmz+aCNxB53P4ey01G0JrZrxMpsLdMg7KdDw+8/QHrMFCXHLBVkxqsLyCBNyBev\ns1dLYIMsjlSP4j1B52kEu9YXwWYEkM90CdEWCcQec1CxCGeUWexdiSOwaRRsk8ALTNebG6HlHcQy\nKYOU6em3eB0DIbpGT4LmGfvuoQ4ISIuQNKGph0vPEvcWGEhJRsUiTJNYZvkuUDRA69rj/WviTrCJ\ncUebYNXwIXQN1i5EEKj0MGpojlDPMf8mo82fS0hJcfR+ZVyEQwaYUj0L5YLYlOhlm2GbU3NsMf/V\nOfpBiGBLVYrvNNKErEc0EJ4wQtmScUwyzZShSpJEEqfT2Eee90zkDUKCEctPn2aNQyo8SCRuNwnw\n+nU3Dlq5951jXTgtkRbxivX4Fq5vuuh2j842q2LuXHrnp+cXntuNl7bz/nikaOJpu/N3Xz7hNvjp\nfuH7euSwVLIr/3D5wn3buNhGQsLayIFuV17G4O47DeOUEp9v78ELexugg3tWzv2CJudsg6d+4NPt\nDe8Pn7la5slWWlu45pU/Xt/zZBUz4VB2fmpnDrai4hykcxkV6o33hyvPzwuO8LxVSoZtKNtI7COR\nUyAAx4i4bVMnl45ZDqF8m6qEQVh2gVfLmFosf2CgU6nwCikPolULuJjkWGZZRyZbgLmcMZ3r9qis\nZImYGtM4EmuP+bIzYhQxWQR8nc1GgGZKkbOM8TViyVXx+tp1CyBf48jdQOucGw6FZujk2arEiEHU\nMVXYa6Tl9qhvXqJotZFxEneLbjBpSAT3rDBSyOSGvjpapkbZEBf6V5kB4InhI+beRNcbwW6K6QQN\na3Sb0cD+s48N30h0vVPexesCkjCtuAsoJBeS2px1hxBMdYQJxSB7i25XwnGSslE8kJcqA7QgLiSN\njtltZ+jAKaEQIUY8Yzg3u5L0gCBkrWRdWRQ8w5oT3Yx9DA6l8u640g26dZ5uG0mE0+P5f/Pd/u/n\n+qaLbhYJa6IZe+/svfO4Vn54OPG8bXy53/nT9Zl/ujwFfCQn/o83b7m3wcu28cfLM45hyfirw6+i\nazbnp/slmj7tHHKmGNSRgCu3Ntg1QhCLZH5//ZE+Qj6T68DcOZQrKs6Jxqd+5Gk7cSpXXvbETVZ6\ng0M68YdL5TIL8VoHn/uR1DuvDvpBZpHOsXS2ccRRLndlSGhXB4m2CZolCGM9jrhppkJYl7kZh0L4\n8/FZz4a8nnyDWSuT20DCPAT6WYBlxBFY9Wt+mveYDWuZ4HNXpM8oGBmQZnJF8q+plLH02/BXrsCI\nhF90FkZivIBJmCdmB4gK5OgghxkJDULaCDu0lmlpnskX4mkKqYAaRdVF8VboBiqKmCN5TBt4Ag2p\nVxgcQFtQ0IYoTomxzZS4ael//jkSpweaTjpbPBxijjKXf2PSgCy63hi3xEhHiOLPRGoyZ+BBuNMp\ncZx66fn7jMHWdKGljPR4mMII1V0xdofRCyqFJDvVY8dBUpxElhiJSNRzum8kKVSpgJKGY77xy+4s\nmkgqvMkH3h4OiAv30XkoYacuSckq/2kD/laurMq1NfoYwVqwwakW9tF52u58vF0porxdD7yrB5o5\nX/Ybv7t/pNmgpESSzFIT3uFpu3LrLbCDvvAgD3g37n3wYjvHVRHtrBLaSNtOaLrRTdhV2LtStPC3\nT79l7xGvUmvjJ878JjcKzoHGl7Zy1cYxN15aYZNK30O7ed0rm4XDac2dSwtd6etRHnUSwcjtO6DK\ntkkAt90RV7pNzW6PrVXSiAhPFgs2fYVz43OjBaOHbEy8xxomxUwxHLwy89OiwOkyWy3JWBdUG8kn\noLwE1Qzh/xXpHop+ojKNgLygBDx95qiFSi2WUq+UMiVcWSKQPYcpQQVWn0yIHMGWZshIkJSUDevx\ncIic8lgO6swXUvdYD3bi6K/z65gzhqEpEjR4HXlEfY8Zq+RYgul8jdcB99QOS55WXQLSnmCiyrDh\nUSBN+EqLfyXFy3w6TtVGcpnuvIT4oKYo9okA1MScBhJGSS0eal7ZdIUusUSUzmY5YDskqkKRzKZ3\nek84hVIzRXPAjNwoKdHU6TgHTRzSAdxp3rn2HRvOQ6ns6rxZ/oxW/XZK7jdedIc7b6d6IfKZjN8/\nPfHcdj7crjyUyloSj175u89f2Efn4+3KMVUeD5VTrvzjyzPX7c6H/ZndjLUkDrnSxs6+OXc6t97Q\nDHs/sXCODbAbX3zj3XkD7VSTcKNdEylvDAuYVxuZmjL/56dfc+8ZcaEsO7/0E+/TM5UgPj3vCwVj\nyTu3ttI9c9kVUdhbOJuM6BiHTxLXkEjP1UhhcCIaRkbc4O4S9lHP856eSQ/SSHO5gxHJstLB94Ci\nSwQ2RlcGYDF+mBpZlyh22mJ5OAjdqiB0j4+R+b4wizTdqcFVwLyHESItKMFv8LmFd3U0v4owFLdB\nNn1t13EN91NqIScQAx8eWMHFcHN6j68kHeC18BnewvbrNrGJwgS8A+aoxuyzE9I6NBQfwejRSXL7\nc1f6dWbg+rU4+6tGTgD3rw8M5r8O+52E+HbYP3uQRtpvJl4zDXkXvYEGfvKVkyySA2uZBkMV0wwz\nYbh4SAdzckR1zu9rPKjcMLkiXkkphx68gxelS8N8o43BSU4c0xpjoGxUKfQeI6tjLrw5rPTuvOwN\nSRtZhPNS/zff7f9+rm+66ELAzHeg9c6lD9ZSeDysvFkXrq3xtDU+XJ659J2M8pvHR7YewJDfffnM\n894YGL89vmezRnbl6X6jW2LnjkwBeDXFe2drzi4Ri0NKfLk8UkRorUfWlnbe1B0wiinbntj6AVIs\nau4Ce8sIxqfrj2y9ROji0njqlYeyUQgc4q0llhL+/WaZ4UpvMSv04dgkd5EJUI04yviKXxRLgZgc\nHp3TdECFPjdkSwKoR/pGEZms1liSBfnLKWL02RnG/DU4q6lYfHGZCzmX2U33ecB3SlFGJyLdVTCf\nvIexxB+vRTFwD0tqytP19totSgrNbBUGTiIcZzILWkpz9GGGtIDEZNfARMY3N9GXKcwMEwD8ajlm\nOD1PfbFqyLR6mDJG9qmJFSCRxhzR2OzEX8EUr8vIfz4qCCFsjKenaEQERMp0B2awNs/4MZMevhBP\n9BgPCUJOr9jGig1hSTvdOpLApMZphIiSV2vUdKf7SvcYBZgPqkbm2dfvQzNLiQfxMBhsZK+sueK+\nUFFG7zRxbiPkGd8tJx7rigP3vXFaFsziYZe+EY7u6/VNF92iyr13nrYNFed523h/POI419b45Xqf\non7lN6cHDrmw98bfvVy59sbuRsnC27KSLMYDv2w3fGZ5va3n6CBHfG7JSpdbEPfFyL1iOrAW44WQ\nwCsfnx9RD6H8EKOnxpvjM5KcarDtld3mPSvOhjBaaE+fLqcwQmDkYtw6LDmkXWrOPsIa60IcCUmw\nTyWB9SmRiq4x0g9ivOBf88hSzFVHkM2YFmBXMAuwgdhrxE/spC3n0O9q4ARTsvhUkhg9RhSvMe6m\nIB5QRMGhC6pRaM0VM0hpCZnXiI+1mXr5FVloEoWmO6koUhVIRBhDMGtF4rWhFmMENJZ1E+KNQROd\neluJyCETcg5X27CI4BGZZo15hXU53Hmv8+UxqWr2tZXV+fPIfM14g/geJMYkHl4G3BxNjow59B3x\nMHMsuBEWP7eYoa/gN0odeFNUEkMqYj1quCokIc/kCpdXE8c9unPNdHmMzltDnhfPrxGpEnIMWJE7\nrcWoaS2wlgXpmepRVLepvT45POganF519tFpNnizLGx98P544FhraIX/Fe73fy/XN110hztrzpxq\nwYEsyu+fn7lsGx+ul6+Omh+OJ/7h+YlPtxs/Xy44wvfHE+/XAx+uL1y2nU/bnZd9oxaNHKiRGX3Q\nrfG0j5lHnnmbHhgt/sguWyOvjZtsZJTmg9oXzGJDv4nhKnRTfn56DC2qJ4YYW0+8Od4RcbLDti1I\n7piHRKtPAIt45nYNQ0San0+GIaF1j5ufwfAcbjEPFKGN+LxuhhOFWDw6bFUja/j3k1kI+luMZ0Ti\nGP860gwLcOShJQXNzkiCW8Y9VBCYQ7GIG4dZFAXfnbTYBMUEm7dIxscgidJFSXNssWqI7aWBliia\nOemUWSnSwzkmw6EKkgpjONo03rPHQjSneAANgdSMpCm6/MwktskEASnJo6AmizEICqSw1zL3X6nF\n8fs1lidmHHO+7p3AcuU5dnDSiIItjD/L3MzDZt18jiASokLyMkcxQpLONrYJJ6qUYtioiBXQCyl1\nbHc8LUEkG9D7wFVJUjBN9KF0O2JcWVNnCIglRB4ifDk5igbIp4A24SBnFqu01AK9KfCYT5y1oqKs\nJYBS99EwN96sR4650IHr3qY+PZHkP8cL38QV2/L4521EhlMSeDweWWtm74Nmg98/PfHp+gIkvjsc\n2MwoSfnw8sLne7DFvqsHHmulSOLeGq11+pxD1gxZE3200IR6C+xfDaLVMWXaPmLJ4QNd7tw8rLNt\nQLUSsh8x7u6R7OqZT88nzBQ1YSRjGKzLHt3ZCA9/niJ2x9g95rVYjlhwDweaScwXBKUwJpVr0G2a\nI0zDHmphDIBgvUaHEoU4qyEp5q/hfpIZCd4pga+aZoOYz7pZJGO4IFkYc7njBlmcoYN0mPNTCzaC\nKlg3TAyVTHaFblTmkV1n0RZhEUFcGAO0O5KdJBqqfgF2J3s442QGt6kruxkyoGrMYFUdmzKtLtHx\nBvid6ESZJDTRCMk0yPPt9ARjhmomn/KvPKcDAkkzw5ws8zb0cEnGKCGTCLh3ckHFMTGShLZaPX4n\nzWKUkCRTUqZYkNVsgGpjN6NoIdmJnuaJZWSSvlAS3Ds0fQNDKLqFlloWXBaSxCKu5ox7j8VjctQT\nj+nE7kpzp3qiqLPWQrKFd/lM1cJujTYGKSu/Wd/wdlmnYVFY542Xk8Tv5Ru6vumim1UZ7vzh+Ymi\nmT9dXni7HjiWzLUl/vaXn8HDQHFeDrxbDijw958/8+W+cRlxDjzXNUYGbfDlfmMMoAtv1gVQHsfg\nvjeMyqd2R7PQMVYSzSG1UAwk4KYDvyeSRDFVj1GBLRstDrT4SFOWAwjcPSAmwxPjcgztqsdRvVkj\nF52z1GkWYC6dbMqWPDqq5AMbkSTRSXPe52QiejxpsGNDrpBi7GA2tbY6P2eMIQQjyyAlxzyO0TZi\nnqopOuIuUdRtvB7DI504FnSKjYTkcI2Bx9FfneLr/JjY3r9uv5MHIzlPOzHu5BTFSWcRFQPvIdZK\nWYlgnkyOFpsyQgGQs8bs08Hm+EFtxqfneADsU7JmAnmaOFwg7cIoTgFIkKd6IavQ50F6NsMzdy3+\nuQGLhPwsOUF/m3+rblDR+B49FqQJi8WYQxWZduaYsyJQpISEzTO1wBiVrM42oOiZlOBVuhz/UxCE\nIsqaQjrXMdocT5zripAYnkgcyRo7i5qcVU+8TWd2BjuNB1lY0sKaFs4l83hYUFFa69y7sabED48H\nlpzpY0yVyrdxyf+HPu4/9KjF3fn90xfcI5rnedv448szt9b5+XaB4RxL5c268KfLC7fWeb7eudjO\nURceD5XrtvGy3bntxvN2R1Q5L5UxjL03BOHTy0bOsdiwNBhzLnfZO4LxMna6OnfriBhmsCBsbnSc\nqzSEWH7kIVMdMOi14wSvFQvdsXfBUp9HVKb0KrpYtR5OLyGg327YiAQFm4s0HymO96/Lrthwxdfx\ngo0A1WQZBE68M3rE/PpUdrk5OUWku07YOE0jxYGOpRRz7xEzSR+xfbegbU/4efnzwyMTyzSJI6tI\ndNliry6yiFjPM8PNJ5Qno5FoEd/ClDZFYFF+NUMQAZChzohia95pqtPNFvNvHUREunhgJXZIOcwH\ne45lG2VwlwmNmeObakpXZ9dZZGUuEyXed5vzYyFUGCsSpxmgiLB5LChTtN1UM+7Ee5jBwThQtTC8\n49mwCYY/itLEJ3IyCvehwm1C447Lwn1spAS7xYPkWBa2vrPWErFE3qglXA7fr4+4OLe+cygVMfj+\n9BDsYFW+ryduY+OQCu/XM2+Whd8cH/nS7iTgu8Mjb2phrQsqcFqCpSsIp/946oV/8RnyTXe6cQlF\nlfDsRCz0UpTvWbm3waKFj9crH683Wh8ciiJj4aEu7Nvgcm/su3MumUWPSFLElaf7FetxDD6ssehp\nBvJK0Bpw1Mx9NA6pcG+Ns6700WkiXK0xstNscBTFR6V64uYd1cHFhNQzinAA7mMedcsWCqe5mIqx\nQBSS3RMQOt3cBu4aRVAiq9uZSbMkqjT61M7uI0d8uTtVCQODhxogDK1BHdMpnk3ZGT1meT4ivSGC\nETtDcrjdSIFK1CiQnhUnFm5Kim7NLCLjzckcopPrIDZBLhqLrq/wGw2bbBJHU450Y5ia18hQK5lJ\n+ooHRZ1rMNNYrDZximdUlG6DJSVco9D5HFRXh1JjjKApY9KQFMWsVCgudJ//fZojayalS5SRPDr/\nWGOyoCHCcJCcKXts/LMI6s5BC10Es05OmRppoqypTFnX4GodF+EgBUmhA1+XjI+OsIUuXOGQDmSM\njUG3eAdrily3JRcOZXndfyIoD8sj57Jw3XeaG0kSp5I5lMIqlbflARXn0u6oOu/Tid8cH+dCUtGk\nPLDyUCqnWjmtC4/LyvN9o49BTZlDLf9qd/u/h+ubLroiQk3K//j8maLKny7PLKXw/XrA3fnvn37m\nab9z2Rsq8FcPjxxq4Y/PL/xyvWDdeGkb75YDp1LZvfP55UKWEjD06cB5IHHrYWt9bndwuMugjYHg\nrJZZdCbzJiXZTgZqz7zRAwnnljov7QYlkH/nrEiL5dg2tZodJ7dKUlhd2Q0s3TEJqZeKUC1Fn+SK\nmWGaQ24lmdwHyRNJdzo5lkskagor6+KdMUX8w2JkMNwjpdcbJqEuiNctZrllhGVMBt4Lw6JQihtD\n59hDJVxZEgXWX40VRWFUssV7HyMKvxKa2DoUzU7fLRxoCpKVIjLHMJBzpAt7SSwpzSWY4218VRMs\npdJt4BJH/UpiiLCshWzGnpyDvQrZIhsNhVvbA/cw2biSlYJy8swLjaYRZZ5zorgE2Fwk1BFEDCkI\nx1zYrdNwkgmeM2suJFOGD9ac2dzYm5BrYTXIOSPJue4dNHOuGRvOuSwMN7besDaQpJzTEZHMbTSW\ntNLHxkmjJU9WeMwLdx346PRhlFw45krz4Co85Bp6Z83z8514v5zY+obboC6ZhRPnfOBxWfjx/IDh\nfLrfGcN5vx747hQUsTAhGae1fsU6fmvXNz1eAPinlyeetw1zuLadP11e6N345X7lvu0cy8L744Ev\n25VbH9zvnU+3G0fJPK4Lmw2e7jdkCH/68sRaC4e8QILb1qia+HS5stRCotMdLn0EbWm7oT3RE3Q6\nL60zfGA4b1JlI6RBN9vZadxHZF1VhSyJz33jZi34tETPWT2K3uZO1z0E6uoUByQKNjS6BF0LEqt2\nRlOSNloXLAc/N4uA9xhbeGNIDi6DJTQ0Xyh7RJwz4TEenaRK2GRF498LYTMWDEnRljrG6AvJZ+Sl\nMBMkQJLgbYnkiMmqMCmRbafODPaNbtqEJYdcqwCFGFkMjUUNUjjlGJuMEajJNS9cx8aqma5h+lCT\nSPUwIx8y9/sWmmWDKgFyGYTEdnNjDNhSjEsEY8+DTTw0xSkUDYWMJKFrLDdd4ntdLFxrXT24XALa\n4ZALNzdWSWiCrTcOKEhmCBw9/l4KSse57zvHHLrZZkYWZ/OY99ck3PaNw7KyaOI2DE+NWzfWUigp\nHmQP5wOX253GICXIWvnt4cQX27nujTXF9/Dr05k2nJoS3x9OfL5vnHLmfFh4Vw58v5740u4cU+Fx\nPfBYCseysOTEm/WAqrL3zsNSqaX8Rw+i/M/xwr90uTvnunyNZP/98xdAeFziyPOmLGy98/Hlxr11\n1lQ45MSxRmd633bcharC9w8nRJUMvOyN5Bab3nVlt4aTadaoQBrK+/WBPgY325GR+NWpcL83Dlpo\ndIY6n24bkmCRhR9L5bk1bAxe5E4Wp0hiVaFO88KnMYKbAKxasS5UwL3RPF5zqRy845bQZLSe0LTT\nPQXRyzNHbVO94OxD0JQZPVOSY2aU7AyLccEYEs4rCSeX6sYYBYgZMczjvXbccxgGupDJaI436zg+\nF2v4SuqEYgCLkY0qdRZA5jJwqFNcSUWoOcMYM6dskFWpKiQyVTIyRmz8U/ycXOGcjiDOaM4pZUoO\nXnJZCvsYFC0Bysmx6NMsPLfOfWp0l5oQEjsdlfhvWwp0JQKq0eF2iRltSrGwax7qCE2Jm7VIpB5Q\nspBLZh2DIkrRKKy1LOQUOuSqic/7jY5xzAvHNXEQjSgjGeE6uzfWmjkvlaecEfOpYoCSVpJ1zrmQ\na+Jp29n3RtLMY17JNXHZd25mZFfeLitLLqgLx7Ri6jQGrQ/erivvykrN+evuA+CUKqe88O545O1h\n5WXb2cbgkJR3p2O4/77h65svumsu/N+fPiICz1sc/X/z8MCSE//Xzz/zy3ZjdOM+Or8+v+GUE89t\n45++vJBRfrltvD+sPNaFI87n6y0Kr3S0FHLKLO5cDbwPdBTICWTQ3NjNqCinvJASPBwW7mNER+Gd\n3x7OmBtVM/feSbnzMjo1JRYKvy6Fazc22biOnTULW4e3RUgm9GJ83oGUMTEWqZgpGaFzwYawO2Rb\nWcTC6qvObopIp/tcXHnioHtwEZJxH4pQGSbkOZ1MqTPcGB6UMyXYrkRzDJRY8rRCABMtRhSioAm1\nAyphbBjiaGGiCwkpVAon2ZjGheME84oIYx+hdU0exU4yS65ID4nbDiwKGYWcp2IgqGULTveYbZeS\npkEvURdl1cRL34MUhlFSjjl0j4VVrcpFwL3TsrOgVEsMN3b1mKWKsJJYHDwlhg3anDGfUujEn5qR\n3VhV6b2TRMkkHqrykDPb1AEfSp7ZZc6pHLjbDRVlGxvJC9+VladJDSMJRy/klLjYjonwsK6UPFh8\nsI8we5zqwnVvvFtWcqkhNyPUHm/riXMp3HywaKK5kzTzUGI88JenN+xm/LLdedp2vjue+PF8xsxJ\nEmaPQymcl0pNaaY4f9vXN190RYRjqZgbVlaaO3+6vfB837hsdxYpvDuunNaF69jpvfPxcuGQM2/q\nwptT5em+sY3Gh5crVZRaM0s98Nw2coLPLxuLCEte0Qe4NaMPA0v8xanS90ZeCvfR2M1pbedxLSRf\nOZfC533n7gHmOUvhdFxYJW7mbdvZ5ca9O4daqV7Iq9B6585G25U1xSjgnA7YcJrC522QZwrrWQvD\nMsXudHZ6d7olSiqsRJS4+2CkKCZGCQCMK0tqbCPW+/soCB662hTdq2iPoEoPPm0mSJGNALoEsDzB\nVpHMXDABKJKmFMqiQA8IIE2CRZU0puzP4rWU4FxWzJ3sQms7ahnNzpKVNS2RWNCjgG298VASqpk+\nbJIdHOnCoUik14qhEpiu7plTifd8oyPmbN7myEMQjZl2MBiE5I2cEosrA6NLIYtzUKVYAo3FLZ44\nTHhNQslaOORKmsuoh5oZW2PvgyKJU1Xq7BbFEqecUQJAfigFlzPDjXtvtG6cDmu4BD2CWPvotCRk\nSbyTwo/HEz+lK5fRUTfOdeG7w5HntnFcMsdcaC040asK57LwuCxsI3LfllT4y1x4eziy5sJ5WXBz\nLvuGOzyuCyWlf4vb+9/l9c0X3eHG23VFRdh65+fbC7fWqEl4ceFhqRTNfLze+OV24SFVBDjUQloy\n90snz0ju7+oRm53WrW1UTWgnoDqaSQk2M9ZMpPYeMqdc2dbOx9vGmgpJOj+8e0fbB2R43ho5Q78Z\nvz48kFRZkvLhcqVLCPC/zydY4KFWbn3jsnU+jei41yIcJGEi9N64Y9hwTiWOxuesbM25K7yMRBoZ\nTcIxFWwUSrpy9x4R2y2z5FBF4NDNaRLAbacgrlQBSzFT7ALSIvoFKeGuYw8G93jVVoT06BUbgEdH\nm4lsL6ZqQjReKznGNzridzdskHImeeaYIya+eDAgunWOpXBYVnof9B6q3GSKV1gkITmjCU6WabvR\nxVmrclOjNyf3FDIvEo8pVCUB+rKvBLU3acHFuNngRXe2CdUpXjl6oefIDyvqkW/nifOhsvWOZSUn\nZaDoCH2siPC4xMjrebvRLHEulZSd87LwvLUAjGWhduVtPbJo4+o7jQiMPKcS4P22c1oKugsuTu+G\nqPLDcmCzHikTEr/vkTKqsWw85FhyNev4gFOprKWyaOZXDycey8LPlxf6MM7Lwl+cTxyWhdu+s/dO\nzZkfHx5YyrelTPj/c33zRXdNmd89fWGYces7m3V+c3rDean8j/SJn68vlKH8cr/xw/HEm7qyj3Cp\ntbvxS3vhTa4c6hFZ4Ol+o2hiN6Fk4VQXkmSetpgJH8kc1krrhqjQGYgqbw8LGDwuDxRRPqeNX9qd\nd+cDbs5vzo9YN24MnvYt8IGj89fH9yGf8sFPlxuWEqad3x7POJlHzXxqG812nvqgLGG6OGrhboaN\nzmaCJjhQqEvhkGJE8ULnxTIJSFqpmoN9wJWNcJP5SFQtuGesGG2M4AII6MiIriElkx5PZP+eAAAg\nAElEQVQyqpZwWwIcDhE1PmegqFBfeQLmeBrBhskRR5kmDnHESJiUw9zxIAsi0JrxmmhRSyXLQpWC\nt05BaAzubhzyylIyvvcwFzQLXW6FREJSIpM4SgC6YxQAm0RBkwAZhGojwQ3D3BlZWL2QhpKTsIux\nTzvEqonFQk/cNeRakjNnhNWFpsKiGc0Z6Z1DUrbhrFp4SIVdXpUTAIMlV0rKXPKgY+SceXQhp8yt\nv2ApU0rmrcIxLdxzj7lxSZxF+NV64snufNo3vuwbzZ3/en7g7oPdjEWF3TzGZnUhIfx4euA6Gp9v\nG0Uy358feLcevkJrzrWCO6e6sJb8n93tv3B980U3aRwXRSOqZUmF53bn5+uFT/cL1TSe8o+Z+xhc\nu/Hz9YVVEo/rwuOp8Pl6pfvg6XbDmrNK4cfzA7fRIptrNN4sK6fcOdWKSCJl4cPLFVQ5W+GwJq59\nBxGuFtvl77JwyImMoLnwh5cv0JzHsvLr05newuv+ed/YzNGiFHf+2+P3iAvP+8bTfieps5nyXx7f\n4j1RD8rTbUPZeDFYctgc6pLYbLA14+5QUkYdllQpCjcLw8aNhQTklF6phyAbNgibLweygk1hfpuO\nN5mYxNe5ntuM4mnxe3jdZruAlzALxNz1zxljTV/pWeG0OpBo1tE5O9YUNuQqGcfYmiE5SFxLrZTh\nFNWvxVNGaFYXSWHnNdibkzLs5hEc6RpAdoS1xLLo3tvc6Tl3G9QMwwaqKSRnk2ujqrH0G45kDd3v\niJ9ts8Fw8Jo4dOW4ZNaSufXGrQ2WklmPB05lYWsb7sZwWFLmbVlCPigLReGl7RxS5lQz97HysC7x\n3obRaJRpQFhr4pPf+dRvOMK7ZeWUC1/GRlehEoN0ScKRhWMt1KwsWnh/WMmbYtk4lcR5Xfnt4yP3\nPrjsG613HteVU/0PZ3T4X3p980W3jcGbZWVNiTYGl7bz4eWFqpnbGPx6PfNYKx9fbvxyv/I2rbgZ\ny1JZlsy+NySFNvQxr6Rj4pQLA6dK4bbdOKbl6wjj0ndcDHrix8OZgXEuCXPhuFT+dHnhsS4soqgK\nv+xXJEVk0PeHI7fSIpvNoGfnH26fOdbEsWf++s1bLrc7AB/uO6qx9X/UE78+rxSHn+9XXlrHZFAk\n89u1kr3iGJ9b6C434JACqL0clOu+s5uH5lQFtcQxRz7ibrA7IAuiUDzYCU44ydyYCMUwhQwNatWE\ny4bDbBb9V2paQGRSQHlG2GxNBElhj5UEdSglJdyczQKcc6oL3YzqiWaRwqEzwTYVIUkil8G+98nd\nMFJUS0L0pSidpIlbi4ihlJ02YlaiVRnd8QSaIw9uuPJGcnyNXLj6Fh1xgoXEaoon4aKNHtHJFDLV\nZboClToXXTqcVIRTLpxrwUXoPXByyZxFMm+WhS8WOm/1wkNdOKX5/cZBgIdD5SCJz32nC5xzPEzf\n1oVUMy9jUHLipTce9MDDUul7LL0EqDlRNPNQEo/rkZKVz7cbP18vPB4O/Hh6ZJ3z/eFOzYnz8sDD\nsvxHl4H9L7m++aJbU+Kyb/zcG92M533jtw8B53i8rvzjyxdul8YfX555t6y8X468WQr/eHuGi/Pp\ndmPNidNpYT0VLtsWJgUXjjnzvr6jpLB8DnVKz5gbh5x5PFZuPSDnJReqKH/95j1DI+akD0Oy88t+\n57+8eWQthad746XdaRqA9N88PLK3xruHI91hpfL3Xz7xeFiQUfmr4wPPY6Oi/O56JedMssFfLA8c\n84E+nD/en6bw33msB94R7IfuO88tuquO8FByHP1TdPxdLfSgE2FbcwRXbu7s3ckilJn4IAkGwWcV\nV9xi+VQ08JVDXj2/QtEUhLP4v0hyxhCqhNU5dehqtGFIFpZaIhLIhZxmLLiMsJrmBesjKG5tZyPe\nb7fgHvw/7b3ZkmTZeaX3/Xs8g7uHR0Rm1oQCihi6CbNW9wXNdKEX0I0euR9AuugLtSSyQXBCoVBD\nVmZkDO5n2KMutmeiiRYpsptIAlXnM6uLqkgrS/NwX77Pv9e/1qgUplimlNFAop3Oq8BgFEuq+ArF\nVNaU28JIuXSMUXBGES+rxIkWWqRFMLXZ2oK07IWuajppNrpIoRiLzYLRQqc15wCxCia1E+7eN6tV\nqQVRBWc0Tpp3WowwascpBNYC+8PAvh9akak33M8Tk1JoUdx4w20/UGrlriyY1TBozSe7I1/NJ7LA\nqloe77VxrBSMNuw7yxoSTjWb2m03cOg8Tmv2znHwnrt5ptbK6ByDtZsz4Z/I9150tVK/La8WLmJT\n+OZ04vUlf8EaxcdXV5RLPsLTsiBJ01nNZ1cH1lpx1hBzJteMZOE47Nh5g2jNm2WG2k4+h32zhHmn\nqQW8awLklKLzDm8Vd9PKKgVjFM+6A4cy4hRQNUYMWQqxZK5tx955vl1OLCEQLjXmP7m5vtjMLEtK\nkAzfPp144Xt662GoPM4zYoTX61ML+EmVH10dkCyEEvlmPrfkLyrXbmwvkFJEFZiSojlVFb3RbaQg\niaW8DWARvBaMqHcFhoFMLu0SS4m0vV2BIPVSO8Ml6Uq9axkvtSK6ZSi4S5ZBy0tr3tmUKv2liUKl\nykqkUrHGYEWTVQupt1oTS2G99K+JclAyuihCyBTVxDvSRhhWmhhOa2tCjheHRDFtqWElYy/NGYnC\nmYhSwlozTnRbfNDSPLhag27Zvk6p1jWmakv4ol04llwZrGHUis5ZHpbAktNF+C29dUQFORWmXOiq\nYdQeesU5BpJUYi3sdUeOiVKFa2+YcqRWCDlhtGaops3JSyWQ2Xt7cU9ICxrSgkVz3Xe8GPd8fXpi\nyglrDS92I5/sr3gKK3OO9MXwbBy48t0mtv9Mvveim0vh4BzdMLbgkwp/df+KnXjerDPP+5EbOzDF\nwK8eHtkZy5oz3mieHfYYZXg9nzgvK6Px/Ojqllxza0rVLZqvV4paC8fR461nCYGHsNI7x6AUz3YD\n8xrpO0NMlau+/ZnBObyz5Fp4+XjG29aO8MPDFYmKVcIcAs8YeVlmeg8701aK/+7hoVnQauuq+uz6\nGl0q2lnCunJOgVOIfLzf4YpnlcC8JiKFhxBwSpMRPhmvWEJhlcjdurQvJirHbkBKBaU4p5WVlvZl\nRGFNW2/NZEK9BG5Xwaq2lqsuzoZUK7W06EJDq05HFwKt4LHEgqnNkaAvOehFAdQmbqqQS235w80g\nQSrtZBmkYlZYVKJSUFqwRqOkXRBZrVhzourWHzfa1tXlsiak9rppDdRMKYIzDkoiXEo3q2mVRJnK\nQKuw7y+ZyKtql6ODaKzo1rcmzbmCCL2xjMqxxLmVRehW6WOdRdXmHvBoxHBZZdbYrNHG4IzlPs2c\nSVgNOww75TmTONWVzhr2qufGj9R45imsvE4LORV+fHXLXCNv1rnFhJbCwRisc2RjOA49MbUR290y\nc911jM7itG3tJ8awlxZTeuy6zXf738n3XnSVKELOvFlm1pz5+vzER/0VN0PP1dDzxeMb7vPM16dH\nrLYcveOD3cDDsnAKgRIjo/WYrnAcBq58T6Dw6nwipMxoNTdXV4SUWkJVTmil6JSlV5rR+5YTK60l\ndXCOo+mIuSPk2nIDChw7RxXBaYMzirVkXj6c6DqLN4of9x1LjjhxnMPKR7sD30wnjtbiqsGowt89\nnogxMpXCB7uB61QYxTLVzBLgVFYShY/HHVYbQi48hpWiC9OaGYwhVc2HXcdUKnNcOOXYlkFqYXBd\nyzZTlTUnIplaWzi8EUWS3PyxbQXt0iIBIO1ELKU9ceSKru1GvFCoRlgAo9t8tcuaovIlVrFeOsza\naqwTCDljpDDrirOKJRf2zlMKpFBQRBZVENOabZVVhFrQAmtNgFClbRMGEZy0eX8qbTSgDeQibT22\ntgYFqXJJhW9jFFsrobQLP6OkZfxa3f4OKTGpxGAsRhl6bXlKMzk2R8tgLb3zlLQQVGXJiSUEnvUj\nnTM8RYU3ipgTIq0JxFQYtcUbw31euI8TuRZG5xmt5fU68yY3Z81gPb04ZpUouj2NaKX5oB95DCun\nGNFKUFrzw+MtCng1t7aUnbUcu/57mZnwL8X3XnTf3mQvKTWDu7TsvfMSLtm4lSSV292O3hmOxhNj\nxSlDDoWbvuPZsKNQmp1mnakFbvuBGAvXY48xBmMjd6cT1np2VnMcR+YU2wWRUgxoJLV9eas1Vjvi\nulBqpneO2/3ItEZCiu1CqMLt0CMIzjRrTlwqb+YzXhuGzvDj/oYQCkrD07JwO/acw8rt0F8Csyuf\nP7xBUAQp/GB3ZCmJnTY8pYJSkTfzhLGK592I1ZopB+aaSTURVMVdWgyuteZUM5HEnNs81SiNVYoq\nQjUtBDtRIbWsWLRqj/aqtEf70vy1WrWYSW0AWnh3yS0Hzooi6pbOlVRL7soUbNUUaRmJxTT/sqI1\naFgUS0yX0t2CUpoqle5ymg+xbYDVWtBav/siKLQksSVnaq5UVzDFEYm4CjFHsm6P6pU2m+7E4NGU\nehmpqNzi4Qt0WjOXDEpjlbBKxQl4JUwISjTWtNqnWNrIYsQwKEOwlrUkJMNV17E3nlc5tZN1LWij\n2ntOKhIDicJcEtd2YPAda8mUKgRa8JIylaPtccpwcI4368RvTo/svOez43VL0SstxH80lhfDjufj\n2BL5ttPt/xDfe9FNteCN5UdXXUvpF+EXb17RVcur5cxeez4Ydzij+M3jE9+uE7ZqrrqOq0NH71pw\nRy2KOAWqaK4Gx+g9c06cY8TmjIhwM44g0nyPuonteV1xWjOOll1veVzWSwh1ZDCGkKG3LVvWG8MS\nWxPvvuu53Y08zXObjabK6A2at6eQZoma48ISI85oDn3PuvZtrliF+7BwO+4IMfJCjaSaWKrli6cn\nlMASEz88HpiXNoN+EwMOxeO0YKzmRg+Ig6kEZkmXlKpLroCA14qJRJH2WG5EYWlimpUgupBzu0Ar\nUWEvl05ZKkYVYi2oai5ZuO1Uq3S95OBWQF3KLoXw7gTa4hoTBVs0WUWA5jhQXHowm8Ct6RJSrluE\nZBFFrytrKUgSgiqtzFIE5xUpCZiCy8L6th8ewYghkNC1jVSWKmgt+Ivf14iw1MS0BKyzOKvw3rHM\nC7EmlqywxrJ3lqIqa0k4bZjnBa0VzhpMilhtQSopZoLJeGVwVA6dJ5TEy+Xcxh3G8EG/54vlkYUE\naQEjPDMd59oaUpRWxBS5dT1OG3bWs+/afHb0jo92e15NZ0opeGM4+O57n5nwL8X3XnQVQimF13Fl\nTYlv5zNH57mxI8/HgVOIGKM4rxGvDIY2g93bHpHMq9OMk7YpdbUbUdKsXzVXRuWIMWGMYnAWqw3T\nGpnWAAq8ac3DUisag7XQp0Sp0BmPNYolJKbU0spKLYzWkKvgTfv3zlie1oAzwuhGzACvn85Y55nm\nleuh4zS3R9ZQKl1n+PYUENolzW3X8bQmCok1G2JcuekGSik87wxzCuATX85nDG2j6qPxwDklBm94\ntc5YDdNacM7iTdt8WqWtIefcmomdasHlzkAorQ8hUtGqhbK33ZCMMamFyieFQlNTRqtKUQVjhERC\nFd+ayIFqK9R8qV16W9HTTs1BcqsOVxVXFNGAiYooCcntQs4aAdX616QUlkRbZTagS7N1aRShtEWN\nlGkxlUrhlWatCSMGVw25RgLgdMEofQlTB7RqPnBtWstFfnvZqDBV0yvDVBNLTuiiGURhlcZaSyIz\np4xTmqN1nFIg5khVjqgKo7LoSxmmVe0ycC2ZIIWddyyxeZiXnFik1be/6Ho+2B34zfmR+7SyF7ga\nBn56vOEUI0uKzDFx5TteDLttyeFfmO+96L415b9ZZjptWvSc8wzOEXLm9Tzxcg1cdz3/9viMSjsh\nCIU1tksqpRRXfU9nNSEmXt/PdL5detyOI1Wgs6151TvFvBR6Y/HeUmthWiNzXFEZOuuaT1UEpRS9\nt6Ta2g4G61BaMa2RJUSsVu1W2jRB64xmjonOe9aY2I0eJ4qD73laZnojzDHyfOiYljY/PqUVrYX7\nU8IZxU45nl8NfDufaLHjhhAr18ZRRXFlYK7thPj1OrVb76q46fvWMeeEx7Qg0qIPO+MotEaMQCbr\nTI2tZUJf5r+9gSQBqRchRlNUxUqimIpVF5tVESiaWjOiK7UWrGmzYJUdpQAI1VaU5JZs1kp9iKk5\nJFYSWjTVFGxVrLpV9BQKhdZLp42hIlwq5Ii1UGKiGrkkjYOlElWhpEoogVSbi8KrZu3iUteepJBT\ni0PcW8tjiuSciLHNu53VaK1JMdJ3ljUXQm7xnoPRxFLxWvEQF+5j287bu5698YRaeYqR+ZyhVF7s\nrzjlwOtlYiJwSoFr36MvCyOjb1+8U0k8rgsH22ONYI1hZxzeGLw2zNFw2/d0xm6n298D33vRrZcP\n0U+PN6RSOXrPl+cTb5aJXOBZN2Cq5nbs8dpyioE3y4RVhhvfc319YAoBRfNu1qoYO4Mzht4btBKm\nJfLwNNE5gzGa3SBosTjd8rm6DDEn/CWVLMZETLltWmlhsL5dmqi2Odc7QykFZ9r/r/OGpymwpoKq\nwq5zQCt/FNMuqURa6ePeWbT0dDYyhUTvPbUEnu171iXhOsNdWEAbHpcJqxUjllvf881yQoxiDone\nalLVGN0Cdc51JRM5xYK6XBztnCfUTO80jzlipbDmgnXtAkqp0NwIkpB0admtCnRisIKoAEWx1pZ9\nUGrBa8g643Um1Uou+lJe+VaI2zggq4Kulnh5TYqpaEnviiGhdYWpLCw5YkRRTCtZnEl04sglU6tQ\nS8Q7Q9Htzi6n1hWXYqLq0mptiGitSLlQa2thgALFXPrWMudaWuGl0mjVSiWplTW1zGIrQtWaqTQb\n27wm9taiRREzXPkWi5lTJtRWm9QbS98ZHtaZ1+uEMsLRdXg0T1SmHPDa0jvDR+ORN+HMw7qy1ITT\n8PPrD/Da8uX5kVOM7K3lR8drvP7eS8Pvje/9KysiGNUsTp3RdMZwDhGMcHCe0XpiztxNE08hohBu\nhoFSKr13rJe57bSu7JzjpuvxzhNiRhTEVDBKIdZgrMY5C7UyzYEQEtZptFZY61tJoaI11tJ8qkYp\ntFEQFSFmyK2csvOWFAvOChXF0FnWOeN7jQgYbXicptZbBlwPHXfnCRELinbrLRFVhIPvUFJ5kMA5\nrfRGo6mI6wm5IFZ4tZwxTnO/TOycZ1kTh7Hjy3DGGUUJl9XWmloDsVYkWZCSeUoVU9vpdTCaLBGr\n2gKDV5WYKsq0yl2n23qtUomU1GXRRFFVYuxLO+VWYS2KeqmEdypTtMKbTCyQkkFEiGv70kIKuiqK\nARUNubQFDaUqSEZZoV6WMdaSUEVYcmt5VgZM0SySsUUTSmrFTpLaIoYIrqoW5F4LQZofN9M2yTqv\nqUUIKYFJJGk5wk4173CtgtXNczzFNjsfrcdmQ6mRcwyt8sZZdtqy6MpdXHExMcfAi3HfQnBCAA1r\nbk9JR99xLR60YrCOU1j59eM9nbN8cjhwsJ5YMyFnDr7jg2HHR7s9vbabM+H3zPdedAFe9CNfnp9a\n6pPAx/srSqnNEiSC1Ypz1C3azhqs0iwp8bSuVApWDMe+b3fYRjjF9vgWY+Fm6Dh0XQtkuQhmLAVr\nVUu80tIM9EWY1hlBY6zgrSXVhDXtVyTWUmrLddWXGaGUyjSvbVZIwftmbeqswWhh13WElBhcW9w4\njorHZUFqq/m5cj2v5zNWGZIC4zUqK/bGkpVh5+Cr0wl0KxEMMbF3PWuOKK94mSeshlNeuHYdU44M\n1vK6nOisIpaKwmBLRNvUQrbtgkRFqBXJmmIL3laoK94UIhkvsKT2BaR1wqtMqoLSra14iRZVK1Ui\nnS+ISu1EnHXL5VVtTKA6wZpWEROygapJreESfYlSLAZMsoRSUEiriq8tg8Hw28UOElTJrTJeVZzy\nzQlRhZgTSUDVViBpTPsyz1KYU8YowTmNM44pJbQqrCmRS8FohUVzUpnOWkLKUDOm95issAidM5zX\nwmOKlAw74zh2HUESj3FhXTOjcVwbz11cCZKZSyLUyofdiHdtTLHzllJbpc+fHG8viXqt4ePT/RWd\n2RLB3geb6NKCzD/dXzVRFLlcTq08ritahELlo/2BOcV3b8wCqBgZbbv9VSJMIXA3tyD0QVviJcov\nUziHyGlZsSiuh4Gx71oObG6XLbmW1morunkkVdsOmud26WaNprOWWDJWm0taYps71lrwujVZxFRY\n10xWl0ZZrVhTaolnxiFUYix0rt14H1THPCekCKZqbvqOr06nNtIgcdU7HkPkxg88yUKvCl/OES0w\nKkeSRM6GRSLVZO7zikYTa2ConkUlBms4y8xgMiHTLFNEbBeQmuhdICbFGoWMQUyiNwUlAadaypdT\nEHLb4utcYNCJWKFI2+hbL43Emoh3oE2gYglJU6pGVMZQyL5V+KRaSckAqkUcFoVIaZX3riLRtS+G\n0u7Zqmpbcq3IUlFKphSQmsiXuh65fJx0lXcXtMUUpGoy7cnFakGkZWKoWkkltpxe4zAI55yolZbv\nK8KV6/FieF0WOm+JsYXqeG8Yi2O99K1PdaWrFmuEwXhux5EynXgzTwzVcTN0/OR4y1NYCTkzp8jo\nPD+7PtBvYvte2UT3glUaq357S7v3Hm9aaWMTQ6EuMKd4SSFUPN+NxFze+WR1FmqqHIdm26oUznPg\n9TSjpDJ4R8ktM0Fn4XFdWULAKcv10DM630RY0RoVUsXo5ot8e+EXY2IKC/qSw+q0IdcWUpJKZugt\n0xqw0uphcimsaSUlIdTQRhimlUv22uEKmDqTUr14ThVXY8e0tss7VRWjM3xznvDOsNbMret5jCvP\n/cDLeGLnHC/zE/YSQiP1kizWBWxJTLmJetYFIxZMpjeFJAsHk5lLayJ2NuHUiqjK3q2sSTivnozG\nSGEwCW0WrFQCgq3Cmi2xKLxb6XQiF9WaKKhMc5uZamlfOsYFSnLErJCiWqAOiWzb00wuhVwMNWkq\niZoUSlcqrR/NFCFJ21BTNOtZUqX5cqU9eZTa7GQWwVpNd7lUq7XZz0KKeOPorLk4WyxFKiFksvF0\n+tJ1plqU5Sm1RpBBaa59z71amGLkfl1YBZ53A0XBXBLaCWtua9OmwO5Sa25VK7rcOc/Beb5dJkbr\nOPpuO93+K7CJ7j+C+x2rzLHrGHJ7k1rdig5fT1PLN6A1NxitLrPZFtJdL/XWe+ean7fC47oyhYBV\nlp0fWOLKWjM5Fp5CIObEYBxXXY8xtpn3Vbuk0abN2+oloUtpIYfIeWrzS2MU3lgqtZUVpkLvCnOI\nWGWaoFOZ5ubdfIoZtGKOgc5aelPp7I5X5UxMmWoqtSr6zlIAnSpRVZQVvlwf6ZxhrolnauSxRJ5p\ny7flxGgVU53ahaECpKIlM/QzMVXmrJBqSCphq0ZMYrQBCOxMYi2KXDs6F7Cq5fOOdiVlxePaEWvL\nph1NRJv2RLJWQ66KNbfeNmcjXmcKEIuhUAmJy4k4tbB0F8nZkHJrX8gFlEmUDLrXlNKcFlQhSoWk\nEF2b46C2pLZSK6mAUxUplcxl8eJtMptu4TqqFkQ0lMKacpvxW0vNhaW2+XJRMKh2EVtDACM4a1nW\nwJrbE42U9iRUQuasAgjcdD0fjQfuwsTD2lZ/nbH8/PoZnTH85vTI47owGMfPrm45+O49fYo2fpdN\ndP8ZiAjemP/6P/BsHFtoCE2kQ868niZSaQlcx843/+Vlk6dc2pedMfTWtg+j9jysM6oITjt6o9so\nI0emDFMI5JzZXaxs3ipiSmitybm0nFpJTYRFtQ20aeE8B7zTaGk+3ZYZrEg6UbIhxIg1GqOapWgK\nK1KEpzUgRlhSYuw8VcGYLV8/PWKN52WcsEbjShMqI+0EWE3k27RgtUF0Zpd2nGvk4BwTT1z5QJYV\nLT1et0tJozNH98SShVNw1OrodUJX2NmZnQ4UKfSmMGfNEpvzYk+hSmWwgVSEN9OOglA19CbgTKS0\noEZS0eTU7GbWJpwqLQEsO1BCjAKlCbHoiphMzYZMhdRWd6kVpVtfW6mtCUIrTSoJ8tuGC4NIs4fl\nKuRcKK4CzatspHXjiVKXU3FlSYmddRyUwitDKZkklcd1piDcdAPaKM4pkVVmToneWK59R7qMmqxX\nrCFxN09orfhwt2N0bX6LwN51fDAWPhz27JzfnAn/ymyv/v8g6neE2BvDi92OXAoicomODDys6yWz\nq3I7DjytK1pfeq6UXG66Pe4yThANb5aWjdvp5jg4p4Bz5tLfFhCpXHUdnbdYp8mpolTrRzO6BbIo\naZXoxijuzzMV2oaVcaBgMIaWfVDJKXNaI0a1+u9d3zGHllT1ZpkRJZzywofjnrUkDq7j86c7Omt5\nWc44I6y0FlspgvIrOkVOJaKVRbFA3hHIHBwgZ279GSSx5JHOpLZ5pxLX3ZkpGe6XjiwWr2acZA52\nZlCBImB1ZsqOJXl6F3C1EEXoTGuvuJvGS8OxMOhEZ88kY1jFkLIikahZYXVGm0wthVQcRVo7b62g\nVaZSwRV0NuTSut1qgljKZW4OsTavr2DItQIZZdrvI6e2yBFrbYsgWmOd5jwFjq5lb1BbKI5SQhEY\nXE9cFs4pttBy3zzXc22hOufSLlk/HHdUC7/JjyhdWWvmo37Hj69u+Xp64ikEdi7xye7IcTvd/kGw\nie7vAaPU3zOVj641oZbLfNiolj52jqEJLMLtODLHiFJvxwftw7f3vokkmilE7qYzOUNvmj3rKUaM\nMsw58TQvqALHYWAcOkqppNxms9Ma2im3ZIx1UAteGb45zRQKRhSdtyjb2g2WUjBh5WleCbU1CXTW\nkktlipFYCq/TCWU0gYVP+xvOJXDjB76Y79g5mDnTaUeVtporAqM/oVIhZkA5DGfW3KEFjt2CloVb\nNwGVKVu8SQgJryLXbmKJmm+WHalqjt1KJwntMp5EUW3Wesod52hxJjESCaKwF44w6D8AABQlSURB\nVHvYOThSFTLgVGZwMzE7QlWEpKBkJFeUabXxVTLUFpZTSqEkhdYFlEJsxRTVatpLq0EqNVO0xhVp\nIluhN7olrelKiZlcFClXeuta1RK1BZaXQqqZg+246h2nvLaSy0tY+GB7sm7VRVe+436d+WY+0xfN\nx7s9H+8OnC7vqVgKB9/x2WHg6LvNBvYHxCa674nfXaU89n3LVKCJtFw+dHOMQH0Xn5dqfSfESl3q\n1cceqNhimhAvEzlXvNWkBOcUEQVPMXBeVnRVHHyrFioVcm7B5NOcUKpSc7M0AZii+OZ0ItZCrpmb\ncWCumYNz3IcFFzV/tbxu7gkljNaSqOSaSCnyWGeMKITI0dyw5sreOh7TG/YmYswjhYFSW7oVUrl1\nT5yycI6eWns+cCfWbFEUPhkmNCtHu5JEc447nEp4Ek4CV34hROHr5YpYDTsX8Kx4G8DJ5bFf85A7\nQlQYXXDSOs2qaJQuxAilqBagLhnXB0qyzFVRc8tXliqX+Mhyiapsl6K5Qi0FJS3tAt1WkBF518Cb\nVUZUi3lMqq0fd0YzhUIEuiIYaV/O5EQoiSlHDtazc465RmKOoIRS4MZ69s4x5YAzFqUqXhue9SM3\n3cBvTo9UKs/7HVfObwE1f2BsovuvyN+bDwPPhoFwmQ/by8n41XRmSe0RvzMGo+Td0gSaJioFroYe\npYSUM6c18HqKQJvlnmMkXuxQ5xh4nBeMUq10sO8IKTQ/akqt64pm2h9812rUc+U35xMxR6JkPh4P\nRDKf2Wu+XU+ErPjr5WuM1aRY2PkOY2KLV0wLp7JSsHgdqfUGVTXPupXCE8/sic6cieuBhYwWgwDP\n3SMWz0PoCPXAc/c1pQhGJT4bz+iaGW1gxfIqDU0sL//s/Eqp8PV0YCkOawsdgcHCXkeqqqRieMg9\nKSuQgjPpsm5sQCrhUjPUNucKRhdUMSyXxPtaCrUIkFGiW7WPhpyFqlvgkLGqBcvkgq4CBqwxDNq0\nBmAFTlouZUXo+468zq0e6LKy3F8WKfrecN33sMIpraS1MlrPT29vyaXwJsycU8Rqzc9vn3Pbj+/z\nrbzxz2AT3T8gfveiTonwfBhJLVTg3UXdq+lMKm3d9NB1LG9XhKUtWogI2mpG2xwTzlqewsIpRKiV\n3hnmlKhaKFKYa+FpWjBKGLzns6FrJ24R3swzoVTW9HYzqgMqsWa+np84x0BQKz8cXpCJfLa74eXy\nhJGB+/o5zlnW2Mz7RhxGMkmtrCURi6MfEiEdKLXno34BeeAD+0CvF6Z0Q0HopIXL39ozuiZehT1z\nsfxb9w2lKCyFz3Z3aApGMkEcb4JHgFGvTdBsQhP5dr5iSg5E4UnsbKZXLWIzJsNTFkrSVCpaCiKX\n/AaprKVl/bYiU7Cm4qqQSiVlgdJGBejaVndLaZVENA+11oJk1WIfL1bDzjfLX42ZmFs/22gNe+c4\nTYnOOqq0J5NCu4B1tq31vr0zuOrHVtvuPEffMdqtGPIPmU10/8DRl3rrt3hj+GC3J5WM0ER6CqFd\ndF0yYI99x3QJohZpflEVFdZLc0wg9CFyioEQM1UqfWdYYsb3jpQzJcH9PAOVY9/xbDdwDhGthb89\n3TOvkcew0hvLrmvBOyEpXk0nlpzpuzM38oJYhQ89vI6n1iqsf0WSkSVoOnFItQw6E8qZJQuhDOy6\nb5hzR8Dzo/4BLYkP3SNeB87JsRaNvWTYHvUZbTMv1z1PxfMn4x25ZByFH+7u2kIEwiIdp6DJKDqJ\nZNFoXTGSmINlLoaMab1nNtJpRxIhZ805ZmqWy3m0tguy3GIkUxYkt7UIkYx1pjUhZy4hDwVtXGuM\nsJbHvOK1Q2tFTgVtNKMoVtWWKWIpeAOhFp6NHc93PbFUvjmfmWKgKvjZzTNuu57fnB54WFaMUny8\nO/CsH7ZRwh8Bm+j+EfK7F3WDc9hLO6uS1i6rl4WnsAAKauX5buQUAu5iF6oOVIocek/nLAKc1pUl\nJp7WhXI5EceSud2NLCmx5MzdMnHQnv3e8snxwBQjSlX+cvqanBP3aWbnPEc7YLXjFBP3YSVU8PpE\nKR9icPywzzylBacU1/YLUrnmTdrhRVNw9Kpypc5M2bLUjp/33zIXx1IcP+tfI6rwwj7hdWLJhqk4\n/GUsctALjsg36567PHDbzZiS6ET4pL9HSyVUw1Ntc94EWApKIlVZrM6sE4SqKEVhqGjb1odX10Q2\n1laWCaWNFxTooltHXm1+bVGtwVjV5tueL2HpQ9fRSWur8KpdxGltOA4dd/OCd22ef16b68UZw4f7\nHVe9Z02ZnDPeGG77kb33HH1PZ7aP8h8L22/qO4LVGstvL+uuuo7OmHeOCata1sM5BIRWf/7h4cBT\nCO9GGkOtxLrybDfSW9OCzueJmAuvp4lCy2AA4Th0zLH5h18tM7d6j/jCJ7sjS6p0+sjn699SEO7j\nib3pGcyAZkda4dtQqCT2+oE5PSdwxY/6lXNKWFF8YL9iLc/4OuzwIuRq8Shu9MJD6pmy5qfdHUsx\nzLXjp8MrlBQOZsHpxMt1xzk1V0StQqcSHYFvw56n7DFGMDWjVOZFt4DKxOx4KJWcNUkUAhgyWSlE\nVXJWlAiqVlQFayvdJQFMCm3Wm9qffRu6PlrXrHxV6JwmxAK1sOu7dtqtQsqJ3tl3l6e3/cjOe04p\nYE3zCz/vO35yvOEhBF5OJ05h5dj12+n2j5BNdL/D/O5F3XXXvytgNJexRXjnmGjG/w93O6YY3zkr\nSu14WBY+PR4ZrCGVyrfTmZoLr6YzFRi9ZTAju96ypsQ35ydeLys79SG6W3nRPWNNmp3uOeW/QMTz\nGDNXdofTO0SOEDu+Wo4Y9Zpn5oGndORUnvGTfmIq7QvlA/eaKd/y18stTgq5OgyKa7Vyn3rO1fCx\nf2Qpljl3fDa28YJWGW0Kj6vnnPxlSQUMFWcX7lfHFB1FDKZWnIooL2QBkuYxW0rRrdEDECmt4YKC\nVIsUmi8ZsE7jrEahGbBIBl3baOc4DLw+z4zWIVVQojj2lpjb7P3QdTyFlcewkCh8NO54sd+zXBou\nYm2LF//m+pZnw7jZwP5I2UT3e8R/s1EHPB+G1opb22lZgFDa5hO14ozm48OBcHmkhZY5cDdP/Oz6\nRWukyImvpydMhV+f3yC6tevedjd0DmJJfPH0wNfTgd7+lIOb+VPjCGVHkoqR/xujrrnPnoMcgT2i\njuS651fhhkHd86FdmfOeh3LDn/b3TNmi0Xxg71iL8IvlOU4qqfjmY1Uzb/LAm+L5tH9gLYq1dHzU\nP2IlkUSTxBCzcE4tgSxLC6rpTCCXri1QFIOqLUjduESsgi6WkKEWTY5QE4it9NqzEJuoZkXOtLXt\n1LbNLJrn474Fv1NRCAfjWaQweM+h75lTxBuNt61p5MXQKp6+enpqSXFdx7HrW439xh8lm+h+z5FL\nqtp/zfNhJORM5bf5E9+eT+9OxKMzjO6KUn/7cy2Ku/nM/3T9CTvfGilezk9Y4C+nr/BGk8Xzgf85\nzrZYw1+d3/CL04944Sq9mfnxsONUb3li5aD+HC+Gu3Qk6xsKO5LsyOXIF8kzqtd8ZANz2XGfb/h5\nf8+5WBSKj90dquz5L/OHGIG1rOSqOOiZu7LjKXtG28otz6XjxgeMRGJ1nLGUoliSIhdDKbqtOuuA\noSOnSkGhKyhVsV0lZoVBY6umBqFmWvuFzXywG3mcMoPSaKXIKeM6y6jbE8fNMHKKkTkG7ueJ637g\n08OBqoT7uTlOvDF8enXF82HcRgnfATbR3fhv0ErR/86j6/Nxx5oS0PzCuVa+OZ1YaqHWFrhy0/Xv\nQuFrrXhl+XZ64j8cP2PvPE/pzKvlhBPhL+fP6ZWhVsdg/hesjRyq5m+evuZ/f/p3/Gmn0CbwaXfg\nLB9zVxZu1V/TS8fLOHDQzym8JrIjlmu+DJ6dfs3HbmIpHW/KLT/tH5iKxQIf+m+5Dz2/nF4gta05\n1wo7vZCLYSmGKgZTFKlYOgNKL9RieUoVqZY1GyQLCk0WAzqxE89cClIN1IqxbdHEJsVOGTyeG1+o\nQtv0i4mPhj1GdCvXvORh7Jzn2LUK+861RDCpMFjLoevYObcJ7neETXQ3/kkYpTDut/5PTZv/Likh\ntMbiVArfnE+k0mpzPuoPvPA7lG5ui1Jv+Rt5yav1iX939TOOvuc+3PMYZpzAL8+fs9M9K5lZ/les\nzTxzwunxa/7j07/nzwYFJvOJ73msH/G6zNyoX9FLz5dxz1X5kMwdK3tSvebLODCoVzy3rwjF8Zhv\n+bSbmZKhimXnHklZ8/n5BqqlqJEKeDORS2UuIHg6LZSiKaJQLuNFOMWKrh6SwVaFItOLZqqZZ/2u\nXbiVllJGFQav6J2j05brcURX4fU0t3VgY/j0eOTYdXx7PvO0BrTS3AwDz8dxGyV8x9hEd+O/G6v1\n31tv1krx4W7PnFqC2GDbxdq300RWbfX2s90zfjDcYnW7yMv1A/7q4Td8uzzw8/1PufEjd/ENUwoY\nVfib0+fs9EAEvuJ/w9jKtYHH05f8x9N/4H/uK6KFT73lvv6AhzJxpf4OT88X65G9/oh/M5xZ6xWx\nHrmLPV59y7WdCNUy5SMvOlhzYqoF1ISl8iprSu4RuYIK3k6tfj5mjPLsnW2z3No2ypyymGI44Bjx\nGNPq5/ed5c155YW74vm4Y4mZNUYG24LFn48ja2opdJ21XA8DXut37pPtdPvdQ+olavAf4B/94cbG\nP4U1JeYYESXsrGOOkdfzhNG6VdYoIZfmtlAipJL4L09/x6v5DiWV227P6/CKXCtGwV89/QopipWJ\n590t3hhSLvx6/gbNxJ/1/4kpwy+XW071R/zZ7q/p5W/59fKC/+PhOXtd+Nn+xFoGznHHfXKo+hJj\nPOdcgD1rcsQsrGHkRGJnhHntSEkxqKtm9VILu7rjIQSeuZEBT4qVwVlU0ew7zwfDjpzhg8PIrR/5\n6vGRm6HHKM2Hux0fXR24X1ce55kXux3WGJ4Nw3a6/ePnH/wFbqK78a/CHCNzim2eeRHiu2XGqrbk\noUUIOdMZ2xozSuQvHn/JN9MdRlVu+wOvw2tKVXil+cunz6FUgqzcuiOj3RFy4vPpGzSBn/r/RK2J\nv5yfsfIZPx4eqfVX3K1H/vxxx6grz3pPrBapPUtSnMsrLCP3qeBlQFfPmqEkxzlldsYw1IElwq0e\nGLTjgZlPhiPzHLjpR352fM6UV3Kp7J2ns4Yf7K6YUqKzlg8PO2LOaBGu+r7FYm6C+11gE92NP3ym\nEJhTQouw855zCLxZFrxWpFIvp+DfCnEogf/n4S94Nd9TSVz3R97EO1Sx9LbjF49fUHImkDm6PTf+\niiVFfjV/g1C50n+BI/GraQ98zLPecSpfsSbP5w+GwQhHfyQU6KWDqngZ7tmpnvs104nnA9cEVGeB\noBic5SfXt6QMvTX84LDnq8cT176n15arseNH19fEUnh5OnEcegZjOfZb/OJ3jE10N/44Oa0rU0oY\n1arizyFwvyytnaEUECHlSG89SoS1rPzn+z/nzXIi1Mhtd+QhntBYdq7nF/dfkkoi1szBjnw43HKK\nkc/PL1GiWcuv6XTh68nQyQ23/ZG79Q05K+6ntrb7vLtmLpm99hy054vzEwfrqEEzKMO/f/4xU1xZ\nS2HvHQfX89FupIiQU+HZvlm/tAjXlz69je8cm+hufDeotfIUAlMMGFFcdS3c526e6S9WtlQLuWR6\nbTFas+bA/3n/S57CmXNeObodT2nBiuGmu+IX91+xpESu0FvPn+xueAyRL6Y3aNG8Xu/wRnGawaue\nT/ob7sKZkDI1gFeWz/bXxFjQWvHDq2vuTmd657h2Hm8tn93eoER4dT6zd47BuxZQvwnud5VNdDe+\nu9RaeQwr57CiUFz3PUtK3M0zg7WUWlhSolDojMUpw5oD//nhbzmtK09xZu9GppzQKD4ejvzy4RXn\nuAIKpww/ubplTonfnB6xSvFynnBKsNliRPjp/hkxZs45ctV1HJzn0+M1GuFxWTn2HdYajFLcDv02\nt/3us4nuxveLWisPy8IpBLQSrvueNWXul5netNLGOUVyzTht6I0j5MT/df8FU1x5s67srKdULla3\nG379+MBdmHBojFb85OoZtcA38xmnhBIqIPzwcEUulR/eXNFbx/000znL4CzjpRV64zvPJrobG2+F\n+CkEBLjpe2It3E0zo7VU4BQClYySFtgeS+YXDy+Zwsrj2hLZLJo1F35yvGYJiZenM15rvLL88HjF\nzjvenGes0ey8w2jN9bDlJXzP2ER3Y+P/i1or98vM4xoAuO57FPBqmhguSWtPIVBrpiBcdz0xZ359\nuuccIyVVjCiuXcccEx8fD1x17cLPqubVdUZvgvv9YxPdjY1/jLefAxF5J8QP6wrAlfc4bXh5PjNa\ni4jwtK4oAYWwd54CPM4La0rcDANaq+10+/1mE92NjX8u5fLZeCucb+bpnRAPpgXRPEwz/rKuu8aE\nUsKh8xi9nW6/52yiu7HxL0EuhQrv6pJOa+AcAlQwWm1LDhtv2UR3Y+P3xVsh1iKbFWzjLZvobmxs\nbLxH/kHR3Z6DNjY2Nt4jm+hubGxsvEc20d3Y2Nh4j2yiu7GxsfEe2UR3Y2Nj4z2yie7GxsbGe2QT\n3Y2NjY33yCa6GxsbG++RTXQ3NjY23iOb6G5sbGy8RzbR3djY2HiPbKK7sbGx8R7ZRHdjY2PjPbKJ\n7sbGxsZ7ZBPdjY2NjffIJrobGxsb75FNdDc2NjbeI5vobmxsbLxHNtHd2NjYeI9soruxsbHxHtlE\nd2NjY+M9sonuxsbGxnvE/P/8/B+sEd7Y2NjY+OeznXQ3NjY23iOb6G5sbGy8RzbR3djY2HiPbKK7\nsbGx8R7ZRHdjY2PjPbKJ7sbGxsZ75P8FyFNRFeuu/9UAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [] } } ] }, { "metadata": { "id": "DKTMw6tRZyK2", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# 🔪 NaNs\n", "---" ] }, { "metadata": { "id": "ncS0NI4jZrwy", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Debugging NaNs\n", "\n", "If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:\n", "- setting the `JAX_DEBUG_NANS=True` environment variable.\n", "- adding from jax.config `import config` and `config.update(\"jax_debug_nans\", True)` near the top of your main file\n", "- adding `from jax.config import config` and `config.parse_flags_with_absl()` to your main file, then set the option using a command-line flag like `--jax_debug_nans=True`.\n", "\n", "This will cause computations to error-out immediately on production of a NaN.\n", "\n", "⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!\n", "\n", "### Fast math mode on CPU and disabled NaN/inf handling\n", "At the moment, XLA's CPU backend defaults to enabling fast math mode, which does not preserve nan/inf semantics. (The GPU backend does not use fast math by default!) If fast math mode is enabled, the semantics of __inf__ and __nan__ are not preserved by XLA/LLVM, and the behavior of inf/nan values is unpredictable. \n", "\n", "To disable fast math mode on CPU, set the environment variable:\n", "```\n", "XLA_FLAGS=--xla_cpu_enable_fast_math=false\n", "```" ] }, { "metadata": { "id": "YTktlwTTMgFl", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# 🔪 Double (64bit) precision\n", "---\n", "\n", "At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to `double`. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!" ] }, { "metadata": { "id": "CNNGtzM3NDkO", "colab_type": "code", "outputId": "211d9880-4518-4a7d-f652-e3663274825f", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "cell_type": "code", "source": [ "x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)\n", "x.dtype" ], "execution_count": 0, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "dtype('float32')" ] }, "metadata": { "tags": [] }, "execution_count": 14 } ] }, { "metadata": { "id": "VcvqzobxNPbd", "colab_type": "text" }, "cell_type": "markdown", "source": [ "To use double-precision numbers, you need to set the `jax_enable_x64` configuration variable __at startup__. \n", "\n", "There are a few ways to do this:\n", "\n", "1. You can enable 64bit mode by setting the environment variable `JAX_ENABLE_X64=True`.\n", "\n", "2. You can manually set the `jax_enable_x64` configuration flag at startup:\n", "\n", "```\n", "# again, this only works on startup!\n", "from jax.config import config\n", "config.update(\"jax_enable_x64\", True)\n", "```\n", "\n", "3. You can parse command-line flags with `absl.app.run(main)`\n", "\n", "```\n", "from jax.config import config\n", "config.config_with_absl()\n", "```\n", "\n", "4. If you want JAX to run absl parsing for you, i.e. you don't want to do `absl.app.run(main)`, you can instead use\n", "\n", "```\n", "from jax.config import config\n", "if __name__ == '__main__':\n", " # calls config.config_with_absl() *and* runs absl parsing\n", " config.parse_flags_with_absl()\n", "```\n", "\n", "Note that #2-#4 work for _any_ of JAX's configuration options.\n", "\n", "We can then confirm that `x64` mode is enabled:" ] }, { "metadata": { "id": "HqGbBa9Rr-2g", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "from jax import numpy as np, random\n", "x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)\n", "x.dtype # --> dtype('float64')" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "6Cks2_gKsXaW", "colab_type": "text" }, "cell_type": "markdown", "source": [ "### Caveats\n", "⚠️ XLA doesn't support 64-bit convolutions on all backends!" ] }, { "metadata": { "id": "WAHjmL0E2XwO", "colab_type": "text" }, "cell_type": "markdown", "source": [ "# Fin.\n", "---\n", "If something's not covered here that has caused you weeping and gnashing of teeth, please let us know and we'll extend these introductory _advisos_!" ] } ] }