{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MAML Tutorial with JAX\n", "\n", "Eric Jang\n", "\n", "Blog post: https://blog.evjang.com/2019/02/maml-jax.html\n", "\n", "\n", "21 Feb 2019\n", "\n", "Pedagogical tutorial for implementing Model-Agnostic Meta-Learning with JAX's awesome `grad` and `vmap` and `jit` operators.\n", "\n", "### Overview\n", "\n", "In this notebook we'll go through:\n", "\n", "- how to take gradients, gradients of gradients.\n", "- how to fit a sinusoid function with a neural network (and do auto-batching with vmap)\n", "- how to implement MAML and check its numerics\n", "- how to implement MAML for sinusoid task (single-task objective, batching task instances).\n", "- extending MAML to handle batching at the task-level\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "PaW85yP_BrCF" }, "outputs": [], "source": [ "!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.21-cp36-none-linux_x86_64.whl\n", "!pip install --upgrade -q jax" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# import jax.numpy (almost-drop-in for numpy) and gradient operators.\n", "import jax.numpy as np\n", "from jax import grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradients of Gradients\n", "\n", "JAX makes it easy to compute gradients of python functions. Here, we thrice-differentiate $e^x$ and $x^2$" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = lambda x : np.exp(x)\n", "g = lambda x : np.square(x)\n", "print(grad(f)(1.)) # = e^{1}\n", "print(grad(grad(f))(1.))\n", "print(grad(grad(grad(f)))(1.))\n", "\n", "print(grad(g)(2.)) # 2x = 4\n", "print(grad(grad(g))(2.)) # x = 2\n", "print(grad(grad(grad(g)))(2.)) # x = 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sinusoid Regression and vmap\n", "\n", "To get you familiar with JAX syntax first, we'll optimize neural network params with fixed inputs on a mean-squared error loss to $f_\\theta(x) = sin(x)$." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from jax import vmap # for auto-vectorizing functions\n", "from functools import partial # for use with vmap\n", "from jax import jit # for compiling functions for speedup\n", "from jax import random # stax initialization uses jax.random\n", "from jax.experimental import stax # neural network library\n", "from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax # neural network layers\n", "import matplotlib.pyplot as plt # visualization" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Use stax to set up network initialization and evaluation functions\n", "net_init, net_apply = stax.serial(\n", " Dense(40), Relu,\n", " Dense(40), Relu,\n", " Dense(1)\n", ")\n", "\n", "rng = random.PRNGKey(0)\n", "in_shape = (-1, 1,)\n", "out_shape, net_params = net_init(rng, in_shape)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def loss(params, inputs, targets):\n", " # Computes average loss for the batch\n", " predictions = net_apply(params, inputs)\n", " return np.mean((targets - predictions)**2)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW4AAAD8CAYAAABXe05zAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd8VFXex/HPSe8hJIHQEyAJJXQIJQREpKhYQMSyiNhAXV2766r7WNeGItYH3WV9dAUVsa2gqIh0BEIMNSShhJBQ0nufOc8fN0RUSoCZ3JnJ7/165RUmuTPzmzD55txzT1Faa4QQQjgPN7MLEEIIcXYkuIUQwslIcAshhJOR4BZCCCcjwS2EEE5GglsIIZyMBLcQQjgZCW4hhHAyEtxCCOFkPOzxoGFhYToyMtIeDy2EEC5p69at+Vrr8KYca5fgjoyMJCkpyR4PLYQQLkkpdbCpx0pXiRBCOBkJbiGEcDIS3EII4WTs0sd9MnV1dWRnZ1NdXd1cT+nyfHx86NixI56enmaXIoRoRs0W3NnZ2QQGBhIZGYlSqrme1mVprSkoKCA7O5uoqCizyxFCNKNm6yqprq4mNDRUQttGlFKEhobKGYwQLVCz9nFLaNuW/DyFaJnk4qQQQpwvrWHvj7BuXrM8nQT3OQoICADg8OHDTJ069bTHzps3j8rKysbbl1xyCcXFxXatTwjRDLSGtOXwr7Hw4RRIWgB1VXZ/WgnuE1gslrO+T/v27VmyZMlpj/l9cH/zzTe0atXqrJ9LCOFAaspg4VT46BqoyINJ8+CuJPD0tftTt5jgzszMpEePHvzpT3+iZ8+eTJ06lcrKSiIjI/nrX//KwIED+fTTT9m3bx8TJ05k0KBBJCYmsmfPHgAOHDjA8OHD6dOnD48//vhvHjcuLg4wgv/BBx8kLi6Ovn378sYbb/D6669z+PBhxowZw5gxYwBjSYD8/HwA5s6dS1xcHHFxccybN6/xMXv27Mltt91G7969GT9+PFVV9v8rLoRoovI8+L9JsO8nmPgi3J0Mg28CD+9mefpmGw54oqe+3sXuw6U2fcxe7YN44rLepz0mLS2NBQsWkJCQwM0338zbb78NQGhoKMnJyQCMHTuW+fPnEx0dzaZNm7jzzjtZuXIl99xzD3fccQczZszgrbfeOunjv/vuu2RmZpKSkoKHhweFhYW0bt2auXPn8tNPPxEWFvab47du3cp7773Hpk2b0FozdOhQRo8eTUhICBkZGXz00Uf885//ZNq0aXz22WdMnz7dBj8pIcR5KTxgdIuUHoFrF0HsxGYvocW0uAE6depEQkICANOnT2fdunUAXHPNNQCUl5ezYcMGrr76avr378/s2bM5cuQIAOvXr+e6664D4IYbbjjp469YsYLZs2fj4WH8PWzduvVp61m3bh2TJ0/G39+fgIAApkyZwtq1awGIioqif//+AAwaNIjMzMzzeOVCCJuoKob/XAlVRXDjf00JbTCpxX2mlrG9/H743PHb/v7+AFitVlq1akVKSkqT7m9P3t6/nnK5u7tLV4kQZtMavvozlGTDzG+gU7xppbSoFndWVhYbN24EYNGiRYwcOfI33w8KCiIqKopPP/0UMGYnbtu2DYCEhAQ+/vhjABYuXHjSxx83bhzvvPMO9fX1ABQWFgIQGBhIWVnZH45PTEzkyy+/pLKykoqKCr744gsSExNt8EqFEDa38U3YsxTGPQ2dh5paSosK7tjYWN566y169uxJUVERd9xxxx+OWbhwIQsWLKBfv3707t2br776CoDXXnuNt956iz59+pCTk3PSx7/11lvp3Lkzffv2pV+/fixatAiAWbNmMXHixMaLk8cNHDiQmTNnEh8fz9ChQ7n11lsZMGCAjV+1EOK8Zf0MPzwBPS+DYXeaXQ1Ka23zBx08eLD+/UYKqamp9OzZ0+bP1VSZmZlMmjSJnTt3mlaDPZj9cxXC5dVVwVvxoNxh9mrwCbbL0yiltmqtBzflWFP6uIUQwmmsfx2Ks+DGpXYL7bPVYrpKIiMjXa61LYSws+IsWDcXel0JUY5z/anFBLcQQpy17/8OKBj/jNmV/IYEtxBCnMyBNbD7Sxh5H7TqbHY1vyHBLYQQv2e1wvK/QXBnSPiL2dX8gVycFEKI30v7Bo7thMnvNMuiUWerRbW4jy/FKoQQp6Q1rH0FQiIh7vRLNpulRQW3EEKc0f6f4HAyJNwL7o7ZKdEig1trzUMPPURcXBx9+vThk08+AeDIkSOMGjWK/v37ExcXx9q1a7FYLMycObPx2FdffdXk6oUQdrXmFQhsD/2vN7uSUzLnz8m3j8DRHbZ9zIg+cPELTTr0888/JyUlhW3btpGfn8+QIUMYNWoUixYtYsKECTz22GNYLBYqKytJSUkhJyencQy47FwjhAvL+hkOroMJzzfb2trnokW2uNetW8d1112Hu7s7bdu2ZfTo0WzZsoUhQ4bw3nvv8eSTT7Jjxw4CAwPp2rUr+/fv5+6772b58uUEBQWZXb4Qwl7WvAx+oTDoRrMrOS1zWtxNbBk3t1GjRrFmzRqWLVvGzJkzuf/++5kxYwbbtm3ju+++Y/78+SxevJh///vfZpcqhLC1vHTY+wOMeRy8/M2u5rRaZIs7MTGRTz75BIvFQl5eHmvWrCE+Pp6DBw/Stm1bbrvtNm699VaSk5PJz8/HarVy1VVX8eyzzzbulCOEcDFJ/wY3Txg00+xKzsgxL5na2eTJk9m4cSP9+vVDKcVLL71EREQE77//PnPmzMHT05OAgAA++OADcnJyuOmmm7BarQA8//zzJlcvhLC52gpIWQS9roCAcLOrOaMWs6yrq5KfqxA2kPwB/PduuGk5dBluSglns6xri+wqEUKIRlrD5n9Cm17QeZjZ1TSJBLcQomXL2QpHt8OQW6AZ95U9H00KbqVUK6XUEqXUHqVUqlLKnHMJIYSwtS0LwCsA+l5jdiVN1tSLk68By7XWU5VSXoCfHWsSQojmUVUEOz+DAX8C70Czq2myMwa3UioYGAXMBNBa1wK19i1LCCGawa4vwVIDA2eYXclZaUpXSRSQB7ynlPpFKfUvpZRjj04XQoim2PYxhPeAdv3NruSsNCW4PYCBwP9qrQcAFcAjvz9IKTVLKZWklErKy8uzcZnnr7i4mLffftvuz7Nq1So2bNhg9+cRQpynwv1w6Gfod63TXJQ8rinBnQ1ka603NdxeghHkv6G1fldrPVhrPTg83PEGsJ9tcGutGyfdnA0JbiGcxLZPAAV9ppldyVk7Y3BrrY8Ch5RSsQ1fGgvstmtVdvDII4+wb98++vfvz3333cfYsWMZOHAgffr04auvvgIgMzOT2NhYZsyYQVxcHIcOHWLBggXExMQQHx/Pbbfdxl133QVAXl4eV111FUOGDGHIkCGsX7+ezMxM5s+fz6uvvkr//v1Zu3atmS9ZCHEqWsO2j6DraAjuYHY1Z62po0ruBhY2jCjZD9x0Pk/64uYX2VO453we4g96tO7BX+P/esrvv/DCC+zcuZOUlBTq6+uprKwkKCiI/Px8hg0bxuWXXw5ARkYG77//PsOGDePw4cM888wzJCcnExgYyIUXXki/fv0AuOeee7jvvvsYOXIkWVlZTJgwgdTUVG6//XYCAgJ48MEHbfr6hBA2lPUzFB+EMY+aXck5aVJwa61TgCZNxXQGWmseffRR1qxZg5ubGzk5ORw7dgyALl26MGyYMXtq8+bNjB49mtatWwNw9dVXk56eDsCKFSvYvfvXE4/S0lLKy8ub+ZUIIc7Jto/A0x96TDK7knNiyiJTp2sZN4eFCxeSl5fH1q1b8fT0JDIykurqagD8/Zs2YMZqtfLzzz/j4+Njz1KFELZWV20MA+x1OXg75z60LWbKe2BgIGVlZQCUlJTQpk0bPD09+emnnzh48OBJ7zNkyBBWr15NUVER9fX1fPbZZ43fGz9+PG+88Ubj7ZSUlD88jxDCAe1dATUl0Odqsys5Zy0muENDQ0lISCAuLo6UlBSSkpLo06cPH3zwAT169DjpfTp06MCjjz5KfHw8CQkJREZGEhwcDMDrr79OUlISffv2pVevXsyfPx+Ayy67jC+++EIuTgrhqHZ/Cb6tIWq02ZWcM1nW9QzKy8sJCAigvr6eyZMnc/PNNzN58mSzy2rkrD9XIUxRVwVzukOfqXDZa2ZX8xuyrKsNPfnkk427vkdFRXHllVeaXZIQ4lztXQG15dDLuX+PW+QOOGfj5ZdfNrsEIYSt7PrC2Aw4MtHsSs5Ls7a47dEt05LJz1OIs1BbCWnLoefl4O7cbdZmC24fHx8KCgokbGxEa01BQYEMRxSiqfb+AHUV0NtxrlGdq2b7s9OxY0eys7NxxAWonJWPjw8dO3Y0uwwhnMOuL8A/HLokmF3JeWu24Pb09CQqKqq5nk4IIX5VWwnp30G/65y+mwRkVIkQoiXYuwLqKqHXFWZXYhMS3EII17dnGfiGuEQ3CUhwCyFcnaUO0pdDzMUu0U0CEtxCCFd3cANUF0OPS82uxGYkuIUQrm3PUvDwhW4Xml2JzUhwCyFcl9ZG/3b3seDlZ3Y1NiPBLYRwXUdSoDTHpbpJQIJbCOHKUpeCcoOYiWZXYlMS3EII17VnmTEE0K+12ZXYlAS3EMI1FeyDvFSn3VfydCS4hRCuKe1b43PsxebWYQcS3EII15S+HNr0gpAuZldicxLcQgjXU1VsTLxxsYuSx0lwCyFcz94VoC0u2U0CEtxCCFeUvtzYoqzDILMrsQsJbiGEa7HUQ8YPED0B3NzNrsYuJLiFEK7l0CZjUalY1+zfBgluIYSrSV8Obp7QdYzZldiNBLcQwrWkL4fIkeATZHYldiPBLYRwHQX7ID/dZYcBHifBLYRwHenfGZ9jJphbh51JcAshXEfG9xAWC62jzK7EriS4hRCuoaYcDq6H6HFmV2J3EtxCCNewfxVYal2+mwQkuIUQriLje/AKhM7Dza7E7iS4hRDOT2tjtmS3MeDuaXY1difBLYRwfkd3QNnhFtFNAhLcQghXkNEwDLC761+YBAluIYQryPgB2vWHwLZmV9IsmhzcSil3pdQvSqml9ixICCHOSmUhZG9pMd0kcHYt7nuAVHsVIoQQ52Tvj6CtxjKuLUSTglsp1RG4FPiXfcsRQoizlPE9+IVB+wFmV9Jsmtringc8DFjtWIsQQpwdq8XYpqz7ReDWci7ZnfGVKqUmAbla661nOG6WUipJKZWUl5dnswKFEOKUcpKhqrBFTHM/UVP+RCUAlyulMoGPgQuVUh/+/iCt9bta68Fa68Hh4eE2LlMIIU4i43tQbtDtQrMraVZnDG6t9d+01h211pHAtcBKrfV0u1cmhBBnkvE9dIwHv9ZmV9KsWk6nkBDCtZQdgyMpLa6bBMDjbA7WWq8CVtmlEiGEOBt7Vxifo8ebW4cJpMUthHBOGd9DQARE9DG7kmYnwS2EcD6WOtj3k9FNopTZ1TQ7CW4hhPM5tBlqSlpkNwlIcAshnFHG9+DmAV0vMLsSU0hwCyGcT8b3xk43PkFmV2IKCW4hhHMpPgS5u1vUaoC/J8EthHAuGd8bn1to/zZIcAshnE3G99CqC4TFmF2JaSS4hRDOo64a9q82WtstcBjgcRLcQgjnkbkO6qtadP82SHALIZxJxvfg4QuRI82uxFQS3EII56C1sZt71Cjw9DW7GlNJcAshnEPBXijKbJGrAf6eBLcQwjmkf2d8bsHDAI+T4BZCOIf05RDeA0K6mF2J6SS4hRCOr6oYsjZCzESzK3EIEtxCCMe3byVY6yH2YrMrcQgS3EIIx5e+HHxbQ8chZlfiECS4hRCOzWoxxm9Hjwc3d7OrcQgS3EIIx5a9BaqKWvxsyRNJcAshHFvat8amCd3Hml2Jw5DgFkI4tvTvoMsI8Ak2uxKHIcEthHBcRZmQlyrDAH9HglsI4biOz5aU4P4NCW4hhONK+xZCoyG0m9mVOBQJbiGEY6oqhsy10ONSsytxOBLcQgjHlPGDMVtSgvsPJLiFEI4pbRn4t4EOg82uxOFIcAshHE99jdHi7nEJuElM/Z78RIQQjufAWqgth1jpJjkZCW4hhOPZsxS8AoxtysQfSHALIRyL1Qpp30D3i8DTx+xqHJIEtxDCsRxOhvJjMprkNCS4hRCOZc9SY1Ep2RT4lCS4hRCOQ2tI/RoiR4JviNnVOCwJbiGE48jdDQV7odcVZlfi0CS4hRCOY9eXoNygx2VmV+LQJLiFEI5j91fQJQECws2uxKFJcAshHENuKuSnSTdJE5wxuJVSnZRSPymldiuldiml7mmOwoQQLczurwAFPS83uxKH59GEY+qBB7TWyUqpQGCrUuoHrfVuO9cmhGhJdn1pbFEW2NbsShzeGVvcWusjWuvkhn+XAalAB3sXJoRoQfLSjC3Kel1pdiVO4az6uJVSkcAAYJM9ihFCtFC7/2t87imjSZqiycGtlAoAPgPu1VqXnuT7s5RSSUqppLy8PFvWKIRwdbs+h07DIKid2ZU4hSYFt1LKEyO0F2qtPz/ZMVrrd7XWg7XWg8PDZSiPEKKJju40Jt70mWp2JU6jKaNKFLAASNVaz7V/SUKIFmXHYlDu0Huy2ZU4jaa0uBOAG4ALlVIpDR+X2LkuIURLYLXCjs+MJVz9w8yuxmmccTig1nodoJqhFiFES5O1AUqzYdxTZlfiVGTmpBDCPNsXg6c/xF5sdiXnpd5iZevBQr5KyWmW52vKBBwhhLC9+hrY/SX0nARe/mZXc9ayiypZk57P2ow81u/Np7S6nkAfDy7t0w4Pd/u2iSW4hRDmyPgBqkugzzSzK2mSytp6ft5fwJr0fNak57E/vwKAdsE+XBzXjlEx4SR0D7V7aIMEtxDCLDsWg384dL3A7EpOymrVpB4tbWxVJ2UWUWux4uPpxtCoUP40rAujY8LoFh6AMfiu+UhwCyGaX0UBpH0Lg28Bd8eJobyyGtbtzWsI63zyy2sA6BERyMyESEZFhzM4MgQfT3dT63Scn5gQouXY/glYamHgDaaWUVNvYWtmEasz8libns/uI8ak8Nb+XiRGh5EYHc6o6DDaBDnWbvMS3EKI5qU1JH8AHQZD297N/NSa/fkVrE3PY01GPhv3FVBVZ8HDTTGoSwgPTYhldEw4vdoF4ebmuKOgJbiFEM0re4uxEuBlrzfL05VU1bFhbz5rMowukJziKgAiQ/24enBHEqPDGd4tlABv54lD56lUCOEakt83xm7HTbHLw1usmm3ZxaxJz2NNeh4ph4qxagjw9mBEt1DuuKAbo6LD6RzqZ5fnbw4S3EKI5lNdCjs/NxaU8g602cMeLq5iTXoeazPyWbc3n5KqOpSCvh1bcdeY7iTGhNO/Uys8m2GoXnOQ4Ba/qimDgxtg/2o4tgMq8qEiD+qqjGFbgRHQqjNEJkL3sRDU3uyKhbPZ+RnUVcLAG8/rYSpr69l0oLCxVb0vzxhTHRHkw/hebRkVE87I7mGE+HvZomqHI8Hd0mkN+1bCpndg349grQcPH4joA627Qqehxu2KXCjPhf2rjBEBAG3jYMgt0O868PQ19WUIJ5H8PrTpBR0GndXdtNakHiljTUYeazPy2HLAGFPt7eHG0K6hXBffmVEx4US3af4x1WaQ4G6prFbY9hGsnwf56eDfBob/GbqNNcLa8xTDn7Q21k7eu8I45V16H6x8FuJnGfe34emvcDGHtsDhX+DiOdCEcM0vr2FdRn5DWOeTV2aMqY5tG8iNI7qQGB1OfFRr08dUm0GCuyXK2gTfPgxHUiCiL0x+x1gL2cP7zPdVyhjC1bY3jPgLHFwPG96EVc/D1v+Dic8b+wa2gFaPOEs/vwXewdD/+pN+u7beytaDRY2t6p05xpjqED9PRkaHkxgdxqjocCKCHWtMtRkkuFuS6hJY/jdIWQiB7WDKP6HP1eceskpB5Ejj49AWWHY/fDoTuo6By9+AVp1sWr5wYsWHjH0lh98J3gGA0f2RWVDZ2E+9cX8BlbXGmOqBnUN4cHwMo2LC6d0+GHcHHlNtBgnuluLgRvh8lrH28cj7IPHBxl8gm+g0BGatgi0L4MenYf5IuPJ/oYfsuSGAze8CUNbvFtbvPNrYqj5UaIyp7hLqx5SBHRjVMKY60MfTzGodntJa2/xBBw8erJOSkmz+uOIcWC2w+kVYM8cYETLln9Ap3r7PWbAPltwER7bBsD/DRU+Ch2te3RenZ7Fqdh7IoceioSR5DGRG2R1YrBp/L3eGdwtjdEwYo2LC6RLqfMu62ppSaqvWenBTjpUWtyurLoHPboOM74yRH5fMaZ6Lh6Hd4JYf4PvHjX7No9vhmv+Ab4j9n1uY7khJVUP3hzGm+orapTztWc6XAVdwx8BuJEaHMbBLiMuMqTaDwwV3nbWOzJJM0orSyCnL4VjlMXIrcymrLaPOWketpRY35Ya/pz/+nv6E+ITQIaADHQM7EhUURXRINF7u0rojPwM+ug6KDsAlL8OQW5v3gqGHt/GHosMg+OouWDAerl8MraOarwYno7UmuyzbeO+X53C4/DDHKo9RXldOVX0V1fXVeLh54OPug4+HD6E+oUT4RxDhH0HX4K70aN2DAC8bdn81UVWthU0HClibYaxTnZFbDkCbQG/G9wzj4YM/UR80iDmzb2322lyVwwR3nbWO6d9MJ6MogzprXePXQ7xDaOPXhiDvIHw9fPF098SqrVTUVXCs8hi7C3aTV5XXeLyHmwexIbH0C+9HQocEBrcdjJ+n805tPScH1sLHfwJ3T5jxlXHx0Cz9roXgTvDJn+BfY43w7tiks0GXZ9VWUgtS2XB4A5uPbmZ3wW5Ka0sbv+/v6U87/3YEeAbg7+FPqE8o9dZ6aiw1lNeWc6DkALmVuVi0pfE+XYK60C+8H8PbD2dYu2GE+dp+A16tNWnHyhpnKm46UEhtvRUvDzeGRrVm2uBOJMaEEds2ELXjU9idBZc+a/M6WjKH6uN+dO2jhPmGEdM6htiQWDoHdcbb/cxD1Krrqzlcfpi9xXvZWbCTXfm72JG/g6r6KrzcvBgSMYSLoy5mbOexprRImtWOJfDlHRASBdOXGP3ajqBgH3x4lTGJ57pFDrt4vr1ZrBaSc5NZtn8ZK7NWUlRTBEBsSCxxYXH0DutNz9Y96RTYiSCvoDNOJrFYLeRW5pJRnMGewj3sLthN0rEkSmpKAIgLjWNi1EQmRE4gwj/inOsurKhlbUZe46YCuQ1jqqPbBDAqJpxRMeEM/f2Yaks9vBUPnn4wew24SdfI6ZxNH7dDBbct1Vhq2HpsK+tz1vNj1o/klOfg7e7NhZ0uZFrsNAa1HeRaM6y0hg2vww//A10S4NqFjtenXHYU/jMFCjJg6r+h52VmV9RsjlYcZXHaYr7a+xW5Vbn4evhyQacLSOyQyPD2w23aMrZYLewp3MOGwxtYkbWC3QW7AYiPiOfq2KsZ22ksnu6nH7VRW2/ll6yixskvO3JK0Bpa+XkysrtxQTExOox2waeZMZuyyGhEXLsIelxqs9fnqiS4f0drzba8bSzdv5RvD3xLaW0psSGxXN/zeiZ1neT8feJWK/zwd9j4pjGRZvI7TZtMY4bKQlg0DXK2wpXzod81ZldkV8nHkvlg9wf8dOgnABI7JDKp2yRGdxyNr0fzLBNwsPQgyw8s54u9X5BTnkOYbxjTYqZxXY/raOXTqvG4zPwK1mbksTo9n4378qmoteDuphjYuZWxoUBMOH06NHFMtaUO3hwMPsEwa7VMyGoCCe7TqKqvYtn+ZSzas4iMogza+LXh5ribuSr6Knw8nHBGlqUe/ns3bFtkTDuf+KLjn5LWlMPH1xl98Ve8CQOmm12RTWmt2XhkI+9uf5etx7YS4h3ClOgpTIudRvsA8xbmslgtrD+8no/3fMzanLX4evgSH3op3hVj2LLPQlZhJQCdWvsyKjqcxOhwRnQPJehcxlQnf2C8L6/7BGIn2viVuCYJ7ibQWvPzkZ+Zv20+ybnJhPmGcXvf25kSMwVPNycZ/F9XDUtuhrRlcMHfYPRfnadlU1cFH19vLHB12WswaKbZFdlESm4Kc7fO5ZfcXxobBVOipzRb6/p0LFbNzpwS1mbk8UPGdtJrv8I9MAW0O+3dxjO12w2M6xFFZKjf+XUj1tfAG4PBPwxuW+k870mTSXCfpS1Ht/DmL2+SnJtMZFAk9w26jzGdxjh2H3hNmTHcL3OtsWjP0FlmV3T26qrhk+mw9we4dK6x0qCTyizJZF7yPH7M+pEw3zDu6HcHV3a/0vRuuKMl1Q07v+Sxfm8+RZXGiK24DkGMig4ntmMtGwoW8W3mNwR6BTKr7yyu73H9GfvAT2vNy7DyGbjhC+h2oY1eieuT4D4HWmtWHVrFq8mvcqDkAMPbDedvQ/9GVLADjjuuKICFV8GR7TB5PvSdZnZF566+BhbPgPTlxnjz+NvMruisVNZV8u72d3l/9/t4u3tzU++buKHXDaYNQa2us7D5+DrVGXmkHzPGVIcHepMYHcbohnWqQwN+ew1kT+Ee5m2dx/rD64kKjuKR+EcY0X7E2RdQnAVvxkP0RXDNh7Z4SS2GBPd5qLfWszhtMW/+8iZVlipm9p7JrL6zHOJUF4CSHPjPZCg+CFe/7xr9h/U1sPhGSP/WqcJ7xcEVvLD5BY5VHuPybpdz36D77DJu+nS01mTklrMmPY/V6XlsPlBITb0VL3c3hkSFMKrhomKPiMAmnUGuyV7DC5tf4FDZIcZ1Gccj8Y/Qxq9N0wv6ZDpkrIC7tsgiY2dJgtsG8qvymZs0l6/3f02nwE48NeIphkQMMbeovDRjOF1NKVz3kbkTa2ytvtZYWTBtGVz8EgydbXZFp5RXmcc/Nv2DH7N+JDYklseHPU7/Nv2b7fmLKmpZtze/cQLM0dJqALq3CTCWPo0JZ1hUKL5e57ZOdY2lhg92fcA729/B082T+wffz1XRV+GmznDRO2OFcSY49n8g8YFzeu6WTILbhjYf2cyTG5/kUNkhro65mvsH3W/OJJ7sJFg4Fdw8Yfpn0K5v89dgb/W1xuJUe5bCuGcg4S9mV/QbWmu+3Pslc5LmUGup5c7+dzKj1ww83Ow7AbkIPgcXAAAY50lEQVTOYiXlUMPmtxn5bM8uRmsI8vFgZMMa1Ykx4XRoZduzwqzSLJ7a+BSbj25mcNvBPJ3wNJ0CT9GKrquG/x0OKLhzo+MOR3VgThvczyzdjbeHG639vQjx8zI++3vR2s+LVv6eBHp7mHLBsKq+ijd/eZMPUz8kwi+Cf4z8B4MjmnHadtpyI9AC2sINnxtbirkqSx18fhvs+gLGPA6jHzK7IsA4A3tqw1Osyl7FoLaDeGrEU3QJ6mK358sqqGy8qLhxXwFlNfW4KRjQOYTE6DASo8Pp1zEYDzsv1KS15ou9XzBnyxws2sJDQx5iavTUP/4eLnsQtvxTLkieB6cMbq01w59fSV55DRbryWvycFONQR7i70lrfy9a+R2/7UVrf89fA7/ha/5e7jYL+5TcFB5d9yjZZdnM7D2TuwbcZf9RA5vegeWPGDvVXL8YAtva9/kcgaUevvozbP8YRt5vnHqbOMLnx4M/8tTGp6ioq+DeQffyp55/OnO3wVkqr6ln476ChmnleWQWGGOqO7TyNaaUR4cxonsYwb7mDFU9Un6Ev2/4O5uObCKxQyJPJzz9a3/+7v/C4htg+F0w4R+m1OcKnDK4j9NaU1pdT1FFLUWVxkdBeS3FlXUUVtb++vUK43ZhRS3FlbWcIuvxcncj5MRA9/cixM/zhLBvCPkT/hj4ep467CvrKnk56WU+Tf+UHq178OKoF+kabIcWsNUC3z0Km+ZD7KVw1T/BqwWtWWy1wrL7jO3QBtwAk+aBe/OuiVZVX8VLW15iSfoSerbuyfOJz9OtVTebPLbVqtl1uLSxVZ2cVUSdRePr6c7wbqGNfdVdw/wdZliqVVv5aM9HvLr1Vfw9/Xk24VkSA7oYm2aERBlL+cq66+fMqYP7XFitmtLqOgoraimqrKOoorbh37/ebvwjUGH8ETht2Hu4/aYV/5tWvZ8nIf5eZNck8eG+OdRZa3ho0CNM6zHFdr9gFQXw2S2w/yejFTPuaXBreRuiojX89A9jE4jYS2HqgmbbTT6tMI2H1zzM/pL93Bx3M3f1v+v8xjYDuaXVrGlY+nTd3nwKK2oB6NUuiMSYMEZHhzMoMgRvD8f+v95btJeH1z5MRlEG061+3HckB6/b17h2F14zaHHBfS6sVk1JVd1pW/WFFXUUVtQ0fr2kqo4Tf1zKoxSf9h/j4b8fa1l/gsqvpbVfUGM3zfGQP7HPvpWfZ+Ptk+5OnZNsjGsuz4VLX4aBM5rvh+KoNr1rbG7cKd4YGxxwFsPTzpLWmk/TP+XFzS8S5B3EcyOfY3j74ef0WNV1FpIyixrW/8hjz9EyAMICvBrW/ghjZPdwwgOd70JeTV0Vcz+9jEV1x+jl146XJy449YVL0SSuE9w1ZXBsFxzdAYX7oSQbSg9DVRFYao3xv0qBdxD4BIFfmLFQf0gUhHWH9gPBr/X519HA0hD2hSe06AvKq/jx6MdsKf4IX9WWKOsd1FS2bWz1l1bXn/Lx/Lzcf+3C8XXn8tpvuDJvPpVerVk78FVoN4CQE8K/lZ+nw7fG7GbXl/DF7cb/57ULof0Amz9FRV0FT218im8PfEtC+wSeS3yO1j5Nf/9ordmXV9E4+eXn/QVU11nxdIexHRUT25UzuFU57b0qcasqhOpf197Gzd1YzdEv1Jgq3robhHYHLwdcS15rWPYAJC1g5dAbebw4Ca01Tyc8zbgu48yuzmk5b3BXFBhTuDPXGgsQ5acDDfV5+kNwBwjqYPzyunsb/WnaavwC1JQZrdSiA1Bb/utjhkRCxyEQNcrYfdxOkwK2HN3Cw2sepqy2jEeHPsqU6CkA1FusFFcZ3TUFDf3xhRUNLf2KWgora6Ekmxm5c+hfl8JaPYC/1MymiKCTPk+At4fRF+/XcGG2sTXv+esInMb+fKNv32W2iDqy3VjfpCLP2EXehjNGM4oyuH/V/WSVZfHn/n/m1j63NukCZElVHRv25jf0VeeTU1xJpDrKhKAsxgQdpod1H8Fl6agT35MAyt3YRu74c1gt0LCG9gkHGe/XDoOg01DjjCOiX7P39f/Biidh3auQcA9c9BQ5FYd5aPVD7MjfwfSe07l/0P3n3a3UEjlncNdVwwudwVJjhHSX4cabNaIvRPSBoPZNG1mgtfGLnZsKh5ON5UMPbYbyY8b3Q7tDzETjo/Nwm/4S5Ffl88jaR9h0ZBOTu0/m0aGPnn7FQavFuPi24imw1htX5AfNpNaiKa5qaNU3hHxBRS3FDX34hRU1Rt99YxdPLRW1llM+TaCPR2P3zckuzDb24zfcbuXn6bhhX54Hn94IB9dDn6uN7dHOc93xr/d9zTM/P4Ofhx9zRs857UQri1WzLbthTHV6HimHimlPLuO8dnNxYAZxdTvxq2nYkcnT3xhvH9EHQqMhtCu0ijRa1N5Bf1zF0VIPVYXGe7Vgr7H9XO5uyN4KJVnGMT7BRgOk+0UQM8Gu3UZ/UF9rjHBKWgCDboJJrzb+TtZZ6pi7dS4fpn5Iv/B+vDz65fPauKElcs7gBvhlIYRFG6fBtvyLrTXk7YH9q2DvCjiwxuhq8WkFsRcbC/p3u9AmF74sVgtvpbzFP3f8kx6tezD3grkn7/vbvxqW/w1yd0GXkXDFG+d1caem3tIY8kWVtY0XaotPbOkfD/2G4ypPE/ZBPh6/Drc8Sas+5MSvNbTym7ROsy1Y6mHdXGP3ev82xs+u+0Vn/TC1llpe2vISn6R9wqC2g5gzag7hfuF/OO5oSXXjlPJ1e/OpryplhPtuJgfuYTjbCak+ZBwY2M7YxCIywWgUhMXY7qJy6WHI2misprj3Ryg7AijoPAx6TDLewyH2G1dOSY7xBzN7C4z4C1z05Elf2/LM5Tyx/gm83b15afRLDGs3zH41uRibB7dSaiLwGuAO/Etr/cLpjnf4i5M15cYvQNo3kPYtVBcb2yt1H2v8EkSPP+++8TXZa3hk7SMAvJD4AqM6jjL+gOz7ETa8aYwYadXZmCHY6wpTxilX11mMC68Vv4Z98QkXZX/fqi+qrKOq7uRhrxQE+x7vwvE8+SSq419vuB3s64nb+YT94V/g89mQn2YE90VPGq3bJjhacZQHVj3A9vztzOw9k3sG3tM4A/L3CzXtPVZKnDrARN9Uxvuk0rVqJ266zmhRRyUaLeBuFxqNjub4f9Qaju2EPcsgdSkc22F8vf0A473U83IItc2wRaxW2PkZfPc3YyneK96C3lee9i77S/Zz/0/3c6D0AHcPuJtb4m5xmCGNjsymwa2UcgfSgXFANrAFuE5rvftU93H44D6RpQ4y10Hq10aQlx0x+h87xUO3sUaYt+t3Ti2n7LJs7lt1H2mFadwRNoTZB3fhlptqzIAcdicMvR08nWvzhqpaS2PIHw/84sq6E/rvT2jtN3Tx1NZbT/pYSkErX88TunEaunIaW/onduEYxwX7ev52tmBdNWx+F9a+AtUlEDfFOI3vknDKDSU2H9nMQ2seorq+mmdHPsu4LuPYn1fOyj25rE7PI+XAUbpbDjDUI53x/nvpXb8b73pjRAgRfYyQ7n6R0ZXnCFO7Cw/A7q+Mj8PJxtfCexrbhcVMMC7Sn22XoNaQ8T38+Izxh6FtH2M4Znhsk+5eWVfJExueYHnmcsZ2HsuzCc+6/n6v58nWwT0ceFJrPaHh9t8AtNbPn+o+ThXcJ7Ja4cgvsOcbo0vlSIrxda+AhgtE8dC2d8MV/24nnxBjtUBRpjES5uh2qvat5JnaLL4O9GeUxYPn+91NUL8/OcYvfDPQWlNVZ/lDoJ9uzH1h5anDHoyWfYifEerHP7f1rObCgkX0P7oEL0sFVf4dKYu+ktBeF+DecRD4tUZrzfu73mde8jy6BHbmz5H3cmhvKVn7duNXdpAodYSBXllEWbNwp+HMIrQ7dBkBUaONj4A/dqU4lKKDRgNkzzI4uAG0BbwCje6bTkON92+bXhDc8bdnB1YrVOYbF4DTv4X076DkkDFCa8xjEHfVWe+spLXmw9QPeSXpFToFdmLemHk2m8Dkimwd3FOBiVrrWxtu3wAM1Vrfdar7OG1w/15FvtEvnvUzHNpknJ7qEwLFK8AIby9/o9+1usRYue/4SBjlDu37o2Mv5WM/T17a/R7tAtoxb8w8YkJizHhFTuHEsC9u6K4xLtSeEPwNn49/v6SqjrLqenyoYbxbElPd1zDSbSduyvi/yPNpw7OtPFnp686FlTU8l5uL/+/e+xa/cNzb9YX2/aFdfyPonHmJgcpC43rO/lVwYLUxpPY4Nw9jVIt3oNG6LjsKVmOTBTz9jO6fnpOMC8Dneb0p6WgSD6x+4DdnOOKPTAlupdQsYBZA586dBx08ePBcandstRVQsA8K9xmfq4qMYYi15caqfT7B4NvKGLLYrq9xunpCV0hKbgr3r7qf8rpynhz+JJd0vcTEF+N6LFZNWXVdY5gfPpZL9q6NFBxZzYrQbeR71XFBfhj9S8IJbh1BZOfO9OjWFf+23YwLw96BZr8E+6oqMkZbHdsFpTnGtZ7aciO4g9pBYHtj5EuXBJvPUD1acZQHVj/A9rzt3BR3E38Z8Be7r6robKSrxIHlV+XzwKoHSM5NNsa8Dr7fefa4dEI/Zv3IY+sew1158OTQ5xnRYTjeHm52X1VP/FGtpZYXNr/Ap+mfMjRiKC+NfumsJji5urMJ7qa8e7cA0UqpKKWUF3At8N/zKbAlC/MN418T/sX0ntP5MPVDbv3uVvIq88wuy+VYrBZeS36Ne3+6l8igSJZc9injuibi7+0hoW0SL3cv/mf4//D0iKf5JfcXrll6DTvzd5pdllM64ztYa10P3AV8B6QCi7XWu+xdmCvzdPPkr/F/5cXEF0ktTGXa0mkkHZUzFFsprC5k9orZ/GvHv7gq+irev/h92gW0M7ss0WBy9GT+c8l/cMONGd/OYHHaYuwxn8SVOdYEnBYovSid+1fdT3ZZNvcOvJcbe98oY17Pw7a8bTyw6gGKa4p5bOhjTI6ebHZJ4hSKq4t5ZN0jrM9Zz+XdLufxYY87zt6uJrB1V4mwo5iQGD6+9GMu7Hwhr2x9hftW3UdpbemZ7yh+Q2vNwtSFzFw+E083T/5z8X8ktB1cK59WvD32be7sfydf7/ua65ddz/6S/We+o5DgdgQBXgG8MvoVHhr8EKsPreaar69hV770RjVVaW0p9626jxc2v8DIDiP55LJP6Bna0+yyRBO4KTfu6HcH8y+aT0FVAdcuvZZl+5eZXZbDk+B2EEopZvSewXsT36Ne1zP92+ksTF0ofX9nsDN/J9O+nsbqQ6t5cPCDvD7mdYK8Tr6yonBcIzqM4NPLPqVn6548svYRntzwJFX1VWaX5bAkuB1M/zb9WXLZEhLaJ/DC5hf4y8q/UFRdZHZZDseqrby38z1u+OYGLNrC/138f3J9wMm19W/LggkLuCXuFj7P+Jxrl15LWmGa2WU5JAluBxTsHcwbF77BX4f8lfWH13PVf6/i5yM/m12Ww8irzOP2H25n7ta5jOk8hiWXLaFfeD+zyxI24OHmwb2D7mX+uPmU1pZy/bLrWZS6SM48f0eC20EppZjeazofXfoRAV4BzPp+FnO2zKHGUmN2aab6PvN7pvx3Cr/k/sITw5/gldGvEOwdbHZZwsZGtB/BksuWEN8unuc3P88dK+4gtzLX7LIchgwHdAJV9VW8kvQKn6R9Qrfgbvwj8R/0Du1tdlnNqrS2lOc2Pcey/cvoHdqb50Y+R9dWsjmtq9Na80naJ7yS9Ape7l78fdjfmRA5wSW7xJx3IwVxWhtyNvD3DX+nsKqQm+JuYna/2Xi7u/4qgyuzVvKPn/9BQXUBs/vO5ta+t8oyAS3MgZIDPLbuMXbk72Bs57E8PuxxwnzDzC7LpiS4XVhJTQkvbXmJ/+77L5FBkTw54kkGtR1kdll2kV+VzwubX+C7zO+ICYnh6RFP0zusZZ1piF/VW+t5f9f7vJ3yNj4ePjw05CGu6HaFy7S+JbhbgA05G3j656fJKc9hSvQU7hl4j8ss2GOxWlicvpg3fnmD6vpqbu93OzfF3SStbAEYre8nNjzBL7m/MLDNQB4b9phLLJMswd1CVNZV8nbK2yxMXYivpy939b+LabHTnHq5zJTcFJ7b9ByphakMbTeUR4c+Stdg6csWv2XVVr7I+IJ5yfMoqy3j+p7XM7vvbKe+UC3B3cLsK97H85ufZ9ORTUQFR3HPwHu4sNOFTnUKmVmSyWvJr7EiawVt/Nrw8JCHGd9lvFO9BtH8iquLmZc8j88zPifIO4jZfWdzbey1eNpys/FmIsHdAmmtWZm1knnJ88gszaRfeD/uHnA38RHxDh1+OeU5LNixgM8zPsfb3ZuZcTO5sdeN+Hn6mV2acCJphWm8kvQKG49spGNAR2b1ncWkbpOcqntNgrsFq7fW8+XeL3k75W3yqvLoG96X2/rcxuiOox0qwDNLMlmwcwFL9y0FBVOjp3J7v9sJ9Q01uzThxDbkbGBe8jxSC1Np79+eW/rcwuXdLsfHw/E35ZbgFtRYavgy40ve2/UeOeU5RAVHcU3sNVzW7TLT1vKwWC2sP7yeRXsWsT5nPd7u3kyNmcrM3jOJ8I8wpSbherTWrM1Zyzvb3mF7/naCvYO5Kvoqrom9hvYB7c0u75QkuEWjOmsdyw8s5+M9H7M9fzu+Hr6M6zKOiZETGdZ+WLOcSmYUZfDtgW/55sA35JTnEO4bztSYqUyLneZyY3GF49Bak3QsiUWpi1h5aCUAQyOGMqnbJMZ2Hou/p7/JFf6WBLc4qd0Fu1mctpjvM7+nrK6MVt6tuKDTBQxvN5yh7YbarJui1lJLcm4yG3I2sDZnLXuL9+Km3BgaMZQpMVMY23msU/U9Cud3pPwISzKWsGz/MnLKc/Bx9yG+XTwjO4xkZIeRdArsZHaJEtzi9GottazPWc/yzOWsy1nXuHFDt+Bu9AztSY/WPYgOiaadfzsi/CNOuSuJxWohryqPw+WHySrLIrUglV0Fu0grTKPaUo2HmwcD2gzgos4XMT5yvLSuhem01mzL28Y3B75hbfZassuzAWjj24a4sDjiwuLo2qorHQM60jGw4ylb5VZtpbC6kNzKXA6WHiStMI20ojRqLDX8e8K/z6k2CW7RZBarhdTCVDYe3si2vG2kFqSSW/XbxXwCPAPw8fDB290bDzcPquqrqKqvorKuEou2NB7n6+FLz9Y96RXai6HthhIfES+jQ4RDO1h6kPU569mev51d+bvILM38zfe93b3x9/THz8MPN+VGrbWWWkstpbWl1FvrG4/zcPNobPg8PeLpcxoIIMEtzkt+VT4HSg5wtOIoRyuOUlBdQHV9NTWWGuqt9fh4+ODr4UuAZwAR/hF0COhAh4AOdArshLubu9nlC3HOymrLyCrLIqcsh+zybIpriqmoraCivgKrtuLl5oWXuxeBXoG09WtLW7+2dAzsSNfgruc9dvxsgtt5p9gJuwnzDZNuDdEiBXoF0ju0t8OvvinrcQshhJOR4BZCCCcjwS2EEE5GglsIIZyMBLcQQjgZCW4hhHAyEtxCCOFkJLiFEMLJ2GXmpFIqDzho8we2rzAg3+wimpm85pZBXrNz6KK1Dm/KgXYJbmeklEpq6nRTVyGvuWWQ1+x6pKtECCGcjAS3EEI4GQnuX71rdgEmkNfcMshrdjHSxy2EEE5GWtxCCOFkJLhPQin1gFJKK6VcflFqpdQcpdQepdR2pdQXSqlWZtdkD0qpiUqpNKXUXqXUI2bXY29KqU5KqZ+UUruVUruUUveYXVNzUUq5K6V+UUotNbsWe5Hg/h2lVCdgPJBldi3N5AcgTmvdF0gH/mZyPTanlHIH3gIuBnoB1ymleplbld3VAw9orXsBw4A/t4DXfNw9QKrZRdiTBPcfvQo8DLSIzn+t9fda6+Ob5/0MdDSzHjuJB/ZqrfdrrWuBj4ErTK7JrrTWR7TWyQ3/LsMIsg7mVmV/SqmOwKXAv8yuxZ4kuE+glLoCyNFabzO7FpPcDHxrdhF20AE4dMLtbFpAiB2nlIoEBgCbzK2kWczDaHhZzS7EnlrcnpNKqRVAxEm+9RjwKEY3iUs53WvWWn/VcMxjGKfXC5uzNmFfSqkA4DPgXq11qdn12JNSahKQq7XeqpS6wOx67KnFBbfW+qKTfV0p1QeIArYppcDoMkhWSsVrrY82Y4k2d6rXfJxSaiYwCRirXXN8aA7Q6YTbHRu+5tKUUp4Yob1Qa/252fU0gwTgcqXUJYAPEKSU+lBrPd3kumxOxnGfglIqExistXa2hWrOilJqIjAXGK21zjO7HntQSnlgXHgdixHYW4Drtda7TC3MjpTR+ngfKNRa32t2Pc2tocX9oNZ6ktm12IP0cYs3gUDgB6VUilJqvtkF2VrDxde7gO8wLtItduXQbpAA3ABc2PD/mtLQEhUuQFrcQgjhZKTFLYQQTkaCWwghnIwEtxBCOBkJbiGEcDIS3EII4WQkuIUQwslIcAshhJOR4BZCCCfz//LLJWxJnAQTAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# batch the inference across K=100\n", "xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)) # (k, 1)\n", "targets = np.sin(xrange_inputs)\n", "predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n", "losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss\n", "plt.plot(xrange_inputs, predictions, label='prediction')\n", "plt.plot(xrange_inputs, losses, label='loss')\n", "plt.plot(xrange_inputs, targets, label='target')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import numpy as onp\n", "from jax.experimental import optimizers\n", "from jax.tree_util import tree_multimap # Element-wise manipulation of collections of numpy arrays " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)\n", "opt_state = opt_init(net_params)\n", "\n", "# Define a compiled update step\n", "@jit\n", "def step(i, opt_state, x1, y1):\n", " p = get_params(opt_state)\n", " g = grad(loss)(p, x1, y1)\n", " return opt_update(i, g, opt_state)\n", "\n", "for i in range(100):\n", " opt_state = step(i, opt_state, xrange_inputs, targets)\n", "net_params = get_params(opt_state)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzs3Xd4FNX+x/H32c2m94QESCGFNAi9EwhFmiggHQQVKSqK9Yf96rVeFb2IKF5EQUHpIkW69F4C0kkgQIDQk0AI6bt7fn9spCg9m0yyOa/n2Ydkdnbmk2X3u2dnzpwjpJQoiqIoFYtO6wCKoihK6VPFX1EUpQJSxV9RFKUCUsVfURSlAlLFX1EUpQJSxV9RFKUCUsVfURSlAlLFX1EUpQJSxV9RFKUCstM6wK34+vrKkJAQrWMoiqKUKzt27EiTUla603pltviHhISQkJCgdQxFUZRyRQhx/G7WU4d9FEVRKiBV/BVFUSogVfwVRVEqoDJ7zF9RFNtRWFhIamoqeXl5WkexGY6OjgQGBmIwGO7r8ar4K4pS4lJTU3FzcyMkJAQhhNZxyj0pJenp6aSmphIaGnpf21CHfRRFKXF5eXn4+Piowm8lQgh8fHyK9U1KFX9FUUqFKvzWVdzn0+aK/5V8I58vS+R4erbWURRFUcosmyv+2flGftyYwieLE7WOoiiKDXN1dQXg9OnT9OrV67brjhkzhpycnKu/d+7cmUuXLpVovjuxueLv7+7I8FbhLN1/ls1H0rWOoyhKOWIyme75MVWrVuXXX3+97Tp/L/6LFy/G09PznvdlTTZX/AGGxYdR1cORDxcewGSWWsdRFKUMSElJITo6mgEDBhATE0OvXr3IyckhJCSE119/nfr16zN79myOHDlCp06daNCgAS1btiQx0XIU4dixYzRr1oxatWrxr3/964btxsbGApYPj5EjRxIbG0vt2rX5+uuvGTt2LKdPn6ZNmza0adMGsAxfk5aWBsDo0aOJjY0lNjaWMWPGXN1mTEwMw4YNo2bNmnTo0IHc3FyrPh822dXT0aDnjc4xvDD9T+bsSKVPoyCtIymKUuT93/dz4PRlq26zRlV3/t2l5h3XS0pKYuLEicTFxTF48GC+/fZbAHx8fNi5cycADzzwAOPHjyciIoKtW7fy7LPPsmrVKl588UWGDx/O448/zrhx4266/QkTJpCSksKuXbuws7MjIyMDb29vRo8ezerVq/H19b1h/R07dvDjjz+ydetWpJQ0adKEVq1a4eXlxeHDh5k+fTrff/89ffr0Yc6cOQwcOLCYz9Q1NtnyB+hSuwr1gz0ZtSyJK/lGreMoilIGBAUFERcXB8DAgQPZsGEDAH379gXgypUrbNq0id69e1O3bl2efvppzpw5A8DGjRvp378/AI899thNt79ixQqefvpp7Ows7Wpvb+/b5tmwYQPdu3fHxcUFV1dXevTowfr16wEIDQ2lbt26ADRo0ICUlJRi/OX/ZJMtf7B0g3q3S00eGbeRcauTeb1TtNaRFEWBu2qhl5S/d4/863cXFxcAzGYznp6e7Nq1664eX5IcHByu/qzX661+2MdmW/4AdYM86VEvgIkbjnEyI+fOD1AUxaadOHGCzZs3AzBt2jRatGhxw/3u7u6EhoYye/ZswHIl7e7duwGIi4tjxowZAEydOvWm22/fvj3fffcdRqPlaENGRgYAbm5uZGVl/WP9li1bMm/ePHJycsjOzmbu3Lm0bNnSCn/pndl08Qd4tVMUeiH4dInq+qkoFV1UVBTjxo0jJiaGixcvMnz48H+sM3XqVCZOnEidOnWoWbMm8+fPB+Crr75i3Lhx1KpVi1OnTt10+0OHDiU4OJjatWtTp04dpk2bBsBTTz1Fp06drp7w/Uv9+vUZNGgQjRs3pkmTJgwdOpR69epZ+a++OSFl2ewN07BhQ2mtyVzGrDjEmBWHmfV0MxqH3v4YnKIo1nfw4EFiYmI0zZCSksLDDz/Mvn37NM1hTTd7XoUQO6SUDe/0WJtv+QM8HR9OFQ9HPli4H7Pq+qkoilIxir+TvZ7XO0Wz79Rl5uxM1TqOoigaCAkJsalWf3FViOIP0LVOVeoGefL5siSyVddPRVEquApT/HU6wbtdanA+K5/xa49oHUdRFEVTFab4A9QP9qJb3apMWHeU1Iuq66eiKBVXhSr+AK8VXez12dIkjZMoiqJop8IV/wBPJ56OD+P33afZcTxD6ziKopSSv4ZgViwqXPEHeLpVOP7uDnzw+wHV9VNRlAqpQhZ/Fwc7XusYze7UTObvvvmVeoqi2CYpJa+++iqxsbHUqlWLmTNnAnDmzBni4+OpW7cusbGxrF+/HpPJxKBBg66u++WXX2qc3npsdmC3O+leL4DJm1P4bEkSHWtWxtm+wj4VilK6lrwBZ/dad5uVa8GDn97Vqr/99hu7du1i9+7dpKWl0ahRI+Lj45k2bRodO3bk7bffxmQykZOTw65duzh16tTV6wO0nn3Lmipkyx+Kun4+XIOzl/P4bu1RreMoilJKNmzYQP/+/dHr9fj7+9OqVSu2b99Oo0aN+PHHH3nvvffYu3cvbm5uhIWFcfToUZ5//nmWLl2Ku7u71vGtxirNXSHEJOBh4LyUMvYm9wvgK6AzkAMMklLutMa+i6NhiDcP167Cd+uO0LdREFU9nbSOpCi27y5b6KUtPj6edevWsWjRIgYNGsQrr7zC448/zu7du1m2bBnjx49n1qxZTJo0SeuoVmGtlv9PQKfb3P8gEFF0ewr4n5X2W2xvPBiNWcJnS9Won4pSEbRs2ZKZM2diMpm4cOEC69ato3Hjxhw/fhx/f3+GDRvG0KFD2blzJ2lpaZjNZnr27MlHH310dbYvW2CVlr+Ucp0QIuQ2q3QDpkjLEKJbhBCeQogqUsoz1tj/vTBLM5fzL5ORl4GLwYUATz+GtQxl3OojDGoeQr1gr9KOpCilosBo5mjGeY5fOkOUX2WCPfzRiYp35Ld79+5s3ryZOnXqIIRg1KhRVK5cmcmTJ/P5559jMBhwdXVlypQpnDp1iieffBKz2QzAJ598onF667HakM5FxX/hLQ77LAQ+lVJuKPp9JfC6lPKWYzZba0hnKSWJGYksS1nGyhMrOZl1EpM0Xb3f1eBKiHsoB4/5EWjXht+f6Vqqs/UoSknYdfIS/1uTzNns85w1ryfHbjfS7gJCn3dtJanHUfjQpHIcT9fvTaxvbIm99svCkM62qDhDOpepLi5CiKewHBYiODi42Ntbl7qOL3d8SfKlZPRCT+PKjWlfrT0+Tj54OXhxueAyRy4dIflSMmb31aTIVfSaO5/Xmz1D4yqNi71/RdHCqsRzPPvrAhx8V2B2PAhC4q2PpLJja6q6BFLJuRInLqVxIvM0p3KOs+b076w9Oxd/xyBeavgsD4U9pBpAFUBpFf9TQNB1vwcWLbuBlHICMAEsLf/73VlKZgqjto9i/an1hLiH8G6zd2kX3A4vx1sf0jmTdYZe08Zw+NIGhiwfQvfq3RnZaCTu9rZzdl+xfb9sPcRHm77EELgRD0cvekZaXsvB7jdvTBUYzUzZepDxCfM47bCeNze8yS/7Z/Fx/L8J9wwv5fRKaSqtA34LgMeFRVMgs6SO96dkptB9QXd2nt/JyIYj+a3rb/SO7H3bwg9Qxa0Ko9u/xuXDr1LHtQcLjiyg+7zurEtdVxIxFcXqPlqxkE/2DsHgvYEeEb1Y1GMhL9Z/8ZaFH8DeTsfQuJpsfO4N/q/mOERaL/alJdJ9fk/G7fyBsjrTn1J8Vin+QojpwGYgSgiRKoQYIoR4RgjxTNEqi4GjQDLwPfCsNfZ7MyEeIbzS4BUWdl/IEzWfwKA33PVjm4T58GDNIHbsasbYVj/i4ejBiJUjmHZwWknFVRSr+GD1L8w4+S+cDQ5MbP8T78e9i5u9210/3sFOz5CW4awb/ibdfMdizIph/N6vGPHHmxSaCksuuKKZCjGH7704kZ5Du9Fr6VKnKh/3iOK1da+x+uRqhsQO4cX6L6pjoUqZIqXknTVjmX/iB5zNESzo/QP+rsWfp3pP6kUGz/+IfNflBDnVYnrX/+Hh6HHf21MnfEuGmsPXioJ9nBncIpQ5O1M5dDaP0a1H0zuyNxP3TeSdje9glmatIyoKYCn8r6/+mPknfsCpoBGL+vxslcIPUDvQi5VPfkoYQzmRfYAuvz7OlYIrVtm2Ujao4n8Tz7UJx9fVng9+P4Be6Hmn6TsMrzOc+Ufm89+E/2odT1EA+O+28Sw5ORP77JYs6PMNvi4uVt2+h7OBuY+9QHO3V8gwptB3/lPkm/Ktuo/ScunSJb799tsS38+aNWvYtGlTie/HGlTxvwk3RwMjO0SRcPwii/aeQQjB8DrDGRAzgCkHpjB5/2StIyoV3JS9s5ic+C0iux4ze31KZQ/nEtmPTif4X4/HCJVDOJGzlyGLX8BoLn9zYN9r8ZdSXr2w616o4m8DejcMIqaKO58sTiSv0IQQglcbvkr7au35IuELlhxbonVEpYJadmw1n+/4GHNOBBM7f0F1v5LtjqzXCab1H45bdm92Z2zizbUflOj+SsIbb7zBkSNHqFu3Li+//DIPPPAA9evXp1atWsyfPx+AlJQUoqKiePzxx4mNjeXkyZNMnDiRyMhIGjduzLBhwxgxYgQAFy5coGfPnjRq1IhGjRqxceNGUlJSGD9+PF9++SV169Zl/fr1Wv7Jd1SmLvIqS/Q6wTsPx/Do91uZuOEYz7Wpjl6n55OWn5Cem87bG94mxD2EGB91EkspPUcvnuC1ta9jzq/CF/GjaRTiVyr7dXM0MLPfSB76JY2lJ+YSl1yfR6o/cl/b+mzbZyRmWHcsrWjvaF5v/Pot7//000/Zt28fu3btwmg0kpOTg7u7O2lpaTRt2pSuXbsCcPjwYSZPnkzTpk05ffo0H374ITt37sTNzY22bdtSp04dAF588UVefvllWrRowYkTJ+jYsSMHDx7kmWeewdXVlZEjR1r17ysJquV/G83DfelY059xq5M5f9lyWbyD3oExbcbg5ejFyLUj1UkwpdTkFuYx8PfnMJklL9b6iAdrhpTq/oO8nXmp/gsYs8P4YPOHJGWUz3mwpZS89dZb1K5dm3bt2nHq1CnOnTsHQLVq1WjatCkA27Zto1WrVnh7e2MwGOjdu/fVbaxYsYIRI0ZQt25dunbtyuXLl7lypXzVAtXyv4O3OsfQbvRavliexKhelk99L0cvRsWPYsiyIby/+X1GxY9SXUCVEjdg7ttkyRQ6+b/BU83v2JOvRDzaJJTvNj5JvtMXvLLmFWY8POOericAbttCLw1Tp07lwoUL7NixA4PBQEhICHl5lsady12eNDebzWzZsgVHR8eSjFqiVMv/Dqr5uDA4LpTZO1LZdyrz6vIG/g0YUW8ES1OWMvvQbA0TKhXBW8t/5nDucqrbP8znnR/VLIejQc8LreuRdaI/qVmpvL/5fc2y3As3NzeysrIAyMzMxM/PD4PBwOrVqzl+/PhNH9OoUSPWrl3LxYsXMRqNzJkz5+p9HTp04Ouvv776+65du/6xn7JOFf+78Fzb6ng72/PBwgM3XO4+OHYwcVXj+GzbZxzNVLOBKSVjw7EjLEj9GlcZzvTe72v+LbN3gyCqOMbgkf8wy1KWsSxlmaZ57oaPjw9xcXHExsaya9cuEhISqFWrFlOmTCE6OvqmjwkICOCtt96icePGxMXFERISgoeH5UK3sWPHkpCQQO3atalRowbjx48HoEuXLsydO7dcnPBVV/jepalbj/P23H38b0B9HqxV5erytNw0us3rRnXP6vzY6ccKOT66UnJy8o3ET3mCfMNBpnaaSe3KEVpHAmDW9pO8NudPYhpMJsecxtxuc/F2vPUFZuX1Ct8rV67g6uqK0Wike/fuDB48mO7du2sd6yp1hW8p6NswiOjKbvxnyUHyCq/NB+Dr5MtrjV5j5/mdzEqapWFCxRaNWPAj+fZ76B4yuMwUfoAe9QMI8XGj8Gwfsgqy+M/W/2gdqUS899571K1bl9jYWEJDQ3nkkfvr4VQWqeJ/l+z0Ot55uAYnM3L5cWPKDfd1De9K86rN+XLHl5y5UuqTkyk2asHeQ2y9PBFvu+r8O3641nFuYKfX8WK7CJJPudK28kCWpSxjecpyrWNZ3RdffMGuXbtITExk7Nixmh9ysyZV/O9BXHVf2sVYun5eyLp2mbsQgnebvYtE8sGWD9QwuEqxFRjN/Hvjf9Dp8/i2w6fodXqtI/1D1zoBhFdyYdfeusR4x/DJtk9u2/VZvS+sq7jPpyr+9+jth2LIN5oY/ceNfZwDXAN4sf6LbDi1gT+O/6FROsVWfLlhKUanHXQM7E/NSlFax7kpvU7wUrtIks/n0srnGdJy0/huz3c3XdfR0ZH09HT1AWAlUkrS09OL1dVUnfC9Dx8tPMDEjcdY+HwLala9NsytyWyiz8I+ZBdmM6/bPBztym8fYEU7uQWFNJ3SBaHPYdPApTgbSmbcHmswmyWdx66nwGgmrukqFh79nTnd5hDmEXbDeoWFhaSmpl7tT68Un6OjI4GBgRgMN85ZUi7n8C0vnn8ggjk7U/lw4QGmD2t69TigXqfn9UavM2T5EKYcmMJTtZ/SOKlSHr27eiJmwymerP5umS78YBn47aV2kTzzyw6eMPTByW4Fn237jPHtxt9wfNxgMBAaGqphUuXv1GGf++DhZOCV9pFsOZrB8gPnbrivcRXLJPE/7P2Bc9nnbrEFRbm59JxLLDv9E47GCF5q1lPrOHelY01/alZ1Z+LaCzxTZzibTm9i1clVWsdS7kAV//vUv3EwEX6u/GfxQfKNphvue6XBK5jMJsbsHKNROqW8enXl55hFDi/UexWdrny8PYUQvNI+khMZOdhdaUF1z+p8vv1zNf1jGVc+Xl1l0F9dP4+n5zB5U8oN9wW6BfJEzSdYeHQhey/s1SagUu4czjjG9vSFuBe2YGC9plrHuSdto/2oE+TJuNXHeLHeK5y6ckoNe1LGqeJfDPGRlWgb7cfXK5NJu3LjDEdDag3B29Gbr/78SqN0Snnz2srPkVLPO3Evlbv+5H+1/k9dyuVEaiCNKzfmuz3fkV2YrXU05RZU8S+mtzrHkFtoYvQfh25Y7mJwYWitoWw9s5UtZ7ZolE4pLzae2E1yznoCdR15sEak1nHuS3yELw2reTFuzRGG136ejLwMpuyfonUs5RZU8S+m6n6uDGxajRnbTpB49vIN9/WJ6kNll8qM3TlW9W9WbuvttaOQJif+2/FFraPcNyEEr3SI5NzlfPYc8aBdcDt+2v8T6bnpWkdTbkIVfyt4qV0Ebo4GPvzbqJ8OegeG1xnO3rS9qveDckvT96wm3byH+u49qVm5stZxiqV5uC/Nwnz4ds0RhtV6ljxTHt/v/V7rWMpNqOJvBZ7O9rzcLoKNyemsPHj+hvu6hnclxD2Eb/78BpPZdIstKBWV2WxmdMIYMHow+sHntI5jFa90iCTtSj7r9+voXr07s5JmcTb7rNaxlL9Rxd9KBjStRnglFz5efJACo/nqcjudHc/Ve47kS8ksTVmqYUKlLPp682Ly9EfpFDgAXxdXreNYRaMQb1pG+DJ+7REGRA1GSsmkfZO0jqX8jSr+VmLQ6/jXwzU4lpbNz1tunBmoQ7UOVPeszvd7vscszbfYglLRFBhN/HTge3RmT95v86TWcazq/zpEcTGnkOV7CuhavStzDs3hfM75Oz9QKTWq+FtRmyg/WkVW4qsVh8jILri6XCd0PFX7KY5kHmHF8RUaJlTKklFrF2O0P0r30AE429vWOFB1gzx5INqPCeuO0i9yECZp4sd9P2odS7mOKv5W9q+HYsguMDFmxY1dPztU60CIewgT9kxQPX8UruQbmZX8I3bSndfjntA6Tol4uX0kmbmFLP2zkIfDHmb2odmk5aZpHUspooq/lUX4uzGwSTBTt57g0LlrEznrdXqG1hpK0sUk1qau1TChUhZ8+McipONh+kQ8hpPBSes4JSI2wIOONf2ZtOEY/SKfxGg2qtZ/GWKV4i+E6CSESBJCJAsh3rjJ/YOEEBeEELuKbkOtsd+y6qV2kbjY6//R9bNzWGcCXAP4bvd3qvVfgWXmFrLo5BQMuPFi48e1jlOiXm4fyZUCI0v+LOShsIeYlTSLjLwMrWMpWKH4CyH0wDjgQaAG0F8IUeMmq86UUtYtuv1Q3P2WZV4u9rzYLpL1h9NYnXTtJJdBZ2BoraHsS9/H5tObNUyoaOnbTWsRzofoWf3RMj9kc3FFV3bnoVpV+HFjCj3DHiPPlMeMxBlax1KwTsu/MZAspTwqpSwAZgDdrLDdcu2xptUI83Xho0UHKTRd6+HTLbwbfs5+TNqvur5VRFJK5hz5GSEdeaGRbR7r/7uX2kWQV2hi6S5J66DWTEucRk5hjtaxKjxrFP8A4OR1v6cWLfu7nkKIPUKIX4UQQTfbkBDiKSFEghAi4cKFC1aIph17Ox1vPxTD0QvZTL2u66dBb2BgzEC2ntnKgfQDGiZUtPD7/n3k2f9Js0oP42bvpnWcUlHdz41udQOYsjmFHmEDyczPZG7yXK1jVXildcL3dyBESlkb+AOYfLOVpJQTpJQNpZQNK1WqVErRSk7baD9aRvjy5YrDXMq51vWzV2QvXAwu/LTvJ+3CKZr4ZudEBDrebjFM6yil6sUHIig0SdbtdaWeXz0m759MoVmN968laxT/U8D1LfnAomVXSSnTpZR/jXn8A9DACvst84QQ/OuhGmTlFTJmxeGry93s3egd2Ztlx5eRmpWqYUKlNB26cJbTpnWEOrUk2KOq1nFKVYivCz3rBzB16wkeCR3ImewzLEtZpnWsCs0axX87ECGECBVC2AP9gAXXryCEqHLdr12Bg1bYb7kQVdmNR5sE8/OW4ySfv3J1+YCYAeiEjp8P/KxhOqU0fbT+B4SukNeaPK11FE083zYCs1my82Blwj3CmbRvkur1pqFiF38ppREYASzDUtRnSSn3CyE+EEJ0LVrtBSHEfiHEbuAFYFBx91uevNwuEieDnlFLE68uq+xSmc6hnZmbPJdLeZc0TKeUhsv52fx5aSEesi4tQmK1jqOJIG9n+jQKYmZCKl1D+3P44mE2n1G93rRilWP+UsrFUspIKWW4lPLjomXvSikXFP38ppSyppSyjpSyjZQy8fZbtC0+rg480yqM5QfOkZByrY/zoJqDyDXmMjNppobplNLwxcbpoM9mYIxt9+u/kxFtqiMQHDxcHR9HH3458IvWkSosdYVvKRncIhQ/Nwc+WZJ49atuhFcEzas2Z2bSTDXZtQ2TUrLo+Cz0hYEMbfiA1nE0VdXTiUebBPPbznN0Cu7B+lPrOZp5VOtYFZIq/qXE2d6Ol9tHsuP4RZYfOHd1+cCYgVzIvcCy4+rkl62atmcFBboztAvojZ1eveWebR2OnU6QeqIu9jp7ph6YqnWkCkm9EktR7waBhFdy4bOliRiLLvyKC4gjxD2EXw78ok5+2ajv9/yENLrxRnxfraOUCX7ujjzWtBqLd2XRsmoHFhxZQGZ+ptaxKhxV/EuRnV7H652iOXohm1kJli6eOqFjYMxA9qfvZ/eF3RonVKwt4VQi6eY91HDthK+Li9ZxyoxnWofjaNBz+Wwz8kx5zD40W+tIFY4q/qWsfQ1/GlbzYsyKQ+QWWKZ17BLeBTd7N9Xt0wZ9tvkHpNmON+Jsa7KW4vJ1deCJ5iGs3qentk8jpidOVxd9lTJV/EuZEILXH4zmfFY+kzYeA8DZ4EyviF6sOLGC01dOa5xQsZYL2RdJzFqFr2hG/cCbjmhSoT3VMgwXeztMF+M4n3OelcdXah2pQlHFXwONQrx5INqP8WuPXB32oX90fwRCjXhoQ0ZtnAK6QobVfkzrKGWSl4s9g+NC2LLfD3+nAKYnTtc6UoWiir9GXu0UxZV8I9+uOQJAFdcqtA1uy2/Jv5FnzNM4nVJcJrOJFafmYSioTv+6TbWOU2YNaRmGu6M9Drkt2Xl+J4kZFeoSIE2p4q+R6MrudK8XwE+bUjh9KReAflH9yMzPZMmxJRqnU4pryu6lGHVpdAzugU4ntI5TZnk4GRjWMoz9SVE46BxV678UqeKvoVfaR4KEL/+wzPfbqHIjwj3CmZ44XXX7LOd+2jcVafTgtZY9tY5S5g2KC8HT0R03UxMWHV2khjspJar4ayjQy5nHmlVjzs5UDp/LQghB/+j+HMw4yJ60PVrHU+7TjtNJZJj3UtO1I17Otj1TlzW4ORp4Oj6c48fqkW/KV2P9lxJV/DX2XJvquNjbMWpZEmDp9ulqcFVff8uxzzf/iJR6Xo+rGDN1WcMTzavhZReMizmKmUkzMZlNWkeyear4a8zbxZ6n4sP448A5dhzPwNngTLfq3ViWsoy03DSt4yn36FLuFfZfXom3bET9wGCt45QbzvZ2DG8dTtrphpy6cop1qeu0jmTzVPEvA4a0DMXX1YFPiwZ96xvVF6PZyJxDc7SOptyjzzdNA10eg2If1TpKuTOwaTW8qIed9FQj3ZYCVfzLAGd7O15sF8H2lIusSjxPqEcoTas0Zfah2RjNRq3jKXdJSsnSE3PRFwbwRP1WWscpdxwNeka0iSI7rSEbT2/k5OWTd36Qct9U8S8j+jUKIsTHmVFLkzCZJf2i+nEu55z6+luOzNm/gQJdKm2qPoJejd55X/o1DsLL1BKkjlmHZmkdx6apV2gZYdDrGNkxiqRzWcz78xStglrh5+zHrCT1BigvJuz6BWly5I14dcjnfjnY6XmhdQMKs2oyK2mOuuCxBKniX4Z0jq1CrQAPRv9xCKNJ0Cuil/r6W04kp5/ldOFWwpzi8Xd11zpOuda7QRCexlbkGLPUJO8lSBX/MkSnE7zxYDSnLuXyy5bj9IjogV7o1dffcuDT9ZMROhMvN1bdO4vL3k7Hyy0exJRfie93qYleSooq/mVMXHVfWkb4Mm51Mk56b9oEtWFe8jzyTflaR1NuIa+wkG3pi3GRkbQJr611HJvQs34gbgXxHM8+yP4L+7WOY5NU8S+DXu8UzcWcQiasPUrf6L5cyr/E8pTlWsdSbuHbLYtOMBtoAAAgAElEQVSRdhn0DO+tdRSbYafX8WLTfkizgdFbp2gdxybZZvHPuwzleGyc2AAPutSpysQNxwh1rkOIe4jq91yGzT48G2Fy4/mmPbSOYlP61o/CqaAh29NWkJmXpXWc0pN3GUwlP7GN7RX/9CPwTUPYU76Pk/9f+0gKTWa+Xp1Mr8he7L6wm6SMJK1jKX+z9sghsnR7aODTEUeDvdZxbIpeJxhcuz9SFPDZhmlaxykdUsLcZ2ByVzCbS3RXtlf8vULAKxQWj4SLx7VOc99CfF3o3ziY6dtOUterPfY6ezXPaRn05ZZfAHituTrRWxKeatwKgzGIxcd/o9BYAcb72fETJC2C6IdAV7Ll2faKv04PPSZYfv7tKTCV3ytkn3+gOg52OiasOUvHkI4sPLqQnMIcrWMpRc5n5ZCctxJ/u9rEVArROo5N0ut1dA3ricnuNOM2r9I6Tsm6cAiWvgnhbaHpsyW+O9sr/gBe1aDzF3ByC2z4Uus0983PzZGhLUJZtOcM9b0eJLswW030UoaMWjcXYXeZJ2v31zqKTRsZ1xchHfl5/wwKTSV7KEQzxnyYMwTsneGR/5V4qx9stfgD1O4DsT1hzSdwfLPWae7bsPgwvF3s+W2LHdU9q6s+/2WE0WRmReo8DNKTfrEdtY5j01ztXWnu3558h538vO2g1nGsT0pY9jac3QNdvwG3yqWyW9st/kLAQ6Mt3wJ+6QFJS7VOdF/cHA2MaFOdTckZNPDqzIH0A+xPU/2etTZz126MDkm0CeiKnc5O6zg276XGjyN0Rr7dPoN8Wzr2byqEec/C9u+h2QiI7lxqu7ZK8RdCdBJCJAkhkoUQb9zkfgchxMyi+7cKIUKssd87cvKEwcvANxJm9LecTCmHBjQNJtDLiU27q+God1St/zJgyr4ZCOCVJgO1jlIhRPtEE+pagxzHjczcdkLrONaRnwXT+sDuadD6LejwUanuvthNFiGEHhgHtAdSge1CiAVSygPXrTYEuCilrC6E6Ad8BvQt7r7viqsfDFoEswfB7y/Cia0QPxJ8wou3XWM+XD4NmamQdRbyLllu+VlgNl27zsDgBPYuYO8Kju7g6GG5OfuASyVw9Lzj8T0HOz2vtI/klVm7aRPSmiXHljCy4Ujc7N2K9zco9+V8VjanjOsIdGpAgHsVreNUGINr9+edTe8wdtMy+jQahqNBX/yNFuZB7kUoyIaCLMvvfxE6y3vXwRUc3MHJy3JEobjMJtj3G6z9FDKOQbdxUK/0GxHW+L7aGEiWUh4FEELMALoB1xf/bsB7RT//CnwjhBCytGYpd3CF/tNh5QewbQLsmQGxvaDuo1ClDjh73/xxJiNcToW0ZEhPhvTDRf8esRR9bhJf7wA6O8sLBwmFOSBvc5JKZwcufuBeBdyqWI73uVUB96rg6m/58HL1p1ttfyasc+NQck1yvZey6Ogi+kX3u/fnwlRoeaEX5kBhruVDzFRQdFHJdX+Pzg7sHMDOEQzOlufQ4FIqJ6LKuq82/4awu8LA2NJpv5Q7UhY1gExgNlp+FzpLTzyhL/r33otox9CO/Gfrp1y238DUrR0Z0iL0xhVMRsi/bCnmeZcs/+YWNcpyL0F2GmSdsdyunIPsdEvBv1t6e8t70q0yeAaDZ7WiruXVLD97BILecPPHmk1wIRFObIYt/7PUkUoxMHAOhLe55+fCGkRx668QohfQSUo5tOj3x4AmUsoR162zr2id1KLfjxStc8t5Chs2bCgTEhKKle2mss7B5q9h+yQozLYs86wG7gFFL0gBBVeKXiDnuaEgOriDT3XLzTsMPIMsj3OvamkVOHpYCub1pLQU2IIrlhdmXqblRZmTAdkXLLesc5B1Gi4XvTDzLt00utHOhTOFTrwY5ILQ6/lVVEXoHSxvrL8+bEyFljecqcBS3P+6FVyxFH1zMa8ctHcDBzfLtxiDc9GHg6PlwwJx7Tm89gQUfQu67t8bFK371+OuLwq3em3+Yx/X7euuHls8vbOSuWBnZIVjFHbiXj4Mb/Ves0Jr8vp9SHmtwXHD83qH5+yWuf72vJoLwVgAxjzLrTAHCnLAmHtt+Z22a+dgaSj91cCwc7h20zuA3s7ymvrrdVWU4xPjGWbJTL4+4UBzXwd0xqJ951+25Lgdg8u1Rparn6XR5eIDTt6W17S9q+W1/Nf+zKai980Vy/v2yrlr79VLJywNwOsnWxI6cPa1fKN39rY856ZCy/sv7fC1euNfC1q9CtFdSqQxJYTYIaVseKf1ytSZKiHEU8BTAMHBJTT/qZu/5dha/Ktwagec2Q2nd0FO+rXi5OILlWtZXiQegeAbYSn4LpXuvXgIYXlBGRwt270bBTnXPnyyz1v+zclAn3uRo3sOUy/zLLN8L7PHlEudwtwb3+x6w7VWu5N30b6drx16sne2vAkMTkXF297SotHbX/vbJEVv8LyiD65syy0/q+iNcBnyMy1fkY15RYe6jDfm+PtzcH0BurqfvwrE3z4gbniO//58y1sX9qv7usNjb/V/eLv7ihw2FZDoUsDgbAN2hWduu+6d83HdB6IVPwCE7rpiX/R/csNz9rf93exv/vtzfP06OoPl9eXgZmkJG5wsrys7x2s3veFaS1+IogxmS0E1FVz7xmnML7pd98FhKrD8bM65sbgCvfV2TDPAchcIMPsQWjXasm+HokOqDu6Wc32Onv/81+BYzCf2b0xGuHwKLh23XFB66XhRgy7NchPC8jzYu0BwU6haHwLqW2qJNQ4fFZM1iv8pIOi63wOLlt1snVQhhB3gAaT/fUNSygnABLC0/K2Q7dYcPSwXU4S3LdHd3Bd7Z8s5ib+dlxCAW82LTBy/Gu9KnzA7sjl1WpTuSaKK7t1f30ZeWUj7nvPBP+jOD1CsqjpQf8kTLJanWHDiWTYMa4erg0ZtWL2d5ZCPVzUIvfPqZY01vnNsByKEEKFCCHugH7Dgb+ssAP66/r0XsKrUjvfbmPrBXnSMCaYgsy5Lji0lMz9T60gVRp4xj/1ZK/GQdYlVhV8zvSJ7kS/OkUUSP208pnWccqvYxV9KaQRGAMuAg8AsKeV+IcQHQoiuRatNBHyEEMnAK8A/uoMqd+/VjlHkpTeiwJzPwqMLtY5TYfywcz5Sl02XUDV6p5Y6hHTAw8GDwGq7mLDuKJm5JT8Cpi2yytkGKeViKWWklDJcSvlx0bJ3pZQLin7Ok1L2llJWl1I2/qtnkHJ/qvu50TO2Kea8IKYdnIn6ElU6Zif9iizwYXiTTlpHqdAc9A50C+9GpviTrMKLTNqgWv/3Q/XbK6deah+BObMpJ7KOsfP8Tq3j2Ly955PIMCcS5dIeDyeHOz9AKVG9InthkiZqRiUxacMxLuUUaB2p3FHFv5yq4uHEo7FdkCZHflDznJa4L7dOQUo9IxqpQdzKglCPUJpUbkKu40auFBQwYZ06mHCvVPEvx15oHYu40pCNZ1ZzMe+i1nFsVq4xlx3pf+BUUJfW1cthtw4b1SuqF+dzz9C0Zho/bUoh/Yqa5/peqOJfjnk4G+gX3QcpjHy9rYLMdKSBn3bPwyxyeTikJ6IM9M9WLB4IegBvR2+cfRLIKzTxnWr93xNV/Mu5/2sTjy4/jHlH5mAy29Boh2XIjIOzMOf7MaJ5B62jKNcx6A30iOhBwoUNdKrjxJTNKZzPyrvj4xQLVfzLOUeDni6hPSjUXeDbrcu0jmNzdp/fT4YpmQin9vi4qhO9ZU3PiJ5IKQkI3kOhSfLt6iNaRyo3VPG3AW+26oMwuzB57zSMtjrTkUa+2jYFabZjeMM+WkdRbiLQLZC4gDhWpf5O93r+TNt6gjOZuVrHKhdU8bcBLvZOxFd5kDz7vfy0dbfWcWzGlYIrJKStxCG/Ae2j1InesqpvVF/O556nQcwZJJJvViVrHalcUMXfRrzW/EmEMPPtjmnkFapj/9bww65fkSKfbmG90OnUid6yqmVAS6q4VGHl6Xn0aRjErISTpF68wwifiir+tiLYPZiaXo3Id9rEpI2q5VNcUkpmJc3AnBfAiLgyOPifcpVep6dXZC+2ntnKI40NCCH4eqV6D9yJKv425Om6j6EzXGb89t/JzFHjnRTHhtRtZJlPEev2IN4u9lrHUe6gR0QP7IQdq08v4NHGwfy6M5WUtGytY5VpqvjbkPjAeHwd/Sl02cj/1qpeD8UxdvsUpMmRl5upE73lga+TLw9Ue4D5yfMZ3DIAg14wdtVhrWOVaar42xC9Tk//mD7YuSTz47atnM1UfZ7vR1pOGomXN+JhbE6TkMpax1HuUt+ovlwuuMyOtNU83iyEeX+eIvn8Fa1jlVmq+NuYv77+6jy28NXKQ1rHKZe+2jYVhIkBNfqpK3rLkYb+DQn3CGdG0gyeahmKo0HPVytV6/9WVPG3Mb5OvrSv1h4n7z+ZmXBEtXzukdFsZMnx3yA3gicbN9Y6jnIPhBD0i+7HgfQDnM47xJNxISzcc5rEs5e1jlYmqeJvg/pF96NQZuPsvYcvliVpHadcmZO4nHwyiPd/BCd7vdZxlHvUJbwLLgYXpidOZ1jLMFzt7fjyD/UN+GZU8bdB9fzqEeUVhU+V7Szdf4adJ9SIn3fr+11TMBd68nprNVtXeeRicKFreFeWpSzDJLIY0jKUZfvPse+Umu7071Txt0FCCB6NeZQM43G8fVL5bEmimu3rLuy7kMS5wv2EO3Qg2MtV6zjKfeoX3Y9CcyG/Hf6NwS1C8XQ2MFq1/v9BFX8b9WDog7jbuxMWtoutxzJYc+iC1pHKvM82TUKa7Xi56QCtoyjFEOYRRtMqTZmZNBNne8FT8WGsSjzPjuPqG/D1VPG3UU52TvSM6Ely9hYCffMZtTQJs1m1/m8lM+8yuy6uwM3USE3YYgP6R/fnXM451pxcwxPNQvBxsefzZeob8PVU8bdhfaL6YJZm6tVM5OCZyyzYfVrrSGXWl1umgShgYMyjqnunDWgV2IqqLlWZenAqLg52vNQugi1HM1i2/6zW0coMVfxtWKBbIK2CWrE7cyk1qjrxxfIk8o1q0Le/M0szvx+bjcgPYViTeK3jKFag1+npF92PhHMJJGYk0r9xMNGV3fho0UE18GERVfxt3ICYAVzMv0ir+qdIvZjLtK0ntI5U5nyfsIgC3Xk6BfXG3k69JWxFj4geONk58cuBX7DT63j34RqkXszlh/VqukdQxd/mNanchAivCLZlzKNZuDdfr0rmSr5R61hlhpSSSXungNGTd9v21TqOYkUeDh50De/K4mOLSctNo3l1XzrVrMy41UfU0Ceo4m/zhBA8FvMYhy4eomuTHDKyC/heTXR91aw9W8nRJ9LS/xFcHdQ0jbZmQMwACs2FzD40G4C3OsdgkpJPlxzUOJn2VPGvADqHdcbb0ZuNF+byUK0qfL/+KBey8rWOpTkpJV8n/Ahme95rM1jrOEoJCPUIpUVAC2YmzqTAVECwjzPDWoYyb9dpdhzP0DqeplTxrwAc9A70ierD2tS19GvuSL7RzDdquFuWHDjMJd1W6nm1w8/FS+s4Sgl5LOYx0vPSWZqyFIBnW1fH392B938/UKG7P6viX0H0jeqLQWdg3bm59G0UxLRtJzieXrEnu/hi808InYm345/SOopSgppVbUa4RzhT9k9BSomLgx1vPBjNntRMft2ZqnU8zajiX0H4OvnSObQz84/MZ3BLP+x0Ov67vOJe8r73dBrnxCqqOTUkyjtc6zhKCRJC8HjNx0m6mMTmM5sBeKRuAPWDPRm1NImsvIo5612xir8QwlsI8YcQ4nDRvzf97iyEMAkhdhXdFhRnn8r9e6zGY+Qac1l1ej6DW4SwYPfpCjvg1agNP6Ozy+aVRqrVXxE8HPYwvk6+/LTvJ8DygfBe15qkXcnnm1UVc77f4rb83wBWSikjgJVFv99MrpSybtGtazH3qdynKO8oWgS0YOrBqQyKC8DT2cCoCjjkc1ZePrsuz8ddhNMmpKnWcZRSYK+3Z0DMADaf2UxiRiIAtQM96d0gkEkbj3GsAs73W9zi3w2YXPTzZOCRYm5PKWGDYweTkZfB6lNLeK51ddYdusCm5DStY5WqLzbMAUM6j8UMUkM5VCB9ovrgbOfMj/t+vLrs1U5RONjp+WjhAQ2TaaO4xd9fSnmm6OezgP8t1nMUQiQIIbYIIdQHhIYa+jeklm8tftr/E482CaSqhyOfLa04A15JKVl4Yhp6kx9D6z+sdRylFLnbu9MrshfLUpZx+oplnCs/N0deeKA6KxPPsybpvMYJS9cdi78QYoUQYt9Nbt2uX09aqsetKkg1KWVD4FFgjBDipmfYhBBPFX1IJFy4oIYgLglCCAbHDuZk1kk2nFnNy+0j2Z2ayZJ9FWPAqxl7V1OgP84DVXpjp7fTOo5Syh6r8RgCwc8Hfr66bFDzUEJ9Xfhw4QEKTWYN05WuOxZ/KWU7KWXsTW7zgXNCiCoARf/e9KNTSnmq6N+jwBqg3i3WmyClbCilbFipUqX7/JOUO2kT1IZq7tWYtG8S3esFEOnvyhfLkirEC3/Cnh+QRlfejB+odRRFA5VdKvNg6IPMOTyHjDzLRV72djr+9VAMRy5kM2XzcY0Tlp7iHvZZADxR9PMTwPy/ryCE8BJCOBT97AvEARXvAFsZotfpGVRzEAfSD7Dt7BZe6xjN0bRsZiWc1Dpaidp4Yidppr3EuDyEr4uaqauiGlp7KHnGvBta/22j/WgVWYkxKw6RdqViXP1e3OL/KdBeCHEYaFf0O0KIhkKIH4rWiQEShBC7gdXAp1JKVfw11jW8K/7O/ozfM5620ZVoWM2Lr1YcJrfAdoe7/c/mb5AmJ96NH6Z1FEVDYR5hdAjpwPTE6WTmW7o6CyF45+Ea5BaYKsz1L8Uq/lLKdCnlA1LKiKLDQxlFyxOklEOLft4kpawlpaxT9O9EawRXisdeb8+QWkP48/yfbD+3nTcejOZ8Vj6TNh7TOlqJ2H1+PyfythOo60itqrfql6BUFMNqDSO7MJtpB6ddXVbdz5UnmocwY/uJCnH9i7rCtwLrEdEDPyc//rf7fzQM8aZdjB/j1xzhYnaB1tGs7qMNXyNNDoxsNkTrKEoZEOUdRZugNvx88GeuFFy5uvyFByLwdrbng98P2HwPOFX8KzAHvQODaw1mx7kdbD+7nVc7RpNdYOTbNbZ1xeOhjEMkZm3Eo7AtD0RW0zqOUkY8XftpsgqymJE04+oyDycDIztGsS0lg4V7ztzm0eWfKv4VXM+Invg6+TJ+93iiKrvRo34gkzcf59SlXK2jWc3Hm75BmuwZXu9JdVGXclVN35rEBcQxef/kG1r/fRoGUbOqO58sPmjT58BU8a/gHO0cebLmk2w7u43tZ7fzcvtIAL78wzZOeiVlJLEzbQ2GnHj61I/SOo5SxoyoO4JL+Zdu6Pmj1wn+3aUmpzPzGL/2iIbpSpYq/gq9o3rj5+THmJ1jqOrhyBPNqvHbzlQOncvSOlqxvb/hC6TZgcE1n1Tz8yr/EOsbS7vgdkw+MJmLeRevLm8c6s3Dtaswfu0Rm/oWfD31blBwsnNieN3h7Lmwh1UnV/Fs6+q42Nsxamn5HvQt4ewO9l7cgmN2O4a1qKl1HKWMer7e8+Qac5m498aOiG92jkEI+M9i25zyURV/BYBHqj9CiHsIY3eOxc1JxzOtw1lx8BwJKeVzqjspJf9ePwpzoRuvNh+Co0GvdSSljArzDKNLWBemJ07nbPa1YU4CPJ14plU4i/acYevRdA0TlgxV/BUA7HR2vFD/BY5mHuX3I7/zZFwIfm4OfLqkfA76turEGk7kHMCn8CF611OTtSi3N7zucMyY+W7Pdzcsfzo+nABPJ977/QAmG5vyURV/5ap2we2o5VuLcbvGodMZebFdBAnHL7LyYPka7dBkNvHhpv9iLvDhg7aD0elUDx/l9gJcA+gb1Ze5h+eSfPFaV2cnez1vdo7m4JnLzNxuW8OfqOKvXCWE4KX6L3Eu5xy/HPyFPg2DCPV1YdSyxHLV6pm8bwbpBccJ1feidVQVreMo5cQztZ/BxeDCqO2jbvi2+1CtKjQO9eaL5Ulk5trOlI+q+Cs3aFylMW2D2jJhzwTS884zskMUh85dYe6fp7SOdlcy8zP55s9vMOeEMbrz41rHUcoRT0dPnq37LJvPbGbNyTVXlwsh+HeXGlzMKeCrFYe1C2hlqvgr//Bqo1cxSzOjE0bTuVZl6gR6MHp5EnmFZf+Cl7fWfEGBzKZ7teeIrOyudRylnOkT1Ydwj3A+T/icAtO1YU5qVvWgX6NgpmxOIfl8+e8CDar4KzcR6BbIk7FPsiRlCQnnEni9UzSnM/P4ZUvZHut83/kk1p1ZgFNeHP9q/4DWcZRyyKAz8Fqj1ziZdZJfDv5yw30jO0TiZK/ng4UHy2UniL9TxV+5qcGxg6niUoVPtn1C4zBPWkb48s3qZC7nlc1jnlJKXlrxHtLswCdtXlVdO5X71jygOa0DW/Pd7u9u6Prp4+rAS+0iWXfoQrnrBHEzqvgrN+Vk58SrjV7l8MXDTDs4jdc7RXMpp5Dvyujl7t9sn8G5wn3UdOpDu6gwreMo5dxrjV9DIvlwy4c3tPIfb1aN8EoufLToAPnGsn8Y9HZU8VduqV1wO+ID4/n6z69xd8uka52qTNxwjPOX87SOdoPUy2f5fv+X6PLDGN/tea3jKDYgyC2IEXVHsC51HUuOLbm63KDX8W6XmqSk5/DTxhTtAlqBKv7KLQkheLfpuxh0Bt7Z+A4vt6+O0ST5amXZ6fEgpWTIwjcxU8i7Td/Dy9lB60iKjRgQM4BavrX4dNunN4z70yqyEu1i/Ph6VTLns8pWQ+heqOKv3Ja/iz+vNX6Nned3sunCAh5tEsyM7Sc5euHKnR9cCsZunc3pwgRqOfehZ+16WsdRbIhep+f95u+TVZjFZ9s/u+G+tx+qQb7RxOflePwrVfyVO+oW3o2WAS0Zs2MMPRo74mCnKxPznB5OO8XEA6OxKwxmQrdXtI6j2KAIrwiG1RrGoqOLWHps6dXlob4uDI4LZfaOVHafvKRhwvunir9yR0II/t3s3xh0Bj7Z8TaDWgSyaO8Z9qRq96LPLshn4IIXMIsCPmrxEW6O6nCPUjKG1R5GnUp1eG/ze5y4fOLq8hFtq+Pr6sB7v+8vl10/VfFX7oq/iz8ftfiIA+kHyHL+FR8Xez5cqM08p1JK+s36gBz9IXqFvMhD0epwj1JyDDoDn8d/jl7oGbl2JPmmfADcHA281imKP09cYt6u8nEF/PVU8VfuWtvgtgyOHcy8I3Po1DSV7SkXNXnRv7poFimmBUQ5P8B7bZ4o9f0rFU8V1yp83OJjDmYc5IvtX1xd3qt+ILUDPfh0SSLZ+UYNE947VfyVe/J8vedpVLkRy899S3S1K/xncSJZpXjh139Xb2LJudG4iiB+fuSTUtuvorQOas3jNR5nRtIMfjv8GwC6oikfz13O539ryuY1MLeiir9yT+x0doyKH4W7vTs5Xt+Rnne2VAa7klLy6fLtTEp+C3s7wZSHv8HJ4FTi+1WU673U4CXiqsbxweYPWJe6DoAG1bzoXi+ACeuPciI9R+OEd08Vf+We+Tr58r/2/8Mo8/CLmMyPW/eW6Hy/Uko+XLSTKcf+hZ19FpM6jSfCW13Fq5Q+g87Af1v/l0ivSEauHcm+tH0AvN4pGr0Q5WrKR1X8lfsS6RXJtw98i0l3CeegSYyYvqlExv0xmyVvzt3J9BMfYOd4jq/bfkk9/7pW34+i3C0XgwvftvsWb0dvnlv5HIcvHqayhyPPtQln6f6zbEpO0zriXVHFX7lvdf3qMqbNGPSO5znl9CVPTF5h1WGfjSYzz8/awIJz/8bO5SgfxX1AfFC81bavKPfL18mX8e3Goxd6nlj6BAlnExjaMoxALyfe//0ARpNZ64h3pIq/UixxAXGMbfsVjs4ZHNJ/xJBp86zyws83mhg8dQlrst7B3vk0X7T6gq7Vu1ohsaJYR4hHCL90/gUfRx+e/uNp1p9exb8eiiHpXBbTtp248wY0poq/UmzxgfFMe+hnPJzs2W3+mEenf0d2/v0fArqYXcAjk75nR+EHODsW8NODk+gY0tGKiRXFOqq6VuXnB3+mhk8N/m/N/3EwbwZNw9357/JDXMwuuPMGbuJk1kl2nNth5aT/VKziL4ToLYTYL4QwCyEa3ma9TkKIJCFEshDijeLsUymboryjmN99FpUdQzlo+pZWUway/ti9n/zamXqcdlOHkOowjiqufvzabTp1/dQxfqXs8nT05PsO39M9ojuT9k/iotenZItkvlxxb0Og5BTmMHbnWB6Z9wgfbP6gxC+gFMXZgRAiBjAD3wEjpZQJN1lHDxwC2gOpwHagv5TywO223bBhQ5mQ8I/NKWWc0Wzk4/U/8OvRH5CYaF6pKyObDyDSO/K2jzt++Thfbf2FP1LnAUZ6hj/J2y2GY9AZSie4oljBptObeH/T+5zOPoMxqwbvtR5K75pt0Ylbt7Mv5V1i+fHlfLfnO87nnOehsId4uf7L+Lv431cGIcQOKeUtG+NX17PGp4sQYg23Lv7NgPeklB2Lfn8TQEp52yt0VPEv3/adPcnTi94jU5+AEGaqe0TStlprAlwDqOxcGXu9PalXUjmZdZKtZ7az+8KfSKnDsTCWsR3eoXm1aK3/BEW5L9mF2YzbOYGf988EfTbV3KrRtGpTwj3DCfMIQyLJyM3gQu4FNp/ZzNbTWzFKIzV8avBm4zeL/U33bou/XbH2cncCgJPX/Z4KNCmF/Soaiq0cxLonf+DrtTv5LmEuyfk7Sc78HrixsSHQIQr9KLjYib4x3XmrYxM1BaNSrrkYXHitycv4Gh/mP2tnYohOYvHRxWQV/vNamADXAB6r+RidQjoR4x2DEKLUct6x+AshVgCVb3LX21LK+dYMI4R4CngKIDg42Fxn/NAAAAYlSURBVJqbVjSg1wleatOALrFRvP7rHhIOXkAYLqOzy0SnK8RY4I0s9KJGFS8+7BNLg2peWkdWFKt5vGk4M7fFk5bcnJUvxZNlzOBo5lH0Qo+Pow/ejt54OHiUasG/3h2Lv5SyXTH3cQoIuu73wKJlN9vXBGACWA77FHO/ShkRXsmV2c8040xmHklns0g6l0V2vpE6gZ7UC/bEx1UNx6zYHju9jne71GDAD1uZuOEYI9pG4Ofsp3Wsq0rjsM92IEIIEYql6PcDHi2F/SpliBCCqp5OVPV0ok102XkDKEpJiqvuS8ea/ny75gi9GgRR2cNR60hXFberZ3chRCrQDFgkhFhWtLyqEGIxgJTSCIwAlgEHgVlSyv3Fi60oilI+vN25Bkaz5LOliVpHuUGxir+Ucq6UMlBK6SCl9P+rR4+U8rSUsvN16y2WUkZKKcOllB8XN7SiKEp5EezjzLCWocz98xQ7jl+88wNKibrCV1EUpYQ927o6/u4OfPD7fsz/3969hVhVxXEc//0cb1ma4oXEmZpCU4eym1khlDQimlP2EERRUT12YQRFNKMoyIeK8sHAIoIiQaILXeyighhBF8u0GE3RNPMSjkmWok7qv4dzBB/U0mbtNTP7+3k6+5wD67c48GOx9mGvYx3jdiblDwCJnduru2ZPGaW12/fp7dXbc8eRRPkDQCFuv3KYrr6wv579dEOhp9+dCuUPAAWwK0c+7tl/WAtWbModh/IHgKJcUddfd1xTq9e+2KItew5kzUL5A0CBZk0eqZ413fTMktM+2zI5yh8ACjSkb2892jhCy9fv1sqNrdlyUP4AULAHxterfmAfPf1hi/7OdOQj5Q8ABevVvUaPT23Q5tYDeuPLX7JkoPwBIIPG0UN046WDNX/5Rv2+/3Dh41P+AJCBbT3RNFoH247q+aVnduRje6D8ASCT4UP66r4b6rV41Ta17NxX6NiUPwBk1DxxhAb06amnPliX/ND2E1H+AJDR+ef00MxJI/XN1r1a8uOuwsal/AEgszuvrVPD0H6at2S9DrYdLWRMyh8AMqvpZj15a4N27juklz/fXMiYlD8AdADXXTJQU8cM1cKVm7Xjj4PJx6P8AaCDmDNllCKkeR+vTz5WEQe4AwD+g9oBfdQ8cYQOtR1VRMh2srEofwDoQB6aMLyQcdj2AYASovwBoIQofwAoIcofAEqI8geAEqL8AaCEKH8AKCHKHwBKyEU+P/pM2G6VlOdwy/9nkKQ9uUMUjDmXA3PuHC6KiMH/9qUOW/6dle1vI2Js7hxFYs7lwJy7FrZ9AKCEKH8AKCHKv/29kjtABsy5HJhzF8KePwCUECt/ACghyj8h2zNsh+1BubOkZvs52z/Z/sH2e7b7586Ugu3JtjfY3mR7du48qdmus73C9jrbLbabc2cqiu0a29/b/ih3lhQo/0Rs10maJGlb7iwFWSbpsogYI2mjpDmZ87Q72zWSXpI0RVKDpLtsN+RNldwRSTMiokHS9ZIeLsGcj2uWlP48xUwo/3RelDRLUiluqkTE0og4Ur38SlJtzjyJjJO0KSJ+jog2SYslTcucKamI2BURq6uv/1KlDIflTZWe7VpJUyW9mjtLKpR/AranSdoREWtzZ8nkQUmf5A6RwDBJv55wvV0lKMLjbNdLukrS13mTFGK+Kou3Y7mDpMIZvmfJ9nJJF5zko7mSHlNly6dLOd2cI+L96nfmqrJVsKjIbEjL9nmS3pE0PSL+zJ0nJdtNknZHxHe2J+TOkwrlf5YiYuLJ3rd9uaSLJa21LVW2P1bbHhcRvxUYsd2das7H2b5fUpOkxuia/yHeIanuhOva6ntdmu0eqhT/ooh4N3eeAoyXdJvtWyT1ltTP9psRcU/mXO2K//knZnurpLER0dkeDnVGbE+W9IKkmyKiNXeeFGx3V+VmdqMqpb9K0t0R0ZI1WEKurGBel7Q3IqbnzlO06sp/ZkQ05c7S3tjzR3tZIKmvpGW219hemDtQe6ve0H5E0meq3Ph8qysXf9V4SfdKurn6u66projRybHyB4ASYuUPACVE+QNACVH+AFBClD8AlBDlDwAlRPkDQAlR/gBQQpQ/AJTQPxU17EKr1ubwAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# batch the inference across K=100\n", "targets = np.sin(xrange_inputs)\n", "predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n", "losses = vmap(partial(loss, net_params))(xrange_inputs, targets) # per-input loss\n", "plt.plot(xrange_inputs, predictions, label='prediction')\n", "plt.plot(xrange_inputs, losses, label='loss')\n", "plt.plot(xrange_inputs, targets, label='target')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MAML: Optimizing for Generalization\n", "\n", "Suppose task loss function $\\mathcal{L}$ is defined with respect to model parameters $\\theta$, input features $X$, input labels $Y$. MAML optimizes the following:\n", "\n", "$\\mathcal{L}(\\theta - \\nabla \\mathcal{L}(\\theta, x_1, y_1), x_2, y_2)$\n", "\n", "$x_1, y_2$ and $x_2, y_2$ are identically distributed from $X, Y$. Therefore, MAML objective can be thought of as a differentiable cross-validation error (w.r.t. $x_2, y_2$) for a model that learns (via a single gradient descent step) from $x_1, y_1$. Minimizing cross-validation error provides an inductive bias on generalization.\n", "\n", "The following toy example checks MAML numerics via parameter $x$ and input $y$." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "grad(g)(x0) = 4.0\n", "x0 - grad(g)(x0) = -2.0\n", "maml_objective(x,y)=5.0\n", "x0 - maml_objective(x,y) = -2.0\n" ] } ], "source": [ "# gradients of gradients test for MAML\n", "# check numerics\n", "g = lambda x, y : np.square(x) + y\n", "x0 = 2.\n", "y0 = 1.\n", "print('grad(g)(x0) = {}'.format(grad(g)(x0, y0))) # 2x = 4\n", "print('x0 - grad(g)(x0) = {}'.format(x0 - grad(g)(x0, y0))) # x - 2x = -2\n", "def maml_objective(x, y):\n", " return g(x - grad(g)(x, y), y)\n", "print('maml_objective(x,y)={}'.format(maml_objective(x0, y0))) # x**2 + 1 = 5\n", "print('x0 - maml_objective(x,y) = {}'.format(x0 - grad(maml_objective)(x0, y0))) # x - (2x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sinusoid Task + MAML\n", "\n", "\n", "Now let's re-implement the Sinusoidal regression task from Chelsea Finn's [MAML paper](https://arxiv.org/abs/1703.03400).\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "alpha = .1\n", "def inner_update(p, x1, y1):\n", " grads = grad(loss)(p, x1, y1)\n", " inner_sgd_fn = lambda g, state: (state - alpha*g)\n", " return tree_multimap(inner_sgd_fn, grads, p)\n", "\n", "def maml_loss(p, x1, y1, x2, y2):\n", " p2 = inner_update(p, x1, y1)\n", " return loss(p2, x2, y2)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(4.015528e-05, dtype=float32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1 = xrange_inputs\n", "y1 = targets\n", "x2 = np.array([0.])\n", "y2 = np.array([0.])\n", "maml_loss(net_params, x1, y1, x2, y2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try minimizing the MAML loss (without batching across multiple tasks, which we will do in the next section)\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "1000\n", "2000\n", "3000\n", "4000\n", "5000\n", "6000\n", "7000\n", "8000\n", "9000\n", "10000\n", "11000\n", "12000\n", "13000\n", "14000\n", "15000\n", "16000\n", "17000\n", "18000\n", "19000\n" ] } ], "source": [ "opt_init, opt_update = optimizers.adam(step_size=1e-3) # this LR seems to be better than 1e-2 and 1e-4\n", "out_shape, net_params = net_init(rng, in_shape)\n", "opt_state = opt_init(net_params)\n", "\n", "@jit\n", "def step(i, opt_state, x1, y1, x2, y2):\n", " p = optimizers.get_params(opt_state)\n", " g = grad(maml_loss)(p, x1, y1, x2, y2)\n", " l = maml_loss(p, x1, y1, x2, y2)\n", " return opt_update(i, g, opt_state), l\n", "K=20\n", "\n", "np_maml_loss = []\n", "\n", "# Adam optimization\n", "for i in range(20000):\n", " # define the task\n", " A = onp.random.uniform(low=0.1, high=.5)\n", " phase = onp.random.uniform(low=0., high=np.pi)\n", " # meta-training inner split (K examples)\n", " x1 = onp.random.uniform(low=-5., high=5., size=(K,1))\n", " y1 = A * onp.sin(x1 + phase)\n", " # meta-training outer split (1 example). Like cross-validating with respect to one example.\n", " x2 = onp.random.uniform(low=-5., high=5.)\n", " y2 = A * onp.sin(x2 + phase)\n", " opt_state, l = step(i, opt_state, x1, y1, x2, y2)\n", " np_maml_loss.append(l)\n", " if i % 1000 == 0:\n", " print(i)\n", "net_params = optimizers.get_params(opt_state)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD8CAYAAABzTgP2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzsnXd4FNX6xz9na3bTSIUECAk9lBAghF4CUixUUUCliIiIlKs/uWIDRFS8oHDBgspFREEQVFQEpQuICAFDwIQWCJAO6dkkW+f3x4aYkApsqPN5nnmyO3POmTOT3X3nvOc931dIkoSMjIyMjMwVFLe6AzIyMjIytxeyYZCRkZGRKYVsGGRkZGRkSiEbBhkZGRmZUsiGQUZGRkamFLJhkJGRkZEphWwYZGRkZGRKIRsGGRkZGZlSyIZBRkZGRqYUqlvdgevB29tbCgwMvNXdkJGRkbmjOHz48GVJknyqKndHGobAwEAiIyNvdTdkZGRk7iiEEOerU052JcnIyMjIlEI2DDIyMjIypZANg4yMjIxMKe7IOYbyMJvNJCQkUFhYeKu7InOX4OTkRL169VCr1be6KzIyN5W7xjAkJCTg6upKYGAgQohb3R2ZOxxJkkhPTychIYGgoKBb3R0ZmZuKQ1xJQogVQog0IcTxCo4LIcQSIcQZIUS0EKJdiWNjhRCni7ax19uHwsJCvLy8ZKMg4xCEEHh5eckjUJl7EkfNMawEBlRy/H6gSdE2EfgYQAjhCcwGOgLhwGwhhMf1dkI2CjKORP48ydyrOMQwSJK0B8iopMhgYJVk5wBQSwjhB/QHtkmSlCFJUiawjcoNzA1hzDdTkGsq/6Akgc0C5gIwGcBmq6luyMjcfpgMkHgYUo7ZX8vc09ysqKS6wMUS7xOK9lW0vwxCiIlCiEghROSlS5euuQOSJFGQZyY3sxCz0XplJxjzIOMspETbvxSXTsDlU/b3l05AdgJYKjAmtzm9evWqciHg4sWLyc/Pv0k9Kp9x48axYcMGACZMmEBMTEyFZXfv3s3+/fuL3y9btoxVq1bVeB/vSi6dgo3PwZJ28HZd+Kw3LOsGb/vDe8Hw/SRIi73VvZS5Bdwxk8+SJH0KfAoQFhYmXWt9IQRuXk5kJBvIuVyAh6cNRV4ymPNBKEHnCSoNKDWAsO83GcBw2b65+IBLbVDc2C2zWq0olcobasORLF68mCeeeAK9Xu/Qdi0WCyrVtd+r5cuXV3p89+7duLi40KVLFwAmTZp0Xf27p0mPg9/ehWPrQaWDxr0hZATUbmEfNafH2R+OYn6Eo19Dsweg10zwa3Orey5zk7hZI4ZEoH6J9/WK9lW0v0ZQKBW4eTlhtdjIu5QLNiu414PaLaFWffsPv84DdLXAzR+8m4BvsH1fXhqkxkBhTrltx8fH07x5cx5//HGCg4MZPnx48ZN4YGAgL730Eu3atWP9+vXExcUxYMAA2rdvT/fu3Tlx4kS5bbq4uBS/3rBhA+PGjQPsT9iTJk0iLCyMpk2bsmnTJgAKCgoYOXIkwcHBDB06lIKCguL6zz77LGFhYbRs2ZLZs2cDsGTJEpKSkoiIiCAiIgKArVu30rlzZ9q1a8cjjzxCXl5emX716tWL6dOnExoaSqtWrTh48CAAc+bMYfTo0XTt2pXRo0djtVqZMWMGHTp0ICQkhE8++QSwj96mTJlCs2bNuO+++0hLSyvV9pVRzi+//EK7du1o06YNffr0IT4+nmXLlrFo0SJCQ0PZu3cvc+bMYeHChQBERUXRqVMnQkJCGDp0KJmZmcVtvvTSS4SHh9O0aVP27t0LwN9//014eDihoaGEhIRw+vTp8j84dxNRa+CjThD7E3SZCv+KhhFfQa+XIHggtBwKPV6EYZ/C88eh18tw4Q/4rA9Efn6rey9zk7hZI4YfgSlCiLXYJ5qzJUlKFkL8CrxdYsK5H/DyjZ7sjZ/+JiapvB9wCcyFWG0CG2qUKoFCmV2tNlvU0TO7sxoy4sC9Pjh7lylz8uRJ/ve//9G1a1fGjx/PRx99xIsvvgiAl5cXR44cAaBPnz4sW7aMJk2a8OeffzJ58mR27tx5TdcYHx/PwYMHiYuLIyIigjNnzvDxxx+j1+uJjY0lOjqadu2Kg79466238PT0xGq10qdPH6Kjo5k2bRrvv/8+u3btwtvbm8uXLzNv3jy2b9+Os7Mz7777Lu+//z6zZs0qc/78/HyioqLYs2cP48eP5/hxe0BaTEwM+/btQ6fT8emnn+Lu7s6hQ4cwGo107dqVfv368ddff3Hy5EliYmJITU2lRYsWjB8/vlT7ly5d4umnn2bPnj0EBQWRkZGBp6cnkyZNwsXFpfi+7tixo7jOmDFjWLp0KT179mTWrFm88cYbLF68GLCPYA4ePMjmzZt544032L59O8uWLWP69Ok8/vjjmEwmrFbrNf0P7ihsVtg2C/74AIJ6wLDl4Fq78jp6T/tIIXwifPc0bPoXJB2BBxaCSntz+i1zS3CIYRBCfA30AryFEAnYI43UAJIkLQM2Aw8AZ4B84MmiYxlCiDeBQ0VNzZUkqbJJ7BvDXAiSFaVai2QBq0VCKKTqRZ8oVODdFDLPQfZFsBrB1R9K1K1fvz5du3YF4IknnmDJkiXFP2AjRowAIC8vj/379/PII48U1zMajdd8KY8++igKhYImTZrQsGFDTpw4wZ49e5g2bRoAISEhhISEFJf/5ptv+PTTT7FYLCQnJxMTE1PqOMCBAweIiYkpvgaTyUTnzp3LPf+oUaMA6NGjBzk5OWRlZQEwaNAgdDodYB99REdHF88fZGdnc/r0afbs2cOoUaNQKpX4+/vTu3fvMu0fOHCAHj16FK8h8PT0rPR+ZGdnk5WVRc+ePQEYO3ZsqXs8bNgwANq3b098fDwAnTt35q233iIhIYFhw4bRpEmTSs9xx2IuhG/GwOlfocPTMOAdUF7Doj29Jzz2Dex6C/a+B5dPwxPfgcax7keZ2weHGAZJkkZVcVwCnqvg2ApghSP6cYXZA1uWf8CUb/8hV+uwWmxkJBtQqhR41NFXPzTRs5F9QjovDRB2l1MRV7dR8r2zszMANpuNWrVqERUVVaqs1Wqlffv2gP3Hde7cuaXqXx1PX9m5rubcuXMsXLiQQ4cO4eHhwbhx48qNz5ckib59+/L1119X2FZV579ynVfaW7p0Kf379y9VdvPmzVW272i0WvsTrlKpxGKxAPDYY4/RsWNHfv75Zx544AE++eSTco3UHY3NBt8/YzcKDyyE8Kevrx2FEvrMAt8W9tHDhidhxGpQ3jHTlDLXwL2llaTRg9r+NKtU2ecbLCYreZnX8MQuhH1eQu8FeamQn1586MKFC/zxxx8ArFmzhm7dupWp7ubmRlBQEOvXrwfsP55Hjx5FqVQSFRVFVFQUc+fOBaB27drExsZis9n4/vvvS7Wzfv16bDYbcXFxnD17lmbNmtGjRw/WrFkDwPHjx4mOjgYgJycHZ2dn3N3dSU1NZcuWLcXtuLq6kpubC0CnTp34/fffOXPmDAAGg4FTp06VexvWrVsHwL59+3B3d8fd3b1Mmf79+/Pxxx9jNpsBOHXqFAaDgR49erBu3TqsVivJycns2rWrTN1OnTqxZ88ezp07B0BGRkaZ/pbE3d0dDw+P4vmDL7/8snj0UBFnz56lYcOGTJs2jcGDBxffr7uKba9DzEbo++b1G4WStB5uNzCnfoGfptsj+2TuOu5pc6/Vq9G5WCnINaHRqdDqqnk7rhgHiwmyLhZFMkGzZs348MMPGT9+PC1atODZZ58tt/rq1at59tlnmTdvHmazmZEjR9KmTdmIj/nz5/PQQw/h4+NDWFhYqYnggIAAwsPDycnJYdmyZTg5OfHss8/y5JNPEhwcTHBwcPEIpE2bNrRt25bmzZuXcncBTJw4kQEDBuDv78+uXbtYuXIlo0aNKnZvzZs3j6ZNm5bpm5OTE23btsVsNrNiRfkDvgkTJhAfH0+7du2QJAkfHx82btzI0KFD2blzJy1atCAgIKBcd5WPjw+ffvopw4YNw2az4evry7Zt2xg4cCDDhw/nhx9+YOnSpaXqfPHFF0yaNIn8/HwaNmzI559XPln6zTff8OWXX6JWq6lTpw6vvPJKpeXvOA4ss88phE+0TzQ7ig5P2UfMv823z1P0KTsHJXOHI0nSHbe1b99eupqYmJgy+6qDzWqT0hPzpLQLOZLFbL22ylazJKXGSFLSUenc6RNSy5Ytr6sP18rYsWOl9evX35RzlUfPnj2lQ4cO3bLz30yu93N1yzm7R5Jmu0vS149JktXi+PZtNkn6YYokzXaTpJO/Or59mRoBiJSq8Rt7b7mSykEoBG7eTiBBTnoh0rUMjRUq8Gxof51dY1G2MjLXRmG2fXGaZ0N72KmiBtbNCAH3LwDflvDDZMi79kWnMrcv97Qr6QoqjRIXTy256YXk55hwdr+GUDyVFtzrEShZOf7HjqrLO4CVK1felPNUxO7du2/p+WWqYPO/ITcZntoGGueqy18vaid4eDl82gt+nAKj1paK0pO5c7nnRwxXcHJWo9WrMWQZMRVarq2yzgOc3O1fRnNB1eVlZGqKv7+H6LXQYwbUa1/z56vdAvq9aZ+MjvxfzZ9P5qYgjxiKEELgWhSllHO5EE8/PQplNe2mEPZFb6YTkHkefJqCkG2uzE0mNxU2PQ/+7eyrl6uJMd/MpQu5pF3IBQmca2lx8dDa/9bSotJU4YoKnwint8Kvr0GTflAr4AYvROZWIxuGEigUAjdvHZkpBnIzCnHz1lV/fYNSbTcOmefsERuudWq2szIyV7PjDbso5NBPqlzAZiywcPJACjG/J5GeUFb2pCROzup/jIWH3VhcMRrOHlpcPJzQPLQY8WE4/PoqjPjSkVclcwuQDcNVqLVKnGtpMWQZKcwzo3PVVL+yrhbku9vXN+i9rm11qYzMjZBwGKJWQ9fp9hFrBaQn5nFsdwInD6ZiMVrxbeBKx0EN8Q10xTfADYVKYMgykpdpLP6bl3XldSFp53MoyDWXaVelVeKi/R+uv8fimr4T18CGuPk44e6tx91Hh9ZZJee3uIOQDUM56N00mAut5GYaUWuVVQ+lgaysLNasWcPkieMh7YR9vqEGhtS7d+9Go9EUq4vKyGCzwZZ/20Uge8woe9hq42zUZY7tTiDpdBZKtYImYb606lmP2oFuZcpr6qjwqFPxpLXVbMOQbTcYeZmFGDJNGLKM5KYbyDuZyuW/cymIPlu6TZ0Kdx8dbt463H3sW63aOuo0qoVCIRuM2w3ZMJSDEAJXbycykw1kXy7Ao45zlR/erKwsPvroIyZPngzOPmBIA713hXoyxfHCimubi7hadlpGhmPfQGIkDPkYtK7Fu40FFmJ/T+LozovkZRhx83aiy7DGBHfxw8nl+kezSrUCN2/7j3wZTibD1yMx93mHnEZjyLlUQPalAvvfywVcvpjLuahL2Gz2sHCPOno6PBRE43a+CNlA3DbIhqEClEoFbl46stLyycssxM2rnC9BCWbOnElcXByhoaFE9OpJ9OEDZGbnYZYUzJs3j8GDBxMfH0///v3p2LEjhw8fZvPmzWzfvp13332XWrVq0aZNG7RaLR988AGXLl1i0qRJXLhwAbDnTahbty7Lli1DqVTy1VdfsXTpUrp3734zbofM7YoxD7bNhrrtIWQkAIZsI0d3XOT4nkTMhVb8m9Si+6NNCQzxrvmn86YDoHFf1PvewavdCLz8fcoUsVlt5GUaSTmXTeTm82xd/jeR/vF0HtqIwNZlVYtlbj53p2HYMtOeje0G0QBeFhtWqw1L3TaoBv6nwrLz58/n+PHjREVFYbFYyL90ATcpm8tWNzpF9GfQoEEAnD59mi+++IJOnTqRlJTEm2++yZEjR3B1daV3797F0hjTp0/n+eefp1u3bly4cIH+/fsTGxtbRnZa5h7nwMeQlwIjvsJotHHg+9PE7k/GZrXRqL0vbfsG4NugrLuoxhAC+r8FH3aE/Uug7xtliiiU/4w4GrevTdzhNA5uOsfPH0YT0NKLbo80rtSVJVPz3J2GoRwkJGxmM5IkodJUfwGbQiWw2QSmQgvCbEOprtr1I0kSr8x7jz27tqNQCBITE0lNTQWgQYMGdOrUCYCDBw/Ss2fPYknpRx55pFi0bvv27aVSXObk5JSbNEfmHqYw266F1OwBDG4h/PTeETKSDAR38aNtvwBq+d4iWWyfZtD6ETj4KXSeYs9+WAEKhaBJh9o0bOvDsd0JHNp0jrVzD9K6dz3CHwxCU139MhmHcnfe9fvnl90nSeScO4tRslHL0wsnj8r1/a8gAIXZhiHZQOHlgmpJdK9evZpLly9z+I89qA1JBHYeXCxzXVKWujJsNhsHDhzAycmpWuVl7kH+/AQKs8hq9SI//ucwBXlmHnwuhAYtvW51z6DnS3B8A/y+2D6CqAKlSkHofQE0Da/DgR/iOLrjIqcPptJlWCOadqwjRzTdZO6ZVVhCCNzrB6AEcjLSsZaTj6AilGpF8eI3Q1b5Et0l5aCzs7Px9fVF7ebLrj+iOH/hYrnyxB06dOC3334jMzMTi8XCt99+W3ysX79+pdRDr+RvqEh2WuYeo2i0kFHvMb79Ih+LycqQ59veHkYBwLuxfc7j0HLITal2Nb2bht6jgxn+7zBcPLRsXxnLpg+Okp9jqsHOylyNQwyDEGKAEOKkEOKMEGJmOccXCSGiirZTQoisEsesJY796Ij+VIRCrcbdtw4SkJV4EekaUjk6OatxclGTn2PCVFBWMsPLy4uuXbvSqlUroqKiiIyMpHVICKs2bqN540Awlv0xr1u3Lq+88grh4eF07dqVwMDA4rwGS5YsITIykpCQEFq0aMGyZcsAGDhwIN9//31xzmOZe5QDyzDlG/klfiRCwLAX25cbenpL6TkDrGbYt+iaq9YOcmP4S2H0GNmUxJNZrHvrIAknai65o0xpxDWpiZbXgBBK4BTQF0jAnqZzlCRJMRWUnwq0lSRpfNH7PEmSXMorWxFhYWHSlYTxV4iNjSU4OLha9fNSU8jLy0WvVOPaoEG1h6k2m0RmigHJJuHh54yyOpIZkgRpsXaJDJ9mZUTG8vLycHFxwWKxMHToUMaPH8/QoUOr1R+ZmudaPlc3jYIspMUh7Ch8lZNpTRg0PZT6zavnGr3p/DgVjq6F6dHg5nddTVxOyOXXz/4mKy2frg83JvQ+WXLjehFCHJYkKayqco4YMYQDZyRJOitJkglYCwyupPwooOrckTWIs29tNCo1BVYzxrTUatdTKATu3jpsNsi9XE2JbiHsyUwsBWDMKXN4zpw5hIaG0qpVK4KCghgyZMi1XIrMvcih5cRmhnEytQkdHgy6fY0CQLcXwGaBg59cdxPe9Vx59JUONGrrw+8bznB0x0UHdlCmPBwx+VwXKPmfSgA6lldQCNEACAJ2ltjtJISIBCzAfEmSNlZQdyIwEezZy24EIQTudeuRfiGenNwcVDo9KrfqDcNVGiWuHlpyMwopyDWhd6tGhJPO0+5nzUu1q7CWYOHChddzCTL3KhYj6Xs3sSfvFeo19yDsgcBb3aPK8QyC4IEQuQK6vwjaa3IOFKPWKun7VEuQ/mbf+tMolILWveo5uLMyV7jZk88jgQ2SJJV07jcoGto8BiwWQjQqr6IkSZ9KkhQmSVKYj0/F4W/VRalS4V7bD5sQ5KQkYzNWP++zk4sarU5FXqYRs7Ea8xRC2FdDmwz2TUbmOjH99T2/Jo1Hq1PSd3zLO0NOovNU+2T5X1/dUDNKpYK+T7UkMMSbPWtPceKPZAd1UOZqHGEYEoH6Jd7XK9pXHiO5yo0kSVJi0d+zwG6grQP6VC20zs44u7ljUiowXLhQ7cnoKxLdCqWCnMsFxcv7K0XvZZ9nMMiZrmSuE0li73cJZFr96ft0W/Ru1yDweCup3wHqd4IDH4H1GnOdXIVSpWDA062o28yD374+SfYlOf9JTeAIV9IhoIkQIgi7QRiJ/em/FEKI5oAH8EeJfR5AviRJRiGEN9AVqHh5cQ3g4u2DqaCAfLMJ1cWLOFVzMlqhVODm5VRtyQwUSrtxMFwGNxMo75Avtcxtw4mfdnMisx0d2uVQr4J5BaPVSKohlRRDCin5KVzMvcj5nPMk5yXj5+JHC88WBHsF09a3LZqb+RnsMhXWPQ4nfoKWNxZcoVQr6DM2mK/n/smur2IZPL2trLPkYG7YMEiSZBFCTAF+BZTACkmS/hZCzMWeePpKCOpIYK1UesY2GPhECGHDPnqZX1E0U00hhMDdz5/0i+fJMxtRpaairlO9XAoanQq9m4b8HBNanQqtvgphMmcf+4jBcBnc/B3Qe5l7hcwUA7/9YqKu01nCxo4r3p9ryuW709/xd/rfnMg4QXx2PBL/fMUUQoGfsx91nOsQlRbFlnNbAPDWeTOy2UgebfYoHk4eNX8Bze6356DevxRaDLnhFKCunk50fbgxu1ef5O99SbTqUddBHZUBB618liRpM7D5qn2zrno/p5x6+4HWjujDjaBSq3HzrU12agp52Vm46XQo3d2rrog925WxwEJeppFJz03k559/xtfXl+PHj5dzIq198tlw2S6RXCJJe2BgIJGRkXh7V09E7GbKb69cuZLIyEg++OADli1bhl6vZ8yYMeWWjY+PZ//+/Tz2mH3QGBkZyapVq1iyZEmN9/NuxWK28uvHh1FJ+fTtX4hCax+dHrt0jBl7ZpCYl0gd5zoEewYzIHAAdV3qUtu5NnX0dfB38S81MsgszCQqLYp1p9bxQdQHfHbsMwY1GsQTLZ6goXvDmrsIhRI6Pwc//x9cPAgB5canXBMtuvlz5nAa+787Q4NWXrh6yioBjuLulMS4DnQurpjy8ynIzSE/KQlnjQaFrgr3EEXzDR52l9LIRx5n6tSpFf5oAuDsa5+IK8iwjyCukxuV37ZYLKhU1/7vnzRpUqXH4+PjWbNmTbFhCAsLIyysyrBpmUrY/20c6akWHvT6GOcea7BJNr6M+ZLFhxfjo/fhy/u/JNQ3tFpteTh5EBEQQURABGcyz/BV7Ff8cOYH1p9aT/e63RnbcizhdcJrRoKizSjY/oY9N7QDDIMQgognmvP1mwfZ9dUJBk5tI0tnOIh7RhLjCpWtPXD19kGlVlOgUWG8cAHJUr2JMo1OhVavol2rjri71aq0rMEMD459njbh3WnVqhXr1q0rPrZ06VLatWtH69atOXHiBAAZGRkMGTKEkJAQOnXqRHR0NPHx8SxbtoxFixaVuwJ6zpw5jB49ms6dO9OkSRM+++wzwG5MunfvzqBBg2jRogUAX331FeHh4YSGhvLMM89gLZqA//zzz2natCnh4eH8/vvvpdq+EmJ75swZ7rvvPtq0aUO7du2Ii4tj5syZ7N27l9DQUBYtWsTu3bt56KGHKryWK22OHz+eXr160bBhw+LRhcFg4MEHH6RNmzZl7tW9wtmoSxzbnUAb1y0EhjXEqvNg9v7ZLIxcSM/6PVk/cH21jcLVNPZozJwuc9g6fCuTQyfzd/rfTNg6gZE/j+SXc79gsd3YRHEZNM7QZiT8/T0Y0h3SpJu3ji5DG3ExJkOOUnIgd+WI4d2D73Ii40SZ/WabXV21skk3yWbDbDIiJAllrAJFkYhdc8/mvBT+UoX1XGo5kV5gwJBduabLL7/+in+9+vz8xSLwakJ24T+RUN7e3hw5coSPPvqIhQsXsnz5cmbPnk3btm3ZuHEjO3fuZMyYMURFRVUpvx0dHc2BAwcwGAy0bduWBx98EIAjR45w/PhxgoKCiI2NZd26dfz++++o1WomT57M6tWr6du3L7Nnz+bw4cO4u7sTERFB27Zlg8Uef/xxZs6cydChQyksLMRmszF//nwWLlzIpk2bALsxukJF1wJw4sQJdu3aRW5uLs2aNePZZ5/ll19+wd/fn59//hmwa1DdS+RmFLJzVSw+3kY6K1dgbf8zs/bP4se4H3km5BmeC33OIU/IXjovnm3zLONbjeenuJ/44u8vmLFnBvVc6jG+9XgGNxrsuInqsPF21dWor+xpSB1Aqx51OXM4jX3rz1A/2AsXj+qrJ8uUzz01YpAkCbPNjMla8Y+3UChQqtRICGw2G5KpeuJdSrUCZzcNpgJzeXp5xbRu3Zptu/by0ltL2LtjS7E2EsCwYcMAaN++PfHx8QDs27eP0aNHA9C7d2/S09PJySm7gvpqBg8ejE6nw9vbm4iICA4ePAhAeHg4QUFBAOzYsYPDhw/ToUMHQkND2bFjB2fPnuXPP/+kV69e+Pj4oNFoGDFiRJn2c3NzSUxMLJbvcHJyQq+vXOa5smt58MEH0Wq1eHt74+vrS2pqqv1ebdvGSy+9xN69e0vdq7sdq9XG1uV/Y7NK9PP9BOo047XzP/Bj3I9MDp3MlLZTHO420Sq1DG86nB+G/MDiXouppa3F3D/mcv+397Pq71Xkm/Nv/CS+wRDQBSI/t6ckdQBCIYgY3RybxcZva05UT5FAplLuyhFDRU/2kiSRbEgmszATX70vPvryffySJJGVkowp34BzoQldvXrVmozWu2lQqhTYrDYkSUIIwcWLFxk4cCBg989PmjSJI0eOsHn9l7z25nz6HPybWXPmAKDV2p90lEollmq6sSri6h+NK+9Lyn5LksTYsWN55513SpXduLHcxec1ypVrh3+uv2nTpvZ7tXkzr732Gn369GHWrFmVtHL38Mf3caSczabfMD0u+3fxSpu+bD67ialtpzIxZGKNnlshFPRp0IfeAb05kHyAz459xoLIBXx27DMeD36cUc1H4a69ASPd4Sn49ik4uxMa3+eQPtfy1dNxcEN+33CGUwdTadaxepGFMuVzT40YhBD4OfvhrnUnLT+N9ILy/ZxCCNx8fREqJQVOGkxJSdWabxAKgd5dgyRBQa4ZgPr16xMVFVXs/klKSkKv1/PE+KeZMWkMRyL/rLTN7t27s3r1asDulvH29sbNza1K+e0ffviBwsJC0tPT2b17Nx06dChTpk+fPmzYsIG0tDTAPgdw/vx5OnbsyG+//UZ6ejpms5n169eXqevq6kq9evWKjYjRaCQ/P7/SflV0LRVRfK+eeIJquwYhAAAgAElEQVQZM2Zw5MiRCsveTcT9lcbR7Rdp3aseQflf8krt2mzOOcn0dtNr3CiURAhBZ//OrOi/gi/v/5IQnxA+jPqQ/t/2Z9HhRVwuuHx9DQcPtOdDP7TCof0N6V2fOg3d2LvuFIbs6isZyJTlnjIMYP+w13Wpi6vGlRRDClnGrHLLKZUq3H3qYAUKlAJzUlKVQ9RRo0bRq3cP4s6epnGzoOJJ35IcO3bMPtnboQtvLF7Oa9OeLDdXwxXmzJnD4cOHCQkJYebMmXzxxRdA1fLbISEhRERE0KlTJ15//XX8/cuum2jRogXz5s2jX79+hISE0LdvX5KTk/Hz82POnDl07tyZrl27Vqgu+uWXX7JkyRJCQkLo0qULKSkphISEoFQqadOmDYsWlZZbruhaKqL4XoWG8sYbb/Daa69VWv5uICstn51fxOIb6EbH+2vxSvI2tui1/Kvdv5jQesIt61eobygf9vmQ9QPt0UufH/+cAd8O4O0/3yY57xonfVVaaPsEnNoC2QkO66NCIeg9JhiLycaetadkl9INcMOy27eCG5XdBrBJNs7nnKfAXEB9t/q4alzLLZebfhlDViZ6kxm9nz+qWpVHHYE97jwjyYCTsxo370pCXvPTIesCeDW5bnGx8pgzZ46cF9pB3EzZbZtN4rsFh8lKzWfYzLa8tW8MW3PP8HyTEYzvcnsZxfjseFYcX8FPcT8B8FCjh3iq1VMEugdWr4GMc7AkFCJes+dtcCCHf4nnwMaz9H+6FY3b+zq07Tudmym7fUeiEAoCXAPQqrQk5CZUOLHm4umJWqulQKPGlJyMzWyusm2VWoneTUOhwYzZWIkLyqkWCCXkX+eQXOauInrnRVLP5dD5kSBePzaTrblneNGsu+2MAkCgeyBzu85l87DNPNrsUbac28LDPz7M/qT91WvAMwgCu0PU6kpHzNdD274B+DZwZc/akxTkypnfrod71jAAKBVKAtwCUClUXMi9gNFS1i8phAJ33zogBPkqRbVcSgB6dy0KpSA3w1hxeYUSdB5QkA226meTq4o5c+bIo4U7jOxL+fz5w1nqt6rFouw32Juwl9cvZzD2FrqPqoOfix8vd3yZLcO2EOgeyLSd0/gzufJ5s2JCH4fMc3Dhj6rLXgMKpYLeY4Ix5lvYu+6UQ9u+V7gro5KuBbVCTQO3BpzNPsv5nPMEuQehVpbWPFJpNLh6+5BzKY2CgnyU2dlVupQUCoGLhxM5lwsozDOjc60gDlzvaR8xFGSCc/XkMGTuLiRJYtdXJxBKwQ/+n3I4LZK33EMZeH4ztB5edX2rlYKj0eTt3k3+kcMIoUA4OaFwckLh7oaqVi2UHp64DXwItW/NuFZ89D581u8znvr1KabsmMJH931EhzplAx5K0WIQbH4R/loNDRwr7eJV14WwBwI5+NM5GoddomHojUv130vc0yOGK2iUGhq4NcAqWTmfex5rOU/vOlc3nJydKVSrMKakVCtKSatXodYqycsyYrNWELOt1oNSa5fIkLknidmXROLJLE4228PhvAP8p9s7DDy93y48p684O5tks5GxejWnu/fg/GOPkf6//4HZ/rm0ZmVhPHcWw297yPhiFWkLFnBu4CByfvm1xq7D08mTz/p9hr+LP8/teK7cRaal0DhDyyH2ldDGPIf3p92ABnjVc+G3NScpNFTtApb5B9kwFKFT6ajvWh+TxcTF3IvYpNI/5EII3Hxqo1AqMagUmJKSqmxTCIGLpxOSTcKQVYGvUwj7l99kgHJcWTJ3NzarjUObz5HjmcpO/Xcs6LmAfibJrsIb+niF9Yxnz3F+9BhS35yHtllT6r7/Hk3/2E/gurU0+HIVQeu/odGmTTTZu4dm0Udp+PMm1AEBJP7rXyS99BLWaiySvB68dd4s77ccV40rU3dOrTqkNfQJMBsg9sfKy10HSqWCPmOCKcgzs2/9aYe3fzcjG4YSuGhc8Hf1x2A2kJiXWGZuQKFU4u5bB5sQGAryq/XlUmuU6Fw1FOSZMJsqmEfQFT0V5sujhnuNM1GpGDJN/OmzmYW9FtK3QV/7hKyzLzTuU26drO83cm7IEIxnzuA3/x0CVqzA7YEHUFawJkQIgbZRIwLXrMb7uefI3vQzZx8aSO7OXTVyTT56H5b0XkJWYRbP73q+UqUBAjqBZyO7O6km+hLgSrv+AZw8kEL8MTnIo7rIhuEqamlrUdu5NjnGHFIMKWWMg1avx9m9FiaVkvyU5GKX0sWLF4mIiKBFixa0bNmS//73v8V1nN01CIUgL6Ow/IlolYbATgO5nBBX7QiN3bt3s39/NSNAbpCVK1cyZcoUAJYtW8aqVasqLHtFXfUKkZGRTJs2rcb7eCciSRI///QHudp0nho4gj4N+tgl2U/9AiGPgrJsfo/0/60g+eWX0bVvR6OfN1FryJBqS2MItRqfqVMIXLsWZa1aJEyeTOL/vYglM9PRl0ZLr5a82e1Noi5F8eaBNysOwBACQh+D8/vsIaw1QIcHgvD0d2b36pMYCxwsDHiX4hDDIIQYIIQ4KYQ4I4SYWc7xcUKIS0KIqKJtQoljY4UQp4u2sY7oz43irfPGS+dFRmFGuUNhFy8vVGoN+QqBMdm+uEelUvHee+8RExPDgQMH+PDDD4mJseccUigVuNTSYjZaMRoq+GAKBdhMYKqer/VGDcP1Sm5MmjSpUlnxqw1DWFiYnIuhAr7atx5NSi30bUwMbGxXoOXYBrBZ7D+WJZAkibSFC0lbsADX+wdQ/5NPUFUzd8fV6Fq3ImjDerynTCFn61bODRlK/uHDN3o5ZRgQOIBJbSax8cxGVsVU/DBBm1GAgOiaUc9VqhX0Hh1MfraR/Rtkl1J1uGHDIIRQAh8C9wMtgFFCiBblFF0nSVJo0ba8qK4nMBvoCIQDs4vSfd5yautrF0tnZBWWXh0thIJadfxAIcgrzMeam4ufnx/t2rUD7HIRwcHBJCb+k/rayUWNSqMkNSmdBx4oR0paKFi6Yh3twjvLstv3gOz2geQD/Ln9FDaFlQmPlIg8il4LdUKgdsviXZLFQvKrr5G+/H/UGjmCugsXotDcmNqp0GjwmfIcgWu/Rmi1nB8zlsuffobkIGG7Kzzb5ln6NujL+4ffZ29C2RX6ALjXhcBucGy9w9c0XKF2kBuh9wUQ83syF2Nkl21VOCJcNRw4I0nSWQAhxFpgMFCdFJ39gW2SJGUU1d0GDAC+vpEOpbz9NsbYKiIiqoEEqKyFJNms5LYMof7rs4uPqTQaXL18yLl8ibyUZNz0eoTSnpEtPj6ev/76i44d/0lGIoTA1VPLxh++w9e7Nps3l5WS9q7jz5FfVvPRd/tk2e27WHY7z5TH7J1zefDydJqF10HvWiQgePkMJP0F/eYVl7UVFJD4wv+Rt2sX3pMn4z3VsaqqupYtCfruW5Jff51L77+P6cJ5/OfNq7piNVEIBfO6zuNCzgX+veffrH5wdfmZ4lo/Aj9Ns19/3XYOO39JwgcGcS76Mru+OsHIWeFonO75aP0KcYQrqS5wscT7hKJ9V/OwECJaCLFBCFH/GuveEgR2KWKFUJBnyi2zOlrn5o7WSUeBQlBY5FLKy8vj4YcfZvHixWUE4tRaFaFt27Bz5w5mvDijjJT0sOEjQbLRvlVTWXb7Lpbd/jDqQ7wvNEZpVdO2T+A/B45vAAS0ehgAa3Y2F56aQN7u3dSe9To+06bWSIYypYsLdd9/H6+nnyZ7w7dkffe9Q9vXq/Us6b0EjVLD1B1TyTaWY+BbDAalxj5qqCFUGiW9RzcnN7OQP76Pq7Hz3A3cLJP5E/C1JElGIcQzwBdA72tpQAgxEZgIEBAQUGnZOq+8cp3dLB+zzcy57HNcyL1AkFsQWpX2Sp9wr+PH5fPnyC3Mh4xMHh41kscff7w4t8LVsttPPz2R7Vv2sGvP9jJS0lpXT7CaUFoMsuz2XSq7fSLjBF/HrmV8+lv4NXbHp36RRpck2X8UA7uBmz/mxEQuPPMM5vMXqLvofdwGDKjRfgkh8PnXdAqOHiXlzTfRhbRG27ixw9r3d/FnccRixv86nv/77f/4+L6PUStKTK7rakGTfnD8W/uIqUQ+dEfi17gWIRH1iN6ZQOP2vtRtelt4rm87HDFiSATql3hfr2hfMZIkpUuSdCVIfznQvrp1S7TxqSRJYZIkhfn43NxVjGqFmgauDQA4n3ses+2fxTIKpRL32n5YgSefepLmzZvzwgsvFB+/WnY7NTUF79oeDH3oEf417fnSUtJC2CUyTAbsjixZdvtukt22STbePPAmrbO6osxzom2/Bv8cTPoL0s9A60cojInh3MiRWFLTqL98eY0bhSsIpRL/BQtQ6HQkPv88toICh7bf1rctszvP5s/kP1lwaEHZAiGPQl4qnPvNoee9mk6DG+Hm7cTOL09UHEJ+j+MIw3AIaCKECBJCaICRQKnVKkIIvxJvBwGxRa9/BfoJITyKJp37Fe277dCqtAS4BmC1WbmQc6HU6mitszPRf8ewbuMP7Ni6ldDQUEJDQ9m8eXOZdo4dO0av+7rR54FuzJ37Jq+88mrpArqiJ5iifLuy7PbdI7v9/enviU47RvdLQ/D0dyawldc/B49tAKWGvJw6nH9iNEKlJnDNapw7ht/UPqpr++L/n/9gPBNHigPnGq4wpPEQxrQYw9cnvmb9qaseOJr0B60bRNecOwlArVXSe3QwOZcK+POHszV6rjsWSZJueAMeAE4BccCrRfvmAoOKXr8D/A0cBXYBzUvUHQ+cKdqerM752rdvL11NTExMmX01QY4xRzp+6bh0LuucZLVZi/fbbFbp0tkzUsrpk5IpJ6fKdowFZik1PlvKyywsfcBmk6SU45J06fR193H27NnSggULrru+zD846nOVUZAhdf26qzR9xavSB8/skE4cSP7noNUiSQuaSjlzHpBiW7WW4gYNlkwpqQ457/WS+v4iKaZZcynrhx8c3rbFapGe2faMFPpFqHQw+WDpg99PlqS36kqSKd/h572a3atPSB9M2iElncmq8XPdLgCRUjV+Yx2yjkGSpM2SJDWVJKmRJElvFe2bJUnSj0WvX5YkqaUkSW0kSYqQJOlEiborJElqXLR97oj+1CSuGlfqutTFYDaQYkgp3i+EAne/uiAEOakpVYb9aZxUaPVqDDkmrOYSZYWwr4Q25YJV1ne5W/jvkf+SZ8wjLGkArl5ONAkrIWYXv4+c2CwSvolH26wZDb5Yibr2rc0j4DN1Crqw9iTPeQPjWcc+VSsVShb0WEB9t/q8sPsFLuaWiD8JecT+2T+5xaHnLI/Owxrh6uHEzlWxWGSXUinklc/XQS2nWnjrvMkszCSj8J+YaLWTE3pnF8wCCpKrzmrl4qFFAHmZhaUPXHEnFVzfilRZdvv2IvpSNN+d/o4xHpPIvmiibd8AFMp/vno5qz8gcb8HulatCPh8BcpqJIOqaYRKVbxeIvFfz2MrLKy60jXgqnFlae+l2CQb03ZOI+/Kws7A7uBS2y6sV8NonFREPNGcrNR8Dv1cM6uu71Rkw3Cd+Op9cdG4kGJIKRXG6uJbG6UQ5BUYsOaXn/znCkqVAr27BmOBpfRSfbUTqHTXbRhkbh+sNivzDszDR+dD0LkO6FzVBHf5Z8otP/IQSV8fQxfgTv3/rUDpWn4mwVuBuk4d/P/zLsZTp0h9+52qK1wjDdwasLDnQs5ln+PlvS/b5+0USnvo6umtYKw4uMJR1G/hSXBXP/7aeoHU+JoRFrwTkQ3DdXIld7RaoeZi7kXMRW4foVDgViS0l5uUWKVLSe+qQalSlNVR0nmAOV9WXL3DWX9qPbEZsUxvOIOk2GxCIuqh0thDMc2JiSRMeQ6V3kq9eTNQujhX0drNx6VHD7yenkDWN9+QXbTQ0JF09u/Mvzv8m90Ju/kg6gP7zpZDwVIIp25OHErX4U3Qu2vZuSq2tFv3HkY2DDeASqGivmt9bJKNi3n/SHVrXVxwcnKiUICxKBS0IoTCLs1ttdjIzymhQqkrcidcJcchc+eQbcxm6V9L6ejXkVpnglAoBS262ddvWvMMXHx2MlJhAfXvM6JqO/AW97ZifKZNQ9e2LSmvz8JUtPDSkYxqPoqHmzzM8mPL+TX+V6jfCVz9boo7CUCrU9Hr8WZkJBmI3BJ/U855uyMbhhvESeWEv4s/BeYCUvNTi/e71vazzx/kZlfpn9XqVGh0KvKzTVgtRU8sKi2odVAgG4Y7lU+jPyXPnMcLIS9y8kAKjdv7onfTIEkSya+8gjEujrrd89B2fMD+/75NEWo1dd9/D6FWk/D8C9iMjh3FCiF4peMrtPFpw+u/v87JrNPQYgic3gaFN8e9E9jam2Yd63Dkl/OknZddSrJhcADuWnf06OnXvR+tQlrRsmVL5r75Ji4enlgUCvITE6rME12voS8SkJdZ4kvnVLE7aePGjcXqrTVNSeG8WbNmsX379grLRkVFlVq/8eOPPzJ//vwa7+PtxsXci6w5sYYhjYdgO+WKqdBK6171AMjdsoXcrVvxeeIBXLwy7K6T2xy1nx9+89/BGBtL2rvvOrx9jVLDol6LcFW7Mn3XdDKb3AdW402JTrpCt0eboHPTsP3zmHt+4ZtsGBxEgGcAa35aw/pd6/kz8k9++eUXok+eQqlUki/ZsFyuOkmI3k2DMd+MqbBoIroSd9KNGobrldyYO3cu9913X4XHrzYMgwYNYubMMkrsdz1LjixBrVAzuc1kju1OwCfAldpBblgyM0mZ9xZOrVrh1SQbtO7QMOJWd7dauEZE4DluHJlrvia3koeD68VH78PiiMVcyr/Ev05/icmt7k1zJwE4OavpMy6YzJR8/vj2zE077+2IbBgchFKhpHmd5gghOJd5DrPZjKLERLQhIx2b0UhycjI9evQgNDSUVq1alVq1/PZ/3qD3/V2LVxGj0hKfnEHv+wcTEhJCnz59uHDhAvv37+fHH39kxowZhIaGEhdXWhBs3LhxTJo0ibCwMJo2bVqsdLpy5UoGDRpE79696dPHnh1swYIFdOjQgZCQEGbP/kc99q233qJp06Z069aNkydPlmp7w4YNABw6dIguXbrQpk0bwsPDyc7OZtasWaxbt47Q0FDWrVtXKslPfHw8vXv3LnUtV9qcNm0aXbp0oWHDhsXtV3avbmeiL0XzS/wvjG05FkuihowkA6161kUIQepbb9tl2ufORpzeAsEPgerGJLRvJr4vPI9Ty5Ykvfoa5mqEZF8rrX1aM6/bPI6kHWFW3QCkuB031Z1av7knbe6rz7HfEu/pjG93pe7s3m9OcfmiY5OLe9d3ofujTSsto0DBoxGPcubMGcZNHEd4eLg9raJejzE/H2NiIqu/+47+/fvz6quvYrVayS8KaTUYDHTu3JlZr77Bi//3Ih9/+AlvvDmbqa+9y9iHH2DslJdYsWo106ZNY+PGjQwaNIiHHnqI4cOHl9uX+Ph4Dh48SFxcHBEREZw5Y38COnLkCNHR0Xh6erJ161ZOnz7NwYMHkSSJQYMGsWfPHpydnVm7di1RUVFYLBbatWtH+/btS7VvMpkYMWIE69ato0OHDuTk5KDX65k7dy6RkZF88IE9wmTlypXFdaZOncrYsWMZO3YsK1asKL4WsBuBffv2ceLECQYNGsTw4cNZs2ZNuffqVmC22jh6MYuoi1molQpctCpcnVS4OqmL/qpw0dq39yLfw8vJiydbPsmez+PQOqto2qE2ubt2kbNpE97PPYeTMgGM2XeEG6kkQqOh7nsLOTfsYRJnzKDBF18Uy807ivuD7ichN4Elfy0hwFXH5JObyyQuqkk6DW5IQmwGO1fFMvL1jujd7hzD7SjuSsNwq1AqlUQfjeZ00mkef/Rx9h/ZT9f2XXH19iX94nnyLSbaN2vG088/j9lsZsiQIYSGhgKg0WiKE9q0a9uenbt3YLPa+OPQEb5bNg8Ksxg9ejT//ve/q9WXRx99FIVCQZMmTWjYsGFx4p++ffvi6WnPMb1161a2bt1anGshLy+P06dPk5uby9ChQ4tltAcNGlSm/ZMnT+Ln51cszleZGN4V/vjjD7777juAMtcyZMgQFAoFLVq0IDXVPonfoUMHxo8fX+Ze3QxsNol8s5XcQjPjVx7iz7PpGKrhd1Y6n0IfcARVxnAefu9PHoiXKAzSczo1C/Wb89A2aYL3MxNh01RwcoegnjfhahyLJjCQOrNnkfTSTC5/vAyfKc85/BwTWk/gYu5FPj7zPT7HV/LITTQMKrWSvuNbsn5+JNtW/M3AaaEoFI6XO7+duSsNQ1VP9jVNY7/GdOvZjY2bNmIptDD9uenYrFb+b8pzPNyjJ7u3b2fLtm2MGzeOF154gTFjxqBWq/+RxnZ3wmK2kJdVNOl8JTpJ61ntPlRXdvvll1/mmWeeKVV28eLF13PZN0RJ2e0rE/U9evRgz549/Pzzz6Xu1Y0gSRJGiw2D0YLBaKXQYkWSijTDAJsklXqfXWAh/rKBIW3r0q2xNx2CPBFAbqHFvhnN/7wuMLH64ufkWrzoFTAQ7ZlCFBSyPjOT0zP+y/NJSVx46W0CsKE6eee5kUriPngweb//zuWPPsKlW1d0DjbaQghe7/w6ly/8zlxTMsbo5TwRMqHqig7Cq64LPUY2ZdeXJzj08zk6DiwnudBdjDzH4CAuXbpEVpbdF1pYWMjBPQdp3LQxfsF+RB6JJOroUR4Y0I9TaSl4WixMmDCBCRMmlCslrVQpUKoVFOaZ6dSpM2s3/QbmfFav+oLu3bsDVCm7vX79emw2G3FxcZw9e5ZmzZqVKdO/f39WrFhBXp7d7ZaYmEhaWho9evRg48aNFBQUkJuby08//VSmbrNmzUhOTubQoUOAPXGPxWKptF9dunRh7dq1AKxevbr4Wiri/Pnz1K5dm6effrrCe1UVZquNnAIzKdmFnLtsICY5h1OpuSRmFWAwWdAoFTipFTgXuYZq6TV4OmvwcXMi0MsZf3cndr7Yi7eGtub+1n54u2jxctES6O1M63rudGnkTf+WdRjevh7NGqaQajrJ/4VPZv6wtjTLV+DX2J0tM3vwbMJezvs04JkTal59/0MwZmNtfvuuXagOdV5/HVWd2iS+9BI2g8Hh7asVahZ3ms19hnze/eu/fBr9aZXRfY4kuIsfzTvVIXJzPBdi0m/aeW8H7soRw60gOTmZsWPHYrVasdlsPProo4wePpr4nHgS8xIJcA3A1cuHfd+s54lPJ6LR6XB1d2fVqvKTpKs0ShRKwVuz/8Pz/36WBYv/i49vHT5fZc9nMHLkSJ5++mmWLFnChg0baNSoUan6AQEBhIeHk5OTw7Jly3Bycipzjn79+hEbG0vnzp0BcHFx4auvvqJdu3aMGDGCNm3a4OvrW24uB41Gw7p165g6dSoFBQXodDq2b99OREQE8+fPJzQ0lJdffrlUnaVLl/Lkk0+yYMECfHx8+PzzyjUTd+/ezYIFC1Cr1bi4uFR4r65gs0kUmK3kmyzkm6wUmKyYrPZ1IQKBVq3AXadGr1HirFGhUSmqzIhWXReCJEksO7qM2vraDGk8hOQz2WSl5tOufzDKXdtwupxCl48+ZJlvC6QfvyBX0vH4Lxreccumpf+dmZlO6eqK//z5XBg7jtT/LMDvjTkOP4emQTcW5CuZ5erB0r+WAjAxZKLDz1MeQgh6PNaMtAu5bFsRw4hXO+DiUfZ7dDcibqYFdhRhYWFSZGRkqX2xsbEV5g24laQXpJNiSMFX72sX3ktOxFxQgKvJglPjxijU6grrFuSZyE0vxNXLCV1+nF1HxrtqN9m4ceMqnZi+W5AkiXyTldxCCwajhXyztfiJUqNUoNMo0WtU6DVKnNRKlNfhJ67u5+qPpD+YuG0ir3Z8lZHNR7JjZQxxUZcY91YnLg4bjNDrCfruW4TNivReU5I8OzIk9SkyDSaei2jMcxGN0ajuzAF86n8WkLFiBfWWfYxrr16OP8GmF7Ad/ZpXuj7B5vO/8kGfD+hRr4fjz1MBmSkG1r8TiVddF4b8X1uUyjvz/wQghDgsSVJYVeXu3Cu8Q/B08sRd605afhoGswFXLx8koFCpwJKUVOnQ2MlZjVqrxJBpxKYtyuwmS3FTaLaSnF3AyZRc4i7lcSm3EAkJbxcNgV7OBPu50dzPjQZezvi4anHWqq7LKFSXK6MFX70vw5oMw1hg4czhNJp2qE3Bzq2Yzp/He/Kz9tHJhf2I/HTqdh7Btud7MLCNP//dcZpBH+wjOuHOXOXu86/paJs2Jfm117Fml5PP+UZpMQiFOZ853p1o5tmMmXtncjHnYtX1HIRHHWciRjcn5Wz2PZMrWjYMNYwQAj9nP7RKLQl5CaBSoHd3x6RUYMrLq/SLJITAxcMJm03CYHax7yys+ou3cuXKu3K0YDBaOHfZwKnUXC7nmtCqlQR46mnh70ZjX1f83HW46dSob/IT3aGUQxxJO8KE1hPQKDWcPpSKxWyjeec6XP54GdqmTXEtWjdCzI925dwmfaml17BoRCjLx4SRmW9iyIe/886WWArNd9aqW4VGg987b2PNyCCtJgIXGnQDnSdOJ7ewqNciBILndz9PgcWxqUcro0lYbVr3qsfR7ReJ+6ty/bO7AYd8g4QQA4QQJ4UQZ4QQZZa5CiFeEELECCGihRA7hBANShyzCiGiirYfr657N6BUKKnnWg9JkkjIS8DZwxOFUkmhkwZLcjJSJauQ1VolOhc1BQYbFoXLPamdZLbaOHspj7hLeRSYrNRxcyLYz5Ugb2dq6TUoFbf2+WZZ9DJ8dD4MazIMgNjfk/Cq64Lu9EFMZ8/i9cxEhEIBNhv8P3vnHRbV8f3h925j6d2KFAXpFuy991hij8ZobDG/xJiq+cb03ntiounRRGMvgMbesFeQIgjSpPcOuzu/PxYLAoLSje/z8LDeOzN3FnfvuXPmnM8J3QHOQwJ9H+AAACAASURBVEB1MzpsqEdz/n1uAFO7tuHHg5GM+uowJyKb1manoacnlo/OJHPtOgouXqzdweUKcBsNl3dhZ2jLR/0/4nLGZT46WfvSHHeizyRnmjmase/3EDKTGy6npj6o8TdKkiQ58B0wCvAAHpEkyeO2ZueArkKIDsAG4ONbzhUIITqV/pQPmL9PuC62l1+ST0phKqZWNmiAYglKrt05g9TYwgBJJpGjtUEU54L23uQsmiKFJVquJOeSX6yllbkhbi1MaWamRtFI/Lxnk85yKvEUj3s9joHcgNS4XJKjc3Dv3ZL0VatQOthjNnKkvnH8achJ0NcbuA1zQyUfTurAmvk90Oh0TFt5nNe2BJFT2HRch7bPPIPC1paEN9+848POPeE+HoqyIfIAfVv3ZY7XHDaGb+RU4qnavc4dkCtljFjgiSSX2LUq6L6u+lYb367uQIQQIlIIUQysBcp88oUQ+4UQ103sccCuFq7b5DA3MMdKbUVaQRrFBgKlWk2hSokmO+uOLiWZXIaJhQElGjlFOmN9xux/gNxCDVdSctEBbW2NsTE1aHSJRj9e/BErtRWT2+tdd6EBCcgUEnZSDIVBQVjPm3czMzh4K8iU4DK80vH6ONuw69n+zO3jxOoT0Yz88jDHrjSN1YPcxITmr/yPouAQMv76u3YHbzsADMwgRO9UeLLjk7Q2ac07x9+hWFtcRefaw8zakKFzPEiNzeXwusv1dt36pjYMQ2vg1p2guNJjlTEPuFUyUS1J0mlJko5LkjShFubTqGlu3BxDhSHXcq9haGmBTgiKDA0oqcKlpDZRolDKydXZoCu4/w1DdkEJUWl5KGUynG2NMVI1vsjqwJRAAq4FMNtzNoYKQ7QaHWEnEnHqYEvuHz+haNYM8wmlH2khIGQ7tB14UxyxEoxUCl4f68GGRb1RyiVm/HSc9/1CKNI0/idU0xEjMO7bl5SvvqIkqRZ98QoDvUEN8wetBkOFIct7LCcqK4pfgn6pvetUA0dvG3xGOhB8NIHQ47WvF9UYqNf1uCRJjwJdgU9uOexQGj41A/hSkqR2lfRdWGpATqekpNTDbO8NrVZL586db8hb3I5MkmFnaockSVwrTsLQzJwiwMbHh5LExErHlSQJEysDdEJOfr6CLZs23bey29kFJUSn52OolNHW1hiVona1eGqLHy/+iLmBOdNcpwFw9WIqhXkltGtZQP7x41jNmYNMVZrZnBQEmdH6bOdq0sXBEr8l/ZjR3Z5/9p5l8WvfcnTbZs76beXSwb1oa9tdUwtIkkSL119DaDQkfVjL5UDdH4L8NIg9DkA/u36McBzBqouriM6Ort1rVUGPsU60bm/BwTVhpMXXri5bY6A2DEM80OaWf9uVHiuDJElDgeXAOCHEjQIDQoj40t+RwAGgc0UXEUKsFEJ0FUJ0tbW1rYVp1w1fffVVlXHvKrmK1iatKdIUkWtQjFypREgS2sxMtHfIZlapFajVEvk6czZt3HBfym7fNApyHG2MG81ewu2EpodyMO4gs9xnYazUbySHHEvA2MIAA//fkZmbYzF16i0dfAEJXEdX+xpCCFIvX6Lr5U3Mivsb18jdHF/zM/t/X8XO77/g79deIi2+/sI2q4vK3h6bRU+Q47+T3MNHam9g52EgN4CQHTcOLeu2DJVcxTvH36nXrGiZXMaweZ6oDBXsXBl0Uyr/PqE2vnWnABdJkpwkSVIB04Ey0UWSJHUGfkRvFJJvOW4pSZJB6WsboA9QP4/BdUBcXBy+vr7Mn1+5pst1Kel+3fsxqf8k9h7ch8zcEIBXv/+Ozt260bNHjxtCcrdLVaflpXLqzHG2+/rfd7LbPXr2wsPNhQM7t+NoY0RKUlKjld3+JfAXjJXGPOL+CAB5mUXEBKXh7GpA3r69WM2cWbaGc8gOsO8JJs2qNX5JYSHbPnuf9e8sJyHiMr0mP0Kz+e+wyv5xCqe9ydjnXiYrJYnVy5Zw1n97vd4Uq4PVvHmonJxIfOedKisYVhsDE2g3SG9kS9+vrZEtz/o8y4mEE+yI3FHFALWLsbkBw+d5kpWcz/7VoY3u/6Am1NhxK4TQSJL0NLALkAO/CCEuSZL0NnBaCLENvevIBFhfKkEQUxqB5A78KEmSDr2R+lAIUWPDsP+3lSRHR9Z0mDI0c2jLoDl3TsV/9tln+fjjj++oYXSrlLRGoyEsKYwkTRr5+fl07NSJt55ezBvffsuqVat49dVXy0lVP/f8c/z144+MGDqa8Q+P5ZGZ0yu8TlOT3Y6Ni2fVej/ioyJYPPcRFs97tFHJbt9KTHYMu6J3MdtzNmYqvaps6PEEhIBmwX7o1GosZz16s0PGVUgKhOHvVmv8vMwMNn/0NslRV+g3Yw4+o8ahUKnoDUTkK1l1LBrXKR2Z8+l3/Pvj1+z/7Ue0mhK6jZ1Y+2/2HpGpVLR443Vi5jxO2sqV2D7zTO0M7DYGLu+ExIvQsiMAU1ynsO3KNj459Qn9WvfDQn3nPZzapLWrJT3Gt+X4lkhaOVvcqNLX1KmVdboQwk8I0V4I0U4I8V7psddLjQJCiKFCiOa3h6UKIQKEEN5CiI6lv3+ujfk0BDt27KBZs2blbqC3061bN3799VfefPNNgoKCcG3pilySo1QpGTF8GAWGBnRq25ao0hv5sWPHmDFDLzk8a9Ysjhw5gqGZGgkthbnFCF3FTyl3K7vt4+NDaGgo4eHhHD58+IbstpmZWbVltxWKOz9nVPReADRaHT0GjUStVDCsT5cystvX/1aBgYGYmprecfz64rdLv6GQFMxynwXoXT6hxxJp6WCE1n89FlOmoLC0vNkh1Ff/263q/YXU2GjWLH+etPgYxr+0nO7jJ6NQ3VRgffUhD3q3s+aVTYGE50hMWPo67Xv25dCaX4k8V3+hm9XBuGdPzMaNJXXVTxRFRdXOoK6jQZKVcSfJJBmv93qd7OJsvjj7Re1c5y7wGe6Ao7c1R9aHkxR1f9SLbnyhHrVAVU/2dcHRo0fZtm0bfn5+FBYWkp2dzaOPPsrixYtvyFq//fbbN57Kb5WSnvzIZBRKBcUmMpRZJegMVBRnZyN0ugqvJRmYopQVodNBfk4xxublC8k3FdntwhItuUUaDA0NcLLV7ynUpex2TUnJT2FLxBYmOE/A1ki/15V4RS+Y52ytd+lZz5ldtlPIDmjuBVZOdxw7LT6Wf95+BZlMxvQ3P6J5W+dybZRyGd/N8GHst0d4cvUZti/uy8gnnyUj8Rq+X33CjHc/w9quTQWjNwzNly4ld99+kt57nzarVlYpWlglxjZg30ufKDh4+Y3DrlauPOb5GL8G/cq4duPo0vzOD2i1iSSTGDLHg3/eO8WuVUFMXd4NtXHlGmhNgca5s9cE+eCDD4iLi+Pq1ausXbuWwYMHs3r1anr06MH58+c5f/4848aNq1BK2khphAwZ2eQhGakokSR0Oh2a5OSKpaplMszNTSnMTyc/qxitprwBaQqy23379uVqWh4g0dxUXU7KojZkt2ub1SGr0Qotczzn3DgWciwBpUqG8c6fMR8zGmXrW6K1c1P0UTRuY+44bmZSIhveWY4kSUx944MKjcJ1LI1V/PBoF9LzinlqzVlQqpjw0qvIlUq2fPI2Rfm1L4F9ryhsbLB9ZjF5R47UXp1ot4cgORjSyu6tLeqwSJ/bcOwdSnT1mxioNlYyYqEXedlF7PktuNKVfE3Iyi/hm73h6Opg7Nt5YBjqmQMHDtCxY0c6d+7MunXrWLJkyY1zZgZmpClyQSahlcvRpKby5Ucf8+uvv9KhQwf+/PNPvvrqKwCmT3+E7378hsGj+hB4tvy2zHXZ7VGjRt1RdnvGjBn06tULb29vJk+eTE5OThnZ7VGjRlUpu92xY0eGDRtGYWEhgwYNIjg4+Mbm86188803Zd7Li298QIlWYGIgrzD66E5/q4YgpziHdWHrGO4wHHszewBKirREnE7GzjgdeV4WVvPmle102R+E7o5upJy0VDa8uxxNcTGTX30Xq1ZV+6m9Wpvz4SRvTkSl84FfKGY2zRj73MtkJiYQtL+WbsC1hOWMGRi0b0/SBx+gK6gFfaPrRja07GazkdKIl7u/zJWsK/wV8lfNr3OXNHc0o+9kF6ID0zj7b+2Gz4Yl5jDuuyN8vS+ci/H1kMckhGhyP126dBG3ExwcXO5YU0Oj1YjL6ZdFRGKoSIi4LDIuh4mCsDCh02gqaFwiRPxZkZuYIpKuZomi/JIbp2bPni3Wr19fjzO/e+Iz8sWF2AyRllvU0FO5I7d+rtYErxFev3mJoNSgG8dCAq6Jb5/YK44PniZiFj1ZfoA1U4X4wksIna7C8UuKisTvSxeLr2dPFgnhYXc9vze2BgmHZTvEe77BQqPViTWvPC9+eW6R0FVyvYYi7+RJEezqJpK+/LJ2BlzRR4ifhlV46qk9T4nuq7uLpLyk2rnWXaDT6cSuVYHiu0V7RVxoeq2MuePCNeH+mr/o+u5ucfpqWo3GQh8QVOU99sGKoREhl8lpY9qGYoUOrYFEoVxCoympOPFNrgCVCUakIFfIyMkobDLhcul5xaTmFmFjYoCVcdMobSmEYEP4BjytPfG09rxxPCQgAVO1BtP4C1gvuC1MuSgXruzXrxYq8a0f+PNnUq5GMuaZpbRwvvuStK+OcWdWTwdWHorkiT/P0H7AMNLjY7kWFnLXY9UlRt26YTZ2LOk//UxxXLk0p7vH7SGIPQm55bOrl3Vfhkan4dPTn9b8OneJJEkMfNQN82ZG7Pr5EnlZRVV3qgStTvChfyhP/XUWtxam7Fjcly4O1S/vWxMeGIZGhlqhpqVJS7IMikCSKDQ0RJuRgTa7gmgHtQWSthATMxnaEh0FOXrNmMYsu51fpCE+swATAwUtzJtONayLqRcJzwi/oYkEkJmcz7XwTJrHHsKoSxeMfHzKdrqyF7RFle4vXD5+hAv/+tJ17ETa+pR311UHhVzGOxO8eGucJ/tCk3g9UIFCbcjFvTvvaby6pNmLLyCA9Coq91UL19GA0Ieu3kYb0zbM856Hf5Q/JxNO1vxad4lKrWDkQi9KCjT8+9MldNqKg0juRGZ+MXN+PckPB6/wSHd7/l7Yk+Zm9fd9ua8MQ1N5Yq4KCwMLLAwtyVGXoNFpKTZUUxIfjyi5bUNNrS8JqSIHlaGCvMyKN6IbCyVaHdHp+SjlEvZWRshqGqFSx9z6edpweQNGCiNGOY26cSz0WAISgubhe8qvFgBC/cDQEtr0LHcqMzGBXT98TUtnV/pOr3mk1ezejvz2eHficrRcMnQmJOAwhbmNS6pB2bw55g89ROamTWgyMmo2WAtvMLe/GQp8G3O95tLapDXvn3i/3jeiAaxbmzBwpivXwjM5se3uQnWDr2Uz9tsjnIhM54OJ3nww0RuDepaFuW8Mg1qtJi0t7b4xDi2MWyA3NKBEJSiUQCsExfHxZd+fQgVKQ6TCLEwsDRAI8jLvfelal+h0gui0fLQ6gYN145W6uI4QgrS0NNRqNTnFOeyM2snotqNvyF/odPrcBZvCq5g5NMNkwICyA2g1+qfZ9iP1br8yp0rY8dXHSDKJMUuWIq8i/6O69G9vy/bFfcmw64zQlLDi53/qJYLlbrCe+ziioICMv2uovipJ+hoNV/brXXa3oVaoG3QjGsC1Z0s8+rbi7K5orl5MrVafbReuMXHFUYpKdKx9oiePdLev41lWzH2Tx2BnZ0dcXByNWWDvbtHqtKTmp6AulCGXZBgkJCBPTUV2Sy4ChVn6H7Niigp1FBdoMTJTIVc2rhtvRl4xecVarI1VXM1unKJ4t6NWq7Gzs2PjlY0UagvLuJHiwzLIyyzC8co+rF+aXz4+PyYACjMr1EY6um41SZHhjHv+FcybNa/VOTtYG/PH0kl8uXgP2acO8IFfD5Y/5Fl1x3rCwMUF4wH9yVi9Buu5c5FVEC1XbdzGwIkf4Mo+8CifhDmwzUD62/Xn+/PfM8ppFM2MqidHUpv0m+ZCcnQ2e34LZuor3TCzMaywnUar4+NdYaw8FElXB0u+f9SHZqYN52q9bwyDUqnEyenOCURNkcNxh3n/zxcZcN6WDgZmtLkQjNOmjRi0KxWhTQyEH0bA2K8p9pzJX2+ewMhMw+SXuzaa2gU/HY7kXd9Inh3qwrO9736DtSERQrD+8nrcrdzLbDpfPpGIQhTTUpWC2ahR5TuG+oFCra/WdgvRF89zattGOgwZiUuP3nUyZ0OVnNGTJ7Dnp+/5Z88JHG1NmNnDoeqO9YT13HnEzJ5N1patWE6fdu8D2fcGtQWE+VVoGABe7vYyE7ZO4LPTn/FR//qt+AagUMoZudDrRvLbxBe7lHtoS88rZvHfZzkakcZjvRx4dYwHKkXDPtg1rsfKB5Sjn10/ho6YSUyzfAK1ueSbmRD/0kuI4tLiJM299L7WMD9UagV9JjuTEpND8JFrDTvxUg6Hp/C+XwgjPJvzzGCXhp7OXROUGsTljMtlVgslxVqunEnENvE0tnMeQ1LeluUqBIT56msv3FLCMz87C//vP8eqlR0DH6tcaLE2cO87EKWBmqHyKF7feomDlxvPStqoezfUXl6k//orQluDGhNyhd5Vd3lnpVUN25i1Ya73XPyi/Oq12tutmNsaMWS2B8nRORzdEF7mXFB8FmO/OcKpqxl8PLkDb4/3anCjAA8MQ5PgqU5PUTTEkSJJw/mOzhQGh5DyjV6kTu9rHXPD1+rcpRmt21twfOsVCnMbtizk1dQ8nv7rHC7NTPl8aqdGs4K5GzaEb8BQYchop5suoasXUykpgZY5l7CY+HD5TklBkBlTxo0khGD3ym8ozMlmzJKlKGviQqkGKkMjXHv3o1lKCO42Kp5ac5bQxMah4yNJEtbz5lIcHU3ugQM1G8xtNBRkQMyxSpvM85rXoBvRAG0729JxaBsCD8Zz+ZQ+/HzLuXgmrQhAJwTrn+jF1K6NR8rkgWFoAshlct4f+SkhHTVkpKSSMKQvaT/9RH6pHAVuo/VhkVf2IUkS/aa1p7hAy4lttaswezfkFJYw/4/TSBKseqwrxgZNz2uZW5yLf5Q/o51GY6IyuXE89OBVDIoyaDuqS9n9nhsN/NDXXrjpYoo4dYyIU8fpPfVRmjm2rYfZg/fgEWiKClnWvgBjAznzfjtNck4tSWDXENNhw1C0aEHGmjU1G6jdEH2Nhkqik0C/Eb2021IiMiMabCMaoNfD7WjR1pz9f4by7t8XeXbdeTq2sWD74r50bFN/irDV4YFhaCJYqa147rGPiLct5GxmIsUObYhftkxf2OdWXyv6UDnvAa0JOhxPSkzlEuB1hU4neG7deaJS8/h+hg/21kb1PofawC/KjwJNQRk3UkFOMbHhOTRPPoPVozMr7hjmC22636i9UFyQz77fVmJr70iXMfVXvbaliys2bRyICdjHz7O7kZ5XzILfT1PQCIrYSwoFltOnkRdwjKLIGjzAGJjoXXZhN2s0VMSgNoPo17ofKy6sICW/YdxqcrmMbjNcKNDp0BxO5vEeDqyZ3wMbk/IimA3NA8PQhOjcvDMdZk5Bg5bDLtaUJCWT+PY7Ffpau491wtBEyaG1l+s9hPeLPZfZE5LMa2Pc6e1sU6/Xri2ubzq7WbmV3XQOiEUgo52LEmXLluU7ZsZCwoUybqSA9WvITU9j6IKnay00tTpIkoT34OEkXgmnWUkaXz/SmYvxWTz/z/lGEcZqMXkyKJVk/FXD0FW30XrXXdKlSptIksTL3V+mWFvMZ2c+q9n17pGLcZlM+/MUO4yKsdXJ6JMpQ9FI3asPDEMTY3bPheT3aUlRcjaRo/qSvX07Wb6+5XytBkZKek5oR2JkFpdPVF5LurbxC0zgm30RTO1qx+zejvV23domOC2Y0PRQJrtMLhOKGrI7HOPceJwer6QoTpi//ndptnNS1BXO+m2nw5ARtGrvVtfTLod7/8HIlUoC9+1imEdzlo92xz8okY93hVXduY5R2NhgNnIkWVu2oM2tgSJs+1GAdGPFXBn2ZvbM9ZqLb6RvvW9Erz8dy+QfjiGTJL5c0otuYxwJO5HYaIJEbqdWDIMkSSMlSQqTJClCkqRyBX4lSTKQJGld6fkTkiQ53nLuf6XHwyRJGlEb87mfkSSJpfO+JKM5BCXEUuztQeJbb1Ni7Kn3td7y5XDv1ZJmjmYEbLpCcUHd16SNSM7hxfUX6GxvwTsTvGquvd+ArL+8Xr/p3Pbmk39mYi5puSrayOMw7Nix4o5hvmDtAjYuCJ2OvT99j6GZGf0emVM/E78NQxNTXLr3JuTwAUqKi5jX14kZPez54eAV/jnd8PWiLWc8gi43l+zt26puXBmmzcGu6x33Ga4zz3serYxb1dtGdIlWxxtbg3hpw0W6Oliy7ek+eLU2p+sYJ9q4W3J4XXiDuHurosaGQZIkOfAdMArwAB6RJMnjtmbzgAwhhDPwBfBRaV8P9DWiPYGRwPel4z3gDpgamDJ58asgBNuba9CVlHDttXcQTv3L1MOVZBL9p7cnP6eY43W8EZ1TWMLCP89gpJKzYmaXek/hr03ySvLwi/JjhOMITFU3q8aFbNY/ZbqP61Rxx4JMuHpEv3oDgg/vJyEijP4zH0dtYlJxn3qgw5ARFOXnEX4iAEmSeGucJ/1cbFi+OZDjkWkNNi8Aw06dUHt4kPHXXzVzebqOhoTzkHVngT5DhSHLui8jIjOCtaFr7/161SAlp4iZq07w+7Fo5vV14o+53bEu3U+QySSGzfVEbaJk56ogivIbNoLwdmpjxdAdiBBCRAohioG1wPjb2owHfi99vQEYIukfJ8cDa4UQRUKIKCCidLwHVEEnl560Gt0PwxQtx0Z4kH/yJOlXbCEzWl/EpJTmjmZ4D7Aj8EAciVF1o+MuhGDphotEp+XzzSM+TUocryJ2Xd1FgaaASS6TyhyPvJiGWX4cLccOqbhjxB7QacB1DMWFBRz++3daOLfHo9+geph15dh5eGNm24zQIwcAfRW4b2f4YG9lxKLVZ4hMaThNJUmSsJw5g6LwCPJP1EDw7rpQYRXuJNBvRPdt3Zfvzn9XZxvR52MzGfvNES7GZ/LV9E689pBHORkYQ1MVI+Z7kpNWyN7fQxqVnE9tGIbWwK1r0rjSYxW2EUJogCzAupp9AZAkaaEkSaclSTp9P8le1IRZM5ahbW1KUmw6mb07kLz+CIUZitJwyZv0HN8WY3MDDqwORXsPSo9V8f2BK/gHJfLySDd6tbOu9fHrmx2RO3Awc6Cj7U13UWZkIpnCCvtWIKkqkQoP9QVjW7DrysktG8jLSGfQ7IVIsobdypMkCdfe/bl68Rz52fqHA3NDJb/M6YZMkpj180li0/MbbH5mY8Ygt7SsmeqqTXuwalctwyBJEv/r/j+KtcV8fubze79mJaw7FcPUH46hkEtsfLI34ztVeEsDoKWzBb0ntiPqQirn9zS8a+86TWbzWQixUgjRVQjR1dbWtqGn0yiQZDLmvPAhciFji2E6WJhx7XQrdEFlK1upDBX0n96etPg8zu+OqdU5/B5wlU92hTG+Uyvm92v6kiSJeYmcTjzNGKcxZTed1x0BwO3hSmoJa4r1K4b2I8lKTeX0jk249x3YIBvOFeHeZwBCp+Py8aM3jjlYG/PH3O7kFJbwyKrjxGfWQnW1e0CmVmM5cya5Bw9SFBFxb4NcF9WLOqzXDqsCezN7Hvd6nB2ROzidePrernkbxRodyzcHsmxjIN2drNj+dF88W5lX2a/jkDa07WTLsc1XSIjIrJW51JTaMAzxwK0pe3alxypsI0mSAjAH0qrZ9wF3oFlrB7pMmkyzVAPW9TOjKE1D8q5IyC4b7dC2ky1tO9tyyvcqmcm183S49mQMb2y7xHCP5nw6pWOT3my+jn+UPwLBmLY3aygIIYi6nI+JJp3mPSoRpLt6GIqywW0Mh9b8iiST0W/GnPqZdDWwsXfE2s6e0KMHyhz3am3O6vk9yCooYcaq4yRkNYxxsJw5A0mtJu2XGqwaXMeArkRvoKvBfO/5tDJuxXsn3qvxRnR8ZgHTVx5jzYkYFg1ox+9zu2NZzSJUkiQx+DE3TK0M2PXTpRt1VRqS2jAMpwAXSZKcJElSod9Mvj3EYBswu/T1ZGBfaZm5bcD00qglJ8AFqP/KGk2cQQ/PwtihFcpYCO5vT8ZlE3LXf1+uXf9p7ZHJJY6sD69glLtj+4Vr/G9zIAPa2/LNjM4oG7mMdnXxjfTF28b7Rk1ngIyTF8lQtcKx3R0SkcL8QGFIbJEtl48fofu4yZhaN54cDkmScOszgPjQYLJTy1Y962BnwR9zu5OWW8zYb442iK6SwtISi4kPk7V9OyVJ5auyVYs23cHIppwrtTIMFYYs7V7zjOjdwUmM/uowl5Ny+W6GDy+PckN+l/kJBkZKRi70pjC3hN2/XGrwPJMaf5tL9wyeBnYBIcA/QohLkiS9LUnSdcnDnwFrSZIigOeBl0v7XgL+AYKBncBTQoiGT8tsYshkciYvXo5aI8ffPJ8iK4lrK7aWK4ZibGFAt9FORAemcTWwevrwFRGbns/LG/Xhdz/OatoRSLcSkRFBWEZYmdUCQOjGYyDJcJ9YSVyEEBDmj67tIPav/h1TG1u6jqskz6EBceujrxkRevRQuXOd7S3Z+GRvrIyVzP7lJO/uCKZIU79fRas5c0CrJWP1n/c2gEyuT/QM3w3a6q0ABrcZfEOaOzHv7vJ9Cku0vLntEgv+OE0bK0N2LO7LmA4VJD1WE1t7U/pNcyE2JIPTvndX3Ke2qZXHPCGEnxCivRCinRDivdJjrwshtpW+LhRCTBFCOAshugshIm/p+15pP1chhH9tzOe/iE0bBzz6D8E1zpyPRqnQ5peQ+Oor5SIdOgy2w6K5EUf+CUdbcvcb0TqdYNnGi0iSxBfTOqFW3h9GAcA3yhe5JGeE4810Gl1BAdGxYCzLz3yyVAAAIABJREFUx7Z9JXr+CechO57AQldSoqMY8Og8lKrGJ3Ng0bwFLZ1dCQ0obxgAXFuYsu3pvszq6cBPR6KYvOIY0Wk1SDy7S1T29pgOG0bG2nX3nvDmNhqKsvRhw9Xgeka0Vmj5+NTH1b5MWGIOE747ym8BV5nT25GNT/bG0aYC3ay7xKNvK1x7tuCU31VighsulPj+WP9Xk9DEbBKzGoeIWF3QZ+qjKGRKWmQ1Y3s/yNl7gKxNm8u0kStk9JvmQlZKAef33v1G9JqTMQRcSWP5GHfsLJumBlJF6IQOv0g/erbsiY3hTRdQqv8eMkzb4eRhXvkeSqgfhTolR46GYufhRfuefepp1nePW98BpFyNJC2u4ggYtVLOOxO8WDmrCzHp+Yz5+gjbLtRfdq71vLnocnLI3LD+3gZoOwgUhtWKTrpOG9M2LPBewO7o3RyJv7NBEULw29Eoxn57hNTcIn6d0403x3nW2qpZkiQGPOKKVUtjdv8STG5Gw9yv/lOG4fWtl+j70T6eXXuOwLi6ielvSMxsbPEZNRa7OBX+nmriHZQkvfcexTFlDYC9hzVOHW047Xf1rj54MWn5fOAXQj8XG6Z3azwSwbXBueRzXMu7Vs6NFO53HiGT4zrKu/LOYX4EFPagKC9fH57aiDfhXXv1Q5JkhAYcvGO74Z4t8FvSD9cWpjzz9zne3FY/fm/DDh0w6tqV9N//KF/jvDqojKDdYP0+w13kBTzu9TiOZo68f+J9cosrzutIzilkzq+neHN7MH2dbdj5bH8GudV+VTilgb64j7ZEx65Vl+okxLwq/lOG4dPJHZnVy4HdwUmM/fYIU384xs6gRLSNQFCstug+YQoGhsZMvGLHuw/pKEbLtaXLEJqykhh9p7ig0wlO+V6t1rg5hSUsXnsOmSTx4aQOjfrmdy9su7INQ4UhQ+xvJq+VXLtGfIYxhooSmjtVEnaYcZW0mEjOx8joMHREvUlq3yvGFpa08fQiLOBwlQlVrS0MWbuwJ/P6OvFbwFVe3RpUL8bBat5cNAkJZO/ceW8DuI2G7Di9mGE1UclVvNbzNa7lXmOW/yyu5ZZdJe0JTmLkl4c5HpnGO+M9+Xl21zpVRbVsYcygR91IjMzi2OYrdXadyvhPGQZ7ayPeGOvJsVeG8OoYd+IzC1i0+gyDPj3AqkORDZrkU1sYmpjSbdwkihOhZy58P0xDwfnzpP74Y5l2ZjaGePZtTWhAAtmpdw5RzCksYc6vp7gUn8VnUzvS2qLiurVNlbySPPyj/BnlNAoj5U33WNqW7aRZeeDU0QapkigTEeLLvqS2qNRqek99tL6mXCNce/UnIyGe5KtVy6Qo5TJeHePO/w1sx18nYnh9W1CdZ+iaDBiAql070n7+5d6u1X4kSLK7cicBdG/ZnRVDV5CUl8QM3xkEpQZRUKxl+eZA5v9xmhZmanYs7susXo718mDk0q053gNac2FPLFfO3WOk1j3ynzIM1zFTK5nfry0HXxrI9zN9sDFR8Z5fCP0+3s/orw7zxe7LXIzLbPCQsXvFZ9Q4jMzM6RhhRay3MSc7qEn9fgUFF8o+QfmMsAcZnNkZXelY143ChdhMvnmkMyM8W9T19Osd/yh/CjQFTHS5GUkkhODK7kB0chXOfStP3Is45EdMniW9pz2GkVnVyUyNAefuvZDJ5YQdO1yt9pIk8dIIVxYNaMfq4zG8tOEi2YV1p+0jyWRYz32cotBQ8gIC7n4AYxto06PaYau30qtVL1aPXo1aoWa2/xyG/vgNa07EsLB/WzY/1RuX5qZVD1KL9JnsQjMHU/b9HlJr+UfV4T9pGK6jkMsY7d2STf/XhwMvDmT5aHeMDeR8vS+ccd8epccHe1m64QI7gxLJK6p7ddLaQqlW033CVK7lmvFKpIZfhsvJMpMT/9JSdHk3oz1MLNV49mlV6aohK7+Ex345ecMojPK+91C8xsym8E04WzjTwabDjWMFZ8+SKFqhUghata+4upYmK5GDQUVYWxrSafjoCts0RozMzLH37lQtd9J1JEli2UhXFg92ZuPZOAZ/eoD1p2Pr7OHJbOxYFLa2pP/8y70N4DYGkgIh4+pdd3U0c2K45XsU5jUn2+xnHhsZyf9GuTVIWLZcKWPEAi8kmcSuVUFo6qnI0n/aMNyKo40xC/q3Zf2i3pxePpTPp3aku5MV/oGJLFp9hs5v7+axX07y14kYUnOLGnq6VdJx2ChMTAyJiTTihY6z+Xy0luLYGJI+/KhMO5+RDhWuGpJzCpm28hiX4rP5dobPfWsUwtLDCEwNZJLLpDLugfRNW0iz8caxoy3ySpL3zqz5hqwSNYOmTUUmb1phu669+pGdkkTilcvV7iNJEi8Md2XbU31pY2XESxsuMm3lMa7VgZSGTKXCctYs8gICKAwJufsBrhdKCru7CPiErAJm/nSCr/9NpI/RqwxpM4LN0St5I+ANSqqZG1HbmNkYMvRxD1Jjczn8T82TU6vDA8NQAdYmBkz0seO7GT6cfX0Yfy3owezeDsSk5fHK5kC6v7eH6SuP8cexq42mhu7tKFQqekyYQnyBOR2CY/AeOo2tPSQy168nZ+/eG+1MLNV4XF81pOm/4LHp+Uz54Rgx6fn8MqcbI73uP/fRdTaFb0IlUzG23dgbx3T5+cQevUyJwpi2XSt+77kZ6Zw4cgFny1wcBk6usE1jxrlbT+QKBWGV5DTcCW87czYu6s3HkzsQfC2bMV8frpNsacvp05AZGd2bTIZ1O7B1q1aNhuv4Xkxg5JeHuRCXyceTOvDjrJ58PuhjnujwBJsjNvPknifJKmqYaEZHbxt8RjoQfOQaiZF1P4cHhqEKlHIZvdvZsHyMB/tfHIj/kn48PciZ1NxiXt96iR7v72Xqj8f47WgUSdmNy0h4jXwYU7Xg6OFAlnZbSuAEL6JbyIlb/gqaWxRqfUaUrhr8owlPymHyDwFk5pewZn4P+ro0HlmH2qZQU8j2yO0McRiCucHN/YGcvXtJNnFFLteH9lbEiY1r0GoFAwZ66wXcmhhqYxMcO3Uh7NgRhO7uwyFlMompXduwbXFfmpmqmfPrST7aGUpWLdYVkJuZYTFlCtn+/pQkJNz9AG5jIDoA8tPv2Cy3SMNL6y/w1F9ncbQ2wveZfkzt1gZJkpBJMp7u/DTv9X2PM8lnmOU/i9ichlFB7THWiYcWd6RF27rfy3pgGO4CSZJwb2nG88Nd2fP8AHY/158lQ1zIzC/mze3B9PxgL1N+CODXo1GNIpFOoVTSq483iTkK4o7s4tOhX/LzJFNKcnOI/d/LN/zLplb6VUNIwDXmfncMnYB1T/Sks71lA7+DumVPzB5yinOY7FL2iT9z23ZSm/tg72mN0qC8iygrOYmL+3bjbZGIRdcJ9TXdWse1d39y09OIv3wPrppS2tmasOWpPkzysWPFgSv0+nAvb267RPC1bJKyC8nIKya/WHPPkUxWsx8DIUj//Y+77+w2BoQWLu+qtMm5mAzGfH2YjWfjWDzYmQ1P9sapggzmce3GsWrYKtIL05npO5Nzyefufj41RCaX4eBZP7L2UmMqDlFdunbtKk6frh2p3NoiIjkHv8BEfC8mEJakL9XX1cGS0d4tGeXdgpbmDRPiqUuJ4LfnFyI3a85j36zhVNJp1n8wj7n/amj26nKsH9WHWB46n8D5H4KJNJV47uUeOFjXPL2/sTPbfzapBalsf3g7Mkn/jKRJTeXM6Ec57bOUwY+54967/N7Kzu+/JOzIXua6XsJ0eSgoqqei2dgoLixgxYJH8RwwmKHzn6rxeJeuZfHLkatsuxBPibbsfUWSwEgpx9xQSc+21gz1aE4/FxtM1coqx41/aSm5e/fifGA/cjOz6k9ICPjcA1r7wPQ1ZU5ptDq+P3CFr/aG08JMzRfTOtHdyarKIaOzo3lq71Mk5CbwTp93ypR+bQpIknRGCNG1qnaK+pjMfwHnZqY8M8SUZ4a4EJGci39gAr6BCby9I5i3dwTT5bqR8GpBq3rMA5DZOtPbWYdvcA6hRw/Svd8gwhct41zE+3T46ENMevUioMSURRvO85C5mva5YPkfWEiGZ4RzNvksL3R54YZRAMj28yPVyhtJAscO5Z/O0uJiCT60Dx/bFEy9BjdZowCgUhvi0r0XoQGHGPjYAhSVFSCqJp6tzPlsakeWjXTlSEQqhSU6ijRaCkt0FBRryC/WkpRTxL6wZDadi0cllzG+UysW9G9L+zuEgVrPfZzs7dvJWLcOmwULqj8hSdKvGs6thuJ8fVY0+j2059ad53R0BuM7teLt8V6YG1ZtoAAczBxYPWo1zx54lmWHlxGdE82iDovuu4TPByuGOiYyJRe/wAR8AxMJScgGoLO9BWO8WzLKu2W9JIuJfe/z52//UmzmxONfrkQmV/C+/1KGv7KDQntH5no9jWtLM1Y83IkdH5zGvXdLBs5sHAVm6or3T7zPxssb2TNlD5bqmy6zqMlTOGw+GRN3Zya+WL4oz/bPPyDq3CnmOxzE6JGfwbPpupIAoi+eZ8N7rzJmyVLcevevl2tqtDrORGew/eI1NpyJo7BExyBXWxb0b0uvttYV3mRj5s6lKDyCdnv3ILsbAxZ5AP4YD9P/ArcxbD4Xx2tbLgHw7gQvJnSuvLranSjWFvPWsbfYdmUbD7V9iLd6v4VK3vgfEqq7Yrj/Hw0bmLa2Jjw92AX/Jf3Y/+JAXhrhSmGJjnd9Q+jz4T4mfHeUVYciicuou+QVyf0h+jW7SlZKCoF7/0WSJF4c/i5+w1pic/kq0/OP89eCnrSxMy3da0ggJ73h90jqivySfLZf2c5wx+FljEJRZBSZ4fHkKG1x7FB+0z0p6gqXTxylq7s5RgZycB5an9OuE+y9OmBqY8ulA9UrblMbKOQyerS15t0J3gS8PITnh7XnYlwWM1adYPx3R9l+4Rqa2/SBrObOQ5OSQvb2HZWMWgkOfUBtQXHQNpasPcdz6y7g1sIU/yX97tkogF5C490+77K482J2RO5gwb8LyCjMqLpjE+GBYahHnGyMeWqQcxkjUaLV8Z5fCH0/2s/4746y8tCV2pfmaOGNY2tTWlvJOLZpLVtPRTHlh1P8oZrHlZZyHjqyGW2ePkrJZ6QDAKf9r9buHBoR/lH+5JbkMtV1apnj2Tu2k2qjF8tzqsAwHN+4FgMjY7ooz0C7QWBgUi/zrUskmQzPAUO4evEcOWn3XqPjXrEyVvHMEBeOvjyY9x72IqdQw+K/zzHw0wP8ejSK/GJ9Yqlxn94YuLmR9usvdxdFJVeS2moQBUE78L8YxwvD2rN2YU/aWNVcGViSJBZ2WMgn/T8hKDWImX4zicpq2DoKtUWNXEmSJFkB6wBH4CowVQiRcVubTsAKwAzQAu8JIdaVnvsNGABcD8ydI4Q4X9V1m5IrqTpcTc3DLygBv8AEguL17qYOduaM8GxBX2cbvFqbI5dJ6HSC+MwCYtPzySwoIaughOyCEoo0N325RRotRSU6ijQ6Cku0N37PyPyBjmkH2RjtRYBlD9Kd+zGvrxMeBacwePI1zg9qzfTvdiGXyTn0dxiXDl9jxls9MLe9f6S1rzNtxzSKtcVsGrfphttCCMGV4SM46zSL4ubOzHy7ZxmXRkp0FH8sXUyvEYPpHfMWjPsWfGY11FuoVTITE/h5yQL6Tn+MHg9PrbpDHaLVCXYHJ7HqcCRnojMwN1Qyq6cDs3s7ojrwL9deWordiu8xHTSoyrFKtDq+3HOZyEN/s0L5JeGj/sKlx5gq+90L55PPs2T/EjQ6DV8O+pJuLbrVyXVqSn1tPr8M7BVCfChJ0sul/152W5t84DEhRLgkSa2AM5Ik7RJCXK96/ZIQYkMN59GkcbQx5v8GOvN/A52JScu/YSQ+2RXGJ7vCMFMraGVhyNW0PAorKa6jkEmoFDLUSjkGChkGZV7LuWjaj/EFW1BaW9M3L5CFC5/F2MwUcODoyG103HWKVVteZdHED+gy2pGQgARO7ohi2OOV1DhuogSlBhGcFszyHsvL3PgLzp+n4Foyac6t8O5gU87PfXzzP6gMDfFpmQOxMnAdVd9TrzMsWrTEzt2LSwf30H3ClAbdSJXLJEZ6tWCkVwvORKfz48FIvjsQwcrDkUzu2JpZzVuQ9vPPVRqGiORcnv/nPBfjspjpMwpx+Qdc0g8BdWMYOjXrxJrRa3hq71Ms3L2QN3q9wQTnprv/VFPDMB4YWPr6d+AAtxkGIcTlW15fkyQpGbAFMnlAOeytjVg0oB2LBrQjJaeIgCupHAlPJSW3iD7ONjg3M8HByghLYxXmhkpM1QoMlXIUVdVc1nWHTz9guo+OP3fnc9Z3E/0e0Zfh7vHGVwQdGYT1j1tZ79mZKa5T8R5kx7ndMfgMd8C6ddN3mVxnXdg6DBWG5eouZG/fTkYzb3RCKre/kBYXw+XjR+gxYQrqqB/AvpdeqO0+wnPgUHat+JJrYSG0dvNo6OkA0MXBipWPWRGZkstPR6LYcCaOvGY9WHR6K2f8DuEzql85I6bTCX4LuMpHO0MxUslZMbNUzuXvwfos6JEf1llCop2pHX+O/pMXDrzAa0dfIyY7hqc7P10m6q2pUNMZNxdCXE9JTASa36mxJEndARVwq8D4e5IkXZQk6QtJkhpfPcQGxNbUgPGdWvPJlI789nh3XnvIg0e629Pb2Qb3lma0sjDEVK2s2iiAvh6u6yiaJe/HrVdfzvptIzdDnxGqsLTE7vmX8IoW/Pv7Oxy7dgyf4Q6oDOSc2Fa1NHNTIasoC/8ofx5q+xCmqpvhkaKkhGw/f7I8hqAyVNDSuWxm6fFN61CqDPDp3QmSL+lDIO8z2vfsg9JAzcW991gDoQ5pa2vC+w97E/DyYNrOfoRclRHnPvmGUV8d5vsDEcSm5xOdlsd3+yMY+dUh3t6hL6Sz67n+NzW+3MdCVixcq9vENDOVGd8P/Z5JLpNYFbiKlw6+RKGm6QVyVHlHkSRpjyRJQRX8jL+1ndBvVlS6YSFJUkvgT+BxIcR1f8j/ADegG2BFeTfUrf0XSpJ0WpKk0ykpta/L8p/AfSwU59C7Zzt0Wg3HN627ccp66jSUzu2Yc0Bi2Z7nidfE0GmYPVEXUkm6mt2Ak649tkRsoUhbxDTXaWWO5x45giYzi2SFPQ6eVmVE8zIS4gkLOEynEWMwit2vP+j2UH1Ou15QqQ3xGjSM0KMHyU5tnN8vGxMDljzUkdZzZtE78RL2eSl8vDOMfh/vZ8AnB/hkVximaiWfT+3IT7O70sxUfbNz+5EgySFke53PUylT8kavN3ix64vsjt7NvF3zSC2o/439mlClYRBCDBVCeFXwsxVIKr3hX7/xV1hNQpIkM8AXWC6EOH7L2AlCTxHwK9D9DvNYKYToKoToamtre3fv8gF6nAaAyhTL5EN4DRpG4N6dZCYlAiApFLRcvhyr9BJGndCw4N8FWHYTqE2UHN9S/xWkahud0LEubB0+zXxwtXItcy57+3ZyW3lRWEQ5N9IZv23I5DK6jJkAwdugZSewdKjPqdcbXR96GIAzvlsaeCZ3ptnsWciUSt4sOMfhpYN4ZbQbr4x248iyQWx8sjcTfezK75MYWYFTPwjZdlclP+8VSZKY7TmbLwZ9QXhmONN3TCcoNajOr1tb1NSVtA2YXfp6NrD19gaSJKmAzcAft28y32JUJGAC0HT+ck0RpRraD4dQX3o9PBWZXMHRdX/eOG3cqxcmQ4cwIUCDcWYRTxxYQNuB5sSFZhAbcmchssZOwLUAYnNiy60WtLm55OzdR7bPGCSZhP0tWjSFublcOrgHt74DMZbyIf60ftV1n2Jm2wy3PgO4uHcnBTmNd5WosLbGYvJksrZspUVBBgv7t2Nh/3bYWVYRQec+DtIiICW0fiYKDLEfwu8jf0cuyZntP5vN4Zvr7do1oaaG4UNgmCRJ4cDQ0n8jSVJXSZJ+Km0zFegPzJEk6XzpT6fSc2skSQoEAgEb4N0azucBVeE+DvLTMMkJpetDEwg9epCEiLAbp5svXYqk0fLJRQ+KdcV8kPsShhYKjm+5UuclHeuStaFrsVJbMcxhWJnjOf/uRhQVkaxypGU7c9TGN6URAvftQlNUhM+ocRBamljlUcaDet/RbdwkNEVFnNtZ9y6XmmC9YD5IEqkrV1W/k9tDgKRf+dUj7tburH1oLZ2bd+b1gNd59cirjd61VCPDIIRIE0IMEUK4lLqc0kuPnxZCzC99vVoIoRRCdLrl53zpucFCCO9S19SjQojcmr+lB9wRl2GgMITgbXQbNwkjcwsO/nmztq7K3h6b/3sSsecIq1TzyCePU3b+JEfncOVs4/Q9V0V8bjyH4g4xyWUSSnlZTZzsHdvROHmSka7F0fumG0mn1XJu5w7aeHagmWNb/c3E1h1sXOp7+vWKTRsH2nXtwbmdOygurP0CPLWFsmVLLCZNJHPTpupLcps2B/ue9bLPcDuWakt+GPoD873n4xvpy+hNo/n+/PfklzTOOvNNL47qATVDZQwuQyFkOyoDNb2nzCQ+9BIRp29s/WA9fz5qDw9kn/3EJx1e46iJLxrzPE5si0SnvXvt/obmn7B/kCSpXKZzSVIyeceOk9tDH29+q2he+MkActJS6DJmPOSmQEwAeIyr13k3FN3HT6EwN4fAvZXLVTcGrgvqpa26i1WD+1h9yc/0+o+2U8gULPFZwtYJW+nbui8rLqxg9KbRrAlZQ7G2uN7ncyceGIb/Iu7jITcR4k7hPXg4Vq3bcHjNr2g1evkBSamk5QcfoM3Oxv6nXczxnsOeFn+RmZRPSMA9FExpQAo1hWwM38igNoNoYVy2Glu2ry8IQbJRO8xtDbFoftNHfcZvKxbNW9K2czcI8wWhu6/3F26lVXs37L06cGLzPxTmNt5FvLJ1aywmTCBz/QZKkpKq1+n6/2EDrBquY29mz+cDP2f16NU4mTvx4ckPGbt5LFsitqDRNY7a8g8Mw3+R9iNAroKQbcjkcgY8OpeMhGtc2O13o4natT22//ck2X7+PJ7shokzpJjFcHz7FUrqqSB5beAf5U9WURYz3GaUOS6EIHPTRhQdfUiIKcLR+2a2c0JEGAmXQ+k8ahySTKZ3I1k6QXOvhngLDcKAWfMpzM0lYMOaqhs3INZPLEQIQeoPP1Svg4W9PrKsnvcZKqKjbUd+GfELPw79EQu1Ba8dfY2Htz6Mb6QvWl3DfsceGIb/ImozaDtI/+UQAqfOXXHo0JmAf9aQl3lT6sp6/nzUnp6kvPMeH3b4H2ec/CnM1nB+T3QDTr76CCH4K/QvnC2cy2nXFAYGUhxxhcL+U9BqdGXcSOd27kBlaIjXwCFQkAFRB/VupPtMc/9ONHNsS4ehIzi/y5fU2Mb7/62ys8Ny+nQy/15L7uHD1evkPlYfYZYVV7eTqwaSJNG7dW/WjlnLlwO/RClX8vLhl5m0bRL/Xv0XnWgY1+0Dw/BfxWMcZMVAwnkkSWLw44vQFBdxaPUvN5pISiWtPvwAXW4usk9WsnDYTKIsL3LSP5KCnMblE62Ic8nnCE0PZYb7jHJx7ZkbNyGp1SQbtUOlltPS2QKA/KxMLh87jEf/IagMjSBsJ+g0evfbf4zeUx/FwNCI/b+tbNQRac1efAEDFxeuLXuZkuQKU6nK4qnP1yC4XHR9gyFJEkMchrBh7AY+GfAJOnS8cPAFpm6fyv6Y/fX+939gGP6ruI7WZ4KWfjmsWrWm69hJBB/eT1zwzXQSAxcXbJcsIXfPXvoFCVS9MtGVwK5NZxpq5tXmr9C/MFWZMsaprISFrqCAbF9fTIePICYkC3tPa+QK/VchcN+/aDUaOg0v7XNpM5jb68tD/scwMjOn97RHiQm6QGjAoYaeTqXI1Gpaf/E5uvx8ri1dhtBW4YaxbgctvOFS40vkk0kyRjqOZPO4zbzf933yNfk8s/8ZZvjO4Ej8kXozEA8Mw38VIytoO1B/4yv9sPV4eApmts3Y+8uKGxvRAFZzZmPo40PSe+/xYscZRLe6QOzxXJISGm9hksS8RPZE72Gi80SMlGUTn3J270aXm4tm4Hjys4tx9Na7kXQ6LRf2+GPv1QFruzZ6N9KVffoqbf8hN9KtdBw6ClsHJ/y+/oRfn3+SI2v/JCEiDF0D+8Bvx8DZmRavvUr+8eMkvPY6uYcOobmTdI7HBIg72SjcSRUhl8kZ224sWyds5a3eb5FemM6Te57kMf/HuJp1tc6v/8Aw/JfxmggZVyFBXwJDaaBm0OyFpMZGl0lwkuRyWn3wPkKjIfetjxg/rRdaScOa3/5ttC6Gf8L+QSd0THebXu5c5sZNKNu0IVHTHEkCey+9YYg8c4qc1BQ6DS/VQgr1A11Jky/fWRNkcjlTXn+fwXMXYWJpyckt6/lr+Qt8P38G2z57n/P/+pGREN8oPgfmEydiMW0aWZs2EbvwCcL79Se8X39inniC5K++IvvffymOK51rI3QnVYRSpmSiy0R2PLyDV3u8Sk5xTpmqg3XFg5rP/2UKMuATF+j1fzDsbUC/Ybvl47eJDQ7i8S9WYGp1M+kr4++/SXzrbVq88Tp/5OpQnGmJ4dhU5o5p2OIutxOREcG0HdMY2GYgnw38rMy54thYrgwbju2SZ9iT1gWVWn6jtvOG914jLT6WBd/8jEwuh9WTITUMllz8z64Ybic/O4vowPPEBJ4n+uJ5ctL0T+WmNrbYe3XEwbsT9l4dMbao+5tXZWhzcigMCaEoJITC4GAKg0MoioyEUheTzNwctYc76uLzqFsYon5qNSoHByS5vMHmXF2EEDWql1HdQj0PDMN/nTVT9Noxt9z8MpMS+f2F/6Nt1x6Mffam4K0Qgth588k/d442Gzay8seLFBWW0OlpM0a2H95Q76AMxdpiHvF9hNSCVDaO24iNYVlRvJSvvyZ1xQ802+jP2q+v0GeyM52G2pN+LY5fn1s6aAmQAAAgAElEQVREn2mz6DlxGuSnw6cu0OupG0bzAWURQpCZeI2YoAtEB54nNugihXn6vAcbe0ccvDti790JO3cvVGrDBp2rrrCQorAwCkNCKAwOofDSJYrCQhAafdSPZGSE2tUVtbs7ak8P1B4eGLRrh6RSNei8a5v6quD2gKaO58Ow5f/bu++4qqv/geOvc7nAZSp7I6DiwhW4Z2ppZtpQM9t7Wja/WZnVr2FTrSyz0qwsK0daaTlSy60oCIKT5WY4QAVknN8f54oQIChcruB5Ph48gHvPvZ/3Lbnve9b7PAoHt0Kg+uTc2MeXzjeOYN0vs0npdy0h7ToCauWE39tvkXTDUDLGv8JNr37I75N3sOjHNXg/4slVPtafoP1468fsPr6bT/t9Wi4pyMJCTsydh1PPnhw4rEZRQ9urNrFLF2OwMdK2nznB7fxDrUY6N+SglSOEwM0vADe/ANpfM5ji4iLSk5NKehQxSxcT/cdCDDY2+DVvQXBEB5q07YBvs3BsjHX71mMwmXBo3x6H9u1LbpNHd5H/TnfyAm4j76w/eYkJnPz1V47/8IN6fba22Ddvjn3rVphat1ZJo0ULDI4N77jb/9KJ4UrXYrDa7LZjfkliAFVMLeHfv/l7xufc9f5UjLaqxpCtry++r7zMof+9iPfaPwjv2QnW9OTVBW/z/PVj6BPUx1qvhA2HNzArYRYjw0dWGMep1aspTE/H99XxRMdk4BHgRCMvRwry8tixegXhXXucHwLZsQDcQtRmKK1aDAYbfJs2x7dpc7rcOIKCs/kc2pWohp3iYlk/70fWz/0BW5MDQa0jzImiPR5BTaxynKjwaYGpZWtMxp3wwKcAyOJizqamkp+YSO6OHeQnJnJq+QpOzp137kViFxqqEsW5r1YtsXF1rfP4LUknhiudQ2No2l8t3bv2zZLhJKOdHf3vfYR570xgy6J5dL3l/CSu69Ch5CxfTsbkKXT94WcOxNnSe9+tPLV8LGOinuC+iPvq/A/96OmjjPt3HCGuITzX6bkK2xz/6SeM3t7YRPXg8Lz1RF4XAkDi2lXknzl9fonqmWOQtAp6PKnnFmrA1s6eJm1VL6EXkHsqh/07tpMWF0tafAxJWzcD4NTYjeAINewUHNEeV886PG+lzU2w4nU4kQaNgxEGA/ahodiHhuI6eDCghswKDx8+PwyVkMCZzZvJ/u38Ag3bwEBzojjfuzDW43NjdGLQ1B/H7iVwYDMEnT8rKaRDJOFde7Jh/hyad+2BR0AQoIYQfF9/naQhN5A54SX6vjyVJV8WMjz3YSZvnczu47uZ0G1CuWWilpJXmMdTK5/iTMEZpl8zHQdj+fHsgoMHOf3vGjwffYTUhBNIqYaRpJTELF2MV3AI/i1aqcaJi0AWqSWNWq1xcHYhvEsPwrv0ACA7M520ODU/kRoXQ+KaVQC4+QUQ3Fb1JoLatMPkZMEzxyNuVokhfh70fLrCJkIIbP39sfX3x6V//5LbC7OySia388wT3TlLl5bcb/TyUnMV5mTh0Lo1Rn9/q/SOLpZODBq0uA5s7NUfR1DZQ/T63fswaXExLJ32Mbe+PhGDQa3cMLq74/fG6xx4Ygyea3+h6VX9SNluYMyIZ/k0+SN2HdvFR30/IqxxmEVDl1IyYd0EErISmHL1FJq7VVwW+/hcdUZU4+HD2bwwA2c3e7yCXTi0eycZKUkMeODx83+wcXPBozn4ta/wubTa4erpTcTV1xBx9TWq3tH+VPOwUwwJq1cQu/QPhDDg07RZyYon//BWGGtzQtgtBAI7QVzliaEyRg8PnHv1wrlXr5LbinJyyN+583zCSEhQpTqK1SS3oVEjNVdxbs6iTevLckWUTgyaqp0UPhDi58O1b4HN+X8WTo3d6Hv3g/z52SRi/lrMVdedrzDqMmAAjYYNI/OL6XT6ug8Hdhrw2NKOaaOmMW7NOEb9MYpXu73KkDDLnJEspWT69uksTl7Mkx2f5OrgqytuV1DAybnzcOrdCzx92J+wm1Y91Ce32KV/YOfgSKtefVXjkwchZQ30HaeHkeqQEAKv4BC8gkOIvP5GigoLOLx3N6nbY0iLj2Xzonls+vUXjLZ2BLRqU5IovEPCVKHDmmg7Apa8AOmJ4N2qRk9l4+KCY6dOOHY6X5urODeX/N27Va9iRwJ5iYkc//575FlVVqZkRVSpoShrr4iq0XJVIYQ78BMQAqQAI6WU5bbDCiGKUKe0AaRJKYeabw8F5gAeQDRwp5SyyiI8ermqBST+Bj/dAXfMh2b9y9wlpWT+xNc4mLiDuz+YSiNvn5L7irKzSRo6DIPJRO7zn7H6l2T63t4Cz6uMvPDPC2xN38qA4AG81OUlvBxrb8z1dMFp3lj/BouTFzM4dDATe02stIues3w5B54YQ+BnU8lwb8uSaXEMHdsBDz8D0x+7h3YDrqPfvQ+rxms/hmXjYcxWVTpBuyyczT3D/oR40uJjSYuLKSnsZ3J2IbhNO/PQUwca+fhe/FDNqXT4sAX0fAb6j7dA9OXJggLyk5JKehV5CQnkJyZSfEYd3GOpFVF1so9BCPEecExKOVEI8SLgJqX8XwXtTkkpyw0UCiF+BuZLKecIIaYBsVLKz6u6rk4MFlCYrza7tRwMN5UvYZydmc43zz6Of3hLbnnpjTJ/fKc3bSLt3vtwvrof0WH3kJGWw8iXO+HkYce3Cd8yddtU7I32jL1qLH0C++Dj5FPu+asVYnEhJ/JPkHwymdfWvcaBUwd4vMPj3B9xPzaGirvisrCQlFtHUZiVRbPly1g2axdpO7K49/2ebFk0lzVzvuWeDz9XJTAApvUCgxEeWnlJMWp149TxY+yPjyU1LpbU+BhOZamjMl29fMy9ifYER7THsVHj6j3hdzepw3uejLFaT7H0iqjSQ1FFJ04ggRPODmQF+dP/3Y9wDA+/pGvUVWLYBfSVUh4WQvgBq6SULSpoVy4xCPXOkgH4SikLhRDdgNeklAOruq5ODBay8Am1TPO5PWBX/pNJzF9/sGLG5wx8dCwRfQeUuS9r5jekv/sujk++yLK9obh6OnDL85HY2BpIOZnChHUT2Jq+FQAvBy/C3cNxtXXFZDRhMppwNDriaOuIvY09OWdzOJ53nGN5xziWd4zj+ern7PxsJOrfq7ejN+/2epco3wv/G8+c/iUZH31EwORJGLv3Y9a4tUT0CaDH8KZ8/eSDNPL2ZeSrb6vGGbtgamcYNBG6PloL/0G1uiCl5PjhQ+b5iW3s3xFH/pnTAHg1CS3pTQS2bIOtyVTxk2ybDQsfgwdWQGCV75t14mzuGZJjt7Jv7T8kb99GXl4uArht/Nv4RbS7pOesqw1uPlLKc0d6HQEq+yhoEkJsAQqBiVLKX1HDRyeklOeqtR0AAmoYj1YT7UbCtu9g959qtcZ/tL/mOnau+4dV335JSPurcHZzL7nP/Z67yd0eS86n79F9whesXJnDuvl76XVrOCGNQpg5aCbxmfHEZcYRnxnPvhP7OJhzkNzCXHILczlTeKbk9CqBwM3khpu9G24mN5o3bo6byQ13k3vJ965+XWlk3+iCLyd/714yP/kEl2uvxXXQIKL/TKG4SBLRO4DkbdFkZ6TT5477zj8gbi4Ig97UVs8IIXD3D8DdP4AOA6+nuKiIo8l7S5bFxvz5G9G/L8BgY1Sn07VV8xO+TcNV6ROAVkPg96ch7herJobsjHT2RW9kX/QmDiTEUVRYiMnJmZDIzjSN7ExIh0jLrtIyq7LHIIRYDvhWcNfLwCwpZeNSbY9LKcsVSRFCBEgpDwohwoC/gf7ASWCDlLKZuU0QsERKWeExWUKIh4CHAIKDgyNTUy/fw0PqreIimBQB/h3gth8rbHLs0EG+feEJwjp2YuizL5V9+OnTJI+8lcLMTA7d+SEJ8flc93BbwjpWb26hoKiA3KJcnIxOlQ4NVZcsKiJl9GgKUtMI+/03DO4efPfKOhp7OzJsbEfmT3yN9JQkHvx0htqFKyV83EGtUrnr8i6spl2cgvw8Du5KJHX7NtLiY0lPSQIpsXNwIKhNu5KJbPc1LyH2b4RnEssswLAkWVzM4b27Sdq6iX3Rm8hMSwHUkt0wczIIaNH6fAKroVrrMUgpB1R2nxDiqBDCr9RQUoWnZEgpD5q/JwkhVgEdgXlAYyGE0dxrCAQOXiCO6cB0UENJVcWtXQKDjeopbPxCbfJydC/XxN0/gO4jbuffH75h94Y1hHftef7hTk4EfTGN/Q88iM+XYzg6+H1WzErA1SsSz8CqP+XY2thia2NbKy/l2DezyIvdjv8HH2D09CR5eyanjuXTc0RzTqYfITkmmq43jzpfmuFgtKo02/v5Wrm+dvmwtTcR0q5jSWmX3Jxs0uK3kxYfQ1pcLPu2bATAydmRYENjmsz7jOABo8sUkKxNBXl5pMRtY9+WjSRv28KZkycQBgMBLVvT5477CIvsgru/dQdPapoWFwF3AxPN38t91BJCuAFnpJT5QghPoAfwnpRSCiFWAsNRK5MqfLxWx9qNhPWfQsKvEHVfhU2ihtzE7g1rWDFjGoGt2+Loen5Ixy4wkCY//sCBxx6nxZJX2dZ7Ar99EsPNz0XSyKtuCqnl79tHxpQpOPfvj+v1avdq/OoDODWyI7SdJ2t++hYhBO36l5rOip2j9nK0tMzSWu3y4eDiSotuPWnRTX2oOZl+hNS4WNLitpKyOZvEeUth3lLc/QNL5ieC2rTF3tHpkq+ZnZlB0tbNJEVvJG3HdooKCrB3dCKkQ2TJEJGDs0ttvcQaq+nkswfwMxAMpKKWqx4TQkQBj0gpHxBCdAe+AIpR5z9MllJ+bX58GCopuAPbgDuklPlVXVdPPluQlPBZVzA1gvuXVtosIzWZ78c9TXjXHlz/ZPlP2cV5eRx64X8cWbOdrV1exMHDhZufj8Spkb0lo1erkG4bTcH+/YT9/htGT09OZpzh+/Eb6HxDKB2vDWT6o3cT0LINw557WT2oMB8+CFfLdIfPuPAFtAZN/v4MGevnkxb1Nqk7d3EgMZ7C/HyEMODbtHlJ2Q7/8JYX3Ggni4s5mrSXfeYhooyUJAAa+/iZh4i6ENCydZ0XE6yTyWcpZRZqvuC/t28BHjD/vA5oW8njk4DOFd2nWYkQ0GE0LHsVMveAZ8U7ib2ahNL1lltZ9/Nswrv2oHnn7mXuN5hMBEyehN0nn1A4ezIxHcfy26Rohj4ThaOr5TbuZH31NXlxcQRM+gijpxoKiFt1EINB0LqHP3s2rCE3J5v21w4+/6BdSyDvhHrd2hVNdLwd7y1f4x18lqibXlcb7XbvItU87LRp4S9sXPATRls7/MJbEtS6LU5u7tgYjQiDgeOHD3Fk326O7N1N3qkchDDg36IlvUbfQ9PILrgHBNaLkhj6PAatvJyj8FErVURuwGuVNisqLGT2y89w+vgx7vnwMxxcKq4wmf3nX8RNnMn28HtxbGTP0Gc74eZ76d3yyuTt2kXy8BG4DOhP4KRJ6tpZufwwYSPNorzpe3tzZj33GAYbI/d8MPX8jtnZI+HIdnh6h5pn0a5cUsLULqq4ZAU95vwzZziQGKeKAe6IIyM1ueRoXAAhDHgEBePbNJyg1hGEdIgsM9Rqbfo8Bu3SufhA82vUuHu/8ZW+WdoYjQx6dCyzX3qav2d+UeGQEoDroIFEhjTB/rl32Fp4M3PfXM/gJyMJCK+9U77yEhI48NRYbFxd8X311ZLbNy5MAgFdhoYR/fsCThw5rDbonUsKOUdg73KVBHVS0M71mJdPgKx95Xa/2zs60jSyC00juwAqUeSfOU1xURHFRUW4uHtUvleiHtFnPmsV6zAacg7Dvr8v2Mw7JIyuN49i59rV7Nm4rtJ2ppYt6Tj7Y3rb/YsxO51FH0Wz69+UGocppeTY7Nmk3DoKefYsQVM/xeimEk56aja7Nx2lQ/8gZHE2G+b/RPPO3QlpX+pAoe0/q0qq7fUwkmbW7la1nyW24iXbpdk7OuLq6UVjH1/c/QMaRFIAnRi0yoRfBw7uEDO7yqadbxyBd2hTln01lTPZJyttZ9OoES0+e49rO+XgcjKJ5bOTWP/5yos6SF4WFHA2NZXspUtJnzKF1Dvv5Oj/vYlT9+6E/roAhw7qYB0pJWvn7sXBxZarBjZh9XczQEr63vVAqSeTEPODqq7pdWklBrQGyNUPmvZTPWZzVdQrjR5K0ipmtFNLV7fMqHRPwzk2RiPXPfY0348by4qvPmPI0y9WOsEmDAYCxjzMsB6xLJ28jq2xLTj+xAw6dzJidHXF4OyEzM+n+NQpinJOUZSVRWFWFoVZmRQcOEjBoUMlh7pjY4N9WBg+417E7c47y1TZTI7N5NCeE/S5LZzDe+PZvWEN3UfejquX9/lgDm2DjEQYMqlW/pNpDUj722De/ZC8GppWXLW3IdOJQatch9GwcZo6p6Hzgxds6hkcQrcRt7Pmx1nsWv8vLbv3vmB7l6vac9P0lqz4v0Xszgrl5B87aJM4BdvC3DLthIMDRg8PjB4eOLSNwHXwYOyCg1TlyfBwDBV03Q/tPcHf3ybi5utIk7aO/PjKFBr7+NHphlvKNtz2HRhN0KZ8+Q/tCtdyCDi4QfQ3OjFoWhl+7cG3LUTPgk4PVFl1stMNN7Nv8wZWfPUZvmHNaezrd8H2BpM917w1Ar/V+/n3J9jefAoDhjTGzc8ZG2dnDM7OGBwublPcvq3pLJuRgIuHicGPRrDkk7fIzT7JqDfeK7vuPD9HzS+0uVmtQNG00mxNat5p0xeqLLezd9WPaUD0HIN2YVH3wdE4dexnFQw2Ngx+8nkQgl/f/z/yzbXlqxLRJ4hhT3ck/yz89stxdqXYIhp7VDspnD6ZT8r2TNbM3cOfX8bjFezMLc9HEvPXHPYnxDHgwcfxCWtW9kFxc+HsKYi6t1rX0K5AkfdAcSFs+97akdQ5vY9Bu7D8HPiwpepa3/xFtR6SFh/L3LfGE9oxihufe6XaJ2xlZ+ayYlYih/acwNnNnqjBIXgEnK+xVFhQTEF+EQV5hRw/coaM/TlkpOVw5qT5bCcBzSK96X9XK/ZsXsPij9+nw8Ah9L/vkbIXkhK+6A2yGB5Zo09q0yo383rIPgBjtkFNT4q7DOh9DFrtsHdRy/e2fQ+D3rngJPQ5wRHt6XvXg6z85gvW/vw9PUfdVa1LuXo6cOMzHTmQeJwNi5JYNXtXpW2FQeDm60hQK3e8glzwCnbBM8gZO5OR9JQklk77mICWrcuuQjrn4Fa1oe36D3VS0C4s6l41CZ20stzJhg2ZTgxa1TrdD1u+VktXu4+p1kM6DhpCZloyGxf8jLObBx0GXl+txwkhCGrtTmArN44mZ5Ofaz6uQ4LR1oCtyQajnQ2uHiaMduU3pJ3JPsnCD97E5OzMDU+Pq7gWzZYZYOsEbUdWKybtCtbqBnD0gOiZOjFoWhk+bSCoq3pD7fp4tbrUQgj63/8Yp0+eYMWMz7E1mWjTp/p/WEIIfMMurpRAcVERf0x5l9MnjjPqtXdxalzBzurcE2qVVftbwVRxCQ9NK2G0V6vz1n+mdsm7VHQ0TcNT/wfNtLrR6X51Jm7yqmo/xMZo5IaxLxIc0Y6/Pp9ywZ3RNVVcXMTfM6eRFr+dAQ88jm+zSjasxc6BwtxKS4prWjmR96rd8VtmWjuSOqMTg1Y9rYepLvWmry7qYUY7O4Y9Px7f5uH8PuVd1sz5jsKzZ2s1tDPZJ5n/zmvELltC1A03lzuPukRxkVp+GNhJLcXVtOrwaArNB8Lmr6Agz9rR1AmdGLTqMdqr5Xu7FqviYhfBzuTALeNep2WPPmxc8BPf/e9JDuzcUeOQigoLSIvfzvfjxnIgMZ5rH36y7BnO/7Vrier1dHu8xtfWrjDdHoMzmepM6CuAXq6qVV/OEZjcFq66S63ouQQpMdEs+2oq2RnpRFx9Db1G31OtssSFZ8+SmZbC0eS9HE3ex9GkvWSmpVJcVIiLpxdDn3kJ36YVnx1RYsZ155ce1tGZvloDISVM66m+P7q23q5mq5PlqkIId+AnIARIQZ3gdvw/ba4GShejaQmMklL+KoT4BugDnKu8do+UMqYmMWkW5OKrVvJsmw1Xv1ytpav/FdIhkrs/mMr6uT+ydfFC9m5aT7fht+HmH6gaSEnB2XzO5uaSf/o0mftTOJq0l6wDaRSbaySZnJzxDm1K5PXD8AlrRpN2HTE5VXGm9MFoSFsHA9/RSUG7eEJA10dh4eOQtKrBl8mo6dGe7wHHpJQThRAvAm5Syv9doL07sBcIlFKeMSeG36WUcy/murrHYEXpieroz6tfgT4Vn79QXVkH0lgxYxr7d2yvtI3JxRWf0KbqK6wZPmHNcPXyufhTsObeD3uWqsN49Gok7VIU5sOkCPDvALfXzyGlutrgNgzoa/55FrAKqDQxAMOBJVLK6tVK0C4/3q2g2QDYNF3tabC99PrzHoHBjBj/FhmpyWUmpG1NJuxMJmxNDji4uNb8KMQT+2HHAvWJTycF7VIZ7VXNsFVvQ8buBl2qvaaTzz5SysPmn48APlW0HwX89/SLt4QQ24UQk4QQlZ4UL4R4SAixRQixJSMjowYhazXWfQycTq+ViTghBN4hYfiHtyz58goOoZG3L46ujWrnfNyN09T3Lo9cuJ2mVSXqPlWRd+1ka0diUVUmBiHEciFEfAVfw0q3k2pMqtJxKSGEH9AW+KvUzeNQcw6dAHcu0NuQUk6XUkZJKaO8vLyqCluzpNA+qurq2ilqCejl7FSG2pgXcQs0DrJ2NFp95+yl9jXEzlEr3BqoKhODlHKAlDKigq+FwFHzG/65N/70CzzVSGCBlLKg1HMflko+MBPoXLOXo9UJIaD385C1R1UpvZytmwKFedDnBWtHojUUPZ4CgxH+vbSVefVBTYeSFgF3m3++G1h4gba38Z9hpFJJRQA3AvE1jEerKy1vAJ+2sHoiFBVaO5qKncpQG/LajgDPKpayalp1ufqpPT2xc+B4irWjsYiaJoaJwDVCiD3AAPPvCCGihBAlW2SFECFAELD6P4+fLYSIA+IAT+DNGsaj1RWDAa4ep7rT2+dYO5qKrZ0MRfnQW/cWtFrWcywIA/z7kbUjsYgarUqSUmYB5SqjSSm3AA+U+j0FCKigXb+aXF+zshaDwa8DrH5Xlea2sbV2ROedSofNX6t9F57Nqm6vaRfD1R+uultVXe39HDQOtnZEtUqXxNAunRBqo9uJNFWS+3Ky5lxvoWZ7LTStUud6DSvfsXYktU4nBq1mml8DAVGw6l04e9ra0ShZ+9Q+i/ajdW9Bs5xGgWpvTOwPcGibtaOpVToxaDUjBFz7JuQcunzGW/96WW1G6j/e2pFoDV2v58DJC5a8qOooNRA6MWg116SbGstf97H113bvXQ67l6ghpCvkUBXNikyu0O8V2L9B7a5vIHRi0GrHNW+AjR38+ZL1YigqUNd3D1NdfE2rCx3vVEu3l02AglxrR1MrdGLQaoern/qUvnsJ7FlmnRg2fwWZu2Dg22ooSdPqgsEGBr0NJ9NUNYAGQCcGrfZ0fRTcm8KSF+p+IvpYMvz9JjTtD+GD6vbamhbaGyKGwz/vw+HKqwXXFzoxaLXHaA83TFZv0ktfqbvrFhfBgkfU0sEbptTbQ1S0em7w++r42wUPqxLd9ZhODFrtCu0N3Z9Qhet2Lq6ba66doib/Bn+gC+Vp1uPoDkM/gfQEWFW/9zboxKDVvn7jVfXVRU9AzlHLXutwLKx8G1rfCO1GWvZamlaV8IFqMnrtFEjbaO1oLplODFrtM9rDLV+reYZfH7Fcae7c4+pkNkcPGDJJDyFpl4eBb6vNb7/cDdmHrB3NJdGJQbMMrxYwaCLs+xsWP1/7m38K82HOHXAiFYZ/fUnnT2uaRZhcYdSPkJ8DP9wK+aesHdFF04lBs5yoe1Xt+i1fq9UataW4GH59FFLXwLDPIKRn7T23ptUG3wgYPhOOxsO8By7/A63+QycGzbIGvA7tb4OVb8GWmTV/Pilh2XiInwf9J0C7ETV/Tk2zhPBr4br31N6e35+unXNLzp6p+XNUg04MmmUJoVZqNLtG/XH888GlDysV5MH8h2D9p+pQ9p5P126smlbbOj8IvZ6FrbNgzuhLH1aSEqK/gSnt66TsTI0SgxBihBBihxCiWAgRdYF2g4QQu4QQe4UQL5a6PVQIsdF8+09CCLuaxKNdpmxsYeS36tzlv/8Pfr5Tjb9ejJwj8M1giPtZ1aYZ/IGebNbqh/6vwvUfqTpeMwdd/IT02dNqn85vT4FPG7B3tUycpdS0xxAP3Az8U1kDIYQNMBW4DmgN3CaEaG2++11gkpSyGXAcuL+G8WiXKztHuOUrtWJj52L4sh8k/q7mCy6kMB82fQnTekH6Trj1e1V6QycFrT7pdD+M/hmOpcDUrqoScVXDQlJC8r/wZX/Y/hP0fQnumAdOnhYPV8haWC0ihFgFPGc+ue2/93UDXpNSDjT/Ps5810QgA/CVUhb+t92FREVFyS1byl1Kqy+S/4FFY9R5uV4tofsYCOoKbk1U76IgDzJ2Qtp6WPcJZB+E4O5qZ6lvhLWj17RLl7kHlo5X8w4ufqqMTHA38G0HtiY1D5FzGFL+hQ2fwZE4cPKGm7+ApjU/8FIIES2lrHR055waHe1ZTQHA/lK/HwC6AB7ACSllYanbyx3/qTVAob3hiWhI+FV9clr4uLpd2KhS2TlHQJpXcQR1gWFTIayv7iVo9Z9ncxg9B1LXwfLXYNmr6naDreoJnDoK0tyL9mql5ufajgBbhzoNs8rEIIRYDlRU2P5lKeXC2g+p0jgeAh4CCA5uWOerXpFsjNB2uJp3OLhVVUXN2gcn96vzc33aqFLGHk11QtAanibd4f6lkH0YDkbDwS3qnHJXf3ANUD3p4K5W+7dfZWKQUg6o4TUOAqUL2ASab8sCGgshjOZew7nbKyz0EjUAAAOFSURBVItjOjAd1FBSDWPSLhdCQGCk+tK0K42rH7gOgVZDrB1JGXWxXHUz0Ny8AskOGAUskmpyYyUw3NzubqDOeiCapmlaxWq6XPUmIcQBoBvwhxDiL/Pt/kKIxQDm3sATwF9AIvCzlHKH+Sn+BzwjhNiLmnP4uibxaJqmaTVXK6uS6ppelaRpmnbxqrsqSe981jRN08rQiUHTNE0rQycGTdM0rQydGDRN07QydGLQNE3TyqiXq5KEEBlAqrXjuEieQKa1g6hj+jVfGfRrrj+aSCm9qmpULxNDfSSE2FKdZWINiX7NVwb9mhsePZSkaZqmlaETg6ZpmlaGTgx1Z7q1A7AC/ZqvDPo1NzB6jkHTNE0rQ/cYNE3TtDJ0YrACIcSzQggphLD84a1WJoR4XwixUwixXQixQAjR2NoxWYoQYpAQYpcQYq8Q4kVrx2NpQoggIcRKIUSCEGKHEOIpa8dUF4QQNkKIbUKI360di6XoxFDHhBBBwLVAmrVjqSPLgAgpZTtgNzCuivb1khDCBpgKXAe0Bm4TQrS2blQWVwg8K6VsDXQFHr8CXjPAU6gjBBosnRjq3iTgBeCKmNyRUi4tda73BtRJfQ1RZ2CvlDJJSnkWmAMMs3JMFiWlPCyl3Gr+OQf1Ztmgz20XQgQC1wNfWTsWS9KJoQ4JIYYBB6WUsdaOxUruA5ZYOwgLCQD2l/r9AA38TbI0IUQI0BHYaN1ILG4y6oNdsbUDsaQqz3zWLo4QYjngW8FdLwMvoYaRGpQLvWYp5UJzm5dRQw+z6zI2zfKEEM7APGCslDLb2vFYihBiCJAupYwWQvS1djyWpBNDLZNSDqjodiFEWyAUiBVCgBpS2SqE6CylPFKHIda6yl7zOUKIe4AhQH/ZcNdHHwSCSv0eaL6tQRNC2KKSwmwp5Xxrx2NhPYChQojBgAlwFUJ8L6W8w8px1Tq9j8FKhBApQJSUsj4W4qo2IcQg4COgj5Qyw9rxWIoQwoiaXO+PSgibgdGlzjdvcIT6hDMLOCalHGvteOqSucfwnJRyiLVjsQQ9x6BZ2qeAC7BMCBEjhJhm7YAswTzB/gTwF2oS9ueGnBTMegB3Av3M/29jzJ+mtXpO9xg0TdO0MnSPQdM0TStDJwZN0zStDJ0YNE3TtDJ0YtA0TdPK0IlB0zRNK0MnBk3TNK0MnRg0TdO0MnRi0DRN08r4f53EOBxvUXmpAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# batch the inference across K=100\n", "targets = np.sin(xrange_inputs)\n", "predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n", "plt.plot(xrange_inputs, predictions, label='pre-update predictions')\n", "plt.plot(xrange_inputs, targets, label='target')\n", "\n", "x1 = onp.random.uniform(low=-5., high=5., size=(K,1))\n", "y1 = 1. * onp.sin(x1 + 0.)\n", "\n", "for i in range(1,5):\n", " net_params = inner_update(net_params, x1, y1)\n", " predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n", " plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batching Meta-Gradient Across Tasks\n", "\n", "Kind of does the job but not that great. Let's reduce the variance of gradients in outer loop by averaging across a batch of tasks (not just one task at a time). \n", "\n", "vmap is awesome it enables nice handling of batching at two levels: inner-level \"intra-task\" batching, and outer level batching across tasks.\n", "\n", "From a software engineering perspective, it is nice because the \"task-batched\" MAML implementation simply re-uses code from the non-task batched MAML algorithm, without losing any vectorization benefits." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def sample_tasks(outer_batch_size, inner_batch_size):\n", " # Select amplitude and phase for the task\n", " As = []\n", " phases = []\n", " for _ in range(outer_batch_size): \n", " As.append(onp.random.uniform(low=0.1, high=.5))\n", " phases.append(onp.random.uniform(low=0., high=np.pi))\n", " def get_batch():\n", " xs, ys = [], []\n", " for A, phase in zip(As, phases):\n", " x = onp.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))\n", " y = A * onp.sin(x + phase)\n", " xs.append(x)\n", " ys.append(y)\n", " return np.stack(xs), np.stack(ys)\n", " x1, y1 = get_batch()\n", " x2, y2 = get_batch()\n", " return x1, y1, x2, y2" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xt8VNW98P/PmplALsAEL5FJggWUIgKBICqV2lqwoCdyeahcrK3VXnx8tE8j9SBYFAKFGg89YvxVjz+tx1pOW4gpRWNawIJtvRQtEgjekBZpIRkMKBklF5JJ1vPHZEJmZu/JZe4z3/frxYtkz87slct8Z+3vWuu7lNYaIYQQqcUS6wYIIYSIPgn+QgiRgiT4CyFECpLgL4QQKUiCvxBCpCAJ/kIIkYIk+AshRAqS4C+EEClIgr8QQqQgW6wbYOa8887TI0aMiHUzhBAiobz11lsntdbn93Re3Ab/ESNGsGfPnlg3QwghEopS6p+9OU/SPkIIkYIk+AshRAqS4C+EECkobnP+Rtra2jh27BgtLS2xbkpSS09PJz8/n7S0tFg3RQgRIQkV/I8dO8bgwYMZMWIESqlYNycpaa35+OOPOXbsGCNHjox1c4QQEZJQaZ+WlhbOPfdcCfwRpJTi3HPPlbsrIZJcQgV/QAJ/FMjPWIjkl3DBXwghROgk+PdRQ0MDjz/+eL++dsSIEZw8eTLg+LZt2xgzZgwXX3wxpaWlhl975MgRfv3rX/fruldddVW/vi4eVB2uYmbFTAqeLWBmxUyqDlfFuklCJIWEGvCNB97gf+edd4bl+drb27nrrrt46aWXyM/P5/LLL2fOnDlceumlPud5g//Xv/71gOdwu93YbOa/ytdffz0sbY2W+7ceoPzd50m74AWUtRlvFsrZ6KTk9RIAikYVxa6BQkTA1upa1m8/SF1DM7nZGSydNYZ5hXkRu15S9/y3VtcyrXQXI5dXMa10F1ura0N+zuXLl/OPf/yDSZMmsWTJEmbMmMHkyZOZMGECzz//PACNjY0UFRUxceJExo8fz+bNm32eo7m5meuvv56nnnqKN998k4svvphRo0YxYMAAFi9e3PU8/td95ZVXmDRpEhs2bOAXv/gFc+bMYfr06cyYMYPTp08btgVg0KBBAPzpT3/immuu4cYbb+SSSy7h5ptvRmsd8s8knO7feoBN7z7PAMcWLLazgd+rpb2Fsr1lsWmcEBGytbqW+7YcoLahGQ3UNjRz35YDYYlZZpK25+/9YTa3tQNnf5hASO+mpaWlvP322+zbtw+3201TUxNDhgzh5MmTTJ06lTlz5rBt2zZyc3OpqvKkKFwuV9fXnz59msWLF3PLLbdwyy23UFFRwfDhw7sez8/P54033jC87k9/+lNefPFFAH7xi1+wd+9eampqOOecc3C73fzud78LaIv/4G11dTXvvPMOubm5TJs2jddee40vfvGL/f55hMPW6lpKXniHhuY2ALIu2o6ytJmef7zxuM/nVYerKNtbxvHG4wzLGkbx5GK5MxBxr3tP36IU7X4dsea2dtZvPxix3n/S9vzXbz/YFfi9vD/McNFa86Mf/YiCggKuvfZaamtr+eijj5gwYQIvvfQSy5Yt45VXXsFut3d9zdy5c7ntttu45ZZbQr7+V7/6Vc4555ygbfF3xRVXkJ+fj8ViYdKkSRw5ciTkdoRia3UtS5/b3xX4AVRaQ9CvGZY1rOvjqsNVlLxegrPRiUZ3pYZkbEDEM/+evn/g96praI5YG8IS/JVS1ymlDiql/q6UWh7kvK8ppbRSako4rhuM2Q8tnD/MX/3qV5w4cYK33nqLffv2ccEFF9DS0sLnP/959u7dy4QJE7j//vtZs2ZN19dMmzaNbdu2daVb8vLyOHr0aNfjx44dIy8vjzfeeINJkyYxadIkXnjhBcPrZ2Vl9dgWfwMHDuz62Gq14na7Q/45hGL99oO0dfj+4eu2bNPz063pFE8u7vq8bG8ZLe2+36ekhkS8M+qcGsnNzohYG0IO/kopK/AYcD1wKXCTUupSg/MGA8VAYE4jAsx+aKH+MAcPHsxnn30GeNI5OTk5pKWl8fLLL/PPf3oqqdbV1ZGZmck3vvENli5dyt69e7u+fs2aNQwdOpS77roLgMsvv5xDhw7x4Ycf0trayqZNm5gzZw5XXnkl+/btY9++fcyZM8fnukbM2hLPtlbXUmvwZnzmxCx0R2BpieyB2ZRcVeKT0vFPAXnVnXaGbZxHiHDrTSc0I83K0lljItaGcPT8rwD+rrU+rLVuBTYBcw3O+zHwEBCVpaNLZ40hI83qcywcP8xzzz2XadOmMX78ePbt28eePXuYMGECv/zlL7nkkksAOHDgAFdccQWTJk1i9erV3H///T7PUVZWRnNzM/feey82m42f/exnzJo1i7Fjx7Jw4ULGjRsXcN2CggKsVisTJ05kw4YNAY/ffPPNhm2JV97bXiPuTwtpcc6nozUbrSFDnUfp1aW8sviVgFx+9xRQd7otOyqDZkL0h1kn1KoUCsjLzuDB+RMiOttHhTrbQyl1I3Cd1vq7nZ9/E7hSa/39budMBlZorb+mlPoT8O9a66A7tUyZMkX7b+by3nvvMXbs2F63LdpTp5JJX3/WfbF610bKDz+JsjWg27I5c2IW7k8Lfc5JsyrW3zixx9+XN+ffPfWjO9Jocc7ves687AxeWz49/N+IEL3kPylh2jnfZNPL5/ukfjLSrGEJ+Eqpt7TWPabWIz7bRyllAR4Gbu3FubcDtwNceOGFIV97XmGeBPs4s3rXRp775wYsaZ4BXjWggXTHFlqgK1gPzUxj1exxvfrdee8EyvaWUXfaafhmEslBMyF64t9BcTY6ebHlURZ/5QfseDMvZp3TcAT/WmB4t8/zO495DQbGA3/qnHY4DHhBKTXHv/evtX4SeBI8Pf8wtE3Emd9++BTK5juNU1naGHj+dtyfFvarl140qoiiUUVMK91lOIaggWmlu+TOT8RE6ZulhpMSXvtkI68t3xGjVoUn5/83YLRSaqRSagCwGOianqK1dmmtz9Naj9BajwB2AwGBX6SGDuspw+MqrSHkMRmjcR4vyf+LWKg6XEXDGeOpy2aTFaIl5OCvtXYD3we2A+8B5Vrrd5RSa5RSc0J9fpFcLO1DDY9rd3bI+c55hXk8OH8CeSaDaeFe5yFET4JNOTabrBAtYZnnr7X+vdb681rri7TW6zqPrdRaB0xQ11pfI73+1PW1kd8LmMapO9JYOOr2sKRk5hXm8dry6ZgVpZb8v4imYL377utVYiFpV/iK+LRq+jdZ8LklKPdQtAblHsqCzy1h1fRvhvU6kVrnIURfmPXu7QPsMS9BIsG/jyJR0vnb3/42OTk5jB8/PiLX/bd/+zcaGoKXTIimVdO/Sc13/sLbtx6g5jt/CXvgh8it8xCiL4onF5NuTfc5lm5N574r74tRi86S4N9HoQRhM7feeivbtm3r93V7KtHw+9//nuxs85IJyah7/j9ai2aE8Fc0qoiSq0pwZDlQKBxZjoBV6rGStFU9Aagph51rwHUM7PkwYyUULAzpKbuXdP7KV75CTU0Np06doq2tjbVr1zJ37lwaGxtZuHAhx44do729nQceeIBFixZ1PUdzczPz589n/vz5fO973+NLX/pSjwXWul/3q1/9KkVFRTzwwAMMHTqU999/nw8++IB58+Zx9OhRWlpaKC4u5vbbbwc8dxx79uzh9OnTXH/99Xzxi1/k9ddfJy8vj+eff56MjORMhcg6DxEPvFOR447WOi7/XXbZZdrfu+++G3DM1P7NWq+9QOtVQ87+W3uB53gIPvzwQz1u3DittdZtbW3a5XJprbU+ceKEvuiii3RHR4euqKjQ3/3ud7u+pqGhQWut9ec+9zn94Ycf6hkzZuhnn33W9Hl7uq7WWr/88ss6MzNTHz58uOvYxx9/rLXWuqmpSY8bN06fPHmy67onTpzQH374obZarbq6ulprrfWCBQv0xo0bDa/Xp5+1EEJrrXXJzl/qCT+/Wo97Zrye8POrdcnOX0a9DcAe3YsYm7xpn51roM1vZkdbs+d4mOgYl3S+4oorGDlyZNfnjz76KBMnTmTq1KkcPXqUQ4cOBXzNyJEjmTRpEgCXXXZZxEo6R2IjHSHimXf1uradQinQtlM8988NrN61MdZNM5S8wd91rG/H+yEcJZ3NHD16tKuk8xNPPGF4TveSzn/605/44x//yF//+lf2799PYWFhzEo6x2JXor5Yu3stE385kQnPTmDiLyeydvfaWDdJJIHffvhUwCZEytLGbz98KkYtCi55g789v2/HeyncJZ3NDB8+vKuk8x133NGrks5Dhw4lMzOT999/n927d4f0fYYiGhvp9Nfa3WvZfHAzHboDgA7dweaDm+UNQITMbPW62fFYS97gP2MlpPkNZKZleI6HINwlnQFuuukmvvCFL3Dw4EHy8/N5+umng1536dKlAY9fd911uN1uxo4dy/Lly5k6dWpI32coorGRTn8998FzfTouRG+ZrV43Ox5rIZd0jpRwlHSOxGyfVNHfks5bq2u5p3y/4bZ08VBaecKzE0wfO/At4/0FhOgNb86/e+pHd6RFZBFjMHFT0jmmChZKsI8ib67fKPDHywIri7J0pXx8aM+LN5ovUpFcVk3/Juzy5P47rKewtA/lxpHfi9u/qeQO/iKqzPYltSoVNwusFnx+AZsPbg58QMFzh59k4tDpcdFOkZhWTf8mq4jPYO8veXP+IurMcvodWsdNQL1/6v3mD9oa4mJQWohokJ6/CJvc7AzDzVSiVUzNVVlJ/YZHcDud2BwOcpbcjX32bN+Tasq5ocbN9a/AuZ/Cx0Pg19coXhtnRbdlx8WgtEgMib5NrAR/ERZVh6tQFz7MoGH1PlspRivX76qsxPnASnTn2gZ3XR3OBzwzu7reAGrKcW1YwjfeHIyl3VP0+fxP4X//XmPp0Ow4Z5ZU/RS9srW6lqUV+2lr94xv1TY0s7RiP0DCvAFI2keEzLtHqautHqXA0rkv7/nD3olarr9+wyNdgd9Lt7RQv+GRswd2rqF+f2ZX4PdKd8MdL7fxwZmf8pK60zNLTIggVle+0xX4vdraNasr34lRi/pOgn8fRaKk87Zt2xgzZgwXX3wxpaWloTYRgCNHjgQtER1OZXvLAvYoVZY2hub/MWq9ILfT2fNx1zHcTcbbPNoaFdcNdzD1Ahsz31xFVcVNkWimSBKnmtr6dDweSfDvo3CXdG5vb+euu+7iD3/4A++++y6/+c1vePfdd8P2/NFgtltRNPcotTkcPR+352PLDJyNBPDxEIUzzYZWnv9LPquh6k8PRKKpQsSFpA7+VYermFkxk4JnC5hZMZOqw1UhP2f30spLlixhxowZTJ48mQkTJvD8888D0NjYSFFRERMnTmT8+PFs3uw7tbC5uZnrr7+ep556ijfffJOLL76YUaNGMWDAABYvXtz1PN0tXryYqqqz7b/11lupqKjgyJEjXH311UyePJnJkyfz+uuvh/w99pXZbkXR3KM0Z8ndqHTfTTNUejo5S+4+e2DGSnImNqGsvvP8W23wq2t8U0EtFgtl/9gSsfaKxLW1utZ0m9DsjDSTR+JP0g74evPQ3nSEs9FJyeslACHV1i4tLeXtt99m3759uN1umpqaGDJkCCdPnmTq1KnMmTOHbdu2kZub2xWsXS5X19efPn2axYsXc8stt3DLLbdQUVHB8OHDux7Pz8/njTfeCLjuokWLKC8vp6ioiNbWVnbu3Ml//dd/obXmpZdeIj09nUOHDnHTTTfhvzI60oonF/v8rMGzW1E09yj1DuoGne1TsBD7EuBn91G/x4K7yYots51HZ6Tx2rjAdNBxq9lLXKQq70JGo7oIaRZFyZxxUW9TfyVt8DfKQ7e0t1C2tyxsGyt4Szr/5S9/wWKx+JR0vueee1i2bBk33HADV199ddfXzJ07l3vvvZebb765T9e6/vrrKS4u5syZM2zbto0vfelLZGRk4HK5+P73v8++ffuwWq188MEHYfne+sL78yzbW8bxxuMMyxpG8eTiqG9gYZ89O3Bqp7+ChdifXEhXke2acv7x5irDU4e526EkW0qDiC7BFjKuXzAxYWb6QBKnfaKRhw5HSee8vDyOHj3a9fixY8fIy8vjjTfe6Crp/MILL5Cens4111zD9u3b2bx5c9fOYBs2bOCCCy5g//797Nmzh9bW1rB9f31RNKqIHTfuoOZbNey4cUd87lxkpGAhxZ82k97hmwpK7+ig+FQDoMF1FLbcDi/+MDZtFHEjERYy9lbSBv9I5aHDXdL58ssv59ChQ3z44Ye0trayadMm5syZw5VXXtlV0nnOnDmAJ/XzzDPP8Morr3Ddddd1tcHhcGCxWNi4cSPt7cYDmuEWifGUADXlsGG8p/e9YXzEpmAWXfsflHzswtHmRmmNo81NyclPKGps6naWhj1PyxtAijNbB5KI60OSNvgXTy4m3eo7ABiOPHS4SzrbbDZ+9rOfMWvWLMaOHcvChQsZN844bzhz5kz+/Oc/c+211zJgwAAA7rzzTp599lkmTpzI+++/77PBS6R4x1OcjU40ums8JaxvADXlUPkDT6/b2/uu/EFk3gAKFlI082F2fGal5sgxdhyr8wv83ez5b1kHkMKWzhpDRprv+FC8FC3sq6Qu6Vx1uCrmeehEFexnPbNiJs7GwHn1jiwHO27cEZ4GbBjfGfj92IfDkrfDc42+XjuabRBxK97LOkhJZzx5aAn24ReVef1R2IbT1IyVnhy/4ZyOKLVBxK15hXlxFez7K2nTPiJyojKvP0LbcPZKwUKY8u2Aw64jGRx6IYf3Njk4NH0GrsrKyLdFJD1XZSWHps/gvbGXRvXvSoK/6LNIjaf4iNA2nL12w8Mw5TvQuZzHdSQD59/suJs8N8vewnHyBiBC4S1I6K6rA62j+nclwV/0WdGoIkquKsGR5UChcGQ5KLmqJLwptoKFMPtRT34d5fl/9qPRnWt/w8Mw/0mwD6e+ZjC63fflElA4Tog+qn9oXc8FCSMkqXP+InKiMp4SD9twdrbB/eSlGI0BuOvqPLN/Yt1OkXhqynGfbACDYhFmhQrDSXr+QvSCaeG4THfkpqCK5LZzjWmhQbO/t3CS4N9HkSjp/O1vf5ucnJywl2AeNGhQWJ8vlRkWjrN2kFPwGbQ1w841Jl8pRCBXZSWHNp7pLDHue0eprB2+BQkjRIJ/H4W7pDN4KnRu27YtrM8pwss+ezaOH6/x9PTR2DLd2Ec0UV8z2DP7Z+MZGfxNQpFYye56bAXO++7tnDygOv9pvH9Xji9be65RFQZJHfwjMYUq3CWdAb70pS9xzjnn9Hjdxx57rOvzkpISfvrTn3L69GnDNojws8+ezehvDmTsYic5BZ/hOpLZ9QJ2N9lk9k+SichK9ppy6v/7ObTb/wGFLbOd0V/7DPsdJSG0uveSdsC3V3u69kO4Szr31qJFi7j77ru7agKVl5ezfft20tPT+d3vfhfQBqWkHHFEzFgJlT8IOvsnGr02EXkRqQy8cw3uRuM+t7vJFtUZbUnb8+/Vnq4h8pZ0Ligo4Nprr/Up6fzSSy+xbNkyXnnlFez2rgLCzJ07l9tuu61PgR+gsLCQ+vp66urq2L9/P0OHDmX48OGmbQiniNz6xmhhS8g6p6B65/v7i8YsDREdEVnJ7jpmPsibmxvVWWNhCf5KqeuUUgeVUn9XSi03ePyHSql3lVI1SqmdSqnPheO6wfRqT9cQhaOks5mjR492lXR+4oknAFiwYAEVFRU+JZ3N2hAukbj1jeXClrAoWOh5oRqIxiwNER1hX8leUw7KQk7BZwG7ySkbURnk7S7k4K+UsgKPAdcDlwI3KaUu9TutGpiitS4AKoD/CPW6PenVnq79EO6SzmaGDx/eVdL5jjvuADypn02bNlFRUcGCBQuCtiFcgt369lc07srCzf/up/bmL/e8baRIWFurazl17Fp0h++2jP1dye56bAWHbruf936TQ33NYOwjms5OHshqx/F/5kc9XRiOnv8VwN+11oe11q3AJmBu9xO01i9rrb01cncDES/Q0qs9Xfsh3CWdAW666Sa+8IUvcPDgQfLz83n66acNrz1u3Dg+++wz8vLycHS+id18882GbQiXSNz6uuvqjI/HacrE6O5n6cBKjv/f/+W5A1AKa3Y2pKdTd++yxEpjiQDerRpPHB9Hi3M+Ha3ZaA32tJx+rWR3VVbi/K8tuButeCcHuI5kklPwGWNvqmf0M2ux37UuMt9MEOEY8M0Dute/PQZcGeT87wB/CMN1g+rVnq799Otf/zro4yNGjGDWrFkBx48cOdL18TPPPNP18W9+85teX/vAgQM+n5933nn89a9/NTz39OnTvX5eM8OyhhmWb+7vra/rsRWmj8VrysTs7ucn9lfZsWun8eSC++6Ff+2OyYtahKb7Vo3uTwtxf1oIgD07g6KvT+/z89VveCRgdo9ut3TeAZyI2erwqM72UUp9A5gCfNnk8duB2wEuvPDCkK/Xqz1dRVBh3Zy9c5obBG6WjlJxmzIxu8vxvikaprHcUP/fz2G/eqKUfkgwZls1mh3viemdbpM1OlVqTYQj7VMLDO/2eX7nMR9KqWuBFcAcrfUZoyfSWj+ptZ6itZ5y/vnnh6FpIlRhLeIWZJobWsftG3Wwu5y1u9eaTy5otMjK3wQUzq0aXZWVYDLt2pbVEb0qtQbCEfz/BoxWSo1USg0AFgMvdD9BKVUI/P94An99KBeL153Hkon/zzhsm7P3NM0tTgW7y3nug+eC1P1pl41fElA4t2qs3/AImMSsnG8viOldYcjBX2vtBr4PbAfeA8q11u8opdYopeZ0nrYeGAQ8p5Tap5R6weTpgkpPT+fjjz+WN4AI0lrz8ccfk+43WB4W9vy4mebWF8He7Dq0pw6L8kugdtX9ieFtveifeYV5PDh/AnnZGSggLzuDB+dP6NfuXcEmMcR6PCgsOX+t9e+B3/sdW9nt42vDcZ38/HyOHTvGiRMnwvF0wkR6ejr5+REIWjNWYm/6AeCivmYw7iYrtqwOcr69IG5TPl4WZaFDdxget8+eDf/aTf1/P4e70YIts52cgs+wj2iG1kYp+ZyAwrVVo83hMMz5x8OdbkKVd0hLS2PkyJGxboboI1dl5dlZV+eOIKfgU0bPqfP0imesTIjAuODzC9h8cLPhcfD04uxXT4Q/LIPmT86e0PyJp+QzJMT3KcIrZ8ndPjPBIH7WgyRteQcRHwJW85504Xx9IK7LNsKStxMmIN4/9X4WjVmERXleMhZlYdGYRdw/tdsajoKFMCAr8Iul5HPK6qoG27kexJabi+PHa+LiTlfFa/58ypQpes+ePbFuhgjRoekzTG97R+/aGYMWRVhJNkY7foGCkoZot0akIKXUW1rrKT2dJz1/EVHRqLEUV8wGeGXgV8QZCf4ioiJVYyluzVgJaX7zwdMyYjqfWwgjEvyFj3CXb45UjaW41VnyGftwQHn+j2KNdiF6K6Fm+4jIqjpcxQOvrqKtcwG2s9HJA6+uAoLPdQ8mkjWW4lbBQgn2Iu7JgK/o8sVfz8DVFrgA256Ww6tfT8LBWSHwVPFcv/0gdQ3N5GZnsHTWmLDM8Y+V3g74Ss9fdHG11nv2kjY6LkQS2lpdy9KK/bS1ezrBtQ3NLK3YD5DQbwC9ITl/0aWjLbtPx4VIdKsr3+kK/F5t7ZrVle/07YlqymHDeM9U3w3jPZ/HOQn+oktm4+yAnYt0RxqZjUmcnxcp7VRTW5+OG6op96zidh0FtOf/yh/E/RuABH/RZcWXb6aj/saunYs6WrPpqL+RFV++OdZNEyJ+7VzjWcXdXQKs6pacv+jiyXF+i/XbpybN4JcQwWRnpNHQHNjLz85IMzjbhFnZ7jgv5y3BX/gIVzVDIRJByZxxLH1uP20dZ/P+aRZFyZxxvX8Se35nysfgeByTtI8IjwQc8IqZF38Iq8+BErvn/xd/GOsWpax5hXmsXzDRp3b/+gUT+9YBStBV3dLzF6GrKcdVdg/11em4m4ZhyzxDztv3YC9GFjv5e/GHsOfps5/r9rOf3/BwbNqU4kK+2/X+je9c40n1JEipclnkJULmuvNSnH9uR7efvZFU1g4cX7Zif/zdGLYsDq0+xxPw/SkrrPok8LgQfSRVPUXU1O9u8wn8ALrdQv3uPkyXSxVGgT/YcSEiRNI+ImTuJuM/I7PjyajXJQKU1bznLxJHTXnCpXn8Sc9fhMx2nr1Px5PN1upa7ttygNqGZjSeEgH3bTnA1urawJMvu9X4ScyOi/iToIu6/EnwFyHLWbYCNcB3XrQakEbOshUxalF0rd9+kOY23958c1s767cfDDz5hodhynfO9vSV1fO5DPYmjgRd1OUvde7LRYCqw1WU7S3jeONxhmUNo3hycb9KN6dk2eZu6hqa+3ScGx6WYJ/IEnRRlz8J/imq6nAVJa+X0NLeAnhq95e8XgL0r3a/ffbslAn2/nKzM6g1CPS52RkGZ4uEl6CLuvxJ2idFle0t6wr8Xi3tLZTtLYtRixLX0lljyEgLHLB1upq5f+uBGLRIRFSCLuryJz3/FHW88Xifjgtz3lk9922pobmto+t4h4b/2f0vANbOmxCTtonwcVVWnk1tnjuCnIJPsefUyWwfkViGZQ3r03ER3LzCPFrdxgsmf/OGQYpAJBRXZSXOH63AXVcHWuM+6cL5ig3XZRthydsJF/hBgn/KKp5cTLrVd2P1dGs6xZOLY9SixNfeuVreNqSarItKGXTJcrIuKkUN3hvjlgnwjHPNrJhJwbMFzKyYSdXhql5/rXPlKnSb76JF3dbGR+t+Eu5mRo0E/xRVNKqIkqtKcGQ5UCgcWQ5Krirp90btAqxKYRtSTbpjC5YBDSgFlgENpDu29CnQiPDzTnBwNjrR6K4JDr35vbgqK9HNxjO32hsawt3UqJHgn4K8PaD7XrkPgAevfpAdN+6QwB+im64czsDzt6Msvj1EZWmTgfQYC2WCQ/2GRyLVrJiS4J9iQukBieDWzpuAZYBxT1AG0mPL2eg0PN6b34vbafy1ACo7cfe3luBV7s71AAAdnElEQVSfYmSKZ2Q5shyGx2UgPXaCdWx6+r24KivBYh4mHSt+1O92xZoE/xQjUzwjSwbS40+wjk2w34urshLnihXQblSIT5F90+KEXtgo8/xTzLCsYYa3wNIzDQ/vuEk4ymaI8AjWsQn2e6l/aB261aAsuUWR+9BDCR34QYJ/yimeXOxT1gF61zP1WeCSYrV7+qpoVJEE+zhi1uExS9F5uU+6jB/o0Enxty9pnxTTnymeztWrqbt32dkFLnV1OB9Y6cmHChHn+puKs2W6+3Q80UjPPwX1pWfqqqyk4TebAo7rlhbqNzySFD0gkdz6m4rLmZpmuD1pztS0IF+VOMIS/JVS1wFlgBX4uda61O/xgcAvgcuAj4FFWusj4bi2iKz6h9aZPhZsCpwQ8aSvqThXZSX1NUPQ7Q2gNGiwZbaTU9iC/Y7/jGBLoyfktI9Sygo8BlwPXArcpJS61O+07wCntNYXAxuAh0K9roiCmnLcJ81XMNocwXOmQiQiV2UlzgdWdub8FWiFsmpypqZhL/7PhKzjYyQcOf8rgL9rrQ9rrVuBTcBcv3PmAs92flwBzFBKqTBcW0TSH5ZhyzTZWFwpcpbcHd32CBEF9RseQbf4roXR7Rbq33ckTeCH8AT/PKB72cJjnccMz9FauwEXcG4Yri0iqfkTcgo+Q1k7/B7QZC9eJPn+XgqloJiIPrN0ZrKlOeNqwFcpdTtwO8CFF14Y49YIAPsIT0Gr+prBuJusnrxnwWfYV62KccsSQ7h3TBORZ3M4PDPbDI4nk3AE/1pgeLfP8zuPGZ1zTCllA+x4Bn59aK2fBJ4EmDJlinFxdBE9GedA8yfYRzR3vQl0HRe9EqycRq+Df025Z3Nw17GE3TgkUbgeW4H+uA7QwNnMtEpPT7o0ZziC/9+A0UqpkXiC/GLg637nvAB8C/grcCOwS2stwT0K+rtJu6uykvrKC3CfHHi2tz+iGawD4HoZr++tkMtp1JRD5Q+grfPN13XU8znIG0CYuR5bgfPx36LbFWcDv0YNysSxanXSpTlDzvl35vC/D2wH3gPKtdbvKKXWKKXmdJ72NHCuUurvwA+B5aFeV/SsvxU8/Wc7uJtsOP9mx1WfB3Mfk6DTB2ZlM5RSvcv971xzNvB7tTV7jouwqn9mS2fg705hpTnpAj+EaYWv1vr3WuvPa60v0lqv6zy2Umv9QufHLVrrBVrri7XWV2itD4fjuiK4/lbwTJXZDtFgtLoUoEN39K6UtutY346LfnOfNk5GmB1PdFLeIYn1N+VgNNgFyTfbIRq85TQsKvCl1qtS2vb8vh0X/WYbZDz73Ox4opPgn8T6s0m7c/Vq08eSbbZDtBSNKsJsiKvH3P+MlZCW4XssLcNzXIRVzm3zUVbf35OyanJumx+jFkWWBP8k1teCVq7KSho2bTZ+MlnUFZL+vBEDnjTb7EfBPhxQnv9nPyrpNxNbq2uZVrqLkcurmFa6i63V/hMPzdnvWofjzq9hGwSgsQ0Cx51fw36XeYmTRKbiddLNlClT9J49e2LdjITXl9k+70/9AjrIhtRj338vUs1Mev7z/cHzRtxTRVXRe1ura7lvywGa286uSs9Is/Lg/AnMK/Rfd5q8lFJvaa2n9HReXC3yEuHX24JWrsrKoIHflpsbzmalHNnkJfLWbz/oE/gBmtvaWb/9YEoF/96S4J+ktlbXsn77QeoamsnNzmDprDFBXwD1Gx4J+nyS8gmdbPISWXUNzX06nuok+Cch/9vf2oZm7ttyAMD0DSDYTJ5E36s03vT1jVn0Tm52BrUGgT43O8PgbCEDvkko2O2vGbOZPNbsbBxSxydsvG/MtQ3NaM6+MfdlYFIYWzprDBlpVp9jGWlWls4aE6MWxTcJ/kmoP7e/OUvuRqX7zgxS6elcsOJHYW1bquvPG7PonXmFeTw4fwJ52RkoIC87I+UGe/tC0j5JpupwFYNHP0SH9RS6LZszJ2bh/rQQCH77603ryCbtkSV56ciaV5jXu2AvxfIk+CcT73RCbWtBAWpAA+mOLbQAac1Terz9tc+eLcE+wiQvHQekWB4gaZ+kYlTLR1nayLxgh9z+xgnJS8cBKZYHSM8/qThNSgVoW4Nx4Jdb36jz/h5ktk/49Hn2lOuoyfHUKpYnwT+JKHc22nbK8HiAmnJ4/i5ob/V87jrq+RzkDSDCep2XFj3q87TmmnI8tfoNKhukWLE8SfskkeaPZqI70nyO6Y40mj+aGXjyH5adDfxe7a2e40IkiD7Pntq5BsPAj0q5YnkS/JNIjuUqWpzz6WjNRmvoaM2mxTmfHMtVgSc3f2L8JGbHhYhDRoPnEGT2lGlqR6fcHa+kfZLI0lljuG9LK43/KOw6lpFmZel8GUwUyWdrda1ZAsd89pQ93zjnbx8eeCzJSfBPEqt3beS3Hz6F9aJTZLmzOVM/iwssVxkOfnn25x2Gu1H57s8Lsjm7SBjrtx80S+CYz56asdJ3miek7P4IEvyTwOpdG3nun/+JsrV75venNZDuqGDm5z7HvMLpPud69+fVLZ6Mn3d/XgD7KLdszi4ShllqR2New6ortSOz3CT4J4OKI4+hrL6DXsrSTsWRx1jFN32Om+7P+/ZQ7D9cnZIvApGYzBbM5RmlfGRacwAZ8E0C2tLY6+Nm1TvdjSrlXwwisfR6wZx3Ra/rKKDPruitKY9eY+OQBP8kYLa9tNFxs+qdsj+vSDS9LuQmK3oNSdonwW2trqWjPQOLLfD2N8M6OOBYzpK7O3P+Z1M/Kj1dNmsRCamnBXOuykrqN57B3eQInNyQYit6/UnPP8Gt336QMx/NQXf4/So7LKyatiLgfPvs2Th+vMazLaNS2HJzcfx4jRR0E0nHO7nB3WQDVNfkBteRzjGBFFvR6096/gmurqEZTSEtwMDzt6PSGtBt2bSemGW6ZaBU7xSpwHRyQ81g7KNJyemd3UnwT3DeGQ/uTwu76vaDyYwHIVKI6eSGJhvMfjTlJzikRNrHVVnJoekzeG/spRyaPgNXZWWsmxQ2UiJYCGOmkxtyc1M+8EMKBP+uvF9dHWiNu64O5wMrk+YNQLauE8KY2dakMrnBQ2lttEA69qZMmaL37NkT8vMcmj7DE/j92HJzGb1rZ8jPL4SInqrDVZTtLeN443GGZQ2jeHKx6dgWdM72SbGtSZVSb2mtp/R0XtLn/E3zfibHk4qsahRJxLtNqXe3Omejk5LXSwBkckM/JH3axzTvl6UTdoXf1upappXuYuTyKqaV7mJrdW3gSTXluMru4dDGM7y3aRiHNp7BVXZPwn7PqaTqcBUzK2ZS8GwBMytmUnW4KtZNigtG25S2tLdQtrcsRi1KbEkf/A3zftYOcsafSsgl3t6di2obmtGc3bnI/w3A9UQJzt2ZvnOcd2fieqIkFs0WveTt3TobnWh0V+9W3gA8PX0jx022LxXBJX3w71rUNAhAo9LasVg1dbuzOfTbwQkXDHu7c1H97jZ0u++vV7dbqN/dFvE2iv6T3q2xYG9+w7KGRbElySPpgz943gBG3+Akd2oDdCjaW6109Yb/3J5QM3/Mytj6H/f0+AOZHRfxwawXm+q922BvfsWTi6PYkuSREsEfAHs+9TWDjXvDD62LUaP6zmyHIv/jtvPshueZHRfxwawXm+q922BvfsFm+whzqRP8Z6zE3WQ1fMh90hXlxvTfzCtqGXRxKYMuWU7WRaXYhlQbLurKWbYCNcB3M3c1II2cZYH1fkT8KJ5cTLrVd4wq3Zoeeu+2phzX7RdxaPJo3rvkEg5dNgbXY4nzt2D25ufIkmq0/RVS8FdKnaOUekkpdajz/6EG50xSSv1VKfWOUqpGKbUolGv2W8FCbJnthg/ZMt1Rbkz/VB2u4sW6R1FpDSgFlgENZDi2sPgrJwIWddlnz8axbp1vAbd162TaW5wrGlVEyVUlOLIcKBSOLAclV5WE1rutKcf1yBKcr9nOTgBotOB8/LcJ8wYQ7E0xmVfwR1JIi7yUUv8BfKK1LlVKLQeGaq2X+Z3zeUBrrQ8ppXKBt4CxWuuGYM8drkVe3bnuvBTnn9t9Uj/K2oHjy1bsj78b1mtFwhd/80VcrYF3KY4sBztu3BGDFomEsGE8hzaeMRzvsQ2C0Xvei0Gj+s5ogdcX3+kwLFGeypVqo7XIay5wTefHzwJ/AnyCv9b6g24f1yml6oHzgaDBPxLsd5TAmXuor07H3WT11PcubMF+x39Guyl9VnW4yjDwgwwGih64juFuMk6buE/H5wp/I0WjigLugA59d0Zg5c6WFuo3PJKywb+3Qg3+F2itvZNvjwMXBDtZKXUFMAD4R4jX7Z+ChdiLwe6z6nVdQqx6DTbbIdUHA0UP7PnYMk16/pntnrUuCfAaMJLSK/hD1GPwV0r9ETCKLj7JQq21VkqZdiOUUg5gI/AtrXWHyTm3A7cDXHjhhT01rX8KFibkH3qw3r1MdRNBzVhJzoElON8Y5DfbTdPepnA9UYL98cR7TYBnBb9h7S7ZlrRHPQ74aq2v1VqPN/j3PPBRZ1D3Bvd6o+dQSg0BqoAVWuvdQa71pNZ6itZ6yvnnn9+/7yhJmfXuswdmy1Q3EVzBQux3b8BxuQvrgHbA20dT6DZrwq116U4qd/ZfqFM9XwC+1fnxt4Dn/U9QSg0Afgf8UmtdEeL1UpbZbIflVyyPUYtEQilYiH3ieSibBpTPQ7rdQv2GR2LTrhDJtqT9F+psn3OBcuBC4J/AQq31J0qpKcAdWuvvKqW+ATwDvNPtS2/VWu8L9tyRmO2TqLyzHJyNTizKQofuwJHl6LGcrRA+asp5b+FK/IM/AEox9r34n/EmehaV2T5a64+BGQbH9wDf7fz4f4D/CeU6qcy/jG2H7uia3yyBX/RJwUJs5z1suKhRcuSpJ3VW+CYo00Jfux+EDeOhJNvzf4JVJxWxkbNsRWCOPC2N9qYmWSSVYqTKV5wzLfTV2gCuo55PXEc95akhIWcyiejx5sK9u1tZ7XbaXS5o8Cy7cdfVUXffj3zOFclJev4G4mm5uGmhL7dfqYq2Zs+uXUL0wD57NqN37WTse+/SAeA/7ud241z3k1g0TUSRBH8/8bbhu+Esn44Oik8ZLJB2HYtSq0Sy0A3GC+3NjovkIcHfT/2GR0yXi8eCf6GvGw4N4enHOxj1tJ1DL+TgOtKtlLM9PyZtFMlIy1hSkpOcv594WS6+tbqW9dsPUtfQTG52BktnPcNXju3F+eJKdAt0bUbzN099fvtoPBu0C9EH1uxs2g16+dYBHTKWlOSk5+/HdMqb1lHL/5vt0/vPh34aeFfSbqH+7aEw+1F5gYo+u2DFj1Bpfvs+WDQXTP7U84mMJSUt6fn7yVlyd0CJWC9v/h8iOxPCbJ9e20nD6hm4G5UE/hQSeFc4JmA/h97ymf1TV4sts51BjhbqawZTtzvbU/m24CSy/1vykZ6/H5/l4gaikf+vNdmntz4j2/C4LNBJHWZ3hVura/v9nF2zf/63jZyCz3AdyTy76UuTDeeeoTL3PwlJ8DfgfTGgDJbBE/n8v7XzunMsr/LqgB9weODXeXXADzgw/mIpYpXizO4K128/GPqTz1hJ/YEhgftcu0nY2j/CnAT/IMx61JHuabdrzRzLq5Sm/Zx8y0ksCvItJ/nOqG04vvdvUsQqhdWZ3BWaHe+TgoXm+1zX1cV8zYsILwn+QcSiXOzW6lrmWV/j4bQnyFStPo9lqlbsrc93LdAZvWunBP4Uk5ud0afjfWVzGKc7oXP179J7eX/qF8L6JrC1upZppbsYubyKaaW7Qkphid6TAd8g/JfC2xwOcpbcHbGAu7W6lld/9zg/sT2FTRnudyMLuVLc0lljuG/LAZ/UT0aalaWzxoTl+YNNePDSDQ1hm/iwtbqW3294hnUHqji/uYETGdn8+t0iWHJbvwexRe+EVNI5klKxpPO00l1sbvoe+ZaT5ifZh8OSt6PXKBF3wjnbx4irsrJz9k/gDlnd2XJzPWNjIai4Zh6XHj/oU2S6xZrGxqu+zsNPyV4V/RGtDdxTWteLJEx3BXUNzeQODBL40zJkIZdgXmFeRHvF9tmzsc+ezaHpM4K+AYQ68cG5enVA4AdIb29jzp7nAQn+kSQ5/34KWw2gmvKu0sx/TS/mlB5kfJ6yykIuEVVGY17dWe32kAogNpQ/Z7StjOfazVJbKNIk7dNPZr0i23l2Rn/ttCc3b8/39NTNAnZNuWf5fNvZmRqt2oZGM1Cdzem6renY5v5/EviFoUimgVyVlXy07icBJSBUWhpaa3C7ux3VZF/UiGNaO9zwSMDfq+uxFdQ/swX3aY1tkMJ92vy6beflUPDqn8PyPaSa3qZ9JPj303tjLw0shQuAZuzibrfD1gFgSYO2RlxHMqg/YMfdZMHmyCXnEif2nMCZDQ0M5nTHQHItH9OSMYzM69dI4BeGvIu+/AeAH5w/ITLjAJ0pzvamJpPKn57XhC2zg5y5hdiHvAOuY7jqzsX5Whq6XfmdG9j310De+v+QmWz9JDn/CLM5HMY9/0y/OvvtrdDeiutIBs6/2bsW0Ljr6nB+1AGXZ2Af4TtHO5vTZK/xzOrJjEzzRZLwLvqyDalm4PnbUWkN6LZs1v15NvMK7w3bdbzjAF7vjb3U5ExPMHc3WXGW74PLXdhHaOr3WPwC/9lz/WV+YaoE/iiQnH8/Ga4BsHaQU/CZ4fn1NYMDV062W6ivGRx4spRmFr1U19CMbUg16Y4tWAY0oBRYBjTQbN9E1eGqiF23Nwsdu/99my0eAw3WzsesVrJvWsyIZ54JUytFMNLz7yfDNQCXOLHnGK+0NF056X9cZvSIPsjNzqDh3O0oS5vPcWVpo2xvGUWjiiJy3Zwld1O3tOc7C+/fty2zvbNekC/bIMXoPTJ1ORak5x+C7tvhjd61E/sdJZ7gbSAgHeQ9fl62Z+4+yvO/zOgRfbB01hhUmvHMGLP9n8PBPns22Tct7vE87999TsFnKKvvwkVl1eTcNj8i7RM9k55/OHmD9s41ntk+GUOhuQHwpIO65/yhs1TEshUg+U3RT/MK8/jpezm42gLLfZvt/xwujlWryJw8mfp1K3F31RY6m8dXVt2VBrWPaAaLlfr9g3A3KmyDFDm3fQ37Xesi2kZhLqln+1QdrqJsbxnHG48zLGsYxZOLI3YbbKqmHP6wDJo/CZztE8FSESJ1VB2uouT1Elraz5ZkSLemU3JVSVT/3gMWPX5tKvbW53s37VmETcpP9YyXF4QQ0RAXHR0RF1I++M+smImzMXD5uSPLwY4bd4TStJDJC1UIESkpP8/fbLArkoNgveF/R+JsdFLyegmAvAEIIaImaWf7mA12RXoQrCdle8t8UlEALe0tlO0ti1GLhBCpKGmDf/HkYtKtgUWpnI1OZlbMjOgCmGDi9Y5EJB/ZJEUEk7RpH28KpWxvWUDuP5aplmFZwwzHImJ9RyKSi3/NH+9G74BskiKAJO75gyew77hxB46swKXo0U61eHthH37wJehI83ks3ZpO8eTiqLVFJL+IbvQukkLS9vy7i3WqxbcXVogG0nO2o9JcOGS2j4iAiG70LpJCSgT/WKda/Hth7k8LOf1pIXnZGexYPj0qbRCpJTc7g1qDQB+ujd5F4kvqtI+X0eBvNFMt0gsT0bZ01hgy0nyLBoZzo3eR+FKi59998DcWC6ukFyaizTuoG8mN3kViS9oVvvEkWrstCSFEyq/wjSfSCxNCxJuQgr9S6hxgMzACOAIs1FqfMjl3CPAusFVr/f1QrhtOkdz8urt5hXkS7EVckNpSAkIf8F0O7NRajwZ2dn5u5sfAX0K8Xlh50zG1Dc1oPAthlmzex/1bD8S6aUJEhLe2lLPRiUbjbHSy/JXlrN29NqTnnFkxk4JnC2K6el70TajBfy7wbOfHzwLzjE5SSl0GXADEtpymH6OFMBr41e5/9XspvLwQRDwzqi0FsPng5n79rRq9mZS8XiJ/9wkg1OB/gdbaO4H+OJ4A70MpZQH+E/j3np5MKXW7UmqPUmrPiRMnQmxaz4ymWtqGVJN5USkP7L+uT8F77e61FDxbwPJXlssLQcStYAsb73v5oT53eqRQYeLqMfgrpf6olHrb4N/c7udpz7Qho6lDdwK/11of6+laWusntdZTtNZTzj///F5/E/3lP9XSNqSadMcWLAMaQHlqAD3w6qoeg/fa3WvZfHAz2uDblxeCiCfBFjZ2WE9x35YDvXoD8N7hGi2eBClUmAh6DP5a62u11uMN/j0PfKSUcgB0/h+4kSh8Afi+UuoI8FPgFqVUaRi/h35bOmtMtx1HYeD521GWNp9z2vQZlvfQI3rug+eCXkdeCCJeBFvYqNuye1X/p3uqx4wUKox/oU71fAH4FlDa+f/z/idorW/2fqyUuhWYorUONjAcNfMK89jzz0/41e5/oQGV1mB4nraeYukfnmHFnu1Y0hpQyoKmA0eWg+LJxXTojqDXkReCiBdFo4qorq9m88HNPsd1RxpnTswCzqZDu8+Ey0iz0OzuQGvIurgUS1rguIGXFCpMDKHm/EuBryqlDgHXdn6OUmqKUurnoTYuGtbOm8CGRZPIy85At2UbnqPbMxjYLR2k8QT77qWhzcgLQcSb+6feT+nVpSj3ULSGjtZsWpzzcX9aCHjSof4z4ZraPIEfQNmMO0ng2SZV9slODLLCt5vLH/kPmu2bfFI/uiMN3ZGGxdZk+nWZtkya3IGPZ1gzWHXVKnkhiLgUbOX5+u0HDUuSAGRdVOrpCPmJh/2xRe9X+KZEYbfeWvHlm+mov5GO1myfHpGymgd+gGZ3M4vGLMKiPD9Oi7KwaMwi3vzGmxL4RdyaV5jHg/MnkJedgQLysjO6So4EKzp45sQstOxJkfCk5+9na3Utqyvf4VTT2d6/WU/HS3o8ItlMK91l2vMHz8y4gTnbsaa5ZJVwnJGefz/NK8yjeuVMHukcBwBoNejpeEmPRyQjo5LQ3bk/LeR/nfc4Nd+qYceNOyTwJyAp7Gaiey2erdWTWPfnATRlVRrO9pE/fJFs/IsRdp/tY1WKm64cztp5E2LcShEKSfsIIUQSkbSPEEIIUxL8hRAiBUnwF0KIFCTBXwghUpAEfyGESEES/IUQIgVJ8BdCiBQkwV8IIVJQ3C7yUkqdAP4Z63aE4DzgZKwbEUXy/Sa/VPueE/X7/ZzWusetEOM2+Cc6pdSe3qyySxby/Sa/VPuek/37lbSPEEKkIAn+QgiRgiT4R86TsW5AlMn3m/xS7XtO6u9Xcv5CCJGCpOcvhBApSIJ/FCil7lFKaaXUebFuSyQppdYrpd5XStUopX6nlMqOdZsiQSl1nVLqoFLq70qp5bFuTyQppYYrpV5WSr2rlHpHKZUS29YppaxKqWql1IuxbkukSPCPMKXUcGAm8K9YtyUKXgLGa60LgA+A+2LcnrBTSlmBx4DrgUuBm5RSl8a2VRHlBu7RWl8KTAXuSvLv16sYeC/WjYgkCf6RtwG4F0j6wRWt9Q6ttbvz091AfizbEyFXAH/XWh/WWrcCm4C5MW5TxGitnVrrvZ0ff4YnIObFtlWRpZTKB4qAn8e6LZEkwT+ClFJzgVqt9f5YtyUGvg38IdaNiIA84Gi3z4+R5MHQSyk1AigE3ohtSyLuETwdto5YNySSZAP3ECml/ggMM3hoBfAjPCmfpBHs+9VaP995zgo86YJfRbNtInKUUoOA3wJ3a60/jXV7IkUpdQNQr7V+Syl1TazbE0kS/EOktb7W6LhSagIwEtivlAJPCmSvUuoKrfXxKDYxrMy+Xy+l1K3ADcAMnZzziGuB4d0+z+88lrSUUml4Av+vtNZbYt2eCJsGzFFK/RuQDgxRSv2P1vobMW5X2Mk8/yhRSh0BpmitE7FQVK8opa4DHga+rLU+Eev2RIJSyoZnMHsGnqD/N+DrWut3YtqwCFGensuzwCda67tj3Z5o6uz5/7vW+oZYtyUSJOcvwulnwGDgJaXUPqXUE7FuULh1Dmh/H9iOZ/CzPFkDf6dpwDeB6Z2/032dvWKR4KTnL4QQKUh6/kIIkYIk+AshRAqS4C+EEClIgr8QQqQgCf5CCJGCJPgLIUQKkuAvhBApSIK/EEKkoP8Hw9DsL+C+/ZMAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "outer_batch_size = 2\n", "x1, y1, x2, y2 = sample_tasks(outer_batch_size, 50)\n", "for i in range(outer_batch_size):\n", " plt.scatter(x1[i], y1[i], label='task{}-train'.format(i))\n", "for i in range(outer_batch_size):\n", " plt.scatter(x2[i], y2[i], label='task{}-val'.format(i))\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2, 50, 1)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2.shape" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "1000\n", "2000\n", "3000\n", "4000\n", "5000\n", "6000\n", "7000\n", "8000\n", "9000\n", "10000\n", "11000\n", "12000\n", "13000\n", "14000\n", "15000\n", "16000\n", "17000\n", "18000\n", "19000\n" ] } ], "source": [ "opt_init, opt_update = optimizers.adam(step_size=1e-3)\n", "out_shape, net_params = net_init(rng, in_shape)\n", "opt_state = opt_init(net_params)\n", "\n", "# vmapped version of maml loss.\n", "# returns scalar for all tasks.\n", "def batch_maml_loss(p, x1_b, y1_b, x2_b, y2_b):\n", " task_losses = vmap(partial(maml_loss, p))(x1_b, y1_b, x2_b, y2_b)\n", " return np.mean(task_losses)\n", "\n", "@jit\n", "def step(i, opt_state, x1, y1, x2, y2):\n", " p = optimizers.get_params(opt_state)\n", " g = grad(batch_maml_loss)(p, x1, y1, x2, y2)\n", " l = batch_maml_loss(p, x1, y1, x2, y2)\n", " return opt_update(i, g, opt_state), l\n", "\n", "np_batched_maml_loss = []\n", "K=20\n", "for i in range(20000):\n", " x1_b, y1_b, x2_b, y2_b = sample_tasks(4, K)\n", " opt_state, l = step(i, opt_state, x1_b, y1_b, x2_b, y2_b)\n", " np_batched_maml_loss.append(l)\n", " if i % 1000 == 0:\n", " print(i)\n", "net_params = optimizers.get_params(opt_state)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD8CAYAAABzTgP2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzsnWdYVEcXgN9ZqlRpAgKKBcHesMTee4stGo35TGI0tlhjjClGTSzRWBJLjCYmGqMm9t7FHsXeBcUCSld6Xeb7cbEgiIWFBbnv8+wjOzN35iB799xz5sw5QkqJioqKiorKIzT6FkBFRUVFJX+hKgYVFRUVlQyoikFFRUVFJQOqYlBRUVFRyYCqGFRUVFRUMqAqBhUVFRWVDKiKQUVFRUUlA6piUFFRUVHJgKoYVFRUVFQyYKhvAV4He3t76e7urm8xVFRUVAoUp06dCpdSOrxoXIFUDO7u7vj6+upbDBUVFZUChRDi9suMU11JKioqKioZUBWDioqKikoGVMWgoqKiopIBVTGoqKioqGRAVQwqKioqKhnQiWIQQvwmhAgVQlx8Tr8QQswTQvgLIc4LIWo81fe+EMIv/fW+LuRRUVFRUXl9dGUxLAPaZNPfFvBIf30MLAQQQtgC3wB1gNrAN0IIGx3JpKKioqLyGujkHIOU8qAQwj2bIZ2BP6VSR/S4EKKoEMIZaALsllJGAgghdqMomL91IddLo02B0CsQFQjRQZCaCM5VoXh1MLHMU1FUVPKc2FAIvgChl8HAGGzLgF0ZKFoSNKq3uTCSVwfcXIC7T70PTG97XnsmhBAfo1gblChRQjdSxUfCqd/hxK8Qcz+LRTXgVAUajQGvDiCEbtZVUdE32hQ4twoOzYIHAVmPsXGHesOhWh8wMs1T8VT0S4E5+SylXAwsBvD29pY5mixNCz7T4cg8SE2A0k2h5WSwKw1WrqAxhHtnIMgXLq6F1X3BtRa0nAQl6+ni11FR0Q9SwoV/YP938OAWFK8BdQaCYyUoVgHSUiHyBoRdhTMrYOso5V5p8jnU7K8+HBUS8koxBAFuT713TW8LQnEnPd1+IFcliQmBtR/CrUNQqTs0HA2OFTKP82ihvBqOgXMrYf/38HtbaDoBGo1VbxCVgkdqMmwbDaf/VCzh3quhXOvMn2VLR+UBqGZ/CDgIPjNgy0i48x90mA3GZvqRXyXPyCsH4iagX3p0Ul0gSkp5H9gJtBJC2KRvOrdKb8sdbh2GXxpCoC90WQjdl2atFJ7GwBBq9INhp6FKL+VJa8Ng5SZTUSkoxIXDn50VpdBwNHzsA55tsn/AEQJKN4b3NysPROdXw2+t4cFLpdtRKcDoxGIQQvyN8uRvL4QIRIk0MgKQUi4CtgHtAH8gHuif3hcphJgMnEyfatKjjWidI6XiTzWxhPfWg2PFV7ve2AzeXgS2peHA9xB1F3r/rW5Oq+R/ooIUazc2BLougSo9Xu16jQYaf6YEZKwdAEtbwYe7wKZk7sironeEEihUsPD29pavlV01LgIMjXP+ZX5+DawfBGWaQe9VilWhopIfSYyC39rCwzvQbwO4eudsvtAr8FsbMLeHD3Yq/6oUGIQQp6SUL/wQFK5YNHM73TzhV+kJ7WeB/27YOT7n86mo5AapybD6PQi/Bu/8mXOlAFCsPLy7WgntXtkTkuNyPqdKvqNwKQZd4t0f6g2DE4vh+CJ9S6OikhEpYfNwCPCBTj8p1q2uKFEXuv+uRO798z9IS9Pd3Cr5AlUx5IQWk5TzDTvHw80D+pZGReUJp36Hc39Dk/FQ7V3dz+/VDtrOAL9dcHyB7udX0SuqYsgJGg10XQx2ZWHDEMWfq6KibyJuwM4JyvmcRp/l3jq1PlIejPZ+q5ycVnljUBVDTjE2hy6LIOYe7PhC39KoFHa0qbDuYyW1RZcFuZvSQgjoOA+K2MLajyAlIffWUslTVMWgC1xrQoNRcHYFXNuub2lUCjOHf1RO7Hf4EayK5/565naKAgq7Cru/zv31VPIEVTHoisbjlLQCm4YrYbEqKnlN8AUlfUXlHlCpW96tW7Y51B2sBGLcPZF366rkGqpi0BWGxsoBuIRI2DdZ39KoFDakhO3jwMRK2RTOa5pOAMvisG2skotMpUCjKgZd4lQZag2A039AcJY1i1RUcofLG+D2EWj+FZjZ5v36JhbQajLcP6sk31Mp0KiKQdc0/gxMrWHnF8pTnIpKbpOSALu+UlyZNfRYBLFSNyhRT4lSSnigPzlUcoyqGHSNma1iVgf4qBvRKnnDkXlK7q4200BjoD85hIB2MxSlsH+q/uRQyTGqYsgNavYHBy/YNUHNwqqSu0QFweHZUKEzlGqob2kUd6r3B3DyVwj307c0Kq+JqhhyAwNDaP0dRN4E36X6lkblTebgD0pxnZb5KOCh8edgaKrUcVApkBQqxaCNjSM5MDBvFivbAtwbKk9z6sEfldzg4R1lo7dGv/yVAtvCAWp/rFSKC7umb2lUXoNCoxiklAQOGsTdgYNIi8ujjJBNPldy4J/+M2/WUylcHJql+PUbjtK3JJmpN1zJCnBgmr4lUXkNdKIYhBBthBDXhBD+QojPs+ifLYQ4m/66LoR4+FSf9qm+TbqQ5zkyYj9kMMkBAdz/+hvypA6FewMo2SDdakjM/fVUCg9PWwvWrvqWJjPmdlBnEFxaDyGX9C2NyiuSY8UghDAA5gNtgQpAbyFEhnqZUsqRUspqUspqwE/Auqe6Ex71SSk75VSe7DB/6y0chg8jeutWHq5alZtLPaHxZxBzH84sz5v1VAoHh2aB0ECDka98aVpyMinBwSRev0786TNoY2JyQUDgrSFK/RPVaihw6KL0WG3AX0p5E0AIsQroDFx+zvjeKKU/9YLdxx8Tf/o0Id9PxbRSJYpUrpy7C5ZqBCXeUqyGGv3A0CR311N583lkLdT83ytZC1Kr5cHKvwmbO5e02NjH7Rpzc4q+8w627/fDyNFRd3Ka2SqpMnymKVbDq5bTVdEbunAluQB3n3ofmN6WCSFESaAUsO+pZlMhhK8Q4rgQoosO5MkWodFQfPp0DBzsCfp0BNqHD198UY4WFEoepeggOPtX7q6lUjg4+jMglMSNL0nilSvc6tWbkO++o0jVqjhN+haXObNxXbAAiyZNiFy2DP8WLbn/1Vck37374glfljoDwcgMjs3X3ZwquU5ebz73Av6VUj6dTKVkeg3Sd4E5QogyWV0ohPg4XYH4hoWF5UgIQxsbXOfMISUsjHvjPkfmdgWq0k3ApaZyQ6vVrlRyQsIDxVqo3B2ss3z+ykBaXBwh02cQ0L0HKffvU3zWTNyW/IpNz55YtWmDZbOmuMyaSZldO7Hp0Z2oDRu50aYt98Z/QfKtWzmX18wWqvdV6qRH38/5fCp5gi4UQxDg9tR71/S2rOgF/P10g5QyKP3fm8ABoHpWF0opF0spvaWU3g4ODjmVmSJVquD4+ThifXyI+HVJjufLFiEUkzryhlLxSkXldTn1B6TEKf77FxCzfz83OnYk8vffKdqtG2W2bsG6fXuEEJnGGru64vT115TZswfbvn2I3raNG+3ac2/c5zlXEHU/AalVsq+qFAh0oRhOAh5CiFJCCGOUL/9M0UVCCC/ABjj2VJuNEMIk/Wd7oD7P35vQOTbvvotVu3aEzZ1L3PH/cnexCp3BykUtg6jy+mhT4L9foFRj5YTxc0gJCSFw+KcEfjIYA3NzSq78C+dJ32Jgbf3CJYwci+E4fjxl9+7B9v33id65M11BjCMpIOD15LYtrVR6810KSbEvHq+id3KsGKSUqcBQYCdwBVgjpbwkhJgkhHg6yqgXsEpmjBMtD/gKIc4B+4FpUspcUwzaNC3ap1ICCyFwmjQJY3d3gkaPJiUkNLeWBgMj5dBPgI+aeVXl9bi0QakU+NbQLLtlWhqRK/7iZrv2xPr44DBiBKXWrsWsRo1XXsrQ3h7HcZ9Rds9ubPv1I3rnLm6270DQmLEk3XwNBVFvuFL6Vt1nKxCIPInn1zHe3t7S19f3la6RUjL5+GQeJj1kWsNpGBsYP+5L8vMjoOc7mFasQMllyxCGugjWyoKEB/BjBajYFbqom3Eqr4CUsLixcop+8H+ZSnamhIZy//PPiTt6DPN69XCa+A3GJUpkmiYxNZETwScoZVUKNyu3TP3PIzUigsjffydy5d/IlBTsBwzAbtBANMbGL774EUtbQUwwDD+j32R/hRghxKn0Pd1sKTQnn4UQuFu5s/v2bobvG05C6pM0FSYeHjhP+pYE31OEzZmTe0IUsYFqfeDCGojNRetE5c3j9hG4f07Zq3pGKcTs309A5y7Enz6D06RvcVu6JINSkFJyPuw8k45NotmaZgzZO4R269sxcPdA9t7eS2pa6guXN7Szo9iYMZTdvQurtm0IX7CAgLe7En/6zMv/Dm8NhYe31azDBYBCoxgA+lXsx6R6kzh2/xiDdg8iJvnJwR7rjh0p2usdIpYsJWbv3twTos4g0CbDSTW5nsorcOJX5cGiaq/HTVJKwn9ZTOAngzF0dKTU2n+x6dnz8eZyQmoC6/zW8c6Wd+izrQ+bb2ymiVsT5jefz+Bqg7nx8AYjDozg3a3vci3y5XIaGdrZ4TJjBm6/LkYmJHC7Xz+it7/kF71nO2WfTU0smf+RUha4V82aNWVO2BmwU1b7s5rssamHDI8Pf9yuTUyUN7t2k1e9a8mkO3dytEa2LO8m5UxPKVNTcm8NlTeH6GApv7WVcscXj5vStFoZPHWavOzpJQNHj5HaxMTHfbejbssZJ2bIeivryUrLKskuG7rI1VdXy5ikmAzTpmhT5Pab22WjVY1ktT+ryYVnF8pkbfJLi5UaHS0D3u0jL5evIB/888/LXbR/mpTfWEkZ7v/S66joDsBXvsR3bKGyGB7Ryr0VPzf7mYCoAP63438ExwUDoDExwWXuHBCCwE8/JS0pKXcE8O6vpMnw25k786u8WZxZrqTWrtkfAJmayv0vJhC5bBk2fftSfMZ0pJEhB+4eYNDuQbRf356VV1ZSr3g9lrVZxrpO6+jp2RMLY4sM0xpqDGlTqg0bO2+kZcmWzD87/5WsBwNLS0os+RXzevW4/+VXRCxb9uKLavQDYQCnXmKsit4oNJvPWXE65DRD9g7B0tiSxS0X427tDkDMvv0EDh5M0Z49cZ70bY7XyYQ2FeZUVlIE9P1X9/OrvDmkaWFuNbB1h/c3I6Xk/vgviNqwAfthQzH9qB+rr6/mn2v/cC/uHsWKFKOHZw+6eXTDwezVzvvsvb2XSccnEZ0czcAqA/mw8ocYaYxeLGJyMvfGjCVm1y7shwzBfuiQLM9KPGZ1X7h1BEZdASPTV5JRJWeom88vQQ3HGvzW+jeStEm8v+P9x09Kls2aYjdgAA/XrOHhhg26X9jAEGq8B/574MFt3c+v8ubgvxei7ihV0YCwuXOJ2rAB28Gf4NPSkQ4bOjD39FxcLV35scmP7Oi+g0FVB72yUgBoXrJ5Buuhz9Y+L2U9aIyNcflxFtZvv034/PmETpuWffZi7w8hIRIub3xlGVXyhkKtGADK25VnWZtlGGmM6L+zP2dDzwLg8OlwzGrVInjityRev677hau/p5yIVms1qGTHqd/BvBh4tufB338TsegXZMfmDC55gG+PfUtJq5L83f5vlrZeSsuSLV/qCT87ipoWZUajGcxpMoeQ+BB6be3FonOLSElLyfY6YWiI83dTsHnvPSL/+JP7X36J1GqzHlyqsXLozfe3HMmqknsUesUAUMq6FH+2/RNbU1s+3v0xR+8dRRga4vLjLDSWFkqyvVgdF/cp6gZlWyr+Y232N51KISUqEK7vgOp9iTl0hODJU3hYswzvVTpMVHI0PzT+gT/a/EEl+0o6X/pZ66Hvtr7cjs7euhUaDY5fjMd+8GCi1q4jaPQYZHIWNc81GsUCuntcrdWQT1EVQzrFLYqzrM0y3CzdGLp3KAcDD2Lo4IDLrFkk377N/a++1H1xH+/+SoU3Na5bJStOLwcpSbJrStCYMYS4mDOsyS0almjC2k5raePeJntffg55ZD382ORHAmMC6bG5B5tubMr2PhBC4DB8GMXGjSNmxw7uDhlKWkIWpW2rvgsaIyUhoEq+Q1UMT2FfxJ7fWv+Gh40HI/aPwOeuD+a1a+MwcgQx23fw4K+Vul2wbEslrluN0FB5lrQ0OLsSrXMDboydSBSJfNclhTENv2R2k9lYm7w475GuaFmyJWs7raWCXQUmHJ7AN0e/yZBaJivs+v8Pp8mTiDt8mDsDBpAWH59xgLkdeLaF86shNQurQkWvqIrhGaxNrFnccjHlbMox4sAIDtw9gN2HH2LRtCkh06eTcO6c7hYzMFQOLN3cr6YkVsnI7cOkPbiD785otPfus+xdB+b2+oteXr1y1Up4Hk7mTixttZQBlQew3n89U09MfaEFbdOjB8Vn/kDCqdPc+2JC5vHV+kB8hJpxOB+iKoYssDaxZnGrxXjaeDLywEjOhJ2l+LSpGBUrRuCIkaQ+eKC7xaq+CzJNeXJSUUkn+cxfrLvhjNXlMA729GTa0I1UsKvw4gtzEQONAcNrDKd/pf6svraaBedenCnYun17io0ZQ8yOHYQveGZ82RZg4QhndWyJq+QYVTE8BytjK35p+QuuFq6MPDCSEIM4XObORRsezr3PxumuuI99WXCtDef+VhKlqRR6wqPuMPOwDxVPCYJaVOKTb9blqevoRYysMZKuHl1ZdG4Rf115cbZU2w/6Y92lC+E//Uz0zqesAwNDqPKOctAzNmfFt1R0i6oYssHaxJp5zeaRok1h2L5hSM9SOE6YQNyhQ0T88ovuFqrWG8Kuwr1XSEim8kZyNfIqo5f2oN0uQVI5Z5r/+Bcakb9uUyEEX9X9iuYlmjP9xHT23dn3wvFOk76lSLVq3Pv8c5L8/Z90VuujnOpWLeZ8Rf76xOVDSlmXYmbjmfg/9Gf8ofFY9eyOVaeOhM37ibijR3WzSMWuYGCiWA0qhZb9d/YzeG1fPlodjbEJVFq6GvEqaa3zEEONIdMaTqOiXUU+P/Q5VyKuZDteY2yM60/z0JiaKvsNqekZXYt5KWVvz/6lWsz5CFUxvAT1XOox1nss++7uY/7Z+ThPnIhJ2TIEjRlLSkhIzhcoUhS82sOFfyA1l/IzqeRbpJQsu7iMkXuHM3qDxD5GUurTThjqoIRtbmJqaMq8ZvOwMrZi2L5hhMVn7w4ydHDA6asvSTx/nsin8ypV6wOhl+H+2dwVWOWl0YliEEK0EUJcE0L4CyE+z6L/f0KIMCHE2fTXR0/1vS+E8Et/va8LeXKDPuX70M2jG79e+JWdIT64zJ1LWmIiQSNHIVN0cECt2rtKIZ/ramK9wkRcShxfHvmSWadm8YWvK6VvxONcK4oinYfpW7SXwsHMgZ+b/0x0cjSf7PkE/wf+2Y63bNsWy5YtCZv3E0k3biiNlbopFvP5NXkgscrLkGPFIIQwAOYDbYEKQG8hRFbhE6ullNXSX0vSr7UFvgHqALWBb4QQNjmVKTcQQjChzgRqFKvBV0e+wt8qAefJk0g4fZrQH2fnfIHSTcHCSXUnFSIuhF2gx+YebL6xmW+im1J57y1sKhtQtJk3WLvqW7yXxsvWi1mNZxEcH0z3zd354eQPxCZnXdtZCIHTN1+jKVKEe198oaTNKFIUyrWCi2uVpIEqekcXFkNtwF9KeVNKmQysAjq/5LWtgd1Sykgp5QNgN9BGBzLlCkYGRsxuOhs7UzuG7xtOctPa2Lz7LpG//070rhzGYhsYQuXu4Lcb4iN1I7BKvmXF5RX0296PlLQUlpX6mkpLDmJWtTyO5e9C5Z76Fu+VaejakM1dNtOlbBeWX15Opw2d2HlrZ5ZnHQzt7XH86isSz53nwcr0B6HKPZQsAAE+eSy5SlboQjG4AHefeh+Y3vYs3YQQ54UQ/wohHhWbfdlr8w22prbMazaPmJQYRh4Yic3YkZhWrsz9LyaQfDuHmVIrd4e0FLiyWTfCquRLVlxewfST02nk2oh/Wq/E+vulGNja4tLDHWFkDOU76lvE18LG1IaJ9Saysv1K7IvYM8ZnDEP2DiEoNijTWKv27TCvV4+wefNIjYgAj9ZgYg3n/9GD5CrPklebz5sBdyllFRSr4I9XnUAI8bEQwlcI4RsWpt+YZ09bT6bUn8K5sHNMPTMTl9k/goEBgSNGkpaY+PoTO1cDu7LKJrTKG8l6v/VMPzmd5iWaM6vJLFKW/0PyrVs4fzsRw1tbwKOV4lopwFSyr8TK9iv5rNZn+Ib48vbGtzkUeCjDGCEEjl9+SVpiIqGzflTqMlToqDwUpWSRW0klT9GFYggC3J5675re9hgpZYSU8lG4zRKg5ste+9Qci6WU3lJKb4d8EK3Ryr0VAyoPYK3fWjbEHaX49GkkXblCyHffvf6kQigm9a3DaoqMN5Adt3bwzdFvqF+8PjMazSDtbhARi37Bql1bLFzTFFdK5R76FlMnGGoMea/Ce2zsvBF3K3eG7xvOtpvbMowxKV0Ku/f7EbVuHfFnzii/e3KMklFWRa/oQjGcBDyEEKWEEMZAL2DT0wOEEM5Pve0EPAp63gm0EkLYpG86t0pvKxAMqTaEhi4NmfrfVK6Xt8Tu4495+M+/OSvuU6k7IOHSOp3JqaJ/fIN9GX9oPNWLVWd209kYaYwInjQZYWxMsXGfKy4UY0so11rfouoUZwtnlrZeSrVi1fj80OesuroqQ7/9J59g6OhIyOQpSLd6SgCG6k7SOzlWDFLKVGAoyhf6FWCNlPKSEGKSEKJT+rDhQohLQohzwHDgf+nXRgKTUZTLSWBSeluBwEBjwLRG03CxdGHUgVFoP+ypFPf5dlLG052vgn1ZxaWkupPeGG5H32bEgRG4Wrgyr9k8ihgWIWb7duKOHMFhxAiMbK3gyiao0AmMiuhbXJ1jaWzJopaLaOzWmO/++46F5xY+3pTWmJvjOO4zEi9f5uHa9Uroqt8uJXRbRW/oZI9BSrlNSllOSllGSvldetvXUspN6T+Pl1JWlFJWlVI2lVJefera36SUZdNfv+tCnrzEytiKeU3nkaRNYtShMdj/8D0aMzMCR4zInGr4ZancXUmPEXFDt8Kq5DlRSVEM2TsEgWBB8wVYm1gjU1IInTkL0woVsOndS/kiTIpW/u5vKCYGJsxuMptOZTqx4OwCpp+cTppU8o1Ztm2Lmbc3YXPmoC3VVgnAUMt+6hX15LMOKF20NN83+J6LERf53n8hxadPJ/nGTYKnvOZ+Q8WugIAL/+pUTpW8JVmbzIj9I7gXe4+5TefiZqVsp0Vv307KvXvYDxuKMDCAi/+CuQO4N9KzxLmLocaQyfUn816F9/jryl9MODyBlLQUZSN6whdoo6IIX3sYbMvApfX6FrdQoyoGHdGsRDM+qfoJm25sYqPdLew/GUTUunU8XP8a+w3WLlCyvvKFoeaPKZBIKfn66Nf4hvgyuf5kajjWUNrT0oj49VdMPDywaNwYkmLh+i6o0EU5y/KGoxEaxnqP5dMan7Ll5hZG7h9JYmoipuXLU7RHDyJXriTJrhkEHIS4cH2LW2hRFYMOGVR1EE1cmzDz5EzudH8Lszp1CP72W5L8/F59skpvQ/h1CM0+OZlK/uSnMz+x9eZWhlcfTvvS7R+3x/r4kOTnj92AjxAajZJyOjUBKr6tR2nzFiEEH1X+iK/qfsXBwIMM3D2Q6ORoHD4djsbUlJBd95QaJVc2vXgylVxBVQw6RCM0fN/we1wsXRh9aCwmk8ahMTcncOTIV99vKN8JhEY1qQsg6/zW8euFX+nm0Y2PKn+UoS/i1yUYFS+OVdu2SsOl9UokTom6epBUv/T07MmMxjM4H36eD3Z8QLS5wH7IEOL+O0NMTCm4qEbm6QtVMegYS2NL5jSZQ3xqPGMufYfjjKnKfsOkya82kUUxxZ10eYPqTipA+D3wY/LxydQrXo8JdSdkKMMZf+oUCadPY9u/P8LICJJilBQoFTqDxkCPUuuPNu5tmN9svhK5tX8EFr16YFy6NCHHDUm7cQRidJC9WOWVURVDLlDWpiyT60/mXNg5FhkdxX7wYKI2bODh2ld8AqrYJd2ddDl3BFXRKdo0LROPTcTCyIKpDadipDHK0B+x+FcMbGwo2r2b0nB9J6QmKn/nQkw9l3pMaTCFM6FnmHTqe5y++YaUiDjCL5qr7iQ9oSqGXKK1e2t6efZixZUV+Hepruw3TJ5M4vXrLz/JY3dSDg7MqeQZq6+t5nzYeT6r9Rm2prYZ+hKvXiXWxweb9/qiKZJ+VuGRG8mt8LmRnqW1e2sGVxvMphubWGN+CeuuXYm4ZkniATUVtz5QFUMuMsp7FKWsS/Hlsa+x/P5rNBYWBI0YSVpc3MtNoLqTCgzBccHMPT2X+sXr06F0h0z9EYt/RWNmhm2fPkrDIzdSxS6gUW9DgEFVBtHGvQ2zT83mRp/6GBQxIXjTTWTUPX2LVuhQP5G5SBHDIkxrOI3IxEim+P1M8R9mkBwQQPCkSVmmI84S1Z2U75FSMuX4FCSSL+t+mWFfASD59m2id+zA5t3eGFhbK43XdoA2qVBFI70IIQST60/Gy9aLceemYDCwJwnhxjxc+L2+RSt0qIohl6lgV4Eh1Yew+/Zu9hULw37IEKI2biJq3UvuN6jupHzPieAT+AT68EnVT3C1zFxgJ2LJUoShIbbvP1Wg8NJ6sCwOrrXzUNL8j6mhKbObzkaj0TCm2AlMXA0IXblPNyV0VV4aVTHkAf0r9qdGsRpM+28aqf26YPZWXYInT3m5/YZH7qRL61V3Uj5ESsmCswsoZlaMd8u/m6k/JSSEhxs2ULR7tyc1nJNiwX+PUndBdSNlwsXChRkNZ3Aj6iZ/dS+O1KYR/NWEl7eyVXKM+qnMAww0BkypP4VUmcrE/yZRfMYMNJavsN9QoTNE+EHYtdwXVuWV+C/4P06Hnuajyh9hYmCSqT/yt98hLQ3bDz580ui3S3EjVeiUabyKQj2XegyrPow1Re5zp1YSsQePELNDTcedV6iKIY9ws3JjZM2RHL13lA0PDuDyw0ySb93i/rffvvhJqHxHQKihe/mMp62mxi0fAAAgAElEQVSFbh7dMvVro6J48M8/WHdoj7HrU4UJr2xSciOVeCsPpS14fFj5Q5q5NePzRuZonY0JnjyF1Adq1tW8QFUMecg7nu9Qx6kOM0/O5GElV+yHDCZ602Ye/vuCZHmWTuBWW1UM+Yzj949zJvQMAyoPwNjAOFP/w3/+QcbHY9u//5PGlEQlN5JX+0J7qO1l0QgN3zX4DhdjS75vk4w2KoqQqVP1LVahQFUMeYhGaJhUfxJCCL468hW2Az/GvN5bhEz5jsRrL9hvKN8Jgi9AZEDeCKuSLY+sBSdzJ7p6dM3cn5JC5Iq/MKtbF1MvrycdN/ZBSlyBreuc11gYWzDHezz+xTQcbGRB9KbNxOzfr2+x3nhUxZDHFLcozljvsZwMPskqvzXKfoOVJUEjRqCNzWa/oXx6bPyVzXkjqEq2HL13lLNhZ59rLUTv3EVqcDC27/fL2HFlE5hav/EptnVJGc+OTI6TLKwVS7SbLcHfTEQbHa1vsd5odKIYhBBthBDXhBD+QojPs+gfJYS4LIQ4L4TYK4Qo+VSfVghxNv1VKHwlXT260sClAXNOzSHIOA6XmbNIvn2b4IkTn7/fYOMOTlVUxZAPkFLy85mfKW5enLfLZj6HIKUkctkyjN3dldTaj9CmwLVt4NkODDMrE5XnIASty3Skc0Ii01vEkBoRTsi06fqW6o0mx4pBCGEAzAfaAhWA3kKICs8MOwN4SymrAP8CM57qS5BSVkt/FYowDSEEE9+aiJGBEV8e/hLTWjVxGDaU6C1bst9vqNAJAk9AtHoSVJ/4BPpwMeIiA6sOxMjAKFN/wunTJF68iO37/ZTU2o8IOAiJUaob6XUo34lhERHcK26Eb3M3otatI/bQIX1L9caiC4uhNuAvpbwppUwGVgGdnx4gpdwvpXyUd/o4kPkUUCHD0dyR8bXHczbsLMsvL8du4EDM6tYl9IeZaB8+zPqi8ul68+rWvBNUJQNpMo35Z+fjZulGxzJZf8FHLvsDjbU11p07Z+y4sgmMzKFMszyQ9A2jRF3si9gx0MCB2VUD0ZZw5v5XX6ONjdW3ZG8kulAMLsDdp94Hprc9jw+B7U+9NxVC+AohjgshnptmUgjxcfo437CwsJxJnE/oULoDTd2a8tOZn7gZHYDjF+NJi40l/JfFWV/g4An2nmo9XD2y985erkZeZVDVQZmypwKkhIYSs3cvNj17oDEze9KRpoWr28CjJRgVyUOJ3xA0BuDZjj63LuBU1I1FHQxJDQ0ldMYP+pbsjSRPN5+FEH0Bb+Dpv2ZJKaU38C4wRwhRJqtrpZSLpZTeUkpvh0cnSAs4Qgi+futrzIzM+PLwlxiWLY11ly48WLGClKCgrC8q3xFuH4G4iLwVVgVtmpYFZxfgbuVO+1LtsxwTvW0bpKVh/fYzew+BJyEuVHUj5YTyHTFOjmWsS0t8rO8T0rE2D9esIe7YMX1L9sahC8UQBLg99d41vS0DQogWwASgk5Qy6VG7lDIo/d+bwAGgug5kKjDYF7FnQt0JXIy4yG8Xf8Nh+DDQaAibNy/rC8p3UMoe+u3MW0FV2HlrJ/4P/RlcbTAGzzmDEL15C6YVK2JSunTGjqtbQGOkWAwqr0epRmBsSZOQm9R1rsvEitcwKOHG/S+/evmMxSovhS4Uw0nAQwhRSghhDPQCMkQXCSGqA7+gKIXQp9pthBAm6T/bA/WBQpdGtI17G9q4t2HhuYXcNI7Ctt97RG3aTOLVq5kHO1cDKxe4siXvBS3EpKalsvDcQsoWLUtr99ZZjkm6eZPES5ew6vhM2m0plb9X6cZKqKrK62FoAh4tEde2M7bmaB4Qx4H3KpJy7x6hs+foW7o3ihwrBillKjAU2AlcAdZIKS8JISYJIR5FGf0AWAD/PBOWWh7wFUKcA/YD06SUhU4xAEyoMwFrY2smHJ6AVf//obGyInTWj5kHCqGcmr2xD5JfsY60ymuz5eYWbkXfYmj1oWhE1rdN1ObNoNFg1a5dxo7QK/AgQPm7qeSM8h0gPpxysZF09+jOArkfTbf2PFi5MusHKZXXQhTEjIXe3t7S19dX32LonP139jN8/3AGVhlIL18Twmb9iPua1RSpUiXjwJs+8GcneOevJwffVHKNFG0KHTd0xNrEmlXtV2WqtwDK2YUbLVth4OZG5MSZ3AiLIzAynrsPEqgftJSu0cvpbLKUlCIOlClmQVkHC8oUs6C0vTnu9uZYmBjq4TcrgCRGww9loNYAIpuMof269tQ1r8QnUy9iUrYsJZb/meXfR0VBCHEqfU83W9RPYz6iaYmmdCrTiSUXltC09RKMllgTvnARbgsXZBxYsh6YFlX81qpiyHXW+68nKDYoUxGexBQt10NiuBocQ+ixEzQPDGSGU0N2LVQ2QzUCnK2LMDT1KDdNylO2dBmiElK4EBjFtgv3M2RRt7cwxsXGDNeiRahQ3IrO1YrjamP2rCgqplZQqjFc3YJt6+8YWGUgs07Nok//3pjPXk70tm1Yt1cts5yiWgz5jOjkaN7e+DaWRpYsCmzOg58XUGr9OkzLl884cN1AuL4Dxt4AA1W/5xZJ2iTarWuHi7kLU+os5NSdB5y+/ZDTdx5wNTgGbZpy/3x6YT3Nbp9k9+RlVCnnjKeTJc7WRTCODYQ5laHlJKj/6eN5E1O03I6IJyA8lhthcdyNjCfoYQKBDxIICFc2UuuVsaN7TVfaVnKmiLGacO8xp5bB5k9h0GGSHTzpsrELJhgxa7kB2shIymzfljFUWOUxqsVQQLEytuLbet/yyZ5PWFOlDq0tLAhf9Auuc5/ZXPNqD+dXwZ2jSrSGis4JjU5k5vE/CI0P5eHt7jQ6eAAAc2MDqpUoyqDGpalU3BpP+yKkvD0Fi1YtGPP2M0F1V7cp/3pltOxMjQzwdLLE08ky07qBD+JZdzqIf08FMmrNOb7eeIn2lZ3p4e1KzZI2qqvEsx1sHgFXtmDsVJnR3qMZsX8EF/u9h8f43wn/ZTHFRo7Qt5QFGlUx5EMauDSgm0c3lvqvounb7UhbsZEkf39MypZ9MqhsczA0VaJdVMWgU3ZdCmb+fn/OBT7ArPRqjIQb9V3qUKuBHd4lbSjnaImB5smXc8y+/QQ+fJg5GgkUd59DebDL8nhOlrjamDG8uQdDm5bl5K1I/jkVyObz91jte5dS9uZ0r+nK29VdKF60kB6UsygGbnWUDABNx9PMrRm1nGrxw4PNLG/Xisg//sC2b58nFfNUXhk1u2o+ZWytsTiZOTG55FmEqWnm09DG6akVrm5VS37qiMAH8Xz0x0k+Xn6K2KRUejaMx8AkjO+aDWZe7xq8V7ck5Z2tMigFgOgtmzEoWhSLBg0yThgfCbePvnY0kkYjqFPajpk9qnJyQgt+6F6FYpYm/LDzGvWn7+P9306w9fx9klK1r/srF1y82kPIBXhwGyEEn9X6jKikKLY0s0SmpBC++Fd9S1igURVDPsXcyJxJ9SdxRRuIX5PSRG/dSvLt2xkHebWH6EC4f04/Qr4hRCem8OPu67T88SBH/CP4op0XO0Y04oHRHooVKfbccwsA2tg4Yvbtx7JtG4TRMykyru8EqQWvdllf/AqYmxjSw9uN1QPfwmdsE4Y2LYtfSAxDVp7mran72HSukCVWfKRsrynZdbxsvejq0ZWlkVsxaN+Ch6tWkRIcrEcBCzaqYsjH1HGuQy/PXkwvexVpaED44meshnJtQGiUVM4qL01SqpbgqEQu3Yti/n5/Gk7fz7y9fjTzKsae0Y35uFEZAqL9OX7/OL3L984yg+ojYvbsRiYmYt0xi1QX17aCpTM46/Ywf0k7c0a38uTQuGb8+UFtStqZMfzvMwz/+wxR8Sk6XSvfYlcGHLwUV106Q6sPxcTQhN+9Y5FA+KJF+pOvgKPuMeRzRtYcyZF7RzhcM4KGGzfhMHgwRi7pOQrN7dN9rdug6Rf6FTSf8yAumU3n7rH2dCDnA6My9DX3KsbIluWo5PLkVPKKyysoYliEHuV6ZDtv9OYtGLm4UKT6M1/+KYngvw+q9gJN7jx/GWgEjco5UK+MHQsP3GDuXj9OBEQys0dVGnjY58qa+QrPdnBkruKyM7PFvog9AyoPYM7pOfRs24CHa9dh99GAjPW2VV4K1WLI55gZmfFFnS9YWSOONNIIX7Ik4wDPdo99rSqZ8Q+NZeTqs9T+fg/fbLpEqlYyvLkH371diUV9a7BzRCOW/q9WBqUQkRDB1ptb6VSmE9Ymz09hkRoWRtyxY1h16JA5UijARynhqQM30oswNNAwrLkH6wfXx9zEgL5L/2PipkskJL/hew9e7RVXnd/ux03vVXgPN0s3fqyo7D2EP3sGSOWlUBVDAaB+8fqULlebw1WNefjvWlJCQp90Pva1qu6kp7kWHMOwv8/QcrYPOy4G07duSbZ/2pBtnzZkVMty9KlTkjaVnLMMF13nt47ktGT6lO+T7RqPM6k+LxrJ2BLcG+rqV3ohlV2t2Tq8If+r586yo7fo8NMhzt59Tm2PN4HiNcDCKYM7ydjAmLHeYznLXYJbViVqw8bMe3MqL6RQKYYHccmkaNP0LcYrI4Tg0xqfsqZWClKbSuRvS590Pva1qsV7pJQc9gun328naD3nIPuuhDCocRkOj2vKNx0rUt7Z6oVzpMk01vqtpbZTbUpZl8p2bNTmLZiUL58xjBggLQ2u7QCPFkritzzE1MiAiZ0qsuLDOsQna+m64AjTtl8lMeUNtB40GvBsC/57FdddOk3cmlCveD2meV4HI0PC5s/Xo5AFk0KlGD5be552cw9x9Ea4vkV5Zao6VKVy1RYcqWRI5OrVpIY/9Tt4tlPCIuMj9Segnrj3MIG/T9xh1JqzNPphP32X/sfle9GMaVWOw+OaMa6NF3YWL//lfPzecYJig+hernu245JuBpB48SLWHbKwFoJ8ldoLnvpLzdDAw56dIxvRo6Ybi3xu0OGnw2w5f+/NC2316qC47AIOPm4SQjCu1jhCTJO42rQU0Zu3kHTjhh6FLHgUKsXQq5Ybiala3v31P4b9fYbgqMQXX5SPGF59OGvfApmUTMTvvz/pyMLXWhg47BdOix99GL/uAgeuhVHeyYoZ3apweFxThjbzwMbc+JXn/NfvX4qaFKV5iebZjovatFHJpJqVYri6BTSGeq+9YGVqxPTuVfjjg9okpWoZuvIMdb/fy6TNl/ELidGrbDqjVEPFZXc1Yxr60kVL08urFzPL+SNNTQhXrYZXolAphublHdk9sjGfNvdg56Vgms86wK8HbxYY91LpoqWpU+dtjlTUEPnXSlIj0y2ER77Wa4XHnbTj4n0+WHYSNxszdo1sxKkvW7C4nzc9a7lhavR6eYXCE8LZf2c/nct0xtjg+UpFpqURtWkT5vXqYeRYLPOAq9vAvQEUKfpacuiaxuUcODCmKX9+UJt6ZexZfvwWLWcfpOeiY2w4E1Sw3UyGJkoWgOs7FBfeUwyuNhhDW1uO1C9K9LbtJF67richCx6FSjGA4oMd2bIcu0c2ok5pO77bdoV2cw9x7EbBKJU5qMogNtY3RCYlEvnIanjka/Xbk8HX+qayxvcug/86TUUXK1YPrEs5R0ud5A/a6L+RVJlKt3Ldsh0X7+tL6r37WHfunLkz3A8i/PTqRsqKR6Gt8/vU4Pj45oxv60VoTCIjVp/lral7mbLlMjfCYvUt5uvh1R5iQyDoVIZmS2NLRtUcxdJKYWjNTQn/+Sc9CVjw0IliEEK0EUJcE0L4CyE+z6LfRAixOr3/PyGE+1N949Pbrwkhnn/EVMeUtDPnt//VYkk/bxJStPT+9TifrjpDaHT+/mJ1tnCmfr13OFZeQ8SKFaQ+eKB0eLXP5Gt9E1ly6Caf/Xue+mXt+eujOhQ1e3V3UVY82nT2dvR+8abzxo1ozMywbJGFu+lREIBnW53IlRvYWZgwsHEZ9o1uwooP6/BWGTuWHb1F81k+vLf0P474h1Ogsi57tARhkKXF3LFMR8qWqMbW2gbE7N5DwqVLehCw4JFjxSCEMADmA22BCkBvIUSFZ4Z9CDyQUpYFZgPT06+tgFIKtCLQBliQPl+e0aKCI3tGNWZ4cw+2Xwym2SwflhzK3+6lAVUGsLmhCTIxkchlfyiNpRqBscUbG7YqpWTmzmtM2XqF9pWdWfK+N2bGujufeSL4BHdj7r7QWkhLSCBmx04s27RBUySLJHbXtoFTFSjqlrkvn6HRCBp42LOgT02Ojm/G6JbluHI/hj5L/qPTz0fYczmkYCiIIjbgXv9xeoyn0QgNX9T5gnXVk0g2NyZ8nmo1vAy6sBhqA/5SyptSymRgFfCsjd0ZSP8G41+guVBs/87AKillkpQyAPBPny9PMTUyYFTLcuwa0QhvdxumbL1Ch3mH+e9m/nQv2Rexp1Gjvhz3FEQs/xPtw4eKr7VMM+XmSMu/Su110KZJvtp4kZ/3+9OrlhvzelfHxFC3zw8rr6zExsSGliWz3zCO2buPtLi4rN1IsWFw90SBLOFZzNKUYc09ODyuKVO7ViYmMYWP/vTlg2UnuR0Rp2/xXoxnewi7ChGZo48q2FWgY5V3+LeWllgfH+LPnNGDgLohr/aDdKEYXIC7T70PTG/Lckx6jegowO4lr80z3O3N+f1/tfjlvZrEJqXyzuLjjFx9ltCY/Ode+qDiB2xtYg7xCUT8ka5zvdpDbDDcK7gf/GdJTk1jxOqzrDh+h4GNSzO1a+VM2U1zyp3oOxy4e4Aenj0wMcg+tDVq40YMiztjViuLWifXdwAyX7uRXoSpkQG9a5dg96jGfNm+PCcCImk5+yAzd14jJjEf52F69H/+HIt5WPVhHHurKHEWhoQVQKtBmyZZcugmDabvJ+hhQq6vV2A2n4UQHwshfIUQvmFhYbm5Dq0rOrFnVGOGNi3L1vP3aT7Th98OB5Caj9xLRU2L0rxJf455CcL/+EOxGjxaPdfXWhBJSNby8XJfNp+7x7g2XoxvWz5XitSsvLoSA40BvTx7ZTsuJTSUuCNHsO7UCZFV/qNr28DaTXElFXCMDDR81LA0+8Y0oV0lJ37e70+THw6w7EgAyan55z54jE1JcKz0pDDSM1ibWDPwrU9ZWzuN+GPHiD95Mo8FfH38QmLovugoU7ZeobKLFQZ5UKhJF4ohCHjaoeqa3pblGCGEIWANRLzktQBIKRdLKb2llN4OeVCAo4ixAWNae7JzZCOql7Rh0pbLdPjpMCdv5Z9DZP0q9GNX06KI+AQifl8GZrZKPejn3BwFiTsR8fRd+h8+18OY2rUynzR5+UI3r0JMcgzr/dbT1r0tDmbZf66it2xVUmB0ysKNlBwPN/YrT65vUIU1RytT5vSqzqah9SnnaMnEzZdpM/cgvvnoPniMZzu4exzisnYBdy3blbstKhJloSF4zpx8v3+SmKJl9u7rtJ93mFvhccx5pxq//a8WTtamub62LhTDScBDCFFKCGGMspm86Zkxm4D303/uDuyTyl9lE9ArPWqpFOABnNCBTDqjlL05f/SvxaK+NYlOSKHHomOMWnOWsJgkfYuGhbEFHVsNUayG5X8oEUqe7SDsCkTe1Ld4r0Viipa5e/xoOduHq/ej+bl3DXrXLpFr663zW0d8ajx9K/R94dioTZswrVoFk9JZRC3d3A+pCcr//xtIFdeirBxQh9//V4vk1DR6/HKMiZsuEZeUqm/RnuDVDmRauksvMwYaAz5r+CX/1oOkU6eJ2ZH1uPzAIb8w2sw5yNy9frSu5MTuUY3pUt0lz8q65lgxpO8ZDAV2AleANVLKS0KISUKITunDlgJ2Qgh/YBTwefq1l4A1wGVgBzBESpnvTtsIIWhTyYk9oxszuEkZNp+7R7NZilmtb/dSj3I9ONjKCRGfSMSyZU+yeRYwq+FuZDzz9/vTavZBZu+5TosKjuwd3YT2VZxzbc3UtFRWXllJTceaVLB7NpAuI4nXrpF09SrWnTplPeDqNjCxVg62vaEIIWjqVYydIxrx/lvu/HHsFq1mH2Tr+fv54+nbuRpYuWQbmVfFoQrm3d8mwEkQNGUy2ujoPBTwxQQ9TGDIytO8t1R5Pl7+YW1+6l0d+1dI66ILRL74g74i3t7e0tfXV2/r3wiLZeKmSxzyC8fLyZIpXSrh7W6rN3m23txKyKix1LltjOe+/Rj+3V4J4euff/capJRcD4ll79UQdl0KeZwFtGZJG0a2KJcn9QR2397NqAOjmNN0zgtTYITM+IHI5cvxOOiDoY1Nxs40LcwsB6WbQPelWV3+RuJ7K5KvNl7iyv1o6pa2felEhbnK1tFwdiV8dhOMsq6JHZEQwZAFbflqSQy2Pd/B+duJeStjFsQnp/LrwQAW+vgD8EnjsgxsXPq1T/E/DyHEKSllFpETz4xTFcPrIaVkx8VgJm25zP2oRLrVcGV8O6881+ygHM4auqQzQ370x7b//3D2ToFDs2DsDWXfIR8Rk5jCX//dYcXx2wQ+UKIrKrtY076KMx2qOONqY5Znsnyw8wPuxd5j69tbMdA8/waUqan4NW1KkSpVcZv/c+YBt4/B722g+29QKftzEG8a2jTJ3yfuMGvXNaISUni3TglGtfTE9jXyVOkE/72woiv0XpVtdNjyy8sJnjqVDiclJVeuxKyGbqvsvQxpaZLjARGsOx3E9gv3iUvW0r6KM1+0K49L0ayVWk55WcWgVnB7TYQQtK3sTGNPB37a58+SQzfZdTmYMa086V27BMaGeRfwpREa+rYbz5EdH1FvxQoc2s7FUP6g1Byu1jvP5MiOh/HJLD0cwB9HbxGdmEr9snYMaVqWZl7FcLTK/c20Z7nx8AYng08yosaIbJUCQNyx42jDwrHu/Bw30rWtoDGCsvpNmqcPDDSCvnVL0qGKM3P2+LH8+G02n7vPiBYe9K5dQudPvC/EvSGYWCkn0LNRDL28evFuhzXUvx6A0ddfUfqff7I+sKhjpJScufuQLefus+3CfYKjE7EwMaR9FWfeqVWCmiVtXjxJHqBaDDrCPzSWbzZd5Ih/BHbmxrxd3YV3arnh4Zi5EExu8dnKfvSbfBKrXj0oYfEPuNSAXn/l2fpZEZuUym+HA/j14E1iklJpU9GJwU3LUMVVvwnmvv/ve/69/i97euzB1jR7qypo7GfEHjyIx6GDaIyfeRKWEn6qCTbu8N663BO4gHAtOIZJWy5xxD8CewsTPmjgTt+6JbEyfX7dbJ3zT3+4dQhGX4NslP7Re0dZuGgA4/+RmNeujduihbmiHJJStRy/GcmeyyHsuRLC/ahEjA00NCrnQMeqzrSq4EQR47xRoKorSQ9IKfG5HsaqE3fZcyWE1DRJZRdrOlcrTseqxXP9yfha5DV2f/I2TS4JPD9vjNHt9dn6WnOL5NQ0Tt6KZN/VUNafCSIyLplWFRwZ3cozy4ppeU18SjzN/2lOE7cmTG04Ndux2tg4/Bo0wLpLZ5wnTsw8IOwazK8N7WdBrY9yR+AChpSSYzcjWORzk4PXw7A0MWRAo9J80KAUFiZ54KS48C+s/RA+2Akl6mY79NN9nyJ2HmTg5mTMatXSmXK4H5XAvquh7L8axtEb4cQnayliZECjcva0rOBEq4qOeass01FdSXpACEETz2I08SxGeGwSG84EseFsEFO2XuH7bVd4q4wdnau50KaSU658KDxtPVnfpx1p47dy61gUHg7xcNMHPNvofK2nCYtJ4uStSC4GRXEhKIozdx4Sm5SKsaGGRh4ODGtWlqpu+SMFNcDWgK3EpsTyjuc7Lxwbs3s3MjEx67ML8KQOwBsapvo6CCGoV8aeemXsuRgUxby9fvy4+zrLjt5icJMy9K1bMnddTB4tFdfe1a0vVAzjao+jW/AJnMyd6LzqJHc/GYzbwgWvpRxCYxLZfiGYzefu4XtbSW7palOErjVcaOpZjPpl7fPetfaaqBZDHuAfGsums0FsOHuPO5HxGBtqaO5VjM7VXGjq5aDTvD8hcSH8PaglLU+l4tE5FuO6naBzFhumOeBRRNGeK4ppfPbuQ6QEQ43Aw9GSam5FaeZVjPpl7XSa6E4XSCnpvrk7GqFhTYc1L4wLv92/PylB9yizc0fWY39trsTOf7w/lyR+Mzh39yEzd13jkF84rjZF+KyNFx2rOOdeXP6fXSDqLgw79cKh225uY9yhcUyMbkqFBXswq1sHtwUvVg4P45O5fD+aI/7h+FwP42KQEvrq5WRJx6rFaV3RkTIOFnl29uBlUF1J+RApJWfvPmTj2XtsOX+P8NhkLE0Nebu6Cx/UL4W7vblO1lm8fzq1hi/D0MuayrXDX+hrfRlStIp7aHe6n/RupBJRVMXVmhblHWlUzgEvJ8t8/0R0NvQs721/j6/f+poe5XpkOzYlOBj/ps2wHzwYh2FDMw+ICYZZntDsS2g0NpckfrM47BfOd9uucOV+NFXdijKyhQeNyzno/svzxK+wbQwMOQkO5V44/ItDX7A1YCt/ig8x/n4R5m/VxXXBAjSmpsQnp3IzLI7rITFcC47hSnAM14KjCYlWDrkaaAQ1S9jQ2NOBlhUcKZeH+4qviqoY8jmp2jSO3Ihg/elAtl0IJiUtjVYVHHn/LXfqlLbLUaK4+JR4lgxpQvNDMZRsHYrFqBeb1FmRnJrGvqshbL8YzP6roUQnKu6h+mXsaFHBkeZejnlyPF+XTDg8gb139rKvxz7MjLIPjY1YsoTQmbMos3MHxiVLZh7g+xtsGQmfHAPH7A/IqTxBmyZZfyaIH3dd415UIhWLWzGkaVlaV3TSXYLEqECYXRFaTIQGI184PDY5lh6be6CVWr4J64nNvB8JKFGeafU/4m78k0OsxgYayhazwMvJEs/0V42SNnrZL3gdVMVQgAiNTuSPY7dYcfwOUQkp2FsY06qiEx2rFKdOKVs0r3GzHLy0DfN3RxPvrKXhqHeg1eSXvvZORDwrT9zh36IHVmEAACAASURBVFN3CY9Nxs7cmKZexWhR3pGGHvaY58UGYi6QkJpAk9VNaFOqDd/W+/aF42926ozGzAz3VX9nPWBFd4jwh+Fn3qj8SHlFcmoaG84EsdDnBgHhcThbm9LT242etdx0E8f/SyMwMIaP9jx3SIo2jRMBkRy/GcGBAF8CTH5AG1eW+kcrMfLUGgI9qnJj+NeUdCyKp5MF7nbmGBoUmNyjmVAVQwEkIVnL/muhbL1wn31XQklI0eJSVNm86lbD9ZVdTcvHdsJ7sx8mPYwpPensC7+8LgZFsdDnBtsv3EcIQTOvYrxbuwSNyjnoPNW1PtgesJ3PDn7Gb61/o5ZTrWzHJl69SkCXt3H65mtsemdxFiQpBmaUhtofQ+vvckniwoE2TbL7cjArT9zlkF8YAqVO9bt1StLU0+H1v4h9ZsD+72H0VbB0ytDlHxrLP753WXs6iPDYJDT/b+/M42O89j/+PjOZ7AuJWBMSaoskYoutVFBVa9XSUkXpQm9v3bb8qLboTuuWSylaLVpVpVUtilIuVUqQWhJLkJBESEL2feb8/nhGrpBkEjPJZHner9e8MvM85znneyYz833OOd/z+Qplo6V7g2OEpn/BeL9JTLrUkPg5c3Dp359G/16A0FbuadLSoEYlVUEcbLUMCGjAgIAGZOXq2RUez6ZjMXy6N5Ilv0fSyac2Izt482hAfVxKMXTt93+Lid7zKFF/ZeObcA5Rt9U9ZeJTstkdcZ3tp67x58UknO1seL5nMyZ086ly00Sm2HppK/Uc69GhXgeTZVO2/Aw6HS79i4noitwN+lw1GskCaDWC/v4N6O/fgKs3M9lw9Crfh17lubWh1HO147F2jRgU0BD/Rq5lW4toNQj2vq9oJ3WcCEDEtVQW7DzHnrM3sNEoNz/DO3jRrZkHLvY6pOzO3EM3WRO+iqBeC2k3YwY35s/nmqMjDd57t2i59WqIOmKoAlxLyWLziVg2hcZwKVHJpuXuZEsDN3vqONuh0wo0QmCjFdhoNNhoBbZaDQ62Whru+YAHdxzn0JgWpHVbQJ7eQFJGDjdSc4i5lcW562kANPFwZFRHb8Z2aYKbQ9WYLy0LN7Nv0vv73oxvM55XOpQ85yz1eiJ7hWAfGFi0BAbApklwaR9MO2/2wr7KveTpDfx+9gbfHbnCgQuJ5BskTTwc6d+mPiGt6tKhSW10RYwk8vQG0rLzyc7Tk52bT6NvupPh1IQ/uqxgT8R1fv47TtlX0aMpTwY3xtPlXgmbHH0Oz+x4hsjkSNYPXI/r17+SuHQp7hMmUG/mjIrofrmhjhiqEQ3cHHix1wNMeagZx68kc+hiInEp2cSnZJOYnkO+XmKQknyDJF9vIE8vydUbyM7Vk2c7nLaOx3H74wLv5h4FvRseTnbUdbGjYS17hrZrSD+/yhdWZ2l2XN6BXuoZ1HSQybIZhw6Tn5BQvJJqfi5c2AV+Q1WnUE7otBoeaVOfR9rU51ZGLrvC49l68hpfHrzMiv2XcLG3wbeOE0IINAIyc/QkpOdwMyO3UD2zbPyZkLKDWesPkq9zZvJDzZjcsxlujsXf/Nhp7VjYayGjto7i1X2v8u0L31I7OZmbq1djU7cuHhOfKe/uWx3VMVQhhBB0aFK7THoqUkouvLkRlx/CmDR0M2+8aDp2vzqy7dI2Wrm3onnt5ibLpmzZgsbVFedeDxVdIGo/5KQqUxUq5U5tJ1ue6NSYJzo1Jj0nnz8uJLD3bAI30rIxSDBIiaezHR19auPpYoebgw4HnRZ7nZaGqWC7dxu7B+fi3PHhUu+8rudUj/k95/PCby/w9uG3mff6B+QnJnLjo4+w8fTEbXD1/t+rjqGaI4TggSmzOL19BM1/OcPW/lsZ3Gywtc2qUKJTozmZeJJpHaeZLKtPzyDtt99we2woGrtilHLPbgOdkyKzrVKhONvZFKxHlApDfTjiSf243WD3VJna6tKgC/8I+gdLTiyhQ90OjPxoPldv3SJu1ixsPNxx6tbtPnpQNagZKyk1HI1XAA3b2xMYJdm06V0SMssvZ3ZlZOulrQgEj/oWr7Z5m7SdO5HZ2dR67LGiCxgMSlKe5n1BV70W56slGq0SIHDhN8gve9bFZwOepUejHsw/Op+ItEi8ln6Kna8vMS9PJfvc+XIwuHJglmMQQrgLIX4TQlww/r1njkMIESSEOCSEOCOEOCmEeOKOc6uFEJeFEGHGR5A59qgUj/vwwWBvYOC+DF4/8Dp6Q6VLlFcuGKSBXy7+QucGnanrWNdk+ZSffsK2SRPs27YtukDsMUiPV6eRqhKtBkFuGlw+UOZLNULDhz0+xMPBg+n7p5NlJ/BesRyNgwNXJ08m7/qNcjDY+pg7YpgJ7JFSNgf2GF/fTSYwTkrZBugPLBJC3KmoNl1KGWR8hJlpj0oxaAKHUbdVGgGXDSQfPcznpz63tkkVwrHrx4hNj2XoA8WI4N1BbkwsmUeP4jbsseLXYc7+AhobaN7PwpaqlBu+PcHW+X+Ch2XEzc6Nj3p+RFx6HO8cfgeb+vXxXrEcQ0oKV6dMxpCRYWGDrY+5jmEosMb4fA1wz/hbSnleSnnB+DwOuAF4mtmuSllp2I7aQW7YuOh4+aALy8OWcTT+qLWtKne2RG7BSedkMnUnQMrPWwBwG1zMGoyUELFVSQbjUHnUYlVMoLNXFFfPbVemAu+DdnXb8WLQi/x6+Vd+ivwJez8/Gi1aSM6588S88goyP9/CRlsXcx1DPSnlNePzeKBeSYWFEMGALXDxjsPvG6eYFgohKj4vZk1Bo0ETMADPgGQ8Lycz5KI7M/fPJCkrydqWlRuZeZnsit5Ff5/+ONiULLEgpSRlyxYcO3dG16hR0YUSzsHNi9BanUaqcrQaBOnXIebIfVcxyX8Snet35sMjH3Ip+RLOPXtSf/ZsMvYfIP6996iKe8KKw6RjEELsFkKcLuJRaGwulXel2HdGCNEA+Bp4Rkp5222/DrQCOgHuQLG7R4QQzwshQoUQoQkJNWvx1GK0GoSbdwr2D3gxem8+2enJTNw5kfiMeGtbVi7sit5FVn5WqaaRsk6EkRd9BbehJZSN+EX523KghSxUqTBaPKLoJt3+H94HWo2WD3t8iIONA9P2TyM7P5vaT4zC47lnSf5uAze//NKCBlsXk45BStlXSulfxGMLcN34g3/7h7/IlRghhCuwDXhDSnn4jrqvSYUc4CsguAQ7VkopO0opO3p6qjNR94XPgwiHWtTr4wEJSSy73o8bmTcYu30sl5IvWds6i7MlcgtNXJsQ5Gk6piH5h00IR0dc+pWwdhCxBbw7g2spQyVVKg92LtCsN4T/rEwJ3ieejp681/09Lty6wILQBcqxV17B5dH+3Ph4Aak7dljKYqti7lTSz8B44/PxwJa7CwghbIHNwFop5aa7zt12KgJlfeK0mfaolIRWB60G4pj9J679H8Huu1/5Mugj8g35jNsxjr8T/ra2hRbjatpVQq+HMrTZUJMb+vSpqaRu247bwIFonYsRKrx5GeJPQetidkOrVH5aD4aUK3DNvM95D68eTGgzgQ3nNvBb9G8IjYaG8+bh0K4dcTNmknXypIUMth7mOoZ5wMNCiAtAX+NrhBAdhRBfGMuMAnoCE4oIS10nhDgFnALqAO+ZaY+KKVoPgZwU6o7sBkLgvPhb1vZfi6utK5N2TmJP9B5rW2gRfrn4CwJRqs18KT//ouxdeKKEVJ+3pyDU9YWqS8sBILQQ8bPZVb3c7mX8PfyZc3AOsemxaOzs8Pp0CTZ16nD1xX+QFxdnAYOth1mOQUqZJKXsI6Vsbpxyumk8HiqlfNb4/Bsppe6OkNSCsFQpZW8pZYBxamqslDLd/C6plEizELB1QZd0AM9/TSV93z7c/hvGNwO+oaV7S17Z9wpfh39tbSvNIk+fxw/nf6Brw67Ud6pfYlkpJckbNmDfpg0O/m2KLxjxMzRoC7V9LGusSsXh6A4+D5q1znAbnVbHRw99hEQyY/8M8gx52Hh44L1iOTI7m6tTXkSfXnXDWNWdzzUNGztlIe7sNtzHjMYhKIj49z/ANU3Pqn6r6NO4Dx8d/YgP/vqAfEPVDMHbdnkbN7JuMM5vnMmyWSfCyLlwgVpPjCq+UGocxBxVp5GqA60HQ+J5uHHW7Kq8XbyZ03UOfyf8zbKwZQDYPfAAjRYuJCcykrjp05H3GR5rbVTHUBPxGwKZSYjYv2jwwfvIrCyuvf02dlo7Fjy0gPF+41l/dj0v7XmJtNw0a1tbJqSUrDmzhha1W9CtoWktm+QN36FxcsJtYAmRRhHGjVGqY6j63N6xboFRA0B/3/4Mbz6cVadW8WfcnwA493iQejNnkr53L4mfLrVIOxWN6hhqIg/0BRsHCP8Zu6ZN8Xz5n6Tv3kPq1m1oNVqmdZrGnK5z+OvaXzy9/WmupF6xtsWl5kDsASKTI5nQZoLpRefkZFJ/3YHrkMFonErIjhfxM3i2KlVSeZVKjmsDJbLMAusMt5kRPANfN19mHZhFYlYiALXHPoXb44+TuGwZqbt2WaytikJ1DDURWyd4oI8iEWAw4D5hgjKl9Pbb5MbEADCixQiWP7ycG1k3GLZlGJ8c+6RKjB5Wn1lNfaf69PctJvPaHaRs2YLMzaV2SYvOGYkQfVAdLVQnWg+G+JNw0zIh2g42Dix4aAHpeenMOjALgzQghKD+nNnYBwYSN/N1ss9XLcE91THUVPyGQto1iA1F2NjQcMHHAMS9Ng2ZlwdA5wad+XHIj/T37c9Xp79i4I8DWX92PXmGPGtaXixnEs9wNP4oY1uPRacpOQudNBi49e16HNq2xb7VvSlPCzi7FaRB+TFRqR74GTcxnvnJYlU2r92cmcEzOXTtEF+cUgIyNXZ2eC1ZgsbJkZh/vIQ+Odli7ZU3qmOoqdzeCWr8cth6edHgnbfJ+vtvEu6YF63vVJ/3H3yfDYM20KxWMz746wMe3/I4+67uq1QSAFJKVpxcgYvOhREtRpgsn3HoELnR0dR+akzJBc9sBvdmUD/AQpaqWJ1ajaFRBwi3nGMAGN58OAN8B7A0bCmh8UrqYV29ungtXkxefDyxr01D6quGqrHqGGoq9m7KWkP4lgJhMdcBA3Ab/jhJK1eScehQoeJ+Hn58+ciX/CfkPwD88/d/Mn7HeHZG7awUI4jvzn3H3qt7mRgwESddCesFRm6tX4+2dm1cHnmk+EIZiXB5P7QZBjUw6121ps0wZaObhaaTQEmKNbvrbBq7NGbG/hnczL4JgGO7dtSf/RYZBw+SsHChxdorT1THUJNpMwxSY5RQTCP133gDW19fYl+bRt61a4WKCyHo3bg3Pw79kdeDX+dG5g2m/Xca/X/oz+LjiwmNDyVPX/FO4sSNE3x05CN6efViov9Ek+Xz4uJI/30vtUaMKD5LGyiRK9IAbYpJ2qNSdSmH6SQAJ50TCx5aQHJOMq8feB2DURau9siR1Br9JElfrCJl6zaLtlkeiMo0HVBaOnbsKENDQ61tRtUnOxU+fgA6ToRH5xUczrl0iaiRo7D19aXJN1+jsS86U5neoOdA7AG+jfiWv+L/wiANONg4EOgZSPNazWlWqxnNajWjsUtj3O3dyyXXdEJmAqO2jsLRxpH1g9bjautq8pobixaRtGIlzX77DVuvYpRUAdYMgdRYeClUHTFURz7vA4Y8eGG/xaveeH4j7xx6h3+2+yfPBz4PgMzNJfqZiWSfOkXjL1fh2LGjxds1hRDimJTSZMOqY6jpfPeUkpXslXDQ/G8AmbZnDzH/eAm3xx6jwYcfmPxRT81N5Wj8UQ7HHeZU4ikupVwiKz+r4LyzzhlvF2+auDYp+NvQuSGNnBtR17EuNprC6cez87NJzkkmJSeF1NxUUnNSyczPJFufTXZ+NtczrnMl7Qpnks6QlpvGugHraF67ucnuytxcLoT0xiEwEO/PlhVfMD0B/t0CerwGvd80Wa9KFeTPJbDrTXj5BLg3tWjVUkpmHpjJjqgdfP7w5wQ3UPRB82/dInr0GPJv3cJn3TfYPfCARds1heoYVErHqU3wwyR4Zgc06VroVMLiJSQuW0a9N9/EfWzZEqkbpIG49DgupVziatpVolOjuZJ6hejUaOIy4gqG2AACgb2NPXZaO2w1tqTlpRVyKkVhp7XD28UbLxcvxrQaQ9eGXUssf5uUrduImzYN789X4tyjR/EFj66Cba/ClD+hXglSGSpVl+QrsCgA+syBHq9avPrMvEye3PYkablpbBy8kToOdQDIjYkh6snRCFsdPuu/Q1fPdMpZS6E6BpXSkZMOHzeD9uNhwEeFTkmDgZh/vET6/v14r1yBc/fuFmkyT59HbHoscRlxXEu/RnxmPFl5WWTrs8kz5OGic6GWfS1q2dXCzc4NV1tXXG1dcdQ5Yq+1x97GHhdbFzSi7Etk0ePGk3ftGs127kBoSrh+9SBIi4eXjpb7NFJeXh4xMTFkZ2eXazsqRZB2HZDgUrKm1v2SZ8gjMTMRnVaHu717wWdW5uWRn5iI0GrR1qlT8mfxPrC3t8fLywudrnDYdmkdg42pAirVHDtnJX9x+Bbo/yFotAWnhEZDw48/Jnr0aGL/9Qo+G77Drqn5Q26dVoePmw8+bj5m11UWcq9cIfPIETz/NbXkL2LadWVTW49pFbK2EBMTg4uLCz4+PuWyDqNSAukeyjpSXV+wKXotzVxSc1KJSYvB3saeJq5N0Bq/Y/r0dHKjo9HY2mHr64OwsczPsZSSpKQkYmJi8PX1va861KgkFSU6KT0erhy+55TW2Qmvzz5D6HRcnTKF/Fu3rGCgZUjevBmEwO0xE1FGET9XaDRSdnY2Hh4eqlOwBvbG3N1Z5bf5zNXOFS8XL7L12USnRqM3KHsZtM7O2DZpgiE3l9zLly2WN1oIgYeHh1kjUNUxqCib3XSOcHpTkadtvRrh9emn5Mdd4+pzz5N3vchEfZUaqdeTsvknnB58EF19E9MGpzZCXb8KXVtQnYKVsLEFW2fIumlWZjdTuNq54u3iTbY+m6jUqIKwbsU5NMaQm0fO5csYcnMt0p65nyfVMago2kktBygx3cXsQ3Bs345G//mPEso6YgRZYWEVbKR5ZPx5iPz4eGoNf7zkgrei4OpfEGB697RK+dGrVy9MrSMuWrSIzMxM8xtzqA35OZBXcsBDUUyYMIFNm5QbqmeffZbw8PBiyx778xhxp+PI1edyKeUSi5cuZu3atYpz8GkC+fnkXryIPsP6eRzMcgxCCHchxG9CiAvGv7WLKae/I3vbz3cc9xVC/CWEiBRCbDCmAVWxBoGjlLumyOIzuLn0DsFn/XqEvT3RT48j+cfNFWigeST/+ANaNzece/cuueDpH5S//qpjuBt9JZNzsJhjsK8FCMhSpknz73NK54svvsDPz6/Y8/v27SPsaBi+br4IBH2f7MvQJ5SNdlonJ2ybNgWtltyoKKtP2Zo7YpgJ7JFSNgf2GF8XRdYd2dvulKmcDyyUUj4A3AImmWmPyv3SrDc4uCvTKCVg37IFvhu/x6FjB67NmsWNTxZW+mQk+bdukb57D65DhqCxNXHvcWoTeHeB2k0qxrhKQFRUFK1ateKpp56idevWjBgxouAH18fHhxkzZtC+fXs2btzIxYsX6d+/Px06dKBHjx6cPVt0whtnZ+eC55s2bWLChAmAcoc9efJkOnbsSIsWLdi6Vcl1kZWVxZNPPknr1q0ZNmwYWVn/u3ufMmUKHTt2pE2bNsyZMweAxYsXExcXR0hICCEhIQDs2rWLrl270r59e0aOHEl6+r0JIXv16sXUqVMJCgrC39+fI0eOgNaGuYu+5OlJk+nevTtPP/00er2e6dOn06lTJwIDA1mxYgWgLOy+9NJLtGzZkr59+3Ljxo1Cdd8e5ezYsYP27dvTtm1b+vTpQ1RUFMuXL2fhwoV06diFuFNxrPh4Be/Oe5eEzAROnDhBt4ceInj4cJ589VVuhIeTGxdHr169mDFjBsHBwbRo0YIDBw4AcObMGYKDgwkKCiIwMJALFy6U/R9fAuYugw8FehmfrwH2ATNKc6FQJsF6A7dVzNYAc4HPzLRJ5X7Q6pRF6L/XKyGsds7FF61Vi8YrVxL/7nskrVxJ7pUrNJz3YbE7pK1N6tZtyLw809NI18/AjXAYsKBiDCuCt385Q3hcqkXr9GvoypzBJa+XnDt3jlWrVtG9e3cmTpzIsmXLmDZtGgAeHh4cP34cgD59+rB8+XKaN2/OX3/9xYsvvsjvv/9eJnuioqI4cuQIFy9eJCQkhMjISD777DMcHR2JiIjg5MmTtG/fvqD8+++/j7u7O3q9nj59+nDy5ElefvllPvnkE/bu3UudOnVITEzkvffeY/fu3Tg5OTF//nw++eQTZs+efU/7mZmZhIWFsX//fiZOnMjp06dB50D4+Uj+2L8fh1qerFy5Ejc3N44ePUpOTg7du3enX79+nDhxgnPnzhEeHs7169fx8/Nj4sTCMiwJCQk899xz7N+/H19fX27evIm7uzuTJ0/G2dm54H3d+/te9DZ6bmTeYMzTY1j26TJCeoXw1ltvMW/1auZPnYohK4u8nByOHDnC9u3befvtt9m9ezfLly9n6tSpPPXUU+Tm5lp8NGfuiKGelPK2oE48UK+YcvZCiFAhxGEhxO1QDw8gWUp5e9wWAxSrTyCEeN5YR2hCQoKZZqsUScBIyMuEs6a1XIROR/2351J3+nTSdu4keuzTBbkcKhvJP/6IvZ9fyfLaACe/V5LFtxlWMYZVIry9velu3KcyduxY/vjjj4JzTxjzVaSnp/Pnn38ycuRIgoKCeOGFF7h2l55WaRg1ahQajYbmzZvTtGlTzp49y/79+xk7diwAgYGBBAYGFpT//vvvad++Pe3atePMmTNFzuMfPnyY8PBwunfvTlBQEGvWrCE6OrrI9kePHg1Az549SU1NJTk5GWzsGNKvFw4okTy7du1i7dq1BAUF0blzZ5KSkrhw4QL79+9n9OjRaLVaGjZsSO8ipiYPHz5Mz549C0JF3d3di7RDCIGbrRt2eXYkJyfj086HfEM+EyZM4I9jx7D19gYpGdQpGH1GBh06dCAqKgqArl278sEHHzB//nyio6NxcHAo5btfOkyOGIQQu4GiwjjeuPOFlFIKIYpb1m8ipYwVQjQFfhdCnAJSymKolHIlsBKUDW5luVallHh3BrfGynRS2xKS1xgRQuAxaSK2TRoTN/N1Lg97nAbvv4drv34VYGzpyA4PJycignqz3yq5oMGgrC806w1OdSrGuCIwdWdfXtwdxXLnaydjdjuDwUCtWrUIuyvwQK/X06FDBwCGDBnCO++8U+j6u8MmS2rrbi5fvsyCBQs4evQotWvXZsKECUWGYUopefjhh1m/fn1J3Sy+faHBybW2ss7g1ggpJUuWLOGRu9R3t2/fbrL+siCEwNPRExthQ1Z+FpdTLpOXb4xYcnNDY2+PnYM9uVFRSJ2uYP1jzJgxdO7cmW3btjFgwABWrFhRpJO6X0yOGKSUfaWU/kU8tgDXhRANjB1sABQZxyiljDX+vYQy3dQOSAJqCSFuOycvINbsHqncPxqNEo1z8XdFK6iUuPTti+/mH7H18SH25alce2s2efHx5Who6Un+4UeErW3JOZ1BiURKuaqMmmogV65c4ZBRav3bb7/lwQcfvKeMq6srvr6+bNyorENJKfn777/RarWEhYURFhbGO++8A0C9evWIiIjAYDCweXPhIIWNGzdiMBi4ePEily5domXLlvTs2ZNvv/0WgNOnT3Py5EkAUlNTcXJyws3NjevXr/Prr78W1OPi4kJampJVsEuXLhw8eJDIyEgAMjIyOF9M1rQNGzYA8Mcff+Dm5oabm5tyQucIUg85aTzyyCN89tln5BmTVp0/f56MjAx69uzJhg0b0Ov1XLt2jb17995Tf5cuXdi/fz+XL18G4ObNm/fYeydubm64u7sTezIWvdSzdNVSuj1ozFeu0WDbqBFaZ2fyrl8HvR5pMHDp0iWaNm3Kyy+/zNChQwveL0th7lTSz8B44/PxwJa7Cwghagsh7IzP6wDdgXCpaHHsBUaUdL1KBRMwUvlynClbxJGttzc+677B/ZlnSP7xRyIf7kfcG2+Qc+lyORlqGkNODilbt+Ly8MNob3/5i+Pkd8oPQysTDqSa0rJlS5YuXUrr1q25desWU6ZMKbLcunXrWLVqFW3btqVNmzZs2VL0V3bevHkMGjSIbt260aBBg0LnGjduTHBwMI8++ijLly/H3t6eKVOmkJ6eTuvWrZk9e3bBCKRt27a0a9eOVq1aMWbMmILpLoDnn3+e/v37ExISgqenJ6tXr2b06NEEBgbStWvXYhfG7e3tadeuHZMnT2bVqlX/O2FjBxobyEzi2Wefxc/Pj/bt2+Pv788LL7xAfn4+w4YNo3nz5vj5+TFu3Di6dr1Xo8vTU1mjePzxx2nbtm3BVNzgwYPZvHkzQUFBBYvIt1mzZg1vvf4WI3uN5PyZ84x5eUxBPgeh1aJr3BgbDw+kwYDMyeH777/H39+foKAgTp8+zbhx44rs630jpbzvB8o6wR7gArAbcDce7wh8YXzeDTgF/G38O+mO65sCR4BIYCNgV5p2O3ToIFXKkWXdpVzR674vz7kaI6+9866MCGwrw/3ayOsffyz1mZkWNLB0JG/dKsNbtpLpBw+WXDA3U8oPvKT88YWKMewuwsPDrdLubS5fvizbtGlTIW2NHz9ebty4sULaKoqHHnpIHj16tPgCyVeljD0hZX5exRl1F/n6fBmVEiVPJ5yW19KvSYPBUHBOn5NT6nqK+lwBobIUv7FmjRiklElSyj5SyuZSmXK6aTweKqV81vj8TyllgJSyrfHvqjuuvySlDJZSPiClHCmlzDHHHhULETQG4o7DjYj7utzWqxH133qTB37fg9uwx0j6YhWXBg8hMyn1vAAAGJhJREFU/Y+DFja0ZFJ++BFdw4Y4dulScsGz2yAnVem3Ss3GwR2QkG29fQRajVbJYeLgTlJWElfTrhbIaJgMt7YQ6s5nlXsJHKUMqU98Y1Y1Nh4eNHzvPRqvXaNoLT37LKm7dlnIyJLJi40l49Ah3B5/3LRyZdg6ZdG9yb3z6jUBHx8fJWSzAli9ejUjRlhv8+C+ffvoWFKCHFtHsHGAzJsVZ1QRCCFo4NSABk4NSMtN43LqZXL1lpHLKA2qY1C5F6c60KI/nNxQrERGmaoLDsZ384/Ytw0kbsZMskuQDbAUyZuVlI21hpkQwkuJhYt7oe2ThRIVqdRgHN2VsO0868uguzu408S1CXn6PC6nXCYzzwI7vUuB+k1QKZqgpyAjASJ3W6Q6jb093p9+irZWLa5OebFchfhyIiNJ+vJLnHv2RNeohNSdoCw6IyFodLnZo1LFcDAq+2QlWdcOI862zoqMhhBEpUaRdR+aTmVFdQwqRdP8YXDyNHs66U5sPD3xXv4ZhrQ0Yl58EX1KmbaylAp9ejox/3wZjaMj9Y2hk8UiJYR9C427WTy1o0oVRqsDO1fIvFWuiqtlwd7GnqZuTfGw98C+nPJG3InqGFSKRquDwCfg/A7ISLRYtfYtW9Lw3wvIPn+ey6NGkWOMO7cEUkquvfEmuVeu0OiTf5tOmRgTCkmR6qKzyr04eoAhD3Lu3XdgLWw0NtRzqlchEu2qY1ApnqAxYMhXpCIsiEtICE3WrMaQkUnUqCdI21O8omtpkVKStGIlaTt3UvfVV3EKDjZ90Ym1yt4Fv6Fmt1+VSU5OZtmyZeXezr59+/jzzz/LvR2LYO9q3NNguZuiqoTqGFSKp14baNgOjq+x+JDasX17fDdtxLZZM2L+8RJxM2be97qDPjmZ2Kn/ImHRIlwe7Y/7xGdMX5Sdoiip+g9XfgRqMGV1DFJKDPehqFulHIPQKIvQ2SlQgdFAlQXVMaiUTMeJkHAWrhyyeNW6+vVp8s3XeDz3HKnbt3Px0UdJXLESQxlSEqYfPMiloY+R9vvv1J32Go0WLCjdUPvk90rkSceJpstWc2bOnMnFixcJCgrilVdeoU+fPrRv356AgICCnc1RUVG0bNmScePG4e/vz9WrV1m1ahUtWrQgODiY5557jpdeeglQ1EWHDx9Op06d6NSpEwcPHiwkO13Uzt9KiaOH8tfKoavWwDLZp1WqL/7DYeebcHQVNOlm8eo1dnbUfe1Vao0cwfWPPiJh4UJufv01HpMmUfvJJ9AUoxqZefw4CUuWkHnoMLY+Pvh89x0O/qUUoJMSQr+EBkHQqL3p8hXJrzMh/pRl66wfAI/OK/b0vHnzOH36NGFhYeTn55OZmYmrqyuJiYl06dKFIUOUFCoXLlxgzZo1dOnShbi4ON59912OHz+Oi4sLvXv3pm3btgBMnTqVV155hQcffJArV67wyCOPEBERcY/sdKXHxh5sXZQ1Nud6UIPSr6qOQaVkbJ2UUM6jqyB9Hjh7lk8zjRvj/emnZB49SsLSZdyYP5+kL77AsX07dF7e6Bo1wpCWSu7VGHIuRpL990m0Hh7UnTmD2k8+WbZcEFePKHkXBi8ul75UZaSUzJo1i/3796PRaIiNjeX69esANGnShC7GXeRHjhzhoYceKpCUHjlyZIFo3e7duwtJY6emphaZNKdK4OShpHvNSQV7E3pb1QjVMaiYpsMz8NdyCPsGHnylXJty7NSJJqu/IvPYMW6u/ZqcyEjS/7sfaUySbuPpic7bm7rTp1N79JNoHB3L3kjol8qdoP9wC1tvAUq4s68I1q1bR0JCAseOHUOn0+Hj41Mgc31bftsUBoOBw4cPY19JEzeVCXs3ZRE6I0l1DCoqhajbSpGLCP0Kuk2tkB3Cjh064GhU2JQGA/qkJDQuLuZnicu8qSjHtn+6xCx1NYk75aBTUlKoW7cuOp2OvXv3FpvsplOnTvzrX//i1q1buLi48MMPPxAQEABAv379WLJkCdOnTwcgLCyMoKAgXFxcSE21bHa6ckdolLWG9OuQnws2NSMtvbr4rFI6Ok2E5Gi4aH5oaVkRGg02np6WSR0a9i3oc5RRkAqgpO7s3r07/v7+hIWFERoaSkBAAGvXrqVVMVnvGjVqxKxZswgODqZ79+74+PgU5DVYvHgxoaGhBAYG4ufnx/Lly4GSZacrNQWL0DUndFXISrKzryx07NhR3k66rVJB5OfCwjbKYu2YDda25v7Q58OSduDaCCbusLY1BURERNC6dWtrm1Fm0tPTcXZ2LshTMHHiRIYNq6ZpUW9eUnKh1/OvMppaRX2uhBDHpJQlqAgqVI0eqlgfG1sltPP8Dki8YG1r7o+zWyH5CnT9h7UtqRbMnTuXoKAg/P398fX15bHHTAgWVmWc6ioJrLJqRuiqusagUno6PQt/LIRDn8Lg/1jbmrJzaCnU9oGWA6xtSbVgwYIF1jah4rB1Ap2DIizp6FHtQ1fNGjEIIdyFEL8JIS4Y/9YuokyIECLsjke2EOIx47nVQojLd5wLMscelXLG2VORp/77O4vqJ1UIV49AzBHo8iJotNa2RqWqIYQyasjPrlT6SeWFuVNJM4E9UsrmKCk+Z95dQEq5V0oZJKUMAnoDmcCd2Vqm3z4vpQwz0x6V8qbrS8qX4+gX1rakbBxaqoQbBj1lbUtUqioOtUCjg4zyk4yvLJjrGIYCa4zP1wCmJhlHAL9KKSsm24SK5fFsoSTxObISKkAX3iLcioaIn5VIJDVEVeV+ERoliVVOWtX57N8n5jqGelLKa8bn8UA9E+WfBNbfdex9IcRJIcRCIYRdcRcKIZ4XQoQKIUITEhLMMFnFbLr9EzKTlCmlqsDhz5QvdfDz1rZEparjWEf5LKVft7Yl5YpJxyCE2C2EOF3Eo5BWsVTiXouNfRVCNAACgJ13HH4daAV0AtyBGcVdL6VcKaXsKKXs6OlZPrIMKqWkSXdFZ+jPxUoIaGUm7Toc+woCRoGbiWxuNZiJEydSt25d/P39y3ytj48PiYmlX3OqSJXV1atXF4j7LV++nLVr1xZbNioqim+//bbgdWhoKC+//HLhQlobxTlk3VKmVKspJh2DlLKvlNK/iMcW4LrxB//2D39Jk2+jgM1SyoIkwlLKa1IhB/gKKIWIvorVEQIe+j8ltvuUZXM1WJyD/1HyVvesIsJtVmLChAns2FExezvMdQz5+fd3MzJ58mTGjRtX7Pm7HUPHjh1ZvLgIPS3nuoBGuemoppg7lfQzMN74fDywpYSyo7lrGukOpyJQ1idOm2mPSkXRcgA0aAv/na/88FZG0q5D6ColksqjmbWtqdT07NmzQBCvODIyMhg4cCBt27bF39+fDRv+t9FxyZIlBVLdZ8+eBeDmzZs89thjBAYG0qVLF06ePGlSfnvu3Lk8/fTTdO3alebNm/P5558DijPp0aMHQ4YMwc/PD4BvvvmG4OBggoKCeOGFF9Dr9QB89dVXBXLgBw8eLFT37RDbyMhI+vbtS9u2bWnfvj0XL15k5syZHDhwgKCgIBYuXMi+ffsYNGjQvX3p3oOTl29A1k3mzn6TiRMn0qtXL5o2bVrgSEp6r6oC5u5jmAd8L4SYBESjjAoQQnQEJkspnzW+9gG8gf/edf06IYQnIIAwYLKZ9qhUFEJAyBvw7ShFZqLDeNPXVDQHFylOq8dr1rak1Mw/Mp+zN89atM5W7q2YEVzsLG2p2bFjBw0bNmTbtm2Aoqt0mzp16nD8+HGWLVvGggUL+OKLL5gzZw7t2rXjp59+4vfff2fcuHGEhYWZlN8+efIkhw8fJiMjg3bt2jFw4EAAjh8/zunTp/H19SUiIoINGzZw8OBBdDodL774IuvWrePhhx9mzpw5HDt2DDc3N0JCQmjXrt09bTz11FPMnDmTYcOGkZ2djcFgYN68eSxYsICtW7cCijO6zT19+ccrhG1fA7npnD17lr1795KWlkbLli2ZMmVKie9VVcCsEYOUMklK2UdK2dw45XTTeDz0tlMwvo6SUjaSUhruur63lDLAODU1VkpZRbV5ayjN+0GjDrD/Y0UyozKRFq+oqKqjBYsREBDAb7/9xowZMzhw4ECBNhLA448/DkCHDh2IiooC4I8//uDpp58GoHfv3iQlJZVKRG/o0KE4ODhQp04dQkJCOHLkCADBwcH4+voCsGfPHo4dO0anTp0ICgpiz549XLp0ib/++otevXrh6emJra0tTzzxxD31p6WlERsbWyDfYW9vj6MJld57+3KTVL0d5GUz8NFHsLOzo06dOtStW5fr16+X+F5VBdSdzyr3jxAQMgu+GQ4nvoZOk6xt0f/4Y1GVXFuwxJ29pbh69SqDBw8GlPn5yZMnc/z4cbZv386bb75Jnz59mD17NgB2dkpAoVarve81gNvcnYHv9us7Zb+llIwfP54PP/ywUNmffvrJrLbLhDE3iZ3MKTh0u/8tWrQo9r2qCqhaSSrm0awPeHdWRg05lWTAl3gBjn4O7Z4C96bWtqbK4u3tTVhYWMH0T1xcHI6OjowdO5bp06dz/PjxEq/v0aMH69atA5RpmTp16uDq6lpI5rsotmzZQnZ2NklJSezbt49OnTrdU6ZPnz5s2rSJGzeUeJebN28SHR1N586d+e9//0tSUhJ5eXls3LjxnmtdXFzw8vIqcCI5OTlkZmaWaFeRfaldR5HKyM+C3IxC5cv6XlU2VMegYh5CQL/3IO0aHKgk2jk7Z4HOEXq/ZW1LqgyjR4+ma9eunDt3Di8vL1atWnVPmVOnThUs9r799tu8+eabJdY5d+5cjh07RmBgIDNnzmTNGmUvrCn57cDAQEJCQujSpQtvvfUWDRs2vKeMn58f7733Hv369SMwMJCHH36Ya9eu0aBBA+bOnUvXrl3p3r17saq1X3/9NYsXLyYwMJBu3boRHx9PYGAgWq2Wtm3bsnDhwlL1BVsnZV9DSqySMvY+36vKhiq7rWIZNk+G0z/Ai4etO6d/fhd8O1JxVt3+aT07ykBVld0uD+bOnVu18kKDkt0t5QrUagKOJUd2VSSq7LaK9ek7F7R2sON169mQnws7Xwf3ZhD8gvXsUKlZOLqDjQOkxoHBYLp8FUBdfFaxDC71lU1vv70F53dCi0cq3oYjKyEpEsZ8X2NSMFY35s6da20Tyo4Qyq76pEhFKsO1gbUtMht1xKBiOTpPBo/m8Ov/VfxCdNJF2PsBPPCwdZySSs3GzgUcaiuOIbfqa4SqjkHFctjYwuBFiprpzlkV164+X1nj0NpUzQRCKtUDNy/Q2Ci50av4lJLqGFQsi8+D0H0qHF8DEVsrps2Di5QkPAP+rQrlqVgPjQ3UaqyI66VfM12+EqM6BhXLE/KGoqP08z+VHcjlSVwY7PsQ2jwOASPKty0VFVPYuyqpP9NvVJ59PfeB6hhULI+NLTz+hZLMZPNkMOjLp53Mm/DDJHDyhIH/rvZ5eMuLq1evEhISgp+fH23atOE//ynbdFyNkt0uDa6NQGsLt6Iqn1RMKVEdg0r54NkCHp0Pl/bCttcKbf6xCPk5sGEsJF+B4asqVfx4VcPGxoZ///vfhIeHc/jwYZYuXUp4eHi5tVflZbdNodEqO+6lXpGmL68bo3JEdQwq5UeH8fDgK0qinP/Ot1y9BoMyEok+CI99Bj7dLVd3DaRBgwa0b98eUOQiWrduTWxs7D3lVNntwn25XWeRstu5BgZO/D/a9hqCf5vWbPiuimQ7NKLuY1ApX/rMUeZb932oJDjpONG8+qRU9kqc+RH6vl3t1hXiP/iAnAjLym7btW5F/VmlixKLiorixIkTdO7c+Z5zquz2vX0Bipfd9mrMtk3fQEoMKXk65bNr7nSnQa+MSMoZdcSgUr4IoYSQNu8HW19VxPbuN5QvLwt+fA4OfQqdnlOin1QsRnp6OsOHD2fRokW4urrec16V3S66LwMHDixedvudBRw4GYWbLs+8aSUpISMRboQr06jljFkjBiHESGAu0BoIllIWKWAkhOgP/AfQAl9IKecZj/sC3wEewDHgaSll1VytUSkerQ5GrYWfX4bf31MiiR77TIngKC1p8fDdGIg9pojj9XitWi42l/bO3tLk5eUxfPhwnnrqqYIfeVV2u3Tc7jsUI7v94SL69OjK7BefUJR/3ZuWbWe+QQ8pV5U803YuimhfOWNuC6eBx4H9xRUQQmiBpcCjgB8wWgjhZzw9H1gopXwAuAVUIkF/FYuic4DHV8IjH8K5X+Hz3hDxi+nRQ34O/LUSlveAG2fhiXVKjoVq6BSshZSSSZMm0bp1a1599dWC46rstum+FMc979Xps4qGlz4HEs4qNzqmRg9SQk4aJJ5XnIJLfaUOra7k6yyAWSMGKWUE3Ovh7yIYiJRSXjKW/Q4YKoSIAHoDY4zl1qCMPj4zxyaVSowQ0PVFqB8AP7+kRBXVaamooDbuCrWbKB/6vGzly3PlEPy5BFJjoUl3ePQjqO9v7V5UOw4ePMjXX39NQEAAQUFBAHzwwQcMGDCgULlTp04xffp0NBoNOp2Ozz4r+at6e2E2MDAQR0fHQrLbI0aMYMuWLSxZsoQePXoUuu627HZiYmKB7Pb58+cLlblTdttgMKDT6Vi6dCldunQpkN2uVatWQX/u5uuvv+aFF15g9uzZ6HQ6Nm7cWEh2e8KECYXWJorrS3EU+V7Zuyqf99Q4RaY+I1FJ9mPrrIjwaTSKM9DnQW4apCcouR40NopDKMsI20wsIrsthNgHTCtqKkkIMQLof0f+56eBzihO4LBxtIAQwhv4VUpp8puvym5XA/T5EP4THPgEbpxRjgmtcleUFq+E+oGSBChkFvg+VG1HCars9v+okrLb90NOuuIg8m4n+BGKAzDk/a+Mjb2yR8fBXXEaZcQc2W2TIwYhxG6gfhGn3pBSbim1lWYihHgeeB6gcePGFdWsSnmhtVEiivyHQ+xxSDynCOGlXFVkBer5Kw+PZtXWIajUYOyclb0++jwl+1tepvJcq1M2x9nYG5MAWeezb9IxSCn7mtlGLOB9x2sv47EkoJYQwkZKmX/H8eLsWAmsBGXEYKZNKpUFIcCrg/JQqfFUSdltc9DqwKGW8qhEVES46lGguRDCVwhhCzwJ/CyVOay9wO1A9PFAhY1AVFRUVFSKxizHIIQYJoSIAboC24QQO43HGwohtgMYRwMvATuBCOB7KaVxUpkZwKtCiEiUkNV7E82qqNQAqmKKXZXKi7mfJ3OjkjYDm4s4HgcMuOP1dmB7EeUuoUQtqajUWOzt7UlKSsLDw8NUhJ+KikmklCQlJWFvb3/fdaiSGCoqVsbLy4uYmBgSEhKsbYpKNcHe3h4vL6/7vl51DCoqVkan0xVIPaioVAZUrSQVFRUVlUKojkFFRUVFpRCqY1BRUVFRKYRFJDEqGiFEAhBtbTvKSB2g9PkPqwdqn2sGap+rDk2klJ6mClVJx1AVEUKElkajpDqh9rlmoPa5+qFOJamoqKioFEJ1DCoqKioqhVAdQ8Wx0toGWAG1zzUDtc/VDHWNQUVFRUWlEOqIQUVFRUWlEKpjsAJCiNeEEFIIUcfatpQ3QoiPhRBnhRAnhRCbhRCVS3jegggh+gshzgkhIoUQM61tT3kjhPAWQuwVQoQLIc4IIaZa26aKQAihFUKcEEJstbYt5YXqGCoYYwrTfsAVa9tSQfwG+EspA4HzwOtWtqdcEEJogaXAo4AfMFoI4Wddq8qdfOA1KaUf0AX4Rw3oM8BUlBQC1RbVMVQ8C4H/A2rE4o6UcpcxJwfAYZRMfdWRYCBSSnlJSpkLfAcMtbJN5YqU8pqU8rjxeRrKj2Uj61pVvgghvICBwBfWtqU8UR1DBSKEGArESin/trYtVmIi8Ku1jSgnGgFX73gdQzX/kbwTIYQP0A74y7qWlDuLUG7sDNY2pDxRZbctjBBiN1C/iFNvALNQppGqFSX1WUq5xVjmDZSph3UVaZtK+SOEcAZ+AP4lpUy1tj3lhRBiEHBDSnlMCNHL2vaUJ6pjsDBSyr5FHRdCBAC+wN/GLF1ewHEhRLCUMr4CTbQ4xfX5NkKICcAgoI+svvHRsYD3Ha+9jMeqNUIIHYpTWCel/NHa9pQz3YEhQogBgD3gKoT4Rko51sp2WRx1H4OVEEJEAR2llFVRiKvUCCH6A58AD0kpq22KMiGEDcrieh8Uh3AUGHNHfvNqh1DucNYAN6WU/7K2PRWJccQwTUo5yNq2lAfqGoNKefMp4AL8JoQIE0Ist7ZB5YFxgf0lYCfKIuz31dkpGOkOPA30Nv5vw4x30ypVHHXEoKKioqJSCHXEoKKioqJSCNUxqKioqKgUQnUMKioqKiqFUB2DioqKikohVMegoqKiolII1TGoqKioqBRCdQwqKioqKoVQHYOKioqKSiH+H3TL8NNoVj+6AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# batch the inference across K=100\n", "targets = np.sin(xrange_inputs)\n", "predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n", "plt.plot(xrange_inputs, predictions, label='pre-update predictions')\n", "plt.plot(xrange_inputs, targets, label='target')\n", "\n", "x1 = onp.random.uniform(low=-5., high=5., size=(10,1))\n", "y1 = 1. * onp.sin(x1 + 0.)\n", "\n", "for i in range(1,3):\n", " net_params = inner_update(net_params, x1, y1)\n", " predictions = vmap(partial(net_apply, net_params))(xrange_inputs)\n", " plt.plot(xrange_inputs, predictions, label='{}-shot predictions'.format(i))\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztnXd8VFX2wL93Jo1eQi8SulIUlaJgQZAFVwVdCwoq6Coq4s9ddQVd27Ksi23trl1AZQVBXBSFRURBejH0FiBAQg0lIYG0mfv7472ZTJ83M28mk8z9fj755L37bjnzJjnvvHPPPVdIKVEoFApFYmCpbAEUCoVCETuU0lcoFIoEQil9hUKhSCCU0lcoFIoEQil9hUKhSCCU0lcoFIoEwpDSF0IMEULsEEJkCSEm+Lh+hRBivRCiXAhxs8e1UUKIXfrPKLMEVygUCkXoiGBx+kIIK7ATGATkAGuA26WUW13qZAB1gceBuVLKWXp5Q2At0BOQwDrgYinlSbM/iEKhUCiCY8TS7w1kSSn3SClLgS+BYa4VpJTZUsqNgN2j7WBgoZTyhK7oFwJDTJBboVAoFGGQZKBOS+CAy3kO0Mdg/77atvSsJIQYA4wBqFWr1sXnnnuuwe4DU1JuZ+eR04brdxd70V5I3Nkk25kij9d4Leu5j5ObD0BaspWOTWq7XSssKWdvXpFbO0d9z35cr/m7Hg6BxgunnkKhMI9169blSSkbB6tnROlHHSnlB8AHAD179pRr1641pd+so4Vc/a9fDNdfmzbCZ3lG8eumyOM13uRr3ceZMA+Ari3qMu//Lne7tjwrjxEfrXK2k1LS9snvffbj2pe/6+Hg6DNYf0brKRQK8xBC7DNSz4h7Jxdo7XLeSi8zQiRt44a6FFW2CF6cLimvbBEUCkUVxIjSXwN0FEK0FUKkALcBcw32vwD4nRCigRCiAfA7vSwmCGFOP28kv21ORwbxKbdH2fnP/y8msigUiupFUKUvpSwHxqEp623ATCnlFiHERCHEUAAhRC8hRA5wC/C+EGKL3vYE8He0B8caYKJeVqXIEIcrWwSFQqEwBUM+fSnl98D3HmXPuhyvQXPd+Gr7CfBJBDJWOia9MCgUVYKysjJycnIoLi6ubFEUPkhLS6NVq1YkJyeH1T4uJnKjRVVV1sIEyTPSa5J9/IwJ0igSjZycHOrUqUNGRgbCLB+pwhSklBw/fpycnBzatm0bVh8qDUMVIdQHgdoaRxEuxcXFpKenK4UfhwghSE9Pj+gtTCl9A2RYjtBFZMdsvGj+rxWqqB+FAZTCj18i/W6qtdI38w/3+9SnGGpZblp/oVKvRmj+O1+ffFNOPt2eW8C8jYfMEUqhUFQ5qrXSD5VcmR7w+lPJX8REDl8Ku1aqVfudYjXUhy/3zuaD2krZpbuOhSmZQqGo6iil74KsAlO/Rt9e1H73iqrKqVOnePfdd8Nqm5GRQV5eXtB62dnZdOvWLaS+p0yZwsGDB4PWGTduXEj9OliyZAkXXXQRSUlJzJo1K6w+jFCtlX78q/DKQT0QFPFMJEo/mhhR+pFwzjnnMGXKFEaM8J0OxiyqdchmIiN9OHjUQ1ARKn/7dgtbDxaY2meXFnV57vqufq9PmDCB3bt306NHD6666io2btzIyZMnKSsrY9KkSQwbNoyioiJuvfVWcnJysNlsPPPMMwwfPtzZx9mzZ/nDH/7AH/7wB+677z6f45SXlzNy5EjWr19P165dmTZtGjVr1mTixIl8++23nD17lr59+/L+++8ze/Zs1q5dy8iRI6lRowYrVqxg8+bNPPLIIxQVFZGamsqiRYsAOHjwIEOGDGH37t3ceOONvPTSS4buS0ZGBgAWS3Rt8Wqt9M0OQBCVGAgZqnWurHlFVWXy5Mls3ryZzMxMysvLOXPmDHXr1iUvL49LLrmEoUOHMn/+fFq0aMG8eVpyv/z8iqyyhYWF3Hbbbdx1113cddddfsfZsWMHH3/8Mf369eOee+7h3Xff5fHHH2fcuHE8+6y29vTOO+/ku+++4+abb+btt9/mlVdeoWfPnpSWljJ8+HBmzJhBr169KCgooEaNGgBkZmby22+/kZqaSufOnXn44Ydp3bo1w4cPZ8eOHV5yPProowHlNJtqrfTNpqk4hTZFGmWb2YSnlVL6CjMIZJHHAiklTz31FEuWLMFisZCbm8uRI0fo3r07jz32GOPHj+e6667j8ssrstIOGzaMJ554gpEjRwbsu3Xr1vTr1w+AO+64gzfffJPHH3+cxYsX89JLL3HmzBlOnDhB165duf76693a7tixg+bNm9OrVy8A6tat67w2cOBA6tXT0op36dKFffv20bp1a2bMmGHKPYmUau3Tr5MW3jLlQJwrDgSvZJC8whKf5dF2w/hy/SgU8cgXX3zBsWPHWLduHZmZmTRt2pTi4mI6derE+vXr6d69O08//TQTJ050tunXrx/z58/HwK6AXufFxcWMHTuWWbNmsWnTJu67776QF0KlpqY6j61WK+Xl2tqY4cOH06NHD6+fadOmhdR/pFRrpd+wVorpfZrp4rHbo6d8B3VpCsDgrk2dZWq9jaIqUKdOHU6f1jY/ys/Pp0mTJiQnJ7N48WL27dNSxh88eJCaNWtyxx138Je//IX169c720+cOJEGDRrw0EMPBRxn//79rFixAoDp06dz2WWXORV8o0aNKCwsdIuicZWrc+fOHDp0iDVr1gBw+vRpp3L3x4wZM8jMzPT6iaVrB6q50gdoUNO4tZ8nY7vTUzTt7XT9gdfBYwcuhSLeSU9Pp1+/fnTr1o3MzEzWrl1L9+7dmTZtGo5d9TZt2kTv3r3p0aMHf/vb33j66afd+njjjTc4e/YsTzzxhN9xOnfuzDvvvMN5553HyZMnefDBB6lfvz733Xcf3bp1Y/DgwU73DcDo0aN54IEH6NGjBzabjRkzZvDwww9zwQUXMGjQoIgT1K1Zs4ZWrVrx1Vdfcf/999O1a3Rca8qn78J9pY+xJm1s0HqXWTbxeNJMbip9HhvGFktFSqAHRF5hCWU2O83r1fC6ZkbyNoUi1kyfPj3g9YyMDAYPHuxVnp2d7Tz+9NNPA7bfvn27z2uTJk1i0qRJXuU33XQTN910k/O8V69erFy50q3O6NGjGT16tPP8u+++8yuDJ7169SInJ8dw/XCp9ko/lFQMx6gftI5E8EryezQTJ2lEPkdoGLZs/lyOobphek76EYBstT2hQqEIQvVX+ib3p8XumOOYcZ1Q3eyykbk5fQe4puZxFQnC8ePHGThwoFf5okWLSE8PnHalulLtlX48Y7NLPv51LyP7nMN1b/0alTFc3xqUq0eRaKSnp5OZmVnZYsQV1V7pRyNixawcPf/NPMjLC3Z4hW6a0buy5hUKhS+qffROPCcfcOS2zz9bFrRusJhjf8Tvp1coFJVBtVf6Zlv6XcQ+momT5nbqQdcWsQ0dVSgUiUO1V/pm81rKv03v09OINyO2PtCqW+X5USgSl2qv9GPl3piePInR1vkhtalQ9u5q+Lm5WziUf9YcwdxnchWKuCdR8+k7mD17NkII1q5dG1E//qj+Sj9Giq6vdSvPJ4eXQ8OXu37bIXPT2SoUVYVEzacPWjqHN954gz59+kRtjOofvRPH5m0oic9Cdcmo6B2FKfwwAQ5vMrfPZt3hmsl+LydqPn2AZ555hvHjx/Pyyy8bbhMq1d7SVygUVYvJkyfTvn17MjMzefnll5kzZw7r169n8eLFPPbYY0gpnfn0N2zYwObNmxkyZIizfWFhIddffz233367X4UPWnrksWPHsm3bNurWret8uxg3bhxr1qxh8+bNnD171plPv2fPnnzxxRdkZmZitVoZPnw4b7zxBhs2bODHH390y6c/Y8YMNm3axIwZMzhwQMvMGyzL5vr16zlw4ADXXhvdlfXV39KPX0OfWWu1PBtbY+zKUW8BCsMEsMhjQaLk07fb7Tz66KNMmTIlpPsTDtXe0o9jnc/xolIAdh0pjNoYws+xQlEVSJR8+qdPn2bz5s3079+fjIwMVq5cydChQ6MymVvtlX40MWtlru/9bP33He5CrUDjKRTxQiLm069Xrx55eXlkZ2eTnZ3NJZdcwty5c+nZs2cot84QSunHAUZ0uGudz1fuC14/AnkUisokUfPpx4pq79OvCoSqoH/aftRwXV9zGvEc0aRQQGLm03fl559/DqudEZSlHw+EqPVDqf76j7s4etrdAlHuHYUicVGWfhxQarOb36mLP2j8rI18enfvkDaUUSiqAyqfvjdK6VdB9p84E1L9MpuHZa8MfUUQpJTVwkiojvn0Iw3kUO6dKkPFF73nWFEItRWK0EhLS+P48eMRKxeF+UgpOX78OGlpaWH3oSz9BMDLYKv6BpwiirRq1YqcnByOHTtW2aIofJCWlkarVq3Cbm9I6QshhgBvAFbgIynlZI/rqcA04GLgODBcSpkthEgGPgIu0seaJqX8Z9jShkEsXlHTMXd/W4AZaw5w1blNTOnLy2BTBpwiAMnJybRt27ayxVBEiaDuHSGEFXgHuAboAtwuhOjiUe2PwEkpZQfgNeBFvfwWIFVK2R3tgXC/ECLDHNHjhxQCL8oIh+2HfadmcOy2FQxfb+bKwFcoFEZ8+r2BLCnlHillKfAlMMyjzjBgqn48CxgoNBNbArWEEElADaAUqDY5gztZtNw58W44e77sxLu8CoUiehhR+i2BAy7nOXqZzzpSynIgH0hHewAUAYeA/cArUsoTngMIIcYIIdYKIdZWJT/i5ynR91RJKTlaUKySpCkUClOIdvROb8AGtADaAo8JIdp5VpJSfiCl7Cml7Nm4ceMoi2Q+lijazp+v3EfvFxbx/pI9IbVzXYC1dJf7TkLKzaNQJC5GlH4u0NrlvJVe5rOO7sqphzahOwKYL6Usk1IeBZYB5mcQqkQEdjpYPG+HeSzLOg7ArHU5Efdls2sPgt15wUM+FQpF9cSI0l8DdBRCtBVCpAC3AXM96swFRunHNwM/SS3Idz8wAEAIUQu4BPCd8CJOsMvQ7OD7rPOYlvJi8Ioxxpc7aGPuKQA2HDgVY2kUCkW8EFTp6z76ccACYBswU0q5RQgxUQgxVK/2MZAuhMgCHgUm6OXvALWFEFvQHh6fSik3mv0hjNCsbviLGQLR3bI3Kv1Gg2jdA4VCUXUwFKcvpfwe+N6j7FmX42K08EzPdoW+yiuD12/rwX8zc/nP6gMB64XqnS8lOXyhAhCN9QVtG9UGYEjXZqb3rVAoqgYJk4ahbloyY65oH7TedJt3cqZAmLWRitl4PrwKisucx5aE+dYVCoUn6t/fg2fLR/N02d2VLYbp3De1Yts1Ff6pUCQuCaP0jXpLJBZKouSyqUw25ebH9SbxCoUiNiSM0o8WTThZ2SI4KSm3OY89rXnXc2XpKxSJS0Ip/WgYuldYN0Wh1/ByZnd+er7//lTyBYVCQYIpfaMIXUHOsfWrNBm8Nj4xEeXmUSgSF6X0A1AqkzkrUypbjLDwtOztyr2jUChQSt8nZhvCTTjJ9tRRdBXZhttEY9ciZeArFAql9INgN0FV9rdmkibKuMv6P+PjmqzzXT+F8u8rFImLUvo+cKhEgWSfNG/1qghB2UasmNVe6AqFwgcJpfSNTmA6qkkED5Y9EvG4jlW7ISn9KGppoRw9CkXCUu2Vfu6pswCcKbUFqemNBPJlLZMlMsbJM6WRdSD8nyr3jkKRuFR7pe/g2OniyhYhpFDJiEM2fbh3VKimQqFIGKUfirvE1Q1T2QnVoqGoVcimQpG4JIzSDwcZkhfeQH8hdlZcZuOHzYfDGyusVgqForqTMEo/FCVorqoPn/yzZcEreeCaQlmhUCg8SRilHw6a6q889044rp1vfgu0X69y6isUiU7CKH0pwwtVDNen39+SyQNWz62EQyMseaXjd3y8rSgUivjC0HaJ1YFQwhTNsIenpLwEwHu2oUFq+iYjvWZYlr4/Ze8esqlQKBKVhLH0w0OYkoYhHGqmJIU1sj+FrhS9QqGABFL64Xo7KjNk0xKGqV/h3jFZGIVCUS1IGKUfCq7RO6Fa+o055fdaKA8QsxdTiTD6m/jtVjImzDNPCIVCUekkjNIPx/DV4vRD05Rr0sbSgIIwRnNn26ECthwMvR8j7h2jbwGfLNsb8vgKhSK+SRylL2UICdciW5FbTxSF3MYXC7aEvjDLMZHrS68H+yRFJeXYzc7prFAo4oqEUfrhIAlP6Zs1D2C2Xz5QdwXFZXR9bgH/WrjT3EEVCkVcoZS+D07LGgDkU6vSonfCJdwHRf4ZbSXvnICLuxQKRVUnYeL0Q2GuvS91ys7yle1KU6z2cNM6hJMC2dHGl/KvWo8vhUIRDRJG6YdiAUssfG4b5DwLeaw4de+49OxV4pjvUCt5FYrqTcK4d+xhKzPz7ONYxPz7+5il5XZEgJlsxzWl8hWK6k3CKP2qSHhhpo7fobVWrh+FIjFIGKUvZfiLnd4rvy60scIbxhSMvtBIKSksKQ+7vUKhqJokjNIH/Lo3aqZYg7U0XxgDhKOAA03kuvKf1Qfo9twC9uZpawqcPn0Dj6xymz10wRQKRVyQMEpfAvVqJANw/QUt3K49OqhT0LahjeX7ITE86ecQewodIw8KKWHhVm3h1968QqAijbOR9i98vz1s+RQKReWSQNE7ktqpSWz/+xCOFBTz7YaDzmvJ1sDPvm32NiGN1deyxe08/PcEc30tRwq0zeHzz5ZRJ0376h1KvsLSD87SXcdMlUuhUMQOQ5a+EGKIEGKHECJLCDHBx/VUIcQM/foqIUSGy7XzhRArhBBbhBCbhBBp5okfOmnJ1pA3J5lr78vvSl40XP+l5A/dzi3Ezh1SUm5n0L9+YfXeE17XJn67FYC1+05WROs4lL5ex4ilH072T4VCER8EtfSFEFbgHWAQkAOsEULMlVJudan2R+CklLKDEOI24EVguBAiCfgcuFNKuUEIkQ5UyiaukdrMO2XrsNuGqyLD8envPHyaXUcLfV4rdfHFR6K2fU0AKxSKqoERS783kCWl3COlLAW+BIZ51BkGTNWPZwEDhWZK/g7YKKXcACClPC6ltJkjemgM7tLMeVzqMREZr4armZE0LevXoFHtVP8VnPcg+KC5p86aIpNCoYg9RpR+S+CAy3mOXuazjpSyHMgH0oFOgBRCLBBCrBdCPOFrACHEGCHEWiHE2mPHouMvTkup+Kg2j0yS0Q5TDDcNQziEugjNUTuUiVyFQlF1iXb0ThJwGTBS/32jEGKgZyUp5QdSyp5Syp6NGzeOskjh5bSp6mjpFSo+99HTJQCs2nMcCG0iV6FQVF2MKP1cwNWh3Uov81lH9+PXA46jvRUskVLmSSnPAN8DF0UqdKQczi92O49b904YKvhsmW/vmQS6tawHQMNaKWzKzQfgs5X7ANeJXO8xVT4ehaL6YETprwE6CiHaCiFSgNuAuR515gKj9OObgZ+kpikWAN2FEDX1h8GVwFYqAdeIneKy2C4uCte9E86E6dJdeX6vXX++tj7hyk4Vb1NO947+5Dt5powzpe7jKp2vUFQfgip93Uc/Dk2BbwNmSim3CCEmCiGG6tU+BtKFEFnAo8AEve1J4F9oD45MYL2UstI3XW1cJyVmY7XAvxIOxvebQt85K2R8KPSHp//mdm5TWl+hqDYYWpwlpfwezTXjWvasy3ExcIuftp+jhW1WKpXlwmkmvOPlKwO/etvHfVmT7S5z+BlKFQpFvJEwaRjccdd04eq0ffYmBkaKD4UZyvyAZ44iu0q1o1BUGxJU6bsTqiV7QfEHnF/8AV/bLg9a14KMG8XvE187bHlY/8rSVyiqDwmTeyeQd8eoTrui5DVs0kI+tbV2Bta1WuJE4UvpGpPvUu7MyulfTqX0FYrqg7L0qVBqLeoFTgu0XzYlF+/Il0AMsa4OW64GNZPDbutJqGr7o6V7nMfKvaNQVB8SRukH2iqwTXotAO6/sn1IfRqx9O9OWkAxsYsW8ofdLrH50d7bDhXw4BfrneenzpQxad62irbK0lcoqg0Jo/RdaeZh0V90Tn0WPXYld10aWgplo8Rib9xgHC8qZfzsTV7lNrvkmjeW+szK6ayjlL5CUW1ISKXfsn4Nt3MJtG9cO+DbQCSEO5HbK6OhyZJ4YzcgmrL0FYrqQ8IofTMmcr37jK4ybNWgZlT7N4rS+QpF9SFhlH4gwk3AZjQyJ9yHwy87j4bVzmxOnQlvC4Tpq/abLIlCoYiUhFH6Rj03d/fLCKFPo0rfnWssq+hvyQzabvexIsOyhESIXqyUpPD+TJ6a4z2HoFAoKpeEidMPiIvu7tG6vuFmRnWnq6XfRhzm3ylvAPC7khcj2pErVqgsmwpF9SFhLP3oELp7J50C53EDfG9rGG/4muzdmHOKXwNk9FQoFPFJwlj6gSJzwrVjLxRZhur1tWxxHl9m2ewikwx/8Bjiy9If+vYyALInXxtrcRQKRQQoS5/wo1OusBrzWV9vXVkxlkt5XOfkUSgU1RKl9Int9okHZPDMnPGGejQpFNUHpfRN5FvbJUHr5MpGMZDEXNQ8rkJRfVBKn/CVWpFMdTs/KeuYIE30EXGQFkKhUFQOSulHQBHu6RysVM90lLF0fykUiuiSMNE7gQhVpU2/tw/Ldx9HLHdvaQlR6Ve3idzFO47yxKyNlS2GQqEIgFL6hL746KI2DejboRF5Hko/mKV/xsMdVFUwentemLeNY6dLoiuMQqGICOXeiQBPS90qbAHrS6p3yGb1+jQKRfVEKX1Cn8hNsWq37bGyB9zKt9jbMjvAvrm1RAmzUic6z+NlK8VgBLo/+WEmY1MoFJWDUvohMurSNlgsWvTLz/YLySiezhp7JwA22dtSII2nQ/4sZTIDLOuDV6xkAk3k9nrhx4p6KrZToYh7lNIH6qYZ34vWavG+ZTasACQJG6+W3wLAfFsvQ/0NsawxPHY8UlpePSOWFIrqSsIr/fXPDKJehBuQl0lN6VuxU0hNMoqn80DZnw21vTXpF5Ioj2j8UJm9Piek+sqAVyiqDwmv9BvWinzT8p/tFwBwUKaH1X60dUHEMigUCoURqr3Sf+v2C7mknXl7zfryb39s+z29it9lj2wRVp91xNlIxYoL1AuBQhH/VHulf/0FLfhyzKWG64e3ObrgGMY3X/GkKSfCbhsP/Pvn3ZUtgkKhMEi1V/pmEw3/dmdLaD72WBPsM784f7teMfqyKBSKyFBKPw44KsN/S4gnzNL5Z0rLOXq62KTeFAqFK0rph8gVncxPjbzU3l3P2xOfpnKsE64Ne3sZvf+xKKZjKhSJglL6BqmRbGXbxCEMOLep6X1PSv6UPWl38ETSDNP7NgMjLq1ym529eUWmjLfraNXYO1ihqIoopW+QujWSqJFi9Xlt2YQBpowxNmkuVmzUIr6ieYpKypnzW+B5h6Mq0ZpCUSVQWTZNwJGLxwzeS36NQdb1ZBRPN63PSHli9kZyTsbXg0ihUISHsvTjjEHW+MvFoxS+QlF9MKT0hRBDhBA7hBBZQogJPq6nCiFm6NdXCSEyPK6fI4QoFEI8bo7YCoVCoQiHoEpfCGEF3gGuAboAtwshunhU+yNwUkrZAXgNeNHj+r+AHyIXN/qo3WPDw9dcb41k33MgCoWi8jBi6fcGsqSUe6SUpcCXwDCPOsOAqfrxLGCg0Je2CiFuAPYCW8wRuXI4Uxp4g5REp6TMnPuz7VCBKf0oFArfGFH6LYEDLuc5epnPOlLKciAfSBdC1AbGA38LNIAQYowQYq0QYu2xY8eMyh5TThfHNhNmdtoI2opDMR0zEu78eLUp/Yz6xJx+FAqFb6I9kfs88JqUMmDgtZTyAyllTyllz8aNG0dZJPMJK12PAZ5J+oxPkz09ZfFJ7ilzJnvjc3maQlF9MBKymQu0djlvpZf5qpMjhEgC6gHHgT7AzUKIl4D6gF0IUSylfDtiyaNEOEon2OKl3fbmtLeEbrUPsGZqB1V0R8KzYbh8VO5+hSK6GLH01wAdhRBthRApwG3AXI86c4FR+vHNwE9S43IpZYaUMgN4HXghnhW+Ih5w1/q3fbCCPi5bMtrtklnrcii3qR27FIpwCKr0dR/9OGABsA2YKaXcIoSYKIQYqlf7GM2HnwU8CniFdVYVwtnn1Yh755PyIWFIE5jhPVsHrxQHrNt3kn3HjaVosHvc/pV7TnCkoGK175drDvD4VxuYtmKfmSIqFAmDIZ++lPJ7KWUnKWV7KeU/9LJnpZRz9eNiKeUtUsoOUsreUso9Pvp4Xkr5irnix472jWuF3VYg+Xv5HfDAMhMlgj8N6mhqf9Hipn8v58qXfzZUN9hD93ih9gA4eaY0UrEUioRErcg1SHibq+htkUgs0KybiRKBqIarCoK9Z9n0h0Ik34dCkcgopR+Ato0qrPtIVIxb2xYXhtFD4sxuBvOuOdw/FqXzFYqwUErfg5V7KrYudHU1hGJYdmxSG4BH677qfbHXvSHLZEkopR/4szquV8e3HIUiFiil78Gm3FM+y0NRMum1UwA4bamjt3VRZBfeEbJMPURWyG3ikV1HTgetE9S9o5v6SVal9BWKcFBK3yCBLH3PS8l6qmWpXxERWup1xZmI2scLg15bwvRV+5m+ar/fOkUlgVc+O3z6VuXfUSjCQil9E/BU6Z2aahZ+rZRkU/q3YKe72MO91nmm9FeZPDVnE0/N2eT3umfIpic2m6701USuQhEWahMVD/y5lEOJFvnL4M50a1mXosO74aj2JvDRXT3DlsmKnW9TnwbgI9u1LjKF3WWl88cpa+jcrA5PDDmX5bvzaNuoFs3r1QjabocBF5FCofCPsvQNEki/el5LS7Zy44WtKExrzjxbb8aVPczVXcLfW1fbNL16sWj7Ud79eTcAIz5cxTVvLDXUbumuPAB2HfWv/LcczGfHYfVwUCh8oSz9ALga/eFY1VJYeajsTxHLYXVT+pKqlPXf6ArnU2dCSzDULMBbwbVv/gpA9uRr/dZRKBIVZel74EtH1a+ZzN9v8F5YNeuBS/l23GV+N0z3S4gRPK5KP9JJ4Vjz1k+BI4/+m+mZuy8wN13UCoBW9YO7ghQKhTdK6XvgS6XPJQveAAAgAElEQVTOGduPi85p4FXeM6Mh3VvVo2ZKEqufGuh1/erzmvgeZMCzIclkd/maqlrM/r8W7gx4/ZEvM/1e+1V35bhSlecxFIp4IGGV/nnN65q6nV+TumleZR31KB5v3BV3gQxstR6WFQ+cqmbpR8IdH6/yf1HAsdMlZEyYx8KtR2InlEJRxUlYn/4Pj1zus9zVB+2Ity8zO42vDK2/M1Q8UFwt/UQ0el3db5tz8wH4aOkedh45zf1XtCPJmrB2jEJhiIRV+kYYcG4Tso4WUr+GsXj7p35/Ln3bNwp5nJ0ig55s83u9vqjYeCyRLP1AHDtdQiN95fOqvSdYtfcE9WsmM7JPm0qWTKGIb5RZ5IGrJTl+yLksnzDAp+vGF2OuaE+3lvUM1Kyw0RfZLuT/eCJg7ZeSPnBpqZQ+wMsLdnitnSguq36hrQqF2ShLPwBWi6BFNKJEdGWVJ+tyb9lj1LamBKze2lKxWbxS+hVY1KyuQhEyytKvFDRlJQGJJSTffFWL3jEbmeBzGgpFpCil78Gdl8bAJ1yjPgBbOo1jzti+ITXt0cqI+yg0erdtaHqfscDT0A9nq0uFItFQSt+DPrFQgEmp8Hw+/UeO50If8f+BmDqiM3+0fo+ZG6uM7pthWl+RYGizc5ePrdw7CkXoKJ++B5WhR0JR30lvdueZZNgrmyG52pTx4yVL8RcBUi77wlNsZegrFMFRlr4HlbL3ahjK6gLLbuymabn40PoFZ4Pn3wn2iZfsPBakhkKR2Cil70FlqL/2+vaKoVCD0qC5540SL5a+EVz99p4ff9vhAu76ZHVsBVIoqhhK6XvgsPTbpNeM2Zifju4VcpvzxD7sJmn9Snm78cGc33JpXs/Ymghf5IeYqVOhSESUT98Dh/qLpX+4QS2XOP06LeD0weBtRKFpMsaHyoc9eUVB67ilu46eKApFtUVZ+h44IkJkrOPhR+tbIabVNVS9FmeRpUVcadmAEJqsjyXNZJjl15CHjhNDP2RSk93/fNU8rkIRHKX0PXAoQHusV/TXaqz9lnb+U35V0OptLUc45/2OTE15kcGWdQA8nPQNb6S8C8D1F7QwPHRVCn10fbtJsngofR+vPoUl5Vz/1q9sP1wQbdEUiiqBUvpR5tfxV7H0ieBKHIvuabOV8WT5fXQqnsrAkpfZZm/tt4kjJcOt1p+9rl3TrZlxIauOznfDU8UfPFXsVWfF7uNsys3n5fk7YiOUQhHnKJ++Bw6j16zVna0aGJwQtuqZPO3lAJSSzG7ZklKCZ/gcYFnvVWY0MyhULUs/EIfyz3qVOb7HotLyWIujUMQlytL3oMKnH+uBKyx9V8KVI5R2VUHll5TbeGXBDs6W2Zxlng/mLi2850MO5WvW/8o9J3z2u/tYIRO/3WpaJJRCEe8oS98Dp08/1ss7U/VdttpdCW67BIauki3YQ4rsqQqG/o3vLGfrIXe/vOdH9PWZTxSVBuz3gc/WsetoISP6tKZDE387nSkU1Qdl6XvQoKYWPvnIwE6xHTi1DvzfbzD0LbdiaVDpp5PvPL7POs9v9NG5zbwVm6gCtr6nwgdvJb9q7wmP65I3Fu0K2O+uo9oGNaeLlftHkRgope9BWrKV7MnXMqLPObEfvGE7LRlbGJxrqchbc6v1Z7+Wfs0U732Bq4KlHw7HTpf4LD9w4gyLtx91K8srDPxGYAZ2u2TGmv2UlqvNXhSVh1L6cY7doBX+Rco/ncftLYewlpz0We8c2z5q4h7l0tmH9V81COzD8rfS+KpXfubuKWvcyopd5gp8UW6z89c5m8g95T1ZXFhSzvzNh4PICt9uPMj42Zt4e3FW0LoKRbRQSj/OMere8eTcFY8D0N/yG/9M+hCAJMp5/fiDvJv8hlvdRrXDe7uobEKZt3BNmV3uY9L2bBClv3rvCb5YtZ+/fLXB69r4WRt54PN1ZB09HbAPhwspr9D3G0hlszevSL2FJACGlL4QYogQYocQIksIMcHH9VQhxAz9+iohRIZePkgIsU4IsUn/PcBc8as/yYTna65RdJDLLRuZkvIytyctBiANzYXRz7LZWe/Z67pELmQlEUznu85r1A0SwurL0t+bV8Tdn67mbKkNi56VrtzmPWr2cS19xNnSwArTGRkWhzmgT50p5apXfuapOZsqWxRFlAmq9IUQVuAd4BqgC3C7EMJTU/wROCml7AC8Bryol+cB10spuwOjgM/MEjwR+Oahflxg2RNW25r5u/gsZbLzvB6FpKKFgyYLTcFNuqEb91zWNnJBfXBB6/pR6TcUXFdVO96X/G3U4kvp/+3bLSzecYwVe/JIcih9vdNTZ0p5dEYmp4vLnNlOg82NOLKZ2uIwPLSwRDMulmflBalZPcg/U8af9e8v0TBi6fcGsqSUe6SUpcCXwDCPOsOAqfrxLGCgEEJIKX+TUjqyh20BagghqqYvoRLo0bo+e+1NTelrQ9oYUjzeGmoke0/qmsVdl0R/28lgBrOvsNs7Pl7ls64vK92hnC1CYNU1tqPsw6V7+Pq3XKYuz3Za7sEWua3fr82znC2LPxeKqKz1KZXEv3/ZzZzfcpm2Yl9lixJzjCj9lsABl/McvcxnHSllOZAPpHvUuQlYL6X0cmgKIcYIIdYKIdYeO6Y2wXDlyfL7yLY3BWvFs/KKktfC6itFaFZNgVXbotE1dU33lubtvfvr+KtiEhFUVBLY9eVL6bsu0np5wXbn8Zky774cCj7JYnHm+XHMByRbtfOjp0uc41j8/Dct2naEH7ceYebaHAB2HQns+/fHfzNzg647CJdIv66C4jKemLUh6HcSKSM+XMn7v+yOuB+b/saWFMZmEm8t2sWN7y6LWIbKIiYTuUKIrmgun/t9XZdSfiCl7Cml7Nm4ceNYiFRlWGnvQv/S12DgM1rBJQ+xX4Zn/f/y6OUASP1rj1b6hVYNasZE6XtG4HgS7E3gncUVyuNMibd7x6HgrRZvS9+qf8A9x4qc7h1/9/OPU9dy77S1Xv2GwpGCYh75MpP7P1sbvHIl8MEve5i5Nocpy7OjOs7y3cf55w/bg1cMgut3GyqvLtzJb/tPRSxDZWFE6ecCrlm/WullPusIIZKAesBx/bwVMAe4S0oZ+SM6UbE4cvNE4IMs18INa6YmUSvFSt/2jUwQzDfxkM/HVekHE+dMqbfSd6RmuP3DlSzfrfm6bXbJZyv38erCnQC0a1yrwtI3+JnD8ek7omrWZPsOxQ0Vz8lkh+yHC4qZtS4n5Mlmm14/0C2QUvKXrzbERfSSzeONLZEw8onXAB2FEG2FECnAbcBcjzpz0SZqAW4GfpJSSiFEfWAeMEFKWXXfh2JMqwY1GNu/vXuh1T03z12l4+Euz68hCO9fAUByUhJbJg6hcZ3qPb1ic1FcwVYdn/GRkM21/ax1mmvGZpc8881mt3rSaekbk6tMn0w+erqYjAnz6PrsfHYcDuzyMZoWpLjMRsaEeYybvp6TflxBm3Pzafvk925uJodrSkp4/KsNzpXKZvLmoiy+WpdDz0k/RtxXpG6kcC39eHhgRUpQpa/76McBC4BtwEwp5RYhxEQhxFC92sdAuhAiC3gUcIR1jgM6AM8KITL1nyamf4pqxq/jB/DEkHMB+Oiunkwc1tXL0l9ivwBa9wlvgKI82PGDW1HMN42JAdNWZBuu+8Pmw5wpLSffZXN2V4vcMdHpyzUTqqXveIPYe0wL9SwqtTEqwN6+RwuKufNjY3v/fr5Sm5j8buMhbv9wpc86/1mtrd5eudd3EjqAkhAnmx3PpEAP18wDFW8pJeWB10UE408zMiNq7/iegy3KcyXr6GlTHliVjaF3Gynl91LKTlLK9lLKf+hlz0op5+rHxVLKW6SUHaSUvaWUe/TySVLKWlLKHi4/RwONpXDn6i5NuevSDKitPytrNWHGmEv46bErwRJm9I2tBP5zG5QG354wXFLi4LX502XZbufBrOkuzy7ggr/9z3nuqvRLdOXg6ZoRQJ6e7sH1Sv7ZMjImzOPBz9d5jeN4cLjWP1xQzEdLtfDcn7Yf4WiBtmr6ZFEpL87fwf4TZwLK7jqug+1+Pq/DVZRidVHQHs+yaBgBpS7hsp2fnh9RXwu3HuHv320Nu/28jYcA978Ru10GDFndd9zYdxDvVP5/psIYnYbATR9D/wn0aZdOu8a1tXTM7fqH3+fLHeGY5puuaS/Cio3W4gh1ifzV/nddQ9jEJUYMfn1J2G0d+/d6Kv28olKKfMwHOB4eP/hIz3BKV8yeHptJ87aRV1jCPVPWctN7ywG48O8Lmb0+x7CcRrxAjs8wfvYmZq3L4WypzVDG0oDjEtynX+ZjYVskfPzr3oj7cHXvTF2RzYiPVrFgi++UGuFM+sYjSulXFYSA7je7J2QTAga/4F7volEYpqwIVrwNwMyTw/kl9c8sTf0z/0sdH7aYDfVN3uPtHySceWVfbWwe2tBhMYLxlbYOS9uXNe24duCEd44fTx6dmcnbP+3i3qlrOHVG898b8f2XuTy4Hv9qA5N/2Baykvck00A0i+f92eYjc2ogXv9xZ0j1jZDk8rbjsORzTrrf+89WZLN01zGfk74fLNnN0l1VK8xcKf2qjud/a2GI3rNtc6FAWz/XSmivts1E4AiRwZY1tBZHvMqb10tj3v9dFtr4VYxAkTehBuX4UrShhAJ+vT6XV/63kx+3HeUzfZFRIBHWZJ+g3GZ3xqg7mLpiH4t3uP/dhPoMcKS1DvRsPXXGPfLMMbfgYN2+k9zy3nK/+X9e/zFwmuxwyHVR8A5DxfPh9Mx/t3Dnx6udE/AOpJS88P12w/Mt/ticmx80b5OZKKVf1Un1yJC5N0QXxtmT8K/zvIobUEDtVO89du6/oh3vp7zGwpQnnGUXt9EWey189Eqa16sR2vgxIhxL3+5D9/hL4wCh59TxVf2h6d5bX3ry237vh7KjK1+W/sacUzz3383c8t4KXl7ge6/gJ792z7kT7iZCnvd5/uZDzigiz4ggT+X+5NcbWZN9kj15obkXl+w8xux1xl1grri+kTqO/H12z+K5Gw76rBcq1731K1f/K3zXY6gopV/VadAG7pxTcT7gr9Diooi7rUEpv/ylv48r2l9+mqiw2mY/2Jfsydf6fEhUZXxt3FIQYLMVR6bOdfv8R8W44mtPX1f8hSXe+O5yfva0zB0KyYe+Gvr2MqbqbwIr9xw3JJvnG01Juc2QNbpi93HyCktYt+8kRwuKeeDz9TzgYzIbtMiZ9ftP8t3GgyzLyqPE5SHgCD2dufaAz7au3PXJah7zkf3UCHXSKv5mHUn17BKy84pYvjvP7UHucF06OOBjcn3EhysZ/NoS8s+WBV09XVxm8xtWG02q139potLeJXnpRaPgwjvhxB5ISoOZd0Geb+suEMvT/o8zWbWBNLfyJ1df6nY+5op24UhcpUmyCJ+hm0PfXsYPj1zOTf9eEbB9l+baXr6OXdr8EejtZPSn7quRHfMDwSz0DTn5tKgf/G3Mc0pm0nfb+GzlPn57ZhANavmXe/GOY86wxqVPXAV472jmYOHWI3yT6W0tC4TTlfPawp3c2rO1Vx0Hnm9XUkqKSm2GDZAOTWo7jx0ht3Yp6f/KzwCkJlXYxanJ7jayxce81fLd2kPVMZH/3PVduLuf76SGvf/xY0AjIlooS7+6kVob0upCix7Q5Fy48d9hd1Xzm3towknGWefw8eWFXrtu/fjoFTwxuHOkEseEvNPmWVSB0igYWW3reIMIpDxDxTFsAO9TSKS5JOM7W2rjMz3+f0kIk5YHTlZYwptz872u+4p6Ajh5ppT39Pw6wbxyby6q2JCmsKScj5bupdtzC3jH4EY1y7Iq3nwcOtz1OeL69uH5PPVl6Xvyt2/9h5VWhsIHpfSrPy0vjqh5fVHI48lfMXDNGLaOdrf6OzSpQ1IcxOMbwRFyaQRfq3ONkpZs/H64WpG+8Od/98Wbi3aRnVdkyBfvK4zUk4Vbj/DrLm1i/8dtFZP2j3yZyZ5jhfxj3lZWB1jcBTDiw4qMpte99WvQMR18sWq/V9mK3b7dUq+5RPRMXZ7NJ8u0MM5g966lj7cdp6VvcEb+P6uDu54cfL5ynzOVhz9ufX9F1BLquVI1/mMVwfnTZnjcT3TDFX8Ju9uGwsWPO80jo/aHAysmjtd+AhtnVlzLnM6t1sU+++ws9tMcY75lswhlYvJ4DPbLheAyeS4uC8b6/Se93jR8be9ohNd/3OVMQ+250nhZVh4fLt3Lre8HdmOFy4YDFRFMB/O1VBX+Vhe7sjk3n0P5xUHrQeAtQo/62Vs5krDWp7/ZzIgPV5ExYR6ni8uYucb7gbF67wlDcxiRopR+daF+64pVu54MeBqe9369NsKXKZP8X8xdC1Ovh6wf4bs/w9f3aeW2cvjmQV5K/tBnswWpE1iR9nBY8oRLKBbUloMFYVv7RhRDo9qpnCwqDTnEMxgWIbyUfL/JP0XU5xOzNnhNOBdHeT8Ao6uPPfH1BlNcZuPF+dspLrNRVFLufHvxZFlWnnMR3FfrfCveYJa6PzzfHFbsPs4Tszf6rBuLTdXURG4i8eg2n+GZEbPDZUm9rQxm+l8g1koE9wkL7Fxu2cQS+/lEnuk9dPxFmxjBiCLPKyzhwr8v5OuxfcMexxdCwE/bzc1y4tgDwJVfq9DuWp8s28u/f95NrRQrm3MLmL/lMMsnDOCsx3zCyI8qXFGjLs3g/SXeO9b9b6v32hQjeC7oC7RxS7ihsqGgLP1Eom6L6PS7xsWiz/wCdsxzu1yHM9RBs95+TX0kaHcjrYuYlvIi11m0V/q6FGIh/nab8kUoOWvM3is3Vumsa6d524pD4iztRq0UK4u3H+WY7qp55X872amHnF724k+B3yb83MZgcxj+8HS5BXpoOnIuRRNl6Scaj+2EnDVgL4evQkjZYJRvvZX6prR7AdiZ0gVcvCydxAEOyXROU9OtfluhvaafX+8MC06WszFtDF+UD+Sv5X80X16TGf1J4I1dXDHbvROrPKlljoRtLhPRWcfMT8UcCUWlNq9NdioWX7nPdSwM04I3SijWu+dbQTRQln6iUacpnHcddL0hJsPdYV3oPO5U6h6+9r/U8WxKu5fstBGkUcId1oU0oIA/Jmlpn9OSrXQWWiTHyKRFXiGj8cjhECw1o1EiRvGXvsBsHG4O1+8jKwr5981G+HkTum+a+25k7//i7doxQqGfxXR//26b4T5i4dNXSj+Raeh7YdV1JZMoahZmrn4PJiV/aqje9rS7mZT8Kf9KrlhXkISN8UlfOs97ZTQM2k9dChmf9B9EFXAHuaZBNoP02imM7HOOqX0GQkptQ5jTxeZ+jqrIkp3H6PbcAp/XPHMMBSIlSBivGSj3TiLT4WpY/UHF+QW3c1v2dWw+IjjQ/3XO/fJS/22jxFXWiuX0O0sacJ6osJyn7h+EPVVwkHSW27qyyH4RB2RjUignU3YA4IuUF+huyaa5OM6fysbFXP5QGPNZ+BPGvqiRbKVlg9jlPpJSMvaL9VF3j5hFNN9GloUZ2ePJlZ2iv0e4UvqJTK973ZW+sLDyiPYKXFBe8eou67ZEFHhuixx9Ck4XIpI893KVtCKPW5N+4VZ+cZYfl3V4pGwc3S3ZAPSyuC/Oqc9pTlGbyogGihW3fRA8lt1MCorLq4zCjzah7jTmj8wDp+jfObqbCyr3TiLjcCBaU+CcvjD0Lecle1LF6lsx4JmKNo5tG2s0iLp47SyH6GHZbahuujjN5yn/dJ5PLf8djqnNduIgmWn384j1a263LnJz/aRSihUtfC9DHKKdMCdzoiKxmLI825R+vvIRIms2ytJPZJJ1V8C518Etmu+9S/O6bD1UQO3adaHZ+dBjJPS4Hfb+Ahv+A+P3VqRzfr6e9rv9QG3V76dDtPP+T8LP/yRSxiX9N+y2TyX/h6eS/+NW9ufk2QDkykbssrdyLhDLtjdlim0wzydPA+CmkufYJttwqWULv9q7c444yuWWTXxiuyZseRQKI8Qi6lYp/USmQRu467/QqpezyOlMERZ4YGlF3Rvf035ceXCFlsGz643a+bnXwfbvtCyfafVhfvg7cEWTFMq52FKRsyXDcoTnLdOc57NT/+az3bPJn3FB8QfkU9utvJfYzkbZjhLMS6CmSEweuqpD1MdQSj/Radff7fSGHi3YdqiA5vXSfFZ3o2kX7cfBLVPh9CGo1xIueQDqNNPWApxzKVzyoOZO+moUXPsqzHvM1I8RCh+lvBp22w1pY8gons5Vlt9oKfIYZl1GL/0B0rb4c/am3QFARvF0rUwcYnHqYzxbNoqDshE/2iNLgKeo3iTFYJtRpfQVboy5oh2j+ma4pdY1jDVJywHkoOsNUHs+NO4MNfVwy6752pZUBYdg6Ssw7F3471hzhI8RV1vW+XxwOBQ+aEnljsgGvJisTZRPTJ6qlRdP4Q/Wpcy09ecqSyY/2i/CMbn8h4ta8vX62E+YK+IHlYZBEXOEEOEpfH+0ubRC4TuwWGDgM1oSuAtHwjl9kTd+wMtltwLwVY1b3aoPLpns3r7XvebJFwZG3hQWpE4gM+1+entEEe1IG80/kz9md9qdfJTyKg9ZK+Yt6lSznccUoRNorwazUEpfUfnc8wPiguG8Y7uBwSWTufz+N+G86wHoXfwOh6TLQ+Ppo3DtqyzsN51HSx9wFs+2XR5rqU3hL8kzqUsRVmzUKchiqGUZAI05xQ2WX3kl+T2/backv8hd1gW0F7n0s2h73NajkLrE/+rY6sAFreub3qeRTXgiRSl9RdzQqHYqO+Q5WKwCbplK/qM5HKUBZ123bExKBeDqq3/PiPu0zdl/s3fg6qfmOP3oAMusvdlg11YcL24WWc6eHNmILsWfRNRHIDam3cfutDt5fPco3kx5h7HWb1iTNpbXU97lZusSlqU+TAvy+GfSh2SnjWB96hiSKKe/dQMTk6eyKPUvfKGHq25IG8PGtDFu/Vuwc6HYxafJL5JGCRniEBniEK/V/ow2oiIdcTOXPQ6SKGdC0nQa4L1PcDhcdE6Fglz/zCDev70byYSWvjqUDWqizZ+v7sS13c1JMneevn0mxEbpq/dJRdxwe+/WvPVTlra/qcVKvbp1mHRDN22V4ptAx8HOukIIerZN5+qSl8iRjdleI9mtr5FFfwKgc/IRLm3Rh6sOf+w13o+2C7na+ptX+T57E9pYKlIUf1J+DWcwMLFtEk8kz3Q7bymOszzt/5znDUUhWWl3ebX7e1LFg+mXuxqTs30N/TY+xXFZh3R9M5zt1rsrGpTDjak/cGvJM8xM/TsAfyodyzf2yxhg+Y0Hkr7jgaTvnPfjN3sHHi57mBxZsWq0DmdIppwUyrjRuoxPbEO8ophSKGPqseHMSurLz+0ep2GtFAbPOZ9daTCqdDw3Wpfy57KxDLUsp7dlh1divU9G92TAuU0ByJjgnsE1XCYnfUBvy3YGlL5KOAv2Hrm6I1+s8p8iORSu6daMbfoWmkrpKxKKRwd14uEBHd3yj9xxSRvt4LEdWhioB1mylfP4Hzd2Y8g3k5n98JXwprZt3l57c65OtXJ+8QfUppiDNHLWt2Bnz7NXc+ToIZp+eCEAJQ/9xpWvbuXnlEfJsGirTQWSlvVr8OuN27isZBl8fS/y3kWIVj21Sel/toSy8Db+MJM7k350HreZOQj9zjkVvj8cCh/g9qSf+KG0t1vOI8cD8EJLFr+mPuJ8o0qhzJlBdYu9DV0t+xif/KV+XSKQSCy8k/wGdeRp7k5awK0jp2ub7OhMTXkRgBusy51l7cQhbi972nnuUPjY7cxLeZLXym/2ioK6oHV9tx23fHGlZQOnZQ1yZSNuS/oZgEGWdSy09/SqaxFaNs4McYjR1gU8Xz4agFbiKMUyFWzlXLn3dRpzMcfQFipmp41wtnd963RwhWUD5VhZbT+Xcl31PnF5I6489RWv0x07lpj49JXSV8QNQghSkvxYXXV8v0pPubsX8zYeAmBknzaM7POgfkVT+uV2O6lJVgqoTYFHfL0dCyTXoGnLdhSN3UBqw1YkW6zANvqXvsa25wfx9MRn+Mbej90TBuitboHzb6mwDS0WGJ/N08+NZ76tN7XFGXo2T2XSqCHcOXkaW2QGU69Jo/lP/0crEf+bj/SxbGdH2uiAdbLTRvBs/Rf488kXnGVdLfvcrvuj1sap2kK/AFxq3Uq2VeujpNnF8PwIbd1IxuV0tezjo5RXnUq1MScppAbDe7amTe48VtvP5TDpWLAz2LKGBfZeNOYUzyR/xnVWbaOUP5Q87xzrRuuv/GS/EBta8EIqpUy++QL2b17JrzsO8VXqRABm267gDKksStW3Ht1Vn1bbP2Fi8ias2Hm4zH0nOCs2Z58A3cUepukPuEOyIV+WX8URGjD2yFbYv5xLLE+x3N4tJpa+MHsjh0jp2bOnXLt2bfCKCkUArnljqfOV+aWbzve7PV325Gu9yhwuhKx/XEOHv/7gt56vNgAf3Hkxv+vazFm2+qmBnDl7lrx3fsdM25XUo4g/Jc3m1tJnmZf6V5/97e7yEO23vhPkUyYuBbIGdUVFTvxTba+l/l5jrh/HW0lEtOoNOasj68OFU20GY89exqxLv2HMkF7BG/hACLFOSun92uKBsvQV1ZLLOqSz7VABjWqnckvPVqTXTqFH6/rcM2UNG3KM7RecZLXwwyOX+82T7gvXh8Pqvw5k6c48mtRNg7ppfNd/Ol3Tknlu7hY+tF2nVXLsXZy7Hj79PW+lP8Wb+9rwVpc+tG/dEpr3gCm/rxjg1mkw09ufv83emvMs0d9UO15wVfiAYYUPRK7wwVSFD1B/3wIQMGbl1TAkvP2sjRI/0+EKhYk0qaNNvN58cSuEEAw8rynptVP5x43dg7Yd2789M+/X0kqf17yuoTz+/mS46eKKOYdxAzoyqm+G78otL4KnD7Ov0ZWUkUStVCtc+hBk9IO/Hg4fbRoAAAnXSURBVNES4j2ZC12GwYT97L5vB22LP3c2v6ZUX8vQrj/cMgUe2Uj50HfoV/yGs87JSya4j3mP7/zvnhy+8BH++/u1lN+7OGC9KfWq1iK7REVZ+opqychLzmHv8SLGXtXerbyWvgCqcZ1U3r/Td0qEJ4acG1XZsidfy5Kdx3xujffU78+jZf0a9G1fMeFMchrc80PFeVo92resx97J1wOaVZj2zA90KJ5G1p1DnVm7rPXPIXfm92QUTyf7+ctokFoXhjzpPuDz+XByH7xxPgx9W0vCJ+3Q5Qb48Cpo1Ytm109kmGv9XQvhi5udXcy/eQeDuzRBrNzPX+cV849kH+Gtg1+ABU+5l/V/Es4fDg3baik6lr8J6R3hy9uh5z2wVu8nvSMUHoGSAuj4O9j1P2cXa5reSq8j7tFOhvnLHvh3Xyg8zFZ7G7pE+AYwqWwkTyd/EVEfv7R6gCsj6iE4yqevSDjmbz5E3w6NqJuWHLyyQfpN/oncU2eD+v6jxaH8sxwpKKGHx4Khd3/O4rzmdbnK7BztG76EH5+HdlfBjdpuZ1JKCkvKqVO0X9uVTQjYPg+2fasl6yvOh2VvwP6VcMfX2sPMF/k5ULdl4JSTpWfAVgo16sP+VbB+Kn1X9uEqayY77K2YNeo87eGRlAZXPw89RoA1ldKXOpFSlg9PHYSUWgAcO13CdxsP8rdvtzKizzm8cGN3+OYhrU39c+DsCe1zdBqiyb/1G7hzDrQfALYy/vbfjXy6WlvvkP3nDO1B0v1WaNqFsiWvM+j0M8weN4AH3/66IlJq9PcVbrsHl2ttgF9a3MuVY8LLDWXUp6+UvkJhAieKSjmUf5auLepVtigJy64jpzmYX0xGek3apNfSQkOFRYuwcmArg/ISSK3t1d5ul1iCJTyTEuw2Lc+UzvLdeYz4cBU/PHK520IrT977ZTcdm9Rm4Hl6CGp+LliStH2rgbLnGrKq+Qgue+Bt4x/aBTWRq1DEkIa1UmhYS6VWrkw6Nq1Dx6Z1KgqsPtSbNVn78UFQhQ/a24dHv33bNzL0hvfAle6uRuq1dDstFilYbcVEG0MTuUKIIUKIHUKILCHEBB/XU4UQM/Trq4QQGS7XntTLdwghBnu2VSgUCgWUkEpSPCh9IYQVeAe4BugC3C6E6OJR7Y/ASSllB+A14EW9bRfgNqArMAR4V+9PoVAoFC6UkoLFVhL1cYxY+r2BLCnlHillKfAlVEzm6wwDpurHs4CBQgihl38ppSyRUu4FsvT+FAqFQuHCqbSW2NOiPydkxKffEnBd9ZED9PFXR0pZLoTIB9L18pUebVt6tEUIMQZwpAYsFELs8KwTAo2AeFzvruQKDSVXaCi5QiNO5fqlEWM/DleuNsGrxMlErpTyA+ADM/oSQqw1MoMda5RcoaHkCg0lV2gkslxG3Du5gMseeLTSy3zWEUIkAfWA4wbbKhQKhSJGGFH6a4COQoi2QogUtInZuR515gKj9OObgZ+ktgBgLnCbHt3TFugImJu0QqFQKBSGCere0X3044AFgBX4REq5RQgxEVgrpZwLfAx8JoTIAk6gPRjQ680EtgLlwENSSluUPosDU9xEUUDJFRpKrtBQcoVGwsoVdytyFQqFQhE9VJZNhUKhSCCU0lcoFIoEotoo/WCpIqIwXmshxGIhxFYhxBYhxCN6+fNCiFwhRKb+83uXNj5TUpgtuxAiWwixSR9/rV7WUAixUAixS//dQC8XQog39bE3CiEuculnlF5/lxBilL/xDMrU2eWeZAohCoQQf6qM+yWE+EQIcVQIsdmlzLT7I4S4WL//WXpbQztv+5HrZSHEdn3sOUKI+np5hhDirMt9ey/Y+P4+Y5hymfa9CS1IZJVePkNoASPhyjXDRaZsIURmJdwvf7qh0v/GAC0dalX/QZtg3g20A1KADUCXKI/ZHLhIP64D7ERLU/E88LiP+l10uVKBtrq81mjIDmQDjTzKXgIm6McTgBf1498DPwACuARYpZc3BPbovxvoxw1M/L4Ooy0mifn9Aq4ALgI2R+P+oEWoXaK3+QG4JgK5fgck6ccvusiV4VrPox+f4/v7jGHKZdr3BswEbtOP3wMeDFcuj+uvAs9Wwv3ypxsq/W9MSlltLH0jqSJMRUp5SEq5Xj8+DWzDx2pjF/ylpIiV7K6pMqYCN7iUT5MaK4H6QojmwGBgoZTyhJTyJLAQLX+SGQwEdkspA+1aEbX7JaVcghZl5jlexPdHv1ZXSrlSav+d01z6ClkuKeX/pJSO/RpXoq118UuQ8f19xpDlCkBI35tuoQ5AS99imlx6v7cC/wnUR5Tulz/dUOl/Y1B93Du+UkUEUsCmIrSsohcCq/Sicfpr2icur4T+ZIyG7BL4nxBindBSXAA0lVIe0o8PA00rQS4Ht+H+z1jZ9wvMuz8t9WOz5QO4B82qc9BWCPGbEOIXIcTlLvL6G9/fZwwXM763dOCUy4PNrPt1OXBESrnLpSzm98tDN8TF31h1UfqVhhCiNjAb+JOUsgD4N9Ae6AEcQnvFjDWXSSkvQsuM+pAQ4grXi7p1UCmxurq/dijwlV4UD/fLjcq8P/4QQvwVba2LYz++Q8A5UsoLgUeB6UII/zt4eGDCZ4y7782D23E3LGJ+v3zohoj6M4vqovQrJd2DECIZ7Uv9Qkr5NYCU8oiU0ialtAMfUpFV1J+MpssupczVfx8F5ugyHNFfCx2vtEdjLZfONcB6KeURXcZKv186Zt2fXNxdMBHLJ4QYDVwHjNSVBbr75Lh+vA7NX94pyPj+PmPImPi9HUdzZyR5lIeN3tcfgBku8sb0fvnSDQH6i+3fmFHnfzz/oK0s3oM2ceSYJOoa5TEFmi/tdY/y5i7Hf0bzb4K2p4DrBNcetMktU2UHagF1XI6Xo/niX8Z9Eukl/fha3CeRVsuKSaS9aBNIDfTjhibcty+Buyv7fuExsWfm/cF7ku33Ecg1BG1Fe2OPeo0Bq37cDu2fPuD4/j5jmHKZ9r2hvfW5TuSODVcul3v2S2XdL/zrhvj4G4v0nzheftBmwHeiPcH/GoPxLkN7PdsIZOo/vwc+Azbp5XM9/jn+qsu3A5fZdjNl1/+gN+g/Wxz9oflOFwG7gB9d/ngE2iY5u3W5e7r0dQ/aRFwWLoo6AtlqoVl29VzKYn6/0F77DwFlaP7QP5p5f4CewGa9zdvoK9/DlCsLza/r+Bt7T697k/79ZgLrgeuDje/vM4Ypl2nfm/43u1r/rF8BqeHKpZdPAR7wqBvL++VPN1T635iUUqVhUCgUikSiuvj0FQqFQmEApfQVCoUigVBKX6FQKBIIpfQVCoUigVBKX6FQKBIIpfQVCoUigVBKX6FQKBKI/wcZg0B7GyPPqgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Comparison of maml_loss for task batch size = 1 vs. task batch size = 8\n", "plt.plot(onp.convolve(np_maml_loss, [.05]*20), label='task_batch=1')\n", "plt.plot(onp.convolve(np_batched_maml_loss, [.05]*20), label='task_batch=4')\n", "plt.ylim(0., 1e-1)\n", "plt.legend()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" } }, "nbformat": 4, "nbformat_minor": 2 }