{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "AEQPh3NtawWA" }, "source": [ "# Getting started with JAX for AI\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)\n", "\n", "[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond." ] }, { "cell_type": "markdown", "metadata": { "id": "lN1DEDeMel9r" }, "source": [ "## Who is this tutorial for?\n", "\n", "This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models." ] }, { "cell_type": "markdown", "metadata": { "id": "1Y92oUSGeoRz" }, "source": [ "## What does this tutorial cover?\n", "\n", "JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:\n", "\n", "- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.\n", "- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.\n", "\n", "After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts." ] }, { "cell_type": "markdown", "metadata": { "id": "z7sAr0sderhh" }, "source": [ "## Example: A simple neural network with Flax\n", "\n", "We'll start with a very quick example of what it looks like to use JAX with the [Flax](https://flax.readthedocs.io) framework to define and train a very simple neural network to recognize hand-written digits." ] }, { "cell_type": "markdown", "metadata": { "id": "pOlnhK-EioSk" }, "source": [ "### Loading the data\n", "\n", "JAX can work with a variety of data loaders, including [Grain](https://github.com/google/grain), [TensorFlow Datasets](https://github.com/tensorflow/datasets) and [TorchData](https://github.com/pytorch/data), but for simplicity this example uses the well-known [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hKhPLnNxfOHU", "outputId": "ac3508f0-ccc6-409b-c719-99a4b8f94bd6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "digits.data.shape=(1797, 64)\n", "digits.target.shape=(1797,)\n" ] } ], "source": [ "from sklearn.datasets import load_digits\n", "digits = load_digits()\n", "\n", "print(f\"{digits.data.shape=}\")\n", "print(f\"{digits.target.shape=}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "lst3E34dgrLc" }, "source": [ "This dataset consists of `8x8` pixelated images of hand-written digits and their corresponding labels. Let’s visualize a handful of them with [`matplotlib`](https://matplotlib.org/stable/tutorials/index):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y8cMntSdfyyT", "outputId": "9343a558-cd8c-473c-c109-aa8015c7ae7e" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAHiCAYAAAA597/kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d3jsV3Uu/E6Rps9Io3L68XE57mBMudj4BmKaMdiBkBiTQDBwyZMvlxBKAsaQQkkg+UwCX0hiSmiB+IZqWiAQO2DTjDHYiZ3ENub4HMunqGuq6sx8f+i+W+9vaY80TXbMo/0880hHR5rZv73XXuVd71o71Gg0Gtge22N7bI/tsT22xyM6wo/0BLbH9tge22N7bI/tsW2Qt8f22B7bY3tsj/8WY9sgb4/tsT22x/bYHv8NxrZB3h7bY3tsj+2xPf4bjG2DvD22x/bYHttje/w3GNsGeXtsj+2xPbbH9vhvMLYN8vbYHttje2yP7fHfYERb+aV6vY5jx44hk8kgFApt9Zw6Ho1GA6VSCbt370Y4vOprbM9968f23B+ZsT33R2bYuT9a5g1sz/2RGj55b/aLm46xsbEGgEfNa2xsbHvu23Pfnvuj4PXzMPdH27y35/7Iz73ZaClCzmQyAICxsTFks9nA/9XrdaysrGBxcRGVSgVzc3OYmZnB7Ows5ubmMD09jRMnTmB8fBxTU1Mol8uo1WqIRqNIp9NIp9PIZrPI5XIYHBzEjh07sGvXLgwNDSGXyyGRSCAej6Ovrw/RaBTRaBThcBihUCjwAoBisYh9+/a5+W42dzsajQbq9TqWlpZQqVQwNTWFBx98EIcOHcLhw4cxMTGBQqGA+fl5LC0toV6vIxQKIR6PI5vNYmRkBDt37sTIyAgGBwcxMDCAgYEBDA4OIpvNIpFIIBaLIRKJrPOSejH3RqOBWq2GlZUVLCwsoFKpoFgsolAouL04duwYxsbGMDY2hsnJSVSrVdRqNfT19SGRSCCdTiOXy2F0dBT79u3DaaedhgMHDmDnzp3IZDKIx+OIRqMBj7SVuXN+9XodtVoNS0tLWFhYQKFQwOTkJMbHxzE5OYnp6WnMzMxgfHwcJ06cwNTUFIrFIpaWlgAA0WgU8XgciUQCAwMD2LNnD04++WScfvrpOOmkkzA6OurmGYvFEI1GEYlEAnLSy3W368/zsLy8jKWlJVSrVRSLRUxMTOD48eM4fvw4isUiGo0G0uk09uzZg/3792P37t0YHBxEMplEX1+fk/GNRjtz57ovLCygWCzi+PHjOHLkCH72s5/h6NGjmJubw8LCQmCNU6mUk4Xdu3dj7969GB4eRiaTCcixPYvhcDjw827n3mytV1ZW3FmdnJzE2NgYDh06hCNHjmB8fBylUgkLCwtYWVlBKBRCX18fstksdu7ciQMHDuDgwYM46aSTMDIyglQqhVgs1tG6N5u3ysP8/DyKxSKmp6cxNTWF2dlZVCoVLC4uYmlpCcvLy1heXl53dovFIiqVCubn51GtVlGtVjE/P+/+rtFouGfr7+9HKpXC8PAw9u3bh9NPPx1nnHEGDhw4gKGhIaRSKSwsLOCxj31sS3PnOa1Wq5iensbx48dx7Ngxdy6pVzjH5eVlRKNRdzb1NTg4iKGhIYyMjCCfzyOdTiORSKC/vx99fX3o6+sLyFOn666ysbi4iFKphPHxcYyNjeHw4cM4evQoZmZmUK1WsbCw4Navv78f2WwWu3btwoEDB3DgwAHs2rULuVzO6Twr61bON5Ibn7z7RksGmR+UzWbXbRofnod5cXExoAjD4bD34EYiEUQiEUSjUWdsdVP4+/y9/v5+9Pf3B4xys4Pv+97O3TfUIIfDYSwsLCCZTDrlTsFZXl5GrVYLbAznHY1G3VxpOJLJJNLptDv0PoPczdytsVteXkY4HMbKyso6QVeB1698Bu6JGj51nHwGebO5ZzKZgHJaWVlxkBPXlJ/JvaWCicViiMfjCIfD7vBwnnwu/m4ikUAikUAqlUIymXSHfbOD3um6+/aA67+0tIRIJIJGo4GlpSUnQ/F43ClSzjOVSiGTySCbzbZlkNuZO2Wjv78f9XrdKUTKNfcfwDr55Lrz/a1CsmfVntFerrt1fBYXF1Gv151yV8XJZ+L3Kk+pVAqpVArpdBqZTAaZTAb9/f0drftG+pHz7OvrQ6PRwOLiojOodOipN6LRKJaXlxEKhdzf1Wo1NP5vd+N6vR4457o3PAfUo7FYbN3ZTaVS6Ovra3nuNMjhcBiLi4tIJpOIxWJOv6khjUQibk6q9/WlZ5xnlueC70eZ6XTd7ZrX63U3b58upLOmMm4/i3PSv9Vn4vebGWWdc7PRkkHeaNgDsrS05Ly5SqXiPLqFhQXnCXKza7VawIgsLi6iWq2iUCggHA5jeXnZGcVkMhlQINxYNeBbMXyRuDoiANxcGRnp9/qc/LteDp8xZvRZrVZRLpddlFwsFlEqlTA/P+/mWK/X3fvwPTjfXszZRu7cZ0YB5XI5ICeLi4tOCUWjUcRiMdRqNUQiETdX3Yd6vR5Ybxp8XfOHa/giZI2AdP481K0YrV7PUeeqip/ywLUDVpVrPB7H7OwsYrEYGo0GqtWq1+mmIYjH4w7Vssqz27nbNV5cXHTyxO8XFxedHNDocX7WsacifaSG1S8agMTjcWfkVE64h0TpAASeT42EGpJu5Ex1tcqL6gpgzZBbPUTZYECicud7vk7mqfKhc7XnUV8MrGq1mrNZhUIBMzMziEajWFxcRCKRcM6COhWU9Vgs5vaAe9rp6NogA2veN41xqVRCoVBwsHWxWES5XHYKl5PmhvBgVSoVAKubX6lUnBdLD5YRBDeWnhUjtq1Qvs28Hj4zP5NKgQKYTCadMqZi0Pl1KnQ6rIJixLCwsIByuYxCoYDp6WlMTExgamoKMzMzmJmZQbFYRLVaDRhk7oUeNH5vjXK7c+daMTqg4JfLZczNzQXkhDAeAMRiMRe5EHpUw0bZoSNo4b9oNOoOZyvea6fDZygoB3wtLS25w88orlkKZquHdeJ4dqmggFXYWh1oIkeFQsGdQY12NOXBFAUAZ2R69Vx2jVWefAGAojBEfOLxeCCafjgdIos08HvKAwD09/e731fjrKm7SCQS2C/Crnw2DVyscW5nUFasU63nTJ34UCgUCK74jGqwl5eXMT8/j0wm4+RrI8SznXmqk2kDAJUT6hlNZ3C9uUblctmhgpQXRvJE4rLZLNLptNuDdtfXjq4Nsi7A4uKiU7JTU1OYnJzE7OwsZmdnUS6XnWKi8KkyJbzHHBeFK5lMYmBgAMPDw8jn824BksnkuoijF0MNj76nwniMygg79fX1OcVbLpeRSqVQrVaRSCSwsLDghFYjzl7OlweFBo9eHnPGzPnQ8BUKhYCRAODgGxvpU8g7jZQpH1SgdNYYrTO3PT09jbm5OVSrVWeQqYDU2SuXyy73o4aYERKdIoXCeUh6ue52WBi1UqmgVCqhXC6jVCq59QbgDAEPtxrlrRw2Oqbc2DXk70WjUedMU57S6bRTUFxjntVsNot8Po+dO3euSyt0u/ZW2VLBqjxRpmiYFxcXEQ6HnePO9ADRNnXmH85hjbJGVpr2isfjAaNSqVTcvPv7+zE/P78OBaCh0DQbz4CehXaGrruNMmn4qIdCoZDT8fxbzp8Rc7lcRiaTQS6XC5wJdVA7kZdmCJWihUQMaZAVYVhcXHR/V6lUkMlkXMqSTg5TAZlMBvl83qXfFG3p5ix3bJCt50SFqwZ5amrKRT4kWNTrdZc/BNY2W2E9Gom+vj6kUikMDQ05w2ahKIWeehUhq+IC4ITLQqU0aIzG6J3Ta08mk+65GE33co4WqqZBpgKdmprCxMQEjh49isnJSRQKhQBSwUiYOSv1gDVK7gb6VVhZ5YOOgZIAi8WiMwiESqPRqMvD8nsaaAuNUUlwzRWZ2Gpj3IxQR6NM5QkgkENjXss3t62Ys+YjVclyfzSy5FyLxSJisZhTTgr9ErZLp9PI5/NYWlpyxo9GoZcybxG1Uqm0LiVDo0yyIqNH5XLQsPn4LVs1fPl3RsVE+BRZ4KC8VyoV9Pf3O+Pb39/vdAsAhwLw+XxGuR2DrDre6geNkClLGiDR0FJPapqqWCwinU6jWq26gEZly8dzaXWuPkdTURTKR7lcDjifoVDI6W7qKKZIfdwDyjpJbDTa3cp6VxGyXQD1WvWQVKtVLC8vuwe3HhsXXZUCoYRUKuVY2QozqRfJQ9+LXKeFPuz3qhR4UJhr0JyWZU5uRT7Tzsd6g8ViEbOzsy4CLRQKWFhYWAcx0eHQnIvNxepntjs0SiaURW+V0YwiCVQ4PAyElChj6n3TqGhOqpc58M2GNXDqGBGW5/PRICvBxRLOVP58+bStyMXqvOnYqHwo+VJfVFLJZBLZbBYrKyuIx+MYHBx0z6uOaLcRskUhKEdUssqRUMY4ZYkEy2Ykn60caoh1/3WfNQBQ+JZGjcQpTX8o05pOLJEAhVq75drYFAdfatAsEZAolpLU1HluNBqIxWLI5XKYn59HOp12ctfOnjRDfnyRMo0zK2V0/mrE6fwouZTyk81mneOZy+VccKN6stPRkUG2m8MH0SS+wkYrKysB+CoWiwVyw7FYDKFQKJBznZ+fdwo1Ho+jVCoFcshkq1omYqfPYzezWT5VDZV+ts236ks/p9fDeoQaoTNXYo0CczbqDKlBUIKPGrVerLEqaFVAJKVQcSkkCsA5ZTpXDhvdNPu6VcNGyAqvMzrW3Lcvf6yOX61WW6eQfM/Xi3kD69m7fAbulTUkmtdMJBLO0UgkEu7Ma5qGCrZXsDVzkBoha3mQOvTAWhkXoWolh7bKju10NCNskaxItNCmymjc+P/MzRKZYHTM3+P7JpNJ5HI5l9brlUH2PZc6DqpLKNv62TTKDAiIgmWzWVQqFWek4/F44L1aNcwbRfN8MViikabccn/4b2ANldAKH9qner2O/v5+lz6zKcluRtsG2QeV0hhrDo8P3mg0HDSkDEzS8dPptCsJmJ+fBwD34ADWeVb2/buFgjeCZPS5fJEvf5/z51BhtZBYr4adN9dH6xUJP1rWLCNOPZxqfH1wdTeCpgqd8BnLlFTpU0mpx6+RJ+WJ72lL5xQG9kWfvR7WiFpCFw0ylRANsWW+Wvmjw2Rzjb7StW6Gz8BbR1sJcYpuAXClNtbJUgNvndNO5q1/r+kx5o+Z/igWi4EKAs6VUC4jR+1tsFEJYrdDEQ7uPSMtBik2oLDRnkLB1pHlPujzsY/A4OCgK23rBRKgZ9ieOY3yfSS0UCgUsBV0qOLxuAvcqK9YC66j3WjZF81rgKFEOKurNXBUQhrL0SjzWjVkA8JuHM+2DHIzY0yIjgZBDSYjYtajsh6OuZxkMoloNIqVlRWUSiVEo1FnCAgH64Za42jzhL14HmvY7Evz2Zo3sblB651uRT5QHRaFSbWEiHMjPV9LPwAEEA4b7XVrlDU6IDuRdcma88rlcgEh132mgbP1m9bJYwSkubOtioCskbDEOk3bLC4uuqjfKjYqMsofn0vlxefcdTvUyNqXD+XhvK3jqYrZkrjUqKih6dQoW9IcCYFTU1OYnp72knVsPX0mkwk0pVD52MooWeW/Vlst41MWu740KGC0r4GB6h7mMLPZLAYGBpDL5TAwMIB8Po+BgYF1vQ/aeUZdF1vaFovF3Nx5DoFVJ4H6PZ1OO8b40tKSOxM8x5b1TJjY5vQ7McbU6frV6jGuiS0VU6fZzsNG3r0KWjg6ipCVycaFVWOs0Gg0GkUmk8Hw8DB27ty5jinNpg9LS0uYnZ0NeI2kzdMT1LIeC4t1a4wt5KtwrxJFCIvRGbCwq+YbbMF7rw699aLVEGjZh+YsaQy0fk4JUyRYaD62F3lYPcwkBlEhJZPJgNdMhUNS2szMjFt7OkNUtoyyNY+pLFotEdkKhAIIpm589d+M3paXlx30zgjAQnmUPwCB+aox5nmy+bp2h0bGWg5jo2/ddxutK5RHh8jHXO5FhMxhDbLtfKV8FcLkdPoUkVNS11Yy3PXMcC6JRMI5ZxpMMGrT8kAtSaNe8pV0xeNx5HI5jIyMYGhoyBnnfD7vYOtuHFM1xlr9QmefKAlRLkboAwMDSCQSaDQaKJfLrpSO+pM2Q0vW6DzYGup2YGtFFqwxBoI6ScmJdCr5Vd9XUcaNjPzDDllzQgzjfRGNFnxHo1HHlN61axdGR0fX5TgAYGFhwQmplr4Q8mZZjnonnURwFlqzkLhGmXzRyGmuQ4kVQNAL9tVJ9xq6ts9hCQxUSjz8SvCgAiWphGvLvfNFSN0MPdDxeNwZFq1h5Pzn5+cRi8WwsrKC2dnZAEGKypZeuSoI+9oKR4jDB4tZtEhliGQnW+vIQ8+zxHyhrcek1w4E84udIhZAkDi0WWTAv1PYWjtC6Zorc1nXqltZUjTC9jsgZF0qlQK1pQoT01mzDttWQdZWQet8ePas7uI55nkEECDLqrNNkiywVo2ikTGdj0Qi0ZXTYREu3W9F3wjBEwEbHBzE8PAwUqmUy7kuLCxgdnYWwBr7WtOQ1K3aDaxVDouu90YvYI1USQfJV5tOp1KdQO6JGmTuW69GR5C1heksm1iZcgpT5vN5DA0NBXqZMkpjhFYsFp0Hy/fWjj+aF7CeT6vzt3kym/ezZRRKiiK0q88IBDeZHtdGdaa9MhDNvDOFdBOJhFs7KigqTgqczW12Ahn5hu9AA3CMUc1bMx+/srKCmZkZAEGFRFSCBsO24dP13uqmD81SN7YhCCMdGl/LsA6F1tidjUaQbKcG08Ld3UbIFoq0ToBvzez/+4yyyrtdr07HZoiQoicK5SpixQhela912Hp9JlVhay4YWN8whc6zflW9pKko6iH9fW18okzyboMBlRFF16ifVf/q7+m61+t1NxdN0ViyrA2yOiVJ+Z6xGfzOVAYRTZVdzpPrbCNijZofMchaP5xCo8xjKhVGXwoZsRickDUXgEI1Pz/vvNhEIuGiBi4gP9sn4JstRrNokopUYcZisehqqefm5hxkqpGxJd/Q+dACclt20MtDbyMdHgKSpfhsjEKZClBFoGsJrHcqehXZ20PAeauAM/KhUSZES+VLOeM+qzG23Yi2CqIGNibTKfSm0QwAhxipAWPUW6vVHIynsLEqOP49/92JArDRrr6/XUPf+ilfQmVOUSE6ehZq7HYvfIiQnl9CuERPFEFRRrj2Tt4KbofP6bcva3SAYLdDNcDKRbCsfUUNrfHqxTNZOVG4mmeV51Oje414qd9t7rtX6Judq031WH3Af2uvbyIKdHqpd4hCWAO9FYaYoyuWtfUAgSDRIxQKOa9Uc3y2MJ9CqsaMm67euHqUOp9WhyVA6Uvhr9nZWZfDZEMNwtU+Vh2AgGeonvhW1DtSaNQR0DXj73B9KICaTiCkrSUuJDioM2EhyE7myq8KvapB1py1dRa0cxijYzUkSo7i6PVht0MjYzpzzHX7EBWunbJheYlEKBQKGBOula5Xf39/IF9nc1ztDHXk1BjbzmacM/fCDkvyUVTIV2rTC8TFRibcA0VaqCs0aldjbJtlbEUaSZ1+DVi0BEcbaqgOZQOQ2dlZzMzMuB4ClCtt+RiNRp1OsmRX1sl2ywNRHaOyy8+Yn59Ho9EIVJ8AcOe8Wq0CgCtNU36RRYF8AUA7+6LOg8q21RPcdzaz4Q1mtisgDbPVUdb+9VLXdNWpS71AhS14SLWLiRopPRQW/uLG80BRIfhgXx9Uu9F8NeKyXVvYe5s9n9U486o0ZVdToIAgW4/PSSeEhs0quV4Mm5tNpVIA1jpBaaQQi8UC8J62s6NDFAqFAuUTZGh2yxTnQeE+0tgooY4oiQq7KjIqXd9hU9jPpjPUkehVlGajGdu7nflMRsiss6d80CkNhUIOlmxWq8y9pcNEh7VdJaDPrqiKRXVsHk1zlfqZlAXf31vGda/hYGDNuVb9o8pWoVWtP94qJ1nnpyRRTXNp1YBWiCihSw0yryXV1sO2nNSXJrHd1LQss1WZUYNodQwdyOXlZZTLZdTrdYcQAXB6cn5+HslkEqFQCNVqFXNzc6hUKi4XqzKoL5++2UyGfKiPOq9aAkcnlwZ5x44dyOVyiEZXGzyVSiVHrtN0muorX1noIwJZW+jIYv1UkswdqGHiodWF53vaheTvAnDRQTeHp5kipec2MzPj+m+zhELhIntDEucMrKECmjtRY7xVkLVCh3ozjBKG9PcBuMNMp4T/pkOhBlk7GnUDP2pUxjlzLxQi0uFTIPr39tDaqETX24estDuszGtHLkVXmOKwN1epw0aESIktTIvoHvIcsBwqFosF+AudRjw2+tYUC78nDEkDwzmpgdXIwxrhXjufOmx+1hpkjZDVUdisRl2d7G7mRnm2vRmYxrCd6dQBXV5ediVdvBRmbm4O5XI5UIvPPSBKoF0OFxYWXOOkbmQFWE9W1ei4Uqm4s0x9SlIu22Sy0cfy8rL7OR1rdRq0EVCnqTJfGsZXksdn4eUQJMJFIhF3BsvlstMrClUTmVC507VVvdWJLHWdQ1YDzQXhIa/X6+sORTNYzAc3KJ1eN6YXkJdCjra7lb060kdaA9aXp6iQ8TmtJ95LaIzDRsqKBOj6z8/Pu7wTYRnmfRgJszUcI2TbX7bT+WuUbPeT0Z/1yjUvzuhRyTiUH6v8FHnRz+nUSNjIWBUgkZZCoeBeiqbweXV/uCckcrHHN5Wuwq6MUPl9L4yxlRcLQ1LWOQ86S4pk2RSSdVj0rBANaQZ/d7IX+nn2GS2/oJlDbCMbTQV1OjdgrQpFdQwNJeXF9pPn79KgaeqMUC/lhWdDn8eiSr6yxXbX3hcla5kQkT+ePzr4fDbtvNVoNNwZ1ny0nodOm/rYs24hcH0/60Rqio66kb/Dwb3UtaTcqUOl/9+pDLXNstbvLZZuvalGo9G0wXkzOELzCj6l2guDphvXzKOybF373L556PttFclIPTCbz/C9NN+mcIsyGuk10mPM5XKuNE07/XQ7fwub8quuF42w9kheXFwMKHVFVyziQcKTzfPz99uVo42cOH4mI2NGMoQW7ZnQKIDrSRY52f1shsP8H5VEu+V9G+2BZZjmcrkANE2jRqKZ7ZBG5UkHS/sDaK0sHUD7+d0aPBsQqINi84Q6R83fKlFHz2i3kX0zHWllR9suKk+CBo23bGmpJVM+RMS0jIvntFm6oJuza9E4G2To+nLPqX8WFxcDv09jzMqbTCbT9HaqTvSm1cE+joRNc5HgGw6H13GE9Nl9+9pM/3Zqr7qKkG2UrB50PB4HAC+JQhfap1x6hcfbobkzsuyUXKF5SipQYC0CU69T5+r7fiuGXXct29LbVLSPMtmaGvlbuCwcDiMej7vSNL60048vh9/J0L+3qIhCSWx0wANNY8U94KFnI5HZ2VkXzWluXId1rjZ7Fl1rXWe9UYu3VenNZoS9aPxsGRy9b+ap2OSiUCg4ljkVrI0Au117rncsFkM6ncbg4KCLCginp9Np9xy2zA+AQ8G43tVqNQAHMlVDyI/ODDkhXNtOhz0DNspV55GGkJBxf3+/c7B8EZk10J2ssX5vkSAqfy3V0ooP7bjHG/IYGashprM6MDDgmnDkcrktbYxjnw3wnxH9OddZu6VRxlgGy3lrOVonLP1mgZbmjlUHsHvY7OysI4zyPGqeWx0c6hS1GVq6RUSI89hyg2wjNDVQNMiETUOhUCCXyofzCcdmXodCnp0MfiaNcTKZdPAPjTMFhgSd2dlZd4C16YldCx2+6LRTyMg3rDFmDpJKXRua8Dl8NdVUslyPTCaDoaEh7NixA8PDwxgaGsLQ0BAymUzPYGsOhZC5BySe8GBrJ6LBwcHA/AmN1eurzern5uYQCoVc/o23x1goHFjLr7cTIVulzmiWLPyJiQmMj49jbm4u0MWNqQC9f1dZ5IwoqBQKhQJqtRoSiQQikUiAO2HRlk7WnO/V19fnLiFgGiCbza67pY3Qoy2h0zPAPDnPB89qKBTsX5zNZgNKqtOzoLCwRS/sszJy4xqzaxTnpN30lCTI99Cv7a61jdLU4PPzbfctZUoTgdE1VSeKHQ95V/zIyAiGh4cdsmUrWTqJlH2IhC8l4fu55l45d3bxGhoaco7EyMiIexblrLTrSFiEVaNjrQJgJ7BGY7X3xczMjDtzRNe4P8Bal0NeNEFbZ1EzZZnbgLOdNe86QtaXGjz+W29W0dyefT99Xwuz0utotgmtDIVcGPlGIhFXI72wsIBsNovBwUEHQ05MTCAUCrmISC89t1GyGkplflrSW6djo8hYoVMyfC3zV/vF2rt52Qt3aGgIo6OjGB0ddV53Npt1NbO9ZKSqUdaf6ftHo6td3pibZSQaiURQLBbd81NmqHjp4fIw+SLjVp7Dwo2ElnVtp6ennUHmfc4AArkp20KV8qDRMSFvAE5BMFrzpXo6NRQKVzPqSiaTgaiMLxoLZQf7+ouT3EMniGdNmajKC6FC7HT4DESzvWM0WiwWA2UtVKC2UYjOizqn1bXeKBWmnBlbMqfdABkR0/lnxEmYd2BgADt27MDIyAgGBwfdK5/PuwslaNgsm7zdNbZrbdNe1ijruivvgLyIeDyOgYEBjI6OYmhoyM07n88jk8msu4WrE1n3OUNa+891YU6fhDTqCZtaJReKSJ3C0XT4FN3QippOAsiuy57U2OhiMCK2uP1mxtiSE2jYGM35RqseFIWScyPkRs9Gr4QkXb9SqWBqairQ8MAygnU9rPDaCLldj8m+v4VP9Q5qGgkyxwmpsuk+hZA5QTonZFUPDAy4yJjF8jzcnRzqzYbN2emLssNcEwv3qWQZQROipqJlnSMNTTqdDjgUGmX6nLxm687319aYPnY1P5sKSPOZGm1TEdh+6XxubeKiiqIbToKiRPw315jGIZPJBFrGUma0dWylUkGhUHBIQLVadblwRkMk82jOk/2Pe5XaUdTJp0eoLHk+iDpobTB/h4bZzk3Jg+0YZh83xRpmm3slRM35ceg5ZdBAozYwMODgal7co1FmJ5C11cU2ANB6Z60ioLxaPcj/5zNkMhk3b+paW3LXqTH2rb9WvtD5pB6hk6kkR6Zu1HmkzqGjryiXprI0h98MDd5odGyQmy0IN4VwoY+gpcMaGsXj+QLg4AT9LP3a6vz0cGmuWBuT9PWtdpOqVquBm2E2ciqs8KpD0a0S8ikZaxgYZTEi1pIuey0dYWI1xryJiw6JGuNe5qHs0FSENciaf2WpgbbRBOAMMJ+r0VitzyS0XalU3I1i+gz8zI32xaZnbJpAiUzM0VOO+vr6nOOje1apVNz5sI1EeKAp83x+22yjmz1QpIjG2cd4jcfjATi1UqkgFou5aFhrULVTFo15qVRCMplEtVpFMpkM9FfvVfrGyiMdHjW0dDSKxaKL2qlE+f90SNi4SN/PRsmbrb2NkpVQpHupqAkAN29Ni9E50xId9qvmiwRMnlnte9BJlNkMibMXp+glO0AQ2qWxsqkWXZNmxNdW1nizoXugzgw7RRJWZj8GLWUNhdZ6/2ubX4XjgTWSaK1WczpACbDW+Wp1dGyQddF0w21OR3MMzaBuhT1sNxvCHWpILSOy3Xnr3/PQqfBS8WvOoZlgN/MktXykW4asQvmaM6Yx5jV0fBG21kjHd21hIpFwB1oPta2h1ihhK4auqc138bPp0Wo7yv7+/nUsVLtG5XLZlTRwKJehldFMbi2ER6MUDoexuLjomMqxWCxAPOP9yDMzM67pAx0mPrfCbAp7d7sfPiTCevTWGZqfnw84A7VaDZVKxSklynwoFApcEqDf23PQzbCywX+rY88rLxUybjQagatVy+VyoIlPNpsNENjonFNHtLrmNkLjXpKjks1mnXPC6LhcLgNYcwLJP9BrI5VsqVGxOs++vgfdGGM9S3T4KbfkTDQaDcdLCIfDjoQGrLWK9Rkwze0qT0V5Bp2gQFw/jXq5fjyXDFL0og6ialxXEpPn5+cDnA463QBcfXWxWHRr7ktVtCrzHRlkjYRtWK45t5WVFcRiscBhVDgDQECpqdfKw8wyEIXZbKTTCbTBDdeFqtfrAXKHr4TAN2yuUXMKvSjO1+hYryfUpiYkF+m9sHovMo0xFTyjYd7KMjg46M3jbAVU3Wyoh2zhZFUQZCInk0lHsGJLQa6X5g6JrqhDojByO/OyL/1/YO0WG5UXjbQJj9brdRQKBcfOZqTP6IbKmIxUe1lAt2iFPTt6Jvh/jOoYTWteVhvG8Lm1pE67VNkopFPHVOepTgR1g/JP9POAtfI4za3ye67z0NCQc6r0xSh2szW3jg7lzjpwdNrIqVlZWXHXE9JxYFRH42vJW4ODg4Fb82wTonZhdg6rxzRFMTMzg/HxcUxMTDhHslwuo9FouGsled8zEStC1cy/ssaae6I8CZ+O7TRqpoz09/cjlUphcHDQ2ZJcLuca8ehVlprmymazjtBL55MOHt8nFAo5prbutT077ch62wbZKifr0dDro4dKOEsPpRpmzeFaY6x9T+mRqLB3C6Nu9CybQe0cXGwLJ/PVLEpuZ866riQWkQXO6HhiYiJgkJW5SeeHEAwjY+agSBAh5LUREaTTHPhmz8evCl1zn3UNGK0QwisUCojH4+7uZMoLGxUUCgUX1SlUT+O20WHxoUAbEXb4+2QdKxeCclEqlVwXO+4jc4d8dm1TqHf4KhzZrUHW5+I+c23UINv8N1nLihxZo8wzrV/VGPVivjTGmivl53EdNcfJSDQWiwWgVJJQc7lcoNyFQyNW1XOtzlH/rTLNXGU4HHaVAkQj6DjrHccsERoZGXHnNZvNrmOKbxY8NBvWcbCd6Obm5jA5OYljx47hxIkTrlkJkVBWEmhwYvPHTNMQqaB+p6FT1Mo6X+0M/TsaZBKD0+m0g9uVtEjiIXPNzCFzzuqUERWgfSsWi27dADhninvTjgPaVYRsc8RA0DABWGeYNC9MI2ONmCUNELKkkmhWPN6tYbbP1+rQCNZGyDY66HQosUhvpiLTl43oaZhojNUQca00J8XIWC8y5zMphOVb414YBP0MNcwcCp+yVI0HnM4DDTCVLgBHNmIUAiBQx9nOnlgjoHlB7fbDw6fpFy2NYFkWPW+tF1ckSBujaF2p5vO7RS2sE6TPSSdIDS0ZpHr2LHtd91LTUJYJ3ek5sEraNlrhXNUoq3HgGvP3+DvxeNz1LybJza51q2iRQqa6pj5HjvDuzMxMIP8IwMGsLA2iUR4ZGXHVDyRwac64mzNp947nSas46PgXCgWHHmhZnzpvKv9cc/I/+JX6yJJH9dUpmkJ7QYJuf38/MpmM040atBCOVj4FuRI0zIS36ezx/3mBhn4G37Ndvd+WQfZFlFZQNQ/MXKy9hYQPyE2z19Up1MVcDpWg7X3ay2ih3aGGxOZebFTQDVTXLLdjuyNpCYU2/wAQKEcD1t/LCwQdJCozH/FCDUI368510TI3XR+FfKmEWYbAsbS0hHK57CI7vi/ljJ2DaIw1n9nKflhZVzmkcSchjrCXbbrCddXWmJw7zwX3xLbzs5exdANHclhny8oxgIBC1PNl8+jcC0VOdM26macOfW9NKdnyHmAtvaF/wxygEnS0LSnTa7lczqFFWq5Jhn+rMsM1tI6szXNzby3JybKqlcRFoqklcHUbmHCoUdaSP6bB2JEuFAohlUohHA67XLflFTDFpAxt6nginzxD6pTY9egEsqb80VHgmiocrxd9AGt3MfA5aLNsYx+mFhQRVjS4U73fEWRtoTtdNCpYRrd9fX2B0gl6E+Fw2JVKLCwsuHyo7STFw8R6WVLX1YvtVU7N9/etvqcqNhshdGqIm83Rvih4Fj6nYlKjYDtOkT1LgQuFQk452QhEI0QLV7Y7rCduc+0WeaFcaY7TwqUKmTYaDUfgiEaj6w5Lq3uin2WjV5aKlctlt2aVSiXwOdbJ0Dy3okVqXGiErfPZizSNde6s08jfUaOlJXYaVTD3RqfP6gWVnV6lmCzhjNAvoUcS/my0znNgIVnuQ19fX6CrHS9GUFZ0O9GOPiPPH4fKvYXxNbJjbttWPijPw5YJbcXwyYM2wGA0n81mXVrMpu/YvY7/5t8mEokAK1/lha9OgxmbKqBMcu153rQDoA7lNumLPweCF97oV5XzdmS+Y8g6HA7jH+77B3zgrg9gamEKBxIH8OLsixFvxAP0/UgkEqi15EUBANZBIqyhJeO00VirlWU+zfZZ7qRG9pYjt+Da71+LHx/7MY6Xj+PzV3wev3T6L3WyFADWw3++yLkbg8wN/as7/gpfvu/L+FnhZ+gL9eHk6Ml4ZvyZgTwSDylJFFQGPlIYowZbIsJ8Jd9Xr3NTlELJPpuN6350Ha67/TocnjsMADh75Gy85SlvwcX7Lg7kcqjcaQB1b7U0hAbBcg6q1apTeIymKHPtMn3V4VFj/MH//CDed9f78IKdL8CvjP4KgNUcml4aYJs7aGMNzXUCcOup116S/WtbCbareN/27bfh7Te/PfCzM4bOwJ2vujOgZCysbOVF69pZVqe399CJUkaxvpQh3o6COlo8iqtvvBpfv//rqC5XcSBzAH/wmD9AOrbWgpFOkJajUcfY/eRXzZcCcFGgVjBQ/vk87ThyB953AEcKR9b9/KqzrsKbH/vmQEc9wrdMDYXDa9dHsl+17b5lneJeGeNQKIQGGvjzH/05PnPPZzBZncRQ/xAuSFyAc0LnBJ6fiBVb3Q4ODjpmMo0v9Tvh3VKp5OQtFAo52WInNcqIpoUYrbYySosl/OG3/hA33HMDJioTOH/n+XjfJe/DE3c/MWCY1aEnqqX8By3zYmMiNh2ijuFc6TzZQLGTfek4Qv7i/V/En/zoT/DHT/xjnBo7FZ+47xN4z/h78Nrwa91BJjzHRSe7jqQW1mXSYPN3tM6UUCOLyUn7p8eobM9WH76yVMF5O87DKx/3SrzwMy/c8Fn1mW1Uqp+nh9TCer7faWfw835w7Ad42Vkvw8n9J2N6dhofuP8D+Gj1o/jfmf+NbCXrlCQFnn/LiJEOUKFQAACHTOhNRewrSwhKvXIaaX3vVtd8b3Yv/uyZf4aD+YOo1Wv4+J0fx69+/ldx04tuwp6+PU7YqUg1R6jELi3FsP26+fz1et3BVLwlydaEtxsh0yDfMXEHPv/A53HGwBlIJBLYu3cv0uk0RkZGAo00rJOgl1Bo1EzFw25pLD/T/KCmFjqBJc8ZOQc3vuzGNT5AI+TOqDoJCj/TqVlYWAh0SWPJi3IVdM80slP5UbJgq3OfnZ/FRR+9CBcfuBhfufIrGOgbwN3H70ZmOYPIcsQ56Nq+kxE8n0eVLhUxo2Wdh0KzTP0QTrY971sZP/rNH6HWWEvF3HnsTvzS534Jz9j1DFc+pJwPdjkD1rq82XW0BrnXUTHl6n23vw8fv/vjeO9T34s9/Xvw/cPfx7v+412YD89jIDzgIlg6qESK8vm8I6pRvrSkizXsChFru1ZFPJWXwX1sZbzqK6/C3RN345O//EnszuzGp/79U3jWp56F//jt/8DuzG73Pmrk9TzZjnxKaJuennYVEYTro9GoQzF4drsxym3nkPm67o7r8JKzXoIrT78SpVIJv3f67+HWqVvxk8ZPsL++3x0QAIF6WQAONmCnIipSKi5CSUySa3cX0v0pnLZOuJVx6cFLcenBSzd9Rvsz+70d1hD3CqoGViPGL/7qF9ccl9gc3hx9M668/UoUM0Vks9nAxQb8G4V8ATj2e61Wc7fKqFFmXSOp/zQS9NT53gpdt/KMl59xeWCN3v7Ut+ODP/4gfnj0h3jO6HNcRzHKDHPeKtg0yL5Do8xJRmxkdna6DzZCrixX8OqbXo33PeN9eM9t70EikcCuXbswMDCwrsWkXudZKBQwPT2NRmOtDpbvb40xy9F8teCd5gij4Sh2pne6iJCpIiIidOLUUbAO3NTUFKamppxS4n4xQgbWLp2wDHHtUdzOWf3z7/059uX24aPP/6jjBYz2j2JmZgYnlk44GWWKgMZUe52rnCoJrVarOTiUc1GjzP1jU5N2eAcAMJIaCUC9X7v/azgpcxLOSZ3jqiHUICuBVWvQldjXi9aSm41QKIQfHvshnnvqc3HJKZegXC4jsy+Df3rwn3C8fByj/aPryIdaX81aZOp/psHY14HnmAEX02bFYtE5GyR5tZteml+ex+f/8/P40ou/hKee9FQAwNt+8W34yn1fwXW3X4d3/OI71ullSx4jp4lOU7FYDNRes5KDxEySxtSZZqDYCSLUEWS9XF/GnRN34jXnv2Ztc/r68Zj0Y/Bg6UEcCB9wD82IhlEMANfFSFmm9KZIQWd0QyiAxfGaS/EVwfdiqCKm4Ns8GA+2em6aN1Zo1AeRqiJoZz48rMlkEvW+VUU4mhlFeDHs4H4ebC3tUOYpv1diA2FW7TpFRaTPwDkQ5ejE0K3UVvDp//g0qstVnJc/LxBFUsETGWGU7FOaRFc0MuVB1yheuQ7tRpm612+46Q147mnPxaVnXIr3/vi96O/vRz6fD0C/augo8/Tyi8Wii9zV0KvibWaMu4EmfzrzU+z+i92IR+N48p4n423/820Yig65/da9VpmlciKbn0qJECNljXJhWxRaQ9KugvryvV/GJadeghd99kW4+cjN2J3ejZef+3I8f+/z3ZqR0Up5IEROx8MH7aoskOhl0yIKYVpkpd2xsLyAz9zzGbz8jJe76Fv7zrO3vC83rsQ+H5egV0PPxAV7LsBH7vwIDpcOY2ffThyqHMI91XtwafTSADdACW+a8+UZIwtb5wys8T1sAyW9oMHWrQObI4wr9RXUGjXEo3H3s0ajgXgkju8++N11ZbeWhMvzoDe5KVytl/LwmZuVJ3baQ7wjyHp6fhq1Rg070zsDrNN8fx6HQ4ddToyJf5K8KpUKGo3VEhzF6NmCTannkUjENbDQqM3m1nqZQ/HB0lYA2WbTZ2BVkdFwqBB0EzFb6DQWj+FvD/0tHpN7DM4ePhvHl4+76FjrHHV9aYiV3cyIlM+g6QaF/Pj59Ip9B2azcdf4XbjwIxdiYWUB6f40/v55f49Ts6diYmIikMOmkqSHzQOuEZ7CYQqbUuGq965GoRO4LxQK4TP/+RnceeJO3Pq/bkVfpM/tRTqddnOinDMa4NxXVlZQLBZdhMChBpnGmEiEHupumNVP3vNkfPz5H8cZw2fgaPEo3nHzO/Cs65+Fb13xLdSWa27NtbWqGiVGCzMzM05BseEMz4FdbyVf+ppWtPoMh2YP4brbr8PrL3g9rr7oavzwoR/i9//l91F/Sh3PGHqGc9SJTCjaxr2w5UA2B0od4iuntDLQyWg0GvjyfV9GYbGA55/0fCzNLgWcNa3FZTRJGSAKuJEx7qVR5njTU96EwkIBF/7DhYiEIqg1aviN3b+B8yvn477YfesMrJ5HAM7RsTl9S8rU8i/rdHYSYGViGVy490K885Z34qyRszCaHMX1d12PW4/eilMHTnVrrUETjbHKj7YcVkPMZ1P9qsGiRYM6QTA6ipA1OtQyDS4uJwjAJe9ZnsK8sjLu2GpNDwaVFC88yOVy3ns+e2GMrRHWiNi2MIzH485xYC6Kz2mZhdqQ3LJY25mzhU77+/vxzu+8E4fKh/DhCz+M2GLMCTxzGnr/sTJQSTCyXihrZOlIcI7WG2YE1Elt9RnDZ+CO37oDs/Oz+Ozdn8Wr/+XV+NQzP4VsPbuO0WjZ1OrwKFGKxoHwJA0hITTehqO1m+0a5bHCGF73jdfhmy/9JpL9SbcnlFfKCaMcVf7Ly8sBhWqVT39/f6A9ojbbtznXTuScqZlGo4FzR87FE3Y8AQf/5iBu+OkNeM7ocwIwryIl6jAz3UTITgl4qqConDRa2Ah632zUG3U8cfcT8e5nvhv1eh2PHXks7h6/G9ffez2ee/FzkUqlAiQ5vQCDLHvqFWANNeFXRb9sRKqVB93qmk/c9Qk8ff/Tke/L46GlhxwqoYiO7ncul1t38YIy1rcCquYIhUL4/D2fx6f/69P4u+f+HQ4kD+C2B2/Du37yLjQSDeT784Ge1exUNT09jZWVFVe/q+VAhKQVvaOOV+fNOiCdpGo++cufxCu//Ers+cs9iIQiOH/n+bjizCvwkxM/CTjv2hOD8s504PT0tEOCmM4Jhda6HPb19bnralmKRme6WfXPlkDWfOOR1AgioQimF6cDrMpyo4yh2BAG+wZdyQmN7fLysuv3CQR7XANrjD3m06iccrmcq8FTbL7T+rTNnk2NsXr9jGDUWADBPtMKWepVarYurdP5cW5v+e5b8K8P/Ss+d9nnMBIdQblcdlEKOw5RyTJHqJcg6OXnGsEzKubP1CDTm7W1vO08U3+kH6flT1tVrsOPxW1Hb8Mn7vkEfveU33VryVwhHRkaal1nW49NBayHnA0VduzY4YiARFtYXtGq/Pz4+I8xUZnAEz70BPezWqOGW47cgr/50d9g/i3zAcPJNVQ2Ohnp/H9grV8xnQe9sUdrMnsl541GA7lYDqcOnIoH5h7A0sAaRMeSQ5UZ5uSpUO2FAgCco0Fij/ZYppLyRQ2tjF2ZXTh75GwAa/J/5tCZ+OJ9X3S3gCkRjQaAjpFCjIzYdNDBJVNYnQklUXUT8RyeO4xvHfkW/u5ZfxeQVyWfhcNh1xlKO3TxxjWtN95qYwwAb77pzXjjhW/Ei856Eebn57Gvfx9+OvFTfO2hr+Gq+FWIxWLOIZufn8fMzAyWlpYwPT3t0ks8q1x3trFl0MDcK50POs00bpqSbEdmTs2fiptffjNKCyXMzc9hKDaEX7/h17E3tRczMzNex5NBlMLVhULBdW4jH0XJdtlsFvl83t26ZSt/2pV1jo4i5Fg0hsfvejxuGbsFl5122aoSXV7CT+Z+gucOPxd55F3+Rg+x7dwSCoXc5lDZs7fy0NCQg6pVUWm5zVYYY190rOUoGvFS8QLBCNk2OunWIKuX9Xs3/R6++rOv4p+u+CccyBzA8vJygO2o+VTOg3ugTHYtMWPErAXuVBQUQn7Vvey2nKuOOpbqS+4ZuY6M1mkQLEsaCBLoOBjZ0HvdsWMHdu7cGZAlHpp2GL/POPkZuOu37wr87BVfegXOHD4TV190NaKRtRacKkObNTeg08dcKB2GTCYTaLbfKYTnG6XFEh4oPIDL9l8W4AxYcp8S5WhAKCeEqal4eEapUPXSAxs1tBNpXrTvItw7fa/7dygUws/mfob92f0Btj8jYMovuyixtpjwOpUr0xp6vmmQfTee8frOTnKCn/i3T2AkOYKn7306Jscn1/E2tDuYkleHh4cxMDAQSF/YQGQrjDIAVFeqiISDfQf6on1ACC5K5JpS30xPT69LrSiqxbNLu8COViyXYvMTRQbUILczGo0GUv0pxMIxHJ87jm89+C287pzXuVbDRNWoyxg4Uc71shcADiGls8BzqvdR88x2Wx/ecS/r11/werziS6/A43c+Ho8beRz+v9v+PyzUFlbzJDNLARr/yspK4Co35iYJFXHjGS3Q89A7eRUu6DZiKC+Vcf/M/e7fD8w+gDtP3InB+CD2pPesi5K1JlAPE3MlAAJQqoWtlaEJdFb+FAqF8JpvvAb/5+7/g8//6ueRT+UxtzKHeqOORCKBfDzvHAb1+GiQGRXzKkJ9X+aOGeVzxGIxZ9STyeS6fFA7OeRrbrwGlx68FPuy+1BYKOAf/v0f8L2HvoePP+PjAYPDSIcQEiFS1mlahiSRAb4I/zJa4/3OCv+1W5KQiWVw7ui5gZ+l+lIYSgyt+7k6ds3IWArHW6Ng22S2k3P1jd//5u/j8tMvx/7cfjxUeAhv+/bbEAlF8Lz9z8NyYTkgJ5ZRaqMJyjDPKgC35gpXK5mrWZexVsbrL3g9nvLRp+Bd33kXXnTOi3Dr2K34yL99BH/97L92nbM4WKJFA2wbzwBYVwqlJCrNedv5d1pDXavX8Pf//vf49bN/HWGE1+kIdW705jW9K9jWHffSOfONUCiEy0+/HH/2vT/D3vRenJI5Bbc+dCs+feTT+IXML6C/1B9okamtkdVB5/x0rVXebZc7RVX0ytt2n/kb938D9UYdBwcP4t7Je/Hmf30zTsmegufseA6mJ9egaCJBttGN6m0lpdGBoKNEo2xTCzYv3u7ouDHIi899MaaqU3jnd9+JE5UTOHf4XHzsGR/D7tBujM+PO+PAriysO2ZDcsJ1WnqgBBctt7EeYrcCefux23HxJy52/37DN98AALjqvKvwd5f9XSDC4QLrxljmr5KkLGOwk9rXZuMDt38AAPDMf3hm4Ocfeu6H8Otn/7ojW2lhOztx2b7PCqlrKRRTDKFQaB200ylUDQATlQm87IaX4Xj5OHKxHM4dORefff5ncX72fExOTq4zWLaMheupg/uhsqTlIhaCtESpXis2fT9FXZrlgNUoa91lpwbANx4qPoRf+/yvYXp+GiPJEVy450J8/Ve/jlwth+O14+s6cTFnTPKWRXjUWWWkqf1/Nf/a7NKDVp/nSXuehBuuvAHX3HQN3nHzO3Dy4Mn4y2f9JX79Mb/ukCEaOUsk0za8TJPxd+mQaqRs98C25OzEMbrx0I14sPggXnrOSwP8B55RRQq1MYy9wamT6oBuxl8956/wh9/6Q7z2m6/FRGUCo8lRvPCkF+JZ/c/Cg/MPOp0IwCGCWmrGZ9IqD66tzzD7OjB2mq4pLBZwzU3X4KHiQxiMD+KyUy7Dq898NSozFXdRBq+mVaNMp4IBlLZeBVbRLC3v4ksRIF9a4WGBrIHVBf+d//E7+O0n/LaLgJkQVxKQ1pBS4fNgR6PRdRtIpaqHvJcKCgB+8cAvovHH640J52ghbKX0+9iA/FvNc/IAKsO628E520hbP8tXZsCxtLTkvR1GWcz6Xtrezz5Tu87FR57/kcBcScJhbbp7RllDJVwoXArA1T+TxQ8EI2afketV+0kA+PbLv930/+xh1O9tYwk1cLofNrLudK7/+Kv/CACB81epVDAzMxNwJC26o41NVNFyXvpvVbz6DLYjUidrftnpl+Gy0y9zz0D5C8CpQqzzsaX1jDZDWbRe2X7fqTF89qnPxuI1i+42J36mlpZpBKkOfzNHZquNMQBk41m895L34j3PfI9rHjQ+Po7Dhw+vk0tyPjTVxUGj2te3elOSBl8q85vtXTvP/aJzXoQrzr7C6S7KerFWdMGSyjfRIUUttJRPI351+PWOcjtnhe3bHS0ZZE6qWCyu+zm9PmLvmndSiFO9D62TJfylyoAeC/teLy8vtxTRcH5qKJrNvdlzUsB8FzUoyUhzmUAQstYNp/fIshd2kbIedydzt79rDamv97D1BPV5tH5aWYjaj5brwk5GjGZbnbs1yNqOTudq58z58n0VPlVjokac68+7WilHqlC6lRn7XMq0197I6oVrPk2JPtwvygfXtlm+u525q0FmFKzrro1N9EIAjY45bE6Q8sO/p6yw7wDLj9RB7FTefQ4Elaq2LeW8VY5VxvWl+6BIgS23W1pavbCEz2UdYztvysPCwkJgvRXmDYfDgdy3fj4jRaJavag/tuu+0dy5pzzviqpprbbqd2vMaIQZFet+WJ3PFpskRNooudW5+2Rde5Qrt8aigBp80Gmw/JqFhYVANEx9tpGz75N372i0MMbGxhoAHjWvsbGx7blvz3177o+C18/D3B9t896e+yM/92Yj1NjUZK96S8eOHUMmk3lYIJNOR6PRQKlUwu7daz1Lt+e+9WN77o/M2J77IzPs3B8t8wa25/5IDZ+8+0ZLBnl7bI/tsT22x/bYHls7WsohP1o8kZ8nrxvYnvvDMbbn/siMn6e5P1rmDWzP/ZEarUbI2znk/2av7blvz3177o+O189DLnN77o/M3JuNliJk9qUeGxtDNptd9//atUpvDWKzgampKZw4cQITExOYnZ1FoVAINPBmIT/rSbVrzvDwMPbs2YP9+/dj//79GB4eRjab9daqFYtF7Nu3z823lblvNBqGOVupVDA9PY0HH3wQ9913H+69916MjY1hdnbWtW/kM7Bb1L59+3Daaafh5JNPxs6dOzEwMOC9eL5UKrU890Zj7SYt1hrb23jYjYvXu9VqNcTjcbeeJ510kmu+onfvKqO3Va9zo3V/8MEH3QUMZDKSrclbb6ampjA5OYmpqanA3cx6G469PELrX9nhSK+NZFOQfD7v6gVtf2V2cjrppJM6lhnKCFntxWIRExMTOHHiBI4fP46pqSl39y1lf2lpCf39/dixYwcOHDiAs846C6eccgp27NiBbDa7rl97s9KPjdb98OHDSCQSrvphdnbWfT5vb+K6s3uR9oW2NbPKYtdSRXZzy2QyGBkZwUknnYSDBw/itNNOw549ezA4OOh6CeiZLZVK2L9/f1dnVSs8isUiJicnA+s+MzPjXnNzc1hcXEQsFsPo6ChOPvlkt+47d+50667s2WaRjF1337x5Rm1pma0A4Nnl1ZYzMzPupiH+u1qtYmVlZd1VfwMDAxgZGXHnec+ePRgeHnYthpsx8zebO+dPlnWlUsHU1BSOHTuGhx56CBMTE+5eApY1sS6dpUaUM3b0CoVCrvET+0zk83ns3LkTe/bswa5duzA0NOS6MfpqkNuZu2W48z5vruns7CwmJiacnJA5PTg4iN27d2PPnj3YsWOHu+6XL65tuw2qfGfVN1oyyPxQFkPrsEYLQKAkSMtqlFKujUBCoZAzyL4m6radJRWxNWp2vpvNfaPB52IJEen41mDpZ/EZ9Hu9LINdgGgAfcXvrcxd1zwajbr2e1xHLS3QXr+cr60bZbclX813O6PZ3NPpdKC8h+/NXrj2ViNtW8oWntqYwsqD1rpzzxqmzpRyw77kbM3HphGdyIwqXSonNqzw3crjq7XX7lC+1p4+WfM1GLHfs8sT5aNara5TItZo8Lzy/wAE1od9AyhLnB8bJrArWj6fd72Jtf2jXrDRjrw3G1puUq/XUalUmtaFan0x152dotgtiuveqvzbZ+C8uaa2DJEyr3pRr+6kI8QXELyv1zpnKtsMZPg8zQzyZnPnUIMcDoexsLAQcGapE0OhUODSlHq9HqgtZkdDztXOgTpS7wzY7KrOVuZOPdjX14d6ve7KCW0/Ccq57zMo46q/bYOhbnSkb3TcGIRDlSDrD+kdMVKjZ6I3aNB4axs2YO0GFh42W/+aSCRcjSb/hg+q/+7VM3EOjPi13tHe6sSDRCHV2jbfXZz2Gdodtn+23lvLPWD9YOP/1tWpZ04FZhVXrxqw6HqqkrJ1pNpYgP19Q6GQuwS8v7/fGQoOHgY6QOwwRoWn3YFUQbCwvxfrz2fT2lY+k64x6x1ZR60NZ3xrzffkz3xRcit7o4pfa4PZrUhrY9lpSS/y0DOpnbk4fzawYB/1fD4f6MGsfX1tn+NuZUvnpwpY+wBoTbI6EdbIWYPX7bwUMVQ55/oXi0VX28uuaIzief2fyozKKvVis/a8vZJrDitzPHPcSz1ztuZb58D+DhpksTWv7Vehhq6bPWmmx22HLr7q9fo6/cjGQslkcl0vhK0YXRlkKyQLCwsoFArOABM2nZ6edtAd2yACq4vN67r4fqqwbfP7QqHg7gzlZ6uA9Nog6xx4eHhHpjZA0Q3VgnjbB7jZoelk3jrH+fl5d6hPnDix2plGDrXCvVzPUqmERCIRUJL0Bqm0uL7driW/2shAjbF2hWJUzKhKn9c+P+VFG7Hwee3zEJHxNbvo5Lmsg6F3OlPxqlNEw6DohFVufD+NYhUZaFVBqXxQEVF+NSXAtBEbPuia0LHUm88Y2WvvcEabg4OD2LFjh7vMQ1MDm8HA3ewD5UqVrl4ewL2mrtB1t924ujUA1jnQSwtKpVIgtUSnSOXEXkyjKJev6Y+9wKaXxliHyp8aTBpionManFiHSZsOMZLOZrOBxjm8yEOd0E70kN0LQu98aetM6mlG1LzxiSgDL/8g0qfr3GsyWU8iZAoJczlTU1MYHx93gjczM4Pp6Wl35R+7VWnvUnasobBRoG1/Xe0hbSHuXgyNDjTqp4KlQdYuXnzpBvHOUL17U6+ctFFIO4fIChvzyDMzMy4Xq5efA6tt62j4tHMYBV/bwrE7UC+HdbZsblLXiodVFSYPpvXGeZDo+DCPxQhboVK9OrKbSMLKiE/5qgNXLBZdBACsOQraE1jXhgiQNcQ+eH2zOerZ5L7Pzc05A0DFpI4b58i58UYnhdT5IozJvCBhaxrkXrcstc9oHSKNcNgVi+iKGl/VG7ZVaTdztOuu8kC0cHx8HFNTUwHnXufLs0AnXwe5GHT46Uip46Fz6cVaW4hcz6WF2bWjoZ4tzk174vf39zs9qt3LdF+60UM2sFKjzM+11+SqQSbEzdy36g4NvnpplDs2yBaGpPBVKhVHSCBhhC9C1fS41SizRRkJDNrWUaEG5nn0HuZee4Z2Iy3UZ4VIja22ctRWfarEew2VMopXb5sQJAAHGQII/H65XA7kuFOpVMBY9WrYKJlz962F5qS0YTsHFR0VFhUCnQ0AzuFjTk2RCj1QvTrsCpNqakOhMTpGJKNpnrnRWLvLl/lQdUg0r841agURsnKsrVQ1irR9iLnenAPPKYkt5EHwxcsQtPG+5o07vYpuo+ey+mcjuJqIgxpiX565l5C1dRJoBIhk0SBrdMiLd+zFEzRMdEpp2HgOrEz38uxamJqyqKktdUxVdwNByJl2gu9p96mXCCL/TuemhDp1APj5zJVXKpXA5R68ict3012vR0cG2XcYFBZgNEtDTMhucXERwJrSZe4pk8k4bysUCjmcnpvKw0YFEo/HAz1Ve7k4PqhDI3Tmf2iQlTxFr8nmXazn3W4ucKOhDpH2RGZagI6L5lMBBNAHRo4Ke3EteuVh86uPZKO3vgCrRledNnUmCPvSeFjlp04R5cLKR7eG2BpjveaSEaf2/qU881kZXcZiMYRCIef0sdcvn5mKz14D2moEZw2X7d2sfbWV/6C5ViW08O7aTCYTIKHpLUv8mSUt9gqqbvY89mIM6hvrCNlbqHzXGnYTHfOr6kUqfOsI2QheI0NLhlUynRq4rTLEwPpzSzkkQZKIFOeipF0On7Nj7YeFt3s5bDCgAZKS5xROZwUIZUXvmbcOULdpDju6ipB9niAjLzVchKl5OAl1sTwll8s5SJWeP6MGHhgljinEsBVEBgv1EW7S0hGF+dQYa0mOwnW+23x8DNN258mv+vxUpMqA1RtKyDyks2ENci8dHE0vKFxIw5ROp91+hsNhd4+2euPAGilEkRNlp9pcmyoRGxG1W9Zlh0UmKO90QElc1BymQl/KtAeA+fl5zM7OurIcyo1eC8gz0y5nwucY2lygJZnZPdJSm8HBQcec53PofeX22sBmjOpOh9U7tum/vUylVqs5p4ZrqNH7VjsMNpLli8bVd0sQHWwljKp8A/CiDr3OZ/I9FV2gTOiclJmvuoOypGtComYzvdiMVNrps1lHi3LOZ7K6QFFRrYSw5DPyXHoZXAEdGGQrcNYzZZ6K0C7JLOphZbNZDA4OYmhoyNXCRiIRLC4uBmA5QpDMG2suYCsMiC86Vghe69aodPVmEK2NZRSkV0jaqxG7zan52I+cQzQaDcCMyWTSrSOfkyUJ1nj0Go6xc1NPnz+Px+PIZrMuotSojfuukZ0qX71lhkrOllIoX8FGRa0On3xoja/WG7O+ns5Ff38/0um0izA5l0aj4WRJc8t0nvgM2WzWKQEl3W227lQ8WuKmskhZ4XnWtaPxouOcz+cxNDQUgK0p5/YuYd8Vet0Ou/72liymCvgiSsSzkMlkkMvlnENBvsRW5rYVkbC5bL0KlfvN4ENhbj6LQtla8udz8Hs1bHRMB02jREbKapApd+rwhUIhN3/VkT7EollpabtzVx2j11tacqIiETTInEcymVyXXmB1kH4GP7Ob0VGEbNmlllJOQWKzAQABLz+Xy2FkZAQ7duzAjh07MDAwgGg0ivn5eaccAKBSqTjmGzdTvUaNjLqNkH2OBqELFpGPj48HivfpFDQaa9eL0YNktKBGWctAenGVGuCvb43H44hEIm6dh4aGkE6nHXGOuU0qrFgs5gyyXc9uYWsVWDpbesC1sYctQ9BrPBVaItTOl5ZK2bpvJSP5WL+tDl/kQxIjyXQkL87MzKBSqTjIlAaWpKdcLueu4lxaWnKyRHiVe0gYOJvNYnFxEaFQyP0fHZvN1p7KSGFvffFnGrVQhpmDT6fTLjrO5/POIKtC1Yh+K0lcNido9Q5flB1GwdQ7rI+29d5b4TSoQbYOI3kzkUjE5d4zmYxDr4jKqfOv+iYUCq2rudZn6NahVmOo0XEymXTnjZG8OvocNMJERimDPMdMV1od6bt2sdvnUN2sjkAikUC1WnXwu+bnuY/hcBjJZDIQYNrrF3sp520ZZDV6KnD2vkh6q4R0KTRUvAMDA+5w5/P5gEGmd0hCEvPKjAgoBL2Cq/V5NC+ocKTW99IYqyGwkZkaZSo+n7LqZCPVUHLu1oulYLNb1c6dO12HmHK5jHq97ohGXG9fDrlXQz1IYA1StwZZ5YckOjpiSn7SyIgpA65ns2YbjJI3azqw0ZpbVIjcAhpilvZRgWr+kgqIsh+JRJxjSeVbrVadYdTuY4xAaJz5zBvtkw+m45orr0BrtDlXJU2qY8noiGurMq75z2ZpmW6G75zyjFqiGmWD0RhlQg2fTxa4br2Yq8qMOrbq8FD+c7kchoaGMDAw4JrhFItF19RFIW7ml3mGNouQe+FQ+xxoRv3z8/OO46F5WRs5RqNR9/90UHku1amzEH6nhq5ZdEyDTBian0sHVzlBnLc6ekQTVXY0hfSIRMiWvOAzzgqx8PAzZ0gmJvNS2WzWLRw7wpAlTM9LIZFekBl8z6CHXEk6WrtJprUaA2BNABQWsfBgsxKLTubuyxtTudLbVAOQyWTcXHnBOp0eS1boNUFE88hUKjyY9XrdKSE6dcytknXPvWekoQgJDz/XmDJGsqDNc7bbiczCpOqs0SAXi0XMzs66ciKthbb5cnZSikRW20dy/SlbKysriMViAWcvGo26XHu7LHGbM6Zy0sjZKhSFGzcaNkffq3reZsNGyKprlDmrXAPKRiKRcGtPeWinK1e7Qw2C5uhpFLh/TEewFSbr5IFVhJAyu7Cw4M4NsFaWptGxrlGv0S2NkumoLSwsBDge3BNfflwdRA3ObFvVXqY6rDOqxnh+fj7gUGq6TglezSooGCEr+VJtVaejK1KX/bcaZ8tCs63duBn0vAlvWIyfBxCA87bUmHY6dxv12PIq2zyhWSG5NnGwORPN91kj3KkhVuawKmUq0f7+ftTr9UD+lAqIkTGVr63p3Uo6P59ZoSDdA4Xxa7VgK1B18CysroqC3jtrYensaf7YwmEb7YPOT2vklcil1QS2Cx0/S2FzRgWhUMjB0FpLzn2NRCIBRdEpGuQzDEro4tDcn6IAWrZTKpVcIx+VN/6dksa2Ip/ZzDmyaSw6l3x+mxO3pK5ezlXXWhEzni/OjWQnzkvlk3oTwDoWPM+ORSK4Phoc8efd7ofCvnQolpaWnGMHIEBCA9YiYmAtZ64M7VQqhYGBAYcCtXs225m3ohKsJ9amU5peosOjZ0z7UfAMUG7UIPeCGNixQVbD4jMy6m0z96DRsXpH7L6lRAW+aIyB9RT6bqNLjep9JVuTk5OYnp520YslPylcbdfFV0ph16lTqJqfbedAoQfWSBOMlul5c15aQrRRDWCvhxpm/cq1WFlZcT+33qk6QlRayk2gbOXzeYyOjmJ4eBi5XC4AUbYbFVnkxNaTEq4mVM0olk5oLBZzpCjW59I50tybEtaUsGUNabcQnv69dTY4j2boSzgcds4Dm1YoxM3P2cqhBtm2XNUuUZwLlbGWZSlaYte0VxCvsuOtjIfDYceLYPROA6ZOhpZIUe/wM3jeVY65LjzL/D912jt5HjU8igASqtbzwUoZGmv+ng0QmNPP5/OOYNcr3oHKOeWTCIlGwmqQKTeUGc19c19I8CUqoIhor/LebRlkNSStGBt7IAjN8NYMzeMAa32stf6Sni6jIVVMnW6ezQcqm1q7i01OTmJ8fNwpW3pSVGBahsBn5vwABOaq69fpsBG9kkWU4R0OhwNEMgqMRqa2HlXzkt0iEK0MXRNVGGqgFBpWZ0jrepmDGx4extDQkCPtDA8PY3h42N32pGzOdiJky7j33ZZEIle5XHYkREJjTBvQIFP5kOxFFEDLYajAlQntI0ptJk/2/31Os8J0hEq1XEfbvvLmHJb9EcnQ6G6rhg/R0uoOLXNSNrONjmiYNXWh729HO2fWGgHWjGt0SKNGBj6dSgDuDLJShQgMETpyJVi/rrpGUQOWizYLltoZ+jmai6UBorHSrlc01Ew9cS0IyzMgI59IL/boFcGO8+baJ5NJ53RyfqpftImS5uzptC0vL6NcLmNmZiaA2vLZKE/dpgq6ap1p4TCfwVa6vObQyHLUsgsLqamHAqzVtdnPa3X4okwtX1HGrJJ0mBfU+meFdvUg65rY9eD3+rXduducPZ0CeqMsndAowLKK+fx8DtsgYisj5Fae0zpLJOloDSAQJEuRvEZyDHPnjEot8aiV9ec6qEFmvbFe7cf88fz8vJsX82SWSKTybj9HD7PCnVqe1K4Xbs+n7+zomvN7pg7oABG6ZuqGDUzoZDPq7zbC3GxYyFqjY9u2kc+r5V6KkgBr5NRmqJaNbltZa84TCObZ9f2j0ahj84ZCoYBzrE2IyFkh4VUVv66JdqJiOaNFVroZauD0PTWlRIPM9WXEyXOayWQwODjorjTUNI6PYNftUGeIuo1zr9VqrrthuVwOpJlsnXhfXx9qtdUb0zQ67u/vd/aM8q+Etk5G15dL8KtPgXNxtRxHae62faA9DKpE9P18xq3VBbBGTT1ty6hWko7Wx/pIE83m3QsIxh6+ZpAiHZ+VlRUvA1ahLS3HaEbo6tbba3fo5yhxz8LzGlEzSiZkTQSGhthGxu3UN1rngAZJLwXQW5MI1zEyUu9Z82265rbmGljLudEQK0u/XXmyUZuvPEmdX82h0fFhvT3PCSPjgYEB99xbxT3QYaNkbQzTrGSPv6/nXvPP5K7oOthz2y4qQblUI0bjrnwIbRSiKISSSLXXOAcNhBpxjfasMe4V492nO325fHVkFCHVmvaBgQGHVmxFPl/3jnOgfNKJqFQqzqAynWd7guvZIKuckHw6nQ6gdmxCo/as3dF22RO/2ojNx3pWZaBsTJ+37zNsXFA1DD4j3e4z+A41IzH2rSZMqrfgWIiaQ+dqyRa9JCnws/hVYSQKPQ0Xcza2EQaw3iDry+4l12yrjbJ9Li3HsXkr31pamWm2F+1CeDZnqblklRFGavSwdS3VoNOA0QHUNAg9ePXMLUSoSMdGz2AVkrK99V5uklyA1fIgGmG7L4wogFWDwLlrp7qtQlasQ+ozCtZp0zSA9qOnYubPNaXjY4z7yFObDStfCpPW6/XAJRh6kQ7lolQquasYte2w6lPOmUa4XC4jlUo5XgLnoUx6Gsp2117lX9nGmttmVKmfS2dSS+ZI7NWyuV6Tuew+aOqCz7O0tOQgczo7iUQCi4uL7nf4/PpcdEwVtVMHS/VsJ05Qx2VPKvC2haFGMjZabMY65gPrYtrFtf9u52GtE9HMw9PbmyyByEJ/CoFo/lsNiC3a70Tg7Brx82xXJfX42abRdiOye8fn90VrD1eErB61jfZZn6zNOBRW4vw116x3mFoYtV1jrCkORVMsq1ehdFVceqOWwl6Li4sOhSmXyy5toHtrey+3k2OzDhtTRpqe4JxDoZBr+OBzrpUAyCiOpXJb0W51o/3gV2uU1aHkzwC4q0mnpqYQiayWVRJiVKSO8qY1ws1y+Js9p5VnjY55bzkbfqgs0SBXq1WH0mmKQN+X+0/DR2dVzzOAQGlUJ8EL15NODUv8mKZRhCQUCgV6rtMAM3Wj5Yfan2GrS89olBuNhqtCYbTOubOMiw4mc8S24oC2zf6c5Zo+NKWd5+q4daZPQVno07cwzYxSs0n7fq9dwbLz10NrF9UqW8vYJFynZRU0yhR+VaTt5i03Gqpkfc4FAFfj2NfXF8hbcvgcqY2MspZgPByGmQYkmUxiYGDAkS1CoZDL5bIpS71eD+R2CRWTIUnlqnBVJ0Nhcx9MSplXg8z8lP5b7/NeWlpy1/Gxrp2ypXC1r8tbq8qLMqk3eWn0zc9JpVKuXIuOjipiIgJ0jGyfgYfbGFu5b4b2rKysoFQqOfi4Wq0GmlFoD26tX1cmtt4m144zR+PJudIYz83NYXx8HMePH3fkODr+GiUzfVYoFFwTH+Y+9aV7QBSG70cCmOqhVvfJzl2v1mX74KmpKczNzbmrZxmdK4HXEhrVIPe67rjZPvjsBeWHKSYyrS2ypWeZVRT698ryX1hYWGeQ27VVHXXqsvlXRgR6SBWHb+e99d/6FcC6h+00StZcku+lc1f4RX8WCoWcseC8fHk6GulOo2MO3Vx9Px3RaNSVU7A9HxsgqKJVR0SNjA+6fjijZMuKzOVyzrjSiDHCUNIRGfLMgVI5UMlSqVIhdYqsNIP39fcYQbKhycrKirvJicqH0RJZ23rQLdxn2fKtKi9FG9gDAID7NxViLBZDJpNxTFOr7KmI6ejwdzZCVLZSXpoZYzsXfZZGo+EcIDW0LMPRlr40ICSeKjO31QjZDjVqNMhHjx7F9PR04N5yzlnr3bXMkmxfPhd1r8oezzKw1p4ynU535DRZg0yk4cSJE4GuhQsLCy46ZiDAF0sReQ6VU9Fu1UAnQ+Fj+zPOmd3w6FirDNEYT09Pu7+j462kYCJ33fKHurp+UQ8FFb3NJbUiCD7v1xpofUj7sJ1spM19W2Nrc2/8f2VRqvJRKN5HDukVJKOG2T47o0IaMUZZ/H0qMhvxcd98OeStjnx8z0YYnvNRuLdYLCKZTLoyECrbSqXi8nSRSMTdYUoYis+rsF+v5qtyyHNBgle9vtaMhZ/Js6L5ONYfq2Onxth3YcNmc9P3SiQSgeiY78X/00s6VNGwqxsJbQDWOa+at91K3sFGAYGmyzRipCNKZ06dHBpn5tKtrPC5lKVtGc4bzZVfud9aOsfLaorFojMGCpcqAsO5UK8wz8k11r4NdBxSqZS72EdRjE6cCTt3wum8QIUOC/WPQtW+FsK25eRWwNU61CjrmaXjm0qlAvutkDRvw6vVag59YCSsa0N2OddBGx21MzqCrPUrsN7o+n5HDZjN5+pXazDUEOqDNmvL1uqwsIJtr0bPmL+r+Th+DufNNbDPoIegF7Ceenb6vDpPdvoB1uq6NSpWmNpGNQ+3AbbP1misdSHSUgUq1HK57LztarUKAM7IqlFmOQ6hbSrsTggXuvdKNNPWqIxe1YFYXl52UaU6cYrOaK0vFb+S9LSpixqEViNklW0OVYL8vb6+vkD7VMoISVzz8/Pus0kA80WqKvetEM/aHRZK93E+bOolFAq5Odt0jZ27Ov38vHA4HKheaKfWWvfbGlq9pYrpCuoczkv1oA6795YT40O5Ol1vdSgsoZFpI9XHjMp9N2lZnbtVUbFv+HQnz5o6lMCaLl9aWnLnml26qtVqQL6JZujtUKq/NL3YymjbIOuCqqLSPKk9jNYAq4Cq0FkISglVqqgsyaUdD6tZBEwCESFfHkJtRE4IiZBxo9EIKCgqCsvA8yEHnQ59TsJntk2dz9PTw69RgG9tOk0JdDv0oBBaVQVMSHFkZMTtEfOaAJwRYTctbQavSAHHZt6ryoq2/COMyHw2lRMdA743DzUAr8wrRGZlXPvuMtJotw+3rqf9GedkDbKmoPg8viYa1gCoIdwK2fEZB7KLyUa2F7749tE6JIoUEQWgEuX6EPL38WM2m691dH18E86XZ1fny78B1i7+UEfQ5r59lzV0E4FaNNQSGJWZblGdUGiN9+HThbo/D8ewaKLaAR9KSzngGeRLZYMpKZ4xnjfaKDXerYyOO3WpQVMmIr0iVQLq2fpefDCFoDQfTcGyzL12b+6xCp99eDWCZZI/l8u5tmqsM6WiJ0nKOhL0IGOxGMrlsovklK3dqyhUIz0aZX1G9ca1AF5JJBYB2Cjvob+7lYOfzUOiESVLFGq11Xt70+l0oE5zZWXFGV9e/0mjbMtBFGLeaC6ah2XrPR7GRqMRyL+rYdN5WyNskSF1bLVMhI10fFdHtmrw+AxWRviV3ry99pIypJUCHFTQmsekwlU4Eui+j7IOdXoZsczNzbnmPWygwUiTz6d3k9PhoUFTjgfz+nSkiHLQWU+n0+uMfbN58qtNa9G48+5v/tzqPJUhGgcrH9qKkrlv7YbFhjjt6Ej7DPY5KLd05EjA1Pwwg7Naba1UjtGz1YWPxFDdyefjsM+pXA4648DaLYRE7nTP+Kx0irY8Qm6lNEMPpY2CLbuXis32pmX0qQbZd3tPKwxm9ToVeraRMtm9Wt/HQn0yH9leTRmN/DeFPxJZ7ZbFu1l50HqdY1OFq54919Y+A+dDDxdY33DF/ruZMPU6ArLPRGOhz0QHjp1/ZmdnMT097TzxlZXVlpR0QEqlkqvPVFRH5bPZ0OiYUJzm5xkFM0Lv7+93zgFzskqc06hY91/PEiNjJcYw8mn3LmdfRKBwIdeCt0sx4qcx1o5PVh4ajTUSEY0x67DbRa1aGb7omCU4Wj5Gx4zOtY0itdGK7iOw6mBz/7TRSyqVwuDgYCAfu9lcbWRMA8bbnfL5PBqNhpsfkRxl8GteU/Oz6qSxE5w2xMlms65LHVtS9mI/9Fk0V83OW+xHraVxlKtGo+HuBKfheqQMMgedat+6qK7QfvlklFtEiegpkTtG0wwethyyVviExthGyao0bP6GD8JSIkJG6m2Tus/PUxiPHlmnRBdrnNUgq3JipJtKpVxj8ZWVFddkwBb8837QWCy27pD1miilEZ4+u40kbE9ce4+wXTebVqDQ2hd/3k2Oyn61SowGUWFDRnvM57G7kTIjbfMOGjNlv2+2FzZCZmSsESMNBFmwbJrBdVHCoG9NlQRCx9Y279BLWNqNdmyKw/6f5sbJUqchsvtk98462Qph2+iwm6HywT1WyJovNtCgs6n5XxqxZt3rNC/Nc1Gvr9anDg4OriNItTJXC1Uzmsxmsy7S4r5Td9KIRaPRwOexaoKdrth6kkxmGmQ+J9u1dnNpg+98WL2ZTqcxODiIoaEhJJNJAHB3gROxIgmKPQW2QkbaHVb32J/5dBDtj6ZWKTPhcNg5sEQsOikLbBuyVnhTm1P4Sn34e5rL1JotKlQKn60DVhKF7fpCBdtujsQaFSohzSXzwDNvpxEJjTFvaFEUAFi7hkyjJKIAvTTGwHpFZUkj9vYqC+up4tE8Edef62PZiRba3ixi2GjeFiKyP9Nokt4qoS8qVyIDjJLD4fA60szCwoK78aqdQ6LKh/PQPede03lURa/ROH+fDoU6OXwuW+rkuyCk0+jTppt0/UkE9OWIbaMEe374++rEWTnvpVHmsJ/FOeizUj/wXBNZs01WFKKu1+sOsg+FQshkMs7Qtxoh26F7TKNKzgH1p9VJlGPqDkbXhLsZAWcymUBkrBdodMI78K21Gic1UPamtVQqFSD1MVIOhVbLi7SE65GIjjcLAOzPfaRirVDg+edFIcvLyy4QU8ejHZnpuJf1ZkZZI2UqS7LRKID0djXq9BEzlN3KKKLT+zMtjEdoi7CQloXo3OkJWcKED8bT/LhVGr0caohphCkgzK/xekB+r7dWkYCjDEq24AuFQoH8k+63Evl48Fqdr66VOgK6VvrV93zKnNZuWfw5FZrNb7ZbAqIKyMJOlF29itASssg9YLRuo2ZL7lHZVrTJOridGDereHSN1fgqMsQcPFMc6pDYSgfLQbDr2O2wKSeF95PJpHOKFI3gutr1pTPPJi1kwzLqVmeVDp0trWpn3jpnwsw8e+oUqJOozkWj0XDPy+iXt5oxQtZLDmzbXNWP7Rhln97Sffbp5JWVlQB6xLSjvUP+4Yas7fNYh5/rrf9PvWMjemCtgxcDHyIC8Xg80HJZo+lWRkcG2R4OpXqrh68N0XnQC4WC80j5MNHo6s0gVAQWduKB0vt9220l6HsGflVYSSMFklrsy+bJN/vsrRA8KzTVajXALubNVVNTU+6+XkbK7AzFBgLMNdOIs7yF0ae+lBRDY0EmcSvztYdB0xi2MF+NhRJ5CMHPzc1hbm7OORlUynTwmrW3tLn8zYYaIf6Nzp3Kn80lFEbVG6I0vUGFzAhOm1XoZRJc91Zy3putu66rpo54NommUG603pTVBzaaV9bvVnVesvqGbWEHBwdRLpcDzjTXSRELGmW+tD0mZYUGuVgsAkDAsVOHjsZ6o7n65kzIM5VKuZpzzg1AQJbUOHBoo5uBgQF31SghatsFS2vX242OfQbLOmPqKKqMUp8q4Y/rqYHKI2GMfY6+dRCsQWYKTNOo1CNEx9j6dnl5OWCQNUCgrt1sdJRD5tcP3vFB/OWtf4nxyjjOGDgDv7XvtwKKhZEDANc0vV6vO0OsXlyjsVa8z8MPIGAEfJ2LFDZuVehuOXILrv3+tfjxsR/jePk4brjyBrzgzBcAWMuz2WjQRiuaf1LF4/NGe5Ufefd33o0v3PMF3DN1DxLRBC7YewHe/j/fjj2xPSiXywEDxbaMfJH0QhSCbGUAbj9mZ2fdXb1ksdv6WzLdlfXbikG+7vbr8IHbP4DDc4cBAGcNnYU3XfAmXLz34nVlEbYGk0aDjQn4nGTYTk9PuyYFNBpqhDW6aZftrvvK8Rc//Av88Xf+GK8691V402Pf5Iwx15XOkaYMJicnHfNUoXVf31/Lj+hGlt727bfhHbe8I/Czg4MH8YOX/CDQDYrdzngH+IkTJzA7OxsoKaKzai8MsDeLdVtqo+No8SiuvvFqfP3+r6O6XMXJ2ZPxJ0/6E4zkRjA6Oop6vR5okcrPZjBgnWeeY86dZUfR6GqXOxpIJZBZpKUV+PHU95+KI4Uj637+G2f+Bt78mDcHInUAAUWvNcRcQyX75fN5d/83L2rQQMUazHado5XaCt727bfh+ruvx3hlHKOJUTxn13PwrP5nBZxzqwct4qKkXeUXcGyFUS4tlvCH3/pD3HDPDZioTOD8nefjvZe8F0/Y+YR1Zbab9WVQFIxIEY2zXkRUKpWcjCQSCfe7NMpbapA5PvOfn8GbbnoT/urZf4XH5h+Lv7rtr3D13Vfj3bvf7WAZQmCERRna80F4iAn96iZSsSrcZBsltEPo0lFZquC8HefhlY97JV74mRe6n9s8m4VpVQj5sjlV+30vyQo3H7kZr37Sq/HEXU/EUm0Jb/3Xt+KXPvtLuOmXb8J8ZbUtH6Ph2dlZTE5OYnJyElNTU64nrkKshMy0/SQNrDbY58Gm4dCer2TobjQajQb2ZPbgTy/+U5wycApWVlbwqbs+hV/70q/hmy/8JvYn9jeFeRROL5VKgcb2JKopwxZYvRhdvV/rmbebPlCj/KNjP8JH//2jOHfkXESjUeRyOVcSRYVNQ8foeHZ21kVbhLEJkxL5oWHT6gELU7crS3zGc0bOwT9d+U9ra1BrBBwHKpXp6WlMTk7ixIkTGB8fx8zMTCAfBsBBp1qXaVsi2oip0zE7P4uLPnoRLj5wMf7p1/4Jg7FB/Of4f2I0OorsStZFmqpHqEiJdlkkTSM8zpuKmE4QsFYZYpX2ZhEysCovt73qNqzU10rc7hq/C8/79PPwK2f+CgYGBgKtUKkTifLoGeDgmmuEPDw87Ihbtm8+17+dfeC5+H+////iQ3d8CB94zgdwSvoU/ODBH+BN330TVkZXcF74vHWGXqNJu16UG5/B42f2crzqK6/C3RN345O//EnsSu/Cp/79U3j2p56NO151B3Ymd65Lz9C4+lKMdMqYBqSOomHmGee+LS8vO8dco+pWUUSgC8j6fbe+D686/1V4+Xkvx8LCAv7sf/4ZbnrwJnx//vt4TOox624fAhAgalFx+ZoOAOuL4C1k3En+mOPSg5fi0oOXtvScmie38IzCZDaC8UU0PmFsZ97//NJ/dn9Xr9fxoed+CPvfvx93jN+BUyKnBJpiWFiXZC4eDBpjYI0pXKlU0NfX53Ih6m1TEbP+TpnGrQjb5adfHmDlvuWCt+Dv/u3vcNux2zC8azjQ/UcPs6/MhTlxbQqh7FiOTiDeZiMUCqG8VMZVX7oKH3zeB/Gu777Llbb19/cHyvl4gIkyAEC5XA4YW0KqSlbUcr5elQ81Gg1EQhGMJkcDiAEVDCN4RVXYp3h2dtaVeZCYY42xJZ3ZqKmb8eff+3Psy+3Dx17wMWcg96X3uXkr7B8Or106T6RkeXk5UNaksCX3lOur55rDKmeNXDcbo+nRwHm/9gfX4pSBU/DM057p0iqcc6lUCkDoXEdGz3btlV1N5rjN5+uztasff/DQD3DZaZfhkpMvwcLCAgb2D+Bzw5/D/fP34/zI+YH0iSIJAAKEXM2D27n0Gj0EgPnleXz+Pz+PL734S3jqSU9Fo9HAHz31j/CV+76CD/7kg3jrBW8NnFFfv3CfQdZuXDS2WtZlI2Grv6j3WhkdGeSl2hJ+fPzHuPqiq9cIC339ePLok/Gz4s/wlMxT3MUA9A4IXSijkYaZwql5HiXIKNHFJ3i93FRgfVtNH3FNC+DtQQb8eQubo+h23sXF1XxXrj+HxkpQcajxU6gWWCNmqAKt1WpO6BqNhiu9oKLt7+93h457Q3JeK4SFAAy0sozP3fM5VJerOHfgXAdHM8pV6FoNiC3fIjTEw0TlbOfdq8jtd77+O3juwefi2ac9G+/+3rtdPpOle1wjJQNqy0t13ixTlZC1RYC6cSpoeO6fvR+n/M0piEViePzo4/F75/0ecsitSwGwnlfz8svLy+45mcO0FwZs1c09X773y7jk1EtwxWevwM2Hb8aezB785vm/iZec+RJHcOIzEoFjvwDWUXMemipgIEB4mvKu0S/1UTtcETv42Ysri7j+7uvxuie/LtAxjjKjDUq459Q5nAsNst5OxZePSd2N0btw74X48E8+jJ/O/hT7EvvwXzP/hX+b/TdctfMqoLKev8KLFfr6+lzOleeRDhHXu9f8Ah0r9RXUGjXEo3EAa/OMR+L4/tj3sfT4pUBUrJwbZdkryctWrqhRbtZ5TEe7iFxHBnmqOoVao4ad6Z0BGGg0NYr75+7HwMCAU+ha9E6h5wSpSBktqGDy8NuoodtD0upoZozVOdDvNedDYVWj6CsLUU+93VFv1PHGm96IJ+96Ms4aOgtTU1OBaNYiCqp8fOxe0viZv9cmD/wdwqys1+YBaxWOuWv8Ljztk0/DwsoCUn0pfODiD2BffJ+L4IvForsGUBtp8DCoEabh5sHR8jhGb5SdjcpAWl37f7z7H/GT4z/Bj37zR2t/Ewq24APWonJl5mufZbv+XEsSc8iUtQz3TuHqJ+1+Ej546QdxIH0ADxUewrW3XYsrvn4FPvWUT2GxuBhAHBgV06ixoYPWz7LmNJ/PI5fLIZVKedNHvTiXh2YP4brbr8MbLnwDrrnoGtx27Da8/huvRxRR/OrBXw1ErOyYNDAw4PLd9tpCRs4KR9NA0uFj+kWDApu6aYdlDQBfuvdLmFuYwyvOf4XbewYrPiTEom+E5WmUlfRnOS267p3uwZue8iYUFgp44seeiEg4glq9ht8+/bfxtMTTMFYeQ6PRcDlRfiYDrFqt5pjq1OOsV/Yxv3s5MrEMLtx7Id55yztx5vCZGEmM4Pq7rscPj/0QJ+dOdqgQv9LB11JQm9ZSEiT1Ei9d0dSfLfvtNGjsOIfMQbinr68PfdE+RCNRjIyMIBwOI5lMOkXLqIaehUZySrbg32nBeyaT6bimrpPn8TEKNYettdA8IDQeNMaWpGFr0riRnT7Ha/75NfjPqf/E16/4OqIIKnZtnEAYGIA7RM1YkowUiFqoM6IlDfF43LUI3cwgq3AfzB/Ed1/6XUxXpnHDvTfg9777e7juf1yH7GLW3SKjbGkqTUY7LMuyJS4qO+wGxEYJbKKgEZ0qwVbGWGEMr/3n1+JffuNfnPcNACGsoSg0gHRm9GYc20UKWLuZizcNaaclX5vMTkaj0cCzT362c2j29e/D3170t3jGV5+Brx76Kh4fejzGx8cxOTnp5knjTKXKtc1kMhgeHsbo6CiGhoYwMjKCfD7vbT7Rq1Fv1PHE3U/Eu57xLgDA+bvOx93jd+Oj//5R/NrZv4Z4fC0SYttL6hgyr7nmSgxkSodGjgaFyhoIOnZKWiMJrJ3xkTs+gksPXoo92T3uLPgUtUZR1IkAnENER82H+LRLbPUN/u3n/utz+Mf/+Ed87LKP4dTMqfjRQz/C2299OyI7IzitcVpAt1G2qAsUjSAJlM1VfJ3Deq3HP/nLn8Qrv/xK7H3vXkRCETxux+PwK6f/Cu44cUcgxeXr8qbltroXtiRK5alWqzmk0HdVqr5aGR0Z5OHkMCKhCCaqEwH4bXZ5FjvTOzEyMuJKEyj4nJBGxzTKHDzYbItIhcqoYashDw4LV2sOVaEizaFpwbuvg5GNlLuBrl/z9dfgn376T7jppTdhT3IPKpWKU+7ayIGHheQXMnsVclf4TOsGAQQMMqM7Kqr5+XnXHKXVsqe+cB9Ozp2M3fHdOO3803D7sdvxDz/7B7xs8GWu3IYlWZZtreQKQkXAGrRIWIz3KLOjEes1Nd/WrmP34+M/xkRlAo//4OPdz2qNGm45cgv++ra/xvxb5gOsTEZbWnY2NzfnIGA6Y4SrGRlz3jRytqyvHTmxuU969n21PuyN78UDhQdwUu0kTExMOEY10wGlUinQtjYWiyGbzWJkZAS7du1yhCKyfLvpBrXR2JXZhbNHzg787KyRs3DDPTc4Y8VnXV5eRjabdQgKjV44HHbPoxeR8FzzjPPcar7c5so1/9/qODJ3BDceuhFfeNEX3M9a4ZfwXGqZVrNLI3qdj736pqvxxqe8EVeecyUWFhZwUuIk3DdxH75w+At4Y+KNAd2iEL9ybTR4SafTyOfzztm07Y57qcdPzZ+Km19+M0oLJczOz2Kofwi/fsOvY09qj2t9TEd5cnISExMTrkpDS0FVL1uCHICA7SLSyK56DB4VOdpSg9wf6ccTdj8BNx26Cc8/4/mrgh8J47tHv4tXnPMKDA4OBvKMNMAK3S0tLa3D6nmIWK+n9XV8r602xhwWsrYtQvV7Gi3NJ9iEvkbGfFki22aj0WjgNV9/DW645wZ862XfwikDpzhoSGtZ6QhUq1XXxL7RaLgbWHTewFoBv0alXAM+G5EMlrORvMHnbnX+AYJMo47FlUVHRmMdNFmPRBzUqFjCCBWT1nlq60C9oEGZ4+3I0TNOfgbu+u27Aj97xZdegTOHz8TVF12NSDiC5dryujIJEqUIARNmbzTWrlrU3KzmkXsB7Wmuj179bGUWR+eP4gl9T3BRAmuOyQDn+lNpao93Rsc24ukVkUvHRfsuwr3T97p/h0Ih/HTmpzhp4KSAgqvX644Ux/3WUiXyHNgQh2eU+69RDM+RGkKNehhRtzo+dufHMJoaxfNOf17T37HcEiBo4GzDmK1wfjhCoRCqy1VEwmupmFAohEgognpjTXcpS5l6g+mtdDrt9AbRHx8BbSt1eKo/hVg4huNzx/HtsW/jdee8zlX30CBPTU1hYmIC4+PjrmxSdbhFSX3d8qhPrawodN0OGtcxZP2GC96Aq754FZ64+4l44q4n4n0/fB+qy1Vc9dirkIwkA2SK+fn5wEUQqmQs6UnrjpXQ1U0+zY7yUhn3z9zv/v3A7AO488SdyCfy2J/b736uRtnWJPtgX/WelKWnUfFGBIDNxqu/9mpcf9f1+NKLv4RMLIMT5RNYWVlBDDHnNLD1J40zPVXCzSpAVGratUvZpPw3HSUeQtu3uBWD/NZvvRXPPOmZ2BHfganSFD77X5/F7VO340/P/FMslddKhfhVG3hoHsc29rBELq1X19SCylG7xK5MLINzR88N/CzVl8JQYgjnjJwTgLQ0GmXKgC/Neauzoz2se610r/nWNXj6nqcjH8nj0OQhvP/u9yOMMB4XfRxm52ddXp5QHp1mzpHzpLOjF1406wbVq/H6C16Pp3z0KXjXd96FF53zItx29DZ8+Ccfxgcv+2AAcaOy9O07nU5lv2uuEAj2vCb0qORSq1xbYVkDq5D7x+78GK467ypEw35VqxUXGqCoMVBd40MIe73ul59+Od793XdjT3oPTsuehluP3Ip/OPQPuDh3MbAIpx80yCKJi8gOc+S+6L5T1KeV8Y37v4F6o46Dgwdx79S9ePNNb8apuVNx+b7LMTs1GyhJJHRNIiN7cKstUkTRwtGKVDSTFyVxtjI6NshXnnslJquT+KNv/xFOlE/gvB3n4StXfgW7c7ud4llZWVmX81DDajunAJuXGvViA28/djsu/sTF7t9v+OYbAABXnXcVPv6Cj7s8MudjDbPCF5q/0aGG13q/9vtWn+m6268DAPziJ34x8PO/ffbf4pdP/uWA02AZ69Fo1BlWLRujEVFngnunB95nINtxMCYrk/jNr/0mTpRPINOfwem50/H+//F+nNw4GQ8VHgqQJrQjksL7ijRwqNHY7Pl7LUd26Dy1vIIGwDoSlJ1msm5Zs52Mo6Wj+K1/+S3MLsxioH8A52TOwbtPfTfq03VMLE8EGl4orKufqzKj6FA3pYetjCfteRJuuPIGXHPTNXjHze/AyYMn432XvA8vfexLA5B0s31XR0EdZGXUAsEOfcxLW5na7Lz7xo2HbsSDhQfxyvNf2dLv23OkMqKvrUYI33/p+/EH//oHeO03XouJ6gR2JHbgV076FVwSvwRjh8cABPUEUSvOl/pfnU4buGzV/AuLBVxz0zV4qPgQ8vE8Ljv1Mrz2Ma/FUnGNQGzL/+io6V3xwJoMaKko9z4Wi3nheh+Zq+ekLgoK28pxvOzMl+FlZ74soIjK5bKDi7RmS2FOVbyE8DQC04sBtI6Q5J5mUB7np4Ltm/vj849H4fUF77Pqe2hHF7I4leXr66xjo2PmFFk7SVKIQvDMc202d52zfgZZf3q7ka651sapElN43Rb1Uyg5B1tIzxuUSIjZaO6NRgPX/sK1WL4gWP9XKBQwNTUVYOHbnrf6rHxe/puKlixaPXBapsDblwjv6wFpVWbs+PILv+x+x+aOdS/4rNqJTNMYOlciSeoskVjlg/g2m3utVsNfPfWvXFMVbY15onrCEf5s7SjXVuepjPH5+XkXefL/2oXXW133p+58Kr73ku+t+1sl2Kj8W31j60JVvilfPBO2TFD1ESMqKmidbzN5uWDkAndm9XkZXZKkqDdVcS/q9XpAucdisUB0VyqVnEyrjmx33ZvN/e0Xvh1/9D/+yDWzmZiYwNjYWKD+1ufEUY9r9Mx5V6tVV+2xkVx3M/fn7HsOLrnqEqcreCb1TntbukSnWQMBIEjwjUQizvlQKJ/Pq7JFealUKgiHw67xi8656Wi0MMbGxhoAHjWvsbGx7blvz3177o+C18/D3B9t896e+yM/92Yj1NjUZK/CE8eOHUMmk9lyMlU3o9FooFQqYffu3QEIdnvuWzu25/7IjO25PzLDzv3RMm9ge+6P1PDJu2+0ZJC3x/bYHttje2yP7bG1o6Uc8qPFE/l58rqB7bk/HGN77o/M+Hma+6Nl3sD23B+p0WqEvJ1D/m/22p779ty35/7oeP085DK35/7IzL3ZaClCzmQyAICxsTFks9nA/zWkQxHp42QBFotFTE1N4fjx43jooYdw/PhxzM3NuU46rB1lOQUbI7C4n92L2IyA140lk0lvcXmxWMS+ffvcfDebe7PBZ9LWcHymiYkJHD58GPfeey/uu+8+HDlyBDMzM652FoBrNTg6OoqTTz4ZZ511Fg4ePIh9+/ZhaGjI25S/VCph//79Lc29IaU11WoVc3NzGB8fx/HjxzExMeG6Q7HovVAoYHFxEf39/RgaGsLevXuxb98+13Upn88jn8933Oxhs3XPZDKB5h6UFbaXnJqawvj4eOAWJ8uM1DpltnUku1ubtyQSCQwODmL37t3Yv38/9u3bh9HRUeTz+UAbzVQqhVgshkqlggMHDnQtM7o3lB/KzuzsLI4dO4YjR47gyJEjmJ6exvLyMpLJJHbs2IE9e/Zg9+7dyOfzrkGIXjrQrNSmVXlXeV5YWECxWMT4+DgeeOABHDp0CMeOHUOxWHS9nXmrEM+lr+TGNkxgZ77BwcFAFy92c7MM7E7OKp/DXhRRKpUwOTmJsbExHDp0CIcPH8bExESg+1Kj0XDyMTo6in379uHAgQPYs2cPRkZGXOMKbczSrETHzr1dedG9KJVKTv555SWvGOU95nwOsuBZJZFIJDA8PIzdu3fjlFNOwYEDB7Bz504MDg4GGippXXupVGpp7tQxPKdTU1M4evQoxsbGcOLECXdOyTBu/N+6XXZ1y+fzroEMdTl1u/Y/9/WXaKZzWl13zp2VPtPT0zhx4gSOHj3qutLZds5sbJJKpZDL5Zy+GBgYcM9CPUIZaaeEyyfvvtGSQeYH8pYXHWoc2NOTi2ELpH3vye/1/7V2SxuFaEcjW2C+0Xs3m3uzwWeiQWapTK1WC1yhx/pEW2umdcs6d3U2tPsY37/VuatxY92fGnit+9P6b3sAVKnqHHlg2u2m02zuNMh6nSJlZXl52TXG0G447IpkG76zZAyAKw2ytetaH6s/o7HhPmhL125lhnKjRoPPyL7fWv9dr9fXXb2osuG7E7nddc9ms4E58dq/Wq0W6CjE9w6Hw4G6eC3tsHW42tnK3lil1wM2M8itzN23vlqGxXOztLTkvT/a1g27nvtyyxblXY0GO7ptpnDtmW1VXrgXLAFi2RZb0rL8jRfzsOY1FFpr1KP7YZtSaJtfvajE166y2dzVIIfDYSwsLLi+3tq1jPPRsidbk8uvWotMB9qe11bqrFudO29eW1xc9Dbd0c/UuwVU/jlve/1lp+1iN/u9ri+X4IHXRbDXVGndrtbuRqPRdfW7tpiajdXttXwULh62XuYQGqY2UueoDR666brVzdxUyWrEwPo/1vHSm2ZTEBoC1udVKhWnmClorKnmZ3S7rjpfNVZa82nrP7UwX9sH8v0oF4x8+Ls87MBab26iG+wW1MqVad08p9afa121to213ds4/2ZNBHqxB+oQac2lbcaiHbBoNNSR4/fAWkMNPvtWDqtnFJFTHaPtHLnWHNZIaz3pVjfb8A393GYNTnT9uSdca/6b61CtVl1fe19TnF7uk+6H1taHQiFUq1UkEglXtw1gHaqh3fPofFrDzDXqdlCnqd7RHgDUCVx/2/yGF9bo+WUzpV7LTVcGWQVDjTH7hbIlGXv56iX0LHyPRqPucFE56FcAbkFYPM+N4+hl5xd9JjXEajC4ufq7qli3avgMm1X+WrSv3XLUeSFUxgb97FDEHsD6bL2cu88Y66HQpg1MacRisUCnKP5MDRsPHJUaANeytVgsBqJToixbZZR50OmUsiUlm4TQMdV56+i0w0+z+TRz2FRe2PRCLyFhikP7niuiw57ONBT8vK0cdn01NcYUhzYjUllSB94aZtttbDPotJfDh+6oIYjH4+4aTM5Nm+NwH8vlMubm5lyKge/L9+P56HZY/agOHo0zAKdz6PyXy2UXsWtLX4tMUL64Nvq1k7mqE0d0Ts8AbRabJul55N9GIhHXk5vOhEV9emWDehIh245RvMaNL71jlYrJTUCgjf7+fgfRUEnUaquXeGcyGQwMDKyDU/naqmhH854+o+G7iq0d767dTfRFxopK0OnRvEg8Hg/Mk3lNGvJ6vY54PO5u11Jj1csoWdfTHgw9IGzDR8dLr7dUw0InkD9T2InX8dFw0MDzXt9eGmT7bGoseGGG5qvYK1ovyeh1tNZszbVrksoL8/TA2tWQepEKURR2EgPWrgi0zuhWGDJ9FvYJ5+UYvEaPeoYdu3j/MQ2epsA06ve1K91qY6yRMR0bzccTFUylUi54iUQiqFarTvdYLgZ7cSusTcPHfeqlvKveZ7cudrSiXiI6pXfI6y1QvF+bZ5ej3c5vzYZt8amdw6gz2bELQADR0r7ueouczyD3SvY7Nsi+A6+e2vT0tCMXMYmuCkk9Ph4GtuRjM3smwJPJJAYHB53SYF5KD1GvBM3CwTaK8xllHTa6sQePv6Nf251bswiZMDQb6Ouh1JaefAbuRTgcRiaTwcjISKBtJT3VXgw7Z1+7Tr18QZWoQqM2F02546EibEbPnFAd2w/y9qteGWTfOWBEYFEiXoTO9nvW8eHotVFWx822MqQx5rWXXCtGaZrj1vnZ9bNy3kujZtMBJOuwDShf6vhUKhWnZBmt+Xpd2z7ndt5baZj5WQpPx+Nxt0/pdDrQyhSAQ5joSNPxLBQKDq3gHlKHUld1e57t2tiAgHKtOpLtMtUJYv6eESeApr3RO0Uem0XHtjWz/ps2SdMgzENnMhkMDQ2hUqkEgkKb+uh2dGWQNVLTG25UEfGQkBmrB5mRpOZkCR1QOff19bmrsbjp2ou2W8jMws7cQGvstAG55rI5fMZXnQ0lzvQKirQGTqF+HmLmXrnOenjoGDHfo3cN8/17lcux62y5A/wM5ik1mtH0hB4wGj8lxRHOo+dORRuPxwP9gjVX3s0zWUdDIURCqWSk6vpqrladDf3a6+jdNta3CAWVErkdVFDs36zzUbh3o0sxgO5kp9kaMx3Aq/QUhdN1JtJiL8awF2RsRe6y2bC6wkbHJB7ymYE11I3PHwqFAnqXuWMavmQy6W4Y4971SqasUVakCljjedCRptOvz0sonsxmJbpyLxTx6mT4jLINrvRyFQYD+n/Aasq0WCwGeqUTddD8fi/QxI4Mss9jtbAp4TpC1EoyIvtUWXjAWvLdvi+FTskD3SotnyFWI2e9KYX2NBpTpa55KfUKlayxVRGEKnPOhZ/N9WZDe3rYNGj28g97YUYvlJSiAprDszAdAPeZuob8e80jkk8AwB02Rsj6u0rk4Pr0KjLWSyL0IgK91o1RG6MazefzedVQ8o5pkvB6oZQsKVENnMqwMmB91xkyB6j5QEKSNHI2Lwt0Jjs+REgdfxpkRSCsviHszrmTBW6rEny3OT0csDXXinO0aQw6RFxLdagZyGiUqlCrliT1yhDbvHszVME6Z5wnAy46FalUyt1TTR2kOrNbQ7eR/CsXSH9HYf/+/n6HPOqlPeRY8KpJtSfdyE3bBtmXl1KWX7lcdsaY3io9VT4AvVYSRZT9q7f9cFBA7QJ2Y4ybGWJ9Lj5PsVhEoVBwh169cB8jWBmE6on7riDrRGH54HAKMGFGzoWKJxKJYHl52d2MwmfUyNJCyWoIu4Vk9DAzItBcEqEhzZmpsrSQHaFqRm+MlCORiMtlWZiyWQTX7mjmdStUPTs76+qqaZCXl5fdnCgXjHR4hvQGMJuS6YWBsO+hJSk0UMlk0hlblgWl02lnyJhL5u/p7+o9xL2A8azT41vn2dlZzM3NOYNMx4zPRRkbGBjA4OAgBgYGXC06S5z0btteOsvNhiJoWo4HrC/Poj6hrNPoKi9E9bJ+7VZXWgdagwtLtOTvcb1ZHhQOh11kTz0PwD2H3vZGJ4LleXSO2pm//V3V83Y9+Dxq/GlvlHeh/Iv5+fmmDt3DGiH7ImOfIfYVXavxVY+VzLp6ve6aFhQKBVSr1YAAAH6j1a6g+SBTm5dUtrgS1LRYnxEPD4XWZWoNIBWUGuZO7pH1RZj2cHBdmT8mgYl120tLS0gmkwiHV6+aq1Qq6+BehXT4f5rP6UTgVPn4SECMhjOZjCNRaImNHkpNI1QqFQBrdZxa707HSCM8HiJfM4J2hnVK9Vq8ubk5zMzMYHJy0jV6YE6fjoJGbeFw2OUBFSq2ZKNOI03rvFnHSHN6PJM0XDRafGkdqr4o3yrzG/UJaHWNfc6/IhAzMzOYmprC1NQUZmZmHKGL8s88ai6Xw/DwMEZHRzE0NOQal/BsaM13rxRrK4N7wKZIAFwlAMsQKVfpdNohJ0ra5JncSidC9Y2F1xXZYl6VBNyBgQEkk0mEQiEsLCygUCisC3z0Gke9cpef1Sso2D6PPhPtj/0MX5qEL8pMf38/VlZWAnqtm9FRhGyT+eymZEsQtJRCySEUNrLWeEcwu9ZwM8iwVsKFj7rf6kI0i4jVEKkXVCqVHDltcnIy4I2Xy2XnpfLwMzpVg8zvtfmFGppOFOxGxpiHg0SEfD6PXC6HaDSKpaUlR/BifgSAQy0IndIg2/xaJwfDJ/xqnDnvdDrtPGQOXSMLWfLg1ut1VCoVF1lqnioWi7n8FI2L3jfciRKzTinnooxqdl5iZzpG/yob/Eq+RLlcdg4HEIySujUSzdAUlR0SJQcHBzE6OoodO3Ygn887o0VZ1qhAXxbmVtJLp8MSK5W4ODc35xzkycnJdfljrh/ZsSMjI9ixYwdGRkaQz+cxNDTkqjasQd4qA2fTZMAaU53OGRW81UepVAqh0CrhaG5uztXUM9ixkHs3CJwOX65bHTHOgTnf/v5+t95DQ0NIpVIIh8OuJlmrVSzKqhyYvr6+QDljt0PXwabLNM3HtICissoNoa2zddR0th9WyFrDfVv/qhc/a7vDWq3mohMa4Gw2i8HBQQcfxWIxrKysYG5uDuFw2OWANJ/QK6/VEkTsc+gzKPTICJn5Ktu+0QfFMkJrFh23C4/5omSFjvi5wKq3OjAw4BRPPB53jEZGGISQKPCq/LQAvtsciTpSdt5qkPl5+neMzG00GgqFXOMKu6aNRsOth+YNtWynG2PhQ4nokDJvrM5bvV4PkIm0PaNGyFxrktA03dBt6mCjKJkOcSwWw8DAAEZHR7F7926Mjo4il8sFImPLi7AvTRF0Cv/6csfWKFvyKBnsjI45h0Qi4RxTtt9lW0RtL+k7l70Yem741ZL3uA8AXMqFz0w5oTEmDMx9aDSCNda9LOFSfePjFtBxJ2+j0Wi4nhHUPayU6evrc2Qops+aVV000zutBl42n60IHddE01haEkbngp/VrLSUuj2ZTPa0hLLjHLItCbJNBxjt8OFJwachVi+VjR7C4bCjlReLRUcA6JUhbkZG00i/XC67l5ZVULlqAwIefoVgte0jX2qILQzZqXK1OR0ekmQy6f6dzWZdbTENArvoMIIB1hicPii/14QQNbJqlPUzORd+5aGgo8a/t3NVT5cGWfvn2ran3RgLSzBiyka/as09IzDNtdKjVkIXIx46dNrowa5lu3P25RMVrSCKRciR6IrNtVrEyn61UVqnw+oaLV0hikXnn5ExHUymBrje7Iufy+XWtcjsJb/Arrk+h0XnLFJnjZA1svqyUbzVLb16FotuqZ7hfmiDHi0n4wuA10HbSC57MSwqpPpG50kjrEabgQGDT0Vn2OQklUoFnr8XxK6uy55UmJQYpN5fX18fUqmUO+T6yuVyzntSOrkKk2+zOtk8Ve7azpAet4Xci8UipqenAzWOrNekE6LlIaqkfMQcFcheeK2qeHhItMOSGiLCu5YkpcpCP8MagF4MVR4qvJo79UVHAJzXzNwx94fO0dLSktsHzYtSGevFAXYNWh3WSWCdJedSKBRceQSNA5Wk9sHNZDKIx+NoNBpO/ui8sn0ljaM2NlGD2MrcfWupip+/0wyW1GhI4WrrUFpotBeQrzVmNMo2WvY1hqHx0CjG5sJ9eeNeG2NrhFU/asWI3RfLbVEEko6HVnlw3a3x7lbfWLmwLXZ5BrRTF88GHVVGm2RSc598c/bthZWrTp7BpsdYckUZ0ve1ZDgN3qh3tN0w67z1bx42yJrD59nYqIqLQZKCwka8WYjeNzeBB4SfoUpEP6ObOWstrpYysUyFxtfWkWqOytbr+iASn8LyKa9OhzXIXGPmgfv6+hwr1kZZFirSg20jb1901s3QZ6fw0nFTeVI2PQ94uVzG7OwsJicnHYOZ0DCNMrBG6EqlUg4lYGTU6sUBzYbNH9sGFbzNrF6vu0iXeTVl+TKnz4ivXC5jZWUFlUrFQe6EtLkGthxmI6fUGjRbq96sJGYj463Q5UY5yl7CvXY+VmbVaaPs0pmxZVqWYLkVJK5mDpBNjbGSROFaX0kO5Yz8mtnZ2UBQoEaAz6JphU7Pr55Lq2MUoma6JRQKubkXi0UXZRaLRTQaDVSrVXdONc2nkbe+uiUFWhSRBjSVSrnKH6bAAKxbe64rf4+ON5Et8hOy2ey6Wu9uRu9aMXlGKLRa6sToeHh4OEAWSafTiMViTuEo41cFWSEeIOgQcLSyEHwf5v00ymKumOQtssT5Yk7cduxiDsd+vsKKvTrsvvfnYaG3xpwkmwOwBKUZVL+8vOyIXgACnmovvOzNnoFfqVR83imjYl4zefToUXetJPeGrFM6dlQeuVzOpUn02sVOSHWqaFlqRbYv5YYRe6Oxmk8DgFQqhXw+j+HhYQwPDyObzSIUCjl4mwzUhYUFB98zUuWaWIdJ81ytztc2QlDDxvNmGa9EFPiyDncvnMuN5m9h3s3gXXvTkSWkWWPcq4jezludIMsQ1wYTuifW2bAM37m5OUxMTGB2dtbpI5KpgDUiYLPyynafj39DxIkliTT+jUYDCwsLrvkQZYpo59zcnMvPahkpUSDlveh+2ZvOuonuFUFMp9PIZrPOGdVKDUb7tDdqa4hizc3NuYCR+nVwcDBwmU+3UXJHBnmjyE9fSqzgHZm8f5ewXSQScYujuUx6J6y9s55ju3kHm/tjgl7rGaemptx9wuwuRjhb8+R8D2Xi+T7P/rxXUYQaMOu9cr2onKjULVxmYS++n+ZbmkXJWzFsxGfTCzR+09PTOH78OI4fP45CobCOyKOsYR5Akgn1ruF2D7uNfHylcXNzc46c1Wg0nPLXMpBcLudaIjLCYbmflqHZq/P02ZQZ2sqcbSWBRsiqgPhvJa/oFXm2Vhfo7cUudu4AAmdeUThrkBVls81MfIZ4K+TayojmvLmmmhKjUba3VCk0r7LGqg8ieCQ1MmLVs9ur6J/vyyjWGirqcEbLdIyBtZaldCoZRRI5tTlp3a9e1IXb9IUaZAZlih5qVG33lfs3NzfnAs1sNuugeO6VIlOdzLnjCNk+gOL++pVKhLkHsqypGDlpmzvw5b9s7qVd0pF9PyWiaS0poy56oarIfGQn/WrzR/b/ezXUA9T8TjgcdsQnKh77vLY9ps8YN2tg8nAM377bCwXoNLG5iUZxXAvKHHOHtg65HWPsU7K+xjHMSVGh2Jp0OklUWrVaLXAZBZUHZc92UPPJVatr6TNuNNjW+SmXy4EGHwpV833tz3oZYerc7fprCotD5deWY+l+9yo3udHcrTNpjfHMzEygFaN2y1OWsa0C4XtoDwQiJXy2XkXHKltW19hqEZtnZfWGRS8Yadu69Y2McbcRshI8qQ+08RM/X9FOnjOdO7DW74DRMfdPDTLPFQmz7Y62DbI+qNLGbc0qDYEeEi0mp2LkIljDzr+jcVbBpJLSaKFVY6fzt+UaOk8KjzW+unFAUHA1YraGuZfGWJ9FiTjxeNxFwxbm48FX4dH3UAOue/hwG2P7fBxUcrZWvFarOdnTg6cHXevA23U0mhljZflqaoOojl1TyrE6RnrrDF/kADCa0LWw821nXzb6OyrUcDjsjPHc3JyLfNTw8XcZcRKZ4HnYqojTQrkaqTV7Vn61sLcPZevFvK0xtqUyXFftZbCRQaasqNzb24msjPh0aafG2Oew2bphJUepPKvjpE2hSPBlYKb9AWxZXTfGWBFEDViYA6ezRMeGRlZTORaFDYVCzoG2e6b7Zp2/dkZbBllhKn1QXzJeveaNomlrIBWa43sBcMQeCjD/n19bMXr8XIU1Cdtyg0g46O/vD/Qv1cXXxuNKcNDDuFkk3wslYL1AzTMymic8T4Wg5VpcDzpKGlX42I4Ph2G2sqIGjf+vStQSLAhPa2mLwpXtRMdqENSr5lqyvEnv+qaDwNSBQtyEpIm6aLtHKjHmqNSBbcZwbnctm0GZlFtGAESumCdXZi+vBFSnhznETpjrraw9z5CP0OVDqywpSvO0rELg/nAtVOl2O289dxrVkjTKngYKWytkbclgyqXQCg+dtwZIPv5HO/MHsG4NbYROx8K2E6au1L7ymr7JZrMYGhrydkyzqaReQO0aIZMdrlEvKyC0jl3ljMZWSXTqbKltYEMTnXu7RrntCFkPt7ZqVMhBcwd66O1LF0+VK5UQSQuhUMjlJ6gwrEOwWZTMz1CCAiNdrRXmxmnUowqXxk3zDPxsHkT1Jn0ElF4M6xzZ1m31ej1giNm0gnAX8yDhcNiVAtie2w83XK1G18LnREP03xR+pkLIYCaDn6zqZjBYq9GxMi2Vfa/5QBrWWm21EY5GNkoeYReppaUld48vm7SsrKxd7q4lJs1KtlrZFwvjaj5VI3c6aZRTGudMJoNSqeRSOAMDA67zWSaTCcClCp92O3xQtS/3bQ0Iz6BNR7HhD5nrFoXjWunoJLJU+Ja8B6Y02PVvYmICMzMzzkFWgqVFAfh+5BxwHXSdfQ6sjY7bOcc+Y8xGSezJoJybmZkZJyOKwpGcCKw6mLziddeuXa6T18jICAYHB71XGnY7aFMYfFHfK4rGtqp6E5wGZ8oTmZ2ddc2u9IY0fantszq5ldF2hKwwKQ3hysqKgwYtg7FZDkhDe108LiAVBwe99XK57JSLKr7NHtzOG1i7iF3zfLw0W1shag1ypVJxUbtl41lyjE9x9Bq65ro1Gg1XhF+v192hmJ+fD3Q00namNGq+/E2zesCHY1Cx0yDbRiuUDbI1CYHZywOUsdludMlhoXI6ONoeVvO9hOiUPMVrDdkiE1jNR1FJM+9GJcs9Yc0y66htydZGw4c82ZSRNcqKANkqBEKllUoFuVzOQX0WzbJwcTfD6g0aI9UhFtbm/2tqQZ1SzYkrgUcRI66dNXqtzFfTEiSE0iBTqTNKZvtdrqXqCl++31eiyK8aIXdaW+3L2ROOZpXD5OQkJicn3TOMj49jdnbWORc0aja/z0ZFIyMj2L17t6u2IdGxF/3Pddggj/urFSm5XC5wsQXPrN6EVy6XMTMz48hzuk+admLDJdsAqt1n6YjUpUJMxWhvNbIRpM/L1QOh0KuySZnrVMjG3kusB3OjDeJnMTLUKF/7CzNCLhaLXidD50JSA5/TOh4bkcC6HdaZoTKgQWNUR4VAxcq1p2Lmszczxg/n8KVF9KVGhZCjkkMIp9re4b5UyWbDp5wUqtJGB5pr1X/TmNOrpvzQOPOub0bO6jTqbUuKPrX6DKqU1LGxRpnnVh1IjQKYz9YzDMDlA7spJWu27nb9fUxri7ap4fIRNWmQrZHVNIM6bTZd0yqqYvPIGklpzpgvrrsiFJb9rrlc5bBoVGzZ8N2QMq382hI/Rsr2alHOSaF0yhzTSmwK5StF7KXzr0GYRsu6N9rYxFeiViwW0d/fj6WlJUfE0z22UbKieNYOtjI6InUBfkHQSfClm6pCyUWKRqPrvCn1MEgMAxDwlG302YqRU6Osm2X789JIa8mQbQ/Kyw0WFhYedqO12VBhsW0dSRzS3KtNNzwSuWMdvjyywtQ0Ioz02dlN52yVaqfPY42DJXmozPhIThZyVBhQo4p6fa3BTLPnbZfsYhEtzUtrF7NMJhOA0pVhysi+Wq06+JFQJCG/xcVFJBIJL+rVq7HR+dY9oePEs0mDTMVKY6fGUlEUS1BVgmO7ypXDOkbcCyJ1lmimxD6t7OBe6veahuPeqpPVybxVztV4ETXRq2jn5+cDxK1QKOT+BsA6jpF2gPPJdC9lxuZxuXZcY/Zf0Jy9IkSE0ZeXl1EoFJBOp1EoFAJBj4/bpMEkA6VWg7CO65D5oD6h1UNpiUWlUslFqMBqtMrf0/dultfpdlgl7cu1UEhUGVFx0gvyCZHPkPmUZ6+Ezgcx+eA6wu96NZ2219QL2xUJeCQdjY2MskbI/F2FkLQTkqYz2n0e3Ud1PK1By+VyjpBVr9fd79jGHkoS4f5o7lBlh+eqW9TCGmTOnfcDDw8Po1aruVIrzo+KS9EXIkJEVzKZTACm1Ja36vh2Ouw50ue3eoH6g+jP4uIiotGo667EignmBIvFIjKZTMAh1es61ajZGuxmytU6QMxdplKpdTB7LBZzUKlFVLT8zN4TTOeOjijn6asuUJJUO7wJ/V7lVhnj3Hc+SzgcRjKZXEdqpeOmXeeU4axr3Mv8sd0T5Trw2RSRUDSCFTw8z5VKxaWNeFkG13VlZcXJFPWnOhztGGOgB526mh0YhYMYDRQKBXflHCfK1pkKRVkCFN/fUvo7jeCs0edXq7hXVlZcrZq9QpHKJxKJBJh7qvx8kVuvh82fKYmFioetPxcXV+8a1vIgLT3oJezY7TPpsEqOeX8lwbC5BrAaYSaTycA9xJ2S6miQuGZajkRFtbKy2qCBZU/MV9Gx4eerMdZ8LRU1I1B7ppo5dpvNW89NvV53Rknn0mis3pvNaEcJRkpMY3qoVCo5pyObzTp2LaMJnWencqTnUp0TCwVqzprRCgBXm64BQ61WcxfX2PpX5Y+wgQydVBpoKvTNnkdTDqlUyjla6sTlcjkH85K5a+F2bedra9L5nGqMtRGOVhh0mpe1UbKvVK9Wq62rXlCDS/nis9NRYr09P0eDICJLvR7NdL3aIjo6lFs6+TTE3DdFsur1uss18xm4J4zAN3Li7Oi6MYh+bzF7YI1gVCqVMDMz42qP1WPp7+8P5EwoBEqa4uHyeazdGGV6TFYZAqs1xyQGaccfjVp8EXIz1ICjVzlkH/lCc2ZKJikUCi6Ko8BQyHx5HL6/dYy2atj8ug8eZlSg179pHpdGAwASiYRra9cK6c83rCNAqNCy0OkAaAN9jRJCoVCgZILKTXkGm/Ef+LUTOecz2Pdg/j2fz7tITEtr6NipDNFYR6NRp6BYfaDIkSq6To2yGlVNifnK1/hZNqfPs1GpVFxpjcKnNGapVAqDg4Prrmckk1zP8kYRMteal4eQ1UvSoeaQmR6w0dn8/Lxr5ZtIJBxpin/LHL7OX+t6WWGgnd7aQRqbRcmKQFGe+Wy8FhKA467QieAaAHA9rTWVo5UuOtde6xtGyr731n83Gg2HPlhnhwaZvIOVlRWUy2XnSJAoTC4Lz0Sroye9rJvBixRebcJPz0iNqyoqW4enhtuSFnoRzSmMQaUFrDoStgZUo2NraNWbt2vQbHSqrPi3fHHtFFIiTK1s4Fqt5i46J9tQb7+hMaZDZA2BzncrYHc1wKoI1KCqHFDQCRtFIhGXV85ms66kpNMImbLBAwogEKnpXicSCWecbEtEztEaj3b4D50M63TaM0oOwcDAQKDZgzJHC4UCwuEwqtWqiwYINWoNtcLW6pB2O391cH18D31OG9VpZMccuDaRUViZ5S+soaVDBSDAEdhIjnzOD9dYHTFFHzR/rPBnOp1GNLpWy873C4VCAaiY0T3TJzSOzcrkWjm3Noq0uWS+uH6ZTAbDw8NIpVIuCta+B3Rk6/W6uwGKz82zw/vQNVXSy2Edfn6v58+isxulH5SgOT8/796zr2/1Uh9f965WRk8gazt5FXw+oDYd0HIjvVxC2wT6FkZzJr1gEuozqFFWD2mzz9xo2M22SrhbodODYglcfOnlCwDcelMpaZ5JjTsN8EYpgl7M36YqdH0UwmvWQUchVUZIsVjM5bhsHrnddVe4WOFTLRcC1tIAdIjYBERhVULFvtTLRp/fzbDwsTXKdMx8dbs0QMVi0eXT6GA06zC1srISaNTTqwhZdYqeQWXmUmY4NJWja65/q/vG/DPfU6NdOnabIRmUEdVZyhDXYEP/rdFxX1+fQxWTyaTTm1xjPqvvEg1t1mJRxHbkxerTZg4y14fOAA1yOLxG6CVqQR3P94tEVksWbbVCr4fP6deXzSOrI6dd81QeuQdE4JhGoCzxmR52gwysb+SgbE56mazrKpVK65jM9IwIy2iuhIdHoSZlBPeKCKBG2Xrl6pmrkCsUodEqFYG+NqvBbndYg8TyJl50wC462kGHJAQVfAqiQt40bDYaUSPSijHZbP6W6KRpC4Xx2Fhfn0U7R9GI8G+UaKQRRqeHXWFKVVSqaCmj9jID/nt+fn5dioPrTCew1+gDhy+1pFG//Tnz2arMFVIFEGgZaFsH2kijEyeIX316pVlHOe6TspUZjVlnUtENwst0TmjQtFyHLRdbTS9oOqOZ8tcyMpK7fDC8Ro0MaBhZ6h3PequVr+tVJ+kO39/aVJLqSK43dSn3Q880119rf62j3auha2hf1ilSva2IIwmxRE+UAc994blIJpMONaL+4bloZfTUICuZIZ1OO4+Pm6SdtxTaoRHgBe9KxlGDrExCrZ3tJWlKIyKrDGweyypXLr7t3qKCx83Rz2h32AiSBkvvCuZX3ojEshYt+VCiBp9nZWUlANFbmF7hWqAzFrzNS9ER0EiLB5nKis/HJi2sf9TCfh52GmtVoN0ecpULAIHIh94+IUSFERlJ8rIGi0rwGfke+nm9NMwazXPYqMfX4UoJaDyTQDD6tP0FVLl2M99mBllf/BlTFmqEKUP6ngCcY6pONUl5LOuhUeMtaspF2Oy51LnhOutXGgQ6EJRXygkNgKJblGcNTnitLW8Qs13p2g1WNC3VzPneyEDzrFkCGNeOqBwdC8uh6HV0bPWkBkVqeBXdUdKl1iLzli2mALV+HFhDHpm60sCS5YStjK4M8t/e/rd4z/ffgxPlEzgrfxZed/rrkEqkAtdcMa/XaDQcM0+jlpWVFSSTSRfZkFXITQTW+vrS0PtyJO0YhluO3IJrv38tfnzsxzhePo4brrwBLzjzBU0FUg2y1vfp5zLK5CFnDk7ziTx4nbKv3/2dd+ML93wB90zdg0Q0gSfvfjLe8qS3YLA+iEKhgMnJSRw/fhxTU1OYnp7G9PR0gD3LQ6AdpKrVKkKhkIOTeOWeEkbU61alaBXPRuO6H12H626/DofnDgMAzh45G1dfcDWetvtpTugV6lF0gXk1NiXg4VDPlQcjHA6va0PYq4POfbr2B9fird96K/73E/43/uSiP3HQZCKRwOLiosshad09FX0ikcDc3JyLdLjuzKcpxN2L8bZvvw1vv/ntgZ+dnj8dd/yvO5xjpoZXOQjanU6dZCB4X7k1xlbJdgNZn6icwFu+/RbcePhGzK/MY09iD/7X0P9aR85i10DqFZ+SV4PIqB9AAC7ms6tBHhgYCCA4rcjTgfcdwJHCkXU//+0n/jbe/5z3B9aQ8k2HUx1qyvnCwgIABCJ25m7ZYMNnkNuNjkOhEOqNOt5+y9tx/d3XY7wyjpH4CJ4+9HQ8IfSEdSlCdZw5R02dFYtF93Ot0++F0+YbpcUS/vBbf4gb7rkBE5UJPG7H4/CeZ74H5w2fF8jfW4RNnVA60dTf2rHR6hwS7Fh3rFwSvnc4HHa/t9no2CB/+u5P4/e++Xv4m0v/Bo8bfhzed+v78JrbXoO/PetvXStD5iI1hLf5AjKZAbg2j3oBQii01kpQSUiaJ2mX1l9ZquC8HefhlY97JV74mRd6f0ejeK3hVONs4Ut6u/RibZ5NjY16sK0K5c1Hbsarn/RqPHHXE7GwvIC3/utbccVXrsBnn/ZZx2IfHx/H1NRUALJWD1W9QMKozA1q1KH1qnR++AKCNdetzH9vdi/+7Jl/hoP5g6jVa/j4nR/HlTdciW++8JvYEd4RuGiBuV8tE2LrPj7b3NxcQOgpK5rj8injbsftx27Hh+/4MB47+liEQ6vlLEzB0CDZQ1+tVh0LVVEWKmPKu+anbEqjGyN9zsg5uPFlN67lWOsIGGMqGCoh5SAUCoUAXEc54pravGgvnaDCYgHP/D/PxC/s/QVc/9zrkWwkcdexuxApRlCJVwIOo/JOiDbQefDNRaNWm8OtVNbemzKmsGor40e/+SPUGmvR+d0Td+NZn3wWrjj7inWcAu2CRUd6amoKMzMz7lwwN0/dx0salBHuqz3uJK107Q+uxYfu+BA+9NwP4eTUyfjuA9/F1d+/GvPJeeyK7FrHoSDPg/C83oDG4Iq/y99pFm13O171lVfh7om78fcv+HvsSu/CJ//tk3ju/3kuvv/S72OobyhQQaCGl+dQq1Tsv/my0DUJwKFQyN2RrKm1aDS69Qb5L2/9S7zq8a/Cy897OZaWlvCeX3wP/uXIv+DbhW/jotRFSKfTzosgNKm4PbCWmyOUTeWl0Y52oFECg/bibVfoLj14KS49eKn3/1RIfJC1OgHWGdA6SGWtKmRtc20Wztpo/PNL/xnAWu7yA5d8AKd84BT8+/S/Y3Bx0AkM4VxVpACcQlGjTPIFGymo80HDqNAkowtFCFqZ++VnXO6es16v4+1PfTs++OMP4rZjt+GZQ890Ddwp8LpeytK3eXGSuXjYdV+6yZ/5RnmpjJd84SX48OUfxp/cshYZ8zM1N6VKnpe4UwFTPjTnx65Nmovr1byj4Sh2pneui2ZogKiUfKRAOkl0ohU9sikbdYJ6AVdfe+u12Jvdiw9c+gHnlA1gAMdwDD+L/yygDywRlLXIvny2foYaR+VSaPlXJznOkdRI4N9/9t0/w6mDp+IXD/ziOiiVz6b9rgmP8gwzsmSjDZYr8rIPG6B0irSEQiH84KEf4JdO/yU897TnYn5+HrkDOXzuns/hSPkI9vftd06lGmRtI2yrDGiUGo01kqimv7rlo3DML8/j8//5eXzpxV/CU096Kur1Ov7wF/4QX73vq/jIv30Er3vs69a1VKWx5RkgWqc/V+OtzXz4AlblKh6PryOgKmrUyujIIC/VlvDjYz/Gmy96sxPqvmgfnrLzKbi3dC+eNfwsl0PWPKcSb0gGYN5SiSPMMUSj0XX9fGmcu4FlWh1aHuKrX1SCiWV3ap5Coz2fQe5UeRUWCwCAbDQbIEGpEqLxBOC8Uxpgtv1k1KNEtlgsFiAlcO4+g9wqg5CjVq/h0//xaVSXq3hM/jHOEFAJ2ZtX6NGqk8E8s847Fou5iF4VVC8OOwC8+muvxvMOPg/PPOWZ+JNb/gQIIeCUWPIOkRAAge4/Ometqwbgopxedi766cxPsfsvdiMejePJe56MP77ojzEUHQrkLamcSJ7jVyIWmr/kWus57FXFg46v3vdVPOvkZ+FlX3kZvjP2HexI7sCLDrwI/zP5PwONPBQhUWfapi0cQoBgO0U1YL3WIcCqvvzUv38Kb7jwDc6IqYxoZE6nWiNzPXOMktlDwHfxSLc68cK9F+LDP/kw7p+9H/uS+3Bv4V7cNXcXrshdgb65vkAXKmCt7BBYu2VO5035sP3Z9WrUXsjOSn0FtUYN8Wjc/axeryMejePWY7di6cxgSZ8aX3Yf0zOgCJyPC0QdBGCdfNnR6jN1ZJCnqlOoNWrYkdrhPiwSiWA0OYqfzv7UdaPhYdCCcjUYTHir5xQKhRCPx93h58096XTakbm2stUah80fa05VO/jwTmXC73wWIAiHWQPdDbTHPM+bv/1mPGnHk3AwdxCHZw+vY/nqfJijpAHWQ0S4lGvJA6TCSNiVHqEq4Va9v7vG78KFH7kQCysLSPen8bHnfAwnJU7C0ZmjKBaLAShaO1mpUZ6fn3esWRpbll6wFnN0dBT5fB6ZTKbtCxmajX+8+x/xk+M/wY9+80dr+4Cg0tPDSWVEp0GjCT3MNG5kNlt+RLfy/eQ9T8bHn/9xnD50Oo6VjuEdN78Dz/4/z8aNL7wRS/NLTinp3c6MjGkkaPA4HxKK6PioE9Ftwx4dh2YP4YOzH8TvPul38Yb/8Qb8cOyHeOt334o3nvVGPDbzWAwODrpyPtaKsm6ajqVdb55Nzk1Jmup0a/9wWzrU7jN98Z4vYm5hDi9/3MvXoQgalWv6gOgUHQ01aHREtKlPJ1dzNhvX/MI1KC4W8dgPPRaRcAS1eg3/z8H/B78Q+gXcP3O/40owJUnEhbqdDhwDKgCu1SpvY8vlcg5qZ3DTLhfIjkwsgwv3Xoh33vJOnDl8Jobjw7j+7utx2/HbcCBzYJ0xVnmnQdbWoLZzHY2tyo5WWPiCRVszv9noitRloV0uKI0xh0bIyqJVWIkHQglFZBEODQ0hl8s1hap77dny/TRapOFdWFhwHXyolGikaJw0B6iesO3g1E2E/Lvf+F381/R/4QuXfQHR6hq7PZPJOENLY8XDwj3SXDcjH1VUfG5Grtqthu+tThFTDJuNM4bPwB2/dQdmKjP4zH98Bq/519fgQxd+CJFKBIVCweWHbZ5GyxEI/6jXTc5CJpNx/Zl37doVuGe1G8M2VhjDa//5tfiX3/iXgPfN9dLvlcSk3zNto7kqKjRlX5Ogo11+upFxpmYajQbOHTkX54+cj9OvOx03/PQGPH3g6c4R4mUBaoxVIYVCIZcvZzMROsuM0ixsqjLV7qDT+YTdT8CfPv1Psby8jHOHzsV/TP4Hvnzsy3jaWU9DtVpFo9FwRKdkMumUob1Wj2eTeok6iw6G7dplnQ2FVtsdH7njI7j04KXYndm9LqLS1IZWZmiEybPG3DGNGQ0b7/32lTp1Mj7zH5/B9Xdfj79//t/jYO4gfvjgD/FH3/sjRHdEcSB1wOk86gIaYIX/ATjOicLszHnzrNIoU7a61eef/OVP4pVffiX2vncvIqEIHrfjcfjl034Zd4zfEejVoJ0MlbiozHDVP0CQV6TlbCydY9RvibAkfLUyOjLIw8lhREIRjFfGAawJ98ziDHakdiCbzTpvlMqfXoft3Urvmx44LzzQTRseHkYulwu0d+xFl65mQx0NNg5QZjKjmHQ6jVKp5CjthMnoNSkD2+ZiO8khc/zO138HX/vp1/CNF38Dw9FhTK9MO4PJiw5sIwSdF40DAKesLIGIf0tlwPKPWq3mlAOVQKsGuT/Sj1MHT8WB7AGcNXAWfjj2Q3zqp5/CFYkr3OXnU1NTziAo5Mi58Xm0/adteZjP5zE8PIyhoaF1F593Mn58/MeYqEzg8R98vPtZrVHDLUduwV/f9tdY/INFRMJrXaOsUda0DRUCUYBQKOT6/LLdpxrkXsp5o9FALpbDqQOn4tDsIVzQf4EzyLxKT5WT9k3m/BjpDA0NBQwCDXIvI7VdmV04e/hspwQbjQbOGjkLX3vgaxgYGHByrrwS7rNNDTDKoQyofFtjrOe722c6MncENx66EV940RfW7YV11pX8yTPF/eczMjpWg6x9q3shK2/8lzfi6ouuxovPfTEWFxdxIHkAP534KW44dAP+YOAPkMlk3Pzo7DMdyciRepPBDHU55UYva6BB7oWsn5o/FTe//GaUFkqYqc4g35fHS774EuxN7Q1Ex2SAq0GmI6Rlkxw0wGyWw6EIHZ9H2y3TILeqIzsyyP2Rfjxh9xNw06Gb8Pwznr8qpOEQvnP0O7jq7KtcGRNh6XK5HKDkaw9XjdZIXKBhsULHB7REgF5Hx1r6pJA1vR+SSfiVHqJleapSVsjaMoFbNcaNRgOv+fpr8MV7voibfuMmnJQ+ybUEZI12JpNxULWWgRCqVuhUn1OfX7/nOpPJGo1GkclkAl5kuzlkp4wadSzUFlwdNRub2Fp0Ond8TubSeNDz+TxGR0cxOjoagMRYCkJFZZ+v1fGMk5+Bu377rsDPXvGlV+DM4TNx9UVXO2Os66Z7qmQuJYeQL0H5UgVljUCvRmmxhMOFw3jOruc4iJRsatZ3M3emRoFnT69sVEPMqEDz5N0q14v2XYT7pu8LOGKHi4exL7vPOYdcPypMwr+2GYXWJVOeFGK3pX6a4+yk4xXHx+78GEZTo3je6c9b93/KUrfpLHWEqIO00kTLP3sZHQNAdbmKcMhcUxuJohFqOD2jxksha+oKK9O5XA6Dg4MYGhrC4OCgkxt17LuFrHVdU/0p9If6caJwAjc/dDNee/ZrvexqhahtfpiyohC16nbN66tNsNEx5ayV0TFk/YYL3oCrvngVnrD7CXjSrifhvbe+F9XlKl56zksRQ8zBML6DSphTjYV2cNE+y9pruVfs2fJSGffP3O/+/cDsA7jzxJ3IJ/LYn9u/Lkps1pjA1iarY6GH19clpl1jDKySiq6/63p88covItOfwXhlfN19pFQoluFtS0CUkOBbQyoLrgGjkXQ6va5HdCsG+Zobr8GlBy/F3sxezFXn8Ml/+yRuPXErrn3stVgaX2uMoOxGLbxXhjuAQJRMOGxwcBD5fN5xDuxB71RJZWIZnDt6buBnqb4UhhJD637ebB3VKGv5G2VIOQnaaalbCO/3v/n7uPz0y7E/tx9jhTG87dtvQzgUxiV7LkF5vOzK37TuUm/zUce3GZdCa9R72T3v9Re8Hk/56FPw7u++G1ecfQVufehWfPzuj+MvfvEvkEgkAgjK8vLqhQY6J154QTmykY2FIPU8a819p+TReqOOj935MVx13lWIhterWo2SLSOXZ1MRIV1//eoLUroZl59+Od71nXdhb3ovTh88HT988If4xH2fwHN2PCeg/yifrKW3zVasE6FOnDYy6RWpCwC+cf83UG/UcXDwIO6duhdvvunNODV3Ki7fdzlmp2YDZ1BrkflVu3fpsPNS2dFUq8qOIjRbnkO+8twrMVmdxB9/+49xonwC5+04D1/81S9iV2YX5ufnA20mlQBkhdqWJahHuBV9q4HVWtKLP3Gx+/cbvvkGAMBV512Fj7/g4wD85U+aK+cBttCYL6dq80adGGMAuO726wAAF//9xYGfX3vRtXha7mmBsiUlpDRTkBvlNZQxyL9T5rNlc282JioTeNkNL8Px8nHkYjmcPXQ2Pv6Mj2Pf8j781/H/cg6c1vBSQfHzGQUBQQZ8MyPRK+i0m6FKV8v+VEYo8/rqFT/ioeJD+LXP/xqm56cxkhzBBbsvwFdf8FUk5hMo1ovrmnto1EOHTJ1lX9c65Y/0ErV60p4n4YYrb8A1N12Dd97yTpw8cDKuffq1ePGZL3a9trn39jY2zpVz4kudUJvW0edoVj7Xzrjx0I14sPAgXnn+Kzf9XSsneqbsHFVefGvf7Xj/pe/HH3zrD/Caf34NJqoT2JnciRef9mJcMXoFJk9MBj5TeTLLy2ttYDlvRR8saW4ryLmFxQKuuekaPFR8CPl4Hpedehle+5jXYqm4tM7xUT2m+sxXv26NMeVHK1PsXthXK6Mlg8yJ8b5Zjped+TL8xhm/EYBk2VWGcECzBvRUwBREFk9rDS/hBW4asMoIJknHPiTnpwvpm/vj849H4fUF77MWi8V1JQmMIghv2LIcbYjAw8SN1TpaNhxQGHJlZfUu3UqlsuncOWc9AGwqMD09HejhbOEXCpxtbmAPgS3dUS4AWxRqRyMyWTeb+3svfi9w8Vo+vVqtYmZmBmNjY+vyN4x81CATRVGDpm0bFY5i0xb+3kaGuVWZsePLL/zyhr+j+6MsTiXtKHeCcqZyzztWmbe3SquVuX/o2R9yP2fNa7lcxkR5IsAs9bXJJDNfjbXWhvOcVyoV54i2msdsdd2fuvOp+N5Lvhc4k0RQlIijcq8cDXsmbXmiyhTPq2X1M+/f19fnmrnofJvJywUjF7gzq8/LVJ5tj+krq6FTpOtOfVSpVBx65YvkW1n3ZnN/x4XvwNue/DaHXM3OzmJiYiJwTSfXW/UN0waqy1W25+fnXRMNNeK9mvtz9j0Hl1x1iZMV6kddZ20G4lt3Gyz59KXOkbKv0Tf3lvrRp999oyWDXCqVAAD79u1r5dcf8VEqlZDL5dz3wPbcH46xlXNnnnliYqIn72fH9ro/MuPnYe6PtnkD23N/pIbKu2+EGi3gjfV6HceOHUMmk3nEoL9WRqPRQKlUwu7duwMM5+25b+3YnvsjM7bn/sgMO/dHy7yB7bk/UsMn777RkkHeHttje2yP7bE9tsfWjpYg60eLJ/Lz5HUD23N/OMb23B+Z8fM090fLvIHtuT9So9UIGY0WxtjYWAPAo+Y1Nja2PfftuW/P/VHw+nmY+6Nt3ttzf+Tn3my0FCFnMhkAwNjYmOvC1WgE+7CWy2VMTk7i6NGjePDBB3Hs2DHMzMy469tKpVKgxRrvGmX3GbY/zOfzGBkZcR2L2KKSXzfqmlMsFrFv3z43X9/cdfA52LlK2dRs5Xj06FEcOXIER48exdzcnLumMJFIuOYTbAm3Y8cO7Nq1C8PDw8hms4Hr9jYrndho7g8++CDS6XSA9U3mZ6FQwPj4OB588EEcPnwYDz30EKanpwPN0bVmGAiWdFk6P6n72vmLDVoGBwexe/du7N27F/v27cOOHTvcNZunnHKKd+4PPPAAEomEu1pzZmbGvebm5tyVc+zQxUsY2NWHTQTYN9bW8eozsV4znU677l3Dw8MYGBgI1LLrnpTL5bZkxg4rQ3zO2dlZ94yzs7OYmprCxMQEpqamUK1WEQ6HMTAwgH379uGUU07Bvn37MDIyEri5xzJn25GZI0eOuN7jyobmNX/Hjh3D4cOHcejQIRw7dsyxQLUtpm1pqCUcvmYa7LDHawG1wYmV/3bOKnUNWbOlUilwZzDXenJyEpOTk5iZmUGlUkG9Xne1sNQhnOPw8DBGR0cxNDTk6mLZqEIvjvGVctm5+/SjsrYXFhYwNzeHY8eO4dChQ7j//vsxNjaGubk517tA+1QPDAxgZGQEu3btcvqEcmHrXNstx9ps7nbNWd9dLBadDLO9LRntfFaeAcqZ3ppEGeTZ7e/vRy6Xw65du3DyySfjpJNOws6dO9d1gNOOdaVSCfv371839wcffBCZTMbNV5vwVCoVzM7OBu6KZ2c6ns1isRi4QEVLKbWFM3sd8Ovw8DBGRkYwMjKyrgmRr5rDyrtvtGSQ+eY0kHpA2I93aWkp0BGJv0OFxTo6HmweZu14xYOgtW2k/bN5gm2R5xNIa2Q4d59BpjCxCxWwWp9rb9wB1u481uv1tDyh0Vi7XUkVlu9mqs3W2s6dytXOc2FhYd3h1PaALBHz9WO1ta5af8lWcGqQ8/m8U7bsiJXJZFwbzmZzTyQSrnyH3aG0FlHXl1/pfNAA60Ulel8y9zEUCrln4n7wZ9Z42BrldmTGyo/K0OLiIoBVRrg1olxjzoX9fdl1iZ+Vy+VcLXWrHaI2kxktG2TJme05bZ+FxkTXVesvKV90lFT2qdD0Mphm57WVdVfjYMur1PBp+aH+LT+zIeUyPPO23EV1jd6e5KvztbJD/WgbwdAxsM4V58ESILunvqZEttFSp/W7vrn71pytdbW0SufPlw42KeI99ypDLKPjOdAb/NS500CNbTXtnK2sa2AFwJ1Jyqv+HeeYTqed/eJeaD8H6gvb00E7dNHZYy/1jfZks71quzGINcbN7pf0XSPGw0FDoReM0xAvLy+jUqmg0Wi4A8N6SJ9QU5l3k0PQgnFbD6ov1qyxUb0a50gk4i46iMViTslpkT+fvd25quBrIwd777J+T+VCQaKgMYqkwGnTE77U+cnlcoG7VwcHB5HJZFzT+FaeR5W9rrWvVlG7nQGrh6pcLrumIPTKWSera8tDTsXNRiE6T9tgoduhRsxepcdeuYwU9Axohx97U1Kv+lerstR6W90LNcBs5UmHT51sXT91qHkHLJ0jIhQLCwvr1t7n/HSyzlxjtlqdnp7G3Nyc64GuZ4A6hc9CdIIGmq19WdusclWv1901g1yTjWRG5VxlQds0am8Ge6sT5ZJ9CTgv3n5nHedeNmFp5XlYV2+76dnzTTlivTF1DedMHaroi34GX622F1bdaPs+8LIU1iHTHvGCGgZ42rmQIxRa35aV3d+0t0Y3N/fZ0ZZBbiZw3CD2xOVF24VCwXu5OQ+zekJURsvLy06BUTBXVlacR6NemrYW7HT4YGtupn3xcAFYdzMIDSANlRprdUSoCDoxyrbZCD9f+7NqYw2utbYEJAyjsJwaZO3tSyhPG8HTQNvbiNqZvzXG2oCB68vnY0MGHgptGKJRBeUhHo+7e57pdat3rU6KGv52h08J6b3NhOenp6edPBNJsJCY7V5kOyF1O6xDpy9t0qPrSujRzkWjA73flpcLUG5465g1yJ2uueod1qSzWQUNMi8KoKHg4J6HQsGrRwln8rrYcrkcaNjCvY3FYm7fWjEO9mYvvUVL755mC9pQKLSud7IiDHQG+Lm2i1gn+qSddVfonY7m3Nycc2B0PRRFoXNhER82ViF6ovui+qudToDqrDGdR4eY8LQ6bWwFnM1mA7c3cR84B643DfL8/LzTq7yLWxHSbo1yVxGyCh0NMheBkcHCwoIzVtqP1d6sQkhTbzsJh8OoVqsAELjeqllup5Nn8eV7mnm1/J7eFf+t98XSUClEZqP7bubKedpuZmqYCb/oWjM/pj3CedhpWLV1o+4R/0YhVsIzmxlkNQY2wtfn0Px9vV4PRJnamN2mQpjSoKGIx+MOXuP8aZS5HnqAOh3NIn7Om/mpmZkZF03wgNNx8BlkdRZ6EQFZ42I7U1kHCYA7182iY0Y3Gh3TAUwmk+688Jq+bvtbW86KGmTmknmPNiMXdYL5dzRemosmGjQwMOBu37LoiToim8m6Omjq3Nu7d9WBVsian18qlZwh03aUXHumHHoRlbWz7noZSbVaDVxSw4iSc2ME3Gg0As4ZO54pv8B3wUarZ5Tz1CCRTrF9FYtFh8LxMp50Ou10tgYIescznVQAzi4pwvGIR8gaqalhoHFWD5CkBSUvMDKmcg+Hw24x2HoMWPV8EokEBgYG3EG3V3V1E0nYyMEaCb2kQVsLhkIh17ieOSLeV5pKpZyxUmheDUk781XYUdde56VtAwkf0glSwWO+UslDCpVqhGxvMLFN7duBVn2wtY3QdH0JbW1EatLBgw6soRfaDlJzhZyDrm07Q3ORCrVZB1WJLXozkebKuK62N3Qv4UidbzMDTZnSKIHQohojjSa47pQbCzf62lfqPFqdu3XmuM40DPYObe5ts975hLPpzMViMec40ImjIdRexZT3zebbLKWnbSdVLm1OXtuSlkol5zDTeVCSZif6pNV11+dRfW9vLAOCeWIGYJR1PhvXBVjlMlD30NFT2WwFqrZrbpFb6wzRfiwtLTk4nVe48gpUEpR1HnpGmFqgbtGfP+wGWRfAZ8R8r3q97hQ8o6tMJuNIBCT9AKubxkiCSsFGqlS0Nu+pEWmnz2MjHr54yBXKCIVCjjwQiUQCzkipVHJGi9czWi+qE6Osh92+FOpRj5VKRm9c4bozelSjrPA2nSclQuntVq3C1a08ExWMVeaUs41IQb6G+0pcs+S1XgwfBE8lq4qA0RBhSI3c9Yo/37WFWwVD6jM0e9FYAWswpDqZlBsl42iucKugdl+faX7VaM3muXnzF+FRhehpFAkpl8tld482ZT4Wi20KS/qMmKJBqk8UlgXgnEklhNEo0/mlc6365OEaPv1o03HWeaFOARDgS5Dj4eNP2BRJuzrS5qKtA0QjSqdSHQciFXSQLJrEuaiOssb4YYes+aEbedoqaHxg0vktQSiXyzlhp0EmM1jxfPUa9bouJTg8HMMnJIwqbIkJiVFaHtCNR2u9Nsvy5h7w/ZU4xLkobM2yFB+5iEJqDbYemnaMnOYQrSNFZ8rmovRQW4aj/r0yUePxuGOGs1RO71zV9+HndLoPNnWjkZC+yHQmiYTzY2kH0zCtlMd1MpRMZclVvv9Tp0vlQqF1GjofU1zlqhlc3erzqTG2Bk7RK8q98iBYKUAUjigdiVwAXErNXpZRLpcDa6DP08qc+bWZruS/gbWUFteaZ4vQOteP+Xk1yDQgWxEl61BZUVRE569rZStuFFHp6+tzfAOeWUWMbAVAO7LSbL1tgKWXpvDnJABaUq8y8Gm0e2mEdXSUQ7YPb42yGp2+vj5kMhkMDg66ur+BgQFnlFljyrzs4uKiS7zzPQif8aCo0lKPqpuFscpJjT0JIfxeDZ7CWJpnofIi41Q9yk6Gz/u2DpD1mBXm0usJ9ZpCRmhqeH13PW9kEFtZVzXIFgJUSJT7qFC/OgT2Wkk1xkQDstkshoaGMDo66uqQiQjYMpxO9kFzVgrjaVSsjE4qn3g87mrtWT5mCXK9NsiqRJs5Q/wdrqmF01kfS8dGCWlau2uZ+D4nqNVns+iJZeDSySUKR4VJBIKokNaOxuNxAKspDbKzAbgUD6FO5m8pU4oUtZIr9Dk5PueVss4zyjkz98oSNa4FdSmheatrezms80QjrM4Y99PHgbBRpqYSlECqZUN8MWWgzkk7Tn+zQEEje8oRnWj+ruVDLS4uumBRyaXN1qrb0fF9yIAf8uKgwWIp0NDQEHbv3r2u6UcikUAkEnG5ZpYzUBAptGTPlUqlgFKnV9lpUt1uoBodCpZCcSogCvHSc6JB5iHTYvh2WIMbrbfPIFvoRIWexkANr0bFCuv5csq+Wsx2hdBnjLX0SqNd5vWY+1bPWefO99AIjlGRNhbJ5XLukHdy0H37oKxOjax8V+rxcxi9Dw8PNzXIWwFVW+NrjTL3BIBbdzowNsXEi+W5JzTGarj52ixK3mz4SHP2pY4onQkqeDbsGRkZwdDQEFKpFBqNhksp8b2pkBkpU88oTEwHYCPI2hoxXWt+b79S1im3JFoyQlYGPJtpKHFzK/PI9jm4Htx3G+FquaemMzUS5T7pXhFdIXrKdVAntZ35qnxbEi0daeaLWRVB/cmmUGTsMy3Z398PYC21oGvTy9FTyFqFlQdd4erh4WHs2LED+XzeeUJM+hOGKZVK7pArpV5Zi1ok3wuGm43eNJ9qoTsKmioyCgwjecLuSiLoJudjo2Obx/ERZ4C1mkmrAKwQWdjXF4X6jHGrwmjXV42yrqG+P505ZXhbZrLmum1nnWw2GzDGPOA22m/3QHF9LdHFst31jl46HURMGEnS6NluVr0aPmTCci/4eyr7VLg0xIx6LdKghD+bT1bHrt08vuqRZrwOiwwRjWMAwHlT54yMjCCVSgEA5ubmUK/XnQNF9juVteZu6VRZMmAr625RNmsouCYaeTI6prPHKK5eryMejztHz5YH0aHYCvmx55ZGmZUcPJ8MTKirLSTMAIuBAuXMllc268i40bPZNVf9YhEx6lGuLdMWKhMkgZEFT4RiI4esF6NtljWHL29sqep6SOh1Dw4OBppLsPCexCgeZipPPZjKZuVGWy+xk2EVks2fKtzbaDQCLGaFyxW+ZA21wku9iI7pffLZlXRm2Yn0+nwMchVSPrfOsVk01Y1X2AyJsB6wdXY0ovYRihSi1AhJ84e+mutOn8MiFc0anCicCiDAXtdo3RqtrRi6n/wKIPCZVpkpUmLL50hKU2JasxSHfkY7z7dZeoyyzfdVJ07r5omWpNNpZxQqlYqDhwlLAmt6zRI6LSFzo3XWtaQhsy91iOz54nMRfeGzp1KpQLkN9aBGgVshQ3pmlVPAFIGiaNZh5VqqMWbplhpg5bdQ/7ea5vDJrzrryovp6+tz+WJdZ86Za+oLcBSF9KHCvRgd1yHTw7BQEgUZWJ8r0MNM5RmJRNzmKduOL5IWuOi9XAxrIPj5iUQCi4uLgYOdyWRcuROflZ/PiJlsTQpEKpXydnPpxiirE6RsTWWCq1Kht08YTIlIdB6UuKCfpYa5W3jGrrUaBzUW/GyNhBhlMrWhf0+oVYkxGi1bFmevcrR2L3zsWTUivnXQNej1UAPYLEq2bGh1epW5r6xmLcHRkhY14L1IcXBsdFY0JwlgHaqieUk6Z7XaavtWjeqppNVBtJBnO/PfCBGyyBPlhOvOZ2E+m9ExAFdpQn7C/Py8i0otMrYVRtk+C4BAgEIdxAiZyCDnpHKTSCScw8T9IYqlaY5WZacZrL64uOj2XoMqzglYyy0rm1rTdcAav0LPvQ8Z7nZ0bJD5EMpGs7lSPjQfzpbRkE1NGrpuOL/WarXAQfEddm5IJ0KoB4f5Is0PMsnPDlKRSCRQfkWvyeZyIpGIY0T68k/twks2OrZRmSIU/Mr8SDQadYxNhRr5fSaTcfDM8vLyOuFVY9ntsAbJt2/qTNApo6xR3ohaUJGSRwCsh+h7FeFz2BSCIhYbISG+w2yNdy9hRz0bzeBSdXT5HJQbAAGng41P2KGI59YaGf38Xjpy9v10XxV9UEdanX82hfE5/kCwNEdROh/svtl8FapWgiSNjSp7rjnzmgACdcoc1EWscWftLPfPrk8v5EidaH0BCKRsrBwpwZfnU8svFTFVjkc3xpg6PJFIBEqTWFs8MDDgEAW1OcCac0d2PgDXZ4JnmQZZ0RNfoNXNGe64MQgNgtYJ23wp2Wkavaiwa/7VKk+FMjWnq4eoF9EON5NUfBVshaHJAu/v73dJf81r8XcphDTcSsCwxKtO1l1rMW1zfIWoATiyysrKiuuJq0pHmailUinQSIPRBICuYV79G/s9/82DwefUDkaE2G2+mIcaQAAKtmz2Xhtifu+LkvnV7rFGQdoEhUaCB5vpD7tWnQ41xs2gU8q6rnujsUakZE6tUCi4aGZoaMgZDx/SwVcnyqlZdO9zsPTflAFtfkPnkrW8wFq0o0YzFAoFau/5srqqU+jUyi6NPeWHzj3XTJv88H1tS0jWzvKze5GOafZcavT4GSTE0XFQA6xwNvPFaoC1Ha+vVK4ddMVGxwwkdJ5ag9zX1+eiepujJgJEPaIBEM+1Eu0s273b0VWErO0bNa9BBWOFxWdU+f/2ECpcZPOIFhbrJjrW99doUJUShU1hPt1kH6mqvz/YgFydlE6iITUA6qHZPDoHoSI6CvPz84F1o5LgdXsUVn1/YH3bz3bXuhXoVOfF36d80UhYZjahyZWV1Z60Gvk3WxOdT6fD5jMtXM1Da3OOdGK10xEVNB0OOqBWrnsRYTYzymrYqGzo1Gk703K57EpUGCEDcIbGlrBp9NGpUd7IGFuFrcqYuW0tyaIM0+mxqJu+h6Y8fDWxrUTIivbZGm7tzmbLcAAEzqGiENwLdigjAczqVj1HncjORnlSPh/nx6sZy+Wy0xnRaNTtAedGiHpoaAhDQ0MBcqBeMNRpqoNrTja02g8iENrDWmvX1Q5QDhTOZuUMm89oQGRRsW6j5K5aZ1roVA2EXSibT1EylG/iNvem3qz11HoRIQNw+RgVOoWOONdGoxHouU1lq/kgXRNrOHtRqqAbz/WxsCEdBypZLcHRg8LclXWSlMzTCya7jRx85Dmuq+bn+ZXRQjQadXlvGmOyT60T1GvihT0DzVqtWgII5UCbxqhzyrNQr9cDP+vGMKtStudPYVQaVOY0AQQcC80jc31DodX6U0aifA91mmj8eu1Q2DWxz2VreptFXep8K4fEvlqJjn1z5pw0QqahYjoLQCDasmRZi8JoGo210hrZK9LSzbqrnFudZVN6dBCIrrCUlfNTUi+vcc1ms+ta93bqiOpeUn9pUMVzx4uOiF5qUyX+jRITgVW9w/sUNLDyVbn0Iu3UcR2yjab04XyTajWXp8JnjYAuerfRsQ49mPx3o9FwjN2FhQWXy1F4j88Vi8Xc7zCSts+i8Da9Xr5XK/P3RZe2xpgkBmCNaEYDxs/QXCFhHR4qeqrcJxLTLCzTrsD5FBSVptYh8pYmzc1zKC+Bcsa1Z4lCtVp1HrePTNfNsJGxNgXRGmSmKdQpAFYPM3+nUCisg+ip4Bgp+MqGdD3bXXub07Td29LpNMLhsJuvfg6dUzpEmo5hTlPLyph7I+rU6VA9oUbZd141wtW0WKtwsz6vD0FoNeK0662OJ9c5l8u5K2a1tEpJsXpGQ6FQgG/DJkmUFeop/r0tz2lHXtQIq17XwMt3fwH7LdAh45yZQiA0rcxqbdJjoeN2h0a5/HzqOuqYXC7n5JJltbZhCfeKiBVJpNrdzWf3VM9b2WxndFyHzImpktLN1N+zB8e36GooNOoAsI6w06pxb2foYQQQOESETn0KLZPJuBtc+GKUo9COZaLrpqmybTY3dUYINdMAM3qhMmXUq/CLCqiygcPhsEMC2C6Qc2IE1E0dtRpjMtg1rcFaQEbpdIAsQY0Gi3KhRC+SkNiIX/dGkYtOvVcbLSjBiZAdb3eam5sL3ANOaIx52JmZGYRCocDcCWFTSWm7TyoHzYm18gy+6NiSK3nt4ODgIAC4FIDusz3f/HztckTjYOHZXiBBvujYZ4w1/aHRvw9uts6VRZuaORGdrjnXO51OY3Bw0N1olkwmHWHUpvyoA5XHwbO5tLTkmg/RuDDiZpRMncKvray/Ncaqh+2d93onMuevDGWev8HBQeTzedcEhzl9my/upQ5XxFMZ3RrBM4DSINJGyEQVw+EwFhYWUCgUEA6HAxUtPsNM56ITlKLj1pn67428YN+EfMZYFZ2W4wBrbDf+rTXIG31WO8N6WTR8jUbDJfq1sUM+n3f3nLIVH6/b4zwVwqZg01iHQqGWi815wNXrZD2dKqV4PB648pKHNhRay3vTkBCGpydYKpUArMGV/f39joihgtvusHNndKvRHxvIlMtllwpQo62tU0ulkjsUCpuR5EIonn2kadS7GT6YemFhAcViETMzM5iYmMDU1NS6KxfpYYfDYczNzSEUWi2bI8zH+6W1NIe1s2rQtEPQZg4chy9aU+PA+tzl5WUkEglnKPi8lF0tYwHWuhWRuc9rDK3B7waZsMZNX2qYVZlbh9kHhWowYeFGG2DY+bdyTnUu+vuKloVCIaRSKefIszSRjqlGoIycue42WguHw874zc/Pu5r7doMVdVIU/bH9vUkoY96Y86fOiETWaozZvnZkZAT5fD7QUrWdvHw7Q3W4BgEkf1I/alBAp0edKDKxSYwtl8uuExmAQGBje2NotN+uM9pV60wOX/K/2UR8xtgqO4VCKHCagPdFyb0aakipeBS+XVpacl6uCunU1JQrQ6AXpeUB6mUqTKMEjI3mpJBcMpkMsAQ1h8Y5cv5UTozKGNHxWkB64SsrK6hUKk64qDQqlYr3tqp211PhSxpnhRXZVlLLr6gYuG5zc3OIRCKBPI72HyYHgB2xlOHeaS5ZZVtRHDoCxWIxcC/v7OysuzWIDoHm9Ol1JxKJgFFmvWwul3M59GZedqvyrsaBz1Kv113EMDAwgNHRUQBwaIsaZMLUSmqhUSa8Tdiaxo9r322awz6DPfPWGKuh1ty8Kn19LjXCVMjNiJftyA3nAiCw7zrPWCzmZN1eSGKjUf6OcnNYNcFojJE2DTjJXoqocB4bDZs+saQ+Bh/s81wqlZxM0BAxd8ye7TTIetmLchZ6aYx1D4A1vc3np360uWBdH00PLS8vo1wuY3l5GbOzs25dgWADGcsXsmv/sBtkHZ0qvWYkGUYUCln78ju93lTNRXBjfZA6e5/GYrFAGQAjYWCNMcwDw/9T8s5Gw0Y58Xg8kAPmOjBCZuTMaIUGmZGZktd44LUJRK1WQ39/fyDKU8FtR8hUEfFzfaxQOjzaFlAj0VKphFAo5KB1GjtGyTQa7JBmS/B6kUO2cJ72PS4UCpibmwsYY+3UxTkSqbBsTVXIoVAowA2wcq5GtpX1B9bkuV6vuwiWfZFppBUiBVaVDtdVyX71et195dwpawrT94IISFlTZWnXQs+/kj7159QfSsZRrorus35+N1C7Qpd6Rum0KFFOexsom1qNspIGFcXS3+HZabd0iM9ukUp1GPSqS43eLaxO+WJ/avaoJolrK42xrr9C14qc+Yhq+ncq2+Fw2PFrtN88EMwlU1fRDvhSIa2MnhvkzYYugIWH7MORJKD1bRsRLno1VIlphMdFVsYj4ZeVldU+3IQrAThiACE/vduUkSu93FagMM6DtauqrAAEDDKAQCciAA4+ZRkADzg9OUbWAAJEExvpdLKeNpqxSjMcXq1b5GdSsXDdmMexV7upR8+/U7i70znrsCjORjk2XTc9lCpTNpVBFidlnCQd3smrDpuuVyvPpZ8LIMBEZg650Wg4kotGyJSTSCQSyPVraRT3SFuG9nLt7XNs9HPf7yiqwdy3ygd1j+ZcbSVIu4QjC1daR4rrb9nr2ge9Uqk4FIWchEql4hw6K3fKgKfeVCRhs31QGVdkypa0Wrnm+gBrrWF97W19+fytHqozVYe3YixZ7rS8vBxgg1M2uL8Wul5eXvamSlodPTfIrX64NcYWAiBMCsB55EraUEhqq4yyHioeWJ0zN7perweYhMViMfA+jJzZNctGr5tFEqqEVaECQZIb8yXMU2nZRqPRcHV4JG/pYVJjY0vYehFhWidH11chJW14wvpcRvf2pibN+3EoiaWXMmFha5VZLVGx8BeAQN5f7+alYxEKhZzxoyJmJyY6az5SY6uD66BIhd6mBcA5cgrlslyE6QzO10ZSlrDYTYRg19yuva6vL0LWPaLhCofDbk6WBc95qrG05XidKlcr87oH0ehq9zzbHUw70WmvdqYFSMDk3H2RrELWlsG82XpbHaxOppK2uEaaI6deUtm2+/dwDw0G1DGx6VUdesZ1f7hH7IvASJoB1+LiYqAZSifP3BOD3K7isxGHQnh6sMkUZuSgXW98DL2t8Lr0IHHuNMLcEGWtMg8IrAkpFZvWeSokqXmOjebBg6xMaK37ZB0vD47WJDMKYD9i9Zx9EaDmV3p1mHR/bC6GP9PImEqlXq87WNRniKyjZsuFupELuz729f+z9+ZhkpbV2fhdS6/V1dX77DPMDMuw7wISFRRBwH1fAYmJ8TP5jHz+VDAmRGI00QioCWJUMKjRoAKuMYABXFmDAioM2+y9VlfX1mtV/f6Y3E/f7+mnumvrQbz6ua66erqnu+p5n+Us97nPOXwufhY/n89Gj5cxNRJbuIc855a8lk6n3e+owleEqNrYpp4JKmSeUWW20zvmHkxNTTno3Ido2dQPG4OtZ9i19yllnhOulYYUqIz5VUvh8qwBWGCo8KVM91oQOd+Z53mht8WXpgNqaVs2wmDYTMMCyvjPZrOIxWKBUJuiiYvthxoylkFPhVwqlQKsbpXJTLcitMv1VpKaRTYrDbs0YvjkeLnB/7P57VorQWsHqPGi+qmWe1C1QraHUR90KfiID6sWv5JkNHbIFx9aN99XZ3Y5B40C/be9VJr8ns/nEQqF3CEuFArIZDJOcAEIXHybc+v7fP08jYloZSGuIddYD5/GhnkhrJdXjoncSENH34sKzHoohNa51rRINX7DZ9R0F9/ZqMdY83loPqWgwpzrzLlHIhGX7jIwMIDVq1ejq6vLCVeSZXK5nFOC6o1SqFGp8LJXYsT51l7PTXt7uyMZWaILhXypVHK9gQEEzrDmy1rGciMMOOsZW8Xs847Va2R8nOmTSmokCUqRB4XyFc3wtQKsdt05bJhGwzcMfc3NzTmFTF5FOp1GOBx2DGtNNWIcn3FnXwqYGi3l1trGj7UkMvecXA9F+eixkzSlxibnZT12GpTLDV3bO1zNsGEGGkw09OgsKLxPVIN3gvem0lGXh6wHywpa/r8On2DTeKHClYxFcXNVkJQr+r5coxz0pkJZlXI8HkexWAx4fBrfjEajjuVaaVqOHg6uOb0bXgxeZuvFEd6yhAQr4PiePqhrudAHVco6GFsvBxVa71QNk+Uy2HywKYksKpy41jSAyGru6elx/cBZWpVCmEKNipAGiBoe9EzqYbyrAaGGripXnlVCdhozo/fMUIySwPQz6j0vFlK0sb9yf6PxVfJPaHxybdPptKvapOgMvT/tTKd1sKsNFSw29D7ze94FnhsaSoSg5+bmXCijqanJGfxKatSKaVQglcbzfZC18gKA+ZoQFp3jOjMfnWEmzSzhM1He0dlqtGwpF+rweat6Vq1hpzLGhjOA+WYTygWgY2SRokpHTQrZWneWCOGL6ehLrWpfY3dVUoSGFT7wpTQsh9Iot7H8mSpSWoy80JpOxLisNhpnZRubC7fUuuuLB4fCn5dZoazZ2dkAFKXGAeeoEHc5o6dRa2sFre+reu1UED4vTJWLpk+pEF0Oo0I9TYU56UlQ0RUKBbeG9EiZx8711Tg590krHnFPtAhMPV6oKmVlTofD4QBZS5WvhpP4byqS5b6TlT6nnnl6jKFQyKVpaaYDvWPKGTWstNa0NoFQD7MRIRCLJugzasxToXiebRpJGvZRZIUv7XpU6ZlRGa1ymmdaQ2ZcCxoLANy949lhfJvraklhjR7W4bM8Dx/nhGvN5wGCDGo+v8/5VFRBP6PWZ6xKIeslo4JUYbhYGo9CIqqkCLNoXpsqKn6uWmR68ZfLO/bBZPYCcdN4yNRTprcUCs0TdrixLMKgaQqVKmQOa11yPvy3QlmMLzH/mAJJIW41eBSqU++gXgHrM8xUiCqZSzvbKNOUsTOuh3oCWuHKFquvZ94WYtRzr6UBlaVORmw549WuJ88Sz5EyuJVkp8KmnmHPtkKVjB0y55QQKb14fjaVmMY9ffBuI9Z+sUGlSo8sn88HPDm9p/x/hWFVlqk8s3BvPaEPlRc25UpDRlZp6l77nB/+Le+Meq2sSbBUSMw3rFHgeylxiTJfiX/qaasD0Miwhg57nlWpakjUer6KBvK56fna+6drWU75W8etmlFTDJmHwkdvV8o3h+/Sl0olF8+ZmJjA2NgYxsfHnbKg0NVDaF/LrYzVsrIWlhJbNN/VEoyA+VgKD6QqZFvzuJKhz2xRBx48Qlhq7KRSKYyMjGBsbAypVMoV/SgWi06xadGIRCIRyL+rx0OwhoyyYG3qBmNXOvdkMomJiYlAhTFNISGhjq3cfDmP1Q6LSOjeaniiq6sr0C+bcC5jsGp00GOgQepjtCscrC81umodvvNCaFIVMY23VCqF0dFRpFIpt/Y85/T6mW+qHXxszmYta09hpoLSd/Zt3JiV8rSuu6JD2suce2rDHKqM65U11njXPfcR4oAgM5sKlwaZEo1ozDNXn3/f1NTkymlWGuJQ2a6pmZpKaNeARjGHEshsTQnlKSyXMlbPXqFkLfNJWavGiy3WBOxfd61RTyeGMWRfep+Fv6sdNXnI3LCb99yM67dfj+RMEmvDa3FG+xmBfC3rsSmcNDs768oO8jUyMoJUKhVYNP28cgq5Wss1M53Bh//7w7jpdzdhODeM41cfj6tfcjVOXney+x3fJdJLo1815g0gMF8A7plZcYfMSS0GUImgvfyOy/G3d/5t4GeH9hyKBy5+IBCvoWBlfeXx8XHXw5nrTOOHlWWo1Lq6utDV1YX+/v4AI7he6PryOy7HR+76SOBnB3cdjDtfd6cjNmm+pa0IxLOSTqcdDKlsYbLb+WLN3FrKCNoRCoWwL7cPl95+KX705I8wOTuJTfFNuPz4yzHQNeCMGk0xoxCiIFOvn3eERpMtYkK42wquWqzucmfmvovuC3iLnJvvzLASGcua8v6xuEh3dzd6e3tdFx8Se+pFVvh3n3vgc/jULz+FodwQDk0ciotXXxxAFigUtWoYWb8A3Ppb75TKTfkHtuVirXLGypjjVh2Hj5/xcRyROMLJAuXOqLICEPhsGm8+ZcLnoGwhgkSeipaPXcwrVaMzFA7hX5/4V3xvx/eQnEmiK9KFU1pOwfMjzw+wtdXTVGY394XzoVK3CpmvRoU1isUiUvkULr/rcnx3+3cxMjmCI7qPwCWHX4J1oXWucA/lBzkD3G/b0KVY3J/dMTEx4ZqoEKWLRCIBtEiNxnrkTc0e8g93/hCf+s2n8P8O/X9Yh3W4cdeN+EbhG3hV66sWpKeoJ0nm4+TkJMbHxzE8POxqQI+NjbkFY2xUc2nLwdXVPvQ7vvsOPDz8MG541Q1YG1+Lr/z6KzjrhrPwm//zG6yNr12gjK0nq2xUVcoKXSv0QciaRJJYLBbwkPm8lYwj+4/ErW+7df5CFLFA4BPqHRsbw+joqDN0iEZMTEwgk8m4WuFMqUgkEk6w9vX1obu723k89ShkHtoj+4/E99/wfbdehdlCoPwk55XNZt2/WeJTLwQtde1YxBrQ6qkpfForGScUCmF8chxn/NsZeMGmF+CW192CruYu/Hbot+iP9qNzrtOlztAjnJmZWdA/VdOZlE2eyWQCcDCAwFnQ9a71kh/ZfyRuu+C2eUitMN8cguvKamMsAUoUhevP5yGaolkFPT096O3tdcYc6ynXykrW8R+/+Q+8//b34+oXX42ju4/G1XdfjQ8/+mG8r/V97nf0juVyORfTpFNgUSwATqZo608VzvWGxihj/u2V/4ZV7atww69uwMtufBluf/Xt6Ch2BPKh1YNUhWxRNpa6JUTMl9ZsoOff2tpaddlbyq3P/+bz+ObT38SHjvgQ+tGPX4/+Gv+8558RbY3iOBwXCDepUlaFrKiERYKWy0MulUp413++C4+MPIKrXnAVuiJduPHRG/FnP/8z/NOWf8JschapVMohtEQEGWqxHcK4nnQK1JBikRGNLRNdqOfM1+Qhh8Nh/Osj/4rXb309Xr3l1RgfH8efzP0JHnziQTzV9RR69/U6OAVAIF7MQ5XP550lTjjSQqiaoK9EhnqU8eTsJL71m2/hljfegudvej4A4PIzLsd3H/su/uXef8EVZ14RiHXb5HgN3tsYkCUg6c946eg9Kfta4+VLblg4ilWxVQvixBSuCjWOjY1heHgYw8PDrgsRD9bk5KSzpKPRaKDhQHd3N7q7uxd4O/XGjyOhCPrb+vdf0tAM8oU80tm0U8D0gPWlBfjpSVplrFA16+VaqL2eS/KPP/9HbOjcgC+9/EtOwG+Mb3TGhBJVCHOxAD+tcRJvGOejha0eMvPuNc9cwzW1GkQ8MzbeSCORhtr4+DhGRkbci2gVhRC9d21QQWSis7NzgSHUiNj9Vb+8Cu84/h246NiLkM/nccUpV+DHu3+M+wr3YXV09YLMAZalJRxNONsSuPgcGgPXDkT1cFRUxjxv4/NQKBRw2XMvw/ce+x6uf/h6XHzQxY4vw7toUTYlZ4VCIefEkIymoTP1jDlv3hnrHVfiId8/fD/OWn8Wzlx/JjKZDHrCPbhr/C7smt2F40PHL3BGOLjeViZqOAaoLQWpkpGfyeOWx27Bv537bzh54GTkcjlcvOVi3LbrNnx36Ls4NXcqxsbGkM/nAwpZ2fSqbwA4o1kh67m5uQCHR4mNNjul2lGThzxXmsOvR36NPzvyz5x1GWuPYVvzNoy0jGBN8xo0Nze7CTOGynzGSCTiFDJhBLX+aBmyiwkvOQVsPYSRueIcCqUCWqOtgZ+3Rdvws10/CxwkKjvbLMCmFlnCGgUt4Q2FLrmRFv6u5ICWSiVsT27H+ivXoyXagueseQ7+6tS/Qleoy3lf9DTHx8cxOjqKZDLpfqZeucJ1tqhJV1dXWeFay5rz2R4ffxxb/2UrWiItOL7/eLznyPegbabNKWQ9D1YZq0BlbJhesSpjpqvY9m71jO88+h2cs/UcvPFbb8SdO+7E2vhavOPYd+BNh74pkIpEA4nx+mw26wQmYWiyf8Ph/Q1I1ENmShww7yFZxnstinl7cjvWfWodWqOteM6a5+BDp34ICSRcCpDC1DwvnD/nRRIPO/loeICdqnzrXo8hNFOYwf377scHTv+Ae59oJIoTu07EztRObGrZ5OQPMF8Vj/AtESp6a1xXvmjQMe9Y519PSIwypiWyPz2Gc2iJtOCewXvwhv43OHKlGvsMi3GeFj5V5MXXGYoKmelSlmG91OC5fM6a5+DLD38Ze7fuRW9zL3bP7cb2qe14VexVCOXm6xeosgfmDUjO31ccYzkUMYdb92iL+1mxWERzqBmPTT6GE2ZOcEYonRHl3GjoQuv/29r5fFaNtSuSW4+OqkkhJ6eSKJQKWB1f7Q711NQUepp7sDu623kohG/p9o+MjGBqagrh8Hx/Sc0JBOAEUCQSQSwWQ09PD7q6utDR0dEQyzveEsdp60/DFXddgcP7D8eq2Cp87aGv4Re7f4GDuw9eEPPLZDJOWGWz2bJMQR5MXn6SLPg3JGRo0J8Co9KDesq6U/Cll30JB3cfjD3pPfjoTz+Kc288F7ecfQsmJyYxOjrqCDgUrqOjowH0gZeWBycWiwXaSXZ3dztCF+OwlmlazeBznbz2ZFx77rU4qOMg7BrfhU/c9wm8+dY34/PHfT7QMUkbNNB7pAFET17jxn19fQ4upWJoNDv8yfEncc191+CS0y7BB0//IO7Zcw8uufUSRBHFq7a8KkDu4AWmMguHw658I4DA87BQBT0lYJ7MQzhVLfhaFN1z1j4HX3r5l3Boz6HYk96DK+66AufdeB5uecktyGfyjrRFRazGG5UxFTHra3d2dqK/vx99fX2BLj52jvV6yKP5URRKBazqWBUg1fW29mI7tru5MKZH71JJkhp3pxDl+WHBFhqhNOjq5UxQxvzdT/4Oh/Uehp7mHvz7I/+O+4fvx/r29c4ISqfTgZRPyh5gPo2S60iERRs9KNythX80Rl6NAuQav/fk9yI1mcIrbn0FwqEwiqUi3rz6zTgtdBqGc8MuRKBKn4gEmdfRaDRAVLXKqd64sW90tnbilLWn4Mr7r8TVz78a7aF2/Nfgf+HR3KPoj/a7O0g5rcRc1gTQ4kJqRPPe8m4qoVnDHco/qOX81EzqAoDmpmZXO3lmZsYp0ng8jkwm46y9qakpjI2NYWZmBq2trc7ysHnHtFo1LsgKR11dXU451Ctkb3jVDbj4Oxdj3afWIRKK4IQ1J+ANR74BD+x7wHnGhICTySSGhoYwPDyMdDod6P2pn2/Zq7w8ExMTjm2rMXFrQS11cUqlEs7Zeo47QAfHD8bh5x6OE244ATc9dhNOjp6MvXv3Ynh4GMlk0l14xgAZClBPR9ukUcD29fU5co5VyLWud6lUwtmbz3aIw4bmDfjsqZ/F2T88G/+5+z9xxNQRDl5n+0L1GoF5j0EZ4Ix5s8UbjQhbDL7e81IsFXHS2pPw9y/6e5RKJRy/5ng8MvIIrnvoOrzx8Dc6w5PQIuHnUml/ihvRH0WK1HCjocS9IZxK1IKGhg+GX2qce8i5TvBs696Go19xNI7+4tG4ZfstODl6MkZHRzE0NBSIGXP9WX2pra0tEMogiYtnJpFIBFCsRueuAwtTHyPhiOvnXCwWA2dGc7X5dxozZlMNnncaod3d3ejo6FhgzNUybnjVDbj4loux8eqNiIQiOKb/GJy/8Xz8auRXLpxE5rpNf+S8VTYok1zlpobHaHDU4oXq533nye/g5iduxlXPuwrrW9bjwcEHcdVvr0JLogWbsTmQ0qTnF4AzCJh2BQRRieXKjOE+f+mlX8Kffv9PcdK/n4RIKILDuw7HCwdeiN+M/8YpTKIJGk6kXtNzwnnyd+gUhEIhhyp2dHQ4Q7UcQlTNqEkhD3QMIBKKYHx2HG1tbc4zyIfy6G7qdpampnnMzc1hYmIiMEkeNHX9VeDSc6MlTvJRPcoBALb2bMWdF92J3EwOE1MTWBVbhTd+6404KHFQQJky3WZwcBB79+5FMpkMWEoq6KmMeVg1KV5TLHTTq3kOZRLyILWUWrAptglPjD+BzaHNGBwcxJ49ezA6OhqAe5UIQoOns7PTxV77+/udcKW3SQVgGzrU4iHbeU9PT6O52Ix1reuwI7MDa3NrHQGNMR4l4WiN3EQigYGBASdIe3p6HKlIFYM1fOoZa+JrcET/EQDmLfvD+w/HTb+7aUG8V63vcDiMtrY2R6SbmJhwhp5W5iIUz5rAPoVsFV4lz+RLNWsPt+Og+EF4IvUEDms+DKOjoxgcHMTIyEjAs+ecGBro6+vD2rVrXbN5KjCGN2zBikYgE33tfYiEIhjODQcU8sTcBPpa+9Dd0u3igUqoo7KgbGEdbp4jciVozNFLZrlMVpuq5xm29mzFHRfdgfRkGqOZUcRKMVz4vQsxEB1waKGmHnK+VK7cPxuPpUxVbxpAwNi3KFw1IxwO469/8td4z0nvwWsPfS3y+TwOajsIOyd24gf7foC/CP9FIG5NhIdcCcLU2uSGcq8e7s9iQ42JQ/oOwa1vuRXjuXGMpkfRXmzHn932Z1jdshptkf3yjMaMxu2VTKcEYkUx9Twx1GGVsqJztcDWNUHWrU2tOGHNCbhr91146cEvRbFYREtrC36V+RXOiJ2B9un9EButbLIflSZuy7txIdrb291loWJWyLpRMCQAxJpjaG9qx1huDLc+eSs+8ryPBJjVqpSHh4cxOjrq4oIqINSK1biwQpN8ZmVwVgrrKeymMEsqn8Ku3C6cFjvNMWSHhoacQqZgVTIOY7CaQ0vBSuGqvUvrgV90/ro+MzMzSGaT2Du5F8fHjg/0EyaRiEYEY6r01Nvb29Hd3Y3+/v4As5cC1XINGnFWTt9wOh4de9R9HwqFsD25HRu7NgbSPegh8wzRCCOPguELssaVzMfLS++IHApefMLWtZx/RW/SU2nsyO7A87uej2w2G4gbk4RGg5P7zzvZ39+PtWvXOgTFB1XbAgv1jOZIM05ceyJuf+p2vOyQl7mUnLtH7sbLV78cHeH9HAJ6igpdK0+C6U+KDJUjMPry7usZ7U3tGGgfwM6Rnfj50M/xloG3IJPMBHgGmqGh95teMb+3pFHea8oVGoTAwg5TlQzu2+TsJKKRYI5uU6QJCM3nh6ucpOEPwN0/Quh8T33VYzAsNXdg/13qau9CW6QNu0d34+6xu3Hh2gvROj3PpPbViNC7Z2W0/szXBERJYfXUPa+5ucR7T30v3n7L23HC6hNwdM/R+PT9n8ZUYQpn9Z6FseyYsxIojGxSNgs68NJQKauy0Etfb4F3HT96/EcooYTDeg/DY2OP4f23vh+H9hyKtxzxFhRmCwtgZ1UYqpCtl6sXySaL62GpNb72gds/gHM2n4PVravxdPJp/OPd/4gwwjil/RQMDg46MhRjxrzMGtthbEqL6CvkohW6bOnMetb80v++FC9c90L0RHrw5MiT+MzDn0EYYRwTPgYjUyOBzjKa2qRKWfOlaTgotG7Z+NVap+XGe099L577pefi73/y93j9ka/HPXvuwb8+8K+49qXXBvafgopKlJ7M9PQ0JiYm3F2gsaehGsavgPlGBxrP8nmflYz/79b/D+dtPQ/rOtZhR3IHPvqzjyKCCP4o8UcYHh8OVHFTT4f51CTntLe3O+OY+ekWomvkmnNccuoluPDmC3Hi6hNx/MDxuOqXV2GyMIlXbHwFpsam3Jm17SF5h9kikJCkjc1XQuiqZfzo8R+hUCzg4O6D8duh3+Ky/74Mm2KbcHrsdDy17ymXSsZ7SmdFvWQaRzwnVMQ2/5WGtk1B4r8rGfo35x18Hv7pnn/CmvY1OKj9INw3eB9u3HMjntfxPCC3sDMUFTOHcmYALFDGjfaQ9RluffJWFEoFbE1sxaMjj+KyOy7D5vhmnLv6XAzuHVxQ357rTEia8+V51tQz6jMtmkLCZbk7Wu2z1lzL+o1HvREjuRFc8dMrMJQbwhE9R+DTz/k04uk4UpFUwMLkQxMq4M9UUfB9NcapjLdGxI45JqYncOntl2J3ejd62nrwqsNehb/+o79GNBTFXGkucDF44Ag/M49aN00VsmVNUxDYS1LthQGA3enduOi7FyE5mURPaw+O7z0enzvpcyiMBouucJ7qfZEYop69sgO1VKD14hux5nsye/DOW9+J8alxdDV34cj4kfi7zX+Hwsh8LIfrrN6lxsQ4b9tIwlaIa/TFP3ndybjpDTfh0tsvxUfu/Ag2d2/GVedchbce81Y3N3rkWoDeNmcPhUIL8jMBOMVn456+dIpq92N3ejfedsvbMDY5ht7WXpw0cBJueMENKCVL2Du3N7DmtnUnB5+Likw9Aht+abSwfcNRb8BwbhiX33k5BnODOKrvKFz/ouvRj34MpgcDZ9beRYXsfZ6OCtJG125XGdPd0o1zNp6Dt619GwZ3DAa8XS06pB6yVhbj/9PAVli7qakp8Lc6anmGUCiEq86+Cn99x1/j/Xe+H6P5UfS39uPla1+OFze/GLtzuwEsROs4B6acaTrUciphOyamJ3DZjy/bv+6t3TjvoPPwzkPeicxIxpuSZDNkgPn0LRqlNHb0HGlcvFw6bi3PW5FCptBJp9OBn12w7QK8+eA3Y3p6GplMBsPDw9g5uXNBaTit0kWFDMDlKuvB0zgJPVQ+JBV7uUvD+SmpwTf3l2x4CV5y0UsAzCvR2elZpGfSgZKTLERB4gLjsfRqSJ2nVa7KWJUIAOcdaTk3fhYP8WJzL5VK+JcX/kugRCDh9F25XQ6BsMxHDjUWdD+U9k8vNRwOuxScSoTTUuteKBTw6ed/2s1Zy6Xuze9dUOtWcxY151vXT4vW85KRZVqNF1npmXn+6ufjZ2/52YK/td6Clv6kceSrh2tzz31569wT5i7rHWhqakI2m1107qVSCZ8763MBxCefzyOZTGJ3brebI9eTv0d26WJz4vlnpbdqkatK1x0ALth2Ad566FsxOzvfrWl0dDRA4tK4qi3eo/mw9vzk83nnYc7Oznq9/aXmXk7GnH3B2Q4RmZiYwL59+wIFQWzBDJ/naeWoKg49N3avuFYMj5Bslc/ny86dMqxUKOHDJ38YHzj2Ay5bZGRkBHv27AkUMdG5KcSuxrUyw3O5nENC9a7WemZ8637uxnNxzgXnBPL+k8nkglLFWiVN1wyAM47VQKbitQ6aniEAC84Q5ZDvvHtHqYKxa9euEoBnzWvXrl0rc1+Z+8rcnwWvP4S5P9vmvTL3Z37u5UaotKTK3u9d7d27F/F4/IDADrWOUml/M/W1a9c6z3Rl7ss/Vub+zIyVuT8zw8792TJvYGXuz9TwnXffqEghr4yVsTJWxspYGStjeUf9vP6VsTJWxspYGStjZdQ9KiJ1PVuggT8kGAxYmfuBGCtzf2bGH9Lcny3zBlbm/kyNSiHrFVLX79lrZe4rc1+Z+7Pj9YdALlqZ+zMz93KjIg85Ho8DAHbt2oXOzk7389L/FkLQtKddu3bh6aefxp49ezAxMRGoc2pzv0iVZ7GKWCzmKgL19PS48pnagYj5j77UlnQ6jQ0bNrj5LjZ3OzgnppVMTk4ilUphaGgIe/bswZ49e1y5u1Qq5ZpWF4tFtLS0YPXq1diyZQsOO+wwbNmyBatWrUJnZ+eCYgPlrKN65l5uaE4j+0/v2bMHTz31FJ5++mkMDQ1hamoKkUjE1YXWykVsOMGypVrjt1Fz13Vn2hAbHuzbtw+Dg4MYGxtzjQ+4/iyowER9bcXY19eHTZs2YevWrdi8eXNgL2zZyUrnXpKUFFbb4hx3796Nffv2Bfp6s1Yx56m56b42dCyz2d3djTVr1mDjxo3YtGkTVq9e7XpUs6IaK5Ll83kcdNBBFa075z47O+tSQVgSlrWsde7JZDJQN0Dzv1kDvbu729U+57lhXfTe3l7E4/GyvbSXWvd4PB44F+wnnUwmsXv3bmzfvh2PPfYYdu3a5YqC9Pb2Yt26dVi7dm2g8YWWxLSV5zRvVPNKfWUe+dXOvZZ7qudJU7HYupNyZ/fu3RgbG0Mul3P1D7RDW3d3N1avXo21a9e6Kmq2H7WOauZuZSLTNNmdbXBwEDt27MDjjz+OHTt2YHR01JUy1Z7ZsVjMlSldt24dVq1ahZ6eHtc1TCu+LVbuttp1Vz2jZ4hzf/LJJ/Hoo4/iqaeewuDgoJPnLIbDbnKrVq3Cpk2bsHnzZmzYsMGdba5zJbWrfefdNypSyPwQVkbSB6ZCBoBcLreghJjWCVVoylc9ipvBg2pLC7K+L9+7XJEE/b7c3HVYpcA8NM6Hv8PfK5XmO8dofWg2xOjo6HCfx0OmJRYrWetK515u2MsUCu3vukWlysG90NxfPh+LQVDRqULzHbx61p355wBc1SQWoGBddM3jDYfD7vdDoZC7FFrwQdtKUiGXK4m41Nx1rizLaI0ULbrC8nqlUinQtk3XXXOudU21gAX3gOvB5+FeVLrunP/MzIwrL8luZqynrLmlAFxTDL2/LJJg88KZQ25rRrP+fDlhVW7uVMj8DJa/nJ6edjLGVwNAc4wnJycDe8OmB1pkRQuw8N8sErJUsQerqKu5p5RxPE98Dpb/1FKqWpKXckhlECvvqfxZKse3krmrHJ6amkKpVHLnR9fDlsakQrY1oXl2mLessp9nW52tcqPSuavBQx3Fs6TFdjgPfV/KfTXatLgMZYwtMbxULv5S0HpNlboWG6q4uBi2cpVOTg97oVBwSfT0lih0uWnWS9PCG7XOV+dK7zifz7sCISxmwR6m1nOw1a2WsrCXe9g9sAUR+GKRAs5RK6Kxy49WzuL7NuI5fMqY86Nw52dRSLa1tbmL3NTUtKBEH8+QVjxSL7Te9dQ563mh8OdaUSC1tra6WJe99CrotBECgIAistXL+F71DFVEWoGL6wnM3zctxM850MtWwUkFwcYltk1pPWvuW3tbDGNubg6ZTMY1kiBqx9KarFGtCpn3lMYTDU9ffW4V0PUOfSZbE5oyh1232PCDHnI0GnUGUlNTkyvs4ivZ26h5KnpF75itc1nHWp0nynU1CAqFgut+x8JD9q6oo9UIOaMeMuU6PXztK20Lr6iDQoNJC0bxjPmMkXrn3BCFXO7i8JCwcDeHwkMU/hSmbAbNnqG5XM5VddGa1mqN1aOMy20aL4X2Fk4mk4Gm7cB8I+6lSjjqZy6nYl5KGbOyjNbSpeXKedFrJjSpyqDR81QLlpVvuL4KffF7emBaAUsrwKnn4fMq6pmvrfykypgvKleGYNrb29178ExQEGt1NPWS9P21spCWKKz1edSrscqYc2xtbUVHR0fAMFXhq63/dO+KxSJaW1vR1dXlBF0tZ4bPtZgi1spb2kSFRn0qlQqgdXo3FX3Quu5sK9nd3e0MOlZg4po1Stnp/aSwZwU7hmc0RKNdrdhhLxKJuF70jb6jVjZOT087ZczQRjKZdO1ltb62GjyU1XNzc8jlciiVSk6xswwxkU9t7NHIuVOus2Y7jQnOQdE3IIheUb4QTUqlUt5yvktVd6t0LJuHrAKXB083jF/VylZPCdjfymtqagrhcBjxeNwVtScMTCVSr1LWQ5fP5wMbx96lvBz5fD4Qy9GayhpLKLcm9c63mmeyyo4CVhVyJpMBgEA5PmC/9U2hTIHXKOvbh6Ko966dYqiQKYAIE6mXzxKnNPw03KGesn5+peu/mLGp59WWVVW4U2vfcv8poFha1MaY7fvTW67XC1LvRRUy593S0oKOjg5X8jCbzSKVSqFYLLrmMLyfPFc0GiKRCDo7OwNGqz371Q6fIaTlRy3qQ+VGBUwBb71desv0jBOJhPNCy8WRuXf1DCsbtfc6lTG7b6lCnpycBAAn94D98jMejwfQi0ajQjyLXFfKQ8rEdDrtFDLnZFFDhgApS+j5T09PIxqNOj6Cnu96561nhXKCTYL4ooHJxhJqCOp7kYOTyWQc6sVn5RlqaWlxeqEe+d5whWwnwk3lpeLP1NpUIUUPjvV6S6USOjo6kMlkFliCVsjWMsopZG4eoSPtLxwKhQLdqRhD9rVipGLR9fHFoRoxfIrOKmP9yu5K/Du9TO3t7e4C+ZRao+dra+IC8+iDGkCtra0B75RnhEJaz5vG3vQZ6kVUdN78ymG9A0KkKuBtLFetdP0sroWF3+tRxkpg0s5HikbQQCIhi3Aw564tAdXI5nlqhPHg8471jKgc4N2lMTc1NRXwWujF8N/aUIWxbsZuyf+wsUE2GuDcahnl0CtF5egEUGlkMhl3XwEEejtrWINnpJ75lZunQuqKGrJDGL1jGi/KeWCIkedZY7nhcBhdXV2Oq6BGaS33VO++RW/U0aIs17kD8/fDxr5VKWezWfd/9OyJXlEW6ftVOxqikG2AXx/Ixpm40D6FqgxWCqympib3faPjJHroFNbQS0GoTmM1tLpJRCAJiQeP70UBpoLZF29oZFzKKmO97L6GGbwgjLupYNMmFfUqg6XmzWGJUaqMFbKk4A2FQi6mqd2TfDCnCq1a4z1WqamXRSFAoeTrIMS5TU1NoVAoBOJmlixTjg1c67wtzKasaWDe4ifxJRwOY3Z2NtBwnevLWJs9D0pMq3f4lJc2T1FP2Rph2iqP66jGEuOWlEdsXM+zr554IwxSixha/gGVhToA2qCEZ1s5CWoM1mss2Hn60CufUiMqxW5mSr4lSSsUmm87Wo4MWK9ct3O28k9j8Wo4Ur6QNKgsdjpazc3N7i5MTk4GCF3cG56VemHruhWyz+q2fVxVUNFjVOuWCtpCgVwEC+nxc+sZqmBouVlIZmxszEEb9BQtJb6rqwuJRALt7e1uvtls1ikKbipj3wqbLcUmrPZZLEmKMIvGpJiupZ4MAOcJ+eKvyzUUPqVHqZ9JBmZra2tAMFIhs3PN1NSUI6BRQNnzxFc9JECfImajeyouKioNYfCz1OjkvG2LS/XotCcyjbpymQXVrjnvKeFq/oyxbAoqvd9qdNPYIBGK6UWWeFnLXCsJaTDm5/MK+bm8az5P2ULY6jRYdEUVRa3evvXclDyqiBzvqEKqFPg0iOydaVT7y3JGgzXsNf7KebW3t7v7qmxvxuApX2n467r7UNVa1pd3S1FBjc0rUY4cB95Vvkc0GnUkQCUC8jm4b83NzYFuaTMzM05xc9Ri9DfMQ+YhocUdi8Uc0UMFPK1qbryFwNQKZq9PC19YL7PWQ8g50PKZmJjA6OgoBgcHMTw87OI5JJbpgUskEi4Hk8zMaDTqoKV8Pu9Sdjo6OhwhTYWrPk+tYzFlnE6nMTY2htHRUSSTSYyMjGB0dNRddhKQfLGxRjAGFxsWUWlubgYw76URnlZo0irkXC4HAM7wIASlIRBtA6fpF1R+lcyTX1UQUhn39PSgUCigra3N9aDmOVHGKI0+CgoKMn3R4NN0Mysc6lF0Nn5MQ4cegqI6THNRaB6Yz0cm6S+RSDgyFO8DBXG9SllJOUqe03i9ZfbSgOE8qbCIUCkErfnrinKpjLLhgmqHVcZWwREC1vxvkqU0Fk9loalw2i9+MSJptXO1HrxmnFCxscYEAOd4UfYT9mfuNwBXB6GpqckZ0r6UoVoG563QOj14evSU5alUyilkevbkJYVCIae7lEAMILAubIdKtHFyctLlJPt0UjXP1TCFrGxjKiFeGg2u+7xjUvmpINRDssSYcilF1QyrxOhtpVIpDA8PY+/eva4ICDdWranOzs5AAYRYLOaYd9pTuKWlxSW/z8zMoKOjA3NzcwGorB6ySDllzEOpBsbo6ChGR0cxMjLimJHcH5+iqtfarmSoQmYMirExxoqtl8JnnZ6edh5pNpt1CovrqpeUFqzNQa1GyFrvuFgsuiT/5uZmdHV1BQwcFTCE6+gJcc2twrIoE5WFNebqEV5ccypk3j+tAUBByzuqRrESoljEp6+vD319fe4+JBIJR7ysxbBTg10NTFXImmIGIPBMhBtt9gNDTZyXGn40qpkHy3nUqowVgbOGBRUcCXNEr6iMx8fHnYFJT457p3F/GhFUiPUqNSCoeHzxV3rvLKLB9aMCprxjIafW1lYUi0XkcjmntNLpNIrFYiAdzc5dw5uVzF3JZ2RDU/kqe10NdwCBcAyJrHwGEriUZc0sg0gkEuh7TnmkYUkNNVU6GgZZU7Cqh6xNoGlV8BLRuqVgVA+IB8/Gon0x2HqtQV4ULjg9yZGREWQymQClX4tNkBnY3d3tYD9txM3fp9eknwnMx/Fqga0VPvMpY42FJ5NJjI6OYnh42KUr8GBp3ANARWtcj4LWAq9LXgABAABJREFUeeuFo7HFdaG3qKQpFRhKoFBGbTQadUVQAASMQS0qUquQ1bPOPF16w/F4PACh6hwI1zElZzHPXBED+/Ip8GrnDyDw+RpbpTLm/1vYVlEFpsTRMKVS9nnItQxVyr4ceuWTAAgQbZjGRFlEBW1JdopGqFdUz7x17ooOUsGpt6kKTisAWlhV115DghoOsx5yvWuu8kRJoRqLJSLBEF5XV5eTiawqRziYZCga3nNzc4F9USOaKFOlMkdlOdeY8Xh9adybn6GIJcMvDEPSu8/n8xgfH3cGiqIHysWZmZkJ7IOGQSodDYWsNS4Vi8UCqSm66HrRADgBrF4SLXHGpGzMpxGwjE2doFLWDaQytZ4LL3AsFkNzc7ODUAmf0iMjm1CFqcKGPKSVbpq1vH3K2JdDzRQKwiyaSw0gMLflUsaqXC0E7SPPqALRNfJZoT5hZEMjjYAfOdRbBODCK9abJ+TOeanXqeSkSmOU9SIWVinr/s7Ozpb1tNSDZ2hKBRiFGJWxvk8j4sgazrKGFeencC5JN/Qk6U1ayNpWdfNBwNWsu52zjRkrKYp3lB6nkriUvKZomsohzauuF65ebM1tXJYhA8ouPQ+Eq3km2traHLFViV2MuZI1riEmvcd6TitZd4tE6EsL7AAIOJDW0aJCJgowMzODTCbjMiOs/lAUjntSjZfPUZdC1svmU8jcSB5wCkeSWnihlPiiMBdrW2saQj1egj1wlrSgebrKwlNyDS85L65CfYwv0APl5mmCPDDvBalnV+389dJYNiFjx6qINRneWt7cS6uM9d/1DJ/HoC+b8uRTSvoemqyvpBd6pxzWkKhXMajH6HsONSxUoFmmJ6Frnjl9dvWmLJFJK5PValToelBh8l76ECk1fvTuEQqmwtNsg3K15qsd+pxLxXR9ckjzrDlPenT0ihS2trHDah0AGwbTPG3KFOux8W5q+VIfq1vvo+6TL36se1zNWuuaW4NCi9Pwq4baiBSprFQGf6lUckqSMksNWpuZUizOF2VZiuthz4iVKTRq1LCJRqMLOD5qTHR2drqQGLDfS25vb3fxb8aLLYrDNVDOVDWjoTFkrXpDuDqfz3utICvk1HLlhnZ0dKCvrw+9vb0uHlEr07RcPEdjU7SitBIRlSmfi5eb8wWwgFDA/LxQKOQEncalNN7FhPKlhKvPevUlvRP6Gh8fx/DwsCN0KbtaLUXrxdhYfb3esV13Je0p6UrTrKyHzPdR71ILFWiOuv6twrHWg6j27PCz9UUlaZm/KshoLKXT6UCMkHClohV6PxgLI/RKpaGpd/UM3VdNB1JSHRWXJUdR+Krg5UsVRD3ems7Tns/FvFY1Kq0HrAaD72W96WqfR1ERZSXbkpipVMqlU9I4o6LiM1hDgGfXGkgWLapnlDP4+VVRJos48e+5L5RxPC9ah533F5gvdEJlrMiRnoFKFZtPjqoSpjwmgpJIJByywxe/p1dfKu3nBrEkLJVxc3MzQqFQgLDJM8dXtcZzw2LIViFTKOXzeRccp+VI9jUXiA+mQXUWDO/t7cWqVatcx6F6yAtqPVFQMqZDb0s9FmCeLKLsO3q2VLiEYhi7oPLjZvF9OAcSvpi7WolC5t+qYrPCfnR01HUZ0lg4Uyk0/9imkVnlxTVu1EW30LpWVqJBZIk6KgDteygRb3x8HNls1pGq7DP5eAiVejx2zRWe0tga40kaArGhEBsnVOONqBGVMTkARIU0zaqaM7PY8BliVLyqpKyyVS9NlbeucyOVsSpZ38uHhADB/VeFq169hsSouO3zaA2BxZ5LEUCeTRrD9Ip5Xm1TD63yxjVV41X3So0nq7gbrZR9aJA6Bjb8p3+ne6BdwgAEMhLm5uZcV7SpqSnEYjHnXerzVXLe7ZlRJ4+fR/nLDAF2t6OnzHNBhJMKOZ/PO64QgIBeYxaHOpS13NGGpz1pusr09HQgaK9wo24ulTFzwHp6elxN2Z6eHvT19TnmpvXcKh02rqdsZI0zaOwGgNs8JVDQ49cYIS1itYR5iCgQNNUiHo8HPMLFNq5czFgJDOPj4xgZGXHpWhQGVFZUeDZdhPvn8ygt+7pWI2gxZEKhXK6/kjpUIVvPVBUXU0T4u+Weq9qzY2FIKl4lz2kRBzU09Pe0SYnC1WqIUDiRj5BOp11ohMLMFlKod/jCTrbSknqLhBFt7FI9V7v3/Hc1614OtfHFTBky0vNiw0NK+tLcaY0bKxKg56XSc6PnnGdT0214L6mQ1TDWkB0AR1y0n6+Kphzy04gwgUXiyoUKrKesypj7SCOjra3NGdzRaDSA1FEZc02oxDWzoNzQZ/bdeRpVxWLR/Uz5D2ynyPQnWwykUCi4MsJ0IDlHq5BpOKvRXM09bZhCBoJdZHRBgIUejmVUA/OF7ekZ9/b2OuYec9rqzW30kaA0FmlhU91Q9RAYMw6F5itF6YtdlJjfqVWAbPnPSr1jG/9WD1M9MF8BEJ8At6SoxZjtjbjkGmux6RRsJEL0BAhCVfa5FSHQWtaau67PW0vMtdyaU5FqSoUqWSW/lKsWpMQdCil6ArwbRHJogNRybiodqpCVxasVi1paWlzxFRprOlcluywW6612TpaEZV9Mm1QUyjL01UhTwihj375CIYoQVaLs7Dn3oXDqABAV4XwoB1XIU5H45qIoVq2ooW/+Vhmr9+uTIVxX3Td+teeKue98H3XSaHSHw2GnvKs574sZcZQLnIu2q6Ri1lCF7sXc3Jz7XdYPICFZYXbypey8qxkN7fZkL6cKT1uOjl6QLjQVMtluxPd9zavr8ZAV86fy1IorPHy+QwbAxUCYZjMzM+MErWUh0hO11m25915s+Mhovjg4P58WoV5qjUMRPtdDq0QRC9XVM8rBv/QkqNhs3qUKWRt3VqNEiWq+WK5VFpVelHLeva2upCVW+TtK6tHa4VrQnvfAp3R8Hqj1BhoxrACj4NSYKhUz4928u7Ozs8jlcoG0Ii3Owf1TD2epefuEKudkkSoaCbOzswuMf1+FL/XK6Ayose0L1VQrb/Su8z31c1gvnEJfCZ4AnHJSPotPySh0Xa93rEiGKmNfVTRVshrm8KELembtGedzEWnkvWR3N97Vatfdht54DqmUdS/U6CQipETdUqkU2DfWfee8NUwRiUTKGs4HLIYMlK+UQlavjV3a+CU3TysUaZDdx7CuZajBwDiPNlvQlnG0uNXS5t/R++WlnpmZCRQR0QInhF60yIPm31VKFuFLFZLCohQ0VEq0ukOhkBMA9m9szMoKPV6eWslQvmfQuBTPC2F3esnKAFcPmda0j/VJIVYsFl1dYvVY2YhCEYClLokPbvcpY5ZY1dCAIiIKsVM56PlQJUyLnWdf+/M24g74hgpMZSYrfM75cN3pJWSzWUdeoRClsLbzpMKsZO4+CJ21gzXlsL293Slj3iUlxqlRw72MRCKu5aQqGTUEajF8VCHQA+eaKcmJbF3ll/CulkqlQI127aRk46IWyarXULPesc+I4b3U2LzPUPLF3X1Ig3Xk6HGqMq7WO/YpfipjXUcfD4IyQs+LokbKvPbpC2tMHHAP2XqdrD6jpeDYPkwVhi6iHlYtLEJl3IiygRZOUshUYUQbY1WLkX+nFlGptL+ZBBWKdkiiMmb6FgkE9PotBL/U/BeD2zXuTYHA2D3/ns+sMVc1hJT0QotRD2gjvGR+VSVHDzmZTCKTyThrWeOSGkO2HjAvAK3W1tbWBWlGJAOqIOOlW2zY+LGWEKQyJnGOKWVaPEBjZPyqhBf1Ktva2hCPxx2HgoU2WGSh2jNTyaASsIpEvQAWNqGxxMI3s7OzSKfTzksmB0C9JztH3pml5qQkM4U5+fnMc9UeuzQGeCctIkODqFQquWIWjBvSONU56NdK15LKlQxeAK66n6ZVKrNe7+js7CwymQxCoZBDj1jmUZWNoig8D/WcCZ/xqQY8jV8AzphVFMVnPDIGro6NygB9lQsxVKrQyilj67XzrCv/wCIZSiKzxEAqZJs1wnlaD7naUZdCtpuoRAaWamTtZG11pVaoLpbCCNxsCyPUe+jU+lNSl08ZAwgYG4xVRaNR1zCcXhtLVZKJx80mgaurqws9PT3o6elBV1eXt95vJets48fqGQPz8BcVEBUzIbB0Or3AU9G6uFzzRq47h/69vfxUdKlUKgAp8nN5wDXuahmdfC5FARi3s6UUK7Fe1YhTZr56yIzZj42NOeFpyXNq3PH8h0KhAAxGQyiRSLjqV/39/e68dHV1IR6PL0BWGjHUG+WaqCClQtaSgzQ6pqf397dtaWlx1cgYo9MURf2cSj0eygQdhUIB+Xw+oJB5JzVP24Y3GDJguk1bWxu6urpcHWJ7zlVRVrOOSmAC4BS/IiUWBlZDe3Jy0qET7L/L9BvLzfF5yPV4ySpjLIGR5xmYv2fqPCmqs5jhqJ9heQb1KGMgiCBQl2ixDp8MU6WsMDuHetN8T+6VhdsBBHhC1cLVQIM8ZMXRlfVLVqF2S6ISpndgoRcbz7SMynotQAvL6Fe1QtVK0pxfq6ypHCioaflq+ggZsolEAp2dnc4qr6a8oJ27vuhN8vO4VgojhUIhTE5OIhqNuvnyuXQPbBpLrTF7O2w8ji/rxXCdNebNtbYxLd0zFaS+WLUVggqHVbruahETerYFZayRREMJmGf8qjLW8EwsFkMikUBvb28g04ApgDbtr5EKmfMjqsKfkS9BFipJdyq0GQ/lVxZY0HrWqkiWmrciZ1wv7sXc3Jx7/0QigWw265SVzSSw8+S5ampqQiKRCDSo91XnqmUdlTSkCJlP3liuDWUIAFefXeOZyiUox7LW/ax2+GSMolD2rqlBadE1RTQ5J6ugfAqr1udQD1l5B83NzQ4J5GdaJrj9XP1sG8qwSlvRLzUE1cCoZjQkhuyLCWo1IsJHPKCElnzEFRXAPqZmLZarDl1chSOU0KRMXS68pqUo9K3PTUtSFTtjRowfl4sHVvJMVpmpNUjhDsAJKC2GAMBBf/RoWHXGxl84J9/cal17nS/nrGQJGg5sQMJ4EoeSSlRY0MDjHirZpxEGhX12+xzWgNSYNwUvgIAA1edlzj0VHj1khjb4f7YzTq1KY7Hn5HuSbEQBoyURaZQqskHjJBwOu3StTCaD9vZ2p1B1zSoJFXDd6Y2RmKjhrEQi4ZAqa8DTW6b3ol5Me3u7C1PRiKKhY+HMateQf8dn5rxVlqnCU2VMNGBqampBtTAlh1aybo0cPrkDzJ9pe4erMW5871vLc+gZ43xYxIP3kncyFAoFDOZy8V5rLKgy596p0c9QSaXprL7RMA/ZEl80L1OJFDygQDClSC0OjV8wH8x6QfVsmCWPaU4vY920+GmdKvOaF9wqZY2Pa2zWppAopFwJFGwVJt9XhQyfj7VjCZdzzYH9gpa9POkJ2VhOubSVegwinTu9eAp6wvmE+i3ZRiFjGnrAfqFF4Wl5B1rcXiE0PrdauIs9i2/N+TlKLNPcSj3LakRobFFr5zKTIJFIOOXM/HvC1PUYcdXukwpbfeZ4PO5QFWAeHWL6jsZpGb5hNTGeGVWYlc6Hc+Fn8t6ysQXRKBqf5A6ogcx502hgGENfGofli2e+Wk+Nc6YS4Br4UDplh8/NzS2oUW332Rd39XmdtQyLTGgcloagGlS6p9oEpRxErHIMQOAcq5Nkf17NvKmQWQuDhC5gPsWKziH1k6IAasDbkJUiY7beAO8F9cEzwrLWQ2AVMlM86FnaMpF62NWC4d9nMpkA8xAIbmC1ikEPGstyMh7GTdB4H6EOClsliaiHpt9znvaA2hJ9Fhmo5OCpYtA+tjx8VG5UZvw/Kn4eGMalrCen7GV91coW1Hnr2jc3NwfeT9csHo+7OKUaBIVCIdCvmmeGwri9vd01N6BC6+vrw8DAgMtl1yIQ1XiZemZYgc5nbMXjcRdTTiaTzpAjCsG/p8erdXNpPOjP+Sy+Mo6NInQttmdqvFIhKyKh8WXmzAJwJL1kMglgXkAR4uQzVEp4UeFNYUvCFJUZDRt65izUQuWsd9SWmmXxFY0j2rtZy/pxUNna8847SkPOfq46Ifw7NZgtoVHvC3+/VsiX95RGDuPaGi5RmewjmpWLx6s8sMgJ369aZM5nOHNdVCHTgOR6dXR0BBxHOn+chypj6iWGFYiwEGWh/LeM9GpHzQrZKmNO3Kbi0OqwXhmVr3oQvOCsvsSfETpWAkO1QyGk9vZ2954UqDo3/lyT+amgdKNtnICHQiFZhVBVSZdjoS42d14UnaPm5Np8QWUWEqlIpVIBL04PnVp8Gge1RlQta895tLS0BC4PBX5nZ6eDP/UzKUSz2SxGRkbcpSGC0dLSEujLS7KPEuio2DS+VQmUrevOzi/hcNiVsFSYmcp4bGwMLS0tDu7lxWQKjMLRnKcaDdqowXY6axSfwjd8e0vlS4MDmDc2w+Gw4yKwwAXvdDabdf9PQzYajQbIgtWcJStw29vb3V6QCEeonO0MmXYZDocDMWQAgVaCqVQKLS0tC4ilNne/luFTKBYCXcwrtH9DpUsFYXPsa72fOkf1MLVrXzabdeuknj/gL7trjRk7N+sRKzJTi5esBpt6pbOzs+68kuNEjkFHR0egyxblIs8r/16dTK22x7TeyclJtx5ax6IWuVm3h6yYupKflEShUB1hLGXx0dsjgSSbzTovTg8JhRMvY7UWoE8p8D3D4bA74HpR6B0rq1QVslqKFJiaxmIL7yv0U+mB4+FlrEwFpSWJ+GLuc3NzLh9XoXIAgVxqHj7bLKEehWw9ZDWMePEZD+RnWhSC3jGVMHM4Z2dn0dbWhp6eHgwMDGDNmjVOCRMG1k5htjTiUmdI583vaUBMT08HCFn0blmph54aFRK9TMaI+WJFOvYR1jOjUCEFjgqreof12iysqlwIolscMzMzrvkFy0DyHJL5TIMkGo0iFouhu7u7ZuVhES41SmOxWIBJzz0G4FjgNLY1ZqteNNdc68svx1CPsZK7z3loHFrjz5S5tSoAnRfPliJCDENoyMTOuVzs1Td3fQZ1ZNQzLeclLzZ3PR/6+ZR5oVAowPUpFouBxh8MWwAIVOrSEKxtnUmHjYYeUbt6ZGZNClkfWJWxxtX04vHiUGABcPmK9N6AecYaISVVxoQLLS5fqWBSK1uhEguXa74dAOe1sxgI56nUfYVflFlN75tCwsZhqhGsFuYh7KdEEf23Ihe0CK1nbvdRY6K+ijO1Dhsz4nprjJJl6ZQ9znnkcjmEQiFXoIKFFSjoGU/s6+tzdc+1WLwSZFSxVWII8cxw/vQQfQYWsD+2zTKv3HMaUhozV2Xc19eH7u7uQOOScvBlozxjXwzSnh0rPH2xRUWt9LzwayQSQUdHhzNoybOoRSHbO8Y90Zg97yoNT4WM+Wwqp/TlU24Hati11zXS720uvsYs60GxAH/5VC2dqoYO56x3lWtKB4dnhjJeX6onFCWtFQFSpc65MSZvyZYMo5LvkE6n0dbWhlBoP5lUq6apca1NQrQIEEM5jTgvVStkjWXYNBCNOyoJyhKpuBiEuCmg9UIRPlCPpB5sHsACocYXL6kG6O3/2c/mpeX7qvdEchEhSB/zsBahqmQK/lv3xApWFTBKWlFlTA9ZL/hiDMR61l7XSqEhwoP8LEJyPBdqRKnACIVCrtQqvU9tp+YrKlPtZefv+djnvNxWMVlvFkAg1qzzVONB05osgmJhzHqGhgN03dX7svvBPdE63VowQnPDeb4ikcgCtKVepVFu8HmUEKrCn0qDv1uvR9moOfuMH/3ed7cVPfKRMGsZei9VXqtzoWdT2euaWaNQLxUznRo6WlqESasa8r5UiyCWewZrOCqhk0YCW7jyrGr9BWC/QiYngi9mD2lao6aOWuJlNXe2JoXMg0NFzMXVUoE8IBRUmq/G92FpO37PeCE9Ij5oLBZzFnYjvTWFKJQUBcBtJIULBU1TU1OA6csDQIuS8CthUyplH6u6FuHq+1ufYqbxwHnr7+r66e9SaNZbj7WSueuceUFIPuNnauxGK5IB8wqS1jtTyugVa3lSm/5U7brr79PjIvlPDQh9qfHG88EzpudDS8Oqx91oJcxRzqC236tBpqkdWvxH+/oylkzECJiH8BpxjvS82HQhWz2NPcDZbIX1AXRuajRZY6pRKESlz6MyVdfdOh+qtO0+Neqecj2okC0HhucTgHOoWHeCijqfz7szHYlEHHRMyFfrNehnqjKuNTSjhrvyAjQti3pkbm7O8QympqbcnKmQGcqbnJwMtNLUbAPrbNqCStXOvyqFzM2enZvF5Xdejq8/8nUM5Ycw0DqAs1edjRc3v9jFWNV7tBsMIKCs6W3oIeX3ZN9qechGQDPA/oPwT3f/E/7qjr/CO499Jy49/lKEQiEX/7VVkVpaWlwsk4dyenraGRxkyBKKJFmHUGQjUlbu2nEXPvHzT+D+vfdjX3YfbnrDTXjltlc6z9nC+YvBkertUwD4yCKNIIwcdNVB2DGxY8HP/89J/wefPe+zgT1VLgE7V7FlHTkHPFe8CPpS8pZVcLWse7m5X3jEhfjAUR8IxCLV+lfrWY0+Es98RWI0vt8IpXDNvdfgmvuuwdOppwEAR/Yficv+6DK8aOOLAmQnrVms3otCkRp7JeNde20TCmSckbHeWu/r5Xdcjr+9828DPzu051Dcc8E9Dm7UbmHDw8MYHBzE4OCgE57sP82MDyIrmrdOY9qm69Sz/uXuqQ5rHNmXz0tWha3GlMpcvnc1w3qX1//melx979UYyg3hkM5DcGH/hYFymKFQyGU/pNNpDA8PY25uf1lT21OYCpn7RG4FsP9ukMio5W2tcVRNPPmf7/lnfPIXn8RgdhCH9xyOvzz0LxFrjTkkjXJudnbWlTseGxtbUBSJvCZm/dCYIHpHmc94O/kkykiv9izV5CF/4hefwBce/AKuOuMqrG9ej3v33IvL/+dyTHdP47DZwwL5uBqHozsPzOcgc6IKxVABUDkSq29EPJMjFArh3j334gsPfgFHDxyNSCSCeDy+IAZM743GAedPuIaDApdNr0nU6ezsdIetESkruZkcjl11LC4+7mK8+j9eHXgejnLWt1XEKihV+VoozL5vLePeP7kXhdJ8lbOHhx/Gi294MV535OsCv8d15SXQim/pdNoRMnh+fHnetlJQPcpY5861/NXgr3D+18/HeZvOCzTH4EuL4WjIhgqZlbe04Idl3TfKQ1vfuR4fP+vjOKTnEBRLRVz/4PV4zY2vwU/e8hMcFDtoQQMOW0OAilgZ+Nruk52uWCNaz7dC/LUiQ0f2H4nbLrht/gzPFQOeOr3gsbExDA0NYd++fU4hp9PpQFlcfrZW0dOGCI2o1sVR7p7aoXd0qVCR5YboXW2Eh8y9+taj38Jld16GT73wUziq6yh85t7P4CNPfgSXdV4WgK4JRdPLJBubKBVLkhaLRWfE0bAmh6ijo8MhXTSUfBkFlY5vPPwN/L9b/x/++dx/xnF9x+GqX16Fv7jnL/Avh/+LmxflNgl/ExMTgbtqi4kQbue9AOAMO+ooLR+q7TyrnX/VHjIA/GL3L3DulnNx5rozkclk8IL+F+CExAl4fOpxHFI8JADV2WC77/3UawPmyV2E/spZi/WM7EwWb73prfj8Sz+Pv7vr79yhUOIKY6s2XqZxbnoFjB8TstbYoE21qWece8i5OPeQc5f8PesZLxaXUqNosd+vZ/TH+gPff/ynH8fW7q14waYXBDx55RLYNAPNUdZUOptSZtnJ9Xo7nDvn97HtH8PmxGac2HsiRkdHA4UmtGsYYXh6YtrxiwLLwluNhktfdtjL3L9LpRKuOOMKXHv/tbh7791Yu2ltICTAf2uIQOOxtkY4FZ7WObaFInwktWpGNBzF6o7VgXNBpES94/HxcSSTSdfUhs0+FFVTlE6RFO1Q1CgjrpJ7au+o/ZkOO49yhnU9g59x9T1X44+P+2NcdOxFmJycxEdP+yju3Hsn7ivchy0tW9xa8S5OTU0hk8kEeAY03KiQ6UnTQy4UCm69Nb+dyEWtzsunfvkpvOOEd+CiYy/CzMwM/unMf8KtO27FHRN34Hmx56Gzs9MZMUSB9AxTlhPupvHD8886BFz3SCQSyLbwIavVjJrMwFPXnYq7dt2FJ1JPAAC2p7fjkcwjOKb1GGdlKKFF2b4ac+bCcChkQkHmi+00Qmi9+wfvxvmHnI8Xb32x+1wKTu0DqzCoQosad7KpTgpf+KylAxGjArCkMvYNJS0BqNtT8I2Zwgy+8uuv4OLjL/bGt22MkMqCl8aSTsqlNC1HPHCmMIOv/+breNO2NznrWdMi1DvmmVJPzHpkjTQclhpzhTl8/eGvIzebwwn9J7j4mDbL0IYZJLHQCyVKocYRBRTj5Eqy0w5AtiBLpc+5Pbkda/9pLQ7+zMG44JYL8PT40w5B0dxjQtOZTMbllmpREjsv9eJsvnclOeqNGvaeKaJgyYFKGLJxb72n9cx7tjiLB/Y9gLO2nDWPQjW34OTek/H03NMBg5IFVaiUteCKkqAY1uD9UIKfZXX7kIpK92KmMIP7996Pszaf5f6mKdqE09ecjkfzj7qQIs+kpuDSEGUNDJvPbs8THQIr7zXDohYZVFPa0/tOfR/G8+M481tnIhKKoFAq4O2b3o7nRZ+H3dndCwQjcXharMB8cj6tLCBYbpJ5m43sh8zx9Ye/jgf2PYB7/+Te+R+G5gk7ys6z6TJqcKhy1Q1SRU7hWw+7utphPeDFPF4bUtDQgsaLGimgbv7dzUhNpXDRcRcF5st/K1HKMmUVcuRaL3aJG73enPsbDnsD5qbmPQJtVMB5MjVHDTqiJb4Sg8t1Nh4aeginffE0TM1NoaO5A1992VexJb4FyWQyUHta21VqLFzT4JQABiCA/DQ1Nbk0NK1CxtCNrwPUYuOUdafg+ldcj0N7D8We9B585M6P4LxvnofvvOQ7AY+LIQ0aCiSXcQ8YTuLcBgYGypYmraaKW73Dxm1V5qjsYe0F3kvOsxwaxPeuZYzmR1EoFbCqY1VgXv3t/Xg89TgSiQQymQxKpVKgKAwAt+50vtiDHIDzMkmsI6KoCl5Ro1rkvZt7bBWAeeLeQGwA25Pb0dnZGWB1s4gN5b6GIH3yWtEuGpusqkdElGep1iYwVSlkvvm3H/02vvnoN3HNWddgY9tG/M++/8E/PPgPaF/Vjq2tW50SorLlg7PUIR9e04zUQ6WlwfzSRCIRUMr1CK9dE7vwnv98D259261ojbbOPxsW1lNd7DCoQlNmIuGwRvZyrmb44sd8abxJh5JcFP61RUwaNb74P1/EuYeci7XxtQvmbhmnNmWGSpes5cVSm5ZjXPer63D25rMx0DaA4cxwoIUnlRi9eCogtcotPHogDLXD+g7D/7zzfzA+OY4bH7kRf/aff4Z/P+ff0TXb5Riy6l3Sw+Td1BKxCq9SGfD8s+Ia86wpqKiUWYeg0vNEyLdUKuHIviNxbO+xOPzzh+O7T3wXpzSfgomJCYyNjTlyGeP3NIhoIJNIR3Z7X18fVq9ejf7+ftdNywrS5b6rqowV5rdldrUYkiJ4VF56nhp5lkIIGgo8qz09PY4cp+eEypa8A/XiAQScGRqntu48+8STKFUrerRgbSPzc1dIenZ21nnF6rkD87UmFAVVBKirq8vVECBfyGbVaP5zpaNqDzkUCuGy/74Ml5xyCV53+OswOTmJrR1bsTO9EzftuAl/1fVXDpbTdKi5uf3NAbioFLhUIFQIWpeYbFTtH1yv9Xr/vvsxnBvGCdee4H5WKBVw14678Nl7Pov8pflFlbAKJA5umMan1EM+UDWIffPVVAnrOQPzFcy4/r7YXyPh3x2pHbjtydvw7dd/u+yc1UvW4gcAAhW+VMn5BGmj13pHagduf+p2fPXlXw148Rp7pUJWA40lNnWuyw2t62iONOPgnoNRKBRwdO/RuGf3Pfjiw1/Euza8K9AqlZ4yY+E2Y8JCqWrEtbS0OMVHIUUPlC/bA7yaUSqVkGhJYEvnFjyVfgpHtR3lUlHGxsZc0QaSy2gsxONx9PX1BUqrUpD29/c7trs1lJZbGZO7oUpPCYqKpJCtTvloUThf/LvW0dfeh0gogqHcUCCUl5pNYaB9f234UqnkjB1Cu4wTq/FMb5iykSQvngVtNaoK2ZchUevco9EokjNJrI6vRnd3tzN8AASY374a+hzhcNgZQQx7sDpgf3+/ewZbd+KAsKwBID+XRyQcFOIt0RYghEBsjIF/prH4iETccHpp9C6tMvbBALUcvBdtfhEeetdDgZ+9/Za3Y1vfNrz/ue9HNBLFbHEeuvClDS2m1HiBFiOKLPdYCrL2HTiFqpfTg7vuweswEBvA+YeeX3buakQog9SGBjQea1nKyzGue/A6DLQP4JzN5yCXyTljQYlPLKmqXrzmRR8oaL3cKJVKKBQLmJybdN69rUKkMXueE+sp8N6qgqCApaDSu6uQcLXniXNIT6XxdPppvLDvhQ6RUHa7tj/VtBTWOu/v73e1xG2dc+WGLDfKwmHDRWoQq7fc1LS/HSxlq/WefQZ/rXNvjjTjxLUn4vanbscrDnvF/vcMh/DzoZ/jDZvfgEQigUKh4NaICOjk5CRKpfnSpGpERyIRh45otoFCvSTAUs7X4h27uT95O1657ZVu7nftugt/cuyfIB6Pu/ecnZ3FxMTEgkp+tngN9RLvMiFq6iYLV9s7fkAg6/MPOR+f+OUnsK5jHbbGt+K+PffhhsdvwDmrzllAWQcQKCygCdVK3FLYVNmoCkk2QjnEW+I4auCowM9iTTH0tvXi6FVHe9N8rELW/9N1UfjJlhZs1AXPzmTxePJx9/1T40/hwcEH0dPWg42JjYHftYq5HHtTBQP3wzfveudfLBVx3YPX4cJjL0Q0vPjR8xkSSm6xcbblhhmLpSKu/9X1eOsxb0UkFFmQqqLQOhAss0nvpxzpbDnHpbddinMPORcbOjcgNZnCV3/1Vfx8789xzenXYHZqnjin5LlyqUJ8LhpHfD5bnU4FVrlyoJU89/v+63142aEvw4bODdiV2oXL77wckVAEZ/afibGdY47JyxdhUyAY7/NVcvOFEBq5L0vdU71TvHtqFKsc1e9tjFmrZzXq/F9y6iW48OYLceKaE3Hi6hNx1S+vwuTsJN546BsRmdqvXEm8JJOa0DR/roWc1Kgg+a8c16YeqFrnftLak/Ccdc/Blb+8ErnZHN5+/NvR1ry/PGahUAjEq60BYIflNtm4t885qPUc1eQhf/qcT+PD//1hXHLbJRjJj2BV+yq8/uDX43X9r8PQ3qEFkymVSk5oaYUjehIKiZUjNzyTHoUlHfkYyhp/1ovt25h65n7f3vtw5pfPdN9f8l+XAAAuPPZCXP/K673z15QK31DhoPNtNJx625O3YefETlx8/MXe//etsY+Aph79gToXnPuFR1+4YF7lapsvlS2gz7RcYzg3jAtuugD7svuQaEngyL4j8bXzvoZtTduwc+fOBRW6tMazpm6pwanztt6dj1HuKwdaydid3o03fetNGJscQ397P05Zcwq+fd63EZ4IY7gwHJiv5tnr+VVYXdntvvryjWRXV3JP9czqfVvqpfNVeWPnXetzvOGoN2AkP4LL77wcg9lBHDNwDL7xsm9gdcdqpAqpgBdvPXPlgBAtCofDgfCkynjdh0aEcTj3v77jrzGYHcRxq4/DD978A6ztXOsQt3K9BXQvfDLHh2KUI9fVatRVpJAdZJRO7xdExRL+5pS/wYdO/JBj02UyGYyMjDhLVS+4LfMGIFD8gwF223GICdlqgZEY5oMn0+l0YL527uXGd179HQBwCeJamMLmZKonpIJYmYWMKTJWoQXLuVl2VDr3E3pOwMR7J7zPkU6nA5dBqxmRgaoF0e38VRgzjYGQZiQScWtvD1ulcz+1/1Q3d/05L7Gm4WgnFc6Je97cvL9XdVNTUyDlqKWlxVnklXINqpn7+HvGXbm9xc4GP1trQGuZx2w2i1AoFPjdWkYlc7/yzCuBM+eNYhZvGB4edvPXZ9C2m/SQ+Z6qDLQ+tJbVJHw/OTnpWObcOz03lcz982d/HsB8eVKWMNyd3R2Ys8oYvo8SGbXevu4FjYvZ2dmqwkp27rXcU182Ae+bfTZbtEfTR3mmCLvyHpWL01cydwC4YNsFeNthbwvINZKfKEP0zKv81kp/dAYUIdVCM7y/KivLxcKrmfsF2y5w35dKJUcqZooT6wXovGxBJHrTymVR/aTrz7VnOMeeI995945SBWPXrl0lAM+a165du1bmvjL3lbk/C15/CHN/ts17Ze7P/NzLjVBpSZW930Ldu3cv4vH4AYOKaxmlUgmZTAZr16511uHK3Jd/rMz9mRkrc39mhp37s2XewMrcn6nhO+++UZFCXhkrY2WsjJWxMlbG8o6KYsjPFkvkD8nqBlbmfiDGytyfmfGHNPdny7yBlbk/U6NSD3klhvx79lqZ+8rcV+b+7Hj9IcQyV+b+zMy93KjIQ47H4wCAXbt2obOzs+zvlUyVJWUrs+YsW3CNjIxgaGgIg4ODGB8fRz6fdykLWqKsu7sba9aswYYNG7Bu3Tr09/e7Upo2hzCbzWLjxo1uvtXM3T6Ddi5hOcF0Oo3R0VHs3bsXTz31FHbs2IHh4WHk83kAcPlpzHXs7e3FmjVrsHbtWqxZswa9vb2uWpEtEJHJZLBhw4aK567rzBKknCfbFg4NDWHPnj0YGhpy/WHJDs5msy7XVPsds+RjIpHAwMAANm7ciPXr12PVqlWBMnG9vb3o7OxEW1sbJicnsXnz5qrmbtmbWkOZrGkWfGDBd54b1mAOh8Po6urC6tWr3Ty1+APLImqzdGudptPpqtZ9qXNjmbCZTAbJZBK7d+/GE088gSeeeAK7d+92fVi1eEJbWxt6enqwZs0abNq0CRs3bsTAwECgvCPPfVNTE7LZLDZt2lTVunOOWro2l8thYmICIyMj2LFjB5588kns2LEDIyMjmJiYCOQk6zxXr17t7mVfX59b8+7u7iXL3TZq3X3PxnM1Pj6OPXv24Omnn8aTTz6JvXv3IplMOgZz6X+zPGKxGHp6erB+/Xps2bIFBx10EFatWoXu7u4FhWd8d7VR50WzDJLJJIaHhzE8PIxkMolMJhPIkAD2V63r6urCunXrAueFDQ7KZaJUM3fKBZU3nGcqlcK+ffuwZ88e7NmzB2NjY4HCMuyE19bWhr6+PmzcuBFbtmzBhg0b0Nvbi3g8XlXxpMXmHo/HF5TcZUZBMpnEvn37sHfv3oAs1AYSPN9spag59f39/Vi3bh3Wr1+P1av3V/zSilyVlC31nXffqEgh80NYUcU3uGG6KEzUZ5Fx7VBi27Txb3yJ8Mxv1GRyVnXR/pwUtr6c38Xm7nsOFkMnNML6rb7cMxZN0PxAzofPqF2ktBrNYhV2Fpu7XmamIulhtAUpdB1nZ2ddOT5Nz+Fnat6d5tLy/bj+LBXH56907pqaFY1GA8/Cl6+bjc0rZQqU3Rc+p7ZFWyoVqp4zY88N004ALDg3nKO2+OTZ8Z15nnsaqLabTLVnRueoqYQ6J55HKiyeBwpnfo7NWdezzmpMSzUJqHfd+Vx6hnimbK4s07b0eTRnXOduu7Ut1sChnvPC+6pFlHy5uL68da0OqMVZOjs7A4aQb1Qyd66rdbKYDqT5w7ZuBIuAAHCV2lTRsb677QJWSQqgb+6qkJma1NTUhFKphKmpqUBXL5Z1pu5g210AAdmu58e+bInTxVJafXMvN2oqDGKHXnStymXzYOntML9U+2BSqVII6CFgbpjmiNmm3I0e+jyaL6depSotzoEF1pmjxxKEmovtayhe7TPo52m7Qs6VOXZMyFeh2dra6go+2OIKuv6qJG2+p9bGrnXoueH8beUo5maqZ8ALQTSFwlU9DW2KYEuFLsewz6I5itqogetLIcZ56WUu/W8uvO6N5kk2Ys0VAeI686V56nYdATgDlX+bzWadJ0nBy7PH/EyfwdmIYb1+3jnfM2lOO6tIcX42B5XrtBznxhoQ2rpQW0pqww9fNzEAgdzfRs1XZYv14vVcp9PpQN1zzoEGDr9SeS1Ww73egixqOJQrE8x7p0Zva2urQ2YBBIwMrgGdM551/p+9u2rE1jrqVsjWgrICNZvNYnx8HGNjY0gmk4FWb9qmjp6aChxN7rcvvTiNHhYO5mWhUaEdZeitEXrk31BpUCBQSba2tnq9v1oG14fKjOut0K9Wy2lubnaKuLm5OZCoT/iSFrCF7bkOqujrMYgWU2BcZwolNeIABOoOE20olUqu2Xhra+sCIaGGz3IpBVVUhNxZI5pGRbFYdIgDDQytLqXnXUtDUqjVs+4+yJrzZD9h28GHd43niHOlsZ1KpVxRB94HRbC0kIzPo2/EutszpGEb9nDWO0HjiO8RjUad0rEGvwr1RsoaK+h1ruwnPDo6irGxMdf8gAYFz3ChUEBbW5u7k1Ym1nLWrVesxrL2DNa+x5SJWrqU/QkIAdsWl42sGa7Pa51DNSSJkGgZTN5HOinqZPE+sphIKpVCOBwO1OumPG1UVcOGeMgKg1F5UTFoizRunh4g3Tj1JOklMBaglq31NBupmG0sSmPgGiMhtEToVyvU0KiIRqPo6Ohwlq92u9FXrYJVhREFkVrWqpA5Vzbw4Lpp3CocDmNycjJgCPH/GDOJx+OB/atn7taI4yXnRed653K5QLH6trY2d7EIF7HnthaCZxzL13Ky0UPPDGOyyWQSqVTKvRi3pEFEYaD1rwEEBDUrvimcWsvz+NZc7ydf4+Pj7v4SDdIesjzrWkGPxlCxWHTnS9vQ1dsQppLnoeyhQuN9pdxRxaxeJs+PrSZoZYsK/UbMXQ1+clMYL2b3LfJu2MWK8g+Ac2AWMz7rnZ8qY54VriWNZs6R6zo3N+cUn3Zl077T5Rqs1LumVhkrugPAzYkOFD1mGrqUL1xfGj3cIwDuflMm0QhVI6Meo79hHrJVYBSsqVTKHTZunJISuEi81FQIhPcUBtevCi/pXBox1PNUhaxN0Nlir7W11R0orgEVXTgcRjwedwKOMRONUdTrZdoymfTkKVAZFtA4Db9SuBKG4brSCrehh/b29gWQZr0espJEtPOQNp5nJxl6xEps4kUolUou5k+rXLvOLFdoQ59F7wDbGo6NjTmDjB4+9x2AK5+psK5dEyU51vo8VoFxT5UoR2+MBKJydaKB/a3r9AzR2CAhkO3oiAjRg2iUQtNnUW+finh8fByjo6MYGRkJPFc2m3Xnm8ZEc3OzV65YhdwIpazvxX2YmJhwZNGRkRFHfuVLFa7KzmKx6AxPol31OCflvGOu7fj4OIaHhzEyMuKMG75opHNuPnKU8n58dZ8boZSVx6SIgXIc2traAMA5KRrCsCFSOiVqbPN+FItF95yK1NUzGhpDVphN48a28TljNxQ0JAUROlJI1Bc/5s8tVNGI5/B5EhYK1p6rCsXx+aenpxGJRBbA1VrzuF6FpvNVwaQ1ZQE4BUUCjyWisSZ0oVBwcSqFUnkhbYtBGkyNsMQ19ktvmevGteOzaItOwqG8dDMzMwiFQgtqAfsUynLkLVqPn+dFzwxDGXwe9Xw0hqz1cn3PUqsy9hlwavxQ2Pji1ZovzLUG5uvSNzc3o6urKxBXnJ2ddcKXRke9sUJFcGwsm2x8esf0MDXeqZ5NJBJZIF8aqYDLDbsX6XQayWTSKWSF2bkfRLuAefKXnp9GGZ7W01QjUw031ramsaDwLUmVsVjMZQhoq8NGd6iy89cXY9mEqikH1BDT+u3kG+i50Hrj7Hne2toaMEbopD1jHrIeWLUoVSmo8AbmA+DagosxKMJ3VGh6sCwb08YeGi1g9dJrHFUJXSTlkFg0OzuLXC7nrF9VNAq5W0is3qGMVz14wPya0kpUmBeY93QmJyed906v2q6DWo9WaFU7rPGjZ8ieJ517W1ubs7oZLuD+0JDQJgm8KMutlH1hBP18AIGYd0tLC6anpwNMWHuf9D0Ufqtm3X3KWIULERXGj6mwqDxpdNrPY3yNxlI0GnXGh8Y7OXclP9ZzZso9h8KnhN8V1aJhRwSLhhyfhV/LvfT3GjUUWbTkOnq91ijSM6YOS7330b6/D7bmOtN4y+VyC4x/Gs0tLS1OGROypne8HH3W7aBM1NBWW1tbQHbr3eJ5CoVCTv4R/eH9nZycdPIxGo0uUMaN4Ko0BLLWf9tDQUiasa+2tjZH5FIqOuNp09PTLp5cKBQWpBDZfsuNtq70mXzxCFpBjOuRqUyFTKGmh1k3vBzTupZBy48XgLE7AA6KUWKDtsqLRqNOoJZK+7uhcC+Yk2eZv/ZrPZdf02lsyo+uL9ObotGou+DMy+U5Yayca1uOc6CM3+USBL7no9JtbW0FMI88MOatECZhXRWKtUKnKrytElOmrHpjhNW5Vvyqd8J+tjVaaRBxH2zMsJK5+57FhgWovOgZE6ImBE/oN5PJBIxpDlUgNsXFki4bafz79ldDc1xj3muGkzh/xj+XU/7ZOSlSyDg9jTd6xjTeSJLifaVS1t7Yy8ErsPKEyJ8aLQpT64vhLiIPalgzxAHA7Q9j95bkSmPvGVPIdkF0YXigKDzj8bizUDTPkkotEom4y0bPTXPt+Dd6eQ4E7KEeFufD+AhhDxIFND1BIT0qdE1j8QnZataa60NrtLu7G5FIBLFYbAEsrn9H2r7GRzKZDGKxGDKZTCAlRD3vcrmRtQ61YOn9Mm1GYaO5uTkHh/b09LjiMJFIBJOTkwDgPAxatbYl4uzsbKBAxXIoZL6vCgTNceTPQqGQM95UaHBN1GP2vSodPo/SttVUNnsul3MoA9OYrPLneVKYj2uun6GKWXPua2k3qcpLeR2cN7220dFRDA0NBWBf9fx5h60nVy6X3c55uQ05PTtEq4D5UAhDYlQQmh9dr8dpz5h631Yhc91pzOsdVu84kUigq6sLXV1d6OzsXJBD38i1tMqYCtEq6La2tgXhIG3FSbmtyIWSGomsxGIxp4wtClePl1y3QrZwjwpZEp5IL+cE9eDx72mZEzptbW1FoVDwNjznJbKW1nIMFYB8ttbWVnR0dLiKLYXC/qbXJGhks9kANKoXyrd5tQ6uY0tLCzo6OpwXZj1jK8gVKiM7UtMSqAhpCNk0BV9xhGrnredAWdGac0voiOegu7sbAwMDTiGHw/urJs3NzSGTyQCY99aoHNRTbkQD8aWeST0usje5hjy/NEAnJvb3y6VgoDL0rZP+rJJhUR5NC7KeDhnIhJ+5TlZwqlKkl6351brmfCYtwMAQTy1nXj2WyclJV1mMBC6mCQ0PD2NsbMwZGIQVbbUxm4uqLy3ishzwqt1XixJpMRVyP4rFolOKzPLQeTbKaLAKWeP05EUokSsUCgVqAVAetbe3OzSrq6vLse6ZCtdIuW3lCc+twtVMc9LzScOUa0yjHYBD2tQxAObPju0dbsOQtXrJdSlk4uwqnPWwa7UYHwTEjSdUNz097bxmepz6vb7qsbirfUY1NMLhsINjEokEYrGY24h0Ou1ihIRAAARgKV86RbVej7XAaCQAcDEO3/vrPCjYGEZQoUlhYL0EHypRy3pa61VjPArxMzZFBIIlSWlx8/MzmYy7APq3NgbL9ddYZr0xN/tstnoS94FszPb2dkSjUcfCz+fzgXxdXRvLDbACd6n1V4THwtXqJWvOOu8sCTiKKvC+EsViHNkKcYUBLemolhi43h3N5GB5SdY4IGSdTCYDhCPORx0CGk42RMIzzp/bfNlGG/6Ui2rEcX3086mQ+T2RB61r0AijwYZJ1MC1Bh1RTYYgASzgehCyJlytVRWXw0MGEDBmiPS1tLQEEErWiqDMmJmZcWedd4ZKmfKD+6/n2IaU9Hxb/VjJaBhk7SMVcUIU5rSMgHkYlxfFxnPoIbNsoPWMfVDSckIgVBoAAiU829vbMTc3t6CMpxoKXAdfHLaaYQ8ABZXCKbwgFjKx8CXhdxVQamUCCBhWi4UK6llfVV70jvXy02rVXFy+AASsbRsvtYq4keSXcs+jgl4hX1XIPB/5fN55FxZpUhjQKodKzrovRlnOg1UYFIDLHY3FYu6zub68s0SBNBUHwILPUkivEexwZfymUimniJPJZCBurAVAbC4q19lyKtRL9TkQyyVfiJ4we4DMb86Fa0FloMQk3oVGeMkql9RD1vtEQ4v/1px6H0LEMBSRB67rchg31pDV86NKVsNGWoJWf8dC0cC88aQhjUYbFw1RyFxcq0x5mAhjaJoKCQoKDfBia64Ya8rqplYrnGp9Hl4WxjV5qW1hcavQ1Fqzno7+TjWwr/UUFCbRy2LJBToXTRWxBVf4tzzEFBaaA6skvHrW3xfeIKREgc/19R126/HbkAAF2mKEqEaOchA82Zi00vk8fAY1qDSko0Kaxp82JKlGACgKZdPi1ADiedH4XyKRQFtbW0Ah01gi2UtzOCkEy5Ehq/GO1btXYpF690xxIpuaMDWhdFaOsuz6cvu32F1t9ODncJ87OjrQ1dWFYrHo6s1bg4QyVosLaQ5sI5SDL4bsQ524p9xzDh9JUw2e5Q4xqkzlvDhH7ivRBzUyeb7UI1YimBJnOzs7A3nViirW+2wNiSGrtclJh0Ihp2AVlgHgguRMOGexDbKTm5qa0NnZiXB4f2ENxiC0KYPPC23Us1hlzMvC+dPq4/flBL4qHV2Dai1ZvSSaK0dhyINkK2hZqFPhS5v2onW6GUIgFM5i8CRmaLWdeqBrJdYQ3iUsmclkAmhKOBxekK9cKpUWGBaLKYHlUsgc+jyshkaFrHBXoVBYkKPOfVNSjy9tpJp1V6WmkK9NpeKaMC2LHW56enqcl0xDenZ2f5Uuev88M0QyNHSgxC81FDm3peathoTCpUpG0y5y2hGJMkafb6nPs79XrdFcyVAZQweGfBR6b6zXoDF6ZYgTfVQkg2ejHn5HufixRZrUyOFQWaPIgn0tpzLmsEpZlbN+XygUAhkHXPN8Pu9kCYBAhgdDZn19fejv70dXVxdisZjX+6/lWRviIXOxae0BcOxRjYlxEUqlUqA1Fin0tALpTbe1tSEej7uqP/QUGkliKPc86ukkEgkUi0WXx8a5qZBS6rt6OhbyVhJGpdaiVcZaOEOJHrawB/eEn8X3UehP90CLENDq4x5o7JbxoEYoZHpavNz0uljfnMKdBpwSMUql0gKPiDDaYl5Zo5SyDT2oB0Mjgt4OX4SIWZBD6wDTo+Y96uzsRFdXl2tnWO26+yBfm36n4Y6mpibH1h8YGMDAwIDLkOBe0FhiTI45x5OTkwHF7QsbVOshK4Jmiw3Zam48v0oyq0TpHwhDzX4uMA+ZU6ky3aatrc3B8aVSKdAYQ1EMIhm8k6znUOudLAdXW69Y11afxSpcfm8/wzosyz30M/TflIPqFDL9T+ty0/Dp7Ox0d4LtZ/v7+11LXV/YoJbRMJY1BSQwXx3KBrvVyqJC5qUiCYowHb0CesjsvasxzOWAlHwecjweR6lUckKIzwtgQVyFViyFs5JGfPHASj18VaassqTF3plPSusOQOBzuV7qKbFgPD0M7aqkRC5tR+fLJaxljYGg4cNB71GVANeclmw2m3Xv4StEoSVJfXtb63nxxf59wolKmeeIyI9WlGIFO60epZ4q1zyRSDgv2VcbupK5+iBkVZKctxqgvb29WLVqFRKJhEtNVIVMQhoLbxAN4HmuB7K2hoTGvGmQUYhqNSvuv8ohXYNqxoFQFkqAIqeD8o2ZA4VCwT07lQO9ahpsRE/qZS/7PGTdPzpT9m75QnX6e+pRl5vXgVhvDkUKtamNGvdEJChD4/E4+vv7sWbNGtf7u6enx3nIGpKtx1FseAxZoUhraTEGpTABLxZp/CRvKUOvnHfcaKae73k0fSUcDruNsp6mxsAta5OXyKZrVUPQ0bVUj4F5gVqViGut87DEJ90H7fLD5yPSYVMH7PvVevD08nJwTXWfAbj11lxIrgvTLwhh6/5ZQdEIRayQqwodG3fjPDhohGp7PS1TCcB5P9o7Vvvy2vNfybwt/KtxQRWsvE88r1qDmM08ADg26tTUlENLWltbXXU9PrNdDzXOK11zG0NWFEShd1tMg3fXfj4HDSUloi52HxshY/jclpQJzBdPopyZmpoK5CDzWbkHKh8JWStpqpKzsdg8lQioL8skLvc3ltOh6UTKbbHy40AoZb0TSgBdjIBIA7mzsxM9PT1OIScSiQXGUL2obcNY1ur18HuNGzHNxpaH06Lp9BzJSI3H484raxQkUM3zaAxYvXIqBxt/YJoOvTsKVrKxlSluvfylRjnYWluhsVMPE/YVKvdR+inkFKqhd+3zaBoN71mL2iIINv5NT4kGHCE9eqChUGgBImHRiGovjL3AVqBaQWRZ3qwkZhuuaDcuAG6urPmrLeuUQFgr/4BGgg0P6LCkJg23aBUmZhVo3i4VCJUGkQ0rzKsZPkGvylXDZKXSfLMIPpsasJwL149npFwqZSORNzWKdF6LwcJK2OQzc85ET4gg+gpu1OMl++bMefsUsVW+5LNQzjP/3GaglIsxL+dQ5e8zyHyEPs0EoaGsdRtUVtX7DA2t1KXDQqyElyiYrHdAJjbJLFYo2fjIcpIDfPElTdMibEZFTMXIPMH29nYACMR47CZaa3Yp+NEqZIU/Vdizgg4hMCokKmQAzlCisNL4M/NjLfOwETml5YZeDsvS5GHnpc/lcigWi85LptdJMmA4HHYepXaWscgKP7dSpeaD7yxxyQpYzfu1KAahMbJqeSYSiQS6u7sDno/tIVuPp29f6vX4jAnmu9q0IDX0tOiHMlkBBIyYamO2VjlwzxgSINEGgFPK+re8J0rUITRvkQgazSpgGyVjdH0tW1nnSMSQ3ALbz1xJpuR0WLZvrWEkO1/fS/9PHS3KEZ9Mam1tdQ6MdUb0jtODXm7Uk5/B9eQ5YGEi1quORqMBA85ygRpZJU1Hw2pZ+yxAesVsFs5ydsrsVcXBC6KWn88SWY4NU8vaQsR6aawC1ALjhGbY+IBxB8YalKVsLatK5ueLe3A+2rZQCXK+A0NlTDjMVl3iYSR8xktGg6m9vd3BhI1QxqXSwoIN6sGwPCnjaoxh8llouVOxKSPcpqhVo9R850BzMNVgURhVDRg9NyzzqA0caMBRKHR1daG/vx/d3d3etIrlIjLyfNmKTEoi0z3jmVJuhMKRquy5P/Uab/xcImjd3d0oFosuxQyYF7YMEVD28GzRU7Nkxa6uroYrNq6rvbv6UuY415ylQNPptMsPJ/OeypgxTHIMlF9Qq3KoZm/4TKFQyIUrqLDa29uRTCYd8XVmZsadcdt+lnccQMBpWC5ni+eAyphGHVnV6nAxTGkVuc+LbtSoWSGXs5isEsvlcq4nspa044EE4DaFMIzGzjRgzg1rpOWqz1BOAFtlzLq57O88OzvrLG8KVl6cvr4+9PX1OfKFVqyxENNi8ywXR1ZvXVtEUkD51ssX76FBQWvcpkflcjm0tbU5AlWtxR6W2gPOWQtjcN8p0DhXYD5eTzIgUQnGd2wOb7XQkiITarjoudDzrL+j7SNp1Cjprrm52cUBabB1d3e7GJWelUYYo+Vidlx7ngOF2K0hoOiC9XToUajhVy7uWOl8+ZVeCr3Ezs5OZzySP8CQBcNK+XweY2NjgfBZoVBwXpH1NLVfeb3wYzl0Tdty2hKm2qqW/0ev3nrHPT096OnpccpYc5EbLR/L/R8Nev08Ip10OoD9pWEV7eT/8ffs5ywXfK1Gv+Z/0xFRI45OgCJqPlm11DpVO2pSyL74goVkNM9V26IRUmVMi1Y2lRg3TiFHhasbsVHloDt9Bo3TqrKzKResWMQ50kPjRdfgv3prtQhYJSIoROpL3FfGN/8WwIJYosaFFHLUWOBi6TKN2AcfgYQCViHQ2dnZQK9dAA6J4FlRY85a5LUU1VCFbItTqCfJeVmehPWImE6hcHx7ezu6u7sDL807tnHjas6MCiEAC5SrChvNy6TxaSvjaToj31/Pi0V7qoWpF5u/kp86OjocKkLjjPNjyIVsfIbMFAFS41l79dabzqfPrOdaw0xEefjS9C0lKTKXWuFqRRA1ftwoo63SZ1Mjiw0ZQqEQcrmcy8NnLwNFBtiUgQa06gHOm/D1cnigisJxnh0dHU6+setda2urcxj12csR3Ro1qlbI9qDZ2JlNUSDZSBug0/LQGsbqEZcTpI1SxhaWtl4i4UbOn7E/tWQ1kRyAE7DMK2SlIzJRbaWxWljKChVqjrMvz1mJO6p4lWHL3/EJPF4YXznQRkN5moZluzRpjFaFmnbTYjlKjQdxvX3x42ovucbLbPhCzwQVMM8G+8XqszDflGfAQqdqvPnKlNaijBVx4D5aghjvA8NMfDaFF6PRqDOMLHPZft5i+17t/C25jF4VWemETpXAx/DR1NRUIDbMEBnfU9EqH8+g2mHPteafK8Oe8kRLgBI95FnR2CXvo5ajXEymLPewqIe27aTRwUYq9n4z7MXQlw9tteG8Rj2TT9bFYjG3R3QGicrSMeC51xrtTIOlfGiEbKxKIVtyAgWnflXWMRtas84sC4DQslU6OYWSCqNGpzipJ6wkHR8zkM9h48UUVEpIU3IAq1pRIduUrVoOmhVMtOx4kJRsFAqFnHVHZaBsTcYu9TLp4aRC09gan4PWuObc1eNFKPyuOaYK3SkcTLiaz0UFoalZCon5kJVqBJaFHa1g4fmgwUYvWfMZuT8cVCzW02l031irjPm5SiCzKVRK8mPxD81Hp0Lj7y2Glujn21c18+fctV4yFWpzc7PzIvXnoVDIxV7T6bRbSwrjxeZWj4yx8WIb8qIsHB0dDShkZd4zY4DGBu8lFZh68geijPBSz0vDjGEBvcustqdcFK1FT29ZGfRMf+XL8hfqHTb8QVk6MzPjRWYpUzVtMZvNBsholOuN8JRrUsi6wKp8KaBUmGrfVS5+JBJxm8K4mVaCogLTC1bvsPCjlmCkcFWmsq8AAaFSCiMATsBpPIqwoy01aT3MaoSTwizsd1wqlQLkCEJ55aBS7dKjh4ckGSoEGhSsRkOSET3+jo6OupSyCi6SKLRaDsMbLFiiPW21FZp6RbxY9Bp8zORazpEPtuY6kh/B9n8UqFpcQNOxNM7NCnR86draeuG1CiQ9Nxb50HQl9ZIZFkilUu7u8T1I0CmVSgtKrSqLWmHmWgwhnzKmR0xPtrW1NWAQAPNwfLFYdMaQJVDSONIzqGG3euB1K18sv4NcGraNVIIrSa5qVFMpaDjPlg9upCK27+F7T9/PbJycuoGpcGzBqSmJlFUMJfBeaZooP6+Rz6ehMJ4jdQw01t3S0uL4HuRWpNPpACNcaz1QKdcz15oUsqYy0dLTXFjC07akIZVxLBZDOBx25Ize3l6njOlR1uMdLDV/Xy4vLwvbuRGWZkyHObq8qApR22bcFnrUHOpavR2FruLxOIB5RapFHAidKntT45pUoqVSyaU48T17e3vdXnR3dwfKw/F59MDWyka1hl02mw00CWBbPRIBU6mUeyaeI4sY8GV72jZCaJUj+/Hs0IAgcqLEM2WNNzU1ubrorIfb29sb4BlYIlc9ytgqNXpak5OTASIQlbJ6AVqUBZivR0+EhcQjHwqgnig/v9o9UEMAmM85Zg60zXHWvWKcmPCoRQL0/JWrRlXr0HNNI59e8NjYGIaHhzE0NISRkZEFoS/Ct9wPZmvYFFCWleVa+gyIRkK8amBZ+WXJmHo3CFdPTU0F5IQy5fv6+lyeOP+WNQY0RGediFqGGosqO3i3p6amFqRM0qCenZ1FNpt1xiGROSKjmgdej1KuGbKmBUgYRluhsR0aPRqFRilsADiKPL0yX+OCRg1LtFCCDhXy8PAw9uzZg6GhIQdBspykMpBptfrIIUpM0zxYe4GA6jxkeoMs6k9IS0kp8XjcCUoiE7FYDBMTE24eANzz870Zw+/u7kZ/f79j+fb29qKvry/A5qTlXiv5xe4DzxEFlp4lkly0iIwqCEUHtMqbTxnrWlY7FiP+2dxLLXZTKBQCqRwseci0FTXcyuWo1yNYrYdMZWaLeugc6Qno3VNYjopXwwn0klmYgwJPY9bVENNUAXBuVMiWSKpeLb+nV6PVq6wSsXupyr3W/HorXzTdTeXj8PAwhoeHA011eF4o0NntTsNJvop7NE513dQTrPb8WAWsJE+NafPfmtLGM0DFSnlDWUNFS2SPZCpgviOToh3MviGEX8/QODWf0xqrlB/6b553KmQ947zPFiWqZ9RM6vIVqFC4MZlMOoICN5c0eGUAK0HJCiELHzXK6vN5PKqYx8bGMDY2FqiRq/FidnoirKIkHZ/wUYHB56iWrKBEBH6vMUHNEeZ+0NLmYII+PWUiFgq500vu7e11RdSZo2nbX9ZDIrEKWWOxrDimpTx17XSNfYrZEnTqRVh8ccZy76tnS4kv1iqvpMd3OaVQybOokLbQmsaRdb24zjMzM8jlcgCC54xKPRwOO89OFbJ+piUaVmu88feUEa7ywL6UG0JPTdOybDqW/o2SO63grmbwb6zRb4l+/DfhWhqa9Lpo+FjPX8M8qvB8BSp8sfJKhypjXT/dTyorhjl03WzWBOeuRYWi0agrB8q7QOOLSCDPFclVtSq7xc4Nn9USB9Wo1xxy3stodH9VPXUUrCFXi8xpSNqTj2XNw6ceDT27tra2QJ7m5OSkU3JWIOnfNYpx54tv0cKjwtLUFh4kWkWh0P6OVLafsDLMp6amAhAZN5Gbr8+01POoEuDFUAGvrFHNb9Y0FlWkWlKSh5/kOsLVhFVJ5moE65dDLWoaQ75iK4SLOEems9CQUMNAn6cRc7TrrwJJjR+iIlq+MxwOu0tqYTcOqwz0Zc+NzwCoVCkvJnBUKfNsqqGkz06FRqE0OTnpUCSmE/F3+P5qdNSjlK13a7+qEgDgBL56eD7DyQrReuLH+r46L32/crAy113fQ6Fv8hXU6ORd0PVVo04N1krlpu+s8JzoWe/s7AwgbJyLnn91TkKhkHs+zcyhYZdOpwNhSn5ePB6vO4RgdZU1xmwKE/dDn4Fz5h1nWKGzs9OFGvhM9bKta1LIVz9wNW557BY8MfEEmkJN2BLdgheFXrRA+FHoclFY0UXZy8lk0sESjLt1dHS43F4eMo56Y2qfvPuTuOl3N+Gx5GNojbTimO5j8PYNb1+QXgUgkNql1nNTU5O7KOp1kIxARujMzIwT3JYVyVc1z7T1M1uxY2LHgp+//ai344pTrwhcxFJpf69aG+fVGArXl5A1yVxUyLZwfT1x8EKxgMvvuBxfeegrGMwOYk1sDV57yGvxprVvcpdS2crs/sWYKo0lGkEAApfevqzBVcugcLnyvitx8+9uxvbx7WgON+PIziNx4doL0dPT4y4hITgtEUulrLElxqq0ShONUp4HKhT7LBQStTzD5x/8PK68+0oM5YZwaOJQ/On6Pw146swJZZ60wnAcc3Nzrte5Nspg6h+fkyEdrSlQSy36u3bchU/8/BO4f+/92Jfdh2+//tt45bZXuv9XRafer4VcF/ssFdD1wo16L655+Bp8/8nv46n0U2gKNWFzdDPObD4zkMlAw1I5Npw3iWmESW2ta8Y7bZqfIh70oIGlS1Ja4210ehSX33s57th9ByYLk1jVtAqvib8GPT09yOfzzsGIRCIOKaEyojLns5JMR8XMteZdYKydsoVQMEM/1ezNNfdeg2vuuwZPp54GABzZfyQu+6PL8OKDXrzAa9fUNE215JnX8ACdB3r4qpAXKydcreypWiGHQiH8Yu8v8LZtb8PG8EaMjo/iC09+AV+a/RIubr14QXoPN4OQbz6fRzQaRSqVcoqqWCw64dTV1eWsDTZ5p8BTBVaLkA2FQvjprp/inSe8E8f0HoPcZA5/94u/w/t//X5cefCVC6z5pqYmV4kGCFan4UHizwE4YcPnIdlqenp6QXMJ/g1hmqUOXCgUwj3vuAdzxfm8t4eHH8b53zgfr9n2GudBUrnPzMy4mBMQrJpEr5OHXyuLkSFOhWzr+9a69v/ws3/ANfddg+tecR229WzDL3f+Eu/6z3ehkCvg5OLJjhCYSqUC1Yna29udBc41t2Xt1IsrF/6oRTHzb362+2f40+P/FEd0HYHcZA7/cO8/4MPbP4wvHP8FZ4h1d3c7Ag+fI51OB/IzKWRzuZzbK3qSFJ5zc3OBM2gREGBeUFQyfwC48Tc34gM//gCuPutqHN1zND5z72dw6W8uxUdXf9QZkcwJpVLWuB+NIaZChULzaSDarpOlEkl01BKm5TINFhu5mRyOXXUsLj7uYrz6P169YA/13vjCCvZ3yw29e/V6yPzce4buwZu2vgkbIhuQmkjhup3X4cszX8Y74u9A11SXk2maEcH1VxkCwLWq5SuTyQTQK61wqHJG48iVyphwOIzMXAYvv+XleO6a5+K6s65D82wzHtr7ECITEbffhK1pZAFwKInyajo6OpwO0LADFfTc3Jxr48nB7BuuSTUx/fWd6/Hxsz6OQ3oOQaFYwJcf/DJec+Nr8NO3/hQHJw4O8AVUIStx11Y5JCxPR4DhkK6urgCxUTkItToCVSlkfsg3X/lNR+bqRz/+YsNf4B2PvQPjreNuI5iWwEWkp0m4K5vNIplMuliCHkoOhRbchI1XWQ38FQqF8IM3/yBAyrn6jKtx3NeOw57SHsTa5wt4tLW1BSwfa4UzPYTz5JzUutU+vQpp6O8Tvq5kDHQMBGCwT/7yk9jStQVnbj4zAJXw87lWCs0QKqdgpHVt86c1/axaz8Y3fr7r53jFYa/A+Yecj7m5OQwcPICvrf0aHhp/CEc1HxXoEVwsFp2SisfjTpnl83mnDOj9ax5moz1kYP8eff9N3w+EZD6b+CyO+coxGImOYNu6bYjH4+ju7nZxcK2JrDwKAK4aEDBP2KHBQ0E1Ozsb+LkaoxSa1SiOq+6+Cn983B/jouMuwtTUFD52+sdw+67b8YupX+DI9iNdqpy2oFNyFMM4Gg5RmDgS2V/ghEYUFbKvG5ESyJYa5x5yLs495NxFf4cC0ypfC+/7lLQPXlZZU+sIh8P49/P/3YVg0tE0/m/o/+KPf/fHyHRk0D3XjVBof9iLPBUr2IkoUs40Nze7rmxUytovXomE9p4rbL3YOvJsXXXvVdjQuQGfO/dzzhgYaB7A4OAg9hb2olgsunPJ0BgNNGUek+Da1NTk7gDlOQ3VQqHgUjH5+TyPNjZbyXjZYS8DML+nHznjI7j2gWvxy92/xPqW9QuK9GgdChJI1UvmHKmj6Cm3tLS4zBULWdtQRTXypyYP2ZJp5qL7LYqeth4gBof/01LQMnAAAnl6ExMT7lLxEKnVZR/MZ/VWo5T1YJZKJUyF9pdH64v1IZQPOWuTl4OXnbAR5wjM9ypVmIbxXW60xiV0voxNVMsgdIK9MIOvPfw1/OUpf+kIOYRhbKxO0zo4f8bCtUWkr+1fI5QxADx3w3Px+fs/j8fGHsOWxBY8NPIQ7h2+F+/a/C7MTcwF4vCl0nzhGHqQXCsqYnoSfNn4cT2EM11r7pWe+7nc/vO+oW8D+tr6nPDJZDKOfQ/AwXq8yNwLluSj8GJFI1XIbFqvaEq1BhwAzBZncf+++/GB0z8w75U3NeOU/lPwRPoJnNx+sosN6h7wTlrio43Na5hHCWvqeSsZsFFxfbtP1hu0sLUaoQr72/gif1aLl2xhX42ll5r3v19/Rz/a0e7uIAU7vyovhcgEgACszTNCQh25NopM8u5oDHap5+Lcv//493HW5rNw8X9ejJ/s/AlWta/C6w96Pc5InLHAadIiMdxb5SbwHlMPKAFNwwV8H61sV60ytmOuMIdvPPwN5GZzOLbvWOckaSlgJckxfKSpfAqv8/cBBIou+ZRxraNmhexYmy3NuG7fddjWvg2Hdh2KoekhTE3N9+Nta2sLFEqgQlDYi2kWVMaEtrhRGnAvlUreEnfVesrhcBjhSBh/8/O/wUkDJ+HI/iOxI7/DFSpRBnJLS0sgeK9Kmp/LDVNqPJ+VytrHDK5lE0OhEG559BakplK48NgLF/y/HnItgEIDgfNg3JMvzb9rlGfM8cE/+iDS02kc8S9HIBKOoFAs4H0nvA9nd52NR9OPus/gvmuYgJfAIhE+ZWyVciPiyPzMUCgEhIAP/eRDOG3daThxw4lOyFAAKSuTa82YOI1SKuWmpiak0+mAEcfiBFqST4ld9KQrPS+j+VEUSgWs6lgVgPYH2gfw+PjjiMVjLqyiHgMVAo0EGo6EHtVzV4jSFmexDPJGK2PfflnFaEl/3EuNIyqxpx5yl1XILS0tmJmdwRd2fwGHxw7H4b2HI92UDpApSYDVojL8NzCflUIlrDwRKnYaher91fIsoVAIT44/ic+Pfx7vOeU9eN8p78Pdu+7GpXdditLRJZzaeapTyEQKqcAo9zQnF/D3Vrd7pcaTD82oZjw09BBO++JpmJqbQkdzB7587pexqW2Ty/sms13r85P8yiItGhvWkKtdq0af55oga7WGP/W7T2HH5A586qhPoWlyvoYyg95K81evkxcagLO+rTLW2sYUDhpT5kWrlujFZ3jvre/Fb8d+i5tffjNapltcdyaWb+vo6HAQESE9ClU1EihseGny+bz7LFW+1npl/LyW8aUHv4RzDz4Xa+Nr3Vws050WoVp9Cje1tra6gia+Ep+Ngn0B4D8e+Q989aGv4iuv+gq29WzDfbvvwwfu+ABaDm7BYa2HOYiTFifjTszRpHes7GVllSsD1ccGr2cosvGeH74Hvxn9DX781h87lEGNAyownleS5rQBBS10Nj8A5vN/s9ksEolEoMQpMO8d14KqAEAIwbvL+XZ2dro2qMB8YQsKIBoSFKZ8DyUDsigNC5xozHg5Kkp5n28RRazng94a4VNL6LEkomoNZauQW1tb8bf3/S12TO7AZ47/DGKFmPu5kvsUttYKh9oUwxr/VCIKV9ebugUAxVIRJ609CX//wr/H3Nwcju47Go+MPIKbdt2Ec046J2Awcy5Ec8gn4P0lK5+ynF6lEmRtPQWiKop0VTMO6zsMD/zpA0jmkviPR/4D777t3fjS876EnkJPoNugLf2sJT+1Nr0yyjXMZGW674xXO/eqFTIPW3NzM95/x/vxk6Gf4IYX3YCecA9yuZyLH/X29gbKxmUymUCRDd0UKl/G1iiIM5mMS8VRj4EvCkSgOvZ1KBTCn//wz/GDx3+AW998K1a3rkY6nXas2ZaWFnR1dQXiNVRqtp6yMlEJa7MEnrIOldBCY8Z6QZWOHakduO3J2/DN133TWflq5TG2QSXAnF6mYjE2y7Vl6UZVyI1UxgDw/936/+GDp38Qbzr6TZibm8OhXYfi8dHH8ZVHv4Irt17p1jscDrt+znNz+/sf22dkPEq9EM3FtpWuGvEMoVAIf/6DP8f3t38fd1xwBzZ1bXKGgbXu+ftNTftrVpNBrkabxmsnJycxMTHhUs/YlEIVshogihYtNfra+xAJRTCUGwoo5NRsCgOxAXfmNSbIEY1GA56CEgKpjNnrW4vK8CzxzB8Izxjw589q3jV5IfQsidBRKfJs8Vnr8ZDJe7jsp5fhJ0M/wVde/BUMNA0EOh7x82zvbNuURxn7fH+eI75srn49535NfA2O6D/C3bFoNIrD+w7H9578HhKJhHs+Dc1EIhHHuOc8bC19yk+ebZ45rXRIQimNumpT5QCgKdyErd1bsbFjIw6NH4q7d92N6357HS7qvchVSFNHhQYP90ErTKqhQ+OSIR6tdudLtaxl7WuGrN9/x/vxn0//J2561U3YGNvo6pCSZKC5yExxImRge8oqxMeNY5GOrq4u15DCbowKQ6AyAVUqlfAXP/wL3Py7m/Hjt/0YW7q2YHZ2FrFYzC16PB4PzN16+ZqnbK1rjfEwzklFYSG+Wqu7XPfgdRiIDeC8g89z3jovNhUxDQltfM5KTS0tLa5kKYt/0LNhDfFGezT52TzCofnwQjgcRlO0CSWUXJUwkkJI7tI8cI2B0yDkheb6WqhULwk/t5bBM3PT727CHRfegS09W9z/+WKVVHqs1xuLxVyxB54rnhemfESjUce94HlXkhqNOC20UMlojjTjxLUn4sdP/RivOOwVLlTz88Gf480Hvxnd3fMEI0LnPFNUvgpnA3DlHHt6ely/b56j/v5+JBKJZenPW26oEaTescZwKTy1iInKGzW2azGS+fk6h0t/cil+tONHuPnVN2NTx6ZA7JLygbLCFhEZHx93ZDgiLJqWo0iMztXCwPqzSsfpG07Ho2OPBgy4pzJPYWPnRiQSiQDDWh2i1tbWwL1VL9QSBjXsEYvF0N3d7c6Qth+1qZaVDqKps7OzmCvMIV/IY3R0FIODg0gmk8540HAq94JfdU15lkiA1drivsZBtZ73miDr9/zXe/D133wdN776RvR39iNbyKIQKaAtPt9Am4eL7Fk+gNaITqfTjt6v6RU8lG1tbc4bYoxKY5uMCVUD4b37B+/G1x76Gm554y3obO3EyOQICoUC2iL762qzxqqW1VTYWr9SWWvcRzcWgGNSKrlFU0yqjSEXS0Vc9+B1uOCYCxANR52i4kW3Xas0PqXxHXpjZFXz8musr5HjZYe+DB/9yUexoXMDDu87HPftuQ/X/upavGbLa1yKDI0zsqjpIefzeacMqJjozamRo5WnbMyynufRMxNviWMoN4RSqYRESwItkRavUlZjgZkEmUwmEC8mfE3Drbm52YU76IHSyPAxOSsdl5x6CS68+UKcuPZEnLj6RFz9y6uRn8vjrUe+Fa2F1oBRqwpKBactjsO89e7ubqeQ6S0zd7xe4ZSdyeLx5OPu+6fGn8KDgw+ip60HGxMbA79bCWStZXuB+RioKoxaUSudRzgcxl/e+pf4xm+/gRtffSP64n3IFDMohovOu+I+2nAT08uYxqh8FTVQLVqopLZ60a33nvpePPdLz8XHfvoxvO6I1+Hu3Xfjul9fh8+c/RlnXBL6p+yhkcPYPBVxNpsNxGu5plo+mUqOxYjYsraW8ryX3nYpXnLwS7A2thbJbBJf+dVXcO/Ivfjw5g9jYmS+njgdFa02pwx1RYO0foQWBPJloTwjHvLn7v8cAODsfz878H+fO/dzePPhbw4cLioAa60RLiJxQVNzSJShACDxgQunuWTVQkvX3HcNAOCML58R+PkXXvoFvPmIN7v4zszMjLPyeZGVpEILm5/PeZdK86XkwuHwAgtRGc+1XPrbnrwNOyd24uLjL16QrqHWtlYO4894WTQ1wfYObpQSs+Mz534GH/7vD+PdP3w3hnPDWNOxBhcdfRHeue2dSI+n3VyIPBDepXFBNjsFPTAfVy1H6lLSXz2j3Jm57hXX4cJjLwzEVn0eChEYChhgXhlQwAJwnIpYLBYo22pjm9XGCN9w1Bswkh/B39zxNxjMDuLYgWNx02tuwrqudc7YYSxQY3itra2BkAzXkv+ndei1MYw9S0Bt6MR9e+/DmV8+031/yX9dAgC48NgLcf0rr1/w+0spZb6UwKVeqiWQVrPGOgcAuPaBawEslJHXnnst3nrUW92dVflBh8TmrGs6lD0LPuPBdwarGSevOxk3veEmXHr7pfjInR/B5u7N+NSLP4W3Hfs2h5wA8ylLHR0dLsxClA6Yz6axhFhFrYhkWFKgVcaVPstwbhgX3nwh9mX3obO5E9u6t+Hqk69Gf7of2/PbHfrGOvl2ffUu00gmp0lJxz7jvxGEtNpKZ/5NUCjovwn7auoSBY+2L1RIkYpClaxNmdIL4/MsK704pb9Z+Huq1PQS00Ll3CxFv1ydXv6eha/1wtdKkT9769lu/dUg0c/jGqn3TKKUWtw8XD4Lr9Ej3hLHVS+5Cleec2WAaZxOp5GP5hcQb4BgvWsKqqampgVrZuFiXzy3nuE7M75BJVAqlVyslzC1jWlrqEENTBqDVkHUS9j58+f8Od598rsD8XjmqrPNnBqctqqcFVa+GK0tw1mvcDrjoDMqXnsdFsL2nQuVOdbTrDd1JRQKLZCR/FruZzwrZLLb8Ivuhw47TxvSs/9X6V689NCX4qWHvjTwOUrEKhQKAcVkm7rQ6VLmvipzhYP5XGo02SyJSscXX/FFd1/IzdizZw8eTz4eYFQrmU4VMoCAMraxecrOcvUO6nVkKlLIXLx0Ol32//hvVQ42QG7rPlvPUa0oPqQSHbLZrCvFBswnbBNqtYU6lpq7fQ61OJUkpTEmKgc1FHT+SgrRZ1QmJSFs28C7mrlbwcowAL/awv+awmLhbc6LVZhmZ2cXeDiLDc6vkrlznbm+yj7Wc6JrqZfZJvTTOOK6tra2ApivVLYUsauauS821HCjETEzMxN4NvtcajBR4SmxRBsS2AwE5q1WM3edo7LxtXynNYAVwuUe6Pe+UAnv59zcXNlz1Kh1t0ORB4aUlEClPA96acpYVhIqY+u2mImd+1Lz9j2jfs8Xz7LdCytnOH+F3JWzw/nyPRVWzWQyVc3dzlUdEj3b5c63vjQrRc8PPWkrj2h8szIWSb9Lzd0a/JSHWhXNpj2pQqaxSRjeOlKaB86wZmtrK0qlkkPwKjnv5RZ5ybFr164SgGfNa9euXStzX5n7ytyfBa8/hLk/2+a9Mvdnfu7lRqi0pMreb3Hs3bsX8Xh8WdmS9Y5SqYRMJoO1a9c663dl7ss/Vub+zIyVuT8zw8792TJvYGXuz9TwnXffqEghr4yVsTJWxspYGStjeUdFMeRniyXyh2R1AytzPxBjZe7PzPhDmvuzZd7AytyfqVGph7wSQ/49e63MfWXuK3N/drz+EGKZK3N/ZuZeblTkIcfjcQDArl270NnZ6X5eKs13wcjlchgbG8PevXuxe/duDA4OYmxsDKlUyrUhYwlKrYLC1CdWaxoYGMDatWuxevXqQFk+tgRkjpqvmlQ6ncaGDRvcfMvNvSRMcLLl2E5yZGTE9bNNpVJIJpMYGxtz1V2UrcdE9+bmZiQSCaxatQrr1q3DmjVr0NfX5yoX8d+sPuMrvlHp3IF5FqHS+p944gk8+uijeOqppzAyMhKo2cpCAqXSfEs2TXKPx+Po7+/HqlWrXJGH7u5u9Pb2uu87OzsDhUM0HWGxue/cudM1SyDDkdWpkskk9u7diyeffBLbt2/Hzp07kUwmXb9jsh41fYUpEqzOxfKfnCeLC/T09KC/vx89PT0uP9amj4TDYWQymYrX3Tf0LJF5mkqlMDIygtHRUYyNjWF0dBTDw8PYs2cPhoaGHIM3kUhg48aN2Lp1KzZt2oRVq1ahu7t7QU5vuWIt1ZwZ35x5jpgFkEqlMDQ0hH379mFwcBCpVMrlKWuREJbGXb16NTZs2IBVq1a5cpnM0V/KY6l17srwVob9+Pg4BgcHsWvXLleNiedIc0g7OjrQ19eH9evXY+PGjVizZg26u7vd3CtJs7FzX2zelpVMmTI+Pu7uwMjICIaHh10v8I6ODqxZswZbt27FoYceivXr16Onpycwx1pHNXPXoWeG52ViYgKDg4PYs2cPBgcHMT4+7sre8s6yTnVPTw9Wr16NNWvWYGBgAIlEYkETinrnbs81ZXsmk8H4+DhGR0fduqfTafdi8SHLxAfg5k89pDJG5aNPti8l332jIoXMN2YBAN0kKmSm7zA9iVWX9O9ZkEJr4wLzSrlk0jK0mhXzwrSIRbmKNL5/c+4qQClEAbg6zwC8Je5Kpfmeq/YzNVFeqfysLhWJRFyNb16qcuXgFps7B+fP1nzZbNb9TiKRCFRZYr4xBZMWTmBqhF1zTd3iurOQiKYS+UqZ+ubOLjTMVdQ8xWw26/q6dnZ2ugIg2q6Tg5+n+X9WQOtlohDWspo2V7aadbeDa0cBwDQNpmzo+nDtOS8KK5ah1AIbiURiQbOPxRRFNXPX809lXPrf1BKFjvk7fEatG85KRbpvLJdZrcKodu6a28+Snvpcerd5Z5nbqn2aWWmMlerYNrOavFd7/33ysfS/KX48G+zyxTugOa/ca23aoeeiEQq50rnrUBnB5hE857ZeA1+2Wp32SOYzVauQF5s72/3quQYQKIhkC6sAcPNk5TDKFy3Ta+tR2DK+WmRJexYsJd99o6bCIPYD+BBapzkWi7nykfx/tjFkvpnmx2oRhVQq5RLk+f/0itQC0fzASi6RvbSaf8hSnuPj4wELdmJiwnnEnAsNiXA47DZGPW01FlgWjp6qKiQexGrjH/rM6u1S2LDDCoWoWn5aAIEGSTQaDbQB5P9rWz3Om4MFCiq9TLagBM+KbWY/N7e/05Z2GNLLoOtPwUwFyIsHwJUotSX4OIdGVfHyKTiiAOwak81mXb4mBQELC/i6VTW60pjOlfPVPHvm3hIV4iudTjuDVdfQtrdcrmIydu56b3Xe6XTaoVhE5Fj7mbnElD9ahcwq4HqLOpSbt64584W1HC+NClZqW64GKfXMX50NmyvPmvkTExPuvPBZfIV8lmOOinhqPjPPNZGqsbExpNPpQO60GpwAnDKnHNScZu4jC2CxFKrPuKjlXtStkPnB9EZoBdHipoXEnzPpXQsfUFHMzMy4h2YdbHZOooWrLdSAeUG/1KarYFcLm8KTynhkZMTBL7ZoBTeCn8sX35PQGb9nAQfCqh0dHc7CrLeqiyo4Chp+Tqk03/aPRR+01ZgPmqGi5kXj/hHW1oR3++xLrb165eppKWTe3d2NyclJV7OX60clrOEF7SvMGt08P9qbVeeubdL4qkdQWONOlXE+n3cdnhj+0DZ6QNAy125VtjFGowWxwtTaam5iYn+dX8Kno6OjyGQymJ2ddd4DhY5tP3cgFIYa7az3rMbD2NgYBgcH3dwJV2ulKxqXsVgsIEuW06BQhaZQL1+s9UwDOBqNuk5Cem4baZhVO3efslNjiOd8fHzcyY22tjYnP3jPyq2xoo/1zJNrrI2BaKwxFDM2NubOtULrdDY4D+4J0SNrCM7OzrrwjYbEbNWuap+tLoWsH6z9P2OxGAqFQkBZxGIxpxy4WBSKWlJTKzjxsIbDYdeAIBaLBWCBahohqBC1FW58lhRrKNM6VAjeCnQqZHrLtLLYFUdjFOpB16MUuO70BtnBiRBYNpsN1AFnBRzOlwJODRQaS4TE2evWwjHVlrXTGLB2TuGBZrenaDQaKGWn8aDJyUl3yBkS4fy5l2z/1tLS4qAxeskUzL76v9UMK6i00pCiLaowNIyjhhSVG722coquEQrDJ7jY5i+VSmF0dNSd/bGxMWSzWZRKJaccCP/ScPCVjV2OYUMDNNaTySRGR0cdz2NoaAjDw8MYHx9HPp8PIAxsqKJdeiy8uFxDPXsqC3rybKzDtdXQlq++/IEaFpmy6A+rmdGYIzpB+RYK7W8IRIXkU1T8nEY8lxprrNLGs00Ox+DgIEZHR925JrKoJT8ZqqN80JAlzxQ7A7a2tjq+yuTkZIDvwRBstaMhHjIXnLByLBZzMIDtR8q6vqHQfMebqakp51USRqUnFIlEXBN1K6jVO6vEU+PvlSv9pyU6+Xkaj6Qw18PFDeTmUdnT4mV3K517oVCo20OzHqdC48B84+9sNuu6CNEAoKWnys6W6ItGo+js7PQSHqpRaHoJOV+to815M36sCpnrxf3hevMZGUKgMUHPmc9Ng8pXG1rPQzVjKWWsncBIGiFkTYuc3rvt4cxYZyPauPnmDSBw7i06RMVMoksul3PGZ7FYDOwbDYdyPI7lGCp0afRQGY+NjQUIOzxLbW1tTtlpMxXLZ1iu+VulRnIXHQ5tLGIbvvj4AwdCKVtlrOdFQwX6oiFPBEhjtOXqUjcKyrYGm7ahVcWcTCYd2Q+AK8Os55hz5zyVs0DDifKxq6vLKxutXqpmNCyGrBBcW1ubeyhbOJxwLVtzaVxSYcpisejavmkt1EoVb63PAizsImRJRFQShCB9NWaj0WgALrYKQZ+lHthakQmGCpTIwrVXIhUFrRJggHlSWiQSCXQb0nnXs/7WS6YAYoNyGipaM5wXjKRBbVpCZQzA7ZH+na51I4ZPGauQYixN+1BzHUkuUfg0FosF4FOf19lIAeyLZyq0RwiV55aCB0DgnKn35oPplmP44t8KWzOLgwKSzVQABOZN71N77S4XHKxCWUMF2tiA9437v1THo+Ueus52zhriUGOTtaHVedFMCNtNzn5OI4yNxaB1ohI0JCYnJ50sUvKqGgzqJVMG0dgIhUILOg9avkuto+EeMok6GijXNBN6xlq0nZYYh6a3aAcQSyaxh7SSDeVc1btUFjGhcWA/Q49xcO16Q8+eHpFtOEFFwcOs7N9GKDbrcVLAs8+qkoV07Uul/c0L9KID8+QFrr21+PQzaxUMds5qvMXjcecpkPinDT3YUIH/1tg1DRrdVz6beuOWhFTt0AuqHgPjxbTElRDIND+eCZ5jnjFfy8JyDPZGDTt/VcgUWtPT0wEPx0KpVqlZxvpyDZ27empWGdu0G54xsqsJwR+IGDLnraECKjjGKfUOkyTk67W73MN6mhrK8iEpDMXw7vKOaVtO9jdmb28r87n2tRJcdd7WWLaNK/SrjW/rsMihOooKtwPlO83pOGAxZP1AVQ6a96deJh+GrfTIUlbP0l4ixnwotBTvrxZu8nnzc3NzAWIRY6ydnZ3OY+PnqmJQL4gWGLuR0Guj8tUYtM9DrkUxa+yeMWQecK6fXS9gf3qXMkyBhVaxzkc/p1zLsUrmyq9qCCkkBsBB63wRDuOaWhYjlTHfj4qDL9vCzqfwKpm/DXMoLKZKmHFYxjEJ+9JIotHEPMaBgQGXw6hnbLlihlYxaLybSllhR01zYppQV1fXgrSsAxGHVWHLudMI0mwIDQuokU0uBOff0dGxwANdrnlbJacdtTiam5sDhhoVmIXVl3PoPDWUx3M+NjaG4eFhjI2NYWJiwoU3ZmZmXIy+paXF1QXo7+9Hf3+/U8qxWMwxmXnv9ZnqUco6f9UpfBZlRyuRVUM5RBC1e5+is4r4aj9k6zTWE8ZpCGQNzHu1/Jn1brnYhUJhQZqKbc1FZUzBxQuksR+1HKtVxkpxB+ahTxK4WlpaXFxblRthU5IyKIQnJibc+1l4VJWOKuN6Yw1ccypkKidlk1oLu1gsurQyCzXy8/UQ0ZP1oRS1CDG1hO3PGeMj6pBOp915KhQKrnevGna6BvQ86WHQy6Anp3HaagSctb4Vwstms46ZPzIy4lJvRkZGkEwmkclkXOybHmZvby96e3td4ZLu7m6nkC2zmntqDaRah4Ui+QxUaOQ60BAlW5aKQhWakuU0zrmccVifd0/vmJ49jR/N+OC8de4HEhK2665omhL8NP1PmbsHmjhnESCecxKjWHiIPBumADEsQGW8atUq9Pf3uxxqokCU+fxcjkbsgXUuKGcXQyX5M80xVm+aclvDNkrCtAZ/PfvVUA9Zv6cwVwHKB1WrTw8BlTrjobRqKVz1ElnFUA1crd8DQYiiVCq5JvEUpPxssvDS6fSCAh8kSzG24PM+GwVZ69yVrk/iGQ+JXePp6emAcaHrpzA+lTDXWJWxb92r9ZJ1vmpUtLa2YmpqyjF6NUZb7pDzmTVGSIVMNq0q41rSSKyHo96lpguRWMRcR8ZhuX7t7e3o7u521X1sBTTdKz6bQvL1xNrU8NM4m42taTEbem1cT8KQ8Xg8EPs+EB4y526JO9wH3jsAC+KxPkW33Axma2j7oOCZmRl31qPRqEMi7BwPFKnLZ3Qybkxi1MjICIaGhpxnTL4MFRWNn+7ubnR3d7t0z5aWFnfnLFeIz8X/r+cs2dCaOozliGXqNasy5h6VSiUnNyhrKFcUhbR8o2fEQ9aFUMhBJ6VCQCevC0K4mmQATVFQuNpaINU+tC9eofOjtTc3N+fiDPQEIpGIMygIvTMmNDk5uaAAhT6f9Y4bpZTtmuuh0MvFlAtecvWcKfj5c5sXu5gyrna+wEJFo5wBpgsw1u1DQtSbp0Km4KXCIKJiIetqjAlfXIpkEXoOyuAkYkJPk8qNAkth366uLucZMxSioY5ysSkbc6t2WOiXgpcKTeFqCh6urcaOua4HwnvTuSviRKNCBScNTUKKGr7wGWcHCq7WM6QwKKFQ5bHwtZz56IvNVdOHfPnGzKunMtbsGptaxrUOhUIB2BiA+zvKAtUXQG0Fk6w8IUfIx0PinHiWAAQyMmhw8D0tMdCGUeuVj0ADFTIXxE6G1oXF2rko9u+VcaiXyFfBqNaHVoGmBKP29nZ3yGweNS1Zsu1sKUZfBS6reBsJV5dbc/X6leXI2LFCjHzxe1Vs6l2qIVTvgeOcFZLVNSPSoIKIP7cGjHrHDHHQKqdipmDzKeNKBveNwp98AbJMGbLQClGE8ej5EPGxMUINg7AYiiJL+tKCJrWmzOkaWk+NQggIxo5VCZdTaPXG/mp5BjVyyXIF5tngPBcWUmyU4Kx0nr67b5EydUSs13Ug48c2JECjU2s+80WuDQuAKClNCbs0/AC491VDlTUrVI4B1cWTrSJWxGxycnLBGdZ6EsViMVDrn2Q7har1PbWUqcb5G0XGbKhCBhZ6QdZaKRf85r81JUYvVK1ErqXmqpvIi828SxX4jH3Qg/alROm8LFTle9XrIfvWnJ+p7HGNd9g94D4wdq8eHOOElnDUKAjNztuupfVOKcioyCnE2JiEjTzYzMPGCi1cvdTcLYSnaR/qLWj+K1nVAJwioCJWQ4EFbkqlkvNE7B3gHtHg412wSEG1w5K7NLWNCo2ejsbjrTdwoJWxb/C+US6oYrOVz3TOy62M+dVH6tQUST6DCn3Ld6gVDal2WO+YULU2CNIaDZoSZ9ObaFhrPQeec7LLW1paXLojIWXdj0rPFuW4Elz1TJOEy0IsAFxOMeteMH2JITI6XmpUJxIJ13Snt7cXiUTCyUeLZNR6rhqukIGgoLVxSbVULWvXp4yXy1pUhaIWKgDngXAz+NmaM7fUsJ6xj2ndCGVc7nmsp1AObSAZjSU+6WGyq4lagRbmbvQ+WGRF4ST1hLg3FARkyA8MDGBgYMB1YVE2bS0esipkJUARomZVK8LVLMnHudHLpIFjGcoAXLhD8xyth8f9i8ViXmivkufgVwuhKoxKyLccwahcCEE/Q/d0uYZPCTPEpEVLtPhHOdnBs9To4fOKNTee661yRo0u9br0Pe06NHq+WniFZ50vm3fMoTKGsWLWO6DypVKk0VoqlRwBTB2IWvgpwDypuLW1NYA8hMNhl6bKyo+hUMgVRqKhQNSI4RvOm8Yn7zGNfu0k56uLXutYFoXMYT1fa/lbAc9/c3PqSVVZaqh3apUY468q/HwCrBxrupyF7Pu7RnnJHD74Ri+6rQhFhczSm7QCqYwXE8SNFAhWuKtAs//m7/OsxGKxQDu03t5ex162sbhaFTIJULbqD0sGsjUnL7Gm7pHZq3mw5CmwfjtzlgE4hjBf7e3tAWude1fNuVksnkm4mspBBax6xr4QgpJy7Lo2WmHYofJFQ0+2yIrlq9h5N3K+Pqha77uP1Kmyx8LqFg1RA8IaE/XM36ImmqOuRW60CIgaxhpOAeA81FJpfwaL3p1cLodSqYR4PI5CoeCQOYW7ySWpJJ6s8s6uaSgUCnw+s2cYflQOguZdE0YPhULOQE0kEq6lbm9vL7q6uhY0r/m9VcgaI7RKWZWzT0n7IOBGX25fbMdeGvs33FiyOzXewJgDD7UlSWisTpPN+TmNhK7VwFAhZePdevGtEFZiiYWqG70felb4vX6G/TnnoV6FsoEVYrXCrZq5q0DVwvKMqTHlhjE17bKlebBKAFGkh5Z7JpNBMplEPp93IRGN5WutY8bA7Plc7Bl8XrE1EC36wPW1WRJap50Qt49w1igouxzXwnc2eOb1pXNXAwRAWQNtuRRzuZc1LGzIRpE5ey/K3ZV65mxllq006PPUlTej7VttVTUq5FBof7XGtrY253VrWWSGDCsZipjYdSkW99eNoKOh5ZDpFROd4rx9sX0bziMp02bb1LsHy+4h+y6KesD02jT3EUBAeGjZzHqHWskKHynjlIrTXihWCNK4Ia1HKmdlEerBpjVGcpUtudbIYY0gGjyqkNVTpqe5lOJt9DzLzV0tXp27Kle1SG1YRM9avciKCnMldTH3lYrYlmu051xZnTxf2tTEKmSeJZ5DxvirOTfWS7PVl9RIVAGqLFOb/kISmk0jU/KZfuWe1jv0WSq9M7yzWmuc8c1CoeANwTTCkLAhgkpe9jlV8VlDxGf8WKOonmFltoa3+D2NRCCYKgTAIYyMz2rOOOO4vN/K7q8nlMe5WKLw3Nycy26gQlYDp1gsugZG+uxKVLN9nEnosuhhIxyVZVXIgL9kosaJGRujRcJYrdZ81aB7PYrZB9nxwjLGQCFLAau/T5ZgJpMJlJCzrdR4WFUYMN+TXqcq5eWIJ6unRk9SU5pIHqG3wAPJ51TBbGE3Evb4WY2ct8J31sPUNArLaNf99Hl9ui6VDmuM8ZwQIdG2nPw8C6Gq5UySC7BfEORyOdccgUXvi8UimpubAw02ADhoXtmflSol9Wy1cYuyq/kKh8OuAYwVMsVi0bFWbYaBGk5MZQFQl1KuxLvkHvvitZrul8lkHIludnY2kA6nBoXGHmudN+euz8DvbegFmK9apfPm3/J5OHwogIWLa1HK9u6Rpcz6+DRo4vG4k880bigTaUzymXh+tQuati60oTu7bvxa6bNYpVwqlRyPI5FIuNrh1EGh0DwLnE4UZRvXwebf87VcHcOW3UNW2NoSjRi3ZPcNLgQD7RTK7GJUjwLTi6E5jPR6CBuyLBxrVCsjkgeMylW9JC3Kzy5V3Gz+fjqddkKAisYHAdW75kCwkYOmkWmhk/b2dufRk4ihbRqpcKwFSwuz0cpYPWNeJNa5ZjMSWqWMBQEIFOvI5/OBpH0bs6t2ztZD1oL1tnGIXmoVlJwjO24xxSmTybhCC+ohkzmu8DBJdxruWGre3Cu7PvyqTTz4HFwjKi49D9lsNpBjalOhNE/T7im/r3b4lLIlR/Ir10sNDyIQWjvA5lSTY2DTMeuFse29tt+rMtKay5R1amDyzvnCfTSKSqWSU4i1rLkSo7Thixq9lFe8f/w38+/pHVt+Ap9NM1mst1/rOutQ+acEReoOIk2a/08yZTgcdlkMkUjEcUDIqVFSplaOVLTuWeEhq9Wh9H5CCVTIFD6MrYVC+wkm8XjcKTofKaLamKAeECpjNjjft28fRkZGHMlGPQiryDkn21yC86MCV4IEUQG+93J4yFwTWvz0utTjJORCIUaPfmpqCk1NTQti5dokY7lSRsopZO5XNpt1yoB1iwE4diSLGBAFaGtrqyvcYSFf7r02B9DwBp/BxrqpCBg3C4f3d6XKZDIYHR11pTbpPbe0tLj3LRaLLl9T70ClkLU97+rZWw6EokK6pmpQMm5mzxEVsSIuGuuvRUFUGoNVRaFVpjTkQah9cnISHR0dAaVHb58wNj0/HwTcCMWse+PbH36mhtBoNKuBrcVPFIalEViNkuN5pXdMT5ifyXALiVpEA6mQ2RJVHR4tgMI1sOExG1pqxLAOIPsyh8PzxT2I4tAAIh+CSpvMas08UWXsawjze+8hA/4UKG46UyvYMFqVFDe4tbU1YM2rEqtlEfQSaDWadDrtlPLg4KCz+FQRaeoCBSPnYyvv6OdpvicVd70xk8WGXkQbt1elrAKUc5menkY0GnUNHtSL0oIdNLKWQymTZNbW1ub2fGpqakE+rJKLNL5Lo4depoXEqlUK/FrOMyunHFWZ0zugFzw9PR3onkMGKDDfnETzKrkPlZ4ZO2+FcNWQ0HOtUL8qNhqW/F4VMAWaetg2hsyfVQPt+eKqqoRtDFa9ZMYEFaUIhULuzmsIjF5ca2sr5ubmAorZEtaoMCq9r+VixXwWAIGYPlEIPp+G7Oip2owJzluRDb6qiSurEtMiSepQcf85F+VPqJzU86Jz4dz5GXwGH+myER6zesn8np/FtKzx8XFXQIT7opC9ZkjYSnW2zvizQiFz2DiypoUwXSQUCjmvjN6PZTSr5aXKfqnFsJdX48e+UohkA5brZawKWOHDcjFEe1msEF8OpazIhA0XkJXMS8M6xvSMCFPyRUHmS0Fr1Nx9Z8RXro4XApgnkKhAI0ylHl+tkHU5kpn9N9/XetSKPnDQ4CTzlMRA9n0OhUJO0KoRUO2ZsWdNvUk9u7ZQBTCPVoVC8721qaTs0PtAD4heWqFQCBAIFRKuZPieUYW17qfebTLedV8U3aICIZSvYQ5bnc5W96pm3Xn2fMYE90JJTyQ6FYtFJ5d8CplhAhqtSlC1qY3A0jJSeScMFVKuEXng2nEulAVcW4sicV56hiKRSOBOW15Io5Sbha4J6XONtLaCldmKztm2nYuVymzEOGAKGQjCIoyTMS7GobG6QqGA5ub5lnxUzMT/dREqgTt8EKS9pOrFqhHAi67WuMZ/Cc3RItPKNRpns+xUaxE2elglx8tAogKFsHryAAL9nul5Njc3Bw5vo2EmKkxVfmQWaxoQL7L+HYsZ8O9o3dbDZvdBX0R19IxQaFLpKiMbQIDFqbFO5SHwvKvhoB6eTQfk/Cp9jsW+X2zofluPk54/4Uy+t8ag2dBE70A1Rpw+q3p8dj70enielexFQ1Nj4TzTzFVXGJ6C13qivLdKstKhZ8wX7/YZQiR80lObnZ11ypAI3vT0dMDY0XnRsKZs1JK3alxwL5daa1Vg1nggqhIKhZwxoKQmSwzlM1LORaNRbz137RrWSKXMZ1Kj2r63NZS4xnrXrTJWRdxIZQw8QwqZgXUKNvVcKKhocUejUdcXl7EsKga+ZzWemvUULGyn7+E7FHpAfcKTl4D5sNqInnR5slRt/tpyDLUUaf2xlCOr1VBJ0DJXViTzbVmzVi9tpR5DJUPfQ2NlhUIhALHTYiVBiueERTW4H+3t7QvY8rWsnV7Qzs5O9PT0OKufngOFJs+DsqmJKvD3ee4UmSE6wTiWPUuqIKqJV1lhpMqLewcEmdDWS9IaxTy/CtURWlUCDxVxLpdzCqOzs9MJvEru6mKK2JKZmpqanDem91obZ7D5CxXJ+Ph4wLjjGVMBbLuHtbe3uz1fbFjDX1nsGhLg92zfWiwWkUql3L+19rOSujR2zDinLXdLdjDvw1IyUs8JuSc0nhTdo4cbj8eRyWTcuqRSqUABGa6TKjn2Amd5WxbYYIlbXzONehVzub2wzhcNap59rqvC1eVquDdyNFwha5zEN5TJx0otFALFYhHZbNaRXmiNUSHzxTiEvayVzM23ObwkOj9rDRN65GWkcAXmvWP1Pnn4enp60N3d7b5qb+dGU+Y5L/u8HLR8yaAkXK0kOq5LOBwOKGSiEnwpbFOPQtbzUk4pa3ocBQ1jsYxxKgEvFAo5bkI9sLV6C6wGRoufaRPMJ6bg5FnK5/OOka+xR4V3LWnHxs/1LFkjrhIIkl+tIlM4k8KFZ53C04Y2WFghHo87Jq9lBkej0QAvg/Pv6upCsTjfT5Z3vZJhFYUqY/VeFYamJ6oGlab+MUaodaPpIKgQ1jusMVoNP9jhkzFE4qyXrFAzz0pbWxui0WgAxePZ1efhGvAu9/T0IJ1OO/KRb80rkZE8g/bsKPcnHo97SX6EnzXDhHOgk9Lb24vVq1cH6s6zJrRtfFOvTNS90BClZtdo5gHJryR06XnQ+1cLSlXpaIhCtkK1HDyoAX6tOkTllsvlnFVGT41xZWV7trW1BS7ZYsSacnOwUAWVqsYuKaDIjKWBwPehUFGomgqDgkg7EJEgQIXcqIRyn1KzMXPOVQV+R0eHg3qJOhBGA+DgvUwm45iJFF4UFLV4nr456x5aVqZl5/NFZcwLRu80Go266llWSFdDRlMBaJmnXCvG/rLZrFOwFNrWk7IxXM3x1vNHC52K2LI6K0VVyikzq5D5f2psqLKjcCLLtKmpyaVxUdEw1EGvnwqZfBA+kzJcK5k/z4CFqW3d5+bmZgftMs7J+0tlpM+qRrdWYqIyo9fJPaQC5L4vNqx8WQyy1up9uVwukCqn94FD14KxTm34wDAb0UWtuLeUjNT15vf27FCxaj66Fgzh+hCV4H3RjmzqHbMe9GJVAWsZKl90vVUpK4JCmUcUQMMB5XKOlwPZrEsh+4SqxlYt/MthlYPGaGid8bCGw+EA7MR/07KtJa3FQnlKeGJ6iRIX1LrU2AO9Qz6HCi/bUEDhal995Xr2wEcgsfvB9VT4C5i/4HwmjfvQ82E8mS99n2pjs74z4xNWNCQ0dULjPAp9UikrY1YZ4pbRrnH/xYaeD1aT0/mTKZ1KpQJ1rBU+td4SvQbrsatnqhC9rypQNR6E9XJUKWsskkxUAIHSqvp7nFc0GnUGB5EK7hMA5ylTMTY1NbnwVK33tRxkbTsjAfNxeq6/hby5dqqYaeBRUGuBH9ZCoPe6mIesY7G7wfNNw5ahDc7PxmXtXkcikQV1pQE4mJipctWuuZUL9HAt0ZINFfiZapgxlZTyORwOu/NDtMWWn1S2daMVnpWFGlohdE1dw/Nim/Is1lilkaNmhWy9YY3daGoSELxQ/FuFi61SofDiBvuIVrWUW9N5aCyGcKjCuM3NzY5wQ9gxFAo5T0wZiJYmT3iaL1/+WiMo8+XWjN+roNR8aC0Ub/Npua6RSCQQ+7IVsKoVqjpfPTN6MfTcWC8inU4jl8s5o0BjP5qeQ+u8XDWqauEmG1PViz01NYW+vj7kcjlHQKSnrCxTrj1DHHwuIAiLU2gxtmnh02pRlcXgXq1bTqSEAsmyTzX+zPegIUolqOdCq7wxnm/z7isdNLpsLFzREub1T05OOiWha2zPINePcogyiO/Pvwf2GycdHR0uHq7pdkvN20dC059ZXoF69JRN6jFz/vYc8V5zL+gZWw5Fpcazxo0ZD9Zn4PrxDihkbTvD8Zl5vjUuryzr5VLG5YaVRSrT1UEjAnMglDFQo0Iup0ApGNWLAhCwsNTD1HKEVjHwsBECUTq9pdUvBpPr8F1oFQ60mtva2lzTgFQqhVQqFdgIq5BpNZL0o91Aenp60NPTE4CqtWRfvbCMhWJUGOrPuDf6XOx1mkqlXD1um2pjvVcfLF7pXO2Z0bJ66tHq59Eoy2azGBsbw+joqCtZyi5JVNT8O1XWqrQ1/kNBs9RQhaawH2OAPOfNzc3o7u52c+F5pnFA/gMAB+NxHjw/VML0IGjQ+TyJSmLI5VAgzdPmZxUKBUcuAxBQTGp82fdVxaPKhesSDoedd1xNURP7LJyThU+ZDke5o2vDZ9C/U4Gvv8f9zOfzThkRamXxIlYPXAyytt4819yGCvj/Kg85V6IlmubHuDvlrA3BTE1NBRRyLBZzaZu1oBK67nbfAQTi2LaEqk27st61stp9zkmjhw9h0XOgZ9vumy/Fya5NI0dNCvma+67B5+77HJ5OPQ0A2Na7De876X147sBzA91vCO2oZc5DqIeLEAeFMpWKD/uv1UPm4h9+7eHYMbFjwf+/btPr8Odb/tw1u2dDa0J0PosKmC+3pgzCNWvWoK+vz3nH9HK0gbdeTM5vyXW/9xpcc981bt2P6D8Cl51+GV644YUBSF8NFyoEje1REbM/KZtl0LtTCEyHVcbVDN+ZueTES3Ba32mBzklqFKhCZmrIyMgIhoeHMTY25mJnetZCoZAzQqxStsUeKomp8aUK/NMPfBpX/OIKvO2Qt+Fdm9+F5uZmdHV1BaqbaboYDaDR0VH3bHx/FmEgg5xGnJJeGNNkb2c9N0uNcDiM3enduPT2S/GjJ3+EydlJbOzYiPcd8j5nQLI6FNefniz3WdEsyxuwHp+SeVizWPs9V6sY9qT34AO3fQA/fPyHyM/msblzM/72hL9FT1tPAEmgYU8Brznd6nFahaz7DOwvypLNZp13z8b0RJUWU8iKvEQiEYxMjeBDd34It++4HZNzk1jVtArnN53vZCG9XzXQKE9IIuzu7kZHR4eTQTRgdZ+4R9ls1jH16dXX4iHr+PhPP45Lb78U7znlPbjynCsXePzqbNmcbSUNKjO8XG/2RnjHmekMPvzfH8ZNv7sJw7lhHLfqOPzjC/8Rx/Qes2C+Sgy2CIw1Niz7W/e7kaNqhVwqlbAuvg4fPfOj2NK1BXNzc/i3X/8b3vr9t+Lml9yMvlIfkskk0um0g381DkV2Kq1pXiTtCavMQquUKWirLTvJhfzF23+Bmbl5KPHh4Yfx2u+8Fq/e9mqs6l6Fzs7OQG3rcDjsIEcLq3JD1Nvo6+vDqlWr0N/fH8izU2jGpp5UurHrO9fj42d9HIf0HIJiqYjr/+d6vPabr8Udb7oDG1s3Onhd472aw83nogIkDMzCFEQ2dKiw1X9XcygDZyaxBbNzs7jh1zfggh9cgG+++Jvometxni+ro9n0kHw+j4mJCYyOjrp64xQ4trSgDaEodE2Bofu31Lmh4I5Go7h/8H782yP/hiP7jkRLSwtWrVqFWCyGvr6+wJqTDEdl3Nzc7OBF3gvGJ8nwpQAmqkLSC1mo1aY+hUIhjE+O44x/OwMv2PgC3Pzam9EZ6cTD+x5GfC6OcDGMnp4eR54k5M494N7ZeH44HF5gtPnCUYxz+urQV3JnxyfHcfqXTseZB52J773xe+hq6sIjg4+gG91oKbUElDHTrLTpxczMjBO+mpKjglUNTH1Gyi01sqjwaawsdl7SM2mc/Y2z8Ufr/whfOfcrCE+Fce+T92JyzySmWqecPCQBi+eRc29vb0d3dzf6+/uRSCRcKhTRFt5zRQt5bzhvX7nhasa9e+7Ftfdfi2NWHeOezYYPfCENy+xW0pdVdOqYNELBveO778DDww/jhlfdgDUda3DDr27AS7/xUtxz4T3ojnYvMCT0LNiwjKIbBwpSr8lDftmhLwvAjh869UP44q++iPsG78PzOp7nSgKScUoaudLGeXEVSlUPWT1QhYoVMizntZYboVAIq+KrAjHKTz/waWxObMaLD3kx5ubmHJxCkks+n8f4+PgCZUr4PRwOB/I0Vaj6BKnmsFkYaMl1P+xl7t/FYhEfOeMjuPaBa3H3nrsxsG7AKVxteEElrd9TIVMJU4lw3fVS2YvmU8qVnhmNGX/wOR/EF3/9Rdw/eD9Oaz3NtSBUL03PGI02Qu2cu+YY81L5wgv2VQupKDebw8XfuxjXnHcNPvazj6GpqQldXV2uWpKm/JCZTrhxenoaExMTzigls5rekKZYELbWKkGMHyvsXsn6/8PP/gEbOjfgS6/4krtva9rWIJlMYnBuEJ2dna5kp8ZWATgv0YZG6AlrPNhySdRzU05DNV7aP/zsH7AhsX/u9LhXt67GxMQERuZGXPyb6JAavjQyAASY2Ly3hHcBBBQikRZ+ns0hVoSj3FkBgE/e/Uls6NyAz5/3eeTzeSSTSTQNNOF3yd9honXCeYo0fLiOGrvUtDPyF5hypoNniXtkqwtWq4gBIDuTxVu+/Rb868v+FX9319+VfVYloPni5OqFlns1yuucnJ3Et37zLdzyxlvw/E3PR6lUwl8//6/xvce+hy/8+gt4/4nvD3yWGhO8V4tlIhyI+HbdaU+FYgHfevRbmJybxJGJIzGVm3L5q7a/K61Cxmx8ZCNajLw0yvb0eZb2tdjg//OQzJXm8I3ffgP/9+T/61J5+L5zc3Nl00xo9XFOWjxAhajtCGIVms6pmlEqlVAoFvCNR76B3GwOx/cd79aQSIPmbRN5oMLm/ig5ioKTB1LTXvRrrRCTGk2zc7P41qPfQn4uj20d2zCVnT8znJcSokgeUTKawqqaCkP4Vz0iC+/WcqlCoRD+4od/gfMOOQ/nHHIOPv7zjyMSiaC9vd15vzMzM4EceZ7xfD6/YD7quVGxqELW82OhvWrm/93Hvouzt56NN3zrDbhrx11Y07EGbz/y7Xj5+pc7IiOVkD2XLJSgBjRZvfTW1EOjZ6lrZt+zmrl/59Hv4Jyt5+D1N74ed+64E2s71uLiYy7Gqze92hGESEhj+h6hWk3fs0awZTAD8wQphiYikUggvKQKp5Kz8r3HvocXb3kx3vbdt+EnO3+CgdYBnD9wPra1b3N7y7SgaDQayNm1c7PvzWewykKdhaXeZ6nx7h+8G+cfcj7O2nJWQCH7ZK6FsGn06D6rI+Uzjul582stY644h0KpgNZoq/tMAGiNtuIXu3+BwnFBYxGY5xdoGq1PKR8IQhdQh0J+aPghPO/652Fqbgqxphg+d+bnsLFlI/aO7w0I/kJhf/cMpirxciipi54FvSMeTE21sOXglIZezSVX6+g7j30HqakULjz2wsCF1YOjXh0tTsZEIpGIy1+0jastm9oSCWrd2IeGHsJpXzwNU3NT6GjuwA3n34CDYgdhZGQk4EGm0+mAUibsRm9ZFZta51QyWhVLi1P4Wo5VMkqlEh4aeggvuOEF7sx8+vRPY33TeuwZ3eMUrhoKuu403NRgY0yQBgQFdE9PzwIilFVoehYqGV9/+Ot4YPAB3Psn9y6IN6mXQHLQ5ORkWcNRhYCFq7u7uwOFCOotKfjk+JP43H2fw3tPfS8+ePoHcffuu/G+W98HnA6c1X8WEomEu2/KrKZXD8ChReR7ELLWkIjPSLLEH2vYVjL3a+67Bpecdgk+ePoHcc+ee3DJrZcg9LwQXrLmJa4/NJUpzzeROSIo6hHzbusZUM+I35NdzXOkHJBK0p6eHH8S195/Ld7znPfgvSe9Fz976me4/JeX449X/THWda1DX1+fQ6N4P3m2AQS4E4T+AQRS/Eg8Y2lM7iOJbtahqHTtv/7w1/HAvv1nvdywZ9kaB1xvrrXKUeV1+O5IrSPeEsdp60/DFXddgW192zDQPoCv/vqruHvv3dic2BwgBNMBUX6Bks6UXf177SFzMtv6tuGei+/BeH4cNz5yI9730/fhsyd9FpHpiFPIExMTzuqkx6UbxQWyTR6YekFyAuE8wni+ouTVLlQoFMJ1D16Hlxz8EqyNrw0UR1cjgXPSakzAfMN4rebDKlyaM6oxkkYw9A7rOwz/887/QTKfxDcf+Sb+z3/9H3z5zC8jlo854tD4+LjzNhkK0BxuW6PbsiDb29sDRU20sImtMrbUuitEeUjPIfjJW3+CscwYvv3ot3HZ3Zfhk0d9EpHpiBPuhNh9zUR4qVUZ29KO8XjcVQMiu71cHdpK92LXxC685z/fg1vfdquzvvf/MQJKhmtJo05jkkrAoSAjk5Yt3np7ex07v6OjIyBQa7XQi6UiTlp7Ej521sdQLBZx7MCxeGT4EXz10a/iVVteFWATa/y1paXF9W3mHlOYAXAGEkMeSvoCEEgFLEfeqXTuf/+iv0epVMKxq47FQ8MP4Ybf3oBXb331AmXLGD3j3NocRglp9Pgt61bjmqwn0N/fj/7+fmcosYNYpXP/6As/iunpaRwaPxQPDT2EHw/9GO/te69j/be2tjpSJfkqZE2n02kUi0Xk8/kAVM1zpalZNKZaWlocc55nqBpkpexZ9ww1ZuyLn0MPmOdFyY42ta5e5BAAbnjVDbj4Oxdj/ZXrEQlFcNzq4/Daw16LBwYfCJSq1SYZJOVy37UeuI1z/17GkAGgJdqCQ/sOxezsLA7vOhz37r0X33j6G3ht62uRz+edt0YYWHO5gPkC/Ap36YWmUPDF16iU66kHvSO1A7c9dRu++dpvOqtZDQNVEKrQSAziJlJhUfgrXGpZvUD9zLzmSDMO7jkYhUQBR/Uchbt3343rfnsdLuq9yClkkp7oJas3rOtt029YTYfsW82l9inkarzkUqmEpnATtiS2YG3LWmxp24L7992Pb+7+Jl4VfdUCQpR2mLIQEy8RIWotxMK5DwwMuLJ8GjqoJZ/w/n33Yzg3jBOuPcH9rFAq4K4dd+Gf7/lnTF426S6pjbdqGpp9Dnr1XV1d6O3tRV9fn/PuLaO61nOzJr4GR/QfAWBegB7efzhufuxmxGIxdzfVyOXdI2JFj0xRC8tgp2dM2E/RLcsfqVSg6dy5Zkf0HYGbH73ZNRdRBaBz5P/x/io3hQYd0TsaJCRnEhHSs8R9YbGNSuZ+eP/h7m41NTVhW+82/Nfu/0J/f79zOigngP13hAxvIoWTk5MYHx8PnFtFZdSY5nnSMqe+co+LjcXO+mfv+SymPjTfRcvOwZcepCmCLFdJ1M4yshsBC2/t2Yo7L7oTmakMUpMp9Lb04s03vRkbYhsWIDpUyK2trQiFQk7+EZH1ecm/l5C1LyZUQgnTxfmLy4cnuUIhDVpMmiNKaENxfCU1EDK1cHCtUMJ1D16HgdgAzjvkPJSKwZ6xmi6kDd1tAXLCjUrA8VXhqtfq8w2Xz1ssYHpuOtDCjXm6DBto+oOmk9lYlMK+fC5rBNmKUdU804IYUqmI6cI0Zkvz7Hnb+lFLUFqGJj1jKjSiFEqqs8rYhg4qGS/a/CI89K6HAj97+y1vx7a+bfjA6R9AJBwJpAnZUIdl5wPzkK7NP7ZGXbXevB2nbzgdj4496v4+HA7j8fHHsSmxyRHONHdUPyMSiSCTySAUCiGXyzlyHWP8en/5+wCcIWHDS9UaF3buALB9fDs2JTahtbV1QahAeQZUvjS0qeCI2DGWyPe2aAX3Q4v78CxVEpc9fcPpeGz0sYAXuTO/E+ti61zeN9eiUCg4459ZJnQMeEc1BEAki+dD74LeW9aHrubMV3LWlUSpMXlFG2zoz3e/tQNYrUzwciPWHENrpBWDE4P4753/jfcd976AXKccpM4JhUIL0rF8yvj3DrIGgEtvuxQvOfglWB9fj1Q+ha/8+iv45eAv8cnjPonCaDBWwNiHWrLMp9NCICRSaLxPLT99lVuoSkexVMR1D16HC465AJFQBHOlOW+RDU2XoTJWgaoxByv07QY2YhMvve1SnHvIufvXfTKFr/7qq/jFvl/gqpOuwmx6HkIkQY7CiTWdFTYFgik9FBrq3egBVUizFi/zQ//9Ibxo44uwum01RjOj+I/f/AfuH7sfl2+9HIXxILSra09DgsKe60qLXJswqFKzRVh8sFOlI94Sx1EDRwV+FmuKobetF0cNHLWAaawK2ebvqjCjMNUWnTb2Vy+y8t5T34vnfum5+Puf/D1ef+Trcc+ee/CF//kCPnfe5xzsr2QlDdlw7mRW09imwa0hJ2AeOuX+WKZ+tWu/2Nz5OSpPaKzTw+F90DAChT4dAwABhafhD1sljcpD87CXmvvHfvoxvPqwV+OXu36Jr/7uq7ji5Ctc7JueMPdcCXSMIfNMaXiG82eKn/6fPn8thvNSZ11lBzBPBPQRVoH5dDIqZZ9stUSreohdP3r8RyihhEO6D8Gjo4/ig7d/EAd3HYxXbHwF0uPpQOU+stO5RsrE94WKfm9JXcO5YVx484XYl92HREsCR/QegevPuh4HFQ7CYyOPBdIDNCmdMTZeclV0AFwlJLVc1IMrl8Nb7SLd9uRt2DmxExcffzGA8iUobQqHrVhk2da+4H8jN3A4N4wLbrogsO43nH0Dtoa24vHxxxdUNFMSAwkjmrKhQonf+3LxfM9YrbExkhvBn/7wTzGYHUS8OY5DE4fiqpOuwqa5TdiR3BG4uDSM1KOnEOLa8yKp12yNNotU1KKMaxnKHPWl5+kzKMFLDdFGkUhOXncybnrDTbj09kvxkTs/gs3dm3HVOVfhrce+1TF8ATjUStdO14/PpfdDy+SGQqEF98N3Vxs9d2Zx+CpGWe6Grr/1xBTFW8wB0PWoaO63XYor7roCByUOwkf/6KN45cZXIpVKLXh/vjfnpnn0ykCm4lAPn8/IO+sLly2HZ6fIn/0c+1m+tDhNnas0dXWpMTE9gUtvvxS707vR09qDlx38Mlxy3CUo5osBGW/XDghWlLTyYrllBkdFCpmLlE6nAQBXnnklcOZ8qkA+n0cqlcKePXsCJRD50rgZFbQqYy3XZwk8tiRnPp93liH/j4ePpQl1U+3cAeDU/lMx8d4JlEolpNNpb51njT3pPHmofCQwwk7sZsPCBJVAXJzfYnO36z41NYWJiQns27fPQbycq41d8m94+HnIlE2uMKsPYuL+WEWXyWQWnXupVMInnvcJzJ46/56E14eGhgLxeUvi4iUCsOAS65zVANHKZDwrMzMz3nSMStbdN77z6u+43+FcmHqjZ0jDHYy1ao6unh/OmQZPoVBYNI5c6dyfv/r5+Nlbfhb420wmE0CFyKDWc2RLj9qytdbTCIVCC8q2qkfN+wXAGeG1zl2hUL4319rWL7fOAdeWXighU5U3dk+am5sdSsBe11a52/Py/NXPx0/f8lP3OYTUtR679RStQcqfkfDV1NTk7on+W+eu95bkqUJhvtZ6JXPXoWddz4zKPHvelRdkw4Gcl6aDcm3LpSrqHBeb+0s2vATnXHiOu1+aDqqIoa1qqKitclm47ryDtQ7fXfWOUgVj165dJQDPmteuXbtW5r4y95W5Pwtefwhzf7bNe2Xuz/zcy41QaUmVvd8j27t3L+Lx+AFz3WsZpVIJmUwGa9euddbMytyXf6zM/ZkZK3N/Zoad+7Nl3sDK3J+p4TvvvlGRQl4ZK2NlrIyVsTJWxvKOimLIzxZL5A/J6gZW5n4gxsrcn5nxhzT3Z8u8gZW5P1OjUg95JYb8e/ZamfvK3Ffm/ux4/SHEMlfm/szMvdyoyEOOx+MAgF27dqGzs7OSPwGAQOoHWXa5XA5jY2PYtWsXtm/fju3bt2Pv3r0uX7CzsxP9/f2u0EN3dzf6+vpcf1gW3/AxT9PpNDZs2ODma+cej8cdI5Ys4snJSaRSKezbtw9PP/00nnjiCezcudN1HtJGByVJO9DkfJ0LqfOshcuSiOyPzFxZ5ssyBzGXy+Gggw4qO3euO9eU66rPwxrK2WwWIyMj2L17N5566ins3LkTY2NjgdrEra2tgZaR69evx4YNG7BmzRrXwajSfONa1p2lAQcHB7Fr1y48/fTT2LNnj2vByEIOwP4cdq7lmjVrsGbNmkA5QxZD4FpWk5++1Nx13ZXVOzExgb179+LJJ5/E9u3bsXPnToyMjLicbw7mdmvxD1aB0iImrAKlud9LlYatdO7VDJ4tsmgzmQxGR0exd+9e7NmzB8PDw0ilUq51Zy6Xc+mMPT092LhxIw455BBs2bIFa9asQSKRCNSwr2TuO3bsQEdHR4CVa5nJtsd3KpXC8PAwhoaGMDw8jImJiUA1Ji0NG4/H0dPTg76+PleqlOVve3t7XQ9iztvHzNe5N2LNtYAJ5dHOnTvx9NNPY+/evRgdHQ3Up2d50M7OTqxevRobN27Exo0bsWrVKlcUR+u5U05lMhls3Lhx0bnbM5DL5TAyMoI9e/Zgx44dGBwcRCqVchX1tPIfAJd+xbx6lrPdsGHDAhlje1QvNipdd66nstu17SzbuO7btw+7d+92zzMzM4NQaL5HOeXKqlWrsH79ehx00EFYu3Ytent7A2taSUqf77z7RkUKmR/GZPlKhk8ZA3DpJ5rnZSu9aD4dC1KwYldnZ2dZhWzna+cej8cD+X1MCJ+ZmXGHQ/MXOR+mHQAIKGTmENrcNV5+JurbTj760mo6i829s7PTpSzxsGl6BBCspWzfxxaY0EughRG0UIWm3FRy6Cpdd67Z7Ozsgspmmp9bkpzWkqRq6YvGD6uM+QyJSuCsxdZd15ylYGdmZhYYZFpAgOk1nLcO3QstzKFVmPgsthpdtXOvduidZQ9epqiocWDPk30GrfjG++rbi3Jzj8VizvjhOvI8lP43Z9rml+se2Dup5R01r15z2bXzlpadLHeGNA+31jXnc/Fs+e4F56+5yPwKIHCGaACq3GTlNy39WG7uVm43NTWhVCoFiuyo7NWqXL6hNQOsXIzFYhWd72rWXR0Vzh+AS4/11aTW57DniF95VrQLHg02nu1K0qKWesa62y+WG2r1ae4Za1zT2tO8QSBYOF1zNvVV63z0sGl+nHYVojesreS4qXZjbNNzHiz+Pz0dWyO4lpZe6g3bLkiaL8n83mQy6axYzQ3kWvASNzc3B0pVslWgFlOgsVBLcQfdM1vBSnOeNR9Tm60DcDWH2RGMF1kNoHIdhaziqHbYC27rONu8V86fzxyNRl1Xm6mpqUDlJFUSut7AfJGWSvPYGzEs6uLLH2UOp+bo0zPSvFO+D9/TZ5hUMhdFVfiZthEJ85vpQQMLi1XY0p1UglovQXNnVfkvV3zSh3TZO6GFfVg9jaWGmbtMZEzrwdO4oCwtFosVFTSxa+/LKeceEKXQPGrN8eY+NDU1BXLx+UzM7eaot2iP765q0yJt8cq+8Vr/wFfa1p63bDaLlpYWd4bYCa2WipG+sSwKWS0sFktIpVLuNTY2huHhYYyPjztYWNuL8dDZC66XpNZ56dzsQdPOU/xMVbb0ELSake3TrMXeWc5RWxiq51Mp3KFWtC1+wQPDvsc8bGNjYxgZGXGNJrLZrDN6LMTERhkTExOIxWKIRCIBL5zF161lWs26W3hdL7fWDFdokvNTj4Brrc1J9P84N66ZQqW1GhOK8tha2yqcFFKlh0zDQ40SFbpaVlCLh2jtYhY3WU7iijVWbZEJCiN2cbMtA9va2gJKTe8u37fS82O9Ri3UQAVMyNz+jDA1DR6fdxwOhwOKjAY0lUuj6yov9ay2wJDeB21wo3WY5+bmnLFHhTMxMeG8Nq1v3d7eXpFD41PEtjgQzwBfOh++N1FEyphIZH9LVy3MQdTEjlqVsnW01IjLZrOYmJhwMjGVSi1wWHgfASyoZsiOXMlkEq2t+ztgcX1isRgAOOOiXqOi4QpZhe/s7Kyr4jU8PIyRkRGMjIwgmUxicHAQQ0NDSCaTLiYCINCvUpWjesr1DF8sUwWqVuZieT7r5fLAq9VNRcGfM0arnZISiUSgtVel3W98Bo5WhCLqwKYS/397fx4ma1WeC+N3VXV1dw09T5u92Rs2m0lkMKAEcY5RQMQYTxwSEdRj/L7Ez2PwKIj+EocjDuE76hVNUAwOvwQSwahoNCbihBJBQDnggDK52WPPXV1Dj1Xv90fnXn2/T6/qrnrr7Q149bquunrvHqrWWu9az3A/9/M8jJVMT0+77/HQAStem60TTXiL5UvVU7BQYKv77lNs2gWmXC6HPGRCxPrcaHxUKpWQAlOFDIQboEQZKjAZT9P9t/2myTnQz0+n005JKVpEC13RDcbOFR7jOdlsb62ed1Eul1EsFkOCTIVxIrFS8S2fz4furypllmltZP71YoA82xoLZHtR3mVtq0chaeFpznd+ft55Nmynqn2d45A3jaxRlZ/1xqj4aIBo60AqOyocDQ20tbW55jdafna9+fhQCZ0PjTE6VsVi0dupT3k0NE7b29vR3d2NQqGA7u5udHR0rAndRL2rdu6K6PDcTExMOJ1DmcgzpP3XOR9FeCgfCfvTWGVZ0zjLbMaqkK0ypgCbnp7G6OioI4VMTExgamoKk5OTTlkAKxuh3XF4MRR2iXpJfN6xtUbV6+QhTqVSyGQyoXivr3Y1/02lTYXMnsIkWKgybpR8tB7iQG94cnLSHbZCoeAuEJWdem3JZBILCwvo7Ox08XFCXIwZ6eXSGG8U6NF6PFqiziplCiKNjVPA2NDH7Oys6yamMZ168SC+VzNzt+eZaIQaD1q6kYaPKmQ909YIVLiU3l1bW5sjgSmSsVlDkad6MDXhPhp+7CBG5QfA28Pa3ttmlDLnQcFKZTA9Pe3kBxE2emg2pq3IlhrA1WrVoXJBECCdTqOrqyt094/EvtdDX9TwI7zKEJ9FHdjullAq73N3d3eoSUgzHrKdjyJwVGbFYjGkjLn/vHOKYHV2dro7Ozs7G2qmoTyiqIanle0KU5PENTo6ikOHDjnCrho4WlaY78c9I8eARgR1E/WDja23YjjHppB9cBcFFeGC0dFRjI6OYmJiIsTUJLuNnsRmFR5XxWAbMWhdWbXqCf0QembcWIkUVMjaLUnb6rF1ofZwbiaG7IOsqXApnEZHR3H48GGMj487RawCXy8MBRL/39HR4bwOtpfjmrQDUdS4Wj041Ae/82VjealUag3rdmFhAclkEplMxvVC1k5JvkYNUYb17m2YQ+uHKyMfWCWk+QbPkBp4+txUSGw2fFovVu5jOisSQKFE1EX5ARo/1vlvtBY9L/XigIoIzc/PO8Hpa9ighq82ygiCwLU4VGNCvb3N3G+fV2eRI5+jQEOVRgtRDDoLHR0d6O7uXpMh0ui+27rYNmSgrSIV3uX9UkQTgJufGhlKvtTGKq2GI62HXKlUnDFHWTk9Pe1CeDynbW1toTlzHarcC4WC4xMRgWF/7EwmE0uYIxaF7PM86Uno5ZmcnHQbwl69FL5tbW1eARSnMLJCx0foUMXFWDBZl729vc6L9DE8tdtMR0eHU8LantHHKI4K49HgISxDiJpwoipiVaJ6cdQ6p+etMShCXtYKbmXvlTBkiSNKjtIDrnNVAkkmkwnFzvl8lOjVyhmyikQFqSUuWQKiZWnajlQKSdfrzLOZcWO7Tgo0S3a0z98qE2VBx3Vf9T0stGuVlhK5yM+whrPuI9fJ4WvEEuda6q3N5yErF8E2n+DvKolOn5O+hzUumjGCLGSt76kGqHI86rV5JIQeBIHjA6hC1u5aeldbMfg1fqzIm5L/qIhpwGioi/ea95YdCpmqqcqYRo9Pd0W5uy0rZLsR2tWFMSeNZZIQwniZknI4LO08DmzeQqfWC1CI1KbSMG+xr6/P5Z9p+ocyZhlD5iFjHuB6vZKbmb+P+WgJXhpL8+2nL6a2tLSEcrkcovKzYXszVnYj+2+Vs0+xqYLj4D7zwqbTaXfWCIUxjmZjsFHgdn4mX7Y9JRES/R09E9xrKggSbIiUaP9akv7Yy7kZjkGrwz4TVcQ2v5T3woZA7J7ZvWtl+N6H8kYJTkR08vk8crmcO9v0+vTMqaFXz5PfzKH32aIAikAoWco3N2tEUQn7jKhG5lSPeEk5qeEY3gEbLgLgZCkA9zPG7YvFojvbmq/cCrvd7qUNhynZEkBo3jp3DVHpOulBE52ko6aQN/e7FXJXSwrZQksaRKfHNj4+jrGxMRcvVs9YlQaHxhTqNTdvhaBjDxyVGglPPCQAHFStRQQ0d04FsI/NqU3DLas6ijL2WYIqlJQMpwdDU2ionJSURsuURVIYF8xms+jt7Q3FWFodPuFgPZJ63pb+Xy+eogT0ijQ1je0wowzuoebGay6i7iNDLjQw+ftUuErwI9tec46JSPT397t8e8393Yzh89YsWgHAGRXMKWfscmFhwXkaNoTjyw9u9NyrEaSMc2WeKymIypj7l8/nXdtBnhPKG65V5xGX8dDosIiXEv0Yn1dlrHfZhyCqUcF72oyBoe+h5ENNcVOvmM9Fw3KZTMZ5kyqHksmkSw9iy1hLmCKSGAWFs8aNkiQ1G4DnVpFPkm05dwChO8DnwzURjaEzwGelhiHjyFEcgMgK2T7AxcVFx6gmm5oKeXR01BEwCKdqnJAbYQtUaGyh1TwvHkwrdDS/j4eHwiWXy7lqLay21dXVtUYh+5SyrsPCklHX4YN9Ndau0JT10PQrlZXuKwDH8OXByufzIeHAz2iV2m8Foe4J99AaanYf1CAhqWVqasrtsRansJ5Co3PXOWnBBaY7aEUtZd8DCBVBYKWikZERbNu2zSkMhjE0R50GnCrkzfaSVbirMFWB2t7e7khPAEIpRUtLSyHDWX8W5czbs6Aogw370PNKp9Nu32jQkAhYKpVcDj7vDkc9g2EzQwZWdlologqQSovGnu+exzmveh67hsASiYQLB/X29mJwcBCDg4POWaEnaZ0uxuoJX9OQYzjQEquambc6KZqmx1g3lap6xgMDAxgeHsbIyAi6u7vd3bXkS2W6k1SXTqfX1HigPtNc92afUSSFbD01Wg4qGCcmJkJsarLaFAazMKi1iKlAol7s9eatSllJKLSgAIRy+aicrULmvFUhW6KXLUrQyiXS2KWFfblGAO7zKeSV8KEFTfie6l1XqyvNzPv7+9dAMmqRN7MOH6xr94zKh/+ngKh3QYnMkLhBJILxnZ6enlAKS5RhFUNHRweWlpbcpaaC4Hnlc1B0hF7b8PAwtm/fjuHhYVc2VY0kW3RGK6Vt9rB3Wo0YMklzuZxTbPwb7pFWrNM7EMU75nvafdezy31To5NzZP4/57G4uBjKrV9eXg558crIjsrIjzK438qyt1kemsHBF+WXDQ1YRKIZw9k+fzsfDVvRQCOCODQ0hG3btqG3t9cZSBYu9qWzJhIJ5PN59zs2Dtusl6zevWbPaEoT9zCbzaKvrw/btm1z5TCZY2xz36enpwHAfY9EOt0bIkrUVzRcNt1D9ik1LR6gzF/mLGp+pfXgFFrl5dA4rM9L5t+3Mnclrtg8Or73ekaAFfC6LhoWzV6KRuav6/DBU/xseg00KOjNUdCTwh8EgYMdNUakqWCW2KNwX7NKmRfCQsCcG2Er/RuF7uxnKndhdnbWeVLKUYga//bFj5W4py8qVcJiSu7j2oi29PX1Odha0RP1+tTjPJJQql07nxcRB00JomClkNY7GlUR62drqED3UbkBzHW1sDa9HRqdet/1nutZ3Iw7u97Qe6xxYIVtuf/0PH0esr1TvuyCRtajc7GkPs6LDguzG5jaOTg4iN7eXnR2dmJ5eTnEJaJio6yhLOno6AghcJaA1qhS9sl1DUny84DVzAZyOXp7e12fhEwmg0QiEeJCUUnT2SyVSl6uhe5TKylcTSlkGzNWxqPmqGlKQrFYdNV8aFlxYxKJhHPxbZzOlp20cWQgulJWz9IWMODPAIRiKSyZRrjCktDUq1cBzQfCucc1bByJQwWZXhhCpOp5pVIrSe5sEKDQi7LP7YFT0lKj8UDrCWvNZm22wEIT3C9CpD6CGr+vFi09ZY0bRYHB1pu/VcyK5qgl7kNNbGiD72WF6JFWDtbwUGMpm82GMhEUGiSEZ6FqH/TbrFJWA92S4BSdALCm/Cg9SPU+GaIC4AxS9b59oaUjMaw8U3nC+amMWlpacmtVFE6NWiWRNquUrXzk53Ju6fRKjXN7f9kwgh6yNf6pqHlmtDKiKuQocWQ7fx8x1J5v7hX7JORyOecA8D7zDM3OzjquB98nkUg4XUgUwIZtFLpuZDTtIWvMWN16rQylBK5yueysBqYk8IIw/lSr1UKCTouP6+XTw9XssN5lPaWsZIRUKuW63aTTK3VNteSbKlstTq+KL5PJhH6vGUXmm/9631fPgqSFvr4+DA0NhbohcS/pFbPMHeEdG6PV9Asb09zosNl4IPdpeXnZEULYEWtxcdHNnYUHALh8Y3piiUQixKDlWaRCy2QyazzkqFCYXYN6r5bk1dHR4T5HLyPPlTLhKaTUs9TPOlJesQ0jaKzc8i5UyCkRifCjNSKirkONA2XhUqEyr5/3DFgtr0rjjJ+pldQqlQqWlpbcXut7872iNDtoZljvT++GGno0WDOZjHNo+BxIwgQQkpkMrVG5aE7+RmtRb92GLjQGzHOfSCTWfCZ5DwDWdC1LJBIuns/Po6yt5xQ163hx3irf1XFRRVnvDpO/wt+jccEGFvPz8yHiLoAQB0CNdr5HMyOSh6wx40KhgNnZWaeIJyYmQvmwvKw8NFyAEhhIIuIirdXiq/3cqlL2xWJ5ONSz4uctLy9jamrKlePje1BwMN1CuzixhRwFdJQHxFFvvXYvqKxYqWdoaAhHHXUUBgYG1qTT8LCl02ksLi46AgRhawpeTa3SC96MUqZlbT17Kir1XEhm0nUpe1kvFa1UVnsjlKY1alvxklUo8LLay6zpE5Y0RCXBC8uKSgBC1rt9lkdaKasytnApgDUCmukfJLkBCAmhVtEgnwFHI7FcLrs0sVKp5J4JUR4WfahWq641I+FGQqXMV6fhzzrzWqp0s/a/HqJloXnKQd5HKmTuBYBQqhfJUT09PWvaLja6Huuw6L1RdCuZXK3ux/3jZ9IbpgND4476gkVk9ExZpayhj0Y9e9+/dVgD2HIebKgIWDn3PFN9fX1YWlpyipv32DoE+l7kwjQqe5pWyMpm08o5rAvKAg0UiEEQhA4aYVIKZ0JLGquiUFYryzKbo456D069ASpnxgMBuFKT3GBeDgAutYUMT1pUyiRUdmhUOEZHPUiQB0k95OHhYUck0h6vFG5BELgSp21tbW5dvCjKRtd4fiMGhvW+9Pv8DMLiiUTC7RPDGVSmilrYzyYxhgo7l8uFPGRLFml2/3Wv7YWzcWKuh5+nLFotvGLJeBw8L77nu5lDjSb+X70Em5dKz0c9hbji3npmaFwqMkdDncY6jTXGKJnJwXCM1oEGVuFtOgl8PyrkVpC4eqMeQmeNH8ur4FeuXWPEXIvWS6AyZogqisdfTyHTQ6YMUMNBW82Sw0JEjgZ/oVBw+2vzwtVb5r0n6tGqvNTzpF8t34H3mn8DrMgnJQoSlgZWG0oQlbFGehTuSiRSl1pqPjYelRktHV6qfD4fEvj8Gwpa3RSNy1mCQpRhITQb27NxPiDM4qUXlkwmQ1YdsCKIstmsg3uVbEWPXz01KvUoUJ5Cgir4dA3cQ6bcEBJms3h6PbzgCwsLIe+A61IhSGGsrOJGD5saCjQGuHYiEqpQVQnwYjKWTMXO96NxqHFN2yTEXvIoSs53dpRAREOS8CKNMUWTSqWS23+fQlbITT+Tz2uzBveBApBGI+dvMx30pQrZd1dbua+6zxrLpvCnwlJWMpEW3lNNI6LxT8XC99SytlYhx+UA+IhH1ivUtW8ko6hMNFTGLBAidKqQo8K+esd5FlU+K8SucXgADvkMggClUimkoDV0oPuhYTEN+7RqHFklbFE+q7T5uRo64Fnh3ug6FxYWQvwnNTiaub+R85BVqCkUoHEQbgQPPGEixiN5+Osd+laIOPWGXnIlrVQqFUcYIcxO4eSLB1LQq5ff3t7uqtDwM/h+rVD6dai36WOIanyMgkxjPZoATyXIGBqVrZJGqOhsGT8bl2103lw3iSGcm+Y4EoYmQsE4GveNZ4bPgVwGCjYlpGmdaUtIa3bf9dL6zhFfSoYh5KWehXrOKvyWlpbcGdS9aAa6a3VYY4VnXsMXPM8AHKqVSCTWEAajKmWrkKwC0L2mcUjhR84Az4fmzwIIeTAaFqsXGos66nnEyl62ddktkdIS6bgGDdso4Y2wsS+GHPVZ6N3Wc69wr5VF9BzpSdKI4jzpPWroTEmCSuCNQozSYc+zVboWYdSfK4RtiZyUJXqXtZhO1FTLphWyXhROjkJ1fn4ePT09rgMPv/KwML5Aj8cSDiwkXi8tKcqw8+aFpFDk+5OARiagHjQeDFpyluZOC71SqThlyJZ0UZPeOXdr+FhGt8LIwGo8zyfIeFFIqtMYqKaRAHDxQkKVtBijxGXVE1PPJ5fLhWJiOnfCzxSq3APGZhkjTCQSKJfLAFaRDQo5ess2NhhF4KrXqnE/7ovGw/nMqby4n0qI5O/Pzc05D6erqyu0Vq53sxSyj8xjlQYL9DMspW0vKfi137fuNfet2eHba+uVKbud8+WaVI7wPXjv6UmS3U8lZgvmRJm3hr+sTLPQP3kFLD6hXcS0kYMSFLknSkpjtSzlsFhS10Zr2SgO61Nq+tW+EolEyDFQNAJYzQmm0VqpVJyxDYQNj2ah36jDKmk1ChWBtIY04+KaImbDZI2MphSyT6l1dXWFFkLLggoaWC27RxcfQMg75oarIFXokcQRG9NodlCZUQlogQAqNQoWQlz6EBKJhBNWmuyue2M9oq6uLm/Se7NDPWMbu1RlS2+Bh8hn1dK4IDFDX6xTzAulUHClUlkDyTS7HvVOgyBw8Xdgtb4sY1G9vb112+txTiyXSQGsUCurx/Hlg6manbsaRzxLFPJaBIEKiyiEVj8ql8tujWSTs4Qfq8IBYeMkLuiO+67/Vg9OK9hpt7apqSmMj49jcnLSlb7lOSMSxPilLfkZVRkDqzC6jdvb829z6TWuTPSKBWPYErWvr8/VqGcNcZ8xEWVoDJZOhe6pFs7QHs/a5pDtJsnHUWRCnQrNb2d5VhaeIQGp1RCC79n4IF9VzkCYCEZYnQVEeP8ZDpydnQWwykC3qYCtDoYxNlqX/t/34lB0lFC79ZCblfVNK2RulBWkmntLhi/jabToCCMRrtBYoMambaUVG4ONSsyxChlYhbC0fCEVMudu48pavNzmupLpmUyupO/wd+KsGOUrpqF5h0rbt4JFlbMKKl2/ds9RhaywWiseP9+bawFWoc/FxUWX02hbEXJQ8JZKJYyNjQFA6Hlo7FY7A9miCVERFz1LmqZHHgGVP+dJlIHz0Bzsrq4ulEqlkDJmfEqfa1TeAfer3le19JUTosqYGRSsSV+pVBz5jNwQ1gWux+6NOmyIQBnufPGsc58ZUwYQUuCs3kaFPDAwULdXeVRjol6sWJEGvmyfYZLP+HMiEppTr2tSmJrrIsvaB8E3stf23z5l26hSszKL5125OOoh0/gCEEKeyFdpde4bjWY+Q+8O16nKOKpsbBqyVtiUtG8bR+Dm+5QVhbpCALT+FSpTIpHGUVpRAKoErBXGS0gyFL1aq9Q4NwpWVbhKLqIhQYu9VeVl4TsfkYJ7qn+33vv5iGAKyShEb1mQra4HWEUs9Lnws7q6utbEz/h5NHoKhQISiQRKpZLrsU1IT+NSirQozB/FgrUesipkhifK5XKITEPhnEyupoNoX2FNK6MypiehtbijeMkWjrYesaae8Gzb+gJaY4D9hynk6R0TKmVqXbNkonp7za8WGtVzyzNPY5mplIlEIgRrM7RGeJcvhattMY1mh40XKzJC5cuud6xkpUpaewYTymbfYd4TRZIIU9u1REl5qrfnimo1OvR9rJecz+cdasTzqE4Az5Q6Mc3c1UbnbuP8zQzqKxoW+ooCVXM07SEDq4KUH2ZjB/R2lOVKhQWsEo700KvS1thPHIvk3K3FxrkrJEKiBD0bfk+FPAsPaOyZv6tfqVyixhPssEQi9Wz18lFoW0td93KjQ8p/W0Gu828lpqNnSeF+Qli6t6qMaSCwsEylUgk1alBL26Zt+QgyUeeuRik9FaIHqmx5aS2BS2tCA6tN3dva2tDV1eXtIqNCtVGLX++UfSlHQ+ObVMiMz09PTzvvTZUdQx6ah6qxyzhh0npr8xkUXAuVAbBaYUrTKi13wirjVrxjTXfTZgeE/1lSmMqXL604pxwDIhLqGWuYwxZL4Xp8KNlGw7Lp1RCyXmIjXqUar0T06JwxtGDJo1YZNzL4vOvNncMapla2NfJ5/B0ayipbbSivGTkZmdQVBEEoZ0sPMA+/sgZpJVLoKoXfBwVYodyKd6xz10uq3+N700KjAOX3OC9acqwexZ+t93DjmrfGu61Q0fxg7pXCkPpSj4IKSslzNJy4fp8HHYeg1b1XS7be/vHfRC4WFhbWVASynqmyVX0XPYqHrIYRU8sYK+a+cU2dnZ1rwhbKlqW3z/fNZrNOGCv7VhVyI96KKmMqKMvg1abzykpXmL9YLGJqasp1utFUPypjwqWa/6pQMvet2eE7B5Z0pkiaCnPyP/hMVN5o6MKGMVoxInTPtZIZ95GNdyYnJ52RozWfieRw/1Vh6Xnr7e3FwMCAawvLGLimFSl/pFHP2DotSqSzaEejsk7fU+tR8BwlEgnntKih2IwT1ujcVc6r8rQvrrMR+bCerIoyIinkIAhCQsFi9Tw89pIDK8xTenS6Ufp+cS/SztNnUPArY79KCuFnLy8vOw+BjF5b51kfaDPxi42GDRUQ/mHNWLWIFWqm4CXEruzpIAhCAk3zd2n1Eda3rFafMRV1+Lw+3zPn8+AltqxbvpSQpgQ8qxSjKGN+PlEgkpkAuD1SeI6eJffXxuD5osdvGbZ8vqqQGzlT1iCjgtUet2T1qiemaIKSkGZnZ106oBoijMey5aGP1NXK8ClinZuS9hg7pnGpApjPTz1m36uV+6p7zjKRJO2xzj+74LGSoQ1daH1nLYZDsqkW+mFL2OHhYVcaV5GyKPdTFdp6CJx1nNa7S/Y9qZDXM7qjyHyVkVZeKXtbQwo+uNkq5PVefD89A63oqUh5yOoJ8/+WeNHR0RFiTKvHwoerDFLfaFUBrzd3awyo12MZ3Wr1AnCQtSU8qTXrg3qagRvtnClMNEZfrVZdrFrJP8Bq6g+VcalUcodSU3FUQOtXjSPzEllvPE5IUi+B/l+/z7n7rGElFdqYP8+gL4c6yjzVMMrn805Jab49y6hq1Tr1hNQIokFXLpdd2U9VnqyfbhXHRvNX5UClylglm8BMTU2FiENqYFpom5A6SZEk35HdyzhmXH2creBXz1hjsjRi7D3cSKjHDaVzrrYyG2PvZKqzLa09C5qbrrwDxoxzuRz6+vowMjKC7du3Y2hoCP39/Y6cRoUclZVs0R+b761KzSpkq5z1Hlt0j8pS0U+r1PRrM3O3n6G8AEU69WzrmVfypM+LtmgoPzuu0ZJCBlYVm1XIXKgmfzNuYNmR9j03c+jn2MA/YRVlEKsgSCRWc6vpgar3pcpYC0HEYX3rgWtvb3fGgK+Dk85P02z4LPizRCIR8sTU0+CFsVCTTyG3OupdPOsd86u99MoLUMKehax9HnIUL1lRIH4+jSQ1XLLZbCg+ODs7i87OTucZsbYyDYUgCNaQvHzFIZqB0lRBUClTOYyOjq5JZeL+KKvdGj2Eq8ksV0KRMnvrGVmNDPvsLVFKDRqtxqWhJt9+kOPRqidTb1BekFxGtrS2pJ2cnHRsdV91Qw7KRo0dszzv0NCQU8hKpotqKFsjVwuo2DtvlZr1MImsqdFo9YMNgel5A5pXxvZeWu9ew1iUfzYnnAVulKyl6Kd61PXOT6sOSsuVuoDVg6ObbnN3FxcXQ15MnLHIqHNXCFIfKL1Du/kUCJYoRCFAeFfr0MZRrN5C6rZIgoWS+fs0GOgh0zBaXl5GZ2cnEomEUwzWWufvWjKMxqqjKmSforX/1v8rRKRWrZLmuD86Hxtr5L+bJYvYoc/BPhM90xQOmrakMTK12imIbQikHrGx0X1Wb8vmQjM+PDk56eKZWocbgAvhsKsSQwK8JxR2RLx8vAob3mpmWENMDWSLfGx0FtTYaDVToN5c68HramQxNKQcAT5rNV4sb8BXeY+FWOIwku1Z9nnIHNbD9HFUNIa/ntEdR0jSOoWUh75QHo1UNehKpRISicQaJ5LyU+P7akDz8zgHldVRnkNkhcwP5VdVbKp8arVaaGOURGEh3XpjM6FrDiUUaaxBSzPS4mV6giVeaF4mvQZ6TTZmHmW+FqbVfdX9pedGRKJUKjkrlIKBCpnFNQiTkl2uh81anDaG3Oh6rIJd73v23+oh2UpXahTxb9RTUdKSwlBRhw3Z6OCZsaX0mCqkXqMV4KrI9My3ImDtuVFyE+erSsNyIegR0cBQg8EaOoT8VPBZ1muzCsP3u/WY4r5whoZofKmUcStlztkib7r/arzZO6QKWc+zvQeK4PnOTNQ5q0IjI5reI+esd9Gy8qm0tfCTes4WIrZGdatGshrBGvtWg5gynVkEDDstLCyE8tB5NxjiYccw8oiUIGiftZ7bZs57SwrZbogOboAeQH2pUvZNOA6rqZnh8ybUk9BqOuPj45iamgpVkdKYYl9fH3p6etDf34/e3l7kcrmWYjscvtCAFj4gvEQ2bLVaRaVSQTKZdDFKwoyEZxYXFzE1NYXp6WlHPqKBwQOn8RjriTejjNd71WNtqgBSoaolB1lak0KWkJQyin357HHClvWEsBp69TwKG4/i8/V9bXTPrZeuzRkWFxfdWeCZ4Tn2GQc0bNra2jA3N4f29nZHPGMcl2mEvDuWJ8J/A35Dxjd/uyf6dz7lZPedFelomFYqFXR2dq4hgMWllPX52zSfXC7njGGtcdDe3u68L97TeoalEjTZrEFlqs9LbkYZ6L7RE+dZ1Rr9QLgELMlrrEtRrVZDZUgVHlblTY+Ta6bxF3Xfed59pC6eHc4jCALMzMygo2OllWSlUllTbpS/z054zDbg3/PMW9SoFbZ+ZIV8695bcfV/Xo27D96NQ6VD+PIrv4yXnvzSNVC2tQotbO2DWdZTxq1enPd87z147/ffG/reSQMn4d433rsG1iMpY3Jy0uUOapEEpoIkEgkXQ+zr68PQ0JCrADQwMOAad7cCXesF+9RPP4WP3P4RjJZHcVLvSfjz3X8earTONnNkg7OqlVb44kHkzyYnJ1EsFkMJ+irMtbRm1HzN2flZ/NX3/go3/+pmjFXGcMbwGfjr5/01fmfkd0KpWvblY9dqPJRKmVY3DUGNEan31yxcWe/M/PzPfu41HGx8zXr1JCFpv2YgzPy1RLUoaTnJZBIfveuj+Mr9X8Gvp3+NjmQHTus9DZdsvwR9fX1OkGpuMT9D48g0cLTjFo1PkriAlSYCvlxYrbSkMP1GoxbU8O7vvxvX33c9RsujGMmO4KKdF+GC7AVr5AXnnkgkQgYj0wBJxEyn0y6tjIoxLvhajalrf3ktvv7w1/FI8RGkkcae9j04r+c8DGIwNA8aNnxRgSmSw3OtxVpyuVwIKuVdVR5JM/wcVWjt7e349P/5ND7xk09grDKGPfk9eM3Aa1yMGlgNh+l8kslkKCyUz+ddXQF7bylr1AAhshd18N7c8OAN+OS9n8TE/ASOzRyLP8r9EdpSbe6e0jinsVmpVDAxMRHqeMd18uzQGeNdYQEpNTLjSJ+LrJDLi2WcMXIGXv+U1+NlN74s9DMLYVsvmd6iTVy3l8y+4hpPHnoyvvWabwFYgeySWA3g26Lvk5OTGB0dxejoqFPCCllXq9VQaT4q5IGBgVBupm3rFnXc+Isbcfm3L8fHX/hxnD5wOv7mx3+Dt9/zdlx9zNVOKc/NzaFWqzkLtFKpAFglH1HJEhbjgaPypmdhGYs2/aFRb5/P743/+kb8bPxn+PsL/x4j2RH808//CS++6cX4zz/5Twx1DoWEkA+S5KUmbFQsFkOGES84sOqp2WIRWqCjmTP15KEn45ZLbnHrSSXCMWGF4vTzbG6vepaM2dIj5fNRy76eQgbWF7R8v9v234b/66z/C6f1n4byXBlX3X4V3nX/u/C3T/pb9Jf7Xe/gIFhJA2S9chWW3C9684nESgUsGqcdHR0h6JuxTeVR8Pw1c24+fNuH8am7P4XPXPQZnNB7An6090d487ffjOUdyzgxeeIatIiGmHo4fEY8W8lkEt3d3SEvWc9dM0rMNyiY7xy7ExefeDF2d+zGbHEW1z58LT5d/jSuGLwiBAlr7nAqlQqlxum86CQQPu3s7Axxd3in1RvkzxpZj8rqG39xI/7qtr/C1c++Gk/qfhI+ec8ncdWjV+GdXe8MVQMk9EuyosLTfE8a2ppqx5rdep7ouCkq0ugzUGTi5gdvxvt//H689+z34vjO4/HZX34WHxv/GN6UfJOTdyoLFhZW+mdrHQMaw/x9ois8M7VazclPzcvXvPZmzrmOyAr5ghMuwAUnXNDQJlmIlQtWD1nJSArVrEdm0XhvM6Mt2YaR3Ij7HB+kQsuPdXyZO0g2Kq1rCgCmJXR3d7s0kFb6ktbb04/d/jG84XfegNc+5bWYn5/Hh5/1Ydzy6C34YfmHODN3plPIXJPW26bAUthZ4+Q8XBobsTm+US3AuaU5fOn+L+Gm/3YTzt1xLqrVKi4/+3J846Fv4NP3fBpvfcpbQyQmVcSaT61KjaEE7XdL5WZh6VbJI23JNmzLb3N/z3OjqArnqIU16r14hjTXW42ejbzjRgXsN/7kGyEv/RP5T+C0//9pOIiDGOwZdDXDGe9VCFUVAz0KVQ7KHOeeq3enwlnvf6P7/6P9P8JLTnoJLjzxQiwtLWGkYwRf+MUX8EDlAZzSdkrIwKesoEJSYqZC8el0OrQ+i5q0opRV5t30Bze5M1rJV/CuznfhZbe/DDOZGYwkRkLkT3VMUqmUy0nXEAz3XRWgQsyExmlMk9NCRdfIejj3j9/1cbzujNfhtU95LSqVCv7X7/4vfP/Q93FX9S7sSe8JpVYuLy+7Gt38Po0SohNBEDiYWjtY0XimDuDfNnPO7d7/7U//FhefcjH++OQ/RrFYxOWnXI4f/eBHuDdxL3akdoS4E5SNrD3Pu6YZKzREGcIkkkQUiXNW9FdRimbXEVsMud5QGEetCAtfU7kBq96NsmmtYo6qjIMgwANTD+Dojx6NjlQHzt5+Nv7y6X+J4Y5hp4htzVkN5mv8h5eJipgvErqojK13HFUpL1YXcfehu3HFM65YhZfS7Xj68NPxYPFBPDv/bPT09Dirf2FhwV0K7V41Pz/v4h7ck6WlJQe92+5PqoiVpNPMOpaqS6gGVXSkOkKCsiPZgTsO3YHyCaulIi1RiC+foiP8aEudqvFneQtRnsMDUw9g+//ejs62Tpxz9Dn4X8/5Xzgqe5QTkopGaLce/UpkhcqY3gEVsDap0FQ2VcjNWN0WtkwkElhMrOTSb+vZhlzbSucjsuoZX/UVCtG+wvQcGIuenZ1137MoB5UvBV61Wg15QuuNc3eei2vvvhYPTD2A3d278fPJn+Pu8bvxZ7v/DJ2F1U5HzJunsFRInMqYMXANYdjqXrb4SpRhY8iE84NK4Pa9L9vnzVzg3eMz4LniveU9JXrnC29wzZoOyf1vZO6L1UX85NBPcMW5V6ymPnV04mn9T8NvCr/BKR2nuPPK58s7YHPlgyBwYQ6FqqmUafzwPmoMtlmFlkgksBws457Re/AXZ/3FaqpYZwZndJ2BA8UDOCF7AjKZjIOd1RFTnoblydBpof5RJzKdTocqJrZaD31TFbJajBpsp4BUWj0vqm6AlvZrtaADsKJ4zt5xNv7+xX+P43uPx4HZA/jAbR/A+Teej2+8+BuozdVCPV+Z1F8qlZwiJlyXSqUc8aG7uxsDAwMYGBhwvVU1jtZKipCOicoEqkEV2/LbQhd/ODeMB2YeQE9Pj4NUeHm1KAu9CF5uPhcg7MGwew9zG62XFgWK6erowjk7zsEHb/sgdp+3G33pPtz4yxtx99jd2JnbiZmZmRBr2laNsjFkfqXHrLFjjXuzB6umh2gsv5Hxuzt+F5/7g8/hpMGTcHD2IN77/ffief/wPNz26tuQXEq6mDANOiphnqPZ2VkHrROqJhqRSqVC6ArDHIxntYJKAOFC+7Wghnf94F04e9vZOOOoMzA6OoqBgQEEQeAUgRoXGr5Rg5TzrtVqrouWQoE2LxUId+9pNFzwjme+A7MLszjl705BKplCtVbF2896Oy4avggPP/ywCw8BCMUk6xU1YSxcQ1M8Q+RUcM8ANH1nVd5R1tVqNQQI8De//huc0XcGzjjqDJRKpTX9gVlMRtEfFj1hSIExWnrJ9JzV8KJy5HnSbIJGDKHJuUlUgypG8iMhuT2YGcQDMw+4OSurnrCuGvj8HtPlKpWKc3AYJgHg9t2icdbLbGRw7tvy20KhtqHMEPaW9qKnp8fl/pPIxbCjnmvWyvAZBnxfzb/XPHwa01HmD2yiQtY4lyV1EV7h10wm4w4WgBCb0Lbga1Uhn3fcee5SHpc7Dp9/4edx9j+djS/+8ot4Vu5ZDpqm8KS3TAuaa9CC+l1dXejr68Pg4CD6+/vXFEhoJci/3v5aIgGFq8JFGnPlReGgRa3wtBYfIOTui383ug5FNK678Dq88etvxCnXnYJUIoUn9z8Z5x99Pn42+TOMj4+7ODaFpLIwFTGxbGWNHdNQ0t6rfLHeb7M5mxqaOXXoVJw5ciZO+NsT8MVffBEvOupFrhKTesbT09Mhz1iZ+hSwhBq7u7tdP1vWKNZ+tlEga3tWAOAt//4W/GLiF/jmK76JXCKH/v5+ACtCsa+vL5QvS6VFj187FHGvATjByowEhYL5zAC4tS4vLzdM3rnx5zfi+vuux/Uvux4nD5yMu/ffjcu/czm6T+vGU/ueipGREVSrVXR2doaMH87Rxvfb2tpCRD9Cp6VSKWSk2bsaxQCiYQgA7/zPd+Kh4kO44QU3oDfR61IjbeMJGnaMszL7gcYG7y85IopcKJzN6mkagoqKrmhMPpVMoSu70uqR3i9RKYZfNM+dLHD+nGvTzBQidaoXrCG60Zm3P9cwIhn/qbYUBgcHsbS05OQmYWfNvqC+UkSNMDYdsGw2GyoZqyFKkhyjZtVsOmRN6MIShHRxLAPJh0vmGxWEQsVRCDnA2qR9Xsx0NY1duV14cPJBnFw5GYcPH8b4+DgKhUKI9EHvq7Oz0yng/v5+9PT0OKiagl89yziV8WB2EKlECmOVsdDBmV6cxrb8NgwMDLg4YCKRcC0KtYa1WtUKGTE1g11kuDafQo4aKtjduxtff/nXMVWawkRxAvkgjzd//80YbhvGxMSEK1Ch9ZwpcCw6oqlShEg1FmtLO1Ih1zMwGl1DEATo6ejBnt49eGDqAZS6SpiamnLtH4mqkJVPSN1CkBpDJMLCCkzaLMCX9x1lvPnf3oyvP/h1fOc138HRuaOdgG9rW+kuZQtU0NBhriZrMFMpcx1UKm1tbS7lzuaX0kjSVpKN3N+3f+vteMcz3oFXnfoq1Go1PKn/SXhk+hF85hefwQue8QJXhCKXy2FyctJ56nNzcw4hotHGeCfvvy0KoQrAR6BrZFgPOZFI4O3ffTu+/ei38dWXfRU7sjuwtLTkPExVpHRACO1OTU05T6utrc3tOVEvxvC1KYXCwjSUlCTYiIfsZEx5LKSUp5emMdA54Hov8+wo5K8NNVShkVinZ4xII7Da31kdtSiFToZyQ0glUphYmHDM/uXlZRRrRQxlhjDYN+jOIw0FVlSjYgbgkAglIFNGqgPGO8vSpVTGvkItj4sYskI/fIBUzBo/0bZ5wGq1Im12YCvxRB1KyFlcXMRUaQr7Svtwbu5cJ3hGR0dRLBbdA1KLiwq5v78fw8PD6O/vd3AFe5GqMo4CW9Qb7al2nLX9LHz74W/jD076g5ULk0riBwd+gNc9+XXo6elxB2BxcdGxMTUewj3gS+Ei9rVlOzdrWLS6Bn5mJpXBYMcgDkwewO0Tt+NVA6/C9PR0qOg+yVq++LDGnNSrUU/ftgW0rQFbeS6z87N4pPAILtx5YYjtzS4+TJPTBgJaq1g9L3IQ6CGrJ6/FGOplIzSy52/+tzfjy/d/Gd+95LvY078nxIJtb29HPp8PVQRTg7VcLiOXy4Xiq7ynjMWR8EUlrR4Z00O0PGSjBnVlqYJkIlzeNt2WBhJAd3e3E+z0bpnC19bWtgY2VyNUkRVFYxgDbCU0xnkGQYC33vJW/OuD/4p//+N/x3G9x4W4E+q9KxmwUqmgq6vLxWgVGbQ57ABCrGwN7yWTSdf4pLu7u+F9dzLmkW/jJSe+ZGXvkwncMXYHXrr9pci2Z9154ftRNtPw4T7oOVWHCEBIWStLXAmNzdzRRCKBjrYOnHXUWfje3u/hJSe8ZAWFak/jjok78N+O/m/o7eoNkeQIo/NcaHEZXQcNI2vos8sZPWNm00RVxkALCrm0WMKDUw+6/z8y/QjuOXwP+jP92NWzK7RRPKS0turlJav3o4eVD7zVFKggCHDFt6/AC455AYbbh/Gbyd/gf//kfyOZSOLszNmYmZ4JwYuEuVQwqrDXmB+FvcaN41TGHG8956249CuX4qnbn4qnHvVUfOyOj6GyVMGlp1+KTDLjBI7mgyqLURPweegoOOmxEbmwhdlbXcu3HvkWFhcXsTOzE/eP348P3vVB7MzsxO92/C4OjR9aQ4aip0yFrOxNCmE+F6Yh6NnS9VgorBnL+23/8TZcdOJF2NWzC/sK+/Ce770HqUQK5+84H/NT8yFPSwvIML2DEB2NCvVUuO9aDlGbhShRJsr+v+kbb8IN992Am191M7o7uzFaHkUQBOhKdznjxJcDzjQzxny1hKAqMypuKnGSw3j21KBuRhkDwEUnXoSrfnAVdvXswilDp+DuA3fjb+/+W/zxk/7Y9QTmvi4sLGBmZiaU9qNClueec6CyUwWnsWeVRc3ueSKRwJv//c34p5/9E770ii+hN9uLyYVJBEGA7vZuZNuzIeLb0tJKTXrGXGlc8D7wDDHnlXPm3yofhHckn887Y1ZlZyODMuasbWfhd4Z/Bx+7/WOYq87hJTtfgsXpRedMETJPJFYZ9sr14TPQoXLRyh9NG7J8lUaV8mVPvwyv/cprcdZRZ+HM4TPxsTs+hvnlebz8+JcjKAXuPJbL5RDDXdHTZDLpzjL1kqJvlCfkpSjUrujbESV13XXwLjzv889z/3/rf7wVAHDpGZficy/93JqNsorZ9+LkbbpCq4pYx4HiAbzh396Aqbkp9HX04Yy+M/Dx0z+OxEwC40vja6rp0LO3bFEtluGLe8QdM+Z45amvxHhlHH/1vb/C4dJhnDFyBr72yq9he892J/iaLXFpFZlNt4lDGQdBgMJ8Ae++9d04WDqInvYePHfkuXjFwCswfXg6xOZVsg2FCgUq51ar1RzsxAtl18NnpfBZlGezf3Y//vhf/hiTc5MYyg7h6Tuejm+87BvILmRRrpZDxWSoqHQdVFi0vhOJcBUqnafON44zdM1d1wAAnvv554a+f91LrsOlp1/qjAN98f61tbU5T4KGpq18BMAZzwBCys0K6Gbv8Mcv+Dj+8rt/iT//xp9jrDyG7fnt+O9P+e9461lvxfLCcigjQI1m7pcaAFaO8KWetC+TI+r45F2fBAA8/x+eH973i67DJadfEjKANHuBrOV6MgVYNTTI/GXmBFt11qtK1+h45amvxFh5DO/5/ntwuHwYpw6eis/9/ucwhCGMFkdDsVUNhVkSHc8RsIpsqaJWL3Q9fdDMeNWpr8J4eRzv/f57cbh8GKcNnYbrX3Q9RtpHMLkwWZc4xjXYkB73TZ1Je1f5Xj40q9nRkELmpGZnZ933zuw/E4XLCt7f19/j32vOmvZgtYQdDgpfZdIygTsIAhcX1UXzc/Xw8d+FQgHVahWfeO4nQvEjwo1jlbEQiYyEBSpkTb9RiInxMwDOo48iTNebu93PS06+BJecfElIqBBi1xrP6pkodAes1spVNqpvffT8CAP6GMobzZ2Ixwu2vwDP/sNnO+FRKBQwMTHhiFzK7rWCXfNDrdfCi6IKQONzc3NzISXCn6fTaZRKpQ33/doXXuu+T2+sXC5jojQRKhKjxee5BoUoOW/NkVV2ss21Zmk+m5LTzJmpd0eDIECpVAp5gvxq46w+opYl1nEtLCqiRpUyh+lxE3Ld6Ly/79z34X3nvi8kQ+Yrq++tYS2NgatyVSXLtSlxtFwur+mWtry8XJeYY/fdN2+777pO7ruuSau52SYGFubWHHiN5SorXgvQMEWKKVMbzR1YkTEXn3hxqPgHO1RpBoTGkOudeWDVKaMxzfnwPfSe6p4vLy87WV8sFhue+2tOeo3bU8bWdV81tVLPCj1kGg+2JoLKFr2rXBdJao3cVe8IGhj79u0LADxhXvv27dua+9bct+b+BHj9Nsz9iTbvrbk/9nOvNxLBhip7BV44ePAgurq6Yodg4xxBEKBYLGL79u0heGdr7ps7tub+2IytuT82w879iTJvYGvuj9XwnXffaEghb42tsTW2xtbYGltjc0dDMeQniiXy22R1A1tzPxJja+6PzfhtmvsTZd7A1twfq9Goh7wVQ36cvbbmvjX3rbk/MV6/DbHMrbk/NnOvNxrykLu6ugAA+/btQ3d3dyN/AgAhFiHZmKVSCRMTE9i3bx8eeOABPPjgg9i/fz9mZmZCNVtZ2WdgYAC7du3Cnj17sHv3bmzbtg29vb2hXE2y8orFInbu3Onmu9Hcg/9i1ZFJqE3vC4UCxsfHceDAAezfvx8HDhxwlaTIhFSGntbqZoGNkZER7NmzByeffDJOPPFEbN++HX19fWt6IwMrLLxm5m73WAseVCoVTE9P4+DBg27u4+PjmJmZCZVvrNVq6OzsRH9/P4455hgcd9xx2LVrF4aGhlxNbla2IrXfxxxfb+6PPvoourq6QsUmyFKenZ3F+Pg49u3bh7179+LQoUOYmpoKNWDQjjBam5r1f7U8phY00QInPCu+lKdG5q6FZDj3QqGAQ4cOYe/evXjooYewf/9+11OaecdkXfb09GD79u047rjjsGfPHhx99NGhSmh6lpspcRj1zPjOEO+Csn2LxSLGx8fxyCOP4P7778evf/1rHDx4EPPz88hkMhgZGcHOnTtx9NFHu5ajrGCkVd7qMZWbvauaH62ldcnQZYc25u6y4hWrjPHcs547zxQL/ezcuRPHHnssjjnmGGzbti20Bk2TKRaL2LVrl5vvRvPmy95TZhqMjo7i4MGDOHToEMbHx0NFZmZnZ11t5b6+vtA9ZWEiFqZgLQSbzrPevjd6Xjh/sozn5uYwNTWFgwcPYt++fThw4ACmpqbWtFVkTvTQ0BB27NiBY489Fjt27EBfX58rrdloilCjc+c8ub+HDx/Gww8/jAceeAB79+51lRhZ0YyyXDNQKMtzuRwGBgawfft2HH300dixYwf6+/vR39/vZA773etZ92Wi2PPuGw0pZL4xS0Q2Omzyu9a9rdeS0Ka3MP+Lub9MyNaLYsvc6futN3c9YEx1YF6mrdTD99cE9uC/0lj4VR8E520rR1GR1KtE0+jcuT9KzWceo+4HBS2fBX/OVC1Wn2HnHFafoUK2/Vo3ymf2zZ0KWasSaXUlFkBgycN0Oh0qwcj91bKrWuzEdgfTAhysa20VXqP73tXVFaqFzHPCHMT1ChjYf2suvtbI1cYXUYrJNHNmfMMKW56NxcXFNVWT7P7pmlg4h2eKld70WUWdO8/v4uKiK/Ch51tLdepg+mRnZ6e75zwfmt+rX/XftrUhz4DO1zdvyg5NvdLUTp2/reDF+elZYVchFpDp6upy59Mq5EaU3Hpz950P3WfWZbD9lzWtEoA7D9r7Wavn5XK5phu9NDJ37md7ezuCIEC5XHbORS6Xc7XONQeca+Nzo7zh8w7+K3VT06W0NgLz4Xmv6+Uib7TOTSudqcrMfl8rAlGYas1ZLa+m+aiaB0mhaBV4s/NT74eez/T0tKtHTCuKF0kLOQAIXSAmyevl0+L2msOpAi5K/EOtbnqezIljpypWi2L1MVtkgx6CVlbSkqa+GsrNztUntGmkaIWq7u5uzM3NIZlMulKLeh70oiSTSecpcS+4x6wkFfxXfqatzKUCtZGhlats0xP15DUHVs+wekXM62Y+rhp3rbRsi2OoYlAkwzZ20dx1bSRQLpfds7SNYPiKui57V+fn51EsFkP1wrUtp+Ym01MjcsV8dT1P6nmzsxXLhVKwqpLeSM5Yr9gaozw7WnKVa9HqbooOaUVDX7GbVmudb7QWLamqldq0QQfrSzCXF1jt5qSyrtViNxsNfW8+P9sBjjKYMkjvLtdKec7eCrOzs66ynVZIU4NT63Pz85tZa6wKWQ+qD6qxiliVFgUAlYVWoLEJ6KrYrCfb7HztJZ+ensb4+LirRVwoFFx3FW4+P4+WlZYeBFYFBy+hnTc9cR7MqPNXRaF9eGdmZjA+Pu4aNkxOTrr2hj4I2MLAWgI0ropRqpBpSBE16OnpcXucyWRCvXd9Ao2Vr7QyFuHKfD6PUqkUauDgq8XdSG1uFUaqULUzD4Wn1lu3RpoK+9nZWeRyObenFAqdnZ0hIXCkht5TXSeFLeFgaxSnUqk1BXvoGVOZtHI3ffPU8NLMzAzGxsZclzAtWKG1lXnXWJKyvb3d/cyW0SRMXygUXFEWYLW0I89towpZi+7wjiqsTkh9dHQUhw8fxujo6Jo2pMAq0qPVAbVqWiulGhsZFmq3RoWWip2bmwMAZ3By/1qtwBVlKMLARkXsspVMJl03KEXW7BliAZNyuYxEIoGFhQVXDIh16RUl0LOh5aAbHS0pZH6Y76vvUFpvUT1jWuVatowK2aeUfaXuosxfFTI721CJ8dJwXlpD2VZysVV0dF32RWXRiLW93txVmFQqlZB3Pzk56WJnbA9IyI4WNg8o4SNCp75ayq0oZF4MYBVF4NzpHROKZCcchdmp0IhgELrk/vJ9uaZKpYJEIuHWRzjYCoZmBKtVyNoq0nrJKrT59zp/KgeFrBmCadWbjDIsZK2tCblG7bbG/bceP+OxUWtXNzJHWznq8OHDmJmZCRkB/H1+5V3r6OgIxcjZCzcIAlfFigo5nU67c6slOjdqHWnhXVVipVLJIW/0iKenpzE6OorR0VFMTk66LnNq0CmaRGWsSFacJW59a9H1+DxjrT/PmtuJxEpTBhqdvprsm+0l11PIrMJGJE2fFc807zX1TKlUwsLCglsvOTh8Xw2dKqrS7DojKWQbH7EvjZtY78anYNUioVUCwJXhs1CghTFbGVYY0fqjEFL4nA+YD0KNDFqzVBAW3rPlK63Qjgq3q9dA65ttAAlXa4MDHk4eJOsdK1wdtQ1dvcHLCcB5G7wktDYzmYyDg7hHy8vLKJfLSKVSzloNgsApD0KQhLuDIEA2m0VfX5/zNhgDVJ6A7mW9od6BkokolOit6/uoMOC+UYnxOVHIc+1WkR/JsZ6HXK8MK2PpCuFzPzZjPTZEQwWn/AMa8yoQqdD0OdBDCoLA8RYopyiQqfi0peRGiJyFqe2ZIXpFg19DYzSatd+0bcJjPWRfjXaOuJVyI3B1uVx2DUkymUwIWbCcCw2BbZZiVoWshi/jy5QzwGrZY8oZYOWsUEfVajUXUqtUKqjVaujo6HDtXdnpic/EhlMbHU0rZB8MrYrGKiAqLIUWyWTmZVfomu/FOKG+v8Z94nqIlpxCpaHQKuENems27k3PUy+j3a9GFUCjwxdDVtiaFqvGOTWuSoYyWYLaHcqnjFvdbwsX61yohAG49ne6X6w7XautNDqg90JlzUuTSCSwtLSEdDrt4oq+9p2Noir1BCxffD/lFfD/CtfxgqvAJ7yr9X8fK4UMhI07Ww9Z7ynPPxWyvQdKxIvLO25kaJxSmx/wjGk8UGuEK9HKogQWkbOEMd/ge1ilZdE37Z89MzMTCn1Yw1k7DDE0oDHlzWxmo2tSrop9cb8o+5gpQyXYDDk0jqGynKQrZr+QTMq5cn3sS6D1qulFE/0CVs5QOp12bH7lLRC50PPSzD2IpJB9MWD1FHlBLWRDhcwDSctWBZLGkHUoIciyIeOIbTKmypSZIAicolDjQyEvHs6FhYWQQvYRGGwMxVpOUYldFk7Vlm08KLQEE4mEuyTd3d2Ovt/X1+d68DbTJSrKXnOo9dzZ2Ynl5WUHc6lBw/Ulk0ksLCyEmn/z/TRcAGANEmN5B81cEKuQKawppBkC4CUkNKpMWZ5ZhXp9nqdFTDYbtraolnq8hG4ZL9PUEGuUq+Gt+xunMrYQJPkHXV1dCILA3UFLlFOllUgkXOolACejFLmx8s0aSuutS+F8wuoaL6ZCZhiJ36cyVsKlpmUpimWVso/VG7d3bI01S6DjGeb9owOTz+dDKYnKum+Ew9HqUOZ/JpNBd3e382x1vhoKoVGv0Lv+Dp+98leUa8FnwyZImxpDtvCFsjBt1w6FpbX5drlcds3ox8fHHYtZvRd9UBqDUIsxDjIDLzhjj93d3RgYGEAikXBQonpUvGwkZ5C9zAepe8O5WWvd9vqMeoFsDFnRB+YbE0KicuDB7OnpwdDQEEZGRjA4OIjBwUH09fWF8hg3y+Lm++ll4f4SUlbLkrARsNKM3RffBuAEpxqJPnZ+M7wDDQ2occn3pZfPPWWPWBXmXDNhUyo9VcpRDIVWhyoYNTQJBWvow+beU0Bbhc7324x1+O7q4OAgALjUNAuRan9dGkQLCwsoFApOOVMhU87wfPpCcY2cFyI2zOEmQZT7SZi6UCi4WCWhfkVbiKBks9lQyo7t7e1rkxrn0POhKCdDYczcUNaxytLh4WEMDg5iYGAAvb29yOVykdP7Gh3WeOvo6HD1BNrb20P6BlhVtsViEclkEvPz8+js7HROH/fBB92rgaIoRtQ7HUkhaxyHh0oZl9qSTq0pvkqlEmZmZjA1NeVIU8pWtpvrU8itkhn0ofGS9/T0OM+YD00vI2ENKj0eLLbTU/iLc+bP7atVhaceCoW8ZXAyHsWLwoPZ19eHoaEhHHXUURgcHHRWbD6f9+Z1xz18F4bzs54i9zwIAhfbUwa47qFVyJZcV8/jaWSfLQJE+CqdTrsUGR+bXtdD65uXWFvsxcmLqLcOu149P8oEJzmQ5COm4ug9PdJscIUfs9ksent7Ua1WXYyXpEtNdbNhF6avdHR0uJggzxb/vhXPTe/i7OwsJiYmcODAAUfY4r1UY1mfP4BQzjHDSkq81PoL9XgecdxbNUbVAWPMWBWycm1Ulg4ODmJkZARDQ0MOhVuvWEycw6IpQbBSfyGbza4hHDLEwHPBbBvb91sNT2vEatvMVjgUkSBry7jTh6SEIkIx/LmSYbRf53rEGFWaqpTjIB3xvWmR5vN5ACtxTMJHHHwAc3NzLtdMhT2LXTAGsZlxHauw9LLoXtOoAOBg1Ww2i56eHvT392NgYACDg4OhIiCbbb1y6IWhsPWR3ej5LC4uhoqCqHGj87Qwm+U3NOod2/cDVnPO+T70fAE4Bq+mutl+tnwvjXv7mNlxDhWsuhb+X2Omelfp0RG2ZuiD6wDCd9QaxnGHOmhUanUtAE7A1mq1EAFKQ1v8+6WlJQdXVyoVF9fkXSc3QUNOzQyNQRMJHB0ddVWsyJ1hCEA5CJovS6NTXzz763nEmyFnfB4y5YwNZfAcKFw9MDDgKrcdyTiynkuiUzQWVBmrEQXAebpWGXPo99RYsaHbZuUMRyQPWZWyWvwsSkHrWtl3+gAJZ9NDoJCzG6qxYlq/Psu3laGKgZcBWLFULQzHg1cvzmoPmC/ernmcmosaZej7q1CltUbhydiaUv9ZWrK7uzukjI9UcQqFrvkM9ADz31TUiohwqMFkBakKOZ+QbXZ99mIpn4H5hmqokslLY4KKWD1SzVn2oUOtDl+sV/kQKmjVqKYXZ9Eu9eSAsNBTRGizQh0qYLPZ7BrOgULUGs4CVlnuyWQSc3NzIeXQ1tbmDD818ux52mhNNtaqSCA9ZOV1KAGQSJHmG1u0ysZzeYb4O/Ty4+Qf2M+0pC7NcuBabEqlrSJ2JGUMZbuicZYHwT1cWFhYt7Kcpm7yrKjRp58bdUROe7KeiC/GoGxqKmRfoQz1DqxnbBduc0njuvzWq/KlKdnULY2TK5GNAoLf0z3RFAz92qwwtl6kL8ZpGcC86Ix1aMwjziIgjQ5+jgoQq5CplDkvPgvreWpMWlEUjSfq+dE5NDpPfgbf0/6tCn19LrqfzcYlow4rcGwcXb+nRXG0YpTmWvN3+awUufJVj4rbSwbWFnoAVtPnNLTl88BIwGM5UMs/4ZmwKVM+jkq9NakstFXdFCGkEuP8tTynesQsQmSLcZRKJRfj5DNW+RmXsrMyRuWfcgp8hr/mSlvv/kgoY8CvSyxKxLOjDpb+TFEMlaVK3tKv6jhEWWckhez7EPXUNJ5svWNbC9R6xz6Fa/8d12X3WX9aEpFCn19J6CJbkuQMhaGonIMgCDVRmJ6eRiaTAbBKQOJe8qBEmb8iFmrocE+VrWlTJ3hJeBiPBFTtG/UUMxmn/JklEyrCoqEHrlMLnXDNvlzIjeamCAovH7Aq5JVYRgIa86p1LTYko183Y/jiXAqVKymFRStGR0ddgwMa1JqTr7Fa3Q+rSOKOESqiws8HEPII+ZxsKIP7Xk/BcqgyrleWcr3nxTunBjLPq89gVwie3mQ2mw15aITTiaBwbolEwt15Epa4T1pnv9WxnodMGcnP8qGY6xk2R2rwM3lWuC49UxzWqFLyF2FsypR8Po98Pu+QRqKNhLyjyNOmFLK1Eu3LJwAU2rDKmC99T+sRq7LwKeOoD9gXH2GuICv/KOlGIWFlTBYKhTU1dJkLy1J809PTyOfzSKVSIQNE19eKQlbDRlEHfoY25VDBuRn5xq0M+9lWoWnKGfdbU0UIlbGAPV+aB6mNKTZaqxqDhLtY95YhDfW4SRpaXl5GpVIBgNA5t16Mz/iMa9j4mIaPeE4tl2NmZgYTExMYGxsLlVtVQ1rnzz1hao6yf+P2ktWzofDnvzk22keL7NHI4/3RdSn83axCVoRNwxIq7+hNktdBYiULapCvwoInxWIxBA+T7Kg8EcuriFMhKwJq67hrHNwqYd/+6zhS8sZ+jg+JUxmjL+4vvX9Vwn19fejr63PpozSqrJxpdJ2RPWT7APQiaF6iJtX7SDv6nqqIfZT+OKEYa/nZmI+meqhBQeWrVbCUXadxNuazscayLcenXX6iQpe+uBK/WqKIesWbwczczEGY2sJmrJBm2alakID77IthNaKUaTgxBs/CAITKKXjoLbCiDxDOkVZPwoZg4t5/i5xoa1EqYcaI+T3WVmYJR4tqkayoZ5gKxe5xXBwPn+C0cL+PuKaGvqIrPhY8Ba6GPup5yBvN1SfnOBT2BMLGsjKpgyAIVUljfWiWcuTZ4vwTiYSrKsYYaRyx5PXCkspTsYaahX35txoWULlzJB0Bom06dJ58aTaFGvw0PqmUaUgx15qx8qhM8qYVsk8ZW4vSl3dL9jGHPiD1NLS6jo0FxuHF+YgumsJFoUSBZMsHkkFJK5ENGzQ+B6w2nqDFxd/n5dOSbK3EElXwasyD76d7qsaN3QNebP4Nx+NBUasAVguWwsCeGR87NUruup519QZpHPB3GHJIpVKOsAdgzSW3cc64aoWvt28aQ6ZXTD4D48Q0LMn7oBdt08U4LISvyqDVGJqdv72vGtvz5Xr69pH3gyiYolm2YhPXp4q50edj51FPufPnSgrkXmpDAnrB5XLZnXueQ9ZmZxqYsrY5jzidF5t/qwV3rJNl4e1KpYJkMumY8Ho2LBfoSMoba9xZA8IW7GF4geE/bSVJJE4LF206ZM1hBYsKQW2lp/EmpgURdlEr1bINte8noUZtDhAHTG0PmK0FPT097eJnenktqUELQeiD42XxMfZU+baqiH3r0hi179DZ6mmcXxAELjYXR1ggjlFvf3yCT1NeLKlLWcDNxI9VIfOcA6vesIZrqJQ1fKMKjedhPaJLXHvdiLDR869Gji/jwbcnluCiCrlV9MV3pjk3WzFN716996JCZqVAIlzMo1UeB+WbL/653jOy+6KhIm0YoRwPNZpUCShXgk0wAITCNVreU3Ps4xhWRlkvWQmk/H4ikQiF/5jPnkiskOmorMgeV2dNvckjDWPr5/nkqsoAPlf70g5crWQbNB1D9l1Gn5ufTCaRzWbXFB/XlCclQdkYIIPkDJRrGkArgssqYx5wVcYs9m7rIdOYUCvResW6L4TxaElpHDMuQaxCV71kjXPq89HSiEwdqVarIcOnFUZy3MMqAz17ZKFaREZ5BzY1odk1qPHJwiWM8ak3wjMMhMlUvBOcu57xejB63Hun99YX57NIF40Pm4qlwkmNFJ+BYfe7mXXZmJ4Kf72z6slbT9m+HxnwlUoF09PTOHToEKanp10dBKY+WS/ZKuON9pvhNiJhvb29qFQqoXBFqVRyKJV6wXRcqtWqk5lUvEEQrmSnhrddd1Qjv9FnYVEqfiXUXiwWMTk5iVwuBwAol8uuKIjlGuhLnZfNljU+yNw6fCr/1KhXR9RmF7TKx4kEWSuEpwdNlVFfX5+Dw7ScoyrnYrHo3g+AKxfHwhV8EZtXXL7ZBfsC96qMOcepqSlXa5apCloAhIqNXy00rKlFGmPo6+tz/+7q6vLWZI46uC5VysrKJFxpYXn+jMYFCQmaZ8242mPpLVv4Tz1fjfH6LkYrRECf8QnACW/lSlCQE4FQJGVpackZk4w/MT/TEkD4ua3uswoVy4peXFx0BTGsd6UwsAopIkHcT30Omvah4aVW0Swazsrl0K5JZIITvfLFbe29Z1lLGt3sna0FRrhvzShjYJX0Q6Swr68v1Jtb7xS/T6+dXiSVLgmhLE1J71O9eN9X++9WhnqKGiqw0LWGj2g46Dp7e3udc8V/0wGgktZ5E9bfrEFjSM+3vqxxSqjd1/ZyI1nT7GjaQ1Z4E4Cz+pVt2dPT45QuU37Y5YSNrNPplb6ieomokHt7e9Hf34/BwcE1CrkVBWYVssbUNK7Gy07rWWFHJelo+gKFthI0enp6XKWagYEBp5R5IK2BEXVNPniPQpVrVdZ3JpNx3iXRCkKwevnqWX1HOtZTDwpkS0l6JRZi9AnTKB4yzzj/n06nQ54Cz68Kfe4rURQaC5lMBl1dXeju7nYknmaY363um9bZtaQVC93ZZ80GDupRr5f21Oo9VehWe+9OTU259KzJyUmUSqWQl6bhGg6eaeWLUNlx7VqdT5GDRgUtlS4rifX397s9410PgnDbUC0aUyqVnAKgw8DYMPOurXHaTFpW1GdhIWtNr9S1AKuhHHr+MzMz6OrqCsl0ltHs6uoKoYu6hs32lK1nrGdav6bTaRduUoO/HgO/1WcQOYZMxcjJa5uwrq6uUA9QKgCFAJiOof0nLdRD5ho9ylbYmz7F5bP26hX90DixxuM4F2VMUuD6vGNbpjIuQWyha8Z06N0oFFkoFNyeq2XLQcXia/F2pBSz/QwLHdMTIRzMn3HNcXj09sKqEapWtnoONs+XeclEK5SlqSSQuFjJdu6KIHDPLIqiawLCefIcyhy2XrIS6nyhmKhrsvuqaBZzpsfGxtY0qFEIV6FsG6vlvabXo6mCfObNrEPPJzuqEW7mOSEpVOdLw4h3UkNMJANqupQiHrzXrSIS6w1L2LKkOk1lpUKm4i6VSuju7nZrtoQp7pstzMG16tc4h4Wp1TumzON5Jvq5noEW1xwjQda6IBK3arXammpQLENGa0nbVNGS5qECEDrMbK1G5RW3F6EKTNemF5EPgJecUCUJIDYFhJaxKuOenh6Xs0aPaDPWY9dklfPCwkIo9so16GXi75OAoWQFG5eNQulvZVjiFj0y7p2vchb3JK5hPUgbl9cysSQL8bwkEokQnGkJixpzjVMp232zSktDTjwTVNaKuCjxyXoUqojjZFjzqyV2+QqbMP9bkSuFWfm87Lq4BkUErPfZKNdDFSWfs0LphKJZsZB7qemKinBxrjwbqiw0TMDP3Cwv2T4THfbZKIJKbgXXQcVGY5bvx7NJg4Q/22yil/WS1fvl/hI9VGRWDZNWs2TsaLlSl8YZlWDDQ7S0tBTqTqIEG2tpqKdNocW8rlZLO6ohYeNq9MxJJiuXy6jVVlt1UbDqS71QFbSaLN7X1+cKq5OcFtd66q1R34cWq8/q5GXRPGsiFqow+NJnaD2hOJUeh89Q0hhdJpNxSoJeBNcW10WxqIoiKhqPJ/eAX7W6FVEUzp3nTA0zreoT59B90xi4JXdpyo16/OoJaejGd4cU0osTQrWejE8pEYVQ2FrziwGsWZOeDWu4WA/JGhj11mTPqY1LqzdMGNumYXHuvj1QwpjKRx/aFudZ4nv5wkAqTzh33klFNxguY7YMlTgAd2aAVXTAR/TarGGfP+VLNpt1xh/XraEeNTZ8Dl6UEUkh60L4VS1aAI6coIdZY21KRLIHWYs7WPZvK3Ck3XiyZilwmG8MrHSRUZa1MsJtBZd0Ou3IaAMDAxgaGnL9P0lM0yYO9dJDoj4DFVp6WXjgWU2HwpXeBRmR2kBdc+pIxKBnr/Wv+dw2c1iBSQSFAom5wEpEUmJSXErZEgEtOW5iYsKRhPj/YrHo4DvGoBWuVuPMKmQ73yj7bC1/EjB9KEM2m3VtCZPJZGgPlYNA41sVpE8hx6EU7PyVO8C7RgWmnowazbbMKh0EGh1W7ihpR/Oq9fmstyaVL5Qt6hkq1J/JZDAzMxPirbAYUSKRcB60fU8bEtOQniItUc+Nb00AvAaRomZWSeuLhkilUkEqlXKIi+aAU9ktLi668pP2vTZr+Awe7nG1WnVcFcb/bftURZ0AtCQbW1LIdlEcvgtp8yB9aQpKENNaxHFV/7ECit+jscBLmslkXPtIvqis7cOo1Wro6OhwxIWBgQFs27bNKWRC1spg9hVPaFV4aSxboR8VUsq4pgCmYKbxQ0Wh6yGxjoqE+8SLGEfuo1VC1uLUC8OQBhUe00oUlrRnLIpS1vdUMpCSFcfGxnD48GGMj4+7pgwUrvTcEolESCHb9A8Lh3F/+be6B/p1o8E94zrs/VKGPXtNJ5PJEHFnYWFhTTMGhVBtCkhc59pnPCuxkqhIJpMJVRRTshHvKwmbWjeAe8q4vr5sk4dGZY/OmdX3dM+p6ElcZZ7u+Pi4M4ascqNCVySPPc3JTYmDY1NvPb5noeEJX4qkoi/8Wa220n86CALnDDBljc/CFhqxcfw4kTif0asKmZW46PGTAEiSnvaCttUoiRpGeQaxKeT1hsYnaU1YWFEvuM9KtczNKN6xHix+Tw8S59HRsdIbmQQSVjRi1S6NMQdB4H6/p6fHwdRDQ0OhPGqFJync4jAwrCehFqvuPWEyErzm5uZCRC8eRCqN7u5uDAwMhCry2LrR6jVFGWpRWoSFiknPCb0LFXALCwtr8jPtGbOf2cx+q0JW6I3e8fj4OA4fPozR0VHMzMyEuvowRqmGknpiCscpeWc9b6PRmJo9F5wDWaMWMSHfY3l5OdSdyOYXc54qcFU4xxFDtneVZ9Q+S3qaWtvAFvspFovuvvO9uX6uQxWwVcbNGM/qaXGoQcEXlers7CxyuRzS6bQznpVsx+/x81VZMG3OtjaMO/ThexZcj68IhkUdGKbkeaPxpGVvebasU8d7QpnTylAZ5ZMxaqApSsszRZSRBoSSf5UgaFG5Zp/FpilkOxFLNtLLZR+mj1gVlycJYM0BovWlEBMFpxJd7APk39HqpVJmzLi3tzcUK9TUHBWuraxFL4uSttra2tyBtyxq3VOFBEul0pqUIp0nCSaqvLUwSqPDKmJrnNm4rU0947otnFWtVh1Bx6fkmx06DyVwaTczejmMIavVrMaevqd68YRQaYn7YnTW42iUUKdwI9/H7jnjdYzr0VO2ncAs38Pe1/UMzFYUsypk+xwJ4yqzXSvqlUolB5HSewZW7wPPvsZjtTOYr55+I3NWo0mVmBrAVPrJ5Aqpi4WJtNCJciM0bu4rNhR3DNkqR5UTtqyxZXnbs8qUQYV2tdiJEoD5OcyTZ8pgK6EnK1d8MkbDXFyzIo80juiU2Ewcnj8aFmpsN/MsjoiHDNS39O1Q5cf/xwlVqJDSz6SS0fqlTM2ym8sN7+joQCKRCEGRfPGy2A44cbJp1SJXQhw/TwtW2H2k4UFBTEuch4vxRgpmrl+hV/WqGhn1FLFCzGr4aHxfPU+N5ysbVT8jruG7uLasoXZNYiEQ7h/XQyuav18sFlGr1ZyHbxWa9TYoCPlsG12r9bR1TTQOaBBYBbSeItb3tIZEHOfbemYUyvpz3kGt4kUBSQSoVquhUqmEFBaFJhVjPp9fw5XwVa1rdE0qyNVgsf8HVohohULBeWQ0ii2BTpWgLdsYZ/pkveeg6KXC+zTYuNfAqgLnmboM55IAAEpgSURBVFUIV41v5bFQlvBZZ7NZl6fcyp22ssa+NBylPQt8VRl9CpmoaaVScZ26OKI8k01VyPowFfLQQ6lwsXqpPlazKupW5wWsKmWmbKkFyhQFHiat+UshRo+XcVe1sn2K2AqtONbBPVT2MdMtNN5GJeEb6p3yAtAq1xhnEATOeuU629ranOex3rBWqaakaOxP44T8P2Oc9EJZ8Y0K0MYFuTdxCSeds+aw+ioWKdmPe8sLz1ac6XTapZZZD8MKbiXz0NBTVKeRoQpUIXs1kKwC9u1Bvffk17j226IDRD2AtQxxKi/uOw0lxmQXFxcxOzsbChNZLkJ3d7drn8cSj5ruZz3PRpVyvX3hPVheXnZGu5bVpeJlyEPhW+ud1qsRENdzsEYh01IJ52oYT3+f8+JXZZqrkUpju7293c07nU67BjyWMNXMUJmmRr/+X88Mw5MMUVLOaLouzxn/T6RsdnbWIQG6d83cU6AFhXzr3ltx9X9ejbsP3o1DpUP48iu/jJee/NI1B8LCHbSaCN9YyJHwklY6suX9NO7c7Dj2Y8dib2Hvmu+/4Yw34P1Pf/8atii9GO3tTHjRWtiqmKPEoBoZB2YP4PJvXY5vPvRNVBYrOKb7GFz5pCudkuzq6kK5XA4pN029AMKFXVQ46/4mEgkXf6MHRa+aMbBsNuuUy0bjmruuwSfv+iR+M/MbAMCTBp6Et5/9djxnx3NCRDlVzKrwtF8ve/Uyn1PZzOpp+va72b0PggAfufMjuPnXN+PBmQfRnmjH8Z3H48L2CxEsB2sUtKISylKem5vDzMyMUyAzMzMhmM4qY43NMqbf39/vnl2jsft695TDQnM64s6xbGYkEgkc/4njvXf1dae+Dh94xgcc0cuHXJC0yTg4zzv3mRUFWS+A1fR6e3sdy9dCwc2GmLifH/rhh/DO77wTb37am/HXv/fXAFYJlzSilTfDgjdU2hrn1lCCspzj9o4PzB7A5bdcjn9/8N9RWapgZ34n/uLYvwhlYPBc0zCmUcz9UoOGoSbNhqCXXC6XneMDAJ2dneju7nbywEcArjeuufMaXHPXNU7OnDJ0Cq58+pX4vV2/F+q1oNkyWsRqYmLCvShn2A2tXC67sJh69zMzM8jlciHugL3Tjd6hyAq5vFjGGSNn4PVPeT1eduPL1vxcLQTLxrRMYx4k9ZgopOfm5pDJZJzHZ2ONzR7CO//0TlSD1U4r9x6+F+ffcD5edtLLQkJQFRGhxnK57HJfeTF4OMlAtpZunMp4em4az/jMM/C8Y5+Hr73ya+hu68Z9B+5DbjGHakfVKWTGf1VQ6YGgwOd8fNYrDxznzd/p6OhAb2+vI2dQwaw3giDAjq4duOp5V+G4nuOwXF3GP9z3D/iTr/0J/u2l/4Yd6R0h1qLCQ8oCZp7v5ORkqDoT18S95jOMy2u4bf9teM1Jr8Hu9t2YnJ7EJx/4JP5m+m/wxuCN3ti38gs05WNmZgaJRMKdaSXhUNirgFWDj2k7FOA+kpNvbHRPo4x6HnTcoaUfv+HHWKouuX382djPcOEXLsTLT3k5stnsGm9HWeGJxEqKE/dKzwK9497eXleet7+/39UMoEKO4/7eeeBOXPuTa3H6yOlOFjJ2amFo/T9zxqmQLUztYzjH5R1Pz03jmZ99Jp57zHPx5Zd/GflEHv9n//9Bdi6LxGLCKWPKCBouGuqiIlajgevRLAKmFSl6k8/nQ4V1mjlbR3cfjQ/9/odwfN/xqAU1fO6ez+HlX3o5vvvK7+LY3LEhNEvDTuxORUWspZ61uxbnojA34/+KCvB5NktIi6yQLzjhAlxwwgXr/o6NAdkYpy0uQYvQeqTcQFv1hQ+pmUM4lBty/w6CAB968EPY07cHz9r5LEdpV2tTg/l88YJrXqmyHeNOAeH48G0fxs6enfjMH3zG5fINtQ1hbGwM+4v7XdxFe9nyq02lsSQV9er4u7a6V2dn54a9cuuNi068KCQ833XOu3Dd/7kOPz74Y1wwcoGDfrStnL6URMV4rcZqNeSgrOBWmPkc//KH/+LSnPqW+/CmHW/CG379BhxOHPbCmD7kgU0NADiloYaRnhEKW0LV8/PzSCaTyOVya0pEbjQavaf86iNp6Rm2RJX1CDOtKujh/HDovf7f2/9f7Onbg9877vcArE2l5F7WarUQCqdESps+xFoBfX19Thn74OooHmhpsYRXf+nV+PRFn8b7b30/kAg3+7AvVWJkwwOrTSt8XnHUua03Pnzbh7Gzeyeue8l1zhnpT/RjfHwco4ujzkDU2vj06hn+0zkSslYuAI143g8SoqgoKVuaPUcXnXSRe99arYb3POs9+NTdn8LtB27HyM6RUPqq5YCQoDk9Pe2ga3bc0pxpevxqBKqit4zrZsamsqzVO9YYpy+tgApB8z2LxaITSvl8PgRd20MY5UAuVhdx/X3X4y1nvyX09xpb00uvpAVWE2PqgQ/mstZrq+Orv/oqzttzHl5x0yvw/b3fx1H5o3DpKZfiwm0XhqBz3SdCqQCchcdB4UVo3q6b76HKXUlUzRw2jR8vLi3ii7/8IipLFZzac2ooFmN7vTJswX7VvFCcG8+XMt0VqWiGIesbGhfjZy0mVyD6rrYuLLYvhjwb7huwWhNZ42UAnPCxxoJ+Bo1X/k0ul1uT89jqUMNhI+OZ/6Z3p1Ac70e9sxE1vMR5AcBSbQk3/OwGXHbOZWhrWxFbNtuBe23nAqym0QBwhjQbwGh6IuFjX2y22fGmb7wJF55wIX7/uN/H+299PxJYm0lST7mqV6noj/09+4pjfPVXX8UL97wQr/yXV+LWvbdiW24bLj7xYpw/fL5zQHg/S6USurq6XOnSanW1EYPdNzXuGG/lWtXQa9WABv7LkVpewj///J+dnCkWiyEZwxcVMo1uTXNVZVyr1ULEMxv2s58fZWwqqYsPgBeaJAoKVwpNrYSytLTkNkc9BY3Nas5aK8L2K/d/BTPzM3jNaa8JKQwf61e9F1XGcTfBWG88PP0wrrnrGlx2zmW44hlX4I59d+Btt7wNtXNqeFbXs1zcRQUgBVI6nQ4VM6HioFKud4Bs6MEnUBo9fPeN3ofn/MNzML88j1w6h2ueew12du7E1NSUy/mmRapsR2XOUhmrt2P7Z7OQCQuxtPIsrKJKt6fxucOfw4mdJ2J3ejf2Zfc5kou2ybPpE8xhJIRNYWtRC34OU8oSiUSISxFXBbJ6a9QysBS8NJBZ2EHJmUS16C2o994K18PO7+Zf3YyZ+Rm89imvDUGARHMoO9S404YGmjFAw7W7uztUuMfm87Zyf//5Z/+Mnxz6Ce780zvXrIVf9aUK1mZiWKLfZihhHQ9PP4xP3vVJXHbOZbj86Zfj9n2344rvXIHgaQGe2/tcd45JbqKRTFSNclINNj0LyphX9vZGlRkbHfcevhfnfuZcJ2euff612Jba5trq6tng/Ontz8zMOH6KxrCB1ZrbdCxt8Sqb2hpl7puikPUgqQLTmtBzc3OYnZ1FNpt1ghZY8R7YJ5lKmsQWWrBqfbVyKD/z08/g/OPPx1H5o9whU8talTHXAiBEsmEMigxNW8IuzktTC2p46van4oO//0HUajWcPnQ67hu9D//8wD/jRc95kSNX0QBSC7y9vX2N90lhqsaIGh8ayyRk5mN5rhcjUcF8Qv8J+OHFP8RkaRJf/tWX8bbb3oa/e9rfIVfJucugl0XriCtxSmFq7Z9NUo7WELfnpZmh55hw+Ed//VE8Ov8oPnjCB7E8tewEOwkfqVRqDbmLg/P3eQDWO6aXTRZ73JkGdp1coyItXANRK608RoGrvA8bxiBMGce47qfX4YITLsCO7h0AsMboZJx+dnbWxf7IyKcBlMlkkEqlXI0AVrqil+wrsBHl/u4r7MNbvvkWfOs130JnW+ean9cLbwBYc5d8Z2UzlLAOypkPPP8DqFarOH3odPx87Oe48eEb8ZLnvsTJBgChOgGpVArz8/NrYsSakqgGh/KK6KzxebTSlvSkwZNw1xvuwmR5Ejf+7EZcdutl+OjpH0V6Jh0q4KNlkTWWTIY1UxgBrMkBJ7rCs6NGncqcZue+qZC1KuRsNhsq7l6pVBxmr51aeOnppQbBSo6wFtpQr0chnmbG3pm9uOWRW/DFP/riGthNU614iLiOZDLpCoCoQqZgtl5ZnOOorqNwytApof190tCTcPMDN7uSloRvFa6lsC0Wi64ACNdbL+bHPVX4UkMNSkJpJA85CAKkk2ns7tmN7Z3bcXzueNx16C7c8NANeHXPq13KgVqnWitW49S0sjOZjDOMtm3bhqGhIRcL5EXRy8G/bXSoYZJOp/G+u96HH4z+ANc96zrklnIYXR4NKWTGMenF63lX5MW330o4Yh1f7q0vwyCuYUNLWtrRhRj+i01LDxlAqAITFbamfcU5170ze3HLw7fgS6/40pqfaQySxD/1hEqlUkghd3R0uIp6tiWqsphbgarvPnQ3xspjOPNTZ7rvVYMqbt17Kz7x40+gcmVlzd/4FK79bAvnbpZiVjlDQ/GUoVPw1Qe/iu7u7lCGDBUyz75ts2j5JRpX1rK9DDWR+Z7P59co5EbX255qx56+PTg6ezR2n7Ubtz96O2546Aact3Qe9u/f70ih2lTHlltVZcz1ahVDKmEadtpASAmBzRp1m66QFYLTvFJaEkydSafT7sHygvEhd3d3r6kdury8HCJsNHv5P3vPZzGcG8YFx18A1Na2eeOL70njoq2tLdQZykJeUQ9RI+MZO5+BX03+KrTHD808hF3du5DNZgGsNkknkUK9SiCcF6tzs3unEKp6bkow4ddGhw0L1IIa5qvzjrTFmsOER23aAxUWP5eGnjWOtEEGkYKowpV/d/n3Lse//ebfcNOLb8JI2wgKhYITIvwsbQighDoAIZhP6/UqIkEFRwNQlZtN/YhbKfN801vhZ9PzLBQK7vv0fNVDJl8hKsdgvcG7euGJF4a+r+dJc71nZmacQlY5QiNOuQZaO8DGjqOO5+9+Pu77s/tC33vdza/DyYMn4/JzL0cqmUIVa4mQUT8zbsVMOcP3TSaTeKjwEI7pOcal9/CssiIa0554lllqUmUPnSZlIpMfxMIs2i8+6t3Vc7G8vIzl6jIqSxVMTU9hbGwMY2NjzgPWUJCmRCl5mHKfd4Nz1jNkWwVHReUiK+TSYgkPTj3o/v/I9CO45/A96M/0Y1fPrjVQL60h5t7ZnDoeRoXJEomV4vGMUVg4LKoVXgtq+Ow9n8Ulp1+CtmSbS61QEpcPslY2r7bo4kPazIo5AHDZOZfh3M+ciw/84AN4xZNfgdv33Y7r7rkOnzjvE660ID+XYQEKHK33S6HD4dtDXixlfSqZQZV1IwSjd333Xfj9Y34fI50jmChO4KZf3oS7Ju7C+09+P5ZKS06garobvS71ymhxq6FHC5uXmrF8FbBRjCP+/lv+4y34wi+/gH96yT+hN9uLmcoMSkEJaMOauusUQPTceKYAhNAX9ZR1r+n98x7wpc+pmfO+0T3Vtdq7yuIlSlZMp9PO49QYYT1SV6txZN7VS8+4FG3JteJKBa/mlNKoU4IdkR4lliri4ysTGmV0dXTh1OFTQ9/LpXMYyAzg1OFTYyHkbeawcuaO/Xfgunuuw99d8HeuOmEQrKRGarMUNmDg+QZWEQwbC6fsUORN47G+tNhGxpW3XInz9pyH7bntmChO4B/u+Qf8dPqneEv/W1x6EkMZnK/WEFfj1yISlIW2EYlNdW2FgxBZId918C487/PPc/9/63+8FQBw6RmX4nMv/VxoIUoC0hxRH4uRcc1kMrmm4IIq4lYs8FsevgWPFh7F63/n9aHvaxzVd2nsYdIEfZsXuBkK+Wk7noYvv/LLuPLbV+J9338fdvftxkde8BG8+rRXhzwpm9do56gCZ709VLjFPi/72miMl8fxp9/4UxwuHUZXexdO7DkRHz/749gd7Mb+wn5vwr56iJyzwkD6HBRW95UTjPo8EokEPvWTTwEAXnTTi0I/u+KkK3Bq26mhc62QXCqVcsKIgwLKKlv1HnyhhKjw70b3lGsEsOae2vQc/oyMZl2TvZdxwdX17qoONQr0/NDTUTKi3ll7b+NQxr8NwydnPnreR3Hx6Re7Z6x8ErufqkT5bEgABMLGvhqBirpFlaNj5TG89ubX4lDpELrbu3Fiz4n48GkfRn40j/sW7wvlHhOWrhcS8s1Z77bvjvBnKjtjjyFzUrOzs+57Z/aficJlBe/vz87OrrFcmUSu5Q4VFuOm0LpS2E4D7qVSKZTLxpzOVCrl8jxVEPjmfs7QOShcVkAQBI70QbadtltUti8vO0k7ao2zKtDS0tKa+Hajg/PbaO7P3vZs3Pbq29zPgyBAsVgMKTOduy1LqXCMClCbSuOLq3O/lcmaSqVC1rBv7kEQ4OpnXY2lc8I1oAuFAsbGxlwoot58NbneFjzRPEA+Q8a2fPXDm933IAgw9T+mQsKe0OjExAT2798fMhzVyrYpc/alIRHOT2PNdo2671r9aL0z08g91Tistpjkc9KOXxousnfACmQ+Ax8E3Oh5512131eoWu+thjsY86ZgTSRWOrlpHeJyuezOcKMwo527b952fPVlX3W/w7g8521lonVC6JzofpdKJWQyGdRqNcdtaUTwNzp3lTMcrL+ubHYrZ+rJcz5/Nfp8Ffm4PhrWPEdtbW0ol8sbzv2jz/soas9Z5SlNTU1h3759uL90f4jApWfahoUAf5693lvVa5Q7NKY1rZHPxXfevSNoYOzbty8A8IR57du3b2vuW3PfmvsT4PXbMPcn2ry35v7Yz73eSAQbquwVItDBgwfR1dX1uIZygv/yFrdv3+4ssq25b/7YmvtjM7bm/tgMO/cnyryBrbk/VsN33n2jIYW8NbbG1tgaW2NrbI3NHfEmym6NrbE1tsbW2BpbI9JoiNT1RIEGfptgMGBr7kdibM39sRm/TXN/oswb2Jr7YzUahay3SF2Ps9fW3LfmvjX3J8brt4FctDX3x2bu9UZDHnJXVxcAYN++feju7m7kT9ZUS2FuYLlcxvT0NA4fPozf/OY32Lt3LyYmJlAulxEEgau+xAoovb29GBoawvDwMIaGhkLlM217tGKxiJ07d7r56tz37t3rqimxuUWxWHQdPsbGxrBv3z7s378fhw4dciUcNRWBtHdb1IFFB7TqTHd3N4aGhrBjxw4ce+yx2LlzJ0ZGRlx5TU0TSSaTKJVKdefOfbd7ygpXk5OTOHz4MPbv34/Dhw+HCqhrGUpbTSmVSqGzsxNdXV0YGhrCyMgIhoeHXW1u9ont6+sLNfbwpbBsNPeNzoqWaJycnMT+/fuxd+9e7N27F6Ojo65ZOOvQsg1mX19faI+HhobcnFnSTuuLR5079z6QFCFtC8nmGNPT0xgfH8fhw4cxOjoaKuHI58G0PrYA5FxZbWzHjh045phjsGPHDgwODiKfz4fqo7e677oOnqNKpYLJyUkcOHDA3cuxsTFMT0+HCiksLS2hra0NuVwOQ0NDOOqoozA8PIzBwUEMDg5ieHjY3VVburTRuT/66KPo6uoKVTbj2SiXy5iaWqm4dOjQIVebmL2yp6am3Fx5T7Woj1ba02p7g4OD2L59O3bt2oWjjz7a7btWi+LZsXNvVj7aO6w9dXmGRkdHsW/fPre+ZDKJ/v5+7Nq1C8cffzx27dqF4eFhdHd3N5V3X2/u3HNNb2QVrunpaRw4cAAPP/wwHnjgAezbtw/j4+NOPqps5LlS+agV9lj8o7u7O3TWKR8p37XYBtdl5ftG++7TQUy3PHz4MB588EHcf//9eOihh3Do0CEUi8VQI5JsNuvK8h533HE48cQTccwxx2BoaMiVWm201oHvvPtGQwqZH0RF08jQS88HzDwtbbDNHEFVdty8zs5O91Bt2TsWgff1K9WN0blTIdv2ccxZo9Lhi/lzrBClebl6CG2BDF+COau8sHiFr5LUenPv6uoKHS5WMltaWkI6nfbmzdFYYB1uKgKuwZZstHW8Abh9Z43W9SqR1Zt7owqZwn5xcdGVoaMy9RlgKmh1X3lGWLWrkfri681dLzb3nwUy2BTFllzlc9fCAZwn5649cJm3qN1kdB1UDK3uu1XGbEiyuLi4pqCDFkvgWoDVM6EV0myhfa0lHuXMqHJgFy0abVoljM+Fc+Q54O/zfX29h+161bjWVrH1ClXYO9uofFS5yHzyubm5UOUzrofyR/tj05Djy9cdaSMI1zf3rq6ukLxmz+5qtYpisRgqSVupVNxZ0MpcPFe24pXKRyurtMiPrpNK2TawaXTffQ4M91tli684FWW3ynA+C632pjKpmX2vNza1/SI3wyZSawK4vpgArvWXk8nVxuxa15gb1cyoV21FG4NT+Hd0dDjjQZWc1uOm58x5UqloMRMmjbMIgSqGZgqI6F6qBcv3pYfGF8vCUfFqy0p9T66J3bd4AYMgcN2UtPPPZg59plrlxlZQ8hkvmzknq4y10A27C9FL0/3n2eD+p9PpUNUxLeigz3ezujpxPSqgWAiEiBFfLPrAakZaY9u2oKNHo2UPW616pftt58m5FgoFhz6wy1AqlXKCXJWBCk29c3yOHR0doT658/PzISNIq5nF8QwsKsQa3FNTU5ienl7TbcgaoT7DIK6he6TV8FRRMm5rlS/XZGu5c+j9sQWfiHxq7eg47kE9fWEVNvWUVhLTcr6qs9QJilKRq97YNIVcT4hZRcyHUSgUnOehjaEBuAbYcfSDVYVsFbC+2KoNWPV0gJUDpdVy2BCD82ENaXY+UUGXzWZD5dVoCGx06HzQiypjdkqanp52ECOLu9dTamwSEASB62Faq9VCnU7a2tocpMpuUpuhJOoNFQx66OM6/I0Oi0yoMp6amsL4+DjGx8cdnK6NUGgQ8XnzogNwZ497rx1yWjE811uHemeEJKnc9PwUCgUUCgUHSxIt0q43ikJocX1fTfdmnpcPXeN9m52ddc0jqMCKxaJrJ0pjQb0uVWLcZ+0iRuOTkDGNEQtXq1fe6jOg8b6wsIBSqeRgeBp33HtFW2wXqs24A9YQVkdFQ3KsEKad3nhHaNRQielZ5rpVGTN0qB3qWDNaz1IcwxrYagArYkQni+vQ0FQ2mw0ZavZ5tPJcNkUh23iblv6jEtEXG9MTmkqn066UXFtbG3p7e0OFy1v1jilY9LBZyJMHTmtWAyuQTLlcDnU5oqVIhZxIJEIKk4pZW7xRcDSq5KxS0PenQKWgKhQKrnwbhT6VgL4XGzdwXyuVCorFIiqVCqrVKjKZDPr7+12Rfhsbimvoeakn8I60Erbz8nnHxWLR8Q8Y69OC9RRWyWTSKV4tUcpzyM/SMqWbURfap5B5/+jhW2VMpaDesYXTCXcyrBFHxzNfnFWFN88797xarYaMBT33VKpqMAMI1bxmaUb1khn/tnWuW91/RbroHfMcTUxMoFAohBrq8JxQFlkPOY6zYRWK9Y4pG9mRKZFIIJvNOqOf82BJ0NnZWefE0DBVTzqZTIZKnhaLRXd2GL/VrmitDrv3vjK2tilGIpFwfeTVebSePF9xyMVNVcgWYqWloUpKoTH2Qm5ra3MWOUlJ9I51RLFI9LCpwqKgYfcSCi5CNRSojGFRoXFeqqw09mWbJZAEoYeCe7befvKrxtm1jjMFCQ8P62rTslarVxse0GLl3DV00N/f7zw9VRKMn8ahIK3F6nvFqZSizrFeuECNoZmZmVCDdioxnhd+1dgahZ5vDzYbslaiFIWievg0LuhBqkLWGDIJUoSKW+16tp4hpHXL9QUgdKdVKdMI51mv1WquzvXi4iKSyWSoxrGvznGrz8MnE7V2eKFQcKQ0evw8R7quevsa1/3wOS50Wvi8e3p6kEwmnTcJwO0vY88A1rQa5V7y/W2ddu69NgZp9f5b59BygdQI1i5snKfKbX0pwhLnfY1dIa+njHnZ+dLmDRpvqNVqLpasnpmOKBe93mGjRca+lgsLC46sYPv+8qLMzc2FCunTAgcQsi419qPwRhxD91oNAR5me2AI2enh1ufEQvYAHHHDGhBxesgWOtL5q/eiHmerIYuoc7QNHywfQo1GJeLQA9a/VQPOngm7prjXqOuxxfEVZlShw3tj44j2pS1V1yPQNTu4R/y3Pfc829wrNbr1bwGEBCeNYSv04zYErUJQjolC8fT+iUpoWM3urfIOdE2t3k2FrC1pkrFjEmM1vk30Z35+3ilrOl71jId69zguNMwadVa2aKtI+/wBhBS3b66bYTTHqpDtgaZlq57E9PS0S02g98vNsl6DbelllVuUB6ckD7KnOzs7HXlpfn4eiUTCNdxW8hW9HCpjy/zj7xE6oydBsovtK6tCa7111IOSfCxd/o6ywRlf4/dVuVnrEFjxNNRL8B1KFX5Rhu+c2NZohInorZFXoB79evsVx6jn2aiRwHlw/9PptCM50ZBTw9SuYb24YNzG23rrUcOY3kwQBGsMVyVyaT/wKGxf37AKQYlkDC1pi1GebyoFrguAMygAhO4DnxtRDL1HVt60+gx8qARlIkmAikxQqSlUnMvlHKRLKNiHqsRhBPn2nvHj3t5eB1crYkmPl8YEw3p2rtxz/Qy+v/an9rGXo/AQfNwPyhIa1D4Y2/IFlAfCVxzkRTtiU8hWGdvcwZmZGUxMTGB8fBwTExNOKSuTkIvWeC4flG3vVo95W2/Us/zsvFOplGN1q1CiZVur1RwMBqxa2Zy7xtiY/tHT04Pe3l4Xa2N6VTOehM9iVYWvjb5pTHB+PHR8TlTGylrXdB0lOSisw4O7XjvDRs+KVcZKlNM4IV+MbTOeXc97acTAaWaeAOoqYyXckI2fSCQc0sJ0JQAhAhUNUUV+LKKymTFzPm/r7avXYGF19ZCY0tTd3R1KT9O80VYFlaaZZDKZEAu2Uqmgp6cH09PTKJVKIQVcq9Xc3lrkwT5HAE6xWy9fjeZW1sIzqs4Jz4GS0/SMaxvXbDbr5AZTnBjiUGNaQx9RjWWf4c80pFwu5857JpMJyRSuc3FxEbOzsy5NiuiQhub4/hqb5n3p6ekJcRE2SlVsZO8tX4KyhfeQ596SKH2KmDqJ592X/93qiEUhW6u7Hht1bGzMMQknJydRKBSc4uPlpyXsg8NombRivaqHrPA1v5fJZNDT0xPyFjRmVS8PUj0jMhG1uEZvby96e3vR09PjFLIWe9jIQ9Z5U/hT+JBgw71ifEwhXsuGVAWrRgehd/s7qpRpiTcL11gPTZWxpm5pTFaJRlqIgN6Q3ae4h8KNlhvAWDyNI1rN3d3dGBgYQG9vr3tWLGihc1RUohGvLI5QwXoeskJ4amiq4KQiVoVs48c+4lEzQ+Pv7e3toTBJEKzkfhcKBfT09LhCK4zd63PhejVur/tIQ4p3Vo0LOgD1cpCb2WueHcpDJWDSOeH5rlQqCILAkZu45729vY4USi4Cn50iZspfiDoszM/3IjlRnRXd47m5ObS3t4cIccAq+9oiEoxLUxlTXtJpoXHXiiHEfSJ5jntPhFb5MSrPrMxVvUTOhDUa4jCiW1bI9YSskrhYfWZqagqTk5OYmppy7EhVyLbIg4XCuPhmvWMOPVj1YLFMJuOEk8Z5aH3SItIYDmFHQtVaAUiLJVCA8UE2c9l1rtz3arUaOiAUKlTIlt3Iv2v0mQJrC15Ejd9aeM0abZq+ZV+0aNU75vOwzzZuz9IqMEvG4zPP5/MhZvrQ0BD6+vqQyWRQrVZdY3olDelYzxOzlnucECqfgXrJFobXc62pTlTG6i3YMx0VsqbBp9kMwIoRw3NOQ4AesRpKVMD63HQ99IwJxep6aPxblnYz8qZe+I5eGlO2VDnwfPOztCiPDYFQwVsZpoxy7mWze89h10ykRIlaigDQmJiZmXGGQyKRCCk79ThtcRlFApTh3uxZ0rvi85D5Io9JiXu6D9YJUkfRF6KJY7SkkPWBqAdhC1ZYJirhGXqc3Ih67GdfHLmVy86v6iGrMcDY6fz8vINdtNpYvYOmDE9VlLZwgg/Wa3TeVMiarqX5gT09PS7lSeOVvviverpcF4eNTdUjODQqnCiU+KLRo0ab74xoXI2xTZ8Fz+/5/h3H0PfS8AQvKX+WyWTQ29vrrP3Ozk63/8VicU2YQiF2X3qNfW6b4SHzbGvsWOF4DZFQKdvyhhbijWPv9RlTxmicUUlONNC4HnvmGY4C4PXOCJfako1RFYIPMbTMfBqbLGzCOXM9Wh+BObqEhoFVhIUOge697l8UpawIoCp9Fuuw3jHTsyqVirfELudgq+r5ZKUquqhIiw2fai47FbLKFEVQ+LzXKxoVl/FpR2SF7FPG6hVrnWjmxhYKhVCqk3pwuhgVqD4lbIVjs8N+llXMJHYkEgmnPDTvzApTvo+vPKYq4XpxqWbnXY/92Nvbi/n5eaRSKReP0pinVa683BRUSmjgM1ahop6ytcTrDR88rdVvVEAxH3ZiYsJVK1Kmt6ZNqCGkFryNxervRBkWSVHWK/cikUigs7MTAJDNZh1fgPmKVGq+cACNIh9crc/KPpuoQsoqY2WJkynOswKsnjVbY14hO2swxyGcrOGsz9an9CmH1FPmflGR0GCmUaFhJIaV6J1Z77jRu7pe+I6ZJiS30jtWHo06JOqlk1RKFEM9N1vNy3q5up+tPAOeX4vY8AwvLy+H5KS9O8lkMkRSq+esWGUc9TypblKOCnWQTenkfhGdsLnfSjKO+7xzRFLI9eAYLQFHb2dsbAzj4+OuyL6SWjSHV72zesMuPM5Lr5Yg/68WkhX2nK8KDIU26h2wOOE8FTC9vb3Ytm0bUqkU8vl86MDZ3D5NwaFi5JoJgwEICRW+fLD1ekpZYSMqX222oHFjLa4/MTHhhBUvDT/bnhUKJzWmfIK0lf3WkAawNueVRW06OjocpNrZ2enYp4T1NHdciUXKZeCzUFRC049agcesgWSL9ainRoWm3rGSb5Tb0arwbGTe/MozZUMgdAoUdfNlPLCYCZuQEFkaGBhAT0/PmnVFgavrERY1fKexYxKzqLToGasy5h2mcZpMJpHJZEK57z7DIS6l7LvvqgfqhRP1/pCkpqiEhj18xl2UoefEFsHRimyWcQ8gVLqT8tsiJlFCpo2MljxkFdjsfDMzM4Px8XFMTk46wTo+Ph46fCpktTKKhed839usYa1K62n5WLD6dT2FrIcsroeonksul8Pg4CCSySTy+byLjWheqS1OYotbAAgdTJ+Vb5PhG3kmfA+eD54JrfusRAtloJJnoHnQ9pmp97SeBRtl6PuT2ERvl1C1zaFXeJe1q23OL180RJWIQ+GpMV6iGK0KAUt0scx2LQZCyJpGqVZpIh/CdgBrZa/rzdUXw9csAZuyxTNaq9XcOaDByhglCURKItI4Jrs82dSbZvbXesbKpVGZyHAMFVpbW1uIWEbjk7KSRUTY1CSbzYYaPRCJsd5tHM/FGrYqKyx6aZErokhE8vr6+twzofEaB9qyHnKr2RuUNUQnLGdCqzfWMxQ2w/hsWiHbi6IXgw0KWN+XBC4yqjVWws3SeKaFQexn8WebNeyBs3CLQtv2sChkrbGfzVDG/HtCcYxhptNpl0tti6Kz2IPWkKUyphdHxc3P8JGZbBx5o2E95EKhgMnJSVf7WePGhKipJLQetFX+9tlQ+NZjxrbiHXMdCn8qyqCxPDUQ6BVbkiON0mq16oQu44B2z6hkeH7oLanH0sxQQcU0HK2eZ6viaQyZsT7biSvOtA87V6uINUZsX1TGNI7UK+vr68PAwIBTBP39/Y4Jr+tRgloUzoovJGBr9s/MzLia1fPz8+458uwyTk65wdj43NycC+mQJMh0JCIzlvnL12bITZ8y5vetDKXn7yMG2s53rfKE7JnRdDMqZBqg2j4SWEXbLIdpM+LFvhELZE0PgBa31sZVr1hTKprxeDdTCdcbvvhJvQeiQtgqhs16kDonMhoJX6uBpMqZyrhUKjmLlLEfdqIicYTDGkXNPgtruPHztSGGKmSdqypjrtl+tQiFDS20uuf8e/UCLUyq8XVLJtHSlFom1gpiJZcorE1ioSI1Gh9v5jnY56HEOq3QRQHFz1C+gmY8xGX4+OZqvU2L8PiKmdh0NM6dMClTEPmiQrYpW74mDo2uy+clc55aj5uxTBoP7Pxls0kUnqeXx45QtVrN3dve3l6USqVQD2rKojhSoRodeg5svN/neVqvuBVUi3rFxu5V7mn8mOddnQtFQ5XQdSSUMRATZK2WiEKhvORKWGBeIYctvwaESV02LnIkDhWHVbo+S1B/psSfzQz865zU8+aBp1Dv7Ox0h5EemsalgJW0Jn6fgkiFcT0YqtlhOQeWKMZLYZnr9gyoRa6/o7H9OBUyAOdh8L0sy1QVsY9Brl2U6Pmz/jK5ALwb2gmtVCqF6pArEhDVOALCZDFfYRA+G9Yk5tAzHheTer256n7WK/GpKILm3JNUpF6ywtJahCKbzdZlVTcLl+rcLYvdhiw4fxYo0nOcTK5W1aPRzDNBA4+GkyoZvi8VXbVadZkiR2roebUhvo3+xidfGxm8DzZUoPX97b+Vm0IjTuUJz4+uZbP1T0seslXKynDkRaF1xjgUBQ+9IB42XnQrfJplOcY96l1K+z3+3xdv9sHbcc6P7884FGOT6rWrRwMg9Lw0zaOtrS30PJRp2Ep8Ry1mTR9g1R+b0qTvq8QoX7zKKmO773HsuT5fYNUaZ4zKFxKgRU5yI2FKplwAWFOUgOvRDl1U+BQQfAZRhr236nnawiDWSLJ33Yd0KTkm6rBevFY50ypLygxXpELnxL3UHFKSiCzPIw4jrp6joh68Lb1KQ1PfQ9nBXJvG+Yk4ptPpULlZ7eXMe3akeTi+8J6iZAyP2SYNOr9mlbEaQJZExzNDw6Uej8muZbOJir7RlEK2sVybn2pJF8BqWUFCq4rpF4vF0GKphFX4K7FiM4PpG416n6sKpB4JbLPnq/BcPYWlcWwqEioTrRpGha7emF5uq5T18+vNTZU7PRVCttwvQqIK65bLZbS1tWFhYcEZbjxXdo0K52428YJnX3kTetkZsqEwUDLPxMSEazUKYE0MVNN1gHDHHN4lKpCoMWT1PhVW9RXbt5kUlrWvyjsuo1nnRzKg7YOsnBQ16JSXoh7yZmc/1Ju/Plt1VjRE0dbWtqaQCcMd3NelpaVQUx6GczSVUD1A7e3eKN+jlaEy3HI6UqmUOycsjpJOp5HL5UKKsdlQJofvvNQrTUpDjudYz7pF3fhvK+c2w7HiiMVD1pduKGMitOYJw9CCo3Kg9WTx+zgZs60M+6B8D2ijl/7eZs4RWFuJTJUUgDWWqsbP1DKkEtWk+GabB/B3qNjJ1qUXrp4yC5zwQvmUqnoTVgjYGOtm7bclqpHIqEJA67VrMQIqERoW1kNdXFx0woHhgyAInEGbz+fXeNWNzNd+pfFsFYS+ty8cpb2CNR/b3v1W912NAB8vRaF/W1ucn6+pcHp+16tTHZd37/OU1WlR5MF6x1RcAFw6HRWypqapJ21z/Ofn510N8KiKrtlh0U2V3Txbc3Nz7p5mMplQOVzlizSDtFiFrI07lJ9ia2Doeef72PXYfz8uIev1hno8JA2pcKdV19bW5g4Rc18p6H2dV440fODzwHwKuBlyzZEgp+ne2BAAEO54s7CwsIZYYeEmJapFCSGo96qNN2iAKXzd2dnprHsqIwoca/D5jB5fiMB3qVoZKmxtST7ml46Pj7u67VqiTxnM9Hw0n5aIQblcdlWMuC9UPFEFrM+ArqeUFaK2OcvlctmlO3V0dLgzpAZfI8hJI3NVCNL2UqdhwHlrWpwPrdJ/Hwk5oorZPis1eDlf7pmum7nJvANasc5yGCzSYdGLIyF7gLBxbB0pyn7+jN68xnKjGg6Wc2DjxVTCaniqd/x4GS0rZAuLkkCRy+UAwMUz6WmRLRgEgbtk7e3tWFpaWkOIqlcy80gOn3KqZyCoVewjLPF3jsSc9StHrVZbU/7Ner20ZtezCJtVcjwfhKVzuZyLd9Nws+ketVot1LicwonxHvWO+f/1kIk4h4V9lbBj4dVCoRDK7fWdBZ/hENecfWiWL7xkFTX/ljnkylLlM+Lz1O5nrTZ/sfO25CifsuHgHFTekJSjz6te/Duu4UOnNBRHGcl1aZMZhpMsmkElo5kHPrLUeobAZg+fh8w1azhEvXlbhc/Gu5s5PxshtzpHRaD4/cfDaEoh+6BbFaraMzMIAmSzWecJ8NCRql+tVp3wIuNULStNG9L455EYuk7fhVKP0Vq2GmezpAW1ajd7LbzcOj+9JJorXS+FxedRabyQv7PRPBQ1IY8AWK0pTEXMOripVMrBwWQaq9e7EbR0JFAUK/DUELPpFtpq0e6HFpNh4QSyf/lvlhdUJnAza9xIGVvFrGsisYhwPLtXcS08M/oZvCcAYjOi6xla6hDwe6zOxTMdBEGI7MO9Zjw5Tg/SOii8Y3RSWFiFvIilpZXOZTzjwNpubHqmaNip4tP0yiPJXam3fivHlZtCOL69vb1uz3WLhDX6udx31Um+NCvWEeAz95G6HqvRtIdslbGm2nR1dTnrmuxZxr/4+0tLS65nZqlUClU0Um/NKuTNjgv61mnXaD1KPfi0bDU9Q3PdWoFjmhm+97fWuk8hU1mrMrZQGC+PrrsRti+FJJn2wGpJQzVg6IWlUiksLCygWCyugUPrPX/f9zfrnPhCFtYj9PEq6iliwtMsmjAwMIDBwUEMDg66QhZaZ7nRQhzWy7SFNTT9TJU1/07TR2ZmZhxjV6F3XW+1utLtioKU5ynKc1jP21KjmPsZBEHo+yxCwTAZyXfK9M9ms2tCAK2cGVUKRA/U0OFekmHPspc2dqqEME3p0jhrKpUKEbeswvHJzbiGjXlbmcNnonKGa+ZeW+a5RRRJbGzkmfiMoGw2696/Uqm4cql0Bu2c9Sz71rne9+IckSBrvSj8vyowejvWKwiClVrJmUwGS0tLmJ6edpY/PThLILKXb7NgSGAt3ANgjRJThaxzYkxRu4qwN7EeuM1QyvqeG723bz1qYXPwYmie6sLCwhpjZKP16LlgXisRE43lEJJj1bFKpYKpqak1HX0eL8Pn6do4uw7ugdbGZRoOSzoyT3ZgYABDQ0OucIWtaqS1ozcaViH7lDJ/rgqcz3RpaQnlctmdcaYgUZHw7lOJEAXSWGJURWcNYlU8RAsY7uI8VElxr4jGaRYCQyeMh0eFSevNmWiQnhP1fNva2lxlOmX96rOxDT8UduX7a41unifbCatV1MjKFvtVjU/OT4m5ClszFZbOmWWe23BCo8+D+oiGlvIgyOxmh0El/20kv+x98K0/zhHJQwZW4zWEWmwM2bLmgNUiFAAwOzvrmlBTOFtLeL3KOXEOq4jrEYiUIKIxV67NxkcY86GFG7cy3mjOwKqx5DMy+KoHg1rCjxIi6JU0AlmrAUfPSSE5fg4F6NLSEmZmZkJlADk/CrZmyHRxDusdWyNSjQf7uwpNE55m83ktuK8eMmtGE14lFLuR12PPhg+mXi+OSuFKZUfEhOfakvJULlAORM2X9oXEVNFyn3k29C4qFM05UCFzXW1tq/WtNaTU6pmy54L7oYaJEvqy2axLxWFWAZEJJbSRQwEgdNY0O0HPB8+IhqGiDFU8PhnDn9EYsyx9lTOKJPI8ahEaKw9sWGy9s677ThRH57mwsOBykiuVioOrlZ9guR1WER8JdBNowUMGwhWMFLv3pULQUkkmk45Ja8lE1nvjV5+HHMewilgPhR4yH6Tis5IUZrIdluIkkahlSqjcFkYAENpXhdUtNGnnZQ+kxh1tnKeRYaFdeipqxVLJA0C5XF5TK1mJQo/V0Pi1TyGrwqCioAFCYcECFdq/WrsOaZ1lfm+9NocbDSs4rSejis96Uvw9lhekwUlhSVg4m82G4GMtsRnV6/TFSNWgscqYXq8qJf6McDHn0tnZ6ZjaqpDV6IsqZ/Ss2/1U9jQAd2Y0fqyyUv9GFRRlJJWwrcNtc6zVkWnlzKiM8XENiA4SFVQ5w7XofScCZ+VXlDgydRBDnyrTFX0ql8uhEICGFNVz9s3jSCjlyCxr9bwAhCwiK2w1TmWFqh4Sm/fqS8mJQyn7PGK9LDxMqmAZ31SISdNQOCy7z2dVxjV3HiLtaauxPXoqeuFtKolCZTYNwFqHrRxKDVtYqJsQFgBXOczC6Y812963FjVA61V+ImOccU0la3V3d4eaHlBBs0cvCV1xVpKyXpwaEgwdUXlZA5QCi9A72eSM1fL7mtrGZ9uMUraGAgmjPOe+3uJKHuUz4Pd5Rwj7qkLmPVamM9AaGc0qdTWGiTQB8DocXJOy2xl/pdIh3E4SrVYeo/GmCrmZ8+KTi5ao6Eu/43y1KA5DG7YkK9FQ6xVbo79RWaPGhuofxvC15SPRHZ5zTYlSD96GeTabmc8RS9qT/puLshAN/20Vnq2swwtoq+pYpnUr8RALwWjhA1W2WoeW+aZsI8nDxjXQEt/MYZWxrUyjvVXpMdB6ZtEPWt/aH5SGRj1oXQ98HMpAhbP+W+OwqpCVrMKz45vLkYCUVFlwblRAi4uLjjugCkNTXegFU/FSIWt7QLakU8+4VWWssUxNWSKZjAqUw3o1/BnZwTw7jMFyXvTc2C85qoes8XnuMe8iPUEljfnCB4lEwt0R3u/l5WV0dnaGCouQG6H7qzIpylAnA0Bov21RHGV9F4tFdHZ2AoArpEFZpISlfD7v2mDSyONd10I/zaIpVi4qUVVlBb1KX6x2cnISY2Njrna7RSNU4fn4DaoAqUca3XM1lLkG3juWG6Uc1CIiaiBoip06K5Z8txkjskK+de+tuPo/r8bdB+/GodIhfPmVX8ZLT35pKF6pEIUqNlu6TMlfegHV66AF2yqx65q7rsEn7/okfjPzGwDAyQMn421PfRueue2ZoW4gnKvOu1KpuNZphULBHTJ9UNZaizv2PTs/i7/63l/hK7/6CsbL43jywJPx9tPejm3VbRgbG8P09DQWFhaQSq10uaGnxXQiesjMmWU7OFsxR9eicbCoa1nvvChyoqxaTYmiUtZ4mk8Zbwa8dOzHjsXewt413//TM/4UVz3jqpBgKZVKzmOhckgkEu5ZDAwMrGkD2N/fj56eHuftkJyjqTvN7rvekYOlg3jnd96Jbz3yLVSWK9iZ24k373wzstms6w9Mr5iDAlbvMeOsVBRkwScSCXdmOjo6nPDL5/NNhTZ07olEAkgAH/rxh/CFX3wBo5VRDHUO4YXDL8Tzs89fo5CBcC9bxhIpX2xNaMZuqZSpkJuFdxt5Blf/6Gq867vvwp+f+ed479PfGzKWqUypYGnkMOeeECtloz4zsu+1r7Aq5SikriAIcM1d1+BTd38Kj84+iiAIcELvCXjD8W/Ak9JPcm1T2aecTgj3mox8Fsjh7yoawd9XFMZyVKxHutHc+fMf7vshrv7Pq/GTwz/B4dJhfP6Cz+PZw88OhQoY32cmB+fIeTG+bdNX1XvfrBFZIZcXyzhj5Ay8/imvx8tufFnd31OoRhUylZ1VyOoh02r0QXZRRhAE2NG1A+9/7vtxXM9xWFpewj/e94+4+OsX4+YLbsYQhkJ9eRUGJsQ7OzuLyclJ9zu0+uiBqqWurFuOVj37N/7rG/Gz8Z/h2guuRX9bP66/73r89+//d3xkz0dQOVzB2NgY5ufnkUql0NXVhf7+fvT29jpYkYJWi1iwzitjhQBCilGJGVGV8nrnRZWH9eKojMncJ+nLx5SsF+9pVTHf+ad3ohqsft69h+/Fedefhz865Y9crWAandrfVcMF9Br7+vowNDTk4sZ89fT0OCVs44BRjaBEIoGZ+Rn83j/+Hp6181n4wku+gFwih58f+jmyC1nU5mquYTxhVEVffHE0Na5LpVKo8EYqlXJpW75CKM3MO5FI4CM//gg+c+9ncM151+C4/HG4fd/tuPyHlyMxksBJHSeF4vT8O0sk0prP/Ldtg0mhm06nXQGaqJ69HXcdvAuf/umncdrwaS7urmlXlC1dXV0ol8vOeK7Vas5gJhGWCjmXyzmlrF6ykv9sDHmjoXdme3473vPM92Bnbifm5+dxw89vwFtufws+eOwHkZpMuQp03E8OwtY+uUIvH8Ca82xDcFF4KsDK868sVXDGyBl47RmvxSv+5RVoa2tDJpNxOobnY3l5OUQs5pz4uTw79ZquNBvjbnREVsgXnHABLjjhgnV/x8IfvrwzLozK2JI4fGUdW9mAF5/w4hAj+oqnXYHr7r0Odxy8A8/tfq6rfapF3NWQYJPxUqnkPEolydgKRhrnapWQNLc0hy/d/yXc9N9uwrnbz8Xc3Bze9OQ34ZsPfxP/evhfcebMmZicnHQpHixOT3jGKmTCNew+xPCBMrCVhe2L4Tb6LDY6Lz6hatn2ipJwzzXWRUvbXmqOqJdnKDcU+v8HH/wg9vTtwXOOeQ6q1aqD0m31M2W6kgBFyFqVseYY0/j0EXKijL/+0V9jZ/dOXPuia53wH2obwsTEBA5WDjrBTrIW99CSnSh4ud9aMITPLJvNunvTSm1iYOUc/OjAj3DRiRfhopMuwuLiIrZnt+MrD3wFDy88jNPaT9uQbW75IWSMK4lIPTJfnLAVgVtaLOHVX3o1rn3xtbjqB1c5mJpnm/NSIho9+2Kx6MIBrPOgRqp2rqJX7Ksr0Oz5CYIAL9rzolVUMFnB/33i/40bHrgB907di6PHj8bhw4cxMzOD+fn5kNLkHmqZUzo0Kud5vlU+6ue3gm696MQX4YITLnD3PtW2EnPX51utVl1ushJHNVTBM+HL21eYXonNcYxNb5SpStnGDoDVuLFVBBuReZrdAF/8dX5hHv9y/7+gslzBSdmTXByWjQIYK9EXlZiWsaMipidnYcdWPR2O5doyqkEVnanO0Lrak+349fyvceriqU7Z8rPouRPGtgqZZDXmlRLiUrKP9fY3i1hVLx7oqyjG88OLRgNLjb5mYa9GxmJ1Edffez3+4nf/AslkMgTvA6vWvxozAJwQZYyY3o325FWYMY78UQD411//K16w+wW4+KsX4weP/gDbctvwJyf8Cc4fOt9VjqJgBeBihtqBh5CpGkFa6pbhHK213CoJJpFI4Nyjz8Wnf/JpPDTzEHZ378avCr/CTyd/ijce80akp8Pygc+BHjznaz2jjdCeOM/0m77xJlx4woV4wZ4X4KofXAUkEMomIdlPZUO1Wg3JDTXKLOnVZ/xZ4znK+VE0ZGFxAV/f+3XM1+axvbrdOSTT09MuJqvGi6IRqowBuPkx5U/zphUBbVW2UIYAq0gflTINSTViLPnNokS+inZERWu1Wqxn58h2rv6voV6QVohS73gz4q88ZPcevhfP/6fnY355Htm2LK4+62oMBoMYLYxiamoKExMTrqQnDxW9YUIy2rFH0xBIrydxR4s5tJq209XRhXN2nIMP/ucH8fcX/D26El342t6v4Rezv8Bwejh0MXkxmH9ZqVTchbcENno0yWTSXQ6FwCxSYZmbcQoxezZs+pu9OEpSo2VOxaaxxLjSpb5y/1cwMz+DS06/ZA0Lld4W18HzTMialbi0EIg2atiM6koPTz+MT01/Cv/j7P+Bt539Ntxx4A688/vvROKsBJ7Z9Uz09fW52DCVmBKguBZVyha+tq0F42KkXvmsKzG7MIvTPnUaUskUqrUq/udT/ifO6z0Pvy79OlR0AliNezO+GQSBUwg8O+pdKvmpHiIR9Rn888/+GT859BPc+ad3uu8lsJppYCFbesDrFZfhfHg+VPHWMzKiKuOfj/8c5910HhaqC8ikMrh81+XoHe/F1PxUqBKhZpko7MyQBbBauUuL4fT09GBwcNChQ6qY40KHuOfq5DEcqoaMyk29w0pKVoY8z7oiNHHd101XyHowrLBVaLparYasFT2QcQl+bvDu7t349z/8d4zNjuFrD34N777n3Xj3rncjNZVynXrYFEAD+oTxVEEDq2Ug8/k8+vv7MTQ0hOHhYUfeyefzLi8y6kXh73/+Dz6PN3ztDTjp2pOQSqTw5P4n44XbX4j7Ju5DPp937fkoiLgGrp8ekA+G0fi9xqas8IrrsvjWqGfEV52Jn0+Bq0ZSsVgMMZw1z7PZ1Jt647qfXIfzjz8fR+WPcopIG8Rreg1TVFiIgsqYHjIFEQ2IKGkqG+1nLajhrO1n4arnrZDPTh8+Hb+c+CVueuQmvPjZLwawen5JQKPAURSLKIsqWPUa6kG/rYwbf34jbvjZDfjHP/xHnNx/Mu7cfyfe8b13oH1PO05sP3ENzKvnXTMygiBw54AFQXT/1ei06ESUsa+wD2/55lvwrdd8C51tnaGf6Xta5azzBVaJTzZuyffR+2LlbKvn5/i+4/HNP/wmxmbG8JUHvoKP7/s4Xp98vbtvNORVfnDuaoipIaRx797eXgwMDGBkZAQDAwPo6elxDkCz8W/f8IXAbBjMhyy0tbWFzq16y7YfuEU945AvR8RDtladVuTipWLXE4V3gdZbuXGoVZ8MktiR2YGupS68/pjX4yejP8FXDn0Fzys/DxMTE46trLC0BvRtNRqug6SdwcFBDA0NufhgPp93grlVS+r4gePxnUu+g0KlgMnSJPLI43XfeB22Z7ajJ9fjCCC+frE0KKw3Z2NTJPrQw7fQ+2YoZHuBFJpTwatKSw0MEkk0D5PriUM5AMDemb245ZFb8MU/+mLoktpWbzTUqJCZR6v5x/SQ4wxp+MZRXUfhlKFTQvfolKFT8LWHvobe3t7QcycEp6RLZdGqguB94l1Q5RwXw/3t33o73vGMd+BVp74K1WoVx3cfjwfHH8T1v7oeH9j+gZCxRsNIeQOaJ06DmaVJlZmsCjmOO3r3obsxVh7DmZ86032vGlRx695b8YkffwIL/78FJBNr03lsWE3ljQ91sMo3DhnJ0Z5qx+7u3RhMDGJkzwjuGbsHt83dhidVn+ScE8oUPnudl3ruKhtZhY7/puOiaGKcZT+tka+hJIvAsWKklReWA6UkQEW0HvcK2ecd+zxkuv6+mHGcw5I8nJVXXcbc0lyobd709LSDvvTA6YMixMtDl8lkQqQdJetoPCjq0DhlV2cXOlOdODB1AD8a+xHeuPuNyAUrXhgVLoWqxsOVMcj35IVJpVIOdtcUnPWgvTgGlSv/7YOslbnOz9UccsYx2UuYecG2gEErl+az93wWw7lhXHD8BQiqwZpzpK0iiTjwfNt6w8qojiODoN54xs5n4NcTvw4p+4dnH8YxPccgn8+HhFEQBC61b3p6OpQKR2Gk87NxfI3jxqGUK0sVJBPhIkKpZAoBgpCnQ6EYBEEobklWMs8QSXXrsZLjMIyev/v5uO/P7gt973U3vw4nD56MK55xxcoajBe2HtfGhgniMnjWG/b9a6hhOVgOzUsVsiXnAqsGEaHq7u5uV4VucHBwTcqfJTXGFbrx6R8lqypvw0caVcKohmY0QygOEiDQgkIuLZbw4NSD7v+PTD+Cew7fg/5MP3Z27wz9rh5u9Xx1Yxhzs1ZRXAKKG/buH7wbzxx5JnrQg0PTh/DVR76KX1R+gdenX+9thK4eANeiB0WVhxaJ4Msq41bX9B8P/QeqtSpO6D8Bvxr/Fa78zpU4ruc4/MGuP8D46DgymQzK5bLbUypmZb+q58M9Z7ECXiDrlUYtw8ex3nnZ1bPL7a3dY7VmrdJSL61e28so6RO+UQtq+Ow9n8Ulp1+CVCKFpWDJGQQaX9KYmkLvWuTEduXZLGUMAJedcxnO/cy5+OAPP4iXn/Jy3LH/Dlx3z3X4uwv+zqXgUJkxd9jCuJbTUQ+e9CmKVvb9ohMvwlU/uAo7u3fipP6TcOf+O/Hpn30aL97xYoeiqBJVPgGfTTq92glK6xswZrlRDDnK6OrowqnDp4a+l0vnMJAZWPN9u3dW6Sq73e7vZinld9/6bjxnx3MwmB7E2MwYbnrwJvyy8ku8JvEaLNWW1qAlaowB4UplfE5ErDRso4TGuKoyWjnzm5nf4J7Re9CT7sFI50gozq6xd5/u0aFGkg3NxIXAAS0o5LsO3oXnff557v9v/Y+3AgAuPeNSfO6ln/NOUB+SMgd9cZDNGEEQYLwyjrd87y0YrYwil8rhmM5j8Naht6J3qhd7l/aGcs9UeWmsRy8RsArRKBTvY4nHsa7CQgFXfvtK7J/dj/7Ofrx4z4vx/5zy/2ChsBCCeZXGrxXHdE2cD2FGXUc9RRhVYG10XjhsjM13ZnRYoRB3HJPjlodvwaOFR/G6p7zOfa4KTZtDybXofq23r5t15p+242n48iu/jCu/fSXe9/33YXffbnz0vI/iNWe8JgRH+5jsFq1SdIfnxd4FPo84xscv+Dj+8rt/iTf925swVh7DUbmjcPHJF+NPdvwJDu47uIbExLUoAqTEIr2b9hkciWdRb1ijRr/an9ufbcYYr4zjTbe8CaPlUeTTeezO7sY7drwDnQc7cT/u9xoN/DcNIysbFfHSUq2+ututeMZWzvzPb/1PAMDFp16MT/z+J9ycfK96n1kPwbDy5YhB1vzQ2dlZ970z+89E4bKC9/dnZ2edkKTnol6nwnvqyShGz7icFtygBazxBfu5Ol879+XlZXzgdz+A8qlll2/MylsH5w6GPEjbUIKbrUKJeb3KtNXG9CSbsOzdemSRjebOcf7O83Hepee592WxEjK/tfSnbW1mS9PRM1ZFprCMwsAsgGEJM43OvZHzormtWkVJ18Jnwctgy55yvuVy2fXbTiQSa1iRze77OUPnoHBZAbVaDbOzs47NOzs7i2KxGCoRqC3z+LlaP7xUKjkm/Ea5tOuNRuf+7G3Pxm2vvi30t8ViMbR/3Guth66MaQudWo/OxtRJsmtra8Py8nKIF9LM3N937vvw3qe/13EFisUixsbGXIqV1jVQtELnpAgG7wdLx+r5aDR+bOfum7cdX33ZV9f8rYZcKCNZLVAJgpY8p+edcrJUKrmSm+utZb25c88+dO6HsHDWan/yqakp7N+/Hw/NPxRCgXy8AmCVrObzotU5UDlPMhXJUj452ci+q5yxBjPrLdhynrZ2gf4tn5HKF95fYJXZr33bG5Xv3hE0MPbt2xcAeMK89u3btzX3rblvzf0J8PptmPsTbd5bc3/s515vJIINVfaKFXDw4EpVnyMN5zQzgiBAsVjE9u3bnTW+NffNH1tzf2zG1twfm2Hn/kSZN7A198dq+M67bzSkkLfG1tgaW2NrbI2tsbnjsev2vjW2xtbYGltja2wNN7YU8tbYGltja2yNrfE4GFsKeWtsja2xNbbG1ngcjC2FvDW2xtbYGltjazwOxpZC3hpbY2tsja2xNR4HY0shb42tsTW2xtbYGo+DsaWQt8bW2BpbY2tsjcfB2FLIW2NrbI2tsTW2xuNg/H9f7IDYL14aPQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "fig, axes = plt.subplots(10, 10, figsize=(6, 6),\n", " subplot_kw={'xticks':[], 'yticks':[]},\n", " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n", "\n", "for i, ax in enumerate(axes.flat):\n", " ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian')\n", " ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green')" ] }, { "cell_type": "markdown", "metadata": { "id": "Z3l45KgtfUUo" }, "source": [ "Next, we split the dataset into a training and testing set, and convert these splits into [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) before we feed them into the model.\n", "We’ll use the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6jrYisoPh6TL" }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "splits = train_test_split(digits.images, digits.target, random_state=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oMRcwKd4hqOo", "outputId": "0ad36290-397b-431d-eba2-ef114daf5ea6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "images_train.shape=(1347, 8, 8) label_train.shape=(1347,)\n", "images_test.shape=(450, 8, 8) label_test.shape=(450,)\n" ] } ], "source": [ "import jax.numpy as jnp\n", "images_train, images_test, label_train, label_test = map(jnp.asarray, splits)\n", "print(f\"{images_train.shape=} {label_train.shape=}\")\n", "print(f\"{images_test.shape=} {label_test.shape=}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JzrixENjifiq" }, "source": [ "### Defining the Flax model\n", "\n", "We can now use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network - subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) - with [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) layers with *scaled exponential linear unit* (SELU) activation function using the built-in [`flax.nnx.selu`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.selu):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U77VMQwRjTfH", "outputId": "345fed7a-4455-4036-85ed-57e673a4de01" }, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from flax import nnx\n", "\n", "class SimpleNN(nnx.Module):\n", "\n", " def __init__(self, n_features: int = 64, n_hidden: int = 100, n_targets: int = 10,\n", " *, rngs: nnx.Rngs):\n", " self.n_features = n_features\n", " self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)\n", " self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)\n", " self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " x = x.reshape(x.shape[0], self.n_features) # Flatten images.\n", " x = nnx.selu(self.layer1(x))\n", " x = nnx.selu(self.layer2(x))\n", " x = self.layer3(x)\n", " return x\n", "\n", "model = SimpleNN(rngs=nnx.Rngs(0))\n", "\n", "nnx.display(model) # Interactive display if penzai is installed." ] }, { "cell_type": "markdown", "metadata": { "id": "FIXmNs5-lrEf" }, "source": [ "### Training the model\n", "\n", "With the `SimpleNN` model created and instantiated, we can now choose the loss function and the optimizer with the [Optax](http://optax.readthedocs.io) package, and then define the training step function. Use:\n", "- [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels) as the loss, as the output layer will have nodes corresponding to a handwritten integer label.\n", "- [`optax.sgd`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd) for the stochastic gradient descent optimizer.\n", "- [`flax.nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html) to instantiate the optimizer and set the train state." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QwRvFPkYl5b2" }, "outputs": [], "source": [ "import jax\n", "import optax\n", "\n", "optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))\n", "\n", "def loss_fun(\n", " model: nnx.Module,\n", " data: jax.Array,\n", " labels: jax.Array):\n", " logits = model(data)\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits=logits, labels=labels\n", " ).mean()\n", " return loss, logits\n", "\n", "@nnx.jit # JIT-compile the function\n", "def train_step(\n", " model: nnx.Module,\n", " optimizer: nnx.Optimizer,\n", " data: jax.Array,\n", " labels: jax.Array):\n", " loss_gradient = nnx.grad(loss_fun, has_aux=True) # gradient transform!\n", " grads, logits = loss_gradient(model, data, labels)\n", " optimizer.update(grads) # inplace update" ] }, { "cell_type": "markdown", "metadata": { "id": "K2Tp-ym6sXEl" }, "source": [ "Notice here the use of [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) and [`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad), which are [Flax NNX transformations](https://flax.readthedocs.io/en/latest/guides/transforms.html) built on `jax.jit` and `jax.grad` [transformations](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n", "\n", "- `jax.jit` is a [Just-In-Time compilation transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation), and will cause the function to be passed to the [XLA](https://openxla.org/xla) compiler for fast repeated execution.\n", "- `jax.grad` is a [gradient transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) that uses JAX's automatic differentiation for fast optimization of large networks.\n", "\n", "We will return to these transformations later in the tutorial.\n", "\n", "Now that we have a training step function, let's define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l9mukT0eqmsr", "outputId": "c6c7b2d6-8706-4bc3-d5a6-0396d7cfbf56" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0: loss=5.68\n", "epoch 50: loss=0.16\n", "epoch 100: loss=0.12\n", "epoch 150: loss=0.11\n", "epoch 200: loss=0.10\n", "epoch 250: loss=0.10\n", "epoch 300: loss=0.10\n" ] } ], "source": [ "for i in range(301): # 300 training epochs\n", " train_step(model, optimizer, images_train, label_train)\n", " if i % 50 == 0: # Print metrics.\n", " loss, _ = loss_fun(model, images_test, label_test)\n", " print(f\"epoch {i}: loss={loss:.2f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "3sjOKxLDv8SS" }, "source": [ "After 300 training epochs, our model should have converged to a target loss of around `0.10`. We can check what this implies for the accuracy of the labels for each image:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6OmW0lVlsvJ1", "outputId": "f8d7849b-4242-48e7-8120-82e5574b18f3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "438 labels match out of 450: accuracy = 97.333336%\n" ] } ], "source": [ "label_pred = model(images_test).argmax(axis=1)\n", "num_matches = jnp.count_nonzero(label_pred == label_test)\n", "num_total = len(label_test)\n", "accuracy = num_matches / num_total\n", "print(f\"{num_matches} labels match out of {num_total}:\"\n", " f\" accuracy = {num_matches/num_total:%}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "vTKF3-CFwY50" }, "source": [ "The simple feed-forward network has achieved approximately 98% accuracy on the test set.\n", "We can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uinijfm-qXsP", "outputId": "632f6e98-1779-4492-c2f7-125499c5b55f" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAHiCAYAAAA597/kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9eZjkWVXm/0ZE7vtaa1f1RjergCjQgiKILK0oqAg6oihu89NRBwYQVJRlxIVh3MdBUZxxQ1TAZVQElcaFHVHWppumu6przzUycs+M+P1Rfm6+35PfyIwtu8Un7/PEk1lZmRH3e++5Z3nPe84t1Gq1mg7H4Tgch+NwHI7Dcb+O4v09gcNxOA7H4Tgch+NwHBrkw3E4DsfhOByH49/FODTIh+NwHI7DcTgOx7+DcWiQD8fhOByH43Acjn8H49AgH47DcTgOx+E4HP8OxqFBPhyH43AcjsNxOP4djEODfDgOx+E4HIfjcPw7GF2N/FK1WtX58+c1PDysQqFw0HNqedRqNS0tLenEiRMqFq/6GodzP/hxOPf7ZxzO/f4Zce6fL/OWDud+f408ea/3i/uOs2fP1iR93rzOnj17OPfDuR/O/fPg9R9h7p9v8z6c+/0/93qjoQh5eHhYknT27FmNjIxIkmq1mqrVqra2trS2tqalpSVdvnxZZ86c0d133617771XV65c0dzcnBYWFlQul1WpVLS2tqatrS1JUldXl7q6utTf36/h4WEdPXpUp0+f1g033KBTp05penpaY2NjGhkZ0dDQkPr7+9XX16euri4Vi8VdXlG5XNapU6fSfOPch4eH05w3Nze1vr6u5eVlzc3N6d5779Udd9yh22+/XXfffbdmZma0vr6uarWqUqmk3t5eDQ0NaWJiQidOnNB1112n6667TidOnNDExIQGBwfV39+vnp4edXd3q1QqpTnGeeZ5c3vN/cyZMxoaGtLW1pY2Nja0tramlZUVLS8va2FhQZcuXdLZs2d199136/z585qZmVGlUtHy8rLW1ta0sbGhWq2m7u5uDQ4OanR0VJOTkxofH9f4+LiOHTumU6dO6dSpUzp+/LjGx8fTs+StczNzd5lBbra3t7WxsaGVlRXNzc3p4sWLunDhgmZmZrS0tKTt7W0VCgV1dXWpt7dXfX19Ghoa0vj4uKampjQ1NaWxsTH19/entW7VQ25m7szf5X51dVXLy8taWlpSuVzW7OysLly4oLNnz+ry5cuqVCqq1WpJdqampjQ5OanR0dG0/hMTExodHU1r3ujzNDv3arWaWftyuaz5+XnNz89rbm5OV65c0ZkzZ3TmzBnde++9WlhY0Pr6ugqFgvr7+zU0NKTh4WGNjo7qyJEjOnHihE6fPq3jx49rYmJCw8PDGhgYUF9fn3p7ezN70+xZjXOPg33Y3NzUysqK5ufn07qfPXtWs7OzWlxcVKVS0erqanqOwcFBHTlyRNddd51uuukmXXfddZqentbQ0JB6enrU1dWVO9+95r6fvCDzm5ubWltby8jN4uKiZmZmdOnSJV25ckXz8/OqVCpaWVlJ897Y2ND29rZqtVo6F93d3erv79fIyIiOHDmikydP6vjx45qcnNTIyIiGh4eTTurr61N3d7eKxaKWlpZ0+vTpPefO2m5vb2t9fV2VSkUzMzM6d+6c7rnnHl28eFFzc3Mql8sql8taWVnR2tqaNjc3VSgU1NPTo4GBAQ0PD2toaCjpm2PHjunkyZM6cuSIRkdHNTg4qJ6enrTu9WSlmXXPk4t77rlHn/nMZ/SJT3xCd911ly5fvqy1tTXVajUVi0V1d3cnW1QqlZJ9cTvDupdKJfX39ye9ec011+jEiROamprSyMiIBgcH1dvbu0v+l5aWdsl73mjIILNAIyMjGhkZSQ/Ng5dKJW1tbamvr2/X4mKYisVieljes1Qqqbu7W93d3env4osNQyEgXHspLf853w8PD2cM8sbGhkqlkqrVahJaFpH5bG9vS1L6fN6DdRgdHdXo6GjaCJ4f4c8zyHwfv+4396GhoeREFAqFpFhZC16sM6/u7m5JVwWVg8zaIXR+wAcHB9NrL4Xa6LqzVgw3yKVSSRsbGxnFsbm5mQ627z+GeWBgIBmGgYGBtg1yM3P3+W9tbaV95hysr6+nw+37Iik5FwMDAxocHNwlSyMjIy0/T6Nzd4NcKpVUq9XSem9sbKhSqWhgYCAplK6uLm1tbWXOMDLtssM+9fb2qr+/PxnlRgxco3OPwxUv81xcXFRfX1/SEeghzoSk9Azu7PmcoxJuZN33kxc3buwtc2e//Ry6DFWrVUnS1taWav/W5Tiedf87zghy5nrJnet6c8eBQE92dXUlHemGhvn29vZqe3tb1Wo17XV3d3dGl6KH/Cz39/enM4/OrBdoNbru7vxsbGyoUChobW0t47i7DOPg+Psh226Y/d/YAtZ3aGgoYw9c9vP0537P1pBBzhMwN2zr6+taX1/X2tpa+n59fT0ddgSJg+Aehx8cSdre3tba2pqWl5eTt9HT05MOWrVaTQvfihLmcBAlo0g3NzczQl8sFpNCYfFHR0c1Njam0dHRFLFHBcr7s048a9z0ZjaJdXcve2VlRZVKRUtLS1peXtbKyorW19e1tbWVDrELlf/bFQJ7tbGxkV7b29va3t5OAtzoHNsdfqAYHAb2h8OPHNzXg89kn1nD1dVVra6upghoY2NDm5ubKdqXdpxQN2SOpNyXg89zAxuVe29vb5IlHA9Hl1ZWVlIkt7KyknkujEmr57SZwVlzZ4l5+jmXlPkZugu5d+epk3LvRnljYyNFx6BYyI1Hw9KOviSSQ+5d9/JCh/FvZM/1ZSPzdGPszlp8oWfQ4+hml3ECBz8nbiuiwXN92coa+zr7/NlzXz90dBzMGVnw5/HAIAZwfpYbdejyRlMGOW4YCwt8ioGIRmJ7eztBAy7oHiF3d3erUChoa2srwWlE0xhHvDQWstnDHoUZ6Aglurm5qWq1mg4C8MTY2JgmJibS1+npaU1OTmp4eDhjlPkMIrz4cuXHi7XYb90RstXVVVUqlQQ18iqXywmixglyR4c5sAe8V29vb9or/hYBdpTjvlCsPKsrWEZPT09G0dwfxpjh6BDOI5A1sr+6upqUFygF8k6kkRcZ3JeGmc90NAXnt7+/X6urq0n2OJvr6+vq6upK531xcVEDAwOZCBTlhVE+yOdCXvyMcLZd+WOQS6VScp5wKJaXl3ONMM/UifkxL3Tb4uJi+rq4uKilpSWtrKwkGa/VaunzMVYYQp5TUtJhOEarq6vJYPhZadQgR6eL948OJ3q9UCgkWXbH04l6nBN3RDDY0o7TjU5sRdfEuSMDLgesHy90fbQn/jNH6ojqQR5AInh+nNFolJt5lpYi5HqRWnxhILa2ttLG8ZBsAgqpt7dXpVJJ6+vrKpfLKhaLyTDwtw4HNLtpebkcDqMrUIQMhTI2Nqbp6WkdO3ZMk5OTmpiY0Pj4uCYnJzU2NpaieIyWR6j+rO4JorQkZTZ/r7lzIFdWVrSwsKCZmRnNzMxodnZWc3Nzmpub09LSktbW1pKH7dGKwzS1Wk0bGxvp8wcGBjKGZG1tLZPLbDaab3a4EXYPV1Jaz+7u7qRYm/X8Oz3PPAW7sLCgubk5zc/Pp9wlso9yKhaLGegdB9O96/ti+J7mGWMgdRxLj4jW1tYk7UDwwLzsm0OWKCnW7iCNckTsUMaei2V9Y/AwMDCwy0C4UWl13jGAWVtbS8707Oxs4tYsLCxk8t0eITvEXSwWk57ieUqlUnKoY8qhr68vGfD9zkqcKwYN/VipVNILHeGBFnpcUsZpAPFkfuVyWb29vUkXSTvIHWhAu+vsPBscCRwddDxzY23RKRjgvHw9EPXw8HDiNpFuc3lvB/VqOkL26ADYCsEmQlhaWkrEhPX19eTtuQfhXz0nsrW1pUqlkuCd7e1tdXd3a3h4OCm4VuDKvOgej43oEBKXJHV3d2tgYEBjY2OJ9HTkyBGNj49nSBPRGPOKh8pfnkd3waw3HB5dXV3NEEEwApBB3HPlGYBWMMQ8K1EDyiASSdzbOyhY1b1T3x/3YKXdEbLLwH0VvfNZLkORHDU/P5+cG1IIrGPMn3GQY67poJ7F3zvmxjCi5MeGh4eTM1ypVFStVtNZRnHxPERw0s65GRoayuzTQUfIDlN6ZMfZxpGQlH6OgXG+hOf/OzVvlxcM8szMTMZ54/zhKDMH1hd9go5yNM8d6kqlkhAOd6QajZBZR/Sjo5/oSp8nBC7WUFKC5VdXV5OOcfIj0TQOIY7g9vZ2xhg2M2JaAB0X0wExQvYgpVgspnV2qBoHmlwxxEY3yPVy4c0+R8s55JgXyIM1UKDufcfQ3qEu6arwspGbm5sqFosaHh5OQhCVcTtzj/kRnA0XlMHBwcTuxSA7o5r5Y4R943kuJ0F4ZOwGZz/vNRqBpaWlXLjLo2OincHBQXV1dWl7e1urq6sJrmaOQ0NDKTp2g+x71U7efq/ninkf9675PxSS54H2W7NOjwiNsteuaNiL5eXlXekPh6pjDuq+hqwjeYXzCQo1PDycIjXPzXJeeC4iMhQR8uZIxkHuUXTmfI4oZVfO0tVz4UxnXijVRiPKZuaIHCMvQP1EyO4Iu06I5EDSSDH69OeIHJ6IJtV7pijf7ti4U8Oagvyw7xBuJSXnxx1o51qsrKwkgldfX1/TjkPe3PlaL4fsa+H7Gx1UAkTPF3uFAQaZyh9Ic05g83Rks6Npg+wPzsPVe5HHBH8Hd49lQXlEAvIoXrYQN61VAxEPcoRMPZeAIgWu4BncGEdj6cbOyQB8Lsa4GeGrB8tFIgiHl0OC4HR3d2fyTpAXmLfnWfiKM3UQ0WiMiN1B4vDwPOxFZGHel1GxH3Z36ID1iHCA8oggS6VSiiCcWeo55GaY7O2OaIwjQ3pwcFAjIyMpAmKgUDHIjth4DtOJgfdVWiFPEbs8e+qL7zEC/v+utzo557zozQlxGDk/Z46scXZdX0ajE0leToBsNofs7+1Qr68TESVwLlUPtVotzRcHw3O10Wmql4I6CFTF19c5TJ5+8SqOoaGh5KDCosYgg5BGRnW7+qklg+wjj7zkXrckDQ4OamxsTGNjY8mYeWToHlStVsscIAQ4D6psdb4u7J7X9QWMm+e5YHcimC+KeWFhIQPpEA1FQWjEa40jOg9+eImIMQJ9fX2JFU6NJfXIq6ur6YA4I9W94PX19V2lC52CraOBc+/Zoxn3outFlKzpQY96St9hPVI0kpKB6+np0cjISIadDxfCn+W+ZFpHZ9OZ644SeQkfOVB3/lx2ooK9r4h39WQpso9B2zxC4jniOfR5t2sYorPvsuORobSTu46ln1HvOErjKZ6YMvMAo5F58jWiiG70Y7DiNceDg4MpVcZ7SEqsZdBBSbvkrZ01j2SyWMkQ06JEwMViMaE6XsqEMfYAzOFqD868DLddhrXUokGORi162ngbtdrV3PHY2Fhq6DA8PJwOeWRFkjvGs3Jvzz0zqXm4Mg+mQ2kCoXjEWw/28BpOCAHr6+taWlrSwsKCZmdnVS6XU81hf3+/RkdH0xxYq2YOS71ncCgUKIsDTaQzMTGhkZGRBNVJyjgOOBTOOl9ZWUnGhHrsTkJ4fHWCneepNjc30zPhfXtzmEg4O8gR0aCIUHgucnl5OcGKkLaGhoYSMdCbB+TV7EsH71zknQFph1iDsiJSgEFNxMNX5uz7mBeV3RfRcURaInpE9ClplyPLmsQ1yvu+E3N1AhFGFSPGnLwc1FN5EdFj7eshla2gcL6W7myhAzzgAjkkeqSFJTJFumljYyMRGr0eupOykSfXeeVJ/L8HLcwdI4yuQf6dWe38D/4PfdSJtFPTBnmv/FMM+yElTE1N6dixYzp69KhGR0cTfOuQ1/LysiSlwx4FoxOH3KHPSLBh07ze0ssmMFZE/TBQYZ4uLS2lbkfz8/PpOejM4szTdoyxG2JSAcxZUlpzUAk6KJVKJa2srGhra0sLCwuJHcvaYmQ8H0XphHvxnYCS8iA8Z75KO/kzPFYODd19Gm0i0O7IU/i+VhhkcvB43SMjI4mJ7+VyMPP3groOani+zImFscbS2aT9/f0qFArJYaYUir+tl77KMx4HwT9ww+SwrRtlHAXOXlwP9MJBOXh7pca8zMcjOQwy8/Ezk7e20dj7ZzSLwkVOR9RXBC4xzTEyMpLIi7xPsVhMqBE60KP+Tgy3ScgyxDdSRESyBBmFQkEjIyMZbhCQtEfCGHTOBn/vTU461ZxI6oBBZgGAFgn7KZ3p6+vT5OSkjhw5oqNHj2psbCzBZBxylAM1jg5nIwjtEnnyNg1CgUO97pESrXuJFOQoPFpIaF76Mjc3l7o21Wq1FCkNDg625LnWi+4pxmf9+D/yx2NjYxofH9fw8HA6KMvLyxmYhQPnxjGS8lpxIPJGPSVK7h2jhkPW1bXTVhUPNs8jPaixF4znZEavY0epDg0NaXJyMrX5HB8fT6USjbZYPajh5yCmmGJZFgqVUh2cNNCvqMAjVNpJ+Nffh695qQTnQWBUJOUa4/uDk8BnuSMQq1CiLsIhrGdc4zq3giTWcxxi+s7n7MFYf39/MtacDRAj5KVeE424Ns2OqB+r1WomQua8IeecUdArnGUiZiJkj6597r5f/LsTZ7gpg5zHSIvNBPCWOBwDAwOamJjQxMRE6uHb3d2tzc2dlpuU2PCQ/mCdPCz1HAl/+UEgx4rjUC6XJe04DtKOQYYxyQvYtb+/fxcZrZWNi84ERp78nkc6RDkIGLmdanWnBZ57dPVyRgdR8+sKNI8V644M0T4F+Biy+xLmjUo/j5Tiyp8zQd6J6Biv+9+DMZaUiU6cH+FpqFrtKt8gtsJ0UiJnhLXx792Z7qTzVC8qd6PskKs783nG/KAhdufX+BqDUEnKKHdfKwxxXk4Yoxhzl3n6s9F5+vdRNvPey52c+LxeMkngwPnI63DVzjng7zzt6BFxfBG4OGELNAvHmZLRmMtnrvHf92uEnFe/SCNx4Nrt7W319/eniwyIcngA7wTkB5mHlXbgtU48dIySI9wem4JLV4kr3jkMspNvvOeQV1ZWMoaMz3VHYC/vcK95Y3AhUWxubiZ4hv/3frbeYL5Wq2UcDlfIjkS4Mj0IQxzzsZHNSSpAyjaf8O5WDuXx3gdtlOOc60F67ANpA2Q+5o3vayJXIyNPkbq8uCKKBjGPAMSrkyS8elExXz2nGh0Cf4+8XG7kSnQCavdn9kCA6I3PceNKKgBnx1nheTolj8gUo7ZGnyG+l9fWMq96aS70pndBI0oGmo9co1hpENetlbWOTk/U8/39/YkEyxxi2hUd6sTG6PC06vjsNVoyyBE69UNCXoSLHHp7ezNRQn9/f8q7OlzJ5sHMQ6icWduOJ+UL6Hka3yR/EXVubW2l207W19czN6dIOx4seWSeoVAoZHoCR+ikGSfDDzNwNE4PBLgIPYJY0NqTOe31WVEReaqgnbGfEnXC3NbWVqbVpD+L55/uq6iy3twjiUna2SdXuLGJfj2o/aCdirzncWO0FwJAlBn/VlIGIo5EKl8jN+admHsePB3LcvKMa9zLWCoUa907MdxAOOKDswIU7fNkHp4iYY7oWHSMGzePAJuBUmOwggyj85DdWm3nQhJJWlpa0vz8vHp6elL3MJoXLSwsaHl5OQUDOHV+yUo91KudkafnvfUlrG+/TIS1z/vbvRzKTp/ZpiHraJClrFdFM421tbUEGzhN3I0InWtocEF3I0kZw7LfVW7NDPfUqtVqRuFDRuOqQ17lcjmRoWIbSkkZ5cbw2msnFkR2bSPGhTUnOsZZGBgYSPm8CIc5LARcfZCw3H4jLw+7V92kIwKuZPI4BvU81XYPS14OOc8gx/WPvAqP7PPmzlx9Xw4SjvccfqzB9frq2CjGjZU7IsViMbU2xTC7oYyIQKvPFPchXlbgRtmhXSJkfuZRZ16v405WFbjOjEiipGTI4j6gH+M58YAlGuS8nKmn4PYbMdBC1zgqWKvVUicuDDC6tFwuq1C4SgCkSQ49D0CIID16WZEHOe2iRpwjN6wRNeT/IXfh4DnK4u+RFxEf1Gi57AnjGNnW4PJ41Q5RkGxHsMjNeoejyMhzAYvRRStRspR1IGKXIgwoRJ3t7e1ENqr3eXnMbb8GzXOge8E0+603Rfg4FAMDA3WFx2HpqEjrRb5xXduNaKT6fapjswEOgtc6uqcaS0BQTPXgo04dor2i+7y0hMOIkU2KEYz7xd/HueM8dlIRRGPsUbBHjH4RAOiVKyt3PmM+3TtF8b7ugLZi7Nw5iuVn9ZyGvM/hZ/7sGPN6HaM6AVu7YYBnUygUElHUnQJ3CNxpcgTSdU6EgOtxFPaSI5c55sk6r62tpWhSUlozR+5IT0pXGw9x2Q3NhaSrcoIhjDqx01UT0UH2CgLOrZffulNXj/NzXyBYLUHWkjIb7cbN28/x++6lAXWw0d4rlaia6MgNsh9on0cr84+b5cafSJncBx5iZBujiB1KHh4eThFwvBUk9jptBn73dabsis+MygMPkTkSQUdDEv/GmZNuVDohiHlQYx55LK/EguERHM/j845f83Kf7cw/RoYxSpR2ZDIv/8fB3ysXFY16Mwq10eeI+xBhZmePLy8vJ8Xq3fL8uX09IrnLc8mtVBfstQ9ec+wNQGKJZL3n5z28E1W9Oup2jbHrSLgFNA5yEp1/jjseRKBOBHPj6fB0bILRjMzEAKu3t1dbW1uZoIj/R76pMCE9STqSNCSlR729vYkf4kFQRAuZRzsjBhToeL9zfXt7O61rRCH2k6ODHC1HyHzFADh0ESMwfgeF6ochtt1zISPidiPWCQMRFV+MlNk4vEDqpT1C4H0cMgKWwgP2Up12cuGurPFSWSdnkPIVY8FFE0688G5LTiaJMLdHpZ0yBnmwryvtvXKEICfr6+t7GjAnpPAsrFer86+Xh6z3c+Qa9IfUSJyDOxEeVTs7NE9e2tkLP4fM0S9X8Zp7DDK5QK/rjXC764I8Alyn2Poxus9rkRkdJR9OXnSYMpLCOgVXezrDoWCMk+sf/p99iD8j/VQoFJJRdmPsa9Ts/PMClZ6eHm1ubmb4EES13u6TMwmiCC9obW0tzSPqnEgYjEhRJ9aerwQyBF1OHmX+/M7AwEBum+b7arTcOtMf2BWTw3N89UPKV+8MxWtrayuTM4x513qEmFbm7oIRiQyei3EmOIde2jGIbLizasfHx1OpF2VHngdvJfJkrv5vPFVfZzxXlCBpAW7j8isx2S8/gJ6PcuPcSWPs+UWPTDy6QilRblar1TK13zw3ysjrDB3tqNVqCVVo1yjnvfz/mffy8rLm5+fT4edyDz8rHjlFp9DTHLHFZjt7EZ01OrbB36D1pzvJfpMVOUFasDIvvrrTnGeUOxUhRxnxdq95kU3cJ6/qcOg6Mq07pYzdmQY2Zf16e3u1tram/v7+XUiF57hpntPX15fIo6SuXC/EMrBm1zwP9cSJoG53aWlJtVotEVhrtVoyvp6bdxQgTw7yeC3tIhJ56+4RP+vMWqHPa7WrZZY4G3kln/8uIeu8kWecpawx9sPEIXDIenV1VZIySjUyVDvJsHWvPrKuY3cXPFI2yQ8YCpS63+npaR09elQTExOpDhgoO9LnW11jlF6eMEdSCEqXhiWLi4sJSpKU8R7xgDEE7ToRzCnmw5yVG2EiV7SUkpVKV+98LRaLmYjG548B5iv1hXHt/Guro56MA+Ex542NDc3NzSX0JCJHeXlA0BU6B21vb2fk3/+22Tl7GgMm7JUrVzQzM5OcNVJH3kHPbxZDCSODGBnWHhmPUWi7EXKUo6g/PMp3ngSf5UgSP3eoOg+ybne484UD7xwQrkjMY6Y7jErEWS6XNTc3p+7ubpXL5QxhDmMYKxZiPn2v5+JcOPE19jWYmppKnwu6sry8nEkb4Bg53O6ysF86o1PGL+pp5BSHEjkCZalWrzYTGRsby6QpD8JZqDc6YpB9eB6NrzHv4x2h/MYTIj4o6cAjB9mdyY1yZCz6Z7rS9ciAXPHo6Gi6pnF6elqTk5OpwxQGrhPEBSf6RGPswkMURFnC7Oys5ubmUq00Bs09d2+wntc4vRORWYQaY37eiSS0I435KdIHEDPivCEVStoF+XYipxxHfDbkhLtpvYkMyseRCVIaNBOZnJzMLYWRdpRMK4bNDTLEm5mZGV28eFFzc3OZe8zdwHnEDNTHPNyR4zki+S6PZNWqUWZtcdp8bm4I8iDrPMfVjfF+hLBm5xr3jRxyoXCV2BSNVB5ShIGrVCrJGPP/7iS7DHq03wpjPC+yRMdNTU2l/CvPsrGxkZEXPp/3chsQz3/k1TgS2IncfcwhDwwMpAjYAwLpqrwODAxkUMRY+XHQo+MGOQ4/RAiXwzAeKQEtQg5jAWPZyEHkGSJ8nWeE3ODFfB+QDh1faIRSrzNTO/P1f0cF5yxYIuRyubyrtIxI3yHeCNd3otwsTwnG/LFH9VtbV5sIeA4Wb50IlNu03CCTr+cgOVvUowjPwbeDUvjzuaHz51tZWUkIC3vjhzs2chkcHMzt++uf2ynYGqeYNZ6bm1O5XE7VDiilSHgBakV2fO5+Vn2OjUZoe83Z1xl5ccc+sqzrfZZH2Q5duxx2yhhHaFba6TXf3d2dMVSRMBqh+aWlJXV1dWljYyM5TTidDE8J1iNH7jdilBzLoIaHh9MZc1KXR+isocsqTgOGkJfnwGNKrhMlUOjpSNrd2Lh6SRBpPZxkJzA6asLcOu3Mx3GgBtm9tuhpe42jkzAQWIdPvXi8U5C1Dz9AfuDjgYqEGvci3bChnIDx6hn4Voe/R4SuYxS0srKipaWlpGg5TE522M8Ye+lOq0bZFaErwTziD/LC3FEMzsyHbU3E4MMjN++vG0vOnAiz33rvBXu77EBMxLnAAYjpHBQUigDl0dPTk2DApaWlTBOayL5uJUKOchIdCU8nYJC9OQhz93PKDTl+I5fnvTtFDPS1jrCu1+nyXFJ+a8do2GP+uB4hrJn5McfohLlTFf8ufuU9QA/X19cz19diUJzY6Ovdqr6JDqCjOV7i6rwNfo6j7zLiKRui6aWlpYxOceeNz/E5tDLyIGtIW2tra+lsunMRbdPGxkaGTCcdrFE+EIMchdIjNaI1iAnefYYNJNrkgHtP0U4y8qLXjRHDS2JToiLyzXEFFyGSPFZ1p0aMjvMUFc+Ccq9UKlpfX894vp4zxoHIY4Tzma0On6OUHzW588PcgfZIZzi86LlKNyZE1x6JbG5uZpwk4LZmobyo6KIjx/xd4fo6+lcn03mLPifjYdxRdK4Ymx15MB7nDUIXho1157NBKJAbUjV04YPASB/goaGhXaV+7Z7dvPMaS+fcsO71Hr5vMVJtFVrPS82AAhExOgkuOitxcJ6BhWPbWIeFnQviDnU7qGKU8Yhw4UzSDEpSujgob07V6k7ff0mZfeP9XWZAtJqFi93587JUGlbh8IM6sJZOSMPx54IVhiNeB2GUO26QYzSER7S4uJjYmrOzs6mUAhjMSQ8O/Y6NjWW87k6VPjFXBIF5wur1TjPkzThAniNzgoIb7k5HBns9g0eeeYQ5omMcIEnpwHLNnue6Y33gQSESUfG5YsQgF4vF1FzA19NzccgFsBkOID9bW1vL3HWKvDVjkOP+Mw9fe2QAowuk6/X0zppGYXiqhGfZ2NhIl5SQlwMmZm9aieRwBlCio6OjiVDpTGlHfzwvWCgUMreJTU1N6ejRo+niGGTJO9TlkQNbHX5mXd5due8FV/M1GpkYWfvvNzs3ZNBz26yvy4JHXnl6gu9xhDwf66gi6Q+cJIKYPA5IM2sfUZR4bznM6t7e3kxbZAy1R9ZUSNBK00vu2L/R0dFM3XJfX18mKm12MAe/jAcnZnNzU0tLS5no3ZFcKlNIoyILeU55J0dHDXJepAY5Z35+XleuXNHs7KxmZmYSicTxexQY999yP6XnYtstwak3V5Q4jgM5tZWVlbSJPT09SZk749cVQszbHJQxjkrFy4lIDWCMea2srCRIyElE5F/jrT5RgcbnOohn8vwgd2RTC+nevxsvVzr8baVSSXn0SqWSnDwMJn/TiFHzAxidLebt0RUGj9w2uTccApxLV7AxQiNShUG8ubmZyI4orFbyg/GcTU5OJqXqfAdkBMWLjBcKBfX39ycm+JEjR3T8+HFNTU2lEj+vjPDa9nZTTvVkPkbIedFtRGHcGMeUSSu5V39vYGaUurcR9u56vIgCo1zFtIhHqMgH+tBzvJ46cJ3ZrDH25yFgoS59cXExPReyE52KvAi7UqmoUqkk3k2lUsnsXcxB86rnZNWbO8MRnWp1p3Z+Y2NDi4uL6uvryzg7DqkvLi5mUk7e68DJtZ0cHTPI9YwxD7ewsJCMMYxfWqthkN2DdKKIe3qdykPx1SPk1dXVTIkHETxGLBokIjI/IPW6/fA3nRgRdnPCHMYYLxalgGJFIXu9a16Td48g+Zx6B67Zsdd7OARNwT6/hwxglDzS9PcBIvScM8qDyJA0yH6H3J/ZCX8xbREJf76+MPB5DQ0NpT66UX5gDgMfE+EXCoV0qQgy1irczhoODQ0lIwtsznrHFAzzKRQKKQobGxtLd51PT0/n3vUcX61GFjG6RUbymsxEgxoVJ3/veV7/u04YZOdA4Ah7jtV/n5RAjGTznDWMI7qGvfU0Qr0qiWaeIy9gQa8sLi6qXC6nVBJniaifz+LvKbNDH4EY0CubKNpz36Rymt2L+LsOnfNe29vbCY72csLIvcFx8PQO0PtBGGOpQwbZBccjNa9h9NKb+fn5DFwB+cWVhrQDOUi7vUUEttVF8cOXZ9AoxfLcDxGVG9u8jlPR2+4kvBFTAu78sN7lcjljiL1DF9CREx2cMOfvy78xxn5goje/34h72wzsAyRHLoim9Bz+mAf06An4F2M8MjKSIf/s97k+dzcs9SBYVyheLUCEPj4+rpGRkQSFuREuFAra3t5O5UXU/G5vX+0TTMqhVSawPwcK3JW6pPS5TrjknLHe1Eo7isVXV8p5BKN2HOqY1vC9jlEnv5/393zvzUH87PrvtzLHyElhHzl77sx7R7Y8mUI/1XvxuwQz+/Vu2G/to37xiByj6n0jMGojIyMaHBxMhDOf+9ramiSl5k+gjpubm+rq6koOHshPb29vQ4z5vebuOoE1d26Pk9HifGPTKs7wfimRToy2DXL0WL02kNo5j4qphfXLJCJJgzKGlZWVXfXHeTmWVvIiLnB5+SjPbTtMykHyrkAxyvED7p10OpE/y3N+YmQ/Pz+vmZkZzc/PZ7pzEQ15xO5wDWsPNA9M6nlOvFkvH9pv+D45oclzla4wcBZg2jvc67wCFJxHSa5Y/fuurq7c5vGNRsgxHxZJexHGYm2dxOPsVCerReXLs2AQi8VixxpX+LzoVcx7bmxs7KoKwLlgjqVSKTGq2RdPd3ArUDyv7aIqPqLzFeF+fmevv42Qt3frirnkVucYz6rLHCQt7+kcO545Mzm2CEXWOQPOfs677anRdff1cZ0c2fegV8DlnE+vQ2d9kV/QF963WCzuurGr3iUhzax3lAnf7xhQ1XNAfF5RNg7KGEstGuSf+vuf0ls//VZ9eubT6u/q15dc8yV61Ze9Sqf6T2UMA233Ll26pIsXL2Zyx0A5sH7ZHK+L9C4v7rnmleM0Mmq1mra2t/TKd79Sv/vx39WlyiUdHTiqZ5x6hm4duDWjyKWd8hkgIAwVECIEhWjIfSMR2lhP2uyo1Wpp3W+fvV19XX16zLHH6OWPfrmOFI+kHD1Oz+XLl3X58mUtLCyk7ksInQtUzKEDq8Ue4m5ccFB8bvWGP/PFlYv6iff9hG47f5vWttZ0pPuIvq74dRnICgUzMDCQ6rkhC5G3BBpFPrxXN85FLGFp1aBhVC6tXNKPvPtH9K6736XVrVWd7DupF0y+YJcidfjRocmoMJiLX+QQO07lpT+aHdf9/HW6Z/GeXT//ni/8Hv3ME34msWK3traSMZWUUWAY5a6ursQ7gFWd16u9njFudrznnvfodf/0On34/Id1oXJBb37Wm/Wk409KaxoVblynvM+PiBjwZFS8UVk3iwStVlf1C7f/gv72/N9qfnNe1/Zcq/808Z90Q+8NWl1dzSAtXnroeXfWcmtrK4N05TVAcZi3VWZ7Qv22NvXK216pN3/yzbq0fEnTfdN60viT9Ni1x2aa+Ug7ZarOj4DLgYOPznRyH/rcg7k8xnwz0XG1WtUr3/1K/eQ//mTm/24cvVHvfva7k57DqchrmerBoZfT5aEwBwFbt2SQb7vnNn3/o79fX3z8i7WxvaEf/dsf1de85Wv0rme9SyuLK6kdH/D0zMzMrhZ9zq7joSBVzczMqKenJ8EdCKAvBAoCY+6R317jZ/7xZ/SGj7xBb3j6G3TjyI1675n36sXvebHWj63r5q2bMxBdqVRKBJaxsbGkvJaXl1NEQRTjB5yXN6VoFuKtt+7/+VH/WY888kitrq/qVf/wKj33z5+r33/876s8W9b58+d16dIlzc7OZvL1GGRymy5YTqDCoIBMuGJwxTEwMCApe7nIXqNQKGhhfUHP/JNn6jFHHqNfvuWXVVor6ZOXPqnaXE3l7nKmnrurq0ujo6MpNwlhyKMyHAIgsGgA3aAUCoWM8o7Q5l7zlqTF9UV95e9/pb7s1Jfpzc94s/qr/frY+Y+puFjUUt/SLiaxz8XRF+QDB4TmIV4nHrtORdi82fHB7/6gtms7zuzHLn1MT/2dp+obHvQNKUpGBnAacXJQvJ5zBp6emprS+Ph4JlcZS3naHcsby3rE0UfoOx75HfqGt3xD+nl0cPLgSUc2PNr1NBVoXl9fX3KE6kVDjShf/8wfed+P6FOzn9LLH/Ry9W326Z2X3qnXXXydfmz8x9S/1Z/q6OtBzS5T1Wo1AxUzR4xHRGMcjm00Qvb1/Nl/+lm98aNv1C888Rd0qu+U3nvPe/Wqf3mVlrqXdHztePp89ttbvmKQkXdJydkjcoe74flmPyfNOs5RFh408SC9+avenPZRVWlpaSnNCWKZp0zdsWMOboybjdhbHS0Z5L963l9J2lF8b7j1Dbr2l6/VB899UNdsX6PLly/rwoULCaaenZ1NJU+xA5DnWGFkX7lyJbOpeGNueEnSuxLwKDpv1Go1vffe9+oZD3iGnnLdU7S2tqanXfM0/eHUH+ozy5/Rjds3ZjxOF7bR0dF0tWSxWExerpSFTGOjAi9RcWJUs6NWq+n/ffP/yyj31z3+dXrU7z9KH77wYU0sTejChQs6d+6cZmZmtLCwkHL1lUpF1epVdnGEdMmXeImQXyyBYfamJ7VaLXPYG4H3fv6DP6+TQyf1+i97fSL5DWwM6N7le7XavZqB23p7ezU6OqrJyUmdOHFC09PTCaYmIsMZw9HwwwxRpB6C0syBKhQK+h/v+x86NXJKv/7Vv56cl4nihO7Vvbqj/46UY/JcoTNDMXjk+XEinbnqiBEyz+85StEsBDk9OJ3590/9w0/pxvEb9cTrnphk0pEQSRkZYy3Zk4mJCY2Pj6cyJ5CjVktr9hq33nSrbr3p1oYRmLxab0m7DDXPGHOFeR2amuWBFAoFrVfX9Vf3/JV+6fG/pEf0P0LlclnfVPgmfWjpQ/rbpb/V49Yep0qlkhwuGNLeKpizhiw5QSy2p5SynQPz+h80Mnfk9b3n3qtbr79VX37iy7W0tKQvnfhSPXzw4bp7+W5NbU5lHBZQLW/mxHxADwme/Hx7UBVJYHvBynvNnb8rFUoaLY1qfWtdm7VNbW5taml1KdOdjrWMF0lEsl8e1H2QoyOkroXVBUlSf6FflUolMaqBTOfn5xN0ury8nA48g8X2cgGv/yKycxKPwzsoKmps8wYLecvJW/TGf36jPjP7GV3Tf40+PvNxfXT+o/rW6W9VdTl7KbUzvyHoMEdXYDHvkkfu8o1sFupwL90/a35lXpLUs92TSstmZmY0MzOTab4CqYLo2CEickGs/+rqasbL9pwnl3vjqOBs7OcIFQoF/cVn/0JPPPVEff9t36/3nn+vpnqm9LTJp+kLer8gHVQ3/F7CQXRMs4m+vr4kB8wnEog85+1GzX+vUQX7Z5/5Mz3lhqfoeX/6PL3nnvfo2OAxPfeG5+pL+7809w5tJ4o4euIoBA6nN23xO7ilHTjQ18ej72YN38b2hn73X39XL7zlhWn9ItkHOUAJkecmQqY/AHB1pysg9hsxpx+RFb8pTlJGkbps8H+eL/RezLy8Vni/c8vvVVXVdm1bAz077OP19XX1FHv0uc3P6VFrj0qolXQ1BYdDgDGm5zKyjl7EIEem/V7cjGZGrVbTY48/Vm/61zfpjrk7NF2c1u2Lt+vTK5/WU2tPzRiuCN3Gz4pGLPIX2Ds3zHnvs5+8R/141+JdeuxbHqvuQre+YOwL9N03fLfGi+PJ+YqVJzjy0fjnOQf+eQcx2jbI29VtvfTvXqpHH320ruu/TndfujsRuqInArwLzMIGQUsnyiVycGWKMcZoeNMBPHTqVvcaL33cS7WwuqDH/N/HqFQoabu2re+96Xv1pd1fqjsu35F+L48kkAel+CGPAtOod9fMSM7L9pZe88HX6OHjD9fJ7pP67PpnMwLnTHEMZiQMkXMFfscY4wC5V9vX15cIeG4oSqVSxrmKg33+3MLn9LmFz+l7vuB79D0P/h594N4P6HUff52eP/F8new/mbqzcUidXMRzR0euWq1mohrP/0k7deMYlDzmaSPjrvm79IYPv0H/9bH/VS+55SV635n36WW3vUwvedBL9MD+ByYY3Rml5MLJo/mzrK9fvT/WYVPQIiIHX4d6jR6aHW//9Nu1sLag5z/i+Zno3XPaMW/nbHFgSVIH9TroHdgo7Cj1GJV5q0+P6nk2P39wOqRshyaPmFdWVpIxZHg0V3eKhYJG+0b16GOP1q9+4lf12i96rXp6e/SPS/+oz65/VlPFqcz+4Xw589ivXoQktb6+npopRfZ7dEKj09nsvrzoMS/S/Mq8vvLtX5l05HOmnqOHzT9Md+mu9HuetuNqSJ4HfcJFJU5Q5GzA9Hc9Hi+b8KqMRsYXHfsi/dwTfk7He47r3OI5/eqnflXf98Hv0y8/9JdVW6slmwTHyR2cQqGQdKWvaXQ2D1LO2zLItVpNP/hXP6hPzXxKf/RVf6RqOVv25K348MbZEA6F0/291SGRm3u0m5ubqlQqiWQFzMPfrays7DvnP/zkH+oPPvUH+vWn/7puGLpBHzr3If33D/13FaeLOqmT6bkwvHim1GLWarVdHcZiAwSfdydzatKOMPzIP/6I7li8Q//7Mf9bm7Obuc6DOxXSzgECOpWUYPVI4HI2NcZseHg4KTsOEdDrfqNaq+oLj36hfvzxP67V1VXdNHSTbp+/XbfN3KbvHf5ejYyMaGVlJX0mkDQXNFAi4h27IinKb4Hy0iTPcXE3dTNGpFqr6otPfLFe++TXamtrSw+deKg+dulj+pPzf6LXXPsajY+PJwdyeXk5U1K1vb2djDQG2Nt2OjxaKOx04opIAZeVxGs8mxm/8c+/oac/4Ok6NnhsV17bo0MUPTll7wAVO3Ed1C1seaOgnZpSSsnGxsYy0S/77nCkn4u88qhYZ0vpIDIYHZT9IPRisag3PO0N+r6//j49+S+erFKhpJuHb9YTxp+gzyx9JjWo6erqSmVsBBvkNal2wInb3NxMHQTpR8/nxdx9q4aY93vbHW/TH9/xx/qlJ/6STvWe0ofu/ZB+7tM/p1qxprHusYwzs7q6mu7+3tjYSNwOyue8WdHq6mo6m34uSUXh2EISbMXRe+r1T01OwPUD1+umgZv0zHc/U3978W/1RcUv0sLCQiIbg9oyT9aaNfWUXV6jpIMYbRnkH/jLH9Bf3PkX+n/f+P80rnGdXzifSc57noPDgjICtvA6PDdgkIvwGrnXFyOMYhgZGUlGGcW313j5375cL77lxfrGB3+jVldXdbrvtD4781m9/Z63678U/oukbFccBoaL75lbzI14yUG9tnWtbijv8dJ3v1Tvuudd+r2n/J6GtoZ0buFcXc+YCBilhTfrz4TxxRN0yAuhhNyG4BJtFgqFfQ1yoVDQ8aHjevDUgxNJZXt7WzdP3Ky/u/R3GpsaS1cO9vX1pXVdX1/X3NxcMsrunHmaI9aQ4kR4owSUt3fKavRwHR8+rodMPySj9B84+UD9+V1/nnLdMNO99hsD5wrW4TnWm+ci6gMtwPDRuW5sbGxXzrbRcc/CPXrXXe/SHz37jzIMWEezMAYQblCKvn7UHudFyOx1p0c6N8UdvTE0NKTx8fFMDpa9Hh0dTZGkNw9BTjxY8JIvbi6am5tLZ53/932qZ5Ad9bt5+mb9v2f/P80tzelK+YqGNKT/+g//VScHTurY+DH19fWlNY8OETW+PDfQOnlkutBFTgFzbFXX8Pc/dtuP6UWPfpGe/cBna3l5WSe7T+quubv0zkvv1HN7n5tkGPb3lStXtLGxkbgdrJszmB1Z5O+RbaopxsbGksw3g75ElMBLx0Z6RnSy76TOrpzVDZs36MqVK1pYWMi8SKOS2+Ys+oU7rdR0tzJaMsi1Wk0/8Jc/oLff/na981veqWv6r9HCwsIu3N3r+TDGbEh3d3fysvFAnB3rh6dSqWS8FRQEBBNaa+5lGFjAla0VFQvZXGOpWFJNO/kh5uERPjfv+CFxlqTfB+v51TzvtdnhOawXvvOF+vM7/1xve9bbdKz7mObn5zNC6MQOHBQcHLxq6aoHS4QbBY3vOexAg6Ojo9re3k4KkZuU/H3rzf1xpx6nO+fvTHBVrVbT+fXzOjFwIkWY29vb6unpSZENkNdejFo+w1Mb5Db7+/szZVMYNm7MadSoPf7U43X77O3ps0qlku4u361rhq7R6OhoclIg5HAWHEr3vLATBuM1o363NrINZI2yaqWF7Js++iYdGTyiWx9wq6rb2d7EKHnaGALzEwGTOyafz5z8zvD7In9cLBQzBhlyJ4odFrhDkURoTuLySJgObpD1iJ5cTlHQrqfyhp8ffm9yZFLDfcO6XL6sD85/UN9/8/frxMgJDQ0NpTnAs0HmHXrn5YRR743taFYndEyhUNDq5qq6SjsM/N7eXvV090gFJWSM9WUNKVN1HRpzx+hGjBxOHheU+L0FeVFyI0Y5ss3Xqmu6sHZBjyo9SgsLC6l9M6W5PAf7huzDQ8jrx/7vLkL+/r/4fv3ex35Pb3vu2zTUPaRLy5e0uLKota21XYqTyMyjAeoZvTUmRgN4wxsjEFFJO3cl0xSf3x8cHNzTMDCecdMz9LPv/VmdGDyhG4dv1AfOfkC/97nf05NGn6TCUtYg40HDqHZyEBvuUX6MjjsNc/zgO35Qv/+J39ebn/lmDfcMa2Z1Rosbi9oqbOWSOiKDEdYjxhmGcj2oy9MK3CHa3d2t0dFRLS0taXR0NLWS22sUCgW98JYX6vFverxe/4HX61k3PUvvO/s+/cGdf6Aff+SPa7j3qqFE0RDVUwbkzNeYJ+bweD6RyL23tzd54H5zmJejNbIvL7zlhXrcbz5OP/UPP6VnP/jZet+979P/+fj/0eu+/HWp/SRzAXpcXl7O5ImRJeDJrq6u1F8XR6+7uzsDUTuT2Z29WOe836jWqnrTR9+kb3v4t6lUKGmjurHLQDl7V9rJgQObOwvYm4AcJJRX2ajozrk707/vXrxbH5/5uIZKQzoycCTN0xWnE+xYdydswZMgHcW+eCqnXC5n+qWzBuxXI5D1X3/2r7Vd3daNozfqM7Of0Sve8wo9YOwB+tYv+FatLa8lo0QFBs19kJcYOXsuHFnh8zq1/sz9GTc/Q697/+t0zfA1um7wOr3/yvv11gtv1eMGH6eulZ0gg3UjLZD3fhhI+ro7XA2CBdJJSsmrPJp5tmKxqB+97Uf1lGufoiM9R3T37N16/Ydfr2KhqIcXH65LS5dSx0i6GcKxYU68z35NVg7KKLdkkH/1Q78qSfqK//sVmZ//6EN/VDfXbs5l17ni5wEhijg5iIOCYedA4TESWVC/7MLRSPnNLzztF/SKv3uFXvSuF+nKyhUdGTiib7j2G/T0/qfrrspVwgKfHWF3ZwgSHXtk5lGqG8ZObeL//vD/liTd+pZbMz9/0Q0v0gNKD9hllD3ajblkYDj2R8p23XJvnxxuqVTKwGx7HcY4HnPNY/TW57xVP/I3P6Kf/Ief1LWj1+o1j3+NvubE1+jy5ctJkQIncqEC0QxOmn+mQ0tejoVj5v+X1xe9Ucj30Scfrbc99216+d+8XK++7dW6fux6/eyTf1bfdPM3qVKppJ68KCgULWuNsndZgitBvb2UbWfqisp5Eq1EQu+66106s3hG3/HI78g4zE5miqU0LtfeDCbvNrCDyqt96PyH9KT/86T07xe/88WSpOc97Hn6X0/9X2m/IxKFcsWYra+vq6+vLznWpMNwhFgPJzvyyuvuttdgHcobZf3o3/6o7l26VxN9E/ram75WL3/My9VT69FKaYewWq1WM3Px5hXe+5nPRuaJRjuRN47jF5/+i3rF371CL/ybF+rK8hUd6T+iZ13zLD2x8ER9buFzGdlmvbxXBOvg5DtkHa5EJIziTNXrqb/fmvM75yvn9V1/+V2aW53TRN+EHj72cL3uga/T8tnlXTfgIR8+N9bY03itlJC1OlqDrH/i6qJ7gf3CwoLuvfde3XnnnXv+bYww2TBJ6YChKHn/WI9crV7tS8yB8Tra/cZw77Be/5TX66ef+NMpnzo7O6szZ85Iyt5QQo6SjfKSFBRqBvrOMcCd3MCtH9upH8Zozc/P6/z587pr+a6Mgc373OgkSbtbX/rf8GxEdV4j63nbRgyyJD3j5mfoq2/66kxObHFxMVNehbKRlIlcPFrgc1FODjM5i9893UjOaDbH9oybn6Fn3PyMjLO2vr6+i0PgxDOewRsOAAlD6omsTs/bR15Cq7DkU298qmo/sdMQQ9rdQtK/RpKUK6Z6zuZBjCde98Ska/Lm7LAye4Lx9GfwNBN/FxESN8ru+CHrsS52r1EoFPTchz1Xz33Yc9PcXI85s3uvCxmcHOsyjYPNZ3VSzxQKBY30jejnnvZzet2TX6e1tTUtLi7qwoULuuuuuzKOPnvhCISX0uHIS1f1JfobY+3wcr3Woc08U6FQ0O8863cy6Zi5uTmdPXtWt2/fnuFz+F64k+rDA5yDlnVGQwaZzS+Xy5mfu0HOI7PEriuFQiGT4EdI8fYi+SLmTzhoGAcv6neyhR8Yn3s80J5PyuvU48rUS7VimYgfHsqOIK9Ry7ifMmVtG527f1bs7uQEiryXv/dexpSDA9TNy693xPvda+5RlngGn7/3sfXcvbck9c45RMiSdrUtdUUGHAs0yH5wyJaWlhqeO/+H7DrD28vM+Owow+7QOUfCo1QcD+Ta+4pHpbCfzMTh5xVolt4AsTUj5C6+Zz6VSiWhEO2UOzU7d/4/8kyQSb8ly6Hq2KvadZKXfBWLxcw+uIyT6+3u7k4GyOe7l6yz7l7SFKNwlwE/Z57moCyHCJ9n8KieyA+0Ma9MK657I7odGXHEIJLlvFpAUkYfIy8xD46ss86kEurVte81d9ePrInfehebvjj7XtIum+TpHM9lk1NvVvbz5D131BoYZ8+erUn6vHmdPXv2cO6Hcz+c++fB6z/C3D/f5n049/t/7vVGobavyb7qLZ0/f17Dw8MHHrK3M2q1mpaWlnTixIkM7H0494Mdh3O/f8bh3O+fEef++TJv6XDu99fIk/e80ZBBPhyH43AcjsNxOA7HwY6GcsifL57IfySvWzqc+30xDud+/4z/SHP/fJm3dDj3+2s0GiEf5pD/nb0O534498O5f368/iPkMg/nfv/Mvd5oKEIeHh6WJN1zzz2pQQSF7PRXnZub08WLF3XmzBndfffdunjxohYWFjKF19TwDg8Pa2JiQkePHtXx48c1PT2dOhPRfWtsbCzTK7cRD6hcLuvUqVNpvj73s2fPamRkJENx9/IDrsGjg8vi4qKuXLmic+fO6cyZM7pw4UK6ztC7LvX09GhgYEBTU1M6ceKETp06pePHj6f7YulEMzo6uqv3b7Nzj6NWq+16DtiFzP/ChQu6ePGi5ubmUkN1OhlVq1X19vZqcnJSp0+f1gMe8ABde+21mp6eTh2tYtu4PO+ulbkzf57Bu7LNzMzozJkzuuuuu3TPPffo0qVLqRUi7Oqenh4NDw/r2LFjuv7663XjjTfq9OnTmp6e1tDQULqsZD+5aXXucQ9g+sZ1n5+fT3eA8wy1Wk29vb2amJjQNddco+uuu04nTpzQ5ORkuuu5UWZ+q+vuZT6rq6taWFjQhQsXdPvtt+sTn/iE7rjjDs3MzGhzc1MjIyM6cuSITpw4oaNHj2pyclJTU1OamppKZ7fZlp7trnve80SG7dzcnO69917ddddduuuuu3T+/HktLi4mlnSxWEw92qempnTq1CndeOONuv7663X06FGNjIykmvLI9vW5580b2a5ZkyG/ixddMzs7q3PnzuncuXO6cOGCFhcXU48FGpLQgWxiYkJHjhxJOnO/pjF5ctPo3FlPdOP8/LwuXryoCxcupDvX5+fnNTc3p7m5udSCEn2/VxkRtfbceX7y5EkdO3ZMR44cSfof3enXrVYqFZ0+fbru3OOauzyw7jDm5+fndeXKFV26dElXrlxRuVxOZ9NrqL2j2OTkZOqeR9c6mg7t1/0vT97zRkMGmU2lfy3t5ui56qVBXm7DBsRaXQaK2OsfvVUjyqnZvr2xllZSuqHGDRiCT0mTpMzhySuB8lrAWJtJfSAvSWlj/Xn2ujWm3tzzFJQ3eNjc3ExzoNbSP8P3hFIaSaklZnz5RQxeqL+XgWtm7tEYU6O4sbGxaz29Zpo1p3kGjenjNY2NGuRW5h6fgYNPuYk7kbFjnddo+gUS3lyf5iXNtA1sdt19ztvb27tqQL2fOY40Mux95L3DEnLSbO1os+senyfKkNdau/MR2/nGZiL0bPemLHkGOc43zjt+Lk11arVapqmG6xVpR76jnmG/YumQN3Jhjxoxyo3OHSe5Vqulq0OZu/c08AYgsUban8v7TwwMDGh0dFTT09M6cuSIjhw5khvE5OnMvLkPDw9nHAnkG30IxO0lcJRPMkdKtbwG3/fAy7xi33w6jCEr9c7AfmejqcYgfpBpxM4dvHhM3lnGi8QZXky+vLysxcXFtFjusVInxua2O/K8aK+Do5/szMyM5ubmkvc3OzubWqyxgWyYK69arZb64S4sLKSNpG0c3YIwhm5oWn0evnotNLWN1E/ijXu3K7xvb/COJxqN70F2p3FF6tfeEU16bS83hnGgvV1ijODvy3xSVGB5/ZK5l5q6UpTu/TFcuccr8qjXZJ6uWGmystcNOPd1Hi+e6XhbE3ppcXExU+/OGcBBjU63O9R8TqPPlmeM/YrCcrmcdCU6x1s4+vtQe+5zQMa5GYooLa9tsAc/jczf5+4OgHcO86sL0fUYPM4ic/SLa2IDHVppTkxMJKSFyyW8M12jDUKiLHitc7lc1sLCQlp3UML5+fnUscv7XHhgw5pLO133OCOFws7FJt5WtZ1z0JRBdhgDw3P58mVdunQpA12srq6mTkpxMfE0aGOH4HExBFBGbFfXCQOGMgLS8qL/+fn59CwzMzO7bgSJCtVbCxJlAFlKyvwuRm9wcPDAnAxXrp5K4OCTPpB2vGnvm+wRcexo1WlFGxUWBswPe7lcTs1m6LMMekKagHaY7kjcV0ZhL+XFs7jC9S5G3CQk7X/5+kHNGSfI4VPW229CwqmMncOaVZgH8SwOu9OcCEeansWXL19O55m72UEGkPM8JCwa5VbmFtNIS0tLmpub0+XLl9OtQ6z7yspKMqoeFeNAeIMTjMzIyEi6mAWjGHVus5ch1Atc/OINdH2lUkkNZJBp7//tDnNsj8nLETnQoXi1516RPnN2uWbOOGhcKgHUjoPmTj8BoCOm3hiKPSyXyxoYGEhX/fb19e261rWd0VKETF4hPii3xnizfWmnG5S3E8Qj9M46pVJJw8PDmp6eTp5XpwxYPCSeM15YWNDMzEzKkXBYeB6iNA6MRwtAw7VaLXPvLYLa19ensbGxzPN0akRoiU5KGDZ3KMrlcrrhCZgUaGhycjJd7ec3+MQ2oJ0cHtkTPbgjgUHGSACrYtCAeevdy3vQEVvkIsTuPnjl3LmKwiRakHa3OL0v5uxnGHmBMxEb7ks7UB6RMYq13Xae7T5DjPRZc4zw7Oxs5vvFxcWEAGxvb6c9wPGPHe7anVtEf+L84NiQd0VfoDNxDsh3u5EBWRweHtbS0lJyMLxPe55MNRJlxrn7pRs4mRhkUpY4bN473q8PpU89PeUxXhhlvymPNrfxQodG5h3bCiPbs7OzunTpks6fP69Lly5lEKvYLjaiuZJS5FytVjOIZ09Pj0ZHR5OTh/y046A2bZCJcN3r8xZ8fjWY51SjoHv/XAxdf3//LngveqvtHPoYTXp0gOc3MzOjmZmZzNVcLDbKyaFqz3UCq2J4ISLkPU+7Iw8qjVB1bBvH53Z3d2fu2M273/Ygoci8PFWEeDHIOEPSTh90v9mpFXirE/Pnq8swyivuw/LysqrVasaA5eVr3UAf1DM4MoEj5HN1ReUEFyKf2LP7/ljzSGRkvT2FduXKlXSePQr1tpcgW7yf65p2zmge0dJbW4KcgFo5SZQ5oaeYl7deBaHEEEBwzLta1PkuraxvNMqcUaJjv/IVo0s+GgKUcz34N45/Xp/5+GrkPNTThQQnnlotl8tJp4AAuW538ip6nXUg8CoWixobG0syldfvvJUz0XR8XQ+q8ybifmcoSouv3jMUoYPs4D2NPenertfh83YvygUMIcMDBSpiYYmE/fowfkaE7IcaKMz7S7eal8p7Fl/bGGV6ZM+hlpTyOkSXbpD9ftuDvsWHuTvL1yNkWMmOTDAPSBR41XjW9frfHtSI0TE9imOfYpwhDriTufye1bxIs5Nwdp4TERWt51j5fU/NMF+HrO+LCDlGxU68ZP4YOXgfRHHIkzsaPJefx07Jel6U6X33PVXmnA6cHNJ46ChPH7BvrlO7u7uT81qpVBKhMZJpG3EwYoTv+t0Ns1/ygp7gLm8IllTMQLgibYfj78RFdx7Yh2aQrjwH39EE1/EgbvRid5mODjL7AOcD4ywp0y89D2FpNYBsCfB2FppfGQcM5AvseR7fWIeq+bkLoHuEsZF3O4YsLzcCTBcPrEPSzrqLwuIwdLFYzHjdvjntet7xWWJpgkPVkBWA5zgA7r1yKTgGOe/qs05DwFG5ev6PuQOforCQAS9FwBsnT9XslYqdnH9UuPGShmq1mimRQzF5tLCfUe7k3GOagyjZHVFPM+VdlZcHWXd6vswZJefOtKMQzgHhAnr4LMiRX3/pjlu8Xq8dYmA9Y+wEUr/gAEeT0hqY6oVCIWMA8y5bcXIt61AulzU4OJhh+rpjvV+UHNfcGcsenTvLGKgcdjpOPi83yJ5icuZ0no5pVp78PHqE7HliX3tJ6unpSVE7yII79pC4VlZWMrcJevTsa8SrncCgKYPshhhPH4iCCfNzHqxWq+Wymrllhyg4emLxZpboOTU7ogHDGLNhRGJeFkSkS/7M2XYupM6O9ENwkFBkVKyQF1BGwDKwHzk04+PjOnLkiKanpzU1NZWYjUNDQ7vg34MmdTlSwdwdVvLaY2knOvbyG3ckGiGAdHL+eTftAFHj4OGcdXd3a3BwUGNjY5qamtLExIQmJydTLWlexHAQzxDnTqogL0J2QhfrTprAlepBzTcaBpcXYEjy37GuFMeOKCmPOe6QKa92I/4InaLv3EljfVk72Mbj4+MaGBjI3CoFygITnr3C2cYoY5D7+/tTCg1d5M7GfgFBdICinvPonKClv79fIyMjGUPM87huQX7yEC1f62bX3XW75+zd0cS5h7AFzE89/cjIiPr7+9O8pB2uU7lcTiVd2ADKcaVs+S4RdatOdcsGGW9/ZGREm5ubGhgYULFYTJ5Sf39/CvU93wy5gVyJRxEIcay1Y5FgwbUC90bv1SMav6i6VCplFDzKiA1wr2llZSVzyD0PQXR9UAorPkulUkl5kvn5+USEKhaLyfMeHBzU1NRUxiDDEEThxrKngzJu0SBDeGH+NC8hWuPwE03wclJXO9cBNjr2IpC4YcMjZ+6s//j4uKanpzU5OZkaDXjt8UES0+rNPToSXv7hBjky2z1tc5A8gzhfZMWNLxA1ZK7In8C5Z65+F3vendOtRP1xfZ1TUM/Z8QYZx44d0/DwcLrS1K8PxeGem5tLn+VROAQmImze35+rEWPs6+7Rn+fZWRN0HWcyNsvwvLajKnly3u7I46R4wOVVA9Sbj42NaXp6WidPntTExESmuY2kVFo2NzeXroakTK6vry89QwzQXA81ikowmoas3WsGfsPj7urqShszMDCQLmCnq1d/f7/K5XJ6KO5WlZTZcPfIOsF+ZNQjLDhEh4DhweLV9fX1pcWny9X6+npGmCK8Fy+s7xQU6RB4ZJnC7IW1DgGKHI4X3lNqgEHb66AclGHwrm8OWS8tLaX8jB9+yrVga3o+6r6ArJl/PY8cY+x5P2+G4OkCys3yiGkH9RysvRsMj+K82YSTzxyydmN8kDnkWMvrbGXKh+ig50SpxcXFTETqxFFphzwF8sWzxXK/VkZeOsNTcSCC0cGkC9fExETSP5xrHFYMNc+FznLomtQh+tmRgUZSZtERioY55kd9Db3UyR38SNrqNGJYL+9NUOdzx97gYI6Pj6c6aJwhR3ZpiML3pFt5JqBtt1XA1g5zNzpajpC9uFu6Gt6DyQPDdXV1aWtrS8vLyxni09raWsYbjYsaPTNenczB1ns2ryOLioj88MbGxq6owD3SmN88iPIQXy9nchIZbG9vp3kzJ4eOmFMeqzAe3lYJCvXmzdeYyyQaiMxFDnI0Cr6+0em5r6Jk50S44vVclUdDXhLCfqBEI5JykFFnXp6TKNLhTs5AZMS6MT4o9CeyZr2qgzpjUjRev+7G2IlQLhueinLHOZ7TdpA4h3w9xyjtpLacZAnUC+zMuSC/WavVUrlpb29vki/WCXi7v78/w4lpJJjx/6+nh/P0715RtQdVGKhOln3uNaLRj/wA72zm6RgiX0mpfHd1dTVDwAR59LXy/W3HVrVkkDmkg4ODWl9fTwsdk+REwtXq1TabkcUWB0qOB6pniFuBrH3+7h339fVllFFkRLuQSdrlbTIXYBAvJyICarYnd6PD5xC9cfL5eOIoHggfrux4LwhgngNHONmvThO7fB7espR9z4tm3BN3dup9AVf799GwOe/B5RblC0QanYl6JLqDfI4Is/FCYXqpU722mgeZ50YeY0mcw9M0AIlNP2JFQ3SceQZ3uDvZ7S3qB9dfThCVlDmbOM7ozsi8XltbS05cT09PqlAhjeeymMf6bSY6zpMPJ6oC/zpCVKlUMoQy9BJIljsjnUIL/f3inmJkV1ZWMgggf+OcIl4elEWEwIPD+PMYQEb70Oho2SADtVSr1QzEy+KTBHccnY10CMGNb1R4ed+3OqIxBm73Np8oACA99+aAIcgrEMFJO5H00NBQgiM9Rzg8PJyioE4qsXqeqa+ts+ExqkBcHKh60Q/fMw7CKLtAe3G959mc3Rthsfuy/Mbn7fOPiisvxeLcghiV3Ve1vPE5ItQGRyMaC+TgIGvT49w8HeCpGO8VAJMaYww86SgeLQ6JdjxCiqmlTu5HdALyIjbW2p1mnDbkxZHF5eXlDKmOYMcDnKgTmo3U9nLW8owPKQRgXWc3Ly0tJYY1+WV0Uqs51npr7Skt2N7ob+cMeY39xsZGys1TpkrwhEOBo+EljJQxehDnTst9HiHDrMMIIxgOA0Da4iBg5MgFOTyWB4l2+sDHuQMRISCSUp6A/DCbhmLAqwJKYv7kmmmWfuzYsWSMJyYm0k1Pebnkdkc0DA6L+XPzjEBbeNyeQ/OI03/G+3R6fxyFcFjPjTGRgsNKHmHeXy0cY4QcFZfLtu8BRtkjzvt63ns5cdKO0YocCEe19oq62n0W5JkGGNygBTw9MzOjy5cvJ6jaWdTIDagdipP35f8jZNnJ1FIeTOqIE/OTdhw1/z/+7c2VIM66/GNcIvrWagCTh/p4+sWjZGSc50V3LiwsZC58IYU5Pj6uo0ePprxzRErbdX4ItPr7+5P+Yy3RdZD8eE5uOSuVStrY2EjkVjgptVot/Y73Q8fJ8/JcjHDMszc7Wo6Q8eJ6enoyNHgXNm/i7rAkGx2jUPcc4+e2O/Kie6BR5rq8vJwOMvAjuRwE3pUFhCOPuJ3F7GxDNvkgGal5BtkH+eZa7SqD0wk73l8WJRUhJkkZxdwpo+woSTTIHiHnRTP1jFqraY1G5xwjCYerYdLm5cvq5SbzDFyn0xt583fDzO8g624cpB3ipZ+NPJi91fyrz48z6D2qKYmDiU8nN3LFnEWPFvme53PjF+HiWGHAszQz4hrkGWVfs7h+rqd4HwIDR4b6+vq0traWnqPdMsWI+kTSq0eD6HN37nGevCaZ9OXo6KhWV1dVLO5U4cTGSq2sta+563ZJmXXZ2tpKLPRKpZJ05Pr6usrlcor0o6OPowE/wev04UftBVW3MloyyHzPQ6N4XBg87+qb6wQSV/h4g65c6+XUWtk43zTP9zLPlZWVlGN1BwLjFWEhh/cQQmrxJicnU2Sc12u5E4MNd+UahcFhYHLMCCJrynpAHuFWKn7X17tWq2U8207n2qIBc6MQo8ooE/HZUWQH5fy4ccLRjF3mHMKql2+mDzHzjPLPs3R6+DM4FBkdAnf0gP16e3slKZV4MP9odFqduxsGb3NLHtlLmfwcIsvIqJOe2CvW1m+u6jQs70Y5GuQ8+JrfzYO0JWX65sf5StlrJNtF35CLvL4Q6APXL6w/RhnWvkPXoIuwmsfGxlK9NXLSjrxEJ8ZlsFqtpkt/hoeHVS6Xk24nSuZ3YmqsWLxK4qUsMHaxO4jRUtmTb3wetduNA4c5r+0aHiyHCa8pHo52WI95c3dyDZsDUQLHgOfwiM2fHeMAcxbWbOzCFLvSdBKujiNGmThAKysrKhaLmaJ1Xw+IUuwPJQERigEewzhKatnBqOdQeO7VIzXkAGfJDVqMiA7CoEXHwZ0cr3n0rm8eacQGIhg1N4rR6fC5d0pm6j2Hv3z9vSxtYWFBfX192traytyV7UhFp+D4OE/2nUGOuFqtZgwTuqNQKKTfR/+48fb0QSSudeKM7mWM4/PlISP8vetId4b9d1iPvQx+o2se4eoYQKEL2WtfT0/9STt2oFgsamBgIFMbvra2ljHIUnvtS90o+3vA+SFY4kIML/Gj5C/WpXd1dSVjvba2lpGfyK+IhMdWn6Pl1pmSMsY4CpYr8b0K5MkRQjmPBeSdyrNFOAgjBOziuQNn40Wj7EoHWMY71ERDfF81rGBwoBAkYHjIWx69cIh4ZuAacrWQIGBKRjJVT8/VK9fa8RijcYhRrjsQTvSBOYmc5clLNNCdmmNsyOAXBlBD7QSSrq6udH/47OxsKlnhTHh/370itnaHn83Ibnfly++hgIgcSqWS1tfXU7MHZAIEKLbWZLQy/3hecX5JN0lXI0dHcvhd9p6LJCiti9E0ZznKTad0TYSt3dEHDYpwZ9577Tcn/6xWg5gIV+OIxRIyRxf4rFjpwPsh493d3ZmuWcvLy+rv799lPNtFtdxBYQ7Iy9jYWObWOL+7AHtEJRDBIUEH8iMpPWssW4wpj1aDr5Yvb6z3QZEy7wX9NPYGcuIBS6XSruu6vL60k0aZHDHe0NbW1i5nwD/PCUeFQiGjePC6uGQ7XrCdV0pxUPAjX1l/FBCRcqVSSYfIYUYnkLiHCDOSA4kADg4OZhQgh6ideechKgzWjAPuTWZqtZ27tKMDxIv38K+tztVhZ5o20CHtypUr6YYh7yWOEahUKskYb29vp5wWEB5GznP5vga8T7vyk+dY4Bh4PlZSJtKHYzEzM5PIOsPDw+l7qiuQG+YdYdlGhhsXZJJa3VqtlkouI6GGgQwRDdG219MDbow92mkXso6GOEaP7gx5LtbTGvudp3rRdYzI8yLmvd6Tz3eZ8PvIkYvu7qttYPkeeQWpcK4QuggUiQsevN6XdXP93gmj7HyhyclJVas7pblwEjxCJrKP5Dt34nhWzq03V8ojmTb7LO3dpixlFHJebi02sIdCjoHDAADzRq8jj7DQ7kHhMGIAvPsQEYovqBsJ99ZHR0czfYnHxsb2vDXpvoiQPffGoVhdXU3zwcFxZyEaZ5oV0EsahRZLvdibTkXIHjnE34GQRpkFBfr+nNGR83lFCLjZ+dWLIJaWljQ7O6srV64kshGeN9FmrVbT0tJSgug2NjYyt1pRez0yMpLprJQXZbXyDPXW06F/7zXvDiiGmvIQJ+vgjOKIwqZlvdvJh/s5RalubW2pu7s70yOcPfG94Xkg3qysrGTgaIesnfyT5/g3s9b+u24U3Sj7PN0g+/PkoY3xjMTz4rqtVei0XoSMsSKQ8YoRvzSCM8m5IF8bUztukD1l2Sk0yNcC5w20jBSjO8a1Wi05IOg35kUARkqT9GTsdugIVzvP0rZBZuR53jHJH3v8kgfy+zKJkCOUFPMvrRplhzQQsLzP8ugPQeUz2ViY1TCph4aGcqGLg4ar43vHw8sB4xlpwUdtta8tELb33OU9EWoMNjnlToyogPxnMSpdWlpSqbRTVuclR85GjY5CO0bN19QdTu8gFe/dxeMuFAopfUBeynkUDsnH9Agy2Cl5isqddfPGMl594Oe3UqmkCGFkZCTly/k7hzS9qkDaMdCNrL1Hl7wXjSVQ+v4+7oT61Xu1Wi3lBmOpEWuJMc6LjtuJ0vwZ3EAyX19/rwV3kp3/rhPv6tW7uwPQSiAQzxtygUyir/m97u7uDErS3d2tra0tLS0tJfKo5//9vECidZZ7dKLbGZwfYGt+hhxsbW1pcXExnUfmhuyzlkTUyAe6L/Z0r4fm3i8GOUYPsRWis+1QVMC/8eYeasHq3SSDsvI8Y7MjRh15i+eCjpC6QmFz/OVR431hiPd6vviV53V4zmFdN4gcSPbREQPvkdsoxNbonH2tXJ44zH65QLFYzBCS+P/+/n6tr6+nQ+Pv1ypRql6k4obM87CRHMXvOxu7q6tLy8vLuwgkrCdzq1aruw681DqZrt7z5eXJMQrIAvIAGdLPYlTanoPD8eX/95t7lFecduBqPtedZuboV+Wtr69n8pv+/m649nN2mo2S84xxPejajXGsZyei9p9FAh4wrMOqeQSjZvWQP4dzbiSlkiCQTapJcOQLhZ22n8zRjTIOHs6dpx14dSJKlnacLxw7HDd3GNkHzi+f72VT9dICHsB1Aglt2yC7MSb3EC+b9xtwuDaMDfbCcSDfvDIhV/rRGHfC6EXP0A8KkWHcBJ9LNEo+p4Mwynnv7wQYSDVea+mUfuAV1tfZvtIONO85IKD6CK22Y5Cjc8S/3ehRpuXKxmEmcrkccC5xcGjTPfBGjEIce+2zH0LfC4/knHCEsqzVduo3OfhR4ebVn7ZLpsuDwuNzOBxKpA/vo1gspmifZ+J3PTcrKcPRcKLXfnN3RYrSLhaLGTIo55C8J93n+HckVGLg4rNHg9nJ6NgROBS8r7FXoRDdey0tjkXsl87P0Fucj726wO33XD5vh2pZewwun+e3PA0ODqpYLGptbS0hE5VKJYOGSjvETH+evGi/E6Pee0Xn0x3QPLSBNfPAwOuzHcFoVx92xCA7pMjNJPPz85qfn0/9Zr29Xa1WS0pmZGQktZnkflhq1Nyr5XMYLvj8u5k550UEHsHEnKl7055b9jq9gxKsvUbMHXEoJWUa10fCEArSvXCexedfKBQyDETY1/6s7czb5878pWxf83hYUGAOdfE9jQcmJyczUTwvDEW7nmxUXl46Fut5UWZ+dSQyXqvVEvszXhm3urqaSeUMDw9nDEorMhYjn6jEcX7cyMa/ZW9wvqWde8153kKhkDlTMKMbmTu/4z0DpKvy7Dk+fg95wAna2NjIpI3c2Do57iCMMXNzY+xy6tC9pOQ8eOQILO9OW7x8BfkoFArpMgo+yytGGq3w8PXgHI2MjKQI3A0nz+gX6fT19alQKCQiHbLs6TLW39MjGLV4Ztod7kBHgrHrdy/pcseGQMYdSfLjPT09qS0oQSY2II/r1Iw8tWWQY87B7+XlJpa8+20RSi5sn5ycTOQovyg6wmEOZ7gxbibaqWeMY84bVqYfYPe43av1TYnwy0GPqGBR/BAQxsfH0yUX8cpCImRQAfe8EVjWYW1tTcViMUHCfojaidRcKTqMzp6w3x6R4yDgUDhJp7e3V6Ojo0kZ+IjebytGzQ+bQ3k4PJ43Y/A79NgdHR3NXOeJIgYGHBoaSo0wRkZGMn2APd/ZjkF2581bpbriRZGwPy7/klIHJI/meXYges4Fc3d0YK85+n5KO4zZSGJy58CNlxsjh4oZEXrMQwpaMcwxOu7t7U3Reqy8kHacS7+gASSAJkWbm5uphjc6bN3d3UlHecc9T/s14mjENYfFXiwWNTQ0lHG+IwxMn36cSuRiaWlJGxsbmUAGufAoMw+y7gRs7bbJycV+bzkOozPIcUgiSQu9I109E5zTlZWVRPqKzp07X42Mlg1yNGoOVy8uLiZjPDs7m8pAyCmw8Q5Zj46Oanh4OGOMeW/P23n+iGhKaswoR9jbI6eYi/ROQNFYoAgiaS0aqoMeeZ4+yr9QKGh0dFSTk5M6cuRIuuQC0hwRGuscWZUgHThaGxtXb4WKz9kqGhAj5Fh7K+1ED+x5hPg8R8nv9PX1JVIJcKcrphgVNTtfaUchoby8VSDr4gQtns/5EjS18PuI+V1qI12unOFKSqcVZ4KvrlQ9P51X+5z3ezilfsVhjF7ZL2mnPAaj2ohBrtVqSUaJGuPfYbBAcvJK39wg5Rnddo1wnLs7PDHSigZZUiYtxFW1BC7STltfN8ReOuoytldv7v2G54sHBweTzPkZdGa0k2F9zisrK5leDKQY0JuRnNZJLoqPGB2zvt6gBPn1NAjOM3nxuE+Sct8HneRGudlnaskg5xljjBOKPDZKwBuhdMEVcYR0UMR4fs5GZOEcamsWwvOozqHwvFyAHzAn1TBHlKZD3AclYHmD+TlchdPjt6yMjo5momQgLSC/9fX1zOGqVq/2d5WyhfEIXl79ZzNz5itKMw/ak5QOlP8NDGtkhWhaUqotJ1IeHx9P7HevOW0GxXBlFuXB0wIgQM5ORal5NIGiQlH4GcKweG6rVtu5Sc1lrF1HyB2LmON0795/J5K13PmQdm4+Q7aQTfL6jc7djSfyGPeLs4tzEA0xjpc/c4zC4qud4etaL4fM2vT09GQ6WhF9VSoVSUplW5JShOxVKq4PHWaNxjg6V/vN3XVytVpNka+0uz7Xc6zoZeadR2RzJ8RtyEHoyYjeOvsexyaioA7Dj4yMZO6hRteT3vNcvreUrUeqa3Q0bZB5UA6hG2Mv/+C6NBp6Y7DiRuRF2e71en4TT408B8zLZg1yxPfj9xESj5cZuBFbXl5OhfJssHvynYJf6j0D+UkiMK+58ybuUUG5oORFB5F5uLa2Jkm50XGrhiESSDBWKB0+0+WEA4b3idwgWzh87sHS9N6h2WYVQTRizv7F6QGa5cAityilmK+VlJszd3lhL/v7+zNXDLajyKLhiJCtGzHky7vpsfYYFGfCY1RIhzgRkGikUd6Bk7A4i3kGmUja6+w9Os5DQ5ClWGoUodNWhyMxMTVAtOxyDP+GdI0zwzc3N5NuBWX0iM6Jm35RRjOkLt9vL43E6fWo39/TjZ7bBXdKt7a2Mu8T00zNGq39Rgy2XDYdafBmJ8yHOw6orSclsLq6qlKplJxmPwNumKP8NStLLRnkaEC9R68TuvzycMgukjLGHEOOABCpeZ4Kg48gYgBjE4JGvG5+J3qyHgW4QJKbiXATUC7e6eDgYKZd4kFGyBE6xRhTgsDaUXzv0Tx/S0TB97FUzaExh8faheVjpOuGbWhoKF2lFy9r2Ot93EmUlJ4VxIZWfU78ajT/zfvnRZQQX8bHx9NBxVkjP+b7lCdvESbmbOGMkDLhovVOy5fLUp5DikEmTz4wMJCUEXOUds400CCOa6lUSp2NWonu6znbvncereUZZQwbaJsraidltpuGiXOOc8LpZB3d+aINa7VaTYaZPUHXLCwsZJoq+XvG3LEjTY1C8dHhlJSMSjSkjlx6yZnnuV13IAfOdWmFCd7oiBEy/CbPH2OQkfFSqZTOM50XSZVUKhX19vZqZWUlEd2knbuW2bM8xKmZ1FhTBpmH/Ol/+Gm97dNv0+1zt6uv1KdHTDxC3339d6t3qTcxq2kjuLi4mGlNRimNP8ji4qKkq8w8J6w4lMPG1mo19fb2anx8PL0fG9uI512tVfUTt/2Efvdjv6tLy5d0dOConnn6mXrGyDMyVP+VlZUkGLHZh6SkGFG6/A2djg7KIJ8rn9MPv+uH9Zd3/qVWNld03ch1etUXvkrjg+MaHR3V+vq6enquXonpuZ2VlZVkeLnZhOgsEkuo9fWSNc/Ttcomz5v7Kx/5So33jycuAYfEyXKsp0eaCDpRhh/i6OxVKpWkqPxi90YjtUKhcHXuf/PDesed79DK1oquH7ler73ltbpm9Jq0NhgCJ2oB++UhEh490TlI2kkReHQxNDSUHL5mnaHMum+s6NqRa/XyB79cfYW+TF7QIUgcORQL6Mvw8HBKGTiHg2dD1iSlPeNigWYN8nvueY9e90+v04fPf1gXKhf0tue+Tc960LN27Q3zjXlb5s5zMZzZ7NCjy3ZMazVqJHyP3/CRN+j173u9Li1f0oPGH6QXPvCFGhkcSY6zpJSO2djYSLrS90DakWdQH9I1XV1daV8w8t6sIrKs93oGdyJ+6r0/pZ/8x5/M/P8Dxh6gDz7/g+n3PA+MsaOnO8houVxOxC4auxSLO5fZ5PX7j+vYzNiubuuV736lfudjv6OLlYs6PnRcz33gc/WCG1+QQXBp4UxrW6Bq0lyTk5M6evSoxsfHE5LBsy0tLSW749EzZDxHlLyJSqOjpQj5PWfeo+98+HfqwSMP1tLykn72Iz+rH/rgD+k1x16jxcXFFBlT7hRLQVxhlsvlVM4Aa1bagUzd8wLO7u/vT/AALNRGobCf+cef0Rs+/Ab95tf8pm4ev1nvO/M+/Zd3/hfVrqvp4T0Pz9Sy+gJ71O4HF7LT4OBgbgTTSaM8vzqvx//m4/Wk656kP/+mP9dY95g+eemTGte4+opXW3kCd3m3JebEmsUmE5ExjoC54AIFR4XV6PPlzf3jFz6u8dq4emo9qZ6RaMXzPU6ucy+U/cA4uOKMCA5MyJh/a2T+C2sL+rLf+jI98don6s++6c800TuhT135lI71HNNocTRBe+RNV1dXE3cCY+0K0T/TIwa8cc9Db21dvZyCfWgWsvZ1/7Pn/pmGS8P62LmPaWBtQJvFnYoHf7nckx6CgDk4OJhgVaI35uI5b37W1dWl0dHRDNze6LovbyzrEUcfoRc88gX6+rd8feb/8tIteRCxn12HWF3mvbrAG1mwV80ah0KhoLd88i16yd+8RL/0tF/SI6ceqV/8wC/qBz7wA/r1h/96IvahzHEOuFuY93BdCKEII8LzUULnnaO8fWwzEZqjIg+Zeoj+/Dl/voNCFHbaDTsiRdquXC4n3U8gRmtYnFLOCYheXu/ndqLjn/nHn9GvfuhX9VvP/C09aPJBev/Z9+t7//J7Vdwo6slDT85A1g79o1PovEjlz+TkZIYlz/o40xo0C31Eioae/83agZYi5Lc/++3JKyqrrJc/+OV6xm3P0O1Lt6trpStzd6nXHvNAkYaOQeah/HBjkIk+C4VCIiXRB9hrIPcb7733vfraB36tvvrmr9bW1paO9R3TH3zyD3R75XY9euDRyRivra2lKMEjYxYYiAtF4DcjHYQxlq4K3KnRU/rNZ/5mQhlODJzQ4uKi5jSn4eHhlLNkPszJ4T33nHkmfpeoNPawlXbn2JuBl+LcNzc3dazvmBYWFnRl+0o6DCgdYGbgZeQjr6yAF5GDG2o3bBjiZsu1mPubnvWmpByvHb02yaYrbSKdkZERLSwsJFTHc8WeU6tnWL3hQF9fX9rHZg+5rzvnabI4qYsXL+qeuXt25R9R9KxXzM06jwLj53LgjnepVMqQHZtFVW696VbdetOt+/5ehFojY58Xw6M7coBxfdtFuH7+fT+v7/rC79J3PPI7tLGxodc/6fV65z3v1N/O/62eNPik1JsbBwG5d7Z6NMr8LmcXxZ/X/z8PCm5kHfnaXerWNWPXZAywO16sWx5vyB15AhZy3k48q2eQWzXK/3T2n/TMBz5TX33zV2t7e1snBk7o9z/2+/rX2X/VE/ufmPbcOQMY0UKhkGnl6RF8qVTKdH1zEqOnVDHG7RAvW2ZZI7TValXl9bIkqbfaq9WN1UydmdPKUZpsKE3InZ3n7+8Hxm9rAd6LkVojD/+4U4/Tr33413TH3B26YfQGfXLuk/rQpQ/pBx/0g+rf3IHlgCX8M5xU5KxYP0gHybD+09v/VE+78Wl6zh89R7fdfZtODJ/Qd37Bd+rZNzw7HU4g3FKplOAtVzbSTs49RpMIEvvGe3FgHA5rloyR5v6Hz9Ft99ymE0Mn9IIveIG+7tqvy9zmQ7QCJwGonXpGn3fMIxIJeTlGXoTQ7IGP635y+KS+51Hfo2976LclghhKnrtX6W+OM0BOjsOL/MbmBC7XKIx2iEZx3Y8PHde3Puhb9bSppyXl6Ll7HE0nz0XnzeeITPmoJ/8HcSbyyHYYZZSrw9ZEyV4O48/t6EmrY2N7Qx++8GG97EtftkOM6+rW448/Xp+qfEpfffKrU50vsg1/xpsReeTPWXP5hlBId8NI4GzFwPG7d8zdoWt/8Vr1dfXpsSceq1c94VU6PnA8Q+LlxjLQUK+qQSfy+bEkK6/Ot938Mbr9M7Of0Y1jN+pfL/+r3n/x/XrJw1+Syed7WoOmJTgM0k5KA2eaZ0ZO2C9kpKurKyM/7diAtjt1VWtV/dIdv6QHDTxIx4rH9NnqZzOwG98zyEOura1lLgnAE3GSjsNLRBkR7vD60kY282Vf+jKV18t6yP96iErFkrar23rZY16mZ13zLJ09e1YjIyOamJiQpAw7nOdhY3gVCoV0r3K7zOP9xl3zd+lXP/SretGXvEgve9zL9IFzH9CL3vkiFWoFPeOaZ6TaQc/zOBM2CpJDT86KdA+QvGtPT08iPHi9eKPrnpn741+m99/7fv23d/036Uulpx59aqakB2fN0xQcDj80/tWJPX7gneTiuapmnAmf+8sf/3J94PwH9MJ3vFDdhW4990HPzURc5AfHxsZULpcT+YkIbX19XYuLi0nuIUF5JOnphkim4fkbVV7M/YW3vFAvfdxL9f6z79dL/uYl2vriLd0ydIvGxsYSDOrODo4m3zvRrFS62ubTURhkxhVfJEN2Qun6iIS7WPoT71Z358zzyB5AdOIMz6zMaLu2raODRzP7d2TwiO6YvyM5n8g6OXccNCcDwlaHo+DPxz2/lDNSyujr3Mpa33LNLXrTM9+kmydu1vnyeb3m71+jr/zdr9Rtz71Npe1ShsTLNYZ+7WilUkkONOtPusOv6uSMtkJ+yhvo9gf/yoOTbv/RW35UX3/j1+vixYtp3fMQM0c+K5WK5ubmElQNA35hYSFxISSl8klSedFxvs8Msm/2a//ltfrc8uf06uterbXLO6xLjyiZGJ4frDu8fy6bdzakk0y8uJ7uXn7VYTOG4S2feIt+92O/q9/9+t/Vg6cerA+f+7Be8q6XaLQ4qsePPF6Tk5MpJ0mLNEg6ROpupL1hwUEYYR/VWlVffOKL9donv1a1Wk2PPPZIffzyx/V/P/l/9ewHPDuxLzmQ8bBHNqkLo0c90s71it70YnR0VNPT05qcnEy1vV6n2MzcH3H0Efr45Y/rtz/123rmdc9MOTVpp9kDBqpUKmVykJ67Z3i5ht8eFm8R84ipUYXlc5ekLzz+hfrE5U/ojf/yRn3Lw74l1Spub2+nHNTU1FTqbsbcJaUcFgbBjQIKQlJSwMh9HiO1WZmpVqt6+PTD9bFLH9Nb7nqLnvYlT0vKxR0vnBbmDTEQB0nK9vXFWfJ9oN7d4dRmWb+NDDd47P/W1laG4ES0HA1yfAZHB9pRqj63TH67ePXzMQrIQKVSSU6PO/rIFVEx6bR4BSZOMh2z2s3HkiaoVqt66NRD9cjpR+pBb3iQ3vLJt+jWo7cmwqc3gJqbm1O5XM5UmuA8wLLnqs7R0dHMRUK+N+3IRtTtHzn/Eb34XS/WePe4vmLiKxLHJqIl8Gu2tra0vLycnFGqVDxAgKhZLO5cPQlKlrfvzY6WI+RCoaAff9+P6z0X36NffNQvqnelV2cLZ3dNwo1yjH7xCCPMwgtlRGcjmltw//Do6GjTkdpL3vkSvezxL9M3PeybVK1W9eCJB+tzc5/Tr3/y13XrV9yaFBRzqtV2eg2TM/EuXg4pSgcDyzGODx/XQ6YfImln0x88/WC97fa3pX62KCYiM1ICnibAofCIAGNcKFwt40KhjY6OpoPkbU7ZE4+qGp0783/w1IP19tvfrv7+/iTQTpxjjigkDgTeLfP18ilXWMDG8Z7telB2U3OffrDe+um3ps+u1WqpE9fY2FgmSgDG884+boRj0xsiC9ifeV2eGj3szN2h3QdPP1h/euefanT0KiEN+HNgYCCjqJwbAUzp6SRQAVAZziH7wNpHxdspY+z7gYz4/OAg4Dg6MdONH3Lm+9EoJyVvTA1MqVQo6dLypYzDMLs+q6MDR1N0DOqA4vc5eX4eR4OGFUTF1L+Pj4+nngOdqut1fTbUNaQbRm7QHbN36LHdj03RMNHx5cuXNT8/nwiYkhIqxZwnJyeTHuFcegvfTjhq6PZv/oJvTs7E3fN361c/9qv6mq/6Go2Pjyc5pRoCtBbHuFKpaHV1VXNzc4nPBPcC+XBUFLjayXjtPEPTBpkPe+m7X6p3nHmH3vTlb9JYdUwzazO7BCFCz/4zIh2Yvx4Re91eoVDQ0NBQgj3ozcyVX65gG1mIlc0VFQvZVnrdXd2qqZaEh7nivRIFxhIJDJ0/20GOx596vG6fvT39u1Ao6M75O3Xt6LUZFmChcLWMgrVhbk7EIPpx8pCkdPgxtlzUMDU1ldY9XgDSiGHzubP2d87fqdOjpxP0w0Au3CNljg7BeykU7E2UAK889mmjpSD11l2S7pi9Q9eOXptQHWTajbL3tUaWqCyI6x9JLygpz4dGR6LZuSMfdy3cpdMjp9P1lKBBQNHA6A7tEUXgxOEQ8b7k5ojsiOa8PrZTbFofvI/vAyRFrz3PY1pH2JrnbTc67in16ItOfJH+5q6/0TMf+MyrcyxIf3/v3+vbH/LtqfQOmXBHwZ1jd5AJTnA0iTTduCEjnnZoZ7BOi6uLurt8t548/eQEU8/NzWl2dlaXL19O9xUsLy+rVqslvUMLThwH5gpa1UpHsb0Gup1nh5yGbvegA9RzaWkpVUcgC6QSPBXluofOc0TJjiy2Ex1LLUbI//Wd/1V/ePsf6o1f+UYNlgY1tzSnha0FbWqz7mSicSayhNDl0TEKSFLqfQuDzevuWoHBvubmr9FP/v1P6vToaT1k+iH68LkP65c/9Mv6Tw/5T5k61fX1q3fqEqk4u9BzsfdFZMx44S0v1ON+83F67d+/Vs956HP0gXMf0K9/5Nf1v7/6f2cO4vb29i5BZ929/jLe2oSDwVciNb8UIR7+RqPMOPf33/t+vfGjb9SvPP1X0uH1yMaNKlG952L5XRwOj5TzOhc5m7PZKC1v3X/tI7+mNzzjDbvylzFSh63upC5qo8ndEmHCvOYZcTS83KtZWC/O/X1n36ff+Jff0C8/9ZczF7dLO5dFcL6ogMApgoWP/PO3Xh3h+5DXUKeZuVc2Krpz7s7078/Nf04fvfhRTfRP6PTo6fRz3o992N7e3kXsY8+RbWl3689O8j9edMuL9Py3P19fdOKL9Ojjj9bPve/ntLK5ouc97Hnqru3UqfrauKMQ8+PR4Yy52FbXOG+8+K9frGfc9AydGjmlM/Nn9Kr3vErFQlFPnHqiVmeyTaD8mt3V1dX0+UNDQ0l/5N0y56WXnXLSom7/yPmP6Bc/+It63sOelxyaiExE7hLnkwoTT++BGmHAPc3k3B3GfQZZ//pHf12S9Jy/eE7m5986/K2a0tSu34+QbjTM5E84UNLOQefweI4wNrlvRgB/6dZf0iv+7hX6vr/4Pl1evqwTQyf0nY/8Tv23L/5v2lzbTO8dO93wHJEBe1+OR598tN723Lfp5X/zcr36tlfr+vHr9fNP+3k97+HPy7CNXYHnIRYe9UAgioQp986jYWsl55M39//5lP+pb/mCb0l54lhek8eW5rPYA1egeQYsft8K4WW/dXfn0p/DWb4YWWe0e0Tm7+HRUT1CV6Oj7ro//FsSB4Jay9h2MdZ6Iy8OU/v5dYKdr32rSvdD5z+kJ/2fJ6V/v+ivXyRJev4jnq/fetZvZX7XES8+O34fP98rRVqp2NhrPPdhz9WVlSv6iXf/hC5WLuoRRx+htz/77To2dCw1//C5SVnejeuWlIMOsuXISaeMsSTdW75X/+mt/0mzq7Oa6p/SY449Rm95ylvUXenW4sZiBuqPddzAuchH3lnOk4lOICa7dPvwCX3XF36XfviWH1Z1s5pZt1gW5/nkvIDFEURJmfRGnqwcaITMB5bLZVWrVV38zxdTrWqlUlG5XNbMzIzuvfdefXb9sw3VS7pR5iC5182hBz7wyM6bPRB5OIHJP8/nznj1416tVz/u1Zl80/rKeiZx7/2U+Xz3llA+XsYFHEwNdqFQaDjPyvz2m/sTjj1B//gt/5j526WlpUwEz3P4rUGeN/bncZo+e+CMcieyeYTK/1Ne1ezc2X9ulPLmCN5E30u23JGIZDRe7mh4Q/lY8kK5Qytzj3vGPPCuvW6+3po7FOYpHX8mnoU8FzA+vABIkq2sO01LXG59zSMBMG/NUbhu2DhPce5+9Z2kBP/tNfdHTTxKiy9cVN7w5/O/5wzwWaArkcjolQWugDm/3MGcd37jWc1bc0n6tgd9m77tQd+WgT7pR+3tad05Y53RgbG6w/WfE2GlHZb2XnnkRub+a0/9tcw5WVlZ0fz8vC4sXdgl395/gbJUlxv2388iJDZKSxt11BqZu+t2nmFteefuY9eLkczHWjt07Wkn51bwbF4Gxr4iezTQ6enpydUzuaPWwDh79mxN0ufN6+zZs4dzP5z74dw/D17/Eeb++Tbvw7nf/3OvNwq1fU32Va/9/PnzGh4e7jhDspOj9m+e/4kTJzLw6+HcD3Yczv3+GYdzv39GnPvny7ylw7nfXyNP3vNGQwb5cByOw3E4DsfhOBwHOxrKIX++eCL/kbxu6XDu98U4nPv9M/4jzf3zZd7S4dzvr9FohHyYQ/539jqc++HcD+f++fH6j5DLPJz7/TP3eqOhCHl4eFiSUq9nH4ml/G9MTa7gotH4/Py8rly5kgrIYRlGKnmpdPUCa3oB8xofH9f09LSOHTumI0eOaGxsLNXdxVKKcrmsU6dOpfnuN/eaMfH8EgtuMJmdndWFCxd09uxZXbhwQbOzs5k6zDy030tqoNhTM007wYmJCU1NTaVmGwMDA9rY2NCXf/mX7zl3PhPWJqzbmZkZnT9/Xvfcc48uXLigK1eupE469F+ltGhkZERTU1M6cuRI6p4zPT2to0eP6ujRo6kLV39/f8P9h1tZd2qinZleLpe1sLCgy5cv68yZM7r77rt14cIFLS0tqaurS1NTU7r22mt1ww036JprrtHk5GTq4xvr0hstSWt07qw9DTKWlpbSPD/3uc/p3LlzunLlSpJ9OnNtbu7clcw8YykUMug3x8ACrf1bo4WxsTEdPXpUJ06cSHtUq9X06le/uqF1dzmHNXv58mWdP39ely5d0szMjGZnZ3Xp0iVdvnxZV65cUblcTi0FkW3OKX3Np6enNT09rYmJCR05ckTHjh3TyZMndfz48STb3o+5VZnJkx9nAc/Nzenee+/VmTNndO7cOc3Pz6fb4bzlpF9Af+zYsV1tYBuRmzh3nzctYJ3tz0UMc3NzunDhgs6fP6+LFy8mHUknt3K5nLlopPZvzN7e3l4NDQ1penpap06d0o033qhTp05peno608sa2dqPZZ039zNnzqQOYrC5YVO7LkReuGKRuS8tLWVuOfJ+5gMDA0k+Tp48mdadjovoQVohxxvEGln3KOveGpXrIWdnZ5NNQmeeP39eFy5cSPLCmnPX9NDQkKampnTy5EmdPn1aJ06cSF0iWXfqq5tZ93qjIYPMB9CuzR/clZSkTCtMb4PpNYlea1yr7bQLzKsZ9KYLtBFE8Pz36xVk15u7bxwHG0Hwshi/T5VnKpVKmVrRvPXymkiv12TO/jz9/f2ZPt715h47VtHkAAXv9XVeB4uyp7mK/x5GN86HlqTNXMbQzLrzDLT5ZL1j+1Rpp77Ym27QHIG2mLFrW7M14nvN3R0h1rFardbtje3KiJ/51YaudHjf2r+VU/iLkj7vKc5tZNwB3ui6u5wXCoV0nVze1XfItT8P++Byk3d7T+y1PDg4mGuQm5UZlx13YJCfeIkFe1D8t1aeKNfYVpXbkmhf2YzcMF++0irUyzXX19cz5zRej+i1sXR/okEMDr/XHNc745yHRjsXxrljXNDlQML0pPZ1jfoanU65FvLjNfQYOf8b5o5s04+bxhv15p839+g4M3e/nAZnxW98QqfQehhngo553uXP9azP3e9UaHTd6422b3vi4b0vKHW4RAl59+p63ohBNyC8Mw49tWMrKyvp4EhK3bwkNa2AfbiR8zo/r7/Nu54tKpMYMdPI3JVbV1dXes9Yn7rfHOM8qZeONbusKYqcmlW8vu7u7kyECirAC4cJpcC/O9XyMCITXl9L9x8uMiDScKOwV8OHg8olufPmtYpeH+o9kDFerF28fQilyxqjQOLfFwqFpBQwdrwa7bfMevszxPpJlyNvBci6uyKt12mplcYlzYxojGOdvDeooMUnshG7t8XGQp0cLid5DTS8IYx0dW3pmtbT05Op/eb8ITM4r9S88jx0ofLeD808V1xbR668B3u89519p6Oh95/wgITALe6Zvxfy3Ozc4zP4+jtK4fXfXIFa/LfOkARFvt60/uTZWJfl5eWMMW9n3ePoiEFmsdfX11WpVLSwsKDZ2dl0A4jfayspPbgPjLF3/uH38R6BUfn9arWq3t5eSfkGsdG5O+weDQNzp4ifg85zMF+HsGl2wiZ5EXmtttPMf2VlJTX038sg+2FBuJgTbesqlUoSNCBGIpPBwcF0QBA2SelSbaI24EVpp+EJz+QeazsCl7fmyAyw6czMTILaeR7m5aiJR0HRo273YPj75Blid9zceSOdQbQmKbU99J7ONH7AAHpTAknpd5AVT+MAUyKHjax3fA6gXi6XX1xczPTYpom+R0IoKb+0IUalrXRCa3YvWDdPL/nlHVwAQ/tXX3+HGA+it7Y3R3EZYV5+uQvnlMhtYGAg4zAhE5ISClOtZnsw+5n23tDtrC1BFSkk0l+Li4tJz8T5l0qlTMMb3o/z6o7E0tJSpje3X9bTqJOZN38PDh2hwLHg7ubFxcV0/3exWExBi3fbc+cN5GJjY0NLS0vp80hL4mijK9vRPR0xyN4hanFxUTMzMynnury8nLkpQ1KC6zj0HimgEBBkoFmUAb/v/UPx0po1yNLuLktLS0tJALnVBEXlkb5HjfH9HG5ilEqlZFxKpVK6DQfPjLWpt8Z5TgPKFCHDgJG3BL6Kzes5MDFnwvoizI4ENApb7zdifgfewZUrV1L+240yjhyOFwrI22zWU6ruTLQ613o8g9iVB4Xrt90AfTk859ExZ8KjB5QCEZ3fuDU1NZVyto0aZF9zN8Z02VtYWEgXBczPz2d6bDv8ztwxxH6Lkz9Xpw1c3AtHVUDiOAN+DlZWVlKqAKULjO792PPayzKafYY8WfEoE+THHWdPZyDL/ox+nnt6erS9vZ26QLGnRKjsRyvRWtTj5I3n5uYyl0ngKNNWslgsqq+vL31O7GDF/JB3+qNzNuj37vqmFT3OM0RjjDOEXbpy5YoWFxdTcFUsFlMqw5EUvnqqknOB4y3tBIuNQtb7jbYMsnslbCRtNCGK+OHG+5B2Ii5vy+ehP0KGAfMbiBie0/KfNzN3jzydlIYgQlpwmCMvPyztRMYRgnTol8b3w8PDWl5e1tDQULpTea+5xogSJYQXiyJCuDHIXDvInZ3Aw268yV+6F46nirfYibtsY3TsF39fvnxZFy5c0KVLl5JDxGXgjn54PjleHMB7S9k8UzTMzUJ5eRAkkXFsJUj7QDxrz9k7ERHDzfngfVG83d3dmZt9IAHyGhkZ2dOJ8/XOewYMmsv74uJiMgLk0jiHICjki7lnGtmK9zXnoRbtjMj3wGhwDjyCA14tlUrpfLlBHhoayhBDmWM0Ys3KTdSH3iqSeWGQccYwaMyLO5BxvHkWvxBjdXVVtVotBQfchMc9w3kBQaPry7y51YmbnJys6I4ncDvGq1qtZs4HiBEyThtbgqzR0dEM8tiOMa4n5xjk+fl5zc7OqlwuZ8iWpPGQY3eQHP3lvJdKpdQ2eGhoKDlXrax7HC0bZA/PPe8Aow2DBiyKQSU6xggRATtkQaQAUay/v1+Li4saGRlJi+fJ+HY8K4eIvD83ESi5zHgJgxMyMMieX47eMsLe1dWV26d5L6gmvheCTaTsSgioDgUJcYU7hyuVSuZQ4yWybzg9EcpuV9jyoF8UVrlcTocFwxDhatY7NqmPTGWHyqQs6z0aiGaiB94/wsvsR4QX3WjFCzk8OkZ2XGYw6BBduGqPO6lheO7HO4jzdzjPUx/IOfLjDjTkFuZC6og8NgS7epe+dHLECDmml3BIgUBZbyfoAFm7LuJ965HO3OluxChHrkfMc/u5Z6/Z54GBAUlKtxEh585R4Fzw90NDQ7uizFbWNjo8e+lB1gSnn7kyP0/jgQKxb93d3bv6eEe92ezc+er2KE9XEmCR7sSGOHnLHYzt7e3MVanLy8uSrtoNbkTzC1eaRSbi6FiE7ALoJKyNjY0MkcKjXDYLYw6UJyltkufpeF9yWHEh2pl/NBJEPAigpEzE4JsXSQswWV0g/TOl7EFvBG7PM8p+2J3IwoHBoKJIeY61tbV0rZ57xg6TeZ4rwvDNClwelMQac+jL5XLKYyLk1Wo1w/52eNRLVJg/w9c2rrMjG80MVxZRccTP8DuRYY668geq5CIAnwsGJJYA8l5+7d76+npD8/b5OzkQefVm+p6iIFoghw0TF6MWSWZ5pWedJAFGpAhnAg5FRHwcbieyd+eI9Bf74Z/FvuYRCJuZK+8fL7ZgoE+4i31oaEiSklGL5w6DiePd29ubDFsMCFpda+TD0UtHLZFj4FqQEhBA7kVGV8X3jxeVuCFrd+TZJI9wcTylndSpV24gJ6TvOB+SkqNCWon3cjvE/7VqlFsyyHlC5xvohgnj4KUbPKy0u47ZIdN4ADE+edFlu0LoxsLZmhxuFA1eIeUTsL63t7cz+cV46wq5FAxLfO2XD4yC5mvtAs5auFLFgeGQQCZDEaO4JGW8y04aY49siIqdTASs55A7DkU0CERiCD4IRCShxRINog32tNkSF/aS9/TSCZCaUqmUgZmHh4czd2ojy7Vabdd1dNvb26k0zeHq8fHxzF3UkUTSyB7499GhiAbIr/gjh+3zYA5EFbzcKEeyXacMs0dwzqWAaFSr1RJygsPAPCNK5/IS9Rnr4WsB0W6vdc9DYCJC45F7LFui6xQIYrztDKcbfeGpkhictGoU3MEEmYL41NfXl97Xy34gPm1ubqa7zV2HSsq9erGT6Y08hzmicug89rGe0+b6kK/IHbrcGefsSbtRcluQNcKL0YyGQcp6/AiceyA87Pr6usrlcoI9ULQYaPcMMcrRILf7HE4mwwPFMHCIocJT+4byoeGCG5a1tTX19vYmD6pUKmlwcDCxZHkNDw83xZj1Tc8jkSHcrky4Tg5ICWgyEsD4HTf8nfC4nYhD3ntubi4R5xYXF1NuCmOFQh0bG9Pk5GQqyCcvCKTkjhRRntfy+jqglFmjRp7Lo19/L4zSyMhIugq0VCqlJh6Qr8gLsgYgQV6b7DXH5AMhck1OTmp8fFwjIyOJjORwfaP7EJ/VDYQ7GDwbZ5ZGGjSicNjXX66cQcI6GSVHMmDkTzhrFiiXqDNG7zhHHqW5DpOy7Hhfr/1kxp236Ay6PCIDsX4bWUOuvbwOR4Jgx6+IbQf29XlzbnAUarVaIov5OWDPkQVJ6VpI9CFOEmcS2XJo+KCct7xoGZ3mZZTuWLuz4HrQU5qkHt0ORduHnDT7LB2LkOtBeb4J5EnIqzmJaHV1NeVPMGAUqAPdeO1nrAluVgj38qJcufsB5zlQlnQNI8mPgujp6Uk5F4e9YfQRbXg+sFHGrOez/N+ReBIjQzdgQ0NDiSVInt4boESjHyOpZoYbZEh/GOLZ2dnE3PS6YxyfkZGR3I4+GKVC4SojHIIPSpnDH40wKAFr00iqwKOFaNiJakZHR5Mz19PTo/HxcR07diwZ5K6urkw5CcY5MjpxjEAEeG5kDRatQ8LN7kV03LzMw5UmMPTo6KgmJyd17NgxHTt2TBMTExlSVF6TCv4vwvGtjLwzCprmBhnWr59TEAaQFS858zQN+oazj+wTZXN2OU/7DZcXuBhxr3F+WTNvyINcrq+vJzQEaFraidZYC56jneAkyjnyTcMekAfm4ueKiBIGeKlUSjloqiRwsh1R8WYbnoLq9Ii6PiJCXrIXq1Fi7TQGuaenJ1NH7QbZP+/ADfJe8Bf/55Be9IqcWOFMNgrjYTrTYSp62O61dAKqjofdF5iD6NEQEQOKEqgGz5A5sg7d3d3pvUqlkoaHhxP051HyfgSdCHvlCRI/w+jkeZvu/eJM4DHuVz7U6hp72sEZj7woK0OpeA6V1p5EyORkeU48caBv8uKuPByyd3LYfsrVnRyPDGKXHvYPIs7ExER6DQ0NJYdNUuoI506BM8VRzBhEUCVg12bZ7tFxjohKPKfFYjGdUQiBY2NjmXav7IFH6t7xbb9ytFbkJ/JUnH+wtLSUUi8YEkexMMbkjqUdVM4Jlu7sS0pRLO/rJUX15CWez9gVyjt08d55chDPtee4MRTFYnFX6i4vOGp0/flcj463trYS2oEO9OoBh/I3NzcTIjg/P58MN06yoy97Rcitjpgu2O8V/8ZtgaeXYrrUkQuvF3fY2smazTxTW5B1PW871ojW8+adJOEHOOYYPO8Qv7ayiT7veNidNYtDQC4NBeWMV6L6rq6uRGBD8NgYIk8vT/DyC2+dudfIi9Tc66ZMDIMsKaPEMAz8Du+ZV6qSlwtrZbiQx9wfBC4nRrCWPT09CVEBTfAe25J2RUvz8/NaWVlJBhnD5hEPjl6jzlzM/eUx0IeHh7W1tZWcTGp0iQKcQ+DOH46apAwRhM9yuc/bo/3Wfa88misQPh9ZdXlyJexOgivTPAcjGpJWZccdOiczIkcQPbe3t5OjRg7eHRk3xqSniOJily/WhNw/XbAaSd9EeYl5Uze6PKNH56CCrvA5G87tAEmKZLFolBsd0enEICM7pJC8Ta1zOSSlNF8eQuLr4TIWWfntIiruEHmQ4kiF2w7XTaSSgJthVceGLrHSwlMd7eaRWzbIeR6HKyovlwEqhdAjKUWjKFeiJ8+H8DkuIDHCRiCaPfweLbgxdo8Irw6iDcrIc1JAfXlQBx4whhPYnrwWCpv32G+tnYDQ39+fgaucREfEu729nWBicsWuzLyu2oU0Mkvbze3UavnNEmBa8xzSTk4nRjnkT8m1chAo1yInvbKyknJzkL04GOS23BA2MvKiHu80RL9r5NmJi96EJTYSiYfciYyRSY9SbianH41ZXvlNZNJG56OeIxCdQ5eVTijXmPsjJRDbvEJqkpScYT+r3vpQyhpjnEKXQUfGqtVqivCaWXc3CB48cL78PHidLHnJQqGQ5ucXlbh+JI/pXI/4ivux35w9xUgVC+sQDTJrzTM5LO3PGREFz1HXQwFalZtojD1tBfJE1M9cILniBOFEcAZXVla0sLCQLkXCLrgRjuRa9K872Y2Opg0yCx2VlJOHgLsQOgzB0tJSgjUwUhjkjY0NVSqVTEMIDgTGGLjX80LOmmwGwotQWMwVQKCQlIQxOgMc8rySqeXl5TR/ImNnznoejhzjXmvuuT7gXN9snkFSKofZ3NzU4uKi1tfXUx4okhMkpfcEespjQ7ZjjJ345wX2sbSM54xwsDswHCLeywv+L1++nCDrvr6+RBBD0fT19bWkXHG4HPJmjsPDwxmSB0gI+4Ax4bYfboTy+lkMM7KGki6Xyynq89QPMr/X/Hm+yG4H5vVWqxHyjAPlxNnACcbYRH3QCWMcnThkJq/JBr0OpJ3OSe44Ex2jLB1RQd9gjNlrzoKzmhvJz9aDrD1HynN544zt7avlQiAqXiPO8/LV24IyH3devEbZeRJ77Yk7WB4dEyx4UBJv24vvH4McbICjQXlOf6sjL1XgCCLyMDIyouXl5eQol0pXWxY7IdAdCc4NDj+OGzooGuO8XLLUXMqvpQiZxYt5NTYMWBfCTbVaTWG/R3r8raSkuBA8FggFDcGHq97Iwca8WiMjGgmHwzyCYG7An/FWIQyct7OklIeN53B77hiHwue+H6GBtSKPheH0XCiePYd2a2tLCwsLmWd1SJ6DNjAwoLGxsaQMIuuwnSi5nnJFwXrvbeTC8/UgCh51ggYsLy+nbkKXLl3SlStXUuE+HYtcPr1WsxlYz+UdJchBx7j7+yNfy8vLSfEuLi6mlqBzc3NJXlgHHED2z40xRsJL2PYzDnmpAhjufoGHO0XIgxtnbyJSqVRS7hPl6tF0u3Bj3twjO5+5k+4ganTZ9eiYMybtoHKxdzqBAg4XugyF3aiCzdOLMacu7UTpQO04CH5xDiRAR1GcbMo8PLhwGJUKlWbOrRtkZCGui/cBcAPsaIyXvfLCiDHyHP52HTl3KnAmCBC5dhNoGhIgsu16yvUVsDXnFefbjbHrVC/7agW2btsgcxg9asA75ZB7ntDr0hyqQOjx/ohyCoWd+0Ahl0xMTGRqMvHYm8mrRaMcvR1IEyhCvEaPjt3TxZHwukgia88dU77S7JVdCFoeGUfKMi95JjxslC6H2SNH2Kg9PVfb2KFg9yLmtGKUo3JH0TiJxlEAJ2MRFRaLxaSoUGKwtfkK/LS9vZ0cKW8i0woZ0OXdnwnnSFKKkB1WxdjSHx0D4BdnuJNQKpV2kd+Qbc5AM81w8iJkj7S8zzyks56enoz37w6n74FHO1F+O5U/RqZjdI8zgTF259n1ULzIA9lhP2ZnZ3XlypWMouVv/TkjEW6vESF/N8oeUToi5867Q9rO/mY/XHfyefy+79nW1lY6v41yJTxAIpCIqUPXCVGHRg6OG2QQCk+PdZJvEM8ojqOXJo6PjydHhd4HOKuesmDfWXcndOGcxEAjr/qHlGUzo61OXVH4vKvSwMBAUkqSUlkQhdVEeXETPPrD8BE1OYQJXO3QSTMRsgsTP2M4wcXJO9TbsSEcKu9M5kw8FJWTwtxz945B+xlkj7Y9TyIpIRGxJSBGOd6chMPDupZKpV2RaitEor3WOxKLXIDdk6z3mQ5TOxpB2YvfVoQzgYJotRQk7oG0Q75yZdvT05PWFYhzbW0t1cc6rM4l6XjoKFl/VldwkXDUqDFm3VEabgBYe37OOUBm4u+vrq5m6nfdoXM5RKZQ6vx/syMqeiD8eJFHdDI9auflMC7nxG8xgnewvb2dUg2kd3zvm9Evrhcjy9qNGsY0wuKS0vd5ULkbMj4rRnU4GESLjeoYDIgHAHl7zT65QXLDlZeT9zXB6Y4oXDvD5S0vjToyMpIJCKvVaoqOOWvuACFTzmlyR4v3iGeLKDkGgI2MtkhdMXfkeRNfaITOm2543oGN8ByAb7xvYGyh2Iox9q/8nednBwcHU1QCYzY2FkD5EnX4lWR4RnjcTijwKLsZQYzChtLyQ+4etkc3TuJC4Uo7V/w58zoqonYPia+3G+b4ishFjIw4HChUN3SkCDBybjzJIzkjuN28FX/nkb/nWL35CR3JcBy8q5TnzSEoxtpMd1DaJb7EsxYdIDfgENQiQYfndDlC3rw1ZaPOZr0RlX0kc0XGa1SM5Loxwuig1dXVZIypf4doyhqgv7zlrz9Lo86zG59YOgkzn99zBMfzv54n9s9w1rCfeXd0iXAxHo3Ih6Q9P5u98fX2yol4PSM2gvV0gpXLeidyyXnr7yVc8D34PZ7BiZ8EUuSY2Qv0Fogt8462zWuzfR8bHR2JkPf7WcwJReiFBXQ4lveJ5VBRqXYC6mDTqDHGK+/t7U3NKJyEhcLCY+dFHpAcDNC9t+/zcoFWhTBvjfOgEzy7PIMn7eR++P8I5cbD2G6UGefgn8u/i8ViUr7z8/MJbiYKJRfqV++BvPBcfgghGIJOdKLXshunmJuFLETzE7+v2pmaPKujSq6w4Rp4y9BGUxyMGKn5Z/lnQnZj/Xiu6Nh5k35n/K6vryc55yUp5fybXWeXEc+lOpEr8gFQiuTfu7q6MlC259LL5XImdbAfmsW6e9603nrzNRpkHP2xsbFUCdDX15eQK9I3nFfWgeF5WiJOd3xYAxwSL4f0c90IdI2Bj3og7gsBCXl9bt0ql8uZJj8emNCf3a/BBNXrdC7Z+SPA1u5weQTNPuTVFnuuWNrppkd3MtbB15x2yM1Ex1KHLpfI+1lUtNEY+ML5ArqRdaMbIZpOMfPwojgsx48fTwqqp+dq1yVvGSjtEC6IfDwfVyqVUjMI6mjzypzaEcBo1Jw5TQThV5o5wYRnr9V2SrU8OvJ9y8ufOYTVieFRGfkajDHs89nZ2fQ9SjdeOAL05AQ8ukx5tyu/maYVGYqGgntjUfB+nSHlEs68p9RM2iELovydDepdpsh/0ZCjUc6EyzfGHqWNIh0aGtLy8nJSsBhhHIfu7m5VKpVd5YzDw8PJMFcqlcyciSRaTXe4bLuhxRGL/ZtRmt5UiJ4AREIx0oYYhtxA/qI7mVdzNCMzrldcP1arVY2NjSVUBCMQa1xjxObr4dCwvw/QMvtG9IYD1kyUHJ+jHsLlzo1fV8u92lSZgDTGnuicR2+l3MkoWcqS1AYHB1Wr7bTkBaUYHR3NOMoxoOFcsPbINoZcUmJiO9qFrm+Wr9KR6xc9unJIMq+MAuPKwrnH74Y45krzou52h5OIBgcHNTExIekqQxchdo+uu7s7RZ1AXwgfm4Xy8xabRDhAerE4vREBzIsuY67Pr9LzO0s99wwcw/O5cnfD6PCZG2ZHMpo5OO545O2rOwJemlCpVBK5z8umIgkDBQYpByPmLTedCNhO3spLIjDIly5dSkaZchqgdK8d5UCXSjtXLNJkhr7mfuUixpr/a7SqIDqcKEX2dnV1NRlRZ/s6McVTUiBTRDoYchwNIj9gungrV7NOnEPWnnbxa/uQTY/WFhYWVCgUtLy8nCn187ym13ejpJmzk0eRmWYNsqeWfC9cYY+MjCSn0lN5eblwjK2ncHAkvATOHQ72nDJTzm6rwx1n1hB9Q8khRnl+fj4hVug6d3Qg5vqFKZ7i6NRw/S7t3N1MPhmZBXFxY+zfe/0+ep73KxQKqXZcUqpaISXRLH+lrV7WHgHnsZQdLkU5AKH4IXecHcHxn+V9nhNRpMaNQ8yhYaz8th5qqJ2l29/fn4wCBpkog+vGvGTHbygiOs4jMbQSGcfoGG/Vmb3kpPBQi8ViZo88H+LddiIxxI2e11I6y7PR9fa8jvMMgAJjRE5hPt49SsnJWpIyUaCz/Cm/4+X39rbqjfv6Yygoo0EZxeiYNWTdnEnOPFFQyIxHZ14D3+ilDS7fGFL2HUjaDTLnjnxqhE85v0RgQJJe5oJDikJyklKraxxJQ05uc4XnhD/SHsi8k40iWxmI0etsvSOf1zI3y/nwQMODDWfMOzM/GmTXf5yFpaWlBHWT5wT54v0wxjHibmf4frhj7Iz1mZmZxIKnF4I3RHL0BwezHmTd7uA93ClynpDLVIyKY8kWwY4jjzFdgo3r6+tL+rcV+W+pl3XcnLwHceHyEN+JBs48lJQUANR0POtYguF1dr7ojQ73WJ18VavtXLUIpOuEIJwL5uIerudKvMMUpDDP/0X4vZnhTklkLMd+q9JOvZ/DXqyxtNPc3mGvWI/tzSB4T0kNe9zspecwBwcHM1GuMxodftvcvNpUANlwyE5SBpHwCx88B+htLDtBIHHHqB5zPBoPSRlZA6r29qCUxMXuZH73diQY7TWiEwR8ubm5mXFY1tfXk/xioHAk2A/nHVBXirPkcjE0NJT2rZUcWr31js5azHFGWBqH08+pyw2GESXtxEtvKpJHwGxk3YlsIzLkzqPP12uMY4TMsxH142R5Mxlpx2GluYwbjkZzyPXWP0bHjlr4XeZLS0vJWUPecSg5j762nTqT9fZB2olkWXtQFc8N+/fRljnLn3RJXulmoVDIBEOtrHtLEbJ7rihs8np5LMhSqZTJdWCg80gJbDKCGXOkvFAC5IkcSm1k+N94fS9GmH97zV2et84hwvA5cQFIxg91o2xNH76p0Qj4Hjic5wcCgyvtRMBECUTQwFpEGR7FoQT4bHdS9lv3GKUBCeJBx0OIAkLJel4Tg8zPnF8Qy+LitXuRSNcOfOfKNTJpI/EwpmfcGAONkuf2qDj2w45r3miU5k4n6JSzY9kHkIUYCTiE7Q425TQuL729vZk2lCikdtfZnyE6tHF4sOBwtdeR4hw6gSvKTDvpJZ+/lJVxPw9bW1tpX2OZkH8WBpkyUNafqoJIxkOHDQ0NtV32F3WOM4qBz71hy8rKSkLmXB/C34gy3Ymqh71G3AP0FfKEPosIr0PXnpPnBfqC8wHCFJGOZp2gpgyye6qReYmn5IX7eKpuFDi4REoYZPKF8/PzKhQKWlpaSgfKSR1LS0spJ8CCN3pQ8uDq+G+o7tEbx/uMLz9A3gDEc4B5kFcrwpeHTHhk7BE7+SXPPUI0AU4hSpZ2ol68bgSM+bLneLu+bnsNh2ghdoAo+NVtzorkMPizOpTOvqBYOfQxPxUZsp1i5cdUB46A1/YSHaPMHBIlSiC/feTIEU1NTSWlFVu0IjsxkmgUskb5xIiSKDISjJAhlIsjJl4iJykppWLx6jWl3n2tFYXkc3eYEefQUw55rOf4vMyV55V2rhqNd23jEOU5cc060f4sDH8fnyd7i2zHVBxnzxE2UAz2BD3A8/b09GTg1VaMch6RyxvMoOshyJXL5cRs99SFk+T8YpJ20nfNDF9LninuQ7VaTQZ6a2srlczhCMUackkJWUIPS8qFtJsZTRtkjxS9/ozmDA5dcPOO377iNb3kxAqFq7Ww8/Pz6u7uzvST9kYQ5XJZg4OD6SCymB7FNjL4W98UNwjAozEi9lpIZ3mymTwnkQ9krk5c2L6fMXZlurGxkQ7tyMiIpqamNDo6mmj6HK4IzzhhyuE9/sajHs/z7xche96MUgjvi+vs72KxmAwCcKnLnhPKYkkDrVUxbs5KbidNEAdy4p/ttfWsL8+KgoyXZYyPj+vIkSM6evSopqamMhC1k6IaiQz3Wn/mG2UdZGF0dDSz7/HlXbLcWDvUWigUMiVJyAoy1CypyyFGUAW/+s97fHtKBLn31JijIcgNxhh5mZqa0uTkZHLkvD623QjO/86DANdfyFOUcWkHkXR56O3tzQRCnGMQEAy25zKbdY6ivneYGkOMvodL4xUEXJ+KPoxOckSrmpWRdkdeUMZ+YJDz+gGQQltaWpKkpDtrtdouct6BQ9b1NshrQ704HGVARICiJBLASJD7WF1d1czMjIrF4i54jc9wCG+/O0rzNoGvCIDDn9VqNUXsLDZROp6hw8LOuHOv0CNTRwfaMQj18pYoRY9icAAGBwdTBOCXM2CEPQ/E+0CEkXbyv55HdBbtfmvv0eTAwEAGJudggoR47oxBFJf3cqjaGZzkYjvN3syL3ICfUYasHQgREQ8tSokUYOBjCGgH67eYeVTvyqBZRCgPCXLkYnBwMANNO1zHuXNnm3MI2sL5cLjOo+hWImR3kFlnvwKSnK8T0ryxChGyG+WIbLAHnA/6DdSDU9tFVuLz5UVpvlb8jT8H83Ey5Pr6zsU96KX+/v7k1LaSOsjjCkXnjJffROXIlzOrY/ouknbRw/7vThlnN4ruIMYXn7dXua101UFaXV1NiCNnnhROXg650dFShJyXQ/YkN5NaX1/f1Y+a3KoX3ddqtRThsWFskLN93WuPsEAzDx4Ph+ePY/6JPKZD8152IWXJOrErV6c60fja+yHJIyIAs3hpDfkxDj7rSYkRUTUKmRy05/k9L9ToQfcIjQiAKIW5kJJgTaNn7oeB3CXP5r1qPW9frwlIpwyzGzSifSA9N6o4HN70w+/UhlUdy2uiR+5zbzbSlLTrffy88BzO23BHmPPoFRU428DC0s6VmO2SiNzp8XM1MDCgtbW1tIY048HgegRNhINcu1OD4cZgeDWEs347KTf1nJO4P/Hn7JU7lf5eGxsbu0hdnClH8JrZi2jAcOA9b+w9+z1F4aiGE+Q8MnZD7GTBeo53O2vvz7AXMdCHrz26i/eAvY5eQf5xyCM65HyYRkZbLGs3CLGYGg/Zo09nVnsUgMBFocPzbXZBGx15G+0QrBPNIC+QH/duS3irnuvy5+uEMfbvXTF6fjWukUcEGA4MBLBWsXi1iT7EHEnpcOERe7TgEZU7Q3utMWuK4uDntVotOQXklv15cOrI65NO8Bygd+K6L8opmLsb5TyZJnpx4xK7hzkDPMLUkQjkX1uZb973kW2KYcVhihGmK9L4Nb78/1udcyRf4Sh6tYWzjf1CklKplBxUnGJ3Mj2fHxG7g2L9clZwWGK05s/uMsbP+FucJ3dGJGWIj+0a4+iYgRDSDIm2sBDLJGXWls5k0UnCsfMywGh8Y3TK8zc73E65bYpBnEfk/tkeETPXGDWjS706xEtL3V41MlquQ3bh8prjvES2ewlevoKCRjFHdmZ8n1ahu/2Gv4cbYzxObwtHP2IIOxhi768d836dmmc8LGy2E1f2clZcmDjIwC9RETlPwJ+lp6cn9YRtBpLBKKPokR0ITpTf8FysvQu/k2Acqs6LNqNy7aQx9u/rHVKewyF3or14tzbKKi+v1kmDIO2cXRxNulXh4ftz8Hue3vD8cUwzOCzcztwjAkFDEyl7bR/5bwwyBhxZ8/QNTFr+LpY4OYmu1WqIvdY86r96yJ47cM7QjwbDX9JOOSREzfgZrczTjTFpLa47nZ+fT10KvZKGc+ktgz0y9soJL13152OPGc4BaGbNHcl12Y3Bi3+OcxCcNOjIqa+3G3tHdH2feTUyWsshV7f1ug+/Tn98xx9rZm1G46VxPbbvsfri7S+WtJv57F4EeWcUF1EPUIgTEfBc2CRXyK3mZK/7+et0z+I9u37+fV/8ffrlr/plSTuLv7FxtV+1X51XqVQScUraqYPNi3A6GZmx7q9972v1B5/+A11ZuaLJnkk9YfQJ+pLql2QOnhs8N94IhXuA5DfzDDJGkQNRLBbV19eXacvZCBTjsBzze/37X69X/eOr9G03f5v+83X/OdWuMveNjY10/64rImcr01CDvLHnqvze6nYjnXoy84IveIF+/It+fNc+ufwAmYICYWBcWUWZ6WRk9sp3v1Kvuu1VmZ/dNH6T/u4b/i5dXwl/wB1MkBTyk9SaxgsyPG3jjl47xswdGBCZN3ziDfrZj/ysvvn6b9b/d/3/lzgp8dIUNySlUimlzeCadHV1ZfLR0SBHVKXVsbS+pFf83Sv0tk+/TZeXL+uRRx+p//GV/0OPmHpELguXs+vG2BE2D4IiCuGGHiORFwU24jzXajX99D/8tN52+9v0mdnPqLfUq0dOPlIvOPUCdS126fLly7p48WLmBjl0CJwV77/ga4rBwijHNIwjTnCLXB5aWfdHHHmEXvP41+jmwZsT8ayeUcYpRl78XLI2npJxefM+Co4SH7hBlqSf/9DP63c+/Tt61SNfpcnqpD584cP6lXO/oi1t6fri9UmoPF9F3ok6Of4NHEkXGqeNMzDEscNTjEAb2bQPfvcHtV3bWZyPX/64nvLbT9E3PvQbJWUFHMiWy+UxyJHI5Y0pDrKu7vXvf73e9PE36XWPf52Ol47rg+c+qJ/61E9ps29Tpwqnch2hmFvGU/VIzEsvPMcDS5jfo5QiIhmNeOAeOX308kf1Wx//LT106qGJ8ToyMpIgNpog9PX1pbnycsY+LSd5xc5onSg1k7IyU61W9bFLH9PTf+/p+tobvzY3+mHNHQaWdhwaDILXYx6UzEjSQ6cfqnd+6zuTItxc38xcsjA7O5uuSXW0p1gsamNjI/Xo5qs7pURHGM94LqXWImWPVj5y6SN68x1v1kMmHqL+/n4dO3ZMY2NjGc6Kl/3B9ahWq8kgwFXBoXOCWIRVO5E7/q4/+y59/PLH9X+f9X91fOi4fvtffltf9ftfpX963j9pomsiU47kTq0jQc7fKBR2rpL0unA3vsgfkWgzaT3/3feceY++6+HfpYdNPEyVlYpe+/7X6oUfeaFeMfGK1CZzYWEhU9ftaQAcHAIU5xpEwxQNMsbYU52OPDW67v/nmf9HR/qP6Lf/5bf1DW//Bv3p0/5UXatdGWfSjXIegpVnlGNKwPdD2kkbOB8j7vFeoyWD/IELH9BTTj9FX37iyzU7O6vHjz9e7559t86vn9dNXTelh2AzJGWiY/7Nwd3e3k5t2IiUpB2DRx4pLz/bbEQxPTid+fdP/8NP68bxG/WE00/I5F/xsonciQ5ccTEfN8gRIu1k7vL959+vp1//dD351JOvXgxwdEh/ee9f6p7Ve3Rt8dpcZYgAkQMiKiYCyoO2Yp7RvcBWiXSM5c1lPf9Pn6//dev/0k/9w08lIzUwMJDSFsvLy7ugQxwHv0rNc8bxVq1ONh1wmalWq/rpf/hp3TB2g77k+JekFobOp4iRCs/AmfAbl7y0ppOy4qOr2KVjQ8euykL3hiq1imbKM6n1K/3Ya7VaysMS8WxsbKToGDLP8vJyek4pe/dsJwyaw7bLm8v6nr/6Hv3K039FP/Pen1FPT0+6kc07KXkbR5y66GiSf/VbrtiHdpuA+FjdXNUff/KP9Sff9Cd6wrVX9cqPfemP6c8/8+d640ffqB/6gh/aRbqK0D+GiTXFIINaxVpxh0mBe1s5n7VaTW9/9tt3KjBKy3rFw16hp/3N0/Tp8qelspIuJIcNcRcHKhowd0ojycnzt165wHnx/djvWVj3tz/37fqy01+mra0tveyWl+nP7/hz/e5nflffdOSbUkWQd9hiLjyL9zx3mWb+eZeB4Ai5Y+M68sAMcqFQ0GNPPFZv+tc36cz1ZzTWPaZzW+d058ad+trer00RAIvKwcToFgqFdFh4yFqtlg4UzURcKPOMXiwqb2VsbG/od/71d/TCW16Y5uHGx2tzYRiura1lYD0nkhwkIaRQKOiWk7foNz76G7r7IXdrujStz61+Tp9e/rS+YegbkhLlBpJ4NVi5XE4RM/OUlCBJb7DuXizGIkZArSjcQqGgH/jLH9BX3fRVesoNT0kGGaXpDlcsV4oH1htsODTWSTJd3tjY3tDvffz39P1f9P1J8aO8Yiu9mHKJjmVeioNn7eS4Y+4OnfyfJ9XX1adHH3+0XvKFL1Hvdm+m7G1paSk5amtra0nJbm5uJkMH5MfVjIXCTvcxP597EdMaHfz+D77jB/VVN32Vnn7z0/W6978uISqORsB6hZwoZXN+EOqI5OJ1qDGH365jtFXd0nZtW31dfZJ2ULfeUq/ed/59+t6bvjcRQx3edPKQnwdymR4okNf1C0wwfO5Yt/osjvosri9Kkrq3urW4sph6MTjiRo4+loni1GGsPGp0QhWICAgYsuV7sp9Ri+vOc/SV+vTPs/+sb5z8xoxud6OKnHhrZrg17hRUq9V0uxPp1Uj6ygsSG92Dpgwym/viW16s+ZV5fe1ff62KhaKqtaq+5di36IldT9SlrUtpsyhr4WD7z51BB5RG6VOtVkteoefb/NXuRQ2S9PZPv10Lawt6/iOevysSzKvvZbNYcEqKYmTWaUPA+7zkS16ihdUFPemPn6RSoaTt2rZecN0L9BXdX6Ezq2cSwYVDwrpjjFdWVjIKU7qKVHA5AgQfjAnKNt4p3Opzvvnjb9ZHLn5E7//O96fn8sPor7z3RFFhkPOMcV7JUKdGrVZLMvPND/rmFK14Hb7XH6NU3ABHvsFBlGX5eOzJx+q3nvlbunnyZt27eK9efdur9aw/fZbe8uVvycDrsGWLxWIi8klK6QM3yChY5u99ivNu02r1mf7gE3+gf774z/rAd30gQxjj3IM+AEVLyuwJslwoFNJcUfj1bl/rBPdjuHdYX3LNl+g173mNHjj5QE31Ten3Pv57+uDFD+raoWszSAN6xiFsnE9kxw2ycxN4Tr9xzqsrWnmGmG/erm7r5z71c3pg/wM1uTGpue25jH6EtOjEKyBzdAu6fC9Ejn0dGBjQ2NiYJO1yzPfLw7Lu//3v/7seOPlATfZO6i2ffos+cuUjOjVwKpcIR301z+Gd6aI+kXYQx3K5rIWFBa2uriZUgvnGlGozeqjpCLlYLOrtd7xdb73zrfrlJ/2yTvWd0kcvfFT/4+P/Q2PHxvTQiYcmyLmnpydFvHiycbApnh8BOqPTVKwvrdf0vdnxG//8G7r1Abfq+NDxjDEmKs6Dhra2tpLX5G0T74tSm7fe/lb94e1/qDc89Q26bvA6/fOFf9ZPfvgnNXxyWA+eeHBiPHKzT7F4tbcqt7AAAWEApR0FRg4f4XTDx+UH3oGMDluNrv3ZxbP6ob/6Ib3zW9+pvq6+XXmzPAMac2BA1qy7d36L+ftWmJmNjN/86G/qaTc8TUf6j2h+fj5xDLjlaXFxMdPPl4jea9NjY/1OkIjqjVtvulXS1bV8yORD9IjJR+ghv/YQvePcO/QoPSoj9ygXZ5J6FI2Rk5SMLukD8vix+UOrz+Ty0t/dv/MfBSUjzBwxVjgPdI7Ku6YQefaWtp1MbzB+++t+Wy/4kxfo1M+fUqlQ0iOOPELPvOGZ+uilj6Y0AU01HF0BQnUCqyOJXqoW+0AQLDgEHAl2zT7bqz/8an1u+XP6sRM/prm75zKkOedJeJ2yM7KRE+afl+5y9Gt4eDihHKAXnI9GiFGs++lfOJ1d98sfzSBUHrUT1YPMEtlHQpevvV8HSvCCsx3brjpjfr/RUoT8o+/+Uf23x/43fdNDvklra2t6yORDdGn9kv7k7j/RE65/QoIAurq6UuSAkduri4kTGjwyg6wDYSePEdnsQbpn4R6966536Y++8Y8yUHWMjD1fgKNRKBR25TLrGeRODAzWy//25XrJLS/RN3/BN2t9fV0PmXyIzlXO6a13vVWvv/71SaF6n9u1tbVERIs5Yj8gzpp2RjtsVjoZwWRutgvWhy98WJeXL+tRb3hU+tl2bVv/cPYf9MZ/faM+9KwP7fobnx+y4bW8eeUq7cCk+417Fu7R33zub/T7z/z95Fm78sEgsw8e0TvSw6HNg3cPwij7GOsb0w2jN+hs5ay+cPALMxUQGIbYV8Cb8WAsUDZ0KqP0LN7Z3Oo5qCcv77nnPfqVD/yK1n50LSPHOJbwPRYWFlIXr8g98K5R9Xqdt7sPN07cqHd/+7u1tLak2eVZjRZH9a1/+q063nc8sdtBHWJ/A6/RjecLwxBzsdLOfb/kX5vVi66La7Wafuyffky3nb9Nv/CoX9D2zLbmNJeBnL1fwerqalpH0KAI93KenSEuZS9eGRsbS7lb9D1dBz0Xvde63/Ydt2lpbUlzK3Oa7JnUt7z9W3R66HRCqJBvHAUv7/P19GoVAka+4gDyDN5NzlOrHik3Mpo2yMViUStbK+ru6s6wYPv7+lUoFhJs6n1VgcPwnjyZ7vWFQHtDQ0OZtoQRFs4r02l2vOmjb9KRwSP66pu+OuP9OCs5NjtxeCZP0cbr8To9VrZWkjJk9PX2SQVpZGQkrS2DiAGyXGzYUs8hcs+0v78/tXzEGDdzYTvjydc/WR/7/z4macfQfueffaduHL1R3/3g71apkBXYWNqBksTLJerMu8npoAzbmz76Jh0ZOKKn3fA0rVRWUq4sNo0BFfISDidyRRLRQUXHeaOyUdE9S/foqceeKql+f3rY9JwDHFacomq1mknbOGTdifpvlxfGd/zJd+hBUw/SDz/+h1UqlpLuwCBTFUGZFiV7g4ODdaF1j2IOYh8GugfUPditc7Pn9Pfn/17fc/33JHmhlpev7sw5yYtXniEjssTh8FRSsxUojFqtph++7Yf1V/f8ld70xDdpeGNYZ3V2V6lVbJGKA4AzH1OT8Uwzkg35t3afXV1dGh8fzzgozZQOSdJgz6D6Sn26VL6kd599t176hS9N64TB97ywN5uJTGlvJUsNNXYAXeRtW93JO9AIWboqBF9z89foZ/7pZ3R69LQeOP5Afejch/Sbn/pNfd11X5cUpS8uOQUnv3hEhmBxRaOkjKcVFVm75KJqrao3ffRN+raHf9vVXGx1pxGJ15bFF0pW0i5le9BGgXX/6X/8aZ0eOa0HTjxQHz73Yb3xE2/U11/39ZkyDqL7YrGYlCzwnQucvzfwmNcMYgA9Z+s5wmYIMMO9w3rYkYdJ2jHIg92Dmuib0APHH6jFxcXM7+c5DAg2sJPnbA6SpSxdlZnf+pff0vMe/jyVCjs9bCPxz3PwkcwFScW7eXUyKssbL/7rF+trbv4anR49rbMLZ/XKd79SxUJRTz/5dJUvliVlyYzerxhkyCMy9j2WinjEH1MZrTyXywtjsHtQk/2TetiRh2XkwptiRM5HT8/Vm+Eiwz3m+tqBdfPGO+58h6q1qh4w/gB9+vKn9fK/fbmuH75eXzn5lTp/7/lMG0rvfuVEIc5g3sAg8DzDw8OpcsJzyPxuM+OF73qh/vD2P9Qbv/KNGiwNanZpVgtbC9qs7dwO50bZ0U7nheStZz2noqurKzUCGh0d3WWMG21w8o4736Gaarpp/CbdMXuHXvqul+qm8Zv0nJufo5WllV19LJhLhOJBcv3f8Wz39V0lj/lFJ3n54wPNIUvSL936S/qxv/sx/eBf/aAur1zW8cHjev5Dn6/vvvm7tTi3uEvpuBfiJIZYA0ZpgrTjbfnixYdsVZG966536cziGb3gC1+Qfhbp6rGGjJ87E9bJSPUchE4qWdb9B97xA7q8fHXdv/1h367vuum7VJ4v72rfiFftRsPRCT+8OBeQZdygROPXidKQVkdc94OOihnIzLc//NszchJLwrxdoSunPPlth3zT6Li3fK+++Y+/WbOrs5oemNYtJ27Rnz3zz9S/2q+yyrsMm/M9/NYvf568iMiV0UE7GT5i1BbrwMk3e+TmOqWe4Wh3LK4v6uV/83LdW75X433j+qrrvkrffeN3q3y5XJc4ipFupOkOz4JOcvhUaq/c8o3/8kZJ0nP+4jmZnz9Lz1K/rubzPQ/sMu9jrzm4s1EoFJLjFC8oabbE0td9om9Cz7z5mXrZo1+mwkZBa8W1zPnjs/05IrHXu9Rht6SdRjgYZWxYu+e7IYPMQpTL5fSzV3/Jq/WqW16VDC2NBuI1brFJQuyoQ66tUCjsqufEY4/XwEnKNI1nML+YC4lzv2X6Fi2+cFG1Wk3lcjlTSkDO2++HjU0e8ohfbgA3Nzcz0FEjo9G5v+Zxr9Grv+TVGYgRCr5f7BEhd4ddWHveHwgcocojjaAsIEL4M7In+83d/297e1t/+NV/mMrdYCiz5vFQFgqFjLEgCoI5DpOTeTXqlTYjM/M/NK/t7e1deT9Kxvy2rFqtlvJeXmaBDLPO7UDWjcz91576a+nnnKelpSVdrlzeJTPuVPhZRSGCdHn3oVgPTFqK5g6x33Kz6+7jT7/+T9P/I0M8T9wL1lzaOa9es8w+9PT0pHk2el7j3PPm/fRTT9fTnv+0NMfl5WXNzMxk9Eq8iMfPquda4yDAYU/YJ/97b5SytLSUjPbKykrdubO357/nfNJvlUpFc3Nzuvfee/XZz35Wd2zekamdznv5qLeeHtwg/93d3Rm0iXRbf39/KjdqZN2f/u1P35WGqSzv6BdebmRZf+dLOIfCdSgkrnhWXDdReVEqlXLXve6oNTDOnj1bk/R58zp79uzh3A/nfjj3z4PXf4S5f77N+3Du9//c641CbV+TfdWbOX/+vIaHh+9ziLKZUavVtLS0pBMnTmRqzQ7nfrDjcO73zzic+/0z4tw/X+YtHc79/hp58p43GjLIh+NwHI7DcTgOx+E42NFQDvnzxRP5j+R1S4dzvy/G4dzvn/Efae6fL/OWDud+f41GI+TDHPK/s9fh3A/nfjj3z4/Xf4Rc5uHc75+51xsNRcjDw8OSpLNnz2pkZCT3d2qh7ABm4dzcnM6dO6czZ87o3nvv1dzcXGpnR6kTNa5TU1M6duyYTp06paNHj2p0dDQVXTdCHS+Xyzp16lSa715zZ74w4BYXFzUzM6OLFy/qypUrmp+fT5dxz87OpntjvcbRm5qMjIxocnJSR44c0fT0tKanpzU1NaWjR4/qxIkTOnbsmMbHx1NXqcisbWbuzL9mTEJnHS8tLWl2dlaXLl3SuXPndOXKlXSD0uTkpE6dOqXTp0+na+xoKNBq+8Bm1z0yT8+fP697771Xly9fVqVSSf3OvXyLJid+6xbXQ9JMxu9IPnbsmE6ePKnjx49rcnJSg4ODuf23G5076w2bcmVlRfPz87p06ZIuXLigy5cva3Z2VvPz80l+uBGnp6dHo6OjOnr0qI4ePaqpqanUBtbb7VFTyrNwp3O9JhvNyky94bIUKw5mZmZ0zz336M4779RnP/tZXbp0Sdvb2xodHdXp06d144036oYbbtCRI0c0MjKSZKleCRSvveZ+5swZDQ0NZebizOi5uTldvnxZ58+f1+zsbOZKSGdZo2NocDMwMKCRkRFNTU3pmmuu0enTp3X8+HGNj4+njlBeR13vHMS5N6offZ39FjYahfAc9ElG79DpkK5j8d5e72bo54CWw1yFSOfE06dPNzR37/VMvTRzQs4vXLigc+fO6dKlS1pYWEjtO7nEw+Wbxizsw/j4eNL56PvY3GQvHdmKrOfti7OxecaZmZnUrIXKAVoTF4tFjYyM6MiRIzp9+rROnjypqampTFexRs5q3mjIIPPmtK/MeyjvdEWJii8owuh0fbrn0NIutsn0nsnN1HL579WbO3Oh3IFyGS9fqoW6vrz6Z/+d+D1U/tjuca8uRo3M3efPweb96LXNs2BkOcTeC5r3xVi12oa02XXHIBeLV3tte+ck78frihFlhELiLm3qS+M8kC8UgiulvNrAvebusouRqNWu3lDm60ZtKMqH3+/q6kqtVb2LD8NLiZAbv9VqvxK6RmWm3ojONOUlnAmvPed8cI3h0NBQ6is+MjKSud/ZazLrNUnI+x5nBceeMh/WEln39eAzKGOiuY1/dpQPauy9mVEzHdRir4G99KOvMXqSNfbn8/K52CkwtskkIKC5ht9JjK5hLV32G5m7y3t3d7eq1avXDnobTM4yv+sOF2voDXGQHe9h4Jdo4DTtFxw0uu6NDJ6B3tbcHEZ5pzuq7B+fy/PhaIyMjNQ1yHHu9UZLjUHiA/lhxkulNtDbwvl9x94Vyq9vuy97+8Yo01vvUYvsdcjRSPt7eG3y8vKy+vv7M3XM3oyjk8MPudeD+qH2etLYki8verkvhn+Wz8Xb7HHYY12sryFOlRfi9/b2ZmpsUWp+UBp1OuopU69ZjJePlEolDQwMJKPAJQw0P+Cc0H8ZRwMD5/2XvR8wxvogRnSq/bm8+T5rz1p7wxPm5r2OY0MO1n6/cxAdIK9x5kVNr7fqpVlDDABwAtFPXge8vr6+q/cwf9+p9XbnOd7hDBJHVEad+8LCQqqv9v7WfkUmsjIwMFC3QUc7ww0Wdcl0FuPqR27x4/M5W9EJYz/pyNXb25vpcMdZiLqo1QCh0eeLsoYN4BmxBdzihpPh91A307xkr9Exg+zN9n3Drly5opmZmUwzdUmpw4lHyXn9oDt9MPLm784EjoQLHL2JY8Qc3wODToTa39+fuas0r3C+3bnnKS1XVBgJogufv0f8B9WxaL/hn+9KHUPMQfbm7+yDG0pvdk8KwXsy510J2OjhiQbZO5+5kfCr2IaGhjLPRqcz4PfYKB9HYmRkJEUlfqvZQZ2BCKUSIUfD5c0qPLpxY+tNYxwJI5Lwv21kTj4f5Jo+1TjN3OUNshJRFW/sQFTa19eXaSTiSAcvl7F21h3Z8bnQmIJrTy9fvqyLFy/q8uXLmp+fTzdVIVeeHuOcEMTQzndgYKCjTn+MDl23z87OJug63sfMZzNPj4w5x7w3t2+xD7ROjrqo00FMfMaIDlUqFc3Pz6cUFPqUueMkkx6JjlA7MtOWQfYHqvcwMzMzunLlSuZqOg5OhOb82rb4/hGm6MTwwxKb07tRQ9Ac6opeEZHF8vJyOtQDAwOZTcu74aqduUfYGiXKi7lH5b+XMb4vh0fGDj+TQvDuP961y9fPb6tinXt6etL+ecc4v0igkT1wY+XRI58VozbWG+gwT55x3Pz9MBQ9PT1aXl5WsVjU0NCQxsfHDwxZic/pcsQ5IAoF1fKWqtEY40w5igC8B+TpiMx+jqmfy3gFJB2c3Oms1WqZixaAdJm7X6yCbOR1KnMkqVORmcsPXZyIgp3rceHChcRV8Qje4Wo3yP39V9tY9vf37/o9qXVduZehWlhY0NzcXDLKCwsL6VIb0KHYc977n3uXt+7u7rSfKysruZdjHJRuykO+iI65+INAkmtpSS+Q2sgzyPdbhFxv0/CiuI4OQ8wBAqP3PLLffuMlEEBefFYnb2SJ+V7fGM8ZSDt9Y10wPHLnvTj0wC8OmUZD0omIx50BNxB+gQfK3p8jCvz9YZQjTI2j43vsHrpD1lIWxsJg85557SDzYKVGD1AeZB2hRz6TZ4K40t/fn1poElVXq9VMS1K/g7parWpoaCgTGd0XBrlehOwyzJllr7zHvMPVPCcIB1B9hC8bmVdME+AwoNTdUAHzE2l5a17mVavV0jPl9U3eq2VlJ9bWoV9g6bm5Oc3MzOjSpUuJxOWpAtc16EycHYfrPTptNRUVEZO8dB7ESlKQMYKPFwPRmpf3Qi7iPnBmyTmDNnV6xOjfUzR5F34sLS1lHA7OaNTtnZCblp42b8PyBA2m6dLSUuamDI+KPOFPJMF7YWwiMaATRjkvfxkFCfjTjYEr4VKplMldAXX52rTSIL3R9Y9C5REy6y0pra2zSNslcLUzfO0d0oTcsbGxkTG2eYoyOhK+Lh5Zu/FtZcT3zGs878aftXW2dLFYTH2VMQpuoFdWVhJMTZTM/vlzH2TqJkajETJFicbbq5wo5Oz5SqWSnpmz5GkJdxT3GtGoYJCIhPv7+5McQJ7kZ04K8/OCTLUrG40Oj5AdfQB+x8ChNyOqxsCh9iCFn/u65F2ewVruN89G0SBHoAiWPHKHqAuJFRSLvffz7Wc88l06pTPjM0a7xTORsvT0iD8nshYvSOpUKrJpg1zPGHvJzeLiYqLGA7/gXeAxx3sj3YuVdhiebDAvSR01yigZ7nUdHR3VysrVe4eJMj068qbrTlIjJ5KXAzmoEdMFeHd42HiYXA8GI9aZjPdX7tjzkM7+xhNlzf05Y37KFQg5Sr9lLO+WsPicjRK7XAa8mbzDoZ43hX2JQcaxAwVyZMnJd36xhnveB51Hc/lGjsjV4hyXSqV0Zp2R7PlBSggxLJI0ODiY0C6MCee83ohOO/LBmrM+7sChL4jISqWS1tbWMsibv7/LRitX5TW6tnkGmfX1S0pwxPJy9s6OdnY43BtSfpQY5RFkG5lvJPY5T8K/5uk9ctmjo6OanJzU5ORkMsikD5aXlzMpmoiGOfLRCprVyH7kcRPK5XLiORFEkq7BDkhXbY/Ln6M090uEHGFqZxZDhvKkP2QX3zivsYRsI12FfNlwck9EGqOjoxmIsxMGxA3C4OCgxsfHE2QyMTGRgcZQWBC/QADK5XLKCaKMHX4/qCg05v2YF542HjY1mL29vRoeHtbY2Fiqf/33cI1iVC5OfkPRukFmz6IsYLzJ3fLMMWLw521USdWDrJ0sxAGVlAwy3AgMC7LlhsnPEVC2OyWdJAHu92wxh+b5zGq1miErUsaFQWZtyuWyZmZmUk25JA0NDWWY8OxtvRGdNXSAtFNqQv2t3x/sOeqtrS2trKwkBy3vwngiyni1aKfTN/GcAonyAo3wm9rQOdJOegaHCMd6dHRUY2Nj6UXJKHsT9dBez7SXscI4lcvljONApMvn9PX1pfpiejEMDw8nZ9RJeNvb27v0fx661WlnNI+ATCUQcut8J7/jnDnGtGYn59iUQY7enhOJ2DAIXVeuXNHCwkKmCQjeGzVb1IYS7fiVgeR++vv7NT4+Lkm7YOt2IDw/9ETHk5OTiaELQcchLhQWRfEIlN9hKim3ZCXmnzsxEA5XpHh3IAxe503TjOHh4VTffV9D1m4MPfoZGBjI5EyBfyPL0hm0TizCcCFfQJceLUDqaiZi4Gs0yMioE4GIaPy5kAFJuTWuHp1K2lWmdpDRcXy2PIMMYadWq6X503jCS7lwVhcXF1PTjoWFBUnS6OhoplykEcgaJwYZlpQic4yRM8AdtQCxI4KOqAmRtKennICEjEntkUhjqiM6zp7Si8Q5h9QdyUO+x8fHNTY2pqmpKU1OTqZ/j46O7qp536su1udaD3GLUbxHx74n9JGgIdL09LRGRkbU1dWljY0NlcvlxHCn9NX7THiUHFNO7Y54htfW1lIDpStXrqSvly5d0uzsbGK6e46cfXCCYKedhrYi5FgfGHMidF1CObl3B8TiEB5eLR5YoVDQ4OCgarVa8og7WWvnURpweHd3t4aHh5PX7TkOOtb09fWljSEq9tIoFMl+Oe92HAo/7E6oIwciKUUGRMZ07omNAu7r6FjaHSHTTANhJ0cfn9VziTwj/x9zt7HRQ4Qlm4mQXe5dJmI+0p8Nwxyh0egUoIQiGnBfGOI8tMvP8OrqqiRlHCfOsEc3yCDG/NKlS5qfn5ekVHaGsejp6dnF/PfhDpuXSxGdew7f0QrPBW5vb2f2m7MQI2OMcp7D1q4xzgtg0G8YOWeL56Up4lo4Yjg2NpYxxN6pq5WeDu4cerDlX50EJylVk4yMjGhsbEwTExOanp5OHQvHxsZUKpW0vr6eyv74N9G1l0W5TB6UMXbHyCuBvBbcjTEOKXOP0Xsnz2lLEXIeXE3uMl4yj2dNe0MiNQ400bGXBAAXAOH19PRobGwsl9rfzoiRmqQEC3mE4jBHd3e3Njc3ValUNDg4mKJjSRkmajQCTizxw9bK8MjNPT5nWSM4OBvuCEXP+f6Aq/NyhDhtNAiQdiAmJwTGdfUIAjSA58wzyK08M78fSTLxfaKRY94OufO3vFfMb0oHSzSKDoY3OvHL4SGzeDkRskQEBlLlTH+gyYWFhZR6ApLNqzjIGy4fKHJykZ5CcmIObHacCD6DfY/RsUPV7ji166TGyNjX1y+vR1968yHPSbIO6CgcTpwih6gjKtRs5708x9NzyRhhT83gJHkrzPHxcU1MTGhiYkKTk5MaHh5OufxabadGHWebPLLrWU/XtGvwopzX4zt5XbUT1iI3JOrwTp/Tpg1yzId4VOYPwkM4JIwnh2FgMxBEp5wvLy+nzxwaGkoeZKejBzcOHFp+7gKKB7dXSzcUqxfEe4MT9wB5Tz6r2eECgRB7CYHXfqKEUKLNkj0OarDuRCx9fX2ZPKukDOlJ2omIJWXqHXkP5MxbBrYT+USIHcUYYVDmi9LhbCADlHUQWSBv/p55ufFOD3eqY5OTvDp2Uh9wIxx94OfRKXQiTKFQyHS68zVtdM35WXSSnYiGkseJj1EnDjdRtqcx8hy2dtc3MpW9lMwNshOkfI3y1gOHwlGKiAY5Z6IVOfL5O4eBr157TvMaEDgM8tjYWCZaZx4bGxupZpr3ROZB+pwo2a5RjvsQHUbSBrxoXOWk2Ogg+X4cRDDTkkGOSXHvwkXXHGkn2hweHk6NxOMFC+SO19fXd3XIkq5u2ujoaOZw+SZ1ElZyDy0vQuZwOUwWvf56pVyRpRuh11bnH2FrhI7/c8/a673z8krtrmczI3r+DiOyZg4vra+vJ+MNaQr5gjzFa3R0VBMTExlYtR1j7JBnb2+vNjc3U27alSAGAei2p6cnOUfb29sp3+nwF0qWnJrzDjptlPOMcew2FhtmeO7X19lbg7rR8ZwuMLy04zyxz40oWNae3yW15WcWh4CvQJEwmCGk8fn0evYuaPUIf82ufR7yEJnKnpONQUxe7hT9gOw7SuGRsTva7ti1Kz9x/5DX7e3tDFGUqBgdD0eI9B5I6cDAQHKuPW2BPCIjbpTbMcbsg+8BDVnok0ETK2qO3UHKQ7byXp0aLRlkHhAM3huA0D6wUCikfMfY2Jimp6fTjUe9vb2ZSMK711AyUalU0oM6kSAy21o1InnwjNeUolDcIHue1hWYPzMHAqPMgXKjHr3XVjzYvTxxLgcAWnJj5wfW95R5eBTDOAgj7ZGgK2qcBeTDmY4Y6t7e3vRc5CbxyrmkhH/DU2gVno+5bucTxEsJMLZra2taWFjQ9vZ2ph0gaAaGxOF6jIYb+E6yffOMhctzLMFhzflsN8ju6PB+3pMY/oWkjMMVa/z3W3eGo0yORHA+MdRwPMgLUgtN+V9XV1fGiEVnqhMKdr+UnufnY3MNRwBZA3fccKwHBwdT6o/0jMt5O4Y4IobIvnM9aHbT39+fSpwgck1MTCQonXPBmrD3/f39KRADFXAovLu7O1Np0Cw8HHUjiInXenveOMLVoDuRWOechL1KKdsZHYmQvaOJ179Sl0Z+AS+KxtyQLxxGQGBpIdjV1ZW5pKEeA7VVWMaNsXdd8g1xgwyk4b1wqbf0DmTucaOIgWPyDHIzXqBHxrEenBelBv6+boSdth9zsfW8v4M0zJFwheNHFEc9KUaAA+IIzPj4eEqJ8GqHvObKCWhc2rnQwiOtvr6+TIMQ6i1jEw0ILBhfXpwZd0ryDnorDmieAxdTTigrSEY4wDhBHqFhxKSdyMZ5Ix5VRPQDZKERHog/pxti/0qUhR4BZcOpB652Y+Dwbp5cMPd21tqNMnrFnfjYvtObTPDsLnvIS7zNyZ0Lr5poBQ3y7/2zPcjwTmjUHDvbm3QR0THn1eUAeWetMMacC0cdW+UL5dkqao2pBCIyRvadyIVzxHqgm/IaVN1vBlnaXVPnOeTY0cQPAqUSNN0vFArpFhP3Il1oi8ViKisCHvbaTw58M8o2GjM3xhwUlKnnMdjY2NyeaILP97+J5Ag3lgh6s55fvfeuB6nHjlIOB7tj4K8YvUbyUacNc96exHpfr0lm3/1aS1ie3O/sMF4jZR/1RkwrAL2h3Pl8oLhCoZCcRu+25I4Pa1qtVhOJjQg5Xl3IZ7aKBuU5n17WggPsJXOcNyDE6ABzHqSdHH9eVy+Ul9f68iIaamT9/StnJaI6rpN4FhxlT5G5TMd1iXISP7uRtWZu9c4+OiyvtW6MyJyjgOPmMufXlnaKIR6hWTdE3tCpq6srOQR+XS7lcF5BI2U5EzFFwD5IWf0Zna9G9yASK2PbTw8gqejJC/Z8HSLXw5+hk3qx5Trk6AHGVmKSMhsQlb/nPWM3GIQVdl68fQZjnR6iq7HHcGXiEasrJzylqEy97ImGCXhUDu25J+iHb3V1NQMr+aY2kkeOa++G1ktAOPzFYjFFDIuLi+lKunizTTwk/pVncRjZ59outOfRWtzjeO0lnxcZv66gHL7rtJJy4+iRAggQnAfv7uZoSYwueF+en+FRtHMYWkGEPFqLzqdD1TS5cSVFpACa5TlQaecKRT/D5GzJ6XpETARHdNFK9JYXuca8uJ83UjelUik5rZFtu7q6qt7e3vS+sfypU/Luc/V9deidz/AzCQnNu3K5E5jHqm53nm6MI2SODFer1cx8eMWzhxzHNeB7/7z4fbPDz0pM43lKNNZ+u5Pf1dWV0CqcTp4fVAVZzqtbb3e01Ms6L5Lh+zymIJAth0FSBsZxcoPXfhWLxUwfVYwLjMm4oXt5UvUcCZR/9Jy8UN//bmVlJTWEB1r36I35SDukNg4PTVA8rwbMvd/cXelEZmy8TAGUgqvSqN1cXFzcVQbk+SF/eRToBx4jg0Ju5fBEx8KbUWAYiLYiU5yyD9YUJieRghPXOskkj/k8FCWtVol2I+TFQXc5gOQFuclzh5Fw52U+/iz7PVOUGZd58mlcajAzM5O5oQ35pzkOt/wANQ8MDGSc6lgHDEzc3d29qxwHuWsm6vE9cGXuBs6dO39W/s6Z8f39/SqXy6kxEXvgzpyTCzFOjcp7RJg80o2Ik0OesUqDqgGIiuRmncHscHWjzn0zw/PI6AOMmLfsjE5XnEuURUcFeGYicQ8WGjXQecbYy8ycVe0XRhQKhVQ94Hu0urqa9pwgwHWP9zjoFP9AarExiC+AKwzPC/A77pHSI7pWq+1qqs5Xh8u6urp2EchQBnxOhKD2mndeTgFDT8tPlBOwtTMfndjFnJ3Q5QxtHA+IR3iNHDw8LryxRtbdUYXIjAV2cVifz9nY2NDc3Fw6NA63MA+Ezb1xJ46w3j7XVg6/K1H2lnaLXHdGcT5scdaNaNRbBqKc3CB3ou44b7hRRlmOj4+rVrvauGZ8fHwX05g184sP8L5p4OJwmRPA3Ond2trpBb2frPta57FMIT7Fe3jjhTAevRcKV8uOyuVy5npD5sg54ZlBBNgvr3/vRESRx6PwOeDsuZPM37ixrNVqCQHA2OEc4TB5Hn2vEWHemD+PLzfWEaru6elJTGrnSExMTGQ6csVrazsl725gIsLDutMRLJaRORztKEAeG9/5IOjKPId6r2dy5ywGXMg7QRQVQQRcQO+UXK6srCTd4akknBFgem9PGomy7YyWb3vy79kgXxwOsJM+KpVKWjSMsTfzhuRACYC00wGI+uSBgYEkeChGj972mnM0yF5mBeuO1mnkFjxCrtV22nsSXXskFJECSZk8J5uHt9VoXXVedO/5bjfIjligpJaXlzUzM5OBWDxKJ3rzUorh4eEkkHkKdD9EYj/5cYPM+l+8eDHD2CcVwEFHOeEoYIyjQjgowoXDibSQlHa6u3le0KPkYrGYlBnGbGNjIzlkNJehBpv5enqFvfS8WyNr7LlVr7+kb+/Fixd16dKlVLYIjE107I7m8vJyxqg67Crt9Jl2h847SLnyajSHXO/ZYnQcYX031jG3y14yDz/L6C72ADSuGXmPSIqjTnlGmXI5hjtwtLul0QbExVjy1AlEKKIv9VAh1sGRKT97MTrOQyZdvpyF78FLs89Tj6QLwx17Uy6Xk4Mbu81VKpXkgElKZ85b8uaVzTWDXu01mjbIruRcQeWF7RwML8hmwbzMggWL/Vx5D/K3ePgexcXa5L2Ge1ARyiBCvnz5cuplGovD3ZA4S5l1kXYgHmrpiPYjrJzHFN9vuGATDWOIyd+7QUYB0e0s5q05CHh+5EJHRkYSnOM5k3ps7FZGHls/XmhAdALDNC93lgdRR4Japwf7K13dd4xzJO54BOAdogqFQkorLC0tpYPN86LMYgTIv1EYjTpxbpCRddaakg9YyTii3r8XmfLowffRIc3h4eG0L+T1cZpiFNfOcKPs+8JcImrmKSfk3gmOONAObUfD1Ii85+lFf8+8dJHzavje22SCMrghzmsG0ono2J/Z0UeH0XGm6nU5Yx4xcnVDibPva5LnrDQqK+6IOXITe0Z49y3SFwR5VEVUq9V0dkFqObuRMHq/Q9aSdm0QB9Cp4nxlYVZWVlIDhK2trQRR50V4/hm+uRHCa8agxQ3DKOeRynASvBlJhON9oJQdonKGaTx4reZd+RqjAg6Iz431x7CiyB2ujsQjP9BuBLlEwNmEngNrNlL2vYglON7pzXM6KCeP0PKcg05C1HnDHS/mgTHy3BUvIi6XC0kpTwXRj/SMr1G96K/R9ebMuEHOq4N1/obX30tKn+usawYyBlJVq9XSXcReL+uwarOErv0Ge+3IE0ZMUmYP3ADj4FOWI+2+FCYa5FbkvZ4DG+Ftf39HrnA+nVXthjgS0DqVmnFClxtbf/965y7OIZ51XuwLOig6LO1EnfUctZ6enT7qIBCDg4Pp/EGG9QZX/rfo9RgAdEqmmzLI7jUBX4Cpe/epQqGQck9Et7BQMcp467EQW8q2RHSWZl4k1IrydQWHQYuKL0abzh53AfUN8j6z1OiR82Gd2mVFxkPgxhTBQckAlfrhjQLlTgRGGijZy9FQZo4ItCOIeaQj1lraiVR8bZEDojaPRPO6+hy0UcYzJrpx8g4H2IldzlaH+FQvCnCnpVlDzN+zlxGFcOIc0GGEciXlckJQUr7WrIOzq1034Eh1MkL2MwDcOTQ0pImJCS0tLalWq2lwcDA9H3ONuWPSaP8/e+8dJmtVpYu/VdWpuqo6p3MOfQIHDjnrKKAoMoMiQUEJjoDida6/q4wj6KAY5prvqNcwOqPj6ABmB1FUdBwDKqAEBUHAgOCRkzt3pc4Vfn/0fXe/3+pd3ZUaBp/ez1NPn9Ohan97r73Cu961NrBkkC0kz8+rZN4q3xbh0PW0e6oOMVEV5Xao7lDnul5rqjrFGkc9r0zf+UqF7DpoEKQOKz+ThllRhWr0u0UnlBXOPHw+n3cRMb8fi8UQiUTcnKj3yGFR1EU/Yy1GVQaZCpLXgLGm1TITuXGzs7MIh8OOOa3MTWU0axTHQ80Fo5FT4awWqvFFVbqJWrxuI1CrBOg0MBpgzozMSLaVYzs5NcyVdNbRuVpj5TOy/D3fVYS2hab9HB485qnZEMOyIqthzFr+gTU8jCYbGxuXEUZ4WPP5vLtAgNGzbQS/1kMjZZUh/T/hL86TURfzUrrvvrVYSXGvNlR52pttaJDZyYpQOVEQ7qkiML6crI2i6Xjw3OrtYkRZVKnVuv4aHLS2tqKjowMDAwMoFhf731O3aEmgdbqnp6edUVHFC1Qn72qAbD29pqsUdVOZ1TOuhMt69apeaT35VWF2voiykQ9ER9PXZUwjX+uYKMeC68r3UoNcqZyoI6POsLKuw+Ew4vG4a8PLAKqlpQUAXN06qxB4bzOAim1MtaNig0zPiSUfbNBB3F/zNswTMB/FKJm5HELWmuBXAWTbTX2VuoC70ijTZ4RpsKLRaKCWVKNnCgs3U8krJF6wfZzebqV3lOoBKzdKtp6flgFpBMznIHxoGZlK2+dBI3tcb+iiEmF0TMFm5F0uma6cocqIz0cGc3t7uyNecJ56qIvFRYZzW1vbstaDT9RQ48yhhs3nVVtDTCXOaLtSfoEdfA+mZrTN7eTkpEOmFEkpZYh9aBGfwSpTRnRqkPXu5FqIRyonarjIeO/u7gaw2Aikp6cn0EmPiI+WC7K0i0YbWCq/oRHhma+EPGpzptqgR0l6/D2L6tgIWZ1oW/taLaTrG7qP1CVkItORJPeGwRbJjJpG1LSl5b1oDpnON3/fon2lIHDfvO3crYEn8ZIOGD9bKwYymYyrJIjFYshms4FKCT2zSsKrp76pOkKmF6wPp5AnN4dCSE+cOS3i9MzTAkGWJo0YDbFeNab1d758w0rzV2OscC0FPx6PB3rxFotLtykBCBAutD6wq6sLfX19GBgYQE9Pj4OpNarXOj0lQ6wmdCqYeli09aK+6PGpo2B7PcfjcZf3m52ddWSfUGiJcMTcI40EDQafvdqItBRUpweK6Et3d7e7oJ6HiVEfkZdEIhFo5KKe+Vp7tHwe+xzcK2tEeAbs4falTWodltSl3fBIWOR+2ghdo0jCjCyhY7RBZ5zPrA1btGTHV5pT7Rpb+J7rTEJZKLTYQ7+jo2NZH2k2PUmlUo41zmCBz2QJRYQ2y5F3O0fLebHQdalhgwXlH/hYzPUaNsq0ekojZOp97YNAFILzpy2gPVBuBRs/abrEl4Kz8ytn7j7Hgk47U6P8XaIhJG/NzMwEApbGxkbHyPalleo9as4hc2N0A/SqRD6EhW5sHRqAAFNT600J95a63aRcYwxgmTHWGlx7zSOAgIPBv7fkEc6RDdb7+vpc7aWFmixJopJoweZ3bDkFDTKJXLaemOtJ54aXGkxNTSEUCgVY4YySafQowKUa4Zcz7O/6Ih4+F5Ug4X8qMaY8tCTK1wz+iRgr5XjVydL9VXm3BlAjZN/7VTtHC4UDCBgfn6LRHLT2E+c+aS6Z+6A5O9u5qdr0ks8Q63rRwHEOJOhEo1HMzs4ua+DA1BnROq1QaGpqcgac8yZ7vhLnU+fqcyCofxQip5HQtVHjZHPF9YZPS6XvbI2tRrvUB4x4Z2dnHVRMlEEjYiXRso0v11Vtiy9/XM6z2t/nmlEuLWGRv1soFFyUbwMnzumJgKuBKgyyRmktLS2BekJN3AOLNXz0yIGlEgm+F+EfboIaOEZ1esemjY4rLYa3MBdrgZUmr+3S1CtUxp2WCqnTwDnT6CmsbiN5hTArETbNL9Gh0Po9PhPhf8058fcVtuZhmJ2dDSguQjW2zMqWVFUKq/qUFNfV5yzxlcvllrFkATjvfDVySb2HL7K1SnclhWKhaq0ztu9VrQK2Z5U8BwCu5tYnjwDccylDW6MPvj/nrA6hyqI2eqiUfOQzwjY/q0aZukhhZcvu1RJAXX/N8VqSYaWOnoVQLXlSYWj9jFIOHNfiiRrWKKtR4lw06mVlCmuUtRUpnR+2Giajn8abekYjc5+uLGeoI8N5Up+U0leaNtCgSYemnfQM1DNdwFFVHbIaNSp0nRRzUqwzTqVSgbwHf4eLDyw20EgkEuju7kZfX5+7V1MNHvOy2i+1mhysktI4F+tN8f0pVIzm6X0zv828sd50okSqUsaYcyp3zflVcyJ6uLXHLRUTPTwOjY7494VCYRl7nR47ECy216+VKqlSkCiFXQ1xKQappjyY82N0U82cqh2qyNUwqKfv8/b1b33vQbRJSVLWuJeroCgrSniamZlBJBJBe3t7ANmycsm5EYnIZDKBSJq/y+duaGgIOHmWQOg7p6s9Q6loWPffdyuQIgz8no2+9LP5e5SvUk5mOY6nNcR6RrXnOomdTA3xM7Wneak1qJVbsNpQw6O6xp5FYMkxJkE3Eok4Q8v1XFhYvPmMDX9omBWFY1BmyamVOnD6DPbfjNiB5eWhfF77Wb4UEvWUrlE9o+eqImROikrjkw9+Eu+75324fMfleP2O17scMJt/xGKxQGOKxsbGZY0TmPfp6+vDxo0b0dfXh66urgBUrbko9WbKOeSc+5H/diR2p3cv+/n5m87HawZfEzg0hHRJClGDzPx2T0+Pu3qMTsNKF4avNs+V5h8Oh7E3vRfX3notvr/z+5hZmMFgfBBv3P5GF/2wzWUul3OeKiMdZXly7zg/NYB2fiqYNv9VrmLIF/L43z/93/jSQ1/C8NQwBloHcP7B5+PigYuXKQC7rz7jxYOsBrHSOVUyMnMZvOMn78DNv78ZI1MjOL7/eHzweR/EMV3HuHXVw22dCsqfz8Doc6lBAYLIgVUYK8kRzyjbV+ZyOdw0dBOum74OL+h4AV7e/fLA+1sjxfmwzSadU50X5SgSibg6WTXKPKfVKthP3fsp/Ou9/4pdqV0AgMO6DsMbTngDTu07NVCaokxYy5KuBG5WB8U6KpU4QuFwGHcfuBsfuecjeGD4AQxPD+PDT/8wjkochfb2dnR2drqe3xyaIuPnWPmwjme95Vz16L/+6l/xkbs/guGpYRzWcRheu/W1aG4Ion3FYtHdgDc+Po5cLhdoEUynjqxl3j2cSqXccyqLXdNvtsyoXJm5fdft+NCdH8J9++/DgewB3HzxzXjx4S92DoIOi2jp93Xd7VlUJ0X/rh5GueoImeP+4fvxxd99EUd1H4Xm5mb09PQ4shM9J5bNaARBBh9ZjNFoFJ2dnejr68OGDRswMDCArq6uADvZwr8+r22lEQ6HcdcVd2FuYc4p9YeGH8Il/3kJXrzjxeiN9wY8/FAo5Fo7ZjIZR1igQdbomMxq9uxVoaoUeim17pMzk3ju55+L52x+Dm5+yc1IRBJ4eP/DiM3HUJgtOIPM3B/Xmp6s3vWspR7l5Ct9MHO5o1gs4gM//wA+fd+n8dmzP4tDOw7FPXvuwetvfT1CcyGc2nBqADGxETHhQ4URmVqwDWJ8XnA9xqtveTUeHnkYn3/x5zEQG8AXfv0FnPMf5+COS+5AR6QjUEevhBjLF7DOkX0uwmuAv3lEuQqK60nC1sOTD+P7Y9/HoW2Hor29HYcccoj3/biGlBn2rmaJkHbw4jzZqUxJjIyQbfOKcubOOWyKb8J7nvMebElswcLCAr70my/hiu9fgZv+8ib0h/tdL3pLCtVqD2U5K4GIw66vL39aLtdD92u2MIvj+o/DpUdeisu+e9kiN6Yl4VJyU1NTASNMxW8JRNYoK4S9VlHyjb+9Edfceg0+ceYncEzXMfj4Lz6Ov//13+P/bPo/AbkGFhGqTCaz+Mz/LzqmIeVzUY7YJ53lRYSsGRgoL8Zn8MoZU/NTOK7/OLzq+FfhghsvCPzMt39WHm16xOb/fXanmii+1KjKIAOLB35qYQqv+s9X4ZMv/CT+8ef/iMbGRrS1tQU6zLBchZ2wCFXwardiseggL63bJWSt3ZnsAamE4MDfGWgbCERbn7j/E9jWvg1nbD/DQXo8zLOzswGGNIkIfD5b2mE7EtXLGHN84OcfwGDbIP79vH93JQj9Tf0YHR3F3qm9LkomSYGKU2E+veGKkDaNtL2BRYcVxErHnXvvxLk7zsULtr8ACwsL6Nvehxt/dyN+m/otntXzrMCeAsHcH7BEvrFNLEoZqnodEACYWZjB13/7dXzrkm/htC2nIZ/P4+3Peju+8+h3cN1D1+HKI690pBBgKSXgy4fRSOjNZrYbnD7bShHySoN/39jYiOx8Flf/7Gp89IyP4qP3fhSxWAwbN270rpXC1ZSR+fl5pFIpd541XaWwuK+tYC3Q49mHnh0oHXrTCW/C537zOfxy6Jc4vf30ZfCn9gZgKoywu6+JjK6xbduojXQqIaPxPc/ecTZesP0FyOVyuOy7lznCGXUGa1yViZzP512OWw2SRvdraYg5/4/d/TG8+oRX4xXHvQKzs7P4x2f9I27dfSvunLkTxzQd49aYcybZj/W9NNZ0qDXPrBcHAUt6RZ3YWnTnWYeehbMOPavqZ1/J6bXnZC1GVa0zObHX/9frcfahZ+P5hzwfH7jzAy561GLy+fl5b9kPoS56Rza/wrIJhbsthFGp8lVnIhQKIVfM4WuPfA2vO/F17mo5KgB79yUFhYfe5oRWalZSL+Nwyx9uwZnbz8Ql37gEt++6HRviG3D54ZfjrP6zHCs5Ho+7yFcJWOy+RYIFvdNwOOxuXGLdJqM9HgrL5q5USQHAKQedgs/86jN4dOJRbE1sxcNjD+PekXvx+h2vXxadaakO157zp5NBZ8KWqdRzvTlyhRzyxTxaGlrc94rFIloiLbjnwD14zaGvCTDPgeVQsxJNmFfT6zu1tlehMR/rtBLoNBQK4U0/eRNeeMgLcfbhZ+Pj938czc3N6OjoWPY+uv5Eg+bn5713S6sh065cSrj0RZiV7g3ntJBbwDcf+yamc9M4In6Ea1rDnvPUJ0r2BJZga70lTS+vUQ6L9sj3dcaq1CjraGhc4p60tbW5FBijYjpjNGrUJcr25XvXW751zOfncd+B+/DmU9/s9q2xoRHP6H0GHks/hr+I/gWi0aibJx1Mzh1YyuMDQVa29q9WEqB13tbyLJcaim6og9DU1BRAEm36wOaYax1VGWQA+I/f/Ad+NfQr/OLVvwgcUi2gXlhYCHh7utB6UHmoLRnECmQlObRSQ434LY/eguRsEpcfe3ngvX0wBfNxJITZ0g5tfLBWArVzcif+9d5/xVXPvApvPuXNuGfPPfj7W/8ehb8o4FmJZ7l6XDo5jMKUlMaDTUeJnq1CSoS7W1paAoiAvV2mkmj5zae+GanZFI77zHGIhCPIF/J404lvwlkDZ2Hfvn1Livf/wY9cOzZ8J2TKRgV0/OgMVdJkpdKRaE7g5INOxntufw8O6z4MPS09+PLDX8Yvh37p4FQqHNtq1Ze3YtTGSx70Ig0f3F2tPIVCIXz14a/igaEHcPf/uBtNkSZ3RqPRaGBO/KrKR/P4Wv5HeFXJkepAK5lR96MaY/zQyEM448tnYDY3i9aGVnz0GR/FQc0HYTQ9GuioxIjXpj3oZBM6JTpB5EW7wmmnPXuTmPYMKHft1TA3Niwx3Ts6OpwzbKFyRptcTy2bXGvHEwDGpseQL+YxEB8IzKu3tRePJh9FLBFz/SfYYZHyrA4d90/5HwAcr4VEQ8qMMvGfKENsgzueDQ24qDuJNLKG2pZ/1ssoV2WQ96T24O/+6+/ww8t+iGjj0sFGaIlyrhGWXWAVRK0F1vt6bRvIegohocPrHrgOz9/+fGyIb3Aeq21zR/KTZUyqJ21z3Gt1YArFAp628Wl4/xnvR6FQwLG9x+KhkYfw1ce+ihc8+wXo6OhwMCLnYUs+2Nd4fHzcNQahkmKjhGKxGNgH2zGNefJKGO5f++3X8JWHv4LPv+jz2NG5A/ftuw9vvf2taM234ngc7wwVDzgNHDvp8EDkcjlnCHgHcVtbmzMCa7X2Xzj/C3jVt1+FwY8NIhKK4Pj+43H+Iefj/uH7A7eH2S5QjH6VDMeogYgF84lUBqqAS0WZ5TzfntQevOH7b8APL/shWpta3V7QydShbF+drzYG0Wsl+R5KcNSuXNy3WvdjR9cO/PSSn2IsM4Zv/uGbePt9b8f/Per/onmm2TX7yGQyziDwGXXwWRQt0qYoTU1Nrg+28kI0FVWNw6cGmWvF+5b5fzW40WjUlYi2tLQEPr9UPn4thxopfmZ7eztmZ2cBLJW1qrNjb7IDlvgM5BpYDo6e3ycrOuYcmX5h74Z8Pu8uXSkUCk5/RqPRQGOrerHfqzLI9x24DyNTIzjx0ye67+WLedy+63b8yy/+BdPXTi9THHrguQi2OYfmb+zG1HuDdqd249Y/3YobX3Kjm5u9qssSdbStZyljXCnRrJKxIbEBR/YeGfDojug5At9+9NuIxWIBcgtLl6j0+Vyaf1UCjzKFGxoaXIRgDTLbcNpbl1Yb1/zoGrz51DfjkqMvQS6Xw472Hdg5vhM3PHIDPnbIxwJQNaEtYKkkjc6Zth6lEtXWjOWy7isd27u247ZX3obMbAYT0xPoburGy29+OQZjg85gEfrnJQ4kHalzpwdXc/X2LPgIjJXK1Irn9Jf/grm3zyEcWoLS+VUNsZ4HbT9JWVeDzIin0j7tvsH5NEWacHDHwdjQvAHbW7fj/pH78fV9X8dLml7i0i+8cF5bU+p76KAjQaeVZ1i5K6yaUF5ItQaZg0hTsbhYG6uXRqhBZm5Zc846B5/Rqqec97T2IBKKYHhqGMCSMZ2cn0Rfa59rl0wHkcgV642tgWIww/4R7MCn6+1zqJ/oQbngvvAqWpLOmOog0tjS0uICGFt6x+euZlRlkM/YdgYe+l8PBb53xbeuwOE9h+OaU65ZhCQRLEPRlxIqFKIrBfmuhbd0/QPXoy/Wh7O2n4VCbsmDVlamduhSOFEbVtjcaiWwVqXj1MFT8cj4IwCWvLo/Jv+IzW2bHXRIJcSDwny9tpv0Ndvni3Amcyda5kVjXM3dttML04E9DYfDiIQjKCJ4ebxGY4R+SRyKx+OBqJ0RBIl/tlxqLUasKYbmcDOGUkP46d6f4upjrnaRPaPkZDKJ0dFRjI+PBy44YJ6QERKVNPP/fM7VcmrlPt9K5/TNp74ZkXDwXmU1xlqKpUaZ8qPpJnIXtGnPWsCPxWIR+WIec/mlO8a1NaYlyalTp6knrjnPLudPqFrvHq41MuXvcq2ApRIfvrdG0VNTUy5No5cf+Moo10LOmyJNOGnjSfjxn36MFx32osXPCAF3Dt2Jl21/mTNWPKvZbBYAnH7RyzxI2NXnpPPDyhQibvVCVKoZCs3b5knUQ9pitVgsOufDEk2flAg50ZzA0X1HB74Xa4yhO9qNo/uO9tap+sJ5HupS3VnWalMKxQKuf+B6XH7s5WgIN2CuOBeA6LS21RfRa+cdH/t7reZ+1TOvwinXnYL33/F+XHjkhbhn7z247tfX4eN/9XGnXLTXs5JbaDR49zOVl7KqtVELo2lfjr+SlqUc5+44F++/4/0YbBvEYZ2H4d699+IzD38G52461/2OdociM5YkLkbkCu0q5GcNQL3X//uPfR9FFHFo56F4ZOwRvOXWt2B7x3act/k8ZJIZt76EtNLp9LJGCGokaIwTiYQj8VDGlNRVi2O62jn1DZv304YwtimMpnJ8ZU4612r24+0/fTvO2HwG+lv6MZYZw9d+9zX8avxXeOf2dyKfXKoc4AUSSkr0tWSk3ACLcCsQ7J/va/lZbZex7HwWj0085v6/K7ULD489jPamdmyMbXSOFlNGjCxJxmQUzZdPR67VuPqZV+MV33wFTtxwIk4aOAkfu/tjmM5N46JDL0LDXINzzLg2oVAoIP90imiQSfIKhUIBB8jHSan1uey6/2nyT3hg6AF0RbuwuX3zst9XNE1ha0Wq+AzkuGgppi27rHVUTeqqdthJa8TkI26txfjRzh9hd2o3rjj+isC8bO7MQuyWhKFKc62NMQA8fdPTcfPFN+PaW6/Fu297N7Z1bMOH//LDeNlRL8Ps7KzXSdCUgUY7jJa1dR0VFg8Tv6+QcbWO0yfO+gTe8ZN34MrvXYmRqRFsiG3ApYdfipcf9HLs37PfrbPm+6hYud7aSMbX1WstZSc1l8K1t16Lvem96GrpwrmHnIurj78auWwuYMRUMbHlJI2Ezp/RUnNzcyDXBmBZFLTWDiqHjZQVWqfiKZUftA1m6jHn0alR/M/v/U8MZYeQaEpgR/sOfOxpH8OW3BbsmtgVyAdzfcnOV4Os1QJce18TFz07tZZt3bv/Xpz+udPd/9/4wzcCAC4/9nJ85uzPuHVmRYe20uRn2eYytRDkKhkXH30xRqdH8c7b3omh7BCO7TsWN557I/pj/UjlU4GzpyVOtoRL0TpC9YxCLRJUr0DMrvvVP7gaAPCK416BG158g/dvrFHWUiy+mOKjQ22Jm08oy5oflk6nS/7Oty/4tvsdRjlUSAon8dCo0GkDcuY8yQJlzVo5bF7OTxfHN/dn9j4TyTckUSwWXdmEQl/0tBmlKVmBAscSikwmg8bGRkc2qtYolDv30wZOw89f/vOAskyn04EbtBgB2/nry8LW9NZ9RDBG1YwuCKlRSWQymbLm/u5T3o13PvOdjq2YyWQwMjLi1lrnS0+bETIjB42eKS/T09POwOncytmHctf9BYMvwPNf8fwAS5ryojJj193XmIKHm7JvG4XwmbLZrMs7Krmt0rnr0HOqf2NTNpRte3Y1SrZwNv+OcyUJxsfGX23uhUIBH3r2h7DwzAW31ywVGx4eDqRfOC89p9ofmqiKQvC63ipzhF2npqZQLBYDvASuvZ27b81P7DoRqatSgWfmmVUiVKmzynlrqk8dCGBJ1ivROeXMHQAuP/xyXHbYZQECIuVc56uwrb0XgO9t+4Nr+mN6ejpQM77S2a123e3f66AOJbJFB1p5RLYhEXUR94/kwoaGBlcFUo68e0exjLFnz54igKfMa8+ePetzX5/7+tyfAq8/h7k/1ea9Pvcnf+6lRqi4qslehBH379/v7hv97zqKxSIymQw2btzoPJT1ua/9WJ/7kzPW5/7kDDv3p8q8gfW5P1nDJ+++UZZBXh/rY32sj/WxPtbH2o7KmxKvj/WxPtbH+lgf66PuoyxS11MFGvhzgsGA9bk/EWN97k/O+HOa+1Nl3sD63J+sUS5kvU7q+m/2Wp/7+tzX5/7UeP05kIvW5/7kzL3UKCtCTiQSAIA9e/agra3Nfb9o6i9ZksJSkMnJSQwNDWHnzp147LHHsHPnToyMjCCTySy7uq29vR39/f3YsmULDj74YAwODqKnpwft7e2uC5DeIuOrV0un0xgcHHTzXWnunL+dO3sLp9NpjI+PY9++fThw4ABGRkbcfZ/9/f3Ytm0bDjnkEGzatAkdHR1obW0N3MpS6Vhp7rt37w7cc8yyDc5zZGQEe/fuxa5du7Bv3z53CTibJbB7ju+OTyBYC84L7QcGBnDIIYfgyCOPxBFHHIHBwUF0dnaitbV1Ga2/mnW3/ZwzmYy7d3piYgJDQ0PYv38/9u/fH7hHlSVSLGmh/PCe2Y6ODvT19WHz5s3Ytm0btmzZgr6+PiQSCVfkb0tYKp071047nLEm0Xa3YrnI2NgYhoaGMDQ0hNHRUSSTSSSTSUxMTAQuSGB/7v7+fvT19aG3txfd3d3o6+vDwMAA+vv73Z3jc3Nz2LZtW9lz1/nb9pi893tkZMTNk41NMpmMkyPWyfLM6l3gnZ2d6O7udlen+lqs1rru2lubZTOc9549e7Bz507s2rULw8PD7jYtNt0oFhcvxEgkEujt7UV/fz/6+/vdnPv7+7Fhwwb09fU5WfeVFdm5V6NjWC4zMTGBkZERHDhwAMPDw0gmk8hms4HWvZTztrY2N2/OkVfWdnR0LGsh69NFpea+e/duxOPxZaVv2WzWlZrt2bPH6ZiJiQnXhY76pVhcukqX57GzsxNdXV3o7+/H4OAgtmzZgo0bN664vqVGJevuG9wHllyOjY1heHgYIyMjGBsbc/8/cOAADhw4gMnJyWV3gLODWltbG/r7+7Fx40Zs2rQpIEc8s52dna4laDabxZYtWwLy7htlGWQuFlvLAUsHRA8Je5vaYnH+jvY3LTX0fW07M70VZKVN9HUI0rnrZ3Hu7J88MzMTmLu2tWPDALa00/tNtRdrLaPU3Hmtot4lyjpEGkjb3pN1r6yPUwNCo6Lrzefkix2MtLWgzyBXuu40BryujbepAEsXsmtHK9u8QZsOqKLn9+hccP7sobvShRirzV3Xy9fVTX/H1ujaumStayzlGGkDBTaO4HPwJiA2+S933bn2aiCKxaUbtlQWtHNRc3OzO9f8yvnQyaEB4Lz1TvSVLg2o5KzyvM7Pzwe60C0sLLi52Ov8tK+7Ppd+VR3DHumJRMIZt1K6xjbpqFTHUM4p874r/VQmtDmFdkwLh8OuL7e9+KXU8M09Fos5g0VIWLv+6RW0fFHHsI8B9YavUxvbaGqP7mp6N5Sz7nZQfvT2ODqLuq4q//w7X4MeK+eUGV5KwdarNlBb7Tmr6tRljbF6f5lMxnn+ExMTrsk+vWt7841ejUbvMZPJIJ1Oo7m5eZlyYvvEenQuUu/V1/IwmUy6xiF689MT1TlJ52kPtt4UpI3PeUjZKJ2HRvdMDzWjbjZS0ENc7y40+ix23ZPJJMbHxzExMeHWnR2X+DzaEY09ujlP29hBW/hZJaeGu5I5qzHWjlw0uLZZCfeHkdDExATGx8eRTCZd0w3+rRr1Uo6s7kM1e8L30KYU2txEL8Xg3PQqzpaWxfuguf7seMU9ZN9zohjRaHQZIlGLw+qTX+vAqbKkIqRxsLKnaMZq12bWMmd1QFXHECUZHx93KBAbsdi7tfWspFIpF/zQGNLQRaNRr2Epd23V0dQmPNppTh0bdj/j9+nEAXDIi7bq9d0G9UQNu/9EWNjYg41P9JY23xzVKGtXN9vzvJruY1W3zrSwlyrWsbExjI6Ouq8TExMOhrGXO7MlIqHuTCaDyclJ19uXm08PlpA151CtUbSbw5t6xsfHHaSYSqWQTCYxPT3tFCYF0Xr7a22cda0Jg9Jx0faMhFYikQhaW1tdhKnzU2WmBn56etqtqRqFeg+77qlUCqOjoxgZGQnApNls1kVvNAI8JIzu+DzsCBWJRBCPx5c1f6cnX60zV0pZ0eDSMeL/qVz5f3XwuF+2p7gaDTXGFtmoxRhz/pT3yclJNy8aZO7BzMyMc9IYqau3z/nw9wAELptgpMb1LpVqqvY5bI9tINjKk8qSTgU/Vw2kdunSbmr1cEZVx/DcUseMjIy4tabOmZycDFyBCiCg3Am1Uv4JFTOF0NbWhoWFhUA71krna+F1fdl1JnqgzjHfB1hELqamphCJRJBIJFzgUO/+z+U8l9X3anMYQOoNbTyT/BsOi2LZvtfV9j7nqNgg+6Jj21Sf+TLi8sw32M0A4HoVRyIRZ2iSyaS7yFrzhIRty4G+y30ONXL2sDB3RmG0l8Y/UZGyhevUs9NWb4zimQvW6/yoDIElhURFxENDw6afaSHteg0qKaIqk5OTziAzOiPEp4deo1uF/ABgenra3Zij7R6tV15J9GCjYnuzE40anQh7/SKNshps2yJUUQB9pnoZY11zaxx4RlOpFNLpdCBnzLPJ27UYgXHtbfRTLBYdB2F6ejoA+eqrXs+hjj3fTy8HaG5udpEd0zaqtxh1shWkRsn1jpAZkWUyGYyPj2N4eBhjY2OYnJx0L3WCqBMBuAiUTqgiioVCAdFoFJ2dnU5PVeNIW4Nl29Tq5SiayiDky32lfGk03NDQEOCyPFERMt/foiLa9lLPKx1T3rilKAWwhO5oCs0aZIX2q7ENVUPWeij0AalcCdERnlPYF0DAOFgvkoqODxuPx93f1+OwaGSlAkQ4iN6Sz1vVBvRq5NZ6WMNg85E8KPTaQqGQW7tYLObyOjzQegFCKBQKXDH2RDgYuvZq3DKZDKanp52s0PGxqASRFZvH1b7K9jYWqwRWkyOf82nJWjzMhBxpoCn7mlJQ5aaQqeYCV5pDtXJvn8MiWnpOaWC5/uRMxGIxR1IJh8POGOhehUIhxGKxZb2ZeVbqoYTt36s+UY5HNBp1MkD0RHPllD3tUexz3GodagxUx4yPjwegar07m3Kv9zZT/mmYifS0tLQsu+FKDVE1SBB1hEUf6JxxjYlU8jzSHoRCITcfAM651l7oa2mM9fmtk2FTSkSv+NWSAFW+6Igox0adP99VnfzbckfVEbLCd5pLU8as5kTm5+fd5Hg4bYSpxpGLxnyUFbhalZONeviZmUzGvWZnZ93h4ILzmrYn4m5SOxQet+QfElJyuZxr7k9UgWQ4PdCEvHh49ErFtbw5yRpEPTDKLbBXWiq83tjYiJmZGXc5g76fRk31UK5qjDVHTAdC+RKUeeYGme7Q+51tyqbSudUS+fgcaJ5XRsXKlo1EIo7kxDuCW1tbASzuz9zcnFPKpS6b4H7Zm5OqHaoz9GYeJTXF43GnUJVfoMbFcgHW4vYeNQwqv1xv64DSgBEOpiOklSWaV/aR5OxaVTo0P2ojQRIkFxYWlhHliJpQT1M++Nya719rY6zOt9ooIouMivnifuiFJdRHwBJxVNMhjIb5Usi6ltvnKjLIPgiA0IYaUc2raZRJ+FkVrxoVZU3qwa73vZMWwvDNn4eEykSp/MrArFXBlDtUICgAFv4Jh8POKJNgE4vFHJuRkbDmixcWFpYxKHnYfczYWtfe51QotK55b36fuVXukcrV1NRUIGKzB6AaYgWHOp5KgqIRU74E0zKa1+fhtoRAX7pDoTG7XrUSCe15VadX4XTmWtXrj8ViaGtrc6VNJA2Fw2F3yxYdbD6Hnil+5TmqNUpWQ6yGQnVRoVAI3O1NZazOm94IZVMEqudqHdbp1HItTYcBcDeWaXUDS8f4LMqDaGlpQXt7u3O4K72jXEcpJ7+1tTWQNuHnKmKpKFU4HHZIJ9NgT4R+BILImw22NMWXTCZdqRORIWuM1XHgmdWImMEZq230bna1C2tO6rLRMfMM1hhTcSmDtKmpaZlBtYaGRpmeld6vWethtpGCL7rnxrE0gULZ0dHhai67urrQ1ta2rMZyrYYlEUSjUVd+QsVJh4E1uvwZLwDn4ZidnXXQtV5VqE6RQjM+Y1xJDrbU86iQR6NRJBIJdHZ2urmxpMDOnUaPMtbU1OSuf5yfnw9E1ZaZXSnzUaHN2dlZFxHzEE9OTjqDTM4BIV8fQxxAIBdO2WZ0oZ64QpZ27uWO1dAsi0IVi0W0tLS4ErdEIuHKOVh619DQ4GBJ5onp3PHsKuKk9wvT8avWqVa5oXwACKwPHdF4PO6cO5KL6LTp3vpKjnR+awFb2ytSC4VCgLTK89De3u7KkRobGwEs8T+ARVlqa2tDd3c32traAqWX1ZxNq2MKhaVeEVoSprwBzcGzDI93PM/NzTlEVMvi1gp50z1Vx9OmliYnJzE6Oor9+/e7ng1EROlAU05U3rgGsVgMra2tAQSS3yNsXa1zVDVkrdAXDzc9bns/rOY3S9WAas0ggIASqQXiKzV3G91ToPQeZC5qLBZDd3c3ent7nUGm52oFzObI6zHUaSGDUmtBE4lEwMNTIhSVYCgUwsLCArLZLAqFxVpINQDqGatBVhJOrblMHfo8sVgMHR0dyOVyiMViAOC+z7rnfD7vmkAw76aRA42EQqP2tRrEZ4caZKYztJyPBpmeNlnhCo0quUM/Wwl61hO3e1ItidCmZqy8q0Fmrre5uRldXV2u8QSNsiJCRLympqZcrTFlTqFLNtLRqKLaM8xntjwJy3RlVEkDBcDlyrPZbCBK5hr5CH+1Dn0vNRIawDA65jMw2mKKoKurC52dnYjH4w4pUtiXaSlFL6o1BNZJJuJGp55zUuNLzhChd85vbm7O7Xs+nw+kLCo9g5Wstw0UlScxOjqK4eFhh2ZNTEy4yg425qHeV1lQvciomDLGs8EXz4JFK9Ysh0xB9UWZCjHrZfNkalIB+5pn2JwFjbIKtJ1DpcPC7bpxllFIw0ZHobm52Xms7BxGb5RwqnpUPoejlmEPC+dFg6zwINEEXVPmjgkzacRm524Ps4V/CEv5frfa52HXJ8LSLOOg98nL5bPZrMuFq2FjBMsDZI1yNdGxrg0NjBK5NFdMD5tRMYDAZ6szwGifKAYjZRshqzH25fbLie7tMyhkTZlRB4KpkLa2NtdxiM1gqGSKxaIjfUWjUYdikDDFPKc6683NzcvSTtUMrg2HImuW40E2+PT0NEZHR53zxvXn2ljUp55jNURO860kYFLPsOtZd3c3EomEO/d0ICjr3C/NNVcDEdu1VR2jpDfVkzyPTIVR/1PPs+TMxyGop0G2sq7z1K6Lw8PDGBoacmQ6OtaMjpnz5p7peeQzqIzRAPMrjbF1PioZVUXI/KqHXb1AhQwIe6iiZNSgh0ChOh/caHOP1Qz1WjXvoWxLzYsQltWOSRpx0lDxWVZT/LUIoRpAHhY+A71RLTXQekFL6CGpiy9FKq3pFAABAABJREFUIRSKVmPHF50lFbpqn0U9z3g87owTPf9EIuEcn4WFBWeYFe7T8i4aNXXsqjHEOnywLx05zUFSyfMzuSdqTDVfTAM2Pz8fqK9Ww631tL4UQjlG2Td3NY6650xzxGIxZxTa29sDHY34+2oAaZCZuwUQkLN6pp3UcChMbV+hUAgzMzOB7n4WzbJrWS8j4TPEdu25PlwL7WLFaIu5Sa410SrljCi7t1Y4WNdWnUfqS5umjEQimJ+fRzabDegCPYOco7KPy5XfSkapVCqjeFbPjI2NueY8trUq/9aXUlXEUV9K7qLRrgaN46jYIFv4zRpP3wTsQ1GwuJC6APrgttC6HvkHn4LiS0kdFsLSfBO9XCoo26aPguiLJKudP9eba6UbzgPDA6E5MUUBlBnMmjvmY8mEZ6MBJf3wbzS6IxzIqKmS51ADQIMci8Wcp62EtGg06mBSTSPo+3HNKSfqpdZqlDWa12gsGo1idnYWsVjMwaD8yohHOQaaRyTKwNp2OzeNMPTAV8Pqt8bBRqq6D1oyxByZjY45Z1vywfdVdEud3XrCwjxTnL86kfxca6B1jS1yor9Ta/Sm+kINr75s+SYdOJUZ5tqpa/iMajA1qNHPZXRXTfpM9ZWmqxgAcJ34PaI86mwwSCH0TQayj2lfD6Ns11yNsXbiUmSLfA+mDSxp2DpuPvRK9b2VqWodvIoMshoVVYIKSdjWYdbj5/d4QJWBSmXDEgafUqjlwPgiBqswONS4KWVeDwpzZHx+9ZR8eUy7jtUMCgUPHQ+2EmkUeuf/mS+0Blmbi5BpG4lEnGc5MTGBtrY2dyC1vtA6CZU8g0JBzEHyM2hUGR0AcJ9po3xFYbj2qgBqgcg4RxpiMo5peJQdzmYCVK6EctWQKdeiWCw650cPr+ZI1UBWA4X55N2y8vkZdLI4b41qNM1ERcyzzvlREXJultVdK1zt25vVntm++HfqhGj3P190U4nc+IyxPYeWxUsHjDLGPc7n865JCBv3qNKnvqGR5D76qlhWWq9Sa2sRTIX51SFWvaMGDlgqXSSkS2ejUqeynKHrrsEE9ZuSQUk+1jSfokUrrYnaKZ+hrhaJ46g6QrYRg23sroxK2+GKQ40hD4iSldg0nA3Ta9lMjcg1elSvToVYDTJbrDU3N7tcrH1WMhNV8NTTBVB3hqEVQpbjaGmO1pjSyPJnhGu0FzONW0NDg+u5SzKO9rNVz1AdmXKHjZABOGWiOZtIJOJ1Ovg82s3L1gRqRFnNsE5DW1sbisXFZhma72P+mCiFOhOcvzpDoVDIQfB2jnqmiBBYskg5Z8Dm1RQJUlKQlraQRGedSp4HTcsQBdD6d8qCojb8vHqWLerz6TP6UjI2GtV1tpF+Lbk/nQuf1/ZT0H7OhEZ1Pfn5AFyO3yIZhLbpCMXj8UAQY1nWtRgHRR34VddaI1Ht8MZ0jKajyLuxJUH1GHb/tUmVNcSWIU7ZWE0mrdEt5bjV+kxVGWSLqTNKoIemxooHXzfALiBrYVWhkcXGW0HUIFeTt9TPtd67Rsi6wBrZsHk+/2+jYio2vh+Ngh4ojYLqCb3Pz887AWRuhNCMdoxS5rt67NwHzWfRCSGbWSEwZUBX0zdXvUtGkFQqCuNxn5WoQWIVn8d65NwTzSXqWvsgqZXmSachGo06BUoloxdIUHny5zSgodBi45V0Ou1yxWSh6rnQr3wONcoayVUTIXMNlQHOSJwOEcs2lADlUzgK2SlkDSAQgavjq2mgaoc1wnxpKslWTfh4FRbhsw6IPm8l8lIKOtXzpq1S1UlQZU/nhX+rTrAaYt7OpN+3MlIvY+HT2baWnY4pgEBqh/nwWkuCypmfrrsiExoN+9afMrGSjOp62hd/rl+rGRVD1jp5jZJ1A3ioW1pa3O/T8wMQ8NgVLqAy0vIFEhzs1XnVPLQPUtKX3Qz+Hg0yIWLbfUgbddj3pSK3c642h2JzghRA28qRt8jwBiWtCydL1hc9qJEoFoPNOFpaWty1f/VgzKpyZ5St70kl6mtioQZZHSlNFXBtFbasJLemsm5LQVTR+JQmHTIArkkCa5lLRTEWSuVZ4lefsl1tqCLVr8ASX4MGWT+jVBTu0wE8Cxpt+KDjavPH+jc28reRsc8A8lyqcVX5Ux6M/Uz9vErW20LW6hxYvoquFfWQ1s+ydFQrPmhcALjGRfbqRR+HpZrhQyBsKkyjf66lMscVfbE55HpGyjpXLZfVs9XU1BRARDUA86U3rJyUIhzXY1TVy9rC1hopq6JiCYsKs7I7Vflq7kxb4DGH7POsKlkIzYXYci31nKhUuCn0uHlxActDNKFPBaqF8pw3I1AOduPR+Zf7HCsZY42Qk8mkM8osySFxiwqCDhE/XyN+RvvaVEQj1noIoO/waCSjSo03QpV6Hs4pFAoF1scyixkxA+VdA+iL5DWC1fUiwkAYmIgOc67kG6ijYI2MHnzr4FaTEy9lGH0wqMpAOXnU1aIF/fxSc6tk/vZZbATOc2DTNqUuNLCf78s5W0eOz73anK3x0jInReMUraPe4LrPzs66HtfsHRAOL3Xi0/OrPeuZJtH9q9Zo+CBqheDpHGt7Y+a0iSrZwKoeXKBSw55XdWxZtUGeQ3Nzc+ACGs7dMq1pl2wQWiubutSo2iBbhqIaUraYZG5YI1CFFaynrkXo7BTEOsh6EAJW815t03/CRzwswFKZCudBw9zU1IRsNhuoUWP5CA8Oh+bUuZ7lzt3O3xpj1tfRcGm3NC1tUi9Q906JdFx/bRmqZST1OOhcX+6D1slqcX86nXZXevJmHOarFIFpamoK/C29eNYvc5QrQ5R1GlaN6vVw8hArIYuRP6MG7rdVxtYgK+NZX9VGyNYoa6So0WElhKaVomBd35UMeiXztkbMOlvMZfKSADVmNMoMBkiKVPiVMqdlSYSDy11vC6NTz6nSVx1DXgYRuMbGRle9QTSF8+dcNMVEvcJafgYuNrKrJWdr88U8i1xfwtRK9mSVRCKRQEdHBzo6OgItPmvRHaWGnh3aEQtNRyKLvdn1WlRF2kqRRTWy1sofDczqZZiryiHrAijERiJWV1eXa6PG1ob0plQg8/ml7koUKhrj9vZ2dHR0uJZ9tbapLJXf8SX6KeyhUMgZPUZIVK66DlY5E2ZkWzu+l86lsXGxFR7LBsqZOw+6OhNqjHm3Kvsq6x3J/CzOk+9HIWXdaUdHR6D7DCEnfqVhrnY/fMaYpBBGNAo7aunVxMQERkZG3H3V2gVOyzL0Agjur1VUClOtNGxEbeEr7jmNkTqpdOh4YG2+09YC83O0bIqK1uYHV4vwbb7VGjMbAaiTsRIRzmcofQxuG52tZqBLfZZFOmx1BBWoygkNBh1TKly+F/fFOm76fyplGwCUs+6+CNl3wQjL5Pj+bDkJIBCNKj+B81JSkhK9LAqk1S3lyLtvr3lO2QdaU2KE1FmaRXQnHo+js7MTPT096Orqcrq8tbW15uoHOzQqZjUEHWieoXg8jo6OjsDNTnoJjKbClLtiSYxaFmidi3o8S1URsi6C5pGZY+zs7HRQImFe5vqUqau4vj6oRma8yKEShmmpYeEXjcw0x8MXAHcItE+rMsWtgHNNaOQI29suOqqwys2rWahaDQ/r7HhYeHOPeq12//j9trY2dHZ2YmBgAL29va4bmW1QoA6TRSzKXX+rsBgJsB2m9V6VMa79aBkda16Xh9AacxpkNTiVssMtBKgGjAaZ68qf81ktlG6dUpvDt1AbEaJqzsBKETI/s5TxXO09rbFUXoHmLyvNtfmgUhvJ2s6AbG2q0RtlSktc6Jxq9Ko6gIaZv6/OymoGTdNiFsnS6Ns6F3xOOv58D85DS6T0mWkENYfc1LR0/7Q2sLDnv5I90CoOXqjCAIDXRnJdCVXTALL/f3t7e6DRTz0jZDXIJFJqlM42pMo9SaVSrg99a2srJiYmnI7g81N3+8jLSly2vRFqGVVD1gACUB6NqiptHh56HRQywgHcGC31sA/sK0eodiPtQVfISw8JDxQVqRJi9FDaPBMPDg1ELrfYRIHRpSVc6EGvZe4rMUsZBdscGD+f+afOzk709vZiYGDA9c/VhuncE0LbNldV7vx9BlnviFWokV/10g8qXUY8ClcT0tOIRJWXNtkolUv0De45Dz7Xj5CWPpvKJ2WIw8qSGjBN36hy0XLCWsu4Vno++542wubPrexrdKyGg86FNciVGGXC/dqukdCiNaJ6ty37VjNto1GmpqLUMefLwsqqD7j3lRhlTRFZh0jXTrtf2TWwTpsiLKFQCK2trQ4+ZvCi5WjWCSt3WKeIjjHXmM6OBlhaHcDAisbY5pDrNfRc0iFhdMx95L5SVpLJZKA/ONeIzgedN+uAa7+NUmz2WkZNETK/+uBrVSTqXXKDAQQOr3rU+v6VClGlz2AXUj9PPS8l8mj0owQkHjC+j3aKyWazDq5RpVtLGZcvKuNcaaQ4TxpmNq/gPAit04ukR0tvVp0j202o3EjN50xoyoBQGEu0GOXQAGuplhJ1qCSt46SfY3OClglc7rDPaCFv33tZQ+YjVqmy5fsSdbKvanJVPrhY/16dAXUOdM/setkUihLx9DMVRVDHrRx50dQS0zJKStSSFv6bBtk6dMzBElFRI2Ojb21kYh2NctebX9UR8e2bzoVIohKIdP3s79Px5xlSVMhHIKsGqrb7oL0O1NFhhQz5E3oTku/SBV9p2UprWe66q7Pc0NDg5RzMzc0hGo0iHF7sRqjXpfoc33o7v6uNqg0yR6mDrg9S6t86uHBKVmIDjlI080oXS5UE4QYbgShxh1EKiWr08Cw5h4eYyiGXy7m5smxHG7EDCBjPcvLIutb6DCRkESKanZ1FOBz2toXjweFnNTQ0uKslmT/WPLKtf7Vd2SqBnWxezQe1a/5bD73m9+m5ag5U18YedDUqPuNS6fAZZjs0Wvatg8/A2ffUfa4myvQ5zGrcNadIpa1OixpZna8PAteX/TxbJlIumkKDOT097fKW2u5Q0SCNkpWvwqiZBsoqZ4WVrTG2z19OlGnPtnZys42TFEUBlm4ro6FlhMeIj7+jMLY6nRYVUiKn3bdyhnWgWe2gd4LT0SE8TjSQ1xFqCaztZW3hf98Z0K/lDMod19CHylF3z87OejvgWX6GyrmVFcubqEfwWLNB1mHzSqU8Z+bvVCiVVk+4iWUjFs6rJv+gn818N0uz2JOY+Q3WVDY0NLirAXnzjTbaB4KtNakACI/xMM3PzyOTyTiYkwKjTL3VNlPnT4jf5sPz+cW+2pqL1YiS0TuVJIl4PT09AfIFGZHacUwJYZVEayrQtpkAoS82fee9wiR4ce58NqY6+Ay6n762pfZg8Wu9UJfVnrvU93zKUSHJlSLaSuenMmMdKiB40QUZ6j7DrPO0BtnC1RZV8snLSs/D95ybm0M6ncbY2BgOHDiAyclJF/FaEhYNhtbFap2sOnJU1hY90eeuRuFapJApPNUxTFvpmVLiH88lyz6ZFwYW+SuM/Fn1oekrjfZV91bClbD7yvfmMyivg+dSS/14h7M69Nbg+Yyxrq86dfx/uUONvTrm+rmlZFLnomtBWVR5omNYLQpRatTFIFsvRIkMWoCtUIIqTGL37A5FfB9A4BCpMahUSdnIkgl/ChzvdqU3l8/n3e90d3djYGAAXV1diMfjgeQ/FRqZwsxbER5raGjAwsICMpmMWxt+nyVSmtMsNXd72LmeFCgeikQi4ebBEii2a1QSkjLBu7q60NPT46Lj9vb2QO23jdKqyR2r46LlKTTIo6OjAYNse/+qUtFn0VyreuKq8EohN2s5nmioq9QcVGa0ZpoKUqF+W55Do6SKUSNGG20yhaBQq68F52qD0cjs7CwymQxGR0exb98+DA0NBVi9Ci1rlOgjfmmZEAMCdSasA2KdjnKVrdUxmiNOJpOB2n6uCfdCy0dJju3o6HB3hE9NTTldQYdC16sUH6ZcJ9QXUSpczWCJBllrehWlU4NsS/U0wOLcFYGwxrgao2x/n5+rhtmXOvCtgU2dZLNZxyInZO9DIqo9/3WLkH1eldYfA8uxfT2gNMpkS2rdKI2NzUNU+uDWe6Xwzs3NBQSIETLzISxf0iiZ0HaxWHSeq/YB5kUNPIw03IVCwUXe3MxyDrxGO/y/MoZ98DpzxvTk+PzM75B0wYvQWdJky8wsXFopkcE6azTK2nJPb2TRqH5+fumKulAoFIjwqPhsOQJTC6XyQWttLG0EUMkoFUVX+n4+J44RsGX8KyypDOO5ubkAGkSDrIZLSWqqVPmZlM1K8t/qwJGAMzw87KJkni2ep1IETctmVn1hZdgq6GpSG3pGqWO4RnNzc46X4WuDStlW49be3o6enh5nkHkJC5EASzRTA1KpI2HX38LVNMhauVAoFJw+Vp1i9Ye+p5Vvy6lQhIWjGkTUN2waxzcPX2qGssYcektLS6AxkS9KrtYo12yQdcPt4VAvjYpBa2A1z6TGjSQOhVbj8XgAdqr0gS18Rxb0wsKCt1+wr9MLBU6p+5wzow0t69I1ofGlQeRGVuLBqiCpoVRCFx2CQqGA6enpQKctwjXsxUwyF8vLlFVtO+r4hLmS9fdBnBqZqFJX5aryY42wvtQgE+pTD91CZv8dIthSwwcJa44WKH/trVFubm52zW3sWmiOn/AcHV9Ck3ScrQFU5AdAACb3EbvKeX4l4hAu5Yufr7+nULMaJULuGrX7qjpKdSqrZNi0Eufmy6dSLpXMFQot3f6kMh0KLbKBfcQo6yjbn1Uy1Bj7HDXN3QNw54vBi16uQwdOme2+KFzTmqrHuJ61OLjlPrP9t52n9n7QVBoRPItGPKkRcqlF9nnNjCwpQBRMClqxWHTsN24O+wVXasDssPkt7YZExW5ZraWEms+jjoWuhxVqzlvzO9V44MBSuZnP0+Nn2Lt4uRf0xOmF6wUeihDYPKyFgKpde3XMlPTCcirWqyuUxWdWOE/LsfT2G9axa2exlaKS/27DRicKpVIeyzVs6oTqmitjnmtC2cjnl243UyeTBgtAoEWsku0AeBEcnwFZbd78qs4E507UifpCDTcNANeSBo7vpfwRNh/SpjgrnYNy5m0df1udYDs9sTOXrovKvhorG8GpHtXzpHyKcufucwLVGNumKQxCOA/lcajjpg1DfEZY9SOwdJd4a2tr4BnrZZRL6fJSP+e8tDJEb8fzcS50vSvVMXWDrH0bqkZAPR9uEIXGMpzn5+fdg0QiEcTjcWeQVUArGfaQKyytyomHhQqGBs7eh0zSF4BAnaQtP+Cc7cbUYgz0b325Xc3X2IihWFwsf6JS0npj2y+5XobLKlbWBy4sLAQ6uxUKi7XpyWQSDQ0NSKVSAOCUb0tLi/t9Es+Yh9dXa2trQNnqBSX1qmlfi2GhfVWGJAiWq5ysvNPoUuY1OuTlEjTI7B3OFBLXlMqWNan2akGeK57vUiRAnV+pufO9iOawgqCpqclVD1AXaL+D6elpp0OYY6VRpq6hw9bd3R2oLujr60N3d3fA4atGXlTerTOhDijTZjRaGo0p14KGiVcI2jQgnQw6teqEVlIrq3PQ3LESmtQQqSOn+od6h7A2nTS+py+lwPdhfrajo8M5NuVyDyod1oHyrZM6fHzmlW6RUoNcTf67KoOcL+Txzp++E1986IsYyg5hQ2wDLtpxES7bcllgUzX3xwPGyTFKI/zLOjx6v3xFIhG0tbUF2m7WEiGHQiEMTQ/hLT96C37wpx9gemEag7FBvH7z6wNQKNvmacRAZjPLsfg8ZFHrBQhkW/Mw8XerKQPh2PqxrdiV2rXs+6854TX44HM/6M39aDkWAHdAlMmpjMhq611XW/MjPn2Ed+4v2vgiXLbhMkQiEUcySyQSrkRM1y+RSKCrqwv9/f0un6/9t9XAUDHZnLKNkssZt++6HR+680O4b/99OJA9gJsvvhkvPvzFNa2Lb+jhV6iWt4tZJ6kco0wF+dkHP4uP3vNRDE8N44jOI/Cmo96EgdYBlwoKh8NOoZCwk8lklq1lQ8PiFaRsRMH8/+zsrDN6NBbWIJdLxqQBfc7Nz8He7N7Fb4YBbFx8HTt3LJ6TfY5zMpnm4vkLh8MObgfgPp+s546ODvT19bnOdDTKtu+yLdcp50yow3Fg6gCuvfVa/GDnDzCdm8bG5o24JHqJS6vo3cF03NUYp1IphEIhx6gmyY3BCZ9NuS7a4dCSqsoZhUIBH7zzg/jWo9/CY5OPoSnchCPiR+CclnOWtRgFEEgH8DNIyKMs8XsauNjOZXTiuD+5XC6QgmPUvdoo56xamF/PldULivDOzc0hEoksa0qja+I7p5WMqgzyB37+AXzq3k/hhhfdgMO7D8c9e+7Ba773GoQXwji99fTA72p+kxGlEpCoLPnAGlkuLCygsbFxWS1htYnzUCiE1FwKz/38c/Gczc/BNy74BuLhOB7c+yCap5sx0zLjYCVCSTTAmUzGGV+bJyGLmq36UqkUpqenAzV6FnqrJkr75d/8EvniUlOVh4YfwvO/9Hycf9j5y3KvmvPhAaBy1pInX761XiQKrnkoFMJdV9yF+dzS9ZQPDj2Ii797MV50yIvQ19DnImAqQgABKLSxsRGxWAxdXV0YGBhwHcU0920b0vhY17YOt5znnJqfwnH9x+FVx78KF9x4QU1roWtiP1tzp5YlrAS1cpUT3/+m392EN//4zfj4mR/HCb0n4BO//ARee+dr8eVTvoxYLObWmOduamrKnS1yKIg00OBqAw6WwNAwUMZKlT2tNnhWfnDhDzAxuVgSNzExgd+N/w4fGv4QTus+Ddt6twWgat7bzSif7wNgGVmqp6cHAwMDGBwcxIYNG9DV1eWYwerA+ZyJckdyNonTv3A6nj34bNz4ohvRtNCEe/90L/KjeaAVzslRprRG/NPT0y5lQP3Bfgbq6Ou5IS+EjirPdDmOkELhP9v7M1x++OU4pPUQpNIpfPw3H8f/Hfq/uGzuskDelPPz5VvJ+yBiypJW7Stg868840SEmE6rJL1X7ln1GWRfSoifyzy4IlfqWGhtssp7pTaqKoN855478aLDXoSzd5yNfD6Pja0b8ZWHvoKHJh7CGfEzAgZY80iEpC2xgn2IeZDoJTJq0rKFWnLIwKIzMdg2iH8/79/d53SHu3HgwAHsHN/pDDIZkpwPoTAqAI2Qc7ncsi5Bs7OzKBaLjpxAI6jNASqNRHtjve7fxWIR//izf8TBHQfjlI2nOGWqRlh7cxM2VyiQCEWpi8PraZT7E/0BFujHf/VxbG3bitO3nY5MJuN+j5FaNptFPB5HJpNBLpcL5I+7urrQ29sbUKT6HEqYUURC17sSp+OsQ8/CWYeeVZe1WA0+VIRDmzwQ3iRKU27EEwqF8LF7PoZXn/BqXHH8FcjlcvjoGR/FD3f9EN8f+T7O6zrP1ZNSmVKOiKa0tra66I0pJ9tFbWFhwZ0FPqfyMSqBTrk/G9o2IB6Koy3chq6mLnxr/FsYaBrAaZtPC0D709PTaG5udkRGGi0OonF6Ex1lqL+/H93d3csa4fiQrEog6w/e+UEc1HYQPnvOZ12/9pbZFvxp9k+Ya51zZTPUabr+dMimp6cdEZPGgGeaqJFeQqJNOWybynL0DPXqjefduHRrVi6FKwevxCt/80rsL+4PoG/hcDjAk9E8vk+WuTfaQInPT7i6UCigtbU10H+gErZ4JWdVjbLqCLtO/GwNeuxz6/esjarEVlVlkE8ZPAX/dt+/4Q/jf8D2ju14aPQh/GLoF7jm+GsczEBBURKA9ewVYtV8p5YpqCHm39SS3L/lD7fgzO1n4pJvXILbd92OgdgALt1xKZ7X8bzAoaVXxJpFGlh6ejRgVEKsQ9baNB4CHhy9ucfeFFLpmM/P40sPfQmvf/rrlwk8oxYaaSDYPcjHUl5LwhPfi++9UFjAjb+/Ea894bWO9eu7apBrxK9soM+G8Z2dnYFoQHPgpZRpJYq1Hs+tzo1645YRCyzJthpFOlh0FCutLZ3Pz+O+A/fhzae+2SmcpsYmPHvTs/Fw6mFctPGiQM9f5vJV8SuDmQaZlRDaIEKZsZQ1Xx1yOZA15ZXnZi43h59O/BQXHnQhuru7A+ktEr189+0qb4HywzNueQalUjbVOKjUM3/9zb/G7btuR39rPy446AI8Lfo0F/mp8VKilhJAicpptMn52WdSlrNtyFHu0GiX+jm7kAUANOWbMFOYcfpaYWhe4EKdyZ9TbijH1I9askZ5sfK8FnljDl+ErLlq3WsLcfuQLV/ZVDWjKoP8lme9Bem5NI74lyMQCUeQL+TxtpPfhgsOucDdnhGPxwMGVpu2KyxDVjW/KptaIQQdtSjTnZM78a/3/iuueuZVuObka3DPnntwzY+vwcIJC3ha/GmB9pMsswLg5s6fqaBTMLWdHHMiWhvMV6mi+UrGzb+7GcnZJF52xMvcZ/PKOeaxWauoddzqFPgObj3hah2q1G559BYkZ5O49OhLAwQ0FWRnPP7fIafjQAXE3LcqIl1P61jUO+qv9Nkt8Y6okYXHeLDpyFoijTqn5Yyx6THki3kMxAcCKYv+WD8enXjU3RtL75/kLMKMWgqleVnyKkgy4lCoWhndlcq5pnii0Si++/h3kc1lceGhF6K9qT1QHwvA3XltqzgoQ2q8NM1BlEgjY6uUq5EZ6pk3POMNeONfvBF37roTb//Z2/F3B/8djmg/Aj09PYH10tSe1t4zElVjwHQZyU9KXNTmRpWmCnyjUCzg+gPX4+CGg9FT6EEKi0RLygvTeXpOeZ595WhaZcJ0HuWEZEzm8NcihQYEAwR+9UHWdOhY4aDVOaUY7LUGjFUZ5Bt/cyO+9NCX8KULvoQje47Er/b/Cm/80RvR3diNM3rOQHt7u8PcebiZP2CUqd1fOGi4NaorVcdYi4A9bePT8P4z3o9CoYBjeo7BQ8MP4abHb8IZJ52B3t5FWDgajQbu09UCcF10ekUqdDZPa5tv8ADp7U+VPE+xWMR1D1yH5x/8fPQ097gbkrT9JG+8yefzTjlSGdmG777Du1ZG2c19+/OxIb4B2Ww2wOgkV0DJIqpQqVT5bzXGpQ7Jk2GEdejh5vPY0joOngEiHjw39vaeSmA8nYNzDiJhhMKLNwWpDLOSgJExDbE6OYze9Ew0NTW5SMeyuDWNUG6ErNFtsVjENx7/Bp570HNxSP8hAbYuFabmSoEl54ZQKPPH6sjZudVT/qln3ve892FhYQFHdByBh4cfxn8O/Seevf3ZKBaLTj9oi9pMJuNSTjRgdMC4rpT79vZ29Pb2oru7O3DfsNZUVwO36+tjj34Mu2d342/jf4vJ8Un3HjTIvMkvn8+73hH8P/dAh6KGWoLGfWlvb3fsd+2KWC+jbO2HTWPZKJkyTX2uKJ4icr6ouZpRlUH++x/+Pd5y6lvwsmNehkKhgCN7jsSfJv+Ef3nwX3Deuee5JgHWm85kMstgaJ9HwUXQBh31iuA2JDbgyN4j3edEIhEc0XsEbnnsFudIMHedTqcxMTHhiC6Egek0qGLUDdUSKgoZL+vmVWS8jrEa2HpXchdu/dOt+MqLvoKFhQVXpjI6OoqhoSHX85fwIw2/trbz1edWSlypZnDu/3H+fwTKK/TWGuaXADgoS3PejPRt3nitHYpqho3WKMs+mdaSQfIb2NK1tbU1ECGXe+B7WnsQCUUwPDXs5hOJRDA2M4aB2IAjYfFcZrPZZU1vbJtKPbMKTzPCs3W3vo5pK+2PrlljYyP2pPfgjn134IazbkAsFnOcAkXWuJ5AsCsU58d0lK1N98lNPWSHekaf44ieI/C9Xd9Dd3e30zGK7DAKY79o7fbH+bEsqLOz012ZSoPsy4NXoy+5Bu9/4P24a+wuvG/H+5AbzSEdSbs9VPIWU3a6xz60S40Zn5+3zDH9xFQCDbJFPeoxLFytKIXPQNOBoFOsqEoltd7ljKoM8vTCNMKhJS8iHA6jMdKIIorO49E8CEuGGC0TGtNbS4rFojPAKlBKCKlHTuHUwVPxyPgjAdhiZ3InBtsGEYvFnPBTwbDBvZY/2RpjYKn8IBaLuU1Sz5xQjNbPVntorn/gevS19uHMbWdiOjsduFN4YmICExMTbm0pTCREaU2uGrInyphd/8D16Iv14QXbX4DcfC5gfLToXpEGYIkpqznvenVXWuvhO/yqgPmMQPCiFV8/5kpzyE2RJpy08STcuvNWvPjwFy+uTQi4bfdt+Jvj/gYtLS3uMwn7sgEI84Bab0yjzLlTSdH46fNxb6pxqDVV9eXffRm9rb0497BzESqGHD+D+sU6k3Rs+D4ajdlUjXUS6iU7qme4749nH8dgYhDt7e0B9I/rrNGwnguNMhsaGgJNTXwOdrURP3+vWCziH+7+B/z4wI/xz0//Z0RnotgX3hcwVFx7OtOl9JgaM0K/DES0Z3dfX1+gdS/1VKkItB7DokZWDlQGfY7maqhPNVFyVQb53B3n4n13vA+b2zfjyN5FyPrjv/w4Lj36UkegIDGFDdFp4IAlWI5GTaMhVQjA8hueaoUir3rmVTjlulPw/jvejwuPvBD37L0H1z14Hf7pL//JkVt4oFlvTEeC0ahS95njISRM0ovCdwq3qkKoBq4uFAu44dc34NJjL0UkFAnkcmzjdwBobm4OwL62HveJgKrt3C879jI0hBswX5gPNJHR0gE6OgBWVPKVsEirHdn5LB6beMz9/0+Tf8IDQw+gK9qFze2bV/zbZVCxyQfqz7R7ka6HrkmlUDUAXP3Mq/GKb74CT9v4NDx949Pxsbs/huncNC4/9nInozbXS0OhjF/fDT8AAmQuiwRUK1/u90LA5x/8PC4/9nK0trQ6J5NGzDK4geV9wHVO2ohoLeVe9cxLj3gp7t5zNz7/8Ofxoed8yKUJgKWaXfZ0p0OqHBVC76pTtAmIbWJS65m45rZrcPMfb8YnTv0EYohhIjuBTDGDXCjnfofry3y3Tyb52TTGbMYELJ1pBixkv2tTFiXOlvscpc5qZ0snBtsGA3PjV58xVnmi/Fi59lUO1CpHVRnkT5z1CbzjJ+/Aa//ztRiZGsHGxEa8+oRX483PfDMKC4VA4tsy/ZTib3PGpPhrF5p6e69P3/R03Hzxzbj21mvx7tvejW0d2/Dhv/wwXnbUy1wtpRotrSezjUvUIAMIEG5UOakyKOW9lvtsP9r5I+xO7cYrj31loG6VSltZ7YTw1KCVOrBPRFTJuV9x3BXue8q01xefzRIvNNJcyVmr57h3/704/XNL9fVX/+BqAMArjnsFbnjxDSX/zjcfn3G2v8c91bNiWZyVjIuPvhij06P4h5/+A4ayQzi+/3h85+LvYENig6t+sMx0Di1doROthMuWlpbAXumz2Vc10OmPH/8xdqd341UnvMqdQ18Zm31vm7v0PaMq5XoP1TPvuf092NqxFR943gdwyWGXOCSCJCjtFmgbSwBY5lxYclEpJnu1Z+K6h64DALziJ68IfP+FDS9ES6hlGRKhDrTKJhETAM4Yq37Uqg9fbraaqo+Vzup1510X+F37vqVkVOegOr1eqK2OsgwyFzmdTrvvvfuUd+Pdp7w7YBRmp2bdoSUhSq/R0zyUrd8KhUKB/+vP2aqMEWA4HHYGRxeQ87OlVXbupw2chp+//OeBuWcymQA1n+0vtRaUEI12wVLGoH0+LV1hZE0SCru6cGNZi7va3J/Z+0xM/t2k65bEumdt10nkoVAoONamPgdvbIlGowDglHK1wlXuuj+z95lIviHpGkuwdIZRPaMwRglcYwCBzlX8u8bGRuRyuZrKx8qZ+4ldJyJ1VWrFv7eDhowsVDLwuTfaQlCNrtZ5ktjFdclkMohGo4hEIu5clLPuAHD54Zfj8sMvD8h8Op12n6P3T9tuSvZliYzaLIFnlXlQngsrX5XKDABkMhmHVKksUHZ0zoomKFNc/y6bzboeCESSyjkDdu6l1vy0gdPws7/+WWAO5HZQH/j6IquDSrkIhUIB5ESJkHwvEqxWOhMrzZ1r+/grH3d7yGtc9+/fjz/96U94bOGxZU6i5QOxTp5IBuVVmdY2JUO9Rf3I5kxqlMtZ91JntVgsOnnnjYKUG+20pfwmPpN+1T0g4ZH2rqGhwVUnqFMRCoW8+t07imWMPXv2FAE8ZV579uxZn/v63Nfn/hR4/TnM/ak27/W5P/lzLzVCxVVN9qK3v3//fiQSif92hBkdxWIRmUwGGzduDEDk63Nf27E+9ydnrM/9yRl27k+VeQPrc3+yhk/efaMsg7w+1sf6WB/rY32sj7UdZeWQnyqeyJ+T1w2sz/2JGOtzf3LGn9PcnyrzBtbn/mSNciPk9Rzyf7PX+tzX574+96fG688hl7k+9ydn7qVGWRFyIpEAAOzZsweJRCJw6wWZaWSuTU5Oun7KvC81nU4jlUo5Vpu90Jn1aJ2dne5atIGBAXR3d7u7PbULzUoMwsHBQTdfO/e2trYVn7NYXLo+bHp6GpOTkzhw4AD27t2LvXv3Ynx83LEktZlJLBZDT08PBgcHsWXLFhx00EHo6upyhe3lMIArnXtRmJtkoJN1PTk5iZGREezduxePP/64mzs7/7AcgS9b32tr7VhH3d7ejr6+PmzZsgVbt27FwMAAOjo6sLCwgO3bt68696Ip02KHsfHxcYyMjGBsbAyTk5OYnJzE6OiokyNedG4vYtfuXazHZNOE7u5u9PT0oKenB21tbYhGo96672rWXZm+ZKNS1icnJzE8PIyhoSGMjIwgk8lgfn7elQixHp3tVPv7+7Fhwwb09fWhra0tUE+6WslHJXPnvLX/8OTkJMbGxjA+Po7x8XFMTExgeHgYBw4cwOjoKFKplKtn14se2tra0Nvbi4MOOgibN2/Gpk2b0NXVhfb29mUtWbXxRi3yrh3MqGso7+l0Gslk0nWq43WNrMrQKyLZ3UtLK/lsvLSE7W15b/LGjRuxadMm9PT0uB79p512mpvvavKia092Lu87Hhoawh//+Ec8+uij+OMf/4ihoSEnMwCczLe3t2PTpk04+OCDsWPHDmzdutUrM6XWu9S6lzN3vSCCjO50Oo3x8XHs3bsXu3btwr59+5x+1F7+LS0tSCQS6O/vx+bNm3HwwQdjcHAQ3d3d7t7zauXdZ5csoz6bzWJiYgIHDhzA/v37nXwkk0lnq2ibWGNflBIzrr92XGQjk4MOOghbt27F5s2b0dvb67ov+pqk+OTdN8oyyHxjtnzUloeNjY1OyC3d294upFdt6UubemudKet3teMVhW+l7i36ff6bbdlWGqpsGxoasLCwELjcngLPpgnAUo0vn5F9Wdmir9KSnHLnTlq+1hurwtL+uL76S/2erenlZ7NWkPtIhaxtCNva2pzyWGnuerhZ9gUgcNevFuTbJv96WG13JftMtnkCL2tfaS/KXXdVUNrzWds3qnJkKQQVP2Xa1oZbOSpXwZY7d86bzTTYmYtrr3Kie07D5auj1xe/r5eYlHM5QCXyns/nMTc350p7Zmdnl9X06xprvS7b+bJ8iGdd1511vXqlod6mxA5SLMWzdcw6b+oHX+kVgECHMdVlvq+UHTYg4hpzXnSAypWX1eauMqP9InhGefmJvWuc6873ZJMoOs3Ui3R8rD4vt67XN/d4PL7M6YlEFhsnTU9Pu/pm7q3WfvNiDJ5lykep2nXuie4FL1SyAaPdh9X2peLGINbbYw0Z62C19liv2aJQsauPbjZffC/2vdb2mdoykcp9LfMGakC07lg7J3HYelL+LQ/lWs6RX7WGlYffNpQAgsZN34d/r4ONGDisASx3/fn+peSGNavaw5e1gVoDqHeulirgD4fDroZR61HrNWzEY+tAtYsb11TloVgsBp6DNdWU64WFBe8NRFz/WuasDpGdO2vZ1XDxYgcAy9pM0gFntMruenSMVEnXY835lQ6zdrdKp9Pubmbb0lYVZyQSWdaDHsAy/aJGRo0mnVd7TlZabz2XXG/qN14Kw/pv7WOgtcfaIEbv3S3VNKbeelEjT60Ztn39gaWGH+wOqA5NqZvl9FXNKKUHS3W5UxmlQQXg9lYbPOlceXa1v4TWkNMJ4BpoQFHuqMggWyOljQsUnua/WayvV21RqAE4o0aPa2ZmZvEi7/93GQIXNxQKBVo/csPXyijbw2QbCvCZOPempiZ3Q1G17Q2rHdpQQhs06FVnOhd7CJRQo8qK3jCVsi8i5VhpD6wB0+YRlBeFjyYnJ50M0bBp4wHeS637opcfMAUSj8cDf1uPPdGDz8/ncySTSXftJdMadChUueqBzmazbp3ZZpZtBJmqKbXmlczZOr9q0Lj+bDDDTlxqxAAElCwjiZmZGUxOTrrGD3xeGkHtCgfU1hHLOkFTU1NIJpMYHx/H5OQkUqmUS5Pp5SR0FIgK2PakTDs1NjYGri/USMq2AV3JydD1thAqoV7KSjKZxMjICCYmJgLyzr/heqlhsY2HammpWs6a+9ad+p3Os72dLRQKOUPHaFivnfV1K6zHXK3jYK8t5RxpMIn4cb6qP9WRJwKpThUDCW0ixTbL/AyuRyWj4ghZoSPttKSKlQeEuVb1SGh8qYRo6DSCA+CgVyoIhUu1X3S9h8+AaK5QDTLhG0Jivj7Mazmsg2Q7clmDrMpdoWpVIMDSlYfhcDgQafBvKvVo1bHRiIy3aY2NjWFkZATDw8NOqWq0w5yURthUjIxs1GMPh8OIxWKYmZlxh6teykrXW3Oxo6OjziCzC5DmpNRzZ06Z35udnUU6nXbGuLu72zmx3CN1jKqZsyIUPLtcf17ZqfyIhoYGdxMUEGwFS7mYmppy55jKmRdPaAtEdcJrXXfyO5gz3r9/vzNo2rFO++PTieBQlEAd/tbWVsdXIUfBtnKkYS9nvW3OWNebee6xsTEMDw8H1p8ySySioaEhEIVpVyntFrgWUbKeXep7dSi45rrejY2NDp7mhTqdnZ2Bi3XqeSHMSjpbO0Va/RaLxRxHxka+urZcTzqhzI3TMclkMu6qSD6PhbfLHVVFyHxowi+pVMp5qHyxTRmFnXg9IUVGOzQIjGa4MIxGmadm83Ft3VfvUcq7VWiPBpmkhVAoFPDANIpaqyjZfoavNaDPK1RhoYcKLN1tykOtikANsoXvyjlIdo42MpuYmHCkrpGRESSTSXfA9ao/3RfN3zMPSoPMtoHt7e1Ofmxrv1qgMc5ByXQ0yJOTk4F2fAp9cahDS+OYyWRcRNHZ2ekcjlJ3rlaSKrARsrb8y2Qybv0nJycxMzODfD4fcMgABBQK9zyfz7v15X7ybnMaNsKUDQ0NyxzuatZdORK8/3toaChAtqTcAEF+h/IUuH80ygACnAMaZCWoaf68lEG26211WSqVwtjYGA4cOIDh4WFHpFPZUUcaWIzwKcOlDLKVb86lXkZODTL1PIOubDbrjB0RNeaL9VrF7u5uZ5DttYX1MsqrRcicI40wz1gikXBrrulXdYKsA6eEQdoFe8FHNSmbig2yetp6j6323GVOhAQSHgYyXdmDmL1XNepQo0ziQ3t7u7vB6ImAZzRfWSpvAiBAenmiB5/fB42pdwcE83rFYtEpKipaklS4BzTaFGAaYwtDVppDVkeOssJImZ4mlRIVCg+s5sr0eQmXqjMRjUa9cHctMmPzVJQNRj5UVAo7AksQHqNc/p/5YiozKgc6FG1tbS7C53Oqt12J8rL5Yz23ylZm72AaZBohzte+D40N8//FYhGtra3o7u52c6cOqBZ2t0ZOe2/zutHx8XFMTU0F8t/6mYQnmY9X+VHSmjXIhFhtqmw1JeszDkTYksmkQ4SU7UuoXQ0s0Qfmvi1UrYjcWupF3Ws6cnxpeoBpSTLx9YpIkrjs/c/1MMacp09/q1OseWOeNw7LTdD5AXCOEgMxvSpW+5CX4hKVuy9V3fZkFaxtFq4kJz4YWYtkTS8sLDhPSRWU5jFDoRBisdgyz7FeXqB9HoUVfR4pn5WbomSQUhexr+Ww3rglMnCeRCeoSGgcuO40eHw2jSS5d1RMzO9bYkY5c6QxsA3l7UUGnJMq0WKx6H6Hv6+HLxxevLuahC5FXHwHpBqDZg+7j9DFFAthWmVQ6yEn0kSFRiPe2NgY8Lzn5uacN885VwKB+VAfVbB6mYrC5CSW6bWpnLPC3iSv5XI5NDU1OadKSZ00jpS3WiJ8GwjoWjF6Yd5YqyGAYOWAEnV4RhjZ6YvQteaS6cCWmrMlFtn52jwyYV9FtShDXDPdO43WNBderXyvtv4+Z5p7z/NKR4UlQqzAUD4EK1YsXF2PYdEjTbEoeYv8GBpplW2iAD4HlE4H/8/98CETtYyqDbJOwheBcLFJfVdaOBWqJviVXKS1j3rvsBW8egwfxERjYZl06oE3Nzc7wdNLwmmw6ilsq83fFyHzQDNXwrwfjSyNcz6/WP7CdVdDTlRDS7l4cbjWmJYzRxtZ2bnyAOntO6oMAQQiDcum5SGiR6zesb5Udio1yj5Fa3NVNEBUAAp3UhFx3Xm7DWs3aaApb4xINAfKtSjXKJcyavoMfAFweXmWizEy5HsRbuecOV9C2HxppYV1UCt1KHzomTp2/Dflli/dKxpn60hS/7S0tCAWi7nITkt0KIdKRlptvW0+U294IstaKwr4HKrkeYZ96JB+tamceg77LHwOIoVcU65LNBp1eWNGxyR0+Urh6hm4cF+Jylp0jPqF8waWdAD1H8+XdV55rhcWFgKfoyWCpQKVSp6xKoNsF1LrQLX2GECgRqu9vd0tSC6XQzKZDBhlLhS9ET14a8UmtAZND44qF0ZcxWIRTU1NLkfCV29vLzo6OhCLxVZsiFDPYT04LcuiALa2tiIcDiMajQbgPCpF1o4DCJQJaS6oo6MD3d3d6OrqcodrpQYtdo52nS0DnBE7UxqRSMR9LovtATi4TJnMNAZUXgACCk3hPEVeKoXbVUZsvl4jfWDRqMViMXR0dCxraEPDQFZ1MplcFr0zT0rlTYMIBM/eSufAeu2+CEshNnXC2ASms7PToVqcN40w500Hid/Xc8O8mq+8ZbUz7EOtVB9oWkb5BszxEuZlvpt1tOoc0XlWY9zZ2RkwIjTMGt3pfvjW3OoTXRem9GyJnzUeNuLS9bB7WY+UzGrPYw0yy7QYnPBFue/p6UFXV5c7w5qPtwa51qHkKUbDqlfImI/FYoF8sj4nCVtM61m0ly++p/Ki7F3O5dTflxpVGWQuAiEBW4jPBQGWDDK9zZaWFqfIyLhraWlxdwUrEckuyloYYwDLBE4PjxpkQlVqkPv7+9Hd3Y3e3l7XnYtC90QYY5vT1yiXxBxeIk/BpbNAr5A5fRoVRqjMBXV2dqK7u9sZZe6jwvOrzdWnXJV8Q8Qhn1+8C5VODteUyn9ychLj4+NIJpOB/aEztxrEp5BxpeusqQyLntAx4LOwqxmdNObOmMOcn59HKpVyMl8oBKsPSJaik2LPW7nPUAqy9sFunHsikUBnZ6dzwLjXABz3o6GhwRlmAC5y1juKp6amAnKicy933jZKsdUMfAY9y5xPsbhIHFVHgIaYTXuUYU3iaE9PjyMg0ZAQqaGOW2nYFBKNsjr69u5pZUvbfdX/W8Nso+O1MMqabtIImSkWBlxEVejAsxsXG4BoCqreCKLC05RlrX5RRI5rzHXjs5HTRG6I3l3OxlZ8X6bxtHkMS+bsM1bynBUbZPVGfGG7Ci0hU63ra2lpcV6V5mX4fsCSx6Leb6kkea35Ej34Cldbyrx6VYRyaazoCTIa0pzZEzF8ipbKh0ZZX5wXlUUul8P09LTzcnnQmA+iomLkYJ2OSqNNq3Qo5IziY7EYurq60NfXh56eHiQSCRSLRWSzWYe80LFQ5a55HgvP+qKJchwJ+35KqNGXGjUyNzs7O9Hb2+tadzKPXygUMDs762A1OhXqUChao920eGbKZW+WMsa+taBB1tae3G/tSpbJZJDP5zE5OemUD5WXPUOzs7PurDc1NZUVzdl1t5wOa5D1b+wz678BONhRST2MelTeaZC1qxP1VCWQtYXXFZ5W7oTPGOu/fYbX/u5aBCv8bJ4lJTEBcGvDyFihauaR6YyuRf0xELRJ1EmEra3MW0dUU0WRSATz8/Ou6kHPHSPtQqHgdBWdNRpiWyK35hGyTZhbQ2x7IOvf+Dw73XT7ezz8vjyJvurBzvNFmTZ3zPwOF5kbbjdCI0Z7SNbCQPu8aY2mrDLh75Bwx2hYBY+Kmbl/ohvMA9l2d7U8l0btPEzaGpDQIQBH3GIOmfvU2NiI2dnZwLP58o78frm5b76fjbgtWYzPQQ+dEZj2RaYTQ4NMElI2m3URtMLpGpFQOZCFXS5E6YM7rTHmunL/de1ZR0qFyughFFos+2BO0MoVz4uFxSuBVn0pDoWqfe+r+8X/cz5E3Rjl0BFVI6hETUZAvu5SpSJ8O4dSMmP/XSoy1mHLg3jGbSqgnkPX0rcPnBMNsubdCfcTXfCVOdVz2PXgOdNI2EL+yhkimqIcAe4H9T0dODqttpFMPdjjVUXIlr3GvDCNspYyUalMTU25jlasvbSlKXpA9FUKoqnVKPsiIGuUuWn8OYBlG2tffA4eXPWo10IQ+VWZvYyQqVxsDaYyxjVCUrYkX2TcajvHajxAC7vqa6X30Z/Zg2eNAY2Z5nltazuFfCsxDupRUwY1TxUKhQLwlfY/1giZ85yZmQlAXnxPsoR5fmwkpfmvcuau8q3nRx0hNUTar5zzY4RMroclOdk95H74zmwl89YzqRGmrchYDbJVOeffMopvaWlxEb6ujUWVKjF8+qzWIGh0pvBpqWHPOB0HleV65mR9z+Fzcnn2lMRoIX5rpCoh9FU6+OzkO6hu8DkXAJxOV1TDIjHFYjFQLdHa2hrIjdfLGANVRMgaCXATOHF6GHxQeh6pVAoAXL5yfn7edfRiLoXlE1wwVZg+YbaHsBajXMob18hKlTCVA3NBaqg0UtaxFkZZjZI2QdAcMg8Ia6a1XEiVPdeQhkWb62stpi3tWu15bIqDsqPKnDAtlZPmIqenp93B0mJ9hZ50j3R/NPVAwka1uUyf86WeMxseMKon616dGTqr/Ht62vF43NXxMl3A39MyL1tpUM3c1Qhw/iT1qddPNIT7r809uHcWEfPxJqwyLGfYedOpt6kkjcCJLtihcsffKRaLLk+oqbVEIhFoJlNqrcs5vytFsnxGjlLG2OeIWuKsolprVdnhi5JzuVygLalFCzW1Yue2lqk8u2b2GRQ5KRaDTa6oz8mAZzka0URFArq7u9HZ2Rm4uKYeddUVR8iqVBVCy+cXyzg016QRCtsDhsNhl0AfGxtDOp0OFMTbvJB62r5oVAW+2qHekSpA9ZL4cwCBnq6pVCoAZ1mUQCOeehplFW4KPQ8FlS0NRWtrq2Ofsi0l94Dwby6Xc1E1Gc9qjH0MwnLXnQKqioTIihofOguERNPptCNmhEIh18jClopoxE/HUC87IblIDQnJHysZCYugaHTAefKQ8ntkKLNcRokemnPn+WHecmZmBqHQYtc3rq06TUr8qSYPa6FSRhDMhzU0NCzrUqXMWBq0QqEQcGqsQfYR/co1xHb+6pAoIUob9Kj+4fnimaDxJWrCaAiAk3/VI+QusMaWzOzVoGoOPZPWCfVF23y/ldAaa9RpGJTZbHOW9TR4voBFIWvOnwGa3pql+oH7UArVXCvk0Pc8wBLxjrqc+oYXfrDb5Nzc4u1iSna0pDXqlnrA8VVFyLoBCuURllYvVKMVKiTCRfYCCsXs7aL6lAo/m59XzbACp4w8CxNqhMwSHL3/kkpWmynYYQ92LZtn94MHlT+jsm9tbXWEBeZaVcmxhAFAAOa2+XEfKaNcg6yOnF57plFVoVBwqEo2m3WGYn5+sfczZUYZqmqQCUn7aj8JLSlxbTW2LLC8ZEjLbbhe7PlcLBZdhKzX4tkaZGCpa5CWBXJfFEZbLWe62tytMtV8JQ0qZVadMHXEeKb5/NYg+9oF1hIR6bzVsbf13rbywkK6/NyGhoZl0HmhUHByT2Pe3t4ecPhsV6lydY3PICvRlfqB3/dxavR9+F6ch14TWaoh0VoY5VIEKRu9Wx6Nj3hnnRc10msRQVsHkak7bdiiXev473w+70r3CFXTGJPIW8/eE1XlkCmcZE3yoFIBaa0ZgEApE7AEZ2thOSMWJbboAypcwpd6npUaZI0gLByj/aBVEVKR2Y3U5gJcAwu587MI/dULwvFB1nx/zeMwCuV+EHLXqw6BYHs53003lXqB/D1rkNUoa9cwRoVsyUgDS4OsJSPa1MIqWy1jsx3BKiEYWeNQKsqk3NrI0kKKfE/dL6YHCoWCQ5AULrUKsFr4V1M+3BNlP3POFnak46J5OTU49qVGoZZUko3KFLlSnUKdZKNHRRq4DjzP6nAXCgW0tLQEyuj0/FP+y1l3XxrJMnW5501NTQEHjCkXGieupToZ3A+fk1xvQ1xqP6w86vpqgKZ14D59p8+mz8vnX+tnsc6er/1zoVBwCKempEhcs3C1Pl81o6oIWY2yemqESMm2KxQW25TRk1VoSWFcJXvxxRyoGnXNLVooqZooWeELhcds32obJduN5PNnMplAvs1G9JpXUQNX7Sh1+GkceEh5YBhJkmSnLfD0b1ZTstXM0xpkG31r7R6wCCmyjR2dBX5PrzyjMbboikKeqsxXyg3aYeXDskypSCnD6gT5OvfQSbNRD42IkopUuZUiRVUSJVvnUPeE5SEKN9qIl+ugEKyN4Kwh9r3KGSvB7Vp+SJkn4kGFyXy+zp3yQHRLI36FxW2PYspnuYgK11YRBOoHIg8k8bHNqso499pG1zbdwn/7iF31GjZoUR2ue2Kb5JBUxfW2cyult6y+qYeBK/Vc1tnjy6bC+Hucjzp8tTQAKTWq7tSlCpbeNbtxsaaSTf41irGelW4qDUU2m3URMA13LpdzXYE0Ares2UqHeks8lGyWr/2ELSNSySasFeXhBZZqfAlHLiwsBHIstga7WshdHSQKts6TTOpCoRDopUsvcHp62sHVSvn3QdKVRGZ2furhk3swNzfnFCiVlcrH7OzitZyUJxvx6qEJhULLyEXWebLOVbkGzTpsmtNV40A40gfrW2VkzxCdEi2HsVwKa1TLHapYdV90vlbZW8cMwDKFaVNKfF/7u9WgQfqsVm/w/Qmz8z3ZG8ASbfh+lDne0KUwPnUUjXI1NdS+faWsM1BhKRzz1zy7dD4ZVfL7iihxfxR1sXyOtYwsuY7AEmJJpI3pO55T9jXw8U2sM2rJsOpI1RKw+OZu5apUKZqv4mctkCA7qoKs+dV6n21tbQ6eaG1tDTCoqcBU0ejfcmN5+wkFljlRkjDY3YiLoJ58pcMqW8KkvFLSGmTOWyEZHmIVooWFxbaIPIQzMzNOSRDWVqZrPSBr65xYpb6wsOAUEe/tpWGmR85DoXkfdZ70Vcm89QCSd0AFOTMz41jJhIgYOShaQudCoUY1rqHQUiMUhVltGVulxCiLiihhzPY3X60URVEcGyFYo0yko5ax0rNZmLdYLC4rY7IG2UdO4udYxWUj52qjCN0DPg8jlWg0ilwu53J48XjckW5sC1ueB+oYGj++FEFSboXyDspx4laSdQYcvJkqElms908mkwFonYZanXebSrAIxloZYt/zauqDxnhychKhUMjl3kuRzdTZ0AYbRDV4fpnW4X6vRZTsSwlZfcfn5VeLVvmcZOqDakbVrTN9nguFka0DFQKy/aD1YFMJp9NpjI6OorGxEZOTk65MCggaZP4dPVA1JJUOhTZ5YJjUJ+1d4Un+jY2aFB7WkihlBJPBSgWgHm61EbKNeigIyj6lJ5vJZNyVb6lUytWCF4tFF1Wy7aka5VLQaTlzVujJ/g2Z+ezoQ4MMwMkNnTj7Uu+Vyouyp60eGf1o7Wq5kDXnbA2y5t15gG2+zxdF2n2ysKSPFFWtXKw0+N42GlGDrPPUqNcStnQdfbJY72hCjV0ikXCf29zcvIz9yggNWJKDbDbr0gN0uNUQ2hIrOtBKIFspQlZZ1+/T4NOZo9PFteYZVWeXBtvXK9mHZKxldOxLIwBw148S1mdFg6bnSkXH2hktn8+7WnyuM393Lc4An8WXHrKfZ+F6S7C0OrGWPagasraesAoR+1WrcGtvUGAp/8Nc9OzsrPOymOsjjAMsQdbT09MuB0njVomC1WGhC42AlDjEkiCNFLSERY06DZzPcCgxTKEorfEsZ86+iNXmdmzLPrLCSecnEkDCVKFQcLlEOk2+9+b3K4WSFFHRNWNHqLa2Ngefc73oJGmDFn6+Rktq0JRUqLC1GuNKIGsgyDFgdM5evlSgnEulhscarlIwd72HOhAaidiIthTcrt/3RbHcb/te1QxdH0b0NJKUw9bWVteqtKenx/Xg5rPR+SOEnE6nHd+D76EpNGW2KwemXCfUInZ01Nva2tx1sqp3mIpjmQ1RAMq06opSSMxaDzU61BHa9AmAW2Nf6oZrw32MRqMulanrpMZ4LR0NnyzbVBP3W/fKInVK+q313NZ0uQSHHmJljioExG443EjmR0gAmJmZAbDoccViMWQyGZejIwTry+FVa4x1qNenzUDoBfKZ+JzWY6UBoAJSA806N8JmVODaCpEoQLnztEbGlvloqkAdIkLxLB2iUHHvdD3VEK8UIZdrdIAgqkJoVts0MkLmGqonTohPI3dV9pQly/JXR8hXNrQaZG09aE0FKNPaGi3f8+v76vvbHKlPnusRXfrGSlF8qc+3f2Pn6/vdSueuDr4y/20jkGg0CmDJIHd1dQUiZMoA2dL5fH5ZdyWr9EtBkeXOGwjCrJRBvYaWAQodZcosUxV0lBhp2iYgPo7CE2WUKfd0Vsih0VSZks6s887zyiobRUotT0D5QfWMlNUZ5VyVuU6HgvPRfDmdJxL1iAjonlfrKFVtkPXBqEDV27eNAnio2FzDRsjhcNgZbs0/8PcpCBp5WgVWLWSt/1aFzzkrczoSiQTuCNbezsp+5BoUi4skBzoYjLbZRcvWUq40T1XcSkJhk5JkMomJiQkHRWvKgPlxwvFKtlMmvEYIq5U4VDKsUVaCF/s+0+gqkzEajTp42LfX1vMOhZaug9PIxzoW9ZAV3zP6vGPf7yoyox63NrvQM1Vvxcv311xZNYZIf9eX8y4115WeQY0xI0SNiOkgt7a2OhJUS0uLg6x5MQTL6fg7tlRSlbISsVR3VVNSpL9H/aZKn868vZJQjayytH2EuyfaCNuXBi9aLqpOCIAAgsS/4/NpmRc5Inw/rl09qlE41D6pbGl7YNvDXIMVBjREP6iryFXgMyqHpVKjXLNB1ofVCfigBnoc/D0lINmyoFJea7W5zNXmzX/rYSDZheQMzVuzifpq9aYqrDSQwGJNW1tbW1kG2UZSCq1ns1kkk0mMj49jZGQEY2NjSKVSgdy1dktjdKxGjh5tKbahdYCqXXs1VspCjcfjTtnyMov29vYAqU7z+DbVwMPMeXGftH7WEjX4/UrnrYdaI3Xfnq22j7aUT5u0KFnM522Xo4xL/Vy5AZwHAC88W2qNSv1M18SHGpSLqFA/sJSJRo3OMGWactPU1ORYzOyexD0qFJZaxlqERGXRluTZErZK4FPrKGqQ4utqRSMLIJDbp/NhHQRfSmEthxpim8pSpM5Wv/B5LOIWCoUCzr9GnywB437UyrPRoTqeqQ/qH6YO0uk04vG4mxOw2LAnnU4DWLoWmDaB6TFdJwCB75c76maQfUOjZWBpY4BgxyolONicgw5VEtVGOaXmaQ8O86kqYBQQrbXWi8tVCJmzymaz7muhUHDXT7a2tqKrqyvQBGKloQaISpzdZZLJJMbGxjA0NIShoSEkk8lld61qDTK/r5CrHjafMbawtZ1bJWvNz6NBVlZpIpFY1oOa8/DBx2rQyFHgM2kKoRpHwhoSm+dVo8x10Pf2RZxWkWkbVuYWdV98r3ooYOsYENYt5YCFQqGSzvBK61nNXFVvkGRIZEr5BFp2RllilEOdQ1jU1pBr6owOnEZLeqVeLfWm1olTPVKKMc311sjYRtFPVN4YCEa7Vm4ikUhAVmwKUR0N7gN/x3J2MpkMEomEQzK4F1yjeg2VLVa8ULfOzMwglUo5kimfbW5uDqlUKsDAp32wETI/o5p9qttTWmVp/81J+nIsq3l61cKklQwLtTM6BuA8dXrPNMZ6N7B6uQpT05PSnG5TUxM6OztdTr3SCJlGmR4coWiWjLFsjFEwjT6jTC39oVO0EiTqi5jqgUwo4sB1Vu4B18cyXOk8sLyMz0/nhxGfrl8tcLUiPhq52PWwTovPiVFjrCU2fA7bMY2H2ocalTt3395aeWIU6WO18++so2Zz8vU8nzRI/Lc6b/r5ml7SaJIRGGWFa63d2vS9+VL4stRVo5VA1xqJU7/4oHH7/jZKpnG2jsFaGGUr1/o9X5RsU0L6vAqvq2HXXLGWg5HUxvuVbbOcegyre/RMptPpQB96yg15OqFQyN0F0NDQ4PgvrHu3OehKkcS6GGSbVygVVVnhpDHw/S5/n17IWhjjUrASPVj+jIZCb+fhJebxeDzgvfJ5aBgYRdF45nI5J3jVCJtV/DbPq3W3vq4z/CwqPMJ9fD4lqlm4rt5rb5EJKh9ddx9PgHLBw1soFFyvdDtPXyRXDdSuypQvGisL3WluTaN7AM7waW9mMt6z2ayLOrj+qpRrNcpcG1XoFhFR4iTXlMYrn88v6ye9EmPdOgGVGDJdc4taWYOs0L5FRWz3PS2/BJbgYQtVa9tZX/lOJcPKOvfShwral5U7jbrWOkL2nZdSaQr9HYWEbdteG0mrYaYz2tra6oxepRUR5T6Lrq12m/RdpkNGvH4NhUJIJBIuraZoZDklcqVGzQZZFaRtf2k9aEZkqnxVoarxUFhJPa96G+aVoCQOS6LQ0gt60fosfG5lCFJha6RRTcRmDZjmvegsEKbT/KBFKtTr1r9l5M/G6UQA6p27shA8I0Vf5GUhcj6PGhNbisDnBLAsmtU5lLPe9vCSIW4NQz6/eHWidnuiHJEhbnNV+spmsygWiwGWuGV/VmqUffKiJXsqryznYrMHGkDtljc7O4tkMrmMi2DTLqUi83IHf19LpxTRoUxYVi7TH+oEM7XD3gJaXkdn21c5Uc+blHzOZyljzGexjli1Dlk1o5TzSmOr87D/Vr3C0jPKkUK5bPqjMLbtP1CrMfbB7Rqh2yBRdaqWmlFvc162P76vpLKaeVdtkDNzGbzjJ+/Azb+/GSNTIziu7zi859T3YEdsRyAfpl40H5wGTQ2fLdfRulEqs3oZ433pfXjzj96M7z32PUwvTGNb2za868R3obup2xknNZ5qUCk8wNLF3LbeUclHlebcfIPCfvi/Ho7d6d3Lfn5O3zm4sOdCLCwsXolG4gv7+fLzGH1pvS4hOnbLYuTf1dWFvr4+1/VIc2nWQSlXOViZObb3WLzr5Hfh4OaDXTmWwvj67PZzGCGT1JZKpRyZjTk4yli1smOVzNf3fB3XP3I9JuYnsLlpMy6MX4i+Yl+AXFQsFp2zRrngvjBiI0Q3OTmJiYkJ95qennb7xSsz+V7MpVGprbbmNsr64iNfxCd//UmMzozi4NjBuKz7MkTDiyVDxWIxMK90Ou1usJqamgo4E9rAZ3Jy0qVl2MRHFZwq62ocuXfd9i6867Z3Bb63o2sH7r70bteekfXgGvXQIDNtxOoDrcFnuV9T0+JNVuwWxbVWg0zDWcn8b991Oz5054dw3/77cCB7AF950VdwxqYzAgZNI15fhKw/o1OgyOJawtV703vxlh+9BT/Y+QNM56axoWkDzg+fH3AUCoVCgMTKZ2toaHDrGYvF3O/Oz8+jubk54LDSCVeEz1eeWO741C8/hU/d+yk8nnwcAHBU71F427Pfhucf/PyAwVRESNEsAAGnQs8csMQatzA9dX01+l1H1Qb51be8Gg+PPIwbzrsB/a39+PwDn8dLvvkS3PS8mxCZirg8pkbKANwhILtNr9ij562dkKgI7OZUCx9Nzkzi1OtOxelbT8d3LvkOOho78Juh36ATnWgqNjljxs2hN635tGg0uox44SM3WIIDn79aGO+uK+7C3MKciwYfHHoQL//+y3HOwedgY2QjWlpa0N7ejmQyidHRUcRisQDxiAaZTgcPTDwedyUjHR0dzih3dnaip6fHlZGogqpm/VVm+qJ9+PwDn8eFt1yIL53yJRRTRYyPj7u2gkCQGKGRErAE/bJ72+TkpLu/FIDrsc7nLwWplrPukUgE3931XXz44Q/jmiOvwZbIFty4+0Z8fOLjuKb1GhQWCs4R0JaIytLn+ishb2JiAqOjoxgbG8P4+LjjFzCib2xsdHtk85nlGAau1zcf+ybedfe78N6T34vD4ofh3x/+d7x/z/vx1sRbAwgO5zU5OYlwOOxaRyqqw98ZHx/H+Ph44BJ3X/7Q5hErHUf1HoUfXvbDJT2QW1zj6elp195WGa/8TEbIbOtIg8zueZQT7hNv76FR9jGsK0nbTM1P4bj+4/DK416Jl37tpStGyCrXuka+dXwiIuSJ6Qk89/PPxbMOeha+9MIvoWGuAb/c+UvM7JtBtiEbyI9qSsGWErGvOM+BvRKVzpDqSpuyrDSIOajtIPzjX/4jDu06FIViATc8cAMuuPEC3PXKu7CjY0cA1bJEV63k0KBFS5tUBlSvrATpVzKqMsgzCzP4+m+/jm9e/E08e/OzsbCwgDc97U34zh++gy/+/ot4YcsLMTo66iIWhZ/D4bArZejo6HAF/LlcznmvNOSaI7Sed7Vw2Ad+/gEMtg/iuhdd5xToQMsAUqkUxvJjAYOsMB4PMADEYrFltwwBSyxC28VFc7b2cPGZVhp8xv5Ef8CT/OcH/hlb27bizB1nun7ZnZ2dSCaTiMfjztFh1Dk/P+/qN9vb292BoTHu6elBV1eX+z5fvPNT2aaVKgYrM3Nzc7j6xKtxyx9uwZcf/TJOL5yOoaEhpFKpQDc3JTepQuLeEV6lsiWy0drainA47BreVyMzqkQ/8/BncMmhl+Cl21+KTCaDv438Le779X34xcIvcNzCcQE2eEPD0h3DhULBKSQaNMKoExMTzhhPTk4il1tslJJIJAKOq73O0SqGleYfDofxL7/6F1x+9OW49KhLMTU1hbcf/3b87Ac/wy8WfoHDI4cDWFIu7PXMfzOnT/mhQtWad+b+9DNtzrNaQ9IQbsBAfMBFNYzQ6Yglk0kXJYdCIScrNACZTMYhEGTJ8myHQiGHEMXj8YBBVg5FNYbwrEPPwlmHnrXMGVeDrOtk4WpfpGyRhnoPGr0P3vlBbEpswifP/CSmp6cxMTGBQkcBj44+ivGGcQdH24CE82PvBpYw6n0ETOMo8Y6GmntSyzj3sHMDz/Ke574Hn77v07hrz13YEt0SiIi1k6HmgOncaSkc/020S1GLam2Rb1RlkHOFHPLFPFoaFtmxjAibwk14YOIBnNZ2WuCSCL37mIeAuc75+XnHdNN2jkpCApZ3VtHDzp+XsxjffuTbeP725+Oir12E23bdho3xjXjVsa/CBVsucMaK3hKNMZ0CwqR8Hn6lEtBetUoe4QYrDFKp12296Fwxh5v+cBNed+Lr0N7e7nJfVCaEG8fGFp0M1tNp/9jOzk4HT/f09KCvr88ZZIXvtBVetcpVZUbzx02hJjyUfAgnFE5whonOj2Wh6p6rgmaO0JYi0AFUhmql5Stc6wdHH8Rrj3ktmpubMT8/j2hLFEfHjsbu6d04pnCM87a57szBFgoFF00oLDw1NeWgduZiC4WCM+JKNNKIrVzjwJ8tFBZw/9D9uPrpVy/xDZpbcGLHidiV2YVjGo8JRJVUmqFQyF2lR+dHlZeS0IAgL0HX2cqMzq2c8ejEo9j0kU1oaWjBX2z8C7ztGW9DPB8PlIsR9qSB0GdJp9MuOqbzwMhOkToya9UYq9zVrGxD/jyy7zzp51hD/ESM7zz6HZyx9Qxc/p3L8bM9P0NvSy/O6j4LOxp2OHmcnZ0NQNQ22ldnQ/sBMJiwt96Rl7EShF/u89MYL+QWcONvbsTUwhRO6Dlh2S1x2uNeb/jS4FG5NuRUMKhUNrWvJK2a/arKICeaEzj5oJPx3jveix1dO9DR0IGb/nATfj3+a2xo2RCAA0jYoXHiQedDMyoA4KAMvbSbD6osSJvfqQRK2jm5E5+691O4+uSr8ZZT34J79t6DN/7ojcCzgOcPPN95djTIMzMzzptTaI9N1ZuaFu9ypkFmjRqjIBIXKJT2xpZKNk8P5nd//10kZ5N45fGvdLAix8LCgoM4dZ2YU9XmG52dna7lIP+tDU9KNUaoVNhUZg4971C0R9px0x9uwkOTD6G/sR/z+aWuY2yeYks/FP4EliBgliVo/S5hYzVovprS1YxaKBTCxOwE8sU8BuIDgZt3upu7sXt6dyCvzjkx9zo/Px8watrKNJPJOFmnctMuZb4LBSoxDqFQCOMz427uWtfa09KDnemdzolj1KCOJSNfGjd92aiUckW5U0KMVVTljmdsegZueNEN2NG9A/vS+/Du296NF970Qtxy1i0BRjibqXDvOWdyDPhidMzox5YwMoVm512viNSiB5bYVYkeW8uxc3Indk7uxOtOfB2uPO5K3LXrLrz3vvfiss7LsCG2wV0GQaeX55M6UqFgpnKApfOqJCj+PdnY6nxWU/tdLBbx4NCDOPX6UzGbm0W8KY7PnfU5DLYMIp1OB9KhtkMe5UX7GRCpYptjzokoosq7r393xYhQRb8t4wvnfwGv+tarsOXjWxAJRXBM7zE4a/AsPDj2YCDUV4Op+VUu3sLCgiN28TYWGjF6KJYFbFtWVpLTLBQLeNrGp+H9Z7wfhUIBx/Ufh4dHHsYXfvcFnLflvEA7SxpfjWDU82Y/Vs5fI2htWMF8IHOCCgFXmo/l7133wHV4wSEvwKa2TYFa0VK5UkZpjAoIVfNFCNvXDtSHRlQzKDNbP7EVkVAER3cfjTM3nokHxx5cJryE5rk2PqWuDGGNjhj1kC3OK/m0brzcUi5FJiINS44h6xAj4Ygr06B8LCwsuFapmUwmALGrEqBjGolEHImrvb3ddZuyTpUe9nKHi7rCwesWOXde6gEg4Ahr4w2F9+j40LBFo1HHS+jq6nJpKK3Pt0RAnddK46xDzwKwqCeO7j0aJ/SegMP+9TB8b9f38Oz4swOMWSVcajMHOj7qYKgDwfuJeQbIlai3QQ7Bf7vXap/jIwyt9SgUCzhpw0l412nvwuzsLA6OHoyHhh/C7aO34zUdr8H09LTjGHDNNZXEJhrhcBizs7OOpa/193SmmGKiY0g9T8eoGnRiR/cO/PJ//BITUxO46bc34XU/eh2ue/Z16Mp3uTSqtqmlTCsCpEa5oaHByTkRuLa2toC8k/hqe1IA5aO3QA0GeXvXdvz0lT9FeiaN8alxtIXacPktl2NTdJOLwJSMxQdWJUoDzOhRazkBBIyHRnNUrrzztJI62Q2JDTiy90gASwt1RM8R+OYj33RdW/gzYCnvQWXFbi4TExOuybgyrJUlqHA7N4t9m2mUy2XN6tid2o1b/3QrbnrpTQCCN5ForaV6ocCSImKLwc7OTkfk4pzU46sluvENykxqOoWxzBii+Siu+O4VGGgaQFN+KfKk06PlYaVyNerFMvLhvbi9vb0Ohuczsq2iPTSlRigUQm+sF5FQBJPzky4nnc/nkS1m0dXchbamNqdYiAQRRleFqs+jipVNBaLRKDo6OtDb2+uQCrsflaAUOvfRmVGXj29ubkYqn0JPtAedHZ0OQrTNZNSB4HMpJ4KRO3kXXPOurq5AyVy1eVg7Olo6sL1jO3Zld+E5bc8JyKSWomg6Sbu+UbnSMe7o6HC8CTpuKh/1hIn53DTGtje1dbSUyFSK6FTvwWfdEN+AI3qOCJC0Dus6DD8e+jH6NvQBWOTREHVg/rVYDPbtz2azbu+BYL8Jzp+6hXJOHozvnK62F1yrhlADtrZtxcaWjTj4pINx95678dmHP4tL2y/F6OhooIGQBl96RrVkVEsFCVe3tbWht7fXybsPYalGfmquQ443x9EcbsaByQP42YGf4crDrkSiOeEiSi4UDQYXgPCkDhVaCizZwNYg2xrZch/+1MFT8cj4I4HPeyz5GDa3bw50jKIQ0csmLMZeq8xX8nBYRnU4vHSJOskidCzUIFezcdc/cD36Yn144aEvRCG/vL81IXN6eWSEU4nqelIZ0cFR+MXmc+o1Yk0xNMQbsGdsD+4avQuXb7wcLZkWl7+3RozRpe6ZVXBUcq2trWhra0NfXx/6+vowMDAQyIszT6jPV2rwM1oaW3DihhNxx747cM4h5wAAcvkcfpX8Fc7pOwcdjR3u8JKUpuRESwDkvOkccc5URn19fQGDvFKkudpoijThpA0n4ae7forzDj1vcZ2aGvGLsV/gJQe9BJ2JzsXf+39QeSaTcWx1JTQqzKgpJDLyKU86d3tRfa252KmFKTyeehznbj43EG1qZMaOZ7ZtLKM4RarInaBTqgTGetQeB0Zo+VWhapR9n2XPQDV9C6oZpwyegj9M/MHp4ZaWFuyb24dNsU3o7e1FQ0ODu5EvnU4jlUq5VCRlJp1Ol2QgU/55sQzlXKs72traAue0HGOsDowSchdyC8jOZ7Evuy/QWpgRcCnCsJJKmUZixUM8Hkd3d3fAIGslikU+yh1VG+TvP/Z9FIoF7Ojagd+P/h5vufUtOKTjELxk+0uQmki5ht00FjTA4XDYwUn0ZLkgqlT5UNo7mjCkEjAqhTWueuZVOOW6U/D+O96Pi466CPfsvQefvf+z+ORZn3Q0fC4iu2qpUgyFQs4LVBIAFTKhU867WCy6SJ7CZ+deCQRZKBZw/QPX4/JjL0ckFMFCccF5c1pWYDvHAAgcMM5Fyz18kXHdFBKWZOaQzkPw++Hf49qfXIuD2w7GWf1nYd/8Ppcz1egMQMBb5eA6A1iWDlAD0dXVha6urmV58UqZylc94ypc8e0rcOLAiTiu5zh84pefwGx+FudvPR+5VC6gjCgbrHtlhx++H/egWCw6A0CSnXWQam1OEQqFcNXJV+GV33wlTtpwEk7sOxH/dM8/YTY3iwsPvRCRmaXeycDSveMAAhGytiQtFotuLePxuJszeQjq3NUy9zf94E04d8e52Ny+GXtTe/HOn74TkVAE5x18HgqZQkBG6fSz8xkdZ43oeX41f8zI2EY4lZ5LO7LzWTw28ZjTbbtSu/DQ6EOIhWPoinR5yYocasB8Dv9aR8lXPfMqnHr9qfjgXR/E+TvOx9177sZXH/0q3v20d6Ozo9MFG9SXqnvIP+D/SaJTo6dIXTgcdtEwAxZfaqlc2Xnrj9+KM7edif6Wfoylx/CV33wFv5r4Fa7qvQpjw2M4cOAAxsfHHdzOOSlzmqgAndRIZLGPuqb1NB3G72lEXw765htVG+TUXArX3not9qb3oqulC+ceci7eeMIbUZguYKZpJsDG0wbqmvhXrJ4KlkoICF5vxRwDk/7V1sQ+fdPTcfPFN+PaW6/Fu297N7Z1bsNHn/9RXHbcZQHFVCgUAkQgFQ7tWGRJaGr0gMU2cGTaKsnI12SjnPGjnT/C7tRuXHH8FQGPUIvcteRK8/UUOh4IXU8lDemc6hkZq8x0tnTi7G1n438e8j+RHkkv6wFrmZmMDoClA6Qdmvi3XF+tsabDUa1xC4VCuOSYSzA6PYr3/uy9GJoawtE9R+OGv7wBGyIbMDo3GnBmADiCCA2ykhTZQIMyQnY/56qGoR5EkUuOvgSjU6N4123vwtDUEI7tPRZfOecrOCh6EJJIBnJpzA9y7TUFYzugUVFZ507lu5a5703vxcu+/jKMz4yjt7UXJ286GT+46AfoCnVhNDu6TEYtUkRHSBUuIx/tbqfISSVM9pXGvfvvxemfO939/+9/9PcAgJcd8TJ86NQPLWNa87N8sLWvJnctx18c9Bf4xkXfwFtvfSvee8d7sbV9K9737PfhJVtfgnQ6HeiSRgdUYWkGLETpNM+sQReNMbD8cpBKkU+O0elR/I/v/A8MTQ0h0ZjAIYlD8L7D34fEaAKPZB5x1T/T09PuTCpyobrbpsKYitWyUMvMr1V/Vm2QLzrqIlx01EWBcqCpqSkkZ5PLyh98ZBQedu0r64r/5RAprd4qbSvM5T78OTvOwTk7zgl8T9m5Cinxc3X+6lCo0Gl0zEhaN91XFlLppp25/UwU//cSLGPzTEpq8RFBlLXsm0+9o2IdFx11ES488kJ3aFmWlQ1n3fr4WKdWMakC5uChUShQvdxSbPFyRygUwpV/cSVe+7TXOoeMNa5WxpXcRXSISkzbUFqHwtdDvFoZt+PKv7gSr3v665xhJdnJ1lpaObd5P34+ZZ1/a68TrMfcv/rSr7p5KEudECnnoU60r6WpjXStAi6lU2oZz936XBT/d/AyBeq76enpZfCozxjrV/vvtR7nHnYuzj707EB5IZt5sPTPOrjAEpql+lHr1Ckv4XDYOYIAlu1HNY5csVjEv539b44wNjU1hYmJCezevRuP7n80gCASsta0F88mdb7qTv5cKyB8Z7XWYKYsg8xJ8T5IHWqQeWsNc656H692u9J/q+cdCoWWNeXXZvfMjUYiEUdEUSWSyWQC811t7r7npECxbElrii1FXhmz2kMXWNxUW7OsJS8UZhIhyHQtd+5cd/4te/WybSnXn0KvDUW0HE3bNfJnlUIunF85c1eCn8qLr5e5lRNtkcjDYmVF94R7NzU15RjmRGn4bJXO3dY/61pbGeFcNEIAFg2CIhnKOtV9aWhocLlRO+9K564yo/XQelZXegbuBRWYXW9bWVDPueu6z8zMuJ7UnLOvvlR7GTB604YQOt+ZmRnHoC0HubJzX03HqIPGz2Qtupbh0IFQ59pXQsr668bGRuTzSz2iyzmvlcxd113l0sqL5uqVPa1MZjpGPL96zu1eKClK98Tqd527JU4qw553jds74qmzfQicVtAoYXZ2dta1/1RnhPtQSm588u4dxTLGnj17igCeMq89e/asz3197utzfwq8/hzm/lSb9/rcn/y5lxqh4qome9HD279/v2vp9991FItFZDIZbNy40UUj63Nf+7E+9ydnrM/9yRl27k+VeQPrc3+yhk/efaMsg7w+1sf6WB/rY32sj7UdZeWQnyqeyJ+T1w2sz/2JGOtzf3LGn9PcnyrzBtbn/mSNciPk9Rzyf7PX+tzX574+96fG688hl7k+9ydn7qVGWREya8X27NmDtra2wM+KwkzWq+XIhEwmkxgbG8PIyIhrN1koFFztJWu69E7enp4edHZ2BvrKlkMlT6fTGBwcdPMtNfei6eiiLDrOP51OY2xsDPv27cPu3buxf/9+JJNJ18uU9WjshdvT04OBgQFs2LABfX19rtNMubV05c5d528bgpC5SRbm5OQkxsbGMDExgXQ67faEjENtUcq94KurqwsDAwPYtGkTNmzY4PbD13Kykrlb2dF90LKWkZERDA8PY2hoyN3Wk0ql3BWdrBPv6+vDtm3bcPjhh2Pbtm3o6+tzd7CWwxSvdN2VZc2yJ16hyGv+eMfxxMQEcrkcWlpa0Nvbi61bt2Lbtm3YuHGja1ai/aqV3VvOWGnuu3fvds15yColq533GQ8NDWFoaMhdlWpbxHKwxIm19GyT2d/fj40bN6Kvry/Qf9vWZZZinVYiMyojZFqn02kn16lUCmNjYxgdHXXPw0sl+Nx6561eRNLc3IzOzk5s2rQJW7ZswZYtW7BhwwbX6IRtV7lXU1NT2Lp1q5vvarLu0zWsU5+YmMD+/fuxZ88e7NmzB2NjY+56SW2s0dTUhHg8jr6+PvT397sOdGzIwn+rzvRFYnbdy9ExOl+uO/XL6OgohoeHMT4+7ipF2ESjq6vL6UfbSKOaSzzKnTvnr3Xp2WwWqVTK3Z0+NDSEXbt2YefOndi1axdGRkbcPQp6aQpbN3d3d6O/vx+9vb3o6elBb2+vk/+BgQF0dHSgtbXVlXTZtffJu2+UZZC5SFTWOvTCBd5qY0sd7EJzsbhgPCRaq8nm72xjV0l9oP6OnTsPhy0rAOCEX2t5gaUa4ubmZncBPd+X7wcsNTJhswR2HKu0CcVK626NGGv82AZO+4Hrhdu2bpTNSorFpYsvVIi0plrvi13J0K02dx36HCxbYvkA91ufU+tgWZbFekA2ANGi/XINciVz55xzuRyam5sdbKZNWNhelcpGZUMvd+Dc9Yakctp5ljt3OrkLCwvufbWERS8OoVywLaa2vNVha8V5JvU99JKAcnr6liszKu/a7U/vqlV50AtHKPO+chwAJctVKGt6BlpaWlwZldZCl5o311Gb9/Ced5Zuqq7UtrHa8pbzYIMKNjThPlP2VzPIdq1X0zG84ahYLLozqTqGcq89D9R5oryxBE3lSOvWbQ3vSqOcdVe7xPlz71l2Wqr/gs6BZ0Cfjc+qTaDY2Ge1C4NWe7aae1nz4TXC0bpYelOsuWOHFK3F5KtYLKK5udl1QGHbMl2gWor2dUG1jpLRJb0ntjzU1of8bNbSsbkCsHR5RHt7u/PCKai19u+187cCz8iHEcLk5KRbe3q02kYQgHMuKKicG9eioWHxdhPekOO7EKEez2KVFNseMuLhBfS8y1abDKhx00O1loP1z8BScwkeSO0ypM6N3jxk62a1mYaur+5JLUPXeKVe59plTtfSNjnRDk1UbESUtB8ADSDlq17DJ/tag0wDRmeeTltra6v37lveo0ydop2n6ExbB5tnsNJ5++Rda28Zxeve0FFVB4AOkTpGtm66FtkphR7OzMy4u6XHxsbcuZycnHR3erMNMtEHGkTbUx9YaiCi3RABLHN2ahl8Fqv3+dn8vm9P9e9Uz/JCodbW1gCKoY4JbUQ1z1GTQfZFa7wnmDAjN41Qo0KlVGjRaBRTU1PI5XLuth72f9Yoh55kNQ9qD4U266BBGx8fd9C6OhI0yLaBBh0KFoW3t7cvU3S1bpB9BhUsPShjY2PYv3+/g+s4Dz3Y2r6OkZC+p/adbWxsXAbzVaOMytkLdeII442Pj7tUx+TkJKampgJ33mpbVRbo17q+5Q5fC9J4PO66jzF6pqwACMhaNpt1kF1zc/Oyw1xPZWQdZYVx2faVyklbStIZLhaLgQtWbJOIVCrlYDquDZ+L+1Krg2EVq6ZpOC89h3pVHp9NjQudb54TDRD4THSqGhoaXMtEe3taJXO3TSvUCHNP1JG2bT+5topG2O5z9ZQdG2QxPcM0x8TEhEsLZDKZwHWF2gCJl0/w7m/VXfPz84jH4wHdqM9b7yDG15iK8gHAOZz6t+rMskUoZbytrW2ZjrTIRqWjbgZZu6MwumS0yY3TW5M0h9PU1OSuR+MVdG1tbctuXalV6HwRg+ZbR0dHXU6NnYC0/SVbqylkQyXV3NzsLsCmkqvHBpVaa/XaaJAPHDiAAwcOIJlMBi4DAJb6P4dCIdfa066HRsMtLS0BT72eEZzPWFAZUW5okMfHxzE5OYnZ2VkUi0t3k1Lp+y6LUAXOZ6/XUKWhkD4dPOZP6SBQDmynI8JmjKwrVfTlDNshSg2A9e55FnnpAuG3fD7vkCA6oFS6U1NTLrrhutAgMiKqZ9TvMxR6iQoAZ0ij0SiA4H3WTKdlMhlntMPhxTt7GSHTMbSKl2e6WqTIF9mrg6QcDwYnfB5C1mqI9eVrxVvNWq8UHWcyGcc72LdvH8bHxwPd0tT54Xyz2azjp7B9LPUOf5/fKwVZ16pndO35mRqgaETv+zvqx5mZmQByxJvRqO8tKsq1fMIi5JU8P03+s/2k5hE0v8DDnc/nnVHj32mUrDmraqJkX4RMxcgDQWMwOTmJbDYbaKavQs+DTSNRLC7e2mPvIdZ2bLUMX4SgJDQastHRUYyMjCCZTDpo1xos2xZTPUTOORRaam1Xqid2rc+iqIr1poms0Jmjd60XT1DxMxJTpatOUD29bR0KWdOw6qUnjJA5D1XEhI1bW1vdOajnGgOlHTjbIlOdYyp9RoQ0yAAwMzPjlKk965qLjkajy+7irhcEz+fSPdYzZpW69mxXfgtv2lKHm+9D8hLPfCwWcygZz0q1kLUSpNgCk/pH28gyktQoXS/ZWa1veD3WlzpaAxcin0StiHZaY6bIENcYWOIG6SVDmudXvaQOdj3RFftaaS/1WciRokPEM2zhaisflcp+VQbZZ4y1kT6jTuZElM0LLC0wYVJ68VTM6slbEkktgqdRmR4Keqm2tyyJT3rVGL/PK+o0ulQnQzemXobMOj6cu+bq6dAUCgU370gksuzSde4fB40XIyVLdKjHKOVYULGrMdYUB+9IpnJlFMe8HiFTNT6aT663UdYoxHdJgcqoNYraR9k6b/WOkIHljHzlTmiPcDVqVDotLS1YWFhwMgQgAPOx7zCjNPb4rTWaXGkoUqbOGYc6nfy55rPpWPDOXvIrmJaiHmpsbFx2hWk1DnYpeWc1hCKHilhw/irr5NaQya4XHFR6EcNK8/Q5cdTJyk2Zm5sLkD8pIypr/D8ve8hkMgGEjqkfdao5F+4j9/3JGHZdtLe7Rvn1kvOKDbIecIVuCYmpp0cSDklP9EwJ9fJ3mDujgWa0pFAYFYUycCuNkG2uWw0YITxlVkcikcAVheFw2Hl7FFJLDuDwMfaqHTp3GmO9HJzlQOoAcc2i0ai7X9SyDamIleZvr0HzsQbrkQfngWdJQjKZdIQRzU8pWYcMU94d3NHRgXg87q7rpLNCx4LfB2ojA5YaCl/bW6p07y0EqJeO+CLkekeUpV6logXKDo2ezoXyQ6MxPT3tzmVLS0sgn1tOFFLusKkCzd3zsgv9uXWUqDOam5uRz+cdAseIXh1URW58VSDlPo865AqzE5HjueX92UqcI1LBsiGWhLKUKB6PuzuD60Vs9J1N1dN0IrjHdPr1elNC/5Rv7gsrDqamptDQ0LCsioWGl2dF75+v5ZmsDtZgoxKdoO+jf2u/1qrvqzLIPkiDXp6SE0iY0Lt3GxoakMvlHHmChAAqgVwu53K6VGYAnBemJIZKFBfnbY0ACWeZTCYA89K7JnwXjUYD+SY1yMpytPnuekSZK819YmLCMasJIfGgRKNRdHZ2oqenx10Gzhya3jDENaYBbm1tRUdHh6srXY3KX8kz2BTH1NSUq1UfHh52eeNkMomZmRlHuOHBZ813X1+fm2M0GkUoFHLGWOtMORQiq2Xwffhv+9J95+/w+S2Z0EKmaxEh++ZoDayS+1Yy3jby4d3JmlfTnFq98uIaTWktNEvklOdh0TSFc/P5xZuDFhYWkE6nXaRJBEZLB/U2olpSCmrkKO+pVCrgfGqkTiSIXJq+vj709vYGejR0d3c7ZzQWiwXu4q7WKGig5SMBKlrFeUYiEcTjcXR3d6O7u9vpmHw+v4wtzn2k46FkKE2ZUb6ok+iIV3NuNYBgJE69rjp7pfJI63TbVIEtnSpVPlXuqMggc+EsQUGvubOkET4EC8VbW1udp0Q2bSaTcfWSJFZMTEy4DeJhZ72mlhNVMne+nzVqWlLDqFIJHYSKwuGwg6oprKyLtEqgHpujc+ehJgGKESXZ7FrUzmilvb0d3d3dGBgYcAXp6XTaHSgahlAo5PKgrGfUQ19OXeNq87fQqRbsJ5NJjIyMYGhoCGNjY85Ro0Hm87BhCZsjsMEAr7HkYVeiks9DBmqHwOzf+wyz5jNt7p+KzV67t1ZGWedjI3lgOTSn6An/rSUj/BkHCWG+KgMiYL51q/QZaJAZGUej0QA5yDpFmouk8zw7O4tYLOYiZJ59+7z6rPUwyCrvY2Nj7sWyIRLmmLvu6OhwTUC6u7vR2dnpXkS99E7eWp1mS5jTlJim82hEGxoaEI/H0dvbi40bN6KzsxONjY2Ora6onRIa8/m8e0+ieZYDAgSdK3WEKxkWwSJqosbUpxtUd1i0RfP5+l6+c1XpqAmy1vwplQwXnwqHucv29nbXRQmAu8+TD0wvlYoVWLqAnoaCwtDc3Bw4gJXMm3NWMpdCo5rTIDzKDjjhcNgd3mQy6SA6H2RZDSyy0tx9h4X5HObt6LVSUSUSCRchJxIJF7Ekk0kUi0UHnVKZ0ujx0LM7ka8evNZnsGQu7XjF/BSdMSIV8XjcOXbsdEUojPsaCoVcxKTGRw3PWhG91CPXtdIzo3X3PuJcveejCknzqarAbQrIRsOKOGiOsVgsupxoQ0ND4Nz7CIG1rLcqR6ayeEY1ErfwIb/HCLhQKLiIUiMdBgOaI1bIuRpDbJEFTceRPMoae61mUBSAiBANMptgECKulzG2uW5bVma5AQAcsbK9vR29vb3o7u52zZOy2SxaWlocNA/A2QjeTz49PY1QKORSZMoY1/K7atYfCMq/cg5KNQQp9z00faocJ18QVumo2iBbQpcmuxXm4cGJxWIuH8KIjAayWCy6KJlKiwefXhgNoa8Au9wHV2WibEcKGudK74fdcAgNccHT6XSA4QvAbbA1yvXIK+jal2KY8hDTKLODDyN81lGSQaqwHElRGiEzT8UItFqDXEpetAxISz/YwpH7ykPPfdBXS0sLAASaLJCdyv1U71pb9TFnVY+ha+JbH8qcErt8DQW4XvWckyoSlW1lhnNuyg7XvbIQojoaarjVyVBdYKOQekTJ/FrK4PucISUKad26ppkoLxbWrCb15AteyP+gvGvkqdAuETrdK80XW4JrvfSLjx9kSavUO1wnGk867/wd6tjm5mZMTU253DJTmYVCAdFoNEBs09xzLaiRdY4VMbEE4dWQP3UG9etKeeRqR1WQtd04W3PLh9DcEjcsFos5geJm04jopqhyJeyhjEetu+PnlfsMaiTUyFFxMEq2ByIcXmx3RyPF3+FBrzdUzWEjHW1IoUqdOe5wOOycH22lR8RBoTl+j3krRqGan1IPvJK11jW3LFPt8kSlpIQgKiSS0uhUWCeB70lojbJBog6fN5fLBZQasNTopdah72H/bSMPOiNqtHxMzXrlu2m4VGlaYxsOhx0zndEnjYOF/2nYOTd1eJQ0qd2orPNTberDpgT4PqX2UOekDo/C7zZ4UCY/iY22H8Jq+2Khf1/9MSFbRbi4F5oiUINI4zY/P78shcQ1qVVufNwBS/zTr9bRDIVCgTlbEiP/zWfl71gSYLVRsW/40hiqz8pdE01l8PlX+sxqRtURsk/o1Bgrq1F796pB5t/xILBWUMsNGhoaAoQC/b56JuUOH6yoUISFITR/EIlEXFSscFGxWAyQzerlseqcOUeuYXt7OwAEFC3rJUOhEFpbW9HZ2bmMXa1er0Y+JHMlEgm0t7fXPUK2LFOtU9f6Syp/wpFs7E8InRdchEJLbE42LUgmk66POtMQfFY6GLFYLABv1nOow2rzsFry5KsDtsqoXsaYEC/LZwAEIgOWxNGB4d+pouXPyMCnwVYjzTPIZ6XTRQPPOfFVi7JVZwDwG2RVmJqqsi1M1RgCcE4gGc4kNpbLoyiFCCkip8iQ5mVppAgZE94mOmcd6Vwu5wIGrgH1ka5Ttetrv6eoAyNgdgokCbepqQm5XC5wmQPPOgOr2dlZAAgEBPoZ9dShKnOaj7bpxFKf5zPGlHEbhNZjVF2HXMoYa3TJ6Nj30tpBGjx6TEo8CofDy2qEqXQrhYRtpKmJehLFSrE0faUUNMh8BoWs62mUNWpnziYcDrt/26YCAFyaoLW1NQBR++pfGXXHYrHAjU+EhbV+udKhykm7RWnZB6MEjY6bm5sdiYs57Y6ODsfU53tms1nXZW10dNTl9LWvOJUwySjqJdc7b2shUjpAAALGwNad8vfrHRUwOm5tbXUpACIFRFmYw7QpJ/03AIcGab7VGmVV1Cxx0fnUg1PB9/P926J4zB37+kZT7thtj2vDlA1lj+fAonKlhkWEfJ/Nl2UuUwey+oD7psacUSXPt5YP1ap7SkG89n0pH9PT045PMzs7G6ii0VagLC/Vskw+rwZC9YTj9e9sWkKfbyVZsukbhe+pQ+tVTQDU2KnLB2fYiFPJE/w3lbx6LoTN0um08xLJeia8qZGFdvAi3LTaoujnUVGxLlEJRCsRtNRQ8z2Yn1TDVe8ImUaTJT5NTU2IxWIBchSNreaVCcPb6IxGgO9vb+rRG4gqgXd8Q+E3lsbZ6yDVSSASQVILr5ajYmSExq5K7FI2NDTkGj9oqz6b8yIMuRZEKmC5UeY62+jMRsnqcVeaivENyg2jVI2W1SDHYrFlDXnUKNNp09SHpqoABPoDcJ+npqYCxDHbR2Cthk0TaHRq+yRoxGaJjVrzW45BtsihjY5t4yNWolAOuGeaax4fH0ehUFh2djXlQXmxUGylzo9FDX2BiTXKMzMzSCaTiEQimJmZcalInnNNTWmuXNNfVp+qzqk2TWafyz5fue+niLDKvY/7UQ+ZronUZSNkToiHT6FcfdnOLABctENGr+ZdlLTDlzLyyjngpfKwVMxKAirVjk6hICXKFIvFZdT3eihUnTs/D1hq4M+ewb5mEwqZMgqzHZq4j77cnHUsfHmdcp7NRipUNqoctR5XIy51EMjCpDGmMrDtNqemppxcaakc38s+f70Mg43M+G8eYMKmKxllPeSaiqlGhmw6RmVbmdZ6DnR9yb7n33M/rKHjnG15ER1FLRHx5crXYviiYxuVak99YElnMW3DVAkZzeVcJamfbQmMalAVHdEcKoAAZM1OhVxTna/mwXledI7A8msKyxkWQVQjyX3k9wG4fSYpl/OnQ8Zn5YvPwDlZfaz2otbgxp5JX7rVymIp2fQZ9VK/U8uomdRF5ahKrpSnpRvNn9NwzM7OBu4/5obT4+bGc4PVIJeT2wGCkaZGBxQiwpkqhJobo1JVj5TztL1Y+Xn1Gqpc+TWfzweIZ4wyFVphv20l8dj6UZvj5Trb56HjU6nDoYfAEn5s1zB1eizTlZECn2NmZsYZY8JjLMFR46vvu9aRme+Zdf11ja1hnpubC9ycBNTOStb3sDCdza0RclaSIK+GVOVoYWBC8rpP+nO24NS9WEtkQuVZ4XNtWMSaX+WFaIWB5VGU2xxHUzS6BrrPXAu++PvcM1veCCzPtyohDVhqnGTRrHLTBL50HtEUEtz4Ig+DRFiea42c6XDoGbdRJedKI0ydqwFRrWijL5+velAdc59M6nrY6gQ6J7UQXu2oqVOXMioVU7dMa18EpoeXzGWF0hS6AJZ6zE5PTzso1fYPXumQq0EjHMpLrem98+/VKPN9lXWqMBGdC1uXVk9jbJ9BD45GJ5FIJMDALBaLTvH4cv2E+5UcNTk56fJA8/PzrikKGe/Vlg2pR6rKqFR5gz1IfA62NyWzmo1RaIxJauMzamrCV7O5Fvukz6xGmdyHhoaGQMTGF/Pj1qmjk1vN4DNaR9kiHVapNDQs1n6ro8rBZyHCwYY0atQtO3gt6605J5/M6OUrrPslJB+JRBxhq6WlxdW3M3dMUpfWyK4kMz5nQHPHWtZj0xRcE50/KyYUleBXLdmkYdSgQNHJcp1nqyNJIKUe6OjocMRJQtOMnFX/Wl1jP0M/S9v12jvFrUxWcgasPFgegeVv+OapDgOJsyS9EjlR0ms9jHJNjUFohH3RzmrGUQWHbEH1xFhgzofUw8UFYE6M9bflGmQyTguFxRaTnLvNhWuOmoLPiEDJUIwkLKy1VkaZ6+ZDI/iiMQaCZU6MCniAqXBnZmYwMTHhSoay2Sy6uroC+WbCmiSyleN4WAi3VKrD93sKfXJo6RSJL+Pj465+Wb1vdnTiy+bF65VO0H3RQdklygMg0L1OyS7pdNqla/S99X1rma86Kfb7lGOFsOPxuDO0OhfNEfOiAHszmkbIVGbWINfTKJcyxmT0k/SnbSqpczo6Opxh7u7uRn9/P/r6+lwTDu3TvFqEbKFy7XaltbbUkaX0JHUio051cLSZDgmRuVzO7aF1lFU3rDTUyeee8SzRAWY+uFgsOv0MBFn7+Xx+WYqS39PcMf9Gjb+9NEPLSCuVB35VPaI9D7QntwYFdk00gNMmLexkWK/GSRwVQ9Y+CMDmQnjwaBj0AX2GQ8uiaJR5NR3/nsy9dDod8KoIP5BcVWqoR0Zhi0QirpbXwhiqOHggOAcaZ76fwl7lwFr1GBZKtsLAQ8oDbR0NriGVLA1yoVBwgkuFbJ9F97ES5erL5dj31PkzutLWhpaYwxwy2/BRrpgPZHMU1jJXAj+WO6xRVhRD31+NVmNjoyO2sTxED7amemqJkH3ztJCm5papfGxXJg7Ck+R7cL217ziwdNFHY2NjQDespTG2JC468JOTk64LHOUEgHPMWX/f3d2N3t5e9PT0BEiE2it6NeOgupFRrOauua6KyPFMcj94Jql3GAApMY2yQwNJTomtiPDxQUoN1Q38v6bsCLkDcD3A1alQ6J1Xd1LvMICz+p9yx0CMTrM2QKnG0NnAsVRvbuo43Q99dup2lmD29PS4F7sFqvPwhEPW9mE1N2mNmkZydtj8lSV8EW7hhtrCejWkPu/G93m6wMBSFOUr86BHr5CbLRni+9ncxxNhkPlM+hVY2hv1Lm1+SOEo1pjm83l3uAmrEYpXZ8OXdqhkvvbv1fCwXprvSWSCz8JyEL3I3Taxp1FTx077Fmvep1aP1hpilWcqFH6lctLoTa/QTKVSAf6ERX7WwijbqEjJV758rxoGNlfh7wJwaQVgKcfpc9TrZZDVsVNWs+0CpzejkYWvHBI6bV1dXS7yaW9vd85buT0GrGGysL06/mqMFerUM6HRoeaalRhG5Io5b4VQlR29mvOsZ1OheYuKEIVqbGwMkDGp90mO5Tro3hCRo+OsOp8Io653LSiWz0mjTNCR0Zy+7p2uCdFPbaPMEkyfg1/rqKnsSYelf+uCrJQ0tx6Tj4RAReAT7lK5Ct+gwHMuFAoqFc2FW/ajrUGjgrRCpJHRE2mUFQpiPs8KPD09Nj0Ih8OBNnY0gMAiSSSTySCRSATY76rEy4kYfMZKmb1qNGl8+d6WLEIonRExc2ja+UeNmYXBbKRT7f7o36lzwbXWqwHZQ5w5N4UBmSfUu2Y5d16aoGhTvYYvomcqo7Gxcdm5BYK94IlGkYhJxCgUCi1TzEroUUNUq1FWnaIRKR0dGmJf4xmeC9bxE0Hp6OgIIClag1+uA2fXy7eWXHslhdoz5avYUAPDyLSxsdEhLKlUCq2trc4ga4BTjsHg51hDmM/nA73kuYZaE839JYfFvocaRupee3Z8lSqV6lFrf6xsEHZXJ963P6pP6WCrA2cd/Hrp/KoMsjWiluTDxQCWhGil3JHPiGuu0bISfd72alCYNQz6HDTKlpTFaFFzOmqkNemvrO965f0qHSsZPrterGlWogkPiyISjDKU/auQ92pkOiAIi9L4ch21njIUCrl8MVEMRUBmZ2dd3kwhdaYftFRK77HWu53rGR3bqIKHVi/1YOpA7wenM6fktFQq5eanfdR5h3g9o2TOX/9NJVnKeebZpAGhs0ODTERMYUxrjH3kvWpGKWNsr39l4xm77pR97UqnlzYwh1kpAVBlwqYd9EVDybVULoySpPiVsqLQKwMI5qlJWuOdxNQDym0pZ93VweffUK7j8bhradzU1LSsmYzNc6s+BRbPMpEVDV6s8a3GEFv5KJXC0NpoJXb55FLn5+sBb2XErmE1oyKD7BM2y3bm4WbEokZUIQx61DSuWkJFY6Cdm/Q9bNRdyQFXiAiA99BQQPjZurEKOfGw2ChZvdp6KtFyny8UCgUMBBWpttlsa2sLNEjQKzO5npq3V2VdyUHnetM5iMViLnpSQQ6Hw64JPSFQRliUBZJ0stms8265viTrqQfL6Fi7w/mg/GrXWb17ri3JH729vZifn3flNIxgeM2eRv1UqFwPrlMikXByr45PPeVJHUcrq7qvNNg0aCpLjJAVniTSoqhSPUhdPmNMOJItHJPJZCA/z1QMDQQh6o6ODvcimmHTGpXKiRpDdYj5suQ5PRckHuplEtSpuVwucGUsyVwLCwvubvmxsTGnt4Cl/vR0Rstdd5UJdXQVWWtpafF2dlPHiDpCDTLRRYXGfXtbjQNqgzorH4pE2R7ipaJk6hbVYT59Ui9EtOIIuVTeSaMPXRT1nmyECywVvCscTWyfxiESiZT0YiqdO4caAn6GGmRgUZnw59qAgwZDhdV26tKx2pxr3Uifc0IB0s5eCjtpOQaJL5OTk0ilUi5azefzTtFRARImswd9pWfj4YtGo+4w0lnge1ApMb/Dg61NT3igtNyDBowEHSpWlmspSaQeuSk+E5UFn61QKLg5dHV1uTwZnZ/x8XG0tLS4+7c1FZLJZJychUKL+U2WmWh5DPdV972ew66JdSrVIVNonoQ7Je5wzsz98zlqMcYcCkcqksOObWNjY65JDGWFkV5TU5OLjJkPZI90SyiqRtGqMabTooiN3mHMrnMkDJEkRMPM3w2Hw5ibm0MqlcLQ0BAAOPJloVBwDl1jY2OAv0O2vCItlTyHRX9Yf6wlcfrSJizqfFC2VX/6ApdSwVa56+8zxtZJIHrCyz20WYlNKfArdanmuhUhrFfFBlClQVbPyU5SWXk0ZDYvywPCSNr297V3qrIbln5+tVGO/q7NJ+t7kqrPZ+HGaqcZzXfaAnFVGr7cjY1Mys2D21EK5ldYnR6/5ia5H/SuI5GlmmTmlRnphEIhd1AikUgAcl7NSdJDzf9rqQM9aPXCWdJBh4DRse1uBcD9PaNTGmNVaOrNVhP1lHoulUOeBUJ7hEhpuBgl0CHlFZOMcLiXDQ2Ll9Nrq0F20FJUBkDVMmOHTxHZZ7W/o4qaDhU5ADw/luhZK2Tti47VIPNe7ZGRESSTSdeNS8mXiqLQedPI1JY3VapbNGBRg6xoDaNf6gY6XySUsdyKDPBwOOzy9YyUNWVDI0h2NCPYjo6OFSPAcp9HHTB1ghXdVAKsIqXU8XNzc460SDuhepH2QvWY6tLV9kFTl74qIMuy1m5tpWTShxTQ3vmal9QjSq4YsgaW37FqoUBdFB+jlEoZgGPvMTdo2W9cZLsotST/+bvW89fvq3FVQhlznKU8IysQpTbavmhAKxk2UvApP1vyxLUjm5bODutKm5ubXcRDr1YjI+adyyXU6aHm/9WJobwAS1E9HSUqWjoEmodkyoPPRPiUEQGV2VoYY649v6qis/NRx0cjdeaudH0BuAiOsBpvRyOzVSMLno9ansHKp56DlYbKrmUEqyJWhnk9Inp7vqyyTSaTrlnM7Oys0zW8ZEUDCMqGT4fVmspQ5a08BjXI5BcoCVG7YVGGaRBnZ2cDLGoAzqjzrNAp1FukakEmdJ/1uRQFsYHHwsKCQwE0WCtFBOW8rL7izyqJkNWoq0FmlMyX7emvQZFPti3xTPVJLfJiR9URsp0khZpDjfHU1BRSqRTGx8cds5e5RxIT6N0y50ODrPi9Cmw9GLP6NxoF8Pm4SbqhNMj08nQzFQkgCqCQjH6W3cxKlavm65gj4fx8xlIPlK6ZzZ/rfPXAAXBNVEqVxay2zsrS5udb1igPqyInREsoTxalUdKJwtWEzdbCGNsozdbj0znyRYVWcXANCKsyF9rW1uZgSIU6uVczMzNVzx+Ad/9Keft2vj7uhzKsfVUXtcLVdt31XNLZ1xInppy4dnoe1RHns1ULlfrWrlQOWXPEnJ/KvOoPjSRt2ZTKlToolvSq+1DLmvsiTxshlwoK7PvZ97bBQzXRvM8Y81XqPgQ+h2+eqqeUjGcNso2Qax1Vs6x9tHq+1ECx3/Do6KhTIoRhuME02hMTExgdHUU6nQ6UUtAYKyRJ2nk9mjzYv1MYWSMc5hxCoZBjM/LgaQTNWkfttKNRqhpGGkftRlXOoGJaWFhwDg9bSOpVg8ASoqFODaE5HiJ7iDXCZ/ROpaAHrtI1XsnoM5qho8PP4VDGPSEtkrmYG2cNqUbIa2GM1RlSzgNz8ul02rVpVBKJso6tI0Ly3cTEhOsOlc/nMTU15Z5Huxixz3E1wypADnWWbMRbKjfnKyWxtba1Ki07B+3WRmOsXc942QF1EyNVyprlrKh8KNJXzXx9HBuFrbmPRKIKhYK7oKFQWGp+k81mXYQ8Pz/vCGsasChaxM+16bdahs/xJ2vdVr9YtruWJVoOkepFdfCqMcoWrVLnhPP09RTX+ngO6nR1rHyBZ70QFTtqYll/4fdfwCd//UmMzoxiW+s2XBS/yD2MLgYnOzs7i7GxMZcT09wDBZK39pDQpaxTXlRPIoZeHF7uYf8/d/wffOP338Dvx36PaEMUpwyegg/85QdwWM9h7nd0Y9UY09iFw+FA1yXOkb8XDocdHKUCwgOuuTc6FOVEO5/65afwqXs/hceTjwMAjug5AlefdDVOajsJo6OjGB4extjYmMsvUdFSOdBwsdaSOSElv1jmpEat1kOvJifFNfvAzz+At/3kbfhfJ/wvvOvkdwWcG6YzOLSsjkaZzGPmncmaZR9i5pHrUXd8+67b8aE7P4T79t+HA9kD+PqFX8fZ288OlJ2QsU7DkEwmXdmN/l+Vqda0A3BkQR5+7g17KvNFR7TcCDkzl8E7fvIO3Pz7mzEyNYLjB47HR/7yIzi+73h3BtVZ5HpT0dvUiEaldD7YkpIscos2VWOM7Vk9+aCT8Z7T3oPB1sGA4ldDrK98Ph/gGNgIdHp62jnWKsuqZCsxyhY9HJ0dxdvveDt+tOtHmFmYwYbmDXhZ68tczlpJTqwcYMc2hbmpT7LZrCOskefB82nheJX7SmXfystxfcfh3Se/G9uj2x0hStOK9mUdNUW6qAvpgNpUlI+JX84e8Pdv33U7PnLPR/DA8AMYmRnB+459H7bPbw84zXrTFnWc3XsNlqiTLF/KVyte66iaZX3zozfjXXe/C+955ntwSMshuO631+HDIx/GleEr3eTUEFGgtBcyN8LCT7wwnA+vNYOMgBglV9q27LZdt+F1T38dnr7x6cgVcnjrj9+KM794Jn772t+itbHV/Z5VQFpuwwhOIwnmb1hnRzKJems0kL5aNhJ7VhoHtR2Ef/zLf8QhnYcgl8/h+vuvx2XfvQxfee5X0JxqxoEDB7B//37n0Cixiyzr9vZ2Rx6Jx+MIhULLrkFUj1YPfClyTqVw2L3778Vn7v8Mjuk7xuWllSSkjgyNAuWFB52OHg0y5YNddOiwaS1ntRHa1PwUjus/Dq86/lW44MYLAk4WDTI5EMpYn5ycdIZKbxqyCoHrTIWlzzs1NeVuINLov6WlxZUWrTZefcur8fDIw/j8iz+PDfEN+OKDX8QLvvwC/OIVv0BPc08gmmU0oIpGkQybsyU8bFmrVKI2HVLJHuhZnc/N420/fhvOufEc3HHJHcjP55d14+JXOgaUU/IiLIrF1BmAwPNr9FPpoD5Iz6dx5n+ciWcf9Gx8+YVfRlOuCQ/sfgCYBKbiU0gkEs5ppnOgBlbrXmm8GLSkUim3ztwrIgBWr1TDFqe8fO5Fn0NftA9f+PUX8NJvvxQ3/9XNaJptcjJtjZuWPyn6oJc5UDas0bWRsc8orzaKxSKy81kc2XUkLth6Af6/n/5/KOQLy+al+liRSxsZK1/JlvfanhVPOmT9yV99EpcedSledvjLkMlkcM2R1+DOsTvx6+Kvsalhk4taKUiEr/kAesjVaFsjYFvcWVZkuV1oOP7r0v8K/P+GF92Avv/bh/sO3IdnDT7LC88RduEhpsHw5RkYWRDaVohPGYpkO5NYVU60c+5h5wJYchb+4dR/wGfu/wzuH70fx+aOdbnH8fFx1zwAWLo/ubW11RX2M3/P/BQjPO4VhRhAgOVe68jOZ/Hyb7wc/3bOv+F9d7zPpSQ0r05IUQ8ya5Dtnmv5DSN/Rsf1ukjirEPPwlmHnuX+byNGJS0yYrQGmRAeYWurJPS9KU/FYtEZP1Voc3NzrtxotTGzMIOv//br+ObF38SzNz8bhUIBbzv1bbjlkVvw6V99Gm884Y3Lmt1ofoxzUjkmAVCjUTqrtgpBIwlVXuUMPauFQgGfPeez2PRPm/DA8AM4Mnakt1zSVkMQzVKDrBGyT5b0YoVKYWv+7ofv+TAG2wbxmXM+4wxpW74Newp78Pjk44jFYoGby7iuWqerCBodNs1/Ktqm/Bq9A7xSdMjKSy6Xw1ue+RZ859Hv4KuPfRUX914c6HhluSu+UiiVCw4tMa1nhPlXW/8Kz+p/lgtwLDeAe6SG13IGmLu3OWM1wmsBVwNVRsgLhQXcP3w//u6kv3OHrqW5Bccnjse+zD4c2nqoU/40tCz1AIJ5CV8Oi1AkexDTADNvzNxgPaKf1FwKANDZ0hmYm87R1kor9Aosz8fNzs46KEzfo1gsBuAz/oxdbyoZ+UIeN/3+JszkZnBMxzEojvrp/owwWSJBxRMOL9aHRiIRzM3NuQhD84B8RkvgszBNJev+uv98Hc4+9Gz81fa/cgZZUw6MaCw0xAPA/eb/yRinrGjf6lrqSVcaRSy/Kk9JI0QcGD1ScVmSmubvuU+8DUrLpJQoQwXOi1VWG7lCDvliHs0NzQGZbo404659dyG9Pe0QKU2p+Er4VMESCSAUb40Jc/u2ZWmlRlkHz2p7c/syEo8lqnHe1pngHtlSTUaZdJQtEa+SEQqFcMsfbsFfHfxXuPTbl+L2Xbejv7UfL93yUpwcO9mlV/g5NLQ0utSHnK+ma9RJDofDjihGfo3eUKVlheWeAcpLS0NL4PstkRbcP34/Lu69OCAPimoy5aXngvKq6Rg61dospRTxslI5sbZAmeF03PXsUAdyzTUQVIfS6iA1xjrHWnVMVRHy+Mw48sU8BuIDAQ+tJ9qDx7OPo729HdlsFoXCYktDn/ekBpmDkC+9PAouW9vZJg+23KLSUSgW8Ib/egNOHTwVR/cdHTDEms+wbFpfTo3KdmpqyvWTVaa2kpai0ahTxDTI5cKPDw0/hJP//WTM5mYRb4rjs3/1Wexo3oFd6V2u7IfvRWNLxcSaV+bwZ2ZmXITMemRtM0ilqofHp1zLHV99+Kv41YFf4Zd/88ulb4aWMxoVKrVQFhBsTqPQnjptqozqaYw5LPtUSXYaQTA6psHW+nqVAVX+c3NzAeatGmOS+KLRaFmlconmBE4+6GS89/b3YkfnDnQ1deHLD30Zvxz6JTbHNmNsbMzlBBnZWs/fyjt5FRodM1onpMczTERLCZjV7EehWMCbfvQmnLzpZBzZfSRSqZSXAav5UwDOyeM59UU7PJckWtVSJsSxc3InPn3fp/GGZ7wBb3rGm3DXrrvw1jveijcd/iac0HmCc+w1PaOX2ABw+6HRmyINjY2LrWg7OzvdlYC8z5nNTiqNkp283PFeHN5zOLqbu3Hj72/Er0Z/hc2xzcvqcDlP6hDNE2uOmMEAz2pTU5PXebDlRBzlzF31AufW2NiIaEM0QPok/4Q6W9NzfA/aNG36pFF9vQ0xR9WkLmAp30LPo6mpCZGGiOvhy4OpDd6VtasCr5EO+8vqdVfMeeoNLLWyZ1/33dfh4ZGHcccVdyzzuGmMVbDozdLroxBSCWj+Rg8Bv2ouV+tQm5qayoIfAeCwnsNw/2vux8T0BL72m6/hqtuuwvXPuR6JRALd3d2OWcrmElSW9FAXFhZcDoq5fGVPMjrmezCiJxnMKtdyD/qe1B783X/9HX542Q8D3ncIyxm4luFOh0dZyUCwU9oTGSG7uXs8ecoOeQfaeYwoBZ06q7g4qJDpRKnTNz097XLj5RjkYrGIG867Aa++5dXY+omtiIQiOKbnGLxw8IV4cOxB7N+/H+Pj405ONGdv98MaZU0pFItF1++8sbHRyQvPcjweD5yLSseV37sSvxn9DX74sh8Goh51yLQjFtdbUyBMI/HZ1PGwjrKFOqtxIJ628Wl43/Peh3w+j6O6jsJvRn+DWw7cgjOPO9NxJ9itS+WC3BPOU+FVvqhv2Bmur68Pvb296Orqch3I2trall3JWM5zfOH8L+BV334VBj82iEgoguP7j8f5h5yP+4fvd6kSOoqEhvWcatMglWuNjltbW53zoD3EfUhKpcaYpYLAYs/1REMCXV1djnnPrmN6Y5zqFuoVXz7eOpT11ilVRci9sV5EQhGMzY25/GIul0OmkEFftA89nT0AEDDUyqqjsBMq4AFjHpDeHj2/np6eAKvaGuNqxpX/eSW+8+h3cPsrb8dBbQctg52pfKg4NedNxQosZ+XRUGkhvEIfvEeUQk0BKbcOuSnShEO6DkGho4Dj+47HvfvvxVf+9BW8cccbXVet1tZW1zu5paUFmUwmUHLDKEvz9looHwqFXFogHA47Up12D6rU877vwH0YmRrBiZ8+0X0vX8zj9l23459/8c+YvnapW5WFR7VUgUqL+69NFdidS2/qqTcLkkMPJCFFzl8NFnPLWqaiCl/lTQf3hs4a30vz6KsZZBqTgzsPxg9f/kMkp5MYTY8iVojhb37wN+gJ92Dfvn0YGhoKdLbS/CmwPA+njqrmi5laampqQjweDxDsSESrJqd/5X9eie8++l38+LIf46DYQY6LojlfLSdif21Gl8DSfb5K4tHyIDbe4BmpBbIGgA2JDTiy90iHkjU1NeHIviPx3T99F729vYGaZGAJniaqqDlXiwjppSnt7e3o6elBf38/+vr6XAUK4WttJFLumm/v2o7bXnkbsnNZTE5PoqelB3/9jb/G1rataG1tdfK7sLDY9pUohLa71ZJFJcqxX0A8Hnd3CjOatxc2lIumWKdFOS/Nzc1oj7e78wMsXYxCHgSNslYbMG2qUbLVd/XWKUCVOeTmhmactOEk/HTXT3HeoeehWCyisakR94zdg5cOvhQd8Q5naGmIKXDqXXHBVbFqeRMhGHugbUK9koUpFov42+/9LW7+/c346St+im2d27wKp1Ruis/hq7VUD1bZjwrTWAJNsVh0XXsq2QO+iigij7xr1xgKLV12QYdHladCpxqlawmAlq5paREVLlvgVWLozth2Bh76Xw8FvnfFt67A4T2H45pTrkEkHEGukHNrbdmajBStx635IWWva318vQ1yCEsHki8Ls3O+Nm9sGaQ+NilfWkvJWvepqSnn4JY7+H7RSBQ9zT3YNbILd4/ejZe0vQRjQ2MuSmbJkjXKpZiufG4iYVSEDQ0NDq7WnuKVQtZ6Vn9y+U+wrWObK5lR5qvde764D1xLjXapuPnSyxJ8DlKl49TBU/HI+CMBY/p4+nFsbt+Mzs5Od0bJGyD8T6eBZDSuOWWdBFemZ1jqx7ucWfLH37GOcyXnINYUQ7QhitHsKH6y5yd4+zPe7lIljI412KJu1A5YwFJ724aGxbaw9rrLUl31Kpkv15lrxdHU3OQ+k0ZXAyXdHy3XY35e4Xldw7WAq4EaWNZXnXwVXvnNV+KkDSfhxL4T8bF7PobZ3Cxeuv2lwBQCyp71flwIzZ0orKvwN70oZVT7ausqXYzX/efr8OWHvoxvXfItJJoTOJA5gGKxiERjAo2hxmWMPMu65r9pxFShqmLmhlM5LCwsuAbvaqgZXZUTIV/7o2tx1qFnYXP7ZqRmUvjyw1/GHXvuwNfO+5prnKLpADI3CS0SClO4kYdHPXIqKFVenG8pAV1tJJoTOLrv6MD3Yo0xdEe7Xf7errct7tdmJOoR26J9JWDUA6rOzmfx2MRj7v+Ppx7HgyMPornQjAQSbj4clkhEx0Llp5SR498yUtIcHNEUEr/KjeB+sPMHyOVy2Brfit8N/w7vvPOdGIwO4sTQiXg0/ahrN0mDoHltO09VSBr1qAHxMX+rqYu1Z3V4ahi5XA4toZZlRENtCavOmK6fEnfUKVV+ixrjWiLkq555FU657hS8/47348IjL8Q9e+/BdQ9eh39+/j+7Np7AYm8GvV+X5WycA7/SgdBSUL3ZzPbmrgbF4vj+Y99HEUXs6NqBRycexTU/vAY7unbg0qMuxcLcwrJ2mIyQKTcKWYdCi10IibbRebMXaFhiVzVnNxQKYTo3jT+M/cGt2/7p/fhj4x9RRDHAbNe5amWAyoglc61V6ktHVQYZAC45+hKMTo3iXbe9C0NTQzim9xh86YVfwkDzACYWJlbsYxqJRFxEqBEGf8f2mlU2ZCV5Bd/41L2fAgA893PPDXz/s+d8Fi8/6uUBg+wjdvjyafp7nBs3kkOdDtt8Q73JlcbI1Aguv/lyHMgeQHtzO47pOwa3XHQLTt1wqouc1PgrwYWOkDoU6s0yMuA8FdpTqn+tefuVhs8JstCub62VELYW5RT37r8Xp3/udPf/v//R3wMALj7sYrznpPcs+wzrxCk5sByykM1davkRndhyDXKxWERqNoW3/+Tt2JfZh47mDpy+4XSc33Y+9u/c73gQ9o5Yu/Ycuta5XM4pfY3mKDO2XlOZ2+WMUmf1k2d+EhccfMEy2JnGgZ+jeoZrYREjZQNbpKyW8fRNT8fNF9+Ma2+9Fu++7d3Y1rkNHznzI7jsuMucs8Nzag0S15NwN9dU4XV9WRRupcsPyhmpuRSuvfVa7E3vRVe0C+cfdj7eceo70IxmFHKFZYZK11dlnUQuyq9FNnxVFLXq+Hv334vnff557v/v+Nk7AADnDp6L129+/bK10s/XHgCKeNbD7pQ7yjLIFIp0Oh34/uWHX47LDrvMeRos9dCWabZtmq+5hDINFapkBMc8HD1eUtXt4eb89DDZuaeuSi17Nh7UdDodqHllbkGZeHqILQEMWLpDk5/L4n19LoV2NLe+2tw/evpHgdOXfqZ5HEbCejmHRgC+GkEt6VKPVpuZKGysOVFtgKK3Fa0mMzq+fcG33e/wWdh6Uut29Xk4V2Ud237G4XC4pIzYUY7MnNh1opMbhf55KQoJi0py0lxkqTzxakONss7LGslSc+fvnbnpTDzvkuc54zsxMYF9+/a5OatzphGjzzjRMNDZDIfDASRAa/anpqZcZEQDRPgym82uuu56VtWYskxPO6TpOdUcpmWxcz21TI1dzyg/2WzW5egJyaoitjJTStZPGzgNP3/5zwPPwHuMiRyqzOicffOm0bBnV5udsNySxEybtiln7i8YfAFe8MoXAJDSvvkcMnOZZZ3p7NnkM/Ar5UVlRGVFm7vovH3OdDlzP6n7JCTfkHSyQl3O/gy+agfrNFPOdc7aa4B5f6KNikSU0jc+PeMdxTLGnj17igCeMq89e/asz3197utzfwq8/hzm/lSb9/rcn/y5lxqh4qome9FL2r9/PxKJxJqH7LWMYrGITCaDjRs3Ok9lfe5rP9bn/uSM9bk/OcPO/akyb2B97k/W8Mm7b5RlkNfH+lgf62N9rI/1sbajrBzyU8UT+XPyuoH1uT8RY33uT874c5r7U2XewPrcn6xRboS8nkP+b/Zan/v63Nfn/tR4/TnkMtfn/uTMvdQoK0JOJBZrLffs2YO2tjb3fWWyJZNJHDhwAH/84x/xyCOPYNeuXRgeHkYqlXKMWdbaso9pV1eXa4vJftVst8cbe1hbx5ft9mPZj4ODg26+K829KKU0fAa9sYfN84eGhnDgwAEcOHAA4+Pj7me2LaXWUfOaw4GBAWzfvh2HH344duzYgY0bN6KjowPRaHRZycBKc3/88cfR2trqOjWlUil3m9DExASGh4exb98+7N27FyMjI8hms64Ang0TWMeoa2aL21lb2tTUhERisd0cOwCxxZ3vXt7p6Wls2bKlrHVfbXBflEGazWYxPj6OPXv24NFHH8Vjjz2GvXv3Bu7NZrlcIpFAT08PBgcHcfDBB2Pbtm3o7+93605GJJmnmUym5Lrv2rULsVjMsSxVNiYmJrB//37s2rULf/zjH7Fv375Acw2t6WZb002bNmFwcBD9/f3o6elxHel4BvT+5nI8/krknWtbLBa97PRkMon9+/e78/vHP/7RsVL5DJ2dne6+6d7eXgwMDGDTpk3o6+sL9E72tbatZO67d+9GPB53bGSyw7PZrGOI79y5E48++ih27dqFsbEx1w1NWyDyc7Veur29HRs2bMC2bdtw6KGHYuvWrejr60NHR4e3W5Sv1MXOvdw112oF9gIfHx/H3r17sXv3buzbtw8TExOYmZlx1SfNzc2uDWlfXx82bdqELVu2YNOmTejs7Kz4zu9y585zSNknU3lsbMzdYsbLUyYnJ10ffLKu2UMfWOyW1dHRgf7+fmzcuBEDAwPo7u52Op/63vafr2Xu2l97fHzcXU07PDyMyclJp7/1vma9M5lfbe8DdqTr7u7Gli1bcOihh2L79u3YuHFj4B52vfyolJ7xjbIMMheHBpODipN9dZPJZKCXs9Yf0hhwg2zdnK2dsw3jtTXian2UfV1UdO5qjKmcCIHMz88H6ve0cYmtv7X1abYmVhud6G0sPoO82tzZCQmAa/ihys+uHY2U1jZS0VuDzK86Z1/nI14YwCYEsVhsWSP/1WRmtWENMlsgzs7OBrr5aK0xD77uA5tEaAOFlVqvlpp7LBZzJVTFYtGV3XDNtf5Va2L5PlobrfWxutZcUxq0Sut1y113GgeuLedfKBSWyZNvfXz/p2zRkdMLSPj+K5JYSsw9Ho87h4zyxbpnWz+qcq912zpP+3vUKyrn1C/aEXCl2lMrO6XWXEvl9EIGX522Pb/6O1pzzK/qZGrnqdXGSnO3jltDQwMKhYLrzKXrziZCbLmrfSRYRsn+/bbmWvUjO4+pbq917ixZm5mZcfpTnXF7Jlmeqv3CrYOn+l2fQeWnVC99nXOpUXVjEB1UiPy3NkEAlhQ9D4x23aIQaY0h61r5Hqpc9SFVEVcyV2uM2bmKddTJZNJ5f9rrVBtpaF2pdQzq2ZSCcwaC1/3pdZDaZYiCYw1vKSOsxtjXVEP3U18U/HoP+776/1JzyefzrkZTb1CyNyXZ5iKrzV8PN6M1rfcm6qO3lqnhIuLANSYSo404WBO5lmvqeyYr/7ZrFZ+bzR20FrO5uTlw97PWGuuZUBmq9Iz6onmtIdU117roUutnHQp1+PiiXrLd92od6mQqKqF3CusdzrZenfNRGWT3Q+pO1sZW2nLSN9dSOpL7z3mzgQw7EobDYdeYSJsdsZUqa7sZcdN40yCrzahlWPkhOqH3SduWtrruNLal9BB/1+phW4tdzXmui0HWCWsnHHpI6m1zE+hF0LgWCgUHPzICpMIClto5KqTn865Xm6PC1GqMaYjHxsYwMTGBZDLpisl9kIztY73WbdV03oRTfJeD2/n4vDP7VaNoeuEKG6kTsBZdjThKvY/umyoJ7p86Rw0NDcsaomhTlMbGRhQKBeeBrzZ3fh4bOTBlwGsHCVH7jIP2SwaWmqik02l3DuLxeOACd3U+1kKWLISqSlYv8NCezkQI9OxpK0cATi6ZlqqlBSKAwD5bI0YDRnlnukg7AaqDoL8XDoedkta7q7V7FJ+znsZBjSl1DVuWssmJNlGirrSOyczMDDKZDKLRKEKhUMAZB+D6Y1cTEFhnTXUk14mX1mQymcBlGIRiqRdURwBwuiWXyyGTybif04jTHtTiDKk+0mfQtAebmmiQZZ1qnmHblc62JaaTSgeFAaMiBJXqyDUxyBrFAHDQgULTDO8ZQeghAZau6mpra0OxWAzAHnpxQrmCpwuj3rHm0EZHR3HgwAGMjIw4I5xKpQKXsFNp/f/tvXmYpFV5Nn5XVe/dVdXrdM8wMwzDjMMWERRFTIxEomIAcYngEvhCNCYhxrglYNQoMYrG71OjBiVRMIpG0SD6U2NcwSWRRTGAIhJgmK33rq33rnp/f7T3qft96lR3LW+zePW5rrp6ppeqc857zrPcz/08j41TqSCwQqhZwerzFlhViIdL699aA4EC1MJZCqlSoBFWIgwGIGQl+5Rys2uz67QKWM+UGgb0bHgOKBw0jmUrldUzbxXotOopRGm0ce951i0MpkprcXExVEErkUi49nnWS95oZWyRIQorVu7SSmg6FIXgM+HekFNBZECVZD1rskYDDSHGLe15p8GjEKLWsaZABuDmwrUXCgXXLEXDNW1tbU0/C+uksMKYxmJ5lmZnZ0NlS3medS947zOZTAi1YNlSDZk06hxY9IQykobo9PS066ENwMllbTrDz1Ungkayvh/bYbKbHEM2zcgUn0Ghcl57eCsqofKcDgrXx30GEDr32hKV50jLPhO5q2c9kUHWPiEKlIU+Yx1qRVDoUyFTaPJQsqQdA+m9vb1IJpNYWVkJYfO1Xpq1Dlsul8Pk5KQjcbFHLPvZWlhJS6xp/Hgtz6BZKInzVsueB1uJRJyLjbPZeenv2dgU42iExBQGbkYZqwdov1qYyZYI1JeWICUSowrZp5QVuq4VslZhygutXoKSFbmXViFxn5eWllxPVnrrvb29FYQkXuKNUsrq9dM71jKxGgZRqN/Cx1bYUWixaxvveqNzpCygwieZUuFSNSYZjyfyFgSBuyv0jPn+9DZVkPLv2aQlKg/ZGnUzMzOYmJhwRl0mk3H3WEvY6nuoQubatMENERdFZBoNFfig6rm5OWSzWUxNTWFyctK1wEyn0460lU6n3d5TLnPduVwO09PTTtby2bS0tLhOTM0YpSo/1lLK6iVbVErjympU0JjjexI1su+pbTEVtq5nROYhc1hBp56xL9itZBlawzyYtFKSyaS7hNUUQq0L1wviix8TpiaUpAJdY7Q2hq2KuNq/m9lTa21ayFpbh/FCKARNlMIWm9fno70/6Skw9m+fb70Hjb9vYRy9QD6oyTIgq3Xp4f62tbWF2mNaI7GeeVsPhxdboV16M4TCgyBwCkLbJFqoHQB6enpCBpWe7UZir7XsP//NNan34lPEFpVQtMGeS3o7vb29Ia5Fo4abDdFo3I+fx31PJBJOmXZ3d7t4JfkoVAyWr8J7ZJ+BPZ/NPAcb9iBkTRSOWSjcU0W1VHZQOczOzgIoc2woW5PJZEg21qvQ+FWNYquUyR1YWFhAe3s7kslkiIXf09PjEEwaDey9TmNUMxG6uroqDFLeqUb3vZpyVn6JL6RqnRiGBGgAaY9ti7TyLPnOUD0jMoVshbyPiaZQqJIw9NLTA6RCtv01m4FJq1lQSrawhepJatGm8KVSqSLuoPCkrtd2SOJeNTKs5arxD6uME4lEqP2dry+qjQWqQlalrCS8Rjop+RAU63npyz4TWtVEKxYWFrwx7Xg87p6JPSvNxvctQU5Z0t3d3S4WBiDUiYeeGS8s19XS0lLRXEANjChDHr5RzSjieebv6D5yX+2e8ix0dHSEYqG1IhFrzc1nFARBuYWpsmeplLq7ux3JT6FdGq3WkNCz1+ic11uLogvaQEXhd8vAt93yALizxMFzZPc8ijnrHqli1sY4AFwIkumQqpA513w+7xQc0UYAXmPIPoNGz7+eTyIoRP+A1b1UZFM5H8rsn52ddQYdwwM+maaOhe9Vy4hEIVvIVtMKFBJV6wOAszJUyaj1wkNolb3+v96x1mby5/xMTTugIuKFsEpbUwCoBAlh+NK0GoVjrPGiQkrTO9RqZrNyTUNRw0lTQBT2slB2NeVey/x9glU9NH3+ltXJPOTJyUnXTUx7rmrMTYcahTbNoZZ5+y40e7nSuqYwotGo54bnV70ExptpUFnWJ3PGOT+9A80MWva6Lr2rmhpHQcvP5LnjYJy4paXF5WczH13DJ/ZeNUPSUViRXIeenh6XktPa2hrqD5xIJJyCmpubc8+dd0XfT+VJlMaPrl3jjrYzGxUDCVmaOqT5rDwLPGd0FCxDuxmDopqRpkYuvwIIyQ/KPsp9NZR5/vRut7S0VA0rNcMWt7KNpLF0Ou3Ogs8IsneAiATJaMof0rlZmWrPUz3riFQh08JgThY9Ao1jcmLceBXCfCAUdOqh2Tzbei/SegfUxlOpeBhXImGBFhPJXizCEY/HHX2feXFKVLD9nOsdPqWmyli94ng87oqTsJCD5rjqYVclbvOZ7fc1Z3C9XHCdt/UQKLS1VZ8S5lRRESYjEWZiYsJBZhZ60mepRqH21rax9FrPREdHB5LJpIM9+ZyVW8B5cE/VI5iZmXHe8dzcXEXogd4yU4gsihHV0Pe0ubgaolCjQAU818i7u7Cw4PZHofwoSH/WEKd86ezsBAAnKPlvKuSOjg7EYjF3RnK5nCN5AWHPSM/3evnpjQ6F3224ieELPZ/qGSvPxsYwY7FyG0mNh0bp3a81LHFRDXogHJ6yHKHFxUUXXtIQ1NLSUqjlZa3PQX9P4WciWH19fc4I6O3txdLSEgCEFLB99isrK65dKTMkeMZ17SofVcY0cnaaVsj2wmheGQ8K45S0vFUZ217DQVAuaKGeZrOwaS1z10INwKp30t3d7chkbW1tWFpaQiaTcfA110HiGqEbVp9h9R+dd6PDWq2WHUhCEdNp+vv7sWXLFldJifCeNWhsUQIbA7cwmsLxtSpkwvz0qLS3Mqtf6fc1V5csRjKcc7mcs1R9xIlqClnPTy2GnBqZZIIGQYCOjo5QH2FL4FCPgESueDyOxcVFFy+kUFWFzF6rvufA+TQ7fOdd7xn3y8fKBxBCI5RwpzFon3dc7xxVwOr5Y74tzziRNEVwqBB4nmgMc/58XzVIrDESlVJWb9PyEIgkKCHNFlbR4i2aWUBDRPOy1XONYth1+5BQDUuqctPwkyVkKipqzw9/R43CeuQ858i4Lw1pogrd3d0hVELnzM9QNCOfzzuiaDabxezsrCPRKRqpob16kTgdTSlk/TALfXZ2drrDw4VTSFFQ0QuidQeEq0WRoEHo18LH9cAB9nJVU8ZUqjxk6XQaAwMD6O3tdfAQLTvGgIrFors8Sizh3OuFeNcbvrgEBU4ikQhBqyxP6vOSLXStRo5eBP2ZHsRaYSWF2pWVS+YsmaZKcOHP6TkqO5IpC1YZW4/K53H4YvrVhnpRvNiECRXRqUakKRaLDoZeXFx0hhwFrO6Jluuz822U3OJbD7/qmVcPWSuFbQSMW89Qo5DGIIWhGj92LYlEwj0XzZ2298WGarS6WNTrtp4iX4yHE4KnAa/GI1COHRMRU4WsfJZ6Y5Z2rKWEfS9rOHKffWEZm4JolbUqZevA1LIehZCBcuqshrVYCpeGHOWEGmyae01428pOvqzBr0hLI/cnMshaLwZjCVTGtEZ4YNQq4lceTKtQfIe0We9YD5huLD9ToZPe3l73amtrC1Ur0vibwhcq4KxXz89vdN72clgvShWnkrWYK6gsd3uRfCEA+z2f4l5rVIt/a+1wstuZFqH51UynUYWlebK+5+qLCdWjiO1+83wQrmVJPY1HWiHIc866xNlsNpQGpDCmvQ/NsDRrWZOeHTVY1Aiwz1hj2arMeL7Us/ORBxuZp0U6SMbREI2uQxWYGjH2LFvPxs5d4fqoh3UESqVSKOWqp6enoryrnhOujd+zBMYozst6StjKADsnAKH7zTx3Ol82Jm3P//LyckVhjXruLOBXyrHYapVIzpGoIp837yydhlKpVIFW0LjTe6CwtYWsH1YPmYvSyamXSUWrcWJfLiPTVpSMxKLjhIt9cct6F+zziqmI+Zk0Ihj71prNvlieQkQKrzWqBNaau7Xq6XlRMfGwq+WpVieFPK3C9fbT92+f0q5lWKKI9ZjpLSubmpfZwlk2ncZnMKiwroYA1DKs4gLgGLxcC9dn18pwzfz8vLdwvu6FZZ6rcI1aKXNdPsSjmpGmaIkar8lkEr29vejr60Nvb6+r/62xT6A+I1SVsYaRgiBwBTs4HyWhqVJQ400LZwRBUIGGEYFTyDFqL1mNAcv21Xnw+xZRtIYtz10UXrFvnr6QlQ2lAOH6CJqvSyY507tyuZxLXVVkS2Pr9KjV26RnW+8a+FWdJfJrKC/5c/6+Gvk2C8eGpPie1fg3jcr8SDxkH+xLT1NTC0iAoYLWYhYaN06lUhgYGHDdZXp7ex1ZQy2aeocqYwpYpavTiqIxAZRTWLgOW2ZQWYH8DH61nmajQ+etwjCVSoXmr6xlra7T1dXlrHFam7zw9oJVm6cPyqp3DXbw0GsuOGPJ6h2rQaexcw6fwrX/btSA41eeb+7dWkqTz4O/Zz3HRCIRSplTJWy9nY1QxmsJ3Wp7pntA+L6vrw8DAwMYGhpyX7WzVqN31Z73rq4ut5+aUaAQqSphLXFqSzwy/MDQElG4np6eqg0Bmt1rXQuNi2Qy6Z4/syBs9UL1JDUrgd9ThR3FPK3DosrGhzQBCN1fVmojb4KhKNZ2yOVyjmjHwbUpn8KGEBq9C3Y9XBP3FijfVRrQ2tVNq3qpd28dL+sdW6O2nhGph0zyS09PT6jwPv/Ni0JFZlmpLS2ruYRs+8fYLVv/ae3rKLxjAKHP5vxtjVNVcqzSpLVQV1ZWQnEroNKr8Vm69cAwPgRB8yo1T25xcTFEjovFYiECHQDnISsEE4Xx4Ju7/tu+qrFPrfGmUJdFJaxH7CPnNLo+vr/+P5FIVBUQigLE46tpNhYK1rlshMJda+geqNDwkWhUEat3Rwb/li1bsH37dtdOb3BwEENDQxgcHHTQqwrveoYibgxndXZ2hlJpgLJRp1yDmZkZx0lgDXpNTbTd19hdit3AmoEcq+21GhepVApzc3NO2dDAYWiOxrPPe9zIcAbnqgpZ0x9955ixbVYQSyQSmJ+fRyaTCfUBoHJmCIefo8YU60ETQbBIWFRDHTDdZ85BW0vSoFDDjvtk96gaKlrPOYoshqwQMK1PwpEUTFTIZLnxwRNKIJGrt7cXg4ODrk8sL0wU5CgVMDYGwM/XCj4sSkESVyaTcaxgy5RUGFM9OQtH0hCpVSnrXJlTl0qlnJJiGUAWbdfiCfTcNVZPoUTPzRKjNmL4rFV9juoFaNqMeo8AKuaqitgSt1S5NDNvoByTYkzPzoVDYUUaRRYK5rCXVj/n4Ry1zEkNQjLOBwYGMDIygp07d2JoaMhxLdgbuVGClJ53/p/3y6IHmh9Ko5/1oXlPiXYxDq09vZWn8nB4yFTITBcjDG9LjRJm1zQh3mGrpKIypPV5q7JRZcx/MxxDz5KlNAG4vtXsCaBFUCiD9C5o4RyGd6I2PmyoTF8aQiWyQoOOhgTROoW7fZB+s4ZcZB6yesk2zUQtKeabkpnMA0iFzsvO+DEvt4XAGlmwHji1/nnwCLfTG2aeI60jpuDw4Wg80xY4UcKavghHAag4mOvtsQpEhgQIGbEdG2E9CqpcLucuDi8/hYIyB9UTjFopV4OOfHCYWpkUwhzkGgDl9BtFOCw5KSrGrJ4bPjOgOgpiGddrfbYaiNYjbXbeaw0b0/cZP5yfzktha9aXJ5OfoSWr2Ood+nlAuUa4z6vhz6kYCDcqY5+yJpFIOAVMZaxZHMqyjlIhq4fMfSNptFQquf3UNpZUROoha8EViwxFdXer3VUlq7L/Mc8pZTvnTj6IDT0RnbMcG0UCbBpXM8rYKmHLXdJ9JXGUZ0cRFp4l/p1yiayB2Kzh0HTakxJrbCyPLx4srcrExejD5wPX1CHt8BSF5WphBApYy8iNxcrELqAcK7FkNM230wdLNID1ilnjlTAt2ZW1Du4lL293d7f7fCpkWqCcOwAnpLjmzs5OpNNpl0fLFAOfZxTlsMaPsnPp8ZPZyN+jF0FrWUMIPDvKPSAUqWcmKuFabW98sWOgDKXybChawvdR4afGiZ1zlM9DlbAvB9TmEFczCjQ84GNpa+y4GeNZFY4iT/QgtZ6B8hAIYVNxULakUqkQ+YwQabW9j2LoXjF+zJQsxsbVedGUKKuoeH70LlmYNAoDVPkFlvzKVovAakETyjTNK9azzvcgYsSv3HcqdV2nr/JYvcrOkvzUOKAc10JEyrvRFqva0ERJXyp/9Q7ZnPB65h5pcwn1Dqz1rYQVejc8VLwsFNK2hnKz1U/sUEGhUKi+v1qAyiS0Bfj5IiGgUCiErGxallS+/F0WH6GSXuuBqZWtCpl/o0n28XjcVU0idL2wsOAg7O7ubgcf0atW4cn4aFQCyVrcnL8WVdG5s9i8pj9pypPCk8pboFLXmKBlNkellO33LLlLWac6Z8sw1TvAvdE8xqgVA+dn4/ZEfpQTsV58XM+cffnCDM0oZf1bfj5hUmXnK0N/bm7OeTNsgkDkbWhoyDVCiDKd0jd/6yETnYrFYo6spt5oqVRyZ4br5P8pc5RAqxXWlNXe7Lz5zFRGExFZWlpyqUSKcuq5oZFMOdrR0RHK21e2uXrJtnhIozFkq4wZblT+jy1bS0KX9fC1mxnPFOWWOl8kHauRWi9LvCmFbL0DjZ/ZS2qLJ1ivmEKVXpNPGUcBP9ph38t691xbtbQUWrEAXCI5UI6JAmUrn0KQ7Gi1RNdTyBrDs1ayxjA6OzudYKJSowIDVlMsFM7T2snWuo5SMFmSjg8C7evrc3NnRxy1Uun9KLzOuD/LhPLFUIev0MVGDR9rnPMmIsG7wH3RHNtqxJmo5qz308co9XkCGjNXlIvGBtfIhg42rhYFoqUKXZUxlbASiMjxoBHa3r7aU31wcNCVkSXE7qteF/WwZ5+GdHt7u/N4dX94T1n7XImOlilOIpoWH2qE1b7WvGlI8I4tLi4ikUi49oX8XQ3DUQ7F43EXAlR0VItAcd4MOWi8XL3MeoZFgUjGnZiYwOTkpFPK2txFjWclltJoVVSUMXRC3LlczrHjVZ7RUKrHoGjaQ7besMK3GvfwsaopiGwjBk3QjzpGUm1YKHstaNL+DQ+TeteqBDl/hS9jsVjIG1rvoamVzffkHiopraenxzUSj8fjrvwbPZeurq5QX1kVSM3Q9deat425KnLAEEU6nXbKS9tgTk1NOQiJzHi1mpUISEFrFXJUnsNaQ40zq5C1qw9JLdwHH+9iIzxkfiYNSkV01LukENK7qu/BM0zBRYFEw84HpUY1fws/8pyzzjmJXDQqOI9UKuVSsqiUWd+AOb8babCpQgbK3iPlospOemGUFzQ+5ufnEYvF3HyJkm0EOxzwZ8+k02nHQ2GHKVsyFYA7w93d3SGdYOtuA2WnQu+OVsFrFq5myI4K+ciRI5iennbseyplZbGr7tJ0UsptXY/eIeouGiGNFGyJBLK2nrEPU1cCk3pGvLj6tRpjbaOFKod+HuegXrOm1rS2lmtaU1Bp+hOtJb34POAKYdYCWfNi23iskqAo2GOxmGtqEATlNmhqFdLyoyJoBiKqZU/t3DlXEvkI/+RyuVCKG5UC93d2dtYJeiW5Kcs3mUx6a3dv5FAoV6EsTZFTwxRABcltI9Ege09JVLRGg+V2kNyo4Rs+CxaAULiORrYKo2bXoXMnZyKbzboKb0RTuAbGaLW5APOlGdLQgj8bpYz5nrz3vAft7e2hDAzKTcaUgTL7mC+Go1Qh+8rzRuXp6xlgmAlYlWmaZmarcGkohmdGoXeF322ox4YDGyVKWZ7E7OwsMpkMJicnMT4+HkKEqJC5/0oYtPJQ9ZeNfUfRXCWyGLIKI50grSclT9E7VHjFMtWiYKzVO6zyt6QGLcTPMnc8gEq4UKKJxujokeq+1PrQ9GLb+em8+fOFhQUHo6ggs8QJX7wmKiFq5855EnpjDErDFhQ49BZ4YVgwQI0QoFwaj0RAsmd9XsNGDqswtFOVQtYqhNZTAr57EIViU0/TClV68Boe0fQ5Gg1A2EAiREdFQeHcbBzZFwqjQrYEHHbRAsJZGyyio13YNiLFqdpQpEFlivJPgHKqk0+WkpUdi8VCqaUkMfpKlkYxLPeDclwr/WmcnAaF/pwKeX5+3jkPvhrcysmp17PUoWeOXrLmGJNJzTix1VN8Zlybj21OGWMbBzW795FB1ja5WgWR0t1JBuAB5YbpwdNYA18Pp6ds4ye81IQ0lCnb2tpa0WTbrk/z3dZSgLXMC0BFzFefAQ+/dpmiErTM36is0Vr3lIOXlLFzTYGIxVZZpgrfquFhmaTqaWtTj2p1lTdi2DuglYuUpannhwrOhnv0RQEHhHOFG11HtXlSIWtFK8b629raQrmv8Xjc7StQ7vVMgg7RDuslNLP3VklRxthqSjQmyK7XKlwblbVR6+DZ5V4QheB98xEBlQ28srLinBjGj0lU45qskxMFbK2OFI17VVKK3tGg0FAazzAROspAhvh0D6J2ynhu9Lz79piGnKK4avgTilaHrKenJ9TeloVdbK+Aep9BJB6yXhYtpsGLws1nn1JVaPybtra2kBJXL06tj4fjAmnsRPMtgbJHRo8umUy6OVOwMfXIV7uYlmCjB856wgBCOZqEuXwtK63yVpIaX5wnv0a91yos9HO4DrIy+Zw19qlGh+UiELXgq5rQjXo9Pg9OoVzeA1riqpB9YZ7FxcWKy6wGTLNcCus1MNZND5keDgt7KKSqsW/e5aWlJRQKBedBMNaoMcBm9tYakkpGI9cgk8mEWNW8o+l0OtQG1aY4RUWAqnXoZ6lMUKSBDF8fGZAhGDK10+l0yOO3CjmqOVMBk8OhClnj4zwj/F2tQZFIJEJppFYG8Xs+udiMl6zva50R9dBVJ2m+NV9E3mgAkfeiz4BIDNn7jTgDkXnI6h1rHVNVULwMWmGKVjlTXjSHcGFhwcuS3QhFoUPjnFS6pOormSKZTDrYwwrfUqkUYsxuxKBw5gWgJagMPx9j1yfomoGIGh36DH0HV+dYrVCAEsRss4CNjh9Xg1NJ9GCBARqmRIv070qlcF9km4ZGj0Nh40ZjhHauGkPmPQ2CwJ0bMoL5Nyq4uK9UyBTE3d3d7r3qCcmstbcWwtV8UaayUMZw3iT6kby1nudi57hR8sUaolZu2mIUVMj8W1XIdn1aCIfvH8U6VCnz/wrjEp3i/aRHTUSACMvc3JxzyOydtvP0GS8bMVSeEP3RUqZEIXp7e10pZ0VcNI6vtRAa7bsQSdqTMjc1N1DLZKqFpfm8tFJisZiz1LV5A9mSlm29EcPG6hSyZp6s5pgxdskcNIUYWXBDCVfq5Uc5FGJXcpAS5KyyU6VcL3T+cAyNK6kVa4lCKhSqtb2M2ju2MSq1uJXMRQNNC9RrLqn1VFnf2GeU0KhTuL7ZOSvXg/eNQl89HGvAWY+V32ttXe0fzvBNFEaeKmRltqoCY8lYhRRZp9qXj25DPJbjsJZyaHZUew72zNAx4V7yWdBJoLKwMKk6LFaeNTLs3tAjVmiXBDW9l/x8hhsXFhacgtYcdk2v089o9Jz71uszKOzncD2U9+Qb0OgZGhpypEDdd0Ui7esR9ZCV0aZknKWlJacUeClIvFCquVZLsfETCgh7wTfqotjP0JhCW1ub1/JnvhvrsNKzUeasGhXNxDUt1Ftv3MUq4Gow0cMJ5wH+KljWU9LftZeJ+1uNpd/ssPutSkmVKhETNUy16Ib9OwpjrXpEJaQFcvRMWoFb7zrU6+SLSJWiDhRi1kCip6oef3t7e2idzZK6dJ4qxFVWaI4u75kaaIpQ8T24x3qG9BxVUwhRniOVmZrTbcN23F9+vsohK/j1jqhirKakmxlUcNxvXRu/cr817VG5NDxzSp6y6a7N7LmuX3WQrSjH+SkyRK+4v7/fdTIbGRnB4OCga3Sk5VbV+dJXvfInUpa1tbw1BkzLjgxaQhns/sEHpfCd9pHVRUetLKynr7EGHhxLNlNyl2/Yi66H1+b71vPQfErKMjL5svmBqnTXU8aPhuETmL6h3g7/H/Ua7F7Z3Epl/RKmJvSo3o6GaJhPm8vlMD09DQCh9BCts0w4HkBTgsrukQ1b2NRECnugfD/o/WquLIk7TH9pJoVO52YNB421q+EOoIJwSQWkcX3KHZ93pAiTKueoOAi+e2vJdVrekTFZlSFWyKsXymeylnGxlsxa71noOVGDxiKX/F29JzQ+bXoQ56eZLDSkGlXK+vvqTGnNi87OztB5JXeFqGcqlXLKmJ3MlMBl9ZLvvDTiDGxIHjIvDq27WKxM1dfgP2NtnDQJSVTIc3NzoYU3Uoqs1qEGhcYClZxDoWpfKmy1nKBeDi2h2SipxF5oHnwKKCoFLfyhbTDj8XiF4nq0DbVoVUjaAw+Ez50iNL7qcM2MagaQ3fNcLoepqSlMTk5icnLSNWa3uch8L56tiYkJAHDVp7QbkaZzUfBqfK6Zfa42fEKbykCNaaCc4lcsFt1Zs1WWmlXKKlOswcnvx+Nx99z1b7Q1YKlUcqk3VmhqcR17R6MMl1llrOiKMvKpkNUrVkhaHZ8gCCqMfOup8fs0zhuZr2aJUK4zFKnKx8fip6HBu8BKZFoumWddyXfWIKp1qKGgqZHJZNKdU2Wwx+NxF/Jir292HNTqf7yT1eqfNwu1N6SQi6Ui3vbdt+FTd34Ko4VRjHSN4Pxd5+Oc1DmhYL0KEAb+GY9lkL+9vd2lRfFCKeNQ47UKgzXqJV9161W46rar8GDmQQDAiUMn4s1PfzOedcyzQhar5q2xGpBCj9ZSVyVIAaW5smr9NWoBvu27b8MVN18R+t7evr24+cU3h4hlMzMzrg8p01mIVNQKa0c9bt5/M/7hh/+A2w/fjiOFI7jhghtw/nHnA6jMx7WpBzTG6L3o7/OcrYUONHtm3vX9d+GGn9+Ae6buQWdLJ5689cl481PejK1tW11xCrJ9p6amMDY2hvHxcUxNTWF6ejrElqV3ScFIJbewsICZmRlHFtScWbbrIyzLvajF26m4q90jeOGxL8SLhl5U4Xnp+3FfVRHxrBJmZAychCqt2hRF/PjdP3g3bvjFDbh3+l60x9vxG72/gYu2XQQsoKJrTyKRCHnRVHI8L8vLyy5Gb7M2NM/WknJsPL3W8/Ou770L/37Pv+OeydUzc8aOM/Cu33kX9vTt8XrHtoaypqBRcVGeqDFIZ0fTAu3dUeOf4cK1Rn4xj7d85y244Z4bMD47jicMPwFXPuNKnNh7YkUBEA0las6x5uFrf2GtCKeGA+s6qCGqcfFa937X+3dhf3Z/xffP23oezk+e7/rIWySXCplM6cHBQQwPD2NwcNBVdkun04406uPnRBEea0ghv/sH78ZVt12Fa553DR7X+zj81/7/wp9/48+xtGMJ+0r7nIXEA0WFTNiN1hatUV5uWn4aj/MV0mhmbE9tx5VnXYm9/XsRIMC1d1yL53/2+bjlkluwt3dvhXdMBTcxMeG8ZFpV1sJVLwgox3FUIavHX0+ck4LtxKET8ZULvlIm1CwXK1inNCAIlbJykYWsqh2kjYgdzy7N4uThk3HJEy7BCz73Au/v2JhPtXiZIgtqjWuoo1r5unqVchAEuGn/TfiTJ/4JTtlyChaXF/G3N/8tXnTji/DFZ30Rc5k5VyN3enoaU1NT7v9TU1NOwCqJRSF2CjDOmd6wth+lMm5ra3NtN5knvN7Qu3pc/3H474f+G3/69T8F9gJPTjw5xHPQ0AYVhn0mAEIIBfefSs/C1c14xzcfuBl/9Bt/hONSxyFXyOG9P34vLrv7MvxN79+EyhxyT5XDsrCwUFHAhOlZNn9WmcIs9sP4vYXAa43D3rT/Jlx62qU4bdtpWCmt4E3fehOec91zcMcr70B7vN1r/NNzVBlD753nnuuhM0MjTRWyFrBQGRuPx53zs9Z4xZdfgbvG78K/nv+vGO4axqf+51M49/pz8Z0XfQcppJw8oRLj86KRqCQ1ZY5rOhe5QQBcARcqYyWq1eu03PrKW7FSWnHz+Mmhn+CFN74QZ209Cz0Lq7nDNBy1lS5T+ZjO1N/fj6GhIddWtNq8ouapNKSQf3jgh3jevufh9/b+HpaXl7GlbQs+c9dn8IvCL7AP+yoOrI2lAvBaFwBCUI6+orK6z913buj/7zjzHfjIbR/Bfx/8bxybOjYEWVPJ0gvSNC5LzqEy4DxViGk1Kl4SW92llhEEARKxBIY6h1bnGF/CfGk+VAOa6WYsoK4xPbU214t5RK2Uz957Ns7ee/a6v2e9FkUWbMyGZ0GNOL4U0tT4qBJkah1ffclXQwL0fc94H076xEn48ZEfY9vytlAtZb64/9rY3MK4FI78N70+TQWhQdfV1eWgNn2f9Qbv6jmPOwcrKyvY2rkVn7nrM/hZ9md4at9T3R6rpwMgpGgBuHgxf6aGBY09hZmjCBPc+KIb3R3MxXK47PjLcN73zsMDCw8gVoqFiGaWRzE/Pw8AobusissafiRj0oPke3IdNlyy3viPl/9H6P8fP+/jGPl/I7j9yO04fevpFV6y7bqlhgbvp8bDaVxoFoquh8gjzwrvDEMN1cb88jy+8LMv4MYLb8Rv7fwtFItFvOmMN+HL934Z1951LV6191VOzvFM8BlQITNMkM/nQ0rYFscBEMpkIWStkHC9Yb2h7iF3BpeXl/Hu/e/G0cmj8eQtT8bo6GioIiCdI/ahptNIL5lGcSqVcp6xjz0dpaxsSCGfseMMXH371bh36l4ckzoGd03chdsnbscf7/xjxKfCcT8eJMKKtMRppfisaZveslGjWCris3d9FrPLs3jytieH4sia86jKTiFpTQHh71M4aUzKNtDQtJxaHyr34b6Z+3DsPx2L9kQ7Th06FX/5G3+JnmKPMxwskYhxfFrOCmVZdmDU1l69wwpJbT5iu4ApZKvFIvL5fOhiqfFkjZ9a1ukj4UzPrhKw2kvtoQIV2m2IXoAtPKBQKdfLs25Ti/TVaNxf7+ru9G7cOX4nbh27Fa894bVoC9qcV8CUJxoJ1lumB8TfUbIgjQYbs6x1j2t9BrMrq329u2JdmMd86NkACKEkDAXwbFhl7ON36PNSZ4BwbKOx+yAIkF3MAgD6Ovq8hDUttat8FI3VKgmQRXRseVgaskRb6GDQq6OhUm2slFZQDIroaOkIzb8j0YFbRm/BHx79h87IB8qVt/j+5AGRH2H7CSvUzfutBX2qFfVpRC4tFZdw/T3X45LjLwkZVFa+sIgHP59ICfONbSU0NcyilpUNKeTLfvMy5BZzOOGfTkAinkCxVMTrn/B6PCv9LNyTvafC6maMjOXtisViRXyTw5IYLLEiinHn2J146seeioWVBfS09eD6F16P4waOc96JTWWxLerUerVwqHrEvgdri8HXYgFSAJy27TR89OyPYlfPLhzIHMB7b30vLvz6hbjmtGtQyBVCcDXhde6twli2Co0lKOizeLgHL4zmeXd3dztYnpeDVjq9SwoqnhktpaklRAHUbHWrEnQpKkuLeOsP34qT+07GcGwYB/MHXahAi4BQuKtAVW+Te03Bo8aHFVCaYqFCqpbBu3r8h493d/Xyp1yO5x/9fBw6dMh5A/SilARIYUuok+gW7zM9aN51W+40ivvqYGIE+OAvP4jju4/H9rbtuC92X0hI8/nQEy6VSk4RVwsPqfHX2dnp5BEViq3Qx/tar0IuBSW89j9fizO2n4ETh050itZ6yUpotQpZzxHj5axVr8Qiyh/lIvT09LgUtrm5uTXnmmxP4qnbn4q/u/nvsG9gHwbaB/CZuz+D28Zuw86enaHSkyQmKpxOZ4shR5uaxsImsVjMtXRklStL6Go29elL934J2cUsXrD7BSgWwo0r6KTQOw+CoCK329cGeCPqGuhoSCF/7u7P4bo7r8Onnv8pHNd/HG49eCsu++5laDu2Dce2HesECxXCysoK8vm8C54Xi0UnQCm8GE/QmI7CCpbN1szYN7gPd/zJHcjMZ/D5n30ef/TlP8LXLvgajuk5xushM2E/k8lUKGTr1fFyMxbBxHKtqmMrudSypiAI8KxjnuUs6R1tO/DhMz6Ms75yFr6y/ys4uXiyg05VMRBa0oIC+qKCs91iHm5lXC2mp/F6spAJ3y4sLABY9YzYJ5lKkBZvT09PRRWjeqxtyzL965v+Gvdm7sUHT/kgCpnyuaB3nM1mQ5Aj16PClC+N8/GcU0H29vair68vVJ6vkYYZvKvXveA6HD94PH586Md4w7fegN5EL34r+Vvo6+vDyspKiCyWz+dDiI+Nu3M/6FUTkbBwYzOIi4WV3/nTd+KB2Qdw5eOuxML4QoVwJFRKxIGK2ceCBcJIBdeeTCbdXSdkSRaz5q/Wq5Av/eqluHv8bnz3ou8CqMzVtbFkKjANf9Fr5j3IZDIhQ1PTKtkDuq+vL8TFYd399cYnn/9JXPKlS7Dj/TuQiCVw8paTcf6x5+MnYz8JoRB0YBTJUdmpKCjvJT1iGg3KYNaGDVHUGv/E/3wCzzz6mRhoG8Do0mgoVY6OCu9cPB539am1PrjyVjYKptbRkEJ+4zfeiMuedhkuPOlCFItF7EntwS/Hf4lP3fsp/P3I34cqmCgTEyjHiAm/8IEB5ebWFqq08YRmN6Qt0YY9/XsQBAFOGTkFtxy+Bf/043/Ce37rPRWpWwqFspwdGZC0sixpzZa2IzmHrFmf0FprVEu5aS22YnvHduzP78fRS0e7lBuiD8r2ts0X1Gu3VukjCVnbNAVNY9GWeywIT0EMIHTx29ra3DPQNns2NWStYVNv/uqmv8K3DnwLHzn9I+hc6MSR2SPuXNBD5plmLFJZzOqtKcmPz0BJXf39/S5+ReWQTCbrVsh6V4MgwAkDJ+CBmQfw0bs/irPPOht9fX3OgCEJioQXwr0Ku2sMXtEMGj8+Y5O/X8850M94y3+9BTcduQlXn3E1OuY7cLD1YAg5oyHGM7C4uOjm5vOK+VxU4Xd0dDiDine1v78fKysrTnlwbfUQS//8q3+Or/zyK/juRd/F9tR2N09F1aqRQ9WbI9zMVFCbFsT10OPr7e11YRN2fmtra3NG7Frj2P5j8d2Lv4vCYgFTs1Poa+nDy298ObZ3bXdyUePdWlWRc9YzoxA1zzpTkPr6+tDf3++8ZMZpm5X3D2YexLf3fxsff/bHK1KwtBypZm/wrtl5PJwhvYYU8tzyHOKxMLTZklgt2GGJOLzQGhfR0plMyG5rawNQrstsyTy2qkqUIwgCLK4shmI7PvYu4yL0+oEylBWLxdDZ2RlK8SJsVC0eYeNa683RWtSZuQwOzR/CqS2nOsueyoGWPT0YEle00ISFrB/pOLIldHV0dISqMxFKooXd2toaqn7FM5VIJJBKpUJdlvg+Kshr4SjwTLzxO2/Efzz4H/jYb30M6WIa0wvT7kxobuXc3Fwo11M9YvWy1OhUuI5M0/7+fqeIeY4aYZ7yrirs2ZJoQYAAXV1dIYOYMG82m60gmmkpTCoxGqE0ntSraPYs8W/e+J034msPfg2ffvanMYABTC1POeNG5QHny/mpIeCDGPV7NJQYGqEXVSwW0drailQq5aD8WglrQRDg1V97NW645wZ856Lv4Ji+YyoUuRrZWnhIU7o07EFuCu+HNSoIwXZ3dztImUYn47zrkbp0dLV2oa2nDaOZUdx86Gb8xfF/UVH3gKE8Kn+eEw7OiQZBR0dHKJ3PNsiwsqhRD/kTP/0EhrqGcNbOszA9OR2C2jWcROOWTYSojG0TkodLLjakkM993Ln4++/9PXakdmBf/z7cevBW/PNd/4xzjjrHG//lRdc2hYzVAgjBQBb+jZp0dPk3L8fZe8/GzvRO5BZyuO7O63DT/ptwwwtvCFmtGt+hQrA5rrFYLCTM+IAVhqym/OpRxm7u37kcv3PU76A/0Y/7J+7HB+/6IOKI45TWU0IKgoeOlin3WNOv9KtCext16ApLBdw3fZ/7/wMzD+CO0TvQ39mPnemdIYWhMCKVst0/CiSgXKGJZREZLtEqPAqt1UsW/Mtv/CWu/8X1+OgzPoqOeAfG8+OYnptGYbEQOhd6ttV743Owz0BjxopcaGH7amGFWp8R7+rO9E6cMHQCbj90Oz5024fw0hNeGiLVMPWKyhQIt0ZVpjiVMQtCqBFt8zObGa/5z9fgc/d8Dp84+xNItiUxXZjGzPIMVrBSIaz1zlIhWMVgB58NlQbvO4fW5q6X3X7pVy/Fp+/8NG688EYk25MYLYyiVCqhO9GNBBIVskaNbUXplDwHwIX97BlQtIJ54WyEQxYxlel64+v3fR2loIS9fXvxi8lf4LJvXYZj08finO3nYGZypgJipzOgRgvlCIs5tbW1OYPChs60HnQUyrgUlPCJn34CLzvxZYgjXlGIRR0qsqttHfxHykFpSCF/8OwP4i3feQsu/dqlGJ8dx9burXj5cS/HS496KQ49dKgCLtKYE18K+ba0tFRAYdUgp2bH+Ow4LrrhIhwpHEG6PY3f2PIb+PIFX8bTj3q6i6+oN0pPmcJI4RgOGh1A+WJQ4foMi0bXdSh/CK/6xqswszCD3rZenJg8Ee/c/U4E0+XKPZrqQ8RCPU+FUPX//J2NOni3Hb4NZ37iTPf/1/3n6wAAF598Ma49/1oAlVCl7iUVke+yEvYj+UgvnY1xNcLa/5ef/gsA4CX/+ZLQ9y/suBCDxcEKAcpzzbXoevQcKONdc60VGVKv2p6dWgbv6p999c8wPjuObT3b8EdP+CO8/kmvx8riSuizbcjCnn8KWy1Yoc9Iz1YUd/bqn1wNADj/i+eHvv/Hw3+MbbFtFZ+h9Q/U6LIxcA59NroWdpnTSk68T7Uac1fddhUA4BmfeEZ4Tc+9Ghfsu8D9X5WyMq+tzOFnKrPdev4tLS0OIdJCNJQJ8Xi8psIg2cUsLv/W5TiYO4j+jn6cc+w5+IuT/gJLuaXQHmj4jA6XFg1hCCSRKPcvUOOtWl3pZnks37z/m3go9xBefuLL3R5zXzWrh/OxctrKGB/CslGjJoXMyedyOfe9K864Am9/6tud95jP5zE+Ph5iCCqrTb1Nuym6WZrTOzc3h7a2NifcqNSreQmcn14YO/f3nfk+4Mzyz3iJaeFpU3mSK3QtemmAcjs1nxWmcSFN5FdIr9a5F4tF/OPT/xHz8/Muvy+bzWJ6ehqjc6MuZqxkHF5AJVvYko9MQ2uGQFHLvp/afyqyr82u+/caKtCCCZpbbL0HPhvCjxpmYAUkojBcP59BPp9fd98PvfKQ23eGBaampnD48GHcP39/qAKXemeqGNQw0HQ5vpaWltDW1hYScPSI1fDg77a2tqJQKKy778DqXb3ijCvCbPG5crlGjQVy3/XM635TGevP9Gxpt6uOjo7VcNDiYoXHvN6ZKZVKmLh0ouK8Tk9P49ChQ/jfxf8N3Um7v+p5VhtqqOoaLeNZc2p5X3S+vj3Xs265CCprbCMJa0TqWbFz5+C9peHB91KSGOetue/V5v6cHc/Bsy9+doi5TpKoks703ChZisaNxr+1kh4RvPn5+ZARyHh9NYPOnhnf3AHg9KHTMfnnk04v6T5r0yI6izQerNzg+dUWwI0O33n3jqCGceDAgQDAY+Z14MCBzblvzn1z7o+B16/D3B9r896c+yM/92ojFqyrslet/cOHDyOZTD5iDNxaRhAEyOfz2LZtm7NmNue+8WNz7o/M2Jz7IzPs3B8r8wY25/5IDd95942aFPLm2BybY3Nsjs2xOTZ21BRDfqxYIr9OVjewOfeHY2zO/ZEZv05zf6zMG9ic+yM1avWQN2PIj7LX5tw3574598fG69chlrk590dm7tVGTR5yMpkEABw4cACpVMp9P/gV63NxcRGFQgGTk5M4dOgQ9u/fj8OHD7sWgGTnMTWlu7sbg4OD2LlzJ3bu3ImRkRH09fW5KlbM2623Jm4ul8OOHTvcfH1zDyTFQCu4kDnLPrZsEqANG8jS41etCtXb24vh4WFs374dW7duxcDAgCsLNzAw4CpGaS5pvXNfawTC5tU1zczMYHp62pV3nJycxPj4OCYnJzE3N4d4PI7e3l7s2LEDu3fvxo4dOzA0NBSqKOZre9jI3Ln3yqZnv+mJiQlMTU2F5si+wmxfyOImrPKTTqcxMDCAo446yu17X1+fK8GnvaernaN69l33d3Z2FlNTUzhy5AgOHjyII0eOuB7UZKTm83nHKuWayarWXGtW6WK5w23btmH79u3YuXMnhoaGXNlMFlcgQ7xQKGDnzp11nXfuez6fx/T0tOtSpedjYmLC9XFmjWtbDnbLli3YsWMHjj76aLfvLNhSC1u/3vOuZ8eWmWSBipmZGUxNTbn7S0Y867sro5nsZebr8g7s2rULu3btcneY69UevXNzczj66KPdfPn1gQceQEdHB+bm5lzLVt41nunDhw/j8OHDmJycdOdDUyaBMvtbq1TZ9Dmt9saa0Hz19/djZGQE27dvx44dOzA8PIxUKoWOjg7Mzs5i165dFXNfS8b4ztDi4qJrqjI+Pu7kJtd65MgRjI2NIZvNutxpFsHhng4PD2PHjh3Ys2cPdu3ahZGREaRSqYrsgmpnhl8feughJJPJikYds7OzyGQyGBsbw/79+/HAAw/gwIEDmJqaclk0trWvloptaWlBMpnE0NAQtm/fjt27d2Pnzp3YsmWL65HMoiZaEKdWGekbNSlkfgBLi+mD0pq28/PzoeRqW/qMFyAwtHV+Bg8ZL7/tiFTr8FXk4cZpehKT5JmjxzlpuopWmqEQZR6dVuuisrXFx3VNPT09ztCotqZqc19LIXOumu8NAAsLCxUHhJ+rtYu1w4mWalTDqJa0qLXmrvPjHINf5UlrWhvTO1h8IplMOniKSfw8H/acaaoRn5s28lhr/rXse7Xzrjm82sRDy5bSeNNzxTmyKAGLgWiTdi0a4msysdbc1zrvvIcUYjZ3m0JJ0/qAcjoXgFBJVhYxqVUh17Pv3Hs1iBYXF0PFHPQccV9VFnV0dLh0IxXEPIPcU82R1bPGamo8SzpfnXdHR4fLh5+dnXXV1XyNNyhPaPTYOgGUH/b3tcCMGqg0UukE9Pf3O8VBpbHW3H0yRu+tniE6V8GvagBQUWu+ug5bV8B2e9JmGNUUsj0nOveenh6XghuLxVwBG18BFTWGeVZsGmIQlBsF+V5a9lZLba6XHrXevWioMIgOFTL24emDYg4pN4VWbTabRUdHR0VVI14QCoYo5qneGT0X5jdqPWJa18xBBuCULfuMcm0UrlRcbJzBg9fe3u4UnRYXiDrm4bs42jqSnhu9TbUCbZK+VlqLoliInZut26vtC2dmZlAoFLC8vIzW1lZXv5kXRAsL8Hyw3GMikXDvrQUKfEn+Ue23LS+pRhqNAVVunIcKWgoolhVUo41KyObu17PntnhDoVDuDEavhnWc9dxTOScSCSfoKMCo1LSCVa1za2Ronq3umd4prSRGA4flENPptGuLSfSiUCi45vRUAHwfzUvWQjt8VdtzdTZ8hozPCdFOWVrvXI08Lb2qlfZoNGgJViItqoQbuce6BltTgd4n+39rP3D2Atfe3tX2zGdU1js4R5V7Ft2kR8x6AbFYzO0hP5vr1EperHIGoMID19rdUZ39phUyh8IatkGDJoPz99iGi/Wf+X0qPfXMopqfWthUxCz7RsFE6JSKC4C74OpFa/I+UK7Qtby8jHw+H3qgqVQqVKQ/SsGll0YPJdtdErKbmppyCpmGkW3HqBWiGq0Mtdb8VDlQGXN+ExMTGBsbQy6Xc+eEwsb3+dzvIAgwPz+PqampUCEAFtFXA4MWfbOX366H+62tRGk0WI+A82F5Sp6FWCzm7gSNkEQiEbLaVfGtd4Z8553KKJfLYXJyEmNjYw5aJMyuBRRo1GjlKBqg3d3dFQ3na5lXM4NGjdYh5/95TmkAcw2sMa4FfzKZDCYnJ92dYM1qoiy+bkX1VHrzOSlWmXPeVAz0tLREqlZro8FGQ4MeGRETLf+ozW2INDZSPc0qOq1dTSN6bGwMo6OjGB0ddaEPGnZaKIR7oN7tRhjHDMeogc92tJQLxWLRhSr4YrVFKvT5+XmHcNGTZk+G+fl5dHZ2uvsepUEamUIG/J4QDzc3i4eSjeS5WFotto8tvZwoHp6rUmQKo6ulNzU1henpaVdGk0XnCZHyYPsEs63QFYut9vzs7+8PXewoRzVDiAqPscHp6WlnJRIWZkzWp5DXal3XyBwVHqVCpnLQ2CUrZ3V0dFS0qlR4TCFgKkN623NzcyiVSk6hKx/Bwl31rkP3XGOyWqmI51hf2gmMAlLXo1Y7u1Mp+9injNcTANW8hkwmg6mpKUxMTGB0dBQTExPOq+H5VQOAQ+t0s5GGxt82UhkDYYXM/9PgoVJaXFxEKpVyytR6dTRG2tvbAcCdHzXY1JmwMH49Yy2lrJ4++QPpdDrU+YvnxgdNEz5XfoeFgsmlaKSLm95ZykyVJzxDR44ccUbd9PS0u3+M1bOCnk8Z66uZofeRRpjyZ6iQtS1qIpFAZ2en66LW1tbmzgKrslGG87xxH7Q6Gc+/NcIaXVOkHjK/VoOwKZCDYLUJA61VHhhat/Qmo4QCfEqBXjIfACEtWv6EjRhnpTC1glIVoDY0aG9vDx3M9eCbRtbEr6og6LFp4Xd6x7QO1TvW5he2MUgUUK8PTueBtnOcm5tzZ6G9vd3Fv6iQCQ3rVyqRIAhct6X29nak02n09fU5wapxOaBxmMx6n+ohq1XN+K/G6JVrQIuc8+O9UAOi2mfXOkc1QnW/tQwoQzRKVLRnS6FX9tXV0MBGK2MOPjN9liwLSUWkITOLyPT09DgInvedz0x5CNVQiXrur88btAQtjcMzjqtdzaicqWAZJ6Y3rWRL3lUNgVj+RC1n3qJu3EMaNBru4FcSMumFcu/5rDg0hGSJa40OO1ctf0l5rogPgJAh1NfX58pkzs7OunNAR5EoreoODVHVg56sNyJTyHrIlJigG88HrVZ7oVBwHkNXV1dVGCwKqNEH6yrJQx8YY9naSlE9F/VuCGdorAFARf1rfWDNrseuy8ZntQ60QpH03kh48nVciVoZ23kqiqBNw3nIGbdhO8Le3l4Xi+W54bPj5eHzIxvb1mimEuTQfzeyHhXcVKhUqtxDClHtbEOkhfNWYRoEYSKJrwtXPc9DDQeF19WgsYQueur2M5TopchXPTB6FMMqNkLoGvbSZ8O5EtolGY/Ii93jaoScWtZmlS0VLr9qcxk1JpQcRDKeGnLaK1tbchJq1VAaP1s95nq842pnRvkeRBPpfRKizufz7h7zGfj2ROPiUTYQUkPK9h/QvdHWp5QvXV1dztikYudXygobzvD1ZWh2NK2QfRafxjRsTJIbREuGDQ7a2tqQz+crFHLUw8b+LAGIkCG9LF6AdDqN7u5uRyQiFMM56h7Ybkq+AxelMrZGhhoa+iJRih4/0w8I23R2dlYwqqOap3prnKteGCBMcqEA4oXp6OgAgIr+1HweVDBkQfPnNErYhlIvDoV4o+uxXAnC1Qo9q8Gjxhz/Xi+zT5jauHO9sLsPudKhkK+NcQJh49LHRK0XRo9i+O4R52D3l/ePd1aNTu4138M6E9Xu8Fr7rs+Qz48errZepVdPo19T+cga1rPD+6DtOTU2rPtv5VCthlw1ZeyDgZneRDhY275SUQEI7bPqBj3bPqOz0XNhnyH3mp9Lbg/j7JTv3d3dzkDmepmtQjlvkRerkB81HrIeAJsqYC1y9RYJC/Pw9vT0uI3YCCjMemlq6VCIkpkIrMYxCXumUil0dnYiFouFWHY8vNptCECIhOETplGsRREHC83RK1bIplQqhZqEa3oE41dRNJf3Deu1WGVMFimAUBpEb2+vy3GNxWKhOM7s7KxTJNrWkwYfERgSw6zyawaythA8z7cSdShEGcdT79KiGjQIlYVNQo+m79TLmLUeJe+pegk9PT0A4HoCq3CnALPhIwvhWqUc1bmpZW38PFVMPFeqmH3PXNE3NQhtf+da9ty+Bz3f7u5uRwTSfFXNxKCB3NfX5wxk5dJQKTN0ZnkRViGocqr1vFj5qFyPfD6PTCaD8fFxjI6OOriacWMa/Gqs6dlThaxhMirnZuSjPds2HNfV1eV0Cr3drq4u9Pb2uloRdMJisZgL68zOzoaQT3V8KL8edR4yEM631bxEbZVHD8LCAXNzcwBWBXIymXQKOWpsXoWGTWugZcSDDpRjDMyN6+npcQJL+34qExMos3rJmLV5o1EqOb1AGjdWZqkWQ+Bn09AYHBysqpCjnqfuu77Ukma8i4JIFXJ3d7djHZO8kc/n0draipWVFWd0kF3NtnEkD6oXpJBZrWfLKh4LifJ8c2gxjY6ODsRisRC8bcMZvEMUGLxPawmuWhQEv1rPm2zdVCrloDreO66Pa+N+au6pojPKQqZh8XAp5bWen52n9fA1RqyGkKYhaa/eWvdcFXJ3d7cLy2i+KnNfqaB43qmUlbylhEstdmPhan61Rlit+2jPtcoTZkOQnT85Oemgat49H+IFwGsE0lC1dQ4aHTzjikp0d3eHlCYAJwtItu3r60N/fz86Ozvd7+ZyOff8lVOhXCiVX1EiRJFB1kpQoNWt/XnpQQRBEMoHIyzQ2trqHuxGxqZ8MJvNfVZLjtARvSwlK9Bb07Xx8tLq5QWMKqdX1wGgwqK1cXGNzdKi1txoVptRNvNGefLWK+RhB8rFHIIgCHkWqpwJO3Pf+fsUGIVCwVm5VH608C1k1qhlaxWRZdhTGVnYTMl/1psA4ObG3FoV7I0oBg6fMtZCJFqghQqZz4Ye/8rKap9axsgBuP3zGVgUrFGeI99zsB67/tuXeaAEOp0vWbdAGcLnXqmyWG89Pi+NimF+ft4pIBqe6olryhIRIr4oQ3Q+1qP0neN6QhscFkHUftAKWatnbGW2nYMam9wXNSxU5jTiCPh0EPfSZguoQmZ4gKgo76eV1zbso45i1OGaSCFrhV+SyaQTOnrJdfO0pKAyVX1QQFQWt40zUBACZaGoF4QWHZUBLSJVgJrbRkuQylirSUUtoHxWreZPUhHT+OFeaslGrlEP4UYIUquQrYfMz+X8fFVxSMqhZwGsViNTUpoS85Rd3N7eXlEmr17vWL1BDVVoXIlGm3pgPPv0ODRdiKOlpQUrKytoaWlxws3GlusVWPascy+Xl5fR1dXl5kAjTY1LVnkLglWSGgtocN+Ufct4PYUdFTyf60ac+WpwOX9H2eWco74IYVpWuY2/Kuyr+7renlt4VtnS6iErucmSvSwXx5LC9Cyo4qhlnmvtqzXweWaVnc98f0torLYnDOf5ql5FIRt9sLgNexIdjMVizujRcrSKlNg98a3H3sUoznlTClknZWEaKgEd9ILa2tpQKBQAwOUmK+yl/+blahYGs/Ok4cB0iVisnAqkyfgKbWr8UqFhwqRWoduL1KgFuNZQRaHGj01Y5+/Z/VADaSO9mfW8FV/cCQinSVAY8ffInlWBx0tIj4eesrKK60FfrDK2MXAL12ouZD6fdwaaKmQqBMv+JOu3q6srlD5HRU2FWCt0qoJQSS26bkVzVGHxvDIso5/P72luM/kVurc0mqI6W2tB5daLsbCr1hxgulO1Yj32btTqaa4Vy9QUJv6bhgv31YY/aKwq0qLGgZ1Ps3ushowqZF/uraYbUhmv56VbY0ch92Zko89DpmxXo7S1tdUVRSJKqHUwVHZqWhPXpmz5akhFs88gEg9ZYwSMuwLhSlCEb5LJpGvkEI/Hkclk3CZZwc1/W0ul0YemD4y1TwnFAXDKlJtOq5MCiLmLLMfG/M3FxUUnSFWx+zzkKIcKId03mz/p+zsr0Hz5llEKUb3s3Eubz0eWsio631osuqHkDSowFXR60er1jrkGJXD5jAkls8zNzbn8+oWFBQetU0HYvwuC1XQnFRbkU9g8UmUNr7cG9RoI06unouElGpn0gpTIQuHD+0DDNJ/PY2pqyt13TYXi59ej0Op9Fvo87dnls5qbmwvlW7OghRZCocHD96/FC19rz63hr4pYwzB6x5aWllAoFFx6njo3GpONci/tvvruqIa/eF8V+awnrOhzAPjvZoc1gnjWtYIZHcUgKOf7M5xEHgqdLBrNRIqAssOloQfLEm92RBZDVqgWQKgcnCa0M4+N1rPWt7b4vM8CbmaefGBdXV1OCPb09ITgZiUXUPBo6UGtfMUcPCoAy7C2BIyovGMLz/m8ZH61+0ZhpbF9FupXRcjnwb1rdr7W8rb5x7zoVMgKhak3qvuoioVnjZ4Gn6cKcCWdNOId2z3TXETd71wuh3h8tcFANpsNlcqkArYGFNdDxMY2M1DhVS/jlwYBjUXGNfUZaI1rTRFjqhDPNxGIxcVFZLNZTExMOM+DQox3y8KwzZ4hfZZrFWdQpILsYO1oxRc5K1bp+gzUWob11PS9NHOAX/nzICjXY+daiZgkk0kvrL4Rwxqelo9is2Ssp96sjG50+M460djl5WV31tWY0LnSwNaa/yTEUjYCcLqD6WesKfCoVci8yAqD0RrkIczlcq4qCoUAY2+AnxmpF6MZyJqwBd+HUCItUIWDisViqGIYc6bpIbOyVKFQwMrKivPO+D5rwRpRDJ8AsZ6gXnj+jY39zc/PO8NhcXExVCHKMnqbmb9CuqqUVbBSkaoSVk+IQtIHUdEDIUSt5CIlH+mZqnWfdX81Pq8eGj+HFngQBK7IiaIkTH8CEDKMCF8CcCX9CG8rasPCBet5Fep58LzzXtqzQsWcz+edIlH4VOdAI2d5ebWb0eTkJGKxmNsXACF0SAk7zd4BPUM0kufn592ztcbT/Px8qE69FrHIZDKhMrL6vHlfrOzR31lr6L4zM0NTQVOplKt3zz1bWFhwIQKehY2sge8bKk80rKSKDCgT39Tg8jlOvjPqQx1qvYvVhj3rnCOdDCsX7blfXFx0df+z2ayLj5N8SYNYq+8pbyWqPGpgA0hdFq6xL2XGTk1NhSoocSE+RdwslKoPjcKRgXxrefJAMhapEJgqET5cFQTVIOKox3remxbK0L3jYaTHb8tlKlKh32v0wFUzHCyxi694PF4BSSpsrcrIB7+SpEQhpsLCKuN6vWRVoBYC1xczB6igl5aWHHmO94RDjYVicbWEow85sLF2zm29wT0ipKfz5L7SG6KHbg0OEufm5+fd3QBWPYvZ2VmnxIFyDXL2cCY/g6MRpawGusaGtcytGmzcz7m5OWSz2VARC8aP2WSFd5yKxTJrdZ9r3W+uk3KwWCyG8mIp0JeWlpwBqrnpsdgqC9imgDYrA9fbY371wfSWwdzZ2QkAjoVvHQDfHC3Xot67uN6odtZtNgTPO8+3bZrBfaehTGSJIVlyAHwprc2OyBQygNChVitKJ1sqlTA3N1eRDqRWKbB2haFmlbLO0QfpUpArFK0vTZFSdh4QJhGpQLUEoEZjQVbBqcepMRAlD/HzATjPhqlCNIZUYVKBaFUdSyhpdu728lYbPqREP98yWqksYrFYhfLSz6rl8uvvWtKcEm8svGkFuhqpNICAcvUrez6sAWLXXS+Eyjlp7JmvYrHonm8QBCEYm/PTuTE9iu/J2CefQzKZRF9fnyvIorHZRvJNfeedZ51sX3oyNI4pfPk7LGLB1DjtPKdchKiK4lil7Es36+npCZGhNPzR2trqSGf6/SiU1npz1jul94o566wVzzRVyhi9DxZiB8J3iL+n98cn5+udv5515f+okcX/897xjDCcobW4LQfDlsDVIi/NnhmOSGtZA/B6UmrZKt2dHjW9Ijs24gCq8WBhDv6bQpCKlyxJejvKNmRcjQd4eXnZlQMlbNLT0+MUo4XwdE7rDRVMFpqzByuTybh6ztr4gAVDpqenndFh85cJx/DgKfwYBEGIWNSoYcSvGhfVVz0xUi1Ko+kXSjDi/unXWobut4XQLZud87IwuuZSk9kJIJRapJ5ItVejkJg1/nQf1MhaWVlxZ5VIA+8A74F2zVHByqppJE+x3CnXp/tT6zPwKWNC6OwSNj09jXw+XxHPpxfNGHImkwk1f+HardfnM0AbMZ71WennUBmnUqkQcVWrBjLNyCpka8BG7SVbdJNhOH1W8fhqjXkaXQzj6XyVpKlhI12npmRag6MZma96yBrglnmvneYoM1mPG4DLtNGccN5hLYf7qPOQddgNWUvgVlNKG2UF2jmqRWU/WysakSxkKfstLS2uSw4F/9LSEvL5fMijoDBmPnMzKVA+mHphYcGR5cbHx90B05aLLPJAZjuhMsbEWa+bB45N3Xt7e0OCgN4d4I8RrbXnilBwLhqLorFC42Y9ZrqmM6hCrjdOvNawcLvCXhZOB8q9s3t6ekJFV3R/iapQiXGQ0GXzNC3S1KiC4HoshAyUW3GS5JhIJEKV6lhEhrwJ7ZLGTkCJxGq1u5mZGSSTSed1a+zfBwmvt/8a6uBnsYzj2NgYstms85I1FKJkTCpjJdCpImbDBhJ11srLr3XfrUJmDJKxYQ1xkPEdBKsMYCVSqSdpofWohp0nWeBKAmTlPJ4BKjS+iFbwxXARgFDIjwYHww2WixHFWoAwEZCGHOHpmZkZTE1NuX7OOv/l5WUn31likxUD9R4r6ZLPpdkRuUKu5oWsdQGbjTE1M3yfpfEf9mTWOAKFJtmwVHo2T69UKrkGFb29vRUF4dV4WW/4oEab2sFDNjEx4Q6YkhM0Vk94srOzM6SUWbIvnU6HoEnfc6nlOdl1qlFDZcO91H22hRmqva8S6PhslL2snlkjxg+/WuKcxpGVFKiXeGhoyDXv4KWmQi4WiyE2Np+lFt631ZkseaSRtfj+r54RU/eonOfn50MCicSX6elpJBIJRw6kN8RqTj09PSHjje/PZ1uPh2xJdYwNs5/z9PR0qAa+IhkaOuJz4h4TveJ5pzGqnk8UxXK4v5oWyvg10QXCqKzvr7LEhjQsWS+KYT15Gi2cbzqdRn9/f6iDXCaTwcTEBI4cOeLOaqFQcPKFUDxQ5kpoY5hqlRmbHRZJVEOOsm56ehoTExMYGxvD2NgYZmZmnMGmsr+np8eV1+Q9YHUvn4f8qIGsgUqISaFVjY3p5qtVZK3QKBZY71DlwcPJ/9vKMoSkCeUBcJYYcxu1xzLZqtZDbgZC1fixNqBXZayVuuhRM8aq3gTnrrWvtbmBhU0paGrdVwsz0wrVF4WfFYTVoENVzJb9GYVxZ5WCje9yaBUiesRs3sGLrAqZYQ6mu8zOzrr30aI09ZZvXGsNvn/r9+LxuCuUQOWxuLjoyjnyRULP/Pw8pqamEARBiChIggwFli1lWE/8nv+2SpklHfW8a5lYNcq0SJGWIrUIAIlojA2qYVivPPLBr5QXJHQx3EQvS1EYjcta8mOzBoJvKFrJM0DUxCcjZmdnkUqlHElXuQ42tMOfK/FUc5t9pLVmhy/Mof2ctQQo+QUMQRIlIcpFo1q7cCnDOsqU1sgUsk8R64Ox1ZJ4wPi3PmFvXxwPh5JWi9FC7zoUauWhU8+MVi8PhEIb9Splqxws89fGgukxqOVpvVVeIHoe/Jx4PO7IHNoHWveHe1LL3FVpKtFFlbHC2PxdG+aoFnfXfbFxNiV6NHpxFJlQr1vj10C5qId6xf39/Q4l6enpQUtLi/PatLaxNSisd9wIeYT7ol99L31O/CxbsY6fTbidBDWeIQChcos2h7WeHPC11mOficYnNQNCY7RkUwMIGRxaz50CV7trNVLa0eeQ8AzSyGcHIgp/TR+yBEKF4XkWNiKWrPeL99Cug95mZ2enK2bDEIbGkLXaG+9NsVgMpRrZXP4olLHuvc7XcmyohLU5BiF6NarVK7ZkrqjrSwARKWRLvLCel4UpeEkV7gPCtXd9ArRaHCzqoUpL4TVVnjbWrA+GqQ7aLIDxQj3gut56YmpWOVh6v72savmyMpPGy5RoRlhpdnYWnZ2dztOx5BxVlLXsp4WX1fuzitfG6m2RCX32PrKVveTrGXi17Dk/S9dk44LLy6ulPAlx8UXPmPtNdMJ6/dbbVwShEa/fJ5x8rHWuTz9X76ESYmhEEKqjErHGoc/La1QZ6z2x3I5UKhVKT+G69aXnSmPlhOFpMBGK5LOygrcWY8giWCoLCTdzDlpO01bz07CU7inlSVQQabW95r5Z1JOpebFYDIuLi06WUFlpXXN1PjSbwJfOZxGnRmT8WuihEriUdU89BMDxJghTsxMe7zCNNR9M/ahTyPZC6kt79DLXSwP6tF7XIn893NA1EPbq+H/9Gb0I5pjykhUKBZeLSkinUCg4b1nXrMUW1hrVwgF6kH2CC4ATBNoHmS3eSCris1PBS4iHv6MKX5XkWnP3Qcrcs8XFxZAXxrxMq8B9CslCUloEXwk+Ft3wwd21DvUe+Owp2BOJ1daQHR0dGBwcxNDQEAYHB9HX1xcqYk8h54O/1YjSs1GN2LWeYvDdS5v3bb03Pg/O1Wcg8P86B1/cTlNb7BrXG9ZQ0RBSd3c3ent7MTg4iJWVFWcUZTIZVyefDTwUIUokEi59h+1Ht2zZgoGBgQo4Uivt1cqkrWXP6YUpC18/b35+3p0zPd982bOwEUMdD32+aoCr3NNMAna1Yoe1eDzu0BPLsraktWbQE2s0cO+VxMVezuodk1HNZ63nYmhoyMlKKuSoC4HY0bRCtofQxge0EQMVMjeDlXKqeTFA2CumYtHf19+JalgIT2FSFRC01HkgCVVqDiEvJIUEySdA+RCw+tJ687FK2ecJ87IzjsaLlUgknOW3ZcsWjIyMoLe311ncZM2yKTcVMr1nFa6E/KhI17tAKtSVEa2woF52Gw/2sa6t18aLx/1lOESVsTX0Ghn8ez77ZDKJUqnk+mirQh4YGHB9prnGIAgqjFZf6hQVkBK7GvWQNbXNFjVRo0UZtVTQfB89g/q12meuBZXXM+zZ4Z4zLkzSJGtUK9xsCUMkrTG+PzQ0hJGREaeQfTWKbQy/Vg9ZCWVa5YrGFjMCVBlbQ0jPCtEG7UzE/dwoZ8Uar3x2GlLRtFAWM1FkQY0L7gmfja3TYI22etdmvWNNbyKj+tChQ5iZmQkVXeFaKB8HBwcxMjKCLVu2oK+vz6EnGrpTvbTWvtU7IveQldZOxUSiB2MNJLHwwthFWUubFpZeiPViio2ug1+t4OHnahxK8021RKgW6mfcgoqOnnJHRwfS6bS7rPUIKivc1FunQiarlJ6bMjz7+/sxPDyM/v5+dzlaW1vd86PAnp+fD3nwVOyE/BSGXGvo3lHJqDLmZ/j+zqdELZyn4RBWYALK8X315huB+nxwssKmABzRiTmaJIKwzzRhak1VU4VhDQhLeKt3/tZQVjKNzxigkcWCCLFYzMXFbUjApntZAzYqBcH3USOM3jD3iyk6XV1dIQMCKJd45O8mEglH4mJsf3BwEIODg64nrpYIVfJmLcaQemhKJtLQnBpWVPb6mYrIqbdXrVobP/fhQhB990AbvPhKpnItsVi57oGFrX1k31oJo9ZRUbhaa7RPTEy4DBSmnVGGE3lJpVIujEFCJs+X8mjUePZ9tf+udUTmIWvSNaFOsnzpHTNv0bY/U0akjTvTe1SWqS/G2Cx847PkqyloKpd4PO7mRsWi1a0AuHKDRA14gAnZN0Nm4OXQSkCM49DzoXer89OaulSKrGOtMcPFxcWQQiADkdZ6LRCTHlSraLQ4jIU8LZxr4+O8cFRsWqGMKV4a/2uUGOWD3FUIUdAwDsUYFK1plo4slcodw3yFH6op4kaUMYdPQeieKcuVCpioCgD3b83hrMYBsQhENfZ8vUJKlbIaQvxcfg8od0zq7OwM5RzzGTHEYJnVhKuVXOcz/ms1hGwoRY1jeuoAvDF8fX89977wRiMyo5mhXrJVyop8qZGtChkI129Xo9CS19ThquXMqB6yhq92MWPa3tLSkgvjMYNEq6nxRcdBQ00aBrOobrOOYlMKmSDeH14AAFGLSURBVA+HB4cwJ/O86BUrZK2eo+YOxuPxUK4Yu86wjKMKJmWAcsN0Po2sQ9mY9tDbw2/hcj2Yqvz4ULg3VJS+Cjy1DH3o1kKlIiCcx/nOzs5W/J3+vRXyFCaqaHwMeTVgap03DRmfx0FIi3F2G2tiLF6VG88SX1TKpVIppIxtXm+tys3OWw0urp+/R4XMz9B5Mo5mwzds8A7AnRkbt2zUqweqx9WU9c9YaxAEoVDHysqK8+w1150FaLR9IfdA44rao7qeOGy15wBUttjTu8r0Ie69nnklgpG4pcRGJVZZ7op+fr2ohGY9aH4rFTJ/Zlnh1ZyDR0IJ+4YPMfIZkeokUZkRsbOKWVnXCnUDtdem8MXv+b68e2T+a9VIizbo+9CwonzS82HlqBqjlhhZ64jEQ+ZkWdJucnISY2NjDqpVy5qeAQUSNycIAnfpJyYmnHBgNRRbRELZwjwcjZIB1Iuw8Em1WK0SWxRCVValj3VK765WuJefWU3I0CtmFZ35+XkA5brixWIR8/PzIRRDFR2Vhq8AAf/eRwKq1YjQNfiMASAcX2JMh0prdnY21NITKJND5ubmXMk75SWopWpJdxrbqtVDpsChBwmUhQT3mYqMwpb3gWebHWXUkydKAsCxUuldM6WoHiKXb/iUsjJPuWc0Yjin+fl5Fxfn3LXLGQ1uJS8ylqieJysaaZW6RmP4eoZUFlikxfIdmMai1edszrEa/M3Ajj5uAw0XolmE1hcWFjAzM+PQRKIOVFpWNkQZDohi+BSyLxtC16HyUWFl3ofOzs6QMq4VlVgLDVJ4XOUc5ZsSEIkGFQoFd/5Z4tYaaxY95IvGP89kPaMhhfyu770L/37Pv+OeyXvQ2dKJp2x7Cv7q1L9CammVXDE6OopDhw65S2sbXPNF65wKoK2tzXWAKhaLyGQyodQceicsdUdWrsLH642rbr0KV912FR7MPAgAOHHoRLzpN9+EZ+58ZsgTVMWkTGZ7ANULVojENvZWaNKn7GtVzEdmj+Dyb12Or9//dcwvz+Po5NF42ylvw5beLQ4ao4CiIlALnPvOZuitra0O8mV8UVM0fHHC9Yg91eYei8XwyV98Eh/5n49gcmESuzp24UXdL0KimAjBe7FYzEFMREpYSCMIgtA62OuWrfVoYROG0l7JFMC15g/y5+/84Tvxju+/I/SzY5LH4Prfvt5BWbFYzHEd6A3Te+T/NVyjDGAiPiQoKSu7UQXGZ9P2zncieeWVGJaf5Y86Cp//u79zZBe2AmxpaXGV51KplEMCOHctOKO5p0C5/znjs4zRstygr7fzemPX+3dhf3Z/xfdfdeqr8N4z3xtCW6iMLYOXhgINVsYHdU71oCbr7TmVwv+9/f/ig3d+MPTzkZYRXLn9SndmFhcX3fklYsjiQT6SXzOGWT3j5v034x9++A+4/fDtOFI4ghsuuAHnH3d+6Hd8qJuNu1ulzDOp+cFECWiA897aVy1K+T3/9R588RdfxC+nf4m2eBtOTJ2IF/e/GMFiuViJOhZEruiQ0FHMZrNOGc/NzVXcQ1/4yuon3aNaY+FAgwr5pv034dLTLsUTR56IxZVF/M23/wYv+cpL8MnTP4lMJoPJyUmMjo5ifHzcFR63cQL+m0IrCALXlKGlpQVLS0uYmZkJ4fiE81KplPMEifkT0lxvbE9tx5VnXYm9/XtRCkq49o5r8cLrX4jvv/z7OKbnmAoFamFltc55kYEyLE2FTE9IC/E3yzadmZ/BM/71Gfjto38bN/7+jeht68XPx36OoZYhpFZSoXw6em30xlQhM8ZPCG1packZTrQkgcq8W/13PUKBv/vFX34R77jlHXjrE9+KY9uPxbW/uBbvn3w//rj0xyFWajwed9WeMpmMK6RBQWaT/ellsIZ4R0eHg8e0L7dVyOvxDnSdJwyegC+96EtlEspyER2ljlCKj4WGNRbPUAzvAz1trWLFYgSpVCpUlq9ZAby8bx/2/8u/OO82+yv+BskuMzMzbn97enrcvhIWnp+fr6iIxRgtORSaF9zb2+tyOJle10gz91tfeSuKQbk5yJ1jd+JZn3oWXnjcC0PKgINGsXpE9Fg0t5TMWZ/R08w+W09tV9cuvGn7m1x6TWmlhKmpKTdX1qCnUURDTZ0LX6jJxiejVs6zS7M4efhkXPKES/CCz72g6u/ZkI6PAGdha4WDeVeokFn7gLKVL6J1a8nMIAjw/QPfxx+e+IfY27UXmXwG77/z/bji/ivw2vbXVhRIsh41w2Ksk06Sq6//gEUo+WIhJSDcM7oeed+QQv6Pl/8HgHK88Z9+95+w5+o9uGv6LnQtdrngOespUyFwU7gRCgNTeBH2XllZcdaJxovZdSkej7t8RJs2stY4d9+57t9BEODvnvF3+OjtH8WPDv0I23dvD3lqJF3ZIhOajqIKmXCHrdVK2JJrVUi1HuX2nh++BztSO/Dx8z7u9n5ncqcToHroisWii+fncjkXq2SMdm5uzh2Y5eXlCuOBcC9QWYO6Ho9Nf++qO67CS497KX5/z+8jl8vhNbtfgx9N/wg/jf8UIxipgPuoxKiQSYpRaIkemxIEgTLZhxarGnb1CGAHT8VbsKNvRwhqIwdCYU5rdFJBqEJm6EKFjnIAFLKOQlGgpQXFoSGstLVhqaUFi/E4iuPjoRoBc3NziMfjzuAgya9UKoWqHBHS5l3gvdT8XtaFVm+/HkOIY6h7KPT/d33/XTi271g8fefTHdrEoYRAriEIAoc+cH6awuLLKY1i8A4mYgl0B91YXF5EcXFVPuRWciEloD14LYtalV2zfIJ6xtl7z8bZe8+u6XdVllXz4qs9J19ozxc6q3X8+wv+vXxWgwxef+zrccFtF+Bgy8HQXFV+0VtWI5+GMAv96BnRkImvBoUSCJU3UeuIJO0pt7RaxznZksRyadl5MRSuhAWUlckNopDXiauVqEKLcWVglc2qEGu9Dw8AiqUiPnv3ZzG7PIsnjjwxpAzUU1YYFwhbP4wHqaBW5i9h5FgsFrKoGim79qVffAnPPvbZuPALF+Km/TdhW3IbXnHyK/CSx70klIpE44AwI1OtyF6k4cMDtry8HPKQCV9yrZYUZYXDeiMWi2G5tIw7xu/Anz/hzx0K0tHegZM6T8KR4hEc3X402tranHfOOebzeQBwBRNsLJRzpvClt0YlzK+2RGWt+87fuW/mPuz64C60t7TjyVufjL95yt+gN9Hrfq6GgjKYFZYjZ4LPgWeACoNkIw3TqAKrRwirMGy5/34c+/Sno9jaipnjjsOPX/QiTMn7qlDi39DzLZVKzpvmXjOcoegD+yCzKhmVHve/WVh4qbiE6/7nOvzlU/7SC4Na5EnjfPRc1EjYiNZ5nAfHoflDuPT+S5EIEtiO7fjt5d9Gx2JHCFYn0VWRKfU4ed8azRKIelikTP/tczZ8jodV0MpL8THIa/Uu9czHYjHMl1ZR1J5ED5balkLQsv6ekvCY6hmLrRI1acBbw0MVMkmBvBvMRuns7KybEd+0Qi4FJVx+0+V40pYnYU9qD3459suQ5UABz4nx0vNwccGagK9pAfR+lXDV0tLi7RdK72i9cefYnXjqx56KhZUF9LT14DPnfwaP630cCoVCCDrVDk66DnvYAITmqOQB1khlTmcjsUyO+2fux1W3XYXXPfV1uOxpl+GWQ7fgdd94HVrQgufvfn7Iul5eXnZxPsKRNHQAOMXAw6jEO6BMWtJcax4yX+m49cb0wjSKQREjPSOh9IKBjgE8NPcQkskkCoUCgPIFXFhYQCaTcXCSksAU9gLK7ObW1tZQ9x5fsYd60+RO3346rnneNXhc/+NwOHcYV9x8BX7vC7+Hrz/v6xXcAR9py6Z4cH81l5klG3XOzXpvsVgMxdNOQ+FDH0J+2zYs7d+PgQ99CGe+9a34wtvfHjJUNB9Ula6SEBkCoAegSo75m6xuRKVn+ws3qkS+eM8XkVnI4OKTL67IivCFlJhuRs+lWi/bZkhm1fY8Fovh5IGT8Vf7/gpdc104kDmAL+e+jGvj1+IPlv8Ay3NlmFZZv/x73i8N1VUzhh8ppVxNGftgbJWX/H0ldvJsKAJXb7xcPyORSCDRksBH938Ux3Udh909u7G/c787Bwxn8YwDcORWyj86BJqyB4TlIu8vz1gymcTKSrlyHFGYesKTTSvk13z9Nfj51M/x+ed+HkE2cNBbKpVyDMx4PB6CbymMKJTJMK120KiYqczb29tDClmZ0bWMfYP7cMef3IGZ+Rlcf/f1+OOv/jFuOO8GbG3ZWlHlRQWskqNsaoLmyGqsgmvhOqksfLDkeqMUlPCkbU/CO5/5TgRBgFO2noK7J+7GNXdegwuPvzBkLPCAEbrr7Ox0MSolGWn+OCF6Ph96QRS8tj1drcJML6OylUm6SbQk0N/fj6WlJQdPA3DcA728agjphWbeLxsGDA4OuhrF3GufIVHLvhO+C4IAJw2dhFO2nIJ9H9mHL/3vl/A7vb/jvF/Cj2TNajUgnod4PB4Kv9hGFOrBqcFW7+A+Fc8+G8HSEhLz8wj27cNDJ52Efc9+No6/807MHH+8K2MbBIG7nzQe7KCg4dd0Ou28Yo0ds/KV7YjTjBL52E8+hufseQ629mx1SJuGk1Qh8xwkEuVCIITSFUZXyJp71sxQpXDmjjMx3j6O8fFxDKwMYGRlBO8ovAM/T/wcR80f5UJ5VMjMmdYsEh/C0yxT/eEY1mFhcRYiQ0rG0/oIzVSls57re+5+Dx6cexDvPeG9KGVK7gzMzs46NE0dRfXStSiSNSKAMLGX4Sa2XF1ZWXFyfnFxtQnHw+Yhv/o/Xo2v3vdVfO2Cr2EgPoDxhXEXq+nv73fCXV1+emkU9KqYqEg0IZ7CjMoYQKgspSVf1TLaEm3Y078HpVIJJw+djFsP3Yqr/+dqvPnkN3sTyrPZrIOW7GdbqFzjrXxY9HpoqLAZuvWC1htbk1txwtAJAMrC4/ih43HDPTdUxHt5yGgFdnZ2uqR4FlWn96zeES10Xh6rkK1yq0UZx2IxDHUPIRFLYGppylmNKysrmMUsBjsGMdAz4D5T21WSAcyYmhYhoGLgBaCwpaLo7+8PEYpswYJGBVpvRy+O7T0WD2QfwGJnuRoQiU/cY40X0xBhGU2NGWt7RlUWCpc1qpR5Dp2BeNRRWNq1CwMzM0in08hmsy7tCYBj/zL3UkMthOa0/CQNH+2lTcON97pZ73h/Zj++ef838fnf/3zIAPbxO4jOMc+U87FFQNSojNLTVC9NkcJ0exoDswPIxDMYXBwMsdS5BmXbK/dBU/YaQac2ctjP5/+5r5oKxZ/bdCElRlXrsrXWWtUAaGlpwVv/+6246fBNuPbMa5FcSWK8OO4MXxqfiUS557TqD8p+39qscqbxR7SICF8ymcTg4GBIP2yoQg6CAK/+2qvxxXu+iG++/Js4uudoFAqFEATHLiyqXLkRsVjMlZsk29HGYqkYAbiF8dLZfLJGY8gcpaCEpeKSN6mcXjILIbDCGC1brhEI90wms5PrUsWmpdjqgU+ftuNp+MXUL9z/Y7EYfjn9S+zs3emUIw+wroWGUSy2mprDghRMNbBkqFisXFieMThlKnd2dtYlFGKxGDpaO3Dq1lNx00M34dw9566mqCwv4bbp23Du8LlIx9OOuUtvbWlpKVTzXJEVKjTuIS1UekIKWUepjAEgv5jHg9kH8dyjnhtCU8gMJ6Exm806ghyhXp51LbnK+8KzwUvebM4uFbJ+Lz43h/aDB4FnPct9riInPM/KfeC8ebe1FvTw8LDziG1jBgpeC1fWO6654xps6d6C39v7ewhKldXcNG9eeQT0NLW0rQ27ROltWsWgsqCYKGIGMzhm5ZhQJUIlyPG+KUytNa43IuYd5eCc1oKwgTLsa9nUusZ6ERX+3hu/80Z87YGv4fPnfR4jrSPI5XLOGeKLqKEadgzN2FDIWp9HQ4MKuVgsorOzE1u2bHFhCF8++VqjIYV86Vcvxafv/DS+eMEXkWxLYmx2DHPzc1gJVkICs6urq6JMH4cNiNOK15QbQh1AuAqLVcJKBlhvXP7Ny3H23rOxM70T2fksrrvzOnzvwPfw6bM/7TZPrW+9PFTIFsbmHPUyETIjDZ7GilYFqldBvPb01+KMj5+Bd37vnXjxiS/GLYduwT//+J/x0XM+GoI26X1SiRI2XVxcRDabdYqZ6+LBUQIXUDYwrGBoRJDFYjG87qmvw//54v/BE0eeiJOHTsY/3vKPWCgu4Lwd52FpZsl1iyHj10LrtEi530xN4HkiBKzFKKop43oE2hv+8w0493HnYmd6Jw5kD+Bt330bErEEzt5xNuYm5kIwr4Y6eD6ojFjgxO6rluvzxVwbFb6xWAx4wxuQOOccYPt2xA8eROcVVwCJBPLnnIOOhQVvwRTeNZZOVe9T47GE2plGpBB1M/utoxSUcM0d1+Cix1+ERCyBIsLyQMNGVIY8F9xf7dzE/d1I6DcWi+Gdt70TT+l7CjpXOvHA4gP4VPZTiCGGvQt7MbkyGap5ACBUYYzKXJW6r/rTRinlwlIB903f5/7/wMwDuGP0DvR39mNHase6a+dX3V9CxOpdKpqgnnS92Rz8nb/4+l/g3372b/jMeZ9Bb2cvMgsZ5Et5oAUV5TCZRaAOiDL0NSPIQs78PPWSSQDT0Gy9yhhoUCFfddtVAIAz//XM0PevPP1KnNF1RoXVYwW4Mqs1sZqepipEfSCEX7lJNq+3lsWPz47johsuwpHCEaTb0zhpy0n49xf8O5488GRMT0+H3lvnYcsOUnEQ8qXAUiRASWC0mpthS5521Gm44YIbcPm3LscVN12BY/qOwfuf/X68/PEvd2vX2IzNk9OyjowdKxLBeXMP7IXRS9OIULjgxAswMTuBK753BUZnR3HS4Em45pnXYBjDGC2MhgQlvXytdatwmCpm9Sz0pUK33rixjoO5g3jJF16CqfkpDHUN4fRtp+Mrz/8Kupe6USgVQtWHtNCAloXVDAM9//bVzP76RuzQIeBlL0PL1BQwNISVpzwF01/5CmI9PWg9ciQEHfIz9W7pWVCDjYpOS2QSObFeXDNr+Ob938RD2YdwySmXAKjsdewTljZOuZaw3whlHIutFvC57L7LkFnMINWSwu6W3fiTlj/BXGnOGcxacEj3WMNe/Le+Nhqqvu3wbTjzE2XZ/rr/fB0A4OKTL8Y1z7um7vdTdESVmSJ6za4xFovhoz/+KADgudc/N/SzN534Jjwh8QSnb/RM6GfQGdNUrLV0C+fN4j6a/66coshjyHyzXG41vSn72qz7vrLT2HNSC0wo4UqhJavs+H5aNETZ05ZBrHByoVBAW1tb6HBXm/v7znwfcGb5YtNTZD6rer+WZa31nNUrYuyVg1a6rdPK96K3TVSAZSDXmzsAPH3k6fjBy34Qej65XC4EwSiMquuw9ah1Xznsc9G0HeYu8/mxKpZlR1ebOwBcdNxFeNnelznvN5vNYmpqqiLNTOdIOJWKraWlJaT4uF6+NDbI81mNPMf5rTX3q591tfu+FvqYLEy6s6Kfr/PXUoj23NhKRfy9Wol+tcwdV18dOhuuWMz0dMigtCEgvjT0oXO3xilTpGpFfWqaO4DTh05H9rVZBEHg6hMQieBdVYKUzo+cFd1jom7NQL927jpvypN3n/ZuF7pgNa6xsTFkljIhcihlh8/4Vx4LeR31kEDrnTvHqf2nOhnv+3ueI87Nyktbm9uiGJolYctnamc03lsaJkyB9M09CAJM/8V0CNlkydepqSkcOXIkJK+t/FP9pP9W6NrnJVMWWYeN55KpT4xLr6uYgxrGgQMHAgCPmdeBAwc257459825PwZevw5zf6zNe3Puj/zcq41YsK7KXoWKDx8+jGQy+agkE3AEQYB8Po9t27aF8ss2576xY3Puj8zYnPsjM+zcHyvzBjbn/kgN33n3jZoU8ubYHJtjc2yOzbE5NnbUXq5oc2yOzbE5Nsfm2BwbNmoidT1WoIFfJxgM2Jz7wzE25/7IjF+nuT9W5g1szv2RGrVC1pukrkfZa3Pum3PfnPtj4/XrQC7anPsjM/dqoyYPOZlMAgAOHDiAVCoV+lngSbdhmkE+n3epUIcPH8bo6CgymUyo1RjfQ3N1tbg6qwNpNSOWwuvr63P9Vtn9adeuXW6+681d16AFRjQXTZtMkM7OVnTT09OuxaSmN7BCUzwe9zZt37p1K3bs2IEdO3ZgeHgY6XQay8vL2L17d9W59/T0hDodaepHJpPB+Pg4jhw5gomJCWQyGVcxinWVWSIOQCj/m116tmzZgm3btmF4eNhVvNJWdWvVVs7lctixY0fd++57DnqWtHVhNpvF2NgYDh06hIMHD2Jqagqzs7MuN51Vmbq7u9HX14eRkRFs27YN27ZtQ39/f6gyms6/nrlrI46ZmRkcPHgQ9957L37+85/j/vvvx/j4uGt1yXSWjo4OpNNpHHXUUTjmmGOwZ88eHHXUURgcHHR7W62703qjnrkHQbmPdC6Xw9jYGB566CHcf//9OHDgAKamply/Y9sbW1NXeE+531r+c2hoCEcffTSOPfZY7Nq1C8PDw0ilUq5ErKZCRXlmbKoLUyGnpqZw+PBhPPjgg9i/fz9mZmYQBAH6+/txzDHHYN++fdi9ezeGh4fR09PjaurXkm6mc19r3npnKROZYslyvKzwxruq8iWTybjObOl0GsPDw9i+fTu2bduGoaEhJwMHBgZcjXyuw3eOqs39oYcecs0RNBWJ+3jgwAF31h944AFMTEyEytpyaA6xr/Qlc9m7u7sxMDCAo48+Grt378bu3bsxMjLiyt2yRC/vBdOeat13K0tsimE2m8X4+Gqt8bGxsVCf8iAIXOpqZ2enq9G+bds2DA4Oun1mIRwtqlPt7vrOu2/UpJB5QCmg7cJ5IVj60NYHBVaLIbAEI/MtbalL+2BVSXIe+lC1ywaLEuh815s7P9MqYSaEa+Uu5mHaNoxsCck+zSw5yEIWthKSFlXQGtHa7cXOnb+jiopl6JibqFV9tEC9tsqjUOXB0SpGWjeXVdZYbpD761NovnNSy777Bveaub6sGsWew7aqjtaJ5UvriGunHy2z6Zt/LXOnQmYRm1wu52ojJ5NJzM7OumI2NH70knLuHLz0nCerXDVSFGG9uatCjsViWFhYcHXVaUCqINJWnPriYKU9bQigDWNYEjSZTK5Zsz2qM8O1UQYUi8XQ/JgfDcDNU89GKpWqWSHb+a41b55T5kVrvWStBa+5uTR86JTQCNIa57YWtO657nc9c1cZw0pWsVjMNUmwRiOruPHcUDb7CiDxRYeLCrm/vx/9/f2u9nxfX5/ryOVrCFPrvnNelOdUmDwb2jxC626r/FAj31cumPK/nlae652tpppLcNFUQHNzc8hmsyHPkS8qKE6ek+Ml0UvF4grcULXMuXHa2rGREmV2/tYb1sIjbCCuhUL4O6zSwgvOogMUcLz8XKdPkdQ6V7W2tdCILVxCpdDa2uosf0165+fSSGKrOs7PJsdzjvy7jYjXqFLVjlta/CGXy4W8C3rILEoBwCkT7WFdzz6vN3gZta62CnUaQNqwQe8HvR9tuKLV6jZqf+3ctZ42PaN4PI6Ojo7QeVLPkwapKm4KIe6vr7CCfQZRrk/PpjWs7YuohZ61qM7FenPUu6uemtY/1yY2LDOspUgpOwMpJtTR0REynqyB2uheq7wFyoalertaYYuKVyvlEVWjEldF19XV5dp1KsqpTVX0Mxvdd3sutCCPylAtPsS9pnNCfcVCRpS/asRGVf2t+X7Iv1osoSJWRRkfH0c+nw/V9bRdPrSUIwVWJpNxFiRhX20uwQ1ghx/7s0bmT2VMqIvCXi8LmxxoMwn+PRUbANecWksl0vBQa97n3VUb1nDQ6kOsxqWdqIhKdHR0hA6K/j0Pk9bN1f3wzXGjh14cKmNt1sBXJpNBJpMJdW5hmzM2mtBzZ5GYRocKJ55fIjR9fX3O6OR+LywsuDWxGtz09LQzIPQsN9Ifu5FhEQTCzfH4akcwW2XMKhA1dKzQVuXrU8pRr6uaMtZ6xLZaINELayj47mGzc9X3tEZ/Pp93Snh6ehpTU1PO8Od8Abia+Gq0qYfHM6g16Zs57z6Fa+tKq4eoIQztfMRQDENF9CC110E6ncbQ0BCGhobQ29vrbQZjveFG9t+G+bT0MZ0uGtC8h0QbaBAFQeBCf2wcRLTJ1kZvxhCKxEOmkC8UCpiensaRI0dw5MgR5HI5VzqQ1obWwCVcSqWbyWSwvLzsvrIkI5UHDwKVHqFxhbXrnbtCSvTwGYPN5XKYmZlxsRwKWHoSWqiexgZbvyn8Ta/CJ6ia9ZBtWUx+Fi8r95mHip4/jQt6Osp0taXk6plnI0MFl0J6bGnIZ2AVMpuJax30RCLhDDUVUFEZFSqMKHzS6bTrmsX1qIJQhdzR0RG6tLzcbAm4kfusQlZjwCxRyObt9l7QUGXvZG3PSKNC121fVuFFrZR9CtkiRnwOes7tvKzRUO1rI3NUY1i9Y8obnuv5+Xl3fxk6opKgnOE95nNkAxlbQ3mjhkUrtc0kERd2XNNWl3wpN6ivrw+Dg4Ou9WjUnbjs3mu5VyplytBYLObq/pP7w5LMfG4MF87OzqJYLDqDSD3pZkbDCtleBFrS9ALoIReLRXR3dztrh7EOJbPE43Hn0VA5UBhQ0GrzgM7OztBla1Tg2odFggUFPi8KW+kpjERhSquPD5MXX0lJ1nDwCap691uFj+4BlQU9RfamJRw6Pz8f6k1Kr5KHfi2PYSNGNYHqE1raZ5gXCYA7Q+zyY7u18HOiGvQaVKmRmKcGE0MaQRC4NeVyOSdoiPLMz8+HQi8bGRrQOB/v4srKiqvJq59L5KtQKLg91oYNPHsWqrOeqz1HUZ8pn7Hqg9l1Xr67RBTLKmUqIP6/mTna5iOqFKgYgDIKwzg35SL3c2lpCbFYzCFjqpB57huZszVQbIiNg8pYnS16lexHzi5glPPaMU45NKqM6ZE2CwH71qA123k/1XAjSqjGUGtrawVKxP1vbW1FOp2uCBc0Mxruh+xTyFQ+2vSBLj1jfdrKT6FS2/2DgoOLVI9TYbVmveO1oCSN6RAKZlxDCSyMwfK9CLUDcAJLP7sRr9OSXxQtsG0faSSw3zQvNGEwesmWZKQxIQtdbcSwqIGFleiVMXasipjwPICQoOL+8v31axRDz6b2/06n0+6M0ppWRatWNp8Ve2tbo2ojhoUheX7Zx1XPB58/PXt2LtPGARaGZizZMk036uxw6H2yyBGFrRJLLcSuCpJr1/1SI2QtFu16c9SzaWWANdKV4UtiXHt7u1MEmjGhHYYUmm9EKVt5ri0iKXOt4U/lSsOfGSX9/f2uTza9fFXIPiJao72Q1xuqlG1IQ5tLKDFNOxVyP2mc8r16enrc/Y0KhWvKQ1brUhWmHg4eCFp8ynbUi83Lzr9pb293cDfhGiVk2A2oB/qt5t1rzJLpWYwT6GEiTEn6O2OCyjTXz9LLwZ/pz2sdFA5UuB0dHaH9D4LAeem08ohExONxJ7BoPPAgqWeg+11vv+Z6hjWIKADoMRQKBWccEaHIZrNOiREd4VmyexNVP147VEgTrkqlUm4vlU8xNzcXah2pcDwNi4cTauT81UPu6upy50bjwgBCBMVCoVAR0+Pe2zaqqpg30qjTu6yGkKYUKadC+Sv2eSQSCed1cn2UWbo2a8TWMqxypxzkOVUlReOGxjS9zLa2NofiqTKxCpPfpxNU696rcazd4riXjJ8StaSj1dnZ6ebJVEkypvv7+12fbJtpog7NRipjHxqiessSgtUp0dQx6igaQyRu8nypMd0MulW3QvbBAJbJyInpIVRGJgAngPUAFAoFl+LS2dnpPB8ATklovKpRWNVnCRJu11xiwketra2OnECrlVR7xgSXl5cdIYPekEJm6sU3c6F5kKnglSjR3d0dMgiUTKH7pMKLh4eGkKZR2DSHKC+JKmPbtoxM5KmpKUxNTWF6eto9E54R7qkKN8JiG3XBKYwpNLnnyrYE4ATn3Nycy1umsONZU3a8Db1Q4OnnRjFU2NCg4x4qUYvrpCFBAhEQhqEplKnceWaqnZsolbNVxipDeI+z2axDKbS+ARXrysqKU+Bq3HEoD4P714hCBsJkOioyGszqya+srKCtrc15mUNDQ0in00gkEi5EQ2XJZ2Rbqy4vL4fS52qJgfN80qkoFAqhWgvMkebdo0HGPN2hoSH09/e7egt9fX2OqKVnQg0RTQPdKDnjg659xDe7Vz4khXtNR9HHU2l2NKSQrXesSlkFlvVY+D2FXgjtKTzS0tKCVCoVymMG0PQD0wdkCSCMH5NAxB6bhGP40tzh7u5uF/Sn8OVFV8FrUy4aGXqhKfwpLBgr1lQsPXA0EDgXhVTpIdPqU6tVL1Gz5Aq7/6qMbZhDY/dUxkyjY8yMkJ0qRyXB0BipJ6+0lqFGJtPJgiBwZz0IAiwsLDiYXQWtetCqkHk++NWiEjae2ez89RzFYqtcCOspEKKjgiC6o/dd914NIStcox7W66GnaDkg+Xw+lBrHtTIeTmMpkUg4Ug/fm2vr6upyd4SveoYKekVVNDWPLxrG7e3trlgPFXI8HneFONgXWDMmtA80uTj1pA5xP9WwoUKemppy94/yhaghiVlkS1MRp9NpB2OTlEb5ry/VD1EqY9/arIeshqXPWKQMtY4b91efX1ThsboU8lqkCJ9nrA9AD4XCehSy9Cx5EchYs8SCZlND7EWmIUB4hpD17OysU3oM3jNfjonrTElQQUpBRoFrSSWNeAvqIVOw8HIXi0WXZmXZ3eqRaXxJ50e4EigXqqDlqhBksxfFd2bUGPLlZFKwksilFccUYtcYOo0JMsujVgwqXOkZ6hoXFxeRzWZDJBXCfFxzLBarIJUo9AhUGp98/s14mWr9axEdwpsUQHrfeN40TqtkLt5HLarAvbAkRrLLmxlr8Q70DKlHFwTlrAMiKUS16HXqOtXg5d/SQK3HE1JERQ1qGtI2vQxYRQ7b29tdSKy3txepVMo9u3w+7/axWoqXxkXpAK03Z72bSnC1REobk9ewDT1jZVhrwRtfoRAbn98IZWzPosbY+Tv299VIsntr00L175odTcWQfQQFVcr8aiFipjlRGRLaVUiWVrkKASVV1SucfHFjfralwBNupPer0JVS3H3Ciu+p3g8vuM3pq1cp69/xUtqDpkJKSUUKC6sSAOCMKfV4ovR0rBFkY8Y0hChMqYQZBySJS8+BGjc2r91e/qitbn0GWvBA42O+ODz3Qdn3igx0dXU5z8OSpNSraJRcVG3++j0fx8KmElEoU5GrgNPfXVhYQGtra+izNWxV79DzrqQjrR9Ag65QKGB+fj4UiqHRph4yf0cNCe4HDe7Ozs6msjnUmKYcoXelDF5CzcpE5gtAyBmxRolVxJaxX+v+qmxUWajFR6iQ9e7R09Vzr+efv2OVspVrUQy73rXiyKrDKJuUGBiLxSrIjFEo3mqj6cIg1YYlCfAB8wFQsAIIWdc8sKqQNR7LQ+BTbLXMSRUnY5Z8KXuXBoLCjdbLobJdWFhwxUPoyZF8BCAkUPWg1nsI+buagM7DofPjvJUZSEuXuaSMNatA1kvjuzg6h1qHtbzVWOFFp0fD/aNCJqlOU5iqETB03tXyGDciLuUjNvoMU/4t16IkQsbLefkV4lNhp6hFM+GPanE1a9DRWCBBTfPeuVaukXeXXhU5ICsrKy7NkbFYwuWNzFk5J5oaxxAHPWMWMYnFYiHFZhUyvT5lI6vnr55QI/ttjR8tjEGEjlC/z/C191sVB2OZahCp0dDoHluipcamqeitscZnwsIZPnSNMkvvbTOITy1rsefcIoiKUgRBEKrOSFRHz7zKSR/U3uw6GlbI60FpasHRC0okEg5/59C0IeudMaZC5cVL00jiuI1b2rQaLVmnDF4eRl5+AKELXCwWnUKemJjA5OSkY2hrXNZHlmoEDtbf071WD4deO2OZSo5iLEgZ1jQUbHzHemLNhAh0DxnbU4iRxVjo7RAio9BXo8aeO51rNUMiyqEcBDXsqAh8pTut8o7FYk6RjI2NuYuvULcvRaSrqys0j3rnbVEiK3DVu52dnQ0VY+FzY+oH77N6Z3oOFxYWHHxJgaeCud55K+nIhjdY+4D3r1AoIAgCF7pgeg0VXxCUeSwcKmiDIHCpYM0oOSDsIVMZq5Gaz+edYUyDWtOxCLvbNCQffKrzrHePda9tuEGhWTUMNN5MDoVyIpaWlhyvgwaZypWNuJ92XXp2lAPBPeR5p7xgOIdhy5WVFXfmeUbWqrXdzGhIIatA9EHI3AhuAgB3CRYXF52VqsniFubg79MK1PewBeNrUWiqkOk1UhFQGVAhWxiDVhPJH4wV80Ixf3l6ehqTk5OO/ADAKWBlour3Gn2QVjHz4FEZs3La5OQkRkdHMT4+7tKHNLamiIMPedDP8nl96w2eA3pcymKnMUQSFwU+vWKy8DX2popDjUCfURG19W2FEb18XYfG2zRuT2HAeDKNNc6ZleroSTI0og1ULELUiHKwKJHCkRrC4ZmfmZnBxMREhVKmwaqd0Ajr6R1j5yh9jkwTrHW+akQQos5kMpiYmHBdqnj3qJCZrUH5wlxxJd4pVAyUG2ZoA5xmFDGHet02Vr+0tOQ8diozGtMqd4IgCKF4VMw+L97GROsdqpzte2iMVe80C2hYmJtV6Hp6ekK6YCO9YztXi/xoqJT7SD6Nevs05vS9tMoYlbKGEoDmjIy6FLIV0ja+5WPKcZEkTyhZpKWlxbGVtUIL/5aCHEDoIFdTxmtthD4cG3tSwoIKCoXI6B3TelbLilYihTIFlsZgtXCHtayi9OSsAqSHrJ6DIgCqxPSZ+V6Nzsda01qOlIJeWZwaI+aeA+XOVjYWXy3Oqhe+2f21ypgKjeEAKmMqLcui1b+jhwyU7xHJYPSCeeH5okLj+akXQvXBjESJKOhtfV+ubXJyEjMzM6GYPtdAtqkqY/WOeH/JyOVdr2XuVhnrvaVCHh8fryirWigUUCyudnuip6sVr1Qgs94+ALevhN+bVcRAmEjH9+UoFouOq6IKGUAIdeF7MHSgKU7kudh5Nnru9W9sjJffA+DmyoYLPMM0JGzzBiIkPj2xEd6y9fQt+qNIA180+LSAj+ZOqzK2DPKoyKNNechWGFriBNm7HGrFEcql5WqtESpCAE6AKMnKwpK1wta+GAmFiEJq/Az9fE2M18PGQ0jIlTmCHHw/m060EakhNtZmy/NRIXBemsqhZBmmn/lgYo569tw3H+4Vv7KQA1BmcPJMMC6+uLgYKnpgFS/nyrVFFTdWyFmZ+eytqsVLNE7Pi67wohagoGHKtc3NzbncVJ5NoKzQ9KzWO39VyHwOVLr6LPiisUpjyULxsVgsxKfgueH7EqK1XIBaFJ3Pu7GpTVNTUyHkR5nVQBn+VbnEveYd5ro1PEZIuVllzGHvDVC+ayoLOGfOjZkf3A9mGRDd4nvbVyMyxSpiHxmL8+S9UoRQPXUiQypjeV74PqrYfZk4jQ4bO7akRJXZGq6xzt/KyoqT2ST30SCmQvZ5yI9oDNmniAmvpdNpAHCHhwLW9kplH052JqLlRWVMz0AVRxRwh7VcNcbL3sZK7gDWLnvpEzIKpdr8Xht7iEJx2D3RtWlIgJdByWWM9c/OzrqKQbom/rtRopQabzZHnWeHLFMAbs/oHS8tLbmQBy1uewb1faOCrNWAs2lyWmebMLz1kpUkqGQS7jkFhRqZXDuFmULf9aZZWGWsRhEVG71fKmFrLFHJabqLomKK/ihPgnngPiSoVjRL05po9NruX+q985xYPoENpdn8d85PiUsctaJwaw1VyvTaq2UEaHybucdBEDijmspNK33pe1ULI9Y6R2vEaKjNl/5F9EufH7/aM8/58U7ZkGUzStkqYkVq9FkTEVI0Dgh7/9Q7/JmWIvY1znjUxJBV4XDCQ0NDAOBaVPFvKHy7urqQTqddZReSLSikqMQ1zclaf/UKWv17HjYqXrWgaAiQuGVb4/EhWesRQAiyY7yBRggr8xByjLLghl0n583OK729vY5Mpy0keUFisZiDX6emptwFW1hYbWLPQ06o0ULctcxHGfQ9PT2hNDdCipbExTPBmB89Mo37WYRGBVyzcR293Na7J/Q+NTWFiYkJTExMOLhUezZrzE/Z4mrw+Aw5FYzVGOPrrakazE4PNpPJuLgr0R0qYgosX5yVd8Gm5jD2zSp2PHvpdNpVuuM6apk3PXn11kmeJFGRzGqdI+837xkVAGFLX+s9nkX1/puRN9WGPttq55ZyhEVN6CVzL0qlklsXc31pDPE+WVLmevNXhIlM8K6uLofWkOSnbQctiY+ym89Qw3pEitQg4txJVLRQeb3DKmNNrfS102WmiXW8+FxUhrIOxcDAgGucwdLJURK76lbINvbFi1ksFl11LR4UClglD2mZR3rIhKtJcqimjOsRRr6h3kdXV5cT/irsuB717G0FIrW2GUdhUwfOb2VlpQIN0PQPGzNv9kFag6OjowOpVAqDg4NuLiTkWBp/LLaabjM9PY3FxUXkcjmkUin09/c7IaekBj0H6wlXzkkNMWAVgk2lUiHhqHnbusfLy8vI5/NunoTu9FxpHqT1FJqxuK1ioOLKZrOYmprC2NgYRkdHnUJWyFcVsTJjNXOARpGt70thoG1K9eLXMvdq4QKFfEdHRzE2Nobp6elQzrcvJqyKg3eYZ5tMWo2v8ecsGEHhtZ5SVqHNeLFWb5uYmMCRI0cc2Yxd5ShYuXeMV9OQm5ubQ6lUCuW88zkRjtQcW18MtZlBxWXvqnqjJC6urKy4hircKyob3kPeITaQ0Sp1jRj8Ks8Zc2fsl8+D3i5j8b67ynnyrnZ1dYUaqfAu0dOsFn6qZ1jv2BrQbKWrfAiecQ1P8pmo88jSoMPDw64aGT1lxpH1XjZzVpqGrKmEOOid9fX1OatIf6ZKmVY1YUmNGVc7SPq9eh6ezpfCBEBIGKq3TzhDL42FRXnBSMiJx+NOoRCOUWKOVviiNRulh8w1sk4xDSSS5+hlaK4mhS6FcC6Xc39LpWfjyLr360Gnaogx7YGsV0uqoCDS2A8tXOZtsqcwvRkKd4VQfbFvzqWeYRUyPWMqiMnJSYyPj2NsbAzj4+MhAa/NUhRytgJfYUF6mHpWSPKqVp97vWH3UqHfyclJjI2N4dChQ5iZmXHGmk1jstC0driip2BJaDznVIza8H09dEUVwPz8fIhRTYLixMREqE85ACcY+Zm8ZwDcs1hZWXEKWUlSZAkzJKBnKqqwkg3/8BzwfmgIiZwVevVA2PPnGrVMJati2XNSy1lR45AOSyqVcvvBc8w9amtrC5UFVgOQv8eypO3t7SGDm3fJeqjq2auir3VvrULmeVcuBAmllG3qHav8UMY9ezYPDw9jy5YtrkY37yqdqyjOSMOQtX7VyfABKPlD/44HRKu6qOWoXpsV9qqA6124hWQ6OztDwkYt9paWllC9ZCvw9fAAcPFmHkIeBgtXayqLegpRwWFcJ59BT0+PU6hdXV3I5/Oh/LnW1taKtBeyTAnZE2JWK5D7oZDYWvNRr51zI3NYPTCbK6jwKgAUCgVnwDF0YOG/akI0Ki9ZIWutDqVsfSWMcE1WCFP40XiiR0JPh//2KeRalbGP3KLEKypmjSMzRUmVMT9be9gy5DQ4OOhKynLehCE1rdEqivUUsjLz1cOhQCXsyIYuPFeWr8FzQpIcmdX6rBg+oayywrkaUlfvObKQKl9KJqUyYmYH7yawiiolk0lnaGhKke2oVK93rIinnknOT/kTzHawmQRqfCp51Nae51kkQVcNqEbIdNbQsfFjNaZpjNGY555p3X7uIc87S5hq0wzK0agJuk15yAC8go+XQi0n/TuNnSi7UNNBfEnp+vf1Lt7CRCoceQn4WVRIFKRcp4WY+D4UwPqASNxQ74evqJs2VFsj0zjoiWn1HF0TAHeRGB+iMqYnrXCjogXrXR41hNSD12fri/1Q+RE9sRWsEolEyNtcbw7NDB9zU7/q7+ga+NXnEalhSCIk465amJ8KWcu1NhNjAyr7w1JwacEEVRKxWMwZUfTIent7MTAw4CA8hmTU6NQaxtVQi7X2m0pZPXtf20oiarq/CqOqYUQY3L6HesZqdPuqMTWy73o2bMU6S/qza+adBOBypLWZivJSbPy4njmr7CDfg/fTl8amzXPoFXNtdKioqLnvigrSUUin0xWFdPS+NLPXlu3N/aVCVkOTMLXKS3WmFP0hJ6JexGq9EUnpTFUE6q3EYjHvBvssThVsNk1ELwuANT2hWuap7DkOVQ4tLS0O4uJn6xr1szXfWC1rEi80plUtHhiVQtb9VUuXa7aGjHoA9Ea0cAVhM5JfeJlUyNKYWW9OQDlWb5WYD26iobayslLBIrXP276HJUutB6uvN3erROn50RPs6elxubkcXIt+vhp1KgjYao+xKe0mpkrOFpOp58zruVVmu3qTetcIJVIopVIpDAwMhFrssQk9PQZ68tZzsPe01vPuM4LUOOd7WqHIM0ADk2dOU5204QqNVh9jPArjWWWa5uqSk2DL7WrurqYgEu3i81MD31fXoF5lbBWyOlTK8ifjWOdII5r1D5QXQieHBU+YW93a2uq4RnSA9Hk0qoz13NgX5QAdFsaICT3r3nGvKXe477YwVZQyvGmFbJWuWrbcXH6Pv6+xZ4WnbX6wWrAUaABCFrduRq1xEgon/R4HY+BUTNWEOQ+p/Zn9DL3gFODq7URF6uKcdB38fDWCuGcK7dGAUK8OKKczzM7OOviPHi4PpuUQVBu6Pu69KkwA7iLSKFIURQUNz4EPolL4zF7CeuNSejH1AlNxaToUz00ulwuhPvwdjVdZBvzAwACGh4ddQ3flGtiUk3pyHu38NQZMr7y3tzfE3G1tbXX3LhZbbbBAeHrLli0YHh4Oecma/qFtL33Fe+o949bQ15CRpsopAZReGRUBzwrPt8Yv+Rz1rtiYN9fV6F3VM0plpVwOMvU1j52hA001o+GhMV418KPw5K3cslA6FVRPT08oLKMessZsmY/P3yEUz7KmfC81kOjgkOjFeTUz7D3gHaIyHhgYQDKZDHGZGLJRJ9M6nI0gteuNphWyCjuFG5UAoJYWN4YPHICDH20wng9dBZpl09Zroai3poJC56VCV5UFUIbWOUfCYaoA1AOll6oK2Vc6MypYVYcaHlagqddOhcznxiYghPtmZ2cdlKOCq5GKUbpOhRV1HXrwrTC3f6OKUYVDIwU0fEONFxXgKgAZi6KQJALDedDYpCLRWCwV3datWzE4OIhUKhU6IwqdKvmwlti9ClgqZO4V2c8DAwOhGGxHR4cr5AEA3d3dzmA46qijMDIy4m1DalEM9dbsc19v6DNXRaD7x45JQJkMalMnrYenJEF6ZEC5CA3vBOOyWvyB4ZpGFLLGw5kuR2Kg1uBmGVDC8pQtqky06BKNNHsumjV+1Mi2RhB7OSv0q7Awm8VwbZrZwX3XSoHpdNplTagsjULZ+YwJyq7W1nJL3aGhIaRSKbS0tLisGUW8fPwP/jvqEQlkbeFGm3umUBMtPaW664Xh32tpM/1dFUyNWK3WG/Y9NPXsredPK5DKSmNxqgB88OBGVOryzdF6nnpJ7eHks9BKYyrIeMn4u7TOFbloVulR+dfqSfEsKRSlBp2GOtZCOWqZl0JXZIlTSHGeqrQTiYTLIc3n887aVy+bXhh7yA4ODrpm9CySw/3W2KsSEOvxkHlvgHKqHhVOb29vCFbk55G5zHnScKBC1iYY1ji2BlWjw55V7p0a+/F4vCInnedJHQI12FRZq5GvYYhqsfv1DCEdFnKnh8yUObLE+VVLk9pCLIpKcS+0qIZPljTjJfN8WzIrGfbkRuj9U4XMfdNsEpJeKdvb2tqcd6wpgb6wU6MGhvWMaViTIEeiViqVctwhABXGtEXbNkIZAxErZMtEZWxBN1aFAw84U1h40bTwN60UXnpuajNWoVo4Csvw4FshrpeKwmthYSF0EKkE+L5KfrIeTqPGhG/fdW7K2qx2cBSqB1ZhaY1RMuXLpjuokWQvzkYNNTDsM9H/KzObX3170chQAcXPVKGvsDVzLy2Tni8ttqAVfxiTZZEcFWIWqm/krGuoh566TbWiArAxvJ6eHjdXCq++vr5Q43kfiahZ1Ef3XVEmLb8Zi8VcupKFGIGy4abnVQ1JOge6H1rkxOd9cm71DKuQmQeuNdwpLy1vRWPwPsVsiVzN7LvKxbUMD87DIltUdLFYLBRTpk5gLJ+lY5XQZhVxo/OvFuZQpwiAIymq8cU1s5QqUA472dSujZJ7kUDW1ntkDEErXwFlT6Ktrc39nw+Jyk5JDRQSFE4KszZLtrBKmYLWt+G6Pl58NUC4TvWOKSy0OLkv9t2sd8x52cITNv6jh5V/Z2MknLeS3nwKv9mLs96arFWqKIsPubBxZBtCiOqCK5THfedztXCtehaMl2kVKyo25jRqupCFRy1MVi/8C5Q7pfE5q5DXF5UT/0Zjquo1avqZj8cRhTJWZKK7u9vFtZUUR+/Kd2dp4KtwpVFCAykW8/dLtrHwRu+rGo2Ey1lxjPFiGvdcW0dHB+LxuDvLatzrv6OWJ755631URr7lBOkz0/Nl31MNPjWadUQFUyvUzpBSd3d3KDyj/ADOUVEVPodmZUmtoymFbB+WdmLJZDKhrjBqqTB+zAdGRp/tOMNYpsJ9WgFIhUKzlqF6y1ybrhMoQ79qfGhshMqaudVKcqnmSTSz97rvumc25u77bAoJCi2bM677wn/bz496qBFgL4YqWv1sC5lZyDoKi1aFTRAEXoFjDRagbIDG43F0dHQgnU670nuDg4MYGhpyRK5qRmazEKT+nb6fja0qhKt8AVuasZqiamZ+vvlqGICGDPeRDF+tKGbJozTwmcfOSm9kXbe0tDhlR4Vv12lj4Y2sz6JrZFZTRjJeH4/HHdNXDX1ttlLtFaVM8d09zbZQRjQNCMqXIAhCDpWmddk0Un3O1RCgWs+VNZwVotZ8YpLReH67urqckc3zT8eG+59IJCocnY1Uys15yB/+MOLvfS/aRkcRP/FEZC+/HPmREdeblPFIXnItCajCl1ANCy2wdi3p/mSXkgjDIgQKm9VzYfKLebzlO2/BDffcgPHZcZwycgo+8JwP4LSjTnO/YxUz/2+VIKv9kLXJtfIg2IIOjcJeAHDVrVfhqtuuwoOZBwEAJwyegDc+5Y142panuUo0tLiBcOyJhovuFSFWrQik61XvysKn9V7+m/ffjH/44T/g9sO340jhCG644Aacf9z5Fb9nFYV9+bxkW5hAIWs1MhqNIx/zgWOwP7u/4mcv3ftSXLrr0hBvwrJ3ie4kk0lX7WdoaAgDAwPYsmWLI3Ix5hY1e3PX+3d5537hsRfij0b+yAkfraqnnoUWnrDch6i9Mg5VyMyF/fi9H8f77nwfXrzzxfiz3X8WEvIWteEdZTxzeno6lF2g4QcS8rSdns13b4QpXiwV8bff+Vtcd+d1GC2MYkvXFjxn5Dl4RuwZrjlGJpNx8C0VBD0yZjewmp7udTWl3CiKoiMIArz9prfjHd9/R+j7xySPwb/95r+5zmBEHijXSahjtT+mc5HUpUiGhgoUWbIcBK6l5rkjwNu/93Z8+q5PY6wwhuHuYTx/1/Nx4bYLXby4WCw6BIKZIyS00inUErJUyFHB6uuNxhXyZz8LvP71KH34w1h8whOAD3wAR11yCf73n/8ZmV/14C0UClheXnawE1/qjTGmZRUylQQHGX5UyL6KV7WOV3z5Fbhr/C588vmfxLbkNnzqfz6Fsz55Fn72Zz/DUamjQr9LbxIIQ8S2e4iW2lS2tq/CUqOXZXtqO64860rs6duDYqmIa++4Fi+58SX4/875/9Bf7HeVjNgRRuNvlgVLCFPzMi2z3FrAGveuV2HMLs3i5OGTcckTLsELPveCNX/XhkFsKpN9HhZWX4sc0si45RW3YKVULun50yM/xfO+8Dw8e/uzQ0VMeIlp2DDM0tKyWrqUbOWRkREMDAy4vF6thNYs2mPHra+8FcWgXNT/jsN34Nzrz8XvbvtdLC8uV7SjC4LAnRsyjvUcV4Oooxw8X+Sa/GT8J/j8g5/H8X3Ho6urC1u3bq0ITdjwBbMDpqamnNDlMyFnhYqQKWha8KHZtK13/+Dd+MjtH8G/nPMv2N2zG9+///t4w/ffgHwqjx2FHU5hlUolZwAw9YYkSzWa7d2N2kO2EPXxA8fj+nOvL+cVF+aco5XL5ULNd8hETyQSWF5eDiEArGJH+aLInS9FzsLvta4nFovhPT98D67+8dX42Lkfw76+ffjRgR/h0m9cio5YB85KnoXe3l6336VSyRkRyj3RWv/0kJVEuNGwdeMK+f/9P+AVr0Dp4otRWlxE/sorMfj1r2Pkq1/FXU96klNWKysrDh7yxSGZd1YsFiuq8FBp09NjHIACohFy1PzyPL7wsy/gxgtvxNOPfjoA4G3PeBu+fO+XcdVtV+Edv/OOqn+rnptW0tEepWrZR82oPnffuW4epVIJb/utt+HqH1+N28Zuw9N7nh7qbUs4TmNivAQKMTEvUNMsuOcap/KltNRzYc7eezbO3nv2ur9nYTP7PR32s62yjuriDHUPhZT+u+9/N3alduGUvlMwNjZWEbqgIcnnTriaaUaDg4OuoAa94yhCL9XmDpRJN1/736+5uR84cKAixVDjyCR8aZpeowqqnkGFDACFpQL+9Bt/ig8960N4z3+/B+3t7ejv7w8JRz0nCg/n83kHo5I4pZ4dlTMNfV/N8EaMTwD44YEf4rx95+G5e56L+fl5JI9O4rM/+yzuL9yP4aVhB40GQeD2lt4m5SMVsaYBqZfv8+CbHbw3iVgCA+0DmC/NoyW+mg40lhsL1YJmWI5GG9OGGCO3BU8Y0vOxxKMI6f3Xwf/CefvOwzmPOwfFYhHbe7bj+nuux8+yP8N5w+chmUw6faRFfKhraFgreZWOC/dmo0djCnlpCbj9duCyy8qbFo8j/5SnoO8Xv0DstFXoV70c/p4yYpeWlpww4oawQxShGWUVqnJplKm8UlpBMSiio6Uj9P3Olk58/6HvV/07G1vxwZSxWKwiR9AqZCAar2KluILP/exzmFuew+P7Ho/lufJ8dA/VurbwEAAH7RGGIrlEY/5Keoki53G9oVA5/2/JZ1Y48d/Wa7Dv2exYXFnEZ3/+WbzixFc470VhLvUyWS+d1biUyGUbAkSBoKw3lopL+Lef/RteedIr3b3UEo5qjCmBspnUn0YHn/lffuMv8dy9z8VzHvccvPeW96KlpQXd3d0VRpqFqylTCoVCyJggyY6KQRvAqKFvz1C9z+SMHWfg6tuvxr1T92JH1w78fObnuDNzJ17Q8wL3nOm5a3iMz0B5K9WUsVXKUXrJ92fuxxM++QS0xlpxUvokvHT4pViZXXFw+9zcHAA4Yp0qZPKIrHdMNJR3wyISjXrHds9/Of1L7Onbg7un7sYtR27B35z2NxXEQO1mppXGqJAtAXIj5Z2OxhTy5CRQLALDwwDKh6Y4NITOX/7SWdaMhWj8jxa55qiRBEUhQVYq48Z8cEp6aRTeS7Yn8dTtT8Xf3fx3OH7oeAx3D+Mzd30G/3Xwv7Cnf8+af6sXXj0LCjNefK1dXQ3qa3TcOXYnnvqxp2JhZQE9bT249uxrsadnDw7PHg4RMTR1Sa1OPWBAOSbOtWjsU72ktUo4RjksOUMLY2jMiUUueBZ4NqqlgvC9GxkKhX7xni8iu5jF8495PhbyCyHSEFu60evhnBg/ZtlJKmNCfRvlHds13PiLG5FdzOKFe16IlUz5DNO7Z0yVPAitlfxwGQ06Pnv3Z/GT0Z/gllfcgpZEOS+WUDbXxa/0mgn12lAWnyOVMYU0U7sYOrDPo5G1XvablyG7kMXjr348EvEEiqUiXrXnVXha8DTcO3kvksmku6uJRMIZxpSTKlcYOtKMDX1ZpnUzIwgCPGnrk/CPz/xHbO/YjodmHsI//s8/4vV3vR5/lfwr5/0WCgUn89hikQYGU7s0nQsod6siIqFdwqIoJXzZb16G3GIOJ/zTCW7P3/qbb8WLj3sxstksuru7naMSi8UcYuKrj67hDMuf2MjzH0npTB6Y+K8OBYuFt7S0uGR3em7K5vRZfFTUFATpdBqpVCokvFpaWpq6LJ98/idxyZcuwVH/7ygkYgmcuvVUvOSkl+D2I7dX/RtrgaswI3yjCllb0PEiRSHM9g3uw09e9RNMz03j+ruvx6u/9Wpc+4xr0R3rDjGobSqU5m7qHNQ6BxBSgjZnVuP3jVYuWmuoMrbnwb6okDVHVb2cqEIFFhn5xP98AmfuOBPpeBqj86Ougwy7D7FzkLYeTKfTrvhHf39/qNykD0HZqHHNT6/BWbvOwpaOLTiwfCBETJydnXUxQUKRVFbWc3w4lPGB7AG85j9eg2/8wTfQ2bpakCUWiwGxci49EFbIPsNLES0NyVDYMg1NW+qpcmh0fO7uz+HTd30a//q8f8Xe9F786KEf4a0/eCtaR1qxq3cXCoUC4vG4665FqJeGn1Y8VEPEVury1cVv9vmcdfRZjlS2LbEN207aht//4e/jR4UfYSg/hEwmg2w26+Bf3lHlpRA10vAN4W0N37BPdhQG3+fu/hyuu/M6fPqFn8bxA8fjx0d+jDd84w0Yah/COTvOcc4enT8qYRZEojFt0RMaDb6wTdSjMYU8OAgkEsDYWFiAZjJYGR5GX1+fg+3IaKQHpgnwXDwhMpbzo0LWogn2oTVzWY7tPxY3/Z+bMLs0i9xiDluTW3HB5y/A7r7da/6dKmRCfVTImm+quZvVPMpGH2Zbog17+vegmC7i8YOPx62HbsWnfvkp/OmOPw15wjxwZIHbIi38fPUwtXay1lvWTkRqZETpJdOQsN6xr0wgLwv3m8Yb+QUWTYlCIReLRTww8wC+e+C7uOoZVzmBRY8hk8k4kg6VGD3ivr4+b2ekqGoQ1zL2Z/bj2w9+G5/8vU86HgcFJxWypoJoERPdz4cjhgwAtx+5HeOz4zj1o6e67xWDIm7efzM+fMuHsfjmRSTifsVsSX5UbkTqeOaZTkUZowq5WYjyjd94I/76aX+NC0+6EEtLSzim+xj87+T/4ob/vQFvH3q7IwyR3KUhD02xIUGUISct6GIVcrNhMcvfoKzrQAdGWkcwujiKzkKng61Z1Uo/W0l1DN+QPMVwQ39/v2vdyX0nqbGZu/rGb7wRlz3tMlx40oUIggAnbTkJ+zP78YEffwAv2vsit6cMjwIIscLpRHZ1dTl01ucQbiSpsTGF3NYGPPGJiH372wie97xVhRyPo+MHP8DMS1+KVCrlLFFCGAAco5fWCB86PQoWWhgYGHDWlCoBa6E3uxndbd3obuvGzPwMvn7f1/Ge333Pmr+vF51eMl+q4GzVn6hIXTr4PqWghKXSUkVaEr1kluqzSplDayvTyydkTRKdrzevPZhRDQpAzSVUeE6hYEJm1otWKK+Zy2OF/Cfv/CQGOwfx9JGnIzuTDfVZ5YtwKolc7IhE5WwbySspZyPHtT+9Flu6tuB3d/0ucplcqEytkhJpFKnw17luxDP3jWce80zc+ad3hr73hzf+IY4bPA5//bS/dsqYZx2onkerKXFUyJQ56gGp4dGsgTS3PId4rJxf29LSgtaWViAGpFIpR3IiD4EeMvkfQFmWcD4arvExlKMylKxSLiwVMLY8hmNXjnVOlfbO5hr5uUrapZJmzXxtqkIjPypSI/ecIxaLoSXRggCrISQ1Mthximz8bDbrUqJoHFjH6uG4A41D1q97HXDxxYg98YmIPfGJaH//+xGbm8PCS16Cjl/FCdgjmEJKSTA2eM6ya5qOQEswaljm6/d9HQEC7BvYh/um78Mbv/FGHDd4HP7wCX/o/X1LHrGXnXEgINxP1RYXaPbCXP7Ny3H23rOxI7UD2YUsrvvpdfjBoR/g47/zcXcpVCFrKVMlVzD9gNBMd/cq3N3d3R0idXEN1kNtJGZVWCrgvun73P8fmHkAd4zegf7OfuxM7wRQWcCCyspHYrEMcH2poovK8yyWirjuZ9fhRXtehFgQC/V/pUKj1Q2UIVFtpVitiXyUZD/fKAUlXPvTa/Hy33g5WuItoTOsXdXi8XioYpTuaaP5uI2OZHsSJ205KfS97tZuDHQOVHxfh03f0buqeek8Q5aAqc+kmTWe+7hz8c7vvRPbk9uxr28fbj1wK675+TU4Z/s5Ds0hrJtIJJxspOHM8IHKFd9dsOzkZkcQBHjL996CZ2x7BvpifXhg6gF86J4PIY449i3uw8GlgyHyqBbqodzR/1sineWl+NJC+beN7Pnff+/vsTO9EyduORE/PvJjfOCWD+Di37g4JJN59wCE0E7yDugkWiLxRvFmdDSukC+4AJiYAP72b5EYHUXw+Mdj9vOfR2xkBK35fIi0ZQlEWqtaFQAD6kA5lsmNjFK4ZhezuPxbl+Ng7iD6O/vxwuNfiL//nb9Ha6K14nd9VHe1Hm1uLFC+PFELsfHZcVx0w0U4UjiCdHsaJw2dhOvPux4nJ0/G6Oho6DN4MbSVpU8hM7RAax2o7JJi2Zx2TbWs67bDt+HMT5zp/v+6/3wdAODiky/Gtedf675v39eyTKu9LBuymsHQKJT37Qe/jYP5g7jgcReEUva4bywnqEQ6NWjUs/cZNBt5yb95/zfxUPYhXPz4i9169AxzLTScNWzwcDJMoxpWIVslbRm01fJgmxkfPPuDePN33oxXf+3VGJ8bx0j3CF523Mtw0dEXYWZyJoTiAKgwkGgIWflT7fw3e470cw4VDuHSb1+K6YVp9Lb24vju4/HX/X+NmcxMhSOi6al2cF6af6zIl4982cy+f/DsD+It33kL/uyrf4bx2XFsS27DK095Jd70tDchHvjLetr1qP7hXH2ZG49oDJkPK5fLhX9w0UUI/uAPQikU87+KWTImonldhExt9aVEIhEqV6bxZlqtjIsS5/ddGs5PD5dv7s/Z8Rw85/88J7yWJSC3ZNaHcOoW2YP0NrXCFdegZC+1qqgcqxFjapn7+858H3BmeU6sjTszM+NixXa/tdmC/ptWIK1WPhtV3louj/XHtX4zD2k+n1937qf2n4rsa7MV+2vXrqQ5XzU0enOqDPX3uffs4cz0Fz039ew7ldVp/afhgYsfcLWIbUs57i0NSy3BNz8/H4Iei8ViJPBXref99KHTMfOaGSwvLzuGLOevOZc6Zz57wncAQvvYLMxe69x1fOkFX/L+XD1h7XmssVnmWesZ5/fn5uZCZ4Px9VrlTLV5X/HUK/D209/u5sRqelp4gnPSF2Fgq/hshUA2SmDnpFoQxLXmTjn3vjPe5+5QNpvF9PQ0Dhw4gCMLRypq5qsSs4OGvi1t67ur/F2iAj5vv5Z9v+KMK3DFGVe4nwdBgMW5xZAnzLCplZVK/vPpIp0nOQC13gHfefeOoIZx4MCBAMBj5nXgwIHNuW/OfXPuj4HXr8PcH2vz3pz7Iz/3aiMWrKuyVy3jw4cPI5lMPqohqyAIkM/nsW3bNmdhbc5948fm3B+ZsTn3R2bYuT9W5g1szv2RGr7z7hs1KeTNsTk2x+bYHJtjc2zs2PgaeJtjc2yOzbE5NsfmWHdsKuTNsTk2x+bYHJvjUTA2FfLm2BybY3Nsjs3xKBibCnlzbI7NsTk2x+Z4FIxNhbw5Nsfm2BybY3M8CsamQt4cm2NzbI7NsTkeBWNTIW+OzbE5Nsfm2ByPgrGpkDfH5tgcm2NzbI5Hwfj/AbyTpMEt+rhtAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(10, 10, figsize=(6, 6),\n", " subplot_kw={'xticks':[], 'yticks':[]},\n", " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n", "\n", "for i, ax in enumerate(axes.flat):\n", " ax.imshow(images_test[i], cmap='binary', interpolation='gaussian')\n", " color = 'green' if label_pred[i] == label_test[i] else 'red'\n", " ax.text(0.05, 0.05, str(label_pred[i]), transform=ax.transAxes, color=color)" ] }, { "cell_type": "markdown", "metadata": { "id": "x7IIiVymuTRa" }, "source": [ "In this tutorial, we have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the [Flax MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website." ] }, { "cell_type": "markdown", "metadata": { "id": "5ZfGvXAiy2yr" }, "source": [ "## Key JAX features\n", "\n", "The Flax NNX neural network API demonstrated above takes advantage of a number of [key JAX features](https://jax.readthedocs.io/en/latest/key-concepts.html), designed into the library from the ground up. In particular:\n", "\n", "- **JAX provides a familiar NumPy-like API for array computing.**\n", " This means that when processing data and outputs, we can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html).\n", "\n", "- **JAX provides just-in-time (JIT) compilation.**\n", " This means that we can implement our code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping the code with a simple [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) [transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).\n", "\n", "- **JAX provides automatic differentiation (autodiff).**\n", " This means that when fitting models, `optax` and `flax` can compute closed-form gradient functions for fast optimization of models, using the [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) [transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html).\n", "\n", "- **JAX provides automatic vectorization.**\n", " While we didn't get to use this directly in the code before, but under the hood flax takes advantage of [JAX's vectorized map](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) ([`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.\n", "\n", "We will learn more about these features through brief examples in the following sections." ] }, { "cell_type": "markdown", "metadata": { "id": "ZjneGfjy2Ef1" }, "source": [ "### JAX NumPy interface\n", "\n", "The foundational array computing package in Python is NumPy, and [JAX provides](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jax-vs-numpy) [a matching API](https://jax.readthedocs.io/en/latest/quickstart.html#jax-as-numpy) via the [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) subpackage.\n", "Additionally, [JAX arrays](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) ([`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)) behave much like NumPy arrays in their attributes, and in terms of [indexing](https://numpy.org/doc/stable/user/basics.indexing.html) and [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) semantics.\n", "\n", "In the previous example, we used Flax's built-in `flax.nnx.selu` implementation. We can also implement SeLU using JAX's NumPy API as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2u2femxe2EzA", "outputId": "89b9f9b0-5631-405c-f4d8-2198593d0d50" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0. 1.05 2.1 3.1499999 4.2 ]\n" ] } ], "source": [ "import jax.numpy as jnp\n", "\n", "def selu(x, alpha=1.67, lam=1.05):\n", " return lam * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n", "\n", "x = jnp.arange(5.0)\n", "print(selu(x))" ] }, { "cell_type": "markdown", "metadata": { "id": "H9o_a859JLY9" }, "source": [ "Despite the broad similarities, be aware that JAX does have some well-motivated differences from NumPy that you can read about in [🔪 JAX – The Sharp Bits 🔪](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) on the JAX site." ] }, { "cell_type": "markdown", "metadata": { "id": "LnDgHRBsJrYL" }, "source": [ "### Just-in-time compilation\n", "\n", "As mentioned before, JAX is built on the [XLA](https://openxla.org/xla) compiler, and allows sequences of operations to be just-in-time (JIT) compiled using the [`jax.jit` transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).\n", "In the neural network example above, we used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which has some special handling for Flax NNX objects for speed in neural network training.\n", "\n", "Returning to the previously defined `selu` function in JAX, we can create a `jax.jit`-compiled version this way:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-Chp8yCjQaFY" }, "outputs": [], "source": [ "import jax\n", "selu_jit = jax.jit(selu)" ] }, { "cell_type": "markdown", "metadata": { "id": "zAKNrQbxQgC7" }, "source": [ "`selu_jit` is now a compiled version of the original function, which returns the same result to typical floating-point precision:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uHeJXgKURL6q", "outputId": "dfc5a602-2b28-4863-a852-38f8fe6aaab4" }, "outputs": [ { "data": { "text/plain": [ "Array(True, dtype=bool)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = jnp.arange(1E6)\n", "jnp.allclose(selu(x), selu_jit(x)) # results match" ] }, { "cell_type": "markdown", "metadata": { "id": "WWwD0NmzRLP8" }, "source": [ "We can use IPython's `%timeit` magic to observe the speedup (note the use of [`jax.block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html#jax.block_until_ready), which we need to use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SzU_0NU5Jq_W", "outputId": "dba1ee6b-32f8-4429-a147-b6d4f4e6f0ff" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "8.32 ms ± 489 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "%timeit selu(x).block_until_ready()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QOu7wo7UQ07v", "outputId": "bd91aaa2-d367-47e0-eb17-a90658de2d14" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.38 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" ] } ], "source": [ "%timeit selu_jit(x).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "id": "1ST-uLL9JqzB" }, "source": [ "For this computation, running on CPU, `jax.jit` compilation gives an order of magnitude speedup.\n", "JAX's documentation has more discussion of JIT compilation at [Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "XFWR0tYjLYcj" }, "source": [ "### Automatic differentiation (autodiff)\n", "\n", "For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the neural network example, we used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad) function, which has special handling for [`flax.nnx`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/index.html) objects.\n", "\n", "Here's how to compute the gradient of a function with `jax.grad`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JtsPYnKbOtZt", "outputId": "834c31f8-ed1f-46ae-a827-e0b7faa52181" }, "outputs": [ { "data": { "text/plain": [ "Array(0.6450766, dtype=float32)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = jnp.float32(-1.0)\n", "jax.grad(selu)(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "1P-UEh9VO94k" }, "source": [ "We can briefly check with a finite-difference approximation that this is giving the expected value:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1gOc4FyzPDUC", "outputId": "95053e89-048d-4331-b898-079818e23dae" }, "outputs": [ { "data": { "text/plain": [ "Array(0.64539903, dtype=float32)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eps = 1E-3\n", "(selu(x + eps) - selu(x)) / eps" ] }, { "cell_type": "markdown", "metadata": { "id": "pkQW2Hd_bPSd" }, "source": [ "Importantly, the automatic differentiation approach is both more accurate and efficient than computing numerical gradients. JAX's documentation has more discussion of autodiff at [Automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) and [Advanced automatic differentiation](https://jax.readthedocs.io/en/latest/advanced-autodiff.html)." ] }, { "cell_type": "markdown", "metadata": { "id": "xsKyfRDNbj2y" }, "source": [ "### Automatic vectorization\n", "\n", "In the training loop example earlier, we defined the loss function in terms of a single input data vector of shape `n_features` but trained the model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX's [automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) via the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation to construct a batched version of the kernel automatically.\n", "\n", "Consider a simple loss function that looks like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OuSSCpxzdWw_" }, "outputs": [], "source": [ "def loss(x: jax.Array, x0: jax.Array):\n", " return jnp.sum((x - x0) ** 2)" ] }, { "cell_type": "markdown", "metadata": { "id": "lOg9IWlPddfE" }, "source": [ "We can evaluate it on a single data vector this way:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sYlEtbxedngb", "outputId": "39030fb7-feee-4da1-ef5d-54cd86ad8dfb" }, "outputs": [ { "data": { "text/plain": [ "Array(2., dtype=float32)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = jnp.arange(3.)\n", "x0 = jnp.ones(3)\n", "loss(x, x0)" ] }, { "cell_type": "markdown", "metadata": { "id": "STit-syzk59F" }, "source": [ "But if we attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1LFQX3zGlCil", "outputId": "a12c4d75-2d94-4341-e9ca-915a33f1278e" }, "outputs": [ { "data": { "text/plain": [ "Array(386., dtype=float32)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batched_x = jnp.arange(12).reshape(4, 3) # batch of 4 vectors\n", "loss(batched_x, x0) # wrong!" ] }, { "cell_type": "markdown", "metadata": { "id": "Qc3Kwe2HlhpA" }, "source": [ "The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways we can address this:\n", "\n", "1. Re-write our loss function by hand to operate on batched data; however, as functions become more complicated, this becomes difficult and error-prone.\n", "2. Naively loop over unbatched calls to our original function. However, this is easy to code, but can be slow because it doesn't take advantage of vectorized compute.\n", "\n", "The [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) [transformation](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) offers a third way: it automatically transforms our original function into a batch-aware version, so we get the speed of option 1 with the ease of option 2:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Y2Sa458OoRVL", "outputId": "d1d8295b-40d3-477a-e5d8-b2d6f28ad803" }, "outputs": [ { "data": { "text/plain": [ "Array([ 2., 29., 110., 245.], dtype=float32)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss_batched = jax.vmap(loss, in_axes=(0, None)) # batch x over axis 0, do not batch x0\n", "loss_batched(batched_x, x0)" ] }, { "cell_type": "markdown", "metadata": { "id": "6A8L1QDFogKd" }, "source": [ "In the neural network example earlier, both `flax` and `optax` make use of JAX's `vmap` to allow for efficient batched computations over our unbatched loss function.\n", "\n", "JAX's documentation has more discussion of automatic vectorization at [Automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html)." ] } ], "metadata": { "colab": { "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }