{ "cells": [ { "cell_type": "markdown", "id": "b69996dc-49af-4a0e-a4e6-36d81b51f2b4", "metadata": {}, "source": [ "# Porting a PyTorch model to JAX\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)\n", "\n", "**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.**\n", "\n", "In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`." ] }, { "cell_type": "code", "execution_count": 1, "id": "NHqB3sNbrygd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/424.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m419.8/424.2 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.2/424.2 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/175.6 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m175.6/175.6 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install -Uq flax treescope" ] }, { "cell_type": "markdown", "id": "ABCg5TvPr1pm", "metadata": {}, "source": [ "Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697).\n", "\n", "First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model." ] }, { "cell_type": "code", "execution_count": 2, "id": "38504f77-4150-47bd-9cf9-3116fe370746", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from flax import nnx" ] }, { "cell_type": "markdown", "id": "95a364c2-d34e-4820-8a86-f43f59c911bf", "metadata": {}, "source": [ "## MaxViT PyTorch model setup\n", "\n", "### Model's architecture\n", "\n", "The MaxVit model is [implemented in TorchVision](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L568). If we inspect the [forward pass](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L707-L712) of the model, we can see that it contains three high-level parts:\n", "- [stem](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L641-L655): a few convolutions, batchnorms, GELU activations.\n", "- [blocks](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L672-L692): list of MaxViT blocks\n", "- [classifier](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L696-L703): adaptive average pooling, few linear layers and Tanh activation." ] }, { "cell_type": "code", "execution_count": 3, "id": "9b1be406-d21c-410d-a2ac-9bd690e5ad60", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n", " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", "Downloading: \"https://download.pytorch.org/models/maxvit_t-bc5ab103.pth\" to /root/.cache/torch/hub/checkpoints/maxvit_t-bc5ab103.pth\n", "100%|██████████| 119M/119M [00:02<00:00, 53.9MB/s]\n" ] } ], "source": [ "from torchvision.models import maxvit_t, MaxVit_T_Weights\n", "\n", "torch_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1)" ] }, { "cell_type": "markdown", "id": "45635b2d-a77a-4368-9ecb-dbb440e647ee", "metadata": {}, "source": [ "We can use `flax.nnx.display` to display the model's architecture:" ] }, { "cell_type": "code", "execution_count": 4, "id": "sZ9x7NpHtBcx", "metadata": {}, "outputs": [], "source": [ "# nnx.display(torch_model)" ] }, { "cell_type": "markdown", "id": "0a36676a-1561-4de0-8e25-38bab90581d0", "metadata": {}, "source": [ "We can see that there are four MaxViT blocks in the model and each block contains:\n", "- MaxViT layers: two layers for blocks 0, 1, 3 and five layers for the block 4" ] }, { "cell_type": "code", "execution_count": 5, "id": "0d5bf6aa-c720-4400-a276-602fff53b413", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(4, [2, 2, 5, 2])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(torch_model.blocks), [len(b.layers) for b in torch_model.blocks]" ] }, { "cell_type": "markdown", "id": "a1d55688-5999-41de-a915-eae8b281eb18", "metadata": {}, "source": [ "A [MaxViT layer](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L386) is composed of: [`MBConv`](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L53), `window_attention` as [`PartitionAttentionLayer`](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L282) and `grid_attention` as `PartitionAttentionLayer`." ] }, { "cell_type": "code", "execution_count": 6, "id": "03ce0555-888a-4086-bb6c-64c36ae60b14", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer'],\n", " ['MBConv', 'PartitionAttentionLayer', 'PartitionAttentionLayer']]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[[mod.__class__.__name__ for mod in maxvit_layer.layers] for b in torch_model.blocks for maxvit_layer in b.layers]" ] }, { "cell_type": "markdown", "id": "d57f8545-43a4-423d-b701-c2e2ca0ebfc1", "metadata": {}, "source": [ "### Inference on data\n", "\n", "Let's check the model on dummy input and on a real image" ] }, { "cell_type": "code", "execution_count": 7, "id": "d6c95620-bf50-47e4-b8d6-3a85262941ed", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 1000])\n" ] } ], "source": [ "import torch\n", "\n", "\n", "torch_model.eval()\n", "with torch.inference_mode():\n", " x = torch.rand(2, 3, 224, 224)\n", " output = torch_model(x)\n", "\n", "print(output.shape) # (2, 1000)" ] }, { "cell_type": "markdown", "id": "133bcf21-8a9c-4c27-b551-39b7dfdcfe1c", "metadata": {}, "source": [ "We can download an image of a Pembroke Corgy dog from [TorchVision's gallery](https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true) together with [ImageNet classes dictionary](https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json):" ] }, { "cell_type": "code", "execution_count": 8, "id": "qC9hpYfNtOEF", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-01-15 21:10:00 URL:https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/dog1.jpg [97422/97422] -> \"dog1.jpg\" [1]\n", "2025-01-15 21:10:01 URL:https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json [35364/35364] -> \"imagenet_class_index.json\" [1]\n" ] } ], "source": [ "%%bash\n", "if [ -f \"dog1.jpg\" ]; then\n", " echo \"dog1.jpg already exists.\"\n", "else\n", " wget -nv \"https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true\" -O dog1.jpg\n", "fi\n", "if [ -f \"imagenet_class_index.json\" ]; then\n", " echo \"imagenet_class_index.json already exists.\"\n", "else\n", " wget -nv \"https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json\" -O imagenet_class_index.json\n", "fi" ] }, { "cell_type": "code", "execution_count": 9, "id": "82be8baf-1292-4766-be34-28c510563d71", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction for the Dog: ['n02113023', 'Pembroke'], score: 0.7800846099853516\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAHICAYAAAD0hBWkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9eZxtVXH+/V3D3vuc7r6Xy6QIMgmIBhUcwCEGUBBFkTgAoiKiJoAiGv0YjSb+UOMMCiqKJM5KnCWJRjRGjfoaoxHHGHEgomgEAWW4t/ucvdda9f5RtXd3c+9lMDLFU/m04Z4+fc4e1l5V9dRTTzkREWY2s5nNbGYzuxWav6UPYGYzm9nMZjazzdnMSc1sZjOb2cxutTZzUjOb2cxmNrNbrc2c1MxmNrOZzexWazMnNbOZzWxmM7vV2sxJzWxmM5vZzG61NnNSM5vZzGY2s1utzZzUzGY2s5nN7FZrMyc1s5nNbGYzu9XazEn9nthxxx2Hcw7nHHe7291u6cOZ2e+p/eu//ivOOT7ykY/c5N910UUX4ZzjtNNOu0k+/8orrxyeqZvye37fbeakfo9sm2224b3vfS+vfvWrV72+yy678JKXvOS3/tzvf//7POxhD2NhYYGtttqKJz3pSVx22WWr3nPBBRfw/Oc/n3322Yc1a9ZwhzvcgUc84hF8/etf3+jzfvCDH/Cc5zyHBzzgAYxGI5xzXHTRRZv87g9+8IMcc8wx7LHHHjjnOPDAAzf5vu9973sceeSR3OlOd2Jubo5tttmG/fffn49//OM3yzmde+65PPShD2X77benaRrueMc7csQRR/Cf//mfG73XOce73vWuTR7X9dmBBx7IcccdN/y736j7nxACO+20E49+9KP51re+9Vt9x++rHXfccavW1/z8PO9973s5/fTTb7mD+j2weEsfwMxuPpufn+eYY475nX7mz3/+c/bff3+22GILXvnKV7J+/XpOO+00vvvd7/K1r32Nuq4BeNvb3sbb3/52HvvYx/KMZzyDq666irPPPpv73e9+fOpTn+Lggw8ePvMrX/kKb3zjG/mDP/gD7nrXu17nZnrWWWdx/vnns++++3LFFVds9n0//elPueaaa3jyk5/M9ttvz+LiIh/96Ec5/PDDOfvsszn++ONv0nP67ne/y5Zbbsmzn/1sttlmGy655BLe8Y53sN9++/GVr3yFvffe+7e9BTfIHv/4x/Pwhz+cnDPf//73OeusszjvvPP493//d/bZZ5+b9Lv/r1pVVRxzzDFcdNFFPOc5z7mlD+f/rsnMfi/syU9+suy8886b/N3OO+8sp5xyym/1uU9/+tNlPB7LT3/60+G1z3zmMwLI2WefPbz29a9/Xa655ppVf3v55ZfLtttuK3/4h3+46vUrrrhCrr76ahEROfXUUwWQn/zkJ5v8/p/97GeScxYRkb322ksOOOCAG3zsKSXZe++9Zc8997zJz2lTdskll0iMUU444YRVrwPyzne+8wafx0o74IAD5MlPfvLw75/85CcCyKmnnrrqff/4j/8ogBx//PG/1ff8tvb5z39eAPnwhz/8W/39hg0bbvB7N3fuv609+clP3uT6+l1/z8xW2wzum9lG9q53vQvnHF/+8pd57nOfy7bbbsv8/DyPfvSjN4K8PvrRj3LYYYex0047Da8dfPDB3PnOd+ZDH/rQ8Nq9731vFhYWVv3t1ltvzR/90R/x/e9/f9XrW221FWvWrLlBx7rjjjvi/W+3jEMI7Ljjjlx55ZWrXr8pzmlTdrvb3Y65ubmNvv/msAc/+MEA/OQnPxle++pXv8rDHvYwtthiC+bm5jjggAP48pe/vOrvXvKSl+Cc44c//CHHHHMMW2yxBdtuuy0vfvGLEREuvvhi/viP/5i1a9ey3Xbb8brXvW6T359z5kUvehHbbbcd8/PzHH744Vx88cWr3nPggQdyt7vdjfPPP5/999+fubk5XvSiFwHwq1/9iqc97Wnc/va3ZzQasffee/Pud7/7es9bRDj++OOp65qPfexjw+vve9/7uPe97814PGarrbbi6KOP3uh4ZnbL2MxJzWyzdvLJJ/Ptb3+bU045hac//el8/OMf55nPfObw+1/84hf86le/4j73uc9Gf7vffvvxzW9+83q/45JLLmGbbbb5nR73ddmGDRu4/PLLufDCCzn99NM577zzOOigg4bf39TndOWVV3LZZZfx3e9+lz/5kz/h6quvXvX9N5ddeOGFgDpVgM997nPsv//+XH311Zxyyim88pWv5Morr+TBD34wX/va1zb6+8c97nGUUnj1q1/Nfe97X17+8pdzxhln8JCHPIQddtiB17zmNey+++4873nP44tf/OJGf/+KV7yCf/qnf+IFL3gBz3rWs/jMZz7DwQcfzNLS0qr3XXHFFRx66KHss88+nHHGGTzoQQ9iaWmJAw88kPe+97088YlP5NRTT2WLLbbguOOO4w1veMNmzznnzHHHHcd73vMezj33XB7zmMcMx3Lssceyxx578PrXv54/+7M/47Of/Sz777//LRJAzOxadkuncjO7eey64L5r2zvf+U4B5OCDD5ZSyvD6c57zHAkhyJVXXikiIv/xH/8hgLznPe/Z6DP+/M//XACZTCab/Z4vfvGL4pyTF7/4xZt9z/XBfSvthsB9J5xwggACiPdejjjiCPn1r389/P6mPqc999xz+P6FhQX5q7/6qwGuvCmsh6Je+tKXymWXXSaXXHKJ/Ou//qvc8573FEA++tGPSilF9thjD3noQx+66n4vLi7KrrvuKg95yEOG10455ZSNYMKUktzxjncU55y8+tWvHl7/zW9+I+PxeBX82MN9O+ywwwDpioh86EMfEkDe8IY3DK8dcMABAshb3/rWVed0xhlnCCDve9/7htfatpX73//+srCwMHzuShiu6zp53OMeJ+PxWD796U8Pf3fRRRdJCEFe8YpXrPqO7373uxJj3Oj167rGM7jvprFZJjWzzdrxxx+Pc2749x/90R+Rc+anP/0pwBD1Nk2z0d+ORqNV77m2/epXv+IJT3gCu+66K89//vN/14e+WfuzP/szPvOZz/Dud7+bQw89lJwzbdsOv7+pz+md73wnn/rUp3jLW97CXe96V5aWlsg5/29P63rtlFNOYdttt2W77bbjwAMP5MILL+Q1r3kNj3nMY/jWt77Fj370I57whCdwxRVXcPnll3P55ZezYcMGDjroIL74xS9SSln1eX/yJ38y/HcIgfvc5z6ICE972tOG19etW8eee+7Jf//3f290PMcee+wqSPeII47gDne4A5/85CdXva9pGp7ylKeseu2Tn/wk2223HY9//OOH16qq4lnPehbr16/nC1/4wqr3t23LkUceySc+8Qk++clPcsghhwy/+9jHPkYphaOOOmo478svv5ztttuOPfbYg89//vM35PLO7Ca0GbtvZpu1lTUZgC233BKA3/zmNwCMx2MAptPpRn87mUxWvWelbdiwgcMOO4xrrrmG/+//+/82quvclHaXu9yFu9zlLoBulIcccgiPfOQj+epXv4pz7iY/p/vf//7Dfx999NHc9a53BbjJe2yOP/54jjzySLz3rFu3jr322mtwxD/60Y8AePKTn7zZv7/qqquG+w8br40tttiC0Wi0Ecy5xRZbbJJ1uccee6z6t3OO3XfffaNWgx122GFgU/b205/+lD322GOjWmR/LfsgqrdXvepVrF+/nvPOO2+jFoUf/ehHiMhGx9NbVVWbfH1mN5/NnNTMNmshhE2+LiIA3OEOdwDgl7/85Ubv+eUvf8lWW221UUbSti2Pecxj+M53vsOnP/3pW7yx+IgjjuCEE07ghz/8IXvuuefNek5bbrklD37wgznnnHNucie1xx57rKLEr7Q+Szr11FM3S0e/ttPd1Nq4vvXy29imAoIbaw996EP51Kc+xWtf+1oOPPDAISMGPXfnHOedd94mj//mDKBmtmmbOamZ/da2ww47sO22226yefVrX/vaRhteKYVjjz2Wz372s3zoQx/igAMOuJmOdPPWQ3dXXXUVcPOf09LS0vDdt5TttttuAKxdu3azjux3bX321puI8OMf/5h73OMe1/u3O++8M9/5zncopazKpi644ILh9yvtfve7HyeeeCKHHXYYRx55JOeeey4x6ta32267ISLsuuuu3PnOd/7fntbMbgKb1aRm9r+yxz72sXziE59YRdf97Gc/yw9/+EOOPPLIVe89+eST+eAHP8hb3vKWgVl1c9mvfvWrjV7ruo73vOc9jMdj/uAP/mB4/aY4p019/0UXXcRnP/vZTTIJb067973vzW677cZpp53G+vXrN/r9tdsOfhf2nve8h2uuuWb490c+8hF++ctfcuihh17v3z784Q/nkksu4YMf/ODwWkqJN73pTSwsLGwyUDj44IP5wAc+wKc+9Sme9KQnDdnjYx7zGEIIvPSlL90o4xOR62wQn9nNY7NMamb/K3vRi17Ehz/8YR70oAfx7Gc/m/Xr13Pqqady97vffVXB+4wzzuAtb3kL97///Zmbm+N973vfqs959KMfzfz8PKBZzZve9CaAoU/nzDPPZN26daxbt24VDf6LX/ziQHG+7LLL2LBhAy9/+csB2H///dl///0BOOGEE7j66qvZf//92WGHHbjkkks455xzuOCCC3jd6163Cta5Kc7p7ne/OwcddBD77LMPW265JT/60Y94+9vfTtd1G8lUbcqccxxwwAH867/+6/W+98aa9563ve1tHHrooey111485SlPYYcdduAXv/gFn//851m7du1m5aN+W9tqq6144AMfyFOe8hQuvfRSzjjjDHbffXf+9E//9Hr/9vjjj+fss8/muOOO4/zzz2eXXXbhIx/5CF/+8pc544wzNttj96hHPYp3vvOdHHvssaxdu5azzz6b3XbbjZe//OW88IUv5KKLLuJRj3oUa9as4Sc/+Qnnnnsuxx9/PM973vN+p+c+sxtptxyxcGY3p/02FPT/+I//WPV6Tx/+/Oc/v+r1//zP/5RDDjlE5ubmZN26dfLEJz5RLrnkko2+H6Neb+pnJcW8p/Ru6ufa59BTojf1s1JF4/3vf78cfPDBcvvb315ijLLlllvKwQcfLP/wD/+wyWvwuz6nU045Re5zn/vIlltuKTFG2X777eXoo4+W73znO9d9M0TkmmuuEUCOPvro633vte3G0KO/+c1vymMe8xjZeuutpWka2XnnneWoo46Sz372s6vOA5DLLrts1d8++clPlvn5+Y0+84ADDpC99tpr+He/ht7//vfLC1/4Qrnd7W4n4/FYHvGIR6xS+NjU3660Sy+9VJ7ylKfINttsI3Vdy93vfveNVDo2d+5vectbBJDnPe95w2sf/ehH5YEPfKDMz8/L/Py83OUud5GTTjpJfvCDH1z3RbuO75nZ78acyP+iqjmz24wdd9xxfO5zn+Mb3/gGMUbWrVt3Sx/SzG6gffKTn+Swww7j29/+Nne/+91v6cOZmZkYHHjxxRdzr3vdi1NPPXWWdd0ENoP7fo/s4osvZtttt2WvvfbapPr2zG6d9vnPf56jjz565qBuZXbVVVex7bbb3tKH8X/eZpnU74n913/9F//zP/8DKK32fve73y18RDOb2W3bUkqraoR3vvOdN+ofm9n/3mZOamYzm9nMZnartRkFfWYzm9nMZnartZmTmtnMZjazmd1qbeakZjazmc1sZrdamzmpmc1sZjOb2a3WZk7qNm7f/e53OeKII9h5550ZjUbssMMOPOQhDxkUG/6v2L/927/xwAc+kLm5ObbbbrthLMP1WT9leHM/55xzzqr3/8u//AsPetCD2GabbVi3bh377bcf733vezf52W9/+9u5613vymg0Yo899tjsNf/FL37BUUcdxbp161i7di1//Md/vMnxFVdddRXPf/7z2WOPPRiPx+y888487WlP42c/+9l1nuNDHvIQnHOrlDhW2qWXXsoJJ5zADjvswGg0Ypdddlk1UuPGHuell17KU57yFG53u9sxHo+5173uxYc//OFNfvcHPvAB7nWvezEajdh222152tOexuWXX/5bn3s/GfjaPytFY0E1EZ/2tKdxt7vdjS222IKFhQX23ntv3vCGN9B13ar3XtcaueSSS1a994Mf/CDHHHMMe+yxB865jVTVr23f+MY3OPzww9lqq62Ym5vjbne7G2984xuv829mttpmfVK3Yfu3f/s3HvSgB7HTTjvxp3/6p2y33XZcfPHF/Pu//ztveMMbOPnkk2/pQ/yd2Le+9S0OOugg7nrXu/L617+en//855x22mn86Ec/4rzzzrvOv91///036WROP/10vv3tb6+aivuP//iPPOpRj+L+97//sBl+6EMf4thjj+Xyyy/nOc95zvDes88+mxNPPJHHPvaxPPe5z+VLX/oSz3rWs1hcXOQFL3jB8L7169fzoAc9iKuuuooXvehFVFXF6aefzgEHHMC3vvWtYTJuKYWHPOQh/Nd//RfPeMYzuPOd78yPf/xj3vKWt/DpT3+a73//+5uU+/nYxz7GV77ylc2e/8UXX8wf/uEfAnDiiSeyww478D//8z8bTdu9ocd59dVX88AHPpBLL72UZz/72Wy33XZ86EMf4qijjuKcc87hCU94wvCZZ511Fs94xjM46KCDhvv2hje8ga9//et89atfHRzLb3PuZ5111iopq2srmC8tLfG9732Phz/84eyyyy547/m3f/s3nvOc5/DVr36Vv/u7v9voWr3sZS9j1113XfXatZvezzrrLM4//3z23Xff69X1++d//mce+chHcs973pMXv/jFLCwscOGFF/Lzn//8Ov9uZteyW1DtYmb/S3v4wx8u2267rfzmN7/Z6HeXXnrpzXosGzZsuMk++9BDD5U73OEOctVVVw2v/e3f/q0Aq6as3lBbXFyUNWvWrJo4KyLykIc8RLbffvtVk3e7rpPddttN7nGPe6z6+6233loe8YhHrPr7Jz7xiTI/P79q0u9rXvMaAeRrX/va8Nr3v/99CSHIC1/4wuG1L3/5ywLImWeeueoz3/GOdwggH/vYxzY6j6WlJdlll13kZS97mQBy0kknbfSeQw89VHbddVe5/PLLr/Oa3NDjfO1rXyvAKqmknLPsu+++st1228l0OhURkel0KuvWrZP9999/1bTfj3/84wLIG9/4xt/q3Dcny3RD7ZnPfKYA8stf/nJ4bXMyYJuyn/3sZ8Mk5euaBH3VVVfJ7W9/e3n0ox99k05e/n2wGdx3G7YLL7yQvfbaa5MSR7e73e02eu1973sf++23H3Nzc2y55Zbsv//+/PM///Oq97zlLW8ZBuJtv/32nHTSSVx55ZWr3nPggQdyt7vdjfPPP5/999+fubk5XvSiFwE6LPCUU05h9913p2kadtxxR57//OdvNETw8ssv54ILLmBxcfE6z/Hqq6/mM5/5DMcccwxr164dXj/22GNZWFjgQx/60HX+/abs4x//ONdccw1PfOITN/quLbfcctW8qBgj22yzzaq5Rp///Oe54ooreMYznrHq70866SQ2bNjAP/3TPw2vfeQjH2Hfffdl3333HV67y13uwkEHHbTq2K+++moAbn/726/6zH6+1abmKr32ta+llLJZKZ4LLriA8847jz//8z9n6623ZjKZbAR13djj/NKXvsS2227Lgx/84OE17z1HHXUUl1xyyTAV9z//8z+58soredzjHrdquvNhhx3GwsICH/jAB/5X5y4iXH311Td6VtUuu+wCsNGa7u2aa665zknJO+6440bDFjdlf/d3f8ell17KK17xCrz3bNiwYaPpxjO7YTZzUrdh23nnnTn//PNvkMTRS1/6Up70pCdRVRUve9nLeOlLX8qOO+7I5z73ueE9L3nJSzjppJPYfvvted3rXsdjH/tYzj77bA455JCNNrcrrriCQw89lH322YczzjiDBz3oQZRSOPzwwznttNN45CMfyZve9CYe9ahHcfrpp/O4xz1u1d+feeaZ3PWud90Idrq2ffe73yWltNE4i7qu2WefffjmN795ved+bTvnnHMYj8cbjdY48MAD+d73vseLX/xifvzjH3PhhRfy13/913z9619fNQ6+/85rH9O9731vvPfD70spfOc739nkKI799tuPCy+8cBhXcZ/73If5+Xle/OIX87nPfY5f/OIXfOELX+D5z38+++6770Zznn72s5/x6le/mte85jWbHQz4L//yL4Bu/gcddBDj8ZjxeMyhhx66agLujTnO6XS6ye+bm5sD4Pzzzx/eB5t2MOPxmG9+85vDpn1jzx3gTne6E1tssQVr1qzhmGOO4dJLL93kNWjblssvv5yLL76Yc889l9NOO42dd96Z3XfffaP3PuhBD2Lt2rXMzc1x+OGHbzTz6sbYv/zLv7B27Vp+8YtfsOeee7KwsMDatWt5+tOfPkx4ntkNtFs6lZvZb2///M//LCEECSHI/e9/f3n+858vn/70p6Vt21Xv+9GPfiTe+01CDz0U86tf/UrqupZDDjlk1XvOPPNMAeQd73jH8NoBBxwggLz1rW9d9Vnvfe97xXsvX/rSl1a9/ta3vlUA+fKXvzy81sM211ZUv7Z9+MMfFkC++MUvbvS7I488Urbbbrvr/Ptr2xVXXCF1XctRRx210e/Wr18vRx11lDjnBiXzubk5+fu///tV7zvppJMkhLDJz992220HtfLLLrtMAHnZy1620fve/OY3CyAXXHDB8NonPvEJucMd7rBKSf2hD32oXHPNNRv9/RFHHCEPeMADhn+zCbjvWc96lgCy9dZby8Me9jD54Ac/KKeeeqosLCzIbrvtNkC0N+Y4Tz75ZPHey0UXXbTqfUcffbQA8sxnPnP4TOecPO1pT1v1vgsuuGA4t5UQ5A099zPOOEOe+cxnyjnnnCMf+chH5NnPfrbEGGWPPfZYBQf39v73v3/VZ97nPvfZSHn+gx/8oBx33HHy7ne/W84991z5q7/6K5mbm5NtttlGfvazn230mb1dF9x3j3vcQ+bm5mRubk5OPvlk+ehHPyonn3zyb61m//tsMyd1G7evfe1r8uhHP1rm5uaGB3HbbbddNYLi1FNPFUC++c1vbvZz/u7v/k4A+eQnP7nq9el0KmvXrpXHPvaxw2sHHHCANE0z1B96O/zww2WvvfaSyy67bNXPD3/4QwHk5S9/+Y0+v/e85z0CyFe/+tWNfvekJz1Jtthiixv1eWeffbYAmxzR0XWd/NVf/ZUceeSR8v73v1/e9773yf777y8LCwvyla98ZXjfU5/6VBmPx5v8/B133FH++I//WES0fgHIa17zmo3e9/a3v32je/LVr35VHv7wh8srXvEK+fu//3t5yUteInNzc3LEEUes+tvPfe5z4pxbVT/alJN66lOfKoDstddeqwKPfuP+27/92xt9nN/+9relqirZb7/95Mtf/rL8+Mc/lle+8pXSNI0Aq5zS4x73OIkxymmnnSYXXnihfPGLX5S9995bqqoSQC6++OIbfe6bsnPOOUcAedWrXrXR7y655BL5zGc+Ix/+8IflxBNPlPvf//6r7uXm7Etf+pI45+SEE07Y7Huuy0nd6U53EkBOPPHEVa+fcMIJAsgPf/jD6z2GmanNnNT/EZtOp/K1r31NXvjCF8poNJKqquR73/ueiIiceOKJ4r3fyKmstFe96lUCyIUXXrjR7/bZZx+5z33uM/z7gAMOkDvd6U4bve+ud73rdc5XetaznnWjz+t3nUntv//+stVWW22UbYroBrL33nuv2tDbtpU99thD9ttvv+G1myKTuvDCC2Vubk4+8pGPrHrfu971rlXBQ9d1cre73U2OPfbYVe/blJM66aSTBJCXvvSlq15PKUmMUZ7ylKfc6OMU0Xuy9dZbD/d1u+22k7POOksAefaznz2878orr5TDDz981Ro45phj5DGPeYwAA+Hnhp77ddl2220nBx100PW+7xWveIUsLCysIk5szu53v/vJbrvtttnfX5eT2muvvQSQL3zhC6te/8IXviCAvPvd777e75+Z2qwm9X/E6rpm33335ZWvfCVnnXUWXddttnfld2GbqjWUUrj73e/OZz7zmU3+XJtocEOsL57/8pe/3Oh3v/zlL9l+++1v8Gf97Gc/40tf+hJHHnkkVVWt+l3btrz97W/nEY94xKrCeFVVHHrooXz961+nbdvhmHLOG42Eb9uWK664YjimrbbaiqZpNnvswPDed73rXUwmEw477LBV7zv88MOB5QnF73nPe/jBD37ACSecwEUXXTT8gBb9L7roooGM0n/2tQkJIQS23nprfvOb39zo4wQ44ogjBhr7V77yFX76059ypzvdCVAl8N622GIL/uEf/oGf/vSnfOELX+Ciiy7ive99L7/85S/ZdtttB8LPDT3367Idd9yRX//619f7viOOOIL169fzD//wD7+zz9yUbe7a94Sm/trP7Ppt5qT+D1pfAO83mN12241SCv/1X/+12b/ZeeedAfjBD36w6vW2bfnJT34y/P66bLfdduPXv/41Bx10EAcffPBGP3vuueeNPpe73e1uxBj5+te/vtFxfetb32Kfffa5wZ/1/ve/HxHZiNUHSgRJKW2S2dV1HaWU4Xf9d177mL7+9a9TShl+773n7ne/+0bvA/jqV7/Kne50p6H/59JLL0VENvr+nrCSUgLU0XZdxx/+4R+y6667Dj+gDmzXXXcdGJv3vve9AW3SXWk9maCfhXRjjrO3Pii63/3uR13XA0ljUySHnXbaif3335+dd96ZK6+8kvPPP3/V+27ouW/ORISLLrroBs12WlpaArR5+Prsv//7v3/reVGbu/b9uJzZHKobYbdsIjez/4197nOfW9WD0lvf8/L6179eRG4cceJhD3vYqs/sR21fmzixqbHePTxz9tlnb/S7xcVFWb9+/fDvyy67TL7//e/foP6qhz3sYXKHO9xBrr766uG1t73tbQLIeeedN7y2YcMG+f73v7/ZHpp73OMestNOO23ymqWUZN26dXLnO995FSx6zTXXyB3veEe5y13usupcttpqKznssMNWfcYxxxwjc3NzcsUVVwyvvfrVr96oB+eCCy6QEIK84AUvGF477bTTBNhoBPoZZ5whgHzgAx8QEe1dOvfcczf6AeThD3+4nHvuufI///M/IiIymUzkdre7ndzpTneSpaWl4TP7utyHPvShG32cm7If/vCHsmbNmo2ux6ash55X1tNu6LmL6Dq9tvWQZL/eRXR9beo+931SK/u8NvWZ//RP/3S9EPV1wX3f+MY3BJAnPOEJq15//OMfLzFG+cUvfrHZz53Zaps5qduw7bXXXrLrrrvKc5/7XPmbv/kbOfPMM+UJT3iChBBkl112WdXk++IXv1gAecADHiCnnXaavOlNb5Jjjz1W/uIv/mJ4T8+4O+SQQ+TMM8+Uk08+WUIIsu+++66q4WzOSeWc5eEPf7g45+Too4+WN73pTXLGGWfIiSeeKFtttdWqDfCGsvtERM4//3xpmkbuec97yllnnSV/+Zd/KaPRSA455JBV7/v85z8vgJxyyikbfcZ3v/tdAVad77Xt5S9/uQByz3veU04//XQ57bTThjrb+973vlXv7TfGI444Qv72b/9Wjj32WAHkFa94xar3XX311bLbbrvJ7W53O3nta18rp59+uuy4446y/fbbr9ocL7/8ctluu+2krmt51rOeJWeffbaccMIJEkKQvfba6zrriSKbrkmJiLz73e8WQPbdd1954xvfKM973vOkqir5oz/6I0kp3ejjFNHa4//7f/9P3va2t8lf/uVfylZbbSU777yz/PznP1/1vle96lXyxCc+Ud74xjfKW97yFjnkkEM2SaC5Mec+Ho/luOOOk9e97nXy5je/WR7/+MeLc0722WefVQHP6aefLnvuuae84AUvkLPPPltOO+00echDHiKAPPKRj1z1/bvvvrsceeSR8prXvEbe+ta3yvHHHy8xRtlxxx3lkksuWfXeL3zhC/LXf/3X8td//ddyu9vdTnbZZZfh39euP/XElaOOOkre/OY3y5FHHinAqubomV2/zZzUbdjOO+88eepTnyp3uctdZGFhQeq6lt13311OPvnkTSpOvOMd75B73vOe0jSNbLnllnLAAQfIZz7zmVXvOfPMM+Uud7mLVFUlt7/97eXpT3/6RooWm3NSIko0eM1rXiN77bXX8D33vve95aUvfekqivCNcVIiyrZ6wAMeIKPRSLbddls56aSTVmVWItftpP7iL/5CgI3ox9e2c845R/bbbz9Zt26djMdjue9977tRQb+3v/mbv5E999xT6rqW3XbbTU4//fRNRu8XX3yxHHHEEbJ27VpZWFiQww47TH70ox9t9L6f//zn8tSnPlV23XVXqeta7nCHO8if/umf3iB1hc05KRFl8+29997SNI3c/va3l2c+85kbXbsbc5xHH3207LjjjlLXtWy//fZy4oknbnK9feITn5D99ttP1qxZI3Nzc3K/+91vVfb225z7n/zJn8gf/MEfyJo1a6SqKtl9993lBS94wUbn8x//8R9y5JFHyk477SRN08j8/Lzc6173kte//vXSdd2q9/7lX/6l7LPPPrLFFltIVVWy0047ydOf/vSNHJTI8rrd1M+1113btvKSl7xEdt555+FYTz/99E2e/8w2b7PJvDOb2cxmNrNbrc2IEzOb2cxmNrNbrc2c1MxmNrOZzexWazMnNbOZzWxmM7vV2i3mpN785jezyy67MBqNuO9973u9QqMzm9nMZjaz3z+7RZzUBz/4QZ773Odyyimn8I1vfIO9996bhz70oRt18M9sZjOb2cx+v+0WYffd9773Zd999+XMM88EVE5nxx135OSTT+Yv/uIvbu7DmdnMZjazmd1K7WYfH9+2Leeffz4vfOELh9e89xx88MGbHYM9nU5XDc0rpfDrX/+arbfeetVAtZnNbGYzm9ltw0SEa665hu233/46B0ne7E7q8ssvJ+e8kfDi7W9/ey644IJN/s2rXvUqXvrSl94chzezmc1sZjO7Ge3iiy/mjne842Z/f7M7qd/GXvjCF/Lc5z53+PdVV13FTjvtxF3vdW+apmI8rqibyNzCmG22WsO6LbdEgOm0I+WM8546VnQ58etLLufKS67g6iuuoesyo/maLbfdivEWDfNbzDOem6Oqa4LzZBGC9/gQmKaOdjpl/foNXHH5VbSTxHTSkrpE8B6HkNrEZKllsjihaxPjULFmbsRoLtLM1fhRhasjBEdVRaLzeO9pqpoqRDoRci50XUeXErkUwBE8NHVNVXvwjuIcuQheCtEHmqamjhV1XRG8Bx/wziGItcMLDoe4RJZEyR2C4EMg+ora17S50OVEzokiHa4UCuC8I0uhCgEvkLtEzoXSi34W8D5S+UhoRsQ64gI4L+TcgRO889RxhBeHOA8OhAIIzoP3DpFMzgkcODylAOL1v3MhlYxQECngBIfHOY8HnJ1jLpmU9fhyKkgRnF2HrmvJOVPHSOWhqSKjpiH4CCGQc9BrHSoIAZyqhXtfyF1mMulolzr9XAEvDhcC3gVKKuS2pXSdidFmveaVx48qQlMxGs9T17V+ZvAIgBRC8MQQqKqoZ+I9o9Faxs06xuMFvI9QEiV1pNyxtLREaTucc3TdhGm3RIw1zWgOEehSYanraNspk8UNIJkYPFXt8CHhfKGqhOAdPnhCjDgHKSdSK0wmE9YvLtFNOz1P5+04K5xz5Cyqr4AHPDEGkhRKyfgY6NpW160I49Ec42aME0cdK/CKgnQp0eWka9B7vK/sWntE9LUYgz5XnmFEfEoJKZnJZMLV1yyxuDhhcXFCkcLcqKFpGqpYgXN0qWNpMiGnhMfhBaQIXe5Ymk4pRaiaETFGxnNjXKwZj0esW7OO4AMISCl0paVLU6Cz6z2hyx12EaijPn9VVTMaNYTgmU5brlrawPoNU66+6hquuWYD0QWaumHNaMzCwhxzc3PUTUO0c05Zn4WcYdp1jOYi82sCzcjr/lYFfPQE53HOnpkCbZf1eURAPBFPKYKIEGMkVo4QHNHuc1c6Ui445/Bef0JwiAhtKnRtIqUpOP2MNie6LpE6Xfv0t5+ACKSUSUmf5RA9TR2pKo9QSNLSdUI7EaZLwuI1iZQKLnmC1+P8yNs/upF48bXtZndS22yzDSGEjcY9X3rppWy33Xab/Jum0QV4bavrmvFczcJCQ91UzC/MscW6daxdtwU5Z1xYok1JHYGPuM7rAszgsjqgqvKMRpF1W65lfu08zdyI0WiMw5NLweF085wIJWdCqKiqmtwBdOCc7pI4nHME7xhVFeOqZlxVjKrIeH7EaE1DnKshelJOukH4QPCBKlbEEJAi4HX7LgCpDIstRFukMeCCbqiUgkOIVSQET11XxBjJpeC9twc/4LzHeUcqLTl1pKIPmXceR0CyqNNwQqg80VW6+QePd56SRZ1eFtrS4n1BvAdRh1KHhqaeIzY1LnoIgpDIucJ73eiqUBNcoDhzmr6AK4hkoJCLJ0hECoh4XAGKw+EAT3TqyAQPFJzzOBxViLrpO31gQurIuSBeN07nwDlHqYI6DxEqHHOxonYREY8Uj/MR5ytcCAgOHz2xCgSBtnTE6CiVw3tR5yiOWEV1Um1Hcp6EHrcED06g8oSmITY1MUaqGCkidF1C0OMKwdE0DaPRmBAq6vEcCwvbsGZuHaN6AYCcO1KZ0rUTrr76KpbWbyA4iCNw0wLe4SuhS4kiGcmJUDvWNWuoq0hdB2INPiRymRJCh/cOt2KD6lrH0jRTQmDkAiECeJz3+KLXW4qj6wqlgHeBGBtECsECGRFBagii66V3UghUISJOKEVwIeGTXgPvPT7G5bWbdaR8DOaknEPIw3PQtlOcD4QQCDESYyBnff+4qfCx0sA0O127XjdEbyuJ7MgOcinEoM+VAONmxJbrtmRhbgGPR0qhlESboSrerkVFYUTKLblknPOM60gIFTFWNE2FD9B2Ab9UGC14mrnMaB5ciQQija913URPCLrOwOGzJ0smlQwu4GNFiJEQhaYZUTeR6NTZhBCIwYN4pl1Hl9RRSREokJOu/eAjo6amqitC0OCwkUIWDdxj8HgH4jTAq1IijTIw0memZKZdR9smDZzbCVLAe91XskDJ2YJpIZrDjlHXf6FlOk3kzjFZgvG4Y8P6KWlaCD5Qsrq76yvZ3OxOqq5r7n3ve/PZz36WRz3qUYBGV5/97Gd55jOfeaM+K1aB8bhh3DTUo4o1C/M0o1pvCOCDhyx4r7F2TonUZXLKSNaAua4jc/Mj1qxZw/zaeWJdU1UVKWXLQ4ACRTIlZyiavWgk4ikCJRfNYJyjCp7RuGLcNDQ+Erxnbs0c9XyNH0fEO7rkEdvgvQ9k9GaXohuGc043P4tmcsl02UHSzKaOUW+s8/Z+LBrLFAqlFEKIeHvIQ4/3Fl3YZI9IIYmdnBRd5BRitI3fg3pfh3hAPBnBuULbLlFKwokQXIOLNT40GvV7jwtCEYcvDofgxOHRhQ1CcYJzK/M83RCkuOEaSBF1HuLJWSgDv0f61Mn8tEaFzts2JOrYnFOHIBRzVBp95pTxeEJxhOzIYtmi101NEkhwiMC0TcQSSBlyglygiODVjeH98n0A0Y0tFXLJhOhwESQLqcsgLWQhlUIuRXPb4DWTk0B0ER9q6mqB0WiBZjRPU4+hCKkESpcpnZBLR1emZCdkaUmu1WtaWrJksgd8oapqFuoRMQScF+qRo2oiXdLrFrzTLNY5pBSQTJcyde0R0cjdOUURpBRKFlKCDOROyHYdcim6VrNQSgFxRK8ZGnhyEb0nHsSelZQLybJc/TEn4AIh6JoW2xcY7rp+frENUYoguQDe1qo+K5ITucuULiFdhlwoQe8XwRFCoKorYrHPy0I3Feq1nnHd4C3mdF7XvscRHICnriOxrhAqUk6IOCqvzisEdbT4RAJGLoLvECKeBpcqcudwRZ+9XNQhSQacQ3KmSx2TNtO2GXEZXER8pG6TvscQG4cjo891tGfOS6FIoSuJUiBnZ9mWPvMl6TXEOxxBnxEs0BTBWQZbOUeRgjh7hxd8KLisWT4OfIi69rPukaHSfaaKFVVVEYMnBE8qAScdJThK7shNoWsDuUskyZqZ3ZB9/oa7hN+dPfe5z+XJT34y97nPfdhvv/0444wz2LBhA095ylNu1Of0MEmIjioGRqOGqq6HTUsXtiBeKFJIKamjEXVgdRNZu8UCa9euYTw3Yn5ujmqkw/y862hpKcU2R0v/cza4ou0MDvMKQ0jBSyF4YRQj6xZGVL7CBcfCFvO42lOiwl2lCG3b2cYZ1CEKuvzsofVBoOhm2ZaC5GzRbcCeSXJZhotiCASLYDxOH1zbnEWFhCmlkFImJ4X6qkofrlISrusopSN4Rx08mrAL3gV8jJTsaEuiswwzpwxFcFFwrlI4yCuUhRckC4VEwBOcG6IvKOD0fIcNB0GKU9ii6INjv4Ji91BQmM8eHucdHg0M+vuDnaeIOsbgFUJyHroukVOidImMp5SgGQ8eV2HQoSfjBmc6nUzIpdYMrpcR7SHKoAFHTgpT9s4n5ULqOnCBmD1dm5CUKbHQ+o5SRIMSEWJVMTeew/tACDXBN1TVHM1ojmY0oo41JRdKynTTpFBTmVLIiGSSTBGfcAHEZ3wQAgoHVTEQAzR1wAdHqIS6hljXBvdotuNxej+Cp2kiPkBVBdug1El1KRms4whVYbKkayingoDCqzmDga9DkJULHQnnvKIDOZNyJhedG6UIgW5sfUDlXH/v1QEuB9nFAhddCyJFM1KxgMFgsJKLOSdBukwpSYMzLwQXCFWg9hFEn0H9As0CRDJSvKK9weHFg0RyEaoQCRFindVxI5QCJYk515oQPRlFMUZ1AImUEiELpQ10CLl1w7NYSgHnzCEKpeThGZUlQVyHrxt8CBQplKpCxJOyEKOnrmocmikG58guU4rCot65ATbPGo1SJOODZi4eyLIcKBTBrqsGUmJBZJFMkYSQ8FVEMggeKeBwChdXEeedZqbeE3AEF0iSDElxhJCoqohur56u61jKy2S469znb4xT+F3Z4x73OC677DL+3//7f1xyySXss88+fOpTn9qITHF9Vo80vfbRUdWB6B0haMSsDwVaYyAw7aakpSm0HUjCR6EZ1axds8B4bsTc/Byj8ZimGgEwQTfNrp0wXepoFzO5ha4ttF0iSUFyIRJwrpBTh6cwrmu2mBszPzfCE3B1oERPCUEzuqJ/h+giSKmj9M7Ja2TtnCPGAN6zOFlimgspd0TxjJxlGljGhcMjaPKokY14hfxcSfjiCL6mmybariOlTIyROmp9JARPdlmhkdRHZwGnELc6KdfQpUzpOugEXwJdCkTvqUJDiB4XnMKvozFFCq2fUtlm4HH4ECCIOrahrOFoW8tQceTiENENDsCh2aSPjtQVnESydHj7KMh4Kpx4JIOTSOUdiURXOkJ04BWuFRFKB36qf+uDOtgQFTZyPmp+ZHWT1CXyNNPmRVttoo5ZHME5QqP7W855cIApa02szZl2Wgi5QNCMaRK0BiMamRCrQBNGjOsRc/UammYLfNVQN2PG9TxNVRuKrBt81y0yba8ipUWEqUKlZUoVCr7K+Eqz8lKKOhlXaGJQWLCCqvHgMk5QOKYPXETrIEUiPjiFlCqtK4SgiEEqlTrXBNeEKQgsLiayZeCI1/vjNKv3RDL677wC0sk5D9lRDAFHIVSKjAYHSKbYBi79rmmm7C8PLiDiIDtSLhSBynmCeHz2Pfah8FTOZAsgq6bGBaGuAlXUjXRuvmbStfq5viPlCSEIxUW8c6SSSLklBE9VK/QWK08MmrHnDK52VpNxlpVAHSralHAIITrCKEBUdEIhaw2cLeSjCoHgHI2vyLSasedEah3dpCXVFV3jFdFJhSABXzxdLloq8BEpijbEWOGo6LpEKULXGsTqNeHPJQFicKkFlWiQJlkD5pQdLkApmZSywnJF11SIEcRRkgawVdQaboga1rrgtT7mtQxR1RWpyzRNg6SOkrVWLThiFW7QPn+LESee+cxn3mh479pW13qhnfOWxtcIhVy02Kc1j0Jxji4lI1Jo4dt5T6xrYlNTNZqBxRipq4gArU1pXVqasH79Iu2SFVujZ25UU3JmYjUjDHoIMTBuxszPzTNqao1KXB+xQLb3aY3ID1masyiyjoEqRoUAszBNiSJCLomcMouu4J3gRlrwhT5LWsZ1cy7k0mrmJ6Kwphe6aUdKiRgD0QeqEKkt68ziEEkU7y3C1ppLpiASUEjFEWOl2VtqCWitKYQeIw/UldYVFFYoJK8QqxNRmGKIlBXuEMxh9ddB1KFWcYyrgjqpUrSALROKdLQpIVlhC+81oit9VGsRIEUIBu85y6ZzV+imHbEVkhQ6D0ECMWhGJqLrImVoc2HadkynU6uRlRVF5gARUjHHjn5eLpkkhZbCUu6QLITi8cHjYgDJA4xVNyPG4zVssWYtC+N5mmZMDBXOR2LU2ut0MlUYFGjThDYvMW03kPISWP3QB6GuPKFxuhkGrbHl7HDiiF4IXggRlD6i9T+kh1ShFDGoDkDrJCEErWPFSAjOon6hS4XUCV3nmHZCXuxISe9PajutswZHzm5wRj1i12e4oISUHoLW695D3Vg2oWmrw5mfKua8NAhIaZl0kVMZSC4hBL1XTvS5L30AhEVFmkFUVTAigadqKooYV8ZlskwhZ6KLQB5gfeethuT0OBDNuLU2mvQ4SxmyD637OGJVUYun+GD1n0jX5aGU7SxD18dXiM4RnEDQ7M45rSfr8+20ZmWWXcZnBz3qYIQWZ8iFWJ0wpU6RBy+aFUkmxkohPUMHihRySrrX5AJZwGqBPQpSEEQckvXZCs5bzduOTxwM5QqHC5U+1wV8CLiQEa+17eACTXXD9vnbBLtvczYaK2HCeSEVoUsdsVjBwhViFZh2HbnNbNiwxPoNi0wmU9qkxT4XlaHlQiB4LYJLEcN0M9Npx/r1EyaTFrGNL3pPFSNV8Ewka7E62WKxrw4E6lCTJNOhkYikggtOoTLbFER04QUMzw2OGLXAm3LWTSUokywXoU0dixOFB0bosWiEKeaUdKNuu6RYe5sItJp6J4U+3GhMUzfUVUVdVZa1jSiSSd2E1LWKSePVobpAVY1omkCMNWFpAzl1lLYjhsi4rhnVym6KMVp0rhF7YBl21ZoB9JBQyaKF4lSG7MmHSB3naOp5QBlWuVOih/hIEYd4wXnNzrRWiAaxdt/EWQ3Egbc6hcJlLe1SC1OofMTXXuGcotcpu0LOjmnOLHaJyTTRpaTfUbIyl4wYIq7fE8TK8VoPbEsmUchOnaULDhcEFwuhiQZVBkZNzdqFBRbmF6hireeZM85li147utLphuKhy0u06RpSWcS7hK81mCgFqsoRa0+Inhgr+yzNYCS3pAJitTdQZ4vXWl9P0VlZCwrBa10hBiWOBI8UZcmKZKqqUtiw0o1uOllkOmnJKVPXFQ6FeaCHj7D/lgHS894TrWaqm6jWd1zQGq/Yhqf+TW+wEz1a3QyXYd2UOkIQpNS6XoPW26JXVmwZalxCKYkknlEIhKgbqbLrHJDJZYqTStcNggueKqqDyiUpOzU7RIoShqoAYp8vWetGooGi8xB9VMKG81BHRtmTGmG61FKSnqN3EHpnIRpYe6+Z5tx8QzXyRAsYfHCKRDir8VFIJSnEKPpcSSla2wWF3bHnwxV1niXpGiuCT8kCZAscszq1QrZyQaJQBifknZKKxDkw2BALBvt7ab7U6sB+CEyc10caV3CgpJh4wwSPbuNOyiJ3Uacz6VpiCsRKay7OZ3LOLC62LK5fYsP6JSYbWrpWcVbntQyJ7muALrCUEl2baKcd00lL1yZ1CCHgsu2MopispEzuMi5rUTm1Cv+0XaKTRA7KtsklUY8bslFwQR8SYCBLOCvA55LIpcMFoWmUdt4Zmw1XtAaSHU4CwRhqKSWKZWtZhNQluqKbbHRBYTcPJSWqEJgbjamqWmtSIpTcQrCo2mpkzjnqqmE8nhugQcjkbop0HdF56hgViw7L0akr4FAYqJCXsWvv1Bkmw847oW0TzmsW5n2Nc8qUiqHGO08XW1pnDEtRQoaz6NyJU5q5V5hQawoaReOdFotLobSFMi2UNpOmAnXQgriGz3Ql03aZNheWUmKpzXRdhyhgSynFaPJ9KV1dkytadJYiCn1REC9UjdVJm0iIChXVo8aOE6qgwVWMUTe2kpHSKUW6Xc+0FpwUClNKyky7Rbr2apBWoc4VpBYAV5yRU/SlQjF2WrYMiIFAE6toma3DuWAbnthGDSF4y3SCnr0AEo0UAlL07zye3GVSq1l+TxxJXSH3kHS/Q9oajzFSVdXgNHorpeB8wYmzGo1SUyQrScY7I6uI1qhCCMPf5ZwpLpCLbsAhBK0nVR6fHAFPKplCphOIxVNcM1Cvg9f6UpcySTq8aL2wmFPFiCdSMqUoBC9FHaBtv0NG71xGXMCLV0KOOagYHFWo8ESmnWazaVpIyZyt0zpWdB4XHSEKeKjnHKFxxAg+KkwbnNXfpAzXQAwe7ckQYmxEhssv6AGq83Q4UurUJ9q+IxbQOYw4IeB8X5dm+Ld3HoIfasTKCvYW5PSZsf44RfHtnmGlGWE0VyHi6Ka3YuLE78qqOgwPQy6F6bQ17LjCBbQGk4XJtGU6mdAuTZksTpCkGG0MASeOlBJt2xFiAHFMJxOuWb/I4uKE6VThJR8V1E1FSFmzIQrKELPMq5NCFkEcVreyYrLTTUSZZ8piKkVQ1Mqi/ZwpUXFnjdSUhCFOF050uqFpERK80whY61pGRLC0PZVM2yVdJLa4gg80VaVZT9BNOnotkOcsiChbLcYaROEL7z11PTL6v5CSU7JK0B6vYLBgX9sWxHrGQEqyhwdiFamrWlGJ1CHSQQlIDuTOWEjeE+KIKs5R1/PUldJgfZhSiiN0HbkryhZDi/5BdNPXswFKUUps6hAj3eXU0VmgkbtCyObE0Eg0S6bNmQ2dU7ptyZZ9GLTqtV7g+2CmCLmAK1lJDV0md0kzbe+oK5SlVxv1u/I0o5p63KD9SIIjavSPblCFpMFS6ZgstTi3RKw8zidKSXTTDVCWCF4o2VFyNpgp4AqU7PDBGVMMhX+sRqFrQ5Tt6QJkpVxrVqP3vGCwtaWIxViWfU0pW21jOs10bSG3QjfRIK4UIYa4fE0tci7FPpM+05ChxgVKcCjOyAN9L1Tp2Wd6bE5jCKIHpOBEnztn7Q855+Xvk541qp8XKk9dKlyVURqrsk4leAoOgtc+omBboOvIA1FAkQsJgcqo+s6zXHOz1ox+M/beI2454NW6pTpShcSFOlbae1mEpo54CuK0XYKgGZGyPT2h0fpVrBy+cvgoVg/TzM+LOkVQSL6UMmT23oGL6lVKUSfvKIjLOIQBG3Q9o5chw2XIIfsgRinqy783gp9zBsla6UJxF2XTGnXf4REn+OCpgqNET11BbjzJaVCJXw5Urstu005qNKpJSYwVptFv2yZi3RHEWHPoBW0nLe2kZTppFV6rHVVV0dQKjKbUMZ1qnWhpaYmlpSUmSxO6VnuKcg7DQ59SpnQZSUYKwCu0Z42u6pgw5pQt2Kp/8LTw205b7esJ1s+RNbIKoY9GBCcJVxJV1CjY4GNCQHFrKXRtJhkjTZzCTsoiykr7Lp4chGCQYIhBHWrbUcVaU/X+LEJlrCDFs2MINKOGqg5Ga/dDVEkpuqeVYpi6OgwBzSqsJ0kZmA3eol0p4AkonOcBzQShpqnn2WrL27Fm7TpEnNaz3BI5Q9u1ZAox1ebE3ZA1iShcooQHPbdcAFcoOWnxedKRuzw4X91DhJQTG6YTru702onTvp0QlBji0L3T4Sgp02J1HC9IgjxNdG2LlKKQW4hUVaAZRao6qLNqKkJUSC4VyMWp80lTa1oFnzMlKVmj6ypiHahqLbR37RRXOiof6Lxm5g6IQ9OpINlRvLPivG442Vh5iLfMqMK5iHMatHhbi94IDZIZaPQaUGj7w6RTJl07TUwXE9PFzHSxwxlMrY47aYOwRyHw3kn16LtzLNdNM61hgYISn3AWajhlqumJqPP1HigKCXpJQ6N6XytLKSm5A9tmHfjoqVxNcIWQrEDvHaEKA8MW12fTEBG9/lmJBUU80WWrVSkUaRU1QozW5KybfbDMr2Snmb7zuKL7jjagQ8odSVD4PEZzHFCkKAJhDdBFdOMPlbIMqyYMwY438oPH2KW2LhGnbDyxvslgkLrVeoWeQYi1BGSjli+/p2fC6n6J7ks9POd7DE/XjWJ6erV1/7K6pTE3i2W8ONHyhTUpxwhV9OTUgSh56YbYbdpJBR8oXpi0CeesgC2ibBunGH3wCVy2lF8fAieCK0J0qjxQeY+kwqRMaUuhnRYmrTCdtJSkRc5p24LTBzDnQpu6oTm05GQBrCNNO/J0CqHBx/7BFOt4L6SsNaySFSpJQKmFUkHMLZ1HoUhXcF5oYqVNxJU1wgpWJxDIPZyjDZZYlOMsQ/LaCEMSIbqAeO356XJmqZ3gXGQ8mtcI1Ae8FyiJoDiWdqvX4ELBuQQ+03ZLpDyhtFPwkS4nGiCEiIvRoABl1VUxaCTpXR+84X20hy0rXFCEJJk5HxiN11CFBlccIdaAw4VFCp3h9BVVrBWKs74QrTv11HotwHcJo0RrkTsnoaQCWTcpHwMZR3aeEjxt6d+vBd4QotL6Y4Xkdujob7sOV6yfKATt+J+2dNMW5wq1c1pXyoIjWLN4RVOPCFWlEGSr9YAuZ3K3CLngF8ZGgc8EV5ClCal1pKjkCB/67Nvue4mUDAmhqow4r9wMvb9JaFuhZI8kpUdXfkzw1XA/vFHuxYM4j2Rj1mWU7ZCVdJI6zaZSJywtZpbWK3yeO13TzuCmqq83WQaUXUbFKazdoWmog9bMkj1DInodYlVp31rlqKJTbMs57beJXus9oixW8Y6qjgqLBYXHOydMu0Q9DTgfwItRogOEQGQ5E9B9Aw0agsdHh1BwSYg4UhGyWCYQIiUlvKuUUOQDla8JVCD2vuCGJmXTaIHikVIZVKtkBMx5jZqGUEVCJVRGpAg+Egzy15qUs4ZYhwsGLZZOG4hDQ4jV0FyP94ry9MSh0gN9Wl90zln9Xds/PI6KSCkQk1Asy1QSSH8MPWnDLdeUzOn0WVVx3sgx6qSUuZnNIVZaSXQajIqJDgTvqatIl4WSC9XvA3GibmoIgmt1IwkxLgsVivLze3jKeT8USn0p1DFAESbTqcoVdYlWCl2XmbaJydLELrQzqaCkG37WzdUZs0sLzyYx4gxCdIrrVr5CnEKRCm8J3TSx1CrNNBnDKlvPVUciBo+Yg/LCAMsFF1fg4JBzImscCnjrB1FIRZxmQcFgxWDsvFwUCoxZN/UuJmLu9GFGmYd9w5/3Th/gotF4LpmUpnRdS0qFVAoltficiXXD3HieejTCV5FpO6V14Io6b4UlsaZLrdeElFRyaJBwqQwbt/6pnAfGXQ+HSs5E7yFEcunwaAYKPTtqmT7epQ7FTvXvilGalaVV4X0kZaEr+j1V9PrAV1FrLlZQTgXapEFOKiY1E501YKvjaqdTHELJgSpa5h61zlLXDu+jOogQkEojzWJEj5QzXdviXSR6QURrb15UyqcUdJP1qsahcKaQu56pFzRZCWI1Akfp1CnnTgv8lY9EayyvYqUFeAegkJ4Yc07KCjjI6lr9itMeJa2tiu2EfXbUF84HJpxtet45mvEc44UFDWJyJqXEZLJEahNdFrAes6oItQ9UovCU1ti81lKkaE+X0+bUvn9L5aVEs5TkSDkQsj73ep81IxF7RnsWR4iRKlbLG3KIBBeYTjukEyUj5IJLFpyhEG0V/LC/CFYbcz3zsNDLEQmFNrekXJi0LYtLSzjx1NVoaDFxQQZHGUKgctrPJlY79r4H5/QeOAFn921QXnFuUFsZGK5iaiyKTw7Eif7eeIx9a/VTQg9JWrY5vHf1Xtv3dmF7Yt9Hp7/UZn3nlxuxBU0i+s3Ye21rcK6QJZBTYlrSDdrnb9NOqhmNCEWhoMlUI7aeXtJHdcEHKh9pmoZUt7imokxbKh9IbceGpSVcE7WXCUid1hko0j/HCH0zX0DylOw8MXpS8BSv9N/KeRyRUV1RGeOulEIRTxKNiMUw9Zx18xNTV8hZm02bKiqsUAo5GWxRNKMprbqkge6ZlcGzktrrBJwVJnrJk+h000AKOUHqEq3rm/F0UdXR5Iu8UFxBjJmQ6ciS9MFtWyaLSyxN1jOZTkmSFNrwkRBqvKuIvlYYorYIt1tSum7fgGnfV7JqhE27RJeykizwNHWjDDUHWRR6kaEwrHCC6vGJUYvFCAzaMd87WqQYdKQ30MWAL2LQhWImGa3tZKPrBgRn5A8xyZacOiat1is1IHEQrZlY0EJ+WdZbnCzpZjgaN5YNFOpayMlR1xVV0IJ9DEIdDZ7xkEumnbZkbywxKYRskI9EijOGWZco5njAUTK0lOVG1qQZnxRPJdrX4r3KVtWhobIWB7GgROsv6qS156sdWHgUfY6897ikWabSusuq+pMSamoMMND/cdrYPVqzlnVbb0toxqSs0PV0MqHgKbKoGXG2RmkSdaOiV87qjkjHsjCW6kXm1KompDkDUBTBzlxhXunpGgz7gDcmrLJZFZKtQqDq++R6GSjnccmcRV8/G+pryxs0mIKJQaWaGSojt6BM3q5NLE47liYKz6aUGNWJGFWOrYpes6oQaGK9/F2suL62juwFPV8pdCKm46cwf5ZMlqLfLRnvBIruFYL5er98/JpB63F6F5brha5vls5GRlpm/nmvz9vgqA3fz6WH+50SpaxVQ/cmrXtWMaiqRjfV4FkKcsN81G3bSfXReawirtUIv512BG9MM5YX6CDiWUVS6sgUJl1HnE6J0wbX1JqemlzKysZDbwu7qSpKl0h0BOeoQiCFgstlyNLqGIbFnUshI7QilInCkakrCjvS96ZgvRe6icZKZWly1rpX6jKh9NGNEgxwWigWDzlpET2EoGG30wK+WxGRDas0g6RC6Re1a0kOJOmD4ytnUIlYpqRNuCkZ8aRdInedspGCw9eBalRT1SNirPFepXBiDJQSmbaigqvJZICcNkhPJhOmXcu07QBntZFMjIGq0hpByi3TdsJ0uoHWxrToZbXI0mltyJXlSBbRDDRWQSFGDCoqpsFom3YSgS4rfBq9SVpBcT2NV0gmFjvptCFUobtAEyJ11KykOE+uVBsyZ+im1jSM1ju1XuNwRBwVwTdEr5lYFYUutXTdlJILbVKihndiRW+DjnqcFGWelr4R3BQLnAt0XbFapWaiyhb1UAI+NNT1PKPRnEE4Quk3owylmGqDQU0aZIuhCOasiik4FCOPGMTqrCai/TLGNDSIqG7GbLFuS+bGCxArRl6ZoTHWet8A7yZ0XbtMvCmqriHGRKyqqM9i6kiSkdxRcrJgJetlEXVqYuiAjxEXrGE3emOcyuCk+mPUPiftSURUrNW7oOQUCyh7dER78Atd6ghUxKgqJhqEFgv4sADInm+xf6PEJIdobdlFlVCqIrExJ+W1dxFU6FdKnzFZ9u+iBR/aBKzZmxvq3mK1rOIyxYllqHoM0WpI4jCnpefj8RCMwGCf4Xy/q8Iy4UUVK0rRxu/++jmcPYqC63vcREMFHxSmtEqW/W8fbGgJwLtiPVrXb7dpJ9V1HXg3pPNdl5hOp1RVICcrKme9GHWMdDEg0ZGjpy2JSdcyysUkglSCZtpqVJs6bX4Fy15Eo7wmVuSYKCGSncJ0yrqTwUnFEJAAxSubp6PQWS1Dck+Zhr7BE8PPY+VpGoXGuqkpHHeZ7JVNp9ghFqtpZK+SMP3mAy5g+mNiPRqoxl4ICim5QEB19SgFSUmV2VdkJ7oNFbqiCy5npRmL6L+rOlAIuFAPPTqhqjT4NnpxMmhH4cE8LGyV0Em2KSqkWFVK1JhMF2maBueh7ZaYTDcwmWwgpRaKspN6ajWihW5yWfEwmzZbrBCnfWGlf+DRY5u0rXbL+6jQHTVEe1BhIMZ0XVLIWJRhGYPKRdXRD/UXbYDM1raQ6ZJeu+k0MZm01HVlhWklo0Tf4GNDEMi+DNBmyYW2nVCyylIF71X7z+oAzvfwm8FyFt3aKdF1CscpYxATAXWIizT1iDrOUYWxislSKGTd7PsmMzQjoKDrTQolqLKGNv2qwkBJRRHGohtvv+l7g9Z6R5BzoWka6zvUWl1V1yQ0YEijEZSsPULeWRai1yF3LZ5CsGZTbVloEcnaoI06pYgGiZX3JLIGZiEYvOXsRyH+Ho7sWw7UQerzIKKwLcLQUOx68oM9Rz1UBrYRlyHM1LoQK+BxAgmHy9on1VQ1Xa2Eg/nxWMWE65rKetC0kdxo+K5fqAZNWiajD5aeV9+83juxficYWmhKsuDN3uu1Thfpm/KFXvOzD4Sc5qlQPL0os64tlb7q1SsKQglAUXUfpP9hYDo6pw5ZD7zf54wv6DyxqjV7doLjViyL9LsyN0SOy02swXtSLuYQCm2r2UgpaONuXRERug6ooxa0QXXaspBzHmT9SynmAJR55Ew8s3KeaSmUriN1LZK0oCvO9MX06EgiTFPHVITirQnOO434sTi5KP6uSuioOrgYLFFFVYeIFUixmoDQmWOrvW4kzjrt+y5871RUdyAsOEdVV1r0DVHZdQ6NmIvqeYGYCsQyy6g4QTqFoKRos6I6O48EwFUWubW2kQg4R5cT7XRRNewQ23CVKReiKRmkRHGO1AZGjcIv7XSJyaSGANPphMnSItN2QkkJMaWB1Gl9sJRC22qTtbfirf4EotdeMRmUFXon0pHajql0OAJ13RBL1ozUaU9MlsJ0qjBT0zTUdSQb9V7HfHitDVqxuOs0oKEotl/E+n0MqFL22TLsUcdG648mo+REVRuWlqbkrqWprckaKAHVQOx7wUpR0o9zuOgHUVeN6AMuKTyKC9bkOyZUNeIDXRbEOfwAu2hfkcLQ2tA8nbakNiHiqStzUsXGfyx1LC1OmU61hzB1STXabMSGwkU2RiIoE7ZrW1V3d9oL5qPWUJqmxlOIHqJ3TCYdOXdDEEJWyrRYbURjH6vjmf4ipZhGnOrO9QmA81YPM6VwXa/L9bM+OxBzjin3WaG3/j3duIN3DJRz1weCShboSksRhZ1dr4pvzMaC9aOFQOcLdV2xbouaKtSMmor5uVoV3xutp2VRJCCElZmoG2SOVH9PBmIDgkFvToNSS7bVXZWh3mllJ1LpM69CXQUVTzbkJiWrGftlkoQ24lvVqvTroqMXH5DocJI1owNzpNYnZ6QPZ9l4ynqPBiak/XcMDomi+9oNsNu0k5p2nTZEioPiiHjaLuOWWrKlo6nNLE2T6mG5gKsVP69HgWpc640TE3ylv5i6cINzjGrrjM+6ieeUTbZeI+2SlV3nnacC2pxZ7Dq8eBIwzZnOFSpXEZ0jB6We5s5gKmPf4DRiz1nZRN4bzZkegUlaK6oqukk2RlVFFVVyXxAykeKcCmH2DL++T8macRX+ywptBpVEalNrUCIDvTcX0bBVDM+2b/AVuOLITjXCSpdYmlxjZI2KJInFdqIbalH2XRV08xIHMVYawS5OGdcFom7K47phXEcQbTIUEwLuI/iSrV4oqqjQZzuqBGKkmFARrJ9sSiaLp7OG06WliW2wfb3RUzeZOtcQjN0Ug7JApZjSvKpThFG9gpLu6DrtqWlbbWloJy10QuWU6tvEQBMrU4CIpCS0Wegy2qeFN0UIcxDmsEQgdVl7wGJPRIi2zjKd1YyCVyemuH/Sxt+BzaiKAVkKtW9UdT8XhKxEm6QEi67LpNyiI1U6pm3H4tLUZLQMWnbqBBcnifWLHYuTRG4TJRei09ExqrJujsF7c5iZLne4NGVxskRCGM2NhzofqLhsqWq9t5XQSlJYsehIlbYrA5yliEDQ4ASDZUX1/irvKc4pI7TomBmdANDXDr3V/owYZJlKSUWb7bPpDzpn0KJKh/kQCQ6bxWTQcRYSGSlKbolR60He9dmy0xlvOKai2WxdK5xYx8B41BAbj486qqaIXqueBl9MLcKJ0sA1iNTP7bXxkiR1nDiKUwfUZ1RKAglaFxLt30ytoI31Tpus8SbjJZBF6f7m7LQepvtDLtkk1jQLylbDLkWfTZUB06qwk2zODlWF0e3YguVexUeWKfGGfPQQ8fXZbdpJLW1YpK5qJovapKtUYe0j6ZJqq0kW2qRReAYIWowOVUUzGoHXxj2HFo/b6YSuS+p0TDYIAAddykymLYuTjkmXaJPiv13RRsPKe7pUcJMWH72qTYgNv1OyDdGaCUtJxjpLAzqhgo/qXAzABbSprxDpkg4ey8bGKsnhQk1AszRNZLzx/UyNOCijq65rog8KoWSNuhScFouoK3tYxLADGVr7ND60arRBMM4pq8xJoV1aJHcdznvanFhqp4gITWwYNZHY04GtzuBcRVP3EEtRGZ5a4cJsVOicVLRVcrEaYWJQTRexRk6Fc7NR6F1Y7sWRpHpy06WW6WTK4tKUaZdo26wKDZbpinfEOlqhXCP5qq4YjcbMz80xGo2GInBOncHAmS7rSIWUijkblXppoqeuKiXauL6xWWdILS4tkQsGiamTyKlFsqlWWzOnRv2yvPbE2agY7B6XgQQQQgApBt9l7QVyKt2TckuXW3zxVtC3IYIu05UpbZpSSiLlTnvLcp8FqhWBznQM27a1vkMGVQccw3wj0Hun9w+tt2UNxLo8NQUVrb14p/W8Ho7WXhusV9C0Hr2npdX1YgV/rBHee8W0+zEpqgQehuGj0vd8BW8kDyW3DGtDrMhv6yilPPy3lIxIRYjWjmAPfy4CpaOX9VIgQhVnJGgwYXuvZr4GmgRrmq4qJWlUVd+CUXClZ0NqIIioMkzO2oOH6HOsyvnW/OyUk1PQcRriljd+Z2w/RYIcxZVBEaU4DVKkKCs5p4x3WpNzWWvcOJWOUgkqMbaiOqqcC5IdRE9nkF4e6m5Z75H0Djbjk9eG9xIHJZqBfILVRI2Ac312m3ZSi+snpFpIPaSXVEKzS0r/dCFAUfWHadvp2IgQtS8oOBOeFQj9GAwb51GyjgAxaMt5R+pU3HSpSyxO1FGlVjM0sYctxkrhQ7eswl5c0f6K2ijkQaPqMlW1A/Q5U7iuKCUdb6ysWhd4U+lYCJlkldG3BYQpU2iThAx+LbqgtSfvDS6MVqtRzL7FCAf0OlrKKtT5WAp5LBc1dUH2zK5e9TtiWYY1WraTljZ3TFMmAzFEnTNFD10qBVup5trnA85025TK7bNGZJIVishta4wuVZEgK14uSecEBVH3OSgUOGd9Y8qCa5daNqyfMp22TCadZRR9xIsNitOMuaorqroiS2Fufp65uXnm53Sabi6FpemEkpxdDYfO5uqLx31BWpUIKrvuwTZkzdSnpE4/Rwf66YydlKZax/TNMpTnMYq1bsiCoA25UdcL1i/leikbI9aE3lGbRI1LdGkR7zUDdz4YBDShSy25tCsci9b8nHdW+DaocRk5GhxUXNETFUzouFgGqrUO1ZmUboKIOmOkMBov4GvNciQrJOqteVfFeo3J6b3WrpLC18H1xX118P2ID0KrivReVcsRIaCbmhOQrP1aGlT1PT6ahYlnuG7AcOy9EGvfrCr9OkFhT4YRMKqCTlCosJRMzj2cuMw2VcUGnWoQgxsCRx89rq8DFtFBrBjc0F/03pkamiC+IEHhSPEFQh/QaM2zv1G9Eo0XRRjEaR2vaxPOGHVKzMuI9ya7pLqYSpQowxDLYgoSCk17colMO1nBMi46UYAelhXdi8QRS++0GaBgrDZVRKHOG2K3aSeVC3SdRZjF03Yt01bx0ioVmznkdOgcfaERBqqp0w7rlAtdtjBVhMp6YTTVV3rm1Bdcq0s3iREDsmZroypQV4G50YgYnCo128NWnGZpykKE4oXohFDp9lqFEXWl82f02dA6mA5oi4b3a5oeKo/rdPEHp02JzgrBgkajKhljIInOcVenJolcljvPC2IabmhB1h4uRHsyPBppai1MIaEen05tsfpGwVMU8+4SKSlWH2JNUzeMm5EpTpu6uym8K76t0abOBVqGgKSYdmAuQ29UKfpv6Zl8KRsEKEPNzYmK1vaNvZIzbdtPFdWivBes+VVp/qo4rddtNLIx4i6wsHYN4/kF6qbBiWPSTiFpH5XgBn1HBKMzJxvgphNH+xHxuGW5rtIm8IHQVtpMqkEpIThGjcGCRWFc5/0A9w2UYVTwtJ+YunImUTE4q6cRO0vbxUMprUFnNao7otBfTzZSYVCrSxZb85VKb2UcoQSqusGHJYSJFcetrmiBQV6xwasSiIrRZtHrFFKwgaEJSgTfa0kqrEanBXRvrD5nCEIfFOk6Rjdilodc9o7LyXLrQXAQcaYjmG0MiY7NiFXA+6iwvu8zK2XB9qr5wekYlRjDKqp1Mvir1zjsHUHwKgFRULYkWG+fBS96PzWrwkgYfXbdU7g9FmwZccLyQcveV9xnp72QXtDjL9YvZl/Q/x9u+dl2PX3djqukYlmjBrgxgIj1aelOZc+8rt1htErUfUVMNMGjtVmspka2YNtIE32Wr2pAHYNOpHOIJFvDvwdOatImKB2pK7STThtwvUV2TouCweAA30dQWWcveduY2y6BzRuKQQVTa6tbxeCXO68N7tJagGLx4jLBBRbmGtaOG+oYbTqujm0oSftwJEFndO5YqbPxrqIkRx1qjbyDo2oqG0mumK6zaC96HU2uDtOcnjN5/ZIJoiQSZ1M6pWACl8pIFMvsskLY1pibTQNQ35+6fiyBKdIA4Cztz8bwU2mh0goSILuiD4FTTH9UjwmikrLRmjlXRqus+G+dTNqPBkgaLRftolcnpLUPMSYVpvPW/4ho5Jmt7qX1vDz8tAbttdbjE7yn8jCKthZcoGpqxgtj4qhiNKpZmJujqUc08zr8UpwpiJRClwrTlMhd0qGRztm47EpJHRhpwWPOzGSW2hbaTlXIoyfU6qRq75hrTD0+6MA4oVKlKLMY4ypYT4NssbqKBivOApMQvFGjnSlbW3NyEMSeESTomrLsRBs+lxtzxbKUqtaNL1f63xKEBVPk6JYmw30ciu1JEQFvLAdnPTUlMwR+Hl1rDlF2VzF4zZq2e0WCEL3Wbg0a6hm2wMAetNBqQBQkKyW7mPizR/sRU8ngtZ2hlGwBhVgvniqOB68O2aGSX8F6hnq2ovZh9dOUC85F6wNS4WBcXjFKQ4Zr410cWLZa6zbl/5wVsejHYDjLzl0farpl5p4RO4oxWpN0ysbNTpv9o2pAqsyXJ7hgfZE9mQEovRiBHr/WCxUtiQ6c+CHL7KFKTYaWnaMzB94HBUm0Po9TFRAdPoo1g/fXYIWWot2tZeKKfmZPMbs+u007qXbaMe06umkmtZnJZEpd2YYlWpsJTuGJGIJiwJYSl05HbYe5CJjGXFRGkO9lY2xWSl9Mxx6PKgS6EMkuKesrBEZ1xZr5eXJOXDPx5MmiyuM4p7TxlJHGU9U6EXfc1LgciL5SRQCLkn1PMvAKSYJh2ya/H0MvfIpuLgF8FYhVBBc1Isu26dgfe2saRDDVCcs2LA2P/UMvPZ3VNj/p+0NanUc1TaRJwpVg9bJ+s9D/9iFS4XBShoddocZlWZpSlNKackfXzwaSAj6obloBsTlAqUv22WIEjp6mnFRRIavCen8sxWo/bZeYtC3T1CEiVN4R6sgoBuZra01wgdjU1E1N3TSMxmPGIxXTdUH7c1IpNsQwKZTZtgozRp01RErar4PVs/zyQ6jXLdFiUWkuSPBUKaqadVOpXFE/zTQECzQ0gxV6ancwGnXPBstWrzK2m4mZhuC1d81Gd/cRjvO9Qrfe0EHWZoCH1QFq34+nrhWydtQ4N0ZEN2KPowqRa9zVKgNl0lelKAl6mHtk/TvO60jz/js8AW8V9V55e9B7tL3NGVzab109lTlboKG7Va8RaU2q9qMTbZNqx5Vg0b3C5ykVurYlxWy9iUGHLTpsFLrXxY4zanqvoNHDVbru+nqglcKMXNJXqAbAyyD0iiIeCVqP9QbRttMO56IKxw4Bm/W+iTonq1JZs3UvGt2Pb9H+MBccoUCJWk7wxSsE6IqxTPveJrHhjQp7ptyZwnvQviocWSKkDm+klVISxdSpe+drxF29VsLAYvVKnVzO7EUJHM6cr9bb/XC3RFQwwMEQiFyf3aad1HSamUxVEUCyGG461U3EeeumtsWebLMQZbS4upca0oe6rtRBiVjU0dNpvcJ/4mralPCLCodFAsktQzveGoi99zQ5MW0rujTRKK8L5KQPZ2V9VLGqsI4lpYYbrNM31LmgRApBEJchBYLRfjX1dsRRoB7XVI2SHqSosKUPfnAfzthHwWjZOfcSP1aRSoXYtUQb9qiODWv+05rEtEukaUdZKuRJIriM+ICTSvXESiGJNjVX1YhxHGsPUqWjKvRQDJKTQs4dKXUmt6RQqAsOcqKsYBZpfVDxkNz1DD/baKRXBOgfBptNlAttm2hNLqmpKxrnaLxjro7MNWNVssBBXaljrRuaWqG+XKC0HRJ1QFtX1OG1bcuka3FJGFc13kWytEMzanB+uWCO4ESZWlI8WRRyllyoTKpHFQdMI7DuSSu2wdsDn0uxupRlEE6vQawUtu26vhfOD9G0jqDQGuYAV/ll9pkGwL1KSd8/5HA+IKTlficXqasxKQVGVUeadtSjEXNrYOoWjayjsK/rWx6C9t+lWFNEZal6Wa/ovcpzWV+feGUGZGsmZpDEsofbYDBxzuSZAFF2mc4i0w2uuEKiEPAa9Nja1gwv6TpOgpbFBOcyMSaKRMQ22WjTaVXJvg+o+kzT4cmEnlDiBO+LohBeUE1FLTf0Bz6MusBRx0b3BaNhOy+U4nDZ+qoM8nP4oTFYSqErnQXH3soHmV5KK6eCeFQZPXtCKITgyN4hUWfjlWI9iX3DcU87L9aTJoVSLLAweTHfX9MeQveaZUGPPGn2E5yOQAGtyw1qOqUP5PW9XZetxscyiUiAGzaQd7DbtJNaWprQ2UA7Q+N08q7L+OCJvrKGOaVi4zBRRpT2PDdHqP1AknA2oNAFj68jVRMZNTUxeppUIeJZXN+xKEsaQTlLpYv9vRV8RRxd0cbgYkrAKngbqQxODF7p0sFFAkGx4pQ0SvJoFifQY4zZIqKC9ps4PHVV04waHcbndAMqxdhFttB6mqcWgJfhMBWyjPjglcBRwIuOUO/lmxDVKkytqNp6m0ht0gXnM8SsYwOcLlAnUFXLsEz/UDjE5v70mVSnnfVZ9edijIyqSPaOtijVXaRvrE0ameHo5WnE2F/ONtqBzZaWx4UH0bHiMWoGNQoGr0VVh3A+IFVETOGiCEzaljSZqFRSXamOX1a4U6WWGDJYjx/qZF6sETplqBypazXL1bwSRFWwcypUtWPsVRm9quMgi+MsIu0j1yx9825BgsrIOLCMyiPS66254e+VCu505Ezv2MzhDTUQVojp+pUabcU0HD0OVa5v6oZ+VlKYRhpT7KfLZKsjBlVrHZxzwVE3KuXjphpI9GSQgcVXkvW+5SGN0j1crJm0hxIZjpmeGGK/q6oKH+3eZZ3l5YqKJ1crRniIbeqp0/5HXCZWiSo5vfdYP5JTeSzHisZayz4VOs3WexiG363MSvUQVTll0Ljrr2xROn1PrsGyM8nL8Hd/voP4dCokaxJPOakwdSkmf6TPq4inlGWoNwQlPnmyyqb1o3rse/smZFyvNepWrQcRvba5JIJbdtjL8HLfC9f3H8qwr2gG1tNresjT4bJmkB2tUupFcFEDfzdEJNdtt2kndfXV6/WGG0smxGg1n0BsNDPqJU/Imgb3GZJzVrQejfTmFi2yuhCoRiOa+YZRXVHXAYdOAC028hznrCA+xTlo64ou66Jqu1aHGjpP3VRQOapRzdx8o/pt0VSxvY6q8ASwmVR6z4rCIs5YPBRyWF4Eztv466CZiiqD68hx8bohTae6gTiLzAGDK8RgyzJE0DE6lWJCx1aojAoY0oIkB8WTO0hTheBcqGm7gsRMkKBMMPqiuRtYjV3XqXyLB7LCPympKkjqWrT5WB+YvqCsTYYmEmuqH/3v+3MKIWixvC88W8G5x+G1mK3wYwyecYw0wTOudc5Y6Ee6VzUlBiQEliYKP7apQ6pAwxy1B58UCp6bG2tG3SXNXjJUVaRpKkrbqeZhUAZWdhlHIhCQko0uX4bRHTFav5qmYH3VXqNdy3YLxQrRNmuoJxUEh0r6aIuENwKBipbaxum9ZTgrNkBkcPLLm7BXFhxlSKFVCHREXY2pqoYQNZMZzVWk7Ck5mTSPMwHRXsvPIm7RtoXKe4iq6lBV2rsUgtZLetWRkhPWBa0wYb9nOTfAd32G2Ad3MSoU25XM4uLiAJNK6XUxRWWvRIlJTgBtjQObjZVy32Aty9fDCBBS+hEjRlaxe2CKWbpmDUrNouQFDZCDZRBOMz5JQw3RiyIZwXllHKItKwH7TtFaLMnGXWSFihW+VB3LIpniig1EzXYsOs+tFI+OvBFI4Mh4scnWNvizf6h7oocPbmg36K+DEho023TOkZLY/TVHbSxW8SvqyjnTy531hB7dWwOu2MDJ0mltuXfstrkkC3Suz27TTmqyNB3UBpyFw845xqOaufmRcoFSgQQ+9zRtvVH+WpteH6l572iaEaO6pqmjsfU08u8dVdtpv0wxKZelyYQNS0vE4JQFVzpCpcfhK08zP2J+YcTcXENto9tBCRVi2ZKxUDWTK4VaQx56UkLf9e5d1vH1VTU8tMvRjYnQOujLkv3mVEQU+rDzxDY1rd9ZNCvK4OnxeS1KA52DDqQr+tNvpnb9xGSecoFUitYiRKWi+s1Stcfy4HhS6lAugWLjbYFs9aS27Uhd0pqTwSYqeipWYGaI0PUe2uhzp0oBErxtOtYIatmWKph7bWr1AakqOqfNrpNWnVRXMtHX1FbTcKikVggNudLZUSVnbSAOmZSWmKLzcYILuApc5XQ8fbQ6nxEJqiqqrmBUeRoXtH6qY7XFYJ48MAdDn0GJp+qzJaze41co/mMZh9dsYGW2qfUVyEYs6AOdfo2UXoPPivUhVNRxbJJB6niqxtOUQFqakvIEH6Ae1SqiWjJ0LEtdeQhFcFEFl32oVPkiNoSoKuDFTSG3JnJqNSU7BwXalgVd1Sn0TrhvcNa6bt3UhLrCTTuk7Rm3iVJqdRxioVO/VizY0XqfXash37GMp8iQ9RTnNFt3CpuLE8R5q4FqMKW+K+BcP9XbZnNhvU3ST/jFvj/Txyd6jE5Fik12TIa68Yp2k75HSpZVJfose9i/sH5P57TtQ/XTFBosquepLSB6rIqlequnMWSEmlVZP5x18/esX43PxYJSfdb7Y8gD4UIDHhe87pE9ciOaPdVeR/n0meINsdu0k8pDmq2bUy+TUgXHwlxDEyPtUsd0/VQZbgZ7aE9Boe06YtcNxT9vKtgxVNS+IvZyJ0XViHv2UsodXUp4H8hFayCLi0tEJ1S1J1SCjwHfeEITWFgzx/yaeVNe1rEhrmhElY0g0E5t5Hzq8F4YS0XTaPTpxeOLUz2upA16ldPosncmmkxoU2cM/dRQffBV7gjdUAxr7wvtOK3j6IamjcauH1yVCtJZdJbAF2UQeW91rxAo3mDEUsAFQixEMTqy140F+kZJzT4nk4lGYOim5kTHDUynHZOpqh9I0kg4ROtJMVmoVWrU5rwd+lgFr8rWIJQQKS4tw10WxUWTmCqiePq0JBbblmmnvSEuer3/IVJ7rTdGqwvV0dPU0TL3jA81ba4ovoN6uceoGtWESpU1SnIErRKj440cVRUtowo28VnXQirJ5pVlHOp4vfcDxNTDt84ke0JfS7EMuchydkC/wdJXo2w7tqi4hxX72iXOmkgFrQ+JavDphpwQlygyJdawsGaESzoPPCQtxBfFsFQnzuZfSXG4UDMez9M0Y6p6DbHSGU9dSbg00exPsO+yLEUrscM4ecAy40CMNU4gUWiaEVWtAyV1KizLMBxOoXeTlKpDgCoQapXg6gcX9hurE5UQcqLDPb2LWlvsSTmiGTRO66opm/isTgaEFc6wFH3mWiMpqCqOxn+IG4hKLrPM4kvFMqhMmzqlzxfNqJMsD5HMYiSuojWsfiRKSsXURFDVFbHSRrEp1l6P1xS7Buc0wKn0AbPtGXnZCfaBtDOGoFGHrXeuQHHDe5Vx26tVWJnB4PBUMm1O2k8aPNYZdL12m3ZSwQecVhqHjbepAwtzY8ZNTV+k0ShR34fVNtouIa0j5QafDI6pomnoBZqqphddzaWntPasq2AUWjF82fq1JBObQD0/pp6fI9YBFx3j+RHj8ZwdR9/1ImizYL84OyatagGGoP08I1+r1ppFTcojdwZx6uloU6/W4MCi4egpuYdVih530f6Gnl7b48q5FMSrBExfmO3HpaeckCQ6V8ded7Eniig8RVFJG4cjVgoj9H0mSiTpF23W6cdtq71WthEHL0qTdWJNty3ttNPAgF7DzIZFmmqAd8qA81gvENb35bW5tEjvMLyy/XKhySAlE3xDKZoBLE6nXDOdsqGdIl6hs9pp1hK9OvG+r0xKxpmDoQqkPKWJNfN+jGsKXVp2aMF6wpBASUJXOdqJDlxsrM5ZVRUu2Awh674Xc8Q9y22gvg0ZozH+omna2f4tIiYoulyLwS1vQMujJCyqNxhY3ZQYdLTMkOvaRKy0FujFU1zG+0IVHX6uxtWRbgKl6H2IIWot0GqOTaw1uw06vLJp5pibW8t4NG8080guia5bUtWNvLzB971CwSu6kEpSmrRXrcX50YISL5bWs9gsGYwYtWfD22DMqh8KmLVelZOyHutI1VTUTWVDBXVfKCjUlXqFeWcCqsazc05bPejrxFmdet9u0Ad+3vf13EybFPpPUogmUxW9p0YHblYETIRIyQ1ZtL6eOh0NUzQzLM7EXVc21yab/WUkBPXNilYooGSEo6LQt3jB576+xhDk6LpWVETXEUMGLpLp7HoMeIRldkqbxmSZbIn22TjL+0q2hvM+Yy+lDPJfUBll/frtNu2kojFTAo660od/vqmpg0ehe2/F1Q5XirGmhGmbdGxDHaxeoPBWMHWGGCKpqJBkl7TGtGFxicXFRSaTiUqhWNFcCsZWKVTBMZ6rWNhqLQtr1+qo9n5oSlbs3btAKcZsS54ezneAy72Mib6odE9n9SIlZAQXrMCKTkd1KLTiespMv2osJc+ZZCwfVad2uKAQX+gbe0WjIR1jr1kXRchtWiFP1KH0VgNkSqG0OplUi7ZKY47RUdeeWC1vpM5BMaiv7RKTaacq4yitWYfz6ayr3GWTVxGDwPQBz6aqPjR8itYPjd+hmVJQ7TYphWxIWNu1THOm8VC6pAMmnTZwT9uOaWpJItSVStaMRw11pXU2vOmLebEMBBXL9R5fFA4LzYjYZFJRCDZaduuMYaHj7ANpTo9vPN8wGqsMlK+80nFtw3Bem1EVsrVswvUb5TKRwJviAQborjQRg3klU4Y6CUBPDvB2vfqFZ8QBUYeVJZPy1DayCmOAEJyjqTzZQ3EVjkDqdCOqYq3rxHrwgo/K4rQMJYSa0WiBZjTSUmuBpp6jqjYQugkt0+HY+nDeGfKhcJmjqhtGzQKjegHnPNOUGY3mbZDgIoWkmY1JY3kdZ4vH03qD+ryjqozJGz0u2DURlJ3ay6w4bEPPw/gTzzITTrltwTKWPttbQacHG4dTlDDjMm1WGv9YlN2bSlT0xzqXk2gm3VPOi9WIBX12u5yGoEWy1dtcsCb2oYysn1VUg7DvfcM5chAqsUAThoZphAG+79ePA6OuG/U8a6+dL3ZpojqtYlC1FN2rstiEAguAe0QnmcZgP524D6DSDUylbttOKviB3lpXFU2tjKlYR2Klo6q7iQpXBq99U9rDghVatYgfY2Q8HtkGo2uh7Tq63LFhssRkMmHD+g1MJlOSCWz2ka9KrhituIk0o5EOY/SeGCJejBprNyRGw5ZTJreZklXs0gsmbCnEyFAPcx5KcsvgfS/B43uIw3ZpljeoNEADWjcaOrwMjy89Xt8H6fZ3PXas3fqJYgP/2q4j2cysECI6tA59EMQetGH89PLGKUbDLpYJ5qwPT9d1TCat6Yd5qhC15mJ1RS8yNAIrBNYxnbZ0rW5EMURtcA5KcXdeoUdvm1FJnonAJHVsmLa4LrHkAxviVPtDfDCRzQ7vHKMqUo91YOV4rmY8rvAV4AshLp8L3hvxQpUYcFCHeUZztY4Sd9aAa6B/7uWdjInlnA521F6kiI/9ID40He/rLgOMbaPehzEYSvPWOoTuMsqoKsN/90PzxDHA2IOad7A2B2c9Oa6nnutIGLwGEym3Ctf1BXPLVatY2xiamlJVTCeqZddLUZEgkwxelkF1AYMXXdFm2eQcVdVQNyP8NKrcjjXU+x7CdR5CIvqaphkzbhao/AiHBmnRBUZ1w9x4jg3xGvK0U/agX+6z8kGDAIUJZGhaXtm3N9Tv+msOQ59UsOstaK9iH6Apb0TX59Ds0WetyJDR+Lh8PKUkBE/bFopXoWpvTcEhhCGLxWkQRM7WrK9NyV3WvkYVeLHpB4bfOcvEHT18mIc6FGJ1Z1UuJWNtOU7P1jlFDLSvykSKDfYtzj7fxGjz4FOyslt7jpUI0Ct4OFYOxaTP/gcERzNPKYKk34NMqopRHZPNENLoSCNsvNdMIOuYh9JLmhSFqBCx+qeKKiohojV6NrgAXUk6HrxLpKzxVHQqAuq9wiCOQvBCXfUd6Dqqwnedfk0fqa58KEwROOWio75NzidER3QVdRNommYQ7vSiXd3BBUpQ56hq5mKy/Up112maUESVFpQa2zd2CiEGnFeaeZ+Ja5Su0VC/kEEhiJ4qbki0PawWmSXrlXK9A0oEG9bYyxxVVTTBXHSxYxh1ynRtp2raISA969I7hW4LxuiDhD6kKRXthxOdaltipEFVP1zJOOtiD96TS2axnbBhMmUyaXEpsyFMdfZQKrgqIja4bjyutU9pHGmamrlRRTUyGC70aI72Ow09RaDZRVQ4t/JBZWrs4R8K2bmQc9B151SFIEbtzeqbdIfoHZXnKX2nPnZvvFXcJGMdWIj4lSjgEMH3FG5cMbJeX0ORIdQWVOJGjEmKU1JFb5rxTigZQqjwBgU7r+xZfKCIjq+p64CIp52mQTQWdK6bOGXSFaZM2wltOyUK+Ko2QoMMa2XlkNGghTsNQEOFCMYyrHGuQgh4L8SqoqpqRT689vhlUSdBXzO14MsFtJnYmHp9Vtlf936cixM0SzZHpLfZrXi/oRQDLQKKqMo5/cQCUYKSC9pL6cry8FMnzth9+veu6N8Xu4HaF2e9SU6sZxBVncgMEJ7OptT5YT2AUqwhGFsLwa2oZVrWreK9WLAZbA2p0G0WXWNu2CKlP1VFWQDXe6UsJGOEaoKkgb8GRf0f2bUHG0Pirc6ViaYKtFLM+LrsNu2kmlFF0zTKdguqLlw1Kg00bRPdZEpOysQSj6pVD4s44GPUGoNAO201skSL9MoCU/236aRlstSq4kKXB3w/OO0mH9dRISLnifWI8fwCIUSrNbUqYhqVaaWjxpPO89GxOAPbyIlBEkHnSGkvhDMFB6zpMKjKQDT6OrZAk82CsZ6l0qnLqSrtw8pOF3GweocKe1oWYul5WTGNuMeQe9KBWEHYGYNIey70gaavmeGskKyblTrJMmwCWhY0JpNBBCq2apG2Rag+2AgBdJNwYqKnXv9+cJneNo/++Itme8XwcFWtVymZ6PRYpzlrXWJcU4/G1HONMtXGmok3tTLwXHQWPntrFzBo2ahNElSY04UMQdABqPqQ9tlR6XuTnNPNPmjzNqapCBYwYfOvhlqlDEVxBvkr1f5DVkDUzg211uWx7uqQlDrd4//q7IIV+b3VDoplVV60n6wYnbxLConmLIRQlOVmsjg+KMSlcKuOlMFp/c2hzq8tClN1OeG6RKwqmrqmcmus1pTouiltOyG1nUpgJVVwBwbiQ3SebDUUMYabeMFFp43wtiZ0bSjcJGCZtfYslqxz0TTLxthxGTHFcbDykgVJK2P7HhFbQTuhZ9s5NKCClmWtTIe2d6imZeiPJTul5mf0XpnquDoEIWfN2BR21+m8/XoopZC7zrRF+ynIwdQeetFWfc4Ay1Sgb45GzDnCUFoo9KhJv2YhiRa5e4emREd1M04c0S3rJuJlYFKX/nrYevd+BZRo1qupa+Cpgb4TR+eFG2K3aScVq0g9apQtVammVlVHigvaBNdmpCuq6puUPZeTLnJXFM7oh7Yp/m2Y/8QhddB+gKIKxGnSMZ1OmbYtyWpGVfCMYsXa0Zi5eqT9IFUFoSIDk05hKqHQYBBV1pEHOZtSgtc6k3P6/zGRUluv9BJNqtvnyMXbTBlHsJqBc55UuoE+WrLNJPLR2IRxKIxHm7XjrYDtnWYeAN2KiLZXJ88l0w9y7HLSpt0QcT7ocVuQrim8Q8cNeCN2FO0XESHnzhybbnbOfMCQMdi/9VpYlic6fbVUFUUKsQqU7IZrIRi7KWidxfdUbpOT6SWgShY2dNrLVEtiRGZhFBhHrWVWwUOE2EN50eErveZg6uPirBahG3FxDh8yLgRCbRupLDeigh2PXRvta6vUJ9lZZ8mIS8ZazAM7T60vcls2ZVma8QqsAG3vlTLos7FiQ17ucdHgQSPa5Z4q12ceRTvIvTA4uywdOaG1Q6ekjLquqaNAyThXk50faM8hVKg81JQ2TxEybTtRaK7SLLXyjpRaUtuxOFnP4vqrSe2UdqpZta7lHnlQDT8930KbW6Z5ov6dSJKOVDoE0eb24DXCN82+aE6koA2w+lo9bOzZ5Kz6RHaA89xyE3TOxWqGuiB7QFCTKtXoLNY/5Is3ynbBlaStEKLfFdCgqRjEOihSWH0w23H0z1tn0KcOY1S0xvz3iudlOTDDnp8+aOkdqYgprodotUmDeQ3JcZbJCzLIKOmaysNiE9Gsskdqeqk2h66zkoshIWVwmMEvQ9G90wreG0M6avDQWf3qhuzzN+hdt1KrYk0dNfqt6np4MB3o2I6k6txt2yJZKaZJCkH8cOL9Q9zXTVR5OEHUlLg4k+a39FhxYcV7mxjYcn6e0aiijpEYa4o4lpamiMPgsmy1Jtt8i0WjwUHu5Sc90ZxNsabelAox6iKKQW+wIKpkXgwTtw5xCycpRecUCdqAOkwmtUUK1ghLT8fWXXSpLPdolCGiV68hTrv4uy7rQEmXqZ2qvscq4KsKF+MgDVWMqagNhsusnj6K0vlWnpI0O1FnqdR8bxCss81J6yeVbdCC5ET2ZZAT6ovtURWD1blbcqIabplgZA7vTKyzFUqAmBq6LFQ2giBKtAdTv7+KvdivKeYTbCSD6bUZFtLXJjQs7QkNVo8TjWy9i6g2t9YZ+w0Ig+6SLNc4ewxvaGB2mjWHvuBs2Y5CidlWj2I8uk95y5SWHdSQbdI3T/fwoAZhPTjjLGtxoM9L0tpJsgCj6ZSpiAQ83VA0L1kJCiKe6ALJnkKPCsXm1DKdbmB9URiq6zoWlzawtGE9XTelWC1mqA8Zs1Hww3UokmjzBJ8FmXqmaUpK3aDU0d83+npvdMrMhUFgGTt/Z5toTsviqcpEW55mXaym5rMYc3RZxw5bn3qftT5TTKGwdyIeT+z7DX3A9cNOi+BKpoffcT37stfXK0MA0tfoYtBp25IdxnoZashDc7uJFBSDFQ17ZFB+0NWqahW2vgbrYX+0FaBQljPLos9NLg4XDSXAsktDgcT1wZE6Ol9UL7S/tn3jujOYk5KMdt/eoH3+Nu2k+pHecSiaQ2d0667taJemlGlLO+l088LSXGucVKiKQZG5S8kgI1B8EKOwi/5TzMl4R9VEFpqatWvGNFUkNjXj+TliPUK8I5GpG5Us8q6Xui/UvlZGGr2DCgQ8LtvGahpjRYDYL3gd8qYjFjyhKHOslGwSSh7nIlSFzutC0zElOtIjRq0NIb2T0kUWgs2dyUHxaqdYfingo5CK9kFNU2JpotBnVUdGY1UQr8YNsW4oeLsXwRhLCsPkvDy2oOuHF3YawQZdwaZMXdHESBCN6nvF5X4jiEnHkuSuI/usDjjGwdkGe6Ccc5o5WQ0oBBg1er36R7IXji1Om37F6cRSm2ZEFiEOhfN+XIM6wJKN+i9iTFEgGzMqCLEyWQJM3LPICvKCOrihC0j7B/Diljc2h6o5iBvqnnpP9LpmEdv8e3hnKEbZauo3J91csimL+NA3PVghO/R9MDIEJT2Q1WcLpWBzvCC1CjelBG0sOCLBaauEE6eTiXOmJN2w62iNwFFp1BToJlMWp6oeknLSXrnUIVaXxHuKd6oaExQOdSHgDXYVyaQ0obXsJesMDuqqpqprmqYipUxVafBS+qxKFJyLod8n9Dno74/5fgvL+im7Bsf185xK0VlmnsHp90MpsxTEByIecRF1T4GSOlvDNoRQUypwWZ9Fc5grBYmzCQUo+YUV68IySq/N9kr9t9pTWHlMyt7T+pQRFEQnWoc+sLJAXOea9t+ABgUihkTo6grGstUl1sP/eUAKVFFFIdQ+S9dzYngO+nqp5ewKMkvPXL2B+/wNe9ut1PoUl37j0Icr5ZbJZKpQRepnp1j3uShtO3qoYlimMzunN6rYFFqUHJCtaNRDYKnrcFIY1w3zc2OapqGJgdHcmLn5BcL8GNcEilcIJ9mIC70pojCZKI1YH5xKN8AuM1nqSK3WvECoIrheE9DOUWV2dFPRqahaiFcsXuHDqRQb/xCoKyUGOB8HxzRE/3ZeMcahnuFtA+jpv11KTFNmw5LO6loTKjpxzMVe7kYXZ4yepqqVXVnpsppaFJaSmMqEUcCzMdaCCa0GRwzK2OrrckM3meg1k1xoqhqJK7ID0AfezttHnYnT65LFEHAxWh1NN+xcHEStZ8Q6EEcRV7mhXgHaf+WL02bm2KthWyHY/EDOmSRJWwtCsWJ21IXV1+FE5wGJz6jOWs+A6hl5OlZiZSUkWOamupM6fys4o40XrZqvJOFoxqGhDCjxoRhrSvt/lqNyIpaVuIHqnEsx96wahH0mOkDR0ZNzIE9bptNFYIKXYFByi/cVqU0WWSs7rPK6nkqw7aUI7XRCEiWV5JLpUksR7cWKsc/GbWx7WM5gY6i0J2dY+4ALCpFS8MExGje0bYNvW/r5WF2nLQvFidHgg4q9+qg1JGM8Fhu0GfqRPKWYMgRWp1VJtP6693Bhn4k4cRSn6IV3Kjjr0AZWHUnRs+76+s+yMLJFFANJoecc9A3IfiU0qD0RfYJk8Kgo1FZkmNALPatSTK8PKNkySUx+TfSrLSvEGbnDmv71mTMJtSHrt/8/MIn1w0q25K5gtU0Vq00CsV4eGqnfJcPzlRC6Fev+uuw27aS6lJl2LZlANLJDSZm2TTporss4kxvpS58BqIMqB9R1pbUs54kEgnN0XUewGpV4jbiSTYfVzn5wRYjeM6pqmrqijl5VxJ3GopX3VoBOdLmjlDREQ9q1rnTzpq4IXkkSOTi6pOOhk1HBc67sQVWGWOkFJ5NuuFVVU9cNzkd1pp2Qsj7cg3pGZc2O3hvcFlc4ZIvCHQNVVlZg2QXNOnIW2i6TklAI+poUbZ6MCh/G6FUhI6qiRs7K4Gu7pJqIBksQyoD5x6A1oaauqEzCR+fbLEM/fd1FpwzbCIK+VibQC7v0025V5imbE1bnpFp5ugk11QhGgWZcUTWBahTxtbL5/IDD93IygFivjMFjit9DVzpy7hCvDkrH5ghSelFfzbREBJdZ1XzbY/d95O6N5a31PYV2q1jZJF6//FmWOeWiG2fqikKJLgxwoHY35yFTX76OeXD6Rex3pf+3EaeN/KKwsMGUovpw7bTVPrd2SklKJW9ipq4a+v49R88wswzCe4IEkw/LCv2VrPpzA2qhwUQIER9terD9bYwVdT1SEeSSyKLyUz2ZZBke1OPNfoW6yYr145wjhkqnRWO1XwdQBqWKHr7KZK08rlBGVykpfWYKmun3gaIXrW0nstWUbe0NvXVG+8dqUGjLyrIZvFjQ4I0+e1dHohmPTilG9NnsnZAz7yfWvkDfZmb1rWSN4bou7fMKlKL1yH7MiIAFkjbAUMOwAarzdh5igb6IDM9BLqLZnRNczkj2SAx4X7QWHvzy2rfgpyuqrNHl34OaVC5ZZwZ5ZVdJUSikr230cEhfIHVSqJxn1G+Mlc6Q8r4nD/Qqxpm2FfBOM4lpR9u2WuNKmVAKlfc0tQ7Li8FD8CylFt+CNA4piWnXkopO0wzGYis9RFVVhFCZBpyOdS9OSJIpfpmRA330q9FTTuo0+uisH+DXp9Ylq7q5ww2FUB3SplRv14/xKOCC0+wh97RpG2USPEKwvhNBB1RlnNdR0aNRoxlkM2I0qjWOD73Ej6q1T7uOadeqgkWoBnKKhExVVRRxSjypakZVg5Kxlmm0sJwxKQwRKGgTbip5GNUQbHMI9D1gxdhXRZWcQzBSjUaC1TjQrJ1n7VZrWViY0+OvtPu/DFkzOgbDC94t9yCJQXzJ6jQpJ83MUfgnpwIlUMwJ6f0AJFlwgEF/yzR23WCtQE3EE6l8ZfThPjvq5XC0683Tw9Z9vakvbPfQSt9Yqr+TIhQv+FIoQesNeK2hSB+aO32/BjKBQIUjIuJxuaOpatpWh1G2E5UICo1TaNRFMHHjnkbeSy85r5B6rw/Zw799ZhGDBjZ1rHAhWnO2p6pHzI3nqOsROtOtpW03KC3bGrv7mkwfAFVVGH7XOycr0BnlReuK3tlYmtIN0GjqMr6HzlF9uUGRxUoCgmYYwYdBtV2FVfU5C14G1ENHf2rtugpxuaZY3BC8KDLiyNlRshKP+iGEvXKF1qELyYRiVaXBDdcXVnxW7oNMVOIK9PkVy7YKOlS0h5lN5q2IkLtCZ7DuMtlGM+m4or4kw/Np62t4NjTKKkXXVIyBrksECcPfD0GaGCh6bRrgZuw27aR6894PD3R0CudIXameVc5IqTTiMP00rWN5bFgyHmhtAqgDUyDu6apKmU4p6yyjXKi9Y1TXjGqlLRdX6CQzTYU8yUitjYApZaYmlKpEBoX4Qu3pUgGSaoo5x+J0wiRNySSVk3EO8UKSRJewupnJ2xR01lGXqWLBhWwPikbBUoRim87ASja6OfR4e1996Mc86whxR9DMJ2eqqB36VeUZjXQmVl0HRqOauTl1UCFGhRKCRlNlyLIcddMMdHcRoesECQ7ndIx38IG6aqjrWu+B4fGUotGebeIl///k/UuobWu254X+vlfvfYw551pr74g4cc7JB54LWbCkoJCIFhQTNK34qhywIAomCCmIBUFQxEQQ1IKmBQUrKmhVwUqCaMFKkqQpwgUt3MTjTfOcExEnYu/1mGOM3vv3aLfwb73PHb4yDph4gxibYMfaa605xxy996+19m//R3fosbNuqzowL7whebKNL8RjiszzTJ1nxr5jrSppNuohm14WLu8vvLy78vx0ZcoJC4F9dBoiMtRqtNAFmcbDL1DFqrvP2jAF0CXfnQV707EcEJtOpQOq/S5cdDyx+ixOAklI/uuINWV0aW0q/7beldGU/SBNKbvrgZ2H04H/t9YcvnSoJrmObARSmoh+UB4Bi5LTOQSdFSNjpp1DD/2UCogooUh36xXrRUaqqNFTdtFV7unBIceAQ7uzpsze3Xz0TYCqt6ADPJXMslyY5ysBhZJGFFs/enX6+n4ezimp0Pce6Hs/fy9nFfyzeMb0Ju0I3kBwkCg6oYga3bsJysLkZBLySWYIA0hCU8YwRmse1+P7P3/GWtBEVXKkjyHz4x5PgWxzHaN2lW/TnEg6PtFGkUzq3thXp6An1z/5s3u8zJuAPrqvKUwTaVBzpPfaFRbaPOcrCanRPW3UTUbLfQwsDPeZTIqU8Xy7JFxc1zQ4ecLvdRu46/mgtURsw+3RkiM6TnRCzezEr0CRSiW6TY0OimGDhMgCk+8qSIk4OQTiEp7gi9kRhI1GP1QPodvow80dPcfJl7Ti+keHCeXCnHNm7Tu1d9YR2GpjG/I568Mjx3snB2XgTEU+Wq0PplSoUyFn47G5ZxcAKjI9GHUMdXwEsWy6TxwMjrj1A6oxjyoY3aRFKNnp4gd8IJr9MRXogDjQ6gOaGM7WUbx2LokyR5Zrdvr3oQEa6rq8OMhk1pySLMhqmqRFOdwRhEl7OGMIlFQo80IpfrvGSNsrvTagqcs1U7Fvcpff2y7fv6QbPaM9xsGejClQSmB5njHb6Xtg9gct50S+LlyWC3OUCDRlCZDPpIiOtDiYe+EFKN+xIRr6HKUdOVKcg5NtwpvWLBwNp0MlpmaKEL9zuBwLbxWmNo7GX4fIMN0HhOYR6M3fh76OmhVfR5v+3rE7MAQDdWtODz6MZTOBA9ZRRAtBvLR47DHzTM4XsOySjRtrbBgJ6xKkjqSDtZ+EAnf9xsjxjZAxTA3F9frC8/zE4Wa/bw8ej1da3XV4xk4I+dyvxaDZ53BVEXQp0kWt9RSMH15/h4mqYdTmrvQe6RKPnaIh1OEkTETtuZAz/PDy8Kb78x1yTP7zuIPEEIXpmFhqVZFKOXuzMqhW9fcGjNSVyBIjllR0xyEhMDlJWDgol4MUvam0offSoTdveoIapsPd/khzOHZX3UQe8ytK7+38PA8mpg3JTYJ3sK1V350bVk0CayQs7lV71TJFylBSQUhi7x2u8hLlvjWozcwJIoGUBrk05mkwz9KmagnmkMUv8PqlLlLLpTh5ILwtPf1wCeYGpSHiW22OmdNSZERBdCTphUKPiis/YAkCu4+/1XdSOFvriEQXNT2TrNP2ndd1Zw1GXKWK76apbPQub8FloWYtMHKOHh0hr8DaKt0dFQ44aBwFMmpRfNjbnwp5h3VCHzpQxyB03BXjoL2/mdiaL17f/Ap1cwc7IssH4NlVXS4cduyeJjGLYsIX343UGha1MBXVXLg2IZGTnQLk7EvplBSPISF0ZCozucwoOr77AWXnDuN0lbZjqWvushNcdxF9ua+OL5dEzAahyAcvDWgXLt5UxBixkijTTEkzRqQ2Ff6OinskntT+Y4AyD5M83QpMfw4crz8mWAPrb8vlgXZiIQpmCiZH+9O9PQRCFMSnncQxk7x1yf2718SGspOCDFcVbqkiM45FOPadf6RTG5isfCxIEGtivx1FEsQ+zaUQw0zJV5b5BSjUrTPizrBI3Qd1bYza3UNSkHIMdjqLBIxhFevhnKxKXnj/4Xt8mOW7RzdqW/kY4cvnjzRHCWJynqHvm4Z1Yiz+cQasqWHZd+3HpBWMDoP5gefPXDAouaj4OTR35Lwd7xRzAkyXIDsbnoEl8fXpSjGCmheHurq95SZxwKr+NcfQLnqtuzR8REqM0uMVSSRSfGN0Bm9G5PLvzxtB8HcfTqLJtBTP3XhIQb6Dp4bu+BrmpBhfiaEJ8nq5cr1coRnrfePzly+01kgWT6bhMYb32tnW/YT9R3EqfJqIIvSfWVSW1BQK9uyKqXcItPsOVdNYZFwGiZloarAEL/4KFKkpJ8o0uWbGO5HmIj9nUg10EHMs/IKs8kcQ2Be9SGkpKspqtyMJV0VhdIf8amdC4WtlkjDQfYGxYbRtZxuD4fEUwOlJxmwwIjV1Z5qpE+dgCpnEn5q2ymmbM/rBkBtKxm0SyHn/jI0392lZNB2sLj/s3a3gVNnD264K3w2YbqijKLVW5VHoBTagyTCE7HCQJqR9VxRBTMHNNAe1DWIs5Di4TBNxSm8P+/n6+V83h39627Eh4eUBkeCU2t7bz+0hYopEx8vLlMglMk2RECemKbLMmX6dJQiO+WTKNdcixaSIbXMosbV67u9idAvRKFLIYXh75FbRg5tq2Fuh8MPm2PGYa1VCTpQ4O0QXCETv+pMKjUU4Nmrm1jX4JOe7omYVM01RYnMdjNZIN+1JDv3Um07rjXQCblsjpE2/Pw5o8OjdEoxCChdKeqbkZ0YP7GOlWab1QKsdq508AlMQnVvTxCAc7gHDvNlTwqtFmKYL18t7Uii619JgyYn58YUvLoAVJbwTYkFSgkbtjeJ+b30Mtrqz7WLttibB+umrjAN34chaErzf6A4D67ONFog5a1ILkZQmNa7OUrAuU1bzCfdADYbvcg65ilbBbk9FwHxxoNzkCJbY1karis65LjMLE9FtuZKTFt7QAhe7h0FAMSd9GK2qYc4x0p0Sf5gM0E37+HAotMRU7h41zxBz8ml+4pIvxBRZ4kLbB31/FZqD7vXucopaq/Zs5jssO5jJgVFUlVNUSGvOiZElFB9jSK7gpIjWZJAbI3LjGYFomdFkPj2q8QtmHv5yF6mSCxePTj90HwfMMZoH1FVlKB0fpJksfiwFV0/rEK/eUbVh1N7dr8+7eaKLgzscmH3OWFBg3hHSZ6K7aIEexObBb+JeB1tfTweJGCN7RLoF1F0tl7eO/7AJGmYyeD12Yq2L9m1F1FL3lDvEpfKWS164ZBbaQ1ABIvqi2Isih37HvAPWg1h3Y18be3WGpIXz4DwYP30YbVOkQCqBZl17tiCz2JQSeZLtlPz9/GE/lfdGcljqWP53P5jNpyl8Xun+944jN6bgOLmKUzoYRL53KiUR48Qogl+iRRTAp4n0iCEH7WHoQ/57fp+c05QXlOh+caAGxiSWOpRV+BfitGZyfVWMkSnNLGXWnzd3FnAZi0SXh7ZK5AvQgvyADHPOxBGcM+aHlC/2A/FtKe4FTdwMTZnHAQq8ESyGlFbHDtNMBBF93SOxOeuQbHJK2Ftk2wb72ggDsiVC16FoBP1sTmHuo7uRbfbPQx6V3TRhSxQ/CKET80RIGTwA89Cfyeapi1bt7/to9g+HleDfF3DT4Uh3mzChEHjBzxS3ERqtC1I9dj8RpjIR0+UkKATGsS7kEMIe0Re9m3Yyjg2HEGAEhh1s1VkyiTHodmO9f+Jx2yXg7powLvOk598bWbOjacSp5FBPCF5kGn9YSUHU+j4kHFSji84xPy9kCBsYVYLZOS/KxqNgozOniQ8v72m7UqhjVNqyQkoPe6d46qlG19Q9mqlIMqgYuRk242xWJ180qHtjqyqSgyC4ug+iVeiJthkpO+ry+BVg983uqiDmm0cfo8kgDNQlD1kBcT6sWuLm4g7aIdC7uja5QkfBa9E5Yl2YeAjpZFwNx6HXfWcKWYWqNvZtA+96hDz6xJQVmNhNhay6hY086/zBzAmbvS00hNObRv7hi9kxmjOXAmMUmiVRYNGhf8Q3H/HOw6cb0cvVSdowrAdGE26dT4ZhhBEZtdH2zr421seu6OrzIRpU27lzp5QhfNnh5e4GqKlIfzYX+RnKGig4LRrfyRwTi0gLONllNC1ebYyTIXdk6YQYRPN3rHya8xlJfhhbvjksaAcWHJlXIRCzMZuo+b3LASCleC7VQS7vwcSZFZwa/YBSl6tWQFNC61UHBr5X8n3QNCWnPE9MZSLF4s7UHXOWobQM4DidDgnTTOWgtViLKWGWz0IwONJ1kzczyd/DmzgygDdCTqt2KDAGZYDJhgn/O34tkskzMR9s0yDYNs+k8szgI63qUDpssULMTs9X0WZ4HE43QtBU3q2x15XHemeaM1O6aKJtjd0PvuC7HCFOg+hNwCHWDiGRE1ymK2Pf6fvuRcv3g2jv1G1gxz65GfvYRUwlkpIHfqZAKHrGUsjM84VpnrEA+77LBcHvvegw4RiV3iV5EWwuRmevRkM+ezEV5umF58sz1ga1Z26fVsZ2Y3Mz2WBBQudhjMN5ZLh7yCbWoh2073g4hwf3rx2MbFjzRiMlrEffoekROqf4YYx66EMjiYM9qftwniau1wuLLQSgrusZPJlzITYlVB+sxhih5+MCvb3v6PfSYZlbe2XfJZPpnjWGGS1A36CtgXXqTHOHEWjbr0B8/BiNtkd37O4cKZXBnQxGymDd6c9vXmgpQgoHlq6OMyUfq0uhDKPJD0heY1ujVx3WbTS2feOxZlI02sjaW3U7tTrdYZ9mg9GQEWpUF13dHX2YQRyyPIliHJ4uB0Gdmkbo5vRpHXKCCkWV960DmEwpm2dB1a3rsPEm7Ei+PeMyquyBgkWsBOlTgvYWDOSrVfWgtyZnZE2OmmpCa8TYIWZODVqwN7cJOikVRY0Ax+l5KPVdLebQoliVRye7t3oedgBHGm8I0oUFt/XJORNdjK3rfkxHB4nmWDR7N6zvIDzcA3jMrwccRcidBBj0ZuxWGeM4NAK5xPMw6EP2NbW6JDF6xpBPJSlO5DSpEx7x7JwMF1T2rkP+3D+OU690FBTtWf3nGG8Q6RgdPPrjyB8Tcnz4uXW/8McU8p39Uzh2ZXCMKDF6RlVW2u4gkeNMdur+ci1cn75wuXzLvlc5N/g+NaVEa/Jm7MOd4M21dX2nsjNS4LZ+Zskz1qDkWT9OypT5yqgr5lRurGOWfSLUJHGkQU9pZkuK4NFE6NN9SrTggNdhunoc1j6VDt956afWP6VMLMuFXCbaGJKjhMSw79LYASTfqM33vrzRsYd1YlFe1lcffsDz8o7ttrHvlSXNRBMVe2QjhezPvtsP+T10BJj2LpJCCMN1joE5C7kZbdceN896Zrq9kX04yDmC6EYT1DmacRAFzYySMhYja91VmNF9GAGG8Xp7OOyJJj9vTmPT80ByVxNE2mlHwQxIxza0j40hC7L3axpsMBJCfrog1WAq8r/I65e6SNXWlQmU5d2XcmS73em1kU16BqEpkZwCPXbCFDH3pJM5qSAUPGsnh0Lug1CFZQ9zbU4f1F0RBLsn6NYWnX6ufVZ0m6EQBCkG0+EYo+xLYk4E115Zc7quO1BfLjPzZSJlHRynxOOwGfHreRyY3xUEHnuxgZbDrXcsKnIei+4RKCZQG04C8Io90PI45sRhyqqXb1Ci6Ns4tTYdO7RohChhYnSz2ewRFDmoWB3083NX418XvgM/uQEt4SAm6OdVJLlEurlMNCr04AeqPuuSivzF/GGN5kUVE8NoaJJNIfnPnzliL0LQdG1uoxX84ey1UT0QUQNuYZlnSp68W9+9U29yp66VZlA8IiaETLSCjYQh70ex7XQQjWDUJoJLRCQafB/Yh38W8QTpANki9SD2ZgzHnuzY4SR3G1CTBT7U+PfE0mn1c9jgJM+Kio5cBYfVgls19T7IMVHyTMTIYeO6XHh+9xWvZrA3ylSYpomDNSdKtRwk9B4GZo2AcsfW9cbnVFjaM88XCeeX+RleEo/7F/bHDbN6Wg8NJ6ckv+fCKTNJfu2iP1eFgGA7G4NSCiVn+r7poC+ZeZpEPW8Di4MepKskRmIspDxDFwSZQ5F3Z9gI2CkEb6Nxf6z0YSQ0UXaf9rgsXD78gA9f/RFeLu/YnnfuVuk/+Slr74zauIZF0HgqlPlCSIG+bkJKmk+2Q/T4rXdyVtQHcdL5FSR1IQjhMVxL2QU3HvtBsR+Hi40Vm7KtG3MqmD/jFtDEFDNt36hNavMxBvu606ru0WEyiy1+L5pD0Nr/H+QbpaMbgTGC77P8bO2AdZG5UvCGSmGqAUGgv8jrl7pIGabI96kQc2CaCpcSuZnRb7v0HVk0YzODcUQwRLGqUtaeCqdTmpg/R0DX4VIA3ukMwWpbU9T0GBCcmbMfcGOMhKzYguF00qmItk4MbLViq6vY0Xt+fnfh/btnYnFH727Ktqqy02eYbuRxhBLGswsL/r77GXPvJpVNGp6Yy+mqfrCTbBi1mS/4oUyzw1jjjdPgi/cY8ELgTKNUSLFwRFCHgJM8IiUnSplOxmKM2fdlg37ayuCf0yEdcIaiU/StDxrBvQQjIUvsmULA+k609AZRBafgugakdgUZjtGVtIqWtqTieyg7JyFNptq7tC5KfwAYkb16yFyIXJerEndTlvnD6LQQHOrrDosYfUSSDQby+Ot9ELu3sU5q6KO9scmCr9ldQCqDUcEfMX3Xfke7E9GkAzlqelRcun6veVc+juEpHrsOj6Xwv3t8zqC3FZL+vnKLZPi57yvwIKcXShKlPofIdbnw7sMHsMH2+sXdPXxSzRm6w+4utvURXtB6yqQoF/3LnFjmhalkpjJTpkX7OzPqdjudJJqHY85ZDjABTiZb8MlZ2rekyftolhzKXxaYysKUJ+Z5cQhehbt1ZZB1J+hkQ9BxzkrI7foMl3kSM+32Wc4wWffBFHQNuhnbvpOXZz58+CHf//5v8v76wt4bPXS++fyZv/57f91NdgfTtPDD7/0Gy9OVYcYt3bjbK2u9s/eKtUH3qSrnwhQm5rx4kq8MfTvSZdXWeeMIQgs+2e8yHgBd69vtQeYjoQ9Wn55iKry8+4opFdq+81i/aLpO2SNatGcnBDfBDcrk8+c1RnctGSqaYige6I9QicO66ozfCdrvlny48ePyib/x65e6SBECl6eF62UiRmUnPeJgnxN9BdxxoJRCs04aIkyo6svSo/eqLsLHcPOiI0qoLkA6TGBD9GmlcziuGdH3Ms668QfomKhKSSxzYZ7lfJ5yOGGEGOB6nXh+vvD0vEgn1Ad72/2AGT6213OsPlXizoKzjhbrZmdGzMkaDFCC/AFHMGkcesCa4LzgJp5nONpRkG34ctshuoR+rlTks4YW4ClrZxLP5blDU7zRUk/rnX6whfgO3q69ynfta2JKhIyz6QIhq+ueGMRDpMjhfzZ8lyOWXrfuk2Sltp2SItlmmDS1mWt0fMiiNWPbJdQ2G4KVwkETjsRcKMXZZl2kgTY6W62sdRdzs3eh/llsNxXwTDyEksMNiq2f5AJjnB6JjHAWqdqbEy7EnjymTRBzTa7gaoBC1P1vduxNnB4vto3/t6ZJ9ZBgIJiUHPV7oZ9xKZIpNLb9zmChTJWUOs3Usecor8YyFWpR9xzyRCmF6M9XHcpJGya4OSVR5SXXkHZunq8slyfFo4Rdz+CyMOoCfafWnYismHpXbAU+LfauZOZSMnsUGUiNjCaLPBVKnZjnmaUszHkhJb3vnDOjd3dBeWvmTsM8f15HN2LMXK8L1+tF+VcjsK4rU6lgimE/XEbCZWF++ooPH77P995/j+fpwj4a6/Z9vv+9H5LLzGafsRj4+vu/xm98/cdJOVF7Y45X5rTwOXymf/5CZ6dvGwakKZJ6ZBqJ6+XKHnY9ty6TmMtMCI5+dBdvi/ftz5xWG8MLzuOx0of28ZerPAxznpmy0JvHtp6SHFdQqOBnRdnE4F8Q3TBHg3nqzZqmqYT0nOY60xySUKSofWAqmTI5BP+roJPKJbFcCy/PFz10vWGIFXQQp1TgpRU5DqpBFIuvCSCRM7bvWwz2pliK3jt0Iw45/qaYzuylEDwlNYht1YNbhTqzqiRNDfNUWObMPGenu+tCF99FXZ9mplkR3tFhmOH7kZwURDf6sbfxsyaIKo/JbDaGoMZ12EnbLlnMt5yyv3dBmju7diYpnwy0I/U1BM6OXmQFgygWZchalmu1ryTkg9ZtLuY74LPejuiJDsfh3GW2SwinBdPpNOCFSo7diWQqwoFwumSEs4t7A8IOmOwocq2L8qu9X2eaI2lAcpp4GE27D4KHInrH3ofwwoQm6qifr5TpFEsaek+1VvZtZ91Wjs6mTDOXspDLQswLy3JhmS8Qg1P/u5MdIIzhOUliqJ07E8+TCuBpDPFt4o9Q8JgFi+diPYRAHd0/v7fo+OGTjPGmHZLEwHcE0WPCx+GmoWdD+8EdCyuP7QZhFgTlO77amqBqsY9I08S0LMSW2dpOaCujBWIpnvl1iLgDkJjmK8siq6PsDNGUmyYk1IBpoimkXE7NkSa2Q0zv8e8p0Ju0QSl6KnCUJOV6vZJjZoozpz1TECxMCycpQ/vRjtF8V6JnelkuXC5PXJaLIL80MZrEzK1VYusiMoRASzNPT+959/Sel+WZS54o1rguzzy/vOPp3Qsfv/kJ0/LEuw/f43n+ipBEsU95IYSJ1hKtJXa7s90OI2Y1yy11bIYSi7thDGd2+hrggIRVOQRN0tja7k+k4PbaO6NuLPHCNC3M5eLNqZGSPD7LXFzr3E+CWdJtSM46VwacLNrkKBYWIWkXO5zMod1aYEqKp8klkkpkvmbmWbKfXwm4L2U5IKQcSZbZh2CuvVa22gRlebEYrdNMWp8QInXA6nTu1gUVtSrT0rqL598bWOveLQ+mFCBlnqaJyZ20ByaKd0zsPilMiMVTcuRyWXh5mpgmLaHjap6/Y6QcuFzUjY4zwtk7uqGDM8ZILjNrX9n2jYFcy6cp0Kow/xSUOpuDRx3EyJyntzwklCbbMFLUVKc46u7Cu06yfN6YgJbTzmyUL2z0KWYwF01UxyREPFUaWtqG7qM/XhCGw5HdiQ/6Hm8eYW9FKufshAC9jey08eOhPCZKUUbcJQQp+PsmfVfdpCsK5nYuUR5mR0YQpo601eZGl3LvSFHfO4WotOdSmItg2j6Mdd+4P1Zut7sOqwCZTAkTU1pIcSaWhcvlWQ8hWpZLiaB7s3RPovVJ+ewmQ0RBeiYH6ZCI04U8qVBaqyTr6uCdWMFB/bXmjcZbhL32N+XUULVxOHgEDhsmEL3a/LkZIdBLovU7Zp+xnsnpQmiNdV/5cr/xuq26VqUwzwvzvDh54oLRziV9CnD4q6cg89hcFqb5yjRdKFkQm4XIbf2s/bKz5HKeyMWJG/mQU6iQCx2R9dVx31S3HuumGIynpydmnxIw7WXWutPRvk/PlMS6te2+zcvENDHNC3OZuE4Xni7PYgFa4t3LVwwC+7YyHpvs1gKUyxMvTx+4Lk9MZWLKmV4bc1lYlivX6wt1GHm6Ms/vCGnmcrmQp0KncbnfyNMTT8t7bp8/0Xdj/fgt61rJI1OKsa6VmLT3C8MoqUiWYd8hgjjpR4Vce8bh7Mq1bYQaeJqeWJ6vXJ6vsiwj0VvVNcqJ5+cLy2vmfo/EFggZYoEyBy5Lhpi0vzIIOVKKyDPGIJaMUh8MS1k6NlNA6ZQj0yI7uOmSuMyZUiZJen6B1y91kVo8Or615smncmTuXaKysytFE8EQrYreA233BiAn9t40Pe1VAsDaNL7WDm3I9TyL7p5T5uVp4bJMpJKpwcdtn8YseDxylgNCLr7fCoCzkRImPWjAM4uylurfmTiOKUNK7sDeB5vTX83jm7PvQ3KIlJiwKOhtKpPj58deTf5cwSIlFVIIdDqn0wqCTg6LHz34uzzqGPSxEVJ36EZsq2uIlOIr1XBAfHDEyAeSZ/m4a7nDKwcl4Pj5QggeFaJpKEVjxKHICi9irbXz57DxHXaa+UQV7W0BXTtj7xIjhw5ZjD5zptxeK91g74PHphiSPowYMnNRU6AClSlZsBYhM2zosFs3Ho8Hoct3MeVEOqQAJufuUmYZqgZjmjIxqXuPoUIu1CBYK9gAa04SEY+4TBPzRQf55fJELgXrnX19ELqYjwdicKRkvemq7DywjilZAuCONf38YkFK+jBGF5OzN01i1rBRMSol77S6ahe7VW6PL3y5f2FtlevlwtO797x/+YqcM9v+oPcdQ+QJubMrknwM5T9ZSO5qPlPmmSlpnzSNRkiRanIcKcV3kEdmWFbidh8QQiLnhXmWoHzbNvZ9kzYrJHptzlKUVm/KheA/794b0VTcUohuAODPZDCHoOW0X5Im6Slpl92TQ9whk+OExUNvCZYLFy9sJSemKbHZ4dqfPDTUQwNDoTw9cX154Xq9EiI81Z2n91+zfv7Ix5/9lG3d+fR6Y987exxOe9cU233H2Vul1u6eiIdsRQ1b73KLwSUuKUalGWcoS2FaRMDYa6WUSI9Dzu858vLuia+2d7Rup1h9umSWp4nrMoH7kfZhGMeKQPu7aS50IrFBnBO1Naah+6/kwOV5JpVAmiLLpM/pV8JxorgVS2uCq7pHUJ/dIrJmsWBI79K1l/F/6qiMWJ340GlbVWHqnq2zN0atpDGYY2Qqiadl5vlpVrhajGwG+4BOpscmvUkx70L8JrVGcObLVgcHGNO6sdcmcRuirO9HknAbp5njXjUd9uGWJFHLS8FSEqMC2kP53iUhpqKZnblYKWRCHLRh7C5UJmQx4br/Pr4POg9AY2sdi4OUZ6bZdS3hsA+SYHnYW86R9D5GMjt3e7ROt/EdOEGMwBTzuQfKMWj3EI54FTgccs/dCkECyoNuPoaWus6Y7FujrzLZrFtny3LsjiYvxK11qnV2c4KGs9kWbyZK0UGaUyEQRIjpnb27Xcy2k5pBE7OKY1KN8SQ8gMuQAw6laDdFb/ItzNk9JgbWwPoOXU7pT9d3PD+9p5QLy/LEdLmw1w3jM9v9lTbujF41uY78c3CnnTCWdmtytQguQs2kIkbjMB14vXX2VbB2FA9D0FuLTHGmhCwEYpPJqTnr73l54mV5YS6LpvUUaXQebYNcmJ1xXwxi71jMBApTnJlzVsBlVMORQwaTw/3eKiknZp+QQnAizlTkAs8VGIz6QDKBSro+EzD2x4McFKHCOFIPxCqjZEqbNHm7RVEsknqkNFHyQkpy5t9qJZeh3THDzWM1hdbHBuZMWETasWlimi9M00ya5Qk65QvYZzpyYAktcDdYpmee3r3n5d07rsusM2zbuaSZPV8Ilvh0u/Ozz1/40e/9norMJmcVcAlMD6y79n4lRXKEVCadYbUSMMJ4k8KkKbBcE5fnzHxNjNh47DdyavQg+6iBbKEuT4V3+8xuM3HRznu5LDw9X3iaJyyYErr3LmmNuXvKkkgFckzMJJFwaobqxtolkguUKVFmmXIrrfxXAO47l8sooHBdN3Y3e8S7ChuKM649yCMs6s+P1mk+0SjDSVHzsWl/MUwWMKPqQsYkAsT1OjEvEzHqgNtq5bY1tgFkU+x4jqfrQiqCZfYhaKkNBc0p2iIxLPDYNvfDU+dpHhIme6M35+0+7KTULnNhmhdKStS9MtxjREK/N5bNyZ4jyEkhoEklJ9dz2Unj3k/bIYl7rZrrHtxCKsZzF+Ugg1hOftBH1yQFvzZevdyJ/ig6vLHWnJkXhjlDSJc1usO5ORvL3LvvdCDv3RNtA2F06TBao62Vbd2pVXDjvnfWpP1arJURoALDf17zyaPkiRLF3BILyd1Jhp14vj6bSkA4e2uN0bQz0LXR5C7mVdXUSTqvyfm5ZkGyLbU3v7kwsGBM04WX5/dclmdinHi6vGNaLqTwoM1V5JBe2doXbBglmlOSXS0zXAoRD9bXQaw4thaej9bVnbc6WNfOvm6EAMvkqb85EilEkrsROISXZPqTY+YyX5mXK2UqzFywFLg97sR1ZS5yeRjdYN8ZKRLRZ6vu3k154+Fav4jAM02QCyEXLtdn5uXCcr1AiKSi58WGitRrzrSY2bedfn9AkzNKJjHCm7UZfZzOMWKbakrKy0RJhefrEyUv2AjsrbHVxr7tPNaNEBTe2BnO3pc1VK07JWWmlLGYuF6eWKZFCApO2y4z3QJ7FSV8eXrH87v3XJYLl+WiImVGssAeImUEPrz7wFcfvubrr77hpz/+EY/1IYIKg5R0H+1Nz8BBZjqe1RgTISdiiVgQ9B1cWlLmzLQk5otIC2aDZg/3++w0HqRpEKxzfZf5erpwWYUwpZJ4WjLXWWSwYbITa1XGtyEEZ+lJ0qCgRbnuJHsT2mNd7zNESvQUil8F4sS2Vtb76tg0PO6rblq3zzd52Mh13MP6xvCuxCmk1YY7nTdGHcTuugwb0jI1mdWGeWIuE8syM82FYO6nNzwB1KBMhRLfYtRhYEFd1xEAuO2VUYdMakOg1hXgLAChJHrUgr4do343pxl3YpiYS+G6zMyXixhpBKpxOmw0UdecbupBgV3GpRYED4UEdas06xQSXtl0ILhTsXkmjdykNSGVM4bb7ZjEKfdlvdh5xw7BPLpDnnv+tTxmQBQMF/r5zSpM2x0DnASAQ1lndk/vOvxMjhkMXadWO/u6yyXjwMNjIgWngxcxl0YKCjlED0kcKhoHlNi7nf6AYqm4X9+Qy0AKgSbjNqcwHx6JPjX2CttDcCuZ0IzQ8biW8nPQ5nBoNiWDGsScDJHecUiqkH3STKEQo1iaWxOr0LIXJKdiv2nnjgIV3NHgYBdqx9i6pvO6Nu6vG9tjExx9yUTr5KRIGusGQ2hEQNTlFAYpT8zLwuVyZZ4X8LDPy/KFddtl1GoAg5JnegzfEWW7hVOEgwpUppmUJ2qtlOVKmmfSPLM8vXC5XsBgvixMRRHx6zrTQ2BLhf1nP6PWT0p8DkGywGGMYJSsyb6HQEgK1pxS4TIJqs/TpIYgJLatMrqaTBu+ly5ife7uTjLQJC4z3MDWJToveSINsL3SUmDESMgFLJDzDBZZLk/63FJhzoWleA6bMzb3AfM2c5kX3j+/4/nyxKdvv6XGSq3ShenwP9LC5epBCIymYpJyIKTEINMZtB4oc5bxcgzImXz4tB/pY2Pdb4xRmZZExHjKE2Uxli2y+w55nhKXuQi2Jp4IUGvC2/uIDmdrLy1phyMNvCULRxe8n76DvwrECSxS9y7oK4jtM6o6RGXVCI7p7hRRhzEQEYKY6MGNUYdG5dGGaOpBHaV1hRy6Q5l2S06nHb5v6Q512UB6DZ/IdH20wB9Dlkl9DPZdWp/RISgoSsWtFN9daX/RnHCgpXD0g0YL8VIKyzRTUtGUkzIjdc+BkcFkr4Pam9vqqFjMUV2ymf5cGxL+lpQJIYmMMI7k1+FSSiiOzeecWUpRofIEXusqEiI6uP7LYb9jn3Ti6SESOh477UUsDPrYXd1uyqOx4yHUZ9uRRdSxJk6+XxLUO9jXyr7ubPeVx33VYRMVNKdzysXFZ4EUuUTXuLlPnzrE3gYjSd+TgnyfA8FNYaEFHf5id7q340GjH3J/qHWjtZ1YE3vNHniZSVF2npIfGCPs3/m59Owq16yS04RcONRwaafn06RDkOlIscWd1Q/OixljaGknbVzgcD7vvbPVwfrYWW8br5/ubI+HYONWCGzkqdJr9ynRpQ0nZThQ5oU8zSzLlXmenRqeSXlmXp6w9eHMs6NcAgN63RWs57lhtVf2qoiaEBLXp/d8/b1fk/i2TMzXC/NyUXxIKcxlIhDZlycu/Zl4eeLx2Ln97KfsrTmV39jazrxIf2UB9w3MLNOFyzSzlAlKdOf2dAYBBtQsDmeR5qxdZGxVO9o+WK5Xv//1TNwIbzloQcxDXNiOBa7zxf3+8HtA++Oj0QN3TncdVzBj8j3Xl3hoNr/DlMRDH+nOSzQKnSUVUknkCEssWITaBvOlkKdMKomQDAvVmZKSVfRR6daIA0qJzHlWHE+KzL5eKFOi+Oc1zEh9kOfkMh7ds1oRyLXGyciEHJ1izfEUcXhGjiHLrl/k9UtdpGrtdB85gwfGHWF61U0fx7F0DKItyGGhUUeVcG+IhFqbdlHJjj2D+1IN5U0d/783N3nsnXWvrJvMaMXfTtA7obmgNwT2+859W6lt12jsSa29Dd8DuE4mHgm+uhlzSrTQ2Lr2aDacht41Zh+JMceiPIXg7gbKuhnu03cu0ENkHxsRWaP03sDp0Qz9fL0N+i4YK6fsBrxQ3Jplviws86z8Ju/ajylHO5Y3vRP4jm00tlbpbegwj3JaliN3J6ZMt+FR0ucmzK+w9BY61LprP5QRZCYblt6Mbas6dNfGvrt9VIbWEn1kuRdY8MV7pqQsB/UQsVHlH4cEsprSRLSPIeh9mSbrQ+gt/0Z1jMdiXuaxQ/T5WiX2NSOnQnGqdggFs3RGy6SUGEOWW210el3J6wObAzEnHvuN1DPr/mCvd2rdJFa2w9vu2OmFc0rVQa37c29vrD9zgXetjb0OHreN+5cHr5/vrPcHl2km2EbKO+Wy8dhXRsjU3QhDLvXbtpEned3laWZeZnLKWB++AxMzb8QuT83Y2e9y275McoTY1pVtWsk5sTelN6+1slxf+LXvfY+Xl3fkGFmWC2WeRGJJRfqoVFiWmYCxjmfKvLDvlc8ff8rj9TOYsbeNWjdSjGwhSR9oMM36u/O06D0nh9Kbdm10seKwTo6F7DY+Ef38o3eulyfevXuPWlgxKbfbZ3727c/43ldf8/VXL5RcwNcEJUU+vLxwKRN13RjtO1KSQzYx1Eze68belSqwrqv7e7rWMjtDVF2qGhf0nI5okDMFTdSpRKUwd93DhEGzTrei5zkOBruyo4bOmeAMJN0/wVm7YhDGrHieEM1hfqDoWT48TNuU3KrLHSWSIE8liPtE32E0MYarp0i3/uaa83/1+qUuUs3ZVqBC9bivrKu6NUUkA0j3MtzUtZksc46AvjaktTisitLQyFxSEuV6DJKpUNQqsWKIynWpu9g3+OFVnB2kaWjASB45727CY0B3ZpobR4J0RS06fRc7GXYpiCEU4tC0dPzcfVC7UbriCg6Y78iJOvzwVBTdmZvAiMHD1zKjNQ6n9daa9FVdzg2jizZsh8zFjjA8TSAHRf4oKREnMhhwiIJjxDxeo7aqCA83OE1VsfWx7nL9wJ3O0fL1+B6HNuZ4BXObpoAo9DYgaVKQC36QS3k33dlOBAhRouXsjhjTVIjOGKvuXnHsbg7n7u4RHt0/4xwjPTmVv4sEcvyZwz2e3rX7sq7oBTPiFAg2Afr6Ml9Vxk4IidpEEQ9RjMpHuNNGZ+87I6jI1bqzbXfW9U7bN6W8umPDVAq4Big7bb53ZS7t+0OHyJkNxknsOLykhsMuQ/Q59tZ4vd9Zbl8wlM5Lk7XOcJZrcLsrDug4hFPom7YE7lQRjryhvasxrJXb7dX3momt7tz3lWHwve99n+997wfMeXLx+0ye3oTiGExTZsqRyCC0xMv0xNdf/YC/9vzCfTTa6xe6NdkitV1JyzkSixeoyySSju9NddjjPnXdGbqdUgbbVsTqTQmru3w5L4XLclXTEgO17jwvC58/f+Sb22d+UD/IQLZDqztzLkwpyU1jfdCqEoP76PKoRPvo1/uN2+POtu/U1vj05TOPdXXrNjFGQ/TPOXvwai6a+NyjM4SBe22RYmQaRZNTVMMycFRiNBciB3KKHFl7IWn3e8heihc+CzBiFyEsiqV8NMVEWHKmWqd1rScOeU0MuMWa+/xtb7vTrSo1bfwqwH3rvgqDbjostnVjv8teJB55PgFww8vam3Q7JrFvinoAdIgrdEHaFbDQZXnf9UUUiKmdyL4HHq3yaIKlJo+GV0JsEWyzNUJoblWye5SHd+UhoTC8QI+wVSN0P/yjLHpklKtYbMuBghhwMtPnECH9nEHl24LeWWbR5ZwmjVQmMHpFZN+gGOcYsdFkk9QFn/XasWaE2p3KH8+iYEGHbI7Ra1KALFgPM1ntBLeZGkdXFs9JMpYoDQUQSvYD1Px9qOhbOhy+38ogZGJSd5h9IrbeaQFKSbrGdfcmIJFc73SIuKeDtVeKHB2CCmWIvpsiQG+EoALe0eFSgkyHezvcCKL2nATylPRwhy7wZUToleZFIqbE5MatKUi1H+NbHphskJpo4dZhdNb7Z/ZNDu91f2WeJ5+KqjdKGzkWShYlGI+NkQSj+HTdCdNEDBO9w9h3aqukNBNIlDiYYqUVaYkqqyytsqJiSlaDsW93lBjc2bY7yUTuSGUmLzPTLFHuWht1ixCzomq25tCVd9E2aNuDmgq1ZG43Uctv28rWKsvlia/ev+f99YmSMvM8U2bpzIbrgQ4D4mMvmHJkjpF3+cLy/IGP20798omny8ySE4EOabisQobSwYxo0jxKtrCzNu0S123VvWeDqz0z50InYD3pTOiN1eDDsvCD6weWlNnHzv/87e/xez/+EY+2KcLE3ddjibx7eeHbL1+UFN52vvn293n9zd/g0mZS1DOz1sp93bjfbtweNz5++Uhz540UE0/LIvFyiPS2ikVXAlZgKoE8RfKcCZMRl0QsmcggzRD3Sq1BuVd90HfF5Bxi3BGGr0qOplBRaUIFIjmrmA0bBIePY/Z99KH/GAYpCx5N0ivW0BjJ2JsgQPM/GoKSDqxL81a9mf4bvX6pi1Sv+kD2XVYnrVaGKg1aHB+WPFr8mivuCfHNNy5EonWSL/5yikwExQUg2OVaxOsPPvXsfWcfsk7KSdMFSeFhNtxWaQywoajrfdch7pY+gYGlAEkPUA9NMGAU/htceBotMOUierIcMTX6I4uj7gtuBcbZWZRTEjYe82GhxAlZGbhrhrYFKUaq+/6NNs70T9rwvc/h0q1MqBCOVNx0WvGcosIQzqyc4Pj1yeJD8F5vHWvRl9yK5HCJj+/29LXS8XfxAS369w1BrKGsw0vWPl64j3EIY2+NUqVTCzGSD9qrKEg6PP09O3ELfQk32XUYOSVZFEl/op2Mpr+GbxSw0bCRGRFG0+c6zbOTC55coKydURKORG31NNfVezg+50HdNyo73XZqLRhQj9ToYZQo/8SUJnKeSPmw/inK5UiZ4MVv3ze2cJe4mUgMhakEWFC+z1OD1rWTQvDPUiaCGXV7kOMEriEqMXG5XHn3/j3v37/j5fpMjEZsO/k+0Xrn9rhBrUS15mIkVk1ya1yZpwspbYy4c9sekBJPlyvXWcSIKcsgVgxLTeIDZ1jaYFkmUpBZ7ERiiYVLmU/dV4zymytFB2c4hK2uHzKpOmAY2/rg9nhQW2XdNzdazpQ8acoyff/HuvLl9gWWC8vTlXm+MEfppt6PD3xYH7x7fibHSKuVECCXTGuV3iW4t94YvfL580eWUhgX7RcfdePL45VvP33L66dvWV+/kENiWmZaS+TLRSSMMYh5IuFkiCUxLYE8SygdcydPPunE5OQUOUngO8veBp3gDEAdATLuxYkVGrCPpGkdVyI+yCMzyOAgmcJjUyQkc4MBaSZ7az4hjROmDKgZ3xns4M2Ae/z9Aq9f6iK1bcpYGq1psc2xa9EBZ9/ZbYzueyVfipq9OV6P3sgYJQbmmCgBfGNOipFrScxFo3FtFfW/Il2E0L0zf2OByazW25IW6Nv4OXV1Sdo9pSBrJHEu5EAc48w8ZQ4WcSBSCu6LJ5eNswhF2Y+0dvjPuftwhOSptRbMqd2ejMkb7BQDbuMUYLigdBg00eB/bhj3kzy4C7c5w+jIltEU5Ya8x8Hv36fV4XR4Nf/43xmtO1XYgwhd/HiEAwZVJw5TS5NDj4gCJJpT7ONJbIg0U/RG67IMMiCkIGaS22TZdyC6w+lComMRYKzBsHjuDlJSvH1rjX1X0zFaUxdqilOxuJHIolKnwuWycH1+YZou4F1nSIGU/Vp2dfmB9HN7PFkXyYC2t8GuakmrvvtLnqMVCilkZHTo06tJ7pCy0nv7ALMIlomhUPKFkmesdko2bAm0vVNXiWJzPQoxJ0M2LFmeeEW2PJfLhXcvYp9NJRPoNMtMUyE4XJ2cno1BdR/KNyjV3UecoTl7YUpyuj1p6SfpA7AQ2Hqjto00FT5MixKra9UEtK9YUxHorTJsIpcLU5nO67ztq4SvZVZzEcJJwe6tMlqVqXFvKi7W6cHzxoJWAmp+KlzkKVn3lZQKP/z6+7y/XDVhx0wYnW3f+PLlCyEEpiIXin1dub++cptnTf5j8GVd+fbTRz5/+cT65ZWwN6YQuSwL1QbL8xMvTy+M1rjdv7CvlTxPXC+J5ZrIcxE1P1SH/dTkdm8gBc8lJwfhCzFv4MORvxa+83zb8bBzJHybCakJuiyCh9EEJvj4rUk9drRm3ckhLjcwRChjMNJxz/9i5/wvdZF6PDZNUa3qJreBDS3WB1JDH/QiMVnUuUvPA5gw5xICORpLjExAdNgKtAi/ZONaRMUcMfNo4xiKdDibnWFrDTmzWzesDkYd0JD40y/mwCBr93Tguxa9s0cY9FwOOvJB4Va8ecmJeZ7dFkawoVh1YuSFKBjzyHgK/vOGlMHE1xsDD98DxZhEWT+NiCXZ45xBhwGIChws0+TFxEkSOPSX3U3eYUZDAuvuU1l17do0TUxFjhdHyB3OiOumrtecjeep6voZAOzwfH4ji3yXUaeXgYlpGGNQvlXSfdCtE60R7DvO9o6vHwavmjI9ecpD22KQ0em2bvrftslEtVZKSZ5ZJHHjCEZJck2Y55lpmpmmSQXHPSW7VcZo/hBHLER69/dzTGw2RKAxUwFEZrghJZYyc11emMtFTg68aZ84iBQp68Dqw+NJshNvkuLShwfoDVNGVtSOsQ+RWmpthOweccGzrdDnNM8zT5cLU8oUv9Y5DnJWB5FjYlomppjp7hTCEFmJJ2U4zfOCBZhYWJ6u5OgC0KbJhRQ5T9Qg66L7487rvrL3zv70Qu6D9b7yzfqFT5+/Yb/fdAh2ReyUrClz+A6x9cq63x2WyowqaK/kSO+RHUHJwZm7674TY2FZZPtECOz7zlZ3SGK/WY70DZ6WC8lg3zfKogN43Xdujxu9V1KEXjfW+yvr/ZVPOWq/OIzHvnO/vVJSZJRED5HWKiFnfvD19/hbfvOP8nR5ovfO7//4d/n8e5/cKSczzZmyaM0gVGGoSUlJ0gqXVYAL7y058cdNDkyyFCWIvxUxc2F+8DNUjZPWJ7521tnaFeZIdA9IL4BB4+t5T59NitpxDiLPsF8BW6ROA3MLfpM5rLl2Buw07R1RWKohJk22QEqFSc+rvMBiYMmJScsdGbMGYwrwcp15uS5KdCUyHpXHKosaC0EYq5MJDtp6GEbbNh3UfvAdGT9hGJDeOpgciSmTSiHmTEqFNE16mPqg7B2zCRgSCpeZHBwHzkV00ubEiTEEeR6MNS8kwaAhx+4Y5QSfom5aYpJrR4UWk2jWwgcVe3ApXK6zRMyug5II+OAX4t2TiChwHLgqniG8Fd4pF7JrZCISt9KjzEv3qqV28AiGrLRczATLBsghuEiznfCh9n1yCckOq80lcCmBKTs1OwhKi164CdqZnWatI5yd+3CPQoDqRVNNQGPfKvvelPgaxc4Kx/5lDPLQe09usJqStGMxZkavtL5jQ5P/MGhNju0y9ZU3XTh2j/6w9+axLuXCZXnH89MHprxQe2OzSh8727bTUmeZLlhQgVXUSpYzOxUl1CZyzGCD4RKJ2iupJN+TuXTBBiFFCaB3UbDDNJ9NAv555BzcAkgxN5d5Io1EDirg2SHaZlDyxMvzOy5XHbp19JOAMfzXY1+hVyXoBsBlG9qpVD5VTRy0ytYbn26f+PZHv0t9/XLmbS3zlTIthHGEQhq1b17IoyLj3b8wxUSOgjjJUCZ3Tw+ipwcyU5J+yTAul4vuC3cIma9XxrZTYiTm5OhK8OfdWNcHve1gje1x437/QojGuk26KUNgydq52jzxmgObdaZ54jd++Ot8/fyelDQRfvXha37vJ/8r63qjWeaIfQwRQaAGB0CfYpLBcQ9uDiD23pF4gDWM4WY1dpIpCNohYrh85VgPyCkmjeTu9lG2Y4ivcT435o3kIYg3t9xykEaSk85hrPyLvH6pixTRMaGDAOR7DYYW4tEQ+yUkGYGmSG/hzW1ZrAothIGSIsmn3ZITU4SnKfPVy5XrVESHNmOKRnBx5GgStoWoG3SMQY6RHIPW+2JcyMECOVdYtMO4DwsSZKaSmeaJy+XK5emiztRgjCbyhAUU1ndQJ44x27vvw57fqeExyGfvWD6bKXZdA6Sd6v+SpZGiDzZfUB/9TgwwL5nLZXK/tEjIB4wafo6GPrqdnfBBex+9M5royVOWODUmiZ2xcZiI6+Hy/Zc565Ko+JPjwnafjCQe7pjnDY1x2AGpiyfLc+/5mlgWwbRlSsQc6R5jcUAaR6SBmc+4/u0ENUnMqoKqSau3t7RRhmG1k2qDSdZA2eAwMD72IQRNKIrv6jqwcG/C8fOszPPlMGkoOgyax4iUvFDKQk6zW0mJ6l77Tt02YlRQY3Tbo9HEGF2WCyBWmMIVj3tJRXJd76LDh0QcndwqsVVSyux1h6G9b06F2uQCP5au2BtzlKI26J25FKJlkY5idzalcoyyh5MeLjGHCH+vVSGLpUiugJGmzDwXSkgng/GxPmTuOwa9V25t40c/+et8+v3fI6xKBjaP2tCH7wJY1xvqzo602pmOjDIzSkrkZQGglCulLKd34lImesw0a568rUkp5czeG7W7PVFKlHkix8zj9c7j8XBSVneUx7h9+cL9dlNgZCnnc1KC3HBiSYQpMaKihHKMxB6Yy8Tem8x8c3YHlMreRc6Llo9WUZNL851RFGGn4swqsk9KxgEE4X3yETwawrEF5px0hh3WW8MbvHF+v2MlkWMSQaa2Aw7xaa05wqXkiSMv7Ofu97/B65e6SJU5uTJeB56FqHiJA1jLScvsnKTO74mpLc6Qi5QI70rkkiLRdUXVstvciPd/iYXZ2WykyL5uMLp+3Sq9BvbaMXaGw1UVRSvMOZGP/RHmEfOFVATvtSioJabEVCbKPDNfF6ZllreZm6YO9tOctQ+w3ikxEpqgotHV7fbWXB2+MOdCSRMWlYHUkV4kWXgzb/XdzujDIwB8V3QcmkkZUst8YcmTJoSU9PWGIdV6w3pVjIC7XQyfHps/TLbL1mWyImX+AEJWcXVXD89KAeuyzUlOzR+irocY9fBGL6SesVXbinkhbwWmEPhqKXx4uVDmQrpM5EXL/zoMs50RlCjbTF87RmVSHewuOPgXgdaht531/mBbV9re6PuOtUYoMFuCZoSp0NNMulwpT88wzXI1CWjv4ZZG0doZNHg0GWNUUdbHm52XepDo04A3WjFR8kQIg9F3cjTi3tgfK+u2Ean0KkdvsSrldjKCUZZZEoCYfIJSuvDhpNFqp0yDlHS4RT+G1u2GWaG2TmqNfVv59su3XOeZPL2wt8FeB/fXz8TRWaYJeoQOaZrpo5ImJ3KUAmh3jMO4B6Gk9Z3bHqijsbfK96evWUJmyplogTl31uXK7eNHPn/+ltY3bvdPfPzr/19unz7Rtp1OYCcS4syS3mE5Qd/ZhkyGJQcR0SWQZMeEvifJ6HSq7czlmaenJ54uYvm1feNLfZByZikTc4wsMTF64nm+sKTIh5d3XMtMiJHXaHz88onX2xcuJD6WJKhwdEZQA53ypOvR1UCWNLEUWXz1rsTeM4OsV1rb3VhWMpi1dtI+GLG5p2c6mXjpsBU7iEqHYjy7f+Y4RMfaQ8veSGsOBdS9FTw9D8f/BAWayxdSlBNNyK7dsoElHAL3BG2Te03HCKHpZE44zP+LnfO/1EUqliCILEZ6lVgshsA4AroihBKZp0TCBNPsgzC8iKTAy1xYkv6OzBNl8hrNmEsmRIlfRwi+s3A7pCbd1O6eeiEldfXBKZ1JDKOc3K0iBh0uoxFjOfUqKeXTcUIMK48F9+kgxihLp/42DSWPug5ogul+E06TAhSXi/QzqSTXHwlKOvdAHMF/dgo/u4emHZMJ3m0VNwPNTnXuaFGlaIfdD1YxFhtdn30MPgUOFxQH8iESDILw3rDqt6Wr9GF6Tsaw02GhG4TkIXpJOik8RuSw7xHUE5mXmZfnZ96/XJkvE2WZCZeFSmDdG5u/39ZFNknx7bBMvmPBP5u3nZkvkDu+Y5Mjw6VMLPMTl8tCvj6Rn595fn7PNE2eUCzG0zhiNOxgYY5jIYeZKP+9ihUV0M4k+dKZc//mzttDLv/BrX8EQ1a29c4R313yhRBkVxRTIliS36HDp7VVZ551F6u/EUTKNLPMF6Z5IUTF2NzvlcftTu+Dy+ePfPz2G65lIQz5sa37zuN+117POqGbJinfKTUzyJnJp6iUZXbbdNLp184yM5M0I7q1WPFY9Azn9JPMRJr4+JH142fGY6dXY2MwzU+8f/d93j99YITE1jf2uvN4rN8hDMhtZW8bKWYsGM2a9pa8kX/yPBFykV5pnwWB5UyKmf2xYwye3j3xtExe7MH64DovXC9XETa2jRGBlLg8PavwFTmxhxg5vMqCAc147Hde768MIrftwT5XSpjpAT6/fuF2fxVRqnZ6jfQU2a0TgymrLujwH11hld2DEg0jN12vlCLDIsOa35suWXEyVshJX8efzejZZofT/hgH6UcN3iHkPRIeLERS0oR1GBm32p3NJ3mNiWXzC71+qYvUG0VYGhv5qAUsDN0YQWsNuRMbIQdyHtCMPAZTCMwJX1pmWQ8NBeeNYVxLpsdEs6Cl7Bh8WSu3dWdz1wkdOuazmeCr4A+nOWyklFdh783kojBM01kpxaM1/EEdEtRa0LKxtc6jNurWEK0zki2jqAbOhWTKiZIjZZ4EUywzMSX2umNNEFnfNg6SBSMQrNN4W9i3fafXSjATVp/lil5rJe873ZmCtQ3WvdLjsdMRRFCHBMJHFlP0nwcXRR4Leu2dtB8JRyowQRHZw8XJjkWEoe85RiOM6Jqq4Jo107W3twC4nCJlmnhaFi7XC3mZGDkRHM4cTrRojonrYBTR5GB9ETxp2Qvp4Rc4PHOs7Y3WB9dUuF6fuFyvlOsT8/N7Upmla9qaM/E0kRzgrPz9+hmPYSZ6ed2b3n8ub/c1R2yKF5R4I+dFvodRkF3tu3Q1vYJpr5RyOokuOUfiHtg2wTlHonPvOqCikz1yTLIimhXaOC0Xwcexk2KX1VNvfPn8kafnd3x4/sDilj3bvtGb9HnbtpP6W0Do3itr3cjLfN4TMUVIiZSMVORlNzvBJG0ra9vpw416Tfsqou6LFOVivj0a++NOfazQBn0E4jzz1Vc/5OX5a5Z8EWyYEq958aj6AgF2LaLdYUbTc7MmP79yIZbCANI0cbk+8YSyEl/3jb1W7u3BouOYdV15vkxgEt93k1P8h3fvuV4vjNAJObA8XXh5/56Xq7vHpyIJhsmUoD5Wtv1GQIayNozb6yv7u+9RaDz2Oz/6ye9zv92JUjEg/UYijCPiPZx7+GPf1Gp3D9DhSbu6N1KK9KGk4RAiKR9QH4on8mc6+E4WM48bcvLEubHQ2gAOgoXvS4DRHXa25nuo4+87nP6rQJw4kigtA33o8K/6gC0GwW9RQ5UOiyg2CoNsiUygm5a6xIilgkXl33SM+9749FjpXhwGcKuNe2tU1wzlBMOCHH4Rcy/44dLNtFBN6tDT0a3HQCwSAAPuzZaILVHj4VNn1L2zbZVt19if3Hl9RtNJp9OrO0TkyDwrQDFPE9NyUZ5UTlAjoUnNHvC9jAtKezdq3djXjbbtjF3GvAEZV/baZB8UcPFdOJNf/SK4NiuR88WX0cI9U8S/hqj/Mo0Fhe7JOuXcX3k8h3nxORzcQUthQ7ZEAXWrDGPUwbrurFtVAGSEEBWoNj9duVwvwvC9gMYgaLV1d5Y/F7eyqAqY7if0/43gHXelukO2ouNlWWMhkKdJtkdFpqshZbatO/YuVtQYKkpjCFLunl12uEAcjLrskfOCwLRTPKbNWnfG4xPEQbCNaZ6RXGjXXqNMgJJp52XWbvOiPUvOOpS27ebMQnM23S7ftuTfL2cXPU9Ouz+0g4kSA31U6nZnfdzZ9pXers5AdPg7BO73O9lkltxH5fX2mXVd+fD8pMYnZSeBGWkqgtXcTy/GwLJNfF5vmkJaFexoqEHojWadR1/5fP/M7XHXbqY27uvOcnnmq69+jZIvxFBIUyKMyrI8QYyUOTMCPLYNG0PC4anIwmpdoRTs8kR+fsfTu6/58NVXXOKE7Z31+o7VvuXT62dC3Ol5opmx55mt7jw9PWsajUI9plx4fn6mLIWQA9frlWVeBMXPF1LJrG3jsW3cH3fa640+dqEWCJKnVb799DO+fPnE4/Hgy+dPRN8tBnCbISAevpvSiRFheKFrtbFtVXBtCkzTJBQl674MBGKCOLQ2EGKgpiClcE7znHtoToQnOgX9SNoeQQLhYwJrJlhxjEiPaHjw5/vw8PtFXr/URaptIkkYh82HoAECboUjaEp5OkGdx6liC3QCtx4lHEVTzjZMk5PBfa+Ux4OQApcwYQTW5usTs5OiHKPrKYCD924OGcSUZRPThMfGICjjCCvTYlfWRIaEtdolGNve2fbGXiVGnrJbPFlgIAHmcDZhcpeHGAMhR7dLkeFkCYNqXb5iyBH76O9rq+z7LiF01Z5i+Hs4Iita78QhhuQRGV1MOy5DKaw5Ke205IIR6G46uW+7gvTGUDF3Huu5PA3HZzDO73UczObb3ZgE+5TzM5OepzWFVW7NvRBzYAhblfBxmtQRdpOIMyTfXw73RUN+ZgFX3wen5ovVNoL8B9daeWwrj23jsW5seyW5CWdMUbENxRfxQQ9mchik1p26K4coEBge5dGqdoh2XPsmKLdEHeLRffjsO7AgdDH5WiX6BD7NV0KciHEBIvN8oZSZaSq6Fk4Nlgfig+5Q37bv1H1n7zs9dN0z0U6fNqdpEWNk8STq19cbr58/8e6rhzfX2mdNKTMmBZCOAOt2Z981/d6+fGS0nez+hmq+diKFOV5YppmnRZT241ka0Xi9P9haJYVASYlt27mtDx5t4/P9lS+vn6h1UwHvg9bkaZnSpD1eiXIKGYH58sxyuXJ9vpLKxN4VS6/7qLKuD8L0pD/z9I6X5xfev/vAZbkwWWQQmVshPCJ736kxU0OiB/j29oWcAl89PZ3PdO+duu3EGE4S1HWaKIQTMrMAI0RGTDBN2NK5vT5YWyPPhRKy3EvWB3uVAfbiDNnm1k85JqKrk80b12ByUDnYnUdsRm92rhgO5EKySJdwdMXbpKQG4tR6xu9IeWxg8TuFZSB5jQuogkVfa+i3Dc5iJJ8/n+qq0asIQ7/I65e6SPXd5OOVRHnNU6SH3V0lwvkh9TqoeyfFIt1IlAt37bA12aTkCL0ajzYU8te6DqwDjyXQuqYjc/w8mVGSyA/CbNEM7Dg6SYtqAi50Pf4dvzOISIfzqA2jOuMmMEKkHh1/bziox1Qmff2g4hbcrTlG3thCk/ZRMSVSMKJFwu408eAaCItHrZbOadOh2VzTdFDKQxTuPM0T0+UiZqIZ2XDyhL15Fno2UDdjI7CFzTFwxaQcrEAjHJwBPdD1rVDWplysEONB3ATXkuUYPePJ1zpdlPc2Bs0G2VQkQnKGVEi+o+xvq9+o6nPAGMcjl7LbMMWkha5LdeR20LR32TYebjE0pQxO4oDhJI7dYThBz9TuDu8q2Adte69V0RJDDhYHTf9wjTB4g2QOSBmckZm1N7q8cFmu9GE8Hndi2pV8WxSrHpzJd0SjHH+/Vlnu1FapvZ7aQYsIEj8o+thJauijkXN0Kx8Jf9dd8grcsLdlOePP08T9cWPbH4zW2dabJudzshQkNGcJfqckcsSSJ/0ZCiXJtaFZ5749SCGyPh481rtr7zThh6HUgVDl07nvO/u6M5713s2ngOK5a5fLE/O80AYiuvTK4/5KmVfyNHF9fuYlXZnmwtPTs+DSlMGMpzR4Hg9+98efqK+v7JcX9jH4NAbJBu/mmfyVjFT3Wnk8Hjweq3sdIsp7G7QhoXCMs1ze3z2JkNUbz1++5X/9n/8/jG9/RiyFXBITYDmfhrUqAklGYUn0FhuH/ZMKy/DpdnyH+YuhZ344fGrBo29Me/IUmaxQPPLDhhIJwin41VkxHAY8dpjkQRpvy6V0WCzhIvlh7HuVpnVttDawEZRWsf9NKlL/7X/73/Jv/Vv/Fn/lr/wVfv/3f5///D//z/lH/pF/5Px9M+Nf/Vf/Vf7D//A/5OPHj/zdf/ffzb//7//7/Ik/8SfOP/PNN9/wz/1z/xz/5X/5XxJj5B//x/9x/t1/99/l+fn5D/Vexm6ybBnS0KRk7MdOIzpZLGjctKHOeCoeqRx1vNS+00agj+jO4XhXIQbdFGSTFHqj7dWXwxrr1UWLITYItCTz2uDhhzEmenAvuUlTDOAKcU1ibZOMsDWj9lXTUIgcoYVmRteQxB4Ga5ZnHA2SKd5cER2JhjFiFlRlQ7kvXQ7bMSfmKbHbwEzq/oAiIwKdOBp9ffB4bAQLpIviplMJpDnLiiUHSpnBPBbE8CJ2WKt4XPmw80DSdMG5KPU5E+vu04f2eKJpqyimEDyAMHi4oedPhcgIovgPbdOUq+QP0JwK17kQcyCETHVro90Caxt0D2ELoYox6EJeF+RokZ2yL40bfWj31HznUQ+Bcq2EuRAG7HVnZ4HRGLcvjO3Vi884hZMhGKErHHEMw3rARpQZZ4xe2CQCn6Z8HgB5yhiDUBQaSI7MyxPP1695Wp5ZloW9dcwi8+I2VlGBm9NUNEGFgUUx1xTxYbR1hbox6ibT0xAgaxKS2WyXq36302InmVEccRh1Z308qJeV53lmBGmrxjDmkFmjYW1ntI00Kp1MXSt1a9TNWbBknqOIOSW6b2VI2sf0RtvuPNYHELEOdd8FPQNP08KXXDAkZG2+N5wI3O5feN3eUUok1aDAQjNizsyLpsz312eeLgsxBz7fbrzebgpCTJkwItNcmKeZ58tF8SB9aJoblS9fPvO7/8vv8M2Pf59hnen6RE1/hG++XZhsIuXAo2/8wZdvuX/+wvb5VRM12s29bjfy+w/EMvP+6Yl3T5ryUlCRfXf5CpYrf/DXfkcizphO37zRqvyUzWBA3cwzxZqKIQHw+I0xnLRgh0scgw5R5AhNPo5EDTXfPUV5Gh73jU/xA29mOfbIgZEjREUb1VZ1z+V0mjWPENxPVMSj1gZ7NeoeaLvijNr+N8m773a78bf9bX8b//Q//U/zj/1j/9j/7vf/zX/z3+TP//k/z3/8H//H/NZv/Rb/yr/yr/AP/AP/AP/j//g/srgW4Z/4J/4Jfv/3f5//6r/6r6i18k/9U/8Uf+bP/Bn+s//sP/tDvZe672KKJF2gGHw8RYs9G3p4dvUbHlmtxeaRDkqQFf3ROZwiVRsUp5BHV93nnIkN6Dug5XX2CHvF1eMOwDp4cgznjkFntExjh2nSizE49FLpXexD/63vsN6cdZMDebid01D0fB+uQXD9QrdOHZX7vtGxM/DQhotD3dqmR3n/YXaSGfZ9F3a9V+/KCtOhVZkKpWSWZWaZl3NSkPmnitXwHVxvOqy6Q2ptdGrXn09+bQ5/vhDf7JUC6EHLLqH0IuWBQG64qwiV6Pbs5nhgAHKITO47d5kmQkreSdv5wB57pKM4HfR/TVCcSvnTTsvwnYemTDx9uKSkLjZII7TtG/LnNdahSfTA/3OWxVVwy6XhLLCUkqY34s/ppE6B4+EWEt2/MQSW5aoU2PlKyTO5THKBD4ne63cg7y5moRun7vuDVjd5ArqP5LbvbI9NAmo751sRK6oc3KsbDLdaVSwjGJ3WNnpdT0j0aKQPoW/051OkGUGN+7pxf/3CkjNxmYmX2bO3Kg+MeQJicMd87QBr3QmmoLwcM3mRSW+vG0/XFx7z55MFmUKgrRtfvvmG2+WZbINpXAgp+X4zkop2YNfrledlVpT5svD0dKV7IQwjMM8T81Q8okYElillSkg8TVe+/vA9bF9ZH3c+ryvffPszUpBr+DLPbKNy+/KJjx9/yu3zR2hVdkj3L7zUr7VXnmZenp54//zM09MLKcB6e/CD91/z//pjv8X6+SOff/JjSirkVFj3lbVuesZGl3i7NUrJMlj2e8e8KA/TPd+qpqZwnIsOn5iN8z4PB3O5Sztnwc4mXvB7gxjYTU1FTJFkSkmIo/s9p8ascLA1Aw1dl5IT8zTR9p1KpVXpJ1v9myTm/dN/+k/zp//0n/4//D0z49/5d/4d/uV/+V/mH/6H/2EA/pP/5D/hhz/8If/Ff/Ff8Nu//dv8T//T/8Rf+At/gb/8l/8yf+ff+XcC8O/9e/8e/9A/9A/xb//b/za/+Zu/+Qu/F0VgC0sNDj2F5LTOIDrraJooCEeRMnqMLsyLJxkgOovKugcRpsBlLkw5npMJQVNQbMKjOQpSknI9IRgsuiv1cQAC/p58/3HcRPLix2Ln4dHnbwe4mDMxRbBEDpHWE+veRA0uiTm9Xb4R5SW490ZYN9oY5CL2VcStf4ZcBOIwxoE5Bx2c+7azV1+wqsIzL1kL36se6JQTIUcSiTjkY2bVE3OPydCL1OoR3PfHRtuq4i48UE37KcFvSo4V0y2nxFwm3F9d0IWzNqXv0N897JhCCOQQKa6Pm3ISUyyX00BYvmvDtV2eoOuFIPjSV/qS7vqPYxemBuYQ7x4hiyUmKBJrt1553G8MOqnMWIhUk9P9NGd6fHM/0eVPfi8Eh1UEC9fdTXqj7JMOC6KctZ+apgl2yLlQykKKE8rBKBJkh0KouzdUQd6C3ZDX5Ma23tm2B/u6ypi465B7PFa2dSckTVBHMzXGYNTKXnequ5gowjxQbbDvK9t2Z++V2tt5jx9mzrSmTrk6NNdgXx/cXj+Rg5HrBUrk9bYwpUi6XsXyjJFt133zWFe2fedpfuJpeWbK8ga832/s206ZZieqaB9io9P3lcenb7m/e6cYipKInjIt7z0xP5dl4npZJPsohZwSe+2srk2c50k7xngkR4lxW2Lh+frMfnnPyAuX6crnn/4u3378luvyxHOZsdF5Xe/cP3/km2/+gNuXL0rsXVfW9c5tr3wwOVe8e3rm6XLhaZmhd6blwlfXF765PEOMvO4b1n5KCom61TMJunVjnvy+HWomjv34AdnK2svOdAQ1hx7OifmKIsh6bAQ33xbRR+YBHo7aTa4U/fBdNNLk/qYJorvkl5EwS5hlJoTynDIO9DVzFmHqsDT7f8Rg9nd+53f40Y9+xJ/6U3/q/G/v37/nT/7JP8lf/It/kd/+7d/mL/7Fv8iHDx/OAgXwp/7UnyLGyF/6S3+Jf/Qf/Uf/d1932+SZdrw+f/4MIFp2hx6lGQJ8/xMhccZydD/4rQ9sVEaMWBxYziekFswzddwkYi6R61yYsmyAah+81s6tDvahxWfoHaYoxtLR5SNXiZxF4U45H9E9wojdOql31zmZDFirs3Bk8KlETCX1Ri3ja3IdV2WtM9M88TSbGHwBkQNah20Dz2ia4U2XEnVAhiDBbBxOIvCbfG+d6tqjkALZGZEpq/8aPmnG3kgxn07vdUj4bIj9t+87vQ/Wx8rr643H7YF1Y57Kz2mizDu9Y3kqoWqRNs2xdBuCEC1JpHjsGYHTmunIeorB4ziKpr6U0mlU6v2ESl946ySPaHsDppIhStPTqgcttgZD+qYjIuLNZVNNzaNujA1S6xAzzY6EVzun/GNHlNJhSxQpU6QUHaAh+oSVfbqLIqGknM505ezO4DEWUi7kPFPyTMyFvVZNCjEwqiaqtjf2fuPxuLE9XtnWB/vjRqurIMq9sm+dug3KwukeIu9F16cpUkB09brzqBv3NhjTjbreuN1emaKavb1XWt2w4cQbghfloCypx8qakopU37EpUqZMgVPoHlNiaxv37cFjX+ljMJWZp8uVKWWn+kQ+z58JsZCnGWI8D7scoN0fvH75xHS58GTvPAFa58AAUslcLguXyyxEpTu0PjZiTkqknUSEyX7NBBcmylRYysxcJjU008zLu3f87u/+Pvf7jW2+0lvj20/fsj5u3O937RotMUXtGnsMpGnWNPd05WmeWUqhB0Fh757f8e7xnpwmxjC2viv+pHXWdaN380h6SEGF+2Dbykoq0lEj0nrXNXRWMYhJ2q15IoAmdc/5VCEKg9Hl7Sc/SxfPj+boY6QcQEMwUpZuMRQNCcfOMVjVSsOdSHIwmXRfFJK5Wqe+Hen/l6//W4vUj370IwB++MMf/tx//+EPf3j+3o9+9CN+7dd+7effRM58/fXX55/5377+jX/j3+Bf+9f+tf/df29jSPkPEr/GANFIJGlnoijmis1w3dJAdOfkpofeaViXaWgmMuXMZU4scyHnSLPA2juvdXDbKs0cGiJqv4PTlwMOC0qPcMRlWBRzEKISMbsb4JpR90ZdN6jdpwrpNkKP9GhAZxBIeZDqYKudvcPcAqMNBcGVDGlgMZIGzL7oyfbmVKylfKZWaZFGenMCF420nFDiFIN0H1kj6UHXFgNySE/lO7N+TDpjuK/dTq+D7b6y3VbqVtXF5Xz+fCEczCM3uQw4Jf8UeRAQrfbowjDpcASrxjNWWzY/XcvkqJwhsZzMp6I3T8MAJ4vwUHMcTs/d/Q6HvXn4mQF9nMJq4IRIRtPit+7NpRBGTAfmH09iSHTfwiOOvJRCSZFSjiKlr9ltOLIp6ysip+mq+XXsZoygGJDL5cKyXJzgInujcKAHXS4Fu+veeqtY35Wj5NBuq8dCPTgrS+hC9piLMKTVaz1gLSprzeHK0Ru9Vbb7F+6OT1eTvdK63WltFxSUVfiHdXkW9kJvmbFBvE+slwvb5aro91QZbeexrzy2lT7MiSATOURFxwTBpNM8E3OSomgoQmb4taqPVZ9ZikyLfPZ6MMqkyPvgmsXkMTnJBMe30ak2uE6FeVESQUmaZgHZFa1qcqZUsKRrdZku2sF2z7oicp2ulDQ5DBuIVRqmTuLp+R3Pzy+iopdJAv6oZ20LsNtg74OnyzO/+YNfp+87r58+0fabGkgdCTq7gmmH1N3gOInoYkHaJPNnVjfsAXmLxGHRhEAYJxV8dEG9B1w4/Ew1bxpBSEYMncNDNsXoz5yfFQZ795RzJ3JkT3zIqTBPBeJG9gSBX+T1S8Hu+5f+pX+Jf+Ff+BfOX3/+/Jk/9sf+mDBTt+E5GC6pOMsg+MiKHyx2wGyqTCMO39W4Qa0n8+ap8LRMPF8y8zx5V2+sFtiasXdjq007oJ7I0TArWlQn704cAjxEruYBjKqmQQ+vxxVEJKT0jFBFBLglSnBbcDN1mTFHwh7Zm7FVo9XE01XU41Qgo71NR0V72JtoNrk9fmv97aA1e6N8hze66BHzIccIjfi9y6NQwXOHgNZ816YbvDePtG6dulfaXrHWxWj0Y16CATuZhSfNcQy6SUgc/eA1k+Ep4A9COqdewOnHTZBt8t1TU6RG8eiSOrSs18QU5GDtjhimL/IGgXyH7ce5oJZyP3mx70HFLqcsG629YTGeSbsEoxS5Z4t1GU/Y9tCUaMoO5zVJDsX2/gZJzn06DXxjiLQhGvLeK/jhmXyK0L0omLF7QOFBNe+jMnqFIbupA9Y5IU3/OUW6cBZXFKxpCCYVVFvEjEsqmtu6Uh93mserrG3ny/0zr5++FWnCp+QjO0h5YGrmeqtsm9h6j8edkqUPa2Pw2FZ14ulwZHG/PG/ejv1rDNHvRzjw/THUPF6WK2VZyNN0/r1pmvyaePPoWiOc2DN8VTDPM5dlZj4mhKA/99gV0DjU1vh9CDFkpjxzmS/M08KUJ1Ke+Xr7Ppfnd2z7IG2DfeuQJ96/+8C753eULNbgMGO0HRudbTRet51tr1ynhfT8nmFwnS789Cc/om6eOI72ttFkt5ZSOd9TH01oCcfzJVgbz9TC97I0CKh4QPR7Z7CtmxpHn/6PgtX6INpwmy+PUI3KaAtJTWHvniDgdk9zSsQUmPLk0KT0biSYsmy+fpHX/61F6td//dcB+PGPf8xv/MZvnP/9xz/+MX/73/63n3/mJz/5yc/9vdYa33zzzfn3/7evI5rif/s6DF3xhwwCsbuTQPRDJkhie1gHifInKngPUE0XPgwjGWSMS07MpZDmQrTB1OSR1rqW6I9aiSnRDeJDOoprjIpbroLsDJOCPSo6vHVjd/p07crGZQSsOYn5oP+iQ7gNkRCchI3tu+xZcmL2nzGad5RpkC3yVGaCW8mI9JFUMEMS+cAgdC/SphykFBK17azbTq2NaIp0sFhoLdCG4Ji9NUqMDAukw8fr+Exrx5rYbCRNpaIAOwWf4FNT97X6IT4c50MnV4mBTejg8Al3tO7T4CTfM3/whikOYWtqGnrrbNvG+lh5vM7ENGExUkdnrUpI7nAy1vQ9xlu+VhxvES6jO9FBi/3WmutqNNmkGBlx6CGNiYSIETEO5uvC9TqfcO/hsHB0vTHp3xaUQebmQSoWHeq208tgnnalC6ODcu9GC1FRIyBNW2iCr5tPNqNS+8ZtuzFapdcH23bDRhVjMZim/BT8PhsnwSdh7rAfKQfsmEXE2GwnmYTKvW6YwXr7yOPpQoyVHAOvr19Yt419XwX1xSyBdUhMeWJKGaudZg/CPBHGwl4br48H1+uT2Hutk2wwhajoFhQiuMfdSR1DxfDxYNt3QjPauhPNmEpiG513P/wB89MLT9OV5+tF1ki18+7lhXfPF5apELr2iTnq+/QI5Mhzmng3TywxSIQfhMREEnO+srcv3OqOebpvY7DWnTAVwjQTp4XnD++UQ1UCn8bgf/mrv8Pj44NYLlx/44/y/e//cX7w8hVPREIzVmuQEnUffHrsfPPxW+yxMR4b0WSuW14i99sr67Yx7g9q3YlTYS6e15QC276RSmH1lPLgmknzfZ0IRNIXpiGkJPYEBWDQzJxlG6lbk61RUFOQcvDd6kHcknnBmKInHQw1Rc76jcjTb+TCHDNpmK+/IiEYT2UmzQGu/w9MUr/1W7/Fr//6r/Nf/9f/9VmUPn/+zF/6S3+Jf/af/WcB+Lv+rr+Ljx8/8lf+yl/h7/g7/g4A/pv/5r9hjMGf/JN/8g/1/cxTL2WWmM79hD4MOU40e1v4HV3pweoTjVqODbRGDtIelKw8oGWeGa0qunsYuw22rkOP1inBF/lhYMGYsqCrOgKNSKydlBuxZHpQsN7mN5GMVfX+c8zkqDTTFA3SkC2Mk0F6VwieOUuMGB36iMQtEmMmdCOHxBwzcy4nDJjiMdK7T5+7HOgD5ISSDjuTmLI0Njh0NIRth9YxKimK/K0O0/xrqyk4oi/OIDW0p2qmyPSpKSn0jd4qoXXt8nuL3tUSghtvqgjmHE/GZohyJG+1sm2rvPSaWGhfbndh/POCzVqadzO9h9bopm7Z0PU/Q90IymHyhzAMBBV23TeHT2Ot3c08Bftob5TFsJompnnyePp46kpORf4ZEeKdrU/NhrG1ql3BGNTRSU2fCxZcioDu75hx/vBBeqRbYO+Nra3s+8q63bh9+ZbRKjYqra1E5EVp4xAFv8GWb1565bTCKiXrsAmRlAN5OUy/glsqDfq2cfv8kbE/SCnx6dMnah/ksjCXxaPoFxiRx+OO9hUip+jOUBBeHFBicRGzTwKm57TXxuPxkEA1Rra6crvf+fz5G7bHXaGDwUMeV5inC09PLyyXK0/Pz4xh5Cnz9Pyer7/6isuyaH8Zj8yjyOGjWELi4vCbHDh0Pw6HjFPQ+37cb4x9I3dN8a02yRlS4vndM9//+nu0robuj/adP/K3/Bb/75/8hPDyxG/90d/iq8sTOUZCTgwXktdaeWwbH7984v648/HTR+73B3OIFJcGzPPiU2UglMw0zZQy+fnVHGkRE1IsRTWjmobGGY2ioiUrntEbecjRPSHWXotqpre6ycuxZGfg6r06qudC6IhZch2wZszWGza6bJ9aI6B4nBSjMvLUkgpdOM6hv8HrD12kXl9f+at/9a+ev/6d3/kd/of/4X/g66+/5o//8T/OP//P//P86//6v86f+BN/4qSg/+Zv/uappfpb/9a/lX/wH/wH+Wf+mX+G/+A/+A+otfJn/+yf5bd/+7f/UMw+gBzVzavbdHaSHy4hZyX1WoduhO5+USG42NYH92NfYYOpFGf0pdOyh5yI2URGCJmdRjW3aQl2OG2x1cYyTcLz06AME7SQZQgbp6ycn9EPwEtT4NAhL3+7COh7pK6O90ifPZiBh9PxAaWUvdKzlvqXPPFUZqYcmWIS9tw6zW+NPt6iJsYYsvzfG/vW2HcVzxw82ypnQj5gUxOMN/Qzp+gEBDOxIQGCiwiPZFPXBTV3Fi+9O4mFN2jR4RkcKugYjKqd1VE0Y1SAn1/z0zrJXSnUJEYf6Ab3feO2raQtSdtmg63JpJagSHvrym0a5kveGKEnPYG+z6MP+q5DaK+dWvV5997lujFnlsvMcpnllTepQB0WMSGoww0+IYsIc5AoeMP6hxiZt8fDhduDZZKuT+enjNpynojl4p5rzWUMgqn2vrPWB4/HF26vH7l9/pbRD2Fxp0SZskrY2b8D4wq6DCm584Zy1mJK8pjE3E4onU77yXeGvTbW22esLYQAnz99gVx4zgspzJQ8aS8zAIzRqxomRMiZp4nnyxMvl2emWE7iujR22mm0uvOlf2bdNizC6/2V++tn7p8+MvaNhLHME3mSj948zVyfXphnQeAlzzw/v+Pp5YmXpyeP5vBCnRRvYcO0x82Zxb0WY3pbGcjRprOtD6LJq+/100cmQ0GldfDy8o7n9y+8fPWel+uVtm7UqfLV8sJv/tE/xn//3/9lLl9/xdeXd7Du1MfKuMxCTyJYHey98/q4883Hb/jpNz9j3ysfPnygTBe29abiHjwLK0Rw2ywI7Hs/iRFwUMnlcB6jefyN24CZHGnMxfojyHYsl4kQMyMaNdSTOCRDaIf1nWih57PTW9CO0bWGIUVCLOf9mXTp3f3iWF8M2r7RGmx7/cXO+T9UVQD+u//uv+Pv+/v+vvPXx67on/wn/0n+o//oP+Jf/Bf/RW63G3/mz/wZPn78yN/z9/w9/IW/8BdOjRTAf/qf/qf82T/7Z/n7//6//xTz/vk//+f/sG+F5PZBLpeRNVEzQhvE7NHi3STOPTa/IRCKPxLDLUGGrH3mkrlMmYuL+Q4HApIo2w1oEbmiuwiOVvXB98heh7q6ALG4xqgkDw6cWS4TS4lv+UIpUrvw6BD08A5zZmiCiUwMsks6mG12dOchEEaDociRyzRxnWWsWhxmGkM3f22iCZsvzavnPNUqvczt9pDuYvg5HdDPfezranfLp87hQnDoLhhOaAgq3N31LXuVp9ph6VRdvJtaV9Bh8LYsHBEfXZ1VfzOyPFiFMXZIiezdLD4dY2J0Ta68P3RtWme4ettp0e3oBL3ois1hTqgI9L15jIWEhr126raxbVWR8f459t4ppgakTOWcDId1agOKaMshhXMXAEbyXCT5OIooAb5THca+7zweG4TIlGc5rld32MddRXJh1M7jdnPLGmOrO/f1lW19pW43+r4yepWPJd5McExkRwilv7eof5fJwzZz8V2Q53rZQUnOvnccnqmlxm573AmHaLxWSp5o5p/xgBgzy3xh3+/cb6tmN/eDm3JmLjNTzLR9177UOn00toPivu+eOfaFvW7cHzfq+sC2lbFvhGFy6XcR6TTPJ7U5pczT87OewZSwJhg7F5eGmBzea60QpMvTBvltL/kmFZHr+mN98O233/Dpp3/A83whTpmcEt/7tR/w6z/4Ie+eX+TwkTOXmHkuC++++orUtIdeXx+MZyOXxDwVz4wzelScRs6JgZOU8sTy9MI8KQusD0h5Zlhgfawss47uUrQG6Qy2trsnpDnrz+gmSchwlnPyfeaUZk+tDiQrhF60pqjGFCdqqGx1Y6fqeUWPTEwyE7BgsA9ibJTgO8MoMoeZy0sCUuGDN8eSRrQ+aDs89r9JcN/f+/f+vSf76f/oFULgz/25P8ef+3N/7v/0z3z99dd/aOHu/9Grn/qVY+XtorWolN7QB9RO6uaR8MLYDwscM9ibDF5lIQTZH6B0QDJjaD+EYnJGQLERhh+WPkY3o0Vz5kwgBaMFmMZgMbcwKZky+ZSWxUqqzTAGIYkGnSfZkoRVRAXtLXQIHzT2Ixp9SoGlJK7LxGUpzFNimrLrb94EfQMdTqPu7E3OxOYH4+Yi3m2t9DageBqnyWmBvUnxHsVckz2/sOUjJ8vo3+n8D0YCx+/IfJQggsO+U8D1KZERmrKygjkDk9MGSEnGA4uJ2CI5ZIcaROPHjJK1M8gp8nyZ+fB04f11YZ5mLEb2Ifhib46XNz1ERzjkMbF1U0hj7529OpzoBb02/bfafBLBiRw5uut09Pdrpxj3qMHSO8lYtTjbTUvucTZIjO7Xo+l9HYdK70rKHUZImWWaWcqswtBWRWhsK4/bR+rjFas7cUisbtETic3NgsdbDMsB+R2WWnkqZxTLUdCOZOXmUGr3vKut7qzrRgyBlhIj1LNb1k4v+OGn4nxoYlqrxKh7BsL5c/bW2LcH0TqdwX27cbt9Yd/vWKv0gQctrvRt1RTttlK2KwvJkGejYXz6/Il0uUpnmAV95ZioKVHixRGSYyJ3PpObsrbeSM1oKZBMhbo6mWggXeBtffCTn/yE+vTCy1fvmK8vPC1XLrlQjny17NNYCXxZb5TWef3yhbs1Ue9nMRYnN7IupfC8wOs8k2KQRKCKQJGioNfL8wtfHnf2bSfPC5fLwjSVU7w9eiUQSKEwl4CNrkDMJof9Ix08lEiaEkuciCU5G7UIzh2VhKcLW4LuQaHHuXMQVDrnBBhoFAtq5MNgRNHSQwr0YVSTpOEkZ3UYm3G/Vb58u/9C5/wvBbvv/+zV1GZpx0SQ1sJthQ5WTrIAREY05rkwz5OmDGestCanhBQgeaedfXFsNthbZ9021k07g5hE57YoOm9viAQQBNu1ZpACDdn8ECYmJLqccuZSCiMYew/Urge3p+OhaXKinrJEsaYuNnrDcfjPhejiuABTCvL0cpy49x2CazLMPesMz6Dx/Vqzc6ratp26N/a9Yh2Yo3ZeICPbAWG4kJm3CeqA3eQs0TgEqMnzgaJ7FB5u8H6pML4jdg5QpoyN+Qw/jF6gFGkxTkuXQxU/hoLcQNKFYBBK4mmZ+fB85euXK++eLpR5oRrQKrl1NRVEWtsYh/uFT2ZyCFG0RPPE31pFkjmMbFsT+eLUjYB+/tHFwlMbfjIej1j6lHW/JI+Uj8mnx96E3/vhHxxmND8UTrTJ5N2YbRL1nsN6qsm9fn/Q60YYnRyCU/1nd6M+oGzRjMUxEi04oIjxaSos88KyiJ59pNvWIRPWbauC/TbtM9dt48vtlRgj1+WqzzLrfl23jfnasSBJwCHUrrXRWpdIO0B0iyNzAlDtRh6RfTReb5/59Plb2nYn4Pvk2hl1h6aD+HDBPxf53nRt7v4wPT3x6fMnfm3fsOE5Yzlhy3x6RWLDqdzmE5VEyIFFfn1R2WHNpRbNBrf7jcfj7onQep4WvwfZ5SM4EliOjBL40lf+2k9+D9t2/uCnf8Bt7EzPV5bLRUUqpjNOxKbA83Lh/cs7lmnmMb7w+npzGYU5BAuNwIeXF37w1VdEjG1fud9v1Kok5ZenCWMoCLOvtD4YlVMPmoN5mq+EuIRMjkVNfjRKmth71TPmco9x7Jy8KR90Yhf819ogtM5ymQklcKj0B0PO523QhsfCOEwTWqLeOvX1bxLc9/9Pr3i07EPdtqAd4a6N7hlGkWhGJpJDoIQunDbooOxtEHrWtBISsWSl+YaOdWh1sLbOw+1uppAZxRixY11OFWY6wW1oCT8csgpDrL4GjJSIUyJffBHZoYzMWhP0TkialAxzS53MEamQcA81teb0rmiEHNDPlI6oeQXiWZS49ugWD/3DfZUDgCI0OqN22lbpBiPooN1bYxmFEDMhgUUjlIMIYSTr5HiQHmSWamZkD208dkWCzY1pzoo+nzPTMjFPE5kg0ktQRLblLDr8oc84Bgwbooxn6Tq6x753d4sfAClSUuT5aebr98987/0Tl8uFEANbN+oGqQ7o/TTWJES6Q1oxQDkKlnGyolrv1CZ4S6oFLYBjglJE0c7OjJMe71i0d6dxayd1ZihlZBKRJHI9JozKJrZpEBszJbm9B5OUwDypctSNtj8Y81VEqwi9yfuxtyOrR4Lga15OI9uGG42iP2875J6YQyaEylxEsIkOdcrdA7a9c39UMcpaI3egdy84jSkWN1iWgbIxuD9uTNcr2zJgVpFc64Pb/abptBtTjOTaafvGXu+kKTHFi+zL1jv7/U6936h1VebWUEhpq1Xam5Il7ciy96qtQe3kEWnbzuvnz3zzsz/g61/7Pt98/IbrfBVEPRds37livBAYFmlmrG79tNXKsEEhMeUi708T2rCulU+3O+vjQTHj5elCLJEeAt0FrSlHaq+UWGh75Wdf7nz86SfuP/tIG40SE09pJociWUM2WtYuKY1En42cMlO58PTuHZ+//ZZcIk/P75hKphnMn77QLfCbv/5H+P7lPSVnvnn9lp99uZND4HmR+32vlXtP2D5Yh9iWjEOmEAgksEwJMzEkciqCcNFz0DEe60ofnSXMrsPrpOkts0/NFkQztjYI1shzgGwiRhH9zsOHhtN8hdarG13/Yuf8L3WRwqGl4Iyh6EwWTREmtwD4zv5BBaXkQkiZvXWHB7UXOTQQ8JbdNLqx1kbtA1JykovnrCSDlKTMHnZqEVrvpDC8Y1aXZ0fzZ1J/D94iHY4lp6HJIbgFt7KNIsWX7iJ9BHrV1c05ezgZ2rVF11k1jV7am4WzcJSYCEHdtXkabN8bdd2xNrTnAFGkI4L3ovRJKUYSImjEEDV19HHmPh2WOskdMKLJ6ToF2ctMc1EUuBc7OyYrnzxSyiJSpAhhKFjQtT3BSRnB91BafqtBSVFygevlIl+2p2eWedI9UBuxjZN1cYRRHkvMw3oqpcjb8+L7qjd6iybyICJLLoFLkS5OOU2JHoILtE2RE+ktJyvnfELMByVdS3D3D3RXBuDUs+kjas7CHECnN2Nb76y5EEJXSrJ19qoiFQjO5MKZaFFIQ9fPKreKhBth+34hOWMsnvBn7yLQ7LvEwLJYarRdIu51fagpybrOakbEYKx7Y308WKdXAgvmh/y+y58yx0Bye6r7upLur4RSKNNCQPtMsTkzvUv4PtqbyF3ICN6Aua7MHUOIgce2sn78Gb1Enn7ygen5wlcf3vPEC/dY2ejw8kKOSUUoBj7VB5++vNJq5brMajrWeLL5aq3cbne2+w1a5zIt9MtVQuoiT8VcijN3d6EoY3C7feHzp28ZtRENPrx/TwqcO9CUEiVqrWAOWz8/X3h6umJBQuQPX3+fp+VKsMHL8wvLNJOd4NJGJ1liuVx59+492+0zyX3+Uki6/n1wu91JQVZndXR3UxnelOvZPvbXdkDDVQ1swLCmMzFqB6AzaiiZ213ETlZuQaLr1BN5Ckw5YN7YHf6KhOD33xvR42/0+qUuUr2Ptz2FL8iDHzg23II++v6JIap3FgvLCJi1E1fprarT1slJRFHv1Tp77xrLS8GPOr3MpK/qHv9d1XVruvJ9h+8pautOYxbFtblmQX/Q6cFoepAQ1zhj3I8YcI86jx7NrILgPlhD+5Q2+unDpryfpJsj5bMIttrY1539sVPXXfR1jn2CG0TmSJ6K4grK5NOaM+tMtPHT0ijnU5ORU2KeM3UvLNN0TpXFoyfw/YjXAf8ZJR0gB3bUOORQ5PR+GsBqivy5z3YcIlFBavMsWm5M2anlYhQaig1IJqZdGxI0qrF5g+BCCoR+8DkcEgtRLt1RsSdLSepYi0S5eBHFECmg5NOWKyT0++k7AtIQ367Ndxhah+j67bN0Fqh1bDTthwjUR9AsWYq0Ytud0dRUaY/qnn/D6AS/3lUsVJcUDMTqiiFqf/Wdwm8uexhNe7mDEWijUfsuRwzMJ/i3iBYRDzpt29i2V8w2xoB93dxFBcXeBHXhfTQe28rS6mm5M8xYloUQ5YQ+9l3IgR+g6SBJmJqa5EJicyYl22C93SF/w49/76+R8+Bx/5q8LFyf3vNH/8hv8W65sluVXc+U2OrOx9fP3L58Zk6Zx7sPPD89U3JkzhPWB/fHxu3LF7b7HWojE5hzYXl+oVyePI5dlHTb8KlwdcNpzlgNcKKWGXQkyo3a64ZobFOmzBOpzLz78DVff/0DJoPH7cbog8vlwu0+ROSJA2g0+unkr2clORkmsiydp+tKihPbJpvt3o19q3R3yMgxysB4GFsf3F5f2R4r5gkC1oc/Ywk7mcmRHgTp1WoQOrYGyhATOncYliQrifJ8DCW7gDp+ZzXwi41Sv9RFahg0r8tn4XAn34gvr4f2VAOjhMDejdg6RwIrPnHlKH3UMhWus8SYdZgn/AZGTJhHnUfvCIK/iXjQq1sVsSIEf7iVhLuVwlp2tq2wlYRVhxHqrvym9jYZ8J3ilpUTzpQj18tMDMLPd8zHdnX93TrbvrmdzOEdp4KQPQH3YMWJNj5o286+blrWm09IMXKZC9eLiAeHSr8URZofbul7bacL81lsnNBhUR17cMistyFbFHdGEO3aNWtecKRdk5UVPqkek8ZB2z72UqAp4O335Y6RUnJzYV/KY6cLuwWly6YQaDYcHhvekXdsRGLR90/umq1YbBWp7Nc7J5iK7LIu15l5LrTg3pA+RX5XxHvgGUfA4+nQcTQgx2dnvHWyQVN6Ot/vWxPTa2WPd/ropFJOhwYMpmkm5eKFTg4OLuhic82RHP9VlE93flSgxhjEociL7iLmMQY56/PovhMMUTHpJSspYPRKG41W9f6sG3VZz0Zr33cOAhKNn9cremhkgHMvl5MCBY/PQ9ZXOihLKVyWC7gAvF0apRR3jFGuURrAVtk/f+bxsz/gU3tAKfTv/Qb7hx+w3m5crpnFvfHCVkm1YdvOa32l7RvrtnJdFpZpptXK50+f+MlPf8KnT9+yr3e5PITI8/WJkZIK8XQh+sqBPsg5c10uYHLEedwfknX44Hxcfj2bIvTM88zT05V5mckGl+XikpLCt/cvTMvCl/uNMs2UeVba8i5Yd2uVqWZCUGOZgGnuXK8vzHmhLpVSMrfHjXEkQwvnZ3QhPYxA3RrrY2ffmtAic3u0dMTx6DmL4ZiGwpt/4g7FomNC2qmHFDA7vBOFDvQaiRZEJPoFXr/URUoL+S6/tSjyhI1xwk4lqRD1ru489sHYK8MgF1FeGeoMl2lSEFtMLCWTp0wdEulacPcE7wAOyxDMFAIYE3mYL4Mb4FlLeHDBkGnp41Fxv1a2Ktua1mXPFLTUktVKVEicOu3EPGWergspyrrk/tjZnL4pFlY7Y86HdXmxpeRjuIx0I76JP9lWQ4w1p1uP4fHSRXEXoghPMtOE03Vhr/VkCHafBAOR4h1Sa419bzzWnfu2OX2qEIb8xaSv8AMbwOGCEZ0275PP4bH33cyn0/ncRYYyW7ITJktO+PCNnGImfHKITgm3EUjD46u/80/znyVGHIZ1OyN/nyH61Ja1X1sWxbSHMdw3Tw4jb7ZHmp40PJwt1DlFwTHtA77HHP07MKZfv+BNlHU3AO4Q404qheGiy5gKJWUis98vMykVUq7n/Tpqpa2bGHyxnhDoQYA5LKBarTSfqmTnlNQ7F6M1MRrzgBhkk2Oj0YZR9y5ha5RVVIoZXCRuJlF3dcp4r1WftZONat296Xp7P9/9NSBySopirjZBjNMsJ/SYtF+qOJRpuJxhh/UhlKBXHrcvfPz0LVMs5BywarTbSqqDHALb6DzuX8AGj1v2KPjK7fUzP/7ZT3jcb0xEnlLxvSG0dcMuV1JUgnFwRELu+50jEuZ2v/Hp9RP3bWVvlToWFnVkut5RsfE5Z1kslYVLmSEE5suFl3fv+enrJ0JOzE8X3n34gLXO6/7g/niwrTtzLIQ8ToPfQWCaFkIu1LoqZ8qai+EF5Sre3ZvPbXO2pYTrpSQO709Q0xkxkdGcaEaQZyj6v4zmZzLGcNbgMD8b1eXA0DN9eFf+jV6/1EWqNt0E3ezMC3KCK8Xeph5jEEOhHhTp0DX9tEbog2iB67Lwsixc5pl5mg87K/beaRzw0Nte4fS1Gl2U81EIWdYv42iTxnEQCWbZ9p1IF27tC+iD7hvOm9XzqdyhoOTAVDThTUUTSi4Tn28r+yZTUX6O+isCwDEpxBh9ZNehB8uT3gABAABJREFUl6I67MO3r7WupXEf5KTNZnRYIoLgSv9fCIrkTrnQSmZflUd0MB5TytRWeWw7t3Vla00uBsiSybpHp1SRGI6JKBAZSVBhD2+fxRG3ceyEpIwHy4UpFzpH7peMW7O/hxAie9tPVpk6cRVRDnqKG2iKD+A2VXyny/8OjOeyVu1fsgpVCNrpxAChJ2Ic52SXchQcld4amvCdIhvORamTKA6bpu/c24cThqyYOvumg7wGsQmLJ6+2tlMmFeCU0xlhkfNE2DRRgknTFIM3Af4cHFR+/87DAySJEtuG2Bk2uCwL/P/I+7cvSbIrvRP7nau5e0RmFQpAA2z2sIcShxpJD6OlJb3p/1/Skx6kh+Fwhi1y+gagUFWZGeFudq56+PYxT3Akdj9Jqxa8GUxUVWakh5vZ2Xt/+7uc9x+M4wCDmLF05MWIjD4xp5qBZXa77v9h4u5aKnnljnVpzJSa7Sm1nGa5MYg1yNefzIJ++drNQxT07h19qKEMUc9RbxVXDh5vX/jxpz+SL69cXl7xh2yJyn7QS6UcB3/89AOUndfbi1i13lNq5e3tM//4u9/xuD/4xctHLq/fgIPaCqUNrlnuNDknsg8c3vP68ZV/+P3vbA+Nxbrc+fzlE2+vr7xcNi4jkcbAdYWttt5JIfIXv/gFsU8uQU4mXz5/IiZBk9fLldEG9/d3+e6NQSuV+/udXiqXUnn55gOvtxeu8YUtbfjZaWXTp+gm+7GfjFXx95x5LQ4lIlRNWS5rZxVtVVDqct6sgsajnS96OnBI9N/GJLhI90bUMOH8WlHMPk+G6T/n9bMuUt3i3Bc0AIuiPZXCOhHuxwqoU0JtH7ZTUT472Tuu0fNhy1yyGH6aFBzHcLy1TmlGLW7+KdiMYtW11imLnWYP1JyCBIJzpxdcKYXelPzamkUaeC0YYzC6ufwzmD7iENziwyQE2KK6kujFaKvL4dx7GNqLxW2T31qTE0T0EHBkHwmxE12mPoo0WnVSy2C3g2DDMZoElcNPGFocY1ZRTPO2o9Oap6Ug6jqTEMWIHFPL2VbN8r+Zj58DN7TDm2Mw+9JWq7g2Jxr/ghiYWjDP3oVnY+m80xGdLInG6EQCOW44n0RxMAJAm+b/5qKh6PYwzUoMntY1vbTaaMGTgu6j4LRbG1P4/cBRhq7tRiSFTE4b0ceziIO5UzuYYeKCrJzWPtQZQULrCJEcgmmZnB+aVqKH5m1qFfnGW/PVe9PuqQ+8z4QQBGN6aZqiE6MK58l5I8eLHEZwcgMwZ+/36E9CUfCOkQIzIsG5ebURA83gmZQjME0sPZg54Fth74NaKo/aoMrPsdXCMOudPhq9e3PukLRgmpBYmWedUCs33DkFa0+LiDx9p/VDqEBwZzKtcxhZQo7lpRdKr+sSiHQRokTT00GbjCCt1+Pxjv/pj4S8sV2vJPTfH7sc4emd8uUz+/tPzOOFLW+kmERDf3vD7Ts//eF3tMcdZmenkuoX2F74V2NS5yDkwMeQOG6Z7/iO62++8MUCFy954zVdJJQdYvz2Li1isNy0eznY73dS0156JiN0jc6XTz8xHjsftgv10zuf3yopR6LrxDFoxyHBsguk25XpPZd0gaT9YsobpQnlGb3zfr9Tm1EYfCT4SIwbfQSOJmcej9e+M1kQkkHE3nkjzDh8DBDcemRlaTU9s3tGGbjRTO9oydqjkUO2M/L/D959/79+nY7YX0EWStW13YUzbA1OOMctOMHwd0yo650j5UzK2RJtMYhPkMVxFKNiNlJPxBxxQ5y8elTKogKbGn8tyW0RYxY0U87gZq46B6Z3iriUzBXY2UM3TQwaDLt3ED1uTHwYRnHWoR6tuI2uzrhNxXioqk6RQ6wg+DWmmJj0zFSak54yIcp9oI/J3hpbzKK7x3BGuLvpaT4wUHbTcRR6a6Ly7mKFlVZovRpk5XB2q61D2wfTUoyOESNtV/OEQNZ+5ryWk/OGX110YC1gDSKcVqecMzakMe2CRuMxgjq5AaPJazEFFeBpjcbSjZXaOJrgTV0PZzsEu8/cYkjZ8tse4OCi4DcrmAwEC+sneUJ+Nvl7DO7xT4FrLYUYowS0lqfm8AQvjY6IHGIYOufN/V3O+lY27fDpDIYKUfQqRHMFP5q7vkWCa0/rTgg0GDvxEj3MznRa7teHY2+K/xYhw0n4PZS7NFqnzGLuI80c8QttqDlrJizFOTPxVSx8H43aCrUW5pRhcTWJgzfpxV40xdVeeZQH9/1OPScvQxrOIEsJmOkDjsL+9oX39AOfLzfSHESfNPlX+QE+3j9Rjwd7UKsYgPo4qPc7fX/geuPzD9/T3+58ernx8u03fPtX/5piDWuOmRQjLji+u77yV9/+ivCLV8I18+uPv8BNEU6C0+RSasM7yTD22vjh7Y23x537/uAWs6YQL5befhTbl06O/WD4Bu7KnI1yFLMZUzMQQ+Tl8so3t1ccUI+dx/snnSfGMGy1cuzKfPvw4QUHHMdObYWvqLeM1mmtylllDkb3cvD3OpO8MZ5XQvcZaDrV4C12rge8TdPOfALTsqP4J14/6yL19Wsto/2fwCruyTg26OToldGDDluDeMTA068dM45sg8dRlBB6FJotGteYHJuSaidS5rvZCYs04ZzF0yv2HS8WXk6ZGccpGFCOi+ksQjQBMSxGoMTBlZSzFv52QLfZDCKR1X6OkVveyCGKXtvMcRjl5XiW43e3B16uGsk62GHvY9kF6QaTU7qKoScGxbNH73Fmrlks7LDNfkJr+1HZa6O0RoyOl+uV20XU2DFEZOm1Mf1keLkqLK80RcR/VUjhpHmHoMI2myjZ5nOl6+vcV4UJhnNm+bN+XRRru8a1Mbqc730MWhzbgTaBozYeR+V+VA6bauISiDsjdwSvSWiIpSWyju0gJxJGmxuGLKnWwe/MJ88K03qAUaEevVOPyu48MekAL0djdMjJSCIxkbek6WxanpBp4Uo5qK2SoyyY3PSm/4KYPDHAY2pXl7zX1GE7B+fN4LeDJpcp2n1QhzymdoExy327N+lmptGJ/dfowUkWGWfRbWNQigqDLJea2VuZkXEp1HKw0pdbb/ShAuQ89CGSRjfh7v1xp9YDnGkLp4xb9Tmp15kD6JNeC/Xxxh4DP2XPaAfeJ3pv3O9v/PjTH3jcPynVeTR6PWRfdhy43ohMXnLk8+e7Ii1GwbuGf/2Gn/Yv/Nr9lmxJ2W5MUh98u125/OobwpZ5iZE45aHoh6abfdehHkLg82PnDz9+4h9+/3vmvTDSpl2iu/A4du3tgP3Yz3Ou7Hfe7l943O+mtVTc/W27cdteuOarNU7w5cuPvN8fOs+MtSmI9qDXDAih2bKYvI9SpY1KkTnldl6raQU9IuWMoM2Xs3tMnT1+On3FAMM/d+xTESO0rsa1/xlMUl+/ThPWr/5ZOyOAqaj2abZFc9KGdcdzGGzyTLIsQ4fu+2Pny/vO+/2g1i5bkCmrFMZkRG+7hHlmWolIYOy7MFnDVAiBbduINv30oY69945PghhxoqYu6vToVlSCx3v5m7lglFwwWxtPDl7O5ynjfKC7LkNT59lSkKW/0Ul7kyP56X5g3c9iRDIH3g2RNV4vXK56WFLOiqVwDh8ntUyO+859P3iUdu542lQUiQOulwvffLhxzUnO6nZgVO8YVZle3Wu3sxhc3th8yg0y254Yzh3R2qec1j1mnbRIIcP+e5sLPvRyh/bedHP+zNBaE1trcqpftKtuXW6pjd4kLnbh6WSCV3JuMpd1b7ZbwQe5QjhlIAVnE1pbzD1Bjvpe8oTsQZCvDlTpUhqdFpo0eoeaJEF5MjCOMZJzYjhHKbrv5pD10P64ayF9uZ6fY++V0RtuaprvTZBZSmqgxpwwdH+roTPszOCYGEVRaVOOK85Hpne0Pk8XltUgTCa1yf1gGHml1CoxrrfojaoFvXaNRiyZEv+eO8MhdmwI/iu2rIpiqZX7452j7IBiOmaKjKCG77IFUnDncSAj6cYog+MOn2nU40EwUtBxPCjHO95Sq8do9CaxrkPPi2OINADk6GVuXQv3zz/x0+OL0rORLmiMrkZ3DnwTczhfNm43RYXEYJZLFjEz5uTz+51Pnz7x93//t3zcbsQ5eRkvjOPOnI2cI71Vfvrhj8Q+aVE7z8+ffhSEtomBG5OxS5GWSaimeXiOSRsisZgIRPKNKC/Dfr3yJd+JKcJDzWpIQlZqq3oO9BgJGl9tme3FAZgKfgSRKBRWqjWIChkEGsMPWm3/rLP9Z12kznXqwv1ZcNLXYkmNoqNbrARiBTorKCGYwabztN7YS2FEwRe1DY4q6Ke1oW66mc5jwuz+XEb31hXwZwfGcMO6z4Fz8aRSbzkRc1AYnPO4PoS7B01lY8gMtdpSGefwtYOvDGOQ9S4iB3QdWj4I0him9rZdV3CQg8gXDmH+tTeG2Qqt/KQUVHxzVkG7bZkPLzdeXy7ElMxB3NGnPpMxOvfj4NP9wftjp5qrOFag2hikmHi5bny4XrhdL+A95TB35R6Ew/OM9ujILSGkaHR33fSil/+p8E8O8PXc6U0jA/SpXSNN1PNuhcr4soL70Hus3WBGY4ACLHPSo8gVvlYJlue0+8qp4Ok9isgR5yCYYHjLmRgi0TB+j5Nz+rCcMAuj834QrcHB21Tm3JnEuyaANjrHoY5WxJB4iohzTrbyEZOLKdfw6Sfh7um9EIMMS3svZ8MjUbmRGeAk0PgQTohmnv2e7MKkPfTMoGtbetdzUeV2nYMcK5QwG/WZOR3UR5O9VOudbFDpIvBEc+2QuF0EipOMNLQnXXu91WzO0amlUPYH9TiY3eygQmC4iU+Co7xpCPFoX2fJzsFpLznKO7M9eNTCfd9ZuWZ9VKb3Mso1SkCb3YxK1JwstMZHhx/KuJrr/rKBvQb43ZcfqX/8xNVHyEoVSMYADfZz9dY59kJ5f7C/3/nxhz9yDz+R/8Vf8SgPZlNBHa3w9vlHfvcPf0dog5wTt9uFUg9idEbqSbgIA+klm7F+Gw0XMy8fv2EC+/d/5HFUYtQkK4syR+6NlCwgFe3wS6uMY1KazH7j8rzkP1upzEWG0LrFA9OaGBwGRw7c8MTZDL79M5ikutAce6L0oXnniSmcD7P3zuxAtIAevYOTiM0DKQZj9CWjDQue6UPREfZsgLEHAelqnNwkliBLO7GuHZkxCnsHCkRbmmsC6vgQSZYNJJsRM3/EtF/m5jDHWXukXelDJo59mhhPNvtuQq+d5iuuB5bJq+jU4bTxX3qpZkLNKU44KQRGr0QngkRegWpuPoWWJkjej8LjsAlqL1RT/EurqGykOWHbMtdtIyZRtkNMdi9PwuaZvlNdxYsgdlLL1wOyJincc0flvdye+1AciJa6dgvb4Vaa/MDKsNyuOcwRXHuAVaBKU3x6SukpJ5hiqR3HwVEOdZzNWIhmfeWTx6d4utErsdZBCGyXTXlEeNsNyv9tOZ4LVhvgGjg1CX2sCWHoUKjV4FmDYPu0XCfBe2tPlIzcoylMbtezT1yDYnTjEII54i9afTynyvnVVHruh3gyDM9MMDtgnP1ax+SolcdReZTCRPDqdn3u/lwU267OrqysrkI9eycEPW/XbZP+bOi6Ha2y7w+O/UEtO7WKNeotokK3gMGyBm+ur8WEnXS7V8xd0XNGkDDVEHnLj6j1ocl2dNzsuKGpL+JowdFapZQdpuM4Cu/3d/b9QWuVYwyuMXC7XvCbBMajDcpx0LzDd3hvlf/H/+vf8cf/6W+41cJ72XG2YyvHQQqRlBRlsgrtbdv48PLC/csXaj349PknZq30srPf7+zvd+p+mP6w0Efhctm4bIk2OtfblXS9mhONnpU6JCm43G58++GV/fGg1Mbf/f3fEVJg27IK7zo/vabp1g0deZsQOO93b6iQH/N5DpkZL9oWqAkx1GnYtdCqBOXhOcWe9q/prP+F18+6SMmi/0nvHUMizO26kaK5GCDn3RwcdTrayQ7XZJVwXIPFeSdLLe3axxzD4rq79FRzgItJuw73ZPHNIYHs6F/FI0x3uoIfbvA4GnGLZBdIU4yuuGKxu/kLL9bfesjQnqP3QRhebVp3y9jO9gLan41oPma+PyE9fzltnzRJTpa5bjGTT+mxPK7b5xgDeE4DzeEAFymz8fa48+X9nff7wef3nT7kGn9J4SR96E1PQgqEHEjXSLxGXEhi95Uuw9zZzTUCFdQQT8FtnxOiFrMuYIalKhKDxnQiA6iBsH7XPO66NSHd0kgXjXARArrtnmSWYf59cxqr0ha/vZ+GnMrb0R3TEdtsghoM5KmYk7HUjEXppjMKtu6NRe2eE3ADNzo0XVvZ1BhVHsGVrQ8F2THxw7wbnQVLmrBZ+sBB9CoSfTRV+xFl1ItbXp/6/XVqh9S76No2QfWhO82ZC4e0epbYG6V10mfVzWBXURJ9KqEVp6n16kVXTikRL5kxOr47o+Er9qWMyYfgSClounfT4Cg7yOi0Kb1fqYWIQv7MzZfZn3lYYy5haVOoXkJsRibDmQg/6M85JBuYJurHaY/SRrNJeVrTNughCJp0Jh1pcOzmYbjcIjArKLTje339gDsGXz79RCmQa+aHtzf+0x/+Z/qPfxAZ53GwHwetVPZ9P59DP9WQ9TGoe+ElXSCIrBERqWpMQcE0me3qfnSClZ0jpMg1Xfnw7Xe8fPtLPrx+y+V6Yy8HpejZSSHgu6zRfvvbv+Dzpx/kjQgcVU7xYxgE16XXbNWyt4Y5eoxJc5MYBr7DaJ7uF1nNnmPh7fqMdaPqXhvagTlgmEZKlLV/+vWzLlIpRbwtK5daPwaxvRQqN+xg0MPtvIPhTu2S5ykq20zn4Jy3NNdxRjY0g0oAglc4oWiY2mX19rRTsqZUN5VXgdnHwQR8hHyJhKglc0dssj7NH25O8IHpAn16Zq+kMImIhbZYjNp7CRLqXaF5s4DRFGl1kIIjWe5LCBjTLMoQdqwYiHZ2uMwpLcex875vpP2gJ6XbDir3Uvn82HnfD97uD+6HFrneYh5SkkmsL9qTpexJlyQoNVjGlr2nwWDOJ0FC1yGc0B3ONEh2HRfDb1onXY2G7PAywPVrHyVSSjeWXu/N9GISylaLkR9znA/Rk+w4jBqv0dUt5p33mo5Zol5bBBsZZvAUeC+LKjXzRshZBc+JWh7PXZicKkarC+TXXrN1fDJYbOH9c2qfNjnZmNrjWZAhQ4nKDqO4D5EMukHSzrHvO7UemmR9OKfaZf/UR4dynNOrZA4yu9Xnb6QR+9lEbOhnmOOydEqGYPTOn7iFzLn+vqVBW8zUYAXVpqo+aW0KDgeiUfQ1Fi3fQ7PLCpPuPc1xCrcZyxnDdtIopNN7ZUmdlHcPoxVKaTSb5pf/ZjVyjkgfg6PsZ7w9Q3EXg6no+Lc39i+f8b8ZPB4HpagB/PzT7/F7wTdopdLLwb5/4f39TZ8HKIUZL53T/Z3H4yGN5usHpT0nRXqUisgTozJno7ZBilfbEWtq316ufPzwDTlfiS6QYiaEZI7vji0EmIXRLdz1duHx/qYC3RqDzl4Kj/1BWxlxRqZy83n9ZoNWhlLEnce7wUzeJB4id/npULiUrklnSSkEM19y0nPv/ozYfV/Hla+oA2eEhnWgj685+VZM1v/2RkAIMUDwlrqLpQGY4tpcOZ9WPWahg3rsMYeQHDiZhiuvBrSsfHt/gIPH0Qg5MsPSQYjhs1lRa31YQJlRNXNg25Kpw+dq5phjcBQVizyirPOnuuPbJXGbimxwEfwYgCc4o+H3YYtvK9imkD965dEauTZ6UQTE0QZvj53WOs5H0nbhZp/5JQe2i6BSLUL1iUzMrSPomkw6IU7cJdIQDBd7kpaqduuqzFnCPkOHMxjKus06xR48FKq4pQ1OkGqAMY20uzAD0tGfD5hBvs5Ntk17pRj9+d/GkPh0DnW4wXlNT0a4UXEy6Mxg4edi1HD3KgGzdyvawCY8g7Um0yZVTc+t16fOz/RLK7hRVlXP6a5WOVtE4kkcqbUy8WxBkM2KfGGs/Zs+j+N4aHo2UkvycsWQQbHuKTfNZSJGLeDNBkxO7roG0ZrC1gSbRp7PnnMLf38+A8/nc+38+klc0VQk3dhojVY6tSi3qzbtn/yskNb39Of3Ws/0yv86aqGOQSKd8KXg1dWl2krAIw0gS37RKXXQm8bOmRwz6L8Hh02OVaLiaTo+LxNZsYAPfvjHv+en3/4rbr/4FcM59sc7n3/4A/moeGMEl/udt7efOI53hQYGR7Vpvx+Vty9fKK3gQuB6uXLdNm63G6XsfLl/4Y8/fs/nL5+orZB8sM+y2H7aMwi4uHG7vvLtN9/x7TffMj3cbhfFadTBLDtzdj7dPzGBbdu4XC4Sa4/JT58/m4yhc4Zj2vPozkPH4aYQndkmK353hin0wEhYA5FxJmqkq7nUgPb9KfO/2DX/f3v97IvUST1f0N6QxY0PZk+CsXtsub5ea4s155AfWRQNek6oQ/EcpXfDWfVUuPVnT1qzO7tqhXqZ9np+VajsL+tM9l3/vNcuI8kcrQu2ohoMcrLJYkuOyyVaqOEFmKaJGMw2qa3oPYxOfTRcQFEGAaO1O0AHAYYZM2VnIz2OFuExelkN5QhBS+O9d/peOdpDD5MTW1B2U1f6lmBOopfOYyFuK123mwNB6115WlFiZZdgd1DGIAyIbZima8G3QXZUxtjTdGV6KhM/9qZrHELH+2Tsx6dfY8AxgiZd158TmDrwBWOZpilYN2pTyiIVLB2TsylXE4Nyp/TgddtdfnVPOUfwsivqQ0vj3qd99BaI2MF3ByhVunbZTa3wxRRXdIbdcF5d+2zNggN1/y1z368ZjzA0Ma+9zFhkBOy9dCua0hQli1AfQ3qpmJKmnGCkoyBkQlOu7KeisQtBE8eCuucQHLnu+zGeRV+QdCcEicNrkauERL8mAq6F3g6m7UIWiahWK95hdd16isc0csCQw0EbSxdoAYhfFUis2IzZ6cMTDPJd4EdtnV7t73CLrCHS1exNaMnUzxqcExFkDnzyzOD4/MM/8nf/+DdAxeE5Pr3xfv/ED3/3D5T7QQyBcn9j3z/z9v4jLtokVSzVuxTe7l+MnDC4xo3b7YWX11d4RyxTwEcZJ68gx7Fpx+VSxm2ZeLlxff3Ix2++JaVM7dVEyZPupM/crlf2Utlb4yVnc2vJ1tQNapM4X8+dRuC1E9T6X6lmblixmkoAOKH0OQC5zEx7HueUKW7vK85GWlMX/gwmKblpP/dRi5p8HGaOGeTwPPpQVbcdgZ59HcyX6yZ81NkHjQ7A+74/iQEGi/jgmdFB9Ir7QGI3HyTgnM2C3DDKJVOY3BR2PwfUo2qPMuUh56M84vS+bBHutXzcUuL1tnG7KHq+d1nK9KauWgcbJ0Qmx+5IusDLy8blokNnoh2W5DmmJWOhSOZonQL5upFyojnH/VDMxdEqPjherheuMUlf5gJzSmiK/axjCJ6KLkjH1YRv9y66fIzmZj5FjS0BmluR4XICmXYAe2zX6Ff6sjrvYmm5bYUhAtGiOi4ps2URA1of2rdZNz26JhjnsPDBDiwGqJqQxtSXHXx9LrOXRdNNZ7LumEZ26N2K0HMXFrNjdkf9utgBWOOxEDNnUxldDUyfnBlUMQUz5pwWyc2pa1lNUe+N1TbJHkeGngr91P3Q58oQ0jQmX8UiWn3rDCfoO8ZIum7kLZ7eiWN1HeE5KYESlfO2EWNiOvcVFD5No/YkiQQT1w6DlEDweG2Vo+zs+0Pi+RhYZsyaIAWra8h7TsL68EwusXagTnumbsxWuifZ9T7hTDthpwPcuiYS6vepZmI6Naopi3QVDC5UFE2gmt4rOie245SkY+w75cc/8Ie//xtGr1zCRtvv/O0//C1/9zd/QzkO/PXC7I1WHny5/0DM9iyQGEbQeHu886UchJj4eHvl9fWF28sNHwO//s2/4H68s9/vfPn+Rx77Oyk2Wiv4kLm+3Hj57pe8/vJXfPvrX/Ph5Rv6oXRf7wLdiWU5joM6Om/7g/fHwQczj44+0JAmcd3LejaCQdnOWNK2Q+zT9sgQkwl7XTC0AKahs3MqvLT2zjTpT22NlNyfTNr/1OvnXaRqJyTPMPqzdg+T7tQVxNDxztv+6Rl/obQVyDFySZHLJrfvmBJ9wP0ovO8H748HpTRWtITz3rJaHD5GnJPOys3JbPOEO1ZXzJzn3suH52K+lqLMqRbYrhs5bXgwxp5ym/BizG05c7tuXC6bDulUeeyaplY+U3CO7ZrJKYGbXHLguumfl2fdtO5Sh2s7ncCHHY4xRfJlI142XFDUu+9iayn5V3qKLSTrchWw2IcWq2u6m3XY9KJMGkF50z4DnRYhy6h07TXk3G5RIHPt3J6HxEqYPfbC/f7g8diJU8vgS05cc+aaMxdzCym+s5tGZE1Tnkm3CcMxTwducLan4qvu+tmFrwiMGIMSh6269yE/QMGzBnE6R0SedouluB724JZ7iKaYGG0CEpZyTsExij14puQaoij6f9P1WOnSYx3EYPO+oDCnRieig6N9RTiotdGOxr7v4ORRma+Z6+uVlBVVvmjD3ohJ3iDB9ZmkGMlmotvbUDjjHCcBYIyvmHe26xODVVESTKilsD/uxJTIl832f5IcpOgZ48nqW/tfENy2Qga9X7vnZ6PkxlBzEs0V33ZsbiEBTtOjJsinNZePnpwD2dK7YwgSSDdld2lO1TO+bemceB+PB/f2I79P/5HjOLjEzOPtC//+f/wP/PiH7ymt4XsjbxlPJ7iOs69eG20vlMed9/2d3336I9/98i/IFyUlXy8bt9srMQaO8uDHP3xvB5/l1tXC5cONj7/8Jb/5F3/JX/zFb/juF9/xut047g9KK9BlZH1/vNP2g/14SJ/FZPoIZmLNnLzcbuasrgY4Z7F9vdkZTaxpdIZejcEF80f9Cn0YcwhhsOcXpqE6csXxlncV/J/BJIXpMfqC89YDig7N4RxO9BlNSUOTVPCOOTqX4HjJiS3ppvMhPSPE2ziXt2DItkEbtTZ8asSUxIYbw248M0sd6++0Fs78+RcBgBMTNwzXxmvDd3Sgp2SdTFB6alAcd4zCCoZWG1xifEZHXC/0ObjmxCU948Bb77jhtJAuKydI0yZrH3culZ0pLCHnZLshHbxtTqKfX+37vGkjJOyb6D2u2I/sg4Urdnu/+tlCjMQU2R2iaU+p8EU0aIb/P7vwOWF2Tcj3vXA8Dm4pE6aMUC9b4hIjOXim97jR5RNmO4thxJC+6MprYrMDuAO+L2iCkz25olDwclEP8RkdIoou1CpXEFlPqPAsUfQ0OHHtRWSrJMjRO/1c0XZDIXrG1KTlYjjNa3E2iYyBa43JJG/bV0QEm6ZitCbgGSKZQqC5wNx3QVZdVODjKOy1sr1eCbfE5eOV6+uF23aVzMAcCYLtXb13NqVJ+KuJZxIjlO4oBhO56iBlkY1KpTwe9CJJgJqUKONdp6Zh9MZxPHCuE6OePedg2wKTZILr517SPtbTscOAXMsa06gUY7SGUFNg8NFQCv2fn2pa2pDBtA7QYHvpSM75LFLlOPB+gtf0tuXEFj3XS7K9pAgctTe+/PRJZ0zv/Pjjj7z94SdNLrOTg8NlR/Yovt01JoU+OrUpf+v97Qu//+EfuX33HS7as+8GYcLmIx9uL3zzi28lfzD42ntHvG5srx/45Te/5Dev3/KSL7rnrhvu3VPfK25OtqznpZS7AlmHJBIuJC6XG6MWvvnwkdfXV5z7Xs1Gs/Rr7wV3TmklQ8RIQgM3hp159r9t6sdNgtOzEnyQdIZpBWoVp+eE/l96/ayLVDfK91zi3a86rjU1LJaBEmUx48RJNn3U6/XG6+XGddPo2wsn5LHlxNZgzGrdQVfHl4IW8kMOCS6Bs4K5WF06X55dnMOzLIkm7hkk9lVXnpLSabvvYs8kdYMKaswiUzgxC6MT3TfnzMV+jny70EZH3gkeprcI9CKzyL3xvh/sxWjrxmBb8RLedhEqvnovw6uwd7OD6lHWSMuvbmBMuqadkVwiwrl4907WQY7FyFLkwgyeOgd7KfIbSzqkuomYQ7B5dw5mn5Sjc+ydVqA3cEkdfQzakcQUZb3kJiuuehEu4MlvEI3a6MtWTGN29Aa1NoMh1tkoXVNwpk0y94tgvnylFHaDln1cSvuBc9Mm1wUrWnEOXmLJuN6fdkjBB3KKBkd6E3xG3FgxCbbrGeMUOX+994k54aO3r0DMCvb0zhOq7lWcozZRqd+PXY1L8txebnx4feV2u3C5XGl9EEvlsT/4E60a4IKcxRfMHmKAXXlRK47B10qhM8x6SpDwICaRf1JaEJJIDb036iH2JUj7l0JkZgnp9ec5UwfW/k3vS1fVx0DKCZEBMjlHco5crxdJGszkWJBxEbyoHgsXHMl7C//T9V0Bna0vqFM/XzAIc4xuRUJeiLpEk/vbnfe3d96+vPO4HxxHozc98ykGYnJ4ryI1RqH1QqkPHo83Pn/6kbfPn3m8v1Oq2MDBJYJzZncWuWzKrFo+oNttI78obv6bj78gu0TfOy5J3/j2/pnf/e7vLEbkwpYT3K784uNHLinRTefWhHvQeiPHREqZ/XhA79LTGZSta+GMdWuCZI+dX3qW5EwxwD/PQAdPQtiEOiQ7COPPQMw75pBDsluRCF+Zd8LZia3lspoveVRdYuS6ZS4GiyWbXAZiS4UoIWTeJE06jmJ5P4C5SwSjJ885DEqyxb3BL8NNU7wb20XQPQ6JNIl62MbCxVkO3KAuUT/b9HK46NMZC8sRfWRLiWveuH2FLXvvlZfUJz1I59Oq2F7zXnm7H9z3xmFwn19sKedOi5QUI0RPb9KduRTOCWIMWTNJEKqJ0HVZyRxVWhI96Cp6KSRTtK/ICos5MfJE6dWYaAPvk1h2IRhlXo3HWGLAqq/e9BkEvxJvNXGyDpS1U5prtyJK+zSR9HrJDFYEmHiJlFKNtbnSbXXP5BTYsgrviobHmHXH2hM5fUYpy3zYe1HEVy6Vc4vMsijdNnNNqaNk/xOI4XkNWu/yJLSJxpl90tqHCaacBER82LbM5XZlu2zKc5pQZzknBB8CrXf2VsEgz7wltpwtg0oGsyEMNWy9sT6EOdfcol990mGeLDcrON2bvXfRu+cSv5uv3ha4XhOXq6ydtOeQ/msM8EM/43DaQWrilVPEovM5N00MOk+Cg48SDyekX9wuiQ8fXvjw8YMJq6E7Ob/02ZlL0Gykqomev2ikoGW2G4KawBAXmWlF28C+D3PL0Oe5pEH1qNSHBOpLpgdYcXN4N3BBqMucD1rbKe3Oe/nCl7dP3D+98eWnH/n8/hN1/Bqcp7XKnJ0YPX0o3NQ5SBaEmlNUA2gI0oqof+x3Pn/+iUe5y4+vR7yPOB+4XC7g4OiNGhzueuXy8sKn/aC6iYvRztHn2USw+HcwlmzX/gnOSJoFnYb1/I4FA+v0nVMNw/QaBNZA8U+9ftZF6n/x+s8YfBPOg7jN5a2mXdTL5crLttlDJrZeZ9Cmvta0k4OnpygLmKaRltogBeie0WC0dhYu140yZB2oHsZnN/wnJqXuCU+sBSPGxNKSsVNb52jSM5XW6FPC15QUyiZsPJ9TS4yRMB1HlQYE7xQy2AaPt3c+fXnn7f7gKM2Wws/le06J6+VC3DYtzLPRZBfsY7+3Y4dvNBPbGAS57jvdDdyA6dTZOuu6TgbrtHC/dda4r/cdgHcyqvR2KDmjZDtwBtn1ZgvaMXEu4F3E2a5Qg6yo/N0OykUUOTO27EGKZly7wv1iCgZxRbzXdBvg7K6XZYx3ak5OOCoKwlwFKkQVTd9tl7l2cm7BVGuT307NXrColhyDmqaQmGWapZMVKfvsRn8mC4cQBFHlTEqCrreUwQXdl/bBat+VwMkYNmZvP1MyJxJNCd0gmxgVcfJ0Vbedru2qgu3Zgu0rsc9+dO3mmnlZat+pz3m7RC6XdE6Sp5nb+ly86cB6U1yM/d+ZasCyPLO9oXdPxxc/dXC/XLjeLs/4EYOT5UwBIXpyNsjcXGTGFEM1bYmUo5kHc7IZvdmG1dbZojFZ7WestTGdVzKBBgjd/zaFLpeXNVWEKAGw0WUYNEo7ePvyiePtnR+//z1//OHvuf/mL2i3G3M07scbfVZq2emtEALKmAuBNCe97JS+M+NHc8mZ1FEpo1J6Vzp5ymz5Ijo9MENkbImX777j17/5S02vMfE//e4fqUs7syZVMVWe7mJzMme3Z8+dzEvt8NrJkozRPDm76Q2N/awka382AP/U62ddpIKyKs5MKb8CA7HDz/79af44tY+6bhsv1wu3y2bRHOrg6oDDvOfAdjI+0I9GCJVGY0UpzBLOToMxoDdG7aeeILinh6BzT22T3rcWusFbTIK96VILo+vhyUFkh9Iqj1LxvhkJAWHWIHeMBScazMiAaQ7dtcu6v9dJL537o/B+37nvlVK7MbukFbpsmet1Ex6fM86K99HaKe7zVmBDSmzXK9u2EXygt0bYNvNzk9fceZjlxPSKzZimFWpHox+VWTt+LLsamXPihtG9YY2ebkyc14Ox9kuKnP+aeba2QU+hqA+eOKUfEilC1yEYjfrryVFOIWouVqfYWjOH8qjDxXwE16JZabFY9Ig3saizL92b6B09ha386S6pjyZmmwuEaIv7nAg+0uY4iQu1DjljfAVng3ZrOenPRKPTe3Ni7zyhusXAc04MzM2HZw6Z6VcW1Pk1m269Z3Cn1tD7cBIlQnDM2enDKfCwD5pTbPwwyrEmHsibGgEVRxV12TTJsLZYcVKz2E44U1fXZCZf7+IM0sWeNZb0IjpSEOzcDC5e7vopydkEY7U22/uFFEmXxOV6OaNfqhncphyJOdP6Qz+/NWytFN6+3HExMAekmNUMtSpR9vrchpAADYVLQD7xThOjJqQd1wdfPv/I777/T/zhx78gDUfaIp8/f+KnTz/www9/gF7JwXPJiZsPbM5RH+/8/sffsV0yW8y03vnhy0/89PZGc5Nf/PKX/NVf/BU5bnz/Qyb+8I+MlPjw8Vv+xW/+it/+4rfSBcaNX/3l3/Hv/5//g1w8jCnqUlDSgDV8Pnhy9GyXC9sl2wSuArRif3zKmqbs9ulGQuljELo7kZB/zutnXaROTYbpKqap3Nc+wv6tddx2WLhAcoqdkBGV3I5rmzQGpTWO0WnOEXNixslhHeFwRp4Yk2ER2IKksGW7OhAX9MBgVOCvX0ry9YbhGlRjfGE5IkjDlK7Z9j6YRsMZ3KCHTmdUtMmkMy0scUwl3wJ4o3vP2mlmVFrr4Kid0oYxrwSZpE0CzhCfJqohiG5bu4xE1+4qXzZeXm9cL1ctmEthusn7+5X3+zszRLbbjdvrC5fLVY4UPioMbS1kqwq+2GKBnMVo8+YU7u1/g2WBMZ/GqLMrArtNm6rU8eOmoB2nyAbBgdrPiYOmnnbtB8dQsYshUPs4D8DFiOutqeNcfmPROuuoQxqDeyVKtCTllMkpE0I8JQLM5325uBhnJ9+k2XFeGpxkfnvBB6pNX5rcLCSTZye/9okhGAPR0mQX5Gm/0f5eW7Y7uKakfzcmR1F+Ux+dWWx67+Za7jjh4JVW7G13SbSDa93g9nM5z584ljtzd/BrooqemGRdlvO2bNBZxrJjrkYBE4Z+JQrWp6z/3+v5mgzbkXg7LJV75aN0XrMPptf3WLEv6z2fUS7OEXMkb9rvqvAOeqsiQeHYUqQfFnnTB9UYdilu9j3UwDSHkm3XZ880fZ/snrqlLbjzc52EICd3Fxyt7Ny/fObLp+/5KSTCI7DvD/74/d/zwx9/x5zD7i+PM1/OVu/8+P0/kIPndnlh3+/8/vvf8+XtJ7757jf88le/5cPrR0KIfDu+49vvfs314we++fZbXrYbocOsDdcnr6+vRNuPruvmDZ53tj8NXrrEtSZYKQWlSSc2+oRkKw7n/0Tg7RA0LrblnwHcd2qR3BqooXfRdM+luXOsyAHnJilFLilxvV7I180ODzNVHFOW9nNQp3YvwT0XpsFwbMZkNtOHrKtly+TTMsfLJTmGYAQOg0PGsEMYLV3dIOKVxlrkHD3nYMuCU6KZOXrvccPZj+LM4HbBI4XQA3Q5WWPQZsSBxTQ3xhmMdxwK9MMWotNNCCquMQViDmC2RN4FaPKGk3Gv2E+XS2bblBMFEOPOliO3lytzOG4vL7y+vHK9XLTnMQFgb+6EynTOO1LSwUVY5rj6eaNNxtPLiLc2LV+nOXy0NjkOuRMcbRLdENHD+xNX987jQrQuMOAFdpi+zXRTPlgkh9T7KzXVIdeGvp61RQ8Pjt51rRNauHsPOazgOyWPfq0ZsjNV9yE6gEcdZ2z3gipVQo0AalsTyQw8w40TntTvHycEhsfCOkXwEanIfVWclCsWvSMHZ2ytac4RjVIPaJ6jVu0Wp3aLhNUsTNZg1t2UM4sP4CohykDYB0/wjtoFfzovNmq06XK7RLZrYLsm8jUTvRoT2WRNiAE/YBJO1KOZvnHBzeuR/qoF1Yc1JEIeTma7xyhEFwX58jQ61ZrXyaXBCa0IKRjl+8Jm7hrOI1f5qQITGWTvoEv/t6D6ZOmIMQXlJ/mACxF8E2HGCaZtJtourbKNzpazTcCRnLRLm1ENzzah7V9o8xVG4nF85qfPv2d//0JvjZGk0azOsXlHdMDx4O2HP3CELzAboT/45pb55be/4DXfCJZ5l0Pmlm/c8oVsxWWMbrvRTnSGMGvhIQ3nnCKcmWwjeI9H4u2UE9fLxqATDgmkTjH7XMa/GClL0HlOSQzPf16N+nkXKdvmWnlaDg9fiVXt9zy7vUkKXpTuFLia1ZAzkkLpjaNW0aLpNsV06qhgQlummCm0lfWEoCxzNFjwTIy2bI9aoE+DU1pbQjkV2DGbBKFdFi3NIp49k+ggorwiNwX3OOswtbThaSBrHRp2YzC1zMwp0iYU56TLKDI0LRY8lgxyXJ5mKWVCzHQHbbTTk80PsbnWEl47HRFO5oScpDX75sPETc+Hlxsfbi/kTaxE51aQ4zz9E9dCVQ/3iqhe4ktnnaherXeOUjhqxY1BaZP3o/LlcfDxaORNECPJOmObxNQNesZqHFZPMU0CYEa+S1S8eva+CAM2UYecCSkRk7wKa9sN+tMklqMW0jllFb0FkS6IajydUZafXe9i1fQuHZhi7QOjK+DRYdlT4TkNreNjQXN6t/qUhglaR60morYUYevyQxILslozsPZM0k+JKn7sD/a6i4XpEs5Fpk+2Z0W7SptAzmYiiPGoZ8w9IT5rQFI2csMWSJdIvka2TdOic9F2ybLlWhCfm53aYdT5LFLIn3F97nqk59n4STmgSbA0Y+BakV4ztHM6MOdEsTauk1Lm9rJxM7hbr6Gk7pQRYR6Cd4yu3XSI2oPiYdsu5EsWGcBNIw2pKOpeU9ZXn405OyE4QYjeUYog0HzLxC1wvSU+vCRGv1MeP1Km4+3TJ3p5JzhlUwVvO9febSXqGL1S9ztwaLIZg4AYsMEFZpN9G2MQnSPiqMfBUR9wvQmVK537Tz/R90O6yCBIvDnt9aDbfspIVimwZUG3zRrOlYisxsidSE8fa8IF5006NP4M8qQwQ0jtog0zjYuFNsxlWzRrZ9DJFiWkdMyn47UdRkdrHL1DcBJC9kanW6fXrTgIWJMxs8kIbZEb7MEMhtnmGEgxPRfP5nax/AWllWq2AxmMrjgJx7D3LTgwMA1O7IzeLOLZ0xMMLCyR58Po1+EkwJ7lfNENdhhdhAOc9nUzOlyOBPvycVHs11JfB3DM0Wjvm2jxMVqRcqS0kWOix04OiY8vr9w2CaRXN1xHObVmzixXpgNzALXD9gRjziWrvNnUiY6uWIXaGkepPGrj7bGb28SFMIPCDl2QRsYPBUE2ZU+lbEm045kM22z/JPcEPWAnwSJ6CaUvGzGrKCsKQYy2ZNT0FCPZJ8XGn7ClXisVGczpJCZbQA8FFE6nTrs6eovMrFgO5zExsQ7XPga+a+JrzSYJBR0xpjNRtGWooefBefeMX7lsxKxJ0E8TKJt9mKCYNXMIkq1Org/LaknMLU2TIRvxJAlyPN0DnCjdvSszKG4RZwfTcE+qsuyp7HMygWiwgt3HpHYnYftQbPxyVcH2jU+WpK6XJsh5Xr9uz9VCEsLa+tuzqryxps84ieV4vW2ktJ2TqPOCjUUEEWHHjWdW1ZyKetmuiXzbMHojvR28362kjkEfgWY7TMUHmctM4MwGyxd9H+8hZc81e5LX85ciXLLg99Eb3XWWX4AzSFGmAp3pNEmXdrCPSek7naFUATdpj8p+vFPLztv+zu/++Hs+bBlXG9//+D0/fP8Haik6R6zZcCfULt/PyHLg0fOxiE/TrXtR9mXNN2vMtZMTTDvImxieLvwZwH1zmNW7F0iyjrezuxrLCh8mQwdOClxyPPUaYdnMzEFplTo6w+sBqWPSHKawfuqnFhV4dWnMteJXlx6DE0srqNNwztHsYJ3Tnx5Xc05Gc9ql9MHsMoJ1ftWdp61On920BvJwg0DrnjY7Y4ZzAllwmfYS0i6VUqlF+PqyCFqRDNFNPrxe+dWvvuXXf/EdL6+v1Ol4HAfdqyuK5pqct2wPlSjvKWW8FxNOXnOKn85B4uPLdiFvmw6ELpsiuZe7k9SCc5pOUmSiwxbkv4dzTMRwLKVovzb0QCqfXUSKUgp7Odi6YjL6UFOSUyJgNPzmicEzSfYQ2SHkMDJHtbgVuY0nE9NechTV95LPn9+xik3k5Xohx0hOmxys0UI4+lVkh8EeVnoXG2quU8b2bq1Tjd3GuYae50G0kIJzajK919LZtW4ZVUjYq2j6SAyeHiKjD67XK5fLhT3uOK+k6JSSFarnsOQtxmJJLqId8DN4iYTj0uhpAo7BoGsfiEkSA1HK18SjwuH+pFDYjkLMFOSDaNDdtEKEiC/rc9Mz/7Xd1PzTXyfGpHWa1h1/ctiu9zOdJXM7S6g294ycla3kmjstr3oX+nG0zqiDzXtCyDZle7mnpECyou16px4710sip0jdZQCdkrwxT4s1bxIL/zTNzTlpUnKeLW2kIG/PNQ3nlAkx6nlYGrLedQ4OmTsPF+hDjMFHLby9/8C9/IpLDjg/+PL4kU9ffuDx/pk/fv7Ef/yPH8izQK389NNP3N8/wTREwZilTDWWfai59FO7tD4GbXayj8DSqGonVRk8JjBF+pGJiRx7pnEQx59D6OE0cZnsU6xw2EHtpi17pyYmvDrH6C1SPScTV8rBu9YqbH4Y5bt3Slcq71HaeUAuGArDbBe7bnV6LlhwYIykoJhpdX/SwSiaGptSdFCtZWOY82SGSYRoPoRY6GF5Zg+F4IzvarolmyAFM43nnq1WWjnotVAOFSo/J4FBTp7vPtz4b/7lv+S//pf/km9/8RGfIvtR+TzhUTXmx2hamsvFls8rUNIyaJoIH7Vr0ksmVI4mDJShp+xoWm3EEE8YL4RA2uQZOAaM+bSwATHUJMSUtsxP2FLUIZAdMUwCHTe6HWDabUyj1weDYjFhtw9OkGdrYBTZ2WXyuaI6vPdc8kYY8PqSeX298nqTs8WWxOpcRfD1KvLIJWXSdqUbKWa5T0iRr73aE+5Tvtci/jhj0En8qp2Z+2raxpmiZ3XOzhhjfKWZau0sEmvCvcQLMXpqLYzRuV4v3F5uHC+FGBy3l5tF1Ijs0XrDF4+b3poPzeQrNHM4WWTlFNly4npNtKOY5ZNsryYTlyZpBoPzTPzrnU06BjF+JbIea+Y01w6Hw8VAHM4gNP3MvTWWrVdrMqqtBm0CMjwdYo6JEPNsKJyzey1paV+HdsbOSDshepyfYv45qNWd93atlaMMeqn4lIm2b8lJrEQf9ZWiZ1wiKXtut43rNbM/NFGqqYv2uepLgmA9Rx7JQC7bhRgyKV3I8cYInRQPPcttcJTKKJPbNVNrt7NOTafEWXKUd63SHu98//f/M2kGHt/8khgmf/zj7ynvn0i90T9/4se//1t+5wfRwePtHe7vmpxslyR4fjJ7p1qmnp9Ca0qV3vKShAqs/epqtI5hht1FqJb3IgjZXoT55xDVschz0pQ+H3bvxEbpBtvglGC7+ahOyJhRIl3MMwahtMZeRPkexpBqtZ8Ph7MdkH1L1tCiFYpbkL35icFwsoxxDkFPzJNmPZdxZlf7O7CYBm9UZ9TBPSYcvcqJYkyOIrjPeWhDHfK09PZoWgZvN67+Do3e5VAMQrNJKjrYouebb1755bcf+bBduXrh7yNMctCklr0npswWEzllUtbv8ZZJ1FunHgflsVNK4XjsRJ6OD31264wX1OEM8om2fHanIFRFRFj/Geq+BMfDWVGfbDHwYUt8vCQ+XBI5OXwUUWGa0HNpzdRoP90Zptdit3VNdLVpGttLO7tT77QIj85zuyWu16xJyg6zVhvM/lwC40xnFQkpnd53zqtJiikyzM0ipUSKie6a9TpmfhvMwsvN04U8mvO6TGYNCl2MKPdsdEfvuJkN1svkpMKzpY0YdC1KUeHato3tms3zcGMLieCT7V7MOsg7wY1nA4bt3iIpJi45c9sSt+vGYwyik0jbsEG5f/in0L0aG3WYg4EEuILSBCM1c1BHECWSabT5bAp98IzhT9i81yYD47qaR8+K/1hqyT+VpCxLK4PbQrNQRpuAMKTET1l9eQd0sQtx5raupOxqgtSJdFfRe0WfpET1hYAjzMkWdA/J7UbFPsd0us+fzvR90odjTrnEOIL+nj6o5qNHh9GnCWYtS2zqjFlSBzlrzNOb041JeXzh0x//AcqD6CfH/U4YlVsKZO9ox536/hmip+4PWjlUUKJsn7wXIiXSks6V6jo+KKW51EaNcjivrdN7telWP0MzM+wxBsGJMBZGsDP1zyD0EJ6F6mRQzGcB0W8AFxw5BG4xc82JFITPelQwRoPeGqXpYhy1qfPtU+mUzf4e521KMfjBeYPQ9Hcu0alKjtiCjK4d2Wxa1PKne6w22pNWHESZjgEcuoGPMRllUWEDpVRRvplcqhy/W+2M2MDSQv2CsOpBPSrHXnh/3yl7kfgQFbScxMZLKdhyX4et9xLiXUk6fDctkFNKpLzB8Gd0wRxyHN/f7xyPnaMcjO2q7sso48HrIGAqI8uZJVIdky3bwWvpny4oDRaL9PZL9zOGYA3v2HLg4+XCL15e+HC7sqVMDrJpWtCE5fae72EJYPuplleCazWyyn40HhbNLfEtbDmyXTNpsyJqcFW32Afng3X6qKPeIjMGwbZe9k/nPZozJw3aB1r1lEOixzabfT9JAbZrNiPfwbYFvMFqwSygllgS21eF4Ng22QBlg/BSyuQotGDSxbKzPxtjEHQZIilkok+EuGnn4AIhyKh4zim/RYsoYcqy6Joz5Xrl7bLRSzMGmOmHYiSkwJwHKx7mMKePq9tIMbJtWfuxmGh2P+oJXuQHMOddy+ZyFvIohGBMHY6zT4tOmaeps1sFwUxMVwOpCTawkBdv6EvAnaxLwZD9hJwXWUoNmTnaWDKzTENsd2nkEe+EiLgOfkgzGZwjpyhnE3s/GJmgmdVYrY3Ho1JKZwsG65fGERu9NuqhYtyr/jm5rLDTpkkm9i6DgSmB/bS731nT2uuDUTOEQGDgjbafUhRbsRzQYN8flKrnJMfAlvV+j6ZvX4pIXiIVRSN6Sfs0pu3ozVzYTYO7DZViihU6Bow27ff8OcB98ysNhWk49IAgTZNb7rxKqs1JVkIxhnNx7W1pW0eXfqhOWnMqKl1R7EvvsqYD5xDjKwZGk76JOS0PCZ5wjTky682e+LOKWgA/TfPz1E3ghD9PjEo8O7UXoylH6YaH4yiaYNq+0XKiWKZSCHpfzmjyda/s7zuP953740GpTcCoN3o2MPFGYVZTlpyKuiecnabo1mgfZJ3cyhJqvbOXg33fZTCL2deMzpxGMohy2I454UOijMkxFWctmrcKWIjmHj6MLFAX5CC4LnhHCpHLdePl5cbH25XLZcNlOTpj8tvpZFXVh8Lp6rDYdDsczl1f77RWKK1a4JsenBDkp7ddNlwI68w8l/SVRnaR5jrDe2KcuKiuP4RwGvaKDOUI6yC0BeactnQOTrqk0YjphdfXV14+vJK8x4XJpVx4OSrt0Sh9ytEiWEQG0gIlY+0toexySXDGxFuict3D+v05JZJ5/K3iNaOmsNG0R2KiKctn7W5cO+/lWiq365XZuv2zhK/SQgVq86crxqIkBzzJBRKRRCD5gIuCYAXjaU96eio6zFDZyOZd7vIDUbrbkGNErXpOg7naOyzx12HNHX+yR5RNlcJNqzNPSiNnzEXbMUTEGRs0GKyeYiAFyQFWYOmKoKjIqR7vCSEJKVhaplbZ9+OEZ73BvypSRQ3kXsgvr4Tpmd3Ry6DWQTka5VA2mTuDH5XL1HuljyTC0Gj4Gc9zccwhcwAzo0050lu1aaezBU92k1oOGnDscqmZfRKSNHvjKySi9UE39ClFxxAN2hIVdF9r7zesGBs0uGrRnPJ0rEoTPtGSf+L1sy5SgA7b4AgpkDaJ0Fpb0RqmOplaDKdg0JLtBdxEbLk+uR8H74+dfa8K1pPvDrPrQAE0lWFu2NGZdsYR0Q2PG8ZUU/HxWLKpPbxjTkZaqZpaHzKk2aFPGct6Tlhw2J5lHWp2TttE52W/0q3Dip5yQIiT0T30TiuNWirHUXjshcOyqEIQ8815KEfh/e2Nx8uNy7YRbQrzo5MctshVx+kn0AaTijdRY6+VchyUUqit4oXraJFrD8fy9hK8lzRN+YCLUTqsoWnXO3XtzoOfHjemoretMDieppXeqfDly4XtdmGmQPOAETAkGjbqrHNfPdhSxasrV4kWIeYreBBvdi8WsWIQovMi0Yypwld6pxn86GKgI92b8+50pxDME05n9DG7/MvmsvWUp5zzsF0ztxcV3QD0bjEy18xji7S96d5LkWGefjGpMVvRFT4ITosrqsLuW526RtxY7DjzPXQE7XKwfaNNwaI3RytUEVygOR0wKSbyFmlF01DFPPemhEPexEbyj7AdXtxIPhJwBoN5hhkJd2O1dhNdr0nHVpNGEhiUUHXP4Ki1cxyagnvr5KR7SXow05LZdG3MJ034TCPi6borDmZBw8+CFoJZf2mQIkRH3iI5OIo9v30oHXsyCDnSTPjeTRyuXdd6dpfImqcUAc6zJvnA7XIlL/bnUOJCKZ39cTC73EUi4Yy9eaI6z92tmKLa1w0HpRzsxx3o7Med4zhgdrYYdJ/VdhI55li5dHIvCR4jma1dPOsp1G7W2tyFEHmjpy+IH++kTBy6Bt11RmmKPxh/BkVqJeOGuB5UkcOdmQs4/aZzcRmj0iA1qXgNvgOO1rgfB4/joJYmPz4ELzmjnw8zuvReBdEv+MdDDFHjq1M37aJj+G77EdFbb5cLk8nrzLQ+eN93vuwH4xgnY0vDh6Mr/YNlgbmo9LM3uv1s3qxdepV7Qw+2mO+D4sxjoXZa7ZSqiOzWpfHyc+H1jlIKP336xGXb8N6zaQvKmIPrFtXRO8GIDlvYT1kVlaNQHzvHY+dx33k8Dq7J4+i42S3VVkw82T01+ckN6cCSkxDRz6cb+4Kvshm8HkeBYOaUw5lxqWHxwTOiZ+aIz5EUPS5l2pykJrsZ79zpJ7aEiaUINw83KyQei39YBI3BMMx4euugcyZk5XMtJeJeD96PBz56rueuaBEcRN8O4XlgeidBQaXQKsbwFDU+hsTr7cp1kw+fd4oDWX6AK87DB93nww+ZIG9ZdjiXJO+5TQv6uExCx5NWPs3hQ0ihs8k1CY51gekn3hwF5nCnA7cgQU/v4OlEH8kh2oHkmF4RJCuKXlEeYvn11QxM7VIkqPZfOYqDm5qChp8MHxn+q7BI272tYlNq1XRCYLZBP9SMTYMIA54czGx5CUvnc+JZu8m2CEh8pTdzzwJFUPhhTFFuFEmf57aJfNWOQbWpT2GVT62YiCrjFLA+SRvpvDcWESaZO3zInryJOemdrNcIUbvyogaTObRS4Gklppe6Vznld1qrJjVpDO/ovVLqzpgKm6ztkO+i0/k37Yd3XpNQTiKTbDkpW47Blnceh7R0wXSgIeiad1Z46UrmXZfuHKGMmQpMGRcoC83xz3n9rIsUWHE/v/Rw9DDwTaaygtkUSz6co7XB46jUCXmK3VNK43EUSqm02lQoppaA3jzjtEQMZ6GSPYw/4QzC+vd2sa2zvWyCpa4pGQThaK2QE0bDNEfyNhhzOS403FDoWPSeMCG6wPCOY4jSGr3GcZ2lNt6bxiTmBFPL3tKsSNmiUhOZCSQX/BECbXT2Y4fZFSsQ5e2Wgj9JITb7G2vLM0Og4Bml8ng8qLVwSxtbEvV5js5oxWj9nl465XhQ9gejFuKcbM4Rpqi0jknyKlApJ2prNDrV+usxoXUofVK7vBb70JTpg0VUbJnaBZPerlft9R47NRyQLxx1sMStt9sL19sGbnJ/NJzX4TKkTGT5/+Utc72Jvg2QeiOnQKtyaigtoViOeXabZ1tuN6mbZp3jAzMOK4jC74PnjFu5bRs5ZZyP5FYJMRN8FsNyG+RrVtBnD2yXjdvrVe/tepVMIF942tiYi/W0oMveDBYLRvs3uC9nYsz0Ktao74JLVyyM/+q9uxHoi8ywUAxzMFcBVCK2nLM9fmpnVHqzVOX+VQHS+/PTnZPFjO6E3p5ThyX6OtHqt5BOK7HZB7MN3etHMcshW9mZvGRBfgtqm1OZWqUUEumkha+mV4VUAvVtS2xb4nrdcLWStyjPPyJjaMe4bRuX2wUfA/v7Q7T9mEipAYUV7SHmm6B9NUeOnDPX64XXjy+0PvFhEpNWASKTiGGYgt7T7E1NsmmtnMkW+uhniOf6zIRKuNO2q7VCb2ZMgIhYK4trTfTZQicvOfJy3XDR01zlcknkXdcym9myC44ZpmVFDUt9cOCH7Z7aV1OX9vHLU14Q7j/v9bMuUt5EFgv+Gd3sh8Ya8d2JLQfUzbU+aX3K3NTrkHh/HDyOci7RezMLFdMpfR2QNzBDFofSP51IBH70sxPxQf8tRXmUbddMCtqHAcwySbNx65vG+V2doG/6+/qA7kUmCFGdixY3Dh8jdYxT3xWNGecNVphmAisKvej0tQmbbgOYk2xOzNuWeH258nK7cb1cTneJnG1vFxPep5OK7Kb5F06HDwnGpARFmM0uVtQlbiZodTjkfjCDIhLGHMzeGbXiW+fqHZsPxAmOSXRTu8MYScGJyIFYmo1JMcFm65OjNrPwKeRxITlPSjrMl3lot0ZhNiX4ingRcSFwu164vW7knCgjsV0u+FQYc5d3YkraK6VI2jKX243by4vef2sGyThzHjDYw4gkKxxw7RpnUPGSwNSdUe61dY4ii53NWH8+bMZ89PqMnWURbYnLhJcPKkK0yXZTNMf1euV6uYi+nLIVJ9uhDoNjnEEyo+NCJOWNGIJ8E5Pu1TaNedf9qZHyaKKKzjP9pC1W5xi6J0VNxQ9oveFyPP33fPT4qWm2Phr7flBKow1boA9j0rnFJvQnfKUJbMnmdX+4CfSOn5x6Io/HNT3P7dKpfYl6n7DeahzGWH6V+nXtvpIPbCETgznCO0efopR7NwnZc7kmCFiGWSA7SQa2HLikwCWK8ZSi9qr1EPtTDjHKdWuzM81+LOUo2NQO/Os1cxyZbJlb19uV6QMxYaxPQdfTBZr9TCEoLkQu/v6cVnzwYi56Pa9LpwZLUyftXu2dZEXcm7/h9A6fImmTeB3vSBkut0zerTAnQaDdDRoDgiUrJ03VfTqYSuftvcsAHE6Exs1gcOufCbtvBaEJw526SeuT8OC8eaGtziIEXNBE1OdkLwdf9p33o8jTrlQLF7SH3Jkpohe5YAkSJ8tyfthuSqw4RQTIMiQlRaU7uzFXUcsRmhvk0rlsmSM1Hma46geGwz9HxHyJ+Bm05sErhMy0Ufmysd0ubFcJDF2QZW2jyA19ddPK+CAytCMLjss18+HDC9e0sS323pZJxgqLMZ/CUHGdJytF2PtJBLKX8DXGwJY2rperTDfN7230BkOL5tY7tRR669CHLJ/Qni86R/QQncw8U7Dlqu0B6xjU5bbQtQd47A+OcuVqhp7TPyPZRcCQl5/znpQTZTQFNDJlz3NJLAsrn6zQO4/3+hnzJfP6+sLHDx+UWvpyU8ZSn7TpKLXoe4dMConglmfkZLQmSBJBa0MnO+poFT64HweP/bBIbdOU9Gk2UpoORjdpQvCES+L15YbzEIvnsm1cTjaf0qW17/O2hhkML3r3KUAf0xhpgg81Ncu2yHd/dt900yu55dVmh+Qctt6R0bEeQvA2oS/hLM7jLbsi5QTzEDO11jOCxllTtkJLv86CU5orpiczSNzQAv2z0dFbM9Zbl9u/TdYhmLOIO1TnnGC8nDQ57bUQajgbMG8PtXdOz9CcZ/xNypG8qRm5bBveOcroBAe3FzmBpxzxK7crRVysT8KKs8nSzoloLvCdJvguytUhBK8UgiTSTsede80+xpk5N+ZXidVg5roia3jn6Mdx2nJdcuJiELBs3JYRtf5cTFFNMNp1q7mIhC3jY8R5xzY9t+uF601GxDknUvYKZY1esJ0z+NGZ76XBed2e2WEFdIDIH5MzguWfev2si9RyfBjGMAMnOnab5wcVfTyLRk7xZFSJAzA4eudukNuK6PDuBGr00JkqfY3GC4PGyUjGnwF8nJEOKUs4GFNQyJefdD9kBTOi2QrV59J6yMoGvyZDdcDWQ+Jj0kgduwgHIUKO+OtGvF6IF3U9IUWYOmhirYSUcf7Q2D2nJZcKwrjdbtyuFzazOZKD93am6i4TzEXzbV3wJFMTgxtdOtMFbHs71MzXz5uHXW+VZvul+/2hAEnTP2l3KNZedoHsA1uQO0SxyQtzzWh9yEl9mtdhbfIfO3byfiFe5KrdvU26PpiQW/srFwMFMZTWZzitY8/XxHaJ5E0LY+fh8nLhu2++5btvf8E3H77hcsmUcnCkg9dtcjjPXg4iuv4Ox2yK19BgJW0J3Qo7gq2OIqblfVehOpoJrc88MUEzb+93WpXzuTwWE7d8wQX9fZfLlWu+kAwSDD6dBcI5g0H9003+LKCm41osxGfQn38uvoNNz4FTc9iG/NZaq+bIYFR+/BnHgvPEnMzQd+ADbJdOzA9tWOeiKY/T1X5BcE/3hSVLkOYHNF1NoPbOblBdNycD8Orc66QXi+ZYaOsUEpFS5LptXC4XWqtPT8oVF/Gst5zO824RKILpA7tswIa8+y5B8HLMmrp9TLIWS0EJCv4Q2dSMbQXRqZEIIejeNreNaJ93XNOR0/4x50TeEjEaG3B0HJ5mZgPNtEbJfDUB4vBsMzMYamSj9pwT7bDzinXxKoAL/p9TBadNNUjTCEC+mz2bZZ7F8CQF2YxqNPyIJzD6YQ2Fzow+1j5UIujWh4aIPwfHCQe6P6eErwxnNz/n6Bu8YwuJzUbxVWyWncy9Vt73g/f9kLuDHZ7TOltnuU7TudNzL3iLQddJgJH0iE7fezGSFkGjghaa9SB22dlP5xSuBmCY+KLLS30Oy0ej14YjSLznHT7IY2/myIyeER0j2PI2JzlFNxNc2X5iAfVLdCqnAfkKSmIiRwzBhcrPUWe2iCbr/55O8KPP0/NODK5gBCoVK3VSS3zZOQ5jGJovWnRLUi0BbfKOaOSP2RujHIyjUB4Hda9y5QDzNISAZ7QuhmEpxLrjoha5vXe+uhBfHbSKux/O2ddkemhThJDgnHwXt8jLxxdeP77y+vrK7XI1VwztOFNM9NrxXW7byVt8vTldzGnXfAzmrDpknNduxWDK4zjY9yL4J0hLM3oz/Uxnf+xiYjUxQFNekd3OoLpov2ZSku5oOV2siWTO1czZ/9Yw96fUdPfUESmnC8E+6PcsC68+Om0YJF7reW2XUbAgBY+PIjBowT/NIdyfOV5Ka34SWeZX0OiiyovooCn6LGZ2jWrr7PtBb01w2pzMYVPi+WUQ15BR7HXbBIdG2fKsIhXtPZ3FfH1u9tmtne/KjWu9mX1Wx2f9uZgl9p5ObMCUA6PKgcMHGE5TWTC9mTOExBv7sg0xAmGe12Q1wjEmLlftvEISjNy6HFNKfUa1h6AwxzEmvsrtvjt/+gXmKBg2L/lBksHt7J3jUKFrdTFdYXrOxi4MFcEcRf8XM3ohKstHcZKanHKWZGe5T6x7ZMkX9P77GSn0T71+1kUKpy54slT3T1zbiwlOjJHblvlwufJy2bhkmRuCKJ7vZUeKcvGS1vg/nDpIjKU0nfvqmJ4Waqjk1jHsbnZSi0/WajIgrz5zv+gVB9pnTGh0Yg7k60atDxiDQCB5x+bt0GbK72ooMkJ2CJru4pZk7eOEBRM8mAbMd6NFW/cygJSSDRX+ZH/10flSDtgfjOAs7DCwIdhjIAdyTZ1VzKXpCF5sw711yoQZolh8VrCnVb9hya4DbIk/mV7ddnS237DP1TvdkHEO5mxQC7NV+lFMQa9JNEevpsNpR5BCJHrwQ6yo1qvYUQQ7tAURlSaafLeO9OhynC9TGVu1aqcZvONy3bh9uJEvcm9YAtnaO85HSbIoRCIv241rvrFtWVPG6By1MtrySfQMyzobQ1qa1pqSmHGic4eoxIluAs5SOfadsh/MPi3PK+oAGU8t1yIceB/sWXDW9Dxp1crbUlRL752VkTAnpvXRATINgvFWuHDaO0RvruuWrDrnsCIwzRjZDiLb1bqgmJUomxTtpqKmi2wQ5Zbl+/iU0MxTFnAKckNgRCcNUq2ScDBF/19ejvP55zWJGHPQe3Bylj/TjlOU+3gKHK1TsabNYC2smTuznuzZX/ZIrndmSnoPo+ObWR5FHfx9ToNW5eZyySoOdTiuOcs6y1zyRWbQCqG2wvvjzkCEi2iO+9OubYhR7M0cwcteqg3F1biwPBJ18oyp2A2c2VHZhdZkanqpGNhSIudMrZVW1NDtRRltIURwiknx1qhv18ylZZYnqvOWL2fyGka3neFU3t06Px2oEeXcEXbEQG/r4v0Tr593kVqvOYUqe3+mUU8vS5ocApfVSeV8HmoOFQmGRGlLPR38V9isN088u4mnaVOCuWAveML71bl6enensK1MwWN+hicV1cmxIKZAvmZ6Gxx7ZX948lS8xy1HXvIzdDD4KOeJ+WQprdC9NSEMa5Nba6KF2r+vdhhqHfKMBvdB5IraKl/2B81NGrLwSTFSaiOmRJ8qUnU09lp4HAdtTKJPjFr58v7O58eDeymnKwdoggpOB2QdTUtUVKymX12EzD6nWx2kpt9orgK9Ffb7g+O+M6qEpCl4blvi9bpxSaIG502x91vetMev6nqXxm25kswxBBP1TqmVx7EzJ7w/Hnz+svP2tsPRGDGQL3Kmjjnb9K2mZOKoc3JYIRLpYeOS1am3IKufFBN7OKyogVogI/IskGROwWMhWkqpM1fwwTCngX0/JDKunSs3RatPFRofg+0M9eW92QytZcX0tqB++vu1Ju9EvSOjZM/xVUzJU4um4t6ffphO9OkV79FaM29IySCqFSWcLe/HgnR0eMckD8iLkTyInl7nc1qx77miYALWmds9r7WROym9Ho9z3SZASUCiF1TsnXYuCYuBD9oXXS4XxhhcauMYYpypKD2f5/U605N7l4GvRU7oU9LX2ls5JDdIIUoEGz3XTUzX2eCSMzlqWl7NxcqDG2PSxsDHdAaOxi2fPpaw7L7Od2bEGjubgonyp2QzfXRZRvWGS1+Jey3ZIPjnFL4/Hhx7xcUoYkt9TpnetHTDVh4xBubwOnvcNIKTptFenm7065zSe/fM2eRMYz+DC15szP5nENXhQxKstKC42dcGR2yxIE2DtB7OZhvBRaJ2ajfShhJq3dp3WFvppthHOEwvs6i2BpcYo4Y5tTS3A9nkIow+SUMU8hAcPkwIsDm7kTdHr0niS7NO2ZLjeolcbTeC84KvamOg1Nk5BiNOelNujk8wmyPOQbFOuR6NxyH2WK3Se3XTe00azMScMq0lKqsplYOjVi45U1uhFM8M+vv33nnfC8dRpA+h0Grlp/udL+/y7bvkLF1FbzxqxU1v8SCD3ejHa5paAX0gqcCYMtR1cdKqqLK9TI698nbfqa0RPdyi55qFsd9ul9NlIUc5r08n53plOg3oFdcrsyvocblNzDl4fzwE+d53vnx+Z3/fCQXmmlDckhTIrkk7ERHiS915K3f+4uMv+eu/+ld89/EDIWDFr3AvhZizCvQcRsdVkrAcPML5eeQUT9r/0hH12Rk0Whu8vR30WrU7DJ4+J5frjeg3trBxTRu3tBFCpM0ObpxMSrDGYYgtOaeowaNL9iCYDGaHYVO/ob+22ww4ZFrc5qCOfjq799Y5du3SWq1WpDLOKO9jNoPnGt1EoWHzuOzxm0IS0+w0rOt3wSB6uWov9/gYIjMMe45FgomXCzEfbEOuDLs9kytML0TLn5pq2nwWxL3FjaMdTFvyBzT1pSjINn21v9MHo+K5741s0RtMTxsi+ei+i0Rv7L6czaezPq2oZjcCSjQdWsTjpYu054Ggc8IxZBhtO6rlHDKnrkFwpjsbykAbdKtdzog62tfW2qljkoI1q6PqZ/JOZgNO4aGtdo7aGKanHE2MxeTlbh+8/jxYNpYHXJBnQZT+NMiQDDe97Qg5Ib7W5fyic0v+cW6a59+fA9wnJuaTZi4Gk8Fxw51kibzsYqw7CMEbBVZpUf38PjbyGx5/qt4nz79j3QhfiRTlNuFOt4iBMwbTpLVA7hCzJzuHz8ZKmmL7xM2Tr4H0cBADOXpebhc+Xi8EoLbGo05c02K4oJ2Lm3Jzl2BSXX2KUVR8JvXo3N923h+FUgVBuTHPjmtRWOWIoZya+/s7WxSlNkfHl7cvdAcVy9oqVVCV+eAdtfBlf7DXKowaWaTcHw9Ne8FYSEMizD6HRNOlEqask2IUK24gtuViTZbSud8LX953K1Ldpii5kSfLdsoG3z4NTeXS4TGjzXrQi4gJx77L7cE69nIc1DF4ezx4f3+n3gupecYlmeCzajFtnXs1KLH1wqf3z3zIr/yf/u2/4ZWfiO+fmTMQXOWb118Rto+4/SDWA0an58lRd3qVYWiM0RJcnUVf2L4mJlYKrQ5IFepSdm6Ph+nvIi+3yBYyl7hxDRe2kPEh4sfAexWUPg+CFT0z5zJPNUGQc8Cw+woQW68jWx7kuzb6oM5DsGnvzCYK/ijyhDyOQztcm0pA0HIdA9dFkBAtHHPDUHPnzM0jBH9CdpoktYcBxzJ1DtHh3STHZJ6TEi2H5BgFYrT9juO5z128I3iiInYdmbY7C57sjZzwFZlg7V2nwdWyPqsED61X5hQhKDjtgpJdtxk1sTnvn16CBjtK7Do549dt46zEbsFgS9gbY2I5t+csWcVKxU3R42Y0sq2ZX/cmevtQhPvj2JW8vVh7Xqa/q0mRVMVMjcNiDqome+foc03r4FMQSzOIROMmhnw8U3qT93Rk1VVqoU9LtR5rl2gMbI23ImGNQW1/BkXqnPzDkzorboPGymTpu9dNup9l77EKUuud0jTKDzcUoe7Wn3cnrkvTobLsWITzy7beOa9pa+F5qACst9G7ozU9ut47YneMNMwFe+IS5Fvk2jI02JyipF9uG9kHaqnEfUrhXiqlmXeca9zjQeuD0kQJTSYYHrNTjs77fWc/noUlOWeL6AkeQk5y6liMRTdo9eDYH6TgwQX2XrkfB3trlvYqrLm1xqMW7kXu8S4E8/tbrvKF3nQzjgGlVUofPCxdV/uXTmiDS1IonDRM0gjVAY9Sue+V0gTbxeDZcuSSbb9gVHW/lu2WriuNT4dWGaXQjiLDzkfB9S440fRWpQ8ee6UfspeKCNqbk9O9eTLpvXBYNESrjevlwn/3V/+a+NM/MvmJmjbi9i0hFLh7Xm+/ZkbPvQX8lthbJyVosUq0atBIxyjYi5psjVRO6YQBx5Q9Th/KwYpRGqEcEpeQuCT9ig94OmFKF1h9oLpAcIvJZaxDg0HncOZgNZUZVSWmGyayci6YRU40Ap0Ys6M1ylHY953jKHrP9sh4ez76GNAkYK1NBrOCxp97tBgDvUsz42zhvhpP7F5dy/lp+yERBBS1k4JnBHM/iO50S5DmTG4WARWOYNPrevaDE7kgRjF/5eahPVT/eo9XO70Oy1DSQz2Q9VWIkrOEFAkxGqlJ50ApTUxjI1203qndJlhjt42mPWEt2q3l7SLj3RTMSmmerv5ufW/voM/T3UPeno39eGivtO+87w+YgqF1/mVyzraTVIEW3x57v/azWrM9cDQ6ja4pKYBbSQN2rZZNlL5EiulTGizn/clA9WPi7T4Qc6eD14691T8Hdp9/LngtoHRN6HiET+cY2GyiSqadET3YbpzR1YX4Yb2mMVvs+60dgn8CwloYN01gYhsbbmzQoFvCb9MBtNGZ3eEbxO7o069UA3z2bLfInBt+eOKYsiS5blxCICUdaO+hEFxnmmtyb3pY2phcnAMXmbMz7MbpDVkhVUvF1NPPWAeId/gcCJdEumTLDBLU0ObgGLLrKb2zV+2jtJfQh9LmVJCjERDCXJwlTmujOVaA+WQ/Dt6Og6PIG9GvTpkKW6JctLydw9HRzX60wVG6rKCmXc/gDYowiDSoa+29Mprt4FrhKGJ/9eOgHo3jqDzuO+0o5OumImRwVdkbVGOEYS7SU5Hq0uBUqj9otdD2A9cG/+a3/5JfXyrxuOPjK4XICJHoL6Ta8G+/55vtxtwSnw4RJOY0u621W2TSprD5OZ8Guw5BQzmKsBHWoZ6ewuoFY28hsQVppHT4gB/WrLlAxJsHm5qslcslBpzBMvbVhiCi2joTuXhH7/BJwvkxxbCrXfDtcYjQENfhb9D32oHNqkms1PnMPnJP1GLJFQYQkb/d6NIdrs57HcbDyADawwXjCHlIUS4WTgiBiAwLKl82T0tUO03rxzNFe/lgWqOjKdMMb9vzC9a+dTEm9Xu9uXd4t7RLcvioVQVuhTL2oUnq1DdNjP0qd/FSm2nX4tlI463x6oKK9e/177yTo8xRK/d9FwnHIWu3USVLiKbxSkkOMAazq6jIMcIlR8zhvD7TTbyPxnhtuD6YLuDCZIYhv9AFrRrJ6mQkOk2hJ2ttSJ8KC6By54QI8M/U8v7Mi5Rb4JoO2HMUssM0WUREDlF8f+9wLuqGRpYdYsU5QtJENGZn2CQltpP2F711LWMDFpLX5Ijg1TnMiAR9RgjQiW3LXieIrvaJK50ZHZt3mn68YyaP3wKRSJxI25QjHSht4JLHJczcVLuAVtXrYKJTlpAUcC4y6LQ2BTkOwZPqEqfeX464HPGbYAqfkiaqEDlwjFLPzJhhzEE/LDrDq5B1M+eUaa1i7NXNJi4xWdpxo8yuK7QW7bWjhar5eeUkB4Bhi9/5JLJU62TdRNcyBa5b5HoKKAM4Gb72UdnrwaMcHPedXoo0Vo/Kj/edz++7JpE5jUbteLzttEdjFJn8EsfZZPTeqY+dUSWM9r0xSuHjduVff/sLxuf/xIieR/hAC4ktbBw9cZkPUt9xDlK4koLny14o/aD0Qp2mcentJCksOAcTkvYxBO+ERHSRl+3Ky3aVhxyBGPKppwlBxAlxxx3d20mCpsZoz4K3JsUvJhvTPCE7rVRqV3PQescziK4Tx4Q6YQamTdOldR6HJspaB9ekMWNpFmUv1pmtwIC6N45HtQnXIMGxDv4nshG9Y1ihXYw0+dOJOYgdjiuIL7pg0S5LMKpdpLz7IpNOWEXBmnZB+HrW3ZS5bnD6TOf0auKmGrteB6102mH7uyBCUDMXhdb1LIWg/RdGJW+jQ5MLR/KB4roaA3N6GFMyiNoHR2m83w8ej8OkGZbllBO+D/K2mZbNtEnBy3nc9GJv7zvRO169nCJiDDi36T1FMe/yRbD4RFrDmLO+7xZIl8h2ZHqdDNfoGEkiB+mzknbSNBEeBl2a1PG1y/0wZwkYDuoQSpR8MoKZhSU6u65eD1htfw6TlGG+Cx/1MSjPZS62ThD7xMgOIMHgdAurtRvS2Efey6es2wPUm1GSZ1C32YY91JNZZZ0/3cSnyPDqQN2cJKN4+9DPrmjOrvyiMaCo86t9Ej36vlN4/RzS8FR7GJuDHgbdTxrypDsjSkSyOTuUENTZYF1pt5C0k6ZrWLvLCZ9VnFyOhOsmjdPQPqGPQWmVMMZJJRVkovBI5z3F4E4VRU9IyWDXcLKHgrG7XJzcLhuPcvB4aJpx6P06MFdsxxbMlXnd9BZIuLRHOQZul43X243r9aKk06jJrttE8igHj2NntE4tFT/Uqd7vB6Xq56m10p0m38f7Tt0rrqsAhyAN3HRm22NfvRSl9vbJtx8/4B9v+LlxjE69aHpI+6CHgy/Dk+JHUndsfvIxdL5vd96PO/vjLsPQWhVdEiPXy4W0bcggUoyoMadp1rTYv1wi16zDh+lIxgbsth8c577RWIRoWggLjj6p2toN6cAztwOnAuP6hDqgdXJ0fHu98HqJuBipffJTfedepyx3ZqD1eQp6YRmcKuVYOWwNhjM3/vYn7/W8h203urRWq8sW89Gec9MNPSnvmnaw4tPNCmlM5Ugt8ombFpi52IBOeIjiMbrdm0u4q92YNFlYyqxQi1qqkIuAnkNDURbtfn3OzZiCY0x6adCHAg+9N2HyMDlHx3fZej2Og/fHg9YGi/7uTVTdh/7sSk+eNj0ty6ps6wYfEt5IGtoFyQjWBS+0JAbr3cf5NdwUTJnDSWOPI+gMW+J8NAioYRA82OsQzDvgZcuUUglRz1QzFnG3M9A7QbpjqMkMi93q3Mk2/Oe8ft5FyqnDd0miM4K8tuacigVImqCC3Yh6gPUwdz95n419SCuzmPzeFqcTCGHih9TrbshdgjEUUz1VrPrUBz2A0eXGPqIKZIwB59eNa/ZEY+Kaje0NmpumzpeotTPIsVOCLvhMSVY+yQpKHoQljrsknIXiLQ+w4IM0X0PdcDAPQjAfNm/BaMkcxEPA5yw9SBVk4IO3YtXEKpqOHBM5pPOzzyEqzTUeuKQu73W78pIStxzZclTXFoMeijY48oVjq4zSxYYCW55LW9aZmiR6E1TXZG/lRiOAjC+T107KWQT91BK49srRC6UeEluae3mvhTKGkgGa+br1ga+eVotMSWvDW3BlCJ7oJ5FBDsCYhN6gV8ZwUCsft41e32jxQicyRwAXuBMIeCqNNh0DTxidWxz8OsHvfzj4sh/4ceh9DrhEr/1KCsSsKXWs3KQJfTacn1wvN7Z0Mbg6nNDenFOmjKGqyKGH39u0EdD+xRk92c3JJb/Ycn7inLp3R6ANR/Kdv/jo+O2LYxt3aYHiC/nDjc+XG/+uHHxvQXcT8AT89EascLTiqLsVTdH2LFJm6KBvVvj703HAm8h5QWHn840O1nNfZUWqVZFJerXIm9rxQzCnmjOl2jJF9uilMnPH9SVANynCMOq90+coSv4iTDTaaBxVpJveBj1O2wFPHq0R27KImqebvw4CL8s0szKqVfupXq3h6RXIjFKo9c6XvXK0blR22QmxrpvB5d4HWnNyz28Tn4LYvlNEo2jQb6lr/we4oXy94IyZ2XgcD452iC62kCgrQGrXu0laGmMEejeLrdZpe6XeK32Xl+bjkJTDd08bnmJGv97HUwYzArgUlu8sYYpFulw4/jmvn3eR8t7SbL+C2NDBl4KmqBSjVe9nfkvrjdoax+wn1i1diGHGQ1327Bbxbl3e6vT6wqu7QY1G9/egachk/Uq55aSProwi55y0Lw4ZWHoxXupRpNzukIYgGqKjlk73E1LQV1fQXNyCYMqlpzIXAtq07tBwdnvAnZBNYoiyronP4hC8J2Z3wnm9d3bThcwpNqG3Q3hZ1MDaHTzJCwrTi+Y/luWI3RXV8VaKRKu3VxyC3GKKiqdIWXsDm46n7byKTXfPHCQTrlqnNsyOZxhE1s2GpQ8d1q01StP1Lk2sN9ccfsrYtx6NVgahTZJ9hjkpIj2HoPdZRZ9/HP2EfXtIzHRluizG3IQ6HL17QrqoboxIi+DGzm8/vPLD58+8fXrnrRwS5PZCyhvbZbPPSoV37YfmxMgqky1nQojGCrX7uDVqKZSw6zCP8cnoAhWIPhRUOYQCBOfJrpL9IPqIcwoVdFR+6X/i9dvA5hwcD8R9nbjHF477Z7bXb/hv/6vf8PnzF374IROj/OdGGxSULhBmZLtFwJlOSq7kMZggvckmq7cmsoEb5vRgMD1r6pk2VRm8Z3uc3gZlr7xZkGcvjffHoeDRMf9EuDysqx9uFadJ8JJE1CrXg+ifFk+uW+Hvg94qtSqo8PFQwnUN2nW3Zo4mFvk+rFkdJmzWNXPayRizslUlaLdSqcdBCIljf/D+eOfz+126Qi9HCjER165naaA4z5zRBT065+WOH0WcSSHSF6pkJLHl1uGj0IE6KkcV7GxPm0GpU96DQasLaUGFILU6OB6Vx1vh/a1Qy6C5yX2rHNcGHo5WeTwODkvvnc6D6ycHzS/tmskdHJzn8T/1+lkXKTBqaVg2RdCHoiZSFBMsp2hWLwt204NfezdfKYxd5Jc4wjBs62jm1E5pytEAOB+CZbD5taXJ6DJGHbZfUWKwwxGkuzrfw+oSoXvrZoAcAmU23saD7KOW62NyuMoIE58cwbqbkPS9cdN2aWZL1IYimtsQRdVo1zKxDEp3TcmcNPx5qMuPy58/Ux2N0quxrpzt+aJ2d2D7Pwdt0EbliDu1Xamt4ZxNCSlRnSNFwVWv1xs9TXotjKSQt198/Mi3Ly988/JCCI57K7TexAhsMoUNrML4bCb6eg9OEBneKxLE60AeEzpOuz3rajHNUJ+TWrXgHkXXcmnpUkxctqx7Zk5qKYDjse9yjL9eYSRmeGGQLJpebhtuBt1TPrJXFdBbuuDrO//m19/y+fOP/I/3nff7DkNecFvelKarrbjiRcxRQbT/cZoiO6cpt7bG/XEnTE0i2xzEmVjx8Mv7jiGyzbDCwNSEeEmRHDcigW1Wkj/4bruLzTUC02fKbNRWib0R3aS+/cg2J//6V9/yt3/7j+b7Jq3dtP1KcMoy8l55Z88GyZ9U9dFNnzX139XBwxKQ/OlrnjIP0eW7SDDvjbd3TcHvj0oxlwMRJOz3guBxgyV7a1TvLY5HrMPkg+I+TicFz9JGDbMfqq3TK1DRX1CtYg5OqFC7RVmLOcSa1Pmga3aUStkL7ai0UijhwXEcvD8OPr/vXCxDTHtGC408iSbeWJKKWml9Qmm8DiEal5jZzBT6MJaxM/usZWel52ece70nrBhOyFRGgkp1WH5706lIlb1zf2883ipH1RV7bIXHtZKvgVHHyVrsYwo+N4q6917G2c2yvOxcjO4/v9b/n18/7yLlnr9oL+NwYerCpShbEovHFqVTXdtERac1dfOjPUV7sl4WkKE7xD/hgKYjcvRxsoSYS80tjDoYdUcPlWMaGymGaC4IpvxGlvatdWMhoYiLGOlxsNNoRjH1MTAykBSqGJLsb1R4tbistcMolNlo+2C/y+9ODCPOQh2jDCIVBqcft42B603mrUh0V1qhGNwhXzvUWXlNklqaN9pRKPuB857DT97v72wRWsvMKAfvakQP7zxbzpQu/DqnwIfrhdfLhddt43a94jwcx91ozAbvtC7G0XCnvqJ1YefeIJFFQJgY4jLEfBzO042NOM0uR5PXc7/IEFMsxWTRFSocl5TIPuCmiC97KXz88EHX3Ce6j/RpfnlTyvqjPHAdtusLhEgZk0Akx1dyq/y3f/3X/PD4G37//U/M6U/lP+hwXS4gx3FIh/QoJywmWyBBl/d64O9vp47rNjvb2Mzl24xmzUmi90arhVELbg5IF1zcSM5x9ZWbf3DxO5NAnZHpsskMdhpGwR+dXh7MUvg2f8tvPrzy75MO5GaHTykd70yqMN3pcD/GFNHHdoyLFONKZ4xmBcz8hOyZXgwy+8EZze650tgflb00dotXPxqCV6c1Mguys+l6snZMHVfFSKxNU3GPKoJuNaHW6LYmP8hSKrVKd9ijbJxmWwSmyVw2KqbTmhPTC1Vm7WDwY3scvN93yrFTjwgBjqqIoPv9YLM03qee7UmDH1MQZatDLhV9yrYNp/27nX/Bh7NhDzmyXS7cLle2nDXCgPb0KeG9TbvmsjGGpCNtdgqaVmv0BAd179zvjfu98ng0Uce9o5dJL53qK7XIecQeK6YhLmJiOsGcQ89hnwaNhj+DSWoJa7tZ4IxuF8vBNSftLizGGwy2U4OifJfWTyt5t7owgxbG6H9y403D1tfNgwnfbPhiiYrX36Mvx+xYAqg7x3lni+bRHbPpL3bJoDRzsiBoApgu4DZPaJF4QD8mDDEUvXP4KeftulfabII1Hp3H24NjL4IknZOq3tzFcXKF2FslNyNxlEp1EhwzzMByNB71oNIpXm4NJW24aY4WZoDaS2XGSC+Kos9+kHOmhIPSKt55Y441Sqnsu1h2yXvTq6AFM4IiL1HQRSDY1CvMX3oli/wYgke2lKVxGeoAgw84P4zZ2GlzsLfKo4q55Bhm6zNPbzPvzMF+LZCDYwuR23bhmuW1VrtMX1NMjF5J28ZwK/ZAKboyOZamqtZG3m6kHDmMWPLBX/n1S+L/8Ne/5T/8h//AW+l2r8kGSYe49Ef7/cH9y526V0Wuh2C0Z127Xg6O3jR1zs7H2fgwb9SR2RDBwo3O0Q6OdlBbYfRmD3wizMHV73wMkH3Hu8CdD0yUBjt6w7tM78rmmkRuOdCPz6T44LuXeJqJmlE9rYEPyyxWvoxif9kzOmHlvXWbtvpscs52i63LCXNhEelytBDT7jgK7/uDt8fB21EZtRmRQd3/6NM6+n56NKrwYMW0chyF3diJl9gF17V2ZrTNbkWqVvajcRyDWgYlDRKaZIbXe+unobWKm0T8g/2oslHqcNTJMTvv+07pjdqL3BmaIkZqXdGAZuY75on4rCbtlJ0MS3zBf9UAmOh8yG/PxUC+XLi9vPDx40derxczCfb0Bq577v6Bn47ktBbxYgoxqg45Vwdzb7Q5eLxXvnx6cH87JNcYygqTuHvI5Lk1pu3oWlc6upIiMPEvp3tGP6aA5NWE/BOvn3WRUoCW0cSt4xgORpBwV6FlX9Fc7UNpXU7U1bBomW76s/hgHcqc7mQCwoISpsE70zBVwU8CnJ5OFauLnKuAeYkJw1TOTh+itTIcMSVyXO85EpODILjBIQW4FOI6SImCRaSsV9RHLx2mF0HhcQhaKO3M1cLLVXk5t4+hKXIvB617fOtkb0QKK/y1HBz1wWM0HJ4j7FzTJkuXPjlKtw6qWlAj7O7BF1vYJp/Y94MYI6V0Ph9ym3+87zjTUJUiGGrStaxt82RIOtQ1ngLRVaBaO33vgskMcJMaNTVHJj0OWmiMWniUymPfKYicsOK7YfnUzbMjjcH8yIKmvktKgjyORhuKbJ+tM3yB2PA+iyXqloOA0mu3bZPNVpUFzeE6OThyeeOvbo7//V//Jf/Xf/+3tleqYpHVCrPxeH9nv995vL3Tj0a+XNQl2+5tMGRKenTe94dlB4kR1poW39F2sY92cNSDoxy0Uhi9k2bn5hofErzkSJ2R9+bpI1LKYQXDUyu4eQH3YPpIocj9vR189/HydGfoyugSLNlMcGyfsEHkfrEMwRw/Oj6Y3GNOnB8nC0+MDAmdAZNcaEfUbL941MpRNK0s8kNfxIsp2YjQjSHR8tDk1FChq0UuLAt2FPNQkPwyUZVjva6LwjYHpWq3WYICVFdxYaoIViOVlD7wQ5Bysx3jUQ3CrgcAtRy69qUZsUufzdKSzOU+b8Xq1GEt6HMMi+yQAfQCxFPOCsG8XMhZQl43oV0mvcAoky1eiO6d4YfkCZMz5dgZpDmj9Hvt6NRDk2VvXasLE3W3etA8OmvWmgPOhgQ4m3ftztVETgcx/fPKz8+7SKGxcXUxIUrYmTxcvOdiZo/nzX8u5AdH65Q2rCvRgzZN9LqEdG6KjTOtY+mLksyz4NkpJ9qr0XqdM7aeQ/b9MxKmpgbCpA9P97ICkgBROVApWWS7uWw3Jzv+dkCvMH2gBZE1/HRyTE5iygwjDdQyOI7GfjR2cylwU55fgy6tFU/B4eNRRQprleygoL188IJoaBXXOgPPESajDiKeMQOPo1Luhdm6QgldoxHZZ6M7wAVyvJNTZuJ5fxTue+FxyCH88TiIIXBJmQ+XC2GoK1v+eq0/acsz2Oduu6ZpjDVmfe6lnCeGBHaQyBDaUYoIGD5AysksZKqxjBxzNMJKqk2BbOa+OYoS7gcKATSa/pxDWUbzoHs5Qnoc0U1mkEi21nY6V08nOjPxxlv/TKDyf/nf/Ne46vjvf/wjj+PB/chEFyil8OnxhT++v/HDp3fGUXm9ZgINkHt+6FBG5/PjDSbEiZzK3eB2udL9IPVInI77/s693A3qg+oDtzD49cfE6yXAzEw26jioY3K0yZYi9SiUVsg5M7iSZqX6zOGvpOMTv/3FX/Cr5Pk9gdF25uw8SifHQJv+ZKw2s7waXo2VGyIRhAR+icxR9MrXJqpK10aQZTfnjw6MQD8c5b0yunaObUh74wwunE6dep+SfLgiSngNAi9LqYo/qQ33akXSPECZzpxAmlIHphNVvWvabbZzbHOcQmg/wY0OTQaurU+6Tbz4dbaYuPko9Eemx8peOo8HtKNCV8s0nQmPTXox2zIG7owq5Kdj8pI2mAymf9LgnXNctwu3y40tXUl+I8Ybc05SnORYyDGxbRe2tNGPYcQOC448qsypa8dvQjKmFehu+4E5pzL5pqc2zyNM3o7OvWjPr02Jt+ieKUPvacU8Qg3QqOTtz6BITae9hTDigXOeLUZyTCQfTgqnRzDOtHXT2Z2YuHQsNg6LyTdYrKJplNC15HWL1mYH44qgfjJVjDFjNzBGyvBObJs5xMNMOeMtrC0m0eTB9g7dvLGIREv0PUqlMXBRFz74YF2SXDSq+fO1IZuVNaGMISEt7umZMVXVJc7EUWeFVuV+4JCrQ7KfZ8gSZiKdWKUgZ7yo3Rymx/ETR8c32aU4z0n37aXQZ+D9Xnh/VPba9X2Z/H40wux8vCSuQVZJ7+9v/PTTJz59+sJ+HCx/M+8UZRDMaXzJfyxX8MyK6kNXsJtDgguBy6Y8nhwDySkZtLpBm0sjJzhWPnpeaaVfwbdiSs3znnNeE6lnWn5Vo9aD4/5Ozhei85Sy44Lj8nql1529VC4+06rjtm38H/+3/w2Pfzf4w/udd58pj4rzkx9++swPP3zm8+c7EZ5xGL0xhiAmQVEHDsfj2NkeTwslvGPLmTE7+14oj8Z7axzjwb/69sa//e0LL7HRp2OEjb0nik00MST2x8GcnRw2NQkukYMIREdttPs7Lr9zzRmmBLQKj9R1qm25S2jP5Jbg1TR8MrUdZ+7TMInEcpB5OpGLIt2Nri4XiH7utzQ1irru59pDatL2ze6Hbj5yTchEM1aptJBz/T+hHYbtz692oeUQ7BeGkT6G5CTdBP7VJjtnvp2Lcj6MkeXWWDE9tWjHW8tg9sL+ONjvhbob5Gm7u1obOX4N9XXKYU4z9jloJ6bCchI0bMeUUyKnjS1vpKx4EMlyKsFHlueoG6KFzyGHmlpEK+/rvsmBMa1R7P1s2O3SyMncUJC9Sig8DLKVbZyy6VJKgoyrGRr4IfTr5Aj/l18/6yK1orLVeniwaSWZFYr2U1qWi+w0T5roMJHeupmwJm4BQX7hw3M+/+WCE8AOsGeRCt6xaLLypeqMJUJc09rQchfnCRELFRRpYqA4C6rDuX4ujt3wUuJHmG0QboFSGt4Jd96yoK7cMt4XjtLOKIXVBblp0ChPaKK3ji+VMCagiakP0atnjkwvurPHjAwMilAAHOCXvYyD4Nmi45odLznysokKPQbUNjnaoNROPRr16FLxd/Hu3Kj85Dt/2Bz0bwjA/njw/n6nHIqe9z6cxIhlAiq7GO2mXJO9VemdYSaaOCOFCG+V39sWuSY5FLg5uDsdkthnFEwisIyIJ45msEudImR/vr/zm48vJnTuuDDPB07ekZ3j/oWZNuLlSm+dchyadIDuAt5tlNL47vWF//O//df83/77/8Cn+4OeN973N3746RM//vELjy87L9t2TvG+WfM05HrQmhqQUgqPfQfnaEN7hMt2YdJ43B8iGPTOX/76lf/uL3/DLzZP95MaNlq8UNpkNBSuaepvj6dVS4H2ScQjKxi+N9xobDmdxUkx75xT8PmIsNh97ul9ORbZSBDV1/KQBWnN80Gbp9XZIhKcRJLFyLPfuXaNrXd8gTLkMB7C0jUakcTeozPykQrTeLINFyphTMJWO36G8z3MAcNZwTSIz0dvkKCmiVGroTxrhYCkJG2cprv1aJRHoZvPX2/63kxnxrbdGLpTWq067Jyzb2i+i8O8LRfE5hARJSxLqLVYX2Qr2z8x56lha31y1MFR1fBX83A8zQBWQ45+HCEdYqGOGThK47DmwcVwrkLOSXnlZ/VJipKkPA/T//LrZ12kUhYspoW64IVgF8hZF7PMKu3c0g29yA/WZTm7uU8Y1W7u82FZN8B/9vo6SXSaO4Pv3mi1QO3MEGwxa6wze6Q03UEMDufG6d4wjWZsNAqb/jopbvRQcT6TLzfAcbvcyEk7ljkmLuzc94IP6rCdDYXOL2qvyrDsa+rJVnROFPVZq2CZ4Jim58oxwISj6pB2KHRNECPGHBTEmYLnuiUu26LQOvMEq9ahyQ26DcWGeNfY0kby8Hj/xKfQcHPy9vbg+x++8OntLlumMRXJbRHZKSnyoI/GvezQo2LFu1n7zGGuzWL2dSCnyDcvFz5cMr1XooN9Oyjeq/MXtqNrkqQzIXrKVKbS3gsjwO++/wP/q3/x2wW3s04g5xwhJLbLxvunL9yPgxc3SSnR90nYAiN4pkuk7RVXvhDazn/17Tc8/tf/mv/7//A33I9CKY1Pn955+/Kg7Y1LjBbyV81wVI70pRYt+12APhSU6AvLUmnfD0KvvD3u3O/v/OXrxv/ut7/l4nbKcHR/o3Dl/X2nHxVXG9Ep3iSZ/nDfD3JIDJR91Jlkb4hAl/P96GZ5ZQc7Xlqs6BWDPjpqFG0ntSai0QWNrmZv2WSN/6xAYVZLzejnK0jyjD5fd/VcIY/GhlvMOOcIZpsymu00qyJHQtTZsN576AM3LRm5T0od7KWdiIiKFCczdKUHj644ChUgnRfOLVYbkJSQEPHMJtcGRud4VO7vhVYED3dzwdDOeGVC6dfjUWllPgvOHIL/dhUv7fj0zC13+7HguaHz8dlwGxvZUq17UeNYqpIbBvMrx5pxuvisIj5tDz9RoZrIxmw5YQTsTOB5GZ01gJj7BWP+edgipQwxy3ajVZUhHWDe9i6Lhtqt2+V58WuDaRMTjjn6SZIY9u9P2MF5o6U/X1/rOcaQ75vo6Z7ehrmyc47wA3DDn2NJDMPcxw1qcoEx3Ql5nEzA6Zgp0hjUoJvgclF20GWTY/KKUFjRG9OIBoz5JI1YhIc05Rq3YeKNROAcdIfglTnwY5C2JD+wKVODOdWtOtNTMOXqoPRQuTUsun2IRiRwThNAGLjoIekT98ORXeT1tvHdhwuvOeCq4JBPXz7z/Y+f+fS+U3UxiDFw2TLZHKdTTuDgqEVsOtuLNSZtNPpoOtybJostBb553fi4XTgKzNG4xshn5ziMNktwhBwJl4TLkbHgKWFahC1xvL1TemULAW9ZPRjpYh1aMXpqPejHjrtciOmqB210fIrMGfDoHgnO8de/+RXv73f+3X/6e75/FI6Hlunq7hV/8L7f8UHO3o9WuQ9BplvKbC4QB8xqAZelcoRAxjHKO//qm8y//eW3bK0o3DJfaVwoNVCOO3FMgksMBqUdchEIER+SBNemu1oxGj4Eem+UQ1qvc09h0HaYonOHKe+6LSSSk3D+qZHCUgTgdHZ2q/NWpz76wC24rz0TgN0KA5xfMeLs3mxdMSwAfXqGPYOjdTpeOVcm8o5ejeBKTphj0b1lm3SURimKwBiIcDXNvaRj76kNYc1Oa4fZxle2bI6cHBcXOWbnJXl8R16AvbI/Kvf7DiYPGaNTW6G3wgie3gutKg7l8VgEGd2RAafA1EelHIVauzw1q1H8q1iLaxrC9utLbCxR9TyTDTCoFocIP0b9GkZd/H+T92+xtq3ZXR/6a9+l93GZc6619r1cVTa2AZsKTrhGLjhCEYdjcsTTgfMKRMrLsQwSlweUiAcSBBa88AQ8ITgvCAkJhAQIcREhCTFBKQgEfMCAjct21b7vvdacc4zR+3dp56G1r4+5ubkcCSlbHtZyrb3WXHOOS+9fa+3f/pfh2GP3vxWd4DID02GaZVoORnSiNroAPVyDKN082mBhoa4/D4rUvItMc6J3paSONj/YYXvDbeIcOwXHrX3BKTzRV21LV2z8958hPhFosD/f+oMxQanT1+1HOWvObGlUlLD6JFcqks1fUPxCwOQxV0zcYwnicD5I2YpXgt4W5mlHqxZdn9223zy+LFK7tOJ7NBgZNXTddjcS7N9YsfJdgLIVykEJtTnJ3oGmHUIkz8nCJWtDcdKH4AaSiZxtj5WiuU2kKY3tHJLEfk2RqEqKAIG7OfP68xte7DM3cwaUD+8LS6mcl8XgA1W0FboOH0bTvMUYIXSqWtdbe6eLschiCiZ4dsZXjJHdPnPc79ilRJBs7t3uW4dDqyFF0pxJU4YU6ILnHVl89kUgHY+cl5V9iEx5BrWwvhDEJ2X7nrspI71xun9FXBb2x725r3drYrSsLNLY5SM3+4nv+ra3eXV/5usffMrQSYRgV1trK+u6EFKiNuW8rlSBnCLHecdxtyOnxAgYrLXRQqDEyHe/eMF3HCGXVxQCvc20JJyXFVJiyhM0YTlDr68oy4LQWXsl4wa37hbe1XcSvmta19Wni6vOMGowiUVtFq+BOKoxoDT1Xc1TSNDvA3dHGNfuiO5g3B9jv/yEcWtR99HTnu2gX1cToPeYzPlg/JzxM2UkajscJrJp1DSI75qaB/MVgwlFqKYYoKpHWbTmItYh5O9oVbvnQiDEzn43m/VXXZlTsr3u2SQBp/OF02mhuWmA7XFcGN4rvVu6bltXg2JRC081pQNlLZweTzzeJ86nM6q2I8tLcWNmK1YtmojbJDeuDxtJzcWaB8Uaot6Ka8CNKGGMWjsrTCCsWyMx9neCmZ5fSbMGcW6rlQ0z9GPTP9tefx7AfbubxDRjS7w10UskI54L4xiy2/e3hlt82O5ipdNUfKQ1rLY6VDDooDCU2C56e2LjoU8mq97tUA9RQDux283Ua2VZlXVphCkx7TPSGiGYpuSxQW6RNJkIOUkghUQGJoQ5ZNBI0WoWNL0xvLZSDKQ0Wfx6MHPLpTZb4nsWjKppXIz1Y8SDJlDUuiAJZl5rBxD0kGxs7wEJmSUEj/DOECCGQgtY0bUBghADyW94jUJNnR4nu5F7YxGlTZFWuqe+2rQSAsyHmfxsR9rbRNhbY9dnQs70MDk1fPFsGi+gQSAqPRSGmDE9TbkL2LQiBRZBspAbTBOEbK4UQc3JQx3CkW7Bj7uc2U+TEWGSJ6JKp+cAvXBzCOg883K95/UM0jKkPT3Odm6fTqhEg+SWe+rpwej080zqbxD0GWtRbsKF/vAxctihKaERXp+V7/uO1/hnP/F1fmJRtAcihSAmQiplJYrBVKUWYkzcxZkXu8iLyXKNTt2myIxwt5/50rOJd45CuLwy+Csf0ZSNip+9uetmODvNKw8Pj4RqDtYxBWpdIASyTNQUCe2RLI3KhEpCwkTuK6tmOhHqCfRgeqRqE2jturkRhNYtfLKsTL2iOtEQJl/6qxcw2zUZiSZHm5BsbxhIYuLnEG2Xa1OcQ64NqEaxbi5kHo2eaXQM9eg0gnRz3Qjmvt56R2pFyegK7WyhjpfzSilKFGs6exTTSYkhI9qUQLLmo4lBZq3TTShktmw00i5znDKxQ1sKl9PK5b7y+LBsO02DKo2kYOw9KGu1mC8BTRaeuhS7dh8vjf7yQr458nhaUDWoO5TCoVuad11WFi/GvQW0h00XVi6LfTbFTQxU3KLKPqfzkukdLpfOunbE32v1z6qVTo1KipYe0XpBtaFq5C5pgdLM7d0sToWKoR/FIetv5fG5LlL744Hd3i7mEpQCTJLcsTu4d56NK8Obz6KUvUuq5hY83JjNEkjdUsZ3Ry7aG8rs61J3HHK+7VJ3l57s59MMNhCftFI0y6QRl2AQYadUSMXyXaYciAg9JIbosWul1mIjfKu+u0pYyzIc0Tu4gWivpgEbFjgDdkgx2sjttikEQUNwTYwxtPDvVVGkdwtoZOwOxPBkFC3qGg2T00sw37qlCUUDBYhdWLpybo3HYj6JGiHtzf4nhcC0m0nzTPagt3VZ6f6czGXab57gSbHduzKw5w9Gd43BaPSof21nJzP5cTUq+WQRBhs7T02AWt0BQbDcpJyCmRLn4AehOHvMjXuCXUNFGw/LmeN0YNolGubiHWIm1eY6uUQvjfP9PQ++L3nzcDSYbH1A1oUeBd2ttPXMnBKv397wC7/8Dv/4X/40tVezo0yZFhOrBLIvq3vv7CZzhN9PE/O8o6oyx05dKm/cHfmOt1/nxS4Ql5d0EeZ5z1l2nJvQqxniakzUVnn5/vvkUpD1kfPpkZgCh5tnkM0+y8ICO4FGuTwQqQQyrx4e7flg7gWmrWpc1pVLTiQNljtWF5uGW99E0bVWQqoIkSTNDufevOseTiJOJOgGT09R2OfE3W7mMGcuaTEaelMGLcg0UTaxSXRShC+ULDTSggdjCCZidd/A3hQl0LrTvZtirHKbQMjRUBbftYgaCUCHbKJ16mWlLIVeTFYyTxN5ykQxt/4YIpfzQgqBh9OFT14+cD4v9t75fq8Vf29kNceLUji57lG6btEsTd0m6dR4fDjx+PCI6AzBiFVlWSmzU9ZLRaNDe90g0jFBbrtEf21+F1Fq43RZUe1cijEcR2Nux4dBrn27Lw23ba1aEzBWGK4+HubWAffmFKH+fGD37Q4z02zbzIgStBOqo6mqTqJwfYbDOo1hzWEXdnM7o6Gz8a+2JaQ/dIgQ/w32xGAtIRBzZN7PpMmtYi51Yx9FwcgIzWAJC08z5+fBVtKqRE1UGqtWgq60ZsXVGEMmZGy++wrifmNiOhKjs545Pz5S1xU3kkNwyx+HEMULFDE4RBJc+GnMsOa4pTVUat+/NrsgVQ0+iZhTuQseBeHSIImwaqJguUBFA0UCawwWPSL2XKKYl1/eTUhOkBMtRooUFoRV7TPqGykjOERgrxm86YhY1MJkyazJPwvVTlIrPDEaSJRyNBFsa/bL3UYCwn7K7HeZ/X5imrNlVDmlWGJAu1OKh4UMyqtSCGVlF4ysY84GzlBr1sX22j2ZtHF5+ZJyeYAJLusjc6n0HRAzMU3EPJMTvP3aazybAnUSAnYwR7FO3mCyzhSFXQpM0XZrtVUkZWIzge53vn7kCwdFtHBZL/Ra0DizNrislR5WDjdHFmB/mAnlyMM3vkG9vEKXhWWBtNuzO9x4NAhoLayP9+irj0i9cphmPn04mWYH2xUFzAn7tBQOuZC6C1RbA2eCLYtZYtVSiVO3naiEDeYaUGDw9zlgkPIuZ9o8UQ4z692Ru+OO+1ePFPWJhkFuMKgpBjGWmQgRm4LUpytBDLGIySNP2JCB7pZfy6VwuXj2mfh9MFhxhI2URPNsOfEwUddSRrH9uLFRMdlE7ZxPF1SFDz+958NPXnI+X5Bpsl3SYiSGVpoZwV4uPJ4WHu4fTRbgjjdDTGI7MmVZC8ulMEVBUmZaq2d6VUMwor03vXa3AXOkp4sRjVyAO8gtIkJpjfOymq7MWYxNHe8U3FO0e5Pi56gascedDG069Z81+TQ3RVsDmC3df4Sd1A//8A/zF/7CX+Cf/bN/xn6/59f8ml/DH/kjf4Tv+Z7v2b7mcrnwe3/v7+XP/bk/x7Is/Mbf+Bv5E3/iT/D2229vX/P1r3+dH/zBH+Rv/+2/zc3NDb/9t/92fviHf5iUfm41U3JCkxJaI6ZOzkJQSMTNBXzsmlrvG03ZIuNB1aMR1GCgp/j4cOIedWlg5dvf+9QFEFJk2k/sjhPiglek+fgOUhsaA1obJNND2CTmdOJq4lDt0Av0bJEHOVez/HH/u6GToHXrTKdqr61VltPC5eHM5dFGeHGqewhWoKYUnHQwzEexu3K8R9gFpmK7qYYdtt24Xb5Y9fcCgzyqG0qqmKo/BStQVcQiN4KiASNajG05VjxDMsf4JrCo0dUfa+dVUR4qrNXOg5GgOhblDJFnEEiBNEN0gscAxXvrrLUzz4lpitTarViJUurKebnweFnMaw7Y58xhntjN2SKy6TbJtMKUJp9onc4vQpfApXfOZeFZX5ny3jQhIdLFGEzkyayigkHA6+OJ5eGe/es3xojcHZnu3oTdC5j21DQzzXu+93uf8f959jqtVE7nM7U1Xj3c8/LhnosHQpZSiMA+BXaitPWM1ALa+OIbz3nzJqMPH1G0U5ZXlKVQmyA3R477W6QKLz9+l1fnFQ1CeXiFvnppruwpkSfbcdnivNPKCpxo55eEcuFSCrJWHs4LkhJRI71XchRKU/rS6FPzdk9IPZAJ1HWlXCwxudbCVFd6jDQS0iv0yhDKB98TJgl+sAXifoVlgnrk2e2B9z74xJq2oZGQa5T8lCNxzqQ5kkRIyaD7GIyxagQfuz9sl9yoiFl/rVagLIOsOdSPwc8iLpG0fXCv3bwr5cpw662RhixlLGEw0e66dE5r5b2P7/n45SOlNHJo1Evl8lhYzpX1VCmh8Hj/yOnxwuPDmbpU85Ds3e2eupGZcD+/0qlLI3RjVFqxs3OlhoJGi2tv1VLFy+rOHdXgSZuErbhrj/QurG4kW8vIKxM02GtRZ0SXalOpmXWrv1e2GyRAkLzpVoMKCUNAhqPHt/L4OVWFv/N3/g4/9EM/xK/+1b+aWiv/7X/73/IDP/AD/OiP/ijH4xGA3/27fzd/5a/8Ff78n//zPHv2jN/xO34Hv/k3/2b+7t/9u4AlSv6m3/SbeOedd/hf/pf/hW9+85v8tt/228g584f/8B/+uTwd8jRB9NwhX+BNKTGHbJHWwDXLHXf2bqzdnLVt0W2H82CsgC3fh6nsZ2E9ezyNwY7R4jNkisQpop5CqR4hQbeLtlWhrRBitkgFBsPJ4Cmq8riuXFJjzco8dVIyp/M85U23oV3R0NFWTYzn7KWy2EWpzRXyMijyht3blKFOV+9G5JjMCSM4bjalCUVY3Q5G1EgDT8u3KuZjdllZF1vsFun0FNjPlvzbsP2XxhEnLbbf8diPIGJsRa2sKJdmDiAvLxc+frzw6mSC367dcmhEnzQdTn8dnXIO5Gmy6BEBtNNrpedm0dlTIkYzs80SqGJegKVYYc0hspsyUzLPPnOUr3RzKTW3AzU4akDAPQqrKGtbuJzuubk9WgRGynQJxGlmur1lks4jjcdWuX91z/7DT3h+dyBNe/LxgNy8BmlHSAfIe3qceHY78eIw09aF2kBits+kOYYvwlpXynJhOT+yXs60DsX6CL7wxh27/kBdhOWTj3n89EM6AfYg6UgvlcurEz/1L3+Stfp7qJV+fmQ3z+TDDft5T55m1t7IWLyFaCMlQXImxpkPPnq0Ih8MMhNgr2av82yXuJ0TB3eAaU1YRcyJvZh9Si8LvSbokZrdfdtztALR0Ac/NOdk8HZPwjolbm53HA4zeTLX/8UhJQkdiUoI6jBvYLebiEGZ5kipSurQRUhJ3JrTCANm3NpZlouFUi6Ny8mMU4NEUhSD42OgVKGoOT2oKmstgFqQZetEScRUr+cIbIf8ZV159bjw6f2J06Vsjdd6rlweV5ZTYT1UQvRE49PKejIBsKgZLA+JkUmd7Y3SZmavXczjcL1UyupGvhIhKdrF9koV2mqu9cs6wgr1eq5h7jttwHluwmzH44BaPB7H+Wlm2ybYqdtRrWb465OxdNOp9dquBItvbZD6uRWpv/bX/tpn/vvP/Jk/w1tvvcXXvvY1ft2v+3W8fPmSP/Wn/hR/9s/+WX79r//1APzpP/2n+SW/5Jfw9/7e3+P7v//7+et//a/zoz/6o/zNv/k3efvtt/llv+yX8Qf/4B/k9/2+38cf+AN/gGmavuXnk1Kyrt/S1balK2Jdb3NFO2riz+H7VpoHrGFQnhA3oVoQ87hrjMnJIaMn+6gxVQ0RYppm5sOO6TAjsaOTEnQxz6tNKV8pBU+4d/jQfboGfXbtDcHsW84xkCaY5sR+v98cnelWBCrQS9t2LctlpazNtRHmacjo+GSIFg2Pto7JOtTDtLMMGTzDCGhFqdVvIIyOvjnBdytSw3Os94aEThQPN5TuOxxzzwbdulCDNxsqia6FKdrBtAZjHL06nbl/XHg4my9ba3WD+sA66yfqC0Ic2qls0gPj/NMQkIvBdUFJ2GuNAkWsUarexBmLMlnwYHIN0MgyEnxn1R02vjYsqzZqXTk/vGLOdzSZjXGZEi9fXWiPD+jlTLks5pG4LHz0wQe8/fbr3L79gunuOS3NlukUbAqrtTFJpVcrRhKjp65mqJU5JKYApUXanODuiPSGSoSYDSFYHmnnB8I0I63w8OEHlNbYvYDIjh5nHj6+591//XUeX154OJ95/todb7y4M2amwloqlEo87gyejVAuK6EVymUh7Z/zk9/4JjlPHtMukGDaz0ieeP028my/Y0JZlpXalfPa6GWl+X5Ke8Li1julBNuPtk6QbLsknBChNlWFIOQYDRHQRJoCaUp2SF+WDRIvrVJbQcleqCIxKnkytCQWQWq3EOSg9F7RagnYdItAKaWwXBZOZzNC3iWbsm8OmRATl6VTa0GSsHruVOudU1korRFTIkreBMrjzAjiwut15eziV1VjxvUG66Vwfrxw2p2RoJweTlxOC+vZIj6kR/fpa5vJsqpdz712egCCTYLLZaEsw2uvgVb3AGz++laWZWWtVxeL1oZYEwc8vCnrTxIf/K9xkplitPQQ7fPZvkIjXeK2v5piAoHV7WCGWfe3dM5/yxXh3/F4+fIlAK+99hoAX/va1yil8Bt+w2/YvuZ7v/d7+fZv/3Z+5Ed+hO///u/nR37kR/i+7/u+z8B/v/E3/kZ+8Ad/kH/6T/8pv/yX//J/6+csy8KyLNt/v3r1CoD7c0FpsCq6WFighAhYTlIVpWolqXv21c7D44XTebUbwmG3sX/qQHewOYjSXbsQQiSkaJhs6xYnr1jOVBRChrhPxEMixE6sQr9Uys4C5Uzz4iI+Gr0uxCibV5lIsOmreXBfqawiTCVajg2LQRrBNB0DUrAFq1GaL48rl1OlrQ4LtGqx7MnHb19GB1+a3u4mXsyJ13IgzZmm8OgX9aogTVFxh+jVTDNjMkZYvZiLghV7c4ffTcO81vKpkjZUA92X0q2BGPfSFuPYe7jqgq4ry7ny8tWFjz892efTTQgbvKjmaCysq2hTLcAxBVIWQnRyi4odIC66FLFd1S4Zo9BcCYTSYKmVjDKFhEg22nxoxJiIQMJHVW+Ha61WPLwLrD1Q9CUv3185Pv8CZ93TQ6TnQGkXWE6c7u8ttfhwQHql6WrsxdaZiEjeQZ798zFij0igxh26e45IZE13lMuZuO/UeiKf38dUTRjNO2WQwGRbcVYSVSbS3TvE3fucPn6f0BtzMi+297/5Iee18rhcmKeJ5XJhrQficibtdmiZ2cdMTDMx79DLK6ZuQt2WM0V3fPRJ53ZnQt+QzIk+x85+Dtwcd+z3O6IWzo9C7cLDpRO1mwZoXZh3M2m5EOZMqGbzpQq5CVMQ3zMHJlGSOfHRUDQG6qWRJXKzT5xD42GJhBZZmnApnQWodFI25HWaJlKIaOhMc0KjksV8NvFDNAksvdBrYV0vnJYzp2VBEeYM+0Nkv4/ELDBHLqdKwBJ8L8tCa8p5WSjrgkgj5uAaTusWc0zUUlirUotSV3fOUCuY0g06L5fCerkgKrz6+MLjqxOXc6VUNXd2NWPnVg0uSynSa2W9LAQ1Y2Sq7V0ryloKtv9WqCvl9Ei5XDhfLizNViDmRmOkoiJYwrlbvWlrRAnUJ7t9YZgF+y4/KtOcED/PtNu5p72y32eOhx15DlAVWeSJNu4/MnGi987v+l2/i1/7a38tv/SX/lIA3n33XaZp4vnz55/52rfffpt33313+5qnBWr8/fi7f9fjh3/4h/nv/rv/7t/681obtRf6ZSGUTgoT9vqNSZJS2ix1FMOk19Uiw20K6FejWGDbCvq+SISrJ9+YQqItg0Vtd5NyIs8T835m2s2YLr+7stoYSgFs3K2edVoDLeLwmzsjuFOz0jerGbNMCiznCzFHV5Sz2cIYiSLQamNZTIBYmjGhhqZk/J/tF2xqi+J7mJyZczZ6qtg+Ia+FpBDVRNDmQK6etzNU5dcUXKUTh6VNH9oL9+6LyXJw2nCk9n2SPIFf3ej38rjw8uGB5XQmNGWXsxUeDILbTZE5G0sveoQKKDHIJspkaDaCw47ORsop2fcTqGU1Kn6prKtBmkWh6sj1Gum/fpMONpQK1R0TGlYkL812S701vvnue+yfvU0Mwm6347E2pHWmw561rKT9zIs3n3H3/Lk5QUvwSZ7rjhKbrgORno5ovoGQ0PkZeb4h1kf6aXHvxEgPwdKVjS5qRQuHmUMi74/cvv6WefyFTC2NTz5+xccffUIicnu8QRBunx25ubux7LPdkcPzN+lpIkgyQkY6Ml3uaZeOHJ/z0WNnnuD2sKM0JeTMbs7sZuG4SxwPMzlnYxEqzKsST4VGRUsnrI1cG5Oa92HoZr+lCrFXpNv1Gog2qdWV3swhHhpdC/MUubvZgyxMj6sJWVvl0jyHzQ9biXYtqlPIRbgKzcO1KW3d9pCXcuHhcuJxOVNqNZhvjuz3mcPBSDWtwtQMcq7NXNVbbZT1Qq9mLRViH4cGDipuOjPf6mzwmgmnozMNzbJMO6zLwnK5UMtqeqlq3pS1WwHSEJBo0+SlKkhllyI0tQa8NvpSzIotBNpqfoHn08LptLBcKmUxFqOqbPel+RoKOQgyChif1awZymj7wxwt3DQFo6NH19WpdOZpx91+YhfF2MBP2LJd/yMXqR/6oR/in/yTf8L//D//z/9nv8W3/Phv/pv/ht/ze37P9t+vXr3iy1/+suUxEUENOkDsDUsEf6O8amPFoHVbqK+lUGoFL2bAZ94w++2gWw76uccSaNjGXFIkTSZ0jdEWzW14YmX7u74EVAzzxaM5Wujuqu17Ixfnuck9Lpl0rUinFbuwg4/36qye1g1ybEWpqzrUN4qBbkGHhhmPEMEGLZLUnKdxUZ16QQhqPyf5PZZEjQEVLECwOzxgkQkdCZACBMwAtbbK2oPt4GrnVDuXarY2A4JWL5i1NZYi9FI5XWwPID1wSBnZmRv0WMTOEfZz3pzFozvOxzC8GZ8wvHr3z8r+/ZTMDT+oWTeNNNZ1LYgkzmXlGTO4SDiO90zN3QA1jJ5qU+5gNN5TqTfK7TzzcCncv/qUm8OBKcG8P/D4cE/IE8ebA9M88dY7rzHvd5TWSCERwkRME4xjwMWUtQktRQh7esr0uEP0AuWMlBOizZKoY/yM1512ICRi2gGdZVmZn73JYWncX1b6/Yn3vvE+58eFTz/8hN1suVOX85m750fm4zMOr72N7G4pMrHfH0k5M2cIjwvl8Uy4fYf78yfczHCeE+fSiDlyd5y5O07cHgw2lRgtR6kVwiRoUtbVSC1hreRSmHWycusyAEWQVhAZZkdCa76v6lYEWjWHkWmGu7uZ0grZBkmWUrg/n3gsM/uaTNwus+mbxLRNBCyDTWy327DPs3SDC0/1zMNy4uF8tigQUaYcmSdhv0/2utaOuVFVlnUlK1AbUTu7KOQpMsVoesiiLgq/NrgwGmkBdZmMQ/BJAkGV2uw+Xr2Jq06rbx5gYuiCxRGpdk7LYtq0oBzKyr5Vi9Koq+2hVSmXhcfTiYfTmcfLyro06uqWU103OLAjTBHmFAztQVlbcCeK8dEMKF+JWdjvEvsMcxaHiCPQySlze4hEN94tbd1WKW14EP4sj/9TRep3/I7fwV/+y3+Z//F//B/50pe+tP35O++8w7qufPrpp5+Zpt577z3eeeed7Wv+/t//+5/5fu+99972d/+uxzzPzPP8b/15WRoShUAiBizYL0SSeCwG17yXXqtnw7hHVaneLT+dlnxBRGBMs6NYRScBSO+QulnqJ7sg04jnUAiSIIHMiu5mdKlO/Sx2oHSuVixjUeniRQNqx15MrtOBBgxptPFau2zx1wL06gwcDyCT3vy5PN3gWKFKQZijBSCadVNwuxgoXbd4gogxD7PTnSUYfX/FdgR4EYjqTgCuZTqXQvHAyK6R1Wtz045X5e3QKBUE8w07LYXVd2pzjOxTZhoMqSBMSbjZz8y7iTglA/JVN+V8HKawvVHWlaXYQti6vdGItCdBbR6UGW0J3qONmCHIZ943xZqCViw3q66rwTbF2Gsvjze8dfOML759x4ePK+t64uUnD+SuzMcbe94vbrh7duTmMGFL5UAIGUIyK6wO0Gjd8q8U8z7UYOa6vRd0uac/fki8fIKWE00EmSbvagPdrXskZaTPRqzJO+TwnLsvTJy++S6ffPDSJu4KyRl82isBOJ3PpOfQCPQO+9sD85TZJSGcPuXy8iVT3vPJ/QWpndvdxMulsLRCjMpulzjsE8fDZPZNMbBQaFGpSekZSlEqbQtgpFdit3gaUZ8Cu0LVDVatXWnDnLdXC/YMnXkXORwz9+doXosxsJbKw/nMq/OOm9vM2lZKWwkJhEgLnR6cBBAs9bq2FRqU3ljWhdP5xMPpxOPjwuViDi4pmSfl7WG2n9NX5qCEXijrBRVlHywGpu92RmDARLjaG7gUBmcWzjGyy4FdTqxqjMM5JqYQycFcXcploawra/XQVAns52gIBfaZWZGKHPcz+9mKgnYLxOxlZV0vXBZbU/TeWc5nTucTp/OJy7ps5sqmMxsgh/ku3syZF7uJ0BunajvpeumOTlx/mR4scHszcbdP7CcjOsVgxtc5Z252Ca2VSwk06VfJz3+MnZSq8jt/5+/kL/7Fv8j/8D/8D3znd37nZ/7+V/7KX0nOmb/1t/4Wv+W3/BYA/vk//+d8/etf56tf/SoAX/3qV/lDf+gP8f777/PWW28B8Df+xt/g7u6Or3zlKz+Xp2MHTYVQO7GzdcF5CFd9igKj7Q5x3OpUbgF6Hzjrk/dMdesaFNv/dGfHoZ5dEwMhmwg3BjGvsj6cLowiKqKEHIk52h5pHU7LNoFY0J4VxQHROXFm8w4MIbprdHeNlYynaF/flb66VmNtUJ3151Ni8MIgYcAWE/vdRAiR0pTT2ijaubTCuVbOl8VEta2ZK3iIpHCl8Q9hZN+EtcPZHQiRpVoSbpeASrI4iAE9ioI0t2Syz2htjWUtPJwWYzuVbjfynLn1MEjTVwnH/UzOEYn2uYzlbW3mExYxj7bVG5IxSUmK2/5PGXH2w++tWaikMxBjDNv7FtxwtbjBaSm2lG6LhRTWELh/eIDXVm7vjuTjc04X5WHKnD75BA2Zw92RuxfPmCbMVBeLaYlpgmA7IglAb2hbzbg1BPs7tQaGviDLA+vDx6T6QAoCEmkN0JERFq1YaadroGlAYqbFTjgk7t4ofPrykf3xwOP6sTmca+e4m9gf9gQRHu9fodNH3H3hyJQi0ldC65w/+EmLNJmPrMticFnaGUkoWEMV/TqXiBWmoJQALQV0sF9boyY4S+WilaNeg/pQUxV0Nc2givm8qXvPme7O54gQyLvEvEQjRuSMcEHVdHCrG8t27V7YLMywqvmxjD1w75Y+LcBaK+fzmfPDicdXZ04PF0pxFEFgSoE5WprCIQVus93nUQuzRI45cdxb7GmNgXOFs9PDFTPYJXhUjFlYsotWnF1NQaQTekVqQLyI97qQpfNil7ibJ3Y5IMFMgEeRSlMmpjhMIzjmyDFGJqCXYuL41unLCutKaJVJLXdvwQp19RBFRYjamUR4NkWyCjl2zqXwsMgWqwJYYE8I3KTA8ynw3IuUQZjZc+kicxKW0gndbJrUPRC/xRr1cytSP/RDP8Sf/bN/lr/0l/4St7e32w7p2bNn7Pd7nj17xn/9X//X/J7f83t47bXXuLu743f+zt/JV7/6Vb7/+78fgB/4gR/gK1/5Cr/1t/5W/ugf/aO8++67/P7f//v5oR/6oX/ntPQfehijTNB2MRw3KWEOpvWIFrE+IuGHZ19t172KQWqWnGqu3aOkXSnmwmD2OXtQTeBJMEcLixuPJAQxZaExZVZjLoVoXnI9RXfbHlk7srm0jwLZHVbqHUi++3LatYiHKPp0F2MCLS5G7oZD65WVJsgWVTJcJlorQCcEOzDPS+dcKpcGp/XCpVYTua4FQZmj0HpwKxnlMrRaXR2etEYgiulpzkshq4mEmxqUIsGgOWLC1Vh2M9TBwMQU7R5XQFN6HK4DMOWIBNN5zVMm+26uD0cR9zA0KLfSavVGpFiMeTSoaMXEokvrXC4W705XYjKIOGYzxx3OIuNzaQ6rNh0sze6O5HBZL3z00cd8cmPEhvzam8z5jvn113i2O1juUmyElJhyoK6PdIyV2DERsHWvsu1MtTWfjiJdzHuulwKXR2q5IGLXie0jrYlorbmHYoAhG4iZWhd2x1vWdeH5i9c4vXXm/uE97s8nvu3Nt5m00dYzIcD+sCcedhyPO57d3lCWE7XDy0/v0ZcfkF97i3sSazijUVklUNXyuCYv8EShh2G9BS1F4n7H3GC3mPdbikpNQomgwd3yRwMhDNzb7KiwQ9N0QbrtmrpgcGcKfi+ay4Q61D3yv8SLXusNpLvGqGNye93cEMAi2MvlTLmsLKcLy8kmLIOtsHurmd7uOGXKYYe2zs0+c7ufeDZHbg9m8LSmgJzxKc0OYwmBPIpUFHY50DSTgbybTZwdrEBIqy7ahts5Ird7ZoUX88zNbB6ZMSbS+MxjMBd2tYynOWeOKbET+0NZG6E3YmvMqtyEwCUleo4s4bqHat1kGRJgFuEmRXYxIkW5XyuRZXxUCEZIyiLcxsTzlHhtyuwmQy5insiTxTMHICyFpGKNvBftz+I8//7Hz6lI/ck/+ScB+C/+i//iM3/+p//0n+a/+q/+KwD+2B/7Y4QQ+C2/5bd8Rsw7HjFG/vJf/sv84A/+IF/96lc5Ho/89t/+2/nv//v//ufyVAB4vD+xyxOZwJysk5ZqDgQSzHonxmiCU8Q0UsOqSDul2RsgIWIrXN+bqDPDvJiYiasvdmNCG4SoG6QYAG1wOS3232IfxBQzPftuLApNhPViui5tyhwNzmrdYD7r7K14pWhQjs0dNk0Nyuam0coJsz7y0VuFXtl0K6omtBVJtrMJyUgMrXE5X+gCl955WDprbTQ1MaP2lZQCtZsgt5OQGLgU4VyFc1mNhRbwSS1QqzcDYm1h6VfosjYl0zfD0VYx5mAKBvU9rKyXTjk3ynkl18azGClT49wLE5G9zEZsSG5dBbQIsTf6ahTirnC5rDycTKxb10qOgSqBNSTKWvnkfuXVaWVdCqLu0p0SN/ts10uaiNm8+xShS2P4HkbMOkviZBDKsvDB5cQ/iu/ylS+/ybddhBaEtSV6nphfuyP3M7ughLIi3QLlDNe3PVoTpfRIEpBQaH1G2cO6IvsLa9sRL4/0h2/Sl3sW8UTmKRFDcwNuhWYMTG3m+B8QJM4gtu+RuOONt76Nb777Mb/i//af88V3vsSP/eN/TLvv7G93hCkz3b3g7u13qNJ49cknPJQL0he+7cU7XMLEq2XhpMpFK6dQWATWEAnNwgabCGeJaDTDY1IgyMTcO69ppezsdR6mTNhnc8QPNmFHu2CNdhQcDsJNez3VoKppd0LKaFm3bLEgHgMRLC02diXXSvL73GOtrdB0K5baylZ46J2pVHalsq8LRxGeT5m8Ew574c195HmCvWDXx7qQssKUeH5IPDskDjmzS9ZQiTbS6tR53BDAEZIpBZ4ddgTtnJYLKmYP9vrNxItD5jgncjB4/27OTEy0vWnFDjmzz3a9xggSzSlmxAyV3llqtXDP2IjtQpBM1EBvlSiVGFZiXOkZWhQeQwCykchq9/2usNtn9gkyjTkFjLBijaO6MVZSIaP2mQbhtSkxz5lhfC1eobqqOc50tbVFa9aIfItCqZ8z3PezPXa7HX/8j/9x/vgf/+P/3q/5ju/4Dv7qX/2rP5cf/e98nB4v9Klzk+0QpbuWSK4LcLCCs5TVXLX970NKBLdNsQLhBcm1ANG726e6KMAdBwAXzapb34daN7ff4GanEoSUMzElO8j7QpdAuxSLmx4Y/CA5xECIRt+M7kkXxu5IxHVZT9+B5nEHZueErSc2fZe3PIA5Jqy9cWmVUylMYt38ZV05rcrFWUDqQrzYhBqHkFaIvluqCsXdJmq3i7CETkqKJLbptDWPBOnVJoXU0ZSIrkOD4MLgwnKpviuxfeGjKvfxYp9fgDDL0NEaucSn2sBQ4DeK+lS2FE5L5fE0NCLKuq50qbS18OrhxMPjmdY6KQbmbBEgOSdSsth45GqjFcQanRijXWPjuug2CTwuha+/9yHvPLvhS2+9bS2LGotwmszzrpYLtRs7LSWLia9OgglpB5IoxQ6ULlCxSUiw6Xe93KPrsun4DM83k9CuHSR9xr3f9viN5i4U4zLIc+a7f/EvZIlHQhe++7u+zKc/o+SU2N88h/mGh4czFz1xe5gItXJ3s+dUTzz0xqN2LgLn3jlrZ1Fl1U7A/B4L3hSF4AxEIwMcxHa6fW6E3tmnwGGezA1fmxdsY6pak9gxl/1wvY4d3jaNnzIJ1BjYRWEOkMN1hxu7Ensj98LEZI2ooyf4Dli6OQ4G38F2NdKCxIwcJ46vw3Jb2U2RF88PvDjumSaLKSlMzF51buaJXU7MKbreEEQ7tayusTNPz+FFmFNEdpkge3bZGMLzbub5ccfNHJkjti+eM1GUm2zaoylH5miT2DSSFByG773TOky9MTUzOE7B6DiileR6pyaQUyDlxGUyGC4Gv5+wprK5pKR3pQVztL/UzqlYyoAB/m6CWwBNBKwZEFUSQgyWKK4ucWytE7oRsszcwBqDb3GQ+nx791kYoVDWSkuRkEwPMSIOwC9sLHzv6Z4iBAu2C2N/N5yUxVg35lUHXYc9khWp7sLW5jHerTRKKE/U09YF9t4NDszBhbMWwEcs9BTRtZinmWuGRMSsWlKktkqIQpozcUrkObkzu/2Q4Pus0iq9FSSYhkzVWVJeo57CKBZx3mEt9LMwt0yQwONp4bHCWtumExNVYjRLo9TF/BAValOaCkpyE09zV2s9mONBNww6pEwQw+Nbs6InJBAjVBij0pl4RFSTZedU+/61CKdLJcUFyeKkFaViXmHJnQ1oHUkJxcTAa1XuTysv7y/cPywE6QQCQTsVKMvK6VI4X0x8OcdsGhnfRUmQDdffdoPjWgnBjTFxASVcauOxK/eXhR//mQ/5hV/+Bcyxs8sTU4qkYDTrKEopjV4KOc/GLCwrkioSZs/kyuYxp5EWZlbJxNahV2o9g0LePbNrC2EFoqgRdXxC7aaNAHwy8elyWQu9N3LOHG8O9JrRqqTdxPH5M8K8J92+zv7FG8TDjpACszSmHlkuL/kwBO5VOcdASYETnYcOj61xbo2giVUbqzYgepEytl4AZk2kCVQqUTuHkDjOmV2O5lBgEMLWWQ/5BC4GFTAYPQSkG0KiMUIQjgH2oiSDP2wa6J0dnYMoR8GgeIazjKMfEkhqTZiI0oKSY2I/77l7FvnCwZKqU4jsdjNpnvyaEHrIVNdWztm0XAGfDrpCqbRa7N5E3PTGd9Rie6vQA7Mk0hQ5HnccdxOHFIjaCK2TxGyvNM0WTZNkM2aO2XbYG+vYG2nTKBpz1vYcBWnBroUOGSAIKQvnZBB6CFZ01N8fC4+1Pe/aFcQMo2sDFXEAFpfuGHozx8gkEBXzLeyQMRlPB6RZqrRUI2rEEGxn/R+T3fd/mYca88ow32g5NjCaAvflcu1BN7y2Y5TQEMUZYVbOaq0m1BWxzKIomwB49fwVVd1CCYPIZj7Zu9JKM88xoIkRatMEOM085cTeGT4tVnMscFv7QafO2bqcRCTvJtJuNogsX6nydqxbsawKSrU9BDjhQ92UlW0aMMKAa5Ja5VwqU55IIbAsC5empilJkRAjEpL5kolQCRSiBZapohqwIKziXajSqlJiN6ZlUaJtXtylgaudSredhdt/efiaUqsd4tX1XyoOXZTGXJs7rhtUm3szecGmQ4PkYsPeKpfLhVePJx7OCzmoJRx3c3h/fFi4f7VwPq2200kQE6RslGnZSJ5OEPH3MDhsHPwXsdEpljTczdLpX/zUN/hV33fh229uqb0Sg6LlwlJeMUebamophLBw6a+ItxOz+Hvauxc+NQr5vDOBrkLQQowd3R3QmGk69lhGWQ+9Im01o1bFXB3KhaCNgKUKp2RebEojT4G0NJaiHJ+9gYZEPh65eeNtjrc7D0FcyHWhlQuicA6BogZvaYsUES6YHdGiyk6gidJdIzjuQ0EJjp/nYKLxpMo+Rg4xkV0mMlpu7Qa6j/c9+v3S/QCMQY2ar7bnRYSdCHabDaasIg1ytz/fY+amDou4QNWEyeLF3PyebSek8Ug/mJO3qiMVzvatrdp1EQWVZJMigrRKb1CiiXd7q94r2M0XfC8cg7jxdWISEDU/zf1uJkXQVlhLp46daDDKO8He3yZqVkzYtCjaGKGlZmZteVSt9W2fjrNbZdNBmZN68GncQijVyVn2vcxs2rOpcrKCGYVjihRtjjYkcoBDjtymyCSR0JTQTN9pMKudk+qawWG+C75///lQpLRbuqQMaE/HG24vfgS11Vpdb2AHngzSw4hwaNegtWEgGazKgXeofUCJwJjP6MYw7EMP7xT2GMVdj817R0SJwez68+QEgubGpe6RJSmQDjP7nVn8TLsdDXMj1ilZSi72GqWbKFkXF8kOjYwNZkhQP1gNKlQ/0EQja6lUhbUtJrhrxpAbQYtpSj452PuQprwZvIq/Fjs8rQFW125IscJRq5vSRtmYha0pdakG+XXzEbPmwGDD4p+P00eswRDzrBtJrJsVlb+mkWZqAl4jjfS6sqwXlnXhvCwwBWrHjDKr8nheKedmuUQhbIy0mB1GFRNUN7ed2RQJ9iHa/iR4zlSCaYrkbnj7uTR++sP3ees2I/lAnHagqzH26IjHbC+9Qi1MlxPvf3oP+yOvv/VtpJxRsWusj4GoNmo5m7NImOlxQsQX+SjUxTKf2mKHLWqODnVBezWqd1dzTMEseVIOZAn0aWZ/+4wiVqipF86ffMI+Vtrjx0ZkmW84pR3VbZpCF+gd1e7WPN2yldSn4o2IYMQbrwsk7DDPEphCYBds75JCJMXIiKVR8SRttWvA10xXJEBsPzXQitIhd5tKjNFqf6XVA/WshzTdpK8DgggajQWJawWHplLjjJklKOLWaiMyvdSVtXi+nFrBQNhgq6YmWm+tO4PNngeYY4moO7KY3oAUkiXZJtf5BfMUbO6WE4Pp4NQb3q6+k9Zu5w6+m8aMrFutVqRcQxmCxWIEv27tOg6glbVULmvhslZK0atvn9r7LYjf28bIOyZ4+2biOEV6M+cWYiIF4fk+czslZv95Bjlb0bR52Kdkj8UZdnQ+jn1L5/znukjlnM04tTegk32ROqKnuzRq75Ri/lq1NBuFMWp3StHIB1RiC3TpXn6cfReCQ1rXvVRKCQnmohzc+aArxlZTF9lWIUQzX01Ttp0B1wIoYuK+OGV6MRFy3E3Mhz27XWaaMylnLmuxi/l4IOdszsHNYsWldaoWG8O7OKzshRauRSrY7qc6c6euptugWScWEesqayfNkLPtZxDzS0vZG4DeXfDcN4NIu0/VXRgqq0SCRLKod6rWPdEwP7UerEBHecLEsr1eb82Wu1iSbkpxo4JHCc4YSv5+ul1VN7jDQHTz0qvrQqkXWltRnRFMZ7KuNpldzovFogSLuR/03eZhgrUKa7HY+BTtRjPmU9uuBeumbW+YcNKCKv/sX/0LvuvFxJtvf5E5w1LMh8+E45lpdyQfD9Qq/NRP/AT/4sd/ku/83u/l2e0NSQ9or2brpdYVC0JbLmgtIAkRM1ptzcLxohu26nr2WG6gVXq50MqKBnMzGPvK/W5H74UUCz1OvLw8ElNkRyVdHphTIVzu4fwpnzye6K/PfLAko85LpNVGH4ehs2VbMVZma5VeBek2ITXwPYVNujkE5mgaoRmD+cIG69khrcHjWdQmyzj2URhkV2FrGFO1a7bX7j/friNR23nUZmbLNCM2pWA6IjuA++bzGBR/HkYv76n7jjdtK4OujbAYhl6lGqpCv04DajBYa51Su/v/eRyMWghr62oySDHigPZOjuZkov4eOA/QNUtGSBiaxDAWPHSa9K14BdTPHJui8IY7GryEE/w8RkM9w85o+kvprM32We7OZ82BWCr4fs4cJ2Mv3+xnc+mpi+2tg9HfD1lMZO/3K/geH4+XeQrntr5lwnU/D7+Vx+e6SKlrlmIMZBGmEMgCovahnX1KsF2GG5/qtVtr2olBSVMAzOBSECMvJJu2BAjNLJbMJFIsM2oyqFAwmKYWswgaOy3tmGu5YumdBoD7gru7+t26JcSorTEYe01SpNLBQxvH0tXYpskKU+o0WVldlEr1wDjsa3LyLBuxcENV4dGuYXStm/N7T9bRBQ1ItxuhizBwkNLFmYUGy1WHB0YstA0gts/R1qlrIcWIBEWCrc+bU/yHIa+o7VBMk2SHcRQxt2lRsnfaU4zMybvNGC2JGFhbJXWjWxv0Wam9mCamVXo1W6deFtbFKN1r6SzrahCtdgtdlMgcMjmZO/raGrl24lI2iGdQlWtrxlhDERpBvSvURk6gGM6+Ox7J66fUfk97LPTlwpo7u2liTgeDQFTZHw9M84Gf/vrXeeetZ0wvvoCGRiMT0jMEZS0n1ssDkxr8VF2i8Hg60WphPwkRyFGRVliXQpYG9WKfr3bW05k87yilGtyEEnolSuN2f2QVJQscdhNhfcn5pLTpBQ/xjvfvLyiCTNZwSK2E1olNiaUgraLFCCm0wqyRm266OlUlYHZVwW1y5q6kYD6MQSIpxC3DaPisGDk0WEfuhWdMM6E3v/YquhbapbKcGw9rY21mo6S9staFpVTWZaGvE2HfCTo6e5NO4PlfEq1hlSC0Xo3fjk8fGKzeFYRoJIPke0o3em3edYZuRbB1S9hysaPFxKjtYVprGFPfAjVzEnZTJudgInosZDOoGRSMvZ66YXRwGzVjDdokuzbfxfXu9liGLIja/WeUiDG9BLQW2lropdrOWIHgqQmYO444vHc7BV67mUkp0nqlrFCboVMCSLSE50NSpmS7PVGD9pzY74w+z75Tc23fWI8MmOI//PicF6mOegw6QbYX3bvZ2YfoztzdisLQWHTv0OwitIVpDB1Nk0c526EZJRgBXAIaobtwMedEmpLhEWB5UVo3zVJz6NGYgMra7Wa2goSTNGykNoKFOjRmjtJO+Nu4GKXYTizF4ZZt6aHL5UK5rGgpNpqLexY6tV0cwmLshKwHcs+9bV21MRuba8iiKqgH+eHu754BVIb+SAZMErYJccAONu1lO2h8IotecC1gETbihE9//oH6VsOeXXANWgyjS7NO23aMajZFIqy9cl4XTpeFy1oscdmh3lKNdVQWKMs1EyoEZ06OiVOGsNqvIfoV5/e9kQmudXN1XrSY8j/Ze/Xpq1f8i3/5db7ypTc4nT+hnVbmeSa8OFDDifVS+OmPGv/HT3zEBz/zITlEbt888Or+nuP+jphNsxfND4u2XpC2gi70poSpspxPfPDN92jrCWkLb75+w93tHTHvmfKeenmEUAhl2T7XMf04GZMeOxIbu1iQfuLxo/fJtzfE/cz07A0+erjn3Y/eZ5XIbg6EYonM0hq5dVLv5NahedFaO3lpTE3ZhcCBvE3iQWzJDyZYTSERJI7ZxeMs7H7BfSc3nFCdBTZw125Ta/fr3xxkKtra4AVyafCwdE6XSlmqHcZLJXTTlIHQavUmabBMjYAlBENDcPswvzf6EJ7KQEPM4igK5g0ow5YrkFsj5pUThVM7GeTv04Oo7cJCCKQUmadEygm4RusMqD26B57qOMxNF7etIrqnCfenB323OUzBsH9fS/iZZsxfgybXbvvp3p/srxiROGqNcxB2U2K/29k5uve9freGX3HbsWhhplvm23Y76/axqQ8H6oPA9U772R+f7yLlxUdisg8gXovUoOO23lhL4VILRRvqSa09iJmKFn8XwRaizcWARpbZdFKWIGr08hBM7xWSj+hi08S6rg4LYji7Ey+sg+gmwB3dbLDuadDFjXjXIJhXOMEYNV0r3U1P99NESpGyNsuPWitazObHOijvWoNh2qOAWB+s2x6hedhYlEBTO2haV1YXBpZu3oNgO4Cg1r1Wp9kLkLIx44Jb2oxi11XNzcDJIhI9iM5TgQfcafRqY0eOaJFRrOzix7vuwBwz2Qv82L/Vas4BrVfWunJeFh6WhbUUgw5dcNwEgka0dNqiBnf2Dsng3JTEojqSxUCYm4Pd/NXEX7ZgHotlPzxDV5JYcq7GwKqdj+9P/L3/349xd0i8cYTzy4/QPDGlNwly4J//9Hv8f//K/8YHfWK+fc77P/5T/L//y19tjc3yAL2Tdgb3GUa7EnuhXe5Za2c5wauXn/LRu9/gZ37yX/Pma8/YyTskFY7HAym5w0iaodh+Lrix7uYg3zu9nIi5I/rA3C60uCDM9HTHw9p4977ycOnspoSez8y7vR2MKiQVKpGzCrl3+tpoWgjnxrQqU1f2xO3gG5OUJVTbbsZuWvug+1Y5fZfcxoVwvZefhk72ZkLwYVacRbjJkWe7iPRuwYUNlkujXFb7lVe02jQfxVAOiREwlKKrkakkiG0OfKG/eSI6GSl4kZMQzfcS25VJiPYrGus1TitnPfFQOrFcqGowVxS7j6cUmYL7SCZDcUr3yPVgNmuDUToa7M1KCHE9nLn5m3NOJ4m5PwSxz3nsaUU85UHMW7QGK+dFzQpqFKkYAk0CaPMzTiyXS8RsqxBaUKezRyczVQDXqfk97pDk+L6DNdi9+X7aHf+8mKQCpgQPincM105H1X3d3NutoWgIxMnjIhyrTRKo3aclx4PFWVcBNdhPk3V4jt221sldSSHSxExeg6hnsrRt4Wo0be/IAOhIMzqpjXH2p6qwLKuzajqpVCTaB2udj3thTZ2czb17Pa++GNbtgDfsXZjd3TxluyCbKqU5bq+OP+vTtFnveGSx1702gsOZPSayGKSm3S54c5d/4m/Xu3X6fh+pGltPJLrZpBFUBo3f/q1sWqFtrAVnP6mF26bAlIMnqdpktJZGodO6QBGKKKUWLsvK5VLpq1F4jeoKoVsi8nppLOdKr46+q/2/IGa5tJ+S6VBCMMaWmuXWKL60bsnKvVs33wqhN4JWRIvd0PPER2vhn3z9G3z1l3zJHKzPK/F4y5lP+JkPPuHlxw984ctv8JMf/CQSMu++9yGX05e55Mj+bna7pEROmSYgWrlcHr1HDpwfTvzUT36Tl5+uvPv+T/PT7z/wi7/zi/ziL7/O3TEz7fZUMhonYlyo60qP3SdY4fF8onz6sRNXPBo9zOS7W0498vV33+XVZWUfd8R1Zd5lZlsxmraoQ2mdXetMBZbzyv0Ky7mhxabbEQczXF1svxut8fOFpvqypEszX0fF7zH1jtwaHms4HKbvtru0XZJNs3fzxBfuDkTg1bRybsp+jiTtls/0cGaOGdntzOnfb5TQM6kbYUbVCUGiaKs24eu4N+1/xXPgwtZsWbeVg50bKgbdN9dp5Xlm3jXS2cgJtXnigkLOkV2SDca2o2AQNSziJgYjb4yj/JrmPaDn7nueRpRmkgc3UUadsKCD6SvXYbQZ2rFWM9mu3SLhY4xMk1D1uss2Xog5tYxJj41M4p9zjPbejfdKfYJ2yJcnxA+wgvatcfquj891kYqqZKKJ8vxDeSo4bs1gsjGuSxhUbkNERY1N5U2isciQq7ODmggxTe6aXGRD61rtrKUiKfhFZOw0636a1x8raGAfXCvNl7IByThEUK9R7E3JayBlY9lJCFYYXZ+z1DNFzNOtLoYrj2WxqDHmUnRoRmwhbF2YW8KMoqBDPYYdBpjtDLUD1Q6k1n0/5kw8NYp9ThNpTvY+qhWo2pr5JzrzbPMXBINSOmgfMIKd8/Rui/jW3dLGbsgpBo45cbObuNlP7KdMSqZl6thS+tJhKYqGbhHppbCsC+uyoKWROja9+X7yshTWS6MsrmcLA2q1QzQnYRdtB5aDkztcIDmmqlabQZmDZTgmQm0GeylM0cgFP/nex3zhzWf8gjff4eN33+X5zQ2FlTfeumM3zRzzxK/6Rb+A994tvPfu+3z8/nu8PgfmQyZMB9L+dmskAkqvhZAzl8uZ+/sHTkvjH/zoT3B/qRxfPPA3/u4/4v/xq38hv+5XfDdvvPMmaT6AKFNM1NUCAcuykiIspxPr6ZFaTpTLiYZyfOsXcBNvWC8OG/dKat3Mg6sCC6iZOGVVZgJ7iaQqlKostfDqvHA5F9qA38bhO5bkYPtMrhMz3kzW1p78IZ8JCxzSj6e/rHeM7HPm+cGyuG5y5HFZuVSz/bqdE6EHauksS0HEDl9zSrGJ0BzIrcmE7nB1cQjNi5MY/NYGzOcf/PDRDNFIUSomxNegSG0OJV9hLXP9tpeZUmQ3BbJrn9bSzOXc34cUjH36NKq+NxPZDilJQemlElTZBWvACdhEB9Bd+6QG8veOyT3WYoQOb1jB0InmE2uK4UpoUaF1a3BTGmZSbAVnkMuCI07js1EfFlprdP8e1b+/+I5tKOK+lcfnukjRlJjUF7DRBa/uXwfbSDlMSC0oDER8FO1YKGJKVrR8odvcYqa3SpZkMArJRIEO+ZZSWHsjZqNsbxMUV5hABmvJIcDelBBM5Gp6GwC156VGK59SZrefDbeXtqXIineVIQRynu3QXAp9bUjpvnQ2jD8nI07EaLuS1m03s6URow7DOFHC4QUweJTiWiYRb6Ht9ylEck7EKRtbyi/MUBqaGlHte4Rof2fvI1Bd0+PK+9YatayW7rvYwdZ6JwvcHfa88ezIm89vuTnO7HzCscW2sq6Fs3bOVJp4V1lXo+DWSup2k+8n24nVUjg7C210yOKFWVDiUPhjTL2guiWSSrdO3+QHg9E19hNCTIHdNBEQcrKlsHQ7QDUfyc+PfOm1N9m//QVYFr79pvKLvnLhb/yDf8A0BfYl84u+dMNHH7zPV77r2zncvECmG9LuyFqLk0kswHK9PPLuex/yY//yp/mRr/0TLjVx98Y7/LOf+Elu9kf+5t/7UZaHj/h//oZfw4s3oJUVymKFFeFyudjhF4V3X35EonOYEy9u73j9S99On44sH39AK5XUKlO31y/RbI82Jmdv0JppYqodmI+t8dH9mU+PibubmeM0GzLh08+4H0czMphkY9FzZcpdO/INDen9ChHZX7oAO7KbZjgYoeiQE2uZTLKBOV3M04QIDsO3zfU/RHN6b1EJEt0hxq/9ziZhGTIMi9O+IgBBLYlgo5SLawyTJRvHaOy9UquRDBxOG7tZo5w7QVu7JzcXi9VQzAs0WZqD7XSjk6KM4r7UwtIbWiqHFDnsLHRySCRQzE7NmcAOMfmKz9iIMdi1e9gpy8V2fK01YogbKqXgBsvNsvO8ex9+hoOQ0Vqz6c0ntuqNxfhsm+fJqfKk0XhyTv8sj891kQqOe8cQiSkbpOCmo3ZzBHoLlGq5LKUq1ZfSZiUfCckA0hCTMXpqc+weW7p222nEGIhT2qYvs9C3w3XYFgUJzniTDTI0NpFd+HFgzJirNt3C64yYYIeC9kaTxtrXzeEA7+RrKdZVxUothXauxuorRnDwPA37rb9H1gl1Vsf5TcTnY0AYz9sLC2y2Q+ZcYZBgU5tkUrZcJsG6RknJF7vFYiVsTWRQj7/+1ptFibS+QYGtGo23XVbaZaWvHbow58CzQ+K1m8yLY+DmmMl5Qlt1h+ZIbY1zWbiomd62XhAae0yln6ZEmMWycErjVIVHbJdBAXXSSmB4LDojaYNtA3S7uehGiOl+QIcY3W09bXTbIEaN33dndLnG5dnNkdvXvkCeMi0lQrwh95f86u97k298+Dbf+HhlN1X+k+/5Dr7ru9/k+PobhMM76PwCTQGWlSQWNQJw/95P8e6/fp8f+d9+jJ9+1XjrtSNv3lRefOUt9tMd//Jf/Rg//t4nvP/eu9yGhZpnSyhuBW2QWuV8OqHamCRw/+lL6pRIeUakE3Lncr6Qm8Go2iuqnXBp9Bg2wauFhhqJRqMJt7UKr87w6ty4nM8sy2E7/IdFGKhFo/h1OKjIImYA/JSObBOUTdjF/25EwYiaAFdVCDkyseMmBnazwfrFG56xExkLfV8lbmhL8xulYZE92t0lxhtOEfHknOb3he2h7PWM/ZRNUOqEgYEKrKvJXdZi0FvEpruQk8N7kdItVkPU9261UdbmmjChdghuMRUYZAU7v9Zm+tDhm2fu8c7K1b4V99q6U9PNAiyoUNUMDXYBns1CaZlLaSyirOLokYyirT4JdaQWZ2w4CiPXaReMUIHvwI3zC2Y3ZdO5dugqVJewpKtb9s/6+FwXKcEMQrNT0JNToYOYQaiNTfW6fPUdlfRmOp5gC1TTx/lENgU0eDpmG3b/AXVWzsBqhwbIFN7WqYQhNBx7lgEZgHUyghEnRFwXZActTvFWJzAsqtRqQlObdoyRVFdbVPbQnMRg2gsVI3bY0jU8WaBeMf0xhttTEWcY2Z5CB6Tg0OfGCPI8po7QXSezrkYmyGJRAYAr2ocQcHRg7hThIYFWuO2i3oITV2cmtkLQzpQCx93E7WHmbj+znzPESKPSSqetneJaqNorKkqgMwXYJ+GQogccJlKAFiqhdl6GSMBCLhXdGorBhhxO7imIHYAy2HwdUaN5BxFzp+iBYsR6oggtu0zBP+iYMze7I6+/9gbT/tYcJSSS5pn18RXf8cU3+H/9wPfzz3/8XYJWfvl/+ov54hffZn/3DuyekQ63LuKuJMXNZGeK7zpeHI785EeVjz498fDyQ3azIPohMQV+5a/4T5mj8PDRB7QYqNV2KNM028G1FMpyIXYhh0RbVk4vX1JPZ26fC6+9/pxPPvmIpo22rCSHc0p3+Dh0ajFY2hb5nSBGxjivK/fnzP3jymF/JoTENBkQL/7eDOh7MEwNcnLj0vDZQ6933aD6wfY0ZtvY1Y77zElMTi6K2aaOEbI5zJfH9x6QXWvXyI4tMBLZfuZIDuiq0JqFEYaxxwV3V4buu7O+Mvwo13XdWLDjZ0ZvoGPKm6N+bdU88JpNPSHAWi2pgBiparpFIzU2l88Y4zEK7GJk9qigbRfcr3urMR25KxMikFJi15W2zzSFtcJ9SrwK3pz7/b/JTFpH1SbP6x7XvRa9zlRn+4E1u+OztJ7NgjnXZhE+w0z75/L4XBcpVQ8a7J2gulkaWbfise2wMVFaa1Dc7iTgGUKAH9ZGJwXRAM1uDu1GvSV4wRq4q4tULcmyosITqrQd+lGMptoQZwEaHtvatTPCn2NvFlTYwU1rG7G5+8PotoaFCc6saaaPQNX3bcIUbZkbgzGGikdX1I0dhBXKaEw+cfx7y4fyPVM3SRCVaiLgauy+VCpxN9lBIc7caoNe6+STbkzC3n3JXRXD/O0GClt3LPRgNP2G7UASMMfAborspmh5VKpud2NxAxNKiE5/VWUXA4fJgtpMd5IM6z8vyNr4OESSujTGD177jEwOFoNl+Ug31l4P9udE3wmq+oEmW3CkyQ8m81kM9ly6KtO044tvfoHnL15H4878U+OEpETMM+3Tj/jON/d822vfSUwTL157jkzPWNML5sNrhGlvOxrUvrfA7njLG29/GzlM/OT79/zL9x5o4RbagcvDhS9/YeK//L//Cr7ynW9ybCfW+5ekNJF3iZgSaW9BhGm/Z3l45LG/x5Eb0JVIZ7n/mOflkbfffp1vfOOn+fTjj4k+GRACZAtrDMHcQ1BBxKykUkp0Cq/WykenyouHxn4+k/KEYjuOQZ4YUFCrfSNLDEGs9zub1s9+D8ONog9JRxgN4ljUP5n4e3WCgemcohehsaN9SvGmd0fbfU/CgJ+cbOQUdAm2RmhO6kCNrGXC/+CEO3Eylt1fg5W3xcOHaEbT2aj5vdmEGFzA3mqF3nxKtOfW/TW01sb2eHuuOQQywn6K7HK061E9UqapJ2cbTD3WAIPBmHMghESOJldel8p7am7wXTvVnZyTiBsam+4yPKHf2kqjb59R9QJKEHKwdULz54rv4oafzGjSR1P8rTw+10VKuolFo5ieJvt+CK4XcPfla1lXg8Wahb3Z4twaIs3uNCFOL08B1BzohoK713Y95FWv+GqzhauK0OO1SRiGpCEEzzQaRUC3YrN96N2pwU8OQ1SNWh7CNXrBvto6NFfiaXPxHsGFgsMN2rqgPpanbjsz8pI26yPvQFUNkrCDwSfP8Vodsqm1EUoj+c3QuzJPky96k//7js/2aKlQq7MYvQPEqLgIdOl06WgQShDoFj63tk4jgETXJnVEK9IbqVeOYvZIKQUOcWKfI4c5mfJ9yqSY6KqcGzzGlWH0rJhF0Og6oxg93kIy8wbhoGw6GFv+XzviEE2iirO9Wk/WVfp1dzzc8vzmGdN0ROIOaWP5rIS8Q+isjx9z8+IdpuMNYdqTb95kevY27I4GHdUV2oq2ytoqKpnj3etEAr/il8HhzS/w0+9/SqDx9us3fN93f4Evv31D0gtoJD//IsfX3iTnQG3FDvgAcf+McDgzH/a8+vgDKCf2UUmxo+WR+fg6L54/42d+6mfsuiqNaTc5A6wj6texXeDEFNjnTI6NtTfu18r9uXG+FE6XgogxNAdNZ0zym7C+q08t13v6erj7vvS6zPBGTa7Ele1eHNPSlUk2JgI7KGS7r4Sr1m1AjAPatm/Xt12axeZYnHoIFgGCM9fiaAxz9gnJd7TucHElemwrXS8WtpNr1ZrOMM4xdy1J3ujROylmf3ZeYP1ciBKYXBQ9ssjGTm8dUpFSr1OcXPVVYdDbRekts8sJ8RiO8V7EoExJmHJkyokpWcruEFWrT1T2+Q3mnstzvNGnN4TI2izdoDhxy2y6/Fr4FpdSn+silQSmYHuS4VNnhzB+zcp2obfWN9sQ6+oNJtNSqZcVPAKe2ZzRJRrMprhtT8cFkd2noSvlfMy9Kv7LWg1U+galbYw3p3AOgaNlq+j1IMefV/eDr9thark7geDjfPAbckQ3jADGnCyzp3e3qml9OyRGNxqC7VFCTltBHPsqk0oY3NW7oNJpY6/QAiF6Um0xqCLsO3m3I4rt6wZNWFs1WG9dbU8l5q6cgsUIBDG35p4CYU6b7+Bp6bw6rzy/rEhMThzodK3QC4lOCjDFwJwiN/PEfk7svEDZoWG5WXWpRLmyksx+xrvhIEw5s5szOUVynkkpu9OGbk7oVrfEiB/BdjPWNycQs3DqLgBNKXNzuOW4uyHGGZ33RI3U0wWVRpXI7nBkfSiEvCPMt0i+RednlDwToxBqZXl8oD3eE2txpmkizTfs7iLf/YvveOvb7nk43ZOTssuVfZwI9UQtJ3Y3bzI9+xLMB3qOUFakrAhKOgbmfWd68Qbp9g3q48fE9ZXljNViCbDP7pjnmY8+/JgJ08pM02TXhxvvhhTIOXE3Rd48TqxNOfXG5G/f2mBZq0dxBIyWEq47Hz+8W7PdSXICgU1bV9+4QXRRp1ps5AoJXre6Tet+GBvBhet+S8IA+rbJfdCjbcJxb0jxW9gbswDgFPo+JsrWiH34fjpLUAw2TcFBTT9ruj7dsTlaMtYB6tBYs8Y3Yfsi6cGSa+2LbUrzJ9O98NI7CYPZg1gj2JqlEAOoNkopZqDrrzml6J6JEILdg5anZfBe35KKbWiektHfzQ3DBMfTlIzcsqFABo+Of5M9ZiimSEzjXApAZG1tk/0ZmGTFXVTc7uJbOOe/5Yrwf8FHkiF8s3dg0yXFsbizw7a5V9eAHVC5TiYd6lpQ6aCNqWV7E7l+TReQptseaRgybpRKtS79M0a10YicUd3jD7bd01OtgQ7Ybnxe/qVjVA4S2HpRv2GDQ21969R8NMcO/+AGmk1NuDsWxk9vGHOmsAO8rstmW4J3fvZbX3j6LsHglkAIHZppVZKa/1jYKRJMANubOcoHZ8SNmOkcg7kOeFfY1ZesTajB9BuX2jitymkxTVhyn78gsokfc4rM2SHBNDFP2Sa6OVsOUzB7lzDCHjV45zagKhMy5xyYcjZ9ycDSnXnVUJo6SOE2Vx2HHQFxh11Djux9m6aJ5zfPuL19TkhHap5BImEp9F4h7ZgPz0w4fLzjInum/IzdzXNI2a/hvk2+wRfUISa0ZyQ3jvNEDsohwPn+JeFSeKz3dF05HA7E3RGZdqTdEQ2RFPYQrUiFlJlSgrpyM9+gp1s4f4rUgmKWVN/+pS9yWRr/+//+j1kfHxEJDtnZ4UK05qLtMm/c7YgCeUqcm3KbM8/31iyM5mJVI6aA+OHu+yyHcRmXpj82naMjD1uR8maqdz/kRJzQIh4b4e9VeAohjc96ONmbeHsEJdoP71f0whsYJPjPH2hC3/Y1w85nTDfUYvEtYozC5giGOhzZeofwRAyOetyMmlqfRpptWklELtUty0IwiL7ZfbwJycMwdwqb68pgJ2pvMETBMZmBb0okVRLdmq+h29TBWDU4chQqoRvpJQYmT8KOKVgBH4VJbFUyBoEg0WOGDDWyghYo1WFz+4kmYYmBKarl6sm3Vn4+10XK9j/dFfXmjGAjrpEi8K5KYnCaOrQg9G43SFC3SW4CBboUar0KVonRDtkARKVXg6y0VusUVHyBjOcM2S7HpicTpSpc9xneQakTObRVG/VVtiIhGrzadPe489BBt0gZRbN328soQuiQ1bDqEM3KSWK0WOnevXM1mIBuzD3RTugGvTX8+TmNHtyfDMeN/b4XQGszLz4EUqOujTVWhxBsghK1whSSd3kKKXSm6BHZ2NSqMVCbFbcmyhoitZsNVC3mEL8RhTBPuByUXY5Mk3VtU0zkkHySi5v5bdic0wU00KpNoEIgiBWmFCFFNnrweGwHohdn9eIhqk5XFzPP9Mk8to6EzLPdntcOt6RpT4s7c3JXuxZrbXTJxHxgClBV0PmW+OxtNCUXZUeqVsg7NB+pa0Hq2cgJat57BCHv9rRyQS+V9fwSWc/M+xum4+sw3yF5pmuAHugITczzMIXJWK0pEw4CWZB5hywnpF1gfSTEmbfeepNv++LbfPzN94hFiaE/cQyxvZTMmcCBGCK7nOkdckxMMTJlcXG7GKsTO9zaNi1hN8l4+E7vSlmXDS7b4LwnkOB4HilEgqSNXKHB7x9v5CQMhh4bvCciuJbfIXecNHElPQ3PPOg+adlz7b1teP7QS7W6YrTFhqRkhXkplKWaabFW0LRNQgYZ2Os0W7Fk9P7gvpVdqNYeoiilV4PJwNim1bLBUsr0aM81e0Nrjc24Pl3cK/batVuUzbBrM7DHmgW7ryJzTnZP5IkQTKcZ3FvR0giMzbrh5+OccJhvJCk0tTOlUo2MhJJz5GafeFEndrNdE601PvjgZz/nP9dFyoLlhv7ALrI0ft/dMskXlyEYe8uEq1cAiLGXam7x0kGjdQ4p2YUycmDsA/cbwAtG3A432x+ZHzt0xjhro7durIjrJHVN13N4dlzI7iMW3dUhjItFjGABQ4gqDEPbFKMZ7AYrAEZacEZit/TU8XxzCDbNqEOHYv87OkZxWEzE1OTbSxywCC4y7Z2yrghGpw1OOEhY9k8U61yD4qr4wOS5X0aMMjGzQf0GrdbaWIop6o1wZWa1URtIIwUrTFMyeCH5xr31Zga5wYptK42yFi6LRRJU322EYCzQGOJ20IQYjWEmEZtm2qaHMpq7QbwR2cwyxX3mRE3uNqXMcdoT4oz24J502DWZEqVEqGe0N1YSi+7Y375O2h+RmI1uXVZ/vwVJQl/wdFfvkLWiWAM1Hw70ckMIBr3E3S3T3esQd5Q2aPXdd3+ChAQxoyER07juAzHMxJiQCq1cIJ44zjNffPsNprJw/vTeXEhC8M8/WGq0CFkTSTKHqdK7vQ8BCNmFqH59gcdBdCcyiDBE7UM79JSC3n0a+Tc3Fttz8ElIFEbWV+/mOdn9swI241q47lHUz9cm2624ISHj4B4kJMTgbvw6MDQDZ/Q5pNebuWaESs/ZnNDX1Wno9fpcHF4UDealWI0oFTz9OrrPqF0yRs0HZVlXmvpnZSeWTSpazYhWxTKanqwQwLR+Qa2B1QFD9nGfBjPF7ma0HWJgnpIjFkZ2Ads712beouNbj4aDMZ363t0mKoMT1e8ZuhIxofzdFOFm5pATrV42otE//Bc/+zn/uS5SBh0ZfJSiEIPHSIspwus2Wdkbm1wFTzAY51ooDM5TNSiPFF2/4KJPh6lkuBAnv1h1wDEOhYn58VVPpPTp3LrCOiLiI1mxgxSxOAFt1k06HdZgJNNAyAYxuROywxEpRqOt4lZIKTAlYU4mSu7VDocgxnIyUlZHungAmVviKCaE9avQlr62oAYc7/dO1GEVVY+3dxbi6lTaEG2c7wGy08JjMHf3JEIWmAN2uEugiwUrNjezxN/vwVLSJhs0FMZNEAzy2yWL2BCvnOPwMDxeaKWyLLbAP60miEZkC1wcvmLdYye0GTah2L5BZdz0wQ9881TMjrdPIZDGtiUkDje3PLt7zu5wYyePh/Jp69sNqU6CWHRCjm8Rds8s/dQXyuu6kkOg92LFvxlbbTQrzXegcTcx7Q5ofwGHPRL39LhD8h4NVojUHalDSEjMhDQT8+wEj+aXfURkQrTQ+wNZTC+42x94/Y03aOczn9TOcr5sS3i7zkFiok3WlaTmmIEz9syrz4g2237YITvx929klA0iz/hM/n2Pjbgyvlaf7HrFXfRVrUl8MoUZHj4A8avTSPAgRAtqdP3WQPscbmNEYqjlrYnD/9uE5wUxqFl8tbjSunK5LNZk+fQd0C0gVPXKgJXWKGM6EYPtJAakux6q2rOro5K6XCJ0I8J0hRDydi6IGMyfgiWBJ5EtuqQ7wWu8N611SlOqQkyR/TwRu507U4p+zcJSGkrB1mnNoW7ZPouxZjAAUjxmwZiTKXgky2QN/ZyE13YJ7dFd4X8esPvGjiLFcfELFhmOi1INfqrVRmwT07rTtl9ohsNiF9Dojt2UUZthxhL8OHJAdtiGqN8oKUdCNoy8dnMtaGrdTIrJoDrbGprhbLAk31qF3sPG+IkOM9kgI4QxYgcrVuqtX3D4cOzZcgjMyXc02Zws6lXYsBEXFLto6RUt0PowwfUIhN62m98Wn35RIt65Ok3XJ0sTu3ZKC05XhjUE5mx6FYIZtiYRcoQc1NJy/Q5rGBRZYfsclIHfDz37eBWjW7NpcJjWqj5hUgLDxbKtFl1xWStrUarDPuqToHWbtm/rtSFrscMsBkbK65icu+8wM9cAvylE/31g3u159uw5L54/95gVg6cUU/43urEK88y6CBzu2D9/i9qFfr5Q12quArVRpNKWR/RyJrlAOJiOwTVvlV4DTBN5dwTdQ5xpkkEyGqK9p05rjoh7saVrtlTDmGsSICSIk/2SgMZEJSLTzHxzy+7hYo4gPsGPPU8IgZ5whMI3Nd2biuayiNHYj73qgMocTBu/YMDhuv1ehkLVH1sTJb4vHjKCMATvfGa5tf3uKcyHETNsSjC4b1gVja9/gipuZAv/wfZ8dTwf/3y7k4yaiZqLO8espdm93tRzDn2fLKaZqhKpWkw6o+N+g+65Uh3ootZk0DdWpcrQNY737Fpg7RkOc19nOXtj05qvLdTZgB2q2s+LSdhNLuURNe1VtLNyrVbEc3QdpWDPd8Cz3hyk6N6cvrYweNXOshwDcYrsdFD1Z0SCGWp/C4/PdZGaUmCesqdbqhUsv2iTH6xD0KrY8p6gNOdaq3aid3gM1lYc8AOmocK1R8kPCgn0No5uCBKZpkTOvjQUc0ZW75ymZGJjUzcJS62URVkvYk7NopRiXx/AWIX4DeK92/j9WNCK+pguvhj2CXKKNk1p0M3LcOAb0TvgJKZFCr5zqb3Rqmx7MrAioKIgQ8/kHmfgaL3vr3zrpriBrcNp9Mw+2u4pCexSJAdrGpIabDYeYUCcvhDvAk2r6VICEMUahO66j9E2bsp/+15j52BCzkpZVitSS2UpnVqt+eieJG7ddyCpTZahdkJwrUr0KUhtu6HieirF8HUXj89i1jVTyEwqsBQajxB3QEIl0Eqh1kKIgVoD7J9x8+aXibsbLpfVNXGFcz1Ry8punpCyEgfZJySTDnfotVqz0rs3XhEJZiZrew/7mtqsNAaxvw/RGI8mAKru4N3Ac6qiBGI60Mojw7V8v5/pdzes9w+U04l1XV3oauL1Ib8QNY3iYH7VpnjE5wbtjekb8fiIp5PTkx3P0yJlot3x305OsdvRIaeB5cvYHm3X1LVUjZ3JNWl6PJ/erZF0AxI2TNu/zSBs4AXx6QrtKfGpu2WUOjxd1eyQah07TT8n7F0gCFQ1U12rXZ6tplacqlq4Y8XOmjH6m1aMrejavzNyhSVZWeMnvdMbaG2k4PeW2m62O0SvYlPSWAIEgSlHIpEYYZ8mQgiUZteGUcZ9ByXBP4tmBs7BvPua6lZoR0T8aAijGy7AWB3IE+r6z/74XBepLMLOTUFtL2PaGUSfwCTqeiqDxYxQktzTzqaQ7aYJeGCfELKNwTFOTNF2PagRcmqpDP8xY5pZMF+es7l9+4UUxQ60FH0RLLDUylIyl1Pl4bRyOVcuF8zVYggIZRQnG+GD4A7K4q+t24LUoZQUgpnLJiFnsw7abkg7kTfIM/r0GaJYYmcPFuCHbj/DjxVQFySrO4OLFdumV1ExMmBRnH4vWAh2RLpBF6E3ggSiQtIBJ8DSKjWYE73h8s5bspHINHBg+7wBKASxqIRwTe69iiYN869d0VpZa+eyKksp9lwHDIt6tIFR2aMEEpDUJ1j1IL7RrRrhfIsZnyQ4CcSW3lobl09forWRnz0jHJ6hLaBiURBTMPeAUw/s7l4w377gdCnmDRkTeZr59Jsfs5SFEO7YpUySPbmZu3xrZXNLsV1JtLMHsd1St6ai1sJaCmMPFvJETHZNBsFZdt5I+HLbNH4gIaEKtS6Ekq2x8Ym1Y4fQcCL3n2yTN+rM1073g7nIcJB4Unhgg2YHdCwDX4Mrmci6u41EAde9kV0VwrisDXiwr0fdDWI0nf7QUSGcdj7+Sr2RHBofmt1MIvGKrGDXVBidIGxODCNddwjgUfX7ybVEPo0bKeL6use+pjja0rtCFWJsBBGKa+oqsn3GHWHpneKykCLCPgVUPC6HANqQ7i/Vs+0Mmnbk4gkP3ITPQ5cZiFk4ZCtG05TZx8COingYawiBERI5xNfalR6C5V45rG0WTU5KE64N9qD9+91nhbYj5Vs75z/fRSpG5phI3umKKOqdoc3l3hn3xuQXq4ZADXYxdNdwiNhNlyez1QkDBUkWo5yDWBct1nmsa9z6tjxorVGYd2bsisMh2S2KppxRjIFUWmFZG+ewcAjKOQVOSajNKKEGtwFc9ycxRe+qbXG8HRbtSiHdT5aHNLlJ7pwTU+qkFOhOeYspMU9G37YiFchFgYs5VKvBeEnsvVWH6kZeVAh24zbv9kZmk895BJQchF3CgtACzCGwT5FdSqZpE2WXjMhyFvVEU1jaQlGDt6YUEbEOPSG+XVa0J9s9xmx5WZ5fY24e9tnL2DW0yrIqp9JYWvHP2DJ1UoBdEnZJP+PQkT1g0Wj/9tiKINYURT/Eot+AQyogrcC6sJxOpGm2CwjPF0qRFDL7m+fkww2P5wv39yfuP33J7f7A4Xjg8f4VKnA+n9m99hpVO+dP30frQk5iGh0vmCrR9kwi1FrMfV4Cyopq86V7YHP3bivam/krSqRJs3lYkunXTH1DnPfmTZiUqImaJ+bD3g4nh7WMgTp2HNZUtdbdJHnsfLZV0NZAxGBsQwANGJKBXedmZqBbQfuM6znX72dZcPIE8rIZyr5uPJ/r9wC/l+IT2M7+0iHAqyMEGLlgTHbbDsyf3/hmXZ+41+ggYti/s4MfKyRDSqBCQX3z5U1PjGhK5uLSqnVl1aD1tRWbVA3joQehqLNBffqbRGkqrmPzAqL4rtRsvGYv9ElBUrBoGfenVHWnnK1mCPsps0uJ3X5iPwmpV3JvzCihd9RJRVfpzfVzqtUDV+Ng/g12pIuyN4DdPu/qKeLl54PjhLFKjBnDONwd2rIOAuYoHKdo7Chxx3KF2pMt6NX2Drs5sZ8mdruJefJCFYZvGwQX9bbWWWOgtm5eeT7FjLTenLMVQ9+BiQhztArZtDPHiWPo1BBZUqbsGsu+mCea3w32P2P5C9lV7cPccksb1sKAVXIM7OfEPgV6UNIkSFSmnWwJu1M23UOOw7BTWKtyvyTqWumeYms7JMu9SmHsaKwbHEJMU5GPecf+LkdlniKHOXOzyxznzDEHbrIVzRyNeDB5ON8pLIR2AdQgue6Jypiew4IiscJDM9ZclC2JGaxLjWPCEkGtenJZOuelcPEYBLyby3GkjZqFUhBLFU2SzfE+JDcMHszOaE4detV3jeI09FtmjmAdpLlcJ59SDAa91ErO1g3HpfH1b/w4X//6NzifL3ZtBTM13e93vPXWm9zd3JmZqBMzaJXHl58SpJHnHQ0h5gkhDK6CwziRNM0ESaSUrQvWap40QUGyCbq7aWYCjdAq0lakV6JAbZVyuQDRWGqXixkfj+W7F6rmrhGb3unJYmeQE7YdlENsQ4bR2ygSYDo6m9Cf4nTjAOywaYwkuH7Kr0f1eJxNT7VNX0/gZLmGfw74u/drbM2AA5/+3OvjSu5oAx4vxXw9VbefMuyyVKH2xtLNNWWpzfKr1kpZK1NKxMkRgBghRhZtlFZZutC0UzwqZJgdS7A9d6nFiDchEpwwNNITVLtR2OlojuyCGckGf05d1eJ+vJDZaGrGAL12QhTmZO4Tc7ZmN3Yhd5BmXp0yJiBAnI1p10FziU6kusuEiaTtejS/Q6yI+j2KmyO0z2Co//7H57pIpSim9lcbJW3cVCcq2FRxs0u0m5m7Xd4gotqtYzFc1jqAm/3EnKMfYLMryL078w5NnJa8FnMstvRbcfsjc1Q3U9hBlAj02ojuVtARNNpkoCmic0abZxWV6nj46CLt93Cl2aYhCmbYFlXbnXhHZRYmlrZ7t8s8P84sjpUrgkpwLcOAscyTby2ZXjq9GC02OXQaozF0wvj3bERdajMzyy6CSLQLPQa/yCNTFOYcTbgah0reXUJyMuPWqpTJitGlNC41sDSF3syHzTOsYrSFtynoZaMtD8YVWKeuXalaLeNIA0sX1uYuAk4rjtI5TNkcK3JiNyXPkYrkmDdtCeB3uMMWcE0tgS0SQVHE3QkaOEEh0wl2cLgBbVNhvRQ+/uRn+Ft/+3/io49esrs5otEW2fM8cdhNPJwe2e0nvvSFt7l59hqnj77B6dVL+nIxw1YRCBMqk72u7F6JrdIIhJQRETsASkHSFRJtrWIkBjMoDVSkL0hdoC0mgFbA6fhzSkx5otfmRcq75AFTBd8zhIBGjICydc8wisUoWgOak8+cTXafXd9y3eA2+59+3Ud5kREVpHfEYfTNxmvQvX1xJX6oajeWnnymAF1/3tNdmP4bX7NppNqgcF+F7SLWxA7HkdaU0m3PW5o5s6y9o2uxmHufMEZGVQ+BRWFdq4evmnG0dCXHSI4G9ulo2vq4+4bp77VARzfX7sEoz+JFIboXX9e+JTi4gYYV0lapBg0gWZCgV2i4NSIQc2Q4eozd1GiWg58pIpZ5tWolkWgeS6TDOEGGGNhQD3tbfz5MUmJ7AgnWffdaiUkI7vIgU4J9RtoMOCUyJGqrrNWYYUhgypHDfja4LBk8F10rNXQe1wMazBXYl5S1MG6NEM0uJjqFM4RAqdVp0QEV9ZRcRV2lbVMfRu+suuHbvkqy7sWX+TkNvNzG++6XrAIqppWKKdpzUigdp+S6dmtcUPFKpW2tom01jnpl292NLiwQtiZZvfM1pqB9b0PjjGmXQ7hq11BnygkBU6MH17eYPYxFue/3jRo781rJFyucaSzbBLchMnpulOwEjkGBjU8IXYbhL0U5r5XHopyqsirgVNjs+6H9lH1ajMxp2qKvo+eSRcHGWd+7iTumC7odfMNHbcA9BqAFiLORayTYfiMmBKiXhVIaf/9//Rr/8Gv/kDAdOL54Rj7ujGDwcCJo5f13v0moC2/d3dATLLWxnB749IN3mXLmzW/7ErMrCLpClIhqo5MJKZJyIvi0JN5hD0KJTabZa8tKLxcoj9Aeib2wajcYsUVq7Zwfzzy+ure8rpERpOAsI7eq4jN7B6MHhX8DvrNrNWwCUIff7CLc7Lo+82+e3Ofj767Fx+FdtSI29pH6pDMf14gVNX0CL9r3Du5i/hlGochnnkfvjWrBAxt12wStT5ZbokiXbd8TgzXPKQZCNZyy1rK5ojfN112xQmlwvza0G/TXeyejzLEzxwRaHcmB2CFG0/kN5u3YcwVtxrrr3SyWojlrbF6FT4pv087aKktrXFpj6Y1pubDPnp7Qq+WR1eLwul1D0VyXt4lV1XWH0ZzOO2q65m4FOXLNntqaXBdHS5SNJPazPT7XRWoYSAnmLxbF5SnYG5R3mTkFe/PDUKgHSmluiQ9TiOx2E9OUmfLVpDbEeO0a1QSR6lMDfthqLdT17EK5EcYWNho3UegtXxlnrbGUCsV0GikKczJ/rCC2gFaceStPbhzwSUoQp30HNY3D9eYykoVBcgk/UjdhsfqBJsmKpfiFj3v8BffSGlTh5vsh6QPaUTSYzmw4WxhU9YQC7sp1u3G8qNkT+IweRoPrh+ZCLpmszabAnEE6os3dL4zNtpnigu1Duu+B3cUdtcluWQuvThc+fjzx/v09H58unN2kdwqRrKYni9jrCuri6KZYm+OXFUYYIbgVzhCP+rkUHL66HlS205p3ew53Lyg5EUhInOmeVqxLZV1OfO1/+xqtKUtZqff3zKWYO3ZTdL3wcj0xXx75z77nu3ntrbeorXJZLlQVltOZ4+nMrjQTOEdrWqobHKeQICSCNnKwTCJLdbZDJoz9SWj0ukB5RJd7+nrPWhdaiOz2t4hG1kvhcjpxvjxSWzUyUYi+azKBeG124MGY9h261HH4eyHBCoWhatZ8KG0rJKMJ+MzhJ3btBonbVDUg22GSum05fM+yETWcZr0Vz0HvfrrrknAtZFzvo/EY97TFuY9iq9uu6qoZsykwqcHd2mGKiUTdduVdG7Wu1F7IsjNyqlohVRWWpXEuhbUWpBs7chdg50SiFOA4GUoyByOLBTfu1WZBjTGYq84kkUg3iYRYgxeDQffiO/VhsrP25oSYTl0WdMp2zbvExDSOg/Dg5wvXwr4FpQ7o1wtRqQ5UD4ZwiAxvQztyPPrjWyw/n+siZReBjca7PJGiWOeBEqQxxYmUI4cpE7N1+KhZ1NRaCWJY7Dxn0piCcmY4h+NQe0M8nvnqShyjYeKRG2JMNnH54VpWY26pKiRbGpTFoAETjdpNkZPBi1M2RqE000uNSerq2H1dTIprHWIIJDePHSJAmzwCpcdNeDsOX1UrLpYP5ToVUegJNNnyXf1GHDf/Jsqs9n74RRlD8hHftWlBfI+T/CcaRj+sYAYUY5exXei1VWoxyxdkwJXRzGtrQ9Vgxhjw5GR8n4BDaRZyqRJ8kV9YauNUKvdL4ZNT4eGyUksjx+Tu8Pb+tN4dopMnuUbqjCX7OWNS83EbP1ktO8vJJKZjsXGvpZl4+xptOqIhENNMiNNYGNH0BDFQWuHTV5/Q0sT6SWOXE7vdzg4V7bTlzOWY0bZSy9n2F3kizDN5f8PNi9foKOv5kcPxOdO8Yz0/OpTXiFgQZGurT7kZjRbYaILpRi8VypnUFqRfqMsjrRTC8RkQWC6F+08fuH/1itPDvV1j0W3CfPkS3MYHHClQt+naDt+BzqnPmAaa9t4+A/eNiR29ehUOJizijh5paBS9MZINadyYtHapje7eJnjEJ6btB/YNErR915NAxXHIdpsWgzchm3AXNoKEIG7JZtd/C8M8V3wFAFPoXFCfyq/MQEsGirQq5qSuDg/W6snhxuLsInT3k5ykMwlMIZtWsdXNtGA0iUFgim5h5c9B/LQycmPYXjfqEhaUfQxMEphDIPSKLpUeILkR7Bj5eq8o0ckW3Qu6+iRs97oh9dbQ2rU29oNDY2XnRIiTvQc/H8S8KZmqOichRtt12AGupChM2boLkl0YY5S3QDz7HlNOrq+yMV0cghg7GxUTQvYUkZAwZ3BIORJVmVJimswbrawn6nqxA7pGx/Zt6RwR5pyZpmwO3dU+oJySQQM+0wtO1lH18Xp0IN13CXYoDifh7aZ0QaUAk1/cvgHYbnwT6PXPQGSmfA4bXDP2CE/x+W1k905ywDW6TXLGGrOv0/GdweHG5oVv7NlaayzLwrIU1rVS3D4qBAUxvz+JmZxsX5RyNt2UQ07N9z+JJzRxhwYJYhNfiKQkHHcBlWRMwQA5Czn7XkzUD87OYMeOg6+2wfzCYD8/8OL4QWJFaghTmQ/I3Rv03Z2TTKJBfmrTX4gTacp87/f+In7sX/045wZ5mmgB7j+pFiIZAjf7HbvDjIbOcnlkXS6U1sjznt3hhrS74fF04tOXj7zzhZlpnni8/9R2gOlIqEotF0JopDhvjFR12EvbivSVenpFPX+CLvdQV4t7P59hPhIPe8JaWD6qsPrrC8FMcDu0UjcyTQjBU1mv92UYcGC/CrJDMIbZ2Mc8JVUYxORZSk86daF7YxBtigNwm65BIhqFCRnWaJ2naBxcNVa+vmI4SnTXStrfueNDt+8rBlp6fL2L+32atuuf63uqQzzvU6NY1EXEiqx0GKGroTWSmyOvJZIXR2xaB3eoAJuCL3RirUiCUqBFQXJCeiM2R4diMLWiVmMgG7hk3pfYNR90yA5Gc2qQ/iHb/rUJzBlyaExBvPm3/x2m1mZGO/LEhqntSGHw+2ZQ8v2+ZDQP4313qL7686mOMvxsj891kbL49oh2g8tyEqY8bZ1WioEoaSAM201FjGjqzkxRLKZZAKNljm6gd4GQkBxA4xanvHV/lrzneCsM/coQRPbuQsNx2E+Zm92OrnA+X1guF2opRJJFs6dsUFtoW2eGF5n2hBE1bs6hNwkeUzJeX291OzO6VwvFmDzjprwyrmyn9BTK+Dd//7QL0963YEd98mf1SRcLAy5xyAQ7tI2RZRqN2pUqZiezrGXDsFMI1DB887wgBHOuqH1AL4bz5xSIcbIJNSVSauzmxO1+5p3nC8fplrWZ24T27g4liX1Kzj40EXZy8Ta4tkWeMihxhw8j5ZgTujtJR6G1Spz2TIc70v6WnmZzdhCPpOiGd867Hfev7nnnnXd48fwZ4eFs11pprBUWbdzsAvt94pd85XsIMXE6L9TVrJRSjkzzDpHIh++/5IMPPuTZzR1ZVnT5FA3Cwmru1ymSUkZDpnvz0loh1AptgcsrwvpAuTzQl5M5U08z8TDT+4o2uy46AnlH0LLtPSXgGifQZkXh2tzYzwrubj2ibOw6uE4tMfob+gQ6I7KRlcD3zL07zJjc2f+zMNOAB+1nXHdL42cPmHh8lvb3fuWOHRdse5KrYvfa4YdBhmM0jf5nziwU4tbsBCcJRHHoLRoRSGU7IYjicTXTRO/C7mw5z90z5jbo3puf3RS5mRI3U2CXhNiLW6qZAaz9fHM+CfoUMrUGWeLVxm1YR80xcLffMadGcSf0GIUpBXYxuizBJkptFkNvQuDGMNx9Ggop8Xq2jr2g6PXcw6co8c9FvMH9zOLxP/D4XBepgGl6UrBfUzLGlo3AJlxNMROQzZ8ubvCALwLVKJjB38TRNSagBGFt3aIsioWfKdbB5BjRHAmabWeUEhKtOqpCuSymewqRXm0hOsdETNkulGkygea6Gnsq2o0uA9Z7MsnUOvQf6nY21xtVnuxGbPkfaE9w9LgN/WzBZqNTBN0U91d6+/VGBz6DwW9RFuPi8w5W/Ou25br/93hOtry/Lly7gniUdsqVqSsSgouxK4/Lym7Km7DabH0iPRZbCj9hlUWf6jKRnWa67hHguEssa2GtdpPhew5idAG2hSQe5h27ebKYAV/mG2wV6H4DD6PY5jusYdvTmvrOTokh0eOMpNnuPYeWoUO3dGQFvvsXfie/5Ht+IV/72j8iRKgoLYCkTM7wq37F9/Fdv+BL9lkFKzakjPROTBMfffySf/QP/ynrcuI7v+MdzmGhXs6EeSKpkvNEmneAG/wGRbTR6wVdH2nLI/XVR7TlAV3O9LbQph273ZF1OqJV0PtGvV/RUmlaDeommB+kiKXxtmp7H2d/DjGRYDD1lQo+HAuurL6QPrv/kXGNAejVISKpF5qYGBNWV90gvnH9WWdnOU7jurjumvye3n6Gp90qg+Bvz2kQP7wBGbqq3o2QU50Wbk3dYLopuFbM9NHqe/HAHJR9Nlq2WVOZPi+Kkv3+bUHYR7hLkRJtpSAud8gS2OXEzRy5mRN3u0SmWzRIb6Ruk1rwXbh3yX5GWVFFDDEK6vv4KJZSgJKy7eo7yRqLIAYhjjPD7+faXZfmZr0D3rf4G5+UXQZh99eAel3g7/d8CEMkD1meNAvfwuNzXaRe3Oy4vdlx2E3spuQMMnGVtBgcGG2RHwb88IRGOe6TzVk8uCo9JiQkpHdqMx+u0s/WRfhCssdAzwnmCZ1t/2AU0mZaiVJYl3X7/gObV89wCjmx0x0Lavge4wO1gz0EtoCx1vrWYak2knfqSv+M2SPdxvIgYjDFuPHBls5q/ZkPY7Zg1rGojmZMOy44L4IblPhkurJJUreQSXCB8ZP2aHO3dhJKkOivxTrreXLNToDjfnZcu3O6rNyfF6Zp4mY/uSBaCCmgLT9RsxtEY3UnkOJEjIl5mjgedrRloRT77Hq/0pqrT44pWrR6jsnez3g1O7UO1Xdy+uRGeuIlhwi9WxSBxGgkiTh5Me5Gkxa7sbsqMWVUOlMWfu2v+ZXQznz0/oco8P4nrzhfLvzyX/qf8Kt+5fdh/lsuEG19pDuACv/qX/0EP/Ij/yvf/V3fhvYLjw8Lh+ORw+0du8MtcdoZg08VKY9oWen1Qrs80pd7qAutnmm1sJ9n5vnGJBldOcoN8bXX+HTpfHz6Juc2MfXCQ6sWftldkxhk0zqxFQwXOoh85lcMg9RwZWVKuBYSfXrNMCaVsFmYibMkh1OD+A5kfAamC+7bz7t+PJ8lYowm68rgs/u/u+zD+i//eU/JFj4FxuGF59dcSpEQEiMANQgUJ1rMKdAnNUeIGFk6hJw4pMgums0arZHayg2NN3eBuWUvqPYa5xg47Gb2k5ElDtmc+6lG9d72pWJY2jDDRsQc771uGSzu79Wo2GEcCdG/ziUJQSxZOAja/c0JXqS6S0EY54bJKkS8sOsToN/PCtr1zACPDcF288aKrP+ek/2zj891kXrzxZHbmwM5pY2ibZTiQIjjhnEsOZmR44gB2JajzdwJau/WxUSjIPcY0AAShaQBSmftxeCSEKAJ2gNFF+pyMkZTyqRpRmImTpO5nbsFSq/NIh98KRyQYZNBVzU7m2AHZsAmB1HTnaTE1tmgAQ0jx0q3CYauiHeoXb1IYVCgrZ7s9WwXENiFG+PVa6tfb9J/87BpPkGkGEGvcMigpw6blTFNbdOXyBanogF3dLAdYk4B1cmg0dZMLFknHpcZDYnjfjLhsXeeecp03C8uDgV8R9Xerymb7OCwm1j22RT6tTrz0tNU/XQa3WEQo/RqjFy95sBuOSviY1Icn1sQodF9HwA5T+z2B9zscfs+2ptHbcA075h3M5fTp9zdznz1P/8+PvnoU9795nv84u/+ErvdzJe++AU+fO+n4Y03kZjMj229wLrw+PDAj//rr/O1f/B/8OLFc/6zX/YVJDRub9/gePucPM/govYstodo1t4zKL/khKTAlAP1Mpn5sgTybmLazybunQ7cvv4F3n72C5D5x7j/mZ8gnz/icrqgnivU2uoTbt4+A8OVdZtAxkPVw0J7J6dECJHgGV4btdwbSG8NrHn05mjsG1Fx2y7dcuPsvR7F6docDcr9VijjKJT4772REPXlvfq0bobQ43nbrsvuE/FDeMTXbPEi2ump0zUytUBJQp0iEgu5JubcuZiAjjkZWcqgtMaklZuo6C7yWtzZvi1A0M4UhN2czRdUbD9Eb/Ro9lWCuObJXNJ7tKaRbjC0uoWZDfQ27UiEEBMhjkiQcCVYaTdGaw4GdnpyAVItpMGp6DY924qjFPOJ1BBdFeOITbPpuPv3FW86zBpJKOuZ/z95/xJr25ad5YJf648x5lxrP06ceDpshzHYCZjEmaSlm45SZoGHuEZCwpYoYRcoWYYClhCyRAUkHgKlEBWgSsUVI1EBWRZCAumClZAIX5FcgS9cZ9rYEXEi4pyzH2vNOUZ/tCy01vuY60Rgh3VF3nsUU7Fjn733WnONOUbvvbX2t7/9f+vK4/5tEKQ+8fyeZ8/OzoIzuCjm7D0EJxSMpr9k653s+2S4iDdllQEbDG24gRGYw6W2HXpFWnHNU7UMXASpNn8SQmZJC0tejdK9rCwh0vZCUCjgzD+rEkwtwCE8bzC3GrwpyWTPmdKCXcvoUIZgTfxJGtBj2n4og8twvdSRwbqjL3i2dQP5mS6NVUXiVtAxTrNEOIKRCcjqPMxHA3m4nd4GtJG9BpHZxyLiDCSDaFWT04+H++jKXe3Ujg8eJ2JM5vmUAn2wKmN0VW/7jBN6nJmw/biwZBtCFcvgog9Emkq+tcfN2sXZkTCDsN1xJ4M4hCF+/5KA6RTauyHuuNrabGKNJvqAgtZ15fnLl4hW6vbI48MDSuPhw1fI3ZmvSeflOy/o+zP2Nx+Qfebt6199n//4y7/Ce1//Ond3Z374v/tBvvCFz5k5XVpoGFRDq5gq/RXdHoDOaV0dWooUFO07tGLIQqsELDHK+Y74/DkqndgKn3r3k9z/n/87Xv3O/wNf/k//E+3XfoX+5n1qe0QEclhcvcLus3iWrqKu0iPzlw1vGlQ+2JGEhF+E3W/xA9BvvE1EuLeZjkesHPtg/FynRTtOG8JBkDpGM+xNZb55d4KE9a2RYKr6Mq7P3nuyz1Sow64lRx+/cJq7Cq3LdAiw6jGa0G5US5b3SgeSdLSadFXEHAHuEsQ10oIRwcDmyYIYIWzNJmUWYLYNBkpWWqO2TvluWroAAQAASURBVFJ141KH1WWSxg3REKzvjinTmGisE2q6D9lO5MNufh+C0dGct3OyvrndH9srJUW2bTNJNQEIdAJNbDxBu6m4Jy8gbPj+sB4JN8nub/b6WAep58/veHZ3JsY8oQezNe4z+x4VbhsNzmzYae+dXqtthODaZw49DNpm9gAXVNFSDNZRo1cGx1SDBs+cA4njxmsMdI00KdjcltBboW/u4RIjdFcE9/6FQQ82U3UbfObmhVFN2zCcxLkxxmyW6vDxGfCgvWfvMv/uVgrmtoc0WHyCDwnOzX38YmSXHBcknm3OQOcvm+2yPlm8gQxtRAC3eHB402E2UNbWnKTiizxlmvis2pLoPXjcsAvRboFWZbSbg48F2L+LD34a1ORiw5auz2BbZ/9IJ6XXAvtttn706wRzEJbeqX2nlc2qNG2G3/vPNsa2PRMkkNcTLz/xada88vzFSz756c/w5usfsD285f5u4fmLO85rIkpFyoWHV6/59f/yX3jz8IYvfOG7+N7f+d18/nPvujPxySr20x05L9TrlVaaVURaidopj7tl5dHgSW3m7CpaEIoJwnbI8YSme0LItHqllwvn+3fJz9/lxbuf5cPv/F6+8sv/Ix/+2v9MaLslZiI28ONnjfUgxpoDn5fw/aUOY0fvj4J0RzVnjhHmIPVkg+qhYNH1WK/dE4YBMU1fN2TugyFiO0cqJrwoI9+znxlxNZpoxJBZTctc7jEdxCTrYY85LUGaoSE9BJoYdL+kTJFmH1JMgk1Eoe6g0dGNTo4YlB0tEODJWnQGpFncuApKsDWJ98ly71TtU/yXkFwoeFSBhqREEbIoSzC5M9s5tveaiwGHYJZDIkLops1pxq525qUYLZC7JJOqmp+aRrZd6WGgQn5vRJE++uRDYs3OiK4G5R4KOr/562MdpFLK5JwddsDWXohoNwM+wVg4wQf3YjK18loipRSK6gxmBjcZvKa90coGMVmwCoIsiTZZOgPi6HPQoxZzAl5QQk5GMIjNDsVgGYllXI2+7z7zYZj6GDAxccruc0lPIbPWqg9EHhtoBODx6jeV2S0+P1WIJ/wRn7Dwbn+HA6IJPmIi4xAQfEo8zd7SQRc2aFNgklOCvTFop/Xj6wckqM7AmnRmdMJIfcIFngk7dRmRJ+rTs8i00U+7L9h8kDjbyOCh8TlGAHcpF1etljbozjYCPYP5R/oswIR5rJqzmyR9J1PpPTvUZ0fYOCxSXli0odroUpG0sN7d85nTyuc//x3U7QK9oj7rdH37ltoK+RT47t/xOb7n+34nd/fPWNeMxIhGs+ioGmhdKMX6FCLiHkkJyiNJCvv1gcfH1+h2tUHTBLVeibHTQkLXE315QV2eEeOZJpnehdhMBfv+7h6+9/upy5kY73j40n+C/taIFG450ZtMZRPR/g33bFS6evP3pmAipv+onsgFq0q0H0r+o4ek6oaY/aNr3OYHh6rC6KEeMOAY4vUAetPTenKdDjMquDWGR6jg8CTMiGrAi6MRKgfVHPuMa3ZClZgzg8Fy9vyDX3sMnZii2d67+NjU7QtHFSSe1I315nrRqD/r6moaEoKTlI4esVr9asPsUQhDraY3T3bt/g3rdxGDB6Mnwdbf9z6hs2NDCN77tucaxK0/1NCFlpTcAlUH6UlsT/k2XdZ03N9v4fWxDlJjkc4D3R91G1WUgrqkzlAM6AwabKT3ZOeaQ2zBATETwWyT6WPwjnoWAa2VmcXPhx1sqFRdjiZFXFncOYEBz4A6Iu4MLMFhx6FsoYctgTKZNHhvyzat4euWiY6r4CZ9Pf5zNHl7H9DhN967UVWN6u2WFRVuNzGWIRmc9nRuahriqW30MUBpQfOYoZkZnj+vJ/NXraFiA7Q26R4tIdBB4IjzAOnjcOqDQdRRHUZRXkvZF+KhabIXw5hJ41g7Kn2S8UTVXU7t14ApP9qY7wohmYp4b516eUQub+E+E3zo2w4b6FN1PExIRkJE0oKoZb7hFKHtbPuF6+WBy+OFEGwNpjXbtENU8rqQTmeW05m8nInr2YKWdmzs3H5HOomd8vYDpD/C/sCr977K/rDx6c98mqYbcRHC3TsseSXllZ7u0HwmxjMh3yFxoSnEDKdw4nPf9btYVfn17RV8sNOiet+tUsUqSIOCXU3ihoAwNPgkjJnAUYmoiQILLso71l7wnMOo/K11D/xjJm48l1ENHyXcLTMWGXva66luCguI76cwfo7v53YkXsc+OZRYRqU1gkh1sdnbHhkilph64IvB3KpDUbMKGp/R7WBUo322rqjLGQ34fMgSxanwYsFSFUK04Be8H2zVZJwB93bHp5CIdGiOFPV0o8w+UKdxL7ufLQaJjzNi2CHZNSiEDjfmox2OJEKV4L5UltP5GIIy92wb5+tv8fpYB6kRbAYkZRG+u1LyOIAtQHmbHVGTCZFoh0N14UfB+0AjYMjI3AYG7NOeN1FEsT5Eaw59tMZ1u9hGFMsoazUViqBGg+8IIaotGPEmpdM7q7tgot2+R9QPSS/RB1SCHepRjt6QIW5+D2pziMOa/a7wNIP5E8bazfeN/7bfx3+LZ5xh/lk7iJM77L8PxuSQfFW/Tw2lDpruNzvwxb/+JtGwhq5DHBOms8/evD90yOL0eU1T00KCHULBrnPARePAGQeh0Z2PSklHv0I5KqjZizsav2DrokuEupMlsHdTE2+tWK/M+3u25xUC1BssbMgMhZApe6P2HdWIhBVJFfKCJEzAtHdiSKynO2Je0bAS8j1pOYMm+r6j/UKvF7RXlzx6i15esT28QutG1M7+cOFX/vP/wtfe+xKf+tS7vPPpT3J3l4kSCVrt+8GRh8WqyhQofSP1TIyR55/9HJ/6Xb+H1/9J6fsbtG7UshFCIXpZezuLd5tIDkh37K2RmEk6AG0R930L49Eam61j91DHGnm6gn2tHLN6s4a6gbONFt1HrW0Z/jhEvaqbvU2ONda9N3YkbLa7Wh8zVME3oK/XbomPmWca+uEYg0FqN4h5cwSneuRKIc2gFGMghkPy6LinnoCOs6G1SZS63bNLzqzrauMxYDJu23USWaIM9p33qIZdjzOQabiMln/ukJz5fMyDhdQnatBUCc0k2xRhXQYJxQCV1hyq7VZx5fxtIIt0m70Y7u+NRWzwbER/na6xo+nnvRmHANGAdOsvWOVjjX3UFKQHFNdGWS0je7dANbKr1uoMWBKMTdZ9JTfM3mPUPoIxrmzy3TKQ4CQO21BOyPDPmRcTQlVfINbzGBYFjkp8JHuyjeileopP7tsBh3BjDX7cyzGEKz5nocEb33IojwcJpl0mfj2qDrEditSTOiyGSY9ME/85wvizVSajvzcOBTxrFYcUVNUckBnT/Mch5DQI+17FlaD9TozKaFZv473tC2prnnvYITjmTUNavaKNEzYFiNrRkGlbcDPNBcF6EykEmk292rpslbpvNmTdD8+jmWVKI6+ZbSuIQlpPnHmH1isipiDSa6dLssHh83Py+oyYFnor1IfXPL7/61w++ArP7xZiqDy+eo/68MBXv/Q1Hh8upleoneu1sp530vlEXJ8R8gmVTm2PhMtXCfkObRd6vSK9kO9eEMIZzZGunef3z7n7vv8Tv74+5/GX/1/0+kjbA60kH++wOb3j892gHUMte/SVONZJq4dmnyUOGINWlZzHIdeoQ15rwoBHn1Mtjh3MO5iSXOPr4rA5RzlgcJ2J0ljHhgD4OaE2KDu/5iZGxrxMGNEMFiMqzdyJF4e2/cyxPaeTIQpCKYVaq/mmqSfK+AjHSLDjSAg7IQZSSoCt19aVEDs4i3UE3xgTy7pyursjpmRC0r529+0KWF/IEJqAhoATHQ3adJ3TSIDutkNBpnqEnQ3REz3vybWGRJnEjBRcwJqRRNv9S45IfWu0iY95kEKGC6QrKozJVD/ktbtY4mzYgQ7hVKwHEmKwr+vjoMMXp0NSjpfXqnNIVhCDkhSCOjtQcdaZV3eqtsk44IEoEe0m2XSrEmGuuAHE9ay0eyAy4VAz+LMMa3xG6zkeZI/RHrYhUIep/LPHZPbhve4Tu+4+ywMcWWI4frdzxD6jLTIP8MoMpFOPa+D4ijN3jglzC6BDRdwP+xGo4EllYf230TsL/vUHDdlIl/LkesRhnu4HjHh2PGAPme93BNZJkLk9bIJtcIuXVnfbmRM8s03zOd421lu4QwKcCPTyQG/v2IEUOuavKvQ6DPKO3lirzSfxj6x4WRZagb53ltMZK8670f5Xp+Gv94R8QuJq621/hLfv8cv//v/Nl77yVe5PkXfOgTdf/zrvv/c+z+/OiBbu7ha+53u/h7Rk0qIs73yC9Pwl8e7e1ntvUDdns45sR5G+w+kz9PXsmpaV5f7MZ37H9/P1yxte/fr/wvMc2bY39BDZq/dlnXFrj/igc4/EMoxDrw2jwD4r3Anf9UO5ZKy35NDWQFCmtJHohN5H8AcLUnIzKzUCW/QK56iu/Rq7GRsa4/aongI2EC6+LY61e5hixmCVRcxG57be7pCG8iDlzD3rm1lArdUPcoeHmdWNP4NgM1Jj9jAMeFdBSpuBsA9FnBCISyYuNpYwoGYNCUJ1w8xmibTPYo7IOyrEGANpSQYl1zGP6XC77zf7vKaArj6XadtoDEZ3QgqepBsEHG8SF2kfrYi/+etjHaRigCAKarMoRh12RhFeY5q84lgmPrhrC683t033ABPGwQw+ADqqhIF9B5DmGaMdNkMPr6uSUp7ZIhjLNnmaFQdDaZiEed8lxkj1RWlZtw+fjoxOTN1iZHxxMKYE/2wOpYSjv9R7t+sIpjsYzUDKIRSrKLvfszGUO3Hsef5aphdGRjpx/VG5ePCNkYgSu2PStXnj/KYqG5Wbw7NmJCjHZh/YuDM1hr3EgHNHkBqfdUCN1nweh4r/PIdAblmGT0kkYR6Ut38fGfHy+Df7Fb2KuoUNxwFin1+BqFD7TtedvVc6CdFK6EB1KHB7ZN+uaK8k6TTsutOg2RPQCmExJmoQ4TxxfYMjm0ROKdLqRqRyff11fu1/+g/86n/5Nd7Wxle+Xnh8/32enwKffHbi81/4PC+eZyQ0lvsTn/38d9JUycvK8vITxOWMhoyEbAQIbU5RL+j+QK0b1lD5FCGf6Jih3rKc+fTv+f2k08oHv/LvOT1fqPvGeXEiQEgzITDPtz7hNxGbRzrgeZ5UO7dV/uh5dnehNgdejoRpVj239hHHs78l96Bj9RoZKt4E0TEGMga/RcTYpDpmuHRkXEy3WVcb6V4hzdmvYVfiXyMcwTZ4QjIo7K02s+hoLgF1k1D5oeaJ4xGojggsQDG4oMU57DzWeUpmwioI2uy5xWR/RpspqGjBlN5tzw+1m9kHa5Ea3I0A8bPPTlNrORuzT/2zjwBs+9eShNYb0sbeccFo9Tm+b+H1sQ5SvVQzDwRmyj8O8AEpGAB6VDeMYNVptU4dstFrAEE9+xkZ0LjpKSW0+vxGuDnIumUQMcUpBXI7US9wBKnenTVn7xdiIjM2myuOj40CHkDGphuf1DZMCGl8JGBUNUK8mWeIKR8Hsw8xam/QBImGQ9sFHRi+8DRYWZCQJxbYjBgaDoKA+EIfbKQgQ2roGK6MM+AelRs3n1/VVNbnQo831duAdCSSHBqMyT/TCExOLBm37Gn/7TgEbqGiASuOg0RkBHwfQQg3A84c160T+MTlejraH+nbisY76J3WCrVcafsD2/ViPj3GZ6P0ghKJIfsIgWnoGQRo153zOlUbVALUDb18CLXwtQ++zq/9p//Il37113n3O76Dd6Txq//zr3N3vucz3/GM7/70He984sT9/Ym7Z2eqJMLpxCmfyOuJ5XyH0aOFUjshQas7SiMjiBRTnDfhNyR8CsLKMP+rpzPvfM/3czrf8fpL/x/y5TXnoLTkO0pGT6c5kzN5oBiBZAxLH9XLITd0wLgziKjOSnmoHgyqm/Y+pZomPVzEURL7mqF2b8+TWdGr9kPSywe/Q7A13ydcePTZxl7sXR0is+AVB3zLAcEPl+0+SVTMoIPeag8+JWwckOQgDt0kbDdfF3Ime5CvHswtYTYt0CjulAtoD0hKaAj0XjHh5EDolngjMl3AEbv23huhpWkJ0gx3nYiTwvSnO/rZAxURNCiBOM+CUZWOfu+38vpYB6lWCvWGaz/EK2ej3SnR4xeM2gO0V8yvyGCDOfirShObabKFHowuDoApJYjcClaaWyxD+ZkjSCHHwzO6+YFLMw7smFzMtjtei88nHJtZb5QQxgExGvMxjJmnp/IvKWWrAMZsig4mYqfWQhDLtCQ2D4Q4HXscBl71JPOm8v/57JIFjmFWZ1lhMHz6Rqooul6YfKRZNhfrvHaDvLpXviNIjftr96bhZ40/5zwlevCANynoIwh5MNebZ2ADhl5J9xGkPD0JpmYgHFWpDuRcRpPdn8M4hQBVe/ZBG+v+yN4XWhaaNqru7G2jl2qHEaMKsIogLZEcjSovIshi8LHIgGFNKV8QatnIsnP52n9hf3zDV37jS+j2wPf+3t/FZz//Bd68/1W2r3+VF59NfMfnPsXL53ecTtmuLd/xiZefJJ3uoOzzUE/J4eleANPk691cp1NUq3rbhXb9AEkrcvI5nFYtMQsLLz7zncQoXL72q8Ty6D0Hq0JHkOq9fYNW3ghStxUUvn8sKvONB7dlELYvZ1/WyVK1zSTjFva7rdJmUsRITpiwXHeY/TYQKaMSYu6DkTQ5IjlTx5vCnEHimFlkPz7fUaUzgyC902o79DV9f4H36dyYc+zBcc5E3zMmQl1nVWbMwKGKYfctpEzMZvdhDsCd2DGFczXVjWXJhGTsVPOaczjRB4elNh/WN8at+FzqE3LRrPS+cRRhBKkgAY0738rrYx2kaJXeDnNCDQE0uTDpoGoeJf6oNEyReUihuGSK4MwDV0gQt8hmHBaBmIAQjqxH1RqFIfqisqbqccCNjFIcZ/bfVc3mQSwI4BJHxpLzykD7nDcZcJOt8UbwQ/xpdnVkYzEmUl7839xCI9gi7q0TY6L3ZsODgqm8g0k3+YxWcHHTlNLx2VTNA2cc8s1waGuqhjnvhR7zZwND/yhsiAw4chxcQ7/QaOjykSDVWrEAIoKpYqTZd7JA7IenGM6vN141E369gfrgmKsB6/1E4Rs21SDiWEwKx+cQh5KtLWH9M4TYdoI+cCm7DVoGpZVqFqzNvJckiLOrzPQyYCodIR4w0poXRJSyXxEttFYpl7dcH97nzVd/nce3b0hB+Pz3fDfh2TNiVJa68x3PI++chU++WMjrifP9PcSAxIwS0A7Jq7PeOk0aS7TgLZoxnTwj/HRpBK9yZHvDGAeQeCK0ndNu2XVthZfvvmAJn2Z/9VUSJrgrNy6ut0nUtIPHYPbbs/12Xu/2dQQXbO04DNt7o/WK+9schJ0b2O+2fzqIPcMGfkS0sS5GEjFEcYf6yOyV4gnK+HrTbLKAMs4hD4DaPxKovK0gMj7P+F4PeD5L6BfGqNZubXTsszSzA2oDmbG9pN0Crbro7G2wPX6386+UQhqzhh6kUkqkHB1BcNWJVs2rrKtDkzaeM5Ls4SAwgu643/N8lKNnfdtJCEFI8u0QpIJVEs1hmiHkOkvqm9K49z4hIlWjh0aHJVobj1Gs6YcNzQ0LgOjQj4RoCgN1nxl4dLHSITnUB4bsh+mclvdrG4sw+DUGH9DrlqrNjaHeMxrMHhkOr07LDRLcbdMefWtlftYQTDw1DtX1OCDRlVJMEofeXeMwotFKiVrNiHCyeZJXWwO2vPld5JA/OmYsDmZfQGfwkviNUMU4ZGaTfGS2N5ns7bMbcykjWItXgsoNMeKm2lEPet+Qxd0kGbc/J97YDdxmg7WWJ1n4fK8AqDANMv39VYREoT2+pddKipntWrhuG+KUXxUj/MSc6HthL4VaCnlJ1ohuHdmvPoO0U7Xw8PCW/fpIvzxQ90qVzPnunuXZc9pW6G++zMNX/jPL1kine1JeWU6rEW9S4nQ+oaIkaYR8b6w0u2g0JDvcamNZMylmVAIuek7XwIlKv76iS0ROz0EbWjeopqNdauUuR2RZ6fXqM4WJ1kbvWBBngt7CW12OTByYMma+G2/W9NGX6yKYRUUgkqwB332ujuOZjmpa59/7c0TmM5s/a6xpIhIDQdTHKPochQg362VWfd4DH5DgSGpmFebVf/ehdvTpz7td52Otf7R6/GjS1LoFCft3JrKhDnH2Pvwabj6nB5LBts21+gijuHq8e+6NPpFaEjqClOB0eQ/whqLYvs43+3hUhjY2cey14/kdn6ey8a28PtZByg7AOE0HW630UhgAnz0gx8ZzRKNpYBmDxm5sbUND7nhP05ITguu2mGeTVS/ajEI8HEZztINlPlxnKYU4ZFY8m+ueubZm8KEMqql9BmTM6XgGwpB2uq2kvA/ilNYxQ6FYNoTDQ01M1Tx5RTVo9M1LejBViOiHfrvxeirlgHtiOggKYzDYFtjorw0o7Zihsr+vjKAzGHpzZm1ksjc6brf9uym+C7OiDF09WMLIJkd/a4hdIgeTT1UJXgHCyNw8iAaTr/If4FCRBY4Q7YBi9BMwsdaRWNwGWuuzwFPzR3P6XaTzziq0HNk1UK+dt1/9Mq0UuFtppXJ6+YxrX8lvLrx5eI/L4yMv33mX2ipLTFw7fPjmQ95595MAPD6+MYhaOz0kKkpIkRyVfn3Fm69+hfL6Q7R3UnpJjpm8rEhKxJxoavYwOS2u8DAycDOhi3mh98LeOnE5kZeVsu1Irahu9JRQGrq9Qtt1Ji+t7dTi1R6NhWIzP1pM1FYCXSI5n9B2gW49xdCDkxQEpd6sn7FvPQt34HusI2QkeBEILs+UoR/CxlOCy3bENw1IHjueJJRg+1k8cI7AMujit4naOHi7V4hRleGwYInaGIY1Rp82+7cg4clBPdbhqDTHaxCO2keC1IT4nI1sKMXohTnBSA+C05MPLzL3a/drHSMu3tV9kjQafb8ecOlHmJaW3NnZdCt7dvsZBnFisv/CML0Uavg26EnFtBAXs2hQBEIlLIsHDVvE0Q+fnLMdKHSDXLCDq/rDHg6cszztB14+elYxmodL2222AUwaJy0LKa+TJecQMIzsXgfu3cw7Cvu3oEYMMG8lK7E9JbKDlSPbDx7EdBzO4v0ebhr+YUBR3vwdC1idqq+dqKYPZpteTNcrLvZ9qsRoi3IMEx6Hb/UK9IADbHFbFTYhTlV6T3af1YkOIc3eWYjxBnpxGNXJIQZjeAU6+kmjH9V0XrPvNIcRjmcWBmsQnN59AEkTWhxWGtzAMBh1V+UgqsjNe3xU2V1EzBU5WH/NjyPrBWYhNHNs7r3z1TcX3r59zbou/Jv/4f9JujTWvBDeuac9v+O9977KKUHZGpVMuFs4vTixlSt3z++4NDWFdcW0Jq9Xoir3dyuffvclWSrxBP0usH76He7OK2nNpNWgtiVZEoU36lvr9HYhaIQQkZQIcTWVj3y2ybwY7X7AtMWonnihCm2nej9XpLPkwPXtlf3yliXazyl7J60ry/oMQqCWHcioNvBETwRiElobjs4eqDwZcUiE4fTr5YD/veltBvxzoBPivYX5CEeQOmLCYOLZfTEyTDzwKLC+bxsHdGd4SdlaEpcoC14dd4fy/L+7QtJJmjAK5+idiX+0A/5GDUm9hT1VrS0Q+qArHX8fQsSlKm5GLQbB6enXzvU/ArOK75MDehzJvLUqkl2nEz1UD3RqVISjYpxBSph6qK23+dlqq4zemnbvsUtwgd5Aub3Y3+T1sQ5Sd89fcHd3IsYEEnxodBwmcfZExoyLDcqNzMw2XQ8yK5xxEwVo5TDwO5QtoPVqA3ilgFpJmxZjS5mOoH1d78fQ4ei6DNLD+H0SO8IxM2UN/sFsO2RlbrH8STiQAfV5w9j7cEZv92l0x7vtBylDFkZ9MYmMbNPee2htjabzUQXdQg/HUPRY4OMeDWaTog4VqA0GuhWG2Ql4kiBPKymxAS8/ODwz7cNl9IAWx720e+Gb1KFFu1+Cyx88eR1sPA8qszdwwKj2WUwVQGVYkj+1flBgWGEPCEr986BKTskCa2gkgTfXR86feIf1Oz7LL//8/0Ak8hg793dnHi+P3L98ydsmvKLxye/9LC9Pwt1pJZ1PNDFttrdvX7ME5d3nL3hxTqwpcrcEUqts9cq7n3hOvz/Z81sW4unkVYs4W817B70R1Byoq8L5/hm1WVXVRQh58QpBCTnbjFcw2xCZbFIbkA+9EiKUckG0kVLgenkkaWUvjbevGmk58+z5O5zOz9BwjxLQlBCG06v1AnVWEp483FTsI8Gw6sjXX4gI8SaWjUpGZ3U7kaubIKXK/JpDFWSwUANaDQUYsL+tjzjX221vTIJpfqJj/wiiY0Skec4pFigdWpOPGD6O6xFJTxh/I4gHCUefzcOVCm41o0/IPNwQrPDtPjb/mHFS3zMqgsQ8Pzc3/zaey9GTG0FqMCQPFqYlugPTtKpxvqonpB+prkKInoV/G1h13D9/zv3d3ZShEcQXGd5eCm7zYBLywRln9u8yGUYz0nu52rVT05hdOOYNbKN0Sim0WkD9UE8rMa9ONLAH0Gqdk/XhZuFYkKj0Wmiloq1aLBErw4daRXI2T4iJoSo86KXjaywrEpJXLE8gM6v/mXJFElAx3ymTRRmNYfWp86MUvz2QRwDp/WjSjg10wG/Hx7NA5J83BBsgIkwWHp40jJAh4+tEMEsHZ27Rfe/I7Ac8wfDnPfXn1mfiDZjh5UEpH2y+Zu8rNyzP0Vto3bNus6kXdUkiOWDHW1hvBHYdA5ejN+GySqJtZqrtutMfd77r+7+P//I//jJf/dJ70AKPry9wt1JevOTl597h+7/wSZ7fB+q1UPdC1CuX1xvvvblwd7/yue/6PN/xqXe4C42ojTUqoQTIC3frSkkFiYHT3T2yLCaX5IdU60oMNj5BN0i3NSXnTCW62kciuFZz6wZNNQl2+PhBiypCQ3qjbo9oEkQbbb+YfmGr1LrZOuiNy5sPubx5xf2zF+TnLznfPSfGldYDEozAQW8WpGxSa6wk7/fALUxsGzt4gHKvOIfGB1HhlhkKT/s8wffISEgQ8erbE5SUnNWmXnnMPoCvNl83ggv5jiDoCaKPDpjvmX2KDjY+MJERmQGzO1oTsflCVa+8/BqboyAjzxxReUDcygE3j71y7MXxs45xiq4gIc5h/8myxH9k8Hsw3mNA4gPibAeR6JBCG7p8ngAPuDz0w8/tJrkU9+yS+G0QpAD3TnH7h2DyPyNjvoVmYvYJ+Dmb41kUGJwQR1PSbjLBsmEJ0ambLkukQkhG3xXG1HUmxOwsM+/5+EF8LJoj0+tqFZyImL8zdtgFGSwYsWFcJz1MyA8Bd94VgdI7qjfzRNhGHaKoquq+OYpiChOHb5PNgzWfl7BrPqrA2w1+SzYY2e4IDArH5+o2RBwle9Y2+gg32R0yg/3YeSMA+F2C1mcWbFp6tnHHoTEDi6ofAV5JebaPmHPsUxaWX59vLB0VET65795Sw1oDr6RqK/NngUM9vvG7z5YcFZgF11r7wXqKpoj//gdf5fTiBb/vj/3f+fqXv8qv/PKvcP/yJZ/+3u/hnc99hmcvhM89g2f7I9dL5csfvM9yOvGV9z4khxf8ju/7Pr7zs58klwdifYtWYzsmUZ7dPyfnhZiLQWD5RMgrUQK1FgugEqjNEpScFyOKxMa1FEIStusVm9iDLoUeorn8pkT33ltrjRSEXjZCr+xvvo4E92Rqhb49ULaNvC5s1wsxZJYQ7Z7sD8ge2KSx5HvSeo8EU9PXIFZJTDTPD+JJwPPjOYxqwhEGCYc9SNejGh7qCcoMGuN9bmExvXnviWfhKvJYwiXJkZUR6GY15t/rEKgBFb7X9IDuDM47mH+jEu9j3wRB1Soo0ZHI+ayiOhHB/uABmSfB0yrSAXnz5Nw72JS+d/zagwvz2BvPbxoXbOdn8Pfxb1SXmfso6xaMJVsdGo2zMvT7H0xubt7mrlR35D1417/562MdpEwfSvxB+4HlLAgTfRwlvZEMrGIC3B00BBx796rBIhXSLQscEkc552kXDRa8NGVQJU4cWJ+IXA6ywVxMvrJtREoPCGM0fG/KkY4gsyc2JrTGoh2ZjNLchmJM34s4LVyPJi4zHwVEpuCteuleSrEhSD0YOPalx8E7g8oITtoIobtNxeEvM8r/Gey9tzRBNhFUutPZj3mWW/bc2ODG4NK5uS1IHdTWUSXNqs77XTG2WflN/6LeJlTa2jbv99ib9vhs4LGLV3jjHvZ6HGTzNhqcPA4ylXHi2bX0BhoTW++oJp6/+wle/dqvcPnSAzvCXjbOn3rO//W//8OkuHJGePeukfVD1tNzWn/NFz7/Gda7ez77ue/m9YNwevEJYlTCXlkSIIH9emW/XljyAr0ZKy3a2kZMlNUcWgPresf1ekViokkmpEgSnCQT6YgNGvsppinTa0VynGuvlkJeF8p2gboRtPD45q0lRtppZUcQLqrkfKKWxuXxEZHG6XxHbuaC3NpGvRSW85m0nGh95aPkG7sOCxrjuQePEKoe2PyrjWHZj8QnHKiIzrX2dLTAU31U9ObnHmcBcthlWELWXFniqOYtKbEk0xKdYy0HOeyDTIfTriMOqWsPQAdL8JaA4wd8bySts0I88t0RBG72hvqZgnzDZx0V/1zFDlYMkoTtX6NOaFfMmX7WbYY8eHDCE7MDzcGhWtsTt67eqBKjkziUuWcH/X8iLr/F62MepIxmrY4oCSbrbw/frSf6oSM1snjxhd5GRnVT7dhMjjNYVBFNdLGN0KVatu2L1QQXZc4l1FbAp+tH1oWOzPxo8COgrdHLjtbi/SHFmsFyZCxtpyenn4s1hrV3GjZZ3rpReEWiqRMjfu1jQdrOscPULdSrKbPXWqcacsOqpTz6SuqisN4gDW72iCoa1PtNndgjMYlTqv2ZVMybCQscHa/M8IXe481mEW6XaR0wh+PbjMzRKe1zkePMp9aO/oP39dBtVm8hHjDHgCbqXkxphJGtmg38NCnUoyrsXSfTcEK3HrSHosWUyFHLoKMfpA8I+3rP25Q5feJzfD4tfPjBB7x++5p3Pv9ZPnP6AnF/5LQo2gsvlhNyuXJpcHrxkoVGjpHKwvNPviRHQXQjrmebVYvRrDSWxcRueyekBUkZjZkYF1o4m7dXTPT1DLqQKWzFfKVK66zrHa2DhIzuV/YAxMhJEvvbt8RnhaoLKQt1v7LVK6FtPLz6OtKuJIHL9QENpjtZa2NdBUmV+7uVdY18+PotH75+Tdkr+ol3ePbuJ4lLppdO6YV4WqznKULvAYh2qIdOVgiuL9fFkiK0uXRmc0jNeqzqc3u27m96OXqgA2n4Id2Qoo4g5U19ZK7bzoDAlpnsWcV3ywp8unaAKQ0G+Dqy/TjOD1WTSJpw4khGxfuy6jB8L/MsUl/7tsGPa5+wNqNHOqosS86tZzq0FMNE18fP7mr7LbjIAKOHPBNv2/PaFI1MpQ4brjekI4WDEHUrS3Xb6hjV48IdvXfKtxh+PtZBai87257AMV4Q7yv4Q7q5YXVIkowHKVZZ1Scw1hHtm8ujmFFimjRwa256FjKSvm6Y7JzTGget/37b30lpyNeryxM1y4IFAlYGSwhoMHfMEAKEgxY7qLGtVUppZiaoEEJ2AslR3tub6QzC1vA0vbC9mGJ7V4UonE9nQgree1If7hwwRp/ZWgjidNJg1VTp08GXuTgdkgnH5o03FVqcEinHhgX8etpBKeNmkbf2pDKtLmVjCYh6v3H0MMxnJ+fsdgCHQn3bC6XsHE1h0xKL4elWuK2A5yHBMa8zqmH76MeaSxLp2qgho+dMDwtEZb17zqfWE+/UT7HtV7btyvVypeyd0xLZ952lFrRUsxPvnaYZXV4QTp+wAKWRLBkpgmhlWQX6Pb1c0a6kvKKD4CDCsq5IFFpTgwBbN6PCkGkaIWQ0LNRa6HWnbRfYOsupoafA/via8ubC8uzTaMokxPpNfafvj5TrA0uKRBH2feN0uvP112h1583lkWVd+dxnPsmHr97wwdc+sJ5dSJyedeJy5nS+dw1DJQ4mogQI2deiaV0GUYTqUN5i1VQfkkcdG4wHxmgDmMiv6kzQbkcdopUHDJsKBgFnaANOyM/XbbR1xPj/uZ5vFw03AeKo0J6w7GaQal75jX7yAaNNJl7viNYbtOGWDCTzZ49/P4KiVytj/sxbFhAOpEKbqcNgrFp7LuOzuUeU+z2NgeshcfZk2Jnj2vR2SFvNXXi86RifGWxc7Z29H8HsN3t9rIOUYdI37DDPnm+D1PzVdcLT89C5yZr1I38WxFW3fQZJle6ubH30rcbXY7pZg5Y+4KIhtnirDjEGYMUzFBsbd1kS/1ijEpgVWzzs5OmWNbZWzc+mK6pCD8VJGUrxwV4PlTNLa1rpHUrr7LXRum2GnBJpgaQDZrSDorbCMHEcU+ki4qaE9hlxiwu79weMcQTnYyOL37MxizafRb/Vb3MrC9Wbn2e24trNtXgoe4z3nA6/4ahyuwilVAtAN7227bpRym5ZtmP5KXeC1G9cMw45HkFKzbYhJIM79Pj6cRhVMVVvTQliRmKyaoYCRcwxeMyvhUReTmi9cC0XQg9oKfT+lk6irffkdz6Fnl6SdSfsD/R+MaUM9yHRutHVHVWTHewmtdW9VxoobaerDap2FeL6nOZZtum9GZHiWh7R12+p+TV8srC/fcXly1/m+XcW+xx5obXqgcl6C9fLxrNn95QdD1RnStmp24ZIZ38spCic1sTzd17w8PotObzPkhZAeBQh5hPLaTUyRmuEtJCGqkRa6WJrI0j0UYwI3eH9fhze05BTLM8ZXF44zgNbdAPkkrkWxREO0T6b/3hiigw/qCMwzCNopoO+XkfmKp5YysHuHYEHbBzkgHBGP4q5WSzwdCPg+F9Y6odfh1fw441vr0iHCLaxnS3oJ4L4mIuA9graXVHkRk7KmrGmgqI24jEDeLTkJ/hljj3c29izx1WMYDSey+2NE8yFIW/fBqaHYD0Y7QoxkqZL5Ijw/mUT9z3IC80XY9MDRpq25jr6KwdFOnh29lG2Sx/MJ9Wp1CxORRa6qaBzPD9VHIbyg9qz/t7N80oYzrBHpnZM7IM102zDdGeqWVWTjBfVlSDNFzloF3eyFVBBxQ5b0eSYsLgMWrSFfOOWi3az3sBo5DbcCNYnC6baPRKC+SFtEwwdwOGLMyFPQF3H7uix3cy2iN5Qh49ZExVxOMn6XYPxFUJAUpzZ57Sp77CpaRTmJVvfRTulNVoHc2wZM1HR0UU5NpRXyur02uCkiqEqfyuYeUtLRjDoI59oaSEEH+rWExIzMQRyq6adFozMINq5tJ0gZ2JWetnQGJB8T17ONImkuHiylSGvRqEW6NsjxJP9t1jFGFMiiR3E5kIbDLoOduinvHC5XhGgbY/U6yNBlLZdqA+v2atZ3F8fHqiPr3j/S53nL98hpMxWC9zdk3NCRDmfT1yuO2nJ7PtO2a+kENnKlVNOXLaNLSXW+xfcv3jB/f0zHt48cnl85D4nWrmCCPu1EkIi5RWhUstbiBFJZ2I4QVgRzrZ+pM81JzfzeIQB7970lhgu2wOCHpxS/yVPIWelWNYvRuYwf6inbMG5P0a/9wYSsz8KGnxk/GZ92D/2b0RaRhz1SzLnDAsW42pv/tmCjh70+yeHnXjyOHrjOENVxEhFLhjcNSIO8ckkFA2o3XUuxcZyZpskpQPkmF8PIseZOx5MkAD9SPwY/SqEEJLt4fxtIIs0GG211fnnUr9xMQnGQDGcV2djv6NOwzwycGAuittKaBIhGLj3uIYjO8sx+zCqv4/qMcR68z7W0MY3lNuJRxPItDf1JnIbfROh6xCTFUxeptMpDKMzdf+eGAXTTLNA0ZUju/NKhqaE6Mveq4WcV5blNGX0Sy1Qo8vx2UYaWr4igqTh0Klzg8zsySL1TK76XMyD2WSkii4Hsicza9UnG1L8GmvrdCIaAyE6FBt8ENgXvzWIG7ROD50wZseCqdPX3m1wekCEIZKjBbCQj6xvmCyGIMQ08mB7HxtpSPMZ3eLvqlbdxBhpIXMhE3sgE2iSEVnte1olJru22Cq0wPbYKFW4iwtL7pDPnF5+1mCaQaBJK12j2TKIGzOeIrFXVK3aF9ls3kmMDtxKQdRm1mRUDFrZH14hUdgfPuThw/ehFSQvBFGKFvT1K2rZyc9ONITrdSPmSlpWOspl31lOzwgxcn8K7NvGkjNt3+he+T4+vuXu2UtkOaHR+mXnZ/csz1+yXa48Xh54nl/67JahBGWvxLQQ0mLV+7aR0om0PId4b7JMckBq6lD0WF9jVmisG+0YNHtjoTO/dj64Pg/17hTrQbhSD3iHq4QyVNUtMMptvOE4QHwd6wQyQNX7ph5ExgYRW1+Tnjc/D5ZNjS+52Q/DwFBufuzt11lkGuMnMFwRxlcFnPDljOfRwhiEJjzRHgG/h2Cajxj8Z4hO9c90MJCfzpMNP/QwYUIYx4OjAd/C62MdpIYunPZGUzuU4g1MBgecNJr3oydlg71O67wJRoN+bCWtTPmXIYWvPhR5y1IZ0GAIPrchR+NR4KY3Y9etKEESxOTzTLZBiU6HlsFQ7MxlIpE2jm4VUwDIq1cnwYgTVhIRJBs80tVwfb9KVIh+mPd+YMghBvOdkYOlZ589EqNVMK01ApCy9b6GFMrQoxu9svEes2nt9z1GE1LtvaOuuddxJYuupvU2hn3FySgc1a9kJXZ3Ou3WqB1Q3pPzRu17WtmAp95CmZVBWadDcN3FKAHyYYYYQrC/T3EO8+IeViG5wknyQDYqLs/mrXcZqQitCmkvVBUTZQ0O9xJYklVCsQqaEl0StGfs+xvIgZ7OrOd3aJi7NDHPAdOQhh2JEDVA3em1EsUgpF4LTTsxCbVVI7x2JcZsZJntkbcfvEfLEd0eaQ+v0Lpz9+LT9ByIcWGNmRQiuggLZxAhL9kyFQl2DUTiYoE3NXXnXIN+BqS6LCvhfE++f2HP24eNT3mhXR65PDwgubCsmfP9PaUWI3SkFSQR+5X6+HXK5YH8/DOE9cWEl0caM1Kk3kc/eDwTmfNqB9lpJIojcRt9IR+0b82eIYGpvC+WWt0uMvEKrPfjnHl67jgbdu7ZAd+J79kBiftb+iwkHDmwITcfWdy3wXJWbx+p5nEoFJ13x1oQY9YP/zx+faMB59djMmHqxKYwmbMDYhz3fPZ11aqpcY7O+zDUYdSS5SOADfWNbwO4zx6iYc+9Na61sZ7uvKnX/CZY41S0Ma3P1WEQrEegCjllcjaMHtV5f+dcwA1sJZ6pImMSnCfq57dSP8wezZEkJRkzT+riq968HSwm9eCUI0NKSL066WJNx5BNysiwgaNisx+0zD7b8H8JIbA4ZDeo7bZYB51UiKETfCI+k1gbHqAqvRuMYn4+Qk5x6nXZJjc6b5iZpYMffnDM2a1+Mx8FtO7ae7OJbT/n1tIevCqsHRUbGkwpsSyZlI5JenGmld3tOzt0Bu7pm/tW+dyCkUPE/qwGOcYGwSM5L1OlGlcpuG20jyHlUYOLJCQGFlFqaIS2EMKKHaIdiWZiObNW+3RmmCmC8Cmzx0Ao3SHsvJBj8sw40rsRJ5II2hK0atRmCRCUhrFQewPCCqK07UqtD6QY2D/8Ovr4mtf7hbvFpMVqSOhqcG9Kia7CupzsPkk04tGy+loUlpypZafUQgrmENArEG1d966czveU3nl+PhNPd5DuDHItG+uaKDTefPgBJzq1VbbSWe/uqBla25EmSF7otRLYuX79P3P38jto6ycY6v6lbNMeZjQqO+N5GylgUrRvoLcjtI2+rp3eNobg/WjpBB9StYPWUJQQ4jREHDDXkDBT1AehvWrFh6CHLl4xmLc70jH3dxjQ4c380aDhjT0w9xoT8lcPLvb5fI86egQH5GexxIamLVF3G6Au86ywM0wmtfxgGyeGm/foI7RWLKB7gGttSEfdBGkPUEbJPypODcMt+NtBuy8Gk5aPRwUwy/CbPxsCN8gVPmRpZOEpzROTSfqEo66fC3u895N+V+/Qh7IvjpyNZuzQ1/NFOq9llOj9qO5EpjU84tVa82NXcAUDe5jHjMdBu3Y74YM8gmVlIyjEmwzPgpMNDx4Z1Q205WK5Q+FCu9F4W42TaAIWuGO4Ue+4gfvGfZuvYDp5ty6owfIEu6djANO1Fgfjb1DLR27YmnkbZTJBZDr8hniMEYxUwPp9I9NkQi/aOzkc2eAgs4QQ6GNM4eYgCzFCTE8ONBVDUsK458EhzHFvhBn0Q+gmAiwBiX74qAOgDj0PR1P73khOq/UXR1B1Ed3ua5WxgrrSaWjdjZ6v3Uks5pOWg1Cub5He2MtGrztvPvw659NKeXiNaiXESBNI5ztCB4mmrNI5hsMHecQg0obkhRQzrXdO5xPXywP7ZScHExeJBLb9wov7e+sRpcxWOndiGpVrWjmdT5R9I7TOs5fv0LYLEgMVC6Y53c+mfKs2UG/Vc2J7fG2al8sdpYpLNQm1uhDshAJhVj83z3Ssh2kRr/hsmR/IQ8/Pq4aunWEDMkgFgMsSHT9p0JTmmISqr4EhAWTfa0Sk4VIW5jrwaTBLgHUgQN+80nBUk8Hom9YnaoHiSXE0V/txrX4L/AwDGIGYA8kRS+i0B6DaDzwE3rxq8jU5A+mx780JYVyv3MB9YapvHGzg3/z1sQ5Sw7jv0LbyEt+zqmEkd8hTHUys8Q3Bdr+riR8+LMMi3qqRNlL1iQGLHzz2Pk+hxCCBFA7ZFhgHt2dXXRGfjQoyNoRYX0YPCGMMvd4Gy9tD1BrkfqCOB64K8hQjHotZp8bY00AikpGbA3vo9vXgw8JhkByOWS/7nPZZD3jP5qcsARwL1623/X6LKuad5c/QlbLHr/F+Isc9O2CUA5aN83APk+zRPfIayWVQ78cW9cPrJpgc/Sd//qMyBE8+3Evptlryt9LebUZtCLHiQ5E0qppfmTaIYgf/GBodGeZgeSJiJpTBDOcY7yX+ufy5jvcfh7doR/uOlgu9F7TvoBX6Do8PVDqPr94HrdRtI0Vhf/sBqZ2o+2bV8ukEIsT1TI6RXhTVSGuFvEa2/WrVpgRTSfd1lpeFfS9ct0diCNTeeHh8JAebtUsiXLcLd89fkM93kDOdyPTtUEHySg6muUmtEAWNYrNy+w5E8ik5qiGWeOVsdvbbB0Ah5XubUwyZFG5GOxREFGfHTDRkvCxZGDt2BKl0nP7+axJ3AkxYULzqZ5zNN2LIqjhv/khYB/zYDdGZMJnckm6+SXLnG+aWJXfsV0wke5wD3wT6fPrFzPdxbp4lyN2kmMZnm/dvfB8DnXAESpKrrHcGxX64MNxe+fwcar09p4Oh7hMnY8g/fhvMSUVsgl1H9SPQVI6sREf1hDdOD5XywfIzfHTMwMQ58yPzzD/YfsDEtNUDjMEFLkY7Vm6waxqzQOJluVXpwRZ9lwlDDGOzYcPupR5BTao/TDiJp0Ei5vl94t8zfsa4xqlJiLHI/PSfwQsRQlyeCNmOn5FcULTHm1kyrxjHqhTHsIfbsBEy7DVorTGMvpVVkdxclz8cP4ztnoEQ5uAiM4OdPbQQzDE42AZK3k+zjFL8/dvMlEe1OtbC04zSqeUc93B8LoNOIOjTg66L9fq0d9dclBlwxsCiHUhhKtwPjbaR7Iz5LuUI6L0rxqC+OZRGHtS6M1IVWkH3K7U8UB8/JLQL0q/Gzrs+ULdHyraDdu/5NYpgCv7bxQJ9ziznO0shJCBxMdXv1klpQftBCmkO2c5BdDXprbJfIZjqCylCqWz7xmldIGBVYrQ+HjEaUWMv5PXsVamw3D1nrzvbfmGJkdNiTsJlu9B7MzJGLYhYg5620eoFWoW0k9bnhPW5W5HHqXUXQzAh3XHo+r209WCDtCEdh7vEww26e9Z/JEejR/xk5fj6EQ9SoK3R6L7Gw7Fn5+8yk6/JmPPks7v4APJ0zpN+JHQzwcS73ZMG6/tQRko2EKQjQN0uqMkZnD/b+0r+zgfcmICOdB+/CWP0w2cob5Cdp2HU99UInq5YP/rm6gmkfsN3ffPXxzpIWdl5ZN2jHwRYBjYqJ7ENMwb6RoVkmY/R0SXGOWMhMIc4xwF0a5QXQrDeVbDMNqbEjGr+klsoqHsAwxq26ofNlLu/hXxiME6MQwJZxlzCgVWPuR/7DCM4+eRHCIQxowHeJxoK0n0u3GnOJuZyOzbLoePnDeYBz4144tmmJVRW6fWxbeSoRIZR4Sz18QpqOBCHkX3irCbH08fCd/bhrCAZcKUP7nrma+aHJv5r1F1/Vq04/OXPLB79whkU9BDtVYlPIa5ZuYbZd7IN7evCPcQ+6onVnZE5FEJqV0rrE34cScK4twNGbq1RayEGo6kPODS48oYxz6r1VssVLVfY3tCuH1Kur9H9gXJ9i5YdELSZrUrXRkr2/uu6znUDkVYhZFMhabrPfmEKmVIrOS+oug5ltF6d9k7ZrtDcysKKIPsZvSPEqXw/SMoiYk6wKZGW7MmLPY/elXC+R2qhX66wmMVHCLBfXxNK9OtS4mn1Qy6gdSdiKi+xXsjn57DcE0MGDR6cZCZ8eICw4OBqMva0Z7Ul0Zitozc0qgsTdxhfP3C0Q83C38UPe5ttsoH6bnsxeOJxsz+YSIHv9faNe3t60X0D+mG9bNCbf8MSuaGA8U0Qk7G2R39MvfdtM2V2UfblI5mzz9pGQtmP6h85Au4kh8z3H0iE9XptD419ciTsT+uv//rrYx2kbiX8YWTMfrhomBWFp/MM1Qgrd8cQKaRxKGG3LQYBP7wEbuAsJtuv+gEYvW8RQh5J/LwWsNJ/kDTAstDJtwvhaUbCcUC2cW0pPQlSh9SLK277z5TbIOnB5DbI2p+PoD0CQHDSwPCDOXTOgs8IARqm1tc4XETsIDSCwzBcPCbRjx5gsKA4IFFn9gQPPBYgbHBzZIbC8bPmvVF1pQsm5bX5xLxl+K5yH83ZtgfoMXzDPUAiQ/lZfIaI3iCvlqiMAwtXAg83a0htpgRVxH15Bhtx2ouHQ5vQJ2t9vbX53FBnKapVha0V38BGbOmtISHQtDnV3yDiXgu0HfYNtgvUC9I2E3dtFQlCWhcE6+/VuhPFDr7aO1upxJSJCDFlmkIOiZgV7ZVGI8RErTYEbX3ARHRT0a6KNIOp63WDUgy2XLIN3+aMLAtaGzHn2UsbTC5bK9H7iWYgWr3PdV7PPL5+zfb6Nen+jvOzO2IvXC+7mUBuG1GeU0jk03Oj2Ner3ZutcK07ctdJp+doyFbFyk0i5CtrVIMDJu+eJI393wc8d7MfkeB6l16VjITVpa2HAvtQnRmsYcDV0G2TGvlyVDi+RhjLW2961LdSXsf1yFCPkZv3uXlZhdJQhHAD7R/nkVUvo6XR1Ktom4i3M7EZVG2HyLg+/94bQofEgPW2byombtoOQEjZEke4Gfyt8+yajhW/xetb61z56+/+3b/LD/7gD/LixQtevHjBF7/4RX7+539+/vv1euWnfuqn+OQnP8mzZ8/40R/9Ub7yla88eY9f/dVf5Ud+5Ee4u7vjM5/5DH/+z/95V2r47b9EG6KVgMmcpJgJHYaZu2p3KqVDgzESlwXJ2RriMRlLR8Sz+dFMNzfTlI0mLikTTnek0x0pLd4/SMZwk2AuqHllySs5JXIKxGAVyBjCNUkis2Iec1JjQQaHlWI8AkBOiTVnkjf6RdRMF6Mcf/bB1yjYpHyttH2n1WKsPodbDC5rqPfprLfk1YgelZVqAzlkl7Q1V86YyxSD0gK9qWVYWPUX0sD1fZ4o2SGlIvQoJlQ66MshMT1t8kpcz8iSCcls7yW5hXfy903RehlxOCDbZs4xk9ZkcG+ILm0kEMxpNuaVmFbTXfPfiT7cHAKS7Lmt6z1rzizLiSXfkdKJmE7kfDILluVMXu7t13rPst6RlhMhZVJebE1EU5ggn0gpIHlhCytNkg1MjiDJkUCUulPKjkggL2fScjbWZhQQk8qidWiVXi/s+5V9v9L2N+yX99mvD3YYp4ycnrO8/E6WT/4ueP559PQu4e5dlvuXxGWxzyVn1nQmnZ5BXljOJ7oIcV2J53vys0/Ql5UWobSd2gqtl0nnN/q0uPt1ReNiNvS90bTRY2BDCOczPWXS+QWkM5WI5OyVd6OUQqnm3BvocL0QUmB58QxdIvvjGx4/+Br75WoVVdvQvrM/vCFsj5Q3H9DLBbRSrhf26wPy+D76/q9RX72HlAeQnaJCl4QEn0UU8fm8QEXoIUKMT3oqM4nsLsTcO6U0evO+35hJ83nJKJ1edvpU9rafISGZukiIDnvaYLDEFcICwWbtxq/gIxUwKhebjfRhGCwJjI5WQK+NupcpKD3GjqSLBUMwcWHvizeM2zdaEXUO+4pfYzRoeslmJpsywWcvR8U0EvvWKr02elW7Ri/HTPLJxJ1VTaw7qPW+aNV0SltDyk4oBWn/DSjo3/Vd38Vf/+t/ne///u9HVfn7f//v88f/+B/n3/7bf8vv+32/jz/35/4c//gf/2N+7ud+jpcvX/Jn/syf4U/8iT/Bv/gX/8Jufmv8yI/8CJ/73Of4l//yX/KlL32JH//xHyfnzF/9q3/1t3MpdvE5e9BwzNMnmUUsS+4MYoLRaHvvqLOFYkpo6zPzv5VWmtmXuDeSeD/EK5aUA6EJtVZTNEiJlCOtVao2mIwgY/JEfFbKsdnBLBLxJRjFmWrxyNy5gSQ5SvRhB/EkI5vzWgM6qONPjCalraTmGDqWzdfuUEcyOEOH1fMBtX00W8M9f4ZJIw4f4rDj7df3roRkWZdNFVtF2V2Q1+zaExKzB9LBoAJ/aBPqMqjspkLzhqzMDcxAbry/BMMQcdxLm1HigOcQNNjmyy48OidLRP1eWZUUnNkZMHYkEw62zHww90IMJLE5MAlH1Twg5uCZa4wRwQ548yxbbG3csD5RKHWj1UJvO61V6KbMENZ7ejXhXz29a4dvyIS0kHtF942+XdD2iLRHctrptSBBLbimDFjFPgbBe692f2MyiaumlGLBJA6Yzde2WcIkX+dQSmE9Z9aT0dZjXlDcdHTJNFWg0/Zq1vRuGpqDKVbW2oh55dmLT9CbXWtvFVOrUmjKdXtDzDtxOXOpG2k9saxnale2ekXYSCEgFMjPWE+fpEer5mxthLlHDtjO91kfpIYx3+d9JlWr+LR+0w5Kx6DDqUcZPSGJyXo6vkeCq+b3KYIrAzLgIFjc7Hlh7nVu9+HoUY0qR9Xp5epfd1N3jD7ovHBjaCI2zIvDtxasxszW2G+3M1XWK04xuj+Un5ui9twcETfX4cE81jlG8g3nme+Vfjt79pu8RL+RUvLber377rv8zb/5N/mxH/sxPv3pT/OzP/uz/NiP/RgA/+E//Ad+7+/9vfziL/4iP/zDP8zP//zP88f+2B/jN37jN/jsZz8LwN/7e3+Pv/AX/gJf/epXWZblW/qZr1+/5uXLl/zzv///4PndyXsSXvUEa/+1ZuKrY9FlF90MKc0eDF1NnmYerjd43Q1eagcH4MPCKQlKn4OZIdp8ibZKqaaJhXYzNKSjLq9CXMCdL7UNiAxXc4huJe/wlUNkcwBu9GdgLiD1zzeUMQZG3bpCGKK4A946+jqTXuu9ASUSgjokU4BAitmcOwd2PYkMOitA8XmHwahKLk/UtZvbcW/E7FXG0Ffr6nNXzRv4K3GIX+KsoduEgTE43SYUETxbNd1G/P4YJX5UKWNy/incaX0ina3kw+F4SBypm0Eao08tI55yUXaw2DM4IOZJrvGAlqRTCHzt0nizW79zwDkxBozQYDN8BmEKOZuWna1Hh2NK8Qy+UsvVXWA7OZobaqvFnn0+IWExFZVWCXWnXx9p1wekviVwIWqFXoBCDAun09l12AKlVa+uxwHdJlmnloIIRIdu6J1WNnotnJYFMKFSQiDGTMgLp7tnVg3EhXB6huQzcTm51mRnPd0Rg7BfHshBptK3qZl7b0YU6m5JZ+/UbaPtV3pv5HUhr2dqh/OzF5zunpleZavW4BchLM+Ip08hd88hL76+buBkmGreI2ipuhrKmPHxQKbOhGMEgpuzaBCtU7Sh6yEyrerJ2QgccsxVyUzAxlyf3rD19ImCxm0ffKyh3oeh5lGZz/NrEKDmGeakqWCJ7hhDCb7fxixp8LNjrMdbksZIRC351xn8VMR6uQzdPyv8wQLsIKGNkQwcLhWHFt+8fcP3/V++yKtXr3jx4sU3O+rt3v5X/+W3eLXW+Lmf+zkeHh744he/yL/5N/+GUgp/8A/+wfk1v+f3/B6+8IUvzCD1i7/4i/z+3//7Z4AC+CN/5I/wkz/5k/z7f//v+QN/4A9805+1bRvbts0/v379GoC4rHbwx4TkBQlpHsBIQ3qdB8iUBZFgwcVvenC/KDCW3m0vb7J9eqe3AhoYhODoQ6C9+SJCvDPjttjOqgoIIQuSMoRsVNqYCflYXFMg1Wc1gJlxhKBP/u62sBkqyiIu2+KQYiTa4nH1Y2H0BbqrrIwGqfWCWuu4GY7d13FP4kIQU4Xo0iZjDletxntto9d1wInNDszewXX6zGbaoEcTtBwZnYcLpybfZo63+PboafXm1O/oA7UhzKSgazPzQmVaJcjA3D1A2zO2boIEr5aCNfbtWVvGHJyJJNJo7A4Hp2OmhiMzfBqoQOlceqA2YWSmt8nG7JegSHLaPXLMaSFItz5RbwXV4fcltC70sJiTbrLKTcYgJQq1mSEiFQkbxJ1IIxMJZEI0QVHEpaFyJMdgSumuFj7XrwefVssc8BwHfJRgOoExUnpnXRabszqd0LTQazNJJ/9sXV2DMkQuD284r4upQbSO0GilPjmMa1eQbKaDIRJzJfZGplHLlRAid+cT122jq7DcvYCQiLLR2iP0B6RF+nVH6mlKSnXw8ZCxXz1IjEN5BIbxb+pEiqn5x3z2+L0YldF8/mqwpjgi0WUkg7Y2ICA9uMeU9/r8HIgMp4FGrc3PtKP6a61NQdfbRG5U9TYgbn1RU5hwaTANIC482zFZJW8TGPmJOQR8uA0MdMRJMzcVqI1gAKjLPh1EoDHqMvdFEKIOlflxrrZ53vxWr992kPp3/+7f8cUvfpHr9cqzZ8/4h//wH/IDP/AD/NIv/RLLsvDOO+88+frPfvazfPnLXwbgy1/+8pMANf59/Nt/7fXX/tpf4y/9pb/0DX+flzMxZ4f6TF06jGxEbAYqhLEQhjqyHxjNmvAp5wO6ajelfx/QkmVcHRMtRV0MdqQMOhZ5n6CaiTm6zE8wa+6QMhIN0sHL6xhH5aFPDuXbJqrCk+B1+zUSfNDUQGqrXCRgwpgBYiKmxTdbo5X9UK9gBEkIDT+wG9q9AssZ0mLq4wVTR+8jcxuVp3rn2bKzVo/sL4xtrh3RgFkjWiN4SAkNtJ3ZhNZZsdmmPfDwibljmyZ6pde8l6Y+PDh+butl3tPbe2awoxIH7V38c3uGaMmIU2x9wNiYdl7yjt5AfAphjCxUupEBdo20nmw2BOazFtzjJwa7dl/LB7TryVJvdG2UuqOtepVokFJMmZDyZL3Z9VtQCcF1GbvaBuiNiEFzxjQLRIIFKj/0UkoYcmfwTEoLGhO1FcuwbywbutvLBP8R53Wl9hVSdPAmUlsnLyeaVy8xJZrbd6S8kAOU64P9vl1J4UCkWh1QsDH0CCsxrxAz2ishCjmvlFoJEllPkbJtXN6+Zjk9I8WIEmjlasOo+wVZ7yA/g3giSPYeq8HO8/67SoQG258Dppv7iyf5600ljaMmHTNJ6N57q9C9ehqV+oDYCC767MoOiK2xZj5ZdIPse2tIeurRJCh2iU+RBvtHCzKmTHH01MDYgEHmjjPVEg88zb9uzogOKSknNFmXQkkx0uvhSh4UT0jxZMRuSIwfZVA7geQGIbCE+b+Rffzv/t2/m1/6pV/i1atX/IN/8A/4iZ/4Cf75P//nv923+W29fuZnfoaf/umfnn9+/fo13/3d321WBCl5NmwGiD6CaSvKF+OtBtWk0e7FCBue/ZtgqcGEirq3kh2A5oogdjCpAmEy60LwvorYEGQEovp8gR9o1jDNPpwZPXu8cY71i+1jQ4z3HXNUYNPrrrIweiAhJM98Gr06SSMOk8RIyIsFKe830HVamQwV7947PTnNfDCw1DZBTMmIGRhZgnZDlcfhRqoxIRFv+JvmHyIYS/8YvB3uqSlaNYHDB9orveqsDOeJNQ8KL5b8MBDC7Bcclcno8dlGTCFbReAV3VSxH5CpJx0DWjT/HLXN6xnlOMyMLyBz5OBgCt7k1jfByrygRoVpd2vMnMyexGBO3iLMyIRyeivUfQNtpJSNWu2wSsxW4dZqX6Mw7SVSiEBGeyFJRjXSaKa0ESOtN4KL0cZg66I7IcKo8KMXEQEL9CkmLKAZlB2xxKmh1G5rLmaT4uoOH6sEswepjV4rp/M91+sjr9++5Z2XL7lsFwtEZaOIHcp5SFSBr1HPzlM2iDdAD5GwLsRYLOiJkqOwl9cUNuT0HFUICu3xQkibuTEvDVk6IZ9pRETDoeg9np9bwY8AdDBeveq9YZuOZAy3C2kjGPi193a1toNXJENuS0Ji1GkqY4B9yANZAOiq3g7oaGVWQmPIXsfpcou6+Gfo41oHC9WwC3ozH68gdn7cuv3KzeeptRLFrYm0WQbrLYLexIW6jQCmvq67vaEl1GFonbbjujigTlWlXi4AlK38Jqf+8fptB6llWfi+7/s+AH7oh36If/2v/zV/+2//bf7kn/yT7PvOhx9++KSa+spXvsLnPvc5AD73uc/xr/7Vv3ryfoP9N77mm73WdZ0zHrevGCPpdG+LUsSyT/GDUZIrkmOzR+HkcEmj6G4dGmdRKRUk2YMEl4bBZGd6s4fqvQmJAennI7scUJKIiY9i2UUtVzsICUgypplg1yjRspXaKwljAU3NijgWotqEfrV5FBEhYgZz3SsZbW1mLsRAHb2UlBjZmev3ENM6+0CAz6h0YhZ63e0ehmDXGRMi7kjcmx1CEulERNLh9unDor3sSGs2+S7Bhoxz9oM+enXj2V2vGMxl1eqcsdIhOllvZGOEkIeRY2NwT6Cbd6tT0k0EdrWKelQsbSP0aNmourWHBOoIjMAwhosh2cEa7NpHhmNVsVHUzcdIUZrZZvi/t+AV0bDPjiYNRDVosan3erQbwcdXhZ00xkxVMT28MXFiZAuzg0/pdPQjsCqX3qllsyCGHUxdIS0nUMhu+d73NxAj5VK57hsxBNbznWfXhg7EFNFu9zBGU3Sw6rUSgjnlSoB924xF6xBr78ppXZAYyCezIdHazAIiZINGPenq1wdCMmmt0K7UB+WcMo8Pbzkl2wsDouq12uftAIGQodWNvKzWL2yNHoWYFjRmU4ORToqJy+MDnSvL6URc7umy2/3Zd2jvQ32k5ufE9SWks/Wv3fhzKM6oa3e23ql6VFFBnG6eMkGSVa0CyNC21AlnqSpNE9oKEXvuXsAbU9KHp5GASLaeZd+nNt4IGkb5ryimet9v+mTDNmegGkI0VZOoM4HreG9cPCls1gYZyFEfqNGN4kRw2L4DXVzTcoxtiGksUg8vqnH+TQPQWR1GwGB/VAmhG9qzb9NhfL8+/FfP/NvX/+o5qd4727bxQz/0Q+Sc+af/9J/yoz/6owD8x//4H/nVX/1VvvjFLwLwxS9+kb/yV/4K7733Hp/5zGcA+Cf/5J/w4sULfuAHfuC3/bPVsz9bS91LzgNP7YhnDI2UKs018WIIpNNC1+TYqR0S42CaE9cS51Alkol5MU+VDqHaMG+MYF9kYqWi0Ho1unOvtmCi+eQArgBQ0GbwWe124AZXAVcfALYVqEiA4OZjJuLYD4UNb9kYEcmULFTsfeON1MuAIEzj0BaI2FFq4rwDutOjYgsxk0SMadVsaDoEmVTVYeUmvdFCgFqMSh8GjGZzPylmy0DnfFi2PoszK6MYNt+dgiXBn5mI0buXTG0dJw3bNYv1o3BcHa9YxOEZEUXT4nNKNq80qqGkngAAXVzI1w9iUy2wjWvVbHA/LZkUf5VKEyXENMIrqF1TEIPRTK9PpiHjUNrvagKfk0kaTNKrVBM/TkGQYF4/Id2Rkx94ze5dip3ed/bdnHTLXohJ3FTR4FTtjcfHN8S+mSZeLQSFUis7UDrcnU/ea7SuzBAKHqSNMRidglu+qFuY1+K0focekyUjKonhBmxq7abYH4IRZnptPD5eOK8LD4/K9fGBvJ5Yl0zdH73vcQyrm+6bfx42EGFvzZJHF1QeWPEY5g/LmbuQKWWjl81mdCRaEhIjSkPrFZrSS0HW58T13vrEXWgKwSuR3p+y26z6NkULVavooiMyvVrAH22BAdUH6x6YVxzNK3Lra1qs8cSsFeul+nD7rMi7EY9wNwBLcL0v66an4IllSMRo81itGdQexSDh3rpBbaj1Of08s3mvNkkao69m54X7TgUjfA2NShwCDEH8ui3wqAvRGkTfHNkIc18gTEjenu8wZf1vQEH/mZ/5Gf7oH/2jfOELX+DNmzf87M/+LP/sn/0zfuEXfoGXL1/yp//0n+anf/qneffdd3nx4gV/9s/+Wb74xS/ywz/8wwD84T/8h/mBH/gB/tSf+lP8jb/xN/jyl7/MX/yLf5Gf+qmf+qaV0m/1GhWUBhOYJAgD6eO2F8GAYowgENQmn2NMpil2c0Crwx3GVIuubYX1aKKVtiEncjLIRn2mwn6uQ1jRN3oRo3mrBczozEJaQIPPGjRbKKLdFS8Olt4olof47cDBnvSsRoMTIYjRqNsUcQy+Ycbhj6sQYFVMa+56elCzGfh5DBbAmgUNM83D3lOiz4AEQm8eAIL1ABgJ26HwDKaCYQMcbYQIDtrEEPYVG2wNbgmSbM4kRdDU7JAUOyDq6MkN4gRMHTKtneb4e5CIJJno2rTtBq/sjHChMRJddNgvyXtMzuLUjvr82XW/QMrk5URIq91/r5q0VYQ2LRkMkDSyTky34wV2T1u38zZKJ/RK3a9G0Q+mW2fqGbaZeyns26P5gYnQy0arBrXEvNBbo2wXUr1Q6yNBy/TUUh/Cbl3Zt511XewA6xDXNA+ukCwJCRpMzqkLgUaN5rZb2JFkJKXuFRPBmaCa0GAVd62Qop3yIUYjpnRTTK+tUMtGQ1lT4HK5TIbqrZ9ZTpFebFg85ROlNPKyYqDJ5smjUaN7CMTlDGpkk22rrHcvrSqqu4HCAuiOdKFeq1W46wuQ7FWFDeyiOFTPcfgOkpEHkBCwZGuGkAMmFF9bgto8ngwBgW4VpyqtVoMh1ZiBwWFOAWNQtgG9m0XO8MPr3faB9VyNWdyC2coENzVsKFW9zSFpEio6FXofwmRo2+06R0vC5ww1JJ8vcyKaONRO8u+0UZbWKqFhBKuBAehRbTN394DFrSqT4ChL/NbO/N9WkHrvvff48R//cb70pS/x8uVLfvAHf5Bf+IVf4A/9oT8EwN/6W3+LEAI/+qM/yrZt/JE/8kf4O3/n78zvjzHyj/7RP+Inf/In+eIXv8j9/T0/8RM/wV/+y3/5t3MZxyvMHIOhJGEOrAPvl5vGYPOMsVFaI3Yh5Uhc16MFgh1MtRZKdVl7sZkqw1Utgx7zDyMbUDpdK6W7F00cWcSCRHW32IaoCyzi0vexU+thH8/Q0HLYwHgaXsGIMO2xOVhlllmN2+FTQ54VimPm2hVtVsWNhnv3eSjxoHRYC5jsfqdYlVMtmCrWfNVgDLskpt/m0c1/HcroHWtM1yk4632kjsEHKc6+Q/QZkkGCGBp8MSajPnsQsH3aLGOfWa7OSnpQ+9VZh+JQpyWNHpDAoAo7SSyDpYH7IU3GJYM8g8EVMZKWlVp3yr4Zi1KG0oR9Fuk2x0RrRI2EHkz3zwfHR5bcXUlA3A5deqVeH9kf33K9PNDrdiOEK/5M7NAqpZIkkvJCdRX1vCwWCJ3injGN/16g+oxVoJFjpNTO1iqlFPKysuTEvhfu7u6dgWfr26TCMiEJZX+YzK3WO10hpwVxtqpKpBHRYJVDyoleK6U24++IVaLleiX1HXqxtRigNCzI9MMZQFWpxQ7UvGBwdzKy0r6pJZfdmWsENEbER5JmFeJD0CxnCIm6PyK9klOidVOKr1fLmyTfI660EcY+Q0a0mfCoBttL0+G57IhkM8ZUnYxISz53S9YkkJaVFF3n08+jsu9cLxfr0aY8vdFSHPNUTlLRCk38bFAbim3FNn1K9llptKKE4E4NArUp2pSYcfHdaJqTqiYs2w0ZwKs4iRgzVtQHnjui1Q/FMeqSPRE0sk4f0nBwrO1WDVIc/X8wdqbT8Q3eN0gwhm+tJ/W/ek7qf4vXmJP6Nz//D3h+d7K1FLxRrcNx17OT1ml+yCZ3xdRmytwxZ+J6xxiKHZG/1sK2m7WxldLBFraq9QhSMrjLrdhbM9O5rsa8smzQ5XF6pxXTRUvZIENuDiD1xSLd5EwcNwCXVGLYdgzZf+SwcfCqsQ+yx6C5ukRQCma8ODZ/DJat7qXY7G04sOnoDXIJATwj61rQWulmd4ykBDHa5YRhRWLyJibLE8nLAhJpigUocSgkDDjKP8+sFIUxRV9qRWn+O5zvnrOe7+xr5n2s9ox1DEV6BjsqJf9/cXNKw6s6rdpIQBervCzAwCBIDAuG0T47FOjtsE55MWJN9WpHGxIXYj4T04AJC1ptkn7XhVc98aZ2NCZCXCcRxxIPG3RFG/X6luvDKy4Pryl1mzDMrUTOvtvMkNHUEymfCGkhnU7k5USMGdXOEhSuHxDLA9QrbXtge3xNqzvr6UxMC61UlmXx8YvFiBk+9xZiJOVM60pebfbq8uF7PL7+kLZfbfRivSee7lnXe0sYklXWvZuD8no+I6pmx4ERmaIIvVy5vPo6z+/vqHiC1Cs2JHwwXM1qBZv58uovLiun++eUUj0vGqLMBsMhiZASKQCunde2Qnr2kryeaWWj1X0mTDORZSEs94T8jF0t+7fYdCAZ9kre7/SWQivYKIa4NmKllc37tI2o0B0SXdczKVkg3ov1sGq1RMGCVGJZVlIcsm1jNrOi7g02BApo1SDkYEo4EqKRVYaqRrTzasCWIZqbtEkY2R7Q7n2lWua81fC+Agsq4sl+dxKPtkJazla9aaXV3fykYiTmk93TPjz8vHHivICRuBv83eZs6pu3D/wf/29/+L/dnNT/Hl7ibDxrfht+LnHAVdGHKI3+GGLwasQrFbUufHcV41qr903wqiF6yR/MkO36QN0LKV5Zzmeyy/40rTRnCcaYiZIZnlWqheGQ2VqZuKx5BLkobRCCWIOxtTazbLPB8PknsWsZMJ+6tUj3A3/MLmno3kOAVpQWxNlePjskcUKMku3wa707i8+hPq9MXUDCCyQxbDovVsn1erMg7RAKFokcsrDrHs69vRVqcy0HZSYEJoozoNiReXllo5jKtkDOy+z56Ag8PjA8o5Mzp5JvXLxy6a3Z85EI3annfUzYO1Tr/z1gkWMI3JKXwU5s4+AKEalmU0+wbBKtBp9UV1KImQWTrtpbtV4zEOKgXFhyU6+PlMsb9stby/QFyCcY1Zk/k6gZbZtBNhLNRmXNpLszgcUo32Wn10dCV/brRtTNhmK1U7aNVht3z15YgBqUZu8FtWZQGiITLkcCeck8qM2RRVVDFYLMtTcSDOe1IVhAXXOyfpF2et2J2VyBUwxcLg/ExeHGuhHjYvtXoJZKyhZwEB9uDrb+r5fNDvu6W2YfhBiy91rt71pO/vMjQQvbmw/odee0LoTo7F23k6crvTzYELAKku58P7VZ0U0mZ2+EnqwHh3rvUcg5zv09LFdiCvS9mbNzXrzya2zble7D4jknTqeT7zN7f5tOuXESF0jeH1SfXZNJlsrel3VILmZiMhkwwOTM6A42+SgFDs83gaHtJ9+EodqbD/9iavE0JHRaudDo7jPVnZjm+0ybK924o7ZzBKzP6707NWSi1kcA9vL/J+LE/5YvbY2QF3qv7PuV3gopLqScIVaahAnvtF0gipXm4vbRasy2ZACzMcjAEwrr7XTEqqR9g7Kx74/QHtHlZKrJ/u8CyNKRADGtli1bi8QOz2xf11o189LaXX7GILwQTHKk1d0OThF6TJbpxoSJPw67B5/xEAtU3Re6cw+sAdwrvUOhTAPETrIsShXdq+PDEaHSFHpwRXS/ZxRToCjlSkyZNRkUZNVIM3aaw3waoPkBLVJcHcGuTbGsrg05HRucAIJLQUWMnW+ilSlljz2dUK6IVnpIxg7T5hVG9h6ZExNiJuSVkFYbmGaIhlrlM1hbFnW9Ad42KBe07rMy7r0SglKbbdSgJv0jzYgsptvWSNnUvOlXtAzzOqt+Gx30So4raw1Gw+6FllZCyIgWWtkp24V6fc3l8QFtBvsQMopVNKPPYVWxoDtIC3a4t0LVR2Q9E08gVPr2lv74HvL2PdrDK2NToaRlIa0nrteNeL0YazQE8rqaSr5YdVH2B053z4yplhK9dpJE1vWeN920I5tCCpY0VAaJxchEoj4jA0aNd4Zra4Ug1vtrIaCloJvRkEMI7Ne3nE73qH9/rxvX/ZF1WYjLyt4Mhk+9jParaxoqNSom7iyoVmiKaGPbLGCEVqmPH/JYFtKyAgFpBlnHKATpBNlp24dEbfR0noephEDA2I9kIzUVrd5DrkhVejAJMkFNHzJYXzh58EpR0LbbPaDRajFCUM6umuIi2XUk0J5ECzbjptGYrMEDQsogkTLYeg5/q4Qpqj1n4PVwnB5oQ62Ndt2p28WqQXcsnw4AuLODiI+7BFSjVcTBBri1FoxKEpCmaNgsqHclMHQSo7cZNujFexB2D7NrAIb6bWAfT2/EoEQ1GmspmzWGfWhwXU/EECkYI0ZFTHEhGGUTgLJbV8kVImyA0CiiaVRi1foM0htRO21Xo13HZF5Gweef6PRecK1Jhh1GtE7sVOVGLdPprc65nKE2sOaVIURLd0Zf6FNlu7lzb0yJeMqe1RjsOHiBrRbWfGdVRGtoCFaae+Oyt2LKCqERY6Jim1y6WkPWD3R1iKTVDekLPQiyNCdMFKeAh0kUsZK/U6vBGXM+aVpa3PSwus2BhOGULEwhVkWmPcA4/FV3g257sxGDvLBk674oxjpEk+Hqgw7r0F0IAZxx1UefTnxAW5vd79qMlSlWWabgh5MB7owqQcSq2NE7HPDUIB5oN8txpRO7smrAw5r3wUxWqO472/WRft0GPRMwONnMFy1oW1tNvKcnqCuIiEDtnceH16TtyhIX+uU1/eFD2puvIdcHexYKqS4s65mcnKkV7Z7Vaj0uWYZLNQ7bNqQHSjNIbTmdCHnhur+1Hozf/yczP936IbWa42/3nk0OirbCXnfrn1Yf5MRMOw11Cw4BOXysjVpcKknhfP+CvVT2sts6HaTOgSL0aoQkZPZPWmsGy4foB2tDQnHkPNggeB3fG6hls3N0NPYl0YNYhS+4nYX4KIKNHNjFO0En2jqfen1T7ZtZZQiw5uGW4PSFrtTqTOKu5jYezYGgoRYkCJjups6uRFCTYKvduHvkK2HPhG3hdH5moyQiqLR53mirtO3C9vCG/fpoYq8iNsztoxs5mWCzCdc6XD9QHA9qdFM1aeOsHOeHAlrp6r0mDYzB4NGyQAY5zXti38LrYx2k2n6lbkKtG223yfzL5ZEUIyyLZeHR1RdapRfsMBv4LdD2fWLc4wCarLNuGHXsu2WE6lROmmO1O61lSImwnIgket25lg3tFW16HMyoDbliU/XmMTWo2va7GRlYK6p3m3mqTmIAnJHnC0A6SU+EkNEYfMP6/EJQhkVEcBg0RJ/HaJW675RtM3+fmJwJyIROB8zYyoVaN6+KrHJK2wWNwdeb2IbW5IQOmRIy6oOZEpPh50PepTd63UF3q4pQYgDBxEcHocFo5dY/6tUVKVq1ahRhmAjG4BBkK7S+Qwv2mUI0hXtVm3CvGzSrmKY2YLdB01YrZbtQyk5Kgdwyp/WMwaQGoTaGAshHemAjSxWj30ta6c2gJ+mVFUElsnvTujer7MrjW8q2oXU3+FEw+nDEZuXccqZ2NWstAiGZn1KrO9ftkb34wb8IMQvZIV+kWyO8mZJ6LxuyrJyWEeAcxnb5qm27klNG60ZX5X5ZTD0kmHoEIpyfP+fhTUPSiaLKmoaIrrHKWq+eaRd7hq1QtosJrnKQl+wgTq4Y4eMHM0Y3SvGmfG/0a+ecTYw2hEg6Z/a92NKe4qWK1N2SQBRtYQavUis5JhtQrq5c7s8vL2dyXii1UctGConWL/S6EpeAxuyJrD3b0KvJHKnSS7WeWEw0bOTD+p1DrcFGJlBo3tdk3nfcpqLTulKdEJTTiZwSotZ7tScuprCulTFao8q0/gjeo7peLrx5fKBL4P7ZO7x4+WnW8wt3DvDjpzZ62ejbI317oF3e+hmTCJoJeuh0pnWoclhAH3T80TvHlfGbw4KB8xz4LpsNWdvnjZ68eTsALHlzF2X6t0El9fD2NVF3O2R2z7xSpGxXdL/Qs1kxhJQptTja4x43g0W020ZOOZmHjmf8TYeArJCieG+jTJwaHD9uDVoi4krWuNZdq/PrLdOKiKw+1+WlsVO6RaEHJy543yY49KYIey3UUiYc6e0fei2kJfrXARgkWGtl33ckBE7nO1RtCLv6vFCKkX6juB6dUNIV9zlqNEwJQsChJ8scW6+0eqOYEe3zOxLtzfdEDnY4S8q0aEr1IUS0N4qxPYwNRydI995I8UpJXLTXqzNrq/hPsIN+vzT27UrOeaoUNO8j9RgdGnFbhO7kkt2GO0WGGK2tIxETHk6Yft14mfmgHli99yjxw9W/2wY1Bft7scFq0UYSzHq9C9u1U/pOa0rdr+zbhbqbxYM4w8tILoXQdUrYDMaa9SvtZ6S8cr+snIyySV5OJFXC5cLWjgoXEc7nk808RSGmyLaZq292Vmst+8yWI0Itlce3b1jPz4wA4OaFebGB3fOzF5RuDFCimjdVEJoYcUS00radKJ1WrtTmBpqqXLcr2X27cA26MCoFgbIP1l8gp8h127gXgWYHuuSASqc3I8MkV65vpXij3uj+IcVpJNlD96Fkg2LFB/of377m7u7e177R+9OSafXRXbVdrLi7gk03RGTq8o15IRGqmnrN6GEbscEJGGNYViwQ1zFX5dW3+JB3oNvhPhEWS4iphVoPKxBj+CbrMwWrHKVXYjeTz/JWuQpQroRlNasbCdAarVyh7iQqPTgVXQt9r8ZWzckG1ZvaPcZQgFqNBOY6+PRWXHHC0tJeqlWaHuRsdsuuc/hzhTCEp6ufMUOl47d+fayDlLbCtsHbh0dEEi9efoJ8/5zr21fU61u0NKruhG5T/WAjrKqN4lYAqlbLO0OaAt4LMJqo9maVGWLUbGe2pRiNOroYgUKCmBOsRIOg1CnTvTsF88BqJ+lgYNBO+0W9UuhWaUmvU9BUevMZIazvgzoV9VBIELysx/B/XFk5iKmad93pEoyI4Krf4lCY0bcL0osN6CnEsBCTKVd0px4bQaXSfJgvqYITNsQDwMDNVYJlXqmbb49YxhmbefL01lExQcxgukxWuQz4UgKSjKat8fC26aLsTgApVFNF8EMdhdYi2q+EpRPX6H22bKSDOsgePs8hRgSIKRubqhcEmbp6NhMCI9EYEOJoPOuAAt1ufcrOjB6VdhKBlcqbUqz3V73BnALSI1p9vqvvRvfviQGEDManORAnuy9RWNazrxElJojlSoqm25dzpGqnBFujIS/WD+iuA1B2WhQgUHshp0zdlXBa6a2zO8kiLyf2rsh68jGHSIgLSSzox4gxO0UI6iKxvbI/vuXutCAuudRqNQakq2oMM0ljh+EjHoppGflcoZM5Lo8X7p4vUw1dvB86AjdALRuqbQ6v52CsXA2BUgspWCDV1giYB5xKY7teTP9vPdtaKg+sayPmhOjZdE16N9ZmxzyYutHMJWSaYPe17tanEa/Su++HoRTiFjFV1fpm2HhFcJkig/q6K6i7CKyzS42EYLT3MRiLSxCpCxiveWF5/oKhOUp9pD5ekX0lxLPvreZoUjES1xgUdthaOmgxXzBVn20cZ5azYXtv7vLQ/PNFhGQIT7V2iilvWHIa556NrnDT3JvMEm4djNrf4vWxDlKoiWGezgY75dM9ujxneZ5gPdPbRoqJ5XSi+uJMMaG9EK+PnrkYj3vcUO22EFt1WRirr+1QTxkJC+L04RgsO40ydMoKksZhki0AYT2OdTn5sJxVYtXf14JDnT0OU5gAL9J9LsvhXHXYq/ugYDdiQ6zRJEx86PS0LHRV9r1YSR2Usl8NHq1Gl7d+isGK2kxiR9tuc/VOd+8E0/6LaR4IOSXL7vYLjgkxJJa0Vh/kNTr90BcLIdlcVLdrj978HhT9uvt8jzPyzFStEGsjLs1klmL2ys8qjnQD+dHVE8vDFK62K71Gg8hi8gHDRKe5jFBnNK5EDwiw+T0usyfiqhQhEnojhOTV4jIPj2HtYH1FRUUd7q1Mb67W6HujeRBLKVH2xnW/UvfdhVyt15hStsMrZkIeyY3SHI4jJGqyROa0rCyyo/VC317RL69JWqyftZ4gJGPWdVMeyDmjtdhYgfcIegWVSqkGNwVsmBZV9m0jd/tsZtehc21GwfqPmGRV3S5EVYOUxskiQik7ios5B7OZz8GSPm3qw787ySvn1oxdl1Lienm0sY7T2VuZnSUm9mZ9q+50Zu2d0kxAGVVXBhNq2dEgJAm0UtGciTm5i0Fm3w11WNYFaUK5vmEvO+kMIT130dhiyWW1IWQUYnJTSmy8YYjzqgj0bvY8bmI4gwo6KxAkkINBd711aq9WcWLBKQYhpdXckRUC0XufpodYSzXTTjfeDHS7962y1wtdC6EHYlgIGvwMSXRRahVYFpQ2yVy975bsYcxnVXU5pTE/ePQLDWGxfda80mx1fEYx08+QaGrc5hjN9DK4fVJEkR5Jy/4tHfMf6yC1np9zf/8up3vHPZNAbJAyuryw6X81OvniLpkx2RyBpIW2X9FeSGmxBmfrBFeQEDlBswrDzPAsEBkDzqw9ajXttBiCs4rUmubVPHZUbH5JW6VcH43ZZUM50wtoDoGqUr3Et1kgkJgprYDYVL1JjijZe1tVd/quFpzcery1alVLN0jk+vY1y7JQekGvj9SuNDksTYJXZm3fjHUXEw1rHENFe3DmlGPMYTHFASK1XJ0UYZWjJEfx1XvKqjMQUfskR5hobzC5nmLQhsZIXJMFrxv8uu4XYtzR6CZ9foDJaNTChKsMuvRBa+7sflzeIPXCoFr3MYQoVkmAJQetGkNp6PqFqfxh0kH4vVesbxmDbXhTc94B691oSA7xBoeMugm4ogQt0GycsetB3Rf/XKPBHAM0MQV0dqvyxnyOQShWAdzd3xNDR6+PSL2wX17Ttrc0LfTlnhQD2edTICApucKFUkthGcmH2CySlAutNjvAJXB9fOR0OlG3R2ycQkjLibKbCnat2+x3CEp5fEtaM7F3tseLE5BsTmdAtSEF6r67XmGnNmU5RS6XRk6rsaqjKY9YhaSU/YGYglUuIhZAHXaqvZPXE1mclKGN3iN7NQr4IrBdLmZrL4Euw7stG4stGfMSyeCO0PX6QNKA5EoLKzVlYh9K85WkAUqh0Ei+inIwNYrarOo4KvBDXCB4xW6Q7YKqmVr2XqldCLt5hYmqVTPRkpPeqrlwg6nJDxp6CK5Lal83iF8h35mslCMPaDfCWM42bpK9quqBNpNVg/T2UkhdLZlQUJdk016JNHDfruHMoOOZlGItDAlUtb6zcX0CYVnRvpqoQStESeQs5mDxLbw+1kHK7JCx5mg35sn17Yec7p6x3j2HrtTrA3XfSCi1KLRMyInImEswmMosLtQbrD7g5o1/s3rOJqfEIFgoiUrdHtnL5odyIKbiWTlErInaaqd2K9dTNhaS+LxUrXVCS6PPMUp+cYoqCjUZnNW6HWwxRswcwO3YY0JGdSFxDhRv2+52DMboKsYVNtxHZbILTTjT+k/EbJVh62a8t/swbzB4KC4nQhSiponDR2d8WHFilFgvcnD8i+GhNHphAeY9AJtdCTFY89vJBKYp19H9kcEyMn8q+9oQnR6r5kuEWjYYfSZm3y5cL916TYaREJyV2XTMmonDfEx5nqF4r8GCoxEm7PPUpqgUC9AyVCmM2IJapqwusR0kmrBrFnLY6aUy7MWVYb8RkB5IizeffQ327gr4tUyoR0ImVCHlleCHQy8VKa5IHuxApVbicjL4xbAZO8zU/Lgiav0JgboZSyzFe8q+mzeb90Nr70jdJsSjXuH00p0AowYPiwXdfbvSutK0cnr2zKuaRC2FGBOtVO+9GjGgavf5qTjh51IKS14QoOw7ZSvEuJEthaSLWiVWd7vPKVJ7sLGB3knOlap19HxsaL72Rl4Wm2dydwCr7Kqx+0KktxOndWW7fEjiiqR3OOdPU/IZVfPoCo5+WO9pKJgbvBfFDAWJru33EXJAijabaTJqns21PlGSXux7xtxYqxXqRhWM6i/R9+qA6i3IypA2Q01zc4x8VOvDxpRc9Ct4IC8EAiEsM5iFriZI4AGwY8lLa44u3fjA2fjM4ufE+BnDJcD2+rbtiNjwf/XBZaSj+c49774N4L7mfjvZRGDorbKo8fjrXpwSKqTQ2R9fU2ujpkTPGZsat75IiI0gkVrNBTVHE2pUxTyVUkbjinal7RutXEmhQdvZHt6gNHB8ORbXt1JjARosp08o2WlkIaPKEKsCus83KMbwMx6AwSUxRh+2jGwuZxQGvVQC8XZAE6g1IDGx5Exe3ZpcQLQQ80JaV5+9aohL1CSxCfUuQtvr7LUFdfZZr1wuV2LZJ2Eh4HTxWolUmzCPaWoexuD+Pi5XJJjWoonpmrCrqV5b37A5qy84U7fVYnBRa5NRiEMi3ixzSNbZelJvgmQjYBty+PMY3BrRtpj6h0S/l8eGufWImvqPXand2GRKp+8brqtp1Za6FYgNgFFdAzLEaNR6bawpkHdTuSGuHliVGJKTbFwaqFUfIxs0Z3uowQc3l2x09tevPqRrR7SR44l4967Bftc2bWfsMw6qvPXsQgek07arDcYS2Eohp2AkBTFoTiRRWif3Togmr6PqCYJYWgVKaUZ4wSXH8vnsvZShmGGqb9GDVXAIKUhiXZKRJUaVjCvgx0jomfWU2Pcr2/VKCHmqv1zc66rT2bcLKZ+YvkqYbuBh/KekZOSdWopT8BuaZEp77ZupxKf83J5JysSgbJcPyQjp5fdQuafuiS6bTVAOuGuOIAgSzMOti1CuV9pevGKz5LQUt0b1BEdroe073ahKRuvuVt0WV2eoPtdGFtJ5JZ/uiSE7S9BUTnorFpB8L2gfUlo+YLyuuBaPwZO1eI/QBHMZnnrI1OU06xrrQQWfIUQUkWSOD2lxS5A0lXlApzrPMuxbWnVYc6eUjXxvkl3l8e23dM5/rIOUtkbdL7QdO7S6cfiD0/JjSvS2U7crbd/MBqAng1HEpopyWlE1COF6tQG3moz5VVsnpUbqQo82Md3LRrk+0MRmtBChVkXrbtI0UaakCd3sK0KyYVe6wSytmk+PqrqnlLueipiyQzjcX4NEzuezT7JncLVpRdC6H2rI0VlnXb3n44zClE2UMyV6zSCbHXRLpqmwbcV6Ih2DQrCG9bbthOjYfTKx0eAHd60bpVfEmVr2eW2jWD/OxVbVsitD/gxGE+zwsqink/1IcO+djqlwqH3tgGgVwJlDcy7JD/Cpnu2WCfZ+0fuNboJiWKPPpZjEUFY16Fassf5kbfmvAQ02txdgEmAAhw6NYuyD0Z7B9ubSTXVkl3BOgfsc2TrGhpPRswtIXqhD6y6vs982Gurdoee8JOidy8MbkMB6dyYuZ1dZEaiP0DZqubKkMPun3e9nDIGmNli+Pb6lbhe6QukmLrusZ3o1eJnkUkdlIwnERbx/KlOpA9dya27VIiHSnJyz74W8uK1FxlmQ4jCXseGWJY9C25VB/Gd7JQq4bFM6WLmuAq5YgjOMTMee6mN+aEj+AJVKWlxotfoAc8cIOR6oWi9Q3wAnUnxu8GF74OHV/5clLYT1UzasGgIQoRuzMAbc2qNbAHS/tOEXpdqnwryqkuLxmXs3tYphCuiTcUaecimsAReKmvBwyishngja6WGngyvaWGshiLkUxGCO4dober2aPJzgvcBBAPEE15l5rRZKq4gLJ9vYgPdnUVOwicmgSzE6vc2B2p5VYIibS8QUXtTmVFszR+92eWNJZLl8S+f8xzpIvX31irM3G1tr5JTZS6U9fA3lPZbFpsxFxHBwtZ6TugmhRGOhdVX22qmtUbaNfYNlMaM1CQ0tV4JTJ8v1kVKuVLrJt8RMXk1WKCWTQGnqB1U1SwpJtlF6qFCKC4H6DEGKRjAIQs4mjS9hcRjB+x8xuECjDw4my461Vqv+1LyHRpCqV5MdMaZPJ2pn8colhMC2bXSUuJ4NwtNGUFt0basTxtHBZsMDsQ5N745WZfPqRjGtuahMKK+W6hWOUWWtWkkeecQ/y82MkYhXQAc9tzXrPYQQrAmtB1TYvH/UdRx+FvyG+VwSY5NNuNEZlNobaJkBaIwUTAX7MXCIBT9xSrQplus0kzNVjGPMwNC46MzAMZxtcI5264vmAM+WRCiNSymTnGA9vNF7kgmvzErOh9UlmNrEpTwiIbo3lBLSMhl8i+4EaVzfvG9Dlc30ClWV7pV3H+mCel+nmpnE9vhIiong0LKq9Q2kN1rdiBlX7gDqhvq1SghOocflscwmInlQijlSq51gccmmrpEXY9iKeiVte8KgvuywkhENojMUezCJHXFbndoqaQS9Uiyx60ZkCbMKFbf4aA4rZvZmpCiCWP3icGhrldo37zkHqvvLaavsX/9PxGcbLb2DRjA41hQZRjyNMUJe0JBIaSWK0LRzfXzrfTQTAw7euDXV9ea+b9kYvh4YTHqs02OY4yhDQb3XgvRgUdZ19XrdLSi724GRtmyQuLVqyhbewkCwSs4ZfCpibFLfTKXtiBopbQAMQieJ9cC6k7hUd7Ru7Nc3mKmi9R9NQzCiLXpyYTBjSAtLzEQxZfpt/29kevi/p9f28EB7fkeIkX3fSMHkLFurhL7T+m6SOK7ppqqUboN9EispVrTDVoySLpJAFrb9Qu/NmYAN2a/ODoMUbNhNsYw35pPrytmsDYj3o9SdgvHM2kpsrcWYSzkzLOhFMDWAaOQMje5x063pXmsnhzRnpERNnR1J1luS5PCEPc6aTL/MmvfGBtr61a0XOpfLA/ub16x3z8jribicDIcW6MVo91kgrCdyXshpoZRC2TcEIceIJHf8UR9BDtagD65pKD7LYvMVTKO1EaBGsOEGBhSYcCE6hGsLzWRCAO8JYjTmQa/HtdhijO4XhcFWrTEtvO1Ok5aFJZnIr/iApD0jh12VqVTdWiODe3257qEWatkpKRK0W/ZNJKVACE6RlkAIx8zVqAmCKmswWHfvZoinMdlB0psz04w4kMQp2sHsRkLorhKv5GzVXy07aV0B6yWkfCLoS2K9cKqF/fLaPrXTl2OKlL0g/h5yd0fdLoTYTa3dMCJT7i/2ucKystw95/LQ5roWNXUOiWmiAcRoNGnEWIXOBqyuatC79zvV/d2WxSSzwtCQgyUmaqkTQibYQLs6UGWDoJEuhdZNGqu17vNxnV52qxw8SA3B1mmV4USeASObX5vN9i3rSk+JXsxKpkkjroF2jSzrM+r2IduHvwHnQosrGjMxrkT/ebaHLSlWSUZUcrhuLzu1FFLK1uPTbv0hHU4LBrGlvNJVjZHYO+KzbSnaDGdXG1+4Prwl8NaCF5bQRN8D1cdOTArNhpuHCr+khZRWkONnkDJxWW0QuxS6BqLiRozqqv9XaI01JVO4lzYdA66XB9588Bve51s53z1jOT0j6mJSaU2RmMjryrIaWaTuVyQt5Ou3wTBvCLBtV2S5ozt+LuKZoovMhpC9PM4s5zskRR7evKJtDxCsnL7slb0L93fPiWGBEMnBBDeHCnBvO1KVtFhgissdOa+oGhVXBcLiixAQFdJ6tjK4N+p2peujlf97QUlm6Oc0cun2gTrCMtxhuyuD98q2XwnarcEcgkF4IaEhI3E1ncJB74yLMQvV5Ev2zZrCOZ9QaSynlcdXr+iPb1nzAumMihKXTF7vqfsjrV5Jy5kYTdFhXQIiZi0SOTFttH2eZRwGGhI9RUQWpO42QxNtuK91M0QjmL2DkVrNVkEcIu0E0uq+VzBlbzoGmwRxmSCTz5zN1+bDltYDElcccGq9DbdZII9n6zHO5rdPJMXRm8KqD5oFrJiARF5WgljSs18fCZ5YGIZvmmSoM9i8clTF5uZcA01UkVIIdJYAfcBG2bLZ5lXAbHz7AW4IqAdvnPXYjFiRyk7brga15BMtnulysoZ4WFE2MMN0q+qD0ttmay4lTi/e4XK9EhEb2JWI1kbdNu7unxOXhMR7kjPbLE1PNAm+7nXCcMOA1BwcAjaj69T6kD377rZxQzJ1DXG6uZhVRgw+vO0DviJC2TdW16+U2FGC26aYpFRK1muiq6tu20jE3itLPqFY87A3NdWYMOTJrKfZ6rDZMUHd2oUQVkNKNLJfHljuP8f+6uvk+j6SPs1jTURRd6+2wNs6xNFzcqadWZecAGFdEunuzirfmNBypbhCx1B+H7C/Ovkq+ABs0U4bdjIhEONiPmujP+uzSUnsDBx6mTIcwSWS1jtSWqxP1RWiaTRO+yGNTnu3e7Jf33C5vOXtmzf03nn+7J7z+WwolFiQ2q4baCCETF5P5pQeT/Rkn7l5y0MDVHVVlflsh8jgb/76WAepu7t71mfvwukZtexsl7dIueAda0K0gCJpIa93nO+fk5aVnO548+EHPD68IoVGPt2R8sL57pkxiq7muZRyRltjvz7SW2QV4bycyeuZtK40jNRwuV7Y9yvr+Y60LCgm9EgvBnlpt6rOe1FEx3VzRlqgOvOI4n5UU8fMNkCM1pjUutNKsSx/31lO94QcCMkoorVsVG0mzCnK3hqtqVVc3aqk892J5XTyua5GL5WgFyOhdKuaYohOxFAXxzSaqohP0YM3o5PPlzlF3ZvGBj02Z1WZ/03r3jvQasWfV1m9FWviepWgwXTf6EbnNaWJyFBlNj5GADEqr/UlrCdZykasdtCY5XVCxLXEvJdRy+4SLkJ3y3gYvRLrK3TUoNEsSF6IebUgFoVSNi61Qquc1xPruhoc5YeKmTDiUI6xnCREwnomEij6aAmMBpIUE2UVqB2UYhW0B6CRCAy7DhCuV6Ms5yXTRdkuF1ThPmY0g8aFcPcJCELShnQ78LYKZXsk6o4E4fF6JYrw7Nk96/lEbUpMJ6v6VSHbobesz1AycveMuj1627c7GrBaH0xcK9FhO4O+kw3Lix2QIp1adm+8H2oaiMGuXczvbLBFffbeenQxsm0b5/tn5qcVotG2SyVHg0AR62+N+aoYI5eLISK17p70KLVAWpb5zMc9LsX6qTEtGKHWGZHLQqYhqjx/8YK3D4+EvJGBvhdCPtGq7cmeO8nJTpazBFLO5HUl8MxW3LL4ujXJKx3yX5iPnfiozIDAa7k6ejJGL0w9JIj1n2otdq3BzogULJir968QqN1nydRsiErZfQ7UYDs0IMnfszdiEmoplLpRys5eroDJtm0dsiedvZsV/fnZJw3GzBbgm5p9SnRptIB5ifWw295wKPJW3eU3e32sg1SMgdOzF6wvP00tO6++WinX116EBAiZtN5BSDSUrdh8Rl5P3L/zjrt5Zk7PXtA849VaiCtIK9Rq1gIpRWozrp5pjVXevr1QtPH69Su++t5XQJXPfPZzfOITn4SQ2Tts+9WzT4OTWjeq9+K9IE2BvC5oXtg2o0CnnKwh302vK6XkxADLVkiC9s2CVbxabydYD6WW3Su+3Q7a6pTSZM1yMD2zGBPP757RWqVsO1KvZru9Z3o+Wy+FQsoGs23b8MOybK8201IbeLW5drreoaq5iOI9n96c9mzQpg54JiY7JLTZ1zs9HTFlgDI2pfGobFGHaDRXd4St3eCi7sSElIw8Myw1Qlo5xUSrG3spBiF1u8fRxVuHN1erBZdst76bC1jEfCKmbNVizORlYT3dEzA6sSR3ppU4laRVoaqLk3bQUoz56XBlEoP8ZEnsHdpudGZxQoFBaMPYzyFD9eFshHVdqbVyvV4cQg3cPzPl9h4zy+klkldSWqjl0TTlSgV9D93eJ2pkzfcEbWgFrZ270xkJavbv60o4vwPLM7a20suVWjcjAblCiUigdEBMnHSo8uuQuBosMyxhwTUTW6/E6PNCAVqxg7cNH6IYrPoW17l0odZWTDleWzuSvd6oY45RFSJTDmlU9zoOfzvpKftmzy0mI7Z0W7/DO6nW6i7NbhIqmG9XytR94dlyz+O1cHe647Ib21ddgzL4AL8FYPsMKa8QrR1weXwgbVeGzmXv3Rm0q7sd2/zWks2Go7dKe3xlMFzbZ0shiiUAxrgzGDgkY292SbbXJCDZ7V6qnQ9DjsgqUmNq9qrAgNOHWICt/ZQz9/fPWJYTiiFD5kHmvdOysa4nmkZCznQ617LRpbNG+7skAW07tVwNZRAh5pXWuxHYvoXXxzpIbZcHtodXkEz6PueILon9WijbTlgapyCkJbHXjceHV1wf35CXhUanaCGEO9a7lyynO0opvP3w62yXjdQ3E4vdr4g2Xr36kFevXhND5HRaiDFQWuHxcuWD9z8gD2Xp0jjd3fuDsLCG2KZc71+y3L2wXk4rhGg2DEsK5LNJ6iTplMtb3r5+dOp5IopZQrdaLaNqRsootSB5N+jxdCaKyfyU7epDdgHB4C+bebpSeiPhjDM6KQrShRZcvy6aavRedpONCREjtSVwX6HB4uve3B+YOhrobbNMSYcYbvN+3qgGmgeExnAZxlurinlIldqpTucdfQWtxRQJiIQIy7KSJVLEFLRVu/fJnEFZNuuLiKBhQaMRRCQmuttbi4hJ5HQXZg0BSG510p1yZpTa2s2JeE0n8n06Gs9iNg2G/5u+nkpE8mkKqV4uj+xvv0boyuIafdBJIdKCsORE3QtVbYjSlAdgH2r02PWsSybEZG67vZNSpOw7tRWjq6N0bFg7pHvK3UptOyEKq2AW7K8T7fo+8ZRMhajuFHaCCks4GZnm/ALu3kXjPdcC1w/fp5QHXrx4RnJzRYbHmJNKSmssMRrRYkDPEgjRpXpcWLb13ZMAu9eVKyFlanEiTO+0bTO7HSfg4BI+2+VCzpl926YnVtl3TqcTrVmAyTnPoBVCcKUUq2gn869s5JvAOL62Fnt+1lcSaq8Tuu3VYPQlR0K4UjwI06s9d1GKmgvDoLt37dRmQtO9Nq7XC2uyQd4YbWopxmzkhWQH914qa8zk7Crt6sPvWJImwVi5Ka82MhC97+c/Ew8wprxivUB1NuroA9vgbsc4MJbI1mZ9PbPkiATJmOtw4nRenalo84LBGb2tm6v2cjrPhCqk8P8j7z96LNmydUtsLGlqCxchT6qrXj2yUQ8ogAAB8t+zw1Y9UDxZdUWKIyLC5RamlmRjmkdegASYjQKJxA0gkUgcnIwId9+2bM35fWNsIYqK0Yo3W/YbGq6W7Wuy9a/+kl9/1YcUJVHnE9ew4ruekldCSmAU3jYoDSkFUUtvyb44z5xfn8BoTNeRauI6ntkpLXeBEjEkeVgXxbhO5JwY54k/ffnC6Xzh158/8eHdPaDo2h3qzoq2vMA0z0yrLLSNMTROQhaqaJpjT7t/B6WILbUs257Ayr4IQ8mBUhXLukinxAi9Qh54cjMpWZxTGmFx2aK46QaatiPFQAgLYQ1SWjYSRVXWMM2rlIeV3D7EJaSIa2IJK8pUvN8YXVF+D+9bnHV432xlx0Ip6xbM4PshZUzGOhkxlK2dbxCIZdFv3DbkwbbtJmrZqNClknIS+kSOpFxIVW1IH493Xg6OmiSmXjXaxG3vxsZKq9QtHmyMhuLIyM+Cs14+TNsI5I0j9rbfkqvu1iVC9ivw5+hvrtKFiSVLMMC6bWewcQVTJMZVyCFKUbTD9QPd/oDNlbU+y03GIB/0Kg+RAlglN6pUJLAhJexKCGJ5NUZ/D9poLSBdYxzGbP0T/dbPUdsNZQukVBn9GWNFCqg0HH4A16HWeyxChq9xppHNBqtu8f0R5QeUFYni6fxIPT1jHKjak4sgebSubELiLZ34llJkOzDlNvVG7sg1omGDAcvYUhtBZ2ltqEp2XjW9kdI1aSMY5I3lR0rSMauVdVnw3uO9/35ryjl/H9sB30visrcRViaqEMKKPMn1lvwTdFlK24ETxaigvt+ULSrOJKBGKfGr4qFIAhizRb23cbjc5goxRfI8y4M5RawGZYdtZr1RUrZdaVEQwkpYZknMzZZ1DZg8C/JICyJN2YaiNc4YiYFrK5OjDYtW8iy7yi0o9kZhUYpNgMr3NKPVjUwONs7em+lAuJwyDvTWo60cjG8vlkpbUslb0AxqlbFmLQmvNJ2Rn1W97f+yc9SsMU2LQhPnK9oK9Pov+fVXfUj1+z1D67nOKzXJW13f9TICU4qUM8s6U5TCu4Z1HhmvF+ZlJFMY8hHjGp4eLjx8/ZmhaVA1UmIQdBAJ7ztCDDT9jt/+7b9DGc+ub7BKWH3OOm7uHfM88fL0xPJ6Ync48O7TZ1zTYZwj50pYpSchqBwZC8UQoCa0SuKgMo6Shc/WtANdr7/DZaFiY2SeJtZVditDu+1Wctyu556UCwFI24jH1s3GubldTqcT83jFN43scrRiWSLX65VcwLtmG3FC6zxYeQOdpnn7gZSlfc5iO9ZKCPHrGnAu0bTN1rcQT5d01SSGLgtaD1tktm4Nd+l5VMIaqSXQtg2H/gDKskShQOuNPaaQA2ZdZsK2w3DWYo2GEKlKdlLayYdYa7MFZvz3aHgp5V8tnOV7YrfxpTin5G5nlMH0u+/pwHWZoWaU1Rj11t3aFsJVOOIpBtYUcUWhfSM70LZF6zusVsIvi2FLV6UtZSija4+MFfO2j7LbG6tMkGRBU7a3TwkLsKXGZFlvrZfCd6mbDyijdaEUKRCrZodrB1T8JMnEuEBcN/hrQcisTuoGeSXFkZoD83yhN/33gEBMgcZb0UhoLTQDJWXqFAOqaVFVXjzsW9WiSrhFG08pVfh4Rb6XaetapShcx1wSrIvsOLcAVElst+2K9fY7Kw74c5jAue/kCkEHyb8TQ/zzfmQbI8a4fWbebgVb1UAQQNutoNYNMSal8IphuZ5pWTFuoPO3RJCQytvNyLVoJ8Bapd/o6NBYTeMdqSK3o230hkLGc7Vi8srgFdrI18DrjW1YkoQmVMBsRoGiJcmXEfWQjO4X8nyVw/dfkdrJsi+txm4g3SjEmA3RJEivDbWWBTRdN9hyDIkaF0Cj6yZ2tHLTl27hBsHNGW8lyJU2DbU2wgxNKeG87LpTrgLl3tLPf8mvv+pDyg53BA3RKIZ2h7Ke+fWZWhKNN5vquqHf32y6jgzLKIr5Il/YOI44ZVjXhaqkvxFiJhPx1tB1PW3X0/YHMA3N7oZSAnGdmKYLRsHhcE+zJqqyvDw/CppGN7TDLabbywL4/MwyviKbLQjzxDxfKCXSNANdL/HtuqnZm65HG8HTKCo5LuQysoaVaRoJIVKS53DYU+LK67dfiOuCcQ1JN3THg3iUSqVuaP9sHOdxJoUXfOOwxtG0sh/LOROWlen8gjWK/W7A9HtKrbi2JabIuswSgfY93dDj2oYcM8RISishRcq8qT5iwjmL1xat5U217fcYa4nLxDpm5uVCzAlrHc44+rYjBFG8t7bBNR0+rkzLSIwrKYNTG9XcGLwzgMMqIAXCfCWlLHNyrbeyrPiyjJXDNcUgO7ftwVDzpmjJklyz35NfcsA66yWUEVbkTUL+fDXL9ymGyBITGSmwusbj9EparpweA7v9Ab3FbtckrDdqhZLQG8ePDQ5at07Xmw68Fvn/1tZsfR4pqOdNoMfWw7JG3si1aWFLXhlVyVlvFQHBQgktQ6OsBPmxPbbr5euxgUJFShiI68g6TWjbYpwllUrMFWNknGp03W4xglZ/SyemFAVii5Rz7bbcM9ZtQQpN1RVtC5BlkrCN3rSTW4F3nhyDHAzKCZy0ZNlfbi4447fbW07YTeYolA7zPRb/JhKtpRDWQNO02w1WRtFaK9ZFErchrLTb2Mpaj91oCTkmlNFENK7rKNOFME8wzaijpfEHVjRLqHiV0SqyFnG1eSuCyJQCbA/ot51r2QgOtWRUVBtoV3bWbI4upxuyqpRVOmOCUZK967pmqNsenUwKi4hY0aAMa5SenjVa9llAbZrv40Oy9KFCTGjnRemTZBTvtMJgSBmu44Wc47/CVjm069DaUcpKLZHivIx5q4z8JR2rWIMIPnNKqBxl4oGm1rSNHf8tRNB376i64nzC729ksxEScXxBhQVnLa5v2e+PBAx2HmmbjqHfYdsd47KSpxPGe+6OR5zv/jxvXUeJtW6hANd0HN+9p90dgcwyjzSXZktXadp2YL87okolpcjp+Zm2v2V3GEglMOfIOj6RwyuH/kgNgTBdQSsa17KMV5TZ6NxI50JbI3bRaslJs4bIsizSzC+FFFc0O3zbsC6B8fxK2+2wfsAbUal/v/FYS9nfcvNeJHcliQNG+Z674w2HfWAer5xfHkhhkh8+osR2142ALYNLum5gtzvQdp38maaJUg3abmy9sMLWuk8pyqg1eLJv5W04F1IqhJS5XEdqKeyHHb13WCsPjMvlxCCTvW38GCm6oHSDkWncdojsIGVCCkzLwjKf0SO4ZofZkFZaFZIW+V0Oy7afkH6WsQ6cJyeJ94a3SHuWHZVaoryF5rS9bcseQlVNWBdCWKR2YB3KVLrhSNdYri8ry+mVkALee1wMLOtKKBXnPd46stkcWGXGZvk5wziSqqicCHFD2myJLauFv1dz2kIgMq61vpGaxeZXkqzC9paeN/VJrRtup8qDTPPdjyQuJlHS57RKoX1ZxMGW8/dQS0oBkK5LTlBcEYv7un73pJUqaB+txa1VlTiFjNYbazB9H0FR2Qj7wvV7s2Mrpej7Hcs0o438HZy1lCwF04oWbmWtKGTkpt9oIjFue0z1PXTSti3XcURrK161EOUCgZBmtBJO3rrObOMCZEQptyrnvLisSqHrOtKsiGGinL7huoXqbpimwlwV1Iu46drue0iB+gZX3qy734t7b4BYieMbCmsI6E2TUqve+mV+AwfYzWQi/TGlKyUtoqRXlbbvoO5Yo3zP241Us4wnSpwp+W1srEXzk1bBRqVEDtLjkr9nYg0r8xqZg+iDvBcqe80VVWeMXbFKff++5yTBr1r/TP+AjeylJZIfllkSjbngnKX+W9hJBWTRjNXb2MGjXIdxE4SVZboSIqC/gfes1xMqrVjb0zYdFcN0fSSMK513ZKUweqDxnjUtxJxwrqUb9qS8IXJywjrFfujpvON6vjBNC/N4oubArveEJVPChev5gXa/p9REWGZICRDoY9t4tP/w/Q28lkKNM0YpUq6klHFeQ4ZKlgdcjDRNy37Xyagqit7AGEPbt+RcieuIU5ZLXOS2gOQBvHOYZsf9h0+sIRKjJJ6stTIesB7tLKkETq+BcVnQvkMZRVozXWewtqPxnqax5LhwXuVGV6uUB5u2QxtLWKWflMLMPM/ydfQzzXXCaIPVAvM1rqFpK3FdCOuCqZL8qrUSU2CcLrhtt9e3Hq00zksfK6RIKmmTJ8tbpm/7790cba2MAa1EYEteUQUpaFMJUUSZOhZi2Lotb8XUurEWa8WaSOVt/4OQHbLZlCeBmoOksOLMOl8Iyyjsx3WWHswYqfNWdi5yc1KKDduU0VsPzNbNRbYR9quuhI16TpWxlLNmu/0lsbdWSX+6N7q1lqa/XLDK2woG2IIsVdQqqYJmGwG98QJTEqV4kVFsCqvAS1NgCSvt5vtapvk7ASFvqdOcEt65LXhQCBtrUmgmsohPZdNibgeANkZGdspgzPdlFnqzUSuV0c6xhvW73lwrRa5sY9wtSVcrKsv4FupWrN/CNEUkn2/24DdHlVIIDSImGUdX2SfDNlIseqM+CFHvDeZaw4pRGjfsoWbRwUwZ3Sr2zZFxlVi7JVGXSaLo22FZs+xnoxIKinZO4u5bKd9Y+ZmqCHBapbzdfEURJB6yQk4Z9EaY2bh4NUuoqCjFmqSv5bueru/IJWMXRU7A1jMEAU1L1SORcyGoP6PZckykmFHWMewPgBB4unZPyZGYJqHBVIUxzWbqlolASiJ8tdYKn1JrjKvb5wsp7xPJVGL9N3CTevn2E33TYIyllpX9/kDrDdex8vz6yng60XQ7zperlA7jjMoZYxv8uFKVJowXqIXVGaqZmZwoBowS1cKyrDgnoydKZDo9gRKagHeexmqqV2KnRKCXjVeUDMv0zOmlQVsvbzvGoyusS6Dbd+z3d2RliesqOoRS2Q07Kpp5jaL/oLCmhZzC9hY04IwEQaoXzblxnrBGUAlnLbEESpUEz7qKAt5qTX/4wLA/CHTWN98XlzHOgqoxDbubWwpJ4rK2Q7sGi6brh+/7p3G6Ukvken5lnmcOh1t2uwMlKHy/w7QtUWvi5uGyedM/kIlhJStF23aiLG/AKiOhlxioa8b7Brt1ktaNs2i1MM8E0ps39XsmMG9U8wbvdrhmkDDG5lQy/Fmgp1Eo18k8XenvhHPZyQeMNjhjsRp54y0ZpS0hLoQwSYnaOdnhKIVSnlwiOayM88y6Bpr2Qt8PtL6RUZqWMZixBquFLl+UYomRUtL3gL2wGDOw4rXBtS21KlKBtN2iSpSu3fd9FvX7TUtvYjmjt4OsVGLe4tclfz+Q3tBagLzNZ2FJ5ijab1UFnBzXKyqt5Bi+j4Nll1bohp0kygC7LfGM0qStJ/YGU3aukfGf18QYcFufLW/hhlq2u/mGbdLbiM4bJyMiJWBW5xwpRVE71C3ApmTfqBCVBbylM4Unt8Z1U8dnYkh//lnfvFWlFJZloe/7rUIBRgl667uUtMIbISMnKcSbridRqTGSry/omplev9LdeZwbaJpBDqlYeL2cWbP83lYJsBWjZLztHErLQR9CICVDTCtKb4R7K2imZbpwvY5oo+mblq4VTqP8vbb4ey3khKg1NkQVZeH6Okq6OAZKWDZ1kby8pBiEWJ+zfMZr3V58hA7j/EZl2eodb/s6NNvhugVatPQiZT9bQBUqkpIsRdLLSkuBXwN2Q03JN+PfQE/q9ac/koYB37TEq2d5ecC7RkyZylGNI8aVuM647bpclSWXynp+kYdEmnDGMI8nijL4bk9/vKNpByiJ08sTT0/f6LoB74RU8Jbvb5oGa5zEN3NiHUcZTTgjbzAms84T7U56BSkldn2PorBMI9V2+H5PCQuX52/ijFFI4MJI+TJub6ZrWLeSn9koFBmjDFlZrGtpbbfRMRK9lsBFqTBZy/WSWeaJkh+wqtIMe/mARHlbVmWR0mxVtM5jjvfs+j3GOJR2YiTViC/KGEryrMuMKoVWK+o6cVlkVKKaDtMMgnIxGu8alMrbw3VmvIpUrh92HG/u8b7ZiqsCv3kb7aHlB/86r1QU/e5A61vQwn6T/2yMM1uJMr+SB7WxGNOglYwuYlzIWQIWAu9M4v2xTm6xFXLp5cAoiZzWjShdKPqNcL2Fw7QkvpzTlGyFY2YdTb8nl0LjDMo2tP0erS0pJ5bxTIkXSFleqHKWW0cpgsNpelCVkmZUXFDWgW42GOyWhKzClcxlw2u9Ld63KLs1Tm5HW92M+ue3+LIFNAQFJActeQvApJW69Y/kmZxZ55Hz01d0WlEl0nWt3Bxi+h7pDjHilDAmS5E36JwyxurNRCuHh3AqIeUVsv4emkklCZzUSrDIKL0Ze60s9rUhlSTVju0QESRUu0WhJfSilabaSimgNhKHfiP+2zeKg9ywKoWSymaQlbJ5CKuYqo38c62V/LmVtIZqFR6jqqBaqMahh3tJqK6zdBbXC3l+RbcNsVSM84zXV57OJ7JWfNjv0MYRJhmzu6b9ngCUbIbeKnoatMH3O5zvWJeFl9OJh4cHPrx/j9/2ek7JTnYtEvawrkFjtlu1yByXZWQZL5KFUcjIVsnPt1IKZxR6K+Ci7TYNeRsl/9mAoNR2u8qJlOU5YYzHNXu89VRjhcySBbdlXUtY520yErebm5FErgJN+T7+dOr/7ZH+//HXX/UhZYDWColhnK7MCpxvafueZndP8ntUmGB6JqWA8w22GWiadvPprLRuwGlY1sCaJDXVNQ1LliViUpp5DcQkLD+j5YcqhJXT6+s2PpKR27wshCDz+a7taUxLLYp1XriOZ8bTE8pUbnZHlhA5PX+jW0bSOkOYyM4wThfqeGE3DPKWtQZyLqzzQlGKYTcI4w/N0O9RTjBEVnu8EjJ6Wlbe4KB71+B9xzxeiWEkJ0kvliRaBGstFIUuRUCVWdI61lrmeaRoJ9iodaXrDHfHGxrnyMaimp6kLMuyMF8vEqGtevuQeaw1W4elyg+tkYfDunXPhqZl1/Yo31CxaN0Ql4XxemWcp21MA13X0x/vaIfj9gGcCfNFYLx5K4xiBGArMUAhOSN7BLSh6kxSBVsTOQVKWlFFMtQS024AwcHMy0KOK84Zmqahcy2ukX/+FtVNWRiHWiu6rsO3HapqcsysOWFsSzvckkrh+Xzi649/YBlH9kPPzeGWxksKtUjrWCLkKaJy2AjV24I9Ce0gLkJCV2oruCIaiu80iipCRxFubjT4bRdUsAIlVTIqFRvr26gobUbnjfZRK8s8yY40zTgDvmlofEOpbMicDZy8kTsqkgpjqxNsfyCxDmjFMo3by034VzxD2dfUTf4ZQpC0mKAIyCVt40slXLhtfCgUfPPnWDl8339U5CYHbIecAJHfbm1v+vqY0uaZ0lt8PW+dnrrZBDYiPxsYuhbpVGXZq7h+T7O7IcVIun6l7QqqJpQqjPNKLg1rrdx8+IHd7R3DsKfmwOnpkeXyIvfmummCrKNpW+aQNn0Jf/5+5ojRhXf399zdf2CZrlwuJ7w1xFxBVfphkGmKEQ1KioEcZcfktMWSvxu7hcwBKCFGYKRnVpWgwdJWZZBepFBe5O9fN45oIqwBYwqDbmga6e3pWog1yX5U2Y0BWKnrIqizKmPc1ls0GVUl4Wz/LejjbdOAc/LmZiWrH3OiUYbDux9odEe8fKM+S9rE+I5UhaCgSiTOF7JXHA87hsMNgz9QjZhny3TlPI6kqsi+I286867xMv4zjmrCFmt1FAxd09FsHaah6eh3N7iu53x9RVG4vb2naVua3Z5OG+ZpJISZsIxSzut7waukRFwCylmqsaCh0Ub0GVq0Em3T4GxD1hXjW9p2v9mBBbIJbBRzjW97eZNbZd+hFNK7UJa23+F8T1pGrk/fhEunZeSR04rrWm5vb5inkbicKXGhGtmtKetx2pFL5ZIvzEskorBNJuXMGhfWMEtsPWa8a+iHDmtFvnc5v4jbyjaknPBe2GFNp3BIhLnkIqrsXEjLylpXwmZE7voWq4eN5OHxTUdFdhrzckVtS9q6LfNTzrJ/KhlTwdRMjmIlzowYY7HW0w1Hck44b7FuhzaC8hFuX2CtCzFIJF9rtY2bMyKRTKxpIeRlS0Y2aFspxjKtMnKxxlNwGOdx+s+3IWUtKW+ac11g21dJcGJz8mwsSeusFEwLMoYrEV09ptbvE72yQWvlrVhDSluEXcqxbNFv6esI3NZILA9jNLVIhyemgHMbJaGKE0kCF5KgrIptf+vIUd6ewwZ7dc4RlpmmabYx30adUHbr3clBL+XV9D3hiFZgQFkrBlott6aYE37bH0rTRMIPb4Xct5uQetuxbNUAZRQ5vRHv33Yk0icSM3PGOdGBFNRWSZDdmzbCCpGv3UyaKk3X0xzvUWUWh6i1KGdQ7Y6KwZsKtuVwfEfT9OS8sEwj5CT7KN/gHSjboFxDnhfW6Sy7nSJhJ1USd4cj3e5G1Bwo1vGVZRkpFZq2w1snOKgqSpWa5QU1bnbwtE1XqJVSM9qIqsdaEZbmID6xqtT2ImbQRkHNlJi/h3V840lZE5ZIyivzcmGNCxUtdRxdUFrhkZuV71pKVIRURDBpZI9sN/Cyc1bYpn/Jc/5/y0Pj/9e/XNOgmxatHftuT9t1PH77E5dxRL8+09/co20hdy37/o7d4Uiqhcv1xOn1iXmuVGWpymPaI7sPv2Utmtenr6S0UGpGGYc3DQmFLgmjZKbc+B0tCu0tu8M9ze4W0MyXM9eXr+T5hFKJxkDvG0qM7Pqe/W6/PRzK9gErFN3Q7Y4cDkfG84liIusyQgk473HDEWMd67qiKIT5yvUyws7SNAeMalB0aFNkrJciISVSFbSJsxKNVnogrBPn8xNaadr+huo6VKvQvqe5eY8aHXl6RdWVrmmpwO1hz/u7G8Zzy/X0QtXCllvWQAqRsAZCzoSqUKah6IbGD1gq6zKiKrx//5G2bTcS98IaA2vKnJfA3d2BodkzXc7EVdh6u777Drr01mIpTOcHzuOJogr7wy26HuStrAqcNpdIVQ0Zx7D7gCIT5yuX8zNKGazp0KaS40pcJ2plIzNXDIv0sZRiGI5bxLYSw5WwJHIOmxZloyAIshUr+WnKKiDbqjYFfYpcryea0tF3A3//N/+e+PFXEhOuoDaqgDUWve04tdPUJKZYpcAbTSmaXAvOGLKqrOFNfCCJqVwqeds31Jq/93ByLltRVgmpozqyk4RijRG2W18peXtpAJYrp8cvPP38J2yNdJ0k1FStxBBwb5TuXKSrs431ahUayp/FhYUYA955YgxSjN88alV8fqwx4l0jwSelvyseShUIsWQi5O9nNoTWW68obJ2mFJLcdrthI/RvezdEymislcT2W+OgFFQRgFdF4bx875XRxCgoJasttVQh+ss1Qh7aFHJaoUihOpnNc3W4k5szlbxe6W/vWWpDPJ84Pz+zLFeG/UFG58qi97ekCut263e2RVtP6ydy44hB8E85aXIKmHbAlYozRXZpzuOsxjUtznpKrTw9vzBNE42TgFEB8WmlJEAApWUHbAqqCI0lLavUUgpYpzdJqIxDG2tlBwqAJBFVldFf0wpRJKbAskzytdaeZujkxQMjPSjVsTJhHfi2xxjDMl5YYqDxMqnJ6t9AcMLv37O/PUgWPwfiUlCqkMLM5duPlOVETDNxvWzsshXfNjgKu6Zj93GPcg0hJ06nidiNDMMOqBi/Y+8l0WRqZs2BuERUFWJ11RZlLVVrijaSmFIa42QXolRG5YXT00Kqlc51NK5Fo1k2J1VMIii7ub2jGY5QMm3rub5e0bpIbDZZ7o7vabsdrumoOUoHJkZeT4/YZaLfv0MdPIoiWKGwsMwLMReatodeKAXC8bMs64xWmmE4QgqMz79QTYPvBqy7Z6GwXgUDNC4zP//4e9m/acU8z+SUthtYw+U6cjmdCFFAvf3uKPujbodSitl6agnc3r2n63rWsICxuEGYd77tGPYHWu8osTJdfmLJAaVvUEqzzjNJaxpvuV5eeXz+RmID72rH0HpKiSzLxOXyQtMe6IdblPKkGJjXC9fxlf1wZNf3aKNZZ82cwhZDl91DTlJqXuYLayybmRbW8UwpSczIzmGtZ1V6G8tk/KaZ30KG5CK0E98PVOMoypC1Q2tHbTriOksFoBS0MZuCQfY3ioJ3Fqs3RiQKaywxizdLwhaWXJIEC6jkHJinCxQZmxU295TayBNaRmBGawyyU8163cC+gVoglcg6L5y+/sjztx95+fojt8cdu30vHiLSlp2rvDmD0mZ9laF7+b6XkFHin3swOQaJxVcJAYnSfKPpa403Bm0lAq6MhEtykRSaqnLzK1n2clLItd/7T4LsUlhrxIab06aoSHJrQqONjPQ2voMs9LOk/lzjt1Sfxmy3zzdlzNsI8c1pVmtBpyjkFWPI60JNiaoNth1YLy8YvWBZoTpyNaRSuF7PZAVdf4N1Pc0W2njDE6GQFJ/RW+k8gWnkEi2iKeZxZK7n7wGXt5tTSImUxBdljMb4hrzpepqhkd3jlqrz1qDSQpovG5c0kWuiqD9bwlPKW4/NopxDZaFw5KJpdQNZ4LayX19xxjH0A6XI17fGRMgzOXtMu8O1PRVoW5meeLO5spT49+L49Bc95/+qD6nD/UeGTrGcZ+bxDMbQOUN73FPDis0zqmZMO2C7ljUuLMsoUEvtOH74hD58IK5XpodHLpcLaR2Zzy/cvvtMPww8fvmFFBa8zmAqrpEZMlsCJ6bEfDmxjKPICFNkPD2Qrg90jcfYFut7Uli55kwMnhBm5vlKzYlu2NE3FqUq1/GFEhcpyeUk6bZlxHQvUvBTiuvllbRIDBhVmcZXtLXshh25JEKYcdbiG4+riqbrxN1TK6pqjPLshhvCujBOIyllvC1gIlMMOO+oSnMeV8gLqcISIt57dsNAyoWXlxfq65mm6zmfzpxPJ0rJtG1PN+yoNbOuExRB0EgPR6OM4XQd+fLtC11/4Hh8h2sQ1uA6M19ema+vQl2uhbbfoXRlDiNhEZHk4XhHtQ7jd8SiCRtJel0jl/MLpTzS+p7qWpTKlLwSl5lVaxbnKMA6TcR12pBMYlO1Zk/bQ6VsD/7rBusFasF5WUprDc6qjb4A6Y3cbSyVwpKvjOdn6nOl7Ru6boe1HRlLzZVc5MEjIkuPdg0G6feoTdVRikS5heWIwIJrxSgRY+bimJcJaiHFhRwTZbuxoC0bC0Nepqrc7NjSfXUb2eVUyFVwSSlnlpgJWcjcSmtiihJC8O7PlAs2T5ZWlDVtolG7EVT+3I95O6DCGreEmAB8NXW7jQpg7u0t32kLquJ8swVFtvEZCqU3QkgpmxBTAgdaKazzgBRM1dt1SQkcuGhRiJjt762MRldDKQLDRcuuRSmNs46okrzIlbolIN/i8jI2/TPzTsaktcj+rVpN0wzM1zOWQprPUmVxLff3H7FOdrTG78jask4X1hBwzskBrarYxbfOmNOGojy6sJmOQZWIUQUcXGOmVved5KGV8DedlmeIQHEVvnV0Q0/VhhTka5/jJu40Hu1BIeQU6+T2VbbPqGkHQimsIWK0o+l7ut2BZb5SpglqxRtP1VWkqqQNl6TJVZxrpULb70gpyG47SifLao3yRzSV8kZ3/v/y66/6kHp+/IWRlTxf0GQOuz373XtSE4jLFV0Slo7d7Sf2H39NDROvDz+yzBcqlUSkUXA4HPG+Eyhrmlkvz4ynB4ypRApPlwuqrLRO4xSicwiBRhlsrcyXCyhFsTIXjnGGXAnLjG2EohxzJYaJ5/MiPpWcqHEmLiPeWIzvSWGRzozSKN9x/27HkjJGQ2Mrxhou50IqkdY77vxHlJUuR1yutG1Psz8yT5vsDunH5Cw7Iq3F6ns4vGeaR+blwhRH5tUCEnN3EpKj0ZnTIhBbMd/KgeybhmEYeDldqGrBN553Hz/K32dL203XEzHK25qkpgz/+E//hPXSe4k54pvEdH1FM7GqRJ4FwbNrLXb/jlQNyxxo2o1pqAvWdfT9AeUHsrb4bqBYTy0JWzR90azjK9P1kVyNVARcw77ZUUJmvJy2Xc4qBAZkdKerwvZeFBApM68zr69nxvFETbLD0cBuW1K/9ZJq3Q5hpXGuYVkWvj1/YZkSqkRubnruj/f0wy2uO9B1Lc4ZllVJEqodcN1OkoSM1JjIQegBSgkZgpxZpplQCnVLhEm2QFhxtSRi2mzKIA8/rbYe3BvAV7gWKQVSKts+wqM0FOepvuC7nXT3DBiVCeOr7HeNAgRj4xpPDhvSKUcZ923dmjUE+XNpLaPXlL4fUBojSKoq6B7vnWCdwoLVhmoldZhi2ArJciCI0lwJYWIjHqSc0daIe0pp2SNViUXLyLdgtcXZTTZaM7lE6VBai9UG7QQZVYuUYgv/iulIQW+Uklqz1A22G6Q1b9/3KnlvVTG1yj76eEe4nCAtglLDoKuCqKihUG0WKG7bAwbjLNaZTSAZsSVgVMZYS8hlGyFmSImHrz+yLheG4YD2vexeixRi4yIvpo2zlCRj1rAu5FqxTUtWButajrs9enuhMq4Ry68W43ZRkioERYiZUCYJSsSA05ZcBI5bssANGu/wdkv8hQWswpvdVtBPtMc9w+0HqWCUzHh65hoCulTWuLKOXyVZGf8NjPtOL1/p6kprNc2wI1RFnlZSCZQs9kuj5c0px0hKkafTCWpg6BvCdEYVQ/GeoORtWOeIUoXl+o0UZw73v6I/3vL8/I20jCwZ1uuFWiqL9eJv2ZbG4DG2Q3dVXE1hpJbKsD8wuJY5SOqlhJX18soa5I0+hQltLOMyU4pCm2ZDuFRar6hGkUKAnGiMwe0OtI3HK0utkWm5cjm/ME8Lh5s7KoaYReugokTyp2Vk6Hu8FgzQzfGGYb8Ho6nVsE5n1vGZ5fJMDSuN93RNA6Uwr4HrOMH5St96unZg2N0w7Hr5cNSNPL7MzOvKdVq4xhnjRXHufCPjoRLoe0833HN3+44YM+tyxulE02hS1KiqBaWjLLFKdF2jQBuMa3DNgGl3RCXASrQhJ0OxBdtVGmsobcM0XQkhCkx362oYrXGqoC0YbQXa+3a7SRNLFEeRqoVd1+OUwti38YYs1hvnaXzzfRdjjdC6G9/jXE/T71HKQVoxdZU34yw3Oq0bqEK9CMuIi6vckDcKfJFWyxaP12At0pcVU6/W2y1IGVTTEbNhGScZ4eUqIYTNsYUWyaRWZhuBSeJNayU/G7XKQ21TxKwpUlB0uxtqXjlRiLmyM4a6Uc3RIqLMMW2JvrKNrJVwKClyK0nyMBRVS0HlDZT7XaHRYp0Ty3EK2CDKnBATndthN9irekP4vPWY3vTpuW40djYqDNhNeV9KAWtlvJolRIHSaN5U6RXj5ACrRW5eaWMlKv4slpRStNpKx3E7jGWvVbablNZaQMNUvOtQdgUFgys0tuX55ZXrPJJyoT3c8f5Xv6M1hur0Vm1QGO2INRDnM+s6oak8XSb+9PURZRx77zC60LiWtt/RdDtxdBUZaaI1bTfI1z0GnDLyZyvyPdBKRsYpJSzLFt/3FAuNt5icSBtt3RqDcoY1B7wzNN1+20dGwjpRKbRNg9aauCVElRFli9datHVZNDolB0rUqFrwTlK+NUaRiKqZXCqH2+Nf9Jz/qz6krGq4GfZ4ozjef0RZTwgr85IxtoUikcm4nAkneDm98PDlFw77gX3Xk8aZOE4k7WiOd+KdSoGqzffkVesUZthJ+W/qhLeVZ8gr8+WMUZqu30uJr9vRNC06R55r5vo8M14uGPfA/v6j0Lh9h+kip7CwjAplHEkZclaMoeDanrt3nzDG8fr8lbic0bYVSOOGvtdGZsZzGik5MU1yyHVDK2+iqZHFsLX0fQdUUjgTxldcWdBvTiQjHK5YFcn3mByxMYG6Ukpm8J7OWZYMynpqWmlUwbUd7f4Gby1esxEFFOfLGdMGrJ/p+hbn5WG0HwbpdcRISJk1VJZxwjjLYX/EESnrTKUHIJQMquCd2gjgGx5KK8J8YR2vhCyHv2sHQirSramZpggHTRcJTSStMGqPteK7meck4kWjadoWY1u0tqgSKWERud1ujzvuiSEyzlcEX2M2+K6kr5QWcoUxeuPJWeo0MTQDRcnPUV4njNbUJGXi8/nEPE+MlxM1BfyyEFLaypcQ1oXz5ZVlnri5uaU7dniT2BnR01vrSakyx0RvNK+nF7R5wdTNHGxk7FZ5K1+y7YP47ljKWWCqJURKCKJc3+glcTqRwkxVFu06lHp7GMveIuVKLmw7VzZQbMQZuW2UnMkI2ssZT1US8045i4vMOoa9J4SAtmBdIwimMqFri+97ci2UJHr7silJhGSQt/6b9Ka0NiICLZmYBMos9LtKKnm7LQr30uKpVW3g3W0PaWSqYLX4p8oWBjHGyAjXGLEJUzf81EbtKEJTkLFwJq1FisK6wfmGuM7U6YlkFcoILeb19RE7j6wh0jtLO+yxrRh6m92erumYY+bx4Zm0Lsyl8jrO4AvWW371/jO3h1v582/R+Ko0pWqc7+QWpDXWtlLKb3bUIkLEkBJpCcyXF1K8oHXDD5//TtJ3KUglJK7oCu12QDvv8Y37rhjRzmGqONKUazbBYSTFLQ2rpCaQUqCmRJpGslJcH+ZtDygv7VYblmVmHJ8oaPb8Gxj33R8G7nadeI+QNwHbNRidmS4nVM1Yiiy/14ld1/E3P3wmxERaIsYoYlgwfcPtzT3YhpTkQxyXC+eHP/LHf/mvdMf33H34LcebO67jledvP1FTxXmHU9A0ht3QkOPM9elJmIGtZ3dzS1pGSph4+eX36GaH6Y6iEKnghiNt09Ae7vDdAd/vWdaFXFZ8Z9nf3DCNDSksvL6e6LzluO/JYeQ0vtIOB1CVFBOmKqzWOA1Re4zrKSUxTResBm81KU/kNbJGTY4LynbYNhFVR1EW1d3S2xYbZ8L1RZb4zpGwxKpYpgs6zfLnOV3JWeE1HIae4XikO94T1wnnW/ZUKU9voxjftrRVCs3TtBLCFV2l7V4xxKTx/YBpe0KYSdOFuAjA1zmPfuvOUHEFnh8f+Pn3/43j7Xu63QHfCt0hpgVnNfthJ6GKnPFenDveFlx7RPYnadsvgLGCI2oa2fu1Tb/dngrGebmlFFEpFK3xViQk6zIzjiMo2O8OUGG9TvjhyG7YURrPGhNJSTctrkKV1sZhnZMbQ4zEmMWq6lv2N/fsDjfsdjvBTMUkqhIKVRt0Y1FFYbxnWRdc24FqMG0vHaiSsEYeDOpfPQTe5IgSrSvERYC5JQeeH76Q5onr5VFKllZuD9Y51lBoWgmFlLr93G7QY2ALOaitcwQpbfRxy3fKdclq065sSghlUVpvXwOJqOeUaeS+Q8wR5zymaSlJ0pu6ruS8MeuM0P5Rsj9TRgIh0tkSPuAblsko2UOmlGToqdVGxchUI84jax1FZdlNWrftteT2raoEXChvwsDyPe6eUsKTRbeyxbprKKzXE+7DHbfdDco0HPojzlkSyAtRzZR1Ii4Tab7Ivs54djefQFk+Dnv+d7sd2jlCmChJ/m6qVtlRlo2ZGGVaksuCMQ7rDHGZKDkR1oVlkbHdOl/59sufmMLKbn9LKYYP7+8puXIar7jGs+t3XOYF3XX07UFCU+uFnBPWevrjO1x/IIeZGq/fSSWg5XkZV9Z1RinY9Q15uTKfnr/v+Jxv0V2H0XDc7wSoML3+Rc/5v+pDqnUrMQfmMHL++kzFst+/w1m9LWVFhxzXmXGdafuedtizHzrpTW2FNazlOo40g2PY36C0ZZ17wuWBMJ+lTxVXvO+52R1xVTHPV+F/lfS95LaMF8brWZQBvsO0PY1zgplJmYCm63Y0uz21FOIyYmpCVZjOz6iqWMcLT08/kSu03ZG724/c3t2yXHsuzw+cX55wqjLPE9fxgrVCol7Dhev5wv39O2p/j212pLhwHk/omtntenx3Q0ky9kxLwnXglaVrHVL705BaTGioMZDjTFUZ5T1GeYwNWynUUaqQua/Thcv5Ef/aYmzD0Hm6xpOr9LP8/oasReWQw0o6v9LYjMUTqmEJSsZSWuNbh9vtaepAbhzrWeLCfluo5yKsxlohrgOmRmyZOTYHbu525FIYLytWOzrdsswT0/xKKlJGXHOltwONdzhryRtkl5IoyuGsQ9Ui3DolNGnT7WWEUivrNDOtKyFErM6cT6/8/vf/xDhe+fTpM5/ff2DwljkszKetQGos2jhqkq+dqRKb7toG27QULCEGSgHrHY3pyTmSS2Gel43a4QjzyHp9BufQw5EaNd72NN0Nh+M9bb/DWCMOryQ3vbcjSr3hkrSME1VKrDEQ0kpYJ15fvzG/PLGMJ4ZhQPc93siuJ+VKowwhJfxGPCnKYJwoHnKpVERwiBLRX61SUVBvRHmtpSisgKqwzfbGbxtSjriul5RhyhiTaZxlmSba3WELR0hknJwFUGsR55SVv1tNUkC1xuGdJYRAyRIyUZrvHadcBM9lnUdtCpo3dY6uG52j1m0U+WcFe82VlAMWvtM13nZkFbX5rsBoh2kGfC2sywW8o+v27O46cpwIcSFWSRyWUsghEEPdUF8V0zT4dod1DUPfk3LidV7p2p6+H8hhZR4vLLMEp1ovbMqc33xiUHUm14yzUL1MH0qa8f2e2x9+x/54Q9N1pFLIudAPO1wjn6+m73G7gabZbbs7IazPS0B5gdzO15nrw1fyepVpU9OB7VCuofctMql3pJTZHW5IKRPL9nV2LUY5TO3wCtS6/kXP+b/qQ+p8XTjsdzTDHfEysswzlRf6rpfZeZEZec0BVRWpylKckKjW0g97knZ8/fqAOS98+FzRzuJ9x+n1lXWN7HdHMprr6zfCdMG7ln53Q9d+kDedGAjTmVJX0SS4hpASuS74tsW5Ft10NNridcNw+2H78BlUyiyXF9brC+SVZVlQJTC0biNbdxhTeX3dKMZF+FrawNC3zMvK15//hNKOu7t7VE3Mp0cOuyNdt0epjqGTsIJ33fZhWKlqJSwrrInFrKiUCZsVtOQIcaWsgZoirkI3NBjXQloJk94YepDizPn8xDieadqOxnnmxnF3e0e3P4pnyDfgu82Qa0jnV+JyxeiO9vgesz/w/PyNp6efeHp9ZuifOB4ODM7iml4euiVL0RM24Z7m5jCw7z2GIl246YWMwpAoOXONinVdWGMmKYuyLViHyrDOAVIgLhO1SLijGw4S/c6imXfdQNPvcdYyTgslQzcc6LrKMp7J0xONMdwfDtiSKcvEdHrBHY4ob4lJelPOe2wvL0PaKMomSqRutw4t6TIxwGopMKeEdQZn5C31dDkTlpF0vRBqpflo8N7i/I67zwPD7a04tJSggWpOoL3sGMR/LreqrNAmk73b9OFKUEsxsKxX4royKwldDPs9a0zolL6TMXKumz5GahhC0ZPStbKGrAraGkqW+HzrG0IQbXiKgabdFv6NZ02Cu1Layr6xyhh1XVYBo6ZAmEdsI7oaY7xoQlIUIgQSgTd6A/KCjLy2AnvdEqVo9WfkE8hiSVvMdiMRrq3abANbEvB7QrHgjRGzwb8ynb/dppxzYDxVScKxKk1xLWhLOL8SgsUPPS/XkbiMOG9Yw0LftVJhyAGtDW3XsMSAUKwj4/nK9fzAvMozZd/1NI3jGmZyKbStp208KSyMlzOg6HYHKppYCqkout2RBljDzKHpOXz8HU07UGqkaxyNdnKrt4qiYJkWSomoXCjztO1g5efsennmyy9fOO52OKNER+MPOOskEel3WO+EKJM3nYpk/RGZjpKysG8xRbICaV1JS/yLnvN/1YdU0x/Fu6M1u4PFNQ3WKt5iqSmC9R7IG6VXo12PbwdirUyxUnGYdmBZF/74h3/m67df+NXnX0kPIkNIiqor8fpK9ReybUjhijJeFtI1obXF9gdMs6P6gfPzI88vD1gNHz99xjZ7lqwxpuF6vXK9vlJKpeukk1WwVNdDlmRXihXtLbZpMc7QO88yFRwZv2uwCCG5bz3zdCVGxe1uh3OasI6U6y/fWYSut5Ie1I45SOLRkMlxRNWANRVKTwojYbpsFuJFFNhU1gI6Qzd07G4t1moupxPj5QVVEkPbcLz5Ow7vfwCjOX/7kdM4oZ0j5YqeJ/qbW2x3xBpLd7hhCTNVO3zX0ByOvGsc3ijyMsJ6QU2v5G4PyFtvChNxPmOUjHKEhm4o68I0jRhnKVrstmxU+LcSom8ssVS0cbh2h20HatqwLUptO8bMWjNZi3DO+g5rFFVlTM7svEbbhv3uiAbOGua80nUw9Humd2dCWIWhdvOB4/1Hckqy36lVdCApSCxio46nKkAJsakKEkjQQZW0LkxneSoqpST1RsHs7+iP93C4wdoeEwrGa4x/syFX9BZdZ9sjwVZi1Ua0ENaxzoW4BvSaia8j4TKyjBOnlxNdJ4lKbT3TOqONJpYqMOY10XYdxjYsUaDF1ognSWPJqmCcdGCqlj0iQYzBbDDdkAK6cTStI9W0MSa3mwAyllvmmcYZag5QLMYb1hTE9Ku3G1wOxDWh+0Ei8JVNAyJswjUE4QYquRHWjWSuNt2FfGn09v+nvsODa5LJyBteKee8AVfd9/Ri3RKzb/+t8oytGd00VNug+yNDiqSoCTVJYbbd4dqWzJllDZS4UlJEW0u2Du93aLvBfeMswsyuoZTM9fVBdl2lYK34tmLOrDFQlMK1Pabfk9GENbPkRCiedV1Zl0zXdXRuQKFRcSLkmahajGnwRoI1TSOAXWMbMsJndMahimLXtnid8EZuSc3xFte2lBiFUAPUHJkWmQDsdnuUddRiZH8JhCjmhVwyOkdiWLie/w30pF6eH7jZd/StxzlNNxhs05FSphRYtMBeh/4GW40QpTHs796RMVynkZJm7oeO8/WCQpIs08tXbqxiuD9SrWOeJlgnSliIccUriPHMCvIm3AwY3xFz4XK58nI+oVDE5cx6MnhtqWvleX4AXTgcb1DKEVTB+0Zm4FYwMU27Qxdomp55nDg9PvDp829QtcibaJB+TIwr8xTp2j2tR5xOXpJ255++4PsebR3zWbG7fY9ud/StI05BFrq1EFehWphmwnuP6hpcc2AIEes8WjvBIa0zdRT7Z+cdx8NOoKOlbvsajVGKpu3xH3/FeHomxEC8vsqBUQKhD2BaebM3nlRhvp6pudIfjhjX4lKCdCUuZ/I60rY9IRlijIQcZR6/QUvjUrHeY31DAaZ5IeZIPzTsD3us8ry+nKkbOsjkBRsVdn9DMwys1uKsYb4UrpdnPh7uaFtBANXK1gkZZW9mHda1PJ5fKEV9H2Fpo8g5SirUJpqupWkadoc9VWnWeWQ+n4Wx13ao7XZQckaYEQoyxFy5zDNGW/qmpypLyoE1RA77nr5tWFNCHd9jb99TtaMqI7QFI2DWrLaxWmZ7szfbg7huo6vyneqNssR1Ik/PjC8PPH97ZJpnlliwvmJrZZpn1pRoMywhobKk+ozP6Ao5VTkckYh4BdF35IpzIteLScy8b4y6uDH4pjWw3x+IabMQv+GMasQ4K4qQWjG2kR0ORl5QlgWlpNcGBmPUduBs/qoiiUVtNRRxdWmVUVr4jtpoGecpIa5bIyNRY+RnrNS6pRDt971UTKsUYd0bmUSsx6qKQLJxYjgI60RnheBgXINtGppaBG+1vwHbop3FkQhTJq6FcVloGk81I1ZbwnQhzqOgjrSjKEU0CeWE2OC8FXu3bWXHWBzaZ2zXoUxLTRnjW0gjKVdCrKxBDLl9Y1Cq4tuOxmm0dhhtCTFSqTRtJzfRphEkWYyoknh6eMQYxbDbU3JmXguJxN4Z5mXEAH3XsC6rBIqAeZpxzuO933ZSSlKmRrqdYs4qpH8LxIm7Dz/QNZq4zig0MWe6Cq5pCeu6FSMTcU3kJB8U6zs0CessK5npckZbw40Vlh7rynh9JZhCN+z5/MNvyFRelpU1SE+oMx7vLCkXrPfS+p5Hxmnkej7z+cM7wjKi9IGC4fF1ZL878PH9Pa+nB56+/kzf7bjRN6zrGY0SvE4IdL6jhkiKkcYoTtcTLw8O7R1rDtRlxpaIsbBqhR461mXl8fUR6z1d3zEcd1xrxhqFnhPpl6+Ur1+xw462kb1RaoVeEcKCcg3jOOGaDq09uRQevjxRamXfN2K1XSTq/PzwzPTyyO5wS3+8QRlINVDylTmKeG3Q75len4GMd451vBBComjPbn8gTiOX0xN7Z+j2R84vjlgT+fTMOk8Ya7HGcD29sNvtWGJgN+xlj+EF7lu1ZndzQymZl9dXaoW+67C6skxXGjdwPr+ScqVUzU4JtmYaLxxv7tDaSQmyadirG1IpXKYFbQRO27Y9cVqI05mcC7vD8TtgdX+4kWCABqzCdJ7Ga5zTqPXM0x/+u5DMgXWecM5D6WkMGER1UrB0/Y6qKm3T0vU91jUYZZhjhOGAWVa8VSgy1jR0XU9aZpIKVN8yV3DG4BsB+uaaifOyxY75fngIsUGKnjkKHWCaLyznB3788V+4XkcRNxoxBHjfkgvMS+JGeyn8pkRYE7qt6Fxkz7MBXCWiLjqSWqSoq5D0X+Mb4spWTlZY74khyWepbYV3mDZWoEJ4kU7I6MqIhTll4QGWIl0sxabp0HJIGmUkrKElwJBL2eL3koATnJKIPxVIcu1f+fZEwrhp7rVGW3FSoWEOixSglQSThOIdUcpvLir5PeOWmqwKcphRrkEvI2m5ECMo19LsB2GBhkLTHOgPn9BWc319IZ9PXMczMSV2xxaUJuQiu203kKjihWsP2N0txhh2h3tOp0dCWhjPD5BkvLzzUpS+vTuSc8/l/IpjRVUpzuegMTZhndyEnevEV1UghiRg7bgQVrFiG+to9ke8r4zTK9fnKy8PX7leX9n1Ld43tG1LPwwbIUWcedMs6xGjjdQkkrwE1OgZ3IHYp7/oOf9XfUgV4+mOtzQloVQlRnEVzetKWtO2yK2EacTVwLDbM12urOOrQA+tZS0LOmussoQSiDXS7nru7+8p15F8ObNvW1Z1w83+Fms6oV+XQji/QkqcX54pMQrnbj/QGgE2WtdifcPz8sjL8y80zmCN4tA1OK94ffyF1jt2bcd//c//hVw1N+/eMxxuhMLcOt69u0MrxRInAXF6i86ALqSXV6ZxpusGTIZ1XFDVwroylUK1AVMVnoqzmo4WVeD58ZW4rpQi5cGSFLrCfD4zXq644QZsg8orOUfG8YJ3crjF64m6zvz8hyd2Nzf0uz3aGnb2Fl93qFjBtQwffiDkQDg/YucLJgaWaSFfelKu7BvD+PyNl68/YvodTTdAmAhLkAOhKEpRjNOCUopvP39D1cplPNN0HqUNc9zkfMqQY2LNkWxhjCuuWeiHjqYbuFxGcpYeT7088fj6Bdf0sqtRiuenZ3Y3t5K2y5l+OPAaZtkZbWSOEBZ2w9bhmU+8ns8Yo/He0rcttSTm64yulXWZRUdyuKHpd9SsuDz/wiknmqbFeidv1o1F10xaZmoohHlCuQY7HOiPd5ASp28/s1wutN2O9fkJrMXtb6hFeltay6glrZWi2IDHWiy81shhrCHGKNimWpgur8zXM+fnFwn0pMQSFoyzsvQuinGexP6L3EKu00rNlZ1rKdoyrxMmqe+9mQyQ1UZvj7TtphopchMtSaoTRjuKUd9vI0op2MazzllJ27YedCblhCWTQhRv1Rs5I0e5XfkWbc1GisjYba9XSoYsAZu60XYFXmw3PJMROsVWcpYxmpXQB3JoiaZCMGLSS8oo+2b+EueStTKd0VpjnJOD2ClyDhjT4Z0TpFGcwVguD5E1RSwIbcXvQBmhvFxnzq9nfnl54frPf2S4e89uf0SlzK9+/Vt++PxJCBW+FbtwSizTxH/5z/9PHp5+pm8Mf/vDr3l/cy/UizngdKHRlXF6QKkJ4xrpAfoelGHNCzROzAhZvi5GGyqF88szISZc29O0O+KaqXWhppnOC12x1T2Na/n52xdSjJitf/n+3Xten18I8crxcNwYj4h9gIzVPbkqfvn2/Bc95/+qDymTFlSWBnxcZyEsuw7re4bhwHI9sVzPLGukHu/p97cERtgYXp327HxhnCf2uyNeWXJMnM9XHr48cHcceJ0vNKoyXWY+/ervKbalbmMbaiLNCyqv/Lf//J8YhoF/+Id/4OV6FlNqkib53c4zLoW2G4hB5GxFwW6/YxmvlOLY7Xa8vJ6xGlpvOI9Xnh9fOB4PmFJ5eHpg2O/RqdA4DQpM66kxUJyhu7mhcy3tsBMn0suD8MSMiP6qkt5YWhNxXWW2rjTD/oDtd1xfHqEGTq8X9Lygmp4YVuK84o0h5sTz+QlvFfbmwPv7W0xVXM4nUsjkanCmo2sqQ9sRl5m2ZqZ14eHbV/ZdS+s9hswcVi4hYdwOpRLolj/99ITVcOwt0zTRDTtc15FTYrxeUNrQ9T2tFiyNLoWyBE7XC9oYjLIbSkpGKNOyUKvCNyNN0xLWhVoLrXWbkC3ROMvj1y/M15GsKu/evWOZRr59e+QyrvS7HS6fxYhbMut8FQJ924CyjFeBEBtt0IDTIh4My8jT4wMffkgcjOP19I3z01eUNuwPt1jfsSTAKLrGYnPBAWtKZKXpDzOcTsQYeHn8mRoXPv/mt6AsFUXdHsaEglaJZXwWz1Hf45odVVsiFlWqiAZhe7hnUgioHAnTlWWeWGPYSOZyGwipEHLlOgu6J6ZMjJnX1wvdMBCKwmlhEoZlpqI2r5o8zHNJkpJDAgrLstI4h7Z6CyQYtC5Yq4lhFUir9aS4kVi2MIVWmrgp7NmgsUqzpTLFOWacx2i/6epl55a30rHxnrQG3mzL1loRXdaMeJyLgImrVKiNFxdXznXzcGXUZs0um6aiFPk3axXiB8pgvPt+Eyyqfg/FqBxonKJpPartUb7bglEjdZ1ZlhNrWuSlomZWo7h5/4H27j3PlwnXdvRDT51HDq4yWFjDyjSOLNcz1MxyuXDTNvSffsXQenprWU7POFPYDQ0qLtQSsGnm8cevuHbHh0+/pmk92rVo5VnXmcvpmbDOxLCiKdy/e49JM62ydF5jbSWEiVACyzSidMYb0QP1bcNvfvjMdbxCFQLK9fUZqypZG86Xk3RFtZZgW8mYwaLdnuzcX/Sc/6s+pGyNlHXi9fEbcZ0Y+pb+5iNLSERtmK4XyGKrzadvTOFCmuctueLJW8ekcV6KfFY6MV0nEUltLX7fUlLl/PDA+e4rwTqapsEZRds4vvzyBxRwf39HToFpPOO9p2t3vL6+cppXjh9+Rdnf8eN5BmV51znG12+0Xqyqr5eJogy//t3vaFrpfVmlaL0Xf0vJfHr/gao1D1++8DxPfPz8kRClCDpPK77f8zd/8/dyc3i6kF4n0DM0hjFlOm15enom58zxeCNg3JS4jiPj4zdqXOlaT983VFNJaWLoOohAyez2B54eZl6vI8O794zzTBpnxsuVvCaKbWm1Z9WvvCwzzw9f2A89pWbGeWZZVlBi5EUpxuvIbujxzlDTzOdjz/PrI2HRNM2RoqHpWuoyoVRhCjN+aDGN5/RyYp1GDt1AqJW4zsR5RZVM2zpOpxO7w44QIun5TOM9xmq8d6xFHpbj/Mzz4yM5Rt7d3RFjQJVMDgtOZR6//AG05W9+eIdREOrCPE04Z1AcaduOGNbNKCeF13ldqdowz4HueM/t+88kpThfXoHC9XLl28MD1nfkDMN+4N39LTWsHPqBoW1JStGrwHy9oGqht5nn04nx8optOrQv2BTRJqNz5vrwla8//8j7X/+a3n+kaodxnayfciYRtoBGQpVMCjPT+YVGC7x3DZElxG236qnaMoXE+TJxf3fHugRCeuL56Zlf9XtiTELYroVUKtOykqva4L1WSBAZUlZoVcgbRUMp0b945Uj5rV8jtyv9neQgxPNcKsoaYU6m/N2ZpZRBG0+pCYrcDLXW1I3lp7b2lrEGa80GlWZTeLxRN2SRn2MURmJKKMf3BKCQJ4qU+uV3lduXNpJGxIA1VG2oxuG740avyGJL1hVVqvilSmEYOhQDqjlw/+lIWCbi5bT1Oi3LdCWsV/ofPvLBNSxz5OMomKNaIq/5wvztX/jp/EA2Lf2wx3iHUQm7rhwby7QWWKWoHMKVGGf6dsAoR0oL03WCIrfcn3/5Qq5fOBxv8UoxX6+EuKANzMuFeZ54+vYLSlv6wy37/Y6vP/4LX375Qn97T9s2tI0llYpWadOuSD0A5PBelxFrrbwchpXj8VZ24KnSNh2L0gyHD9z0d3/Zc/5/m+Pi/z+/lgINBu0bvJay5evv/xfa/sDx9h1d16GKJy0zyrfMIbBkzfH2lhgip8tIuH6j2+9Y1omuFxHd/mYQKOk0U4NlXVf+/t/9DZnIMkeeX57lA0dhjZmhb/n4+RN3NwfOp1eq0lzGK1lBP+yYY6U6x27nSWnl+fGBuo7Ml0DjRDy23x9phoF5nqFWpmn6HoWdX09kBbZtsdpye39PLBmNYdfvOP5wz7xmfvrDz/iux7y/pf+H/z177ZiuLzz+p//I9Y//yGw1v/3db1lSIsZE13WklCCsOOtJ2WyYm0zX9cQUePz6hbZtKUooyyYGTg9PWK0wqtD2HnY9vrXE1weUFcHessx0O0EiHbuen788oY1DNYq+sfRDi3KarBMUcNpx//5GOHxWqAIxJ8bxSoorKkbCNJIxuK4j1MJLXNkfjricSWbh+eELWhe0Ufz80y/03UDbtPz+n/6FDx/ec3PcYxU8vjwTSuFwc+D9h8+cXs9YLTeDMF5Y15nP90e0NUzTyDwrbu9ucI3sItterKmHmwPaORSKMK+cThe02+zPyvHl8ZG2sex6T7KKYX/gcrkyjhNWVQ6t4/r8TAgLD99+xhpL2/V0rUSSpyXQOMf55cT1vGCU4nh3x/u/ddgh0DpFnJ8Y+pZhOEhYIify+Ipxhuw64va9FuldICwTYbyikBvF/f07fv7ylaZtWEKgEEE7Unnj4UGeF15fXnj34TMgeKPr+YxtGoG9xsQaJcXWeI/WHmPNlpJTxJLwRkkYJCrR1OPkgbYuUvK1lpqE5rDGgNctfuMhqiJE/roBpLR1tNvOr2TBJuW0PSSN3GZ1RTQoVgtiaovKW7OVbrPAalMWGntKiVQL3jmxD2cZ36dU8I3UF6TYareDVErFRXtcIwZwXaXYG+MqVPtSKTWQ0ygesusr6zhDXPDO4bqBXDLKeZZxYdUJg6LRlTKfKGnGxAnChHYW0wz0XYcugeV85XK+kBRUo3l+eWbfdVgyy7ISk8boRIgrtRqM8SwRtNPc3d8SQuAff/8vTNcrx5s77t695/7jB+nGzSeUqijnyUiy99MPH1DdkePNPaRETRHITOtECYWaYV4mcsk4b2m0otPikPr28MjTywtDf8DeddzdvcM6x5evf/iLnvN/1YdUd3iP2w34nCjLQqyGkq94Z7brpixC/WEHMYMSh4zRhmIKznuiDd8b5WmZcd4TxgBGEar4npY18PzjT6zTTAqF7t17fN9JkXdZyDnRH6Rng9Zo15AinF9eCKlw+NRiiybPC5fpTL5eaGomLCNm14sSWxvW6cL1fCZVmJaV9x8+cDpdiAbGacTnwK7tWOYLw24gVhjXgOl2nJaMbjru3t+jVaE4T7O7o+l7zP8wcf3SiklVKYxtCEWRlKMZ9pSmJ5ZM6xrSOEoKp20ILyfOr2f6Tx05rKKPdxY3j2RnWXOUPZtSQviwClVavPWscyBHscU2jWOdZkq5iialv8X2HY3vZA/jK2tc6FyPQTos63zFtS1aGZaU8G3HeBnphz3DcY/ZdUynC7rmTQ+/sDt0tI3nfD6BM4QK0+sFawzjdMY0ht639IcbBgONNZAWWcz7lj/++BOX10c+vrvDNw3TEtjfHjFKFuneNUzrTHx54dAMxFgoNdBthH1thSP4v/wv/0jTd3z8/IGhtUynF7z3HIYdKiR2bUcoiWmeaLSl7VtidagCX375yq5raIeWdYnElDlPCzFMuFI4PTwSl8jdhw/kzsMaaLtbuVWHhO0Fd7O8PlNtg93dsQbpv9VN6xHDzDrP3Bz2xBR5vbQUpZjPAesl/LI7dEJGyZnrUhlDZlxXfmgbQsrEDA5NjKt8nqjMIVK1oXEebzTECKWIXwqD2hQcOWeck25YjEHCHFqDceStUFtSIgDWit4hxvhd0ujbhpKrkNpTJIdF+lZK0zi5ySmF9MSshCooBavlgLRKUbegh/WySzLaUFOmUkgxiu8o5W0PVkjrhDF+82UZjLXkCuT4fQ9YqVQlQOmaI7kKiFdnTdf0LNtBWkiM40STM2uYsU4RLldA4bXhen7FailJP5xPKKW578CHkfEkwOHx9czpckY7z/uPH7jd7zBKuIbdcCDGTNi4fFW1+GGHrpmiIBRY1oWub8klk7Qmljdy/YyzDqzm27cvfPnyE0YbPv/61yiTmccLcxESiQoTOly5jvN3RJbRYHXBWkPzydD6ht3Q89NPP9G5jjDPLNOJOhVev/70Fz3n/6oPKbXOPP58ogA3N0eG/UH4bcYTpomKItVC27XSg/DSHYgp45uGmAv3v/m1jBpSwljHtCzc3t7gjCPPi5QHrWd/vEfZE49/+gP3/hO/+tUPXJ8fsDcHGisoly+Pj+impfENaR2ptfL85Ru73ZGXl585nU58+vwRNXQs80zXD8xLIFwm/LxwOZ9JpfDx17/hsL/h+O4DsWhSaDje3JNS5PXlmePNEdV4msr3dv7hcE8z7Pj60x+Jpwe0bfn0w29FoHZ5wRnNSmCeVz7c7Om7TkCQquAsOGVpLdy0B5TSvJxfeXr4StN4KAmKofU7Us60724oa6aeIiUnys4TTKLGwnW60DYtH//mdxxu9qyL2Gz7vsVqwzB0UrJWWlQeVPqUcanw/PDI+TpyeX7mh1//iozh9XzFIqgev9vx+PRIm1a63Z6aEkqDqYUQI20rpeHW97wfjlzmFXd/Swx34uPyPdpklhipaQUS8zLTDwPnq5DPj/0Oi8YoQ14C/lZ4cW4DySprsM5xCRldA5fzK/uuQxXpP+Wa+eHzR2ItzOuCwgsVe5kY0Tw8PtJ0HdoZ2mFH7jopTSrNfLqS5sLz9ZV2bji++8Dd+w807cDz0zMhJH768sR/++f/K631aJv5m3/4Df/h//B/5Pr0M3v9+fvNLq0L08sTzbKwpso8XlinM6UU5tMj0+kZa0Ri95sfPtANB/4f//m/8/xyRetIxbMmRRonpjkBkuAL64qzhvEyYbQlxkTfD98liI1vNoKejMnqNsIzxooBORXqprYvpVBKJaVE2/WsYTNPGyMakpIF0FsKKWdqrtiKlFardIzYSujGVDAGX7x0xmQJJyoKY8l5lv+97ZtyKdSQt/HepmwBwmYH1oivSRuh15QCbL+n0W/qkIpJKxlRoOSSqUqJ2bqqrQQdWaaJ6+vCkgqqwK4X4vj1fKHWRIqF4+Ed1+uFSGUYWtZlxVjHp3cDSygslxlMIHBmjZE1BfqhZ42J6TrSNZ6m7Xk9j5zGUaYuzuOHG4Z+B6Xw84+/Z9gNvN8fGc+v7G/ucP2Rm7sPtG3Hy/M3wnTFKkMshXWVF6rpfOLxxx+5++Fv8fs9xg+0GebTMw/fnphfv2CsphrHlAuXy4Vh13O+itYjl4RtG7KG03jmMp1lpBr+DRAn/sv/7X/m82/+jl/97d9RdeHx8StxWaBUaqlczhfu7u+wiH765fUV4xx9L2y2w/FAKJmSCtYIhNJYz7wmeXtvLNZmvO3Qw55+d+DQNyhtuDw9EOYrY5jRfmCKmbbbs4REqTOqrgy3Lbt9y/j8Ddt42puW0+mJw+EWjKSPXNdhvGcYBrqbO5phj+8GmqYjhJWsDMt1lrSUtdzc3KGsYU2F6fWVzjtq1dy6hmW68PTjHyFcaLznDy9fKKnQdh3DbocLC8s8c/rpR7GrWs88zzReszsehL69fW2L0ex3DVa3pBiZJuHste2OnCNRTWQnSJm4RNISuJ4nai7ENnJ/e0+4zqBlBPPph8/kKMZZsdAKDigvK9c0s+8a+rbF+IHd4Y7huGMpFdM46hwxuaA7z+7ugMkwPjxyfH/Pej7z5ZcvKKW51Z7LPAGyr+z3Hd3NAdodtjnStTcs41earTTtapKXc9ei7YXOaV4fvvDw+MzxeOTdu3d463BtS5LNCr219MOAB2qIqBAwRTiFTavR1tB5qNpiuo5dv+P6/MrheGAthcM2elrHkfkyQtF0hz0VaA9H/uZ/uuPLn/7A9PrKHmEKHo9HDkNHTIV3P/wNU+n5449fePnyBy5T4OXpifO08IMx7LTBWU/NhVYrynLhn/7Lf+Pl4QtOZYzzeK9RZeVynailkGMljQud9ZQQWaeVNQvL7nI9sxv2gqPaHEopFVIuKOvle2ocuUhhttQityKlpFyNFIiVVmjXoLSVQ9rYbSclHaWqZnwj5dUYpXtGRQDRMW7wV0tWmmUrdCulCCmgjOB/UlzJpUH2SIIa0qpScpBxnJJkHuYNFis3OKWF9CEa+fpdEVLR1KJIUaYuuRQRHVZQW+crZtF+KCXTmFgSYgDTWG1RqtB1mqrF0HA+PXN+XnHW0rU7jFXUVDm9PJFyxjUN1Rhs35EKzGvCNB37/Z40iYDzOl7EPGwUN8MRay3j5cx5nGh3t/z9r36DaRq+PT1KuKNp+en3/8wf/vB7GcPGSKMyy3gm5UiwisHeocIVFReKdhzffWT/4TNpWmltQ12vLK/fyHFEW8c4zvz4xz9igM5psoJxmVlLxbQe1TR0N3fcHA4Yo1imiWmchK6iwGnF+fT6Fz3n/6oPqf3dez58+kRjLafLM0NrGFdPLJm+77hrPYf9gRQCl/MZTWHXtdLsLwVKpjGVx/FC43oEKpP56Y+/ZzyN/PZ3v8IgptH2NmLbTnArJXE9jdSaKFozroFv357I5YTSmn/3t79j2A2gC5fnk6TMbg90xqIC0sEp4ts5v75wOj3Qti3t4Y7d7Xusa0gx0rUd7t095yoBheP+lutlZblc8E3D3c2tSBbHiX/8b/8VVJXY6fEo1PZxZF0i364zlz/8zMf7gbu7W0GUWHlI5LVy3O3IKZNSYp4WlNGsVajTFYUymnmc8V3c7JyWYT+wu7/h27cnfK64UtGD5jpOQs2ulf1uoKq6WVyLoJhUYfCNAD11Zdc15Gqg9ex1Q6d7TOMJYaRvPBZ4/eVHxuuV1krs15XK18dnVjI1Z24/fMBZz7oGmv2Bm+MNtbWwjBzagVV5fD9QbCVdy/fi6cvjk8BwhxuUn0mzwu12pMWgmwbt3aZET5SN+j00LV4JRSEshVAzbdeylsSPv/zIu3d3OGOpGnkRaXrc3lDSSogr5tBys+uZnp9wVt5Yw6aC8E4zzSM1Vz7+8Bv29/ecpwVNpTUabzJ3twO/+7v/wD/8n/7PTI/fqM9f8E5j6itLWDHXE+c18PTTH0ml8P7TZx6+/MTrw8+0RmHbns8/fORwOOJ8y+l04cuPf+L58YnzrMFWfvd3nzncHvnpxweWy0jbtAIRTZmqLKfLFd/vyEbR7AdU03ANK13ToLyDLISKjMbaBuM8Maw0vpHROnZj6Ql5PJeCSpnqK+u6YI1iXVcxHhtJ4hol/L83M7CiklLYKPkbgLekLWlnhR9YCymu1FJIKQg/Mhd8s0FotZTQU0obMkxtXbgN6KrVtuuq3029BUQsqDW2OnKVQ8k1HqUrqkhUJJSK1QbftKQUWOYL82XGGNCmox921FJkB10EZVWAYyP739fLmX635/DhE/2wp6wz364vhLDihw6nFJfrxMPTEx/ff+RwODAtCyWP5NDQNWDXkfk6cW068nLh3VEgxK4mGgW2azldVl4ev/Dy8JUQVu7vbhhjZHCW4/tP5DXyv/7HH8nTC7/ZHwQCXDL9rke9f08IiaAtc4gMreamdWinuT0e+PTh1+QkjEt8xWJwVuOHPYfdAD/96S96zv9VH1Kua/nxx99TYuDTx3fUkujbjmm6klJgv9/z7dsv6JoxpXB3d09MgXVdOF8v7PcH3n+4Z9kXIgIcPT08chg6Ptzf0+8Gri/fMCWyfPsJZVv8MFB13WbSjoTjch4JsWBM4eP7D4yXE9/+9MyP//IHpssVbSz//j/8Bz789u9o+z2qb+kbT1wDy5pJ+ZFliVy/fWO/v2EaR7yzrOtInC80FlTvWMcTKmdsLYTzhafHR+z29jperigDf/ubH7j5239P5y2//OmPUCquaXk5XyjTib7fSxJKyVz60A88Pr1wupylVNrIHqgaRwgrNzc3rCnxfn/DdVrIVR7av3YePQcG3aBbjb3puTEtBemprRpSmFmvZ3yFXBX7/Q5qwWvFdHrFtz3Fe+bnC+vLhL19h7+VZXycV6bLhVoK1+sVlTNNyvz8409YZSlOs9vtySnT9z2+aVnDyrqsvI4Xmrzj63//R9bDHUG33P8mE9KFyzKjtcErxXK9EJrE67SSc2C/2/H+3TteXx5prKdpPSkEdFX0jWeeRqbzK3G+0nYD3jv6/UCm0riOj58/47ynrJFxjeyOHecVhttPKGAohV3fUMJE63vC9czRW1CWaVnp+h61TBhraNuO9nDHTdsT14XL4zdCXjg0lfzwI013kLf0336GsEKJkBYII533HA8DY0jsbw78T//T/0gcf8uf/vB7LtPK3YdP3H/8AeVbTueJH/7umZ//9Ce+nVb6vePf/fvf0DWGf/j7v+fp6yvXyytffoF9p8jrhbCMdI3DqoL3QjqwxtF0PRXRuaMVtu3wWmGdTA2MaHK3/pVBOy3jtxi4XC4oo6mlkmshrqsIEauUq7X1GGXk1gSEdZaXTOdEEFlBay+je7XdxjbQbbNp4pckXZ5cFCEnum0XXFP+c/pPQarCCLRab6K/yhKDAHytA7Q4D7WMCGNcyTVvVI+CMwqtLSkrfNOQy5VluuJtixuOxCJE9Mv0jK6F6+uVvmuIOZNfLygzsabIOK28y5Zwnnn59gvh+rppawxRV/phYH+4oeTM6XRGKbk9/vQvj1wuV1rXsusG5kluSD+8f8e0ytd6UVV8VXc/4EshBAHo1qbneNA8n175+eXMdL6iciKkwv/lP/7fKUqz2w18/PSJnDKqVtZ55Ne/+S2ff/VrtNXy/I2RP/zLP5KiJC+NknTt+XyBaYWcpC7wF/z6qz6kGr3Stpp1ztSwMM8BaqBSsb0iTiPyQr1SUmK6CjdrWldubm5J68qPP34ldQdUv6P1nvFyodEV1TSsykDT8/jHf4F5Yg3w6e//nuINoRa6Ycfx5h03x/ekJLJFrSrjaWZcI88zmOMnbKOYUZQSucYLZckMxhMTFAzWNrw8n2g6x/NPf+Tdh0+UdcLmhen6zDxOvP/wkVVHztdXHp+fSbngtMF17TYSyey6gVwqaTyTYkNjLON0Yr2+cGgb2o/v8c7QNQ3rOpOryOjcYYdrnCjFU2bYDfz49Rd2vqEoLeLCpmN5PaOt5sP7d4wvP1KKohluaZ2hhollvjDNgXkNMj4pldYbEQdaS4krNWfmkmi7hnGdsUqz1Eqyht436FzR68L5l69cT2f2t0da79GdZh0X6hSJrebj3/4OowxDI0K36XLFeUfbNNLbyQX328/0hzvM8RPzNFHGhRpWljWCMyJspNDtB6ZzJs4rwTiGboczhrb1lH6Q/WbMzKczpQaa+3cU37EahTl6akioVLjZ9yStyGrBs2Jch2oORD/g+gM6JiqJOUz0vqfZa3IILLVihiOqa/Ebjd15y3Bzy5LAKsfte0ecBryGRq1cH/9ZgMPjDuKKvjzju54yV7I6orzn/rin1sDNsYcOKJ9Rfkd/+wE7iFrF38FwP7J794mPFd6//4TDoNPE2n7j/u7I8tXyu3c9fthD1/PoKo9fH2nMjqEbKFVJKtG3GLel50pmaA+oHFHIyLmkTSdfCsbJNGNdZpbpKuqGdQUqYZ3QWajgqchNvOkUbNBX4fatWFUpSm5tNQmfs6BQxhBDwFlFygKDZQPHlipyyVSRgydpYehVwR0hv42AfqXpC1SM0sQ1ClFDO1Ip+IqUfstm6C2b8ddq6ScXmFLBOcfheGBeNM+nKx8+/x0379/z/PoLpmTCXSDFyLouHI4HoNK2DSknTstCXAO2a2ibWygF37U0XUvb9ljrmMaREhOqJK6nM43tGD7d4ruOECM6Zg5aXsi/Pp94ej2LEUJLqT7lwus8yXgzKrTr2PkGQiCpSnc4oo9Hbu8Wnp5fCblQs4GUoCRuTCA+/8Q/nx7I2vPr3/0d61qYl5HdsMe5lm/fHkinyMvLEx/f3TO7yjqe/qLn/F/1IaVjwLcHhtu9HEwpUeLGViuFOM+EdWXf77gsI19fnml9s/G2CglFKpE1vHLTdBBWbu7f0zQNxfeQE/1uxxIyOa7suj1qf8ux9UynZ6iF3bt3VOtkz1LkFrG4J479O/7H4w80d+9pP3/GN5brw8/sVSWNM9n1nJ4eef7l99z0Fv/hI941rNPIennCqiI+mJQ5vLtnRtPfvKMPmn5Nm7RNc+g7hnVh2Hu6vuXu0PPwh9/zT//rPzL0A23X0g4daVnQJaJ9w/mceX58kpm80szLws39PSEFqrY8PD7S+hbXOWqKLJcL0V647R3TfOH1j2dyrVyuV1p/4vbmSFpFF39dF1rfYHNiWRfGa8Ld3uI7j24aXr59o9dKrLdVc31+ATSHmz0xLJxOI7ZIoGE4tOxvd7ihZ3p55bJMRKXprKaczxTg8XyBqvHW0/YdOBgOA1YZ0u6O0u/IZUKZyP7jZw5euhtWVUqYWa4X2kbhDnuq9qK8MJoSV7S31GH3nWqNsdjuiHv/O5q+xzrPNE5M568QFo77QVT0TQem4XI6UexKsVf8sKCNYVGVthsIaaHZ9eRUcErRNS3eWnIXmc7feHl6oBnec/jwiaUUWBbq10Rer1zClWWd0Wh+efqRJQZu7+8Zp4S+fONwu4JqsbklTQvOSuPn9v1n9P4e094RlaG6A9Y4WnvAHz9y7zzGtcQ1wHKhbxznr3+UGoFveTpf6VzDD+/ec9MPzC8vuDBifIM3Hmszysr3Rysj0sSqNqMtkrorCZUil8uJhGIOK5ZKrxteH18xnaX3BnKhmoYcZny7Y80KpxN5WmislVSg06xrJI0rrbVkKqFkjHW0vmEaF7SqTOdF6iVVwg2+DUyXE13rmWdhYQ5N8z1p2DStjAq325XzDtO24mBLFaXFmTXPI943KCMj5JzTxv9M9O2OUgKJTOMa2qbHdAdWe8EbeP7pT7w8P6CdYV4jN+9+TZNhOY8Md3cMhztSjLS/vkc3PevTzzz/038ShUm/w/uGh29fcMZQYqaWSus913mk24mwcjccGVWD8p5weeXns0wRPn58T98P7NoWHa/883/5nykh8nSJ3P7wO379698xHPa8jlf0bxtiTIzjTLGe29+Kh+zx8RHXWPJyZRxXrueVm9t7boYGg3Asj3c3lCVyeX2m61t0e4MdBs6vX3nv7/C+/Yue83/Vh9TryzMPj89gLLd3t9weDhSnZJE4jXhvMXgeT6+czxdsY4kh0biWOYHvGubxQtUWte4EjGjg9uZXBDIPzw9M88j9p1/h2p5qHN1ux/j0RJgXdt7iwoJteiqWtE4o4NOvfiCtgcsaufn4a/bNDXpd+Sl/o7+7hfMzLDMuZ26antfHR975G7p9x/n1ma/fRvrWczlfaLuWrt/Ttj3KeKJ64f0Pn9iQ0Tx9+SJmVSrreGUqmX6/5+/+h7/DO8/5csa1HuMsxrdUa3l+PfOnr9+w2vD540eOty2H/Z6Pux1PL2d805IrXOeRH3/6+v8i709iLF/Tu1z0+dp/u7poMjJzd7Wr3JQ7MJx7rlxXdwQWHniGhwgsxMgyyMIMLEsM6I2YwAAjIYSYIUtMAQkMAqELRvjYB47BprBdzW5yZ0ZGs7p/97V38EXVPT5wzy1zjq5UOkvag1yxMyNirYived/f+zwEV2LWbdfSdh2iVlT9Bt7eEuYzWENEELXC1hbn3ZM5VKMbS9KC7ALMAS01sm/IxuBOZwKZylr8PFO1HS8+/y4pBg5dcSMJIQlLpFYa16353HvfxbIsjPevsNGhIty8+w5YTcgRkQLuNBH6msvrG45v94RlZn15QX+xg5i4Oz6CyGgSMQRitGwvdghtsG1HiIJpdsi6RsVIEA2vHl4hs+CiX8PTEGZdd9Rac3PzjOP+HkdAVhVh8Zwe73m832Nsg2pWLN2Zqm2RWlFviqzSx4Q0FUqXeZJMQCmB0Yqmthz2d6iuRzUr2suXKNsTxiPL8ZbdytE1NedhKODfbsX5dMaPR7S2mH6Lrmp8iiRZghzWrmF1jeivUFnhtSUkSLYQOLRYEGnB6kw0kphr2uv30OtNoX6MHz3Nv3iqGrpnu4LTGkbUcKQiMzvPKSa0lOTo2KzKyEQmMXiPG09k57BVjbAVNpfS2V4Hqn7HZr1GakEWGTfPLGmmyYXE7+cJSWQOgXkcSATGYUBGECkxuuWJ1NChpWA4HVCi3KJWbUfOiTdvb7m+uWE4HrFPeomcEmcBMmfa1YoplvRfVVVkMvMy06JKleKpD9b0XYEouycFCQXaDLkMUfuSVPShlCURK1IK1BK8n3F+ISlD022QdUCle7Z1QIuI1CfU7MBnzFEyMZCGkXW3xkjNEgLu/gGZMufjA2RY9T3rdcdq23EezyyT5/bVK/TFDavNlqbtMXHiePspaRl5PDzwtf2eZ5c7TH/B4TRz/blnXL54iV1vmZzj8WHPPM988tkrfIi8fO99dAo0VvLeswuarue/fPk3aC/eKZWeaSb6hdPhkWwb+qoh+ETfNfgYESLw/HLLqe+YTiMfff3/AhH04B2b3RXK1qy3G2JKhOARugx8hhB5PBx4+/oNla6omuJUsm1Lu94QyZz2D2wuNnS1pRKBuIw8fvZV3t4fyEIhbMN4HtARRMyk/R4hBEYpluh5fX/LRb+l6ho+++RrTMd7nl1eIpLAPTyQnGNQFVpmVNhzOjl0VU5wSUWqzYpw2uNlJiePkuXUm1LEVAaE4vz6nsW/obY1b7/6ES+eXyEUzNlBLDp3bTVxOiFSou5bkCUckVWhUldNQxRlbkR3K77j+36A2laFKhlnxmVh9r4Qq6lZnCd6WO0un+r1+Qn7YzBtQ911dG7N4zIwzxNaGuLi6OuaKQke90eS9/Q314ScEN5jfaDu14ja4slgKiplqKu60JJTYJlOJYHZ97hxREiF0QUGWxuoNjvU4jg/3CJUZnuxRleWbAUiZqaTZ9X0yG2LXyaCd5xOE2YVaRLcff1r7O/f0BhFRDJlzfbFS+TlC0x2mDThznf40wmbNjSrZ5wHT3/1DkY6XFzQwwGWhelxjzaW9XaD0IqHu3ukL1Dj4BxNZxAqE8RMbTOVKg6kaUhPttKqkE2yYB5OJCVoKoMRiWq7QTY1YjmzqluyyNi+wUiHmBVqiVTCQyNJeoO2HcYYRi1ZvCfNEyFHgtIIWSOMJZk1wmzJMQMJlcsga/FWCTIGqzIiOLIRLNmSdEVUilpZPqwMzAesjJAkh4eR4B3KlUOAqDIogVYdIkce7w/Mc6AyME8j8Vw2ss1uyxJSMQy4iKlXfOf3fx+m7rBZYitLIvDw6iNul5HzcWDxiXkYScnj/Uw4ntFaMi/zE84HlFToynM8nPFuQsZCYSclwsqzuJFlmTncveVwOFFVFev1qsTkRS7JPaXK0HZVlR6nc0VuKA3miYTuQ6ARhV6R3IJbMi44jK2QUkAWLNOMtoLgFqSGpB3TvGerBcfhyKpb0dRNIa7ExDp6wv4t8+mAsDWmXxNcZBJvoNqhTc20BFa7CzCG5XDg9PhAdI66qTgc9pznkfV2g5KKVbfm4eFTmrbm+KAxukaHkbs3rzEiEX3x1h3PZ2R/wfV3fA9GWc7DA6/ffIKbZzSCrmsZzyea7QYtMo9vPmXAkRF0u2uW4Ux79YybZy9p6ppPP/4K9w93qJjxukcYzcPrV099OoG+kOQYkSo9iTP/fz++rTeptl+xWvVPJshADIHTwwOr7Y6IRGnFut8ydyPW1uiqJsbEeZyx6y2eTLu55Pl7H7JerzkfHvj0a4/cvn7FeZj54MMvUHct03ymMgo3HIneIZpVmWdYbwgIzvcHTrdHhv3jN3UMUUpUU3OYHFIm8jIhvOP2k7dcXl1zPj4S/IK1hqt1T8NMWuB6t2EcB3IKzOPC6eGBrl+TleSrX/+Ir/3Wb5Hzh1zdXLO+ukQbi1IWUxvMukFHD2R2mzXOOc4po6REpoyqS6LPhwBKMjtYlhk/njk/AVM/+OBDmqYpA4laMS0jzpUJ+YJ/yezfHnm8v39KnUElBZFA1Vn6rqFa14hWswwjUkgqW2G7FlHXCCTRJ5TSdJsdbp5L01ZJovOc7h5o+57aahYhsE1Du17hxolw94g5vUUnT2sA1TMRYB6ooynkt5zJBtz9UBJITcPli/fZbLbE6URUis3lJX6Zqaqe3dU72N1zVHtBw8zd1z7lzVd/Eykyy/HI8R1Ls7tgs15zfPMxj7evWCqNypmqaVACxuFE2zbk62vu9yfCPNO3a/JTD67d7sj9Gq0V3i9loNwWOy3LjJSCKgVUTqQw43yku9xgakk8ndi/2nP5/udwORHnkdZKcjZMAczqgm57Awnc8Z7xfCoA1DiRjhO56qm2K5I0LCFjhCgCyZSwOaKf+GkCQFhECsgQETEQoiejkUYjk8bkCqV73LIQZaR63tClTH3c0xhN8I6+6ZibHucWrrr3CMvIMB5ZfKIWNUoJ2n5DbSxBWxppEVim/Ql9oTlHT7NI6hTpUiQOZxaXUOsrttfvMp0P7L/2FT766Lb4zsLCZrMqN7xKczoOZeA2OC5WZa5vmiekLiDVnGFaHNJYstIEUSovzi1lnioLsi+LaEwJIRTLspDzmbptn8jniePxiBACTSCFXEqkqXD/tFIs80QMCkkmJEfODrLlP/z6b1H1Dc9fFuuCH2f8NHHVNPRa8vWvvUbICvQ93W4N24jwifN54Xg+0WxWZCUQLjCfT+y2W1wM3B5P9Js1YlrQWTAND+zPR77y5jU377zLul+TXVHLn5cFaSy77SWnceQ7vvh99LtnzPsHfvU3fol+1XBxsUFhOE2O5+9/wMU773GxfcbxeGJ8fM1u3TGdD3RWcF3Bp7/5H2nXW07HA9Mwcb2+pL35LjZtx7OXB9xy5nw4cjiesLqsP5/7ru/9ltb5b+tNyidw3lEpxTicnthVC2/v7rFtX8CSZJq2K5gXU+HijLGWu/tHVpcX6KYHU3N7GDCqpb16h3qOXL1sqJoG7yPWVkgiVVsRhMGhadoGkxL+8cRhfmCRmZvnN8zHI/vbB8bpzPXVDclkxNUlsRGE/YTSkUTi8XjAakWz3aHbCqkitlqVUuVcQJKEgCUjYyBnyXsvnnOz3eKdx/vA7f0DIng2q56uqnBLKev4ceboT/SrFU3bPdlPi3mUnNmsVhyOB87DUCjQMdM0LS+eXRckzzIzLI7D6YHNqsegIBct9eID1mhU3ZUNWwmatmOYRtIyMiyBarfDtmsENeE8gUw0tkIaQ4qZuqqKx0cLVFcgnSJnjLFoWTw/UigEkvPpVER6wLPdmhgWwLO0hmQ6qsaQk0DICiEETbMgdUYlj1scF9c3rN75fAkonB9ZbTZYozkcDpwnh82JGBf0w1e4Pz4iU6C/fEbIkjlkmuhxYcEvjpw8bVMTg8Pn4l3KZM7HI4fDHqk023bFMUIQivX1cyCTlWWzuyQtA6fxgJYZQWKeF7SQpb8YPcGXUI+pLSInVlVPloZXx4/49Lf/I9Xmina1Remex7moT2q1LU3s5cQ8nPAZbL9Gq0wcPV4XoaR3M8JPiEEh6x05g5unIiNUFUIraj8S51MZuwgzYfbougOj8d4TxjMVAWUsKE2rFHGeUFWDNBpdZUKGHoHXtnD56pZBGKi2uHHm1Wcf02sLypabkzEs54G4PzPGGdk1ZFvhXcb74ol68eIKuh316grSFbuu58WzlxyOB7KEd14849//u1/m1f0Dx2Ggriu6qsIsGe9LanMKEF3RasgnkWJWkiQVQYiC/BIClManRJwWcs5U2iCFZjwPkEFXFikl53HCGsMcF5qm+2ZQKHn3FMiJpGQKj9F7pMjsthf0/48/yJvXr1B1x+ObA8NpRGvNcYkY57Hvf56mWeGE4lxpqqojjWfGMDDMZx6OD6zWK9q2HN5U03L57DnNy8Cb21tuH49sjOY8DmRruVpfsd1smE73yBTLmtPtULbGNB3jwwO//mv/M7vtBZ1WdLVhOu45H44kqbh6+T67Zs3lRfnZ+9wXvx9/fokigp/AjQz3t7x+OFL1O4SUvPrkE4bjmf7ulsvtjq5t6Dc93bZCrWvuhxO2rvj6J598S+v8t/UmVbc9TdMwTRNGSw7DiSQFVd+i6gbIdE1D1dY8POyZvafre+qqQVU1PmdynHnz5jXS1FR1g0PRbnb0tcLWHcPsSty0qnn9ta9Robn63DOESjy+/ohaQ28D26Yvt7ksMbZhWxvMusXLlrQosvMYWVFdbMkKbGNZ9xu2N+8hbIuVgaZq0UqCrlitTzx+9ikaxfrqGVErhFLstGGaFl7fvWU4HlmpjK41y6NnXAIpSWzdYKqKwTu61YYlBJaYUDGVkEFTs5YKY0xxxjzcUxtdhnaHAf+EfWnbjvk00tqK3eUFwzKTpMB2K6Rt8c7jxpFoLcZUmNCSXEa3ZcBw9CdoDUl5puNIXUVMXaP6iri4AvWUJaZLzrh5IVtDlGVORqSM9JHjqzdEN1CvN+huQ5CCrAOkmSo3LLYhb67BWtowko93TFWilgnrHflwIMYAfkCmSI6SpmkL8fu0ZzntERpAsF73WBI6JfI0k88DAoMfT7jpXMjUoiM/oXZCDOQY2D88lJ8ta1l1DYtPuBDRWtG2LZUxLFNA+Ynx9EDddHT9GmEskBjPUxEuShDBcXrzmvHNPfVuy6IM4XiG+EA4H5AiMg4jw+Rp7t/SXV4iwkyaz6iqR7bbok0XRzKCJS6EeUA5T5hPiO4RY55CA1WNaVdkqVmObwrc2FaoqkaaBhcXgmoxpsL7SGMFfp4IUhOkJftIU/XFopsTbVsjkmeJAaEEqIpue8k4DAhT8bKvWW/X+Hkh+0AYR9aVYGkE2XvUVLEYi31xjdjvuZSBTd+SheF8uCemzMtnF1xdbKhXHbYuapr9wx2DDxgELkTG/YnFJSqdUdLjlxOV0WzXK477ia63JJHo2xZCwnnH5B2mbYr/i6fbZY5YbQjLiVlKhFtASty8lDkrmfEu4rynaxqW4cwxBepVA0jEeoNEk/zEGO4xds+aSO0lLzYrPnMTpmkRsUB2/TTh88IYMy++8ztpd++Rh1tqLdBWsV4V9YWyFoFiXCKf3u25+fx38cV3PuT+a79FPN9z8eKK9uo5m917zMPA6fiaw+0r9JNnrqok03hG5oROE/effcRnvhy6Lncv8M5j+wbb1Nw93lErRSXLHKZsO0KMICTTcCQIS7264PLmBTIHCI7D4xHTSB72b5m6nmF3gdAV+/sj/cU1YHnx7L1vaZ3/tt6kluGEqzWmaXHB0Wwv6S5fkGMqjpTgGMaRtq7YbLcoY7m4vObu9i3D6S05xRIflSeS1jwKwTh5NrZjoUVdXmD0jIye/eMdx/2em3ffL7c33RCwVF3NuNzBdGZ6fKRqV5j1Bts0NC/fR/aX5KCZ3ER0B+T5gfDwGVsl0W5mfv0JqqpweUa//IB6u2N1uWEeJE3XcLh/JGuFTIH8dNIeFleGBdsOmSLLnIhZkW2LtDV5c0FjKz797a8g4hljNbYyxXmzzGQ30uiKrm5BKabgmIJHaMPheCbnxHq9RijDEkd0TpznBVW1aK3QUhGCI+YIViGtRuRETgZxeYHYPWddNXTb+enrHjnf37IsE/MwUJExShGWGb+MNH2HUob5tAchSjgDcG5CCYEwAmk3xJRR0wEtMhYQpsHUPSkJ3HGPqQxKlzmXvm5JKTLnjDu+fRrWtChThjKtLgbY03AsHqEo+Nznv4CyFT4qcvSwBNz5wLh/C6nM0qjKUHcdQtdMwwgxcNyfQCiafo1s15jomY+fMe5ntDbYsHCeToyHR07DhLI9bb0jJgjzgrEKnyOyLv4xkRL7wyMheGohWW82sCon/tNxz5uPvkrX9qQkefXVT7B1xXqzpWobNpc14fiA0aV0p6QmCjBNh5IzKiyEOGArhawUwZ9RgSLpk5E5LgX4KyzRD8zjhNAVfaVpLrdM5z0JSdX06K5jOJYSqs2Z4/0tIUzlhC8TOgWIkigt/bN3mIE8TZzPR7p6gzaScR5JIoEPBB9RROR4wt1GelOxWl1yns7IrqLe7QjjAMueVmtUkIjFEuaRL374ARebDV4KTsOJcTjjTie0EsV6Oy9M55EpZI7jwjgFbt67Rncdbz95w/3toZSCt5GL7Y55mEqvqK0ZssdPc0kS5sw0L2Xjmh3CKtqqws0jY19TG8V5/4gdOta7S1xQDEvk9v4jqnbNxcUVyije3L3lYrXmvXfewXQtOSmm8cR0/ynnx89wQTIebkjhFWIZECgu1mvkMJHGmcUKRNtSCYuazsjDLe2zK3xvmKOla2pwM9P+M9a7F7Td57nYbnjz6W8hs2J183k2OSGM5njYc3F5wf5hT1dbtDGlL7a6JJoKZWJ5z85bTucRbQTRjUih6V98N/v6Eff2Fl23rGrNsszcHQ4Yobg/PPD6/sALoWn6NfeHeyZ3JvgyHP6tPL6tN6lTSOw/+YyL3SVKKq6ur7F1gwROx32ZzbAG7x1LzhihaC8uaVPGPQjSMhHGESOK7dPNjqvdJavNFfXNC6IS+LtSlz7uZ5puS325wVGIzvV2g21NUXScBiYyc450XYWrK8I00pwXxrtbQpzp1j1WG04pYbsa23SYdoWxFrxEDScO+3Kql0oRXACfORwen8R2M12/IseMFIpWGUSKHNyINopWaObHM1e6ZTqcsAJOwxmTDJVo4DyicxnAy2LCnR5ZzkdQGmlrtNLopsbPC+dpot1cINYwHg6sVqoI/c4DIUSmZcaHwHazhWFknh3d7pKmafBuJIcFiSQmSUgdUV9wHvZPPLeuNLtDIsbIOEx4fyIsE1VdM48jSiliCpyGEWsMXVNTNxVSKoK0rJ/dUK2fIZoGmxI5RnSK5OlQYMPzQBKlXKO1BBQxRvLsn3pgCqs0wRTcVJYalC6aca0wErbbLaxWJVwSAuM0gVJIXRHcTKVgcR5jG5rdNWZ7Qd004CbWu0uUpJQttSEDtlvTVDuiXeP6LTo5/PGWaZpRSlBpSwoJoxSX2yt8LOlGlZ7oDJWl7nu67SXDMNL1He9cXUKOVMYWO6rRnMeReSnkj4hAy6cyVaWJOWDrlm61KfNojxMpeuZloRWwWvWF+WZtwRwFT44zfnDUVnJzuUWZmiQVU8zcSzBWMy0z1WaNNZqcclG5E9BohMioFFA5YxsDpkcCOUeatkFKSS1VGWzPkKVhjJHjcI9FoI3FoDApgpGc5yc00DRTNS1xGemNRK9WxJyJbYt9+W4BympFFILz/ohfHClE9o+P7B/29BcbbL/CrDydM5zGM0toeLPP6HrD3eGWdzdbpDDIumc5jtx+/IZXr96wvnnG7enIVsG2a8jBsdl2XG7XkCo++fSRd9QabxKr3TMum2fUbc96tabpV0zTwP7tW9x+QA4zFy/fIyOJw8CH3/2S3/jyb/Nrv/w/sbvs0Eg+/O7vIrdbWGvOt28xaeb+1S3d7oJqveb2049wj/fMw5m+bgiHicNwh5Ka5oOZKQZev3mLyInrqzUt4G4uAAEAAElEQVTCn9CmQQlD361YbdYoUzFNR8YUEE3NMA/oHLl9vOXF9Ybf/u3fYnSBtquxBq6unrO+esnFy/fxw5Gvf/nXuf3snrbS3NxcU3dbzOTZ9D273a6UPPsKkiDFwLcWm/g236SO54lnF1vGYaK2FYeHPU2KxfImBWRJygIfMsLB/cM9q+6SbrVhtg7d9Li6hBSMkjR1Q/CO49tXvPrkq1RNTWdqRpGwFzUMmcOrz4pBd7tBpMA4lkn5arPl+U7jfMRmzXyY8Ey8vX3NcnggRcfF9TM2F9dFabDaUK/XmH7D6XAqwIBwYJ5GdFWBNiA1C6k4enKp98cQef7iJSsfmW/fIIWiqovtNAZPnOdCGZaSftXQNi1hdhxuH3HBo0WkampMt0JbSyQzTSOVlgitWW13eOeISNpnH+Lu7xjPiViXk/oU73HzHXVXI+bE8XiHSJluvUXmxHD3KUiNrFo8lln3xGpNXl8RVIOUCt+0uMNn6OmARKGeME1WdRhrkUrhnCPGyHq9xhhNWkaiz4i6p+m3NJcvoNmRrKUSCZUTYhlxbsQYQ4oayE8g0aWo1DNkoTBVDU/BB1M3pJQxmy1ojQ+RlCIyJypry/CnL76lCoEwBts0T0QDR1NpVtstanWJE4YcHbpqEN4Rc0JXNeqpaR/DSEQz+cx0njF5Ic8L1pZmvMiAL3I9pW25RUoYzydCcLRakVXD7p0P6ZcZoQS7y0tspRAxkiZHmJfSK9QKULh5RMiEEQ3CVnityN4hpzNNVZNqg0yOVmuCc8TwRFEQkHPZ4HIqZudlmbl7AvJ2/RrGhTieSTEy+kDTdwhrwQXmYUSGSGdAK0kOC2KekU2JfYcYyKGo6rMuv6NZG7Qy6LpjmQdEctTtqiCRwsy8vwMh0E1PVbckIYp7SkAUCdVIrDA4H5BK43PEh4Qylpyhriq89Lx8/z0+9x3fge0avIRnz79QBJqPt6Ruw+ryiuM48oUQsTkwH48IP1K905LUijdLJlZb/N6xjwumsczTQjAB00l6s6JqNecx8uF3v+DN3QFhK27euaLZ9ERpUE1LLQx2GXHLiWl/S1PXWC3Z7/c8v9jQtw33syvEh8OItBpZV3SrFj/C1dUNtq1otz2ibcEv5Epw+/jA2XmQAjeOvP3sY5rNFc9/4PcTlOb28Y7x629IzqNEZrvtCplSwv50JCfBdbtjGUbceeDFqsPd3/PrX/4atl/z4Rc+T8gFMeWCZ8rQVh3vfO5Dzm8td599TGU0n729w7YtQmnuHx7QTzolmVUxM0v1v7O6/38e39abVBoXfD1S1RXKSrLOhGUmA1XbkjLMi2e/P7Hq1rz/4Uts2+O+ocxePFfbC7JM+NOJ8XBgDp4QA73VxNOR+/mO+aKDdY9VmWqakUIxHfZYrVBKUkmNHxaWcGJcJjbXz6jqGgV024762RXORSKJZCRtvUNIwe3dgcpLmrojusTxfOJ4PHI+lyTRenfBtDjS4tmuOrRRNJ1Fa6ilJLWCPDnGhyMeaDcrVC3wYWa1XrM4x6uPPia5SFgcQkW6zYr1esPq+buYbs3Dwz2ruShFoiggXqk1AsHjp7/N/v6etdWcbz8mNVWRIm43tHXNaQ8n50DAfn+HHQcqJTF1Q0yZIQeCLeVBoS11UyGUAVkQnEJKurqUEAGW8UxMqSgQhCivoRRFWGhWTONIWBZs78nTHk3Chg5Bie8v0xE/H0HEp9mV0jMKwZeF2rZIa0nRMw7nsomYMtydpEZIhdaCSkvSNwyzwbNMI9F7pJJU0qCUor+6IZweSeMBmT0Wh5ICLRUpZqqqeVJBVKQYyTHgpoEUJTmemBaH1kU0aFVhvK23a/w0Mk4TUSaUMmTKSIGiZl4S7e6CutuSwwxhZLO7RPcVisR4t2d/d0+zKiVYGQUKSWUNOYFMZZ4mOcfh7jVLXRGmE8kvmLqjatYk4Uk5450rYwFClI0WQUARg2OOEKeJaSr4HiENdd3SNivkk5Z8c/0So8CHQEQSQ0B6z+xGUvbYqqZtViil8NEjUsYqibIaZRQ6W3ysCSHjZWKeRubxTNOuaOsNiojICb84ZCVRKrG4CaUtF1dXTMvC/eER5yKdMqXcaw27zQXaVk8MP1sCO31HpTWXFx35fOb49hUPtw88e/E+m9UVXm0Zzh+hK8v/80s/yP/9f/zBgmAKmWN03L56hSKxXreIHAjjRKMlMSRUdOg085Xf+M+k4ZGLi2eYfoOyLY02PDzsiWlhCAJ91YJd8+rhzG7T064ku/UF8/nMw2efse17VMpsNytoGka/MB2OfPQff4PL62tUY6j7HjkGNqsNu+sr3nz8ESMCv32GWCJrJZHVipsf+j7yPHD89GusaokIkdM4UaO4fdhzfPwvNE1NZTWyq2itpZKRzz76Cu88v2JwCxLBMHt2L97Fa8kyzTw+PhJTLmzIqmKePePiuH7+AmvKwHSWirapub97+y2t89/Wm5SSFKillgQSu80WkWCeZ06nEaUtbb/infWWMXpy21BvVxyPB1aXa4b9nvvHI6o2TMeBvHiu33mJr8rgnpgdy+09dvsMLw3DMiObBl01LFlwmBfcONMkQW01da2oG4XWGU3Au8DqYoerOtbK4k9HpunANC2c9nuGcWLrMnO9QI4cDjNxEcQgObx95HT/SGUq5pxIbuTmYkNnJGl4LDVxKwlBc3F1TVaaOXtijsiQ2L95S0TgcuDyvRfUXct8/0BcBlIOLNOA0Iabyx3uVDhhS8yFjq4UWQjqRvPeBx9QGc3p8FBU5AnCPPPm9hG/lNtOVdcl4t/UaEkR4aWATBl5nBCHO4LWSFOhq6ZMx+No+wYdF+ZxJD3xBGNMZTOoaoSAFMqcSpIGzNM82NMsm1SavtnQ1A0xJ5ybgIzSBWy6LDNh8UzzDEKxsi3LvODdCKkov908FehmXbQP0RfleIyBEBPaKqqmYf6GWiKX2H5MCypllLIoQPkZ/IJfPCElmrZHKllcSM4xHI+EeaKpay62Gyq7IyOYk8BUFVLJUk4MBY8jtMboUg693G0hRYSItG2PrFqCm/DnI14L6vaaqqpRuy0xZ+Z5IYRIzpF6VWPqisXNuOjZKInuGvwTr67bbEghELLANE0B8T6x3JQuGKLwhBLSdY8REP3C/ngi+KUMeK92nIeJw9tbqqoCJZFGk5QFZZC6IslCY2AekEmglSUKxTQ7jC5BFoTEx0iYTsXjhCAEj88Z260QOSGlKcPUc4mMK1Pj3cxyPhFDgl7jXMSoinefvYNPiUxi1RjE0/eyhABCosLC5bpHVkXjM0bF//Kr/5EUM6JqOboj2+fvY+qa7V1mftgTD+fCy9QTyUhevnjJu88uOO0PIEsvfPW8RqbINI4oJfn8+y/58P0XjNNElhJYePPxK+7fPjB6x3d+7xdpRIbxwHbdIZ/t8G4mLgPD0bG6vkH1n8ePnqg1jyGwsqXcurt5ByUaPrl9w7MX1yghcJPj+fVz3OQRsuPzz29o1h3YwO3bT9itd6i7z7Btx6fHR/7Tr/4GfVVTmxYh4b0P3+c8zXzt61/jdDzy/T/wfUDm+uVLvu/3/Y8M88xqtUJpzTKPnG5fI5sa/MIyjOQMowusuh6pHF/+zV/DVi3vvHwXowyibUv4o+6+pXX+23qT6i8vOc0TvV5xtXuGNS3eO5St0UKWQbV5pm1bWlEEce5+TxU9bjwzvb2ne/EOQWe8kvSbLXa7JRLI5yKHaxqFSwvBL0zngfPiaVfQry+QtkPKwDDNxMXTr1tEipzu9mhjcB7WsgJ3X3TwWTFpy3E8QHJcdBU6Og4PA6vNmssvfIEUI2++/tuoUaFjQIhMq2q0ECzHA/n8SEqR9XbLan1JqDW6X+MRxPv7khKygm3VImNCaoO0Fi8y9eU1ea6QORDO95BnVGzLIiuAGPE+lpNmChAzPkB0ovRXJKScWKSEtkNVDUKU0lDVrUqvwi9kNyFyoJWBpDPD6JCyAiKVLqdqERzezbhlJISIrgqPTEqDi082WZGRslCpMwpTN0BJzUEZ5t4vnxH6HqV0Ke1oSwjlJhRCQEpQUjJMC+lwLOm+4Mq9QBvi4jnPAZ0kSlCAqFpTPoUiCo3RBmEiWkh00+G94/DqK5yOJ14+f0F/dUlMiel84nR/zxICVzfPiUKSMk+ai8IrNFVN1a2p6oYwjVTTRJhKA9u7GZEix/0jWYCSBiUVtZAEt2DqihBmGrliGA+Iec/dJ28RIpBWa6QySJWJ0REWjzAVuaqRVYuuW/J0wIWlbPZ1W/QSIuPCWMyz+UlxIWQpjeZMVTf4VCj23yDii1gGM4e7O94uC8/f+4AoSlk0ijLztsywSE3WFe3aIIwh1WtCCkiXcKk4m7zUCKFRWRU4bCzUBtPUqKamTZlhnMrBqVohMmhycTYlT/ARKcEj8Sj0akUQZWBUa4si4OaRxRfTdlUXzJU0irgMnO9OVF2H3OxQEtY373L3+i3D2yP26FHCsH35nEp1iNZg+wTREdzE3WefEpPiarNDpkQC+mZFW9WMPmFMS/KOeRnQUlI1DVW3Yr3Z0lSW9959jxHB7uW7SOeY7z8mLBPExMXmgmG/p22vMfUGIWZsGlE5kK1G1ivefvVT5p2i3l1x3V1ADkw2cfHeO3gjuXz+Dpefb8DPVFah2oaq7Xl49QnHu0dev35LXRtWdakm2KbGtDXLNKKWwIc3L9Dvf4DWlsl5rm7eZXfzAnE48e9/9X/i89/1ndy8fIE/PPKVr/wmlRblANo2rLYXdF1P2J/4ri9+HyIlPv3qb1EZzeqdD6jrlq7bfkvr/Lf1JvUdv/f38fbNLZUUyJwIh7eYaovd7ghK0CpF9oFxGQnnA8e7h0JgNoZV3eImx3x6xKw2cHlDc3lN1a+4/+pvkoYjeRmR08RGCZZs8BkeHu45fPwp73z+A158/nMY03D30R1+GJgqwcXlBXa9wepi9k0A48Q8T+BCYQzGDLbC+RJi6OoKIzKcD8jguelbFiORlQVd8fHHb5jGEVFFks6YpsVsruh2V2RtiULiTiesFGAESQjIAVtZEHB4vEcojaostuowWuNOe073dxw/m3DREKqWUVkwFQ0SvCO7kWm/p7NFI602O0Tdstp0dOtL3r5+xXC4Q9ct1fqCnD0ye4SXT/DJVHotaWJ+eMBWNZW4xoVE8o7D8YGmbxFGUTUViEKYFlJgnvxZKSVyTszDGakydddRqlARlSMplP7G4n1B/CwOrTXaGkJMNO0KIR1SDOTkEXWNm2O5EWaIlDKRfPWIrBvYXqKbuhiMcym5aQm1tSzzRJhOhGUmeofuW9rnzzkME8P+gJYCLzReeE7DgYzA6AZtarSpkFlikKTxBEoSvCPnQA4RHwTWVgw+FDyTFrR1h1Qa068x2pQh3GXGvf0aeT7hAyxz4rOPPqZdrUofyxhSjEQiajgi5oHctiRgWSZCa3FxQKKIYSaG4moSSpFTkQJCojTISrihqi1uLrfmrCuksVTtmo4V+XDHMM+k7PEuEiMIW7HZbFimmcPdHWmeyUJQrbagGqK15JyoU0a5mdHNOPGktYiRRPm9qfp1cbopjV9mjNUoQKdSMfDRFy1K3VPrGvfwSFgyxpbecDCSiAJVk1NFCo7j4Ui/WyORLEliukuGEPDHieA9m+tniKpi/WwipEy3W5ORHKYjclmYHu6J0bParllvL6is4XDao6RiHgbG44GpbdGXL1hffQBIpsdbtD8zHh/QVcPrhwOnOSBdCfiouliJbYLD4YTuWkRbU6srcBPz3Ynn1xckW/Hq0weEj3T2yLNWoKpIYzLDOOKXAXdy9Jtr2u1z0CsqYbCbLbLKuMfP2JhEfb3iVz/6Kq0ppfQX3/EhaYm404z3iWlJbN75Dp5/7gvcffnXePXr/57q8gKtNV/7L1/GNg26Kvbt5Xgi+MDq+bucTkf2h0/Z4El3icvjiuRHGh1JVc/6YsvxzWccjrcMg8bP/xdI9+3vH7l+9pyq0oyHe4bTRDUcmKcDQinG/QGRE0lmbN0SY+JwOrO6vGR9fY3MEJRkjuVUnY+PPNx+xnL3hkqBnx2n44l+e8Wzd97nxec+4JOvrIsKXoDRhuuLa2zMhOlUkkoiMy6eutmgUyaHhGk3BG1ZlgltDDpnHh/3BBIXmxVCSWJOHPaPrNqGi6sd52EoUEtluJ4Dt1/dQ99iWsv68pp6fcE56qIESIklwBISx/OEn86s+o7FaDLgQmlSpnnEVjXOVszjyOnxnvF4wCdFtb3ErHfEceYYPaumZrXeUguNFODmgYfHe65uXtAYiR8OxPEAS1n8T+OJNC/UtaUyqgzgCkNaPLoqCSVtKqYQiShCLDc2iST5SAwRqTTTcsZohbIVlZZlsUkRYRWqsujKlhve4Mh+IcmMF7F4fhJPDLVyKvdLUZtXVYWi0ACSVDSmJM5ySpiqpgMsppQf7FPfLDuSX0r5E4EUhakXYmCaZnSzot2s0VWDwjIOCyEnZNvS5QpjNCEk0LYkTJ2nEpLjMBGGgfU39BA5E5/mXnLKROcQUj2V2GC33bDerlFVXRQT08RweGRcHKlky8FYpiCeUqqUeHirmE57fM6ElAiTI8eEEpHJH8lIKmWwgtJL0xLvPSkXQHOOxbE0B490I0YZgveEGZS1WGPpLtastjUyzPhlwfuMkIaoEv6J93g6HLh985qX770PdYPWBqltITEsU5EGhhMxe4IQRASqrrFGk+aRIUZkLhgnYsC5hRAcEIkIslEIadC1QTcVs3ckAT548gIqLJAT1W6Dvb4gRwERIFGZtmDCcmQYzrhlRmSJDIl3n92QclGtqyc/3eQ9vrJYXTPHiBGGRMbnRBaKOXgAwjijxVtyBmsM7vyAjx6li0pkvdlQd2tef/wK2605LR5pBat+RS8ypm3KzbEWyFbCMvP61WuqrDAjSFWzj7B78R627YqodfYoN3G+vUXMnkzCLGeOSSJtRW8l97/xH2g2GrPbcnlxhTrN6JSpnES2HUvVEzOEbqHuKsLhns2mZXy2Y7W7oOnWnL7+dd6+PrCczrQvNdNhz+u3r2hXW55dPUNttsgws6nWDMExDxPaWF6/ueV7fu8PIOeR+e2e3HZsnr34ltb5b+tNSkhJ23eQE9rW1Kst/nwkzAutrWgkhOBKTNUaqr5naxquX7yg3e2IIdJLxeQcQmQeb19xfPsGKyWqbdDKcHV1Qwie4+ERISSNLcBTUqZSiuNhX7ApIRKmkZgjqmoxdc/x9i3j/pG2sSgFkYwxmhg8VW2LwZSE1QbvE9u2ISVP8jMyR3LwGG14+eKabaNZjo/ItmaIEE4jFy9ukMYQgud8PJaNb1lQQj71JEQpMVjNsiwENzHPM6Zu6FcrNlVNd3HNtEQ+u7tDZ8HNs2eoIDGAUpLV82e0TZEwfvL1jzk/BTtiTMRlLE13BclF/PGMSC1y05PRWGtpdIVAoo3BVFVJ00nNef+UDBqX4qWRFrvSCBGRGcbTA/OgSpIuxKLwSJHJDdimgdoyxzIQLHNJwyFKmk9IiZbFa5RjuR0EUcqG1hRLbEwJnzK2arBNi718iWl7gpCl3LVMzKc9/nxA5kgMlJtMDOX1yA39eoMRAl1X+PW62GxlAOfJbiGKjM8BIytsXVNJjW7bEuQgIVMCo9FPKad5ntFScHV5iX96fd14JtUaJUrJ1GhB01Tk1OKcK54lVXQlKStEFuQkCnbKtJzOJ4zV1Kuayc94kUhBkmNm8jNJCPpVVX4+/Dd03uKbenfnFrybSzLOeULdYGxF3bRkacgUZXoSmqTKIDnhzPF0Ioearm3wy8w8nNnWK+pOM57P+JjI0ZeDXIrF4ZQCSSikgBQSyXuUKqXH9GRGlsYgtSTFQPCl76lJSGXZ7i5xzpXxBlXi82EYyDlRGQNo5JPxVypR7L9kKlsTvGOZZ1ASURvm4HDBs95eII3Bn4vjjc2WoCRpnjHSsNpumMaBnDJV02NM+b1WCuI8cD47VHR0XY9PknmciD4yLYlhdqxevEA3ls2mR84L5/OZ8XCi6RLjeSC7mZwCSinmnFhSYlxGbj54n6prcdNIDInrm2v2OeLbFUZJWM64vWNwkW27I5CYjo8sDzPi1QP9yy/As47j/R3VONK1FUlnWjQ+ZU6vvkpVr1jSxNXLa1SSpByRArarjuw9fjgzOU+92fLhF78PISSzGxgeB8w8Y0Kg2z0jNg3Pbyw+K7SouHl5xcN54O74/4fgxF/9q3+Vn/3Zn+Wnfuqn+Bt/42988xftz/yZP8Mv/MIvsCwLP/IjP8Lf+lt/i5ubm2/+vY8++oif+Imf4F/8i39B3/f8+I//OD/3cz+H1r+7L6dqK6bzCTdOjPs9wc3U24bNrmM4nai3a2JsqCqDrjouL64xtsXUNVnJUuJwkf1wJgWHtYbdxRZyJKRM3bb42XHeHwlzIKfiT5HaoDKc79/QX16wP97jzyeuLzZYU06c0ziRk8foyKrTrLuWZZo4n44sc0kImgx+HDBkiBERAtPpRJ4njNGILJDSkVTpHWwurwHwKZKVZB6OzPPMsoxEN3G57umudyAti/dkBEjJ4hxTmLFC0nYN2lSF8lDXaGNoI9j1BhCsu5YwnsluphUz26ZjPxw5DxNSW0QCazXzcCpKiqanbipSCqiq9BF01TANEzl6rJLEXOjOh4dDmdHZbFmveu7GkWmZyia1LFTrDq0UOQREzCzLhM8KUTeILIhIxtFhQqapdTH+el9O5zIiEmXxchkRwBhLTJEUE1obUgxlRixnUizPSyPLLROJtDVG6DLIi6CzFb5pGI77chDSGhcDdd0gbUWlFfP+ET8tdJsdm36Lm8+clgfO41j6NzITl5GYFPVqR9OvCH4mnB7RSoDRiCTxIRIThXVoNArH4/nA4/Eef3pL36+YA9RNR1VXVIpSDpaanEFLUVxGLvLZ7RtSyuy6BheK40xXFSYsIALJ1jifIQeWsKDmhSU5xJO4L8ZyyyteqKJMd7kk/VL0TMOCm0a6pkVISRISqQ26rsjGIAfQCnIKWKPYbTf4ecbdf4J/1Nzt99i6oe37J4NuEfCZqkJai9Blodd1i34q+froy/v2ZMy1SiCmkWUaON+/IaGRtim943mmqiratqPeXpC9x/oE05kkIcqMatdorZEil4XeB8zT5+rrmloXHbzMIGLEKFX6um2HbFpUBu0jLnqqukGSaBrLPM0YY0kIslRoqzGiJaKYQmI5j5hW8OzmJZeX1/SrNZNf8MvItD8yDDOIhFWOVhtO5wFtC9SZquby6hlK14hl4eHTT9AxcjwdSFrRrNas3/8cPjiCd9TC0CnBOA8MStC+9z4qejIGdXmJF55eZZb9ibptsH3N69/+OjI6Dm/fcHX1jMFNBFoqDMIovvO7vpPbz16zqlu22y0Pd/e8fpj5tf/X/8xq3dHqiHaBh8/eoADjM40yrDYtWmaun93w8atX5GZHt734ltb5/+5N6pd/+Zf523/7b/N7fs/v+R3P/+k//af5R//oH/EP/sE/YLPZ8Cf/5J/kD//hP8y//tf/Giin2R/90R/l+fPn/Jt/82/47LPP+GN/7I9hjOGv/JW/8rv6Gj7+rS+z6dYcHh759OsfsV333Lz/nNz2TyeiCpZMmDxhPtBe14jKlKZwiJASQsC6a5inRNIZR8V4OpWkV4iEEBFCEP2CeML2+zyjpSbMnpQTtRY0qw5pDFXb4VIix4l1a9DNFikygkj0M+fDHh9jESZuNsQsmJcZIUUhMLcNIXpS8GghSVNGWYNWEqEVy+IRQjAcDizHM1JK/DcWIlkWO59GQsooZQi+RLAVAqEt2thS6vJFTBicQjQrLi4vWMYS81U5lAFNt+Bu3/LwcGBZHHXdlpLUksh+pl5d0F09RxmF9xPJNsSYMKZGtRIRPCmHsng+zR5l7zjcvsVWNVIqmvWGTLF3Eos64/j4yDwu2Kajaiuq7gJZt0glacYJP5xYTifEPKGUYBlHpnN5LYQQBCFIWlE1EBIsekap8qM+TgtGq2LnrXQ5TWcId2/Ka2RqHm9veXz1dXAj7fUOv8ykWF6vGBN91WC7FXEZOT7ec3o88aJqUF1HcAFTNayuioBRCMFwOhGfsEGqqug2PUOcyG4BkclKoHVDVzW4aWAYJ4KbylBsiLz65DUhfML+cOTd997n+uYGn0HY0ghXSpFiLOW4UDxedVXjti/pr56hyLAsqOFM8CMuzMyLY93UVCYjswepWEJESvXUB8xYa2nblty1pJy+KR3MMRVf2zIjBIxzQftUmw1+GZHBY7TGKI13roQ+rGFOJc6/pAWjK4SFFAV1vUUZg1AKqRRSaqSUhUOZElJKrG5IT3DkkDJCKrKpSYsn5AWlZEkrrlb0mw0pC6q6wxiJH8/gXKGcJE8WGSV4AscmcgKhKnISxOHA/v5MbQ1932OyROsG0bXc3x3IIVOjSsxVSGpTsbhAiq5AqOPCPJ7AtshmRbfeIlIk+ciuroirBl1XdN2K+XhiPO6LSkcKolGsdxtOw4nJebxbMOsWoSRL8NS2RsZM0xgO+0ekgvNwJuaMQBBiRvd9GUIPHjHOJD8i65pxWuiNYVkWJudgGJHzRFgm4rzw+qtfBwXtquX8ONOuex7PB9Cavl7hvCPkTExFY5NDQBz36FrTVplf+/J/5vt/4Pt5/u4LSBsmF3nz+hVaK9JhT5wPvPULwbaI45nFe2Rw39I6/9+1SZ3PZ/7IH/kj/J2/83f4S3/pL33z+cPhwN/9u3+Xv//3/z5/4A/8AQD+3t/7e3zP93wP//bf/lt+6Id+iH/6T/8pv/7rv84/+2f/jJubG37wB3+Qv/gX/yI/8zM/w5/7c38Oa+23/HWEw4Gvf/opn71+g1KKly+u6FIiDWdEXTPPAzWSab8nCsnuvffIShBcoJIKN474+UyMgXH/iA8BIYFQLKyvvvoJTb+iamvWm75ANbc7nPc8fvKGw5tPqbu2DOm1FcpqogQhBArPdDghU8KFhaqxCAGyaahSgcxObi4uK6MwlUGaBuVq3DygZCq9lxSQlEVoCjNuKaUiJUq50+WENpq66XHOczidMMKX5JapEQhaXRqkyZiCHIoeUixlBCmYvUOSicGRoy9fi9YE07CMM7XpMFnh3IxbiiiwrQxS1/RComyDVBJra+bzwHgsyUhrZXk/s2X/9hbvPDHOvH37wDhNXFxec/PBe3R9i58n/FMgwVQ1w+BYoqCpVoh6g1Clnt/0FVoJFmakMrSmYhpHzuPpm9QFgShlrxhRykAG5xwpRTISpVQhHqREjJ7sAyJH9g8JZVsOj294+/oj0nymWk7UlcVag7EGiCwhIkTRkGtTc3FVU7UrlgBki+06tASjFTrDMkwYGWiewiCmrpFaF5eZAEyDqVraumFUhYfWdprGVoRpImXD6XxG24CxFQ+PB5I0PHv3mrrflaj7NKKqRBoH2tWapm4wKVK7mRAWVHTASGUjOSmCAyMSvVVoBOdpwYmKplsRvGcYB0JMKC1LCjIG5qWAj1dtYUy6FCAEUlgI01PqLWfqJ8u1WxaM0QynE8P5jKl66mZN225JAmrTYlY11lQIJZnmCQCjC4sg4MtmkErJL/oAMSB5OtDVHVlV6G6NkhQUGpCFAGWhbojLhNSKZFu8lkynAZNKeVJoWdKTdcM4TaRlBhHR65plHpkrTbdbF6i0O2O7Fjdnjg9H/DKyaUuK+Bshnvl0wuSEMZIpLBi1LqU/IbCAUgLRX6GNxk8jbhowSmGF5uwn2lXDuEiuNj3OObLzLMNAWBxNW6O1LodZkTF1hQst264pqCLdkJ6CNT6VmUCVIDMh0kLdWGZtkUhaF1mWE6prqK43iPMRPS3IeSH5QGUb0ApNpLUdbgqI2hCjoOk6NtsdkpK81cYglgP/t+99l83WFOgsEiMsz9Y7qr4j5BkZCiFE+AwW3Ljw1V//T9/SOv/ftUn95E/+JD/6oz/KD//wD/+OTepXfuVX8N7zwz/8w9987otf/CLvv/8+v/RLv8QP/dAP8Uu/9Ev8wA/8wO8o//3Ij/wIP/ETP8F/+k//id/3+37ff/X5lmV5UkuXx/F4BMDWlg9efB7bN2z6lk3X4BB0TaEyJ+dxLvL27QPX775HPI4Myz3D+cx0GpjPZ3JwzM7z7OqybBw50109Z3d5w9tf/V9Yb3ckLQhJ0rQbmstr/DLzm//lI77+2S0r03P36p7tpmVzs+HFO5esthtUtUUpw3h+xBM5Tmc26x1GNchKMi8L3nmMymT5lMhzA0pXhCCJyWNkRpsKbSzL6Y60OE6HgZjhYrchGvU0dKp4OD7iQ3oaos1sdxfltBk9Uki0SKx0wmiPl5noDUobkkxoIm5O5Ej5AU2eMJ0Jw4hUBtu2JGNockeOjlNtWeaFw+EB99Vf5+r6GmzFaveMVCWO7i2KSIgSVOHyST0Rzo8IY+hunjN89oqPPv4aQcD3/MAPEmVg2d9Tr3rW19ckW7EEEE++HZkjbplLozwXpbpIiqA1uutpcgk3yLZBWoPMEp+LfM8/zT4550piTkoCGURGBI8bRhY/k497dNchrWL38prkNmVmKmeyVmXgVymm88Qc7sutwrSs12vq9aZEndNEmkZ8COSqYxIW2hv0RYe9fonQmpAWqqvvYEyf4U9vMf5MXBa86zEZhCoHBqsVendJt95RnQc2OQBQW1O8XlWD1TXJj/TW4LxnyYFKJvzpjsPtpxxeGcbDESng2ee+E3P5LtJ0dM2RdHpNPo/4GFFtTd9tkHVDmgP6vGBTJrozLpyQSLSHPB45u4GqNlht8T4TRZmFyjmjvUPMZ7I0RB9JGVTdIqRCaOi2a2xV4/03btgJqrr01lBP/b8FiGSlEClBygQkHoOwTSllZ00tLNKIgpfycyFmDBMKQVSSvF2oTIVIghwyImZq24KUSKWpbJnDkrqUwYUCu32/yDCnASsyJgW8n4nLgkJT1YbaJ7wyxFhCGTFnlMooIVDWkFEoHwnLCELRbzboqiZFj5gH4ujJy4wIZ3zINNWaWklScKjoEQLauiXbInqsTE+WkrrfkBFPZeoKazvu3rwm+4WQD8TpTJ4X9O0nDA5Gt7DbbXDzTNWtqFc71rsLsjEsd6+R08h2s+L6xQ1vb18zZY+bPUJKRBasnoaz7bph9o7BeU5LZPP57yKGQOMGHr72W2Rd8cH3fg+23/L48IAVEpslsztSCYmYM6puEBc1Wq/59OufsN3sWG2uvqX95ne9Sf3CL/wCv/qrv8ov//Iv/1cfe/36Ndbact38Xz1ubm54/fr1N/+f//UG9Y2Pf+Nj/63Hz/3cz/Hn//yf/6+et+sV25sbsDUieuq6QvoJckRmSZaSLDOPhz1vH/Z88vHHBFJRfp/OKOBys8G0LaNzbC43XN3cUF88w8+Bq/feR3ctsjIYrQvU8wl187nv/h6WJfL1L/8mr71HPirUVx3/w/d9B//D7/le8AfctKCbBk1EOUFtDG5xDI8nzueBqjH0m56QIIwLAWjbFX4aOL+9pe8b6s2W6CPL7GjqinZd4sJKS0IKWGMgZ2xVYyqFUBofHFqZQjlPnnmZkcDiHGkGoS1VVRFSJuSAaVfl5Ck1bhpYTgfcXLTaUZbGchaZJTrqtsKwAmnJy4LQliUkghsQ4gGBKEOyk0cYgUaQlca0Fb1coa2hkwKjYNysadoW70bG4QwhsbYFU2StRRoNOZGCw8vAvIyoKJDJ484n4nhgpUQpE2n91GtKSB/xoZSAokwYY5FP/U4hYZlHrLFoI4kKbF+DLwPcyQeWxYGU2K5Bq/K9ZzLBlxLMMi/oDNpU1O0KW3cs88jx/g5JYlkmbGXxwZMECKswmx7ZN09KiBrZ1ly1HafXlulwXwgcWaBE6aXVVcEihZgw2rDTm6fogEDKjJsn4umOFCdSzKAtxlqMrZhCINuKWjUYW6OqnsXNTMqQbYU1NUJ6zksknkf6ui43leGMkpIcA1LDNE8kERnGEZES1ih8DITF43xiu9Hl504ZtDGEZWGYPdIXooS2NUvwxJTZbC/wQZJlg9I15IXk5tJrDIm4nPHzUKDPUoBQCAc5ebTRJKVJqgLTEVUmBVfKWykXxmWWmKqDrCAXD5lte7SUSB3xPpABU1UgSnAmpYQQAT+eiW5ARo9IDVoJrAgIIYkYUr1CVIL0tFnlFMpMkNFI05XxDhnhaUg8RJB4QpZoa5HKkAWEmBB+InpHCp5pcUWVcx4gOAQZKRXLMOHjgbppUUqBkMzOMd/eUjctu2tNFArbrdheZxSJ4Eac80ipWeaJF3VVWgEp8PA4EkTL7e2Bx/sD0o/YNDH6wOGxuNCM0VS2BxGx1jBPI5kybD2HofSwU+T8+MDzrid6z/TkoKu94HD7lpWusF3F5cWOnGBe7mmbGjGOmCBZDkdMrfnOD97j/uGe/eH+W9pzfleb1Mcff8xP/dRP8Yu/+IvU9bfmp/8/4/GzP/uz/PRP//Q3/3w8HnnvvfdobEdrWpqrhvl8xkhBVpLgI1lIzvNCUzdcvfs+h9dv6LuWZtUxO8cyTmxWK6ytcBkWvxTjpfe0ooQTNjclpq6MRj5ds0Wcse2KDz/8HCJ6OgvnYWB0HpvLgnJ43LO5vKDedCwx4U5nWttCjNw93DEdp9LnUoF4jhhjCT4QyTycR3LKyEpjrC323Cwxqw3tuqcVCucCuAWdFpq6Ivjw1KsyCKlprKWqC59OeA0iEsJCTgl3DrQrTa4kIXiMUsTJ0W47ms2GY07F6vqU1kopcPf2MxA8yfpKH8dUlACB0oDAL46jf1titpTFQNky3BvdgvcTPntyAq1rNhcXXD97hiIxjEemeWLTrhBKE0Ms/QKZUCkisieERIgT3hV7cjwfUdHhphmUIktBzAk/OGSCnBxVXYFUyG5VGvw54+JSTtrzVER+bUWQko1eM50HUiwoJZ8TSQg8grqqESIXZ5AUNHVV2JBS0/QbYo4Mh0fCMmHqDqEq4lNCrqoMQiTk+YjQNabrIXqiD2XDqTvUNJKiL4m6sGBkppKKREQryi0kZ7xPpJxBgsywLCPJLU+LqaTqV/Rtz2q1JoZIzBKkYppH4umIqSqSd8RwJrkzORa6/TRNtHmNShPRL5y9I8ViYvbzQhhnhEyQNUsM+Fgh1BZnWux6jRACGTw4TxSi9JOEJAmwdU1X1VR1jVsmQp6YZI29eIbUNcl7/P6WPJ8JbiCR0e0KIQ0qS0KKJB9KSEYIcq4JSZDHM0sccNGR3FI2SSHhKUCgbEtVr1DiKSSjyvxV1qaEQ7IDUgEZB0f0ES1KiTil0nNFSqK1qKahbfoyI5cfCjHDViCeNouqRlL6UkmUVOyyP+JTRipN9J66qsrQsVvIMTylTPvSExWSFD0hRGpt0TIRoyvhkKYu/WVrCCEW1NRUVClClAMaGULKyKqhXW/p3YCKjmEciFmwur7AdDsuYgC/MDy8IZwXNt2WLDT7xwOH45HVdkvTrOiFZJwctlkRSPgYODzscd6xu9Cc7m4ZjkeiG9CijPhEEWlFZDwdEOcjrWmobUXIidvzgf6i5/K7v8A4LCz7E/50wj3uv6X1/3e1Sf3Kr/wKt7e3/P7f//u/+VyMkX/1r/4Vf/Nv/k3+yT/5Jzjn2O/3v+M29ebNG54/fw7A8+fP+Xf/7t/9jn/3zZs33/zYf+tRVVXBrfxvHqurHdWqYFxkZUr0eh453t3TNj2rdktV1TT1mpurS5q2IcuinF7cgtamIE/6DV/+z7+Ov7tHPx5Yzp56tabSujiPvEMajVvc0y+5xxrLqoIvvP+MZZlp+5KOkzmgO8toLSFltCqE7SAUPkRs02KeTiLEBSUEMmbSEomTw2tNd3OJ0msIiWUuTW1jFYfjGVV3DNOElWXyPnhPjqncmkKhBEipGMaRZZ5RQtA3FZvNink846aBaZzQ3RahJMPpSPSRcTxRn9YE58DNQCSp0kQfhzNWW5qqRWVDThPLsuBDKD2WtkNKQEDTb6hsWQwXv+B9IC4Lh9tbpmVmvbui7kqMGDL9dotyAV2N4BbOw0Db9wSfEDoiKUioaZmRShDjwng+EKeBWkpyyihdflkNmvP5gJ9m2r5+SqOlQp7QBqWeQiYUX5V3klqDap5I+fOMEaLM1sXyuqcMh/mMfoLEppQwUuH8Qs4C5xxGUugGbYtLgiQkdd1gtC4L0zyyf7inGc5stjuCm4kx4YUonEECKZVwi19OJFKZGUqehMCjCGiENPDUbxG2wmhLXZXww/l8wrkyk1JRI5MAA6IyNLJm3N9x/OhrmPqOkIr3qTWaqVsxD2eaXErH/nigaSqmHIkxIzLI6PHecTp5plwh6w11dYG5fI5cNcTxTJjviQmEMU9NdofJgro1gGCZZxprCVHghoGYFbopqpJlOJGWM2464VNEzg5pC2orxRkRQ5Egxj2qeixYr/0bhDsioi9SQV0ORcJa5tWGutsgVgsISRQSaRuUMYRYSowy+SJ2zIKMJtkNylpCDkQ3kv2EyI7sj4iQUH4huYkcHUlKkjQkaTFpQcSZeTzjxwPmSQjaNRXjNCHChBARVCpD6qaoU4RSRcCoShLQNJtvjn1I41HLjCwNchBQ2Yralvk1dz4grUVpW3rEy4KfJ+qmJg8JN5xoK8uqaQhzAD8THj5mu91iO0OvLrmXGaM04zDStpZpUUgt6bY92lSoCEcXUbZht93STiMJz3F/QAhYrVpSKuM0BsmcBPP+TJoLOu7N3Wu0VExhocqKZe8YxZmYAs3zG64ud4jb/3bl7H/7+F1tUn/wD/5Bfu3Xfu13PPfH//gf54tf/CI/8zM/w3vvvYcxhn/+z/85P/ZjPwbAl7/8ZT766CO+9KUvAfClL32Jv/yX/zK3t7c8e/YMgF/8xV9kvV7zvd/7remEv/HIh2PBAaVMCE+m01VL069IuVAC2qbiPI1IJfDBY6uK/qnxmwCywLlAv9oyHR4AOD8ekMLycNqTvEcmqNoa07WQEtP9J6jokTk9Jbha2s0FmAqtIfuJZZ7JItF0Fbnvi7IhR/q6ZnYLUo50usJKxeQSslkTfEbYirpvyXFB64ibA7KuWJaJx7t7ds80xtYsfmTYH/gGR3h3cYFWosya5IL20f0KQQkHOB9IQiCF4nQ64+0jVpoimjMG7xfm4wFbV5i2RhrFOE244OnaFm1bZL3BoTju35CDRytJ3fX02y3TshCcB9uSbYVzHh9BNj2iWaMf9pzevCX6zPoyUTU185io2hXd7gbdzIx3n+KngcN+zzTPKG1AHdCmhCZAEoNDkpDWsur7Ur4BwuyQQhCcQ0iJTyBjCZeM00zMYG0FKbPEQr0PMbK/f8Q2FW2zRuSACyXqLHJCZk3yM8swIFcb1pstWtsSzFCSlBVSl9da2g3TWcEyoSuLtlUZPHWOqrKESnKYT4xvTvjphK0qqrpD5cRy3JNCRCrJeX8HyRdMlfcgNWq1pb56gbGW+DQUGkJEEJA5weKQIaCEJvgZmUuUXYaMMYZV15Lahq987StIoVld7Ghag9Y17J6h+x1KG6bjLcP+TLMsOBJJGfpmxVxVyKxRwiOWzDKdGU/3+LGjbjU5l3J0TpJqfYE/3JNzOVwKIcutUUg8muhmhsMDIn9C0zQkYArFXBtS4fXhy8xWHPblfY+JlMD7CAKCn3DjGZk9koxzMzEJbNORYuJ4f0u/uyJeDAhpkLYttAuRySlD9GQ/k59KkVLrQlRHoJWBuiGTETkAmeBHwn4h5RIkEU9RfJUzKYxEL1jGAZVz8YEZS9P11F0HiCdvUknbCiGQstx+pCwhj5JmfLqB5lTwZMFzPh0RxlJ1PVJIBLEkf92EyAmlNVZLtLA0SqCNLl+3Usw+okSi0hqi57w/MCwzseuo2p7VaocPjjpnxCz44IPPgbYsWbLd7ri4ecniIk3TYJRgGo+Mj/e4caRtKkxtGAeHrVosmsbUZCAYQ7vpYbOl3axZhj1hnIi+hF90djSmY3GR7PK3tM7/rjap1WrF93//9/+O57qu4/Ly8pvP/4k/8Sf46Z/+aS4uLliv1/ypP/Wn+NKXvsQP/dAPAfCH/tAf4nu/93v5o3/0j/LX/tpf4/Xr1/zZP/tn+cmf/Mn/5m3pf++RUkYoQSYjrUJhyWQqa0os+fiAP8sy06BUITFb+5SaGem6juAi0zKwantU8ty/fcvjw2fs/8OX+ejNK4zSPNvuePbOCz7//d/HerdlOg9EN2ErU8Cq1Yr+5kNUUzPs3zIfjoz3D7gQSdeJqltRdT3jcCJnUZTS0ROGM0tK6O0VzdU1saqpVEV8OPLqqx8T40x/dYUiFWJChjA5qroheklla2QuP9A5R6RQjMMRVXVsNhvK+JVDSonznvNwxoRI161o+p5lnIkpgTb0/QotdFkERELoTIt4YqAZTL/C7K6LiuS0L8mkpkLXHVnXaGlRNpGkBVXhsmTxgWbbo7st9v4R/1u/zfHhLeO0p2l7rm7eRdcdUVpiXkgpYbXmcBpomhZtDdMyE2MoPQ4hUNbSXFyQvKNv+6KgR7CcS4Q9OA+q3GSa1YrKVpyOR5bFIYREa0kIgfA0Q6aywApDJSVBwHkaSurPR6ZhZJkLFNRai9VXNF1PyqIkITOQAtHHspi0PT4WWnpIhaChDQihqKoWkiDOI2lZyo0uRt7ePzAdjqxWG7pVj48QsyiopOOIloZVe0XQDUlIQgy0laRSArk4luMtD3flFtPsLlHGEnQopaMgGaeFM0CKXFzuitDQFiPzlDVme4moNSFm6l0smo39PbkyCC1YnMPunqGVxAwjw6vPWA53rAzowTKLM0Jq6romIlhUKQN77xFQwgRaE1NkCXu8n8DGEnnXhbYhskfriqgM5EgKDvxC9DPTEssGVFkikeAXknMIbYjJsPiFGEtFJ4wDxIRSmhMCaTRt01MpTfQzUkmM0sTgWOYJKQRCCLSA6CZ8GEvoRmswFVnVRSnvHdkHcgSkghhJy4AmM4WE0AahNE3doZuGKZZIu1KSGCNalpItupT+jC7qHQm4WOLwImXmcYQcMSLhxjOPj/esr27olEIIEBRCSU6ZnMtMqtQabSoQhdKyKEG/ucRNI8RIJRIhKVYvXxbklQCvFVXfIYcJCdRVw+ICStest1tkVRGAqusIKSGBkBLTcWDYn2jalmmasLZhOo2k5xvqpicczqSYiS5xfXXN+e09cTijlWbV7xjnufTcvCOkSNs139I6/386ceKv//W/jpSSH/uxH/sdw7zfeCil+If/8B/yEz/xE3zpS1+i6zp+/Md/nL/wF/7C7/pzbb/w/TRNjV5mci70CBcnwnQmjSUe63xkHEdSnBBZkuU1SSmyVgitCBJIZUd3LhSCgIiEOEPSeBQvvvO7+Z7f/3u5vHlG9uXNlFoiheP40W8z3b0pg8ExcTqUBTy6mbptGc5nUo4kNyBSRiMYjyemuzvm4xlnDDfPPqCqrqiaCq0FUlm6uOCHQnZgGPHzTNc3zPOR4c1Ev7vEdFu0CGwqjapajocjbnZFy0Gie8I0IRJSw2bbYWSNljVECbZm7CMxeHxQ6EYTloUcPOfTCZEVum7xKXE6nehNX5IHRtKYFU1lkapo5yOCGDL9ymCaisq1zOcD85tPUeo1aX9H3zakaBAKUDXt9Qui0qTxRDw+4IYBnT2qqpCbS3QMNMvMHBaSBKSlblpSFtSbC2qlENHhgkO1Ddl5UIq6rTFNgzCa+DSHlHP5L8ZI9IGsDPVmS0IwuMA8zvhpxo0L0zCShSpSSqGJOReSQI4kP6KkKJbWOeJdQhvFjEcqyC6RpUQoyEJS1XVJFy4zOSRQFJpJzhzvHnn75h6lJVVM6CwRFy9QbYvKmfz6U6oU0UaShj1JaeZxAFfRPEXtQ9WRa09aHHWzQiSBTSBCwLmJcTwXP5WxdJfPiEKjnoaXfUzIp9cnakGIWxZ1LHK9TU/VttglEDI0tianzHq7prOWxgjydICcEKsrxGaL8iNy/5qUEqRIcBOJTN10rFbXjMcjwSXarsf5CDkivCdmxXh6JMxHhKAwHJUmmw5r5VMgxhMnR3RlAy6mZRBE/DSUgXClqGyFlGUgezqcsLYtc2LBwZyRxiJSQjzF5ZWtCSmRY8AvI8vBE6NDG4VpWup+B6rDzUf8OOKmE1VlqbuOWVq8n7BCoOuWICRpCagYQBa4MFIjnmg0ImiEkJBD8WsBIiaSUvh5wC9PBH6jqfqem9WKuu2ehtUVOWdMdcnqUkFMLPNEyp6YHO16TVaq9Fu9L7019zSbpwv/MPilpIFx5HGksjXNqofk2K4spm5xyREXx+wcUyol+ZACbjgixML7X3inDFyriiVqounxosOfHdPxxLarGKeR5bNPWO5uGX3k4sULBncmupnOGGCBHEol61t4/B/epP7lv/yXv+PPdV3z8z//8/z8z//8/9e/88EHH/CP//E//j/6qRFhQXpBRSA+TVkzz8ynA+PpyOpih9l0nB8fkUNCxIxbZlZXVyy6YoyCaCpShkprTD2Rg+edruHqZeL9KfLq09e0wpHuXjMLRXf5ku1Vi2fm/PCWecnIrNi/+Yy3d/e0bcOzywum4emGst7ip5G3d7cEv7DdbUEr1HpH9ol+vaZpGnJYyI8DIQS0Fly3DSOR/eGBrAzbmw06R8x0ZhjL3I1AIYRCaYNzpfHa9StyKCmkUHXUdUX2E11tqDsFaI6HicfbR7QAaw2T96QYcfOCczNKQoiBpDK1bQrJ4TxyvP0M8XQ769sWZYEYyE6UVFaKzPvMeFAEt7CMJ/yylMV61XL1znNS8JAjtl4hcyY/hQicX5BWk4REeIEexzLsaQRqTqiUsdpgpSKljAyOaVjwKSNs4cF5l5BVS9TlpuC9K9/X03CrNj0hLozjRN0IUoooZTmdB47nfTkxx9Kzquqatl8RZQ/e0RpFlTNpdshmhW3XZOtJxFKKmlwhYYczwshCMFei/ILFQJpOLKOj2WwxtuXxbs+nn7zGu4GqrXl89IQUaC4usBJIArtaoZ/SbVKU/ldyDhcD2Ttqo7BGY55dkFMROvplxsiM8u6byboQIy47jFhQVuDbDcJYcogsWeLOMwKJqTJZSEy3pr58jq4b8jiz3H6MI1IZxcVmi7cDtdZUlQEC+BNqFvhlJowj0lis1VR1hVCWSimMkkgK+Tw/9QmjLyGCmDLD+QRxwWiNNobKViVo8jRvdT6VmHl8KseqLEq0PEuitggjkJRBbnJG58SylE3FGo0gkKaMA1LwuODQtli8l5gYhxPJlznBcTiiBPTrDVJodCuIGeakEHbFEjzWJVI4lQOoLMPhSZbWQ4656DximWHMBpSQuORgCsx+LvOItpQZRdI4N+O9R0mJ4Ol9tVWhiiAQUqGUxpiCwSInpMz4JeP9ghtPpUwtBDklhPfInPHB4x2lvxc8MXoQxXMmlCUpSRgXhFuKHy9G3FKqJ97NGGOQyaNlplYC4ReUiPjxTAqCRlck6VHKUO96VIrsHx6/OdgslGZ/f0e/WnN/d0vYrKh1wVFJIb6ldf7bmt3nzwem6cBwOjCdT6QQ6NuexmhOKeFj5tnFNdXmCn+4Zz4cWK06uosraizLkqAyZeAzhmL3vRUM+7d0leLi+pqmUegUON6+Ynt5RVgOpP0deToSzo8kf6ZpapS01O/ePCFewDY1HsE0njAh4g576sbghgPdzQ3d5Q6XPNO8cH54i+kWmF15A7VEaUEjwW7W1HVTknnLTN+vaVdbEpCyJIxn7u8f8AnatsPWFVLvyLZHNWsaq0iHt6jlgHtwBKVBamyjidOCHwNJZKQ2JTKGKN+DlFS2QtsKpS22qplPR1KIYMrisaSEURKtKEGQ4BmmheN55ng6EsLCuu9RSrDebqi7hrw4luOxlHXOe4KK5JzBDeTgUFVN01gIgfk8g8zUlSUJwTIvT5xRz/58hFh6YKrtMLZmnhdQiiSLHmMeB6ap6OftyiK1ILqSHBynieM4khE83N3zeNyzalt2mzVVU0pX+0MBFQs3IcYzbhgwqwv6ekvWVYkMh0hWGrvuyFKTwoYsFVELhMjMsZQxfRZMy0KTUtG5HPY8PtzTtLacgnPpT0zTyGk448aJdVvRNZZlGhHjgtKKziiMMUBh24XoUU1LkqLQ2d1MWEpkXFSW2Rd3kpSa/emE84+o5pH1ek0OieNxYJ4mKlvR7lZYEanrhro26KZlmDzZLzweHzC1oWsbtFYoI5FkjIjk5YRbZmJI6JiILJiqUBRSjrj5TEoBNz/1mlyJzDdNS4oO5xy2qZFZIeCJyi5LgpIS/xcioaUgiIT3ESsVAkkSCt1uMVqVYdwUyOn/Td6f9dhyZdma2Ld6a3bj7qchGRGZN29XKEHS//8PAtSgBAHSVeXNJpLB5pzj7rszs9XrYRpZb6p4URPI/UKAIHmc283WXHPOMb7RiMuC85UcV/LqaNudtK20Lcr36Rynlw9QE1sqxN1InOPK4/0V0xoqJbQ2DHT08MzgRlpr1OVG7RJJo5XCuILuXn7W3mlKDPItF6zZ0E6UtOv9Jr+Xkmh5pQWHn0aUakzTCENAN4TDGKP4J7VCaQFJK8V+GY94JyDnnmGNG9vjATsb1GoJLQ0hYIwjlYrRXYJhu6fULnlVSmGd2439lcf9IWuRlAha0VSVy3vwJF0w48hjXahxo6eMKh0/FFrf0DagtBTQp5cjNS8Cwb4snA8nSs4c5wEXNNsjCizg9436//vP33SR+rd//UdGK8F1aV0AiFUAq601dG2Uh6R1tvFIwHJ4eSFrzWNJTONMR0CO3758QZUNYsIoQ94yZir8h3/4B4bguL1dab0RjOLWMuv9QnCWTx8/YLwVIkTMLNtG7oXxMDIqxXq9c3/7Rnq805Lhw3efGLXCe8vnH77jL3/+kW8//cjxZeX503d8+PiRmjP3yzsGCMYytcQWE6N1qDDjDye21LhdvpCLqA2VsrRSyBWOn/5EOL2g/YhOC7evP5KuF9SWyM7iX57x08itFLbtwRgG0BrnBmwYxB9VMh5L70oiJag00/cZ+ChR67VgvCMDXy8Xtvsq47MlscaEsZLxszxWPpTO8/GJtb1zXzfZsalM7LLo77WStpWpa47fvVCNprzecChsEEq8t55aK9u6klLEG0Urie0SaSi220OSktWImg5MhyNDCBJNHwK1NbQ1DOPI9XqlpY37feHt61cwGrcTOYZxYF1XHrc76b7Q4sq7AX964ofjR5w19LrQbhfi+xvKDYRPn9HTE15/J7suGlZ1el4pOtC6YrYDTTVub79ideTzdy/YeeLz3/0duTRUV7uRsqDGQIoLP/3lR+iaDy8fZTSnNKU3fAg4b0g5sd0XnPdYBThLV4pUCpoOxkj0R5ioSvP19Sf85YK5L6SauVzvIvu2FvI7w/lIrivl58p8/kC+3oTLqGXaMHi5fTtjSNtGZccl7Vgi6wz3bd3pHtAxtL4BN4leKYlUmphsjaWmVW73MvwSZWjrlJTQFtJOR3HGkJSSnUcuJFWlsFuPGie0tizLBdUKXsCBGCvqxC1u5G0hryuqCXqLmsm/ZswexgmKdV1ZlzuPywVS5vL6xny7cb5eGD79ien0LMSSYNkeUvR0k0tWUCJL18pgjSJR6HklPSpLKYRxxqqZ5DJGKR6XNx7Xd87PT9ggCr9gHbU1Ykxoo9FaCrMbrIhzrN0z3zZy2aTjd5Y+DrRtIW5S2JTSohxUMDgn9hmawJi1ROFsMVMeG1hBUQ3DyDge2FJBlUapmSE4epWLpzMK4wcGc6T3zu39wpYWvOsctBT3SsQMnjAZagt07fjj00dU6ay3O6iKpuPnWUDC7q875/+mi1RJkW48p/OJfjigAHcYSdsmD3Xa+OnP/4QdZsJh5Hh4oinNuq7c377RH56UEl++vBO3xDx6JqexxwNunNFhRBshWPgwMgxHTO8wTwTzHWrbKNsGFYx1HP1AzpnH/U7ZF7MaRVWd6fmJbV1QWnP5+sp93QjzzDSfGUd4+fQd0/kJO0wYVwm1ofd47MfyTmkN50em+YAJE7pFnLd0b1FqoNRO3BasUbT7haYVdiqslzf++b/9z8TXXwjaMH945uAC3TucG6hDp2tN05aqNN55ucA6GVus24NUVsZpZJpmcql0xGUfpllC8baEMYHhqCmpkQtoY1DO8PTpO7S1DPOMs5pvy8p9eUhUw60wHmeCdbRSuT0WejUczk8wn9CHjxKXUN4xRjGdDpQoeCPnLCmt9Fy5X28sW4RSORwEyTMeTzitKOtCTkL3QFVyjbvHT5Fqxp4MBx8ovXI4HsVbtcd6jIPHdtgUrFvE+QkbJomlv79TlodQ67eCuVms7piDkRFUq9AbPS9AJUwHfCm8//ojX778gtWGp5cT4XgmTCMqCVR3GAaGYBit4nG9sN7vkoGmIddC3wuZ9Y6uHLlUSmkYralN0oQPzy8M2pDThvEjXWmM8wyHAzFnbq83mg2cnj/w9L1DOfHchLZBSyK3753l/ZsQtafA0P0eLQJqD0W8PTZaLfjpwDSLkjQ/skCNS6Wukp7rnMMYI/Efte7Jx5VbjPtlzzCPI62IuCltmyjBXKTzmwq3obXBGkfqiVSzAHwVqNbIvZCrWBY0WsQJRrM87uhNaA4lRWE71k6N4lPrXWJstDECa14jtCZepxRJv/7K4/2K//LK+cNHfBjFdmE1SntauUuRbRnE6kbXsN5eid9urLEz/93fEz79AaM8ui20+zvpl1+4/PpKzYXz8xl2dmdO0lkqK6b8k3Z4PwoouXS06uguwo1cO84HnDOE4LjfLsS4MjqPtwZdNxkVGwNG7xaKjjaO0RliLPRaUdqQYqLzQCmD9aJgLSWLxNw6eq0sy4PUGrUBxjOdB8bzCd0aefctxtudbVU8fXjBjgfSllkfG04rjNIM1qD8QFHSCf81n7/pIjWMErt9fwhB4Hg8ELzHa8W2bfz0689c7yvKev63/5v/gfnjBwpNUmNto9y/UWLl8fqVlBvn8RPTNKBMx5+PEsP9iBQN0zxT64MY37ExEr+9k+5X+qCoPWCiZzzMeGMp2nAIA34Y2XIhNXnBzLqi/Mw8TRw+vNBdYJ7hMB7Q4yBR41Gw/DZMKNWpxmOdY/SOZix6CKSyolXFO40bBzICyu10clz58v/8nxgPJ8LhmVwbcY3kYmijIRjL5XKja8P56YVxOKCckWV6hy0mai0465jPJwbTqO8LbYs4L5LaZX1QVOYYJuga3RTzOFHI6EnxfD5xu9+pND58/sjh5TOuZLb3bxTADJ5WmxQla6m2U3IlV8O3y0WycaYPmB/+nun5xLRGJqfQxtB6x4dA1oqqYBgl/XXe/WK1FmIuDE3m86XLwVFro5QkBmGjORwPrJsQ2B9JWIcdxbqspLSi6BhrGV+OKO9Z3+7MTx9w1tOXSH2/UnXHPYkJ+P71V8bLO+b4RldGYiR6lZFQcKhhRueN3jvTfMKFUeJEDOS0YrQlWIs3YKiyY0yZ0+HImhKPTUL5jtPI+XwSjqAS1VihsT0eaC3BiUZZtPGYQVF6o3aFMGErz09PfPr4LF3jNMv3UhvBD6AdbS+urXR0byiVKcES40aOknXltKeWSleW7gfcWQJDqVFgwcmJeKKrPeFX/mwJ/OqM40SYD+KzUx1VBfRacqakREpy2vfe8OMkmWOt/e4tqrVR6MzzzOnpGaWFiXh+OqJ75XG7iA+ud1RNGCy9RuL2oCtFV4ZWOjVXUpI9WBgGtnWhxIJ1jm5kFKVro8aN+PUXvr59w3iPcoHv/vgnDuczxha8d2g6OYrAqdfKdv1GfRTseMI/v2A+fqZVjV01cVloGLY1clNXjNZoq/FeYnc6Qh4ppdJ6pdNFiFKLFOuSUTtM+DffV66VWmXPV3rBdkttkbbJaNX6QMtFCqCyGDcQXCD1TleGWjK324VpDPjpCWUsqUgo67pu0GTkWnsnl442gdIqb+8Xqja43pmNIl1u3G5XyrLx/CdPGCYZV7ZOTRv36x1/bGADTf115edvukhZNMSKLoUPH57QDrZ1g17BBJSf+e6Hj3g6Vlti2uimMwwTy31DTwE/VP7u9IltufPhPKBa5nA6U7UmXh/kZeHfvn7heBjx04wNE3Zf9K7rwnE4UbZIsw2bNdYButFVZzodsa3vceiZ4/MHxsOZp4+fGaaJ3jsxZ3AObR01lV3dJJ1Z74r5cEAphCbcqvgjWkG1glOaRy40YxmPZybrWB93vsZ/5vLjLzh/gZcz9oeP6Kczz59eGKwQlK9vb5hh4PP332HCzPLtIgOXwVLWileasm0yGumKkgutL7tKTosxNq6olsiPGylGqraoLkqpVjKHw4mQIub9K7kXHvd3Ytrww0jOiTANaDfg5iM9Z9SyoHFc75Ff//n/zun1nb//L/8R6wx2srjQ0VoCJyud8/MzJgw8a0OOK+/fvvL16zdSjLQ1sqnGstxwRhNckBA63WjOyM5giSzXO91pjv/wXxmtZf36M+nbLzjv8W7EhImhaj66genpTEGzvt9Jt1ea80xOkDfr9iA+3vn2+hdazmyPx+7ZUgzzgfFw4vD0hD3IJUVpyzgNmGAxThBIrnfYVvL1TomJ1MBPR7LeaHtEDFpTa9+7BoVVDq8UbppQXuO1SKylvTeMLggZvRdqbwzW4o0ROsv9yrasNG1hrFg/oJSlN8jbBnWjdyn+KUW0dXTricpQescdPtLCgPvwHX0cUEkTu4HWcG7EGPmzW4vEtMlFZ5pQexGYDk+o8SwG2eWduCwobZlfTjQty33vB1KM6FrR40xTr4S0oWvH2ZEhzDLicpb5dMT2Rm+d29ufZWd2OoJTgCalKkIAI4SHmrKoWVslTCPWeWpvApnuHe88DOCDYzLh9y5jGEeCbdi24sNI751lXaglQxNRSNOe7hK5bqT7xvH+ID5e2X79Cy0tBBKfPr1Aq6zLQtOSE6aNonZFwKP9kW4E7qtVF3UqGo2nW0VTv+3gqmRhYUTN6DRFdaGT1E6NK+qxCv2ETsmb0E6M2HNaEl7f6XgQMkvJ6CqJBb13Uk6SMh0G2VeiaA1Qitv1yu3Lzzjv0KcTxc9UtbFcr8R//u98+sMf5VLQukx3xoncNYPz+H8POynjFKf5zNvrG//25185Pp04fP7McrvweNz54w8/MIdAi6t4Uv78Z+x4xI8TlrpTBRrTYeL0NFDjHas9p6cT6/XGv3z5hZoSx3kmxUiMCaNvWGOppTBNM2GciKVinMeGCTeIBzOlzLIljHf4YaAgkFLj3T6q2VleSgmfi45G/U7sRinJrtEGrZSYEJHdDNpQm6JVTWsKrRWH8wk9zbjziev9wq8/v0JcJYjt6Ug3HZRh3SI1Fz4+Pe8gSAhOEZVIjafxyDiM6NrJ8UqJmzzse/XU1uKtkN5piZJkVGhHh9OW9frO8nhggLzc+frjP2F0YzycxMszBK6lys98DIwfnvDjyBgT6i+/kNJCrAutPnhcfuX9F4N7PhPsibQpUYXVtudBVabe0KXhFDxaR+XK6XhmOD+Lhykm6vaglEZXipwVbYu8f/3K+5dvNA0//Of/wPcfPuKN5cf3b3RtRfxAx5SI6QVDx6hKVZmFuCesizHTGIt1nm3bePv2yv12hVqxSsZya67c4sr12zem8Ui1mvHTByZ/wIdZ4ilAmISlEvewP6NF8HBwB0pvxBRRXdHCAaZZuoIesX3FeIMbPLp34rLSisYdDrRSqVuit0JpldwajxR/z0u63m4YP/ESZtK2kFPdyd2gWpECUixWDxjn6caTjaWNDvMs3D4/HaBWaqr0tlM+emSYHX4eaFlGR2qaaBoeywXfM8ooxulI1zNbLzQ3oVrBDDPeOqzTGC3E8FZFYTZajT5MdC2RHi2t5FoIPhC0wVnDIXji6SSj4VJZ20pJG0opvPfEXIh5wzjNYTqitGI+jHz69MzyECis95Ku3FNi8BbvAqCx1uOCBw2tF4w70Von145xBnpH+063Gq0L98sCP/8rY91YXn9hu74xToHn80ScDNf1QUqZdF/QqhPGQUacvaMUlBgpCtpSMM6BMlRl6GhKbvK99vZ7fpVyFu00mkaNieX+YFsWjIbz0xN+GHDeoq3I3kvJhHHEmoBWjd4qa83SJRsNShiooNDzCeXEO9aKFMfD8UxTmr4j6YbDmcPhSI0SQ7MtGxUtHEIXsFqLFKZUYlz+qnP+b7pIqcfGlqFow+oH5ucP6DDD8iB4zzw40v1G2R60BrGIZ2o6nwnDwHJ7x6SIzZJ4qWridD6yfPnC8rjhaBI5XjLruqL3+Xfa2YDzWSIkxsFijcFouVm0tkvBAWMsx9OZbhzaOWyQ21pLSbKXrMMq/bv5MUXJLTJOJKA+BFoTf09H01ujtc4SE32LBG/JaeH2/hWfE84FKYRD4PJ2o94XjDbgHIdgGMeBlzEwWkuJi/hDHu84D1RFqwXjBroGsiOT6UEAotpYzG581PJaUOmE+Sgx6bnA8hDjrYYcN5alME0TR+938aDm/rjTdWf0A8fBi2eiJbKVufvp5czp82fx9swTg9XkvKGjFvGDliX9rz//zE3DGIKE49XCPE3ENZJ++YnTPFG3yLcvryg/cPzwAescj9dXLu/iydFGkbaV7esvFOdYrm/0lmilUbJFWwm+q1vi8uNfcIeHIGm0RtEgR3rcaDlyeXvl57/8yrosDINEK+jU6KyMx5Fr6WzxC2YI/Mk5jPG4F4f2jsdyJy0PXE0yHhlnehUKg3Me7wxme9BywQ8jNozUIt4s5yUnKcdKMBaF5EFZb8iPyOPyTqmVYoWU0Wsl9YbTMI8TbpjRSnN9+8bb+xVtLS/PT4zOSDjnJLvWmiOpFJIb6K7hB9ApUZaV5fYumKG0okrBDLKg7yI1YNxHxXG5otPGdn2QXm88f/8n+uEgnfounsi50BqUChbZJdW0YayS7+FwwHo5zLuCtjwYrOYQrFymDgemaYLeWbaFx132JOMwShhig+l4omvLMM0Mw8hhDIxexEfGeew44KwlXi7E65WqxWeHNjQxaNF6lWgZ76hGsrZijCitaMphxxMHNUAvbG+/UO7fCN4QnIZWUcpwOJxpaPqxUluRy592+HEGbcnbgioJoxTpfqf2RjcON53pIEKH3gijkCToHdUrcX3w2ERpq1rHOc0jVRgM4zBJkU2RDhKcqSy5ivK1KSPvvxvQJuz8SVAuoKyhlyrdemsEawiDhKxqrVDIMzMMgzAXG/gwEMaZnLPEpii55Nf4/wHixP+/fb5eLqT4Kx8+/8DL6YTKjbwunOeRw4cjtTTee6bnRDgeqTELwVojpONg6b1QqTREeNCaBLbdL+/kXYK6rLIPSDHSaiEMI34c5dabMoOx3K7vpO1ByZlxnjkcZ4z3dDS5dY7nwHQ4yBITRcuytERBzV1Am1k8MV4bSok03VCmAwbZHjfhjhmDc55tuQjhQDvifeXt51dGY7AezueZXgqoynZ95/ThhWEUFzy980gbunVu68pQDXqnfayPDR1kidv1QB2FEWatoZdEjg9y3DBaRqi5d6oVdVZOkdQrYZ4YvQPONO2wfmIpibjccKqjW0aZjmEjfXlQUxKkSi0M5zPHT99jhpHeOh5FTBtxeRC3d9CGcRrRHfyefJtqJlOJTajjj2XF9J/R28B6F6ntcDhjjx8xoTPrjnNA2liWO+vlys/5/yFiC9UYzgcZrTrNMJ/opXK7Pvjlf/5HunIcX14o+Y42mmkasfsLv9yve1ifZkkJohxg1lhC1kwfnjmfDnjnMd3wy7/8yPXXLwzTJB6/VjlNAT+f0Majrca2DsbSjMG4DDmRH+9QVzGzFjnYem/UXASui/49syqmK4/lgUClHE57UZppoOR9jzXQeifmhrIBN45gB6oWX9NWV4zqbPcb7/eF4kZ0mAmtSopxztzuV5QFGzyme8baCN5TrUN1I9EdzmKcQZWZvi1sy8r766+EvNBLlWemGxlDOU3t7BOLCbxBtULrHaMcYZqkqO2pvsFKKGhtDeU0ox/RxjCVmXkcuV89OWdiTHgfmKYTyk8cnz9xmGdMk467qii7psHvAZ+Zx30Ba3B+wBovVJZaEF7rgtMjTnVSq6he8cZjlYM+Mh5GtnUhbjf0PNNaZ+sKXRB6vdJ452hjwHWE7hJGcJ64rZhaaVVxOBx5LAsGJTDlIgIVlAalhG5i9S5TL6gy0GxGDQKGDqOsK/Q4QbDoXhjsQOtKio+xgomqim150HXHmhlvR7oyxLhRHu8Eq1A10fdgy+o8YVcQohqtFMlyaxq0o7SC6praBDuFUvRhwqDw9v9HxIn/b37G5xf09cK23CTAb31gpsAwatADVjvmKdCr3G5yEu8CpdJTJseEG4RMnbZIzJnH487peBYFixbPUo4bpRas9/jhxDQfcMNITAIArbXz9n5huQpYdzqeMOOI1oacZBFqrcNZD70R40bvDas0mkJNeZfzglIiFaU3bIP1umDdgHUjuUlbb5TCWnHX3+4JO04kDbd443L5leE0EULATYHlvhKsQ28L969fsE7krHHdMErzdDgyHj/AEKiPhce3GzXeiErTXMDagMOQ102oB/GOUoVmHNVAKgmvFS2uLI8FqzWHw8w4zoQw0o1FYfj281/oteEHyzgMXK8XsJ1GFzSMH3DDieOHz5y//55iDCpV2ipqQMHBNDE4Kk1KG9Yo5nEmlyL/nbKypYwxitE75sOEHwdCQZiCrVHWyGgNh5cntvjADJblcmd93PHjzHw6YSmUlGjKMtgJfXBsufH+duOnH3/hH//tX/HeYo1iDJ7jYcJbze32YN12J72VkVEIAWcdRUlUylMYeDqduX57Z1kW8Tn5gfn4xDRNDMMOke1CFNBI1phybh8Pa3p6sNzfUEYzTzPNDVLU9rGM0h07TCjELGkGL/6X3mkpU0zYFXyFVju2K8GDZUUxgVINj8uC7QVHZ/aZ4zjSSyalDTcc8PNEu1zldmxgPB0l/t06Ui7CNiwZhcIMgXA4YY7SAZVaqXFle/vC9csv9Ms7oGV0ao08o86iawOtCcaQapVY9SDSZmUs3lkGH/DGkOLCGiNKa2jQSqWViqIzjCPOaHJKIuhpMHiD9gavRS1X48r99VfW25WeVsb5sPveCnaQUEatNLVXKS6IAKfGxLKu8t0rLcXAe7Y1UrpGo8EYjAsYM+CsQbXKer9Si4wwc69Y4yUt2hqs0XQlisVSxXBu/EBQhjVtKKUpOdHUzujU8ryUKCGRTWncMHJyM8MzdDdgpyPTdBSgddsEfOtkeoNxYhdQCmUdOI8bT5jpGTseUV34iaZKooFu8o51ugQsDh4bPCUlSut4F8i5kAE3TIRxpv+2A7OWgiQJ/3VW3r/xIpWWyPF4ZN0eFFUww0C8v3PPhuX6zpayLF+d59dffmU+nOlVcX99wwcvah7l2a4FRWf5dmW5XdAfG9sePvb5u+84zAdyLkzzERuCqG2MxebCcn+w3FesHfiH//hfCEPAzjNuPJDihjKG0Y3EXHjc7hhkDIOSMeIWEy1naq1o74UCrToGWG8rOUUaGj8eaMoShhGMeCBKg1Q1W9HoMHD84QeiN6i0oV1gfnph237BW4ttcPn5V0qrjPOEDwNVaZIf2VLDO0ix8P72xpYzxRrs8YVx1tScybc31HrFm8rxfCJMR9CO9/dXbpd3tmWh18b3P3zPcZqxbqRjKI8FSyM/7pQUKdagjcOqgPMjapoYhpF5OFIfcsiUWsg54pXs6tQeXGi1IfgBHwIlbdxud+J9RTuZsaumeD6csMbgDyMfv/+MCZ5UO6+vN778+o3yfiHphgsKvCJMgbRs9D7J71ZpttzpTQ6Y0Y+oaeT4SfEfvef8+SM/f/nKOEy0nFmXOxjNmgtb7YynA59OEmE+TBO1yO5MG4Xrmu3rN3768pVUC8M48OHjdzx//MR8OmGsZdtWSnzQtpUUV4KT/KNeM01ZepVFd40bympWZXDd4AbJjhKQDbK7yBVtPfPzC8YJz7KuGy4E/DhRtGCwci58e73y5e1KU4aqDbVWWpLMpz88j5QniKmjrKdbzWVdGHzg4x//wOnpxPJY2JaI0YZJF0wrkAuxVibvJewQR8Vh3QBY/Jgo9kKKmwhBdAdlCcbiwoCuRS6VWuOHgZolrFDeQUUt8t+nZow2Is3XRobQpcqtXkMYBpSVkXwI4sNa1zt1S0TV2O6Ktt5p25WSNr7++EApixtn3DxhvJVJhsrC+tNajMoVktaSdrALPZwbUApCrRjtKbXilJausnXyeifFDWdBBbNH3h9QVSj8pRa2+5WulaChULhhpGmD8Zr4WCh5Ja6r0EGCFD6nwWslQcYGrA2M04jXA304Yucz1lgomb4pat6IMdHp8i4ojUIEWmE+YKYTRVuyVhgM/nDAe9juRfat1rDmSm2Z42Cw1mGVRVkJU9UBrDS6okLdvwM0hC4KzVTLX3XO/00XKUbPPa3cX984ffyIHWcZfSlFeqy8/fyNaQq44LDKoIzDjRbVMq0maim8ffmV2JtwuqaZ6eUjj5TEld5Afeq4eWT69MLx6SM1SuT8FqO8MEaxpQfDEDC6sS03kYhqR9MaP4wobalblByk9OD88h1GG1K6kdYHKQk+xNSCNYau2MnuG4/bjZoyGIdygZfPP0jAH51EJT1utOsrZhqgddlJnI6Mp4/0+8ZTk5uWcwE7HShlY5q8hJh1EWB8/cuPaDrL487b2xuxdrr12GWlHE4cxgldK24YJXrB7hk4u8lya6Bqh9pQNbOsCwGL6ZI0KjfMGTNO+OMM20q3QTqD4YT2M+sinZqZA/ah6SmxaYGEjmPA1BlqR3cJrMulkGOk9IbKlo5hnGeGaaDVgt5uhHoiWM/5MNNT5su/PXi8/oyfPaZZhhoIWlz/06wJwbDFu9zCk4xt7vnBuMGAxk4z/lNnDlo6hdZYlhE6LMuCHz0vHz5wPp2l6ARPK4319iAMA2EI4tnaNnLcaDlRbq9cykp5POGHIESDWvYIkhEdAq1U0vKQFx7AWYw//T5aVHuWVq+FXgvBOWG41S4jKjQxRslMChYfRnATzo2g3kmPhfv1jeXtIgGVwwhKkVKktcq/3hNfVWb2A8+TQ8WV++0VdX5GlROqC+2kNcQca53EvteGrQXvPXVdiEU61JKTZFuhaT7glEWlFa8yTics/nc4rd7Tk5U2uMOZbiTc0jRR0uW47VR6KVKCFhfWn2oN0xt1W2WspTWqdVoU+Cp0etoI4wSHEzkE9Jj2AEJQbkDbgO5dYlJ6Fb/XzsHbLm8AhOkg5AZl9nE+ND2RkqhLVbwLDaN30uMVVMecBDrtwghNU0qj5owyYhlJDQ5PT4QwYb2n1kLKGQOk2rDeg9J7VE/BjUE6IudQbpAgUCUQ5dbA94barmy3V7bbN1SJYsmolWYNXQeasfRW6KmwLl9obiAczvhxQvVGSRJJQi8Ua6mtY1wQpNlOkGhGC9ldidiqxUROC8ZoVBtlIlAqFdkv/zWfv+ki5YJDGYm3uCwr/XpjnA4YY6m+4Y8HlGrcH3eCDWQuKK3FWOYc3VjGD2cOYYQwEEahZl++fRHpZulsuXI6nzl8/Ix2E11Bun3ldvnKtmx4oxi9hAPWmrgvD3raMPOE8hPBCNDTWoOqmhIbabvLKHATo2ltHWNlHNOLeCFyEvxJ7dB2b5dq4thvrbCsCz2tEG+0baNnhwsyFrTnHzh8/hPzH/yuzmto7VB5JaWF4DRWK1nMp8TXn/6Vn3/6Ed0r0+gZlZWOYntQWiKmkcl7/OgxtN8DBI1SeGs4zxOTM9yvN+JdDvk8RmY/4ZyB4Dh//EBWSqCebkAZR7peCC1iu+X2uJHvC6PqXNaNsq74IeCniXGc6S7ggyanjWWNuyppxgfPul8Y5uORjsiBH49XjCqcP77w4fNnpr7i8w1FwdiRME54N1KVIxxndKvUWqFC30P20Jp+uVC2KNJlpQnWYs9P5N6opRL8iDGGZVlprTONE0MYOM1n3DzRUsEqixsDh9OB3ivb48FyE6HE+v7O/f0reVvQYcCNR8J43gntCXKi1cSaFnTrEh3vHSaIHaIri1Eihti2hZ43mgsYf0D5E8FDWRdSkjBIP4xoFyR8sTdKjugs5BDVG/FxodeVHgJZG/Q00+zErQju62QDwVqetMMEz/vbN3LeGMaJXir0QFYOqy3Oqb1gwbolTF94//EvLF9eOT8/45+fCN7gj09sd1Cx4I0UjlYka2pbF5y3jIcZE5xkPKWNmiKqVsn/QtNLpiXxC9HECN9LJtYCOz0/DKMckLVj/YB3QToR76l1F+QYj6ZijUO7QZRrOWK1Ybu9k+9JomNo5G2hISDh807tb62ilGLZbvzbz69Y42mPN9Ryg5LYonjJauu8OIdRijVeaaVjrMYEjx0UFejW0JUm5UoxGm0H/CFAmFC7abg1eQ+tkSIMCtU6qmVUKzJS65V2XYnLnfX+Tk0L3mpUzfSSyUVRdaM7+X+vy53HsmHnE7YVHpevtLTRy0LcBCXm/IByDqsbfqf+b7niwoBxjm1bSXmVtcrjjq6V4AeUUmQt6eDbX9dI/W0XqRg3Rmc4fvyADyOnl4/gR5SGtXWWJuqeFFfUAKrL2OC+JabTmfH0gW5lflsAgyauK6VVDh9OjC7QWub29QvOBYaxcP/1Z779+K/SdVzvDMPEdx8/oGqmJsvH52e2XKnbilGOuKxkrchpZVtulBzh3tBKEkGVFhEELmCUEtbYJgINo52Y4ewT1jpQWkYdXWGdiERayWi1J9jWzvB0JLx8jz8+Y8OM3rsyjEPtAFWrCj1HAU6mSK8rtUS2xxVvNWGYaE1hzQtmGMhFiMrBe4ZppCqDahLQ5p0XiryxuNKI6c563VD3B+blA5qB9XGlGfHYpK6EbrFV1G1Fm8T5j57w/YH3UBn9AElEJCK9r+S47YidRt4etJLQxjIcjvhhJMySxaToPB43Ui3UlPj555+53C5cLhdoHdMLP/zwPQwBO05YO5Jrx88HWnxw/fIFckRrMNaQabhexZSbd2/SnvszjrN4SFIWxqGVxbwxWuTeJaKTlkNkF+pYrYix7jHfhulwxhrPsN5FLWYG3OEZf/4st/LlK/X+lbLesbXQuiYmUUDaAic7EIKW52EYGAzUrZNSlYvP+QONTukF3Qb85LHjgeHwjAozOSeUH3A58d3njyirRUVnFNU6LgkuqVMVaNWxVmOsJ9jGcWwyGmwNVSspFZSbUOGAKlkO8FaxqoNWlFroZSHlzJI3yts37PLADBPzqeG0iIlyinStMUFxW+5cXr9x2FFVPUV62dBlk1k3Gr0nQeeSSDtZRLWGswZaFzyaMfRSJAMueHzwu3enUWujR/GzGWPQLghmTBvseMTshPK03am90HsVMLHTTIdJqAnaYbSWw7pWSkrU2xXbCvPxiUIit1VSvucTOgTM4UAcj7TgqSbhasU4ETAdjGZUCqX2YEQvuDIfxF9XqmCdhASvMMZKmjCN1itl2WBbhO3YEphAVZZeIsEbcpdjX2uH8ZbSOsoGcAM5LihjhTqhFLfbhZ43bM/oImZ06TD3jlkp6JWaN7bHQskRM8+UuJJjlMtCb6Sdil9qpdRKGA/Ev3Ir9TddpFQpzKcj5uwlrCwliS/vCKojZdIS0UaxpoTBYENgjRsmz5xfjuhhlgXg7UJKC9dff8G6hn6a8U5T10x6+5U/f/0Gxu/kbHj68D3jsyZuwiKzyhFjYxw1pjcer+/YubHZhaY6tWXut3eCs7iDKIRqqwA4Lw5w1TtKiUfKOSf5WtpQduS/1kbm385zOI00a6BX1vc3Wu344YA+PhP27CVdC85qrBbza9Maoz26QaorvRRqyeI7ag2tJNm394YyjsPTM9P5JEiU6w0BWHqctrtPCVIzFH8Aa7DaUB+dQUHritwrr9++crtc6cZQtaNZGTvW1rCq4FXD1ohVlnq9U2cIwWOnAeWMJK2uCaM1uTRQmvl4ohlL8ANGC7wzxkXgl60RhoGYz6zbne2Ruadv4qOZjigMzTohhu/w1XkcSD1TW6PXxuiDQDV7wxrLECRksJa2s+j0HjPeCftNHKUZJkB38hJZHne2tAqxujRi3f1KdFTreC0vuXUjq5XE2GE6CXPxcMB4j9aJ++2VtEYRjriRiqPVRtsyUV8x2yLPiB8EdmoDwVnsMKPwWNVI1UBR6K4I2uF8oGi9kwgG6nrjNI+Y8AGrFUZLWu37Uvj524WlFp6PJ86ng7Di4oP0eBcMmR8FzxVmzPGZbgNh+crteuF+fccghAltLYrO8cML04cXeu6s7w+W64MS3/j04YRWWoINvcFNgcFoiu4S3KcQSnnO8n4MM1VZmjLkWikNzDxgxkpOSQ5a62i9ynNdoqgVnQCWVYW2q9S0Vhit6aqCqWIn2fPnasnY3tEa3ODQY8DaAWMVZdM4FMaJh6yUTCuFkiKThj98eEENgc2OPNQR5xzD9EQPE9WPNBuo1tNKxGQRBxE8QxAWYSmiPNbWobRD2YBxHt0qukmRUkpYfWgB8nY63UWWklnvd7HNICpcpTTGDcSUUUg3WUuh9YK2gfuysF3f8VaJaGqa8EBQE6asbJdCLR3tBvHL1Y4qkYe6Ms4HnNWUtHLPGzlHalEoDV2BGr0kGreG2io6BFT+dyBBP/tA6AAdVRL5+oCSaM5iTef88kQKbv9FW5T1ot9XTW6NvaBrIpXM4/HYjRmBMFq815I8Ow30oni/rtyXSNGOD9/9gfnj94xorl/+gqqF4/mZtC3cYyJtC//2L/9EV1ZUVk7jxyCYl6cnMBbtHKrKTa40heoaZ2URmqNAN7V1sowtwjzTre64/Yryg9z2QsB6j2mKYTrAMEn65/6dUJvwA3sSxWpvtBIhR/mu0sp2f7DcH4zeME6TdCjjzOF0IkwDqhR6XNmWOy2BDYZWG7Vb1HTEzyfJbFKKXBKTVYCht851vbCVSi2VJS3Y6cjpcAajqEWxxIy5R9RgqbGytTvrrdHShh08p48fGYLD0Ilbxg0zw/lEp0HKMvqJAhguUW7hpSuUH8SXtRdnO57w40xOCT+MGO9Zrw8gk5crLWVC8OAsfvA4H7CtyUisNZEcA1pbjA+UWii5Mg6WVpt4aJRGu4ANdjePaoYwgK6kUojbhnVO6CGlEtcNZS0oi7IW4xzOdKgbKncxcocZO260GuWZ8YPsPXql1sxjveOUkFcYTthhxE0nynCUMR4d1TXGBNK20FMWiogZMEYW3Tkn2vYgmIbCQheV3cvRMQfHY4t4J/Hi1QTUYLC905Y7PWe0Syi74prQF3RwjENAJy9kFK9lnNgkNXb0A8yOMJxwxw3jDH4eUMmgFkVqsG0F1eF4fMIqRelVumsUzs50L8IPjMMZQ08i4e+14IKMnXpX1CjfHZQdZNx30rqjGclSU63Tat8TrTVNK7ABMxjh++UFlaLsdscB7UZ6afS2EvOGrmDqTnhAokC0NwSvaUSCBnN6QrkB6weSMmLEzvL+EVdyWnDBo4MFDRiNUo5eFLl2mu44doGDEtp6ycJGND7QlaI1MdXqoFHjBGlBG0l56J3d+CvpBrWV36cTy/og32+8fX2lPC48PT/jTs/o3jHe02KirgvGj+KZ8hO5w7beCAbc8QR2wCqHtoUcRZRB16QcqaoLBiwErBkZXkZcGCj39a865/+mi1RaxUdgBge9E+8PBlPJWZMadKtQ4wC7twkUaEWYJoEc5sR2eeWyRnSYOL28YLXdieULBChNo/2R8eVAMANJe8LpGTOfuH/7StweHE4nGGeG6YDuBXu/4l+/8cuPP5G/veMnz3w6cno6M04zxlrJX5kMKSVizpR2oxotwXq9k0shxijz9FYkqbQ1FJpqE61sEpV9eaemlUEPwuUqmXXdl+w1o2tB4om0jHH2kV+vWW7MOeKMZvQBazUujEyHE9YIPSIu4o9arztFwWge1xtbAXt4YTh9QE9H0npH10ZLlVTAWY0PgePHTxzM96ScsderONadk7hzLNvWUGtmnhynH/6ArYXH7Z01ZyiNtka07xLTroW4sTxWWk3kdaHnSO8VaxTTMFB2/Io3ThBKteyJvGFf9HrCILgdXRopriy3C1ZZDvOEMkKQ0FoO4toLNRfavoTGGJwP9CzjGPit++ysywpFEbRF22HvYi04JR6eHWYKitorXcE0TzyNE3EfdRHvjEbUYxhNnWa0qhILrzs+eGwYqK2zrRs1WWotGD/gTx9R4wEzHTDDQG9F8o5GuSjFtBCvr8xO9o6tQ1nurLcb8fINR8GOE246wiAHojOKl2mU6HXnyWGS0ZJG3p91IV3e0LcrfrkwnV9wpw+44wk1OiFWKEtBItd1qaiYwGncYWY8HiTKhIZ3E2OwqC1zfyyonGX/4yx1D/lTRuJC+nCg6h0n1jq1AL2gjaXVSIqVphQpZXqK2FrQVqF6wjjZwaR1paZFKOs+oLWRUMXe6DpjSxd5d16pKWIHMd73DiVlepGRZ8kJ3TrGup0SoqneUltG9U7QMJpAVoZlubGVRimdklYhmrQsBfL5jJkHess069A+EMJIbUKpd96htKIUYSA6KxjxkpKAApwT43+v2GlmRHh5eYsCfO6/IZGE0F9rwzuN1Z31sTB5A23AAD1vbBehtJiaOA0OO4xU7VBegMglR8bBYfxM047WFaWK2Xg6PUHr1LuM7K12TFZUxdPTE9pYUv3rzvm/6SK1rgttWfbDQ0k3UBPH8zNUmc0v94VaGqvLGAU5R8bTiY/PnzDG8Pr1K6YrgjcY1bEGxpcXTJvY1gdmOPDy+Q+YmHh//Upab5JI2xr1yy/oVAhhZhiPTMejGFJPT/yDd5KW+/qO9YYtbQKELJkaNb3sbbwx6JykG2g7XslZarXE5UFbHxjTcdaSUv6d2ybL3k6KGZpm65WaVsr7N5K5clNK9kvG0EoWX0VKgjMpma4Vp+cP2BDwwXE+P1HpKO9pDbb1wbef/ywpn/Q9J8syZcS5HmbO40yqoN9e2V7/Qr78AqWgnMQVhDkwTS+YYSSud6zT5FjwRtP9AF7xSHf6+qC9aiqGJVVqGKlh4PHtle0vv3J8mpkPIxrF+lhYtiQCGB9wg5dcp31/1UvCqA4VdO8SCaAMOUWM0Vhr6TlTmgB5jTXoBlqJl60IlEw6Aq3QBKz2O4OuiYBCK5z30t1aI51q3dCqU7YHdZ/nm9bZ1oTqDaXMPvKCnLOASI3BjjPD8Uh73KhpxdbOQMGpDBaB74Ko7nTH9QYp0Xcsk7FeRkZhQFmDtZ6CQpUiERS9ovQiNoLHg/X2RspdwhNTYbl8Y32/UB53etoYjpWDdXvchSzmnRPpf2sF0o2upPh3JURtpWEePcYogoZJNYx3VJWJXZRyRo80d6DFRXapLVKrwkwHFB1nDVZLorKdCt0o8pqEwABoNBYD2oOfYJjR1uNdwNSGMpa4Lmy3G8v7hRoXSk3E9c48D5yPR5pWEliohARDK2jVgIY1CmMdOWc6VpSHtcBuGEZDTgV1e4Ba2GKUn8uPeOdEabjH2htlqHvCtUzqNSndZMJRG+V643FfsFrhp0DwjpYjuURyzdAh+AHtZAdlWhXA7G7Ol0NfY4yl1ypQ2lZQNe8q2yijvWEW8VWJlMzO/6u0fWfaOxjjZGQdRpQyogItCa2QsXKVIMYwHgjTQNdaiBfe4YMkDTtjqDnCjmgyTug65I5tCt27GIrnI1qLItGBJCb/FZ+/6SJVVWcrlVKqmCaDRytHr537643lkViXTOvQvWb8De9RG2wL3VkOn54w1zu//Pd/JBjPcT7w/X/4e9zTxLbeBfWhFMEZdI6U1y/8+uVnpjBi/cD58yeG8xk/T5TeWOOGVY35eOLjp494oximgb/8/BdQjev9jXV1TOOE80JvHrwltUJsckPqKdGzRBaklPBhoARQ3tGsYkuLqJ9KEy9FFUbWlio2ZYwT7MhxmqhD4HG/i5Q9rqjeuL2/E4aB/6Ab5w8fQXX84KmtUbWi7e75DS3R4WGgpMy6JJorFA01Pgi3V+rlQt9W2nYhpwfHMeAHS1eV3jKqJogigZ285h4LpiWMC0TVGMeRx+srad3IFZKyzJ8/MR9m6mMlbglbwO1BbVtKgvYB0Ahc1QoUNeaGahLtoIz83gzyV9m/dazutLxJ3hAKraUAa2MkU6nJQr7WSsmSWfT7Qan1Tnfo6C7hfLVUlLP4IZBLYV0exPWB9w41BjF/loz3nqAHoFJjJD7ulNYYB4+pkbiukNNOZNjQpgmGSouvLtidbFHLbuoFa+VWb2qlVIk9V6rRaoPWccZS84oqD/p2QVfBPcXbgzQ/sTXDXQXUfMKPjr48wI80N6HdiKmZXlccK7Zm5Fx2GO3o64aOidkHwuiZB0PrFfoG9U5FxphaWdqW6ayU45OkwdYH6+XCFl/lVn1wjNOAokpumPechpFHrdTSUUaTSkF5j5kmoT8YhzIBZwPKQ3AeNx5ZnGOND779/C+8/uVHjFX86e/+iD4doYuIqHbQSkvXbEes1igthalbGbMaLUW6tEIvBaskvj1ukdbl33U+YMOA1jKhUL2jaqG3Qt4WGcvTIRsZtQFYje6FXhJ2GPDGkrdIbQmdZT/stUYru6fsCqy55kKRn1AuHmanwnQl4o8GvYrgqrYsbEMlrMv1cYFS8EHguqUkck50NASN90GeNe+YjzOtFFrOpJTorQrRHiTxWImHzu6himkfFSvYd3daWJTG0VXDJ7sbsQUb1VvbZfMyav1rPn/TRaqkzv1yJ9bCd99NHKcj0zzgvce7B1/iA8JE15a0LbReeDpPHMYAaaGUTjg+g1EMMdLrwhJX3qxiukzYXkn5xo9/+UZaH5S0Ual4ZyDD/PLC6Yc/Ubwn5ZUUkyTUGlCtYpXicbnzuN7pseNGQ3lECJVsNdYo2Xcag7ei2ikxUltFK0XsDaxhOD5RVEc7g3OW0BXrbZWYCwzb9iDd7kL53jb8HNBKsdREvCvu9zv3+wNvpXh9fb8wDBvH93cqEKZJbsZaY53DBU9cN8LhgC1BElCrAqtoyrKuG8vjATowzAecU0xuhFixCjSduK7cb1ecHxjnCWONhK5RSesVkxxGCzljcxIZPo0Dxg4UrVBp5WnyVPtEVZ31sQnDrTfC4FAg5tRlkYN3k9Ho4XwS9/7etZi+K4haR0lmh3RAuqKMiD2U0SgbRB5bBL1Ta6W2Sk6ik9VaujCQ8R77C9nhf0FVefHS9FqouRH7HuNgDFi7L44RVprz5OXO6y8/sT1uYkQFSaONC3kQykAtTWLDQ8BoGS0ZlSWevCtSqRLPkDImbHgquhehWy8FXR7Eb79w/fnPfP31K9aPTJ//xPG7P3IcjzzlQr9/o96+0e5fqV3hhoHBGVpZSetdQvgGaMrTrEIpSLVSEDRVLYn3rzfSeqc2hfIjw+EDL9//IPsxFtJWaKVhlEJ5T5gbRW/UulFvC4/tTi8JeiWEAYCaK8p58TdZA12hrCTUtpyhSiaRCQ7nB1Ge9Zl6PvArDa868zwRvNC/7S6m+C26BbUTKLX8jkqpNC0eOI3eRQeGri1lRzKhjcj4rSPlTMsZ76x08rWwxUwtiZSlC9RGoYzCKos2WogSWuOshdYp28b1+o4Jlvn0jLN278yLKAebjHpLSr8XqJojS69YFzA2COhYITxJ1TC9y5gX2Xv1VqFXSlU44zHOg7GorrBKsby/kUvh4+fvsNMkkfJVkg5ckF16rZWSCrU06h702HfVpEa+M6UNKEfvTd6z3uTnjyulZPFxhoBTWjx95d+BmTetmfsmCP/h+QfGeRKDbIz4yfH8/RPjy2ewA9vrO9++fWEpUdJ7q8Eqi22dhuLpdGCLkawU365vvN3eGEcvD6SyrPHOGALnp48c5pFaKsM0U2phuW3oWqhZhAqxFtb7jdvXV+7frrxdrtRWub698/HjC8eXmfv7O9HemQ9HQhhpHWraF7laiRrPao7Pz0xPn1mXB4/bG+V9w3R5obZcsc5QkSC3ksSDgRKW3nq/o5C5/LqutOkg7f3xicF7yQ8Ceqty+Pkgfi6lSUUQK5JyLA9XrYq39yvX6wOnO7oVpjGgLbBFZmswRvN4PCipcL3eeX39M8ZZjh+eOD2d5dBuhZpldBlTou5iBLwnjAGbpNjXlukUaJ3H7Y53XsLgcqbVTo6JtG08bncey8LpdOBwPuG8paqG1lbylYyj7abH2uU+WlqHVmSHQKekBWOM3NiVEbKAteh97/RbN1WKxK5YI/+MsZaSC2qPzhjHEVqj1iwHjR7w44gJI1pbtNIMp8D0EeLjRrpffv+ut7jR4sqwI52GqYH1EoaXM1Xv9Pciu46utRAGmtx81fZAb3eMFiWdAfHGXV7J60KOGyFMzM8v+OMB40fm2sllI8WVOozEuFKWK7dWJDByW0SW7QLNzzQf0FqTQ0N5C8NIyhuP24PHt2/i9wkz84tmPJ7RwYH3KByuV+E7tsJwnHl6OpJToW2LWA2S0NpbbTRkJBd2s7J1VozQ60O8g65jg3iLVBP7gW6VvN1py43D4Bi++yReOy+Jzkpp6I2URUo+z6I67F26K6X072ZcAGMU1nq6NZRW8Pu+UmmR3lO6mF8LAmpOkVrkOUu9MgZBJFlj9r2pkBmSL2AscV2pORGCZ/wditvEYpE32cGWTfau+3i1FvlzFF1GyMrsXqhCV8g4r1ZS6VSglQ3nLNUYuQAZeQY7lZ4lDXt9fSfnxA3N+dNnmtHEuNFqxTuPc2Iyb7VSSxH6SS203rBWAi01Bq207C57oTVFoRNMp1otwNpWUN0KgLfVPZH5f/3zt12kUHz/n/8Hnv/uH/juj39kDAOXrz/y4z/+N9oaOUwT8zQRjs+kacKfJvGs0LjfL+hYWMs7KkWWlmkafJgw00zzRpD304w1hnl0BA0ujGSl2FLm/uULZs+C6bWiWqPEyO3tjfXxQBvLp+8/YgbPly9f2NKG8kqCxB4P3teV+3Tj+dN3uGEkl0LZNjk0a8NYjymQl5X4duHtxx95+/qV8XxmOJ0IhzPayIw4jBNWaRqNlkUJVWqX5XhtlNLI9zvOC+Ntmid6q6S40qtI9c1QyApqaaR1Q+WE9x7jHcYNoO7clpVhHgg9s759gRKZppGgKlVVlPZUpQQQuxZuS6K2FcaJcDbUAr10CbJTnWoDWg+sKYKp3PtGXjZ0ztALaR9vXS5vzMcTzk00JXJbaxwFjSmNUWvGw4wNBqXl4NZozB6zkHdJcc+J0pSYdQGNKCeNQg6z3XtirZWXWySC9L57brSkiRprf6ex1yJKP5DoDnbgqzaGYZokCTUElLI0IMwTZhgZD0fifGSNGyVH4v0qt1ZjyDrgtMB9tZZRU9sZMwrorUAXDBDBYltBxQf1/VeagtwLY3A4o5mPA9Z9JpxOWBcYtKbfLzTeSSkSrxcoWb6XXMjLTTKErEefPtEOT5ThJJBYBbYX/ABGNbqVuPrqJ5IdRSpeC/f3rzjXUe2zcOC6J6Uby7rKYdsMg5sFSns6QyuUEMglYY2IOlAKZaykAcPv0SutVMzJYMyIJtPXQm+Z9fbGl3/7F25vXxm9I/zwg9zmfwvuY0N1ZGxld+SWpOOIJUIbvBfahYQ/9t9/BqMVdh+h5ZyoJRP2Z6Fl2eP8ZuRVxuAweD8IoaJX6DJSrMpiBkcwDmWEzOC8Zj6e0QbickfpfdSXI+m3OAtB9v/eEf4WD6MV1LhS0wN6IeckWWRVkTtYDc6JerT3DrmCrvQssTA5RxD+A8v9gvIOP83QO8YY6VqbWFTqXqh6ybS8yXvizf7PSZSNUhqMBQOxVrwBgmCuWgOjFVpJnAv1r1NO/E0Xqfn5A//pf/e/5/ynf5CHuVS6n9DG8/Z+58vPb+h//cZ3f/gTxx9e+Pzd99Ta2R4PbpeFt29fGaaNT09n3OHEuiZi6/ja8O7ElhdcN2gz4I4GR6WVyP2+QFPkmPnxn/+JdUscDkdMb6iUULVQeuH44QPnlyNmsITJkndPxv2+oCqEENDWUvesppSLuPZBcC7Lyi9v/0RpipYSeVlwKA6nI89/+gNhHLm9vXPLSZRTRovqaIdgtobIp3tDW3lZakpYZ3BOggLzepebnjZMw0BOmVYfIjRohZZB9SqRI/OICQPNedp6Yfn6DVsKx3FmnANtu9FaYzqdUX5mtRMvW6XGjcPxaVcqNbqB3BqZih+CRAkozcePnynGsA0L9XYhbwtOW4bRUVrHBKEcGC3AztYkKfRkDUafCN6itHSBpRYsirJt9N8kxq3TgFylCy1VVE5aIT4hJzH2tTYRO/QuSJpSxJ/Tu0ShW0sIg9wIa6FqTdqiAIx7FwVlkz0pDVJaoRSsD1QFuWSGnNDW4aYjfjqS4gNvDW2U8Y3zYT9YlHRPrYNuwqizHuj7xUgwSKpDydDjuneLGd0c+AET/I4oaqQceXz5ib5c6b3KmDpn6SCa0BPUMGHHGTMcseORGkb8dMS6EWqBuFL7nV43SpZdX7MefTyhukbnFd07t8srcb0xTCf8eKIamVocj2es0oSuAIvS7GGEFpUcWvF7fpm4CETValCkkihxQ1sBJeflRnw82K5vvP36E3G5c5gPzMcz7nDcM5k2Ho+7jH/3OHX3Gz6plN2eojHGSLcO1JJlvLyrY7XqUMsurOkM1gmaqFS2tlH3Z6O2TuuNYORojXEjp4y1Wgqjkz/HNES5GMSLFeNKTJFuLcYJWSbFBPRdiOCpAmVEGbtTNsSPlLeHoJ66WFpqkWlBa1CNyOqFMNN+H0XX3vbphaX3St80XWmWdaFow3w4yHO1U9VrbTQ6qE6tIsTyXtSF+rexpupo3VFaImG0lTGr21l/uXYRYdVMqzKC/Gs+f9NF6nA6MZ9OOK1YHg/u7+/88m//Qn1cmaylD5ptyfz8T//E9DLQhpE1Naobmb/7O5IbKesV7Ig/HVEHMFbx/v7GcrtgtCDvtdVoa0hxoy4b67rRlEZrR+6Gy2XhdlvRvfE8jsxGzIjH8wm9J26Owydu1zvNgDsETG9MwWFGT7cDcets71dxaDtR5C2PO+v9ISZd73BzYDBWUi2NRZeETRtDL3xdblgt8+zL+1WC28LAfDyinCGWSM8Fq/RuHhaVW42JbjTucKTvZruaV/GcADFFQhdyeRgmJjsSTWBzkG53IX0/vWDmGfMIEBeM9bjTB/r0hFKG9fJGKpllWTgcTygXuL6/UdJGMIoxBIwRw6zxI8dTYK2iQmy90ZRiOBz2W7FgcLxx3LYHl7dXJuc4vZyESt2qCEBawzkvGUgxooyTMZ5RxJR4u9yJqTANA/Pg2JJQBlQH9gWvUgrtHaUUepe/Z2wgDAFjndwqO78DcC+Xd6AR/D4CUZ20PsgpkZRmOB4x3pK2lbY8cOOMmWZ075A2rIJm/e7D75TduPrbqFEZuRHTdqlx3dAti6BDW7pp5D2uHQylGVTXqK6gCVHFasfj8U6pqxQCpeTZzhlNR4cBG54lE8wFNAplBWOltYyCUdKJ9mVDdzFkDz7ArCTGHMAotGHH5WS6lhwu7Qehp7RKqQnQu59Nfq/GWFQrcjCyq/A60PrvS/dtW9lSYrleaF3G3HnbKCUzPj0Tjmf6cMCGA0p1Ste4fW/j3J6ErEVarQC0E8Vi2827WosMW3VaSewWWdm7NLmwOe9EwLOT1rUCP4T9UthpezbTti2s6yo+Mw1edXq11FIp20baNtZlodYiu8BpJi8PCg9ulwutVtwwMM4Hape9k/OBhIaaUTvDkCaYrtqadHLW4RB15m+Zpb01ShL4ANbg51G+nzHt4qOG0g7rPN573L5vdEaEJr13EZ4gwh1ap5SCNgazC5WUUnSEMqFQOCWEllYlRkaG4h2tGtbov+qc/5suUmnbKPc735aNZgTxoXNkixvPH5/44Xji7fWdv/zLn/nxv//E4QfF8z/8J84fPwNwvF65//xnVFygdqbzE8+fvyP+0z+yvv1Cy4W6ruRSyRLFirYe7RrfvnwlxYxqMEwDbvDMxxFTG0MIPH944vTyQRhc60ZOCTt6pnHi6YfPpOUO64Jpna8//sjjEilxBaNwQxAVkPecPg4MT5/lBrM+aMud9PoL9/JgCIHJWOx5ptQk4obcSSGRe0V5JHbdethJMsFYaqvcb3eJeY+J4+nA0/MzzRg6iiVuxGVlu1+JMTKEkd4M4wRdb1RlSPd32rJQ0JQkcSHK7uyvpjAlobdVDlFrMEoR/MgQBLOkWqNvC1gJ7cvLg2+//oR2M/N8oqQF7yzdwH3bqFvEOI+bR+bTGYPh15++cPn6jjod4HxEWXkRUX2HpBrWmMi5MIxizHZGcSsP4u3OljK+V7o5oNVAiQnVRDTRSsUowe7UKh1wr7JAhkYviVLq7wvhnGUHFZyhtUqukUZjnGbcMJBSAa1E+dQ6pVXqcoX1to9WNNrIjRdlaF3Jn7MLNKyzGGMkBkQblBvkNoqIMozuDE7RWiduK61kUjWE8bT7awa2ZcUCQ40iud/3L9u2sS4PhvmMH44MhxMmDDupf1/At0aJK86oXWpc6E3k9dZahjAI2ToX3DQwTSPDOOy7B9mfqA4o8QbSO2lLWC1kkq40KINxVpJfWxPJ+654rSWLoMjI3kj8hUUQVcOIf5qZXxwheHIuxJIZ6PgdxlxLgQY1NYxXmCmQq6bVjrIarRquRVSJWC9ZXtp0ShOgMUqKklMaQ0enFa0VpXR0TXjdGYOl9k7MlUdKtJxI68ZyuxONBjRTyTtNovK4XknbSi0J6wMaTdkij02sIpf3C8tjw4awJ/Y2vHMMIcg7hUB4aZ0YI6lkXAiS0q0sxlqZDPSCxlKKAGJ/YxeW+5Wmoe87VhsCGE/whl4iZfeJWWvBaErvpCZqQGcMOafdyymXC6tF+dr27rN3kZ/33lBaoYOjdtmf71PXv+rzN12k/vWf/4VlzRw+fuK7f/gHGLywstKGmgNRKXrw2MPIv/34M3///JE/Hs7M05G0RJSSKOx1W+WF8IHSOs8fPzHpztef/pVffvpGXBac1/yX//G/MhxODHNlWQut3SgUji8fePnhO+Z5pOSN5+MZauf6WKg5ER8L93Wl9cZgPW3LuC4jomVdeLy/sd4jzTmenj/x9Pkz2ntSTuQccaePBGdJlzfuKfJ4e+Nyecfoxt/93d/jpplhHlHakx6JEDe8VejB05Aoe6rsUayxtNJFyo36fXzkhsBjWfHeU2JkXR4SodDkjvvYeYS9ZdbHHuftLLlXbvcLyllaTgxO0lLf3t+4fvvGttzpSoy2s9PUVegWNW74MDLMZ7CWGkXyWlOXMQSiNFQ0qCugmZ+eOH7+SJgPlDUzzUf8Htv+5etXQnBM5xk/yo20lYoqhWEYGKcJ65zsWlRnmEQeDo1xGkRCnpLsZGKi9SaemlwkINMaIc/HSMkZbx0VqB2cNYR5lO/XWlovKLuPTAZR6JWYZH+okKyk3zq2LrQKa53sdpDOTJu9KLVdWdWhZvHtmOBR1lCUwSAsRKXB7yOmHBMxR7Squy9mL4DO0I0maCeZZUDJmdqKMAeVoLHojZpk7yKqWLWr36TbynGFkjBFRswqDPJzGcu4++5CCAzjCFrvqweFyokUF2LsYkpf7tie2UoB7YSyP8+y59qFH7KIl2bKGEuYT9jx9L/QHVrfyeteRCQlidhDQXJIPtK20rdF/GbW04oir3Kwt6oYJstgKtRIqRGKwzmP8R4THLUZKl5yp/LGutygFrxzKOP3QtooSnbJORfKshEfC/H+oKyRqBWlw1Yb03xCmcDWPaklrJXvqZQKZaNXUZXeLjfuS8KPlZgjtRXGcYA+E4KjKc1SRHVXilDdtZW1gjJ6H1VKkWkxyph05/vl9UHKSTpeid2lYzCG3ztZFCIsUaC0wfSOocnf0+p3Kbk8w1KU+m+Xi5wE5aSU5FFpxZ63ilNG1LHq34FP6pdffmVZNv7TMFCWhdv1Qi4wfvoBZeRl7aGh5xl9LYzjyNAbXN7plxs8bqS3r+iycTwccJOnLxdGJVDNoB2/fPmR6/uVv/uPf/odP5Jq5/D0gfOHz6zrhgqB+XSEHEVi6yQXaouFdLnQUuFwlKC4yQ+Y0kUMcbuD6oQxoP3Ayx/+xPHTZ+woqaM5rry/fkV3JPpadXQt2DCyxEh6vPPtemMoVdzp0yQjmlLoPmCOR7SztG2jLAvYTlkjWmuOxycRCqTMME2kVFlXMRiq3nBGo6eBw/TE+PRCCE7o7q1zOM3UlEURFQaMqqTbu3iS3ISbZnJTMloJQaCS68ovr6/cHwvWOob5yHR+IRxfpOvrlhFNM47h5RMxRq7LSk8FZyfGExyen5mOR4xzlJgYZs/TxxfStpBro+XCwQZcmABR5P3WLSjkRm68p+8L39D2sZyTwLraGrkWsFoi28MgN1WQm3uMKKUouRKCl3FdGNDBE4zB7uFxzgwMk4xanffkLRGRkU6ue3egJdJAaSeeHGNotVNaE+6iUvQuPrXcBY6sShS5O5GmFDFlbK8YduEWHVon5UhMEa8a2+MmhJMOav9zawFqk6K7beSYZEFeMzlv9FXGMCWuAgeuRUaaSS5NNYvZU3fww0RHlvrGORHmIIT9HDPWOrQSP09KkbgudO3ABrbHio13mtEYL3T21pqMP0G6V6Xg99gJg1EarOW3QdEWo+xy94W+LgXfC7p3dF6pccW2htWNrVQxos4DGs3961celztP5yfMeYQeUUo66U7CKnBhN4v3CrWQt5XH7UpNEWctLswCx+2VHLWodDvUGNnud+JjpStJPw7jjD08c/j4PW6YILzyuLxB3uh5Zd2TDXpt+1itYwbPfD4wztKVemdl99UFMWX9QDh7emvkbUM1GUFrJZeq3iVCJ+3xQ8aYXdEKxpmdEamFsdcFVK1VRyF+u1TrHhvjqLVQatlVfMIXxZh9F5alqGH26ULDGBkLdqUkubcZDJrWlXSj/x6KlDGax/bg559+5Hq7cTg/8cf/9J8J55E1Jdb4xrIkSbYNDx6Xb1x//jcGrdiu76Tlzte//MhhHlEYgva4tNHTSr3d0duDdr8ze8dp75IwllIbbpqYjicOWomf5n5DPx60nIilSRKngvX+IK+RH77/zPzpE7YqLpdXLr/8BbWtfPz4gQ+HE9oOHD9/RIdBbtOtUteFfnmj3q7cHyM2eMZpxI4TB+u5fv2Zn//5n8k/vfL88QPDn2astbz88B3VHWA+yw4j3lFJCOzv9RtGazlAh0DaIlpbXt/faSXy6I24LWilmOYZezgyv3yitcoaV3qr1LYHMuaMP2RMlcjojCLGkfGjpqJYU+HxfgWgpsIvP/1CzIk//f3fMY6zFH0nOJqmxUhrfKANM9WOZB0E2xNXbCsca6cuiaYi+X7D6850PjCdZ2oVIoQfZcHfStpNi1EgtDGSY8SMM1YIdXiloTYe71fSYyWXQm+NcZo4TDMuBOqOKxITbRZxQev05rFOEme998TlTolRZLYYdNP0uBK3hXXNpJR3Ca+k1VprQPF7B6UaeGvxytBap8bCUjNdW4wbwTqsNpTtwf1+2zuriO1l30tBjSLKiOuCdV4IGrlirJABqhJqO0XIAzknlsdDft9GMw6BhnQnSqnfc5mW+3WfzXS5fe9QU9MKpjmRPu9ih96VjO6kdP3+V6UVyga03Sh2ptoZRlGO9R1nhNbU1lG/mWC7/BzGe3RHvv99NCpTx0hNm/ipFLRcCAomF/ht+6GMPItdW5QxVCxde5QCb0C5xmwbkzcM/og2v+1QqvgW11XGv1WiPx7X9z1TaY8feUScNfRWMUaUbdo4Yoysy8K2boIomiaOHz7w/Ie/5/j0kW4DNswcTkfS/UZ6XLi/NlKKdNXx84A/ntDaMB4mwhAoRawZLdedVD7hxpkwDJQ9UYGUBUZtHUPw0vmvG5SK2y9DxgmZxIUB7WQX1VujloyuWeJKfhstd5nASMhkkl0a8h313iUDq0JWInL5TaaPFlrFb+NqZTRkgzaWUld6aah/D2bew3nm2/udX758w7/f+K/jhA3yv6RK4/b2ztuXV+bRMxwPrCnyf/uf/i+EJg9la5lteaBBUPbGM376QCoJS2HwivNp4pEy99sb5/uMUo7aLPN0ZJxPrMuNvm3U6zsmbdhaMcHScOTeebvfeXx9Z3p+YXz6gDoc6FvAn85Ua4jAbD3X643aq+yjWud+u5EeEjCmwkh6CKG7mkpG4SeNVp1gHZVNDL3bwjQeKCWR6kaYnoSYYQ3ODNS0SY5VLdyv70xtBmWwRvwV0zzTcoKOqHV6Iz6u4u8xirw86K3ghwNKG7btxpcvX+i58YfPnzl/OJFvleJHug3crg9RNo0Dmc6WoXdLaopvlzf67cpnFM8fP5Fz4/p+J/Ur/pEww4y2ko769rixfPnCcrlzCB7VK71ntJXbnPPC4qtKcTg9Y7Vluf5KzXskxjBSSuHydsEMgrMJVlA2qsnh9v5+AQXOe+ZJuHclddlLlrKP8rxEQeyCg1oSPUUBgpZCqxlKZFkSy1WJk3/3sP0eKZEzzoiPqe3L7t/uk9a7PeoD8u6DM7uMWTV2c6hBxY0ao0ifaUJFoPGoZd+NwXweCeMoJlgtF6va9hFZb/RS9oTfRKvit8N4lDKyQ+h9J1nLTVpu2xqnnYwjlcLoShimvQia3xmG2lmMkn2J0QpUB6NoSdG6o3RD0RY1TFBmdI+4ccI6D0hZU9rATi+w1uzjJDlI87ZQWpFDM270hgT+9U6qomRVStFyw1hNdwNqesYbj6oNXRr0zKePz5z++Jn5cBBjb5P48y46esGV5QSqo1Tfjb8GE0aMtuJNfNwpa95TgCWk04aAmybmekbtVAofAtM0MQ0epzuNyjx5Rv/MFjx3GnF5YOUWJEnLg+xQQWGc2YVDAjkOLtAQooQqHaU9dphByZj39wBIZJ9nlcGPE03JczSEEeudmNer5E5p1XcAdd5N1CIvt064l0pb0BL8/tuzG7cNoxG8284SdHvhqwjUttVCjtBVBmWoytCUov976KQ+fHzm1/c7l/crRik+ffpAul8xm+VxvfL+5Sdu9wsfPvwD4/RCzJmf//u/0C83jlNgKyujM+hsOLoz0+gYxoFcItY7bAicPr2Qbw/u68pPf/6R43nDH1+YhgNV3yGu6FrwzmK6Jm130rWjT46SO7f7yvZYufz0K09Pz8zzzHyY0O3EZkRSm68Xvv7ylfkxEJzDOSMqphhl2X56kigPbSkpQZE8KFcSh8Ez2ReUB9MLqmTS9cpPrz8xXm58/uNnrJJUT1rFOEsuEWJCOcN8foJW8dZxPj2xPh6yy2lym+05ke7f8CGge2PdMjFdhfxwOJFK5+vrF776G+Y0UhX014v8c98uHEbP4AN5bbTacd5xWa6Uq/DnrDMMGuKyCUWARlpu9C2ijWOeArO14Afev73xy/2G6oWn5yOH84R2AWMC0+kDw9ML8+FEWSNp+SYEAm3EvFsqRiuJsXCO2hsuieRcvDMy/rLO4pxBtSpFp++UBBEiU3eVmamF9P7Oer8zjIN0Ha0KzLdD3CLbKrHmaI0b1U6ekL1Q14KOUftSvzW5uf+WVyUO/yyUk7TibKB5J3w8q6ipspVIrgVVDE11CgqakqyyBpWGV06gqTlLV6kg0/euTpR4Joxi7tWGUiKqrdQsF7jfDJ1C4ZDoErP7iazb49qLgGy1MbjgCSFI57ptNGNkOa+07LlyItfdVKob2nlRkw3DHv3eBO47BBEp0cQ4WtvvnVVrjZrE0OyMlS4ShXaGVhq5FenEmhRmP56Ynr8DP9Af78QvP6J6YTyNHE4TYT6QSqVHMN5JVFWrGOvBRAE1t4o2CV8qDekIBmO5xUgpslfUSmOcww4D4XREH0bsfcXuCCZUodyvoojzHr8XOkPFOy0hrvYk2WO7MV4rxNpQZFcZhpHeG60JhFqbEXc4o62hrg9AYbqIebYtSrG1GmUkgbppRS+SMKz2TrmWgrXiBWvtt+cREansqj5plC3WysWh5kwrCWPNbhFQKCREszfBhQkKqe9MzULrSuABVqwsGPdXnfN/00UqpoRRu4TZar69fuMf/0//ZzmEdMdLdAyxR472IKKAzx+4tYLzFlUaYwh89/0nDqcjraxcfv6R+/sVXTbWLGKG+emJ7XEnxkz65QtPMaLXG2me+fT9H1CHA+0YWL9+od0fPG53dFGkXHk6nTGHme2x8H/9P/4f+M/3N86fntm2lVrTDn4t+MFK7k2rrI/MY12JtfDy4YXuLdoZ8pbIj7vw6nD4eaalSlw2StrwDSZrWHphub+xpMj5ZDGjEyyQH5iHmTGfsG2DVpmHgbgmxiAQ0VoK1pxJKVJyITiRYtt9kdyV4v5+IynD4fkjbj4LWdw57OkZ/MD19cG//bd/Jd0fzNOI/ukrMUdK23g6fUAPgXUtjIPDqkZcbwAcTyPKBoq21B32GpcH8e0NjRSRSmOeR04fXvCDHMDTeOT48j3jp+8ZwsDj7RU3nhmaEMdbbxQifpLsI++dHGTa00phCAPnzx9JMZHiSvvd5FwFX4Om1I1auxzkrbNtScZORigGvymVUhFFWCwi97VWYu1xFu3le2pNUEfBD9QkhPtWpZDIPkRGZr8ZhfNW6SZRi8F68URJCKEh98pWhZJhvMM6QQWhDXHdqFsSf1htAsm1mm4saAdOEfyEm0dc8PTtTryv+0hTuqm+g01rqVIUgnRUGiF7S65YofZGtbLfK1HiKkoRYUJXsn/rKaFVRbeEoTIEj7Wyu7RGyZ6sC2nfGYOyWsCtuZCL0OyVlv2VUoZeE+hMz5nWRIUInVji7xePWjuzHzn4GzotxPsbbbtQaazXDV0jLhW6ETqENh5o1AaNjrLCaNQlge/yu9RazPKtYqZRzKu1iZx/8OhhommL84ERjWmCKapRPGlZSbGxLkgWVs7Y3jg/PWHDJGq6GlnvF7bblZqyXJq924u2EDm6cejpiDs8CRA2RVBN4MAKYioCfFYaHTxuGtBak9aVnCRUVGu9Q5M1RmkhqsRIKStaGVoYpEtUYu9w3u+jXxFgGANGBxnxIZ2TFLl9p7azLkup+1dk0MHSjds5mf/rn7/pInW/3QjOMEyB89MTrTf+8uWLHAzjwPHpRHCe6/uF1hTBeYL3xMMsMM4ui9iSIroOlC1yuy68f7tgrcYfBuanE8cwsl1GlvcLOW5Mo0PVhfvbneeXM4fJkUvj9es7r9/eWWKiXq/UrjgejkxTIK8Plus7tkXqciPFjTAEwjiytQWapaXG22Ph67dXYqkM80TnzlTgfDrijSHMAWc00xBoRrHVA6k1cnwIo0uDG0dePnwA5xnHgDYS+2DGs6TBWii3V3raUNZjbGX0AkhNKeGtRguxD2fEV+WtJZbCVrOM0caB58/f4bThfBgxrWPOZx7W0R+N83d/ZPhPE9Z77pc3XF0wZcPQRGY9DDydnzg/vzDNM13LiMfYkW4H/OlMKYUv//LP3LfI/X4j5UgYA8fnJw7nMy546uOGN6C7RIbn+KCnBRNmjsNRZuslE8aJHBdAxBM1J2iZpjraB4bjkTAUlrsmr3faHmBnlKLxG2bGyLx+v+Vas4OHlfjJcpHsqVYrOVfZ3+iAtRarDBY53H8zkRaERO2co+0HX60dtfMcg9fgHP0331DOaC3Ed3bptyzHdyqCc+gm9GqlFXnHQFkte4hKI7eG7hnnAm6aMUMgjIFcC9tVwLf8BuXtsK4rylmGMWD3eHK0Rg0D3Y/0FFGtoUqjlUpunazlu6F1Gg1UQ1svoy8HdIMLwtjsmb3IS4H5baSnYsfUIubTGKUIGCNx7vvP0KqkEpdSxB+kxIPTd/FA743l8aDlTFvukuPVE1aLBHtLkeX1nSEr3KTwpaH68vvYsvcmy/1W9u4FjHWME7BpYkz4caZ7GcXn1qgoYu2oJWNcpZVG2URsoqyiWKGGDGnEGJGiozThcGY+vuDPHwXjtd1JtZHe3yVqSHlqyayrRNnbIIGVKCd09pop20bdVgoNH2Tyolyn7xegvMl+Nm/7ftZY7G7IFaJJI+csqt6Ufied5CRePHZfU21y4dm2DWomDAZVCzSJGTK7TUCAzLukTwsx3/pAcwMo6cD+ms/fdJGapxFrDKkr5nlAOScPybLg/l/k/VmPZNmVZgmuM99BRHSwwd3JCGZGZaGqHvv//5AEsoEMRlSQPpibmaqKyB3O3A/7mkWh0Y3kK0EDHO4AQXczVdF7z9n7+9Zq8Nsvv7PtKy4E7i83jNZM5zOP755RtbJ8+cK2b2x/eUPVwjzP0lcYBuzgef/DO4qRWb09zVAybZCRYC2J7X7n9csnes6s95U//88/8+XtlR/+6Q/4wbPcF5zXoCph8kzDE/PkGCfPfJpQ2rBtO/fbwnK7kZdKLnC9JlKrdO0p5ZV4W1HPOw8PM+//+JFhFOnebdkJ00hB4Q0M84AOhpN7ZEkdZS0PlxPjNNGxqPGCcaOkuG5vKO2IeyIEx+nhzHrfqCWz7pm8b9Qc0cFxcjOWwh5XeaEHjx09CsEvfTPalvtCNgE/nPgv/6//wod//mcZt6XE7fUTy29/Jb/8TtkXlBNoKggiRcCrXbiDykqyynsG7xnGkd4KF3vidJoZp4FhmnDBU1qhpJ3l669oJWOzfL/BMd6wzmM06DmQlivdapx13G83XvPbMbtXtEP9Pg4B0woZ2HNhj1F+WJUBzDGiEBtwyRGJgfqD6ScPHDEdyw1Ea03co/SvrEVZGRy2WtlQ2OCZpumIiFuMNVjnjq9Lkey1duQqJGxnDCYElA8ScDBWngHO4ayhxUKssqPy8yS3rmN0W3MhlkIwlo6Mj0BTYybHlfv9fnAR5WFlDn7h8HDmFAZayqTaMePM+PyeMEzUdaXcXiFLarHGTOxVHoIahnFAK4NuGRsGMB3XBE1GjdS6S2BGuUObYgGpDlAbNSdKOtQTSkSlpR9Fa21Ed3Lge3o/XtBdKAk5Jsq2sa4LZbtzfnwknC5oH1DTCRV3+rbQTZDe0u0z6f6GcYYQRiGKlCSMRzugVMNbjXeWeZKAQEyFphvdgmlN6gZVxuQUKf+m+0bLCRcc5jIKtguJcJcsclMJznj8IM+g7j1zyVIGv11RRkv3rWTQoHXDToFug6TlmpD5YyqCXzMePwVUN+R0o24rJcrNPh/1AmUstkpSVWu5KaWUyCWTa5W0bpURNk2hqnTGWoda5WDQ2gG0TUccvTW0d/jgsC58Txh2benGUrWVVPAR1f9bfv1dv6T+5V/+C7VV3u4L4+WBiqbnxqfbDd0ap0FeJjkn0cmXxsPzO9799AdaLHjjKesbpiZ8GNHDxOnDA8Mf/xlTMycvjL5cO83ANDpq6nKaMRbVNH/987+zPN9Ytsjb/Y42no8fPjKMjq+64ykMZoDa8NYyzwPDGOjK8Xq9s9xXShLY40GCYTxNTFZzPk+UvB9L7sTr55XTaDCPZ1JM3F7vmIcnqI3nyyPPDxNmHoXBFiXurOrO4GacD+QaSXuRk1FObLcXequc/+lHwujRWrHcDK/XV+EQls5adpwSLD9oTg9PnKcTjc7y9pkSC7100h7Zto2H5z/wp//2f3H+5//K/PiMOawxj08fuI8PbPN/cH/9K6UmtnXj9nUjDYO4crSSlJdylPsVHwJtu+JsowcpjJ7PJ/w4YI2cIPcuhIL901/5/dd/h9oIxvHw8R2lb3Q7oK2l5Y2+vqHHk1A9tpX97at0kZyczksX31jvHb5pyXMWu6lVNANdmaMbU0hbpJeO0YikMheCC3KjBeiNHDe2RdFLwHrx8ECnlyoUAxJFI0y9IDDcrjgoFx26whiNV5DaKgoObcFJGMJ2KN+I01VI7+IaMjg/CmFgWyUFVzLzMDBePmDCSNdCrC77Rk8dP0yEYRTw6fFwOs8T1jhaTKzXO3Y8cz49M85P2GFC+4m1N0oTTbtSRnxJZcVaw6Ak9GCUxmtoWl7gan2V4rrpAnr+hnwyCqqRh2CXm6OM3xqqyc1R14QyDmsdFYNyntokHl2LRPJrruRlgwLNWJK2bG5EnT4QTu+ow4TJG9y+QKvUElmvX1luV+l7zdITKinivcVo0WRE1QjO4UOQPc1xq+iS/j9Gq53cG844ChDROD+ivEc7Iby3LuGYlCoqd7pP+PXOMF5xnORAPZ+p7z6wKUWJK0PwGBfo/QiDXN5jn36Q+P52RaUVow1VG7qxdGspVbphwmaUlGJqmUrBdKGkgyCdFMKm1EeSVBuhc8iOVBQ2rQubsPeO9QGlvIwCDxxTr1XCF62go9zGlHF0H6jtkEWqhlaaPf4DpPt+/OMfaTTc1xecH4h74ksWQOuy3nl+94g2ipKEINBLZXQO3TrKGowfIEce373jfB7BWdz5jHGBcrvRtg2HxQ6OXCsYi7OV2/3G9XXn9esdYyp+nNlj5nw+UXvDtoKOnccwQKucvGBKxsFxPk2klPn69Su320LDMowjTRnefn9lva+cTyeqhlLEcmub3IrolZQz27KxLhvbsqPqVxmBPT4zDh4fPDU37POF2huJTtwSy1bZr68oP4oHx1lqawQrvQtRoHfCMKK1eGcIFkujGU1zA34YqDF915KXmNDa0kdHUzBZy8OPf+DdT/8s8NumZCZNJ5wuuB9+xKhEaQtvn35mfXujlUoagpxau5AbjHG0bWGcR4wxzPMsc+1ayCnhvKMI6Yd+cO72nNn3iOkdpTv5dpUxzVDYb5H48kLfV8w04bwX0GlJGDPSk9hj6/FSyiXLv1sbMF36U/tOJaNzwYUJ1RJbjGQlfD9tHN4pUlwlDWc0bhzE86O0PDytoTahVButCEEU7vLnTuhaUFZsqL0jhddWD0+UEiL/8dlX9dgZobFKkWOUnos6uGkK6r4R7wvpLl8L4wPDODE+PGKGM9p4WmuU7c5uDF5NtBJpvRNKQvVCLpF+r+zrSkyVh/kBZzQmRxxFhHxe0y8nSQcaj3YTWAstyaGjVZxVTEEU9HHLQuQoss/QR8tTNTlZ9yLxZaWRkZWxUDI5LfKC6+C8jPN6ybRW5AGrZOxbWqb1DKaD1kKgPz2g5vf06Zk+PWD8hDUDKibK/YvIRRGJJEp9Fx4qZQVka5T0pJKM7jpNUo1Go3NnX1ZSrvRvavgwidpEd5qX76kOHoyjuYFkLN1tKG60FMlJbn5523DGSLCkyWhcGdGRGOfRXtN1oM6P2NMT9nSRBOkuRAjZZxrxT6WE9Y4wnSjmQBNpRdcaVQWRRpdbfQGUOrpXSjxXxhi5rYLgarTGWi/jvCxA3Y5C90rpcqMUyUmnV9F0dOvpNoCxNKzgtLr0p/I/AgXdhSCRSn1jXTZurzfinlBKM0yBYRx40JpS31iuV1QppOuN7ctXKvDy+SuDavR5ZJpOWG/JtfDy9Qv7deP221dybfzp//hvDMNEuq68vC789vl3llvn9083/umPDwe5oPD8OJFzYl+vKOcxRhbbSjXmaWQInhwjXz994fffPlPkD8H88UemcSJWSKVQqeLrWQvrfWMy0N498Pj0xHQ6k2plq006L3nn+fSe0yxLaF1FCeBPg4wTa2dTji8vb1x//QUzzAKApZNKIXgrp5vaaEWSVX4YqBXG+cw0Brw3DCHIyfXrCy+//cZyu9F7ZzpfsEFhFAyXR04//ISdL8ciVdGVcL+0RvYfDw9MyzPL6xtVv5B6lfDLMLJtKzFlBm8PlYdmOl84DzNh9OzbJmOzwxZceiNuG0nJPiOMkzy0akMNnvnxgrGOr59Wvnz5yvb6Rng4CarIe06nC2GUkWA70nnQyaXRFEekVskuCJG/tdSpFPk6OYe1hnpAYDsHOVqLM2kKIgg0tZKWO+0Yc5Ra8W7gdDphXEB1KQunlKi5YJzDuIB1jq1k0h4JYWAcT8JFy1mEkiBf45JpaZOXh3WHCgG0kVJmbY2cK2GwqPEksr0Q6NpCA0elH4v6EhOqFNRxM2o5s+w7rRZsCCjd2ZY3yvpG9yIHVH7AzxdsOKG1xw1nnFGk2wslJ3qVScbgB7y12OZJxpLrQYRQcnNqRdQpqgvEtWNk9OtkllfSRi87dC2l6C4P8NYlSWaUPhb7TSwCHUBjpwvh4QPj5R1+mMXL1Aq1VUqv7HFD1YSfpJjdDxrCN2ULh/pEqE5eHuZuRA8BWyq9a9IqRIemi/zsa0dHzNvGCwOvNE3UHnt6J/LG7UrOhe2+MQ9i3O0tQrXkVMlpp+0rrXxjOBrxO5lBbvRd0aOMQ9PtxrYs8nXRx++2JEAi+c5ZmpHgiTIWU4qYeovQ3uWQKAdVYx3W+u/E/5wzRitUGHDWHeR44TS2fqQtXaN5oVKotEJPQnn3ge4Hqg7fR+ZdaRSIzeBv+PV3/ZL69OtnTk8PfPjwE6+vb/z5X/+DddnZt8y7x0ceziee33mM8fy27VhtISXS/UbXGl0zw+ShF0arOY2B23LnX//jr/zrn//C509vTKcLuIB3QmtYt8Yvvy/U6gjTKLswVbicB2naB2mDx16w2uLcjA0zpShe7q+UbeftttKN/DCkUnkOATdPfPjpA96Jj+npJGXSL79+pt6v7HEjZs8SpbTX/YjiRkkrtlWIUagETYMTXYLuDWc9e85s1xdevl5RdpO+jIIWV5zu+LvoAHqTF983/w054c8Tp1k8N2VfMTTOf/iJaX9ifXmll8r2+say75zeOfkatKMzpCwdOfFjNEZrxnFGPb4nx42oRFlyngbZBb29EsYzYzjhnEabjHWKefJiL44DaY8Sa1WQt8i+rdQiJ0SnZW5vFDT/Djd6eqqUGDE+ML57hx4GlA9SELUG4zzjPNGqEKS7anR9ADidp6x3GdV6J3iZg3zQtUOrkcOeIQk17yhJFN7KWfw0SVN/32g1UdqhXzEWZzxaGTkpa03rXl52vVP7fybdWhfUzXi6MM9nUXrUm0gq+0GS/iZgVIrSO7UJZgmjqdqwY2jWMZ+eIVygCxm8Iw//Hhfqdie+faFECbcYpTHWYKynl4YWqBDbeieuC6YVLvOAGx8Y3z9gp2fsMKO6ImjxTNXllZhWWlxoONrmyNYLeNQNaCu9Iw4D9DeeoFJgO/J1UOrg5TUcFRMEFbXt+UBIKXSXqUBTGucHjHEok7FDpQ0X7Okddjih/UBLERU3Cm+UvFOXV3LexGDQkXBFa+JKckb2Q0rRUpW/A8p59OkRHUZUjoBnbJqmpXekrPAUlTakCnsSx5SxljA+oIYTFQPdoLSXknGNtLZT4p29R/ZlYV/u4jOrRZBZh78q5ht7uZL1K6lrSk64lrFtFwwRSPWhVfK60b4xI61FdS1UGjtIVtZkWk7y2e9yNBDqRqEKdp0c03dCPE0+o+2wAQMkNMkN6DCRi3TwTBPUlTUDDSGgt+NmrDmMDPUfYCe1LhE/CZLj7fXG/b6xLjstJXnA9kLck1TxNQTnKGnj5//4N9798APvPzxyOY08PZyYJ0/wlt5H/viHH/nlr5+OzkbnP/78r3jTJKoZJk6XC7f7zsf3Tzw9j7JXaJ19jbJDcI7xPMqH2ToahjVGrl9v5GVFX85MjxO3t1dsraTbwvX1Sl53TM48PTxw/vgeOwacs+y/O6yGZVnoRvHhpz/ibaAvV/bbjm6VeLthT5owe4wSyYEG8rZx//pKfP16fNUkHdaQwl9tB0etJHKMLMsCSsu8PL9h8456k8V+D5bhdOJ0eqTHyM1YtmVhS19Y9xW7r+TtTtmuGEYZnaDQraGKQquK7U00DW7gPJ3xyjAER6+FaZhRzzNhOmNaI28v1C1S1Z1urAQ+UmSPiVI7tTdAyocpJ+5H+s1qzfW3L7jU6bmy3zfC6YSfZ3wYaaWyx11U3UoUBUr1oyxrsV5hQsAPA1sR+6o1hnka6VpROxREYAeAkuJuCEEW0g1K79KJahIuQR8LfmPQWDqaddnwteKcl+W3k5d8a5Vcsgj16LgQ0G6gakfXoGxG904ru9hslRIwq7G0rsilkbco8rrc6DYwDBPDeKaUxn6/HToMg1Gi92hxPUqgXVJ31uCGET9MTF2T1ivr7ZV1ucmNu2YpLM+Jj6cPnLTcKC2SPDzidQedO5FiZV1WwqjQXuC/sn9UEg6xFqO1HJCauLP68f+lV3qJqJIorZBbhyOO3lHyMtVHIlVriaMrg26KNj9hL+/QfqDXQnz9Srp+odZMqwlKwqpG1537umERJJU6RH7Ocfy8yM3auyCUh9MDxgXypmgpgw+4MdOjGKIbEPeNLTdyyTijscHJaDIf48gjHWeVAoyMzmsn3hZur19Zb1fCMMrtSRvU8XXZ7gs/f/rCNXYylt46H55O/PA8Y+1A600K1a0ScyLVgvNB/FdWoMm9d3npHaGghkTtnTNoK2O+kpO8cMwhVsxRNBvqG4BEo7rUHaQ0bihdrL3OSmG31EatidiPG6YfQCEuqX8EM+///Ne/MP3+9djTrATr6KMnpQ2bN86mMSoI788EozBZbhfkxmn2BK+J25Xpx2d5CGtQ1vD89I4fP34g550PT0+o0sg18/UmJ9jzNDAFzYdnGX/4cOF62/n9deX53UeG03w8oBo5FV5iJu87r683eqm8/zCixxFud3SF10+yn+olo4zjNEzimDqWlcoYHh8vdHXMhV1A+QHjJ86XzmmaJNaNjJxKzvQmO5zldqPEjT9+eObx2RJL/a6fUM4fu6lOLRVNExdWk1iz0ZqXr6/8tv6ODoHxh3c8TWfm2PDKMI4T6/2Gtprn5wes0+T1lfvvhXTMs1sTptw8jdjgyDlx/fwbL58+cb/dBM+zChdumib8sbj1RtPLRtt2asyoQR+3B0mpbTEdMeZA3HasddJp+tYfuu/89eXfqKUxXM48PH1kfHgk+IHlvkA6qOrDIBy8bRcsUgdrxP3jrKHPs9DGW5Y0nVYHULOjlOzzrAuEcAQ0kkbVhqFTVRfYZ++yaFcHURtF7YpeGmShVCv9bVSn5UGcM71x1BQmmrKkdqjuR4lCkzQaGR/mVnAaZLl9mI+LxjjPZAO9Nta3r1it2dqOsUcM3FlUPZA4WqOtpihQ3mPHGTedoCl6TZg9oG2j1ETKlfuSsBHGd19wpwnCJGK/uFIPSrgfRhFxWsgqYE1Au5Ha+vFCPsq6rVGrplQRB2qUHLQ0xy1X1BclZ5R1uHHEaHmhOWskjn+UTqtSGCO3d6xo29GdvEduX37l7S//iulVelBaM44y/kxZRp3OB6bTgHUDvamj12ak0G4Myllh25UkZdnDo2S9PXBXimVPxG2ja8vDwwlnJZbf9ztYLy+kEql5Byrz/MQ0ndAocizkVNl3OdwM84mmrXjQ4sbr2xtvb2+sWaH8CL3RasAYCDagtBardG+4miEJnGAYJ1yQ/XeNSUIQ2qGdAl2lZWYUVToQ5NopueCPSLlRSriX1khHq2pqQ3b8rVK3GzVG6YUZC10CFaV3mtFoE+h2wCh5Lihl/v8/3P8fv/6uX1Jv1zt2GEgxEfeND8/PXKaBL+uN0AsmLUynCaMU+2DwVtFz4fnxPTlnPv/lLzycArY1gpGTzG3f+Osvv+Ks5qcfHvjpcSIox6eXO9dtI8fCOMK7d0901fn8emecNMo53v3znxhPF6Che0XXRto3Uq4oa7CnB8xRDl1fb/TSub3eebvepRkfPB8/fCA8PuGGE+v1xsuvX0jLjaeHC8E59pJ5fXmloLm/3phNI6aCM+qwYmq0c6Rc2LdIbBWsJkwDYXzgdl+4rQu1d0KQU9oWd/J6ZzAKA6RUJMI6ntiDoTw7hssDdj5BmNhaIuaddL1RY0KlTMuRojS3r79y/dwo6yLw0iT9iYfLhfPzE34c2JeV9eWF9X6jdjF4dq1JNHyOPMSKPV+kNDpO+NNJRm13RSvgfad1wbAYZ0gp461hHkfm00huhe1toSXNaRgE/gvsy4pBk9KO8052QtYekNJMOSgM5YCvtpLBGBEf7pWUC14H+eHqctOpzaC7qLp7raRdor7aGtEbJLkJOOMkrNM5knOSwFNGVORa9eOgcdwgaqUVIWBID8rRnZSpUdIPUq3KDqBXSt9l+Q3HA79QtSPYgAXuyxv3+xtegxs9JnhyiexKH/vIgtfIQyVW3DBitMO7gVbbEY+32DBQcdSmUSrRW+P2+88426inBzlZtwpxgdZoNtB9pztDdRPZjEJOKQldIyXL3uTIVQMH+81IYMKqTqkZaz1Ne1QWB5R2/gjOyMFCxHxFOmQoejtGVnGjmCvNaMq6UrYbPd4ptaCsRflA7XKj8PNZHE3GiOdrmI64dpKCbGtUhISe9xV1YImsUThnMLhjT6jJNTNPMnnxwyA3qFblplkKymoyTVxu04g2MmqrrVFoVCX69YaSG0lvbHtkWzbStmBNI3RDrjv0SkuK7QbufGY8z9IZU41gLriSsdqgjpDONwaiNUZGd0gNQ1upzNMVRltsGFHHzdhZh/Jevi9BbkStQM8dnTOqREra0bVh3EBUTUI+SqOsI5wecadn8CNOaVrcyE39Tc/5v+uXlHGabb3z9O7E+WLxaFqqOO+475n/++cXjF+IR9Hu47tHSV6pRryu3F5feTz/hO6VvC8sr6/8/OkrtyWiz08Ml5G97CwvN5brwuQcRsHlPDCMntt9JW0J3TfCacRPBm0aXmtYd8q+YHXHXGbGywd6M+T9jf3lE+ttBWV5uW5s644dA+M08P5f/jfmH/6Z5hyx/JnOX3DBYwZPrpleOmXL3HPkr7/8zmzBtsLTZeJJWfQwYbVj3YXx16wk93Kr9O0ms3HjYJiw4wmlFSUlSoeSK9u+s+6V87tH3OU9w/mCOl8I80lsprVS1ivpfqNsC23f8MbQVGDZN/bffqbVxna/E/edvCdSlBjvw9MDz+/eMUynA/9jyLkKINR28rZSt06PO7onrDecZs80zzirySmy3hvaaKaTxOrFP6RpOTNNJx6f3wvc9rRQ836wCDvL7YpShrTv1FoJw4BqjbRtpHjovw/jaOmgqkIZhe2Chino72ZUZzTGeLSXNJvSlpY7Ma7s+0ZOEeusFBeVph1RXmOd3FpqlAehtd8fqnKM7fTWBXWnNajCdnsDZfD+wqCFJFC1ln0MYLTFuUA3jpSKjMpioZSMMsAh9mtpI95vNKWoema0lni9UYqMIztwmie8NYd11csLjUomfS8Pq57ohzpEKQHVLvc31K+NfH7DO09w44GZkhemOkZK6rhltCKF51oTvRahoah6xJ0tGH8kBQ9EbevSvTEBnWRE9i0i7759/2uh9f6dSqHIYp6OC317xYwndGmMtmHPIznuKO/x0/mwF49Y71G1klujaSkOxxTp+y7A1GzF4t4VKow0aw+Ds5bbnzWH6gSmacA4C8agyNSaJRjhB7zK0LLcoK3FNCE9bMsN0+XnUdWG6ZqWEtv1RnWe7geCn5hNwM0bKUmPT+GYQ6D3TiwFX5sk8LRGO4/qcmvtpZHSRkPJzKV1apXAkD1i9aUVWpFFq7UamqFmSF2CKTJ9mFEuQJfDqTYWZxw6XNCtQUnUJABqO4zY+ZH53Y/40wPaBin9lkTkHyA48U//8k/kbeX9uyfSvrG+LZSW+fBPP7C+Xfm87rgqJkvnwU8Tzp1kKbmvWG1wypJj4fPnX/jy9sY1ZobLMwwjbphIL698uf5GXDeePr7HO0uYvIwlQIgDvaFSZigZvcP9fmd5eYWY+OkPP3A6zQzzQFwjKd1ZXl7IWbo4qoMLJ8bHJ04/fOD84Z84vf8DzTlZwMeNvrwwzhM57ZQCy7IRS+Th8QFdI7d9577d+LquPC2Ry+WBWCp7KVQ0rXVJkHpoaKbLI+7xIy5MtJIk5pwKy+2VLVaU9YT5wvDDB/x0pimHagrVKzUnGa9eX0lvX3AKTo8PzD4Q1sivv/7K9fUq40OtZXRgHN06WSxjKXvh9eWNv/7lZ5ZNkmvv3j9KL6VkSocpJx7PT7gwUZombzsxJRqddpharRHn1MPjI/148QzzCRMGpjjy+uUT903wRq1W6dJ8G7eVQtr37zZR8VeIm8goI0w56/BOxohGy4K6tCqnSyMPoKaPkZVqlFqke2vtd7eURG+14Ha0g1oxGFreiHE7VCAK571IA1MW5fhBWchxR5sFM+/UQXZUNWdqitRtReUdtMKFkTEMoi/PlcE5oLHHlT0nShVW4x4Tad/p2pBihH4kSlujqs7DPENv7PvC7e3rYc+VcSCtUZMs9VPK6MMdVKoWOn6OzPOImfg+StNKMXgv4/TWxNmllMTnuyhSehO6R2gN3SXA03SRsA2K3kQoSe+4g7iuDlCqtZKmbErMt+pIjCU6Je2HGkQznHbCMDI9TDB9JC0rtWu0HyhdiwE67ejWMFpi8y1HyroQb1dmpyCLW8lpT7Ae5azc1orcvGWaYWi90EvFqo7i+OylCE3SnyatNDo17pRS8MbRmxRxB6exVjNMIykn9m2jZ8Xp+Rl7fsTZgfP5xn77yr7vx3NIo+23UaMw+7yXCoNRHa2MlMCNkZ+hWsg1U1KiVUF0KWcFBntUO1rOQjGvwvkz7oAPN6SE26DkQikJZwfG+UKYzqjWWF++kEqimYYbJsbzE/PDM26cj4OOpGHbP8JNipoJRvHLX37l7fWKU5bzU2A8n0ArwvlER2Os5RwM2juGeeDtfue6b4zDQO6NP//7z/z7X36mW40/BfzloJ1nGT/404UwBz78+P4AWipiTkzTzPO7D8RtJ20LJoop+O33TyzLncF49uvC44dEvX8h3xfaIpr3PUtMcwyO6fyB5//2f3J69w5zfkczRvxNjw/cHt9zvb+wbTvwbSzhOfsTTw9nWt7Ybm/8/usvfHq58RY70+c3ujGEeWK+PGJ9ODhhmZgrs3UM44gJAZomLhOrsuSuKU3cL0ob+mHazHkTU22H3jLp/kra79DrcSPQGO8pFTF8Wo05CMpFFZo2zJczl6dn5suFZdl4u628XDdSLexZKNzT9A7nB1yQkZAxsi97u91Ju5z6lQ0inytFTrFGoZGbBYe6uhbRE1xf3khxF++PC6gqfDcFpBjZt40QJPAg8xYtdAojqgNtJIJsvcM5RzwsqkprupYuVDla+rQKpQpdQXVqabSeUKajnAdnpHSJoeuF1iJ930S3fqSqTOvsKVNbFxVHEdJ0u98w8w3tRxn/xI2yL+TlTlqvWGu5PBmskkjG4I4RXk7kEjG9Mp0Gwkn2caUIsikeD6lW2+HJkpescYZWihD+exH3Ve9smyTO1jeJTlsnjw9jPM4FvA8YLeEklEI7d3z2jq5N54D+Whm1KU1rhc4OfaXUTN8LaEPTHqwnOCs3p7TJ7sd6QEvMXyuJUfcqKCk4EpGN1jNKSUGY3gm9MJnCME8YO1POmVKk4H+7Xtlvt6OjpxjGCUOlVkfNGzmt3Lci05LgJT1YE3XLpCjJWK013Qeali6jrnIjLi2z3hfBSwG8vQlwzDjSEWJAK2xt0CuD8/jRyxi7Z5T3mPGR6fFH7DxjasHVRN8DYOhaf1exKAW5yL5QdTBdQiDOezEUH3G81pogq2qh5YLVmtYaKRbwcuOnNaFE2C7wWKPpTVNTpZWVbiT44qxQ34d5xo8jdUt0tPT3/IwfZ4Zxkn6V0JEAmSI4N/5Nj/m/65fU7effMKry+brxtkQeH594aIptkQ/Ew8MD+57Z4073ssRT1jE/XHiXK+M0UqfAby8Ln64rl8uZgBUlxrISYyZeV0wpPJ9nxnE4LKWKh/MJ5yVOfH155esvheV6QzX4+O4B+9M7BmNlObzdqQrqmsj3yDCeuC4LwWnsMHH5r3/i+f/4v8Q6agWv37/xwoBtj4eJdsJNE85adInMk2caT8SnExX45Zffua47P//8hraKf/qXf2Z66AyjQytYbwvbdpco7DTLKKg3dJW+lBkH9vuN6/2K+/qZx/mMlnWOYGaazPjTl8/Et1dsTuSs2OyCq4oUE6p3pmkiJynlBmfJMcqyVcG+rcRtwRl4ejxjnHRQglWYXqWPpTvkyPb6gppHjDXklOkYwjhigiItd3JOqP1Y8JZMz4VNaxqN29ev3K83xmkUNFPvogao9T9VE9YSvMc7Jzw6RB4oqCT7/VTcWsU5WTyXbGhdiuGtR1GVb5skKrWiKis6iZzp6ltKcMIqi9FBtBW9Y7XElOs3q+mhrRd/UqAZS9WJ3itb2lD3rzJeUkb4bPudvC3EdWEYAmlbKVG08C0n0r6JQE+DHwLKeyERWM3yeiPuIkakg0bGVUYJyVsrsF5o5hxx/1oKqWS0UXgnBO3aJPE6DtL5GscBYzQtbTLqU/2wDXda64x+xDqDGyaUH2jaSfm0JNhvtPWFkgUhhTFYJzqJXnYUsnuTXZ1QEjDHDahX6Wv1frSZRBMyjQEfAsYo5uDl4NQ6VnnsOFBrpd6vbPcb16+foYMLFqsRkWHzWC0/PzY84IYBa8z3kFXKu9BFjp7btq40ZK9H7Uf3LbNtu0wWjv0jXVN9IDy/I/gTW+2c8krXcugyzqC1IlgD/nBilYRNlhwX8v0qjL1SiaUxTBNDcNRWUdpinf1P1YniGAMmcspH0lThtJOu2JGizCmCG9HaU7SnG3OMTbuMKouQ4NvRL8NqdBjFh3ZMJtZlZb8t7DEJq9FYlHaiQSmZXCrdGVSYJFpv/gFuUsuy83TxIqHrirfrG2evmKcBqNQoCKB927E5UGsWyRdafgCVQofAPb2ivadWxcvnO/uacc5yfX1DU/nh3RNOa9J9heDRTXM5n5nPZ1rvXL80rtcrn375lXka+dOf/sCHj+/oOfHy5crry5XHpwfoFWMde4oYC+6gVr9/fmQaLcqBlQo499c33n77le3LZ5x2lGVDz5rSGsv1FZ8i96CYwg+EMBKmma6+sG8buXSmMUCvMpe3hn0RE3BvieWrnNiVdQe2XzF6S/ee4XKh+MBaCvrLF/K+iuOITuuF7XYjfv6ErVImtVZuGTlntnUl7htKGXIRU+c8n5geHmnakhq8ff1KvL7ScmH2itNlYJonvDP0WtC9kJaNe1pZNfTnZy5PjwRrqdodJlFHzYm6b+zrgkLTW2fZI/e3F0BAr8MYmC8PxJTI60bO6bs23RhDCEHSeVai0EYLV08oD5lGRWt1aAnqcauMxBiFPJ8r6VDNn08TYRhIbaOUbw8DzWA0rQqDrayiUhi9xaqGMgZnRzpKMDm1UHOhKkVVB3XAB3reifuKun7BKHW8kDb6Eb3vveKcxXvpILUUyXGj7HJTNXkkzDM6OLxSbIdKpAPjNEKFdVnIuaBbZxoHHp8eGcaJWiqpNWrLGBe4PA/Mp0f2dWO5vmGtZR4C4zgwDAHoKBVIOctfDUzv0MA2hJ6gFSYMGD/LGLQUcJZYk6gKrcO4AR9GrNMUGgZoNQruKIsaBKSGYIzGONlxmm8OKiU3y4KoKpwfaE1RsehusUaL/qMUVC0MwTKOE2EcGAYhgWh9vCzGEfv4I24a5VC3r5S4kGqhZnlxruvKcr8LR9Eaem3sW6TVJj6276PfLkVZPzBNJ8I0sd5uLC+vOF3RmyWP0uXbj89cvb+xLAs6eFqKlJjYa2eJmVwVj8+gtOhllEJsBR1i3GXv18QZpgFrLM576UNpA0oT942mZIqkuggi8UG+N73T40rL5Ts+qvZKr5ByRtmAdZ2UK7krShQzhQ3ioGpdRJ81CV1fZylrK2vpafubnvN/1y+pz/vK08eZ6TSzp4WUN76+vrFuG/Mc0MiJ1xtLMI7b2xUXAtM00mrj9npl3Arp6xu2w9vrG9fbgg+eh3lkcJ2nd2fm80gtjfW2YGOi1IptYLti33dun7+Slp3eO8MYcFZTtoXb9c7//R+/07XH2MB8GWhqp1wjtsN6uxFy5/b7rwwPz9iHQjdOrtulUO6it399+SqKjuAZhoFhHFBp5+XTJ5yxYDX3Y4xjjeXDjwPPzw/MpxHTFb/8+y8s15V3TxOPjw+8bhv/+j/+B+CYxpHx5Hk4P+DCwPz+R+YwoI0l7wvbeuP+8gI5oTvSA+qFMHjmccAEh50msX4qzcvrK7UeN49hYLhcePeHP2HDiZQy199/5WuttJSoHMRvqxjmiRTTUXrspFSwRhNzZllXgvMYrb4z9rq2uDBgdKcUcVWlHNnWO9ZK58gPE9oZVBFoq9YH2iYLuLU3ietrJW37eiT6Uko0ZN/lvDALY9xYt5WUZDe275kcxbbr3LEPMBLRFUOhjCT3qEi1YoyVUaICxYAeJrTzOC2x9LRH4sE/LK1LB6cZVBPzqUND2miqQ9noRQjqyhhqb+xxo/cmMfxegEauldt9RcWI3WQEB4o9VnptTKeJ+XSSyHPOtJhoueAfpFfV+dbrkgPN4D3OyYvQXhe8ltuUNtBapjbZOznk96W6FJ+NdUcvJ1GKx9Hk6+UcTRlsyUc0HxkZ+oCxHlmxNDSNdMS9jZMHfU1RvF7KMIaBeRolbILsqvJB4bYUWoeiHM1JQrBoxVAyZd+gVOZpYBic0CqswRgnNzMlliTjPWZ6xs4TLe/E+05eEzlmlBkxBmzT2CydTa0UzVh6rjQq1gX8OOKHQcIe2tK0xlDQeWWgkFSjt8K2J3LLuAZNO5Sx1D1yX17JStiRwQfyIbEM00k+58oeCvgqqcre2G83+n7HWZExDgeHr3MkHxHQrTpu0HFb2aqiuw07SV9Rg9gS9pVcpJRunNDNc6rorjGqEMudrVQJVwyD7LZaE7N13AX6mxJRdfy2YJ0jLf8AL6lSCl1VTk8TsRamZHBOQ5fwZkmNwTmc6SjTUHTinvAucD5daKVC68zTTHq7U0oiOEMw8P5p4HIamUfP5B25J7Z9xdZIaZ3fPyXuy8IWd14+v/Hydsd7x+U0oenc36788vvCL79fuZwnWnvAWc988fRu+fL5znorxP2K+/RXzu+emaeAGQa0cYRhYp4nfiezbjumClbm/ccfMN5xtT9z++u/8fX1DTuMVBTTeWSaO2EMDFPAGsvLlzd+/uuv5FQZgpaC6pagWymuKkPaCne9cQ4DTw8PuHE6OH0XXr58Zvn6ypffv5CWhYeHM+Psqb1ReoWqUC0zjhO9Gx4eHqi9M8xnwnxiOD3w7uMfcacHWuucxgGbEzWubHElxh2apOqMdSjj0K1SXSOMHjee0DagNNSaUWmXGCwNZTXaKGzQNDTu6LlpmXGglYNuhCzeHUZPaKXYlhu6Hw+T1g7mWj9SSYmcDwaeuBnAKHKRm0EpjbTnwyckNtZGJ9WKSgVtATr6ELzVVOWlRRFhoNVi/Y0r3lq0c0JLKJkcd/ZNfnBr2mlNDgVWK2w3UBTOO9wwAEq6Rp3j6xcxvRHMTEPTlKOPZ7QW5lur5QDNRujyYPfDBNbhlObMiT435mHgMo/SI6rCjLR2FLV7idjeZBGvFdFpelU0beiI0Xdf7+TWZc9mBYiKFldaq1k8QlYUE7p3VI0Qb6T1Kgp2rdC9ow9DbW98v/0qY0FLj80YSysJTWcIGmeVvAyVEWZjzejgMWqgA3s34ALJTMSu6UUqILUXKflnoFR6TWSVaMaikNI2ytNzxucCWW7pGMc4X9DzEz4Ehn0jnK6UuAjBgU4YT8RUsX5gOJ0kxt+6UEJapbWIWjcG3QnngVqs8PCcR1mPtVYOZNbgxkESDcaKILJ3fNNMpwfmaYKSqUUmQK0XekpQ8+GNCtgQMG7AHGqOb0Eo4weUCaRU2JaF+Hqlqs5wfqA9vhPlx75AvH0fR2vER2a9yDubquSYqKUJYSV3eqzULp+DDOT1jbIv9FxFXWMta4x/03P+7/olZazh3ccfGOcz6N9Y3+6MzmKMIsZErI3rtvD+WYIU0zSzrRvaWB4uZ9yBxsEZmoGHpzPBWdK+8vRwxluJXO/3RUaDWuOcJHtSEWVAKsKQ014zjg7nLd0YduvYlcKMA8NpJIRAq4ppPtFyx3jDmgVrc6qFmnfqcqcbR7PyMDjPgfPlxKe//kZvHWM883zBeIN+fMLEVSLYfuBpmnHBs6/SBeq18+X1C8ttYZg8ymbeljtb3NlKR7kTw/mCcZZaCtP5gYfHJxl/KY1TBu0Dp9OF5/fv0b2yvr0xBIv3R6KqyYOkpEzS8t99fjiTa5Xb1TiKBoBCjyu9FOp6hxqxujMHz3kSmK12Hm0HIZvvOzY4TqeZeZJYdG9dovL7LlHhXkkpUrWkNoP3MASqlReBMRqFwRg5zY/niTA80JpGu8/SNdGHCcM4tNNQHNEYigOUoiKOJKMNfhwwzlFywVhPfXvDOkfwQUCcRpNyoezpICfIgxStMUeJN8ZESY3YO6M16HFi9A6nNap3lAJ3OKJKEStwa5lYGv0qoYZpPjFMM84PUryuVfA17TDXIkDbpi1+mlFzww1B4vbLjbQuUDs4SzVCCZjmE/YCLScs8hIoWUZvRoH2M1ZrKbAeo79SK+gBf6jRjbOUEln3Ss8SsVcNrHayZFealgptj7g94r1EydO2sH79RN4XDA3v3JFMU6JsP6SK5lBzYERFVMnkVmgpkjYx0eoAxgaohZazXByP3YiuopMoahc1R1fEnMn7HZUzJYN2hmGYMV54jp2OPgrefflMTPIi7WlDdVGue6MkSVkLyTuUGlHBk2vFj4bx6PP5IaC1oqWEVkZI4VleFFqJ8aV1uUGHccR5qTbo3rHOo2M+dupWXGE5y8jagi6RtN7loDKNeO+Ayj1GSqmEQdJ/Lnisc6KHV0KYaAdgds+F6+3OstwwVijnvVbKMODo2K4wLnzHK2krt82D4IU15tjrdvK+SD9QdVG3bAu1RtK2QpGwlQ+BWOrf9Jz/u35JKePwYZJToTXcq8xFzeBJsZByI+0JrGWcJlKFf//5/+ahNs6XB5rS7HHndLlwebyw75uc+KLHoliuV1pOjMNACNI+32IRtL4S9XspFesd4zTQW+N621HTGfv+mffhkfHyilcdbR2vr1fm2jG6M55Gkn5jjRX3ckf/+d95+P3K08f3jJeJeR5xujFPE+NgqfYbFxDIGd0Ll4eLzMUR7phR4kpyXXF9u3K/3ZjPJ57fP6M0fH298/r7F1LRzA9n3v3wR8I8oHrlcj7x+PAgu5kjLVQPUsLju2fGwbI9nqHXo9hYMV3I0/IA3g9eWCPlQk8FP2uU1ex5h2UhrxvX338h7XeskRGZdVIg1Nbih4BxThhfCpyT5FMqRWK2xsmOZN+gVVLcoYnQLUxZyoUdWtekmDBdgVUYp/H2hPUBYzw5RVLc6aofL39JG+UiJ/3pLGDafb3BHuXGYjQmWNnBeU8YRrZt+04jcE4oHi0ic/iGdGtKPcq6oo1HS3ZCtW8PYkU9pHQK8N4L6FPIevJQOeLXKH3cJvT36LzRBoVCVxmPai0tfmW1PPhA9B7OQQioWkRzPoySIp1PXC4Xecje3mSMdthUUbK761roG85oWpb9xqwt0ylAkCV7zen4Wg6kjnSEtBxmoFMP8G9dFuA32nZD00nbyr7dpPBsNb3JSMoaLWGXA9yrjXTEvruiUpbd4P2G6plhntFjw1hBKPX9jgaGUXaFvYh2I9UN7CD0cislXG8VbXTo8YQfT0K5N16qG/udvLxRly/kLvUCpw2dArUT375S1zupVYwW8r02Gt+FqtAP51KrQhBxqskDHQ3GHP+7dMW0k2CQCAsN6MMiUJuQL4xQMnIpUKukaEsUo/T9JoBpJ64pfbz8tLVU5HtpncMaQ64JehEyUYVUKqV1tLH4MKC1wtBReUXpivIjZhi/p11R6sBWlcNJZvDeE7RQYVKvNHNE8Y/QTD32a9+CQW6c/lNp87/49Xf9kkpb5t/+5185j57tvkKuLDnz8voKXbNtEaO7NNOVnHbH04ncOk1rzk/PxF3oA9Ya+rqCUgxhhFoYhoHrtgnRwA+sW+S2SOR0nOfjmyUnIIxnWRde76+00zPv3p8Yz170C6XS0fz26TPjsvHjjx+ZLheUf+H69sbyb5/49OmF5/PMj3/4yMO7C+eHM6fTmVwb5/dPGO9wDxOxZkzOpLRjaxJf0HEyaqUxGCFHOx/wQ+bp+Yk//ekPOK85fb5BaWxJ8/TjT/z4X/6F8TJjW4FeCc5+XzzXKm6ZvRaGecSohg8G3QVzpA9NQk6JnrKcsJwQoq2D+ek9l+f3uEmIBbHcvnt3hvlCGAchKadIroXRjYwPz3Lz2GTsVZUhd0PrCtuRnVSTsYbREvstubEuG6U1vAmHWE9TtsxyX6X3YjvjKXK6dMI0oA59ed131m2TnZITsWLXBqstw3TCGk3mRs7xGHVUFPJCkpdLFjrH8cLQ2jLNZ3qrtF4F6dNFve1NENSSgV4Kh19WRIRxI6VIRySDDSNqFvj+fa2lHGJEL4JII2I+KQgX+r6iqogbndUY7ch0ao7EZZeRXM1C8zYWN0iBdTqdmeeZfVvRx8Gk9f79Z6zXBjVLCq9WSj4suFrJfs0qmtHk3GgUlNM4N8pLDnDWMDhL1ZXWHSkl1tfPpDdkpNck2GCdldBHK9SaKOVAIx3k7nrAcKmNlqXM+82gy0HoOH4YhRqy75QUyTnR2gWMF/p5zpS4k5TB6cY0jnhlaMpjwoQJw4GvqkJiMBat5EbbSkV1h/aBWmWPl9P1YNc5Tk/PhPmECZ6OZ1tX0r4dtAkxFLdaqb1irUM7J6qU2OnKMJ5mea4gL2ilZce3JynEay1v6VYSCkUwGtsrNW9405nHQBis7KPSjrfyvUALrb4fKpwYZQzY0NSmybXRUbI3GwfRyFgtBmMnKK7aFQ0lHj0lab5OFzmpdWgjEfdWAWOxwdCdx4cTxgZyXKhxkVj8OBHGEf2P4JN6u2389//+Z97NjvM8gnbE0vn9ywuPl2dSykwD0BvrnvDDzH/73/9Pvl7fyK3xeHlmHCdSjlijuZwu3N9uxJo4neREAcKv2mLi5e3KfVkJ4yDdgFEWofctgdIwjILOL5X4emdf3mj7yuQHmpZTy54KsXaMc1weR9ZYUdrxww/PfPj4xPkyo41mT41+j1RlCQ8PuGCoJJa3z6ha2dc7/b4ICsYovHOUXGWefBl4eH7H/PiIM7DvO0OYeJgGfvzpA8mcefzT/8bjDz8SBk9LK6plTJPlLV2gkl1mAXICboI7sVpBzqjeZH+g9H+aa43ndD7x7uGZ8/NH3DAJrTxG6VfYG3oa8WrAhUBpnYzoRYyfCMOEHQbC3Lh+fWG9L3xTDbReyPtO3Des1fjhjOmNYbyQj9GLak0K0EZhkHm9s4GSN9K28+X6b7hJnFumdxoSZ97TDioKO9F6rJcgiGpZRpFVkfMuqeejlGuMYhg8ALVUKanWyulywbhZHib9KHH2JglEuiy8U6IpcFmcSSlG2sFMtF2UJsMoD3oxzcp4RxvpKyljCOPEMJ8lNbZv5FYpe8FoGYM36wlGdPXrugmxWjuGeWa6PGPDSCmyd0rbTt4iNWVaa/Suv7uEaI10//qdhpFyoSnROaSy0O8aZWSEZKgYq9AN9k3IHk519HHFrbXKiyPuxCagYWO1oKa0FHtb61IybvJ9tMqjlQUaVGglU4sICZUxuCFgu3TygAP704m5sa0b27YTc8EOM005Uaq0jgsjl4cLwT9ilJAVSi1SqF6vlNKo/kFebv6M0gFyIZdCyTJObDnSDueTNQqjOkZrhjDRu6bEndiFiWmsojdNzBulCNrKKPm+1lox1jNOM94F2Xkd54SSdtZ1IcfIMI3HaNXLSN4ZWtyhZ6wF6+TrkPfEvm4UFMoL33FfN0pKaKOPG7zUZ3pt6Ao2iKLHeysxe7oUwJWSn9O4Ylo4ILQS1mq0w3MnL63aGxWF9pP8HM0PhOGMNZa0vVFXh2kNFQas9+T+//1E///96+/6JbWkRneyO8FYtlx5u65YJ9fl02nm4WzQuvH15RVtNx7eCYGia0MqmRpXco4o58lxI60r8+nEniO9FObLmeV6J8XMdJrhSA0BxBhRWhEGOV2FVikpQ4usn3+Wh3POJFZCGJlPZ/QQRHntOh8+zJIAy5bnDw88/vSecZxxdsAaL7c7Gttyo9Q7LW2QCtsSue8bNRU5nWtN3MQ9RVfMvXC5PAGKLUaub1d6iTg053minp6Znx4x1sniXSt0P5r/rVKqnNrjupOT0A9aB+88IhxXpH0RWoY7xiKloi2M48Dp/EAYR5rWMtZqAJ3Si4zYeielKA9qJ214VQtxEaaaNQ5KYr9fqc4yjAF6Ix/l1NoNtlW09QynM4ORoMH2+kbaNhwwzAPz5SPKjsTtitrfiF+/kmqhN4v1no5iPouaoeQCTlEbLLdXakkM3qKc2GvzLqM9748lvu5yK2nlOK13IUNrmE+zjA9zYlMI+qcrkfh1qN8Yfb3LQ87IIjlHEcl5YxnnWcqgWQ4MJR+uq5xoSklJ1nuJVZdK2nfSutCKoRpH0jvaB7SyuOGMPuLPbhwZTw/HZzgJxLV1+Qwd45fe20EDUORUKPEqcF317TBiUHRSyvTcsUqkg82I5qSlnV6yGHLTzppFyIiStFdrUqk+PhZywFNSDamtiGzx+FrmlDHOYbWGAjknqhLhYGtCNTHCpaCUxBaFl9lrlYRqSgJv3oVHGaaBy8OF6eHC/PgARtO7eJD22yvptpHjRi0N/CPdn1A2oFoh9EaOK2V5FYdX7+QjIeqUxsVENyutH/vIJh4yFALArY1qvahIjKXkQtzlhiQGhiaQWy07vFyK3KB7ZzhK58ZZDEKbN6oRu9xqlVHU3ogpiS2hNVwYqEo4hrlEWrWEQQ6IWlug01TDorFa472kh63qpJLJ+8HyqwVVMnSpZhgrRV6lLL0LJaZWSceqI16v/YAdRjBC6bBGYazCVMglgmr0/A8QnLBaoqxjGPHesNYrW06cpol12zlNA9Z6gvFc3+4s+ytvt1c+/vCB2b+HPVJSYXm787pG1vt6zI4b65aYvME9atwYqBreXWbOKbFtO3vMhGFCGVDWMo4zed9FIbHt3FPEGc++biit8KNnfhhk1t0rTlvO05ny1Pn6FqEq1reFYTrRnZWCXCuEYLGXiRw1a+68vPzO29sV6x2pFB6GM01pXt9ubNtGr53SEkY5QQilRF4sb5+vfPz4gfPHj/THj1jngSYUhwOT0sthpo07OUdakcVrP5b6ORecER14U5ZKPRbaitYyBhlhKCtYGZVXiAm1buTrV1rakWm8khuPUkJ2CE7CBSmSt/XwBUmnqdIgCMlaaYMLJ9Tx0NY1UvyId9/o1IbSdlS8ERgx+x1tIy3v5JppXmHaNySPbH20liWztdC7lq6tlvY+Wstpuci+hd7JtaKplCShhdJEWZ62nTAMaK0JQZJZUCnBoXFYJeXImgvOyEOPozRsjCIoQzcSq/bWELwXSn2u9JSOvlYkdwHPpmXBH8bf7eV3Xn/5GUWhziNLrdSu0OFEmB8lIemDFJm7JqMwvWFVwx1c19SFWUcttCal597lcdjViPaWboT0YK3FOUNogV1n8I7eKqrXgxWYURSMUcehR14mPnjGMBzjQ9AYYiqH2UOSdvWgxqMz1mphDaZEPsZVKefv5W9aprbC2jq1NnKFYhzaODCe3hR6HGhWRv1WG8Zx5PL+ienhmTDO4gZrFZ0zOg2025XUOl1ZcSK1JNTw77F0Q8/iYOomwrHja7mw3W/y83fswbUGK4BF+TwZjRsc7dDMiA4lHCgtgUIr3cWlRqWmnRgXXJA/k1aakrIwNDXk3inIjsdoqF0KxN1YpsskxAglKdAtJklZ+gnrvETQm/AOrTX4MAiFRsvurtYm/CkOUroX6Kz3Mg5tNLQRlFJphT3vKOtQ1aG3jKtI9WZoaFVpbYcDFty7MATj+g/wkvJG8TiP9Fp4u20sOTPNgvCPKfF2vTL5E+H9iaezp8TIL//+hbIqbDWEHy3BOW618cvPv6CUwTjF8vsnelVsWtF64fx0kd1Aq0caTDMMnqd3j2jvWJeNvO3o0mh7kn/ukPOO0Vr2S+cTyg8oLy+p3qrsKkLAusJ9WbktN5Zl4eHxgSkM0BXD+cTjD+9x05nSDfvXN2614psRsCuIs6gkVK+EYHHWELcNd9wCrre7jDjOD7iuGK2jl0IpBaMVrQrPi9ZoKEqDPRbKtkoJUGvoMvJLWnYEOSWcanRn+Wb+U0oJ1r8W9revLNevrK9f2ZeV19uVGCPTNKGHAW1k2Wp9QB2jmp7l9yAvD4334t/J+05WoI1hCAPOWmoVgGh6rTQnp1NdCipXUsns1zuf//rLwfZ7wFhNTRm6ktJsk7FRrzJSwxjZuxzF01Yy7ZvWHI33A7U3WleUIg/F1qSYWV2HvR+oHk3vjZKFdResOzBT/nhQa4le94ruBd0l1mud2IhL7zhr8VZjB0fWhj0l+b0ZD0aYdPREvF9pObLdXugtEcYJM87YDqpJPFh7i7LQVEWpTs47fjki0HGXdKQfoDUGIJZCLInUGtpLudU+/CCfUw26FUxNtBzZdaLrTtED2oKlYJvFO0PMI/W4GfleGb1lDOJFU10ejjl1St1oIOT5wxrbW6N2CSZYF7DO0hpsMRKLJP2UkrEiTcmhCQnpOO8xPsgDonUp0jqH0kDJOCsOLUtD10Srsmtr3/xdVf5Oa7T9Ri+J3JHovPPH6LbJ99RYwnmSF9ARkIglY7UmOCuhj9Ywzsqet0uxuffjOaI0HE4yycXI577XLB6odUXVgndeAhM5iYbHeax3smvyGpqnV3lhd2txwdG7ElU7nW4NRks3NNeKo4HppFTpXTGMXlQywLZsbPv6PeRQq/xejTUE7449q6hq6GC1dLAanlQr+7YiN3Tw652npwvzacBZkXHmosgxH7f/5W96zv99v6SswavGtq+k2rntiWkwTKeRXkWKqJRcnZ9mw6Cfsc3zcr3zr//vfyPFzD//0w9QK+fLmdwUVTVOpwHVFGVbiaUylkqNmZw3tpwI3uOcpddCq8KFo8B6f2NfVqiCQNFaMYwD82k+0kcnmrXUElFNFu5bFBX6si3EbWO9v2HySrEeHU7Y+ULEMQ0j4UlzXnfygTsZjCI4h+6N91ajqpxm7RGXP51OoDXblkR7cHlPcYE9JcouSu9pCILwb+1YzIp2RPsAKVHrERroVR6i6tCq94ayBqUN5egQdXb2+w11BCK2188s1xf2bWO/3Yl7wvSKdxrtJ4ZB+huVYyRxHN60aaKgyIm07Sxxp9KY5xl96uhu5YcybsTtBWWk2Is2pG0XmkKrbOsmi14tltltW9ljxHnHOI1Ya4QwUTIaf3zPkOV9bmJXNlLmdW7AKsilkksR4jhicnXOUgfPMAZ88DLHP2ysmoM2rRJa2QOIKjsf26GmLG4vlPT7WkFRUaodD1RLXIUrqJ3D6ZHar7heaGmlx5XJNubnC34+k/2IMp5cYF9XVO+0EqlZTv29VPqucNYcJlqHRsto1nrUYQ8OzuNPD4T5Af3wE4P3mLKjtxt67+S043qEUolNCODGNILq8s92oGsPtRFU42GyOMqhvTeklCg1ko/uDb1hlZN4Mx3d7HGjMmjjjqSYfK6tPbpvqIM63g8Mk5JxvJUbhFYKrQ7ag2qUVulpJ69XslGoEmRiUSstF+K20pQSG3FtQmqISTBXtVNcEHL9cfvQZsL4cNQHOv3wMBnrUF1cYq1XvHYSDGmdXASjVGujHyx7bQxYLd8LdaQQlyv7csdbC84Jk/BIkhrnvpf+FdBKEmdTF2mnNlbcW7XJ6FYJ81PMx6LZSkm6idaN8jPSIa0ry+2NWCo2BLTqoippjTAKT7MrGQF2gCZBjt46qjV6LchZt8ozYSsEnfD6LJOBLvtu2XPLX3/Lr7/rl9TJO5wx+DEQjMNsid4LzsoP/OAd3mp6yTjvefxx5KcfH/jv/+Pf+MvPV/7j33/j8RSwRmGd4b7s6BAYzme0MmRrGJ3F2sCyvhLvK8ZavAuo2lmuC+f3Zy7vn2m9kWqlLXfpDqEYncN5h/deFpqtse/iF7IKYbtZzek84kuHXlAUoRQD09OIe3hE6YGmAtopHp8/EowVzIhSTPOE0dDiju2NuN5p+4J3nmEcacpipwfcdGJ4+oi2sotL+04vBdpIGGaclxGVkMKr4IemiTAEkel1+SFOMWK0wQ2eaZAHd+8d7aQ78/r1i6CKFJQokX7rLN55epX5dWsVa0RqJ7MeRUPTtcRlFXIjoQrWadtWuhLifNkVZZOAR61VuHreoYeAtoGiDBqF1w43XSSF1TW5KPaq2TJUa/AmMEwzVkNbNzQKqwFE69JbpzeNdhIMMc7SQV7Ee2S5bygNYZTxRxgGhmlkGALGSEpO/niFVoWo0A7SQO5N2HBGQ2vkKrfaVitWy2jKHH+pfiTCUGhrCZM4ltp2Z79/pW8L82gIw4gKA10bqtKUmIjXN8r9JpgeJM4dfKA/f8A/PhKGE8YPGOsFiTS9QdzwvaGMw/gJYwd0zQylo2pC9YgxBT90AlKc/bzciWslWgiTQ4cB7Uasmagl403HBwN5Obp1SDDhiP9rY/HWYb1HGxkfm97Qx231GxjVWoNWQqAwymKswx2l21xlVF33hDZH8MdompKbT1cN1YscXm53eq3fbw81fxN9drABO2t0kdsVtWB7oxfZB7ciJmdt5L+vlfpOZHdDwPqBXBstJgD5sxgjTMPWjp0ulNaOtKuEqKyX2xZN/hwp7eSScIf+QxkF3Yrt2zjZkxkjz5qGmJ21QzXNHjf2fQfEXZZrwzrPeZ6ZjKHGSFw3kXEa0L2Qt0jcNlFvaDFq1yr1iaY0TVmaO0mfUf0n9b7EHVJiWUUDY7XkAZzt6Cql9fW1UIOQ8GuVn3mtJHj1t/z6u35JTcHw8acfePzxI7frlbpuDHMgLhv7TRw0g5PkTaxg4oZSmQ/vZzAjX1933q6vPD5fCNOASZUwT4RplsBDbczzQcg+TvtDCJzmEZyj2JHnn/7E+O4jxmthcrXO7csX0Sz0Rts2tPe0LXLbM2acOD0+0g1QRfF+HjwxF+gDe9xIteGGgXB6oGiF0hYwpCI7JEOnH94jMwWstqw5kVOCIotUYw+YpbMMlyfc6RE7nVC9UWlYqylNwh9aG5wXHIwA/sVVpIzQq6vqOOMJk+CGeqlyytcG6Fjvmc0jpTbivrPtu+xQGmg3MIYJMz5Iq329U3JlXzb2PX0nZWMCygSsG6BJIRTV0VodtO2OMep4uWlKOZa388BwvhBOZ7S2+H1j8w6nnURxo6TMaoWmLcYNKO3pRsyzQxhQbkeVhO5HqboWAXV6YdGBxLLbEXWuJaN0xxyBB2ctxhjGYcRqTS9FFvlVuly9fStPVkH0pMhSMt45xmlCKS1Ju9YJgyMEuWHCIc2MO1ULZqgeQNZ9vbMum+Cq3EhvGlUhlsptX1i+Xvn66yfZITmLcXLj9w+eMJ6YHz8wnJ8wfkBrRy+N/c3Sr19J+4ZqlbYvlLJg8opyjuANwTUssnfSFCYiLt4pudO8o44eY4cDhVSF/qAUnSApOVZKWkjHrtNbi/MWbyzmoO9jxNlktD6+d4eS3GjRndDRLqB8AOflMxojOYp+Qvd+uJSsYJzsN/VKQCP7nL1UYTmWRM0Fffi/UBI500YDVl6C1mG0IuVESTImTl0KrMMYMMcO02iN8h5nNMUZvHb01uRFmKSuEeMO6IOmUWkIr/NbGEv2nHKbHWchxDSl6IDzA2E+UbrchFpJtG9w2JJRVhiWMUUxQiv5+vVaZe9Ep8Z4jI9FHNkVyN3d0G1AI8+NUquAkP2AHmbceMae32G99EFz3On5CA2tK/u6Sp/OG/l+Woe2ClphXzbyvktPy1m8HWm6Ed0/wE1qHkdsGBkeznQK2nbG04R5PNNLYVtXpsEQvEYPMypn3r4uoD1a7wyTOqCc4L3jdD4zPz4xn86UnOi5cbmc6SVxuTwwuYAfPDp4unfMj+95+PGPmPkBpToPTwtv08+sLy/0KoQBVGc8NRSGuK7YDlsYGQaPV3KSNlSCARU8vRY6GqUDSllZymoFztCjYs+RfL9Rbq+0nFD5zjhPrNcr2+2KRTGdJug7xnmG8QE/nXHzBT9M9JqoLYI1gId+XM9Twhy4f2eOnk6t8jAphXrEzY12NOHtk2uhFfnLHjoMjn+H+lbGPPYCwziJffjtlf3lC+uyUnvFhYHprLCDRGKxhpJWKf/RcU6YZNCxWn8veqKElqDCjB3OKCPYf60y1Ma231G6izajZ2KSMmkvhaahFeH3fdNy9FbEKos6kkqZqrTQIIyD7y+pgui+J5z3h4bAYY0IDHtpNAXGaHCW3ix0ednrgwfXWifGRO8ZbZKATI0lDPL30iBm+T0v1xv7tso+s8r3QjWpRRg/4IdAdY6lOyiaW4p8/fyF5fcX9vudMHhqLzgVGOYzxgd0qdhSsaWgSGgljEu17eTbjbheUaoeyKiCWq5wJMxmb1EtQckkYN0ieYm4YcaEgDkkgiiFTglTI8YI6gulaCodRlYtfTMUTiuhTWiNC14etlXUMDHucutvh3RRi7q8u4HqRhg8rXcB0x7JwLrvok+pmdA1ftASVrEBrb10rtKdljcoCrqYjnOuqH5AYWUAi9FiCE7FyqhUyeeylUIvRaYM1qAUbK1Se2M+X+RWVQ0lRXot5CxKlxzlYW2OVKFSCu+dROi1ois5lPrjlq+MpE+NNsdhzhLXjbTd5ZCn1Hfrb2sSNFHWYJWXvdIx0nZGE/OOrgbt5TZWOyQlmCoTPKY24raSV5Gjai29KDfM2FFuUcY6OTRl+fqqvMvY0EjqUyuBm2sEYKvRqCZuqnboW7zz1FqPUvP/+tff9UvKj4ESF+6//Mo8BMw4MY0D0zwzTgMvv/1G3xZaV6hhgGnm4fxAWlZC+8TpaWTUBmcFyzN7eaENNaM1DE8zwzhAtrRtwSqw2pCzZjw/8fzhj1gv3SFVJfG1lc6tdO7LRsuFx8uZcD6jvWHQnTVJMmtZd0ZjULESBs94PjGHE1PKvFyv6PMZPQwYNNROLYVSIqRI2zdBjLTMdhOpWXADUS8SyNAWPT7AdMG9/4kWRlD6OIkG6Beiimh7qA96/p6astbKjWQcZETUO+kgdKtjTCW/pyY7qrSTdoHbGiddD+v9oQmQ3ks78DhWQ/WBOgzsbzIKVKqzbivTMKBNp8WN7fWV5e0LFLHq2kMi6K3iNDmc0Sx7oSpHdxrrjxRTlYf/si6UFDHHn8c4x+AH9p6Fh9c79+uCQZFXGUEZ1VCto3unN0VKu3ytpzNWqf9EAvUqEV7l8MPAMIwHdPUovh6EaJwHp7A9yLJcaSn2agh4qOMhYkRutk7wPdL4t9Q1U7pmjU0IDiVhu7wE0Qo7TqhxptlAPPou9XZj+Xrj/vsLKW24yRIGmQyEhxOnx0eMHlm3lfrXPzOGwHTEkVOuXL/8xtv1lVwSJkjVQh9j2XVZ+PTpV+rhzhqDp2uhiSsjY3c7DEc4KBwJtoRBDj3SpWtkDLEaqhnQvYiXqXT8OBMuF4lZ644+hIOdTtHy0m51lyCFtRQXqHakKy9jclWx2gvVOyfSGrFhpAakcO3AuEB1gVIbuBO9arrRWCXyvt4jrWyY1mVKYC21yfdTH/0C7Ue0DSQWVEmUrlANWk1HYrTjBs/oHgGNsYpYI40iigwr/EitDcqVg6xi5ICUK741QQ11UDbQmji03OE3qyVR9o2aMrofN69aD7q8PdKiFneUnBUIhzHt5FJxw4jxHmu8iD6dofsTDGdMSegY6ZILlodsr+he8Qps2qm72LjTeiPvi4y0D6GiMQJkdvYQULYqo3wtTEWlNSLulZ/rltLf9Jz/u35JzdPA5XxiXe58fn3l3dMjD3/4yDBPXC4XvNG8/v4JrJWRmdZCrbaeHz/+JEm+1ys2TEyP75i1w7TMSMEZS9OBkgtpX1E5y4x6dFzev+f04Qfm8wlVM21L5LRxu76y7pm3tfPzb6sUF7vjn7vCGcU8Drg58OXlytvbG6Mx+F646DPju5n5/Q+MvWMuj2gfsOcTeE+phbSKsbUkMauWktGqS9KoZfw4cX56h9KW09M7xssTejjh5jO1itpAirAyojLGfO/p9Cqzd+lBdBmzIemwrg3aHvHrlClRJHp0MKoLPbx1SinHjsGitZabmJKX67ZurOuKDYN4cZxlPp/ovUqEel9IpeHGjbiupPtVZt2tosxxMiwFO3gU8sPog6FgUcOED1Lk3teVFDeg453cbMQA2vBuoHVIOcnpjsa2rNSS8M4R5pGYdnLc2PeNWhvhcA5pLaXYWmWHoA9GGa0cPqF6pKkKOe1o5XBOHuJNW3qK9Lpj9ZHuSoked0qTxGQDzOEF+maajSVDNwTnUb5/jwUrN2IfnvDThaI9TVlMjbTtRkwFZyTmPY4e7TRhOAno9zwznU5oPGVfeH39yqd1gyIR+loKOSb2GCklMc6DpEynQUZTR6Krg0gQv71QFcARse5VOjW944OjN0cvGz1vpE2SfjUlak60mqFJutE4g9UapyUdSi+UmqU07C2zP+FyOKSTlRJ3KitURY0breyoWrA9o5oEW3JvWOswLmCthy4jVaOEHVfRNO3AQNFV+mW10HrDUnFdyXi1NcoRkjFGXgCqC82+JoHqlpLl93aUfcXuLCXXfnyPtXb44IlB0Wqm7KuwEl0A6zFWbvdxX4lppyFpVoWRW7rWh7lZunTdiGJGoVCtoI8ib+9ykJaJTJPRXO9oPxP8YQ7wgWaMcAmRaU7PEUoS7uKxA6TJHrPGSDZ3mhGbc4oRqW1oGoIKQ2lcGPHDJGXf3lBHivHbBEO3ArWRyi7utX8Edt/gLYM36BpYcuHt6ytPPzyhrcQ93Thx+fADJSdCGIkx8ctff6WuO++fn3DB0lAYPzM9fgClUOsrLkZ6OrAne6KmQl53cikMj0+cP3zAPzxQSiG9/EYtmdQKL1++cr3vXPfG1j0Jxc9f7/z06+/8+PHCME24cOK+ZW7Xhfu2o0tCOc8ZSzEC4zyFga40bhgFzVILKR/a7rgJSihlvLfywTPCWLs8vGc4v8ddHmW0Z6ykjuJ2XLmlFa6U+p4w89ahuuwPevvPxI1CRmpKG6BK4VdJAfPbg9o5h9GSdEwxyiKYRmsVcwD4VD96RbWjusSTMQPaKvZtlUTR2409/g4YtIEQDPYIH/Rj/NZLI9K53zfy4NFByNxKG5SS2HfNovoYjt5OKXLKNErhB8U8jYL2OZQmwxBED++sEAy0Yo2Rfd/xh2/s2x7qP19OMh5VLVNyI2spknKYbUtOWCtoKeP9EftdqduNnDKkxLZtxFIx4/y9tNn3HXrHW4Otjto6KEPvmtYU2s2404g5PeHf/wE1P8lLsBTq8kbRFr2uqHnG8Cgx5CFghwsuDGirBaaMBApk7Je5L3e26x2jIFU5dCitZN9mDFYblLEMo9y6y6HdUMfJ3llLOTT3ZbkfJJaALp0SF/J6pStNL+m7PLHGRdBWx3+rarjfr5RWeXp8ZBxkHEbwGLQASWtFWcO+7cRYKPuNHnc5ODbRyGirsQqGIeDGE+PlmTCd6NpQmiIl4TlaL0iupj3dOdk36iBQX6WhROgV2bh2tFKUKmw9pe0RFYesFen+RsqJbVlopVCKMBjjng6lu6RnlTHS+QLytrC9fMYoxXx5pimRGtYiGCeJuR8kfyMH63ogxUpKgjiyXgINXdG7l0KWUqL4KfVIJ0pQQ1kvt6cg6Ull5QWnnSPFyrYtsK1QMy1tUo3WSl7GIM+OuCD1uYxBbL+tN9acJVrvHDqMmPGE8f54gjQZ/7dGjxttu9HTTlMVq8To/bf8+rt+SV0eL6RVdBkxV3Jp/Pb7F+7bjh8CP/34I/PlwuvLKy1lci6kmPjy2++stzthGghjYHxux4Ouk/aNeL9S98K+bZTc5ASTK3YY8af32NN7sjKisPjXP3N9vbJ3WPed27rTNYQ5kBbFb18/8R8/O+bBM+nAsl55u95Yt0TdE6rsuCnztu30bRfVs3MSD0YwL60mKXLGnV4qwxBQLeGsxgWZtdv5wuXDPzE9/0APwhUUrNERd8/y8hDsjXQgevvG2zuCIQg4spYC9KOgKSk1jUI5OXmaJmm00qTQZ/yAVxp7tPC/AUU7Atj03hGMPNS0MbRa6F0i+r1V6dzcbqITmQNjeMBaJyghJQkg3RDXTn4hDIHLs8aEEZUyZl1JObOvdzRNEmLeQypoKwXMmgvWdh4uMzEmVNcYLbuP0io9i44DLXHs2uW4rPU3RNDx0tWK1iTdpFH0WigJQMsJtxawB0XCe2rcWVPi9vUrZV0xTQqp2g9MVvYQ1ko3iN6I20ZeV2quNG1odqD6ieHxA+7hA+HyDnd+QocRbzSmJlajWGKkGov3FtVn2c0+vIPpAaUNdV+pST6bZphRTdFLx18M4+mZITgy8pnvrTIEL3BSZzEN3Hji7Eahwx8du14SschLvOmO2nfichP4sDa07U7f7uRjtxSGAaX6cWhJxFTYm6jfOzDPG4PVBH0W/uq3XY0yGDrNOkyI6PvKcl8paZOCuAZnLBo5VIVhQIcZN0/4MJBLY0+RghR6W5E4e1dy8+hITLspERS2ninHrUpr2S8a9EFZADRoGyQYUHZibxgnDiyt5GW454y2gqEK41kqFqnQYmG9vRC3hfF0BmPJvVFi+i4UdN6DNod8cydn4T865+RlpTQNKWw34+Qz2RGuZRIFjbMWqrykrLfHAUXKuRLS6OQsxoG4xWP/JlUN+218fIhNW0nUkuBI3Sqt5bBcKikXUAbnR2yYwQ3ggoQkrCQWde/U9UaqRQ4KWqO1xdh/gJuUnwZqL6Qt8hJ3rB3IGF6ud4YU+fDuHcMwMo8Tt/TCHjeMc4zzfMx9O6U20r4T72+4EMSyuezUPbPe79AVfprx04Xw+J7pD/+V8OGfaGhysez8G6/3O3uRh12tFachGLh1+UBdl8xvn++MeyMCL7eV23UT3Ik1qIMuIAW5Tk6CUzJ0nA9ycqdh6DIjPwCfPjiJEE9nwuUD7vQENhwPVIPSoif4NrsWIZt8EMULg0TAD8SNADPlA9lrppYk+y3rcMFjTBCqQNpZ15WcE+YgI0u/SWG1BRQxRmqRNJvzDm+UFCaPkZk1hikE1OmE6R3vHCkV4em1Ts6VVhHV+1GSrVU4YqZBzhV0QZWVPSdyltGbjED5XpC1VszNcV0oUQqvYwi00oi7WHWtUXQjNwLrPLaWI/Bhvt8Yv41GjbGEMAjk9SA09JS/7+96q/ScicuNGDeW6yvr25VUOt04VLAMl0A4dgOlCa0aLZ6pfc3YY21fVUcPjvH5B8L7P+HmR/w040IA1bBKoLvOD2hnjzRkP5iTCt3kwNB7hz3S7jdwiu5led2nM8YOzOPI08MFZRq365W4rYD8e5yz9ILsaA5eY9oj67pQsiVuG10pfBjoqpO3lbTcAAU5YVsm+CC68NgY55HTaSJ5y778J1tPUpCV+zQK6NiLC0wpJfxEbY/Pu0gRtdKk+41yjJ2s0bKfKYluPKo3SkrUWomlsMZM06KhMdijmyWjqFoyvey0vGNJ+FbQqh8HFI1Byc3u0NN0+eaL4C8MeG1Rdjg6cVVKwa2JYbn2Y0wsjMZ8u4vQc3pkeHwUI3gpqCpFehMGEXQax7qu8r3oEJyXF1FtNHV0rVohpnrE5ytGN4yG4GSkj64CMN53ugLjB8IoL/OSM8uyErcd1SrOinZe0p5GDqol0ouoUnqtVAXODBitpS+Yi1RAxhk3zJK2PFYExnkhwh+pVNUbZVsgbbIGUApxBf+vf/1dv6Te7jfonR/++Y/0YebldaWh8X5kW+/8/Je/Mo2TnFD3jXUTu+rD0yNDcCxxJzhP3O5cv/7O4/M7Ooo9i/piGGdq6wyXRx5++iPj+58YfvwT7vEDDYUvhTbNVKfwznAeZtY10cpC6qKymOaBqgxfrztTgdQLX9/u3JbIu8sD58uZMI5gNLkWvJJ5bimRHA32iJX2VnHGkIGUpNFtrGE6PzA+/UB4eI8JI9paSddoAEVqTYgKtUGtoLt8aNBwiOVKqeQc2WOWUeK+UdJOPwSDbpjQpxljREutlaQinXcMQ5CXwNbppaOUppVGjkm6Pyhy7bQq8M16jARDCGJMvlyYx4Hz5SJEgePBkuKRpKuicC+Hk2kaAtY6SmlAQqlOiaIXmYIXCjMGlBIHkna0LIzDmuSmY5w52lh8pxcorZmGwJoitjrctyRZ78S40zv4QzdvjMXYSfY5cSVXuX2oI4FotZFEZEmonAjWEp4+UlWnqMY4nXDGk7b1kB0KQ88cTLfutBBpDgKDcxN+eERb0UAY1Y5gvCEpRybROYqeIUDLAgHJOyY5Opa+3unXr3QLaj6BcbQko99hmgjTJLqPi2UPo4ytVAcF3XugU6I87JStGOcOM6wo243WrPtKSYmWjrixkpj2NwfRMAxM84QPnpwKWhkRNG77MS5u3+3EQcneJ6WG0gXjR4GWGiOjyNJIdLJ3x/hNkeJGo9C1khF8vhNzZi+Z0hXdDVgMXgmGiyYpvbKv1HijxTuGggmeaR5xXpJ9Sr7YRy9PeHqtNZlWKDDWo4w/6C2y29VN6A79SHQaY5hGRToAvH6YCUOAEqkx0lrHBXvcNuWlnJIoXMZhkLBIFj287p2SNra987ZWYs44ZzidR8JlkhtXK5I27Y112+S5iEcNFqMk4Vi6PlKnChRHMlBu0NrIi7/bgqpSJWgHxksyJF2+x90Kxu3AL/UmRfx+qOxF7aFRzmOGAVcmVMnSf9N/2+vn7/olldfEdDmTW2U4TVyw3LfCPMpp57dPnznPE0pDvO2k+8KoLX6eGM4XXO+yjLeB2BpLqnhr8aeJYDzzcKGaAff0nvMPPzHOF7p2lBjpdPK64tDo5shxw5uCN5XHB0dJCx+fJmL1kDKdwtv9xsvW+O3rhm2R//rxkfNlZHx8xoZJEmJNEmJGAQhhORXheLXtTlmuUpY7nQlP75k+/BOXdz9ghjMujJhDI6EV9JxINVLySm+Z752fWlFIsqj2KjufKubi2iqpVWoTb028XaF+ojw88fjweIBKK6o2tLMYP8qc3Pj/x+y8SCiiVenINMhFoYyj5kKtsug2RhGcZzgKpU0blBtorWNcpN3vh6LiW8/H4Z35frNJacWZQK2NSWt6TuQotAkXgszGexdXjgtCfQZKLXQqYfayL1EyIsqlMJkjsutkWd2zvMC/veCxQjO3ZhcaR9pkQVyqBEN8wEwzPozykjUaPYwyduydWDNYS29KXjJ7Iq53SunM8wmvG7Z3ujHE2uj7Sot3SlrwzlByBzOgUPSSqTXD9QWXIsla+jhjVBWygNG4Wqk1EalkK0gmExO1J3qWEEHbRmpwMsoqcriS0bBEqlFRbiCtipG4VkIYJHKuZW9VS6KpzrK8gVGEEBjCEN9bCgAAD6dJREFUKPQTwA8eFzxOa4KVB0+0nVKjyA1pxJqJvZB6occNnVes7jhn0L1g9eXoDHVssMRNU2LHNAkotK5AW3lhWEfOja5lj7LvBRsUxkZUEGtBzyLRLOuNut/RFIzVEkj4pp+oDaVB4i1yqGud/0975xYbVdX28f/ax5npdGZaSlv4pEKEaAhCFARHL7ygEZV4ileEC6JGg5YEEmOCGvUSEhMTNYYbI97ZRCNoFIwEsErCsbZy0moiSqO0vBzaOe3TWuv5Ltbu6Civb79DO1Ncv2QSOnulWfthOs/aaz3P/w9p2upLHRSfHaktNgYXUiiPKKWRrxZAjuMALoOLNFKGDRCDwSuolIoI/EBZfFguGJmgiKMclCGFKjISjosomYbJDMjAA/d8hBWO8asFlAKVlBOJZjQ1JZBOpWExVt2e44ZqRpcS4DDAYcGLoER1nSYwMITlAgwQTGIwpdraV8Uc6rsEQiKKVd9V/5g6hzVMdWZsGo5aTFCEpOPCcmzVoxUBZmyQSoYJI9kMi1kQfkEVTYh/wJNUpeypfXvTQBARLl0qgJFE83/NQnOmGTISSKdSACRGr4yDRxyts1uRam2Hm2tFoikDw02AWTaiMFBW1mEJRlI9BQjXRVNbJ6xsC5jrqi3dKIQYC5Tw6NVLSMgACSuCXyzh6qUCuCA4bgKtLVk0pVNKSLFURMIklD0fnuchZYRozjahpa0VufYOJHI55egbq6ibkLAMpZXFAAgeolwcR1QswACQbWlFqrUFbXPmITerA4l0BoiFNVWhg1RirX4FoV8BDyPVcxGf7ygfHrXNR7FEixRK7MSKV7xkWvBBiAIfhbEiIJRrqptIxEUKERAaIDOW+bFUTxDFgptObAIohIRlWLBjkU0hla10yEMgJPAwUhphgGpoJQBQT2rcdRHGq1An3uNWGopSPc3ZNpKJVCzbJOH7gfrjJIAZUdV4MBISjm3FquGkBGalgGlZasuBGfG2oipiANSq2bZt1QAbmcqqPBY7Df0AHhUR+h5EFMG2baURRxIGATLkEKaE6bhI2qo6U3LlIguuqt8Mpp7+DMeFJQSYY8BOpSDBweMFAw8D+GUfvrSRTbaCGbHzKwuAIEBQLiL0KghLV2ExrsrobQvE7bigIFKFMZDgYYDAq8A2mapGhOoBMy0GJgKIwIOIk1AYBqqYA4hFgE3VqxafZRqG8n+yGYPrODANEzw0VWNsaIMsE67jIuG6qhcmPqMyBYF7PiqeAOccQdlDUKogiBtHDQICz0PJUiK9tmuBlNif0hmEDWarMu+gUkFpwobDcuMnYhumbSihU8sCTAKYRBjasAxVxcr9ipJEYgwiVOdqJHnspmwjlXBj1fkIhqE8rYRQfYFgRlwOrrT8iDHwwFNzj88vDWaAQYBIxg2+v8tgTdismKaLMAyVs7SQMG3VzmA6ljpbLVditXcol2C3GWYyG589mQh8Dh+Ak06jtdmA7djIZFJoSqdgGfF5sFCLu0RCLWg8X+1QRIEHKdTfHCOhqgOZctWd6AVkIm4+h/JsIxAMyOqTkTqVoNgbC6pS01RNzIYkmESwY+83SFM5EZsGDMeGzSSE0QQzDJVE1CSYkUlqwlDtwshl5DhhVudsRFzgX/+6gqRtgs/Owky5ECJS7qlCnW/YtgtYCVjpHJyWOWCpZtU7IQmcyuBcgnEDkbQQlCoQjoRl2uAEiGIRFZRAXMCMIpiRD3gFNLmEjvY0bEvgt9+uolz04fshsq0ZOKYyAGzOJZBJ2PCCJLhUH9xsSwtSmQyMRBM4Y+CVCmxTfSBCIdSjtiSUyx6KpTJKxQJ4qQgHDIKZsBJNMN1mRDBBoQAYgSFS+71RhKBSRKU4Bq8wDkiuJGDM31cuJNUfkmEYVTFKdcyiyo1ZLNdi2g4iLlHxA5TjqrQw9BH5ATgRUn6AZCKBhG0j4iG4F8BihCiI4JV9eIEP07Ih0s3KJyu2N6FIgHseeHyeJCQBsfqEZdnK6RZMORGTcg/1wxBEEqZhwLEcOAkbRhRVTdi8ckWdSUhZLYuNnxmRbEojmVTl61JKcMGr1WWGycAjVcHnB8rIUAqhlEZsVrXykEIiDDmCIEDIQ0ShEixNJhNKLibkkKIIWDZS2RY0tcyCm26C4ZpgYYigEqJYKoMgYJs2iICIDEhLNUFz04EIBYIogAmJgAsUCiXwYojATKPJE7CbmmCC4F+9jOJvwyhfHoXtArPaZyHZlIBhELgfqe2zMIBlFSGEgFeuwCsVYDDVlKsEflUhg1dW26AkOAIeKYsIqZqPLcsCExJwlCCqkISQcwgiuLYBbjBwROCxzQgzLVUmLoGIUzXWxH2lLSk5pAhVIZoAgiBCEPF4wUColFWZue3acFMJpJsSsIUFGXhgnvJeIyFR9jxUKj7AbKV+IiQYh9rOlkpoWEipzrsiDs4lQlEB81W7iGXbqrBISLW4gHLbjoSECNSXdBg39QrBlZ17/HlSPW9KCFl1IsSLCi6qpouS1Bew8rliMA21pWfYJpgjEAUBorKnGuGdBEpBCAiBoFJGZbygkoWTgC1dJF0GW1lowQ8EKpFUIrMJWy1mrbja0K8glDJ2LmDKat5UJplcErzIgxFxmLYN22SxrikhiDikNCAkgyGgtkylUM39E8mLc3BSDsmSGIKQI+JqxwVS2d47sTg05wFcy4Jh2DAdJ5b3InV+G7cehGGIilep+T7/dzD6TyMakJ9++gk33XRTvaeh0Wg0mv8jw8PDuOGGG/7t9Rn5JNXa2goAOH/+PLLZbJ1n07gUCgXMmzcPw8PDyGQy9Z5Ow6LjNDl0nCaHjtPkICIUi0XMnTv3b8fNyCQ1IcaYzWb1h2ASZDIZHadJoOM0OXScJoeO039mMg8Zkyuv0Gg0Go2mDugkpdFoNJqGZUYmKdd18eqrr8J13XpPpaHRcZocOk6TQ8dpcug4/f8yI6v7NBqNRvPPYEY+SWk0Go3mn4FOUhqNRqNpWHSS0mg0Gk3DopOURqPRaBqWGZmk3n77bcyfPx+JRAKrVq3CsWPH6j2laeWrr77Cgw8+iLlz54Ixht27d9dcJyK88sormDNnDpLJJLq7u/Hjjz/WjLly5QrWr1+PTCaDXC6HJ598EqVSaRrvYmrZtm0b7rjjDjQ3N6O9vR2PPPIIhoaGasb4vo+enh7MmjUL6XQajz32GEZHR2vGnD9/HmvXrkUqlUJ7ezuef/558FiE9npgx44dWLp0abXxNJ/PY+/evdXrOkbXZvv27WCMYcuWLdX3dKymCJph9Pb2kuM49O6779KZM2foqaeeolwuR6Ojo/We2rSxZ88eeumll+ijjz4iALRr166a69u3b6dsNku7d++mb7/9lh566CFasGABeZ5XHXPffffRsmXL6MiRI/T111/TwoULad26ddN8J1PHmjVraOfOnXT69GkaHBykBx54gLq6uqhUKlXHbNy4kebNm0f79++nEydO0J133kl33XVX9TrnnJYsWULd3d00MDBAe/bsoba2NnrhhRfqcUtTwieffEKfffYZ/fDDDzQ0NEQvvvgi2bZNp0+fJiIdo2tx7Ngxmj9/Pi1dupQ2b95cfV/HamqYcUlq5cqV1NPTU/1ZCEFz586lbdu21XFW9ePPSUpKSZ2dnfTaa69V3xsbGyPXden9998nIqKzZ88SADp+/Hh1zN69e4kxRr/++uu0zX06uXjxIgGgvr4+IlIxsW2bPvjgg+qY7777jgDQ4cOHiUgtBgzDoJGRkeqYHTt2UCaToSAIpvcGppGWlhZ65513dIyuQbFYpEWLFtG+ffvonnvuqSYpHaupY0Zt94VhiP7+fnR3d1ffMwwD3d3dOHz4cB1n1jicO3cOIyMjNTHKZrNYtWpVNUaHDx9GLpfDihUrqmO6u7thGAaOHj067XOeDsbHxwH8Lk7c39+PKIpq4nTLLbegq6urJk633norOjo6qmPWrFmDQqGAM2fOTOPspwchBHp7e1Eul5HP53WMrkFPTw/Wrl1bExNAf56mkhklMHvp0iUIIWr+kwGgo6MD33//fZ1m1ViMjIwAwDVjNHFtZGQE7e3tNdcty0Jra2t1zPWElBJbtmzB3XffjSVLlgBQMXAcB7lcrmbsn+N0rThOXLteOHXqFPL5PHzfRzqdxq5du7B48WIMDg7qGP2B3t5efPPNNzh+/PhfrunP09Qxo5KURvO/oaenB6dPn8ahQ4fqPZWG5Oabb8bg4CDGx8fx4YcfYsOGDejr66v3tBqK4eFhbN68Gfv27UMikaj3dP5RzKjtvra2Npim+ZeKmdHRUXR2dtZpVo3FRBz+LkadnZ24ePFizXXOOa5cuXLdxXHTpk349NNPcfDgwRpjtc7OToRhiLGxsZrxf47TteI4ce16wXEcLFy4EMuXL8e2bduwbNkyvPHGGzpGf6C/vx8XL17E7bffDsuyYFkW+vr68Oabb8KyLHR0dOhYTREzKkk5joPly5dj//791feklNi/fz/y+XwdZ9Y4LFiwAJ2dnTUxKhQKOHr0aDVG+XweY2Nj6O/vr445cOAApJRYtWrVtM95KiAibNq0Cbt27cKBAwewYMGCmuvLly+Hbds1cRoaGsL58+dr4nTq1KmahL5v3z5kMhksXrx4em6kDkgpEQSBjtEfWL16NU6dOoXBwcHqa8WKFVi/fn313zpWU0S9Kzf+p/T29pLruvTee+/R2bNn6emnn6ZcLldTMXO9UywWaWBggAYGBggAvf766zQwMEC//PILEakS9FwuRx9//DGdPHmSHn744WuWoN9222109OhROnToEC1atOi6KkF/5plnKJvN0pdffkkXLlyoviqVSnXMxo0bqauriw4cOEAnTpygfD5P+Xy+en2iZPjee++lwcFB+vzzz2n27NnXVcnw1q1bqa+vj86dO0cnT56krVu3EmOMvvjiCyLSMfo7/ljdR6RjNVXMuCRFRPTWW29RV1cXOY5DK1eupCNHjtR7StPKwYMHCcBfXhs2bCAiVYb+8ssvU0dHB7muS6tXr6ahoaGa33H58mVat24dpdNpymQy9Pjjj1OxWKzD3UwN14oPANq5c2d1jOd59Oyzz1JLSwulUil69NFH6cKFCzW/5+eff6b777+fkskktbW10XPPPUdRFE3z3UwdTzzxBN14443kOA7Nnj2bVq9eXU1QRDpGf8efk5SO1dSgrTo0Go1G07DMqDMpjUaj0fyz0ElKo9FoNA2LTlIajUajaVh0ktJoNBpNw6KTlEaj0WgaFp2kNBqNRtOw6CSl0Wg0moZFJymNRqPRNCw6SWk0Go2mYdFJSqPRaDQNi05SGo1Go2lYdJLSaDQaTcPy3yDSbHpiXpFyAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "import json\n", "from torchvision.io import read_image\n", "\n", "\n", "preprocess = MaxVit_T_Weights.IMAGENET1K_V1.transforms()\n", "\n", "with open(\"imagenet_class_index.json\") as labels_file:\n", " labels = json.load(labels_file)\n", "\n", "\n", "dog1 = read_image(\"dog1.jpg\")\n", "tensor = preprocess(dog1)\n", "\n", "torch_model.eval()\n", "with torch.inference_mode():\n", " output = torch_model(tensor.unsqueeze(dim=0))\n", "\n", "class_id = output.argmax(dim=1).item()\n", "\n", "print(f\"Prediction for the Dog: {labels[str(class_id)]}, score: {output.softmax(dim=-1)[0, class_id]}\")\n", "\n", "plt.title(f\"{labels[str(class_id)]}\\nScore: {output.softmax(dim=-1)[0, class_id]}\")\n", "plt.imshow(dog1.permute(1, 2, 0))" ] }, { "cell_type": "markdown", "id": "8cbe4ccc-224b-4e8a-a2a9-e2c756c9b207", "metadata": {}, "source": [ "## Port MaxViT model to JAX\n", "\n", "To port the [PyTorch implementation of the MaxVit model](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L568) in JAX using the Flax module, we will implement the following required modules:\n", "\n", "- `MaxViT`\n", " - `MaxVitBlock`\n", " - `MaxVitLayer`\n", " - `MBConv`\n", " - `Conv2dNormActivation`\n", " - `SqueezeExcitation`\n", " - `PartitionAttentionLayer`\n", " - `RelativePositionalMultiHeadAttention`\n", " - `WindowDepartition`\n", " - `WindowPartition`\n", " - `SwapAxes`\n", " - `StochasticDepth`\n", "\n", "The Flax NNX module is very similar to PyTorch `torch.nn` module and we can map the following modules between PyTorch and Flax:\n", "- `nn.Sequential` and `nn.ModuleList` -> `nnx.Sequential`\n", "- `nn.Linear` -> `nnx.Linear`\n", "- `nn.Conv2d` -> `nnx.Conv`\n", "- `nn.BatchNorm2d` -> `nnx.BatchNorm`\n", "- Activations like `nn.ReLU` -> `nnx.relu`\n", "- Pooling layers like `nn.AvgPool2d(...)` -> `lambda x: nnx.avg_pool(x, ...)`\n", "- `nn.AdaptiveAvgPool2d(1)` -> `lambda x: nnx.avg_pool(x, (x.shape[1], x.shape[2]))`, x is in NHWC format\n", "- `nn.Flatten()` -> `lambda x: x.reshape(x.shape[0], -1)`\n", "\n", "\n", "If the PyTorch model defines a learnable parameter and a buffer:\n", "```python\n", "class Model(nn.Module):\n", " def __init__(self, ...):\n", " ...\n", " self.p = nn.Parameter(torch.ones(10))\n", " self.register_buffer(\"b\", torch.ones(5))\n", "```\n", "an equivalent code in Flax would be\n", "```python\n", "class Buffer(nnx.Variable):\n", " pass\n", "\n", "\n", "class Model(nnx.Module):\n", " def __init__(self, ...):\n", " ...\n", " self.p = nnx.Param(jnp.ones((10,)))\n", " self.b = Buffer(jnp.ones(5))\n", "```\n", "\n", "To inspect NNX module's learnable parameters and buffers, we can use `nnx.state`:\n", "```python\n", "nnx_module = ...\n", "for k, v in nnx.state(nnx_module, nnx.Param).flat_state():\n", " print(\n", " k,\n", " v.value.mean() if v.value is not None else None\n", " )\n", "\n", "for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state():\n", " print(\n", " k,\n", " v.value.mean() if v.value.dtype == \"float32\" else v.value.sum()\n", " )\n", "```\n", "The equivalent PyTorch code is:\n", "```python\n", "torch_module = ...\n", "\n", "for m, p in torch_module.named_parameters():\n", " print(m, p.detach().mean())\n", "\n", "for m, b in torch_module.named_buffers():\n", " print(\n", " m,\n", " b.mean() if b.dtype == torch.float32 else b.sum()\n", " )\n", "```" ] }, { "cell_type": "markdown", "id": "305ac55b-62ed-4f4d-902c-c6f3082afb02", "metadata": {}, "source": [ "Please note some differences between `torch.nn` and Flax when porting models:\n", "- We should pass `rngs` to all NNX modules with parameters: e.g. `nnx.Linear(..., rngs=nnx.Rngs(0))`\n", "- For a 2D convolution:\n", " - In Flax, we need to explicitly define `kernel_size`, `strides` as two ints tuples, e.g. `(3, 3)`\n", " - If PyTorch code defines `padding` as integer, e.g. 2, in Flax it should be explicitly defined as a tuple of two ints per dimension, i.e. `((2, 2), (2, 2))`.\n", "- For a batch normalization: `momentum` value in `torch.nn` should be defined as `1.0 - momentum` in Flax.\n", "- 4D input arrays in Flax should be in NHWC format, i.e. of shape (N, H, W, C) compared to NCHW format (or (N, C, H, W) shape) in PyTorch." ] }, { "cell_type": "markdown", "id": "8d7e3479-bffe-4cb6-81e1-ed8f972c5bf0", "metadata": {}, "source": [ "Below we implement one by one all the modules from the above list and add simple forward pass checks.\n", "Let's first implement equivalent of `nn.Identity`." ] }, { "cell_type": "code", "execution_count": 10, "id": "54ece7f1-14c1-41ef-980a-fc279d1702f2", "metadata": {}, "outputs": [], "source": [ "class Identity(nnx.Module):\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return x" ] }, { "cell_type": "markdown", "id": "dd87b2aa-0285-4995-a9aa-ebd58ae00de6", "metadata": {}, "source": [ "### `Conv2dNormActivation` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/ops/misc.py#L125)." ] }, { "cell_type": "code", "execution_count": 11, "id": "69d71163-676e-4ad3-8d8c-45efaafd76e7", "metadata": {}, "outputs": [], "source": [ "from typing import Callable, List, Optional, Tuple\n", "from flax import nnx\n", "\n", "\n", "class Conv2dNormActivation(nnx.Sequential):\n", " def __init__(\n", " self,\n", " in_channels: int,\n", " out_channels: int,\n", " kernel_size: int = 3,\n", " stride: int = 1,\n", " padding: Optional[int] = None,\n", " groups: int = 1,\n", " norm_layer: Callable[..., nnx.Module] = nnx.BatchNorm,\n", " activation_layer: Callable = nnx.relu,\n", " dilation: int = 1,\n", " bias: Optional[bool] = None,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " self.out_channels = out_channels\n", "\n", " if padding is None:\n", " padding = (kernel_size - 1) // 2 * dilation\n", " if bias is None:\n", " bias = norm_layer is None\n", "\n", " # sequence integer pairs that give the padding to apply before\n", " # and after each spatial dimension\n", " padding = ((padding, padding), (padding, padding))\n", "\n", " layers = [\n", " nnx.Conv(\n", " in_channels,\n", " out_channels,\n", " kernel_size=(kernel_size, kernel_size),\n", " strides=(stride, stride),\n", " padding=padding,\n", " kernel_dilation=(dilation, dilation),\n", " feature_group_count=groups,\n", " use_bias=bias,\n", " rngs=rngs,\n", " )\n", " ]\n", "\n", " if norm_layer is not None:\n", " layers.append(norm_layer(out_channels, rngs=rngs))\n", "\n", " if activation_layer is not None:\n", " layers.append(activation_layer)\n", "\n", " super().__init__(*layers)" ] }, { "cell_type": "code", "execution_count": 12, "id": "e5269a0a-f43f-4fdf-9955-aa3fcde60c01", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 14, 14, 64)\n" ] } ], "source": [ "x = jnp.ones((4, 28, 28, 32))\n", "mod = Conv2dNormActivation(32, 64, 3, 2, 1)\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "2d0cd827-ad40-4cd3-9560-6565a3df10bc", "metadata": {}, "source": [ "### `SqueezeExcitation` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/ops/misc.py#L224)." ] }, { "cell_type": "code", "execution_count": 13, "id": "4232689e-e6cc-4ffd-8a2a-41fbc34e57c2", "metadata": {}, "outputs": [], "source": [ "class SqueezeExcitation(nnx.Module):\n", " def __init__(\n", " self,\n", " input_channels: int,\n", " squeeze_channels: int,\n", " activation: Callable = nnx.relu,\n", " scale_activation: Callable = nnx.sigmoid,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " self.avgpool = nnx.avg_pool\n", " self.fc1 = nnx.Conv(input_channels, squeeze_channels, (1, 1), rngs=rngs)\n", " self.fc2 = nnx.Conv(squeeze_channels, input_channels, (1, 1), rngs=rngs)\n", " self.activation = activation\n", " self.scale_activation = scale_activation\n", "\n", " def _scale(self, x: jax.Array) -> jax.Array:\n", " scale = self.avgpool(x, (x.shape[1], x.shape[2]))\n", " scale = self.fc1(scale)\n", " scale = self.activation(scale)\n", " scale = self.fc2(scale)\n", " return self.scale_activation(scale)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " scale = self._scale(x)\n", " return scale * x" ] }, { "cell_type": "code", "execution_count": 14, "id": "83c55286-b92e-49aa-bd5f-c2448a787673", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 28, 28, 32)\n" ] } ], "source": [ "x = jnp.ones((4, 28, 28, 32))\n", "mod = SqueezeExcitation(32, 4)\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "7935790a-4cb1-46dc-ab73-12d3cb8fc636", "metadata": {}, "source": [ "### `StochasticDepth` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/ops/stochastic_depth.py#L50)." ] }, { "cell_type": "code", "execution_count": 15, "id": "96834419-eec1-4690-8bb0-447524f6bdde", "metadata": {}, "outputs": [], "source": [ "def stochastic_depth(\n", " x: jax.Array,\n", " p: float,\n", " mode: str,\n", " deterministic: bool = False,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", ") -> jax.Array:\n", " if p < 0.0 or p > 1.0:\n", " raise ValueError(f\"drop probability has to be between 0 and 1, but got {p}\")\n", " if mode not in [\"batch\", \"row\"]:\n", " raise ValueError(f\"mode has to be either 'batch' or 'row', but got {mode}\")\n", " if deterministic or p == 0.0:\n", " return x\n", "\n", " survival_rate = 1.0 - p\n", " if mode == \"row\":\n", " size = [x.shape[0]] + [1] * (x.ndim - 1)\n", " else:\n", " size = [1] * x.ndim\n", "\n", " noise = jax.random.bernoulli(\n", " rngs.dropout(), p=survival_rate, shape=size\n", " ).astype(dtype=x.dtype)\n", "\n", " if survival_rate > 0.0:\n", " noise = noise / survival_rate\n", "\n", " return x * noise\n", "\n", "\n", "class StochasticDepth(nnx.Module):\n", " def __init__(\n", " self,\n", " p: float,\n", " mode: str,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " self.p = p\n", " self.mode = mode\n", " self.deterministic = False\n", " self.rngs = rngs\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return stochastic_depth(\n", " x, self.p, self.mode, self.deterministic, rngs=self.rngs\n", " )" ] }, { "cell_type": "code", "execution_count": 16, "id": "fd95babb-95b4-4015-957d-11b9c7b9957d", "metadata": {}, "outputs": [], "source": [ "x = jnp.ones((4, 28, 28, 32))\n", "mod = StochasticDepth(0.5, \"row\")\n", "\n", "mod.eval()\n", "y = mod(x)\n", "assert (y == x).all()\n", "\n", "mod.train()\n", "y = mod(x)\n", "assert (y != x).any()" ] }, { "cell_type": "markdown", "id": "0ce251eb-a8dc-4415-9856-d16421c1d646", "metadata": {}, "source": [ "### `MBConv` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L53)" ] }, { "cell_type": "code", "execution_count": 17, "id": "636c713c-4a21-439a-b220-2b9407a06dfc", "metadata": {}, "outputs": [], "source": [ "class MBConv(nnx.Module):\n", " def __init__(\n", " self,\n", " in_channels: int,\n", " out_channels: int,\n", " expansion_ratio: float,\n", " squeeze_ratio: float,\n", " stride: int,\n", " activation_layer: Callable,\n", " norm_layer: Callable[..., nnx.Module],\n", " p_stochastic_dropout: float = 0.0,\n", " rngs = nnx.Rngs(0),\n", " ):\n", " should_proj = stride != 1 or in_channels != out_channels\n", " if should_proj:\n", " proj = [nnx.Conv(\n", " in_channels, out_channels, kernel_size=(1, 1), strides=(1, 1), use_bias=True, rngs=rngs\n", " )]\n", " if stride == 2:\n", " padding = ((1, 1), (1, 1))\n", " proj = [\n", " lambda x: nnx.avg_pool(\n", " x, window_shape=(3, 3), strides=(stride, stride), padding=padding\n", " )\n", " ] + proj\n", " self.proj = nnx.Sequential(*proj)\n", " else:\n", " self.proj = Identity()\n", "\n", " mid_channels = int(out_channels * expansion_ratio)\n", " sqz_channels = int(out_channels * squeeze_ratio)\n", "\n", " if p_stochastic_dropout:\n", " self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode=\"row\", rngs=rngs)\n", " else:\n", " self.stochastic_depth = Identity()\n", "\n", " _layers = [\n", " norm_layer(in_channels, rngs=rngs), # pre_norm\n", " Conv2dNormActivation( # conv_a\n", " in_channels,\n", " mid_channels,\n", " kernel_size=1,\n", " stride=1,\n", " padding=0,\n", " activation_layer=activation_layer,\n", " norm_layer=norm_layer,\n", " rngs=rngs,\n", " ),\n", " Conv2dNormActivation( # conv_b\n", " mid_channels,\n", " mid_channels,\n", " kernel_size=3,\n", " stride=stride,\n", " padding=1,\n", " activation_layer=activation_layer,\n", " norm_layer=norm_layer,\n", " groups=mid_channels,\n", " rngs=rngs,\n", " ),\n", " SqueezeExcitation( # squeeze_excitation\n", " mid_channels, sqz_channels, activation=nnx.silu, rngs=rngs\n", " ),\n", " nnx.Conv( # conv_c\n", " mid_channels, out_channels, kernel_size=(1, 1), use_bias=True, rngs=rngs\n", " )\n", " ]\n", " self.layers = nnx.Sequential(*_layers)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " res = self.proj(x)\n", " x = self.stochastic_depth(self.layers(x))\n", " return res + x" ] }, { "cell_type": "code", "execution_count": 18, "id": "5cd24b07-f160-422c-bea3-2baf5ebca5b0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(4, 14, 14, 64)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from functools import partial\n", "\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "x = jnp.ones((4, 28, 28, 32))\n", "mod = MBConv(32, 64, 4, 0.25, 2, activation_layer=nnx.gelu, norm_layer=norm_layer)\n", "y = mod(x)\n", "y.shape" ] }, { "cell_type": "markdown", "id": "3a8d9cb4-795b-4cb2-a014-bb440acc800b", "metadata": {}, "source": [ "### `RelativePositionalMultiHeadAttention` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L140). First we reimplement a helper function `_get_relative_position_index`:" ] }, { "cell_type": "code", "execution_count": 19, "id": "df647057-8c6f-4c6b-84f9-d6f78e649343", "metadata": {}, "outputs": [], "source": [ "def _get_relative_position_index(height: int, width: int) -> jax.Array:\n", " # PyTorch code:\n", " # coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))\n", "\n", " coords = jnp.stack(\n", " jnp.meshgrid(*[jnp.arange(height), jnp.arange(width)], indexing=\"ij\")\n", " )\n", " # PyTorch code: coords_flat = torch.flatten(coords, 1)\n", " coords_flat = coords.reshape(coords.shape[0], -1)\n", "\n", " relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]\n", " relative_coords = jnp.permute_dims(relative_coords, (1, 2, 0))\n", "\n", " # PyTorch code:\n", " # relative_coords[:, :, 0] += height - 1\n", " # relative_coords[:, :, 1] += width - 1\n", " # relative_coords[:, :, 0] *= 2 * width - 1\n", " relative_coords = relative_coords + jnp.array((height - 1, width - 1))\n", " relative_coords = relative_coords * jnp.array((2 * width - 1, 1))\n", "\n", " return relative_coords.sum(-1)" ] }, { "cell_type": "markdown", "id": "2670d86b", "metadata": {}, "source": [ "Let us check our implementation against PyTorch implementation:" ] }, { "cell_type": "code", "execution_count": 20, "id": "5ce55b8b-5305-4a57-a413-8df43392ec3a", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import _get_relative_position_index as pytorch_get_relative_position_index\n", "\n", "\n", "output = _get_relative_position_index(13, 12)\n", "expected = pytorch_get_relative_position_index(13, 12)\n", "assert (output == jnp.asarray(expected)).all()" ] }, { "cell_type": "markdown", "id": "5518bfc4", "metadata": {}, "source": [ "Next, we can port `RelativePositionalMultiHeadAttention` module which a learnable parameter and a buffer:" ] }, { "cell_type": "code", "execution_count": 21, "id": "1f46b3e4-fd69-42c2-8ca7-d242a20d13de", "metadata": {}, "outputs": [], "source": [ "import math\n", "\n", "\n", "class Buffer(nnx.Variable):\n", " pass\n", "\n", "\n", "class RelativePositionalMultiHeadAttention(nnx.Module):\n", " def __init__(\n", " self,\n", " feat_dim: int,\n", " head_dim: int,\n", " max_seq_len: int,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " if feat_dim % head_dim != 0:\n", " raise ValueError(f\"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}\")\n", "\n", " self.n_heads = feat_dim // head_dim\n", " self.head_dim = head_dim\n", " self.size = int(math.sqrt(max_seq_len))\n", " self.max_seq_len = max_seq_len\n", "\n", " self.to_qkv = nnx.Linear(feat_dim, self.n_heads * self.head_dim * 3, rngs=rngs)\n", " self.scale_factor = feat_dim**-0.5\n", "\n", " self.merge = nnx.Linear(self.head_dim * self.n_heads, feat_dim, rngs=rngs)\n", "\n", " self.relative_position_index = Buffer(_get_relative_position_index(self.size, self.size))\n", "\n", " # initialize with truncated normal bias\n", " initializer = jax.nn.initializers.truncated_normal(stddev=0.02)\n", " shape = ((2 * self.size - 1) * (2 * self.size - 1), self.n_heads)\n", " self.relative_position_bias_table = nnx.Param(initializer(rngs.params(), shape, jnp.float32))\n", "\n", " def get_relative_positional_bias(self) -> jax.Array:\n", " bias_index = self.relative_position_index.value.ravel()\n", " relative_bias = self.relative_position_bias_table[bias_index].reshape((self.max_seq_len, self.max_seq_len, -1))\n", " relative_bias = jnp.permute_dims(relative_bias, (2, 0, 1))\n", " return jnp.expand_dims(relative_bias, axis=0)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " B, G, P, D = x.shape\n", " H, DH = self.n_heads, self.head_dim\n", "\n", " qkv = self.to_qkv(x)\n", "\n", " q, k, v = jnp.split(qkv, 3, axis=-1)\n", " q = jnp.permute_dims(q.reshape((B, G, P, H, DH)), (0, 1, 3, 2, 4))\n", " k = jnp.permute_dims(k.reshape((B, G, P, H, DH)), (0, 1, 3, 2, 4))\n", " v = jnp.permute_dims(v.reshape((B, G, P, H, DH)), (0, 1, 3, 2, 4))\n", "\n", " k = k * self.scale_factor\n", "\n", " dot_prod = jnp.einsum(\"B G H I D, B G H J D -> B G H I J\", q, k)\n", " pos_bias = self.get_relative_positional_bias()\n", "\n", " dot_prod = jax.nn.softmax(dot_prod + pos_bias, axis=-1)\n", "\n", " out = jnp.einsum(\"B G H I J, B G H J D -> B G H I D\", dot_prod, v)\n", " out = jnp.permute_dims(out, (0, 1, 3, 2, 4)).reshape((B, G, P, D))\n", "\n", " out = self.merge(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 22, "id": "18d0c993", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 32, 49, 64)\n" ] } ], "source": [ "x = jnp.ones((4, 32, 49, 64))\n", "\n", "mod = RelativePositionalMultiHeadAttention(64, 16, 49)\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "875aba65-53d0-4241-bdd7-36384054ca59", "metadata": {}, "source": [ "### `SwapAxes`, `WindowPartition`, `WindowDepartition` implementations\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L213)." ] }, { "cell_type": "code", "execution_count": 23, "id": "d8a19362-733a-4359-9658-53dcffa25220", "metadata": {}, "outputs": [], "source": [ "class SwapAxes(nnx.Module):\n", " def __init__(self, a: int, b: int):\n", " self.a = a\n", " self.b = b\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " res = jnp.swapaxes(x, self.a, self.b)\n", " return res\n", "\n", "\n", "class WindowPartition(nnx.Module):\n", " def __call__(self, x: jax.Array, p: int) -> jax.Array:\n", " # Output array with expected layout of [B, H/P, W/P, P*P, C].\n", " B, H, W, C = x.shape\n", " P = p\n", " # chunk up H and W dimensions\n", " x = x.reshape((B, H // P, P, W // P, P, C))\n", " x = jnp.permute_dims(x, (0, 1, 3, 2, 4, 5))\n", " # colapse P * P dimension\n", " x = x.reshape((B, (H // P) * (W // P), P * P, C))\n", " return x\n", "\n", "\n", "class WindowDepartition(nnx.Module):\n", " def __call__(self, x: jax.Array, p: int, h_partitions: int, w_partitions: int) -> jax.Array:\n", " # Output array with expected layout of [B, H, W, C].\n", " B, G, PP, C = x.shape\n", " P = p\n", " HP, WP = h_partitions, w_partitions\n", " # split P * P dimension into 2 P tile dimensions\n", " x = x.reshape((B, HP, WP, P, P, C))\n", " # permute into B, HP, P, WP, P, C\n", " x = jnp.permute_dims(x, (0, 1, 3, 2, 4, 5))\n", " # reshape into B, H, W, C\n", " x = x.reshape((B, HP * P, WP * P, C))\n", " return x" ] }, { "cell_type": "code", "execution_count": 24, "id": "daee5b6b-595f-4344-af93-6e4bd44c217f", "metadata": {}, "outputs": [], "source": [ "x = jnp.ones((3, 4, 5, 6))\n", "mod = SwapAxes(1, 2)\n", "y = mod(x)\n", "assert y.shape == (3, 5, 4, 6)\n", "\n", "x = jnp.ones((4, 128, 128, 3))\n", "mod = WindowPartition()\n", "y = mod(x, p=16)\n", "assert y.shape == (4, (128 // 16) * (128 // 16), 16 * 16, 3)\n", "\n", "x = jnp.ones((4, (128 // 16) * (128 // 16), 16 * 16, 3))\n", "mod = WindowDepartition()\n", "y = mod(x, p=16, h_partitions=128 // 16, w_partitions=128 // 16)\n", "assert y.shape == (4, 128, 128, 3)" ] }, { "cell_type": "markdown", "id": "fe9643dd-b328-43c9-a82f-7180ee2b9a00", "metadata": {}, "source": [ "### `PartitionAttentionLayer` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L282)." ] }, { "cell_type": "code", "execution_count": 25, "id": "dfb3c640-4b51-4ca5-a7ba-2ad5f9907c57", "metadata": {}, "outputs": [], "source": [ "class PartitionAttentionLayer(nnx.Module):\n", " \"\"\"\n", " Layer for partitioning the input tensor into non-overlapping windows and\n", " applying attention to each window.\n", " \"\"\"\n", " def __init__(\n", " self,\n", " in_channels: int,\n", " head_dim: int,\n", " # partitioning parameters\n", " partition_size: int,\n", " partition_type: str,\n", " # grid size needs to be known at initialization time\n", " # because we need to know hamy relative offsets there are in the grid\n", " grid_size: Tuple[int, int],\n", " mlp_ratio: int,\n", " activation_layer: Callable,\n", " norm_layer: Callable[..., nnx.Module],\n", " attention_dropout: float,\n", " mlp_dropout: float,\n", " p_stochastic_dropout: float,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " self.n_heads = in_channels // head_dim\n", " self.head_dim = head_dim\n", " self.n_partitions = grid_size[0] // partition_size\n", " self.partition_type = partition_type\n", " self.grid_size = grid_size\n", "\n", " if partition_type not in [\"grid\", \"window\"]:\n", " raise ValueError(\"partition_type must be either 'grid' or 'window'\")\n", "\n", " if partition_type == \"window\":\n", " self.p, self.g = partition_size, self.n_partitions\n", " else:\n", " self.p, self.g = self.n_partitions, partition_size\n", "\n", " self.partition_op = WindowPartition()\n", " self.departition_op = WindowDepartition()\n", " self.partition_swap = SwapAxes(-2, -3) if partition_type == \"grid\" else Identity()\n", " self.departition_swap = SwapAxes(-2, -3) if partition_type == \"grid\" else Identity()\n", "\n", " self.attn_layer = nnx.Sequential(\n", " norm_layer(in_channels, rngs=rngs),\n", " # it's always going to be partition_size ** 2 because\n", " # of the axis swap in the case of grid partitioning\n", " RelativePositionalMultiHeadAttention(\n", " in_channels, head_dim, partition_size**2, rngs=rngs\n", " ),\n", " nnx.Dropout(attention_dropout, rngs=rngs),\n", " )\n", "\n", " # pre-normalization similar to transformer layers\n", " self.mlp_layer = nnx.Sequential(\n", " nnx.LayerNorm(in_channels, rngs=rngs),\n", " nnx.Linear(in_channels, in_channels * mlp_ratio, rngs=rngs),\n", " activation_layer,\n", " nnx.Linear(in_channels * mlp_ratio, in_channels, rngs=rngs),\n", " nnx.Dropout(mlp_dropout, rngs=rngs),\n", " )\n", "\n", " # layer scale factors\n", " self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode=\"row\", rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " # Undefined behavior if H or W are not divisible by p\n", " # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766\n", " gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p\n", " torch._assert(\n", " self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0,\n", " \"Grid size must be divisible by partition size. Got grid size of {} and partition size of {}\".format(\n", " self.grid_size, self.p\n", " ),\n", " )\n", " x = self.partition_op(x, self.p) # (B, H, W, C) -> (B, H/P, W/P, P*P, C)\n", " x = self.partition_swap(x) # -> grid: (B, H/P, P*P, W/P, C)\n", " x = x + self.stochastic_dropout(self.attn_layer(x))\n", " x = x + self.stochastic_dropout(self.mlp_layer(x))\n", " x = self.departition_swap(x) # grid: (B, H/P, P*P, W/P, C) -> (B, H/P, W/P, P*P, C)\n", " x = self.departition_op(x, self.p, gh, gw) # -> (B, H, W, C)\n", "\n", " return x" ] }, { "cell_type": "code", "execution_count": 26, "id": "d6feac34-35be-420b-a7cb-78995aed4c7a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 224, 224, 36)\n", "(4, 224, 224, 36)\n" ] } ], "source": [ "x = jnp.ones((4, 224, 224, 36))\n", "\n", "grid_size = (224, 224)\n", "mod = PartitionAttentionLayer(\n", " 36, 6, 7, \"window\", grid_size=grid_size, mlp_ratio=4,\n", " activation_layer=nnx.gelu, norm_layer=nnx.LayerNorm,\n", " attention_dropout=0.4, mlp_dropout=0.3, p_stochastic_dropout=0.2,\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)\n", "\n", "mod = PartitionAttentionLayer(\n", " 36, 6, 7, \"grid\", grid_size=grid_size, mlp_ratio=4,\n", " activation_layer=nnx.gelu, norm_layer=nnx.LayerNorm,\n", " attention_dropout=0.4, mlp_dropout=0.3, p_stochastic_dropout=0.2,\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "b89b4ca6-c17a-4c0f-859a-de7134348818", "metadata": {}, "source": [ "### `MaxVitLayer` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L386)." ] }, { "cell_type": "code", "execution_count": 27, "id": "45b3199e-711d-4125-86b9-22e90fafa28c", "metadata": {}, "outputs": [], "source": [ "class MaxVitLayer(nnx.Module):\n", " \"\"\"\n", " MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window`\n", " and a PartitionAttentionLayer with `grid`.\n", " \"\"\"\n", " def __init__(\n", " self,\n", " # conv parameters\n", " in_channels: int,\n", " out_channels: int,\n", " squeeze_ratio: float,\n", " expansion_ratio: float,\n", " stride: int,\n", " # conv + transformer parameters\n", " norm_layer: Callable[..., nnx.Module],\n", " activation_layer: Callable,\n", " # transformer parameters\n", " head_dim: int,\n", " mlp_ratio: int,\n", " mlp_dropout: float,\n", " attention_dropout: float,\n", " p_stochastic_dropout: float,\n", " # partitioning parameters\n", " partition_size: int,\n", " grid_size: Tuple[int, int],\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " layers = [\n", " # convolutional layer\n", " MBConv(\n", " in_channels=in_channels,\n", " out_channels=out_channels,\n", " expansion_ratio=expansion_ratio,\n", " squeeze_ratio=squeeze_ratio,\n", " stride=stride,\n", " activation_layer=activation_layer,\n", " norm_layer=norm_layer,\n", " p_stochastic_dropout=p_stochastic_dropout,\n", " rngs=rngs,\n", " ),\n", " # window_attention\n", " PartitionAttentionLayer(\n", " in_channels=out_channels,\n", " head_dim=head_dim,\n", " partition_size=partition_size,\n", " partition_type=\"window\",\n", " grid_size=grid_size,\n", " mlp_ratio=mlp_ratio,\n", " activation_layer=activation_layer,\n", " norm_layer=nnx.LayerNorm,\n", " attention_dropout=attention_dropout,\n", " mlp_dropout=mlp_dropout,\n", " p_stochastic_dropout=p_stochastic_dropout,\n", " rngs=rngs,\n", " ),\n", " # grid_attention\n", " PartitionAttentionLayer(\n", " in_channels=out_channels,\n", " head_dim=head_dim,\n", " partition_size=partition_size,\n", " partition_type=\"grid\",\n", " grid_size=grid_size,\n", " mlp_ratio=mlp_ratio,\n", " activation_layer=activation_layer,\n", " norm_layer=nnx.LayerNorm,\n", " attention_dropout=attention_dropout,\n", " mlp_dropout=mlp_dropout,\n", " p_stochastic_dropout=p_stochastic_dropout,\n", " rngs=rngs,\n", " )\n", " ]\n", " self.layers = nnx.Sequential(*layers)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return self.layers(x)\n", "\n", "\n", "def _get_conv_output_shape(\n", " input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int\n", ") -> Tuple[int, int]:\n", " return (\n", " (input_size[0] - kernel_size + 2 * padding) // stride + 1,\n", " (input_size[1] - kernel_size + 2 * padding) // stride + 1,\n", " )" ] }, { "cell_type": "code", "execution_count": 28, "id": "6a130b58-95cf-4ad7-8a42-5044a37c7c09", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 112, 112, 36)\n" ] } ], "source": [ "x = jnp.ones((4, 224, 224, 3))\n", "\n", "grid_size = _get_conv_output_shape((224, 224), kernel_size=3, stride=2, padding=1)\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "\n", "mod = MaxVitLayer(\n", " 3, 36, squeeze_ratio=0.25, expansion_ratio=4,\n", " stride=2, norm_layer=norm_layer, activation_layer=nnx.gelu,\n", " head_dim=6, mlp_ratio=4, mlp_dropout=0.5,\n", " attention_dropout=0.4, p_stochastic_dropout=0.3,\n", " partition_size=7, grid_size=grid_size,\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "21460039-0ed8-4c37-8382-7d91655f1086", "metadata": {}, "source": [ "### `MaxVitBlock` implementation\n", "\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L483)." ] }, { "cell_type": "code", "execution_count": 29, "id": "e4fd31d2-4354-4694-87b4-3d0644388d3d", "metadata": {}, "outputs": [], "source": [ "class MaxVitBlock(nnx.Module):\n", " \"\"\"\n", " A MaxVit block consisting of `n_layers` MaxVit layers.\n", " \"\"\"\n", " def __init__(\n", " self,\n", " # conv parameters\n", " in_channels: int,\n", " out_channels: int,\n", " squeeze_ratio: float,\n", " expansion_ratio: float,\n", " # conv + transformer parameters\n", " norm_layer: Callable[..., nnx.Module],\n", " activation_layer: Callable,\n", " # transformer parameters\n", " head_dim: int,\n", " mlp_ratio: int,\n", " mlp_dropout: float,\n", " attention_dropout: float,\n", " # partitioning parameters\n", " partition_size: int,\n", " input_grid_size: Tuple[int, int],\n", " # number of layers\n", " n_layers: int,\n", " p_stochastic: List[float],\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " if not len(p_stochastic) == n_layers:\n", " raise ValueError(f\"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.\")\n", "\n", " # account for the first stride of the first layer\n", " self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1)\n", "\n", " layers = []\n", " for idx, p in enumerate(p_stochastic):\n", " stride = 2 if idx == 0 else 1\n", " layers.append(\n", " MaxVitLayer(\n", " in_channels=in_channels if idx == 0 else out_channels,\n", " out_channels=out_channels,\n", " squeeze_ratio=squeeze_ratio,\n", " expansion_ratio=expansion_ratio,\n", " stride=stride,\n", " norm_layer=norm_layer,\n", " activation_layer=activation_layer,\n", " head_dim=head_dim,\n", " mlp_ratio=mlp_ratio,\n", " mlp_dropout=mlp_dropout,\n", " attention_dropout=attention_dropout,\n", " partition_size=partition_size,\n", " grid_size=self.grid_size,\n", " p_stochastic_dropout=p,\n", " ),\n", " )\n", " self.layers = nnx.Sequential(*layers)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " return self.layers(x)" ] }, { "cell_type": "code", "execution_count": 30, "id": "e168c27f-98db-4831-9723-dffac88f3226", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 112, 112, 36)\n" ] } ], "source": [ "x = jnp.ones((4, 224, 224, 3))\n", "\n", "input_grid_size = (224, 224)\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "\n", "mod = MaxVitBlock(\n", " 3, 36, squeeze_ratio=0.25, expansion_ratio=4,\n", " norm_layer=norm_layer, activation_layer=nnx.gelu,\n", " head_dim=6, mlp_ratio=4, mlp_dropout=0.5, attention_dropout=0.4,\n", " partition_size=7, input_grid_size=input_grid_size,\n", " n_layers=2,\n", " p_stochastic=[0.0, 0.2],\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "markdown", "id": "cef5687d-e390-438b-95b3-e66406e2c000", "metadata": {}, "source": [ "### `MaxVit` implementation\n", "\n", "Finally, we can assemble everything together and define the MaxVit model.\n", "[PyTorch source implementation](https://github.com/pytorch/vision/blob/945bdad7523806b15d3740ce6ace2fced9ef9d3b/torchvision/models/maxvit.py#L568)." ] }, { "cell_type": "code", "execution_count": 31, "id": "0e874e63-0eb7-40ea-82f3-bf10ac33d7a6", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]:\n", " \"\"\"Util function to check that the input size is correct for a MaxVit configuration.\"\"\"\n", " shapes = []\n", " block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1)\n", " for _ in range(n_blocks):\n", " block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1)\n", " shapes.append(block_input_shape)\n", " return shapes\n", "\n", "\n", "class MaxVit(nnx.Module):\n", " \"\"\"\n", " Implements MaxVit Transformer from the \"MaxViT: Multi-Axis Vision Transformer\" paper.\n", " \"\"\"\n", " def __init__(\n", " self,\n", " # input size parameters\n", " input_size: Tuple[int, int],\n", " # stem and task parameters\n", " stem_channels: int,\n", " # partitioning parameters\n", " partition_size: int,\n", " # block parameters\n", " block_channels: List[int],\n", " block_layers: List[int],\n", " # attention head dimensions\n", " head_dim: int,\n", " stochastic_depth_prob: float,\n", " # conv + transformer parameters\n", " # norm_layer is applied only to the conv layers\n", " # activation_layer is applied both to conv and transformer layers\n", " norm_layer: Optional[Callable[..., nnx.Module]] = None,\n", " activation_layer: Callable = nnx.gelu,\n", " # conv parameters\n", " squeeze_ratio: float = 0.25,\n", " expansion_ratio: float = 4,\n", " # transformer parameters\n", " mlp_ratio: int = 4,\n", " mlp_dropout: float = 0.0,\n", " attention_dropout: float = 0.0,\n", " # task parameters\n", " num_classes: int = 1000,\n", " rngs: nnx.Rngs = nnx.Rngs(0),\n", " ):\n", " input_channels = 3\n", "\n", " if norm_layer is None:\n", " norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "\n", " # Make sure input size will be divisible by the partition size in all blocks\n", " # Undefined behavior if H or W are not divisible by p\n", " block_input_sizes = _make_block_input_shapes(input_size, len(block_channels))\n", " for idx, block_input_size in enumerate(block_input_sizes):\n", " if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0:\n", " raise ValueError(\n", " f\"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. \"\n", " f\"Consider changing the partition size or the input size.\\n\"\n", " f\"Current configuration yields the following block input sizes: {block_input_sizes}.\"\n", " )\n", "\n", " # stem\n", " self.stem = nnx.Sequential(\n", " Conv2dNormActivation(\n", " input_channels,\n", " stem_channels,\n", " 3,\n", " stride=2,\n", " norm_layer=norm_layer,\n", " activation_layer=activation_layer,\n", " bias=False,\n", " rngs=rngs,\n", " ),\n", " Conv2dNormActivation(\n", " stem_channels,\n", " stem_channels,\n", " 3,\n", " stride=1,\n", " norm_layer=None,\n", " activation_layer=None,\n", " bias=True,\n", " rngs=rngs,\n", " ),\n", " )\n", "\n", " # account for stem stride\n", " input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1)\n", " self.partition_size = partition_size\n", "\n", " # blocks\n", " blocks = []\n", " in_channels = [stem_channels] + block_channels[:-1]\n", " out_channels = block_channels\n", "\n", " # precompute the stochastic depth probabilities from 0 to stochastic_depth_prob\n", " # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed\n", " # over the range [0, stochastic_depth_prob]\n", " p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist()\n", "\n", " p_idx = 0\n", " for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers):\n", " blocks.append(\n", " MaxVitBlock(\n", " in_channels=in_channel,\n", " out_channels=out_channel,\n", " squeeze_ratio=squeeze_ratio,\n", " expansion_ratio=expansion_ratio,\n", " norm_layer=norm_layer,\n", " activation_layer=activation_layer,\n", " head_dim=head_dim,\n", " mlp_ratio=mlp_ratio,\n", " mlp_dropout=mlp_dropout,\n", " attention_dropout=attention_dropout,\n", " partition_size=partition_size,\n", " input_grid_size=input_size,\n", " n_layers=num_layers,\n", " p_stochastic=p_stochastic[p_idx : p_idx + num_layers],\n", " ),\n", " )\n", " input_size = blocks[-1].grid_size\n", " p_idx += num_layers\n", " self.blocks = nnx.Sequential(*blocks)\n", "\n", " self.classifier = nnx.Sequential(\n", " lambda x: nnx.avg_pool(x, (x.shape[1], x.shape[2])), # nn.AdaptiveAvgPool2d(1)\n", " lambda x: x.reshape(x.shape[0], -1), # nn.Flatten()\n", " nnx.LayerNorm(block_channels[-1], rngs=rngs),\n", " nnx.Linear(block_channels[-1], block_channels[-1], rngs=rngs),\n", " nnx.tanh,\n", " nnx.Linear(block_channels[-1], num_classes, use_bias=False, rngs=rngs),\n", " )\n", "\n", " self._init_weights(rngs)\n", "\n", " def __call__(self, x: jax.Array) -> jax.Array:\n", " x = self.stem(x)\n", " x = self.blocks(x)\n", " x = self.classifier(x)\n", " return x\n", "\n", " def _init_weights(self, rngs):\n", " normal_initializer = nnx.initializers.normal(stddev=0.02)\n", " for name, module in self.iter_modules():\n", " if isinstance(module, (nnx.Conv, nnx.Linear)):\n", " module.kernel.value = normal_initializer(\n", " rngs(), module.kernel.value.shape, module.kernel.value.dtype\n", " )\n", " if module.bias.value is not None:\n", " module.bias.value = jnp.zeros(\n", " module.bias.value.shape, dtype=module.bias.value.dtype\n", " )\n", " elif isinstance(module, nnx.BatchNorm):\n", " module.scale.value = jnp.ones(module.scale.value.shape, module.scale.value.dtype)\n", " module.bias.value = jnp.zeros(module.bias.value.shape, module.bias.value.dtype)" ] }, { "cell_type": "code", "execution_count": 32, "id": "7e0f08b8-03a8-4941-8ca3-10d960783486", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(4, 1000)\n" ] } ], "source": [ "x = jnp.ones((4, 224, 224, 3))\n", "\n", "mod = MaxVit(\n", " input_size=(224, 224),\n", " stem_channels=64,\n", " block_channels=[64, 128, 256, 512],\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=32,\n", " stochastic_depth_prob=0.2,\n", " partition_size=7,\n", " num_classes=1000,\n", ")\n", "\n", "y = mod(x)\n", "print(y.shape)" ] }, { "cell_type": "code", "execution_count": 33, "id": "fa2a4a47-b6c9-43ba-822b-002e0c03e85c", "metadata": {}, "outputs": [], "source": [ "def maxvit_t(\n", " input_size=(224, 224),\n", " stem_channels=64,\n", " block_channels=[64, 128, 256, 512],\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=32,\n", " stochastic_depth_prob=0.2,\n", " partition_size=7,\n", " num_classes=1000,\n", "):\n", " model = MaxVit(\n", " input_size=input_size,\n", " stem_channels=stem_channels,\n", " block_channels=block_channels,\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=head_dim,\n", " stochastic_depth_prob=stochastic_depth_prob,\n", " partition_size=partition_size,\n", " num_classes=num_classes,\n", " )\n", " return model" ] }, { "cell_type": "markdown", "id": "25ff32f7-e4a1-4029-b114-8ecafb4378fd", "metadata": {}, "source": [ "### Test JAX implementation" ] }, { "cell_type": "markdown", "id": "b3e02373-c3b6-4ffd-a98e-e425824f2f88", "metadata": {}, "source": [ "Let us import equivalent PyTorch modules and check our implementations against PyTorch. Please note that\n", "PyTorch modules will contain random parameters and buffers that we need to set into our Flax implementations.\n", "\n", "Below we define a helper class `Torch2Flax` to copy parameters and buffers from a PyTorch module into equivalent Flax module." ] }, { "cell_type": "code", "execution_count": 34, "id": "22f49ecd-4999-4c1c-b1d8-d16faeb60389", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "\n", "class Torch2Flax:\n", " @staticmethod\n", " def conv_params_permute(name, torch_param):\n", " if name == \"weight\":\n", " return torch_param.permute(2, 3, 1, 0)\n", " return torch_param\n", "\n", " @staticmethod\n", " def linear_params_permute(name, torch_param):\n", " if name == \"weight\":\n", " return torch_param.permute(1, 0)\n", " return torch_param\n", "\n", " @staticmethod\n", " def default_params_transform(name, param):\n", " return param\n", "\n", " modules_mapping_info = {\n", " nn.Conv2d: {\n", " \"type\": nnx.Conv,\n", " \"params_mapping\": {\n", " \"weight\": \"kernel\",\n", " \"bias\": \"bias\",\n", " },\n", " \"params_transform\": conv_params_permute,\n", " },\n", " nn.BatchNorm2d: {\n", " \"type\": nnx.BatchNorm,\n", " \"params_mapping\": {\n", " \"weight\": \"scale\",\n", " \"bias\": \"bias\",\n", " \"running_mean\": \"mean\",\n", " \"running_var\": \"var\",\n", " },\n", " },\n", " nn.Linear: {\n", " \"type\": nnx.Linear,\n", " \"params_mapping\": {\n", " \"weight\": \"kernel\",\n", " \"bias\": \"bias\",\n", " },\n", " \"params_transform\": linear_params_permute,\n", " },\n", " nn.LayerNorm: {\n", " \"type\": nnx.LayerNorm,\n", " \"params_mapping\": {\n", " \"weight\": \"scale\",\n", " \"bias\": \"bias\",\n", " },\n", " }\n", " } | {\n", " torch_mod: {\n", " \"type\": nnx_fn_type,\n", " \"params_mapping\": {},\n", " } for torch_mod, nnx_fn_type in [\n", " (nn.Identity, Identity),\n", " (nn.Flatten, type(lambda x: x)),\n", " (nn.ReLU, type(nnx.relu)),\n", " (nn.GELU, type(nnx.gelu)),\n", " (nn.SELU, type(nnx.selu)),\n", " (nn.SiLU, type(nnx.silu)),\n", " (nn.Tanh, type(nnx.tanh)),\n", " (nn.Dropout, nnx.Dropout),\n", " (nn.Sigmoid, type(nnx.sigmoid)),\n", " (nn.AvgPool2d, type(lambda x: nnx.avg_pool(x, (2, 2)))),\n", " (nn.AdaptiveAvgPool2d, type(lambda x: nnx.avg_pool(x, (x.shape[1], x.shape[2])))),\n", " ]\n", " }\n", "\n", " def _copy_params_buffers(self, torch_nn_module, nnx_module):\n", " torch_module_type = type(torch_nn_module)\n", " assert torch_module_type in self.modules_mapping_info, torch_module_type\n", " module_mapping_info = self.modules_mapping_info[torch_module_type]\n", " assert isinstance(nnx_module, module_mapping_info[\"type\"]), (\n", " nnx_module, type(nnx_module), module_mapping_info[\"type\"]\n", " )\n", "\n", " for torch_key, nnx_key in module_mapping_info[\"params_mapping\"].items():\n", "\n", " torch_value = getattr(torch_nn_module, torch_key)\n", " nnx_param = getattr(nnx_module, nnx_key)\n", " assert nnx_param is not None, (torch_key, nnx_key, nnx_module)\n", "\n", " if torch_value is None:\n", " assert nnx_param.value is None, nnx_param\n", " continue\n", "\n", " params_transform = module_mapping_info.get(\"params_transform\", Torch2Flax.default_params_transform)\n", " torch_value = params_transform(torch_key, torch_value)\n", "\n", " assert nnx_param.value.shape == torch_value.data.shape, (\n", " nnx_key, nnx_param.value.shape, torch_key, torch_value.data.shape\n", " )\n", " nnx_param.value = jnp.asarray(torch_value.data)\n", "\n", " def _copy_sequential(self, torch_nn_seq, nnx_seq, skip_modules=None):\n", " assert isinstance(torch_nn_seq, (nn.Sequential, nn.ModuleList)), type(torch_nn_seq)\n", " assert isinstance(nnx_seq, nnx.Sequential), type(nnx_seq)\n", " for i, index in enumerate(torch_nn_seq):\n", " torch_module = torch_nn_seq[i]\n", " nnx_module = nnx_seq.layers[i]\n", " self.copy_module(torch_module, nnx_module, skip_modules=skip_modules)\n", "\n", " def copy_module(self, torch_module, nnx_module, skip_modules=None):\n", " if skip_modules is None:\n", " skip_modules = []\n", "\n", " if isinstance(torch_module, (nn.Sequential, nn.ModuleList)):\n", " self._copy_sequential(torch_module, nnx_module, skip_modules=skip_modules)\n", " elif type(torch_module) in self.modules_mapping_info:\n", " self._copy_params_buffers(torch_module, nnx_module)\n", " else:\n", " if skip_modules is not None:\n", " if torch_module.__class__.__name__ in skip_modules:\n", " return\n", "\n", " named_children = list(torch_module.named_children())\n", " assert len(named_children) > 0, type(torch_module)\n", " for name, torch_child in named_children:\n", " nnx_child = getattr(nnx_module, name, None)\n", " assert nnx_child is not None, (name, nnx_module)\n", " self.copy_module(torch_child, nnx_child, skip_modules=skip_modules)\n", " # Copy buffers and params of the module itself (not its children)\n", " for name, torch_buffer in torch_module.named_buffers():\n", " if \".\" in name:\n", " # This is child's buffer\n", " continue\n", " nnx_buffer = getattr(nnx_module, name)\n", " assert isinstance(nnx_buffer, nnx.Variable), (name, nnx_buffer, nnx_module)\n", "\n", " assert nnx_buffer.value.shape == torch_buffer.shape, (\n", " name, nnx_buffer.value.shape, torch_buffer.shape\n", " )\n", " nnx_buffer.value = jnp.asarray(torch_buffer)\n", "\n", " for name, torch_param in torch_module.named_parameters():\n", " if \".\" in name:\n", " # This is child's parameter\n", " continue\n", " nnx_param = getattr(nnx_module, name)\n", " assert isinstance(nnx_param, nnx.Param), (name, nnx_param, nnx_module)\n", "\n", " assert nnx_param.value.shape == torch_param.data.shape, (\n", " name, nnx_param.value.shape, torch_param.data.shape\n", " )\n", " nnx_param.value = jnp.asarray(torch_param.data)\n", "\n", "\n", "def test_modules(\n", " nnx_module, torch_module, torch_input, atol=1e-3, mode=\"eval\", permute_torch_input=True, device=\"cuda\"\n", "):\n", " assert torch_input.ndim == 4\n", " assert mode in (\"eval\", \"train\")\n", "\n", " torch_input = torch_input.to(device)\n", " torch_module = torch_module.to(device)\n", "\n", " if mode == \"eval\":\n", " torch_module.eval()\n", " nnx_module.eval()\n", " else:\n", " torch_module.train()\n", " nnx_module.train()\n", "\n", " with torch.inference_mode(mode=mode==\"eval\"):\n", " torch_output = torch_module(torch_input)\n", "\n", " if permute_torch_input:\n", " torch_input = torch_input.permute(0, 2, 3, 1)\n", "\n", " jax_input = jnp.asarray(torch_input, device=jax.devices(device)[0])\n", " jax_output = nnx_module(jax_input)\n", " assert jax_output.device == jax.devices(device)[0]\n", "\n", " torch_output = torch_output.detach()\n", " if permute_torch_input and torch_output.ndim == 4:\n", " torch_output = torch_output.permute(0, 2, 3, 1)\n", " jax_expected = jnp.asarray(torch_output)\n", "\n", " assert jnp.allclose(jax_output, jax_expected, atol=atol), (\n", " jnp.abs(jax_output - jax_expected).max(),\n", " jnp.abs(jax_output - jax_expected).mean(),\n", " )\n", "\n", "\n", "t2f = Torch2Flax()" ] }, { "cell_type": "markdown", "id": "a323a19e-fc64-4f8d-8be2-c7886b6191b9", "metadata": {}, "source": [ "Let us now test our JAX modules. We only test the result of the forward pass in the inference mode such that we avoid discrepancies related to random layers like `Dropout`, `StochasticDepth`, etc.\n", "By default, we use absolute error tolerence `1e-3` when comparing the JAX output against expected PyTorch result.\n", "For larger modules we set the device to CPU for the JAX model to execute on in order to reduce the errors between CPU and CUDA." ] }, { "cell_type": "code", "execution_count": 35, "id": "2e10fb43-6ae6-47c7-81fe-5027b115b25f", "metadata": {}, "outputs": [], "source": [ "from torchvision.ops.misc import Conv2dNormActivation as PyTorchConv2dNormActivation\n", "\n", "\n", "torch_module = PyTorchConv2dNormActivation(32, 64, 3, 2, 1)\n", "nnx_module = Conv2dNormActivation(32, 64, 3, 2, 1)\n", "\n", "t2f.copy_module(torch_module, nnx_module)\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))" ] }, { "cell_type": "code", "execution_count": 36, "id": "7ac4b49a-712f-4725-a8d8-33e72b8d0b66", "metadata": {}, "outputs": [], "source": [ "from torchvision.ops.misc import SqueezeExcitation as PyTorchSqueezeExcitation\n", "\n", "\n", "torch_module = PyTorchSqueezeExcitation(32, 4)\n", "nnx_module = SqueezeExcitation(32, 4)\n", "\n", "t2f.copy_module(torch_module, nnx_module)\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))" ] }, { "cell_type": "code", "execution_count": 37, "id": "746c8882-0001-4c97-b5cf-576dc5c87c02", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "from functools import partial\n", "from torchvision.models.maxvit import MBConv as PyTorchMBConv\n", "\n", "\n", "norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)\n", "torch_module = PyTorchMBConv(32, 64, 4, 0.25, 2, activation_layer=nn.GELU, norm_layer=norm_layer)\n", "\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "nnx_module = MBConv(32, 64, 4, 0.25, 2, activation_layer=nnx.gelu, norm_layer=norm_layer)\n", "\n", "\n", "t2f.copy_module(torch_module, nnx_module)\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))" ] }, { "cell_type": "code", "execution_count": 38, "id": "249f6d28-57b6-4d36-9079-cd60964e6afc", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import RelativePositionalMultiHeadAttention as PyTorchRelativePositionalMultiHeadAttention\n", "\n", "\n", "torch_module = PyTorchRelativePositionalMultiHeadAttention(64, 16, 49)\n", "nnx_module = RelativePositionalMultiHeadAttention(64, 16, 49)\n", "\n", "t2f.copy_module(torch_module, nnx_module)\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 32, 49, 64), permute_torch_input=False)" ] }, { "cell_type": "code", "execution_count": 39, "id": "f48fc475-c556-4101-ad2b-19480a73c6ba", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import PartitionAttentionLayer as PyTorchPartitionAttentionLayer\n", "\n", "\n", "grid_size = (224, 224)\n", "for partition_type in [\"window\", \"grid\"]:\n", "\n", " torch_module = PyTorchPartitionAttentionLayer(\n", " 36, 6, 7, partition_type, grid_size=grid_size, mlp_ratio=4,\n", " activation_layer=nn.GELU, norm_layer=nn.LayerNorm,\n", " attention_dropout=0.4, mlp_dropout=0.3, p_stochastic_dropout=0.2,\n", " )\n", "\n", " nnx_module = PartitionAttentionLayer(\n", " 36, 6, 7, partition_type, grid_size=grid_size, mlp_ratio=4,\n", " activation_layer=nnx.gelu, norm_layer=nnx.LayerNorm,\n", " attention_dropout=0.4, mlp_dropout=0.3, p_stochastic_dropout=0.2,\n", " )\n", "\n", " t2f.copy_module(torch_module, nnx_module, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", " ])\n", "\n", " test_modules(nnx_module, torch_module, torch.randn(4, 36, 224, 224))" ] }, { "cell_type": "code", "execution_count": 40, "id": "7ab2de6a-9790-444b-b175-535cdb05f5d8", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import MaxVitLayer as PyTorchMaxVitLayer\n", "\n", "\n", "stride = 2\n", "\n", "grid_size = _get_conv_output_shape((224, 224), kernel_size=3, stride=2, padding=1)\n", "norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)\n", "\n", "torch_module = PyTorchMaxVitLayer(\n", " 36, 36, squeeze_ratio=0.25, expansion_ratio=4,\n", " stride=stride, norm_layer=norm_layer, activation_layer=nn.GELU,\n", " head_dim=6, mlp_ratio=4, mlp_dropout=0.5,\n", " attention_dropout=0.4, p_stochastic_dropout=0.3,\n", " partition_size=7, grid_size=grid_size,\n", ")\n", "\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "nnx_module = MaxVitLayer(\n", " 36, 36, squeeze_ratio=0.25, expansion_ratio=4,\n", " stride=stride, norm_layer=norm_layer, activation_layer=nnx.gelu,\n", " head_dim=6, mlp_ratio=4, mlp_dropout=0.5,\n", " attention_dropout=0.4, p_stochastic_dropout=0.3,\n", " partition_size=7, grid_size=grid_size,\n", ")\n", "\n", "t2f.copy_module(torch_module, nnx_module, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", "])\n", "\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 36, 224, 224), device=\"cpu\")" ] }, { "cell_type": "code", "execution_count": 41, "id": "e8e8f997-0184-4af2-82b0-ceaa580645c8", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import MaxVitBlock as PyTorchMaxVitBlock\n", "\n", "\n", "norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)\n", "torch_module = PyTorchMaxVitBlock(\n", " 64, 128, squeeze_ratio=0.25, expansion_ratio=4,\n", " norm_layer=norm_layer, activation_layer=nn.GELU,\n", " head_dim=32, mlp_ratio=4, mlp_dropout=0.0, attention_dropout=0.0,\n", " partition_size=7, input_grid_size=(56, 56),\n", " n_layers=2,\n", " p_stochastic=[0.13333333333333333, 0.2],\n", ")\n", "\n", "norm_layer = partial(nnx.BatchNorm, epsilon=1e-3, momentum=0.99)\n", "nnx_module = MaxVitBlock(\n", " 64, 128, squeeze_ratio=0.25, expansion_ratio=4,\n", " norm_layer=norm_layer, activation_layer=nnx.gelu,\n", " head_dim=32, mlp_ratio=4, mlp_dropout=0.0, attention_dropout=0.0,\n", " partition_size=7, input_grid_size=(56, 56),\n", " n_layers=2,\n", " p_stochastic=[0.13333333333333333, 0.2],\n", ")\n", "\n", "t2f.copy_module(torch_module, nnx_module, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", "])\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 64, 56, 56), device=\"cpu\")" ] }, { "cell_type": "markdown", "id": "e313819a-e93a-4201-806d-783bd1336c78", "metadata": {}, "source": [ "Finally, we can check the MaxVit implementation. Note that we raised the absolute tolerence to `1e-1` when comparing JAX output logits against PyTorch expected logits." ] }, { "cell_type": "code", "execution_count": 42, "id": "e2af63a4-b16b-40a3-ac00-bfc23d532c82", "metadata": {}, "outputs": [], "source": [ "from torchvision.models.maxvit import MaxVit as PyTorchMaxVit\n", "\n", "\n", "torch.manual_seed(77)\n", "\n", "\n", "torch_module = PyTorchMaxVit(\n", " input_size=(224, 224),\n", " stem_channels=64,\n", " block_channels=[64, 128, 256, 512],\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=32,\n", " stochastic_depth_prob=0.2,\n", " partition_size=7,\n", " num_classes=1000,\n", ")\n", "\n", "nnx_module = MaxVit(\n", " input_size=(224, 224),\n", " stem_channels=64,\n", " block_channels=[64, 128, 256, 512],\n", " block_layers=[2, 2, 5, 2],\n", " head_dim=32,\n", " stochastic_depth_prob=0.2,\n", " partition_size=7,\n", " num_classes=1000,\n", ")\n", "\n", "t2f.copy_module(torch_module, nnx_module, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", "])\n", "\n", "\n", "test_modules(nnx_module, torch_module, torch.randn(4, 3, 224, 224), device=\"cpu\", atol=1e-1)" ] }, { "cell_type": "markdown", "id": "0d3c4f4d-2a50-46f4-814c-42c1a423cfd0", "metadata": {}, "source": [ "### Check Flax model\n", "Let us now reuse trained weights from TorchVision's MaxViT model to check output logits and the predictions on our example image:" ] }, { "cell_type": "code", "execution_count": 43, "id": "7975f311-7a02-4c82-99db-b0b50fb37528", "metadata": {}, "outputs": [], "source": [ "from torchvision.models import maxvit_t as pytorch_maxvit_t, MaxVit_T_Weights\n", "\n", "torch_model = pytorch_maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1)\n", "flax_model = maxvit_t()\n", "\n", "t2f = Torch2Flax()\n", "t2f.copy_module(torch_model, flax_model, skip_modules=[\n", " \"WindowPartition\", \"WindowDepartition\", \"SwapAxes\", \"StochasticDepth\",\n", "])" ] }, { "cell_type": "code", "execution_count": 44, "id": "922cc4b5-f181-4865-9043-fd3b56bafe43", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction for the Dog:\n", "- PyTorch model result: ['n02113023', 'Pembroke'], score: 0.7800846099853516\n", "- Flax model result: ['n02113023', 'Pembroke'], score: 0.7799879908561707\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAE4CAYAAABxMwiDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9ebxlV1nm/13THs6599aQVEZIgBBIiAjIoNB2EiQMQUTAgCjIIBpURAVtFPqHAXFgkkRBFBVFgaYFFNsJlEZsbGUQlEEFjQiIdsZKqu5wzt57De/vj3fXhaISUokJVUX2k8/9QO2zzx7OWfs5a73D8xgRESZMmDBhwoQJE44i2CN9ARMmTJgwYcKECV+KaYIyYcKECRMmTDjqME1QJkyYMGHChAlHHaYJyoQJEyZMmDDhqMM0QZkwYcKECRMmHHWYJigTJkyYMGHChKMO0wRlwoQJEyZMmHDUYZqgTJgwYcKECROOOkwTlAkTJkyYMGHCUYevmgnK0572NIwxGGP4mq/5miN9ORNup/iLv/gLjDG8/e1vv83P9dnPfhZjDK985Stvk+Pv27dv+5m6Lc9zJDHxxoSjARNv3DC+aiYoAMcffzxvfOMbeelLX3rQ9jvd6U686EUvusXH/eQnP8kjHvEIVlZW2L17N9/1Xd/FNddcc9A+n/rUp3je857Hve99b1ZXVzn55JP55m/+Zj784Q8fcrx/+qd/4jnPeQ4PetCDaJoGYwyf/exnb/Dcv/M7v8OTn/xkzjzzTIwxnH/++Te43z/8wz/w+Mc/nrvc5S7MZjOOP/54zj33XP7wD//wK3JP73jHO3j4wx/OKaecQl3X3OEOd+Ciiy7i7//+7w/Z1xjDG97whhu8rpvC+eefz9Oe9rTtfx942A78Oec47bTTeOxjH8tHP/rRW3SO2yue9rSnHTS+5vM5b3zjG7n00kuP3EV9BTDxxsQbE2/cctyWvOH/00c4ijCfz3nyk598qx7z3//93zn33HPZsWMHP/uzP8vm5iavfOUr+cQnPsGHPvQhqqoC4Nd//dd5/etfz7d927fxAz/wA+zfv5/Xve51fMM3fAPvete7uOCCC7aP+f73v59f/MVf5B73uAdnn332l30gfvmXf5mPfOQj3P/+92fv3r03ut/nPvc5NjY2eOpTn8opp5zCYrHgd3/3d3n0ox/N6173Oi6++OLb9J4+8YlPsGvXLn74h3+Y448/niuvvJLf+I3f4AEPeADvf//7ude97nVLv4LDwnd8x3fwyEc+kpwzn/zkJ/nlX/5l3vnOd/KBD3yAe9/73rfpub9aEULgyU9+Mp/97Gd5znOec6Qv5zbDxBsTb0y8cevhVuUN+SrBU5/6VDn99NNv8LXTTz9dLrnkklt03O///u+Xtm3lc5/73Pa2d7/73QLI6173uu1tH/7wh2VjY+Og91577bWyZ88e+S//5b8ctH3v3r2yvr4uIiKveMUrBJDPfOYzN3j+f/u3f5Ocs4iInHPOOXLeeecd9rWnlORe97qX3P3ud7/N7+mGcOWVV4r3Xp75zGcetB2Q3/zN3zzs+/hinHfeefLUpz51+9+f+cxnBJBXvOIVB+33B3/wBwLIxRdffIvOc0vx3ve+VwB529vedovev7W1ddj73ti931I89alPvcHxdWuf52jCxBuHYuKNiTduDm5L3viqSvEcLt7whjdgjOGv/uqveO5zn8uePXuYz+c89rGPPSRc+bu/+7s86lGP4rTTTtvedsEFF3C3u92Nt771rdvb7nvf+7KysnLQe4877jj+63/9r3zyk588aPvu3btZXV09rGu94x3viLW37GtyznHHO96Rffv2HbT9trinG8IJJ5zAbDY75PxfCXzTN30TAJ/5zGe2t33wgx/kEY94BDt27GA2m3HeeefxV3/1Vwe970UvehHGGP75n/+ZJz/5yezYsYM9e/bwwhe+EBHh85//PN/6rd/K2toaJ510Ej//8z9/g+fPOfOCF7yAk046ifl8zqMf/Wg+//nPH7TP+eefz9d8zdfwkY98hHPPPZfZbMYLXvACAK6++mqe8YxncOKJJ9I0Dfe61734rd/6rZu8bxHh4osvpqoqfu/3fm97+5ve9Cbue9/70rYtu3fv5olPfOIh1zPhy2PijYk3Jt74yvLG7XKCcgDPfvaz+djHPsYll1zC93//9/OHf/iH/OAP/uD26//xH//B1Vdfzf3ud79D3vuABzyAv/u7v7vJc1x55ZUcf/zxt+p1fzlsbW1x7bXX8ulPf5pLL72Ud77znTzkIQ/Zfv22vqd9+/ZxzTXX8IlPfILv+Z7vYX19/aDzf6Xw6U9/GlBiBPjzP/9zzj33XNbX17nkkkv42Z/9Wfbt28c3fdM38aEPfeiQ93/7t387pRRe+tKX8vVf//X89E//NJdddhkPfehDOfXUU3nZy17GXe96V37sx36M973vfYe8/2d+5mf44z/+Y378x3+cH/qhH+Ld7343F1xwAcvl8qD99u7dy4UXXsi9731vLrvsMh784AezXC45//zzeeMb38iTnvQkXvGKV7Bjxw6e9rSn8Qu/8As3es85Z572tKfx27/927zjHe/gcY973Pa1POUpT+HMM8/kVa96FT/yIz/Ce97zHs4999wj8iNwrGPijYMx8cYXMPHGrYxbHHs5yvDlQrVfit/8zd8UQC644AIppWxvf85zniPOOdm3b5+IiPzN3/yNAPLbv/3bhxzjv/23/yaAdF13o+d53/veJ8YYeeELX3ij+9xUqPaLcTih2mc+85kCCCDWWrnooovkuuuu2379tr6nu9/97tvnX1lZkf/v//v/tkPNtwUOhBFf/OIXyzXXXCNXXnml/MVf/IXc5z73EUB+93d/V0opcuaZZ8rDH/7wg77vxWIhd77zneWhD33o9rZLLrnkkBBvSknucIc7iDFGXvrSl25vv/7666Vt24NCxwdCtaeeeup2OF5E5K1vfasA8gu/8Avb28477zwB5Fd+5VcOuqfLLrtMAHnTm960vW0YBnngAx8oKysr28f94hBqjFG+/du/Xdq2lT/90z/dft9nP/tZcc7Jz/zMzxx0jk984hPivT9k+5f7jG9vKZ4vxcQbE2+ITLzxleSN23UE5eKLL8YYs/3v//pf/ys5Zz73uc8BbM9a67o+5L1N0xy0z5fi6quv5ju/8zu5853vzPOe97xb+9JvFD/yIz/Cu9/9bn7rt36LCy+8kJwzwzBsv35b39Nv/uZv8q53vYvXvva1nH322SyXS3LO/9nbuklccskl7Nmzh5NOOonzzz+fT3/607zsZS/jcY97HB/96Ee5/PLL+c7v/E727t3Ltddey7XXXsvW1hYPechDeN/73kcp5aDjfc/3fM/2/3fOcb/73Q8R4RnPeMb29p07d3L3u9+df/3Xfz3kep7ylKccFI6/6KKLOPnkk/mTP/mTg/ar65qnP/3pB237kz/5E0466SS+4zu+Y3tbCIEf+qEfYnNzk//zf/7PQfsPw8DjH/94/uiP/og/+ZM/4WEPe9j2a7/3e79HKYUnPOEJ2/d97bXXctJJJ3HmmWfy3ve+93A+3glfhIk3DsbEG1/AxBu3Lr6qunhuLr44lwqwa9cuAK6//noA2rYFoO/7Q97bdd1B+3wxtra2eNSjHsXGxgb/9//+30PysbclzjrrLM466yxAB/vDHvYwvuVbvoUPfvCDGGNu83t64AMfuP3/n/jEJ3L22WcD3OYaGhdffDGPf/zjsdayc+dOzjnnnG0yvfzyywF46lOfeqPv379///b3D4eOjR07dtA0zSEh6h07dtxgl8SZZ5550L+NMdz1rnc9pC301FNP3e5+OIDPfe5znHnmmYfUEBz4LA/8EB7Az/3cz7G5uck73/nOQ9pJL7/8ckTkkOs5gBDCDW6fcOOYeONgTLwx8cZthdv1BMU5d4PbRQSAk08+GYArrrjikH2uuOIKdu/efciKYhgGHve4x/Hxj3+cP/3TPz3i4k8XXXQRz3zmM/nnf/5n7n73u39F72nXrl180zd9E29+85tvc6I588wzD2pf/GIcWOW84hWvuNHWwS8lzhsaGzc1Xm4JbojUby4e/vCH8653vYuXv/zlnH/++dsrWtB7N8bwzne+8wav/yv5I/jVgok3DsbEG1/AxBu3Lm7XE5SbwqmnnsqePXtuUGDoQx/60CGDtpTCU57yFN7znvfw1re+lfPOO+8rdKU3jgNh1/379wNf+XtaLpfb5z5SOOOMMwBYW1u7UTK6tXFg9XUAIsK//Mu/8LVf+7U3+d7TTz+dj3/845RSDloNfepTn9p+/YvxDd/wDXzf930fj3rUo3j84x/PO97xDrzXR/uMM85ARLjzne/M3e52t//sbU04DEy8MfHGLcXEGwfjdl2Dcjj4tm/7Nv7oj/7ooNaq97znPfzzP/8zj3/84w/a99nPfja/8zu/w2tf+9rtSuivFK6++upDtsUY+e3f/m3atuUe97jH9vbb4p5u6Pyf/exnec973nODlf9fSdz3vvfljDPO4JWvfCWbm5uHvP6lLaK3Bn77t3+bjY2N7X+//e1v54orruDCCy+8yfc+8pGP5Morr+R3fud3trellHj1q1/NysrKDZL9BRdcwP/8n/+Td73rXXzXd33X9urvcY97HM45XvziFx+yYhORLyviNeGWY+KNiTduCSbeOBhTBOUm8IIXvIC3ve1tPPjBD+aHf/iH2dzc5BWveAX3vOc9DypSuuyyy3jta1/LAx/4QGazGW9605sOOs5jH/tY5vM5oKuSV7/61QDb/fSvec1r2LlzJzt37jyoZfF973vfdjvaNddcw9bWFj/90z8NwLnnnsu5554LwDOf+UzW19c599xzOfXUU7nyyit585vfzKc+9Sl+/ud//qCQ3G1xT/e85z15yEMewr3vfW927drF5Zdfzutf/3pijIdIiN8QjDGcd955/MVf/MVN7ntzYa3l13/917nwwgs555xzePrTn86pp57Kf/zHf/De976XtbW1G5X2vqXYvXs33/iN38jTn/50rrrqKi677DLuete78r3f+703+d6LL76Y173udTztaU/jIx/5CHe60514+9vfzl/91V9x2WWX3agWxmMe8xh+8zd/k6c85Smsra3xute9jjPOOIOf/umf5vnPfz6f/exnecxjHsPq6iqf+cxneMc73sHFF1/Mj/3Yj92q9z5h4o2JN24ZJt74Etzi/p+jDLekXfBv/uZvDtp+oNXrve9970Hb//7v/14e9rCHyWw2k507d8qTnvQkufLKKw85P2Ob3A39fXE74IH2qxv6+9J7ONC+dkN/X6xy+Za3vEUuuOACOfHEE8V7L7t27ZILLrhA/tf/+l83+Bnc2vd0ySWXyP3udz/ZtWuXeO/llFNOkSc+8Yny8Y9//Mt/GSKysbEhgDzxiU+8yX2/FDenle3v/u7v5HGPe5wcd9xxUte1nH766fKEJzxB3vOe9xx0H4Bcc801B733qU99qszn80OOed5558k555yz/e8DY+gtb3mLPP/5z5cTTjhB2raVb/7mbz5IgfOG3vvFuOqqq+TpT3+6HH/88VJVldzznvc8REXzxu79ta99rQDyYz/2Y9vbfvd3f1e+8Ru/UebzuczncznrrLPkWc96lvzTP/3Tl//Qvsx5vhow8cbEGzeFiTeOHG8Ykf9Epc5RhKc97Wn8+Z//OX/7t3+L956dO3ce6UuacJj4kz/5Ex71qEfxsY99jHve855H+nImjJAxlPv5z3+er/u6r+MVr3jFV120ZeKNYxcTbxyduDV546sqxfP5z3+ePXv2cM4559ygG+aEoxPvfe97eeITnziRzFGG/fv3s2fPniN9Gbc5Jt44NjHxxtGJW5M3vmoiKP/4j//I//t//w/QFqhv+IZvOMJXNGHCsY2U0kG5/bvd7W6H6Dwc65h4Y8KEWxe3Jm981UxQJkyYMGHChAlfPZjajCdMmDBhwoQJRx2mCcqECRMmTJgw4ajDNEGZMGHChAkTJhx1mCYoEyZMmDBhwoSjDtME5SjGJz7xCS666CJOP/10mqbh1FNP5aEPfei2muRXC/76r/+ab/zGb2Q2m3HSSSdt24PfFN7whjdgjLnRvze/+c0H7f+///f/5sEPfjDHH388O3fu5AEPeABvfOMbb/DYr3/96zn77LNpmoYzzzzzBj/zf/qnf+I5z3kOD3rQg2iaBmPMIa6jEyYcCUzc8eVxpLkD4D/+4z94whOewM6dO1lbW+Nbv/Vb+dd//deb/yF8FWPq4jlK8dd//dc8+MEP5rTTTuOpT30qJ510Ep///Of5wAc+wKc//Wn+5V/+5Uhf4q2Cj370ozzwgQ/k7LPP5uKLL+bf//3feeUrX8mDH/xg3vnOd37Z9/7rv/4rf/3Xf33I9ksvvZSPfexj/Pu//zsnnXQSAH/wB3/AYx7zGB74wAfyHd/xHRhjeOtb38r73vc+XvWqV/Gc5zxn+/2ve93r+L7v+z6+7du+jYc//OH85V/+JW984xt56Utfyo//+I9v7/eGN7yBZzzjGdzjHvfAe89HP/pRPvOZz3CnO93p1vlwJky4BZi44+jnjs3NTb7u676O/fv386M/+qOEELj00ksRET760Y9y3HHH3Uqf0jGOW6xBO+E2xSMf+UjZs2ePXH/99Ye8dtVVV31Fr2Vra+s2O/aFF14oJ598suzfv39726/92q8JIH/6p396s4+3WCxkdXVVHvrQhx60/aEPfaiccsop0nXd9rYYo5xxxhnytV/7tQe9/7jjjpNv/uZvPuj9T3rSk2Q+n8t11123vW3v3r2yvr4uIiKveMUrDpHxnjDhSGDijqOfO172spcJIB/60Ie2t33yk58U55w8//nPv9nX/tWKKcVzlOLTn/4055xzzg1Kb59wwgmHbHvTm97EAx7wAGazGbt27eLcc8/lz/7szw7a57WvfS3nnHMOdV1zyimn8KxnPYt9+/YdtM/555/P13zN1/CRj3yEc889l9lsxgte8AIA+r7nkksu4a53vSt1XXPHO96R5z3vefR9f9Axrr32Wj71qU+xWCy+7D2ur6/z7ne/myc/+cmsra1tb3/KU57CysoKb33rW7/s+28If/iHf8jGxgZPetKTDjnXrl27qOt6e5v3nuOPP562bbe3vfe972Xv3r38wA/8wEHvf9aznsXW1hZ//Md/vL1t9+7dN2rANWHCkcLEHUc/d7z97W/n/ve/P/e///23t5111lk85CEPuUXX/tWKaYJylOL000/nIx/5yGFJb7/4xS/mu77ruwgh8FM/9VO8+MUv5o53vCN//ud/vr3Pi170Ip71rGdxyimn8PM///N827d9G6973et42MMeRozxoOPt3buXCy+8kHvf+95cdtllPPjBD6aUwqMf/Whe+cpX8i3f8i28+tWv5jGPeQyXXnop3/7t337Q+1/zmtdw9tln86EPfejLXvcnPvEJUkqH2KpXVcW9731v/u7v/u4m7/1L8eY3v5m2bQ+xeD///PP5h3/4B174whfyL//yL3z605/mJS95CR/+8Id53vOet73fgXN+6TXd9773xVp7i65pwoSvJCbuOLq5o5TCxz/+8UP2A3jAAx7Apz/9aTY2Nm729X9V4kiHcCbcMP7sz/5MnHPinJMHPvCB8rznPU/+9E//VIZhOGi/yy+/XKy18tjHPlZyzge9VkoREZGrr75aqqqShz3sYQft85rXvEYA+Y3f+I3tbeedd54A8iu/8isHHeuNb3yjWGvlL//yLw/a/iu/8isCyF/91V9tbzvg7Pml7q5fire97W0CyPve975DXnv84x8vJ5100pd9/5di7969UlWVPOEJTzjktc3NTXnCE54gxphtV9XZbCa///u/f9B+z3rWs8Q5d4PH37Nnz406p04pnglHCybuOLq545prrhFAfuqnfuqQ/X7pl35JAPnUpz51s67/qxVTBOUoxUMf+lDe//738+hHP5qPfexjvPzlL+fhD384p556Kn/wB3+wvd/v//7vU0rhJ3/yJ7H24K/TGANoBfowDPzIj/zIQft87/d+L2traweFHgHquubpT3/6Qdve9ra3cfbZZ3PWWWdx7bXXbv990zd9E6DhzQN40YtehIhw/vnnf9l7XC6X2+f7UjRNs/364eLtb387wzAcEqI9cI673e1uXHTRRbzlLW/hTW96E/e73/148pOfzAc+8IGDrqmqqhs8/i25pgkTvtKYuOPo5o6buvYv3uf2jq8qN+OvNtz//vfn937v9xiGgY997GO84x3v4NJLL+Wiiy7iox/9KPe4xz349Kc/jbWWe9zjHjd6nM997nMA3P3udz9oe1VV3OUud9l+/QBOPfXUQx60yy+/nE9+8pM36lJ59dVX3+z7O5C//dI8NEDXdQfldw8Hb37zm9m9ezcXXnjhIa/94A/+IB/4wAf427/9222ifcITnsA555zDD//wD/PBD35w+5qGYbjB49+Sa5ow4Uhg4o6jlztu6tq/eJ/bO6YJyjGAqqq2C6rudre78fSnP523ve1tXHLJJbfJ+W7o4SilcM973pNXvepVN/ieO97xjjf7PCeffDIAV1xxxSGvXXHFFZxyyimHfax/+7d/4y//8i+5+OKLCSEc9NowDLz+9a/nec973kGrwBACF154Ia95zWsYhoGqqjj55JPJOXP11VcfVFA4DAN79+69Wdc0YcKRxsQdN42vNHfs3r2buq5v9NqBiWdGTBOUYwwHCqsODOQzzjiDUgr/+I//yL3vfe8bfM/pp58OqLDYXe5yl+3twzDwmc98hgsuuOAmz3vGGWfwsY99jIc85CHb4d//LL7ma74G7z0f/vCHecITnnDQdX30ox89aNtN4S1veQsicoMh2r1795JSIud8yGsxRkop268d+Aw//OEP88hHPnJ7vw9/+MOUUm70M54w4WjHxB03jK80d1hruec978mHP/zhQ475wQ9+kLvc5S5Td+ABHNkSmAk3hj//8z/fLlT7Yhzon3/Vq14lIjev0O0Rj3jEQcd87Wtfe4OFbuecc84h533DG94ggLzuda875LXFYiGbm5vb/77mmmvkk5/85GFpIDziEY+Qk08+eVtPRETk13/91wWQd77zndvbtra25JOf/KRcc801N3icr/3ar5XTTjvtBj+zlJLs3LlT7na3u0nf99vbNzY25A53uIOcddZZB93L7t275VGPetRBx3jyk58ss9lM9u7de4Pnn4pkJxwtmLjj6OeOl770pQLI3/zN32xv+9SnPiXOOfnxH//xm7z32wumCcpRinPOOUfufOc7y3Of+1z51V/9VXnNa14j3/md3ynOObnTne50kAjTC1/4QgHkQQ96kLzyla+UV7/61fKUpzxFfuInfmJ7nwPV8Q972MPkNa95jTz72c8W55zc//73P6i6/8ZIJucsj3zkI8UYI0984hPl1a9+tVx22WXyfd/3fbJ79+6DHrTDrcQXEfnIRz4idV3Lfe5zH/nlX/5l+e///b9L0zTysIc97KD93vve9wogl1xyySHH+MQnPiHAQff7pfjpn/5pAeQ+97mPXHrppfLKV75Szj77bAHkTW9600H7Hqikv+iii+TXfu3X5ClPeYoA8jM/8zMH7bdv3z55yUteIi95yUvkEY94hADyoz/6o/KSl7xEXv3qV9/kvU+YcFtg4o6jnzvW19fljDPOkBNOOEFe/vKXy6WXXip3vOMd5ZRTTpGrr776Ju/99oJpgnKU4p3vfKd893d/t5x11lmysrIiVVXJXe96V3n2s599g2qQv/EbvyH3uc99pK5r2bVrl5x33nny7ne/+6B9XvOa18hZZ50lIQQ58cQT5fu///sPUZu8MZIRERmGQV72spfJOeecs32e+973vvLiF7/4IDXHm0MyIiJ/+Zd/KQ960IOkaRrZs2ePPOtZzzpoVSTy5UnmJ37iJwSQj3/841/2PG9+85vlAQ94gOzcuVPatpWv//qvl7e//e03uO+v/uqvyt3vfnepqkrOOOMMufTSSw9ZYX3mM5/Zbjv80r/TTz/9sO59woRbGxN3HP3cISLy+c9/Xi666CJZW1uTlZUVedSjHiWXX375Yd337QWTF8+ECRMmTJgw4ajDpIMyYcKECRMmTDjqME1QJkyYMGHChAlHHaYJyoQJEyZMmDDhqMMRnaD80i/9Ene6051omoav//qvv0mDqAkTJkyYeGPChNsHjtgE5Xd+53d47nOfyyWXXMLf/u3fcq973YuHP/zht0j2eMKECbcPTLwxYcLtB0esi+frv/7ruf/9789rXvMaQOWQ73jHO/LsZz+bn/iJnzgSlzRhwoSjHBNvTJhw+8ERkbofhoGPfOQjPP/5z9/eZq3lggsu4P3vf/8h+/d9f5CxUimF6667juOOO+5Wk06eMGHCzYOIsLGxwSmnnHKIG+5tgZvLGzBxx4QJRxtuDm8ckQnKtddeS86ZE0888aDtJ554Ip/61KcO2f/nfu7nePGLX/yVurwJEybcDHz+85/nDne4w21+npvLGzBxx4QJRysOhzeOCbPA5z//+Tz3uc/d/vf+/fs57bTTeMD5/4Xdu1c5fs8uTjxxD0NKbC07gvPsv34/n/+Hz7J1/Qa7T93JaWeezo49O6mbhpwLJRe2Fguuv36DvdfsY/++TYYh4oDcDbgizELNfEfLbPcMKkfKGe88wVcYDEMq9MNAignnLU0TqJuAdQ6RQuU8bdNgrME5T11VWOeIaWBIPSICYtSMikzwYK0BLGRDHBJ9v0RKpnYt89lOfO3BF3KJWKAONRiD2AImUSSSUqJkQykGiiVnkJIQMsYYnHUE50mSiUMkx0wpGWOEUiK2CGuhxVpHtjUm1Ii1WG+xYlj2A/2yx4gh+BoHxK0ly80FOQ/Y2uFWGqq6xThHygXrHbt2rLK6ska7uos9x92BebODlAcW/Tp7r7mSuFzS5wWFiGDoY0bEMK8aqsoS2kIpC5wXnDWknNjc7Njc3GLoM9Y6KEKMMHQZZyswFkFIOWMQ6rqlCS3eOYoIfRwoUvA+0DTN9ozeWQMIRQrdcsG+fRvs27fJEAd2rM4IvmLZ9SyXHRaDs4ZYMkNKGGuZtXNOPfU05m2LFGFIS2JZ4kNGzEAqicoFQqiom4ohLdm3WOf66/bTbxVsDjT1jNm8xYeAJFj2kVgKO3dX7NrdsLrSUPtAVVVYY+mGSEqZPBRiTNR1xWyu4w9TwEJwFjGFmCKpFDDCECOLbkm37DAYrA2UFCkIoapo6oAxhW7oWW4V9l3Xsbm/I8bCO379fx3VxmYTd0zcMXHH0cUd11+zwTte/0eHxRtHZIJy/PHH45zjqquuOmj7VVddxUknnXTI/nVdU9f1IdtXV1rW1ubs3r2T+eoc1w/EUnAYTAZnDLNZy0kn7eGkU05g5/HHYaxlsVzSdz19P2CMoQhYH7BFkJxpm8BxKzPmdUuzc4aZBYqBxbLDGIsYQ8Hgq4D1jq7rGSSDs0ok1mIMNHVFqDwYg3OOuqowOHIRvBV8CHjnSamnlIE6OIyR8Uu2bLKkH3pAqOo5s/kq9axBbGaIS7wx1FWFmEIhk0pHKgXBYIzFZANGr6UUi7UFaw3eWrzz+CJYY8muYExBJDEsMz4XGrE4KkpokapGvCPmgRQFaxzeB72nuiX3EeMcBkvJgsOSh8JQBoqBDLSzGbN2jR1rx7Oy6wT2nHgKTZjR9Qu66zdwVSFLIUgBLxgLTXZY07BSzQm1pWohSwUSsVhSHKgqz8pqjTEW5yxDTHSdsH59pFtmShZyLnjvCF6vG2vJxqgLqTHUoaGdzaiqgLGGUgqUMo4yoeRMCPrd11VDE2qMMXgMtgAUXAjU8wqRQkyZugm0rdcfHWsIRTAmUDcQ6oKIULJusx6GDLbJOJ/pNoVhc/wOvcEGg3HgsqHvCovlwDxXRIFgLcVAqAKrVYNk6PqBoU847/C+BlMwtmCdxXsHBlyIpBRJFMQKFR7jWxCHLQbvGmzl8d7jrcVYoUmRqooYAj4E1vdvAXzFUiU3lzdg4o6JOybuONq4oxQtez0c3jgiE5Sqqrjvfe/Le97zHh7zmMcAmht+z3veww/+4A8e9nF87QnBUVUeIZNLRoyh7xP9YolIoW5rVnesMl9dYd7OAMNi0bG+vmBzfQEZvHcYBCeGkjKrq3N2r+3AOkv2FqxFcgYMOReMMwTvqapAKdDFSNd1mD5TV466rnEYrAEoxCw6oMVgsmWIA1VVMatnOOcolaUkizMWER14xVq8iXg8BUsVKqoqMG/nFFPwTq/JOINIoeRMKUIaCqUYnK2pfAPFUnIiDh2xdFh0YOQEOYMRi7EZ6wzDMhMXCTsI2UfCvMHUliSFrk9sLTpSTBjAOourHIVCzIkhDXQlsjl02BJxlcMETwHmK6vsWF1jbb6TptlB065iMfRRVz2L5V5i2gCbCCYRZhAqRykWiqGyECowJmGNQYolZyFng3WB2ThT90FXqotFJKcFy26LfojEIeOcgSLkVCil3x5zIQSCcSAFKYVchFwKRuCAtU7KhZKEIuCNxRsPRm3WU4ogQqgcdW2p6poCYDxFelIBbx1CIlSBqrI4WwALzpJSAUnYcbzUswDW4LwhdgUxBSHhnaOqLEMUjOh4yhS61FMoZBFqV4OAcx4fHDlnui6ii8MB5yyhrsBAyokYE2L0ubE4grdIsYAhhIAdf/SwBusdtXOkZAh1wi2gqcOtQwiHiVuLN2Dijok7Ju44YtxxM8rVjliK57nPfS5PfepTud/97scDHvAALrvsMra2tnj6059+2Mdom4YC9HGglhrrIaXI+v5N9u3bpOsj8xWP9QHnAgZDSpnFomd9/yaS9GGtvAfJxKHHRoEMznqKRQdBTORSyAUNY2Gogic4Q1cS1gveQ8yJzW6JsVD7gIjOwIcYiTHjSocVgzEwb1pW2jkhBERm9P0WwzAgAsG31PMW5zxp6MkxMm9amqbBOYcpQrCOJIIxBmscZYC+T+QCwdW01RpNvYIUoVssIGvI0QIl6uNTilBEsBZIhbSMdBsdPjnsqoECQ9+zHIT1bmDRRSi6QpmttIgxHGgC63NiICMebA2+NlSzgDGWHasrrK2s6YorR1LqWSz2k+np0iZdvB7r4jhwDU3lqOqACMQhk/ImeTBYCz54EA/ois4SqKpAXQe8t6Q0kLyhbjLe9XTLLWI/UNcByRkRQURn7t57DXFaS04JjP6gYBxZlGAMBeM01JtzApS0fbB4b7DOkIsQS6SYQKhrvPMUgVQWeClYU+ErBzaTxnEUvMeaACSQrN+DtzS2pg6O0jq2NjrSIFijl+W90LaWdqWibj2hcnjrsBaETJc6TDEgliIWDBSEIplcIn0qmKHHGKs/yCJgM5iEcQZnHFiHGAMWvNMVtDCOG/R+jRd8sDR1desRwmHi1uANmLhj4o6JO44kdxwujtgE5du//du55ppr+Mmf/EmuvPJK7n3ve/Oud73rkAK4LwdfOYrA1qKjamsyQj9ENvZvsH79BkOXWFtziEA/DGx1HVubW6zv32C5HGhCIJdCShmTwEQYhkSMhSFmJAtiHCajuVox+qBa0RVHKSCRqhJWrKNkh3OAJEoxDH2PGOhypu8irhhqF1iZN9TeU/sK7ypiilShRYoDYD5fJYRAzj1tVROz4I3VHLBzGAopCt56QlUxxIFSBtJgyMlSz2e0s12szndoyBHHkCJ56BDAisFai2QhDpEimVIS3WZPWkbEgjWGVBLrWz0bMTEUgSJYq+FYyVBSIUlP7AYkF0Kl4dh2FmjainalQTBY68EKRSI5b7K5vqSUFucTy+UGVpYIghSDsx5JyiMGzX8PQ0EE6qrBGovzFu8ChQIj8Us2xCx0g7BcJGIn9IuIFHSlWQqlZEopB5FMKbqascViSsZYi3EerMU7MJLQxwtSzrqCkDHcXTsaaoY4YJyGfo1z1HWFAEOOpNJjSqGqKoyzgMW6AAZd3XhLFostFicaQg0hULwBqVgsIiIFQQjBUK1UGrJuDKE21MFjxSBFCbkUwVo9Rymi+WOTMTZTcibFgrEWa/W5sALG2e3UgrWWUgRjldwQJdci+rlXlaFpDGXN0/t8KzLC4eHW4A2YuGPijok7jhR3tKuHP+04okWyP/iDP3izQ7NfDB8CWYQ+Rbo+YsccWT8MDEPSge8sbQikIbK/bLK11bG5sSD1kU6UYIY+UlLBFNGipCEiOWsezmg+r5eM5EzfR6QKGt50BiHibGHmHFLpF2UQ0jDQLZRkxOggAE82IMYRU2K5WDKbBax1WFsjpeCcpZnpLF9I5NSTh4FcCtZ5xGqI0lmLDwHEYfBYU4HU5Byp61V27ziRdrZC1y9ZdgtCVRNSJKdELhmy6L13mX7oKSURtwZKFOzcUcaHaX1rwSDgqppQVxjJZBG6riPnjMPSbS0pKRIC1MEyb2tWVmfUbUMsQj/AYrFJcAZjHPSJ1BlcpaPcYchYStRCP7GCOEMG4mBIveBdjTct3vpxdm7JNpMGfXBECjlFFl1kfV9k4/qOEgveGLAO6wylJAQDzlBVNc55/a6LEEqhEkNwAW8d3nmcVWL1xWG9BRGGkhhiwtcGGyyzqqFGScN7r0WNwQFCLpY0RDAF5wRnK6z1VK5GOEAAgkmAeEq0dLGnKwnnA1UIyEwQY/U+ELw3OG8JtcOajCCaGy+CxVKSKLFaA95QxGCc0VWWs7goIBbnHBjBedExa4z+21nygRx6KaSsqY9SIGDxBurKEaMlh6/8BAX+87wBE3dM3DFxx5HijnAsRFBuDbSzls3FAmM192WNpfJBc7zBU/qELbDoltiuYShLFouBlDJFCjEWSikYC2J0wFfBUzvNK1rjSblgsjD0A8s+klIh5p5cCsGP+UFvaYLFiAVjxgcZUi5IydjgqUKFKRbrLCll+iHiXIfz1Vh8lHTmGYSYl6QcWSzX6YeOfhhYs555u0azMqMbKmJvoWi41bkKayNlrLqvqkYfwqz3WYoSZGU9g0kUAxRBSiHGyNAP6NJDZ9DOV2QsQ1LSa73HhgowdL12H4ixIJZMYrnsiH1H8IacK0KVWVnxVH6GNw4kkzMMQ481FkPCRsElS6gDgicPiTQIiYLFko0lI5QeTNGCutrXVFU1hh8LkChFr1uKRcb6tDgkUkpgzFikpas+xNHMKtaO24OvKlLfs7G+Tt93DGlJFgh1gxXBStHPAxA0hGqcYegGhhgJyeIqS105rEPJxY41BN4Sgsdg6bpOV4spYa2hCS3WWMSIdgiMpF8k06fEvvUlQ59Znc9YW7PUtcNXgdpXGAOa4B6XLxSgUCRjjcN5SyljXtgqEZsxr+ysxY7dCNa4kWT0/SIyrpYFMeBwIEIpBoPBiGCdAwzOBeoacnakYXFEnvtbAxN3TNwxcceR4Y6mSYf9nB7TE5SqDoQUSCmTUiJ4ncHWdU3TVAz9QB8HNrYW+JUZBUseIiUXDIYmBErODNZSBQ8+4pxhVldYAzFn+gwxLRliJosOYosgUgh1TU6Q+kTMorNmaw4Mf4wxGGtxOvHGe4uRQkmJ2Cd6ljigqgK2AlwhJ8GIoVssWC43SCXhakvdtoTQ4mxFVQv9sKAfltpqmAvLZUcf41j1b/HBk0pPt1xn6Dst5CJryK9kJBcQLQCrg0eE7dBhBhZ9ogRH5T3Z6aDr+p5FNyBFaFtP4z1GhM4a1rvI/j6yFgtV3WjI2ra09ZwqGIahox+WbG11OKO567ro8CuStV0xCyJCt4SSdXVu8HhbMW92MGtm4IRiCjktEYEiCTJYLEYcpEEJVMYVjPcE7/W7MIYdx53A6nEnYEMgDkswjs391xNjJKeMpAwp4o3WJMQ8kGOPpIQZWSyXTMrgjMN5rysT57Bjh0PwNdYYgjeUIBSyFuiJkk2wWixZshbXjalrsIaYhCEKFIv3nnYWCHUgOL8dahYRMIIISMkkjObgxZAkkdDP04pVcrAWC2MHhQBKvKCrGyVmDQVXGJyxUISSoUgBZ/BeVz3WBqwBZwVbjl2v0Yk7Ju6YuONIccfhF9cf0xOULBkQun6g6QesccQh4dDZdW4CA2UsyII4ROIQSTHinNWCNR8YnGdRCmQt4vJWv5wILFNiAG3rQqBkjDVUtaVuLLF3xFQAj5Sxatto7tI6SybjvPa6W8A7Rx0agnGYLEhOCI6Ysq4Wkobechyw1tDOx1CsBzGFIfX0qWPoO1Ie0KAwhMrrYLSWGDv6fpMhD2xs7iMOS3JKpCFSUqbves0fGkPlK6xAjBp+HWKiX64TXE01bynekm1hiEpo3lmsgzYEam/HSveBFJUg+z4RB131GCqaapUgIGLo+p6N9Y5goR6J3HgNCUrRHK0UoeszKWVEwPuK2coM7+cYWyFk4qB555wyQxdJsVB5LfDqu0y3jGNOtRCcH1sy3dhCafBAFQLOCjtWV/BktjYXpJzIcUCsGQsBIyUPSMrYIlTWE82AMRaM5mKVVJRkNKPtyJntz0NkJCAjOKt1AjFGSIwtfBrWzcbinWM+m7HSOnauzlmZtfhKQ6zbjhRiRwLVlZcYGVeEGdDVVcwRyYLJmsOvxGuVPZBTwlpPKQYRJfI4ZJbLQeskkqX2ZpvAxBSctUguFKtFms56Kmfw7gsKrccaJu6YuGPijiPEHfbQtv8bwzE9QYkxk2MhD4X1jSV9EmKX6GJGQoVbmdHMAt47DZcVnV17a6hCpXnBPtJ3Aylm+iGRjKNrEnY5EB30FELwlMqQkraMGSlYB0jGWrDVmNcsA0M05JTZ4Wsqb0kYjLcE53BiaauaKugM0lghScbkiLU6I80iiGh41XgItUNyRbfc4Lrrr6BY2Op09dQEhw9Wc45VjcESvGdW16S4JMZIGqKu0oakXQHDQN/3OBxVqHC2opdI3yW2NpdsbC5JQ8abjtkQcZXHeM0rr7QzZrMZ3jksMAyR5XLJcquDIdE6yyxUBFeREnSx0MSC8IWiuhjTWMzl9WEtaBV7yTinoc2Ue+1KAGIuVCuFLiVyr59ZN/TEtGAYOra2OswYTM0xsX+9Y3OzJ8esffghbBNALpmt5RZ2MzAzSmo+eNqmJadMt0iUlBgMgGj4FS38EoFgdYXgxOKt1/y/MVAMghbQJdG8dkz6eSOF+bzGOQ3jplRI0mGNwXrVIlBqMjhrmc8rKlezujoj1AashosFg5TxL7Fdx5BLwnntstDgdSZ4RxEthoxDwgjIKBbmxIBX4SWRQimJIWZiLkgyOBORlMmwPc6LsRSTCcZhi9NCOGQUpDo2MXHHxB0Tdxw57jhcHNMTlM39Cw13LiPLZSR0Wnm96LUdytUVvq6IWSBmze+mRF1VOGeJMbHR9ezf6lh0iSRgjRYW9TGRMpjGUjUOcZAWhULBG60SxxisN9TOUVWOfsgsYiLnjCkZi9XBYxzBBELw1HVFHTwxFzLaq64dapaSC9YYCqIFZbkQY6a2gsSB9euvZhEHMobV2SomaEjah4qcdADmnIk5kba2SDGSelWrzHFASqHErOlDXROQcyF1kfX9CzY2Fiy7CKKDLTSRqtVWvKpt2L37ONqmpRt6usWC1BfSIEq8GCrnaYKGFCmwtbnJEDOV95Q8YCmszBpdDXo35iY1/26t2a6QN9ZwQC8BhCFt0UeHmAoxiS4uiLFjGAaKZILTAsdYCkPStruq0gLC4D1FNF8bU2ar20cuiVwi8/kazjiMsXhrsGMNgDOGQXqqRnPndlxJWevw1qOlhWBKIUdDsRlrtSBSRoPwFBMpJ7wzFMlI0Yc8xYjB4b3DGbDOYbLRB9cI3lrq2lM5R+UdxWQkq7qnEYskwWQoGfKQKV5DvXj9DI0oaRmr40ikkLpEjpFShGwrQtF6hZzT+FdIKVLGlZkbWwStE4xAzIIVQy5ei+o0nk+RI1Mke2tg4o6JOybuODLcoROiw8MxPUHZ2FxSEmxtdmTA9xE3trFZGQt2jKWPmT4v8QaaKhAqR1VVyLJQELKowM6sadi10lI5S5cHYgYTPNZ6XK1FbG3V0IZAEzyV9dpm5SzOGorxWBdVoAYtHLKjSiJjaK9QSKIaArkUrIdcBFu0NatgkJI01LrsSUvBhoBrMkXAiFCFiraqCT4QQsB5T0oJoTAMHTlp0VceInlIlKx/OWVKKkjW6uycMsOgq8DFoifFQmMNtXcY65g3FfN5w86du1hZW6NZXSWnwiINxAIpaX60aSq6rPlQ5z0FYWu5ZKPrcXVF5Q2zyjOftfi2Io86BN6PBVhmFG2KuhoSA6EOKo9tHcZGun5ds5wWDAVnDd4LzliqWnO42Iq2y3RLDUPXIYwtiKqESBq042DoKTFhDlScy1jQxRhet9oWaGPSYjGjn0fBqCZEzjgRbBK6PJDLgHWGKjSEWtv/fAAXHME5QnBgVJY85awhe+uxXovoUtZVlhKE/tBonFSL4SyGPF6hRQsoYyzkHMkxYpNgg+a0nXN4rFbgU5AiZMnkGImpEKwgTUWFhmpzFlJOmssOuorLIgSDdo9I0Xy8s0RRbQ9jrHY+lMMvdjvaMHHHxB0TdxwZ7uj67rCf02N6gnLd/i3SUBi6Huc9ITmCtQTroFgwQhkypS54YwnBaVWycxqmsoI1Qu0D0Wp4c/fqKsFb9m6uM/QDEjMxRebNjNV5ixVPbWuCdzhnwQnGGaAgRgiuI3n9wnzjMS4gY4U+DmxwFCMMQyLFjLEO78f+eqdz6DQkukXPsDlgigWvhITz1K7CjuTivUeLlQrD0NPHYVT2E0oupGEgDknDcaloqHZsocQ6cha6bmCr6ykizCrPauWVZJynXp2zsrrKjtVVQt1QcqGPA4u+Z9H1eOtwvmCNENwY7pfCEAe6Usip4JvAfB5YbXdQB9VeMM6Qx6r/4D3WWWKOGFNU3MqosqMPTmWwVU0IkbF4SwwYi68rMLoStHaGtC1QYTFs7tsYj5vGSnyDSQaXtX4gGDdqOmgLXillWysC0dl/iuqN4ZylCAhFOzhS1JVuhpwzfYzkkmnnBhdUobLy+mg563WcoPoXAhgn+HDg3BqNMBi8rTHG4/DEqAqdNozUIqpeKQi5RFLWmoJYIkTBB48PBe8t4gOm6Ap3SJGYIjH1OgY92KQrpFIyBQ1X27FQzmC3BakMaKgZbaHNB7QspFCM/lgeq5i4Y+KOiTuODHfEdPiR12N6grK1udTCsKKzRe91NmwF1WK2hqZuaZuGqvIYSZhgcU1gNqvxtWdzo2NDtrR9i/HLcg5rPDEuME6wtqWtKppmTjAVzngkat7OWy1oSyYiXdHZcB2Yrc5oZ+0YOtVBa61KEqeSiDlSxoGTsxqNMa6OhphZLgaGLc33Op8Qq+JPdeupQo2xliwqFFRSoh+WpBTxY866Gx+Eruu0+j4XcspfJNhkSDkTY0RKYVbVrHjHjrbSwibnoK5wvmLZD2wMEdvUlPE9qSSCq3RA5owZWy5zGjCDRbyhG3pqW6h2NtRNRagrnHdgzXYbI0aLp7zXVaKxTvOzXkOZdtQHyFln9Gkc3M4ZrFOidzbQtnOQGUNfyLt2QoY09IgziColUBuVc1Y/EdUHKEW0iGu8llwKdlwRSdGOAWt1hWGsJZaMJNXPwEJJmdRnhpQxLtLMvHpkjBoYmmcG4wx2XDH7MaTqnSMn9LxZJTqdDTCGWksRtK9TSSilRBoSfYwMKTNENQ7LQ8aliI8q3Z59xIhTXQNlMMqoblmkEHMkj2FW43SFI6I/dNZY7XDIhShjKyZgitGVurI9OPkiz5FjDxN3TNwxcceR4Y4U42E/p8f0BGVzfROLJYRANQ+0swpXwAwqr+uMIXhH27Zja5VltjJjvtpSBYijG2XKma7vQTLLfiA6Q6bQrjZUKxWrO+a07YymbnV1lSxDzkgpCOCkkK32hIfKEqynrhvqqkGkqEBTLhgz+iCMX7gLHhcsZdxujdH2rMGQlsKwjHhvMIPKQDtjyEVnpjENiHUQDXH081BdBn3wY8yj8dOgOUFzQPFPB6012mJoraFxHlMZ5iGw0ja6qgueFALLPrLc2KIEx4pdIxjL2sqcNngkZZItDK2jyz04IbuEtYWSB4wp1I0nBKutcMGoiFLJpJKw1pJKgnGl4I2u7oxesCoWGu1iQNDVR8ljC6bBSsEYR13NqaoZ1lTMV2siA1vBUFFhSKSkJB+cQF0RQqviU6Gi63tKieivANrXj6jLbPA0VUXTqDR31dRYr10TMWUNwYtRZcwspFTIY8g1yxh2z9q5YeFA7gCt7xO9LpzmhDPk0RdFnGDw5KgRCusskgolZYYY6dNAKplEJIq2OKZiyNmN333C4pCkip3GqraCdRoKFvmChoIPQtFaOi23sxo6juP4zjkjFFXOjEnNz5zqJsQ4HJkH/1bAxB0Td0zccWS4I6XbiQ5KHDK1t+o82lQcv2uFuIgMaUlJGeOtWkOnhA+eqm5o65Y21OojgEEQ+qFHUCfH6/evs2Otpl1xzGc1850r7Ny5RhU05xmT0A8Di2XPEAfms8DMVRhxmGIxhbGQzuuXh6i89bjiGfJAlKLh2aB55VIcViymWCRlpAczGGxRYsBbEoWcwKRCY4x6PBjDkCLLxYJuuQTUiXUQw+bWkqHr1UjKqvASMGorjC1zxVEHj/FeB9WBVWAI9CIsu8h6N5AQ5tUKlXVUY9udaSpyTtoqSYPUWYv+2gbnA6kXiEJVOepWc/fGaa99HwdKGhUtRw0A5wzBHSh407D0gQTC+PxjDvzZL4TFcy5qkJYBl7Ahg+1ZXWuQ5OiGga2FQEnUrqKpPKGe08530bar4DdZxC1MMpBUGVGEUVoarHdUdY1xlnY+w9f7NWRrDM55IOOcpXKBqlaXXxlXmDllrAFvdVUERc26xBJTonJaZAdqRNbFniSFpqpBhIBqYZRSGPrMMCT94UiJJHHUVEjkLNp6KOr1wUjAkgRnDHU1+mCMPzAqrqSh6TiMipSofgEmIU5r+nWhpCvDkr5QENcbldSWfOymeCbumLhj4o4jwx03hzeO6QlKcIbgLfOmYm3WUAdPtglTNJQ0DAO2dzQxgoGmrvE+qIhS17G16Oh6FecpqegAy5nZamDXyXto5rUaH1VqXJVSpGRDToU4RIahp6kN1rX6UGSDK54cISdRAZzxgTZWRq+OTEFo6ma7UhoZnU4Fcp8oQ4SsKpMYq5LBKRF8hQ9QN2p9LUbo+56tRcfGhmoNNFWBXOiWAyQ9ZsqqYGgxOGPAecSAcdouWFzHVoqQIsfNGsR6ljGyb7lgMIambWnrgBsNrjAZ7x11U5Pbgm3XmA9uNM/yGBx9nxh6WJ03zOcttvKIzVAKFhmr7QXQz8ceELwyRR8AGYvegFLsqFaohWhilWCKFIoI/bDEWD+27xVWWkeyDUOn+gVxGBhiIfiaIhCqhh1re2ibGd4Hlt0GXbelIX5jtnPW2UDdrLBjdQ997FmZ78f7veALwXvqOlCiEj3WMJtVVJWqLGoho0o9O/RHoaAGber66UZzMYtVYQn9wUg9fY5U0dHWFVXWdsRcCn2KDEm/y1QSMSXtrkiCNZ4shmLGFsAcMZjtHPgBAgcN2R74r1BIWTTMnkcCDyBOScSMIV57QOCJTEyCsRaJx+4EZeKOiTsm7jgy3HEg9XQ4OKYnKG3bMK8b2llN1VSEKmBZag/2mHfLMdH3PU1bI6jQTYqR9c119q9vEHttoRPRdq+qsbSzlrqpqHxFKpk8qNRwjomhE/VpcAbX1rRtrR4EWVUcrXEYYyipULyAUfXAmDNJBF2LWJXGEVRwqKhKYBZhWPZ0XU9MWR8oDKlPQFK/BiNjzk/FiLSdMLJYdMQ+0oWKygdMLttSy8u+Z2tzgQHqUFFXFcZbqrrCekuXEutbS0Ip7KsCs1QYjOC9pW4rVtZaVtcaXCU4VU7COs39BiralZ1kVrBOc9QlCzFliliq0NA0lRatGV13amHgKJtsGXO1ZmwNNNtaDjknLUb0Hoxg/QEbegsWfYgRhtgBDh+Dhoarlso1BO8RcfR93K6CL6KdDyL6+XsXCFUDqH25d2MxZLDUoWbWrOFsTXCwMlulCTWLIeoqyRpcsJg0ihA5qyQMIEXHodXVs3PaNmrG710/C13xFRGK0eJIZywpJ2RQ1c5h0BDrgZUhRu+7HzJDSkhWRVfvdPkmkrelyL11WpcgcVSMHCWsrSFL1KK5saiNDCYbzFjwhgdjy/aqVEQ7Lbx3WkOKjvFjFRN3TNwxcceR4Y7c306KZFdW58xnrRaxec+ii6SYtGp5DEseaOMbup6hj5o397BYdCy2evKghXLeGma1p6krfNPgqpou5e3isZyTKh3GUXjHafFTHSqMWJwIlffa6jWK+6SYKSRENP/nXEUIX6jQtsaqQqBoRbeIMKSoxIahZG05FDEQhJIyIlo4pxXSaJFWLqQh6gA26n/gLDjMdqhYW9I0Byho4jDlBLmwiAMbiyU2Fby1zPphDE+vUK+2tCs1beuxlVbGW5ySqVis89gQsAH1A7EHKrnB+aCFW6MKIkULynTuLSowZQpZElYMQ0yYAxXiOSNknYUTVaXT6WeGCBpB9GTJDMNSV52uxllDVTfjCiDgfYMxatDWDT0xR4YszGcrVBZSGuiWWwxdT4oREcF5ryqZIVCMlg742uPG7oAs2ghqvdPrHImyMFqyaw0kfvzBMbqAOFB0zwEvkKGMuVsDxmSCN2A9adAQa0rq95KKtkYqKWj+N6ZIHBJWhbphJCwpBdA8Owb9wcvjcVCZ9gPH0QK4AzWvTr1OBEaxC9XZMILzqL4FRv/tLN5Ysjl2dVAm7pi4Y+KOI8MdjCmww8ExPUFp24aVtRWct2Ad3XIgdpEYI/0QMd7R8AXL7JSTtgY2AZMhDZm+HxAR5k3F8Wtz5k2LdRV9FrquRyTTOs2ZJgTjrVbHG6cPRFavjOAdztbEbFBZ5AqMoRsGck5UTk2/7ChW453XVj+DEktWT5AiasUex4r3qm5p2hpntZ+eYlUKOQ+jt0LBGq0sP+BEGbxjtGeAArPSbAtNacGTwQctdByymmMdsA6/frHFkoTf0bDqNL/rvcV5R9NUOB+weK30zgYxGRc8NmgOViEY4/C+0gIpgVISabRmlwNhWmtwzozFVSosVYrmWkspWK/hbUGtzK3Tsb2tjIjBiiWjhX/LpYoH1fWMYA+0HjK6mFb0ZUmKAykn1jeugTgw9B37r7+WfrnUULDRh9QZFXDq0oKQ9McjlajX5DRPrDnczKidOLqe5lE4CVWOtOrwKVEVGI0FfboLziacOBwWy1h8ai1JCtlomNc4XWEdyBkPaVTOpBCsGcOl2tngnLqOqn6Drqy885BV2l3GdkgNIUPKacwt6w+XykEYTDGYJBrKt/pjoMt1wQclI7KOwWMVE3dM3DFxx5Hhjj7eTopk61BpKM5pWHC5tSQuOtKyJxVRYymAwjhbHggOpGiYUIwgCargOG7WcvyOVeY7V2lnK4gzzOYzRDLOgveGUKlIjjee3GcNpw6F4MFbh1hGB1DVSLDeAw19LDir5wlNu+0IaTFkUUJRJ1BHcpahJK7f7CmlsGceaFZn+KrBe0/lNbSZSibGrKqDManzprM0TUXrA3YkHwxqb10SMaZt7QDn9Rp6dAC1jcOOMpepFLI1FGcoVrStzBi8C7oSsZ6UDJKFkoUyaEiyqtTqPeeiq6TtprsxfywaprUOHFrk5oMfVzpl+71OGAuvDEUSxRRMsBRRpUIpsr2aGtdD28qZy2XH1lbE2SXBdRgJxC5CFhpfjeqGicXGOmmxpV0Myy0tkq8qXAi6QnUe6x059yy7/ap7ADSzGc2ywzunn39JCJbaB5zxo4FWwXldbxSRcVUOOarr7AG1xQLaTaElbYiIroyNwXgZVx3oCjwluqEnHVjdGy0MFCwpj+ccC+eMyaMWgSENB7Qt9AeJUa2ylDy2D4IVg+RMlqwtgdlAUq+OECzFWZwAYlTTwKSxYO/YjaBM3DFxx8QdR4Y7bo4C9TE9QRlSpI8DTixdF1kue2RQLV8nUDvPrFF5aAHN7ZlRojlpoVBKicoY1uZz5rOGtm2p6grXerphSUxRQ3LWUXtHW7d461m6nsWi09BpafBVICYhDhlrPe18jnEe03fa9++NSktXlcofWxX1IaXRMl1liDsDfUpsLaP6L4C2rNVaBd+2ta6u+o6h11yyFJVUroLV+x39LkAHn8FpOBmVGc5jbM57NYhyRosGVULa4tqWZqWmWanxrWoJIJo7VAMrNFxcVH1SUsZmo/bw7kAPPuRk6I3mTC0Fa7Vlz+OofIP34zWhJJXT2NZnglagJy0Uy0YQSdtiQzBWuksCKSDq3mlNIPWR/etbDP06tWtoqjkWC1JUNtx5zT333WjkFsklUoWaqm6xXr+fqmqomoYkB1abGu6uqkBdBUAYhp4iBWs152yNxxlLEXVtxWrhpNrXF1KOSi5FScaIJUrBGW1/lFEJ0poxxDp+iSJG5aqLw6PXkUTUUKzIaEanIXsZ31uKtqMimoYoCV0xGaXlbohfVDxocYwFr2PR5YH2wZyglIAUIeBVlMloOHjIx+4EZeKOiTsm7jgy3JHy7UTqvk+JphS8GIJRQyYxIMZiEOZ1RV156uBJRY2Q4jAQk9ANSY2+hsw8BOazlnbWEMkslxvM/IxF3xFjJHivEsSVNqIbZxgkEc34UKGV1ypeVPBBH3xjreaA5YBZ0yiHPKo4UkZHSucQU8BYQrA4D94L3hvmKzPm8xl1U2OdpZiiXhylEKowJiYz1tbUPjBvZljUebIUXY1kNCy37AeSFEII1KPokRm7GZo6UAWL85bZ8WuccMJudqzOVbqhFGIEpBDIOKehyBgTMfVYB2IhDn4sttL7d27Q8Kqx4yrKYY2uFrxxo8JgJpYBRAWZra30O5Si4d1cyALZahg7p4yzYWzXA2McoQp4U1OSJVWJdRb0yyW4TMAixpNLxvmgBWoFxKoBlwHUadVT+0AxjlC3rK7uwrlAnzbp+026oSMlNWarKs9yudQWVO+xxuBGFccDieciaj1vvXpgFISSdaw41R2npLGd0OpqtZhC7R2VU2GnkiEnfeBLMhhx2HFVFGOmXx74zFSgSkrRc0gmJTVas0ZFl3JMSNbCvpyzetD0PRhR+fZac+TOWrCaCy9FtEU1F3J2pJyoaq9CV2Lwxh3Bp/8/h4k7Ju6YuOPIcIe5vXTx+Eorn0UKwcLMO3JV6cNdBBcCYnUWrdFIQ4qJISf6pBX2BmjrmlnbYL1jMfSsDz2hW5KLFqhVwdPWDd0Q6ZpEFQIbi6XmD51nkAxDTxr0y3VOpZpFkvpXpNHYyY0hyaL5UhnzpVqVroPMmExdO1Z2BCwVzmqrVhJtASyxYIynqlu809CeOkB5mqqlaWZQMv1ySU49oAVkXepZ9EswRn0qRkJpGs/arhUqL7R1oGoqmrUV5k1DFRr6rLnfVMZQX9Gwr5CRnEcnTbBYlQYvZcyja9X3Ad0EXRJ4ilgEw5APFLxFhtSploGvx2Kw0SRLiqod5kFNtcSCGBIRLeaCqnG01Zy6WiP1Qu8Tki15KEilK6viNDcPWk9QTGHHzt2stauUlFnffy3dcqnFgw6cUyLUKKZjiAOLxRYARjSvmnKi6zvmdoaxajGuIgCiCqCibRYmC2RtEXQEpCRdQUomxcwwqNR6lkySzOADlVftDYyhpAwFimiHgrcafi0Z4qA/UsEpgWN1ZRqLbLeeitWceEwwDAmxRovk+ky/OdDnSKg87SzQNB47qA+I04S9Vv33kcXC4INhNguszGsq57Hp2G0znrhj4o6JO44Md3Azmv+O6QnKvK2pqoApQo4dB+ZlWXSgZmspzuGCB1ExpSxCPyS6IZKGhMHQNI0K5xhLv+zZ3/Xk/Ro2yznRtg1dFcd8sooDmTFkVwUNwXWpMHSRlAp10MRh0amk6geMFdB97BDR3nw1VJKRbBIxDqShp0imqh3OaiFc1/fksXe/jxlnK2Z1TQi1Vt07j3EqnlUwGn4uGoLWPnh1xxSEEDw+qOtl0wbqxlDVBtm9Onp1eMQarK3pB7X+Fin64FmHFYf2OxrIbBdMljFcKKIKBS54AhUWizMW6wIQyEkY4oAlafsbUQvZJFCM1SCslTH8qaFGDS1rWyDFUMbiLucMdT2n9ruY1TvZih2p7KdfJmwUbOXJpehqxxhiGvThqzy7dp3Iqp9pmLck+mU3rrC0EyBJwvtKrdT7nr4fVK2zFP0my+hP4vWHahgiDocPAfWksOSkkQbntYMgxkF/dBj7FIuSr8FScmbZRTbiknlTM5upxLpDCUzN0BLi3VgMqDLdKUYKGnY3ow6DKRriNiLs3LHGrJqz8Euuuvqa0W/FQFK9iyFHUhwZQ7TYLdQ6tjFKTH1UW3Vjhdg12GLJlWG5OHbbjCfumLhj4o4jwx1GbicRlHnTYJ0jFZWDzv1AjBFndXCJs9igrXExqZZAn4vOwItRq28B67wWd/WJoYuUbiBhxupvS+qS6gnIgQ4pYdZW1FXAoG2JcdAugJwybe3JOY1GT1nDqxT62KmgUgbJQDZjy6BWQvfLxGJjSe6TChKhssLLZT+2DxbEeOraUTdhlBtPXyimGguzkxGGcX9bNOcqgBu1Ceo64MNY2e5URIksGga0YRQB0vt31uHx6imh5Wk6oAHJkVTi2HamLYIhVFShoa1mGNEVnxijxxyLrzRXGbFj22GWMoaxtZBQcta8utEZfLaQUUvxXDIlq0qDym+DDzWlOIp4usEzLDNePBbHgZVPymNrHCBJCyOlDtqt4CoVoBIND0vJiJExxKzeKoyvOcO2/0TJmX45qNDRkEhEvPM01YzVZhcxRUTS2AoIOQ8Mg7bfWaOS6CKeWbuLVePolldy/fVXk5qoXh6rc3CWnIQ0ZHIp5Ix+J9ZRu0zuElkKyRYMGuo32SC9rkYb19K4hmrm6VcHhjiQh0HHnXWUQQWboi84CiKFOAil1dBxHzOLbiDmcYz1Hbmz1E0mLg/fU+Now8QdE3dM3HFkuOOA6Nvh4JieoKQYR9vrUX7ZOvJoZqWW2lqUZZ3O1F0VCLWK5fR9T0oFm4TlYsn6ZsB4w5AzlVNtgiRl7NtXz4qYM95ZmuDG1kCrngMxIyWTS0Rn7Wm7rz1mfZ9FrdydVVGfoc+UwVCHhqaq1KCqy8RlJg0F470O4r5DjFNnSqvOn5U3VMFy4Hs24ygu4/myqHKgKi9qqrmqa5WTPiBKNeotjAfQ1j1jwJSx6OqAKZcSClkYkooaheDwQUORfYz0QwdWi/eq4KlDiyNQRmXHnBKVEbwy8vgwaJ2+GEcuSTsYch6ryDUUKmNh27YewNiLaIwSjnWCCxasx/oZbTtnvrKkaVfALPFVjQ+eYVgqcSAMWa3D929dRxVqmqpFXEWoZ+RhUFnvot97sAEJM2rfsKmVkjhvxxZAvb6csua6U0aCVtLPZ6tUbauus7knpgFrBPB0g34vumApzHYex0kn3AWJwv7rF1zdX00kIysGEcNwwEk2iupomExTO5rQIG6g9vpDJMmQxrGQUyHHwgEn0+A9xWhnSZ1r4mLJYmsBGB27WLwvRCv6IwD0qI9GzBlJlhQ15527QokDTSfIcOxGUCbu0M9h4o6JO77S3OFvLxMUKcJspSYEx+YwINZoxbl1JAo4j3MabjLe4gmEsVI+C6oJMGSWfccQVREwi2CMykEbwAd1GO2TpWx1WGNYXZ2za9cqNjhSlzUvnAoUHfwiauCFFGLJY1hSML6CrFPxkiHmgrWJSsJoAoautEZLd+M8zjusUxEkHzx1XVN7j7MeYwyWPOZs+aK+f13ZxSRgLaH1RLPEC/jgCU4rq2OM9P1ALJngLFWo8K5CSEix5KiCTEYMMfYMKbMyXyX4GmsKfdY86LLv8VZwVUCyI2WLyk5nYk7EnJAcKU6ls7WVzn5hlWlUhdGOlFhQASk7hmtVBFFXUdY4NdDyWiyWk+oR1N7hrGN1Nue4k05m49prxtZFIQvbGgBCxllPP2yx1S9Yma2xa8cerAmsX3cVKXbqH5EzVmQs/FKXUW/1ewhVRVXVVHVFE1pm7UxDuKKFYWIszjU0M2GIFSYONLUjlczmcjE6exqiFHYcfxqnnXYOBrhuuaD/l8sJxbIy38lxu/fQ9Uv2D/vUI2aALBbvG0I104LJYCgpIUlX14VC3/cMva7C9123Dz9KY89XduCxbNjr4brrGfpEP2RcUMJQktcitpTYzlenWFTptOgK0Amj5MOxW4MyccfEHRN3HEHuOEwc0xOUlZ0tx+1aQSSTlp64CQ6PdU4rn40lZqDXsGvJQhyLzJwYPI5EJmahqMEEAwlxDucd3sFKW9POAl1UtcUQLDuPW2G+0tIPiVIGtZdOA4K6P8acII15x1x01mwswQa8DQxEShxdSj3aildEbbPRPnh1K61xNhBCpcQjZsxRqtIk6P/KmEtWOetxRVhV5KKhTGctNVFzqsZo+1hGVzB9R8oDa7NaVz9mu4ORZZcYYsKPHhl1VVOHoC6uKbLsOpb9ErJQh0DwDd43WBfIkimSyTlj0JWQZCFl9TZRXw710QiiOXNnrBYmpkQuKmSkhmgRMRlBq8WlFG0vRBiGjs3Ffnw4Hist3gaqusbVAes9ddsSjbDYv6SYQqg81nu8C7TtKisrO0gxKhku11mWhDdONQ6i5vTd2EZZsq5OnVS0sxnGOBrfEkIAoI/D6OCZgQyi+fOdKztp6gophq6LkJTAUqg58cTTOHXPqYgU7nDqnVXrwsOdTr0bO1d3sRi2COYarknXMGzu1wLO6JjPW3IVGJYDXqU4VQAqRg1DD4lsLVuLJcGvM5uvsGtlN1WosRZcfRXiNPTsvcpdW2PULMxpEWQRg5WMJZOzdnCEYAmNp12pjuX5ycQdE3dM3HGEuCPUtxcl2aZiNqtZLBb0KTGUQuMcmUwZq78XQyJg6VIkDpm+G4h9RmImGIOtKmZ1hXGWhCEaRxLLLATaNrBzbUbbehbdEiNCVTvquiLlghTZDilaF9jqF0hKZGsYOsEbSxhdPGdNM5KAIxtDXVncaLMuZZRlHoWaAAyenISqCdRVg7UHfDqEWCJkQy6iuUojagAmgvfqvVDVFaBlVDIWQYmMFk+iK8jYDfRbPZIj3tYELxSjA2qImWWXtLOg9oTgaNuRZAxsbnasr28gqacNM+ZhBd+ssrKyE+MsxgkiLSH2pGEglzySBlpt7wK+XdE2zDgAReW1jRDHVkF1ygTvAwWrRlXFqDS0CHkYWEbIeQNj9tJWu1j2C/ZtbtAD89mc+cqq1vvnjmHotYjOGGxomM13MJuvUXJiMSyIueB8oKrbsWhQFSCdMzRtQ9+rlXgqQtO2rM5XqH3DEBOLvsMUXT2XEumHBc41tLNVdq/tpq4bbNF6gWFzE0kJVlY5bu141mZzYo6sreykbud4a5mv7GH3zhPY7Q3N/Dhm9S7+PX+Wq666iuVWYlYnsAXJaq9esoZlrbFYU0AKKWW61NPlgeNW5qqvgYXesrpjTt163MLia0M7d4TgSVlFwqq6BqNtmrYHHzO187QzTz2vWFlptkO6xyIm7pi4Y+KOI8MddTh8eYJjeoISvCOmDMZTNP1HNhkxlpgKMkT6DcAbljEyLAckZugjDIk6OOZ1za61FUIVWORMX0C8YIMl1BXWa9ua6gLoeUU0F9wNPTFFTDF0MbFMCectfsy95lIIo8BPcAFvPCkXvAlU3iF5wFghlYwzIEUHRsyJRZ9VqrqB4APWqrPkAQMEIauEcUrgGPOtluDDuKJQvQCk0A+95pnz6KwtKgmd+kjaiuQYiSGSrCMNeh/LnFkMOpNfmVmapqEKgVIyQyksF0viZo8VGV1I/Vg0V4GBpq20AHB8zebIQAcls7Kyg7Ude1hZ2UmRwub69cRugyEPGGORPH7eTjUGnHfkEjWnbVQSOi4jKUU1sSJBmxjKFptb62x1C1ZnK+zefQJt04AVurjAuwAi9KXgbE1bz2ibGaVo1f1QCsE5vA9UoSa0NZmaZt5Q1579+/erDkC3wFihcp4m1Fjr6XMhjKFb6xzBe4xV0asqVHhjVeK8WIKrERyumbMyX6WqAyaq6FEuhsE42tVdHHfSyTR1za7dJ3DCzhMx4rj2+v1sdEuaZYO1WvjWdwPWWdW2MNrGaEcH0qq1NKsBUwnLtIX3DcVmdu6ec/KpuzC14CvHzp2rVJWl7xIpGaw1hFYl0tvsiUOiMpam9dTzwKyt0J6LYxMTd0zcMXHHkeEOb28nEZSSCylmtjYX6sWQC11KiLH0ZdQAGAYiwpAz0icVuhkiDJF5W+sqZ14RS2Gz6xmy0LZqfe2DpVjY6nuWQ0SALIaNzS3K2C6mVunQD9py1bYNq/MZbdMyLDsYV0slC9mo8ZIVjzPqzJlEMKP8soZiwYi2mqlhVcCO/8U46ENltRWRlJCUR80ANfhitIQwqE235KKV1qNAj8mCFSBF+q2OxdYSU2BjoyeljHhL8gYtdjdUPoxW5F6t4iXTDQMlJoLxqt8QC3mUPh+GHl9VlFLIUnAhUDtHihGcYLNlx9pxrLQ72bGyhyQqu71YbrLslgQXRulq7eJTAw0VHrLGEofMMCS2NiL9csm8tVTe4KmQqJ+zd562nrE238lsPsdWgX2bGwRbY4vBpYQzQQnFWbJYqrrFBK05CLM5K7uOo24aqqZGpLC5cS3Jerb2XosrBnDkDNmp70QINZWvaFcaVuZr1KFl2UfVkeg7Sqi0Yj5r62HlAyHUzJs5FjUZwziGIbPjuN3sWNvNzrWdtFWgdhWVWPaddCq7d32Oz3/mXwnOUFeOISaGVKid5uTrxmOCkIw63DYrgXbF4apMYoshbpHMFvVK4YQ7zKl3CRlhdV5TV56UjRZhitlWCkUgD6q3gRG1i1dFsSP16P+nMXHHxB0TdxwZ7gjmdjJBWS56EEu/jMReLbEpuqoYipBjplhLXwpdjBCzOhekDEPEzhrqWu28uy4So/bXOwmYnBmGjhiFrlvS95HKBuwwEIJn1rZY7+hyZrHsGQYNb1Y+sDqbU1cNJgux74klURYLnFcS8dZTrCFJoosDzs2gjLlsMfpQ+8CsqqmrarQhz0pYqt2jugZ9T0yFYIK21xW95pxFi8tkbP3KSd1Ni7YoppxZbnZs7d9kubmksp7KapjXzWq899TWITljCeQEOWast4BRq3GrIe0iWrVfREgpsrXYwA6OflAr8lk7xzoVlXIljCHTrHLHBu3tL0Is6hdBpTlbZ722YRYztuXp99oPiY2NJfuuXmdYLolrluB7hp3ayqmOnJa6mbGytoOV+SoZYTbfyVA2kJS1ij4LKQ7EOGybaYVqzu7du1hZWWVtbSezpqVtGrz1bK3tIoWG9WuvI/aRRMY4R11DEmjbVVbbOaEJNFWrOX9Rwy0RdHVSdIU+W9uJd14z4VlD9dZpS2ITWrxT2/smVMyaBoOl7zoq59mxssZ/GIuMnQm5FPo4qGqjFWZVRTP3FNsQS6FqRrdYHzFS6LqOKFp4t6tepZ55CkLdOKoqkEoh5UJMqmNhEc1LBwEZf8RwlIyuyI9RTNwxccfEHUeGO6TcXqTuu4wxka6L2w9RTjBk1S0YiuaWu1wYUoJYaK3HSMGmTImZOESMNSy6gW5IqpJYVHCg9IbNbqlhziJaMe518Ax2UOElMxpWZxXdyaPgEKJmYyJCTImhROxgtFXQOVwVKDlpSDIlypAoUahcDUDbtqzMZgSrXiBStP99HF2kUlgMPTFmKpPxC6MkaS2p5NGO3SKSx576MR+bE7kUumVka5Hou4ipNaRcMNSuoqlaJUR9lDQEngWckEVnw9459GP6wvXYlOnTFikn6mbGfNXiTEOoNL/aD4khDWxsrlMwmKBW9IvlflLsscYSfI0bzba8c/R9Tz9sja6pmb7PLDd6NvZtMSw7vGtY39xkZWs/zrb0Xa/GZnVDVTdqKuYCVWgovgPrWPYdSGGxuclGs04qma2u5+RT7shJe05g3ra0sxZvnVrItw27ywrGB6644t/Ye83/I8aOOlQsAVNVzFda6qrFWLR1LxdE8miKNiARyImdO3Yxb1cpkrluscEV11/DyScej8+CA3atrLBYdpQxVB2T5qjXl1sMMbK+sY4YqKqgZOwSYgpREs4YxBhc5amwSIIkauleGGB0rBWrIdxgA36u4lq4gveORmvdSLmoEFNKpKHg3fgDFyF1hb6XY7gCZeKOiTsm7jiS3HG4OKYnKOsbm2xsbrFc9KRFxBYlzZILfUwkAGu+QBpZw6SOsUOyFFJMRBG2YtrOIToXiEMmSqTvOoZeScZZhziDDFktw2ttDQzekWuPzwaHISd9b86ybUfuvaO4Uao6F3JUb0hvDHHoKUMhDQmSbBd0YVX7wIoajmE1j2qd04fewHLowUHXaSGY+nqy3cOuojgG5zzOgjGWhLZ85ajko96Zgg+Bpm0JVYVgcXZccuWsZlzWqM278yTjRvVAi5hCLpEhGrX5tqqKaK1RRcfx9VEqga3N6+mWmywW12OdVV2J2NPULb5qqJuWuqoxxeBtBLEsl1frPdmaJiTVIBiS5tgNdItNvM8s+wWVD9SzFdr5nKYKmL5GgDioM6mI0C022dzYh3OORYqYEDjthBM5bnUn7axV5cOcSKLdC947djcrzNaO4+p917Fae5rgwCS8q1RGWnSlutXrD9Ny6KmrhiYEiLV+z6HmxD0n03rPcPVn2RgWSCpEMvNZS98PmDRw/fo1bHXHM8SBfZsbXHf9dezddy1dv1QyqBusEUJlqfGEmaWaVfiZxdcO11rskEAMaSgkp+2YzpVtkapswOLwo9y2EfV28cYQnBp85Vzo7YB3qjeRPfQ4cq/h8mMVE3dM3DFxx5HhjnQzeOOYnqAsN1XEpqSsbXfGUhjJJCeMdZAESZlKhCY4asCI0NSeWaP5zT5lkhSqA14cuZALSMwMm4MOTqOigdY48pBYbnU416qFubOYRp04q6BS1lI0x60mX5aqdmoONioiFgGTwXs/qvxlVeZLBZwZC9fGVZboasqAtv45hy0aulTZ64KkTLbqe2GdUyXknPGixUrOGbxorrkMOmO2Ggelj5FUCq5yOG/VBfSAiNVomR6jfgYVUEphGCIpZX2wSiTFDjGFUDes7dhN08y1UK0yFLXDHC29jRbAxYEivXZAxELwFcG1eNfgXIVzNcY5fAk4O2CocN4TbAUzR9pRKDFRcsFbx9B1UHuV4/YNJ+w5gV07duANNN1Sr19UkEtyYUgDQz/QLTuGkti1skYTKuoQcE4L61IpbPYdeMeKC5RBBQNS7JEAVR0IwSMl0/VbmFptPBdbG2wstlh2Hatra8TV3Vjv2eqXeg9NTfCenTt306yuYNFsyaJbAAVJA+v7rmPv3muo6ppr9u3jqquvYHHdPmZ1zbBzB2vHH09JmX49sVLPWN0ZCG2F8ao+Kka0HVO02E6SIFaQUXzMWEa9Ua1lEAGbAQQZZdmNNQTr1BumaEGjOAMp0UUl62MVE3dM3DFxx5Hhjv72UiS7ubWpwi+oy6RxgphCIRMlY1Fp39o7amtZsVol7wzMguW4lRpXVeztEl0WTClKLKUQnCV1PXnIlCQIohX6xsIoFORs0EIqY4l9REqibiqaqlHLbAZyFowuaHBWw7TgNJTr1Oo8i7AYILKkIDjvqGc1odKqdkENw0a/anWIjIkUM9Y4mqpWY60CpghDiohRF0/x/ovaQbVdsORMGfvpHZnKWoI3YIQkCYdHUJ0BEUPOiSElcskgQhwim5tbdMseO+oVDENPcJbVumY2m1E3DcUISQZE1MlTXUZVOtogSuapUIxl1u5g5/w4koFcEn0caKoZ1qkVupEKi6OpZpgGtqoNZAwn9sOArRsqU6AYfPCstDMq63AWKu8JzuFmc0wW3KajpELbzliZr9BQaOuGnDLLoWcwgjF63H0b+7l+a5NVH9jY2uSqK/+NtOyQtqGuZ3hXMaTE5mJdf4gweGcJ1iLe461niIm6NgQfGEpGrEGcpWpa5r4iFfV12VxskdJAGTrW9+3lyiv/HRcqumEgWP2xKgin3+ku3Pnk0zRs+8l9SOmpG0/dBsQYitGxVllPKWDF4Vwgl4SoiPn2mDjww5VzHlU5HdlZ1E9VxnZY3bOUMhrUlbET5NiNoEzcMXHHxB1HijtuJzUoxRa2dW1GG2pbWVxbq1snUNtAZaFBqIzBlkLrLCeszti10jKUwkavq6Z+kSijql+wRj+cUTTogBGWeIsNnnreMl9bo521+sAWh4iGepVSCiJqk26weFvhnEo8yyjV7LyjqmpizCyNoaCh15V5w+pKS1WrboAKLBVKKvT9AKjBFGJpQk3tK+pgyUXzvBKjroaMQVzWEJwziHWYGEfzqKIPQ1Oze61mZR6oKkemYETttmW8HmE03UpKmLFXu/lhSBjJ+CFR1RXBelyoRiGloj9jcYlBFR3VcCsBgmSwtaekQh1mzGc7qauWHBcsuy2t+m/y6O5ZaNs5IlmFmqx6U/RDhzjVqfA5seiWpCjMQkM3WprjNRxde4dpW3KfaJqGWCJtO1P10KKKnsuhI5GRzrI6n+ONZe5rrrzmaq7vF1y7/xquvPxTlC5ijCP4FsSRUqSLHSXDrGrw1jGrG2Z1Q9WusTJbZaVdxWIZ6rAtvuV8oKlq6rZhY32TruuQnMgpsu/669i1czdzY2iswYWAbQPJCKvtnJlrcMYTqpr1jb3MksEWrcZXUzJwNiBjfQOCFhYawUjBFDOuhEaDs/F7KWPXhxmVN8uBH4KiDruSUW+Xko/pCMrEHRN3TNxx5LjjcHFMT1BC4yjRMEYBta2stoS5pekMEqGxhpPaQMiZRZdYRKGmUPtRNTJlSs6kIdLFQrEq+FOcU2Ms70Y7c4sLjuI8dd3QNJpvtc5BgQHNGxdTVLmwaPW6s47VlRXmswaxhi73pAwOoFitMI9RVyY5gzFUdUNbNTjnGHImjjNUispf55zphwHJmToccK20qrgYBRPBZjXhGkqkFFT22jsoSdv2csYHy875jDscv5N2xxwzm7GVMmk8J6gyoDECoiJNKRa65ZKh6xmWHdYbbZFrVpntOI5qvooJFRkhp56cEoYMFOLQEwctuFKPDKcW5FgVoUoJSqbb2tLuhjpTN3MN7waHM56U1blVUsZk1TtoqoYq6Gpia9GBC1x59RWszeZUVWB93z4s4zmdwwaHrSucV6VNS8EET6GwNSwJdaBxO6icBylcnQY2r7uaa//tcjav2Uu/jPh6lV2rJ5GNoctXUPpOx5LN2Er1JJIkjBXa+Qqz1TUVkbKGxgdsKuxaWWPXypzaeRbBq7JmKWSE0NSqgTC281ZVRSmRRbfJclgSi6pVbmxu0nWRxRIwqiJqra6CKJkcBWMiPqRR0VLN37y3akjntBjSWq2hlKLdLM45ZPTUMOrGjhRLTOq0a5yMK/1jExN3TNwxcceR447DxTE9QfFVRXGWfqkhpeSgriz1LOCCRfpMI6jXgrPQZboh4SrPIhdKP7Cx6Flf9vpQlaI25aBtZwZVAnSeNLZmBe8JIeCsPtRdKQxDZmurp5RMUznNzxWtgJ7VDStra6ytrpJKgn7BYrGkRF1tDENk6HvissOM8tKUohXQRuiHTFcKzjuMFVxlaVyFd5Zl0Yp41VPQIiX1QMikHnJS/QUZQ2y5lLG4rxCjrtKatmXXrp2srM3praXfXNIXVYO01uIxatRlHUmEnBMpZYY+slj2zNfmrK3tYmXncbSru9QAq4gWuI0h4RR7UkrEGBm6SFVpO6K1lpQS/bCfULXIyiqp9MQ8kEvCeEvTNjRNQ987un6LOCRiGgD9rOumZWVllbqdM8TEelxn//XXcs21V3HCzuNp65rY9wxxQHpdEWx2C2zQ79EGj228Kkc2DX3sWaSBPvZIUkK21pBzz9a+fcTlQDYVJ5xwOqvz4yhGWCwXOOfJkkkYkrUkKYgP+PkKYdayc9du6mbG1Vv72b+5ztxWDF7HlMPQhprdu3YhAWbVKsfv2sO8XcV4w6Lv6DY2iGkJObK5uc7+Hetcd9117N93PbYCSQZJTonCqW5GSkW7VEqhriuq2lMk6RivPOJFzc2sAUYVSbSDJDiHsZbiVKFUawcSebsFVVSx8xjFxB0Td0zcceS443Bx+NUqI973vvfxLd/yLZxyyikYY/j93//9g14XEX7yJ3+Sk08+mbZtueCCC7j88ssP2ue6667jSU96Emtra+zcuZNnPOMZbG5u3txLIUcBo0ZYNlh18RQDWfO91jnEetYjbGRYYhjEsr4cuHZzwUaf2BgKfdYPrAqOOjiaELTi3DqMD2ANoQpjXlfdSbtlz/71Lfbu3+Sq69a5dmOTza6jT1GL7VAdAR88oQ5UbUMzm1HXDWGc3Xpf6Ypm2RGXPTkmEBk9GVSAp53VzGcNs9qzY9Zywu5d7FxbZT5vCZXO3HNRER/QIrSctT0w5UROGSvQOI8VQ4mFrhvY6iNDLhgfqGZzQt1qIZNVfYJyQNqajDgBq21ofYpsLBbs31qw2fcUC6GqaKpqJA61Q18uN+m7ha4ulz3LrQXdVockIdhKlSOdQySTUsei20+Xe2yomM93sra2m/l8lfl8zqydMZ+t4p1nGDqW3ZIu9RDAB0sI2vrmnaetPMNina3N/aSUMEVoKm2/3Njax/7917Cxf6/6igAxRupQsXvHDo5b28HxO3Yya2oWQ8eQIxvLLTb6BdeOVfDDkEgJ6moFZwNV27Jjxx5Ou+OZ3PWuX8sd7nQPdp1wGvNdJ3LciXfihBPvxAknnsLa6hqr8xVwhn39ftbjFtesX8cyHlDqzDhjaGc1q21LsOrpUqwjWUfvHYMITV1j88B1V1/Jxr7rqL1Kols8ZIMkgWQoUbU8+m5g6JKGWlMhD0LsM7Ev419WzY9cKEk07VG0TN/h8MZTB08VRov1IvRjmF49XY493pi4Y+KOiTuOHHcM3XDYz+nNjqBsbW1xr3vdi+/+7u/mcY973CGvv/zlL+cXf/EX+a3f+i3ufOc788IXvpCHP/zh/OM//iNN0wDwpCc9iSuuuIJ3v/vdxBh5+tOfzsUXX8z/+B//42ZdS+6FprYEZxiKKhKmQVv1glMVw0Jhf1TRmNgXhhipURLKudAn7f+vvaPyhjy20hnvNXRnnSr1WYOx+p6UhI3NdcSorXXMGlprKkvTWIYcsWLUUMlZnPeaxy2qABmCI43aC84arBT6xZIYs1aS///s/UuMZXt21ov+/u8511rxyMfO3HuXq8rGx/diMOeYayzbAiEeFmDcwOCOW9ACyXIhGUuAjGhg83CHBoIOHYSFBB0a0ADJAowAAYWB0uX6gF/4QZXLVbkzd2ZErMec8/8+jTEjyoUPV3v72M5KVUwrrIqItTMi1przW3OM8Y3f5wzDZmDYboFGrX29iGQFL5XCSQmuO63iIgY0qTSWmKi9yq47t2FjAlOqvVNKpbSKt1qESlsKhtgbFY1Wa0aINtKis1rmyjWTamWKidM8CxwJKE0qk3S6ofVKyopbsETJdY38Fme/c45hCGgrrn+0QlnDMOw43z2W9rebZM7tHM47SXlVUrHHeSbFWdbujMYYMdCxGja1AtsraZ5YlokH4walJW21t8R8uqYuEy1n4ViYxIU1jM4xOPlZqjdOy4GpKuZJMlJU67IRkRstd06nE+lhwaCxYeD8/JLt7pzNdkvMC4fTiXEYGe3AxfkF59sN4zhwVU78zM/8JO/FBMPAxhhsVZzizKurK3otdA1LnEgKzi8ueff84yit+Mxnn3LMGRMLTils7zilaapKvkquaGeQqHMhUdbSJVAOefOquVDWlVkfjADBUpNMFiVBX7UBqtLbKuCYu5yVeYocD5m0NNLpgwvNl5Ju3GvHvXbca8fr047T4dfxBuXbvu3b+LZv+7b/2+/13vkbf+Nv8Bf/4l/kj/yRPwLA3/t7f4+nT5/yj//xP+a7vuu7+Mmf/El+5Ed+hP/0n/4Tv+N3/A4A/tbf+lv84T/8h/nrf/2v8+67737g3yVNCV1vkxTFprNkCZXytkhKpjFyY9cgr+mYozcEozCqy5pfqTir8U5TJSIL49c2nlIoIymfbTV9nZbI6SRJpCCtVeUsWnmWXDBRwr+1gtQKp7jQT7czWUUIHt01NRWKhjjPTKeJ3jqPHp6x3Y0Mm4Gz3UZc/qWQa6X1TiyFnAu5FJkl946/ddj3lbpoNN4NWCWmpV7XE+12fgsYpdk4z24c0cZQgFQkzKqDJJw6i1r5CfKXSupmzhXVOt4YFJ3TcqLtFXEvhrhhlFU4TaM3Jbvxw0BtDa202Km0BEwNw4DtnfOzh4zDGX4cwQz0VkSAFeS8sCwHlvlEipH5dGKeZpTzKKUoOQlAKhdayxjdSfFETDN1rUZ1A907vWR6qUz7I1cvX7B79ICYHzIvM0YbYs3M88Q8nQhmZDNs4KyyP3/Iq+HzKA0tRl6993lebrds2wVKG1wYuHzwgEcX5yijOUwTy5Kw2rAdBsYQcGguhh2Pzh8x7694cXPFMxcIaOYcef7sFzm+ekkNA1enG54ozaMHj3jn0VsEa9kox6v9K37yU/8BWuO4HCmrx2A6LoTgqN6gVJdtlCywJ6WQDw3StBcglFIdbaDUQqkZhcJ5i2+e0hvJVcbmGIIThgaK4B0nVUixSrrqG6gb99pxrx332vH6tCPH17TF8wu/8As8e/aMb/3Wb7372sXFBd/0Td/EJz/5Sb7ru76LT37yk1xeXt6JDMC3fuu3orXmx37sx/ijf/SP/op/N8ZIjPHu8/1+D8AyJ9n9p6OsphtFrp3aOi0Vqi0E52V+uq7XWQVnG3HqL6UylUppoFvDOS/AJaT16LxHGUNV4j5uSTIKUqmcppmcClqLo14lMQHV1lh2GzaDwzuDzhk7L1Sl2IyB4KwQFh1CCmyNJRVSrlgFIWjC6FFGiStei0jmUimtscTIMkeurm6YjjNjCOgh0HunZHGUe+sYg5No8dpvm8Z3znoNBKMZvWcTwl2VV5us8kmLtq8hZwJrKllO2lZEYKw2OOQCvzncsJRMU+IEN6rD2ur2zhGcmAZLkeev645xBuOs8CNqx7oR50aGYYe24kZvOZLTwmk+cNpfMR1uSHHhdFqYT4nNmZdtAxo5R4k+L5GYI+2053i8Yb/ZUZrMsltOqwGxMB/3XL98Ru6JsAkMRlOaEBBvjkdyrTy9OGf0A0c/8OzZZ/HDIImxNXN6+T77J28RtluUUzSl2GxGLs626xqecAKCNQzBYY2iasUYBi42F4xYbuLM1dUVD8dz9sc9r16+oC4ZrQOxVuww8PjyAZe7Dap1Hp5dcL67JJbK+6f3SUuklIayBq0cOcvGgzFQWyWXQkPYBTElGgZtBKrUaqMANTZZEWxy7rrSqGo9Z1qVTZS2vvWWitWK7Wgo0ZCmL23duNeOe+24144vPe2Yx9fEQXn27BkAT58+/aKvP3369O57z54948mTJ1/8S1jLw4cP7x7zPx8/9EM/xA/8wA/8iq/HlIkpCa1vsGANBYmlbq1TUqG7FcKUM6Y2zjYbzrej3KnOhespEUtlbJYhVILWOGsI1mKtASvUw0IXscmNFuVOOuVMa5IE2lFYnznNmVPsXJwFzs82aA+5Q6qNrdJ478EJ/rkYaZkpZahVQE1ay/dyaehcSHmdE/dOipllWphPM9P+JB4A58V8pKC1glZrRlbrgLisldbCJahVtg/WjI3bHTJJAJVthN66rCwCtYuB7ha73VuDItHmMoZVlFSZp0htGuMdm82I0Qa7GgJDCIzBSlZDySwlkWvG90CtFTpipNKKzXbLGAZaX6hFEkeXRdqtJc3oXshJvA+9Shy8tVJNGRqNQlUGpS2lZKb9FXtjia1wdfNCgreARkPVTM0zp+MNN/srznZnaGOZ4kwqCWcdwa4CuLbH5Q0M+UgZ4zzDZivu+s0W66XFLiZJ+Zu2uy27MWCNJpYsyHNjQFmCHVhKxGnL+faCcXNOSopiFePukocPHsncWMsbztQqSlk+/u7HmPZ7nn/+c9Ajqmp0kVKnlUo1VUheXarOWgp1yVQswRk6mpIr0zHJOaC+kLrbSkXpjrUa7cza0s3CvdCa7SYQnMf7Rd5Mfg2OXy/dgHvtuNeOe+34UtMOPoRsvBFbPN///d/P933f9919vt/v+ehHP0peW0tdKUzVsn+nVhjRXTKnQYIFCjutOR8Hxu2G3iplSswpcUyZWOSpuDgHWyodhVfSao0NUu2UBq10tDLS8lqzMkqRrIE+gw2R0hpKdUzQXNgtzliCW5M9u4Q80RttTQuNOcmdppFEyFIbS5KqwsjCuWCWW1sjwrtQpJs446VmERplKyIYzQskpxXJ1vDBo7TEaU8xMS2Rk504HCb8MFCVYo6R0ipyq6xptaH02l6uVZJNs2Qs1Cr78soIuMcoGIJhu1tD1KxDG4W28hi0tHprbSxLxBiL08J9SOWW+9CpOVNy5DQfOZ2umY83pGVPrRFUF+Q2HaMURstMPngtnAalMEXRFcRSONy8IJjKMs/sTyd57pVEuzttqCmigmOeZ+Z5ZvQjNWecMlA7MUnGx2E6Mc2zVCgdrDOcv/UW52eXbLc7jB94dHlBsE42MNBUrdhtBs6DY3SW3hUay5IlU0T1RmkNP244e/CQYQj8lp75P//Lj9O9552P/794evEQ0+C0ZOaYeP7qinQ4EnRg9+gpy+nEPC/03gnrimClEVuk9SKsgt6oVV4vugDFtDHUDjHKeqoxQv2kQ64d4xplaFRbqCXSGlhlGQaH1YrgDHoz0M4/uEn2dR332sG9dtxrx5eUdpyN4QNfv7+mNyhvv/02AO+99x7vvPPO3dffe+89vv7rv/7uMc+fP/+i/66UwqtXr+7++//5CCEQwq/8ozRdXMgrCdEoQ+3SSqStyZ+tCnCmVDZnWzaDJ4QgLSw9MzU4xsKcpA0aS2EIARsrIRbc4NDBrzyBtkJpwGpJcKy1U7p8D6XIJTMvC342XKTA1gV2ITC6gFpbqWUVplIlzyTGCmi0sWC0tEtjoRaF1useeRXhY2UZlNKYl4j3hqE0utK0rqRVlwsuCzxHobBhQKxOwk0oqYhZb154dThgR0czmiUXKhp0pyv5PYx3qNqhNFoqxJiZl0hMhSFYhs3AZjvih4EwBKm4jBZU9S3K2wgwaSlJWqCl4sNIq4rWNc5vJF2zS7U3p4nrw/tcv/w8y7RH9YrTnVYKvcndOlphvbSCnZfXx3qP2ujVWFc5Xb9ElYXDaUa7kQdnj9ltLKopluWI6gKc2oaRjR0wGGgdVRvzPPNeeU7ulatXLzhevUTT8IPFlyBI7nHLdrPj4VuPOdtuBP9tDD03nNIMzuG8YMpr7VLZpcSLl+9ja6PXztnjBzx9+ymjNqSUefsrv5KUM+/sHqFqoxmgV/bTxHsvn/Pe8xds/UYqMysbEEoZgh8oNa2hcLJh0JUIR6oZkCRZWiMMAujKqnGY9li/vk5aAF01y/OXrAYj53S3GpMKVTWadWKWLL82Nyi/XroB99pxrx332vGlph25vKYsnq/6qq/i7bff5kd/9EfvhGW/3/NjP/ZjfPd3fzcA3/It38L19TWf+tSn+IZv+AYA/uW//Je01vimb/qmD/Xz9G10tBZokS0NXRsmV5QCZzS6d3KpOAXb4NgOg6wQtkbVhqTFLZ9r47gkSq3YJaGcYRgD4+C5uBBxykpRdKf2jHOaruxahRh8F8QzWuEU2N7YBcfFZuRsGMQkFiMzUFNiSYm0JI6nieNxolShRHYUKRd6W1A6o43c7eu1eVpzIaVEzLLql0plThkfAlqbNZujQa2rsABqbTnTaVUCrDbBsRsco7c4rchdKpRMl+pRaVrrVBSpNnLMxGlhniJzTMSUcIMljBI532mUVgmqY6xCG8Ro6Cx3xMFS5G8tjcvdpbTRG4RhgNqYTjdE77g5vE88XVPiTGsFozqtKUoRsx9KSS5FCBjnQWlqy+Qq9WApmSVG8vpz51TYuFGySJRlHEfm6UYEzlpGN9ByYZkn5jxxOF4Rl4mrUpmmI2k6kU9HTANtLdrJRocxhnEzCjirVOwQ5AIvRVqxdEltbRJDXmphfzzw6V/8NBfDhs3lJR975yPsxg2mVXZhYBwlVTXljPeOYC21N5yV5ziVwsXuXPDbfgQ/oDvCh2iGOS/U0qEbasqy2lgrIXjGcYN1FqM8ZLDNYrojzRltG0opfFCgG/Mxoxs00+gaTG9UBbkLDyQvjeub5ddANX7jdQPuteNeO+6143Vpx+kU/1eX5a84PvQNyvF45Gd/9mfvPv+FX/gF/st/+S88fPiQj33sY3zv934vf+Wv/BW+5mu+5m5d8N133+U7vuM7APjar/1a/tAf+kP8yT/5J/nbf/tvk3PmE5/4BN/1Xd/1oZ34pTasltUmY5AshSqiM44e5y2lwbFUTO9Yo+SC6F1alSmhjSEMQcKjSia3TsoZ1St5petZrdkGT+6VlCM5d3TvDM7S+upWr4hIKLBWsXGaMVhxRfdKo1NaJ+fMPE3knMlLYpoWljlDY608xBGvVMbQZE2rdclc0PqO0aBWtkBDuANaa6yGYfRkhVQt68ltqKtLXyqI4C27wfHOwws+8viSYbPhkCtTlbZzLhKWplF3WQsxJaZlIWWZ3ad1Dq1uZ9EIR0EpyX0IQZ7/EDwdSEliwFUTg5/uSubZpTAqyX6oNTEfD8TTHlWL8BGMsBF6bpQCVCUbBoOVLBNtSKUSc+E4LfQkv+vVzTXBenpHDGwprQjtKiK9RHqpDDHRciblSFGNq5v3efnqGS1HaqnUZYHW6aXQKoLRjon33vs879x8nHk6oTtsx1GMdkZL3HhrzCXjjKJbQ8yVmCuH44HT8UBAsTWPcF3RayEbxaFGnr94j+PNkTZYhnHAGwPacrHZcrk7Q2nN8XTi4uwMP25pNvDRt56y8Y733n+GVoZHF4+I88xNumFeMjkXggk4LN6MWBOorZBwlCR0yNpWc2PvmAalNHRpmFGjnCJTVnS4VFRtguPLD36D8qWkG/faca8d99rx+rQjz7+OHZT//J//M7/39/7eu89v57t/4k/8CX74h3+YP/fn/hyn04k/9af+FNfX1/yu3/W7+JEf+ZE7lgHA3//7f59PfOIT/P7f//vRWvOd3/md/M2/+Tc/7K+C6tKGun2ScqlopTDaYJAE0oYmxQwpSxDDiuieU2GqBW0MVmuwhpbNHTa65Qa6yJ21NlgvM2DnDSzQVEPRsc7gasdahbZGkkZZKzAlMeYxZ3kBqwBwapZgsbIk0iRuat0lRklZg7JamAbIqlzKwigIQ8BYhdEweFmD3GwGgneoJqtb1hi0c+ReKS2Jm99oCjLrLlXmiGfbgadvPeTx4wcobYinGTUvUMQg1zoyB9ZaHNlqhUC1hu6K0Rk2g5MET2vIHYwCYwXr7VeREVFX1GoB4SMEL7NwAVJl0nRgCVtQG2JeJI5bKUIIuGo5zdBUpneFagqnjMCvrJXnN2aOU+RwPNFjZjlO5JTZWC9vQr1zOh04TkcG7zlOB07zQg2eMUZujtf4zQaTGnlZaDmzTMI7qFli2buGXAu9NnJKfO7zn+XiMz/LeLHh4aPHZK94xzxGKc0hJ96/vl7PAVmvKynz6uqKPC9c7s7wQ5DX2mhiSZTcee+958z7A5txxCsrs3InKap1dJyfXzDszrgwjoe7C6YY2c8Tl+eP2IRBOAfHa7ZjINuRMhcmJdComho0i9cBax1LKhIdP0dp0SpAN3LtpKSwvtFzZ2gWE9SKRpVzQilF6W19g3nzdONeO+614147Xp92fBiS7Ie+Qfk9v+f3/P/9AUopfvAHf5Af/MEf/F8+5uHDh78quNL/fPQmAUZ1Re62UtZWXAftMMbKapdiXXcCp2VeGWulKoUdwzpt7fTSSNNMi4leZJaWc2FZEkv0giPORWa6XaLEBc2scFYTgqUWgdRorSm1EHNcd8kVGoPRCoum5MJymEhTxHaJqA5DwA2e7bjBOSeR2CnKHfzKKzBGM3jP6BzNNAmyWlvRTdjUGOMoVlgNVBEKpTVUwQxrpfBO2nbeDxLMhfAOnDV01agrOlsbhe6SvaAVWKUJxuCt4XI7MgxOrP8NjLeS52CkUhM6pFRxxlgUwnKQr98GSGVSPnAyhpy3xJLJOa1VVBAnOZqpHaXSUh2lNVZr2Rqoja4KaVmrlpqJeUHrTrDqDhy1LJHpcEWyjnmaKLVR1zeoOS3UlqE32RwYRvJ8oikFWkyCVQk0q9OhVo5XV3zm538G7yq7B5d89df871yGLb4r5pp4/v5z8jxzePiYi90OoxQvXjzndHONAy525+w2W2qTyPqUk/gdWsMqBarTG2g0wUqux9n5ju3ZBU/PHxB6xzsv8+PWybXJa6w0SjmGjefsPFOLwpsDc1rIaYVDaUuvnePNkRorynas07KYoQylQk6d2hOld3zWQpr04JxBawNNr+F2H+z4UtINuNeOe+24147XqR0f9Hgjtnj+V0dqkrbZlbiznbXY3ikadK70KaGMglrZeDGcnY0jUylUDVVLlLQ28oR1J4KVcxFx6pWSCsfTQrCK0gpLjIL17QrdFVp3xkGz3YxsN56UCodJXOOpJKZlJueM0QavHc46QNbjliUS50yvDe8du83IZhgZwyDBVrUQcyVlmQmH4MXgFhOHZcZqgyoakhjfbsOzvHE0hEmglFpNTBptLcF5GgrvBabUlaK2TioFtNyx11QoLdNXA59SUl3dch6MBucNm11gu9sQEZHz3uO8E8qj6vQuoJ/e+/p8ypaE8B/kLr/XSk6ZfbnCnCaJUzeGIfjV4LjBOI8C0nES7oJcg9J2roVaO/RO8J6qOilY6AVNo+QorIi5soSFMEAp0spVyGqd1kgwmRag0G3EjFLCqRhCoCjNdrdFrbkivXVYIvHqBa4VDo9e8PLiLQZrKKcJEzOvjjekeOJ62NBq5ZeefY6rly85115c+6kISjxIl0CvVdPheOAwHYklUfsWFBit2Ywjbz16zINhiwY252foV57dgwucMhw/NzGdJrTxDEMgjFsePzKEYHl1cyW5LSnTO6Q5Mq8jgnHjsU7m9EYprFNC30RTc6P0RlEK0x09VzoVqsE79xt+zf9aHffaca8d99rxerSjpNcEavuNPmIROFBnjQdfPwyapTRQhVYK5Mpb5yOX2xHvA8dcOdVKbh1rkLms1WKkMpamJG5d9duLIHJ1LUauWgrOromO2oKW+fQQDOebgRI6ucq+vFFAF2NZLYVqNUP3aGeJ2pJzY4qJWApjd9ArqndqKWijGcJAcJboLSknnLfU1jieFpZYsa5DUqjaqUlSUI0xVNfvMM5GqbsAJ41eLxhJ8TTWoYxZzVCdFW6AWcO6Ui5YK/NCpcWE1zrk3tkZASVZY8itoXVHWWljW2vkOVWatprmtLIoY1FGr+1i2QBJKXE8nDDaY32gK8W42aCNYRzPGP2Aap263bD3hroKZ1Py2qlWUUYqFVTHbTwbp9m/vJLqsRTJf0iFWs9Xv4DkjCQlWxNnVSplZdW6JbHIdsc6wmh0Sr9Nm4WKQmmLNYaeC3E68fy9X2J79pDgA3mK5CXx6sUzep7ZbnbQOi+ePeOXnr3Hu289RW8CGw3OOc42gbPdwKcfPyR1eb1KjByWhUelkoumlkZJmbe2ZwzeU2IizTOjthxfXjMET0+Jm1dX1NZ58s47XDx4RF8Nc61X9vs9tXQxaXbNEit1rbSctuQqWG7TwRpNXwmiLUNdOrlL2JdWDoN870097rXjXjvuteP1aEdLH3w2/EbfoPTW7zDNyhpunVe1yVpWr1V2t2vDasUQgsxUO8TcmKYZpRNhHLBBVidTirRaBVpkZD++1cZcFnoXOJDTFmfsmslRJXvCKJSVdq82t+01h2odqjAQlFbohjjJpTyglkZMmY5Ae2IpeIWELBmDbo4YPMd5ErPZHJkXaSsOOmC0QqHQtktmRW9rGqq03bQWnLdeBU8bqYrMus/egKYlStw0EReTC9S+So6itjWBdEmcUiZXAQYZLS07jaCgtZIWoNMWvbIVFA2txDRn1vZqK5U4LTTnmKeFGBeslXm3cUEYEbeVqFJY6/De4b2BXqm9YZ1QLI3VVHkY4+AJBjyFuDdc70/k0iiVteIpzEUC1uKy0HojrSyIUoTsGeNMzvEuv6MDp2Ui58L+tBexNxrtrKTS1iaV8ssXvDf8HF6JC/76+gU3L59hVcf1Ri+NejxQTjd85uaam1fP+ehv/Xqy6owhYI3hK996B/PonMcZAgbVOnOMGA3HOfK5F8+ZTkeCviTVwrLMgt6eF2p0d/NzKuw25zy5eESME8vpBtVXQT8ecc7QapHrpzUJ/CpCDtW6iw/DWhHU2qAqmmqgBQama0ErhesfwoTyJXbca8e9dtxrx+vRDsqXyQ3K7XGbEaGULMfVBrRGXuEyji5rcLUw58TNceJqfyKeohi5csEGB0rRS5HcBdTKHFIEH9gMjq4kjVEbhXaWXhulZmotWK+xi7QQu+p449j6wNaNgMIYxWgM3mpyFWqjd25tY4JeQULDxjOMa4vSWHqtHA8HTlHWx5ZUiKVwtht56/JcwDm5kadINeKqBtBO4tOtFVOY0YqyIoh7l7t5tBJksQyfMXSo4o6vVXI5ShbT2RILxykSs7QWtTZoa4QnkJFZuLYyK1eWXqF1EQ7rDc5qrJY9+5IK87SQdGaaFqz1bDcDYRgE5d1E7A/HG9mwUAiRUMuFX2pZTYXtLo5bqYYxCh8MrRi0s5TayKmAsWLsrBWlFDFn8rLQacQURXzptFaoNdF7lVyOtSqKKXM4HJnmE85qtoPDBrca9iRNNE97rp9/hl5mlHUcjzfkeMC4QI4TuikMjaARA+J84OX7z5iqtPB7LWjknNg6yyZ4rII4J9KSebXf899//ufwTRGcI6cFaLx6+T6uNE4apumAdwYfHIMbUAgIY0qFOSXBqdeKVYrNMDKMgZubvSDRVw4GgK1NsmaUhqbQTaOKpsYu4WutQ89rnsybfdxrx7123GvHb6x2tC+XEU+T23y5w77FFhtDy5LbUFvHKPDe452ltErMiSUXlpgpMaFZDWClSqZELdDaejJ3WndoLRkXxlmMK1QaKCg1E2sTaE8sKLtgrKW2jgtGdsFLlTt5rQluDRVrUEteW7jgjWLwhvPtwMXZiHUCOrqJE6d54mp/5LQsKNWJWdb4LndbLi93glg+zLiuKCpRmty5Oy85ItaKSQ0lF31KEbWuhZVahUtQ5G+Sdb9OLJUlS5XlUNRameeFeYmUXAlGBEZ7h7EWrzTawRhG3CowKWegY21HG4NWWirULnPwwXoqmVoawzgwhMAwBlKp1Ao5zZxOmt4zzllaLdI2B1l77DLzV72LYHap+JTWFBRzypyWSEmNYeswzlLXv3mKCzkX/OCl2tRmDRZL5LRIFdQrzliqktdqmU/kuGA1eKsxTmEMkjSrhQOgeiaeXlJ65zQvgiTvnakJkXPOC7VVoOOsQfdKRUSz687Pv/w80+ffo33kIxgrK610hG1wmpiPBw7LwnbwlHlmOtzw7Jc+g8qZzWbAWnAh4EbxESxlYamFR0/fodF58eoGpSB4T2/SIi6tcpxmprXys1a2R2R1VVY6DRrdFC2tiba54Sgfyo3/pXbca8e9dtxrx+vRDsoHDxl9o29QjHd34KRxDARvoCtUN8RW6a0Cio21bFa2Qa+NuSZKyZJOaSyCtG60LHkCvQmKuZTONGcOQ2I882y0w45OIrS7rK/V1jAgGR5V2rQ0WQtMpcIyrwmphouzrWwKrO29uCzo3rFKycwVOdELnaV23r+55uWrPccls/WW7eDpdKzXhF0g7Dxoj00F2zuxJpkrW2khKg3Wa4QSpMRZ3TKmrwCkLqKbW6X3KjPzKumnpTS662gpjMgpy5YBAI2+RqmjFV47gtYE41Bd00qnpkJdYUPKmLs2YV4pmLVI+Jqhr+IgSqFVp+pOLYlWHCkqehUnei0ihKV3ShMnve6yfuCtzK1L7dJWroWlNFqtDOsbjRsdLUbBdnepcMdBeAm9d3nTyJFpmTFGmAmqN2qt68aFRpsORtYcm5aVUa2MVJ260/s6sy6VmrNgqctCyY3D4USKCYOY2gZnMFXx/P336V3zM5/5Oabnn2e6vGCeJw5rsJ2qQvHcuoFpjiyno6wsxkSJCVUbqsv66NnlAx6981HOzs9ZUqSjOBs3mMdPeP+t91C9Sa5HKZSU6LUz14wu8maUdMEmRbYNbyyqs7bsV0NkE9hXNWvk/Rt63GvHvXbca8fr0Q7167lm/KV0iKFKkNTKKIlNX9vOSskcs68gomHwWGeJXS7+2xml1kpmosaQaoLWZPaLzFhTzlzvD2inGHOjG0nu9Ebml6ULrdE7aUWK41wuqJvphM6a0ivn25EnqmB9wDctbeRcyEWCmZYSuZlnzDShvGM/LywxU1E4pxlGhw9WqgsDxmuU1aAaw8aResUlRyv5zqCltEZpAwqpZGIipswmSB5Ip64X15oAusKgeq/4YLBW03sTTHStGNSdcc4YhTFG2tx9hTKVumaECPcgtyp4b9VRXaqJ2za41R2r5HmoOZNSwlhLqZVSG8NgoRVqYRWcwjyfUAgqXOvb3y1jvcNbh1srYWeFm5Cy/CxjjFAQV1+AUtIubk1SV0G4FD0XUizEVNE0TDOwPt5oeTNOpbKUjLOrMIp14c4s2ZW0fpdYaAU6Ba1E7JY4UUrGeM9SMlfXr7h69YzjdmTeT6T3X0KqxOOBm5vnnO+25N6gNW6ur1Bac35+yXZ7xqurF7x//T4pzThtoVdaNQzjGZdnD3ny8AmFxul0QsVEigvOOZlrG8sxn0SEVuMgXd5Neta0pQnzw8ubRF83D2qTkLucK6pLvPqbetxrx7123GvH69GOD1PWvLkKg9AHUQIRSqmglIUuc8i27ntJOqbCWEVTMKfKKaY78RFzlcEaQ2uaWjWtSyXU1zvsuHRevdwzLAU7OqzXmMHd/ftnO8/lxRZrDcfDiZIy2IaKGlM1YXSEQaiRrRUUBbOyE1pvhMFjB8fSO1fHmVgPeGfZOs/2MojxTkNvDdsVS67EmGi9ibnLidDWmFFLRluN9pbSxVxXWyWnxLJkUi5sBwhOeAQYTe5tBeiIwc86g7HgrEGhJedi5Tfo3nBGciw6Ikqty7zVWo/O4tyvta9shbqayGQdsXfw3uGckfOaTowR52Qm3VqXdmyXZEylmsw/1wTYkotUJyiMsfgxoDRoK63623yN4D21dWFPtEbX8obTu3gKSimknFiWiRgn5uVISgtUaZ/Xnllix1oLNFBSeeVVPDVrcquSN7dGpfYKTUBVMWWsESOetZa0RHKtxCJzeb3MtNMVv/Df/79QF+q08Is/81PEKRJPR65unnF+tiXlRM2Jq/01U8m88+ApDx894rSccMMIShHnhTRY7HYkPHjEo3c+wsX2nGk+ARBbB6M5LAvvPLhkGAe4uSFnYX8oJZU5CuiGVoRdoCqU1ewnVbpsZ6RUaK3hxzd3zfheO+614147Xo926A+Rgv5m36CURmsyC8+5yy621qjW1vbpmp64GRnGga40x3nm5nAirxeDNkZc6N6iekGvQCKpZtraVpUci+k441vlzG7orWOQlcHNGLg83wEwnWZqbQJE8o5hM+AHw/lugzH6blZbS6ZUAST54Bh3G5pWzHNGWSEAboNHK4m3LrVRWoHSUaVR42pGW9uTNhhONFKteCxqTUWtFVotzHNkfziiamf0nrNRcj6aVjKbVgCdu9RPbdYE00ZpcmGCoKhdkJyM1iUwLq+vg3aZQhXUtTbYlVlgV7YBGpTTGOfQzkrompIWpEuJsW/lrhxZ09Mrh4EOJSYBVM2RNEfOH52xe7Bjsx2Fw6BYnwupWkLweG+ZcqW29aPWtQ1fxfSGppbM8bAXj4DWWKsJXtPaGiSnumwvWHk+GqwehbZWmfLztDYYpQXpXAUIFQbH9mzEW8sBudAbDR8MITjatPDZn/1J0nTktN/z0//n/485RrRT9DJR60RcFMv+yIur95k1fPXHvpqLsx3hN30Nr67e53/85E8xHWdqzVy89RZf+Zu+mo985CNYNHOaKTVznCdO80wsFWU93gbOtjtZs+1rmqlRgvVuDVVg0yWSntapIOdSW+flSiPgrA8OavtSO+6141477rXjdWnHl8kNSkkJZWU/XikB+dC6JHh2SSzdeMeD3Y7tZou2VtqiSGs2lULLkh+BtyjnIVdUYjVQCaRIKZnJ3gKEel/vjmvHOYexjiEEER5tMcqwCZ7Lsy3byzMBQRkv7eGSyafINC9362kuOKyTEw9ruY25Lq3hnV5NT0KSVl2CtII29NJFYLzFeEdaE1Vtc7QcZV3QGJmHnzKnU2KjNYPzbIaBwXtSb+t6pVR+tVUxPBkt6GmtmdWyXvxiJnPOCKzJGGJKzEsC3fF0nLeovgKVjMU7J+3r3nDOMgxeno/g6VrRuwRkxZRoa+WpnVSmxhnCOEBVpJhIOXOcJhHx8y0PH12y3W6ZU2ZeZplHKzG79d5wTlPp5JJJMUJTpFlyPZy3+LDirFthmY845zGqsx0DUwTodK3WNU95/bUz2OAYxoAPlmETYIU9lVwFDaI6ITiG4BnGAWc106QAeQ6cM2jV6aozn0784i/8HK9evOJ4tSe3hguGYEEpmbufpitevPgccRBypzWG0Souzy/ww8CxXxG2gYu33uLdh28zYum6czjuOR33DCGwG4e70YINgWEY2W62vHx5QOeCqmIg1MaA03L+99VPUcvKo1ipqa1T9dqFeEOPe+2414577XhN2vEh6ARv9A1KW01mah2c9taoIG0wYLCGi82G3WYg+ECi0ZSWqmMM5JO068iZns0X3N3cwpvEDNU6Qn5s69cRRzLrjK2hqF2t/60muMBu3LDbbBmclzW10qi9UlNivjpxtZ+Jqaz5H5rg5cLvWsmJ5Ky0x7TGrnfgy7KQWyZYJyuA1t2JYNeKuBIkWdcAldbCMkBRUiUtjc1gsMZirUNrIzv2rQnESIsLu9ZOCAZnLNZbpkkExVmHUorNEPDeYbQlpkwsmWFw+PWDDnZlJlhn0VrRa19x0BZr5O9tqqOMwnRDb8IT6ErhjGEcR3bn5zgXKCnhg8dYw5wS1ho225HNuCGEsK4XNgkGU9ytQXpv8dbczbiLytQmLe/t1nN2PuC8lTlvK/S+UiypKL2a25Ts96O6tL+dZjSBi8sdl5fneGeopVBaRVmFN0aqxiZCY53Ap1xwdCSTIy4Lej2HFI68LNSl0kpHo2R7wnW0TsQ4cVyuef/Fe0TjeLV/zkffekIpC0p3YlzwVnG2HTEatPTzmeaJlzfvU0rmcnOJH4T9UJxh8/gJ2Viqd7C24QXUJc3n1uQcssaSW4Lba6x14JYVIryMN/W414577bjXjtekHf2DFzZv9A2KWk8u1fRqOhLqohjYNLtx4HK3xYdAV7DkQmoV7y3bMXBaEjUV8pJoSkuFkzJkqU70epEK0VCMXdZalNaknKm54ZQmlcwhLrTS6NrgvMBz7LrGqIwll0ReEnkuXF8fuNpPxFzwRjMOnrPdhnGzwXhLqoWG7JLbIXC2O5PwJ+85TCdq77KF4MTQVXMlzxlV5A5ZGxFIbUEraFqgR7VU6B4hP+j1ahQh1VbMb22WO2FnHcFbyb4w6q5Scii8NxLmZQypVpQ1WOck5MsZjDLQV1jV+koJWCpjrWY7BoYwkGuViqUVSf1c78I348AYAsHKmmZR8vOV1hil8caiEZEHEXqlJfejdcF0y88WUmdvld46OSdiTCitGEbHuPH44PDBYbyhaZmLp5opTWBPXcof+RM0YLRkdYyeYZA5cen5LoXVdEU2wloYtiObzShrh87ig+N4nMUcWCrHKWJDwxtHL7JaqeCOL6JVR6lCJTGdDhxz4dO/9FM8Ob+gtcr7L5+RliOb4Dg3ljodeG//eZp+zKubK67nI+8+/Qhf8fSj/OLnFGy3PPmKj/NVH/lq9g+e8Mkf/3H0z//iFzYJ1uvGe0cIHmst7W6tUs6D3hq1ypup0W+ufNxrx712wL12vA7taF8uI56O0OxYd8Vv6X2tV4yxbH1g3EjyY64iMnOtKO/w60XQ6fQse9laK1RtgEIbjXX6l2GLm6CpdcPSSZOcsLvtIITDKqts9TbRs0l7UDdHz1F4C1Xi2OOSOc2JJRV0MNhgcMHiB4/xBtstuWb8ENjtdlxeXKCVQJq25zuMsjx4+JDtsMF2SymJljK9dakyvKEiVYgzll4qqcjstLRCjI2UG0tpFKXRzsnevlEoJXTHWitWD+tstJPbLTHR0DW4waO9RdUi6ZlaMXgJy+pNZsFGy4rgLU67lgbr7xSs4aTBKC2PWee/WgkgyjppV5ciaa6KjqYzOoNRStrDcabSOcWZ2updtaiNAqtoxsrn2ogXoHW4azMbho1jczYwjCPKCD9CZslOwEtKPu9thQzRUEaol01ViiqAoaqGXVdAQdONknyUsy277UhrhetX18Kt0AKb6sbcGQb1SuZECcEzxcSSIrmVu7n2OFr2cWa+esHVzaeptfHZX/wZcoq4cUu3BlUW3vvszzFfveI07dkOjnfeepezccvjh0+4uHzAg7NzbOvYDptxQCkRYrRGW4O3FmMVPni2o2dKkWlOUvVrvVbX4I3kobypx7123GvHvXa8Ju1QH/y2442+Qbm9De6tgRVDFOvqk9WKMTiGwaNXyM8pRbLqVFUpPaOUxH/XVulkutHQO9ZYvDNsRr+uwkEuVdbreqFmRY6VlDJqdHjVMbXRq4heb1Jg3F4kyiictnhryUqRkwhNKp1hNGhvccHjQiD3IiFTXhJKx3EkhIA1ls048ujykmADjy4uMc7SWqPUtLZBZc55C6DSSlq4rQv2eEmJJTqujgvXp4gbHMoLutpYQUxbs1YStawtOyitruyDLifiOOCHgHUOkxJbZxhDYDOMWG2l/QrQ1vRQY8naiAO+ZnL2dG9RGsyK+Abk+ZO+O611Uq7rNoRUoBLD7mCd5ZdSaO3ENB/FRKcGtLEoo9Y4ckk/NdbJvN1ouu+E0eKCJmwt49YzbjwaadOXZrBJgaqknGRDozasNUAXmqZBnPp5wbQVioRGir9OV5VxG7g43xJCIOeE9Z7WwQCtZIxWOO8ZdyOtVkKw0g4Gcsm0XrFW461jCI7xfGBTFjYB8vKSeYm0fET3Ji3q3oT4OZ9YioRyufFsfZMRKJXtnf3xhkQixhPT1RVWGZy1JAWZijMG5wzj4HBewvBak5GI6rICqgxYr0C9uR6Ue+2414577XhN2mE/eGHzRt+g9FLRTi4kWJ8I6akyOCNYYS1ZEqVWppzAKFrtZKqsFNKkxdv7F9pjTjF6yzZ4nLPk2liSQIlaLeQo2RWqVnQTJDZrGFiKC61B7pauAsDKRmjSEk6Jkoq8WHSCtzx5+oi3334LrOe0THQF27Mtw2ZkHDdrYqXGGMcYBnbDyPn2DOMcKUdyTDLTbn1dZXS0LhVc77DExLJEem20lkklcTydODsLeOux2jL6gdSkndiVmAhba+SSqblAB2ct201gd75lGAdZy3OO82FgM4zybzlHV7I6WHvBaI1zjpwzSiniEiklgJI7cKOVLN11SQrFiBilUnBrReKNwyg4Oz9jGAcBCw2SJbIO88UTUSvGWaw1+EHeYFSskhFhVvaDlSoRXUWYtNzcG6PQHaia0hSuOcFtI8KxvoNJpV3FSNlqQ6GwytxVYLlktNP4wbLZjvK3l7KuDxZiamydB62xg2WzCxilyPPE9SvLdIzodb1Q69vVR2klh+AYg2P0jpoT3lkUihhX7HwrqJYpaJa0EMvCfrnCnz/g1f4Z++v3eXX9gq2pHK+vOe1fyfmpNPRKTEWqWaeoqoEGlNBL5zmjUeQsGyjGSuv6TT3uteNeO+614/VohzZfJjcorHdmSslMWaMRcI5isJbRW5yVE2DKiSknDvNCLpllSWtlsz5ZXfbSFes6nLMYI6tizmpSEVGrDboR4bhtGU5LorRKzo05Rpm/9bZWABJSpXqTOea8sCwJ1RujVXz87cd83ce/kkePHjKVzMtWaVoxbkY2u53wAYyVkKoYmWNkF0asc7g1EXSZlzsA0TAOYiJjXV1bL4RWO04rLraB3WCwutF6oSlPa9KipqyBYb2RcpaPOVJKwSjF+WbDgwdbLi42bDaerkQkz7cbxmHEbrZS/WkteOiU72bvsoYIIBRKpRVGA+udv7qd47fVfNjBe6kCN2FkWSy73Y6zy3Oc0mzHDeMwUGqVufXajrXa4K1jGwbOz0YOKUm72ChphzdNb1Xu7Hu/M8VJsq0wD7SVVvktDyOlRE6JuCyULO1gGmil8cavrWWL945KlYrXO7QB66Taq6WwpMIcK041jNWMzuC8VCFDsOw2gb1ZZHvBu9VYaLid+XsbsGZA64BWmVYVKUuyay7yOtMKNWlqjBznE7/wsz/O/vIB1++/IJTE9fsv+PzWUWOkL4uUZarRsoS6qa7wVnDf3jpKLpLU2xS1NFIU0FTfyLz+jT3uteNeO+61A/iN1w7rv1xMsr+sLauVujUKE6zlPAwEa1ZDlLRUp6VwWjKtVkqUdFExtIE2llIkQMqu7bZU20oTVHTV7lqVrcp80TuNpjEvieOUaU0Rl4x3lrQMlJhhdc3XLBHXh5tJeAe9s7GG8/Mt3losFm9hG4KsyHkJwFLKUZrEZZ9uDpymE4+357Q1ctx5OSGLUmS1AqS0xlnPbRtbo+il4qzh8dmWdy7POd+sVUTr0pZdL+5SKzlnwWmnzGmOTHMEOuPo2J0PuCDtu1QK6E7T0K3GjZ4KYgZUSIWoRPqX5Ujuma46wyaw3Y00Cm7jMHu5KJWWu2trNdvtyG47EsLI4Ad6L3gvlc1gHIMb8W5E6SxCTMMog7UO1RXn2y0Pzs+ocxJBcA4bDC1FchIRtcZKG3wIpJLoDVjTS5U2OO9lXU5rYsq00gXmVUUvjTZoZVYTpJMPndGtorrAmESoJKVW+tL6TmyVQubU66zddjl3jVLoLq9NaZkYI9MpUVKj5sZ8SiynRF4yNVe6MSwlk0pBx4zz3G2UxNMrsgNVI86CVZ142rPMkZgy3moGrzksnXnOlNrxTpNiITup1nuvtGLIKUul2qGkCrX8Bl/xv3bHvXbca8e9drwe7XD5g2//vdE3KCgwTlpytffVNf+FGbJzFkWn1cxxmTmcInGu1CoYZIXMe603wgPITVpvBrpqKCWZAmK6kjvnVAu1FKEBdkWnklcYUq2K2oSpUJZInJcVZe3pObGcFvY3R06nmdrBBCeVzbJwVjNaNUajcNaKcap2mqqQC2lZOE4TOUZ0y2LuMxpjLM4HlLY0beioL/ANlIByapfqxmhFCIGz8x3nF2c0b2nO0UqjZLkouoKU8hqqVcn1l51YVi486dN1Si/EVsm60b0CLShkHxxWG6xS8m/EePd44zQPHl5wdrGjqcr5xY75ZpGVPacx3hDW7BMXHC64OxT5LYPCh4DzHqMdYPBupLWKN45gR9AFamW727AcJ+ZpwWiwRlOVJnYFGAbjCVjCaoirtZIpgMy+lZKZ7mKXuzbrNEWcldwWry0a2QQQi5uid0VOwscQwx/r6p/GWgjeMG48uVdykSC1W4gTKKwz1FrXGXlb1xwLJWdGF7AYSq7Mc2I5ZZyW1cu+VnEgPIpcEilHYnLkmohpprfMqJWQKVOhl04wFu8s1khInqpNQGMNCb9DuBaqKZqSNcHeOnXJMp54U4977bjXjnvteG3a8UGPN/oGRWslUB5voFSqkl1u5zzaOAmqWqO7b05SfZS0GsN6AdUEVe0txhuUNfQms2lcR7nGdhs4G0dQWw7TxPOrPacMtXSSBm9kh10h4VOlVqwN9NqoKRO1IhehTC6LVGKptBVeY0gx8erq6g5gNAbDYPQXklIVtFyIp5nD/kArCas7qhVajZSaKTHSSyZojW0doxTBaobgmVioWjqwrUn+RzMWFQLDdgDvaItkQ7TaCNZxVRrWGIYQmN1CR2a8io5yimEz4oaAaZWcFo7LxNn5GV11lO5oDdZpLBatHa0WWhfE8Xa35WK3ZbPZkXImDEHIjcGjnWFztmF7tmG7Gxk3A8Za6IKDbq2iUDjrcGGQSq9mnLE0pQjW442nofBG8jWs0zirYJ1Zq5UYKShtvRInDbaB05Zgg+SCdIF25ZKY/YLXlhoLac6oATQabwxKGWIVMmfOmWWR6mLt+mO0kQ2FMTBsAqY0xo2jxi7bHt7BOnc3RqM1K59CqnNrNeM4sNkN6GbQWuBHyyKwrlvxNKvQtN7otawiVaktc5r3TMsRTYMuRFJUJ3hD0I6z3YapKl7uT2uyrwEDZU1MVQo6Vd6UVyJkzWKCfFOPe+2414577Xh92vFBjzf6BgUlDIPeGq32X7aiJqt6S6p0nWmtcnOaWJZISuIwV7rjnKw+dY0glBX0rnFO47xhGBy78w27MAoG20EsmThlchJ8dNKK4BxjMHSjMV3ajN45AQ1ZyexIuTBlCZOqFUanCMFweXHGOI6CWF5nw84GjNIS42A01jkGoOXK4ALeOuiFPE8obchxpsXIVis8YOkMzrDxlrhA7pXYGq00ppiY4sJZa+zCgLaGXLPsqedMHAaUtZyfb9ldbJlLRtkTtITzjmEcuby85PziDIBUilwUILP43qml0Iyh60atnVoLy5JorbMdRrzbYKxDW49znvFsZLfdUnvj/GLH+fk5Z2dnhGGkoX7ZVoCQPMMQGNYtC7K8KfTUsGisEiYBStEbwutoitQrXsmbknWaZZ+Y5oVcG3XN8JD2sqYZ+RpdEmp16zil0V3RYyU21i2B29AvJVVIa/Qi7XOvLd54rDX0Jpsd291A6TB6TzWa7SZwthmoHW5WQ5vqSvDk6xvg4D3j6Dm/GOkJxk3ADV9w7bcObRVKs4rTrUA5KyuSsq1RJF+kr+MGOtY7hiBivmuwPRvWzoKmmg7OopySSrquuSYdqbY6d+F6b+Rxrx332nGvHa9FO1z/4BEZb/QNiqy0dWmLpUrNlWDlBbXeoayhtMZ+mbmZF+I6HxX3tUI7QV1ra2jINeKcwwfNOFiG0YFTdCePHdgwHCPOTLJquP4fKMEvO01WStgAu5HxYocJFmMsJzWh7EnSO41mDJ7dduTy4pxNGBnGkbBChoYwoJQSsmNrmNYlXtwYhjDgnACTSkrkXDkdjpSUsawUTK0ZjWHjnOQ41CbchAoxJqZpYplnUlvDs9a7b+0MOEOi4TZB1gJXwmPzlrAJPH3yFk/fesK4GSRKu2tKrXjr0Q16aeSaZD69xtkfjkf2pyPLvNBKpdZOTJlpXjArXfLB+bkkt+52jGHE+wFrvVRg9DtDYmsNbYzQKJ3BdEmT7aahjEbW+aQKKGvAWF/XJa33WNWoNPT1fBcgdlv1KLVWIavhsTZhY+ScJQY9C7GR1ihZotIVYJRwHLbjyJwWShKqI8hc3BgJlAujx1dF0IakFduzDWEM1A5u9Cg337VsjZXVSGuNhIYZhQ4O79eP4EBDaRJEVqoMt0NwoDSpGAYT2I4DwxCIc1jD78y6vCBtV6xcK2NobDYBlMIFgzYalLTFa213IlNaEy5EhV7eXJPsvXbca8e9drwm7fgQdc0bfYOiFKtre40qB7w1nA0DmzFgvaP3zillDnOk5MotdddoTTeCeTbWoM26k6402jjQGqUdGdjnBVsiBuEECFRH4D9By5NYa6WhwFvMJtBHT/UGN8j6nCkZrKUpOYFC8IwbSZTMvdOQn+f8iNaW2hupdTG01UbMha4NtSs6a5pnaSwxMS1RTgjnaDSchtDB5ESbJ+JxpqTKRhucEqBUzZm4LFTVBWi0EiRzr6QV+53p3E4LB+84f3jGo7cecXl+iXWWgzsRbMbUxMZtcM5SayXmTCyRtu7/T/PM8TTTKjhlqSWzzJrD/kCak1AvjUGjGXzA+xHvBiFFIu7FviKg5SXSKx1SYbTFaiUsi7W1mFshl0TNhVJlDREtoWhBW0qv2ODwIRCccCLy2tqUFT1hQbSVgJlr5TRNYirrXcS1cefBGENgt90wDJ4xDjQjK47qliSJCFdfcei5d/pahYzbDTFn/GDXmHpFcJbgPEYZjDXUXsmlMLovrGfKpsZAU319vhPKQPCW1lnj7rtQPZ1jDJ5hCGiOzFMilUbriqYV3Wp8sGxGv4qIMC28s8JCqF18Ek1mzPIGKKuFb+pxrx332nGvHa9POz7o8UbfoAB8IcFSoWxnDIHL7ZbdMBCCl0CzpshZ3OYyqxMCojIGrMwWjTGw1jS1anJSLLrRqGgLTiuGoAg7z7AbiDGzcYbzwa3R4oapNpoWPkQ3itIbeTWZVQWxFHKVuZx1BqXgFBee31yDM3jvGYY1rjtnDvOJVBq9Nt4/nJhyYeODtNuKxFfn2igAVl7Ktp7XtndMK6RpYjoc0a2yGQKX24FN8GzGwOi9zHiLbBGo3kgpknLmFBeKary6OXI8LOzaete+GbDrc1eQv2kXNjw6f4C1mpgzp2UmtwLISdl7o7SGsw67IqzLkphPM/vDAWc9KWcwGus93nmscVS+AK8qpZFSFmc4crGW2tbqRdqSTcncHdWlNZkLcRHWQ7PS0rdWo3THOM24GdnttoRhoMeF1uQcsau50XRFjAt9ZULcEi5b6zKjNhptlFz0w8B2tyW2SptnqV7MLU1UiKIxJVrJWCvbFdZY3Eok9dawHTyDlzXH4MIqVFBbJZbM2dbihoD1XjwUVq8AMOTNz0qQWC2VkrOwN9a/SWvNEByKzum0sBTJC7FWZsbGG4bRrz4ICN6vK5Fi/pT4GJm/99rIurN8CKH5UjzuteNeO+61g99w7ZBe3Qc73ugbFG0MaNCq0pXCYNh4AdJYY3BGMedMaZXWV5SwWimJXaGMBGuBOPJ7axg0JRda1eRSGUojDAYzSgvPb6xQBBfPWbA8Ot+gu+IYC2XOJBRWW6ZpAdMJtaCUYj4l9icxJ6lWaV3aaq024rJws99zvh05nqBqzSlnTtMsc+uUeH6zp2TJXNifTkwlUWtljok5JcpKSFRai6mtN5apcDwsXN9MaAWXm8B2s2GzGRiCJ2hNahXbIS0TdVmY50X27peFU1x4/+U1082JcbOB1amvjCGVxJJnait87cff4Z0zS6MRzy64mrfEvJBrZp6PGGNpgPeW4B3OeHJd5/s3E84uOGc4P3vA1m4591tGN5J7I7VCyQnVDTQx47XcqLnTsowaaOtqXe10nSm1kHMmRhGyWiodQ2uN5RbdDbjRogeHC57a61oly7qhpM56aZs7S9iOhMHgF0PMFadkTotSKKPxw4AzDlbk8+ADwXksmrqul6Yl03MiK02hS7Ks82hrscah9br1oaS6c0b8AlVgEwRnMN7h1lRUZTreQK+K0sXYVnonVWF1NC3hcRgZDXQl1dccEzELAtwYcFZTumygoBXGyetklf5l1U8jV3neVFfo2lliek1X/v/z41477rXjXjtej3bwIeqaN/oGRQHaGrSV4C9vNLsxMA4e7+16wVWWkmmqritnskPeWsMiiYytNlItEnxk5e6uAq1AXBqojguOograwebSo3VnpwwPLs7wSmOvI8fjgeMcOa6d75gzwxgwRnM6RI7HhZwzfs3+sMHKnNYbWk2c5hMoxTFFbub5bo55XBYOMaKV3HvmFEk1kUrlsCzsp4WSC701tsHx6FyDcsSauT4sAvjRmt3oGQfPEDxGKWot5JwoaSHPM8tpYTpMElpVKkupzIeFHBN9M5JSopRMzgvzPFNL4bd99CNczs9J9SV23DKGjPI7btTAUh05J6kslcCWrHeEENBZHPWSO1LIuWCVYesHdn4k2CACqDRNO5y2aGVWdLJaRaYLACtXegdjOtqtDICUOB1PTPNCMAbVxVNQcmWJhVwqxujVeW6pVRJaWxOHfF8ZFnr9/jh6gjMEb8i9EbzDWxFP65zkgqCwWuOdwYdbWJdcqCVWWmlrEqnMvI13eO+lCkVm/K11Wc+sRQiPvZNipXXY7LZC69QCAdNGo+2aItsrc5pRRzjNC0tOnJ+fSRpv8AJSMsiWieqS3qoVqReSyjQL2oNuYLysLFptpdJUCu0Mpkgei+qA6uT05npQ7rXjXjvuteN1accHv07f6BsU6RTJ7BelsNoweskfkDtUIf4VJSefNbc74wgMp0giZs2VUgS41L1cCFpp0I2qIZaGjgU9aIZgMBvLqDUBg9sNqFpxg0XpTk6VeUkSirXxWC8nQykQU6PUjrFQlYLBobcBPw5YY5lbp6ZErHI3qxGXd+kSFCaAH8XovbzItZNT5HQ6UrJsVjgUvTZ6raRUmKNsDTij2Q6O7SbggqX2ymk5cVikjVuXyM1h5vr6SAOWmDkeFsokO/loWYPM80JZZuo88/bmnLdDo2XPbM5xbBiWBe/3OD1wvUSWspCqRIobQFuLdoZeFN4FgvXiql//d9COYAPeOBQG1RVBWZy2GH2bXQG0Tk2FVCsxVwwZ18Vod0tenKZEjpkwGHptxCXRcmaZCvMkF/S6N4dxBm0lur33tbW6nl/KCjDKW4e1mTonvLF468itoboWcBdCJDXKrpwFiSOvuZKWQo4VrxCnPVLFW2PpOZNLpaWGXVc2y7ohUGvjOEWmKcrc3Ml/o62Y0bTRlFSY5sjNcQKtqbQ1RM4wjOEOyOXHAbsRw13MjaLVuiIrdFNcp8RCbTJqEJy7oiBrk9oadKmYDlopyhu9Zwz32nGvHffa8Xq044Meb/QNijYa4+UJpzSc83jnZGddAVoIialXaclaK2apKi1NigQ3tdIpsaJNQ3URosEP+MEJwKlnYqlwSgg3SELBkobFgEGRvaJahTLrNNoorFuxxWvYVa9S1XQUJngYPHoc0JsNtTaW3lnyIgYlZ/Hayk79skiSp7MMwTMGj1IaZwwxJ/aHE7Vk4ThojdMyDy85r1AdCZS62G3YbTbyPNBYYuQ0n8SRHhPHU2SJlaI7+ThxvJnIseCUxlqNpqOKzEl7TDw829ILLGakdqkmZwJjs5ybyvNy4sXNFaflRG/CMQhjkPm9UnQlYKnL7RnBBamKeqeVQiuCuDYdnNIYkK0EIwhyYwElVdCZrjzdaM6cQofAMzrPCrTaaRVYTYFxSvRSKEsmR3HW91qlPb+ulzZ1y9CUoyOikFOjF3kT6U2hkLax6lr2+1ecc1xuo88lKI7aKDkzL4twEKwmq0Luq+HSGCGZ1k5JVc6/3Cipkoq8LjeHg2xSaBEmrWUTwTpPygL7qlWsl857dErkWtBeOB+tN1KJNFUx3mCcxlhNaoXOGg7SJVF3PkWSykxnkaAsSyrULtWT7grXLLYBXVZ039TjXjvuteNeO16PdtzmVn2Q482+QbEW4ywocMYw+kDw7m6nu5TCqSZKr+KEpq9rUp1WO1WJqNTcKKWiq7QSlVFCT7QWrRuqy2phyYILNkaRaqJ6xTUaby2zLvRg0IPFtIYdhW4oREqzorVl390oJfTD4GhK0Y2EYmmlmEsklshgvMwREYCUpIzCYB3eWwY/oJUil8bVeGI0gVYrF7stwQWUtdQmDv7WxeDnrZygaE1VkFoh1yqu/1KYSyFlEZleBIvcloJ3jsE7Nt5Bq8TTkWnO+AcX5DDQ1UhpnZS1tAGLxdjKx7Yjn38Wee80QS3sNiMheDSSclqq7NWHYZSdfxRxWZinE603tBVTnCoNXUUQBqPWHBOD7Ym3/ZHdpmNqhFTQcc/Hd5dcPTzjp51DNUFNzwv4bcIqaKVhlFS/NWVakfCv3oG+LoAq4TK02sipsr+eubqZ2R9nOoa6rqiuGibo9y7hWyDsgFKkpZlT5DRNzKcFta7yVQkeodVKL41Woda+/huFvGTSNBGXhZc3N+TSsNrIut9qymwVUmySpaEMwXq2LjCXuvI6oMokmTkvpJZWcFJDWYFhqVVc45Q5vJo4XC8oazieLVTnmeZIyoWmVgCUN+jSKbl/qEroS+2414577bjXjtekHV8uNyhKawlzWsEyt3NSs679yYkvs0PW2OdeVkexkvUzXdfv10ajU4288K1UWq24IG5puZst5CJR3i5onHHs+yJ3sbbTQscMBtXkLr3WIi3bBMsxkmOW1FCj8d5LcifiZtdGdueXHDmmicUUOgJcyrlSlshcG4eN52zQUDtD8FitOd9uOdUTfgi88+Ahb10+xDnN8y4ik+rqZq9iVpJUcklyrUqRWyd3iF2Yf0ppci6UJALrrESib4eAQ8kGQC74YaS7LaUHWk4yM+6N5h2qei7thv/9Kz7CZ9+7prfOEEQYa6lM88x0nCm50pG5/THOPLt6nw5cVIFQOWtpNZPTQo0LxmiCH9jpxrme8CrTuycqTa4TtiXCq2d8fLvhYvC8qI15hlOt7FKRFnGX9qc41ou8LrC25sUZr7Ss9pVYWU6Rm/3Cq0PiZiqMg5E2cW5kGijhHZRWiTGLQNWGWqvenJLgpZdGaB2tG013ehOhQRlKacRYKLWznBZOhxPz6ci8LLzan6ilY5QRB7zoHzEVcm7ELAIXtMUbS7IWo2AYBknUzVmi44180BWtVHKrxCUzGsV0SFxfRQ43EecdyyFjh06OlVI62olR0QK1FXKraPXmyse9dtxrx712vC7t+ODetTdXYZBqppaCUhCCYbfxBCe76CAX1ZIyWZjVUDq9ra57Bap2WlnnycjsTNBJCnqnV0Wta0KltpLUmSrKaJyWaHKcrK2pAHbQhGAl0wNDWQrHeSJPhePNkRwLRhu8lTXBOWeOMZJyIimNpTPnxD6eaEqxhAWL5TRFltPM0BX762schaM/Mg4bplg4nibmacZuRwwNo8BrQzDubs2utn4nOFsrsJ5Cw9rI0hOpdUlrbQ2rZYaoV+aDs4bgLLthZOsDc5WLUjWxYytlqE1ajS1HqJqI5YTjrXHk/3j3Ie+9/5JSCjkmpn7icH3D4eqAxaCVpjW4WSYOaWFpmXepnNXMEAIxT0zLRI4Ltmu2OvMwTCilmeo5qVR6q9AG9JpxEerMg22gVvn7Y63kXAhWzGxtrThKqaSYAIXWFqUk4Iv1zSYtQo28OU3sTwspVaxfKZelyCqoUgJjao1pmvHOUrPMs1GKlCPTvDDPhaFptFZ074Rg2iVVdEmZU0zk2Jlr57BMLPFEWhLLlDHGymihyepgzoWcMrV0cpEguriuB1Y643bL5YNLHl08ZEmRFDvT9YxTB7y29NxoucBcyL2zv5rYX5/IS8UoRYmJRKOmKt6Huib/alDWUHuhf4jY9C+141477rXjXjtel3Z88Ov0jb5Bab1TSpEUSwUb63BrixatybUyx0yplVqEHIlS0jKtiNHp9tlSYirTWhC/+vZzxGSldKdpA6rinMCSgpWZZiqJlBNNa7rv2G7xwQONnAqn08JpiuRSCUqgRq13ci4cjhO6FebeCVZW/EqMpA5lylhlWWIjzwmjGvGYuS6RK20JPlCq4tV+ZpplLe58HHlyfkYsiZIlzVIMXYLzboCiobRCGStgp6aoFeZcZTffG8ExA1qBd5Zx8ARvpUWYMl2pVeAj3XoMUHpBMEUdow25Nqaq+Np3H3Jz/BifPp242Z/QnPj8++/z/vtXPDo/Q9NoTRNT4mY60kvBaEVumW3dkJYTx+lAKZkz73l6pvDGc6yOKVdavj0PLF5pZg2b9D7vnm9kLpsyc62UCroLHRIt66I5VbKXdr01IjCqCxq71pUWmTrLMRFTJZXG2KH3NVirNYidFBOxyHzar5VUXx8nq4sCuMrIudlTEVR8rbQiGStxpVLW2kjzQjxFTsvCdIicnTlaE45A642c8lo1NVKRSr72Kv4CaznfnbMdzhiGHVp7tn5hE7YEE1BNUWIhLYmyKTQFORZyriJ8FZYkbfQpFRoNccOJN8K0TrRN3mTf0ONeO+614147Xo92WPtlEhaIUrQmoUqjdfg1U8AY2R+vRe5iS1rd5OsLLzHhSEXU5C5WKTEdmXWuLvPDQhgkoaJRccHjncGskeolNTHbKUewmmoLbgsGy24zgmoc9wvl5USKWXbq19kovVFjogC9Rmot9GAx1mB6RzeoJdF6pTWN1ZpgYGM1Z87QemeZJ45z43ATmZbCrCqDKlx4ydN48eIFx+OCQoTXey8i1iQnIuZKabLGl3vHeU8wELym58rUGxqDM4YQnKxeAqU3lpxJteJMvbsoS5ppWYLBanAYp8jKo3B802/+aurP/Dzvvboh1shnPvseh1cnzsdRzHhaSzT4PHN0jvevX1F65axE0jxxnBOPH3j+P+8+YvSKxWyYU6dXWcFTXdO7RRmpbOJ8YjDr3j2scCahgaIU2gggS8xwHdU7Xd/CneRcqeubU1lR1aUIDbK1TqqVJSVia1jrZCsgJ2pZg7GETkRrlZQr8xQpqdCMpfc1xC1lqWZKlRTYIj+7o1jmRJwLp+PCtF/YDJvVxyDBZykW5lOkFvmdayzULJWd1obgAoOXULRWO0Zb6B29Vv7TUjmeMrttkje31kSAuzxX0xJRGpaUJfZdLBYE70B1FpeRd+s39LjXjnvtuNeO16Idk/syAbUNg6XT0WisNrCSIWWdTESmlSrGqlucNNKqZRWZ20Mp+V6tFZ0NxTaUkVlqMo1uFHptBxvTKK0LFEdZjNI00ykuS3XkB87GEUUFdeTV8wOrBV/Igus+e81JEjx7oeVENR1jNSF4airk1iTQqbNSBzXbMXC+lTyN61PkZpJWYaXiTWMwjf2r51z3ymd+6SX70yLPVXAE7zHWcEoLtSWmnIk101eR0Ubz9OGGwSpUruy1pqWONuBGh/KWSCOSOeaJKSUuwwZ64RaXPR0PqFbw3oH2dOcwJrAzlW/6f381/+bHf4affXXk1Ys9dS7EFDmeTlTglBYoFVUq8TRzVSvT8URPC+eu8lufPsVaiHrHfsr0paCaJHBKdaslrwSN0YpestAua6OUhlUKpwxegVNShbT1Q9/O1ZsQJiVErlKSQJi0ku0GmrTla62kXChdSb5EqqSYyUlWOnttKysis8TCNCdJhK0daLRcxWC5mid7E9KkswqnNaZCnjOnw8J0zORLITy2lT8Rl8h0XOTib52yFNIkuR/oTslinmsV2Txp8nNKKuLyT5VUoVRWoRJh7OtFUUpZ34g7pivUKprGKAwaYxQ1vbkk2XvtuNeOe+14Tdqhv0zWjM8uLb0pehE0cAPE2y1PakpiLBMMziowXd21ZJW6ZcbI/+9NWpmlVlhzRiY1s+SCCQZvFc4otAVjpdryxjK4gaIbrRh615LlMTiJZ29Vfq/WEd6QAq0oSLtNKYk4r1pRumCxtdH4EGg9U6pwGox3+KBwweEHMcnp2mComNrwg+LxLvDu43MuguXzr17xan/kFBOli5veOYv1ltIKS67kWwKghUwjDJYHF1u86hwPE1prquq4weM3ge4N3SjM4BjZMKfIRZPep9aahvAiSCeWq4bpl/TdQC2R4gpvXTzmG776K/iFzz2n5rbisRcO055UO3OtXIwbHgwjWmmWaaHFzEcuAl+zU6i2ENsZSwStDN1Y5umKkk+gLX70WG0RRKgVBHatgltuDVU6KjfZxe8SA99WE6AygqRuTVDXra4rdE3i0UVk1N2WR06ZZU5059Y3rEbvDUXHGVlzrOusd14S0xSpK3K7d1BB01KTddXY6BWU0ZxvJSNjQDPvZ/ZXJ6Y5QW+UkuQjJ+Iyk1LCGoVRMJ0Wrq4OPD6c0FazWZK0XmOWCi5VcsqkJa90x0qrhVrEDFqLdAP6nYdCzlfd5U27rW/YUkTKdZLe4DTje+2414577XhN2nHLo/kAxwfvtQA/9EM/xDd+4zdydnbGkydP+I7v+A5++qd/+osesywL3/M938OjR4/Y7XZ853d+J++9994XPeYzn/kM3/7t385ms+HJkyf82T/7ZyW460Me5w8CFw8C252szxlt1mpGUUpjyYVYCtKJXS1sWq3/W1bwtBEoU7+ds/aOMlpWvKbEtJ/JU6QtkTzP3NwcORxnYhLXuV0BO711uuoiGmumRkVW8HLJ1FbQipVAqKlK042iGmjWUo0jdsXSFblrlHFo52hagFhdQcOQNGQ0U22caCSnMDvH5sHI5tEZw8Mz7PkGOwaqkvRJrWUPrivoptIVwlkYHMPGMW4dxmuGwaKtAqNX5kPDG8V2CJIAag1NNZSFcWuZ6kRKE8p6lPEoZdDA8uo5L3/uJzi++Cz59JJ6eEE7XtHmI1/5cMNveutSVjO1sAOWWohlxtF4vAk8Gge23rEz8FvfecDXPR4YegYdKN0Dhlwah8Mr0uEV+bBH1Uyn0a1Gt5nWIZZOr5mGltXHJZOWRJoXesmkKCjwUotsaazt5tYqK3YSA2ycZzM4nF3fmDq0/IX1QL1+sdSC1ZrByGpozY2yVPKUmedMzZIvUjv0qmil0XInLtKutc4wBM/FbsR0mPczL18cmacIvVNLlnn1kgQkVQvWadCKq8PCi1dHrq/3nE4L85LEJLdkcl6roba2nIvkYSgU8xw5TJnjlOhFzhGarC3WJrPyUiutQi4NKqiuya2SPsQ1e68d99pxrx332nGrHR/0+FAdlH/9r/813/M938M3fuM3UkrhL/yFv8Af+AN/gJ/4iZ9gu90C8Gf+zJ/hn/7Tf8o//If/kIuLCz7xiU/wx/7YH+Pf/bt/B0hy57d/+7fz9ttv8+///b/n85//PH/8j/9xnHP8tb/21z7Mr8N4vkFXReoVWzUGJS5qoPZObHXNloB+21pSAIq+3t190d2cgrANjGcDNVWmGNG94bIBr2SlLRfZf984+Zl1IfhG7Y0YI72BUQ2XHDFGTjd7luNEvw1dcgZv3R2fAGMxxhBLoSqp3kqpKGH50LQiiZ0coxpnzbDRhqxgsZbiC9YbgjX47UD3gaU3Dk2xrEAtoyRLoq+5DMpr/Mbi7Zr4WSvBa5pV1NWcNy8Z3RXnm8B2FzBe01Wjm07pla5h7oV9WnhqNblJm7M1aKVT54Xj8/c4O3fENLMbH6DdBj9s+Lqv+RqevzpRckY7Q9WdZBTOWEYDtWaMVvxvjzb8b2eNMh9ZcqJoEb9aokTcH69I18+pKNxuJyuiObFcv0eerrg+nmROjySInqaFkxMnPrkwTRMpRnKWaqJVRa919RJ0rDZcjgPvXG559/E5z5+/oqyzVmBd+bQYA60Xem8E5/D2CxsQMSZOx8iyJFBdTGPdSLx8KszTQkkZSiN4LxAla5iOMy/TgU9/7jlxjuQ5EY+JeFo4TidevdqznCJrwhlLyhxPC9N+gq5ZpkiaM3ko0CV7hKpIpZJrW1NXNUuulCkyxzW3AyQXJYuI1AYoJd/rnaANVpk7DPe9dtxrx7123GvHh9WOD3p8qBuUH/mRH/miz3/4h3+YJ0+e8KlPfYrf/bt/Nzc3N/ydv/N3+Af/4B/w+37f7wPg7/7dv8vXfu3X8h/+w3/gm7/5m/ln/+yf8RM/8RP8i3/xL3j69Clf//Vfz1/+y3+ZP//n/zx/6S/9Jbz3H/j3UdbTW8EazWDcWgUJhS+2wpQzpYnzWqnb5EqZFYpLut99KKWxwRPOBtzW0Hq9E5WyRIoDa2UNsKVGWiKLa8xjxQ8aHxwpCVmw58zgPSkWlmOkpoJWInLeaqxd53KlY9EYpfE+MFdxe9dW0EYyIuJcWaZIrQmaIfVAVlI5CaFH0WqlakVWlWOt5Bh5tl/Yz3KCWbNWfh2JDh8DYRckXKqIaI4bTysVqxT7nJmXgjeGszEQvKFSKC2jGrRW6HSyhmOaeLQc8faSYh19GFEXF+T5wNXzFzx59zH+wUPMg3dQwyV9OOfrvvYRX/tVXynudKWZ4kJtlWWZyPORXBq1w7uXAbX/LPv3/gepdNQjafn+j5/6Hyy5YWpkcJbt4yfrKl1DtYTrmdI1L65OWGcZdeMCx4W37Jyh5ErskKeJkhZqnMgG6NKGNM6hOmxdYGyVfDbw+NEZYXDMpdGVIMe90+y2nnGwaKMJRVJCjZbnqMTKNJ04HBbmOWGt5HG0rlFWUWrldJpYloRWspLZkXHBclr43NWB51cHqeAOicP1wuls4Xg8cbiStb5cG6XLR6udslSSK8RTJi2FNGdZxSydtGROSyLlInkhSlFLp3WZMd9y11qH3NpKn5Q35tYLqgccGlWlfd0/xL7gvXbca8e9dtxrx612fNDj/5EH5ebmBoCHDx8C8KlPfYqcM9/6rd9695jf/Jt/Mx/72Mf45Cc/yTd/8zfzyU9+kt/2234bT58+vXvMH/yDf5Dv/u7v5r/9t//Gb//tv/1X/JwYIzHGu8/3+z0AuWkoYJu0XyvSWtIN5iUJ6KivuRTr/E9rqWago7SltkIvyKx2cAwXI+OZ3MnmKbOsmQzKQPW37D9kDfGUmA4Tw+jYno3CR2id1BqLXUixMh0jvXRp+Vl7N490RnPuPW+NI8YaDjGSJmnX9SoJkCjFMkk6aG2VwQ134U5drSdkF2aD7pBK4hBlJ/3ZyyPTkljX8rG3xiTVcd4Sgsc5R12rRq1Z7+DFnZ5Kx2jN4APOG7RRaNWhi3EwCR+Z2BT7l7/E9kJSNA9xIR8PpBhJMTIdDrzzsXfBbVBuoKPQvWB8QG3PwO7wKWLqCRVf0fsFGEfrMN28oLiRvCT2r14yhkdcXUU+/VP/neN+4uGjc5688xiXC14ZrHPocqTWRKye0zFzPg5oa3jHKx5djpwFy81+4pQSJU2QZkwOkDqtSQqoMetrpBrVIM+Xt2x3nmPKpFKZc6aqwhA0m42n0/HFiMGOSq+FVArLMrM/Hkkp83Cz4cFlIKFZuhaT32Q4nCaMUXjtxHPQu+SZTJF5FuNcy43TzcR+d+D6es/xemKZMylXYiqS0VEKaU4Y75jnRa6bJWGUIk0zp+PEFJOszrZGU5IBorpet0O+sMbakDj1TdB3vowhGMbg6KXLPPlDzJLvteNeO+614147brXjgx6/6huU1hrf+73fy+/8nb+Tr/u6rwPg2bNneO+5vLz8osc+ffqUZ8+e3T3mlwvM7fdvv/d/d/zQD/0QP/ADP/Arvn48RXQsnOHEHWwNWgurYFkSx9NEyQW9XmCt93U/X63uGwXaoMVijNt4wm7EDmJCcs5QlKLnRjpFSsygkZCwJneG9a7zu8avd5m/zijSXFjmRK0N1cCsra3WwKK4HDwXzoKz5FLxDclBqGAQzELJMuuUzmGjlMKSIr1ZObkK0DUdRa6Vq9ORly/3TPuJjfXsdh5tFLtRxEJphTEKZwT4pJWS3Ael2LpAsBqrlURuF9k40FZhrcFqma9Li7GSW2ZShR7OePHsc7jtQ4x3nErBDIGH5xsePnkkVai2Ioy9Ulul2R1t85Q+PMDmE/36M6AMzTppK7aG0hbcht2Tj3FcpHX82U9/jpobm3Fkd7bDbM4I54/ofosdNgzphmI9r1rnchvQprPZjjw491ycDSiZNmOOiZYrfklsasHVjEbMd0Y1DBZUo6ZIawnn4PGDkf0UmfYLh5yIqqGsQltFbV2eX6XpulF7JpbIzWnP9fGA0rDbWS4uA0vX5LmxpMixNvI8Y2mC7F63NWpfZ7+9Y4xGa41qkJfEcT9xOk7ElIi5EqsAwFpX7KeI9hZypcREnidSbRyu91zfnJgnMbd1oPSO17IIq1fGhlJ6vSwU4+jYjoZgpYLfjCPnG8thv1C7YM5/Nce9dtxrx712fHlrxwc9ftU3KN/zPd/Df/2v/5V/+2//7a/2n/jAx/d///fzfd/3fXef7/d7PvrRj9KywXaZPwZrsEojq4CdZXVAl1zwfv0zuywWrteUgGN6o+qO8oYwSBaHXoOThk2gzJk0x3UO19FG0dVqiNNdsjeqxHffmn9KLmIUWvoaFFVx3BqiACUvrtOahqZmMRapCq7CYCWyesqVmybplM5qIUimRDt2qrLMVZDcKEVVAsjppXI4RVTTPN5uGI3EwD863zBuA9oqFNKaNihSKcScZdXPGgxSRdZYSaUSKZK9sFZLNCipkOdIWiLH1nl394BHFw+4XmbiPGPHLRePH/D46QXbbQBl0Gagd0VthZYX0Dswo4jjdIWb3qe1RFeOjuRNaDeC3WAu3mL7Dnz605/jZr37H70hpYT2I81v8dsdo1P06z29GFKKXOwCsS5cnAXefrxjGDxTyfRJUX2X2f0ScTUzNEE8996gS4W3REVulVwjfoDLBwPhpaZcV14eD+zjjlgjAUvTHUwH1cgtklNiWiZeHq65vjlRW2O7cTw8H9mnymGZSPMJPQQej4ZiRkprzElmvNZogtWMzlJ0xxtD0JqeEikmllQxSnM2GHaDZLqcbxyWRo4LJS3EZWKyihQTVzc37I8TNVVpjzcZW+yC5dLLlklsRc4nOsbAxZnn6YOB0asVE+7ZWM3NURJm26+yg3KvHffaca8dX87a8cELm1/VDconPvEJ/sk/+Sf8m3/zb/iKr/iKu6+//fbbpJS4vr7+okrovffe4+233757zH/8j//xi/69W6f+7WP+5yOEQAjhV3y9p04rDTtqgnU4I6FYJWfmORKjgGJqVRhj5L9pHZSm02hV5qHaWkxweGewteOKopcm/IDBUnKh5CJzRqVRWqOb3EkqLXAbmhiEYJ3ZrvvlLWZUaxJQpQ1Ka5zTbAZPqY1Xp4VjyhziwjIvWBob51b+gjjES+uYJsJwSoVjKRRlQa10RpoknnZDnBf2h4WWCo+2Gx5vNzhneHC+JQwOtPzeuRRyK8zLwhKT3HmrTmqVJWZyqtDFdW+cWZMzxRiZcyEuieU4s5TC9dUr3j674MmjxyzjwGIN55cjm52ht4T1I10ZOTFLpuaIGSW0ijRRbp7R4hFlrCDFKaAcDUNTHjWcsXtiuTwuZPUcO44Y1QjeEecj5xp0jaSrF9Tphmg9yiSwjmYWbDAYrykOklKwDfjdQE2Jk87U3tCqoxAXW0uNqmSWHelk1bGDZdhYrHd3b2IVKC2TahJTIhWDJqaZ1hrHmz37lzccbmZUNzijGI1CB0sZHFvbeDQqLreBnCyvTpGSM0ZpBu8YreZ8sCStGK3G9YpKEdMyZ65z+XjHo40neE2w6+9mJCPl0hlcKdRppsWITZEdnRvdyVUi2lWHc6f5yC7gjeLVaaGVglaKwRieDpaPX4wEr7He44yhxCwrqqVLzPq9dtxrx7123GvHh9SOD3p8qBuU3jt/+k//af7RP/pH/Kt/9a/4qq/6qi/6/jd8wzfgnONHf/RH+c7v/E4Afvqnf5rPfOYzfMu3fAsA3/It38Jf/at/lefPn/PkyRMA/vk//+ecn5/zW37Lb/kwvw5lzmycQTWh3f1f7P1ZrGVddtcL/saYc6619t6niYivzdbGzsQ4E7hlTIEThC6qArsQlFBhJFRXAiOherAMD5gHhISQQKIpeLB4APNUXOrBQgKJunVpykIIqFvge2kK3+trbDAydtqZ+fURcc7Ze6+1ZjPqYcy940uaIj4aR4byLOnLiIw4sZu11vyvOcb4NykGVJWluVuhmRvntGbn+XHDd7oqQmnWpYMQuplOWyrz7P92iBG5mGgG9XallUIScUKViLPcNfrritsfuxOyQqosLKjhbGkFOtkum5FL5ma/Z18qT4+ec9Hq4lkbaowtcrcYx9yzK0RYs2FLw4JSpKHRGASiKK1USivcPp057jPrfuaogXlKxKFHnkcHwCbGkldyM27u9tzuZ1rJzCKUOfPekz1lLWzTwMV2ZEiJmBwgWy1nD4hqSl4aP/bTP8/1lPj4JxMat1y/+RobZrSsVB/S01AaCaVSbaA1Ra1i+1vq3bsUWxlUkVZ7q7K4/4AERBMxNt74+Kf4zOcrKez4yr/4cXbXF+ze+BhhSjx9721WOTJuLnh/XnhK4U4qR4RsjTmoJ3SmwHS1483uoDhNIzJGpM/ba5d8NqH7SzSqunvkkCLhJDWtwlALoVWkZFSck6Gt+iLMlc2yclkrb04JmSJvbCKXIbALsLkceLAbee1iw24c2OvKzdEJkCIwROXVqy1RKhYCb15PXG8CU4BXtgObjz1gEyPX08A0BEJ0l8nSGnMtpEEY2kLIlVgzEjNlUm5T4B1TrDqxbTtGLpJwk31eLCK0piSMq6C8MY2Mk6sDTGBfKlJxq+2PsD+5x4577LjHjnvsOGPHcx4faYPyfd/3ffzQD/0Q/91/999xeXl5nvteX1+z2Wy4vr7m9/ye38P3f//38+jRI66urvh9v+/38YUvfIHv+I7vAOA7v/M7+dznPsfv/J2/kz/9p/80b731Fn/4D/9hvu/7vu/fWen8/zvmY2bXQJMz7IO6PntZM3POXnFEN2JyE2DO7cbaTXWcUWyEHJzQhs+hRdxAaNptsRCxMFMP7laoNE94jAGJStpMpCn1uay/3vG4eqolvrsWrPsbGHc58+7+yKaO3B0XbuZCcd4/4+C22as1lirkquRSiQJzNU/NlEgu0MpKG4wxedJkLY1lgXmurHPlMUfGIZCGQMMorRG7Z0Npmf1ceO/JgQ+eHJiiL5B5v/D+kz2YW15Po3sYIPhCDIEYI6IecT5X4539nn/xc+/xiY9/M0MMPodtC3ndM8RErY3SQNKIBiAKVUdSLeT5CUEjqomCEprP6d0yy1umUgu1FIZp5Bd/7hezv8vo/JQ4XbB55Q0YIg9e3dFq5cu58FSFJ1Z50hp3rXKkMlshhIGGMOnA6w+vkDVznRKbwbkEYoZa/1WjEwQbjFq4iIGroFwGiCJoqWxa5ZLKFs7dBDVIeKrrLiR2D6/4VEqICFeXO+IQPXsuRqZpYBN9jktZKXlGBTT47Pj6YmQbK+M08Oh6yy7AEIxwMfFoN5KSMqbA0OfM4FLc3TmQbSVaBoPNoJQpMEagx6iX1si1sYhyKO6x4FJGI5gxipCasSEwmFJxAqQU50V8lA7KPXbcY8c9dtxjxwk7nvf4SBuUH/zBHwTg1//6X/9Vf/4X/+Jf5Hf/7t8NwA/8wA+gqnz3d383y7LwXd/1Xfz5P//nzz8bQuCv//W/zvd+7/fyhS98gd1ux/d8z/fwx/7YH/soHwWAVgrE6F+4OSgIwlqap1SqMoyDOyDmDLUxpADiCZJz8SpJBY/GlpmyVhS/yGkckCBM2xETYcFYDguqkXGT2GwHhs0EKcLQW5k9jOl4rOe8DjWvwiQITRrLnDmWwjCsLkVUJU4DMQ0QhKzJrZeh6949eXJdjDgYYXBrY6se212lIdFbqOvaXAoG5Gocl+L5EiVTMMYYCCYe155XHj+95fHNnutNgKI8fnxkf7OCiMfCDxEUqp3cK8UJhSlgQWniyoOf/Pm3+FW/9IbdZUAoWLnD6sqxVJbbG6QlhqvIEBTTBGnESgEKbB5RNdEwWj4i5Q7FKMsByoy1TFSl1MwgkIcNV5/6JoaLLduLgbA+RdeFDxAOGrAhsqjHpB9ro2A0q4SepxGbMQVl2k5chsg2ueU4zT0wTNwyGhUGjE2KbFpgEWFjOBBVIxZji3CtsVsXgbWKCkSFi6vEK9MV7XWv6Fozz9xoRusjBVsXZtbuZ+CSVRXPbEkCu2HHZkpMSSglgzU0RiwEihh0Z8ZQe45JKbSSuxQwYxoQBO1eBKXU3tL3h26ubi51mYSP7QZqUzRE3tiNXKdIqKDFkFYIgCzFPa6Fj+QIeY8d99hxjx332HHCjuc9PvKI5z90TNPEn/tzf44/9+f+3L/3Z77hG76Bv/k3/+ZHeet/5xFUUMxviG6pW5sxd7222zxH0pTAPExJBEIMTuhZs7szpuh2181Yl+yyviDUWmk2EMdECBDHSKm+Px8uN1w9uCCNA0tphN2GGILnHdiBXA+0KrRiBHAme3Cjn9IaazaOnawUkxBjZNpu0KDeEnVBOfRKtZmHPIXYmNQ/nxQjt4wCKh6pXUul1kqS5hWJKkGDz8qH6MFOrQEe0DXPe9Z1pgwjh2Lc3s2UpTLEyDCNEGAtmXlZ0KDdoMuDr0ybf3ZtPN0f+fmv/Cy/7GLwqPHDHRonxosdN3dP+Ykf/XE+/yu+nVdffQMT9XyH+UBdZkQTQqCUFckZjrfewi2zh4hhvvAEYPXgNduwoZBun9AO7/JuhffGR36PrgUrPvvPS6aWjJRCqh6UFoBJI9sY2WoiaXCba4UWjGiGmgDGAGQDKYYujby6lTO1kteK5UYSz3ORnrGCGUkUGcAGQ9kA1h8+4jP8UnolbuRSWJaVnAu1DmgSTF2JESUQQgRRMkLDAVDM36fWRlbPQ/HMjBWqyzlNAybObmglc5gzx7WT3PDRxBgj11Pkld3AGxeJWlYIgYsxcT0lomp3yWwuJSzuDlmtnUr+5zruseMeO+6x4x47ztjxnMdLncUjeABRVIHmznsVZ8RXjAqdMe85FKU5eS2kQKyNIUV3gNyM6Bix1s7kOFBsyS7zmzOouJGS+BvHGN3WGSA6eDRtIIFswvG40tZudpTieQZpIpRWOwhCihHEKxonmwlLbkgtLOvqc1twp0tzqdg4RNx6G6I6eLn6XsE8A0FEiSEwpeQkqBDd9rpWai2sbWG/HJkXz2pY5oW14RkRrTEMA2NMiNIjvzNrjoiIh0G1SlRzt8y+ON9+/ynf9OhLrPs7wpTYXgpvv7Xyz3/2Ce+/85hHr/00V2OAYUO4KOTlFlkPFFEOd3e8/+47XI7wyoMtbjpUoPbQLg3+oLDGEDKECvMdcTPwND3gizcfYMsdMSVizqRcCesCSyYcVy4bPJCEqhDEK5BRAzFEokRUg8+tzWWa1qrLFWtFSqHMK8th5bh423POhdvDwnqYsYsVSSOYPxBUA2jArPnriI8EDHzeC3hAmSESiAaFhXjcu423eSt0iMo4BIYhOSFTBfV4FXpCFyr+oF1z8cyW6p+5tzgcYESp2RUXa6m4GtC5DVMUXrmYuNyOTmKs/r1DDGySm5RRqzuVngPSHKTsozgufY0d99hxjx332PHisON5j5d6g+ItMQH1i1jMKK1yLJmioEOk4G3LIOKg0BVdGgIhqjv7BWHsIIS4tKqZUbNRSsOkeHXSKyVRYV4KeljRVN3XQBY204hKYNmv1N7OUhGGGJw016PcS98FC0D16uwQfJerKTCGgNVMLYUQI6nPIWlega1zdhCZIil5hoiI0ErtQU3OKHcDJWUzDIhCLsZxybRWOeQjj28P5GV14AHIwnJXabkho5BiYDsmphhQgVrKeedNa0hrRGsMAWYa/+hf/ivevEpcHT5g3OzIC/zf/p//lJ+djXwofPZbP8Wy/4Apvua79LJQDk9YSuUrX/oyP/FjP8ErDy/51l/8jbz+6ALiiC1HqNmvszWW/S1wQ6qFJpEjW774dGU5NLbjwtZ7yuQmbEvjrbsFu1kZV+NSuhS0g4yaE9Y4GVi1htuHejJnrbXHphesVALCq5vEmxeefbIcC/PNkXlzpA3uJKkhEWOk4mFurZ5ezyV4qoGUlDQITcQ9HkJAx4UnRXlvuaPUSkzKNCY2gzIMibU6Qz4GzywxOnCFwJKbJ53WwhSUqBFpDcX6fDlwMJir5+d4qz2CVK/uA0wpYomu/HCyYYxKCME3I11m2zh1Q06mZS/ncY8d99hxjx0vDjue93ipNyjJxL9AR8rWGmvOVKtoNweC5rt/vwWcRW8QhkgUn8O24r9K9wvQEKi5ntMpRQSrho5+YnOpPL3Zs6wr42boFtlC2S+oBpbDguWKNCOqeMszRpoIuTgD3sznjrUaRsX2C7kaOiYYEkEacUxeoYlbUrfcvNZRz60IpggRETwWfO0Jm2ZsYuDBbuT6YmIanKh2d1zY50aRymGZOewPhFy5UCWYcXO3sBxWrIm/vsJuCGxjICG04jbHNRdqzlh10l9omaSwVOGf/cxX+HWffoUaEyFUYhx5bbNytya++NM/zWfevOTijW8iDhO2N+pyYF1XvvKlt/mf/uef5r195tHf/5/5b37D5/mlv/RbEBWolXndoxg3771LvvsAE+PRN30bd3PjuN+zrcZQjGQrwRqXDXZZuNmvvPf0yLJf4EEjaPBziXjyKF4ReFS6K1NEPGul9orAWiOgPNxs+CWvP+DRmDjkyuUUWefC/nZPnQomkIaRVN20yKxSSwbz+XAI/p8GheAVmGlAQmRp9Jm9nNu3mzGxnTy0ba2NViopeHUr2guh3JiLP5A2CtuYiMHHElR3Ly21sK6ZnLMbjvXqKYRIQL3VWysxhPMoJnbjMgcX/zfFjNz8HDWzjzRL/lo77rHjHjvusePFYMdH4a691BuUKJB6eJYG7bkAwtwNZYoVVANCI4y95ZiNljMTiTREYvCbaX84PmOcW5d9iZx3fbUYor7ySq7YccXWilSn37fa2K+ZqIGcK+3gbopDB7sYArNV5uIGV4JLCqWnojaMkjNDUFpQ3wWngTSOhJT8hl8LoTkI5lY96rplpM8ol/2RMmeCwJsPd3zD65e8+vCKqMLaGof9kad1YbWC1ZWhFrZDYgzKclh4umZa8QpAMKIaUTrzvHlLvzZv0XoquZKGxMXFlmksGMLu8pLrT36W3dUFmPCrfqXwf/9b/4irGHl4tWX74HXCxZtexYlR88KXfuon+eG//6O8u89cbRPzUvl//J3/kY892rC72tHWI3UtHJYjy/GWuyePee3117h68JB3vnRDWhtaKtoaFjw4rS4ZqY1cCm89nvng8S2vXF/AtEHFq4haffHU9mwhgceml1rPLf5T1P3lxZZPiPDoYmJZfcEOKbCu2efzIVAb5FD9QVC9GgV33UzDgAm9Tey5Lz20Amuwrm7yFKN6pocGB7/amI8Lc84u21Rn3puZ26bnQmzVja36KKA2I+dCq6UDSWUIyqPdyO3amJfsoAKU5sZTxukhFrpao3VVhK+B1iq1M/bptvAv63GPHffYcY8dLwg7vl44KFGDzwLVTYdMAswrpVTymik1E8KA6IB10ltQqEsml0LURAjqBKV19YscBBXtc1kHGm+wnXb/zq63UsmlcVcqcUhuJbxmciebWfG2XQyhW0MLeS2eAGtu0hRS9HYZdCKRk/WKird715UalSlGB6MYUMRvgFrJy+pFrDVqKZRlhbwyKLxyMfHa9Yar3UheV+Ylsy4zpSwEqewiPBgHdtNAFOFxM96RoxPrxEmEUYWoRuqzSjPP7Rg0EAd335xCpNbRJW4x8uk3PsnFwzcxDcSg/JJf9Cr/l//Tr6OWyq/433wr0yvfjF68QskLKpDGDVIdPJ/cFZbDketd4Fd82y+lzbfs56cgiWGzwWrD1oo2Yf/4CWE58PD6gi//XKbNC4MIFgOCsa6NJkag8fS48vbjI68+uCM8iqTosrpSGrV5Amxt7bxw7cRY77PfkwW6NfPo+CCMU/IWPoaIYR2IWsmIGU0UTM5uqyLiDHbtowEDmtFywXJlWVbWNft9HQNpGHzxL4WaV1rxefpijRICQeO5pSrN2A09YdZgKZWcCzkXpBlDMIYh8srFyHGpvH+78oGtGOIkyj4qKKV2kOlA26qTRaOrV1BlbcZSqo8GXuIZzz123GPHPXa8GOz4uhnxRIMkwpiiVxKdnLSsC3XJ0Aq1mx/JZvK5MaGnhjZnrYMTeIobICFCTPE8a/OdrTsg1ly8zdXnjSZQS6UspTeB+8XJGSkNBVIQRlWkNUqtlFw9aCkGbyX3DI6cG1TPYVibdWBYGdZKrbCZ3LEPcJvxNSO5uHSMShKhaXP77Oya9pPtcaEhZWVjhTEJ22Hkekw8uJjYTKMrFI6VxA1m/XwBQwoMfW6JOWAiSmuNpMoQIgI0cwneOG751JufQncPKVVpZNQyn3ljYPfKJ5gefAK5/iQWB9rhhppX4nTJN37mW/jOJfHqv3qbpJVf/flP8s0f3xGbMF69SohKk8ZmzUxX77N/9+dJ9YAtT3j9tc/y4wrzvKApnStMA0KKbMeB9w+Ft28zH3tyZBgHphT9EWL4NakNaL77t55W2xz4mxmlf//TfNm6bNLEfRcQziQ0FXFxS58bN/NANhOlSaEAGhqxGaHHo3sIHC7FBGLwh0otxloyUgtSMhF/wPnn7O1UPOV2k2I3EGusuZLXTljsn28cA0FHjsdMoPZKtzFGZUyRMaX+3azPz52xHwZvCdfmD9Wls/BFxdNKX9LjHjvuseMeO14cdjz3Ov1PXukv8JhiIEZvdQIg/cao1Wdoze+WNq/Md3tEJzQoQwg+Ky2NXDK1uN3zKQW04Lu/hssKAU/paobVSivV7aoV38k2nEClToRTEzAH8BiUFJTSCqUU301jhOA7XVWh5nxuHbYu9QOQVVnXSsuVeH3BuN1QrdDWjK0rVlaGIAxBSEEwE+IUOazGk0Ph8X4lDWsnaRWuIkxD4mo7crXZsNltiOPAshai3rkUDHfHTDGwm0bGYSSmkRbEPRPEEEo3lvJqDbxF/vDqIdeXr6Dba2xt1OWW6eIB0jKyeYRevkkdRyiZvL8hlIqGkc31a/xvf+UF3/LN79GWG0JbafMd08c+S7r+hO/E1xm9DIyvfJqL1z6BPf0SaGA3Jr7x05/mf/ngx1hLZgz+gIhD5NFu5DOvXXCxX7ncRKrBcfb5tzt3+jy+tgYYIfTWo9l5Hg/Wsydc3og5J6HhnAPE2eohBFL0ObH1OWs9ZWL1tmld3erZpaoeT0+MHmne6ldxO0S6z8BaGKSxjYGlNpp4KFitPt9OoiQN/nAtBibUsrijZG/BD2JARbUrb8QBMQWYxsB2MzBOkVIKtfR7T5RBlWFM7mXRwbh1dXGKQm0v7wblHjvuseMeO14cdjzv8VJvUIJ4FLi0hqgRqrfDYojEIGTzv7NaaAdYmxGHzmw/8YqLt8kCgtppoWfMt7I9JhykNDd3KsUBpjWkCQScpa8ex04DEwP1HeqowbVnKNUaZoWAuw8GvF3XzMELc7Cq1W96moNf1cAhzD7bpqJWGaNf7UGNTVQi0JKiq1JUmHPlcKyUy4YipJC4jMJuk5jGyBQHkiaUgJiT+koRzCpBBq+AUmdii1dGpTkQnxZlEiWoV29jHHnz6hXG7TUtbqhtZpkhTlcEK7SLN9GLR2BKsYqlHXW+QfpMethMXNUdh7ffpRxvSa9+AqaHNHUiYUmJmAZCDEzjhFxcEI5P0HLgM9/8Ddw+fcL+nfcJAVRduhkeXJJC4NNzYQqRzRShuZLBF3JfzabuDdEJbieQOEtDT231kDprvxF6m/JUwZwIbCrq7d3mDz4vVhpm3Yb7Q9oX4witEmNiPi7e5m/ucYD5w1LNCOqhdkhhbcZq7cxHqGZgA5ZcJirNEIQUlHGIJKDV3KWP/kBMITINg8tXQ3KTMLwV7d4XzX8NvUpXmBcPqZuS8trlRExCrSM/8v4v3Hr/z3ncY8c9dtxjx4vBDrTRI7T+g8dLvUFJQRmiEgMMwRfxqWU2Bg/hMvOdoZVGOa5QmyeLqhFFGBQ0CGLOcJcAxYQMYOLSMVxyFUWpEkCqbwWb55tG6CAl6LOGLVOMbJMyRijZjZU8nstto0OrTrCqBelmUIi/poh0a2qwWjje3VHL6vr2KGyjkKIydpDRBlWUghtKrdmJcK0JIoGg3pLbjQMpeXvVasOo5OPC7d3CYSmAm/ConAhOFaz6bJzWu4+NCKRWGS0QiFzsrri8fATissao6lkmJqzxIdur16kaWZfsmRmt0GpFavUMChrTxSVqHweJyOYBFidyrahUJG3QcYuE4Cx2HQhSsbZwsZ34pl/0Kb60ZpZ57mAhhI0SQyL38ytd0WDtBCBO6tIgHUwdOE5x4IYDUejcgqDq38189tyqAwdCr0z84dKaZ7WYPXvomZ7yKvwOsVIpOVMPK6LK4TCzrG70pXR1QAVyc3Pp6HH3Zo1cKvPqkj9VoVrBTFEL7sMAJBWSde+KUmhWKcVYWmMcI1fVZaQisKy1PyAbqg40UbtqALc3DxhjiLyyS4xhR60jrTV+5Cd+gRb7f+bjHjvuseMeO14MdnzyKvA/PiduvNQblO3gAJOCEKQ9q2BaY+j5EsX6RbPmwWBN0VpQjCFFWlDPP0BIKZLGQAWy+W5ySBEw1lxZDsZ8FPLqs0oJICihXwRv4Tl4iSpTVHZTZDcqc/GWX8BQINRKWxdyM7fWrg5CIn7BgwZEOuusFEqGsmaOKlxuBy4vItvk+QiTusPhSmM2f/1iXZN+2u0DGiAlN15S7XbJpbAcZm73C4elO/7ZM2KWLNkj04P6blwEkciAMaiwCZEpjjx6+CqbMVEpNIvklgkhktMl29e+idqUu5tbZ6zT0OOtA61GbKmUtjJstwyXr0EYqZKoJuS8EOLAEAdCGFBzmR8hIWkDFSwOXD58xNWjxzx5970u8XNzopS8OskFNyKiVzue0IWKg+qpKjKz3m53omJrjaDRAaZbWDsVnXM949eqk+eqg00Df+CdflXts9dOLmv1LP8sZh1kCjRBT+Q4FVaDQwO1QAMKQkWQ6DPwdJ7nilttN3/Qtepg2cozeWCpRsa9Ea4IjL3KnXNGxLsCqoKKULsrZlCFk3ukCpebgcspIqLM+dSHfvmOe+y4x4577Hgx2LGU4bnX6Uu9QdmkweeoSUlBoOImM9aYoqIk1iqeNipeNYSkxCEwDpFRFWvizowI0xDZbBLDOLiDZBDGGDAx5py5vRt48nTm9k57hHpv04k4o1/E59N+zzImZUrKEPSceKqintEx+K+tmxZBNwBCiKLnz1ysO/D1XbEqSFPUlAFlEmUjxjgEDgWOIRCDK/SRShB64JknioboygDDq7+yVPK8cnsszLnQDJr0llwIxAZD5RymBoqKMfQ54zYmAkp+/IS9RsJDw1Lx76KRcPUQ3V4z3x3ZjCNfefdddhcbLuKIlgO1GjVngjqnwfwMeAbIuhBiJKTBSXa9WhIBkwZhwMoRqyvTEJk2nnvSWjsvYqk+h6/ZCWqqzpAXOSktvGIV626fzatRJ7k5FUE6togKJ0dKeqWKgETnEIh4kFzts+BO48eaA5eqM9przT6z7aZVpTZn9beGiLdFVWBujcVc0qjVyXVLMYoJFWUu2SuxqJgIq7iqJGLksjKInRN0EfM5sAnDmHhtG9ltBrZSEe2SyZNs0hopJqoZSUHx76PituonbYqW/CKW/X+W4x477rHjHjteDHZUnh83XuoNyth9AlSa51KIMQa4jMJliFQia0008wV/uZvYTIkQlaRCxAlN8xLPIDSOiWkzgTpJbYwJEzdsWsaJ/Tiw362sp+rRh4PEFIndirhVb4vFELjeDkxJqWpkbeSaGIbEZgxoCBzXxu1hpubGIEJK4tHhjkiUaiylUfGW7TQoDy8nXtmNvLJJXIyRjfq5uA0zUjPHFdbOtA5iDClBNFLf9Z4WhKhQcfLXzewmVSEkphS43ETGKIwhsgnJKycB6dLKQdV3zHjGBqVQlgVpLm3LtSFxQmTgn/+v/5z33/2AIUVyznzs4x9j+sTrHA9HUtnTjjPjZoI4geETfvFgqxhHrzDzDHGgdnCIuBmUpkS13IG+h6D1qrJWt3Cupbn3UHOlhEhXFZi59XInLRp0spkz8N0IC5f+mSHm4GXtWcv69IDhPIPuLfYPyQ7BHx6Ck+ryCRzMMIFMY2mVpTZWKmuubDagKVLWyDEvLPNKKW6frhopGMdSWM1oFlhsdRWFKLsY2AYDM6/cDALqzpTVZ8xXFzsuLwYmCmNd0VLgRObT0MmWggbcz8C86m/m1WUpjaXav70oX5LjHjvuseMeO14MduSPgBsv9QYF6S0/GkmMkIRXLwZGrojq0dRLySDKxW7ickpsN1PX5VefkWKsuSG1V4UhkMaBEH2a12pDNNKAmhLsJqiNlosz8+3ZDnFI7szoGQpe2QzdP+GNdcM3lWsqgqm6gkCEXDJlWbG1MajfsKkb3LTe+s1mmARiDGyHxJSUKQXGGBljJAVxXwe7JZcjN/PKzWK0WkCMGAXFVQsxery6CNBgzQu3GfbZQB18r6bEbojsxoExDYxx9JAp/MY92z2j50vRMEwTEhIiEQ3Kcan81E/+OP/9/+vvkXY7dEyMQ+LNn/8iv/G//rVcTztuf/7nqIcnPBw/RpFuFlQrlUScJk+brRk0UPLq14eM5FtYb2mWoSasBloulJxpxQGg9XUQQqC20lupdmayn9qSDhC+qJx34JVAa9aJcA4wXhZ1h0j7UOu1NS9PT+3VU1tc9ZyjAe7YWavfN6JKwChWfZzQqs+Ha2OeFy4uNj7LV+VxrtRjYS2ZaMYmVgQHPJVARR1IzYgBNuKR9kH9s58ANmfjkDNLUDRCUINSkNaTegEzV5P4aym5eogeQMRJj4bzHdCXd4Nyjx332HGPHS8QO57zeKk3KC5ZCg4wNKYhMr5yxWvXG1JIlFzJJTOlgV2vgGLyqXEzDwRDA6IBK5maZ5pBOs1Y1b0KljVzWBaKKeM0sJ0ioV88w6sJ/ywRteos6vNNqL5jBipCMydXuakTUAvUilQQ6a6EzRdBU5eBheiEJm+xRqJ4K5jW243iJKppyWzXxmZaWKtBK0AlKATxS23mGQ4YHNeFd28OfPnpntu1EmJkI4EpJiI+jzxVOadEUKN6JdXnruIDWEQj09VDLE5o3CISqfMN/+D/8yN88UtfJu52bC8u0Nb4yk/9FJ959ZrP/bJfwv6w5+n7H5AuH6KXjTiM3hqNnqKaWsZEvVcKYIWa77Dj+5TjB6CBNFxyd7vw9OkTn1On6LLJqpScCcFIRHe0PFUq+H9Ib7d2LkHrXAAVJQTDkO5T4H3XEPzh468jPXBLOvA4KKlovy4AnuCas9tWtw9XSeLqkSEExhhImtFaWfNMbc3JdybMa+P92yNrzow0LnoVfzkGpjGx7XkntEpSZRKXj0Yxot+gVOs8g1pYilGXGQtOOBRxMqioA430c9KaezHkUjxSXkBDRMwwC6g8/yz5a+24x4577LjHjheDHc3Cc6/Tl3qDklTYDJHdNBCkMQ7KsNm5h4DqOT58GkemITKMgxOQmtAk0GLCJCFBiCJMQ8Jo5HnPfDj6BWiNHFe3dM6NFP3CBvFtdjMnOkE733A+bmyEGJ093Ylvfltrzyk46ebj+d8Z9Ivbnu20g/rfq1cw2m+E1lpPBvWFsC6Fec7UZgzJSVwmgSEo4zh4u7JVmniGA9ZYa+UuF57MHsZ9kYauTPDKodUeWNV6LpZ2hn4DNPjvpTtm7h7Ag48jwwYJA6EBOrO72vLuu1+hvueeDpMql+PQpYR3VA1sH73OePmA9957l9fe/BTzvDANgVYWFltIcSKbMARF1iPt8AG2f5d2vCNur6lJuLvdM9/eoSH6dS3NE2DNTbVCAA3hPCNWEczck8H/C9DMWfMYhUbCW+lghCidBOi7/1qtt7sdruiVFSI0aldS+L0WOGV39Cqym5y5CtWv7RBgVOOAnKutzWZgiV5lLWvubdpKFWUKMJHQlBhqZYhKikKQyiBKsODmTeoEtVbctfJqCGyB0TJhrUQFjUozn7X7uXk2Bs/FE0hJglRjWSshiNtpD9Mv5HL/z3rcY8c9dtxjx4vBDpPn33a81BsUs0JQYzslxkGJwTX70nlGmqK7IgquubfqC6ZBVfFqA59BpxTR0DM31pmoQjWvLmIaeLjbsSyZw/4OK5VhSmgSj1QXwUzB/OY9tXs16JkY5IrFvntu9TxvlE6A61vm8w75NIfU/m/ooBKC3+StOZnNXTA9iCkDay2+mIfg1UqI/jCpFVt9tqo6ohoZx4Hryw2f+dg1rz3YUIpr2DfDwKOLid0m9V2/L4zWmkfRqxCsoWaYevT3dvcQ3T5A0kgzcQlgDHz2s5/havM/cLs/cDw21qD84m98k9def43DcaERmTYTd3crP/kT/5LtMNDyLXHakLY70pCQGIFKXe7g8AH15h3K3ftsNxs0BWr1BNmmCY3uvKjB5aGiwWehKCKhezaUTnqL3cOA83VQ86p26NdQxLNOvioELLvT5ul6WTNf0AJ9EO6/iiLiDo+oV2DhXFH5Q0Zwq4ukyiYG9mtF6FV9VK53WzbxKdpan507cfJ6E7neRkbJkBtikRADsQPrWpo7oQ6CxkBS4+F29ORRzEmYwav9Viq5+EMIg9qvt6h0gqN3oEsHIOtyx5eYgnKPHffYcY8dLwg7Poq/40u9QXl4MXGxSd6eHQIxKUMaoNk53ls0uHa+VhqKhgGisj/OLHcHGkJSoU0DbDdoHMilMa+Z+XDAqrO3dykRh8RUJ+q6dpBQUkqUWruLnrf9PmxQdCI4uaVz6DNLOYNKqw3pOnrroAN9FtnbidBnkuY3M+Dz4BhRCeRSKbVyebElDZFXa+PV/Uwx2I6JIQZaUFpnsau6U2DsFdIbDy6Y58W1+BgW/DyOw8CQBjS4aiB8KLGyNW/t1VZd7pe2tDjh7lOe4IoG3nzzFX7Df/2r+Ml//hM8vtkzbkb+97/+VxODV4piSmvCP/gH/xPvvv0VvuUzb7C7uGC4uGLYXhIRpOyx9Y66/4B6eELZP2WKCdNAGh4xDw+4awUsdDZ5c/WBCNoD4aSf01qrx5R3ANdwWvT01q0THH2RBZrRTbJ6ixcIXep3ehAgXVIH7pPQr1mMoIQzQJd6usZ97Ny6mZco2yHyYDcgWmgpsovCFECs8OYgcDEAjYhxtRm52kQuk6KtQLXeai1QwNSw02fHpZJB3GL8Ykr+OtqrMHNZrD8YvaL2yto/n5za16cqDq/6FUPay6viuceOe+y4x44XhB1fLxyUT735iMvdhiG5s50Gn3MhhsZAWdeuC1fqKXExCE3E25/LjOXsEdYszPmIhYFhd0kanNC2znMnIWVSTKBKtcqyOoANQ0JNXeJe3dnQDQQc6DR6VVObS640qMvHehUUO0FKRBxsTjer9TagtTNgGXTLYF9EKXTpXzecmgZfLFb9887FmMaRFALEQO3taz+coT8NA+tu7DbfnajVZ+yCEjR+KDuh9RvzWbuxEVANpHFyvlekkwgb4zSxuRj5r375Z7jeKl/++S/x6MEVdb7l7sn7jMPI7Qfv81P/6qd5573H/Npf818xbEYurl8hDZO7fNYFoRAUJAg6DUR2SG2Mu2vSZkN49A18+tG38NYr/4J3/9WPUZ9+QCmeuCl2Wtin+XHf1Tefm8cUunzQiW0uCTxpA739quYLzx8iDkXag9rcf8LPjorbdtP8msUUCD37opRCR7qzc6RYo7RAKoEhejLtOFbW5gATamYi87GN8FA3BDGCGJshst2MriQpbl9damWtPVNUAk31TMBUYEypW6IHsOpSVO3sfAwpEAYfAzTgOM+U7LPyKpDLyWUykoJS14V5Wf+Lru//ksc9dtxjxz12vBjsOK7Pjxsv9Qbl1YdXXGx3PqOV3gLrJ9JiI4t48JEJmk6mMUZUZTMkZA3UVondUVqbgUIyzzzQMWF5hVqoxwM6TGitxD5LXdeCSOBkcXxaiiEoopxbe6UnOLZ+84mcZsJ6Jkid3AglOOlK1Qls2CnXwaup2EEyBk/8DAJNICqQIkG9Pfth7XqMiRqUljzsjL7TbYoztmN0q+4+2na/gOA/Z/iOH2/hqjzzbxA6Ia8ZVmZofa7dFzXA7uIK/eQ3MG02vP7669TjngcPtoyyUvc3vPXlnyPEwK/7dd/O668/JI47d300yMuKlVuSZcQysh7QMmNt8XMet9jmIdIKF1cXfNOv/t/xyi/6Vn7mn/x9Dl/6Ka8Kmu/0K6fv5wDinhKRRp/vWyd7ifhuvzUnEJr7E3hbHMwUDckNmjjh0IkTANL8YRGDkqK36luvnlQN1W5trb39Xeu5gto1CMElg7QV1kakcjkqF2nofhh+n2+SW7KrJGgee742N6KSkDjJEKMqm+Cupzj1ATMhJa/O6urXNioM00iMnnSaaMxaaAil+kNZVP2BplCbt4tf1uMeO+6x4x47Xgx25Pr8uPFSb1BCSOfFap0lba0SVBBNZ4OjWpqz0lvD6oo1NwPaDAELo+cN9F2wmZHXQBy8nSXSGII4ac7mLpPzAKfcDZdckoXfxeLtMW0A3rJr/fen4yQTPPkKnFqfIp7p4Ztwv7CIBzGVUtzOGq+IaEbLmQZnoOoccT8HyWWE3beSEBKNPgc1J3J57xFUwlkdIOL6dZWe69GJbmqCigO0A2CnQmkgmvmMNy9I2PqnUCWkgRBHNA5srx6wvbzCciavdxyffkCrC69//BHj5pLLq4ek7RU6XtBM8RzQ7imwf588f8DdO2+jDaaNEq+uCZsr6vgAHa9BR1SVhx/7BOHX/B/42X9orO/9DK0Ucs7uYUA9kwZPbpHajadqBwJDCGaIOrB7hoifBnDFQfywvFYceKy30N0CW5DeJq7tQ/4GQdx/QQXVrjBEsG6FPUR3lRxNyLWgUtEAw5Ro1eWtUbUDmINYCB4Rjyi5+pw/9ETTkz32oIpUTyg9qR81BH9sNVcphOrvP44+dlBGhiG5xLBXV6FzIlqtfq4+Qqv2a+24x4577LjHjheDHV83HJRTSNPJ/Q9TVHF3xU4qSjFhlhG0k3TaGVAa7SwNa9bI2RnUy7r4bNEaLVc38DEjWoUQ3Zmx+QKwWjzsS9TDkcRbqUFOMzc3u2kNRP2if/g4zY3PvxeBLo3zn3WZntBo5ouE5qFT2TrjO7jT4Jkwp4IQUPGZroiCBI9OF0E6WFlzDwLRhqg+Izapt2hjUESN1gRplaB2Dojym9uzRygritHKQrAdTf3zWqu9mhBiGj3VsoHpiE4X0AYGM1IcSdtL0vYRMW0odzesT79MkkzZv8vx3bc4PH3KcnvL3d0TPvOLv5nXrl5HabA88bZqGlGZQCPXr7zCN3zbr+XtH1Pyzdvked8DatuZXAiucmj4n2uJDjx4penXzFBxUuIp4Au8SnTyXzds6soJNfdACEG9rQzn6jj64Lw3ibv3gzgYiUIwr45Dj1efiATx2jevmVydZBZD96QIwdvQwrM5eG3niltDYLPbMk4TVgv5cOdyRWdTeDUEyOhMB5WGxgQavfWsgWSQa3M79uDJq64AAZDzXPllPO6x4x477rHjxWBH/Qi48VJvUAyfWdbmRjilL/jSL471BTUE9Z1q7W1SdaZRLkaj9heCIC6vUrFzpQOCRl/YMXrl4ko5j6k+VRIhBmJMwJnDhnWCFCrEGJCgxDRA80j22pyl7S3D013ZiXDdzuh0hBA6Cam58Y+Iz6v7rFe7FJFeDWrwyuYEQCa4L8EJbE+t5RA64/xkuSx9rmru0WBg6ox0AUJMpNRTLFUYEFqbMHX7ZFMwKdRcqCVjtbgFdHVzqzQkaAMql8/IgKIM22visEPKwgf/+if4Z//ff8IolXb3GFkzu03g4598k0996uNsXn2FePkQoxDWJ6garRWkrdjlx0CFVz/5jagG3vvxH4HDs4DO/AAA5a9JREFUO+S+qlQ9H8UfLpy/b22e/uqjYuuKjeqJsLX2iyq0VjHqec6P9Xj1TpIL4lWxe1J4BdWaX68TvJzgqo/jnWjZ74VTVLtZO1uWr2t3gQxdLdCrUL/nXGramqFrRoJbjseUGLY7t50umTq4H4GY9bh1nxHHpKSQPEZePDDOTAkpue+C1p486w/dUgpRvHpv8vx+Bl9rxz123GPHPXa8GOwY+DrxQaEUrPiFBvFUzwqE4JWO2ZkhL0bPTPAgJBElaKBoheZMaj29VPDFlvrbRPF5ryqklDxd0qCU3NnUDj5+nFquroH3lpyDSIyDt89ag+ZkMKC39Po/F/pcXHrCqlc1ISbUOpO8zxLpc+oQo1clKh1mnNAVwrM/NzoY9ta2O/oF0E7UOm1qRZ6xyu1kutN6G9wfTG621MlanUjVGtAO5GXxz77O5MNTlsMt0grWcpfmRWTcItMGVWVIY792mbZ/jy9/6Yv8qx/7X2FK/PRPvc1rV4Fv/eaP88br1wwXOx68+ibT9oK0vaJZ9Op3PbhfwO3iX+HqYxSE3cc+RQhw8zM/zm7/GI3dUMiMZgV6TeIqBzrZz9vYH65uHIy87Qnusnk6N6efs/5zJylgaw5OYv76tc/j6WDkskM7jxfo1wfDpaTNnSoFObfpT9fHHwCcq1HMr1UcCqlUb/FqYBxGUnRraxknYozkUqH4eCGlhAbFrFFzxqr7PiBGE1dfxBTPrW3Mszo8nRXi8vKqeO6x4x477rHjxWBHyl8nJNk8z2RraFBSTD07wJB+A1g7ud35Dl81eruuAuK77zSMqBaMznzvTTjp7U03zZHz7jjE5HPH1s6tN8FZ+UgHAQnnKsHMCVPjODmhTN1foZS1/zsnt7XiNsixVxlwWuR6fg1363OyU4pOUDvPc8VbcyrS5YTPWPTWGrV5qFmKg3+OEM6t4BMhj05io7/LKcnSPwdeTTmVilMmh51+WmHKd7S7x8wSOLaFfDxSSvHvKcKYImZ4oJo4KTGqUMtKefoW73zlX7O/ueGbP/9ZpMLD+pRv+PirPHj4gDSNbB+8yjBuvW1tjSAFqlLajCKE1GiHtyGOMFyhJbO7esjms5/n8JWfInagM+tVU4fkkxQO7MwlELMz7hogIQJKaxkr5dnsvj80Tk6Q1qvHE0Cd2sJmrb+fnK+tY1nj/E7mYEX/89O/bb2KPAGN8w76vLg/jFozpk7Uy3n1MUWfMWuMjNstOWdSNyBTdWdTkUAtmZIzJZfeyu1Ve69Sz7JV8Yrd7xOI68u7QbnHjnvsuMeOF4MdcVmee52+1BsU/+6Nsla0t6bMDFNvmxrKoNFbtAaoEnVE1NtsMSaCQS6+oxtSdLkX3r6THoXdaj0zsb2qCOcFiFUngMUE+Puo+qzP+o5ZNBDTwJBGJAZvk1X/dyFGTH1W2Ep16d/Q6y/jfMN61REw8+yHKIKm2KuvTlDT2G/MXhWKW2V76692QFH/rL017Gz/ToeTE7CebmRoZuf3FoHaF9gJZJ5xETxOvZQnLMcFqcJ8e+uVngg6JPI+U8tKHBPaBKywtJX9zWPy3WOO+wOvvPY6WzUef/mneSVGri93bHYXpGnswJoI44RJAHF3x5gSzYRWjZBnePrzsHuVYBXWFaVwdX1BOd4Rwkir4i1o7FytOOj8G4d5S9T6HF4kUOoKPU79DCxnkLBekPeWrCimbiVtrToRsSOz+yL4A6H1qsdoDm69VXu6JnACpQ8rAk7R7u3stXCy3UZwnwiePYCMTszsqoKTXBGDkrOHo9Xaz0UjxuihaHQiZes95dNnUkHm5wear7XjHjvuseMeO14MdsjXywYlDiOb3Y5ccq8mHCTikIgxEGrpu0j6TNhQCW6WY+ZVDKBrxFpzA6JxPF/YE1u5dTMlQwiihBSRnmgpGKHvFuk3o7fPnskATRycYkyICdUqrVV/rRCQc2VU/HOHU+XTKw/re/VesdTWjaTkFODloKdnApp0oKTPMl0yd3Y8DN6ePUV4fxhkNESfv5vPyF1d12fM5kQrBxlvQ7e+0Og7dqnGJiS+9P4dx5//Im/9q59l2GyoVxccbh5TckW3G+IYGDYjl9cPaHVPXVbSMHCxTejtewxlz+Z6YhwHhmEkjgMhRLeKqB6xjg4QAi2NiESsFJCAyoLdvYOGBK1Q5xtk2SOt0aiE8QJk8BwTHLBd0u9geZovSwdvUfGWtgTXZPaZr8jJCutDh7ikzs+t+rlp7sL5DMTk/IuZEykdPDpHoNauupAPVZrP/rc1byWfgM68gHvWkv8QMPn7nMDHZ8Hn+0ldudL6rB9r57m5qrurNrM+c25dFupJvN5un/+T1/CLOu6x4x477rHjxWCHfYTR8Eu9Qbl69BpXlxfntpVKIKZ0np+Cs61NhFJLn9G5CZLrsj0/Y1lmWimkYWTcXDCMg8/5Wju3aFurtJJpxYFLgrd5YwzEkzTL3CfgBDinC8KH5nCnOa+dd95y3m2H7kIYeoaDWyr3QDL7kLRQoJRCNenAEt3ICTl7IlgnQ1lrXomZubnOaYGodhfKwNmI6EMtWpHuoGjQxG9INVBJ3nr1FyH01/I9tzCpy+EOxz288oh//Tf/33zwlXcJKaAXO/LVBY++6TUubOIywfr4A2KofMPrj7geAxfJqKtw+emPE2JguLgkqCeuCoblI3WpyLTBBiPurjthUzGNVIJnPZjR5j1BKi0fqfMeW48sa0XTxOX1q8TNFRL8WksdeZZSy4fAQjvoREQjKr2Fezpf4vfUqeCgnys5PXRaxeqHKpl+7/osuYH5+XdeAM51ACxYb5n3FvLpAdO9D/TUGjaw0K/beazwIbWHCCKhg0pAg3wV6JzMoQwni9YPV/zioHJqXdec/fNbv/YfRS/4NXbcY8c9dtxjx4vBjszzFzYv9QZl3GwYNhtiZxj7fDMiyeekop6hIKd2WG3kWiAUYoxMmw2qgU0pYB7QFYeJEHp6ZzuZ6JzaeStlmcFql+N5FROHkXCaJ3aWvdtgD32h24daZ97uqmZIU0rNTrQ7A8vJm6BvbVFU/M9r95/QGJGeGgonwOvtP3EQ8zuVZ60+s86G79kPFdQiIdg55wN5xvRXkX7jAapUL7wonaB3aukCns3SK4LTxhuE42HPZ37jr+Ff/i8/ybjd8dpnvolPfGrLmynz9ruPKQaPn2Y++5lfzGtTI+YbYlvRiytfMCESpg0tz1RTr+SAoJGcC2JHiBNS8STUNIJI92BYaPsnrHVGa2a5uwGcQ1CWhadlZvfwVYbpmjhcOKmC4bz+Tm1Pn9obIgE00Mzlj84zdGBop2wM/ZDl+KkyVK8aJcTexj1VIfTKyhxYrCH4nL511UPrQNRo3dVU+gPHp/ku83x2zq27Wap2d9H+ZazfPycLbr9W/cGhTlQ8zcJDv7ane7mW2t8bNJTzA720Ci+xk+w9dtxjxz12vBjskFyfe52+1BuU2PX0JjxrTakRaVg1xAJEdwJExKMegIAQY9ebiyIx0e8XqIXWSt+hnmbFTgATq4hVrKz9RlLMAqbQNJxbXq26K59q7JXHaQbXQ6GMM6koZ3c2jCnSp488y38AKD47DU6wEzP3ahB5thDsFCol58/sZK1nrVqa725r/1XEZX8+D/bZ8gmAyrr01/QZsqvsfL55avMhnbHfCYFm3uA0E+Y4Mj54hXh7R9E73vzsp/jGz30rl9OGT1wsWN6z207ENPE0X7LZDIzlA1IYyMeChYCEQBgmJA4gARm3fn3riqTUd/yRvBzBDK1+YiVMlLJA3lNuP6Asd7S8+ox+GBjShmU9cny6kAZnp1dthOESCQMQekVnqDnxrX97f7iYV5Xg1Z+HZDkAuAsnncWu58pT43But6qfqA4CXqmf7L9PTd9m7dSV7WRN42Sl5cTFSqz5PAbwpqs9I9/113XXTn8IgBJit+PqpYz1MYAIrgZpDbXOnTBXJZjULuc8PUT8nMSWWdu/1aR+aY577LjHjnvseDHYMX4E3HipNyi5nIg51neq4oY3tXrQlghtMXJPAPVKxlupw5DIy+w7eb9HcGvpTKv1vHBFvH0aU+qVRIFazqQlU6XOPn8Fd8orJbPmgpkS4nAO6jpdlpOx07y4S+G03bBLk3/GWpDeMpNexqg2Vwv0meZZR99bsyIewnRmbZtRT2UL3RHhBDKdtS1wbumiyjCOCN7+XQ7HPvc0UhrO4AcOcrWD3GlueboB1T855WpLG6945RPfyMVrb3A87juTfIW80JY9IWyosmF65eOkNhMbaBnYbK5pywENEVE3/xmHkWq+M2fdUyQQQyKXQj7eooc9w7Tj8M4N4/WrvsMvM4eb9wm4H0GMA60U1nrLbrNhyY13v/JlXkfYPoiEoSJSnQwZRhpetSg9yK236sVaZ8CrL952Ml7Sc9Xh7dBwVoaY0cG4z5bh1CX1h+S5dDxVLv0tzRxkrOGJqtIfLvV8PU/t4pOa4vRSbj52SmQdcHDpIHmaWfcq7WyK1XkF2mWRp+Tcvjx8rciz2XkNm/9cS/kX/LjHjnvsuMeOF4MdJQ7PvU5f6g2K5xxAGk7ksP4X5uScWivFGqU5GJ1mZKpKrdkXiVkng/XF2ReiN2HplsCRWCohesR0bdZJUnQdOsTY+g3SerXgsiprjdJKrz4EQylWWXNjrYJJIFigiuvKqy2s6wrNP5e7/4m/X7+R/bN6CBZw/tnWgVRFMTm1hCEGXwxrdnKSOz66mdQJpGV/ZBgGWmssxwWrzQl0J6llbwWLgqkT5YgJ1R6YhVejGiI6bIlxZMBImy3j9hJDsbpwWwqTRdq6opcPmbaPiLZgZUK4BitY3GNdZhnUq0wrlSZKGHaUUmjzkTLfMR9uKE+ekq4vmd99j82j9wnDyDAOBKm9danU4uQ4ywvHPLO5fhUJiePtHWmcMFHSNHnFQvOFmXaIbjnJMQU/zyanh0pf9B/2NuhAcWbQ19x/fsBUT6WtPyhwsG9dCmpyUlKcKiNn5vuv9PY//QHYzlWup8H08hr89x2c3HGyg6JVtMsRvV3slfLJTCrE9IzkSCe49dfsn/qsThFRhpdXZXyPHffYcY8dLwg7hq+XLB4EcinnTATwiyOtoX0e26wh+izu+1QdxT4rbrTe5vVdXuz/ibprYjiTloSKoCmhIflM90R3b41mbiEcVICCaQUJaIgkILTq0RSGA9hgpF6RjePAOHo1Wq0RzB0Gg+pZsuUtUL/40jpByYwgCmrU1qu5WpGeMdJPEZh1kya3Ik4pnf9eW6WW7DtyDUBl3G3B3PfgZMGt6nbcsWc+aBiIaUBiOO+aQdEYubPEIYPGCxAhdQ1+LQv7vVB1wzAMjFev4/T+LaREsUKURkgbrBUszzSrtLL01rTS6srdB28T1gO3773FMG1oUrG7W8J2pJSCxkC1iMaBKY1e/exv3QsgL1w8eoMwXXAxThjKui7IevCZesv4ha1ovkPHh8j0sPe7DaQ8u/mkmyThM2QnB8qzCqJLACUkqtENv9q5OkVOfhH+cibtXGWcxwZYb7WGcxteTPw/8c9h+m8uePGqq3U7c+3tWRFoPTeF/n0MN9zS4HN4M6z0zJnulWH21QB7/nDCy3vcY8c9dtxjx4vDjuc8XuoNis9bjXXN7tLYo7qVZydD8BamhkBKnXXdF40b4vRZXu2kNvF8ihC6BXTrN4H5xh9cihhao5UOYsmrAROPlyYkondY+2zRZ8n13FY7zR853wQhOqClOFFK8FwEgRAjKcYeVe5MrzORrVWCBkcuzB0HDdBwNhU6/VeLS8BSiIxj6jfK6SnTW3fntq7v6GNM7rXQz3WMkRQHX1aqXvEEt+H2hRHQGCnZ5XdIoLXKGAdfcDEwTVv3RtBIiwPDMJBChJagLWgrIAOo0MSrR5pgtbIenyLrnv07XyL1DBBNA2kcCDr0GHMhjCOqiRD8PKsrOKE0n50PI5urR0ja0FqmKhwPt1jOTLsLWkiENtFaodx+mVQXmB46qNRCt2dwB0/DW/dwrpZOaxcDU/XzWVfItZPTXDdR++zdn1Pt3DKng7ZXSJ7bIcHrcvl3mDbRW6nW0cq5BX360Kt8o/X795m7p7fXe8Bb5ylgrd+bDWlOjvMWb18E/f2aPH8V9LV43GPHPXbcY8cLwo6vFw7KMCSmcei7SKGHOZzzIU4zYsXTFE28Ajm1MwXxNiAg4dmd4Ys+9urqRCjrum/ztp/EQKX2CuhZNoPP4YQYnbgmemrdaV/wpxaZV2WIEjSdLaeRRq3Fq6x+gwXtu+vzDNAPofmM2HyLHc3OJkQiHUR7yzrGZ0AR1GeczThXe+dZptl5Ru7R7+EcCqeq0B0Iz2Dnp6z/vmG1oBgpbjxzpOdotArV3EZ5SBsnF56kj3hLu9VKzQfIe6gLttzR1iPL3WPqMrMc74hizjRPA5uLDUL/ruKL/ZRpMkwbSl6oJZP3t0QMamG4uECHiSbuLRGDJ6b6tYC5rAxlQKPbWQdV6vEDgmUYrjCZzrN8h5TanSLpi89HASGm7hkR/Jo0dxw9gQJyEld22Wg9uZb2+bM4wNCZ9842bB+qhvlQF8N6sulXrw/p4CCn+74a0grSQcR6FXUiS6Kxg4/1Vv/pBY2zm6UoaAQNEF5eH5R77LjHjnvs6PfsLzR2xMLzHh9pg/KDP/iD/OAP/iA/8zM/A8DnP/95/sgf+SP8pt/0mwCY55k/8Af+AH/5L/9llmXhu77ru/jzf/7P88Ybb5xf44tf/CLf+73fy9/9u3+Xi4sLvud7voc/+Sf/JDF+9L1SEmXoxkGOD70N2wOvsL7jaz6nC91cKfQWbq0udzotGug7x/4aEgLSHR59fitQeqCSOrNZ+o1zkuOp9rCtkHrkufisVfSc00CXltUesqQh+f0iQggDFp+5/olK33w+UwUYfV7cQ5iEk2qg+XZfn7XX+mk5VzgxJVQjMXk1Y+AeDa16xsqJoAUf0rTreTZ6qtoMXwS1v7Z2V0qvxHoUuXZ/h7NfhpznlyeCHkCpBa0Zlj317m3a/l3q8Qn5cENZlrNsLfRqaxgnNE2em1IN5QTYzyRwrVVaXrC6EhTKvJJCwIIiKUGIFGuoRSRNDLtr5rsnTBHIC2vNBJEuCTTqfkFzJm4eeKy7DL3KeEZgBJDWCP0zaEh9IStYgC4H9erXORBijVYznDwU6PdUvyFP/gStlTO58Xyznua8Z8kFvd3rlexXfa6TrJPihD2/wIAhGvu1Ll0Cy/nXZ6/hnAdC7HNpObUFnuu4x4577LjHjnvskN45e97jI63sT37yk/ypP/Wn+OxnP4uZ8Zf+0l/it/7W38o/+2f/jM9//vP8/t//+/kbf+Nv8Ff+yl/h+vqa3/t7fy+/7bf9Nv7BP/gHgC+q3/ybfzNvvvkm//Af/kO+8pWv8Lt+1+8ipcSf+BN/4qN8FAA0CEGlp3lKb6U6OIRuyewpk+Y7Vzwx8kRs84rFL1xplZgGNI48O7VuzHSqJE69Wt8N9irBvO0W5MSE7g6QrbnhER20NBA4jwIBIcZOHOotY7ETqDgZzyVavoBKyf3vvZqK4DvSM8hUaC5ZFA39Mz6THZ4+v0YHA6/yEgjU4tWKnMLJOuO61UYTIaThPCs+ZZM0a0gzFEVjr2Y6eLfSqPWZKuDfzJZo5oK8VgwN0PJKWe5g/4RwfEw+PKYtB6/+xg21LGBKLZ43IiEhGglAVaNSUTyrJAwjzaAsB8p8gJoZYkA3k5+/mDCEUjIxuvpAzKghkTSyf+89ht1EiEKtfh+kzQVx3GDzY0rew/SQcPEqLYxehcG5+tXYVS4n1nrz1rr0hFm/P7x6dgKDO4Oe7hcfNeAVvTgoYydvAScfOgQ9q1rx5u/plu2z6Q+BDv7+Im6cZXLy3egPx1J7m9evUYgBMz37MvhmJJ4fXphhrUB5fpbsPXbcY8c9dtxjB2ZYfX7cEHumU/qPOh49esSf+TN/ht/+2387r732Gj/0Qz/Eb//tvx2An/zJn+Rbv/Vb+ZEf+RG+4zu+g7/1t/4Wv+W3/Ba+/OUvnyujv/AX/gJ/8A/+Qd59912G4fnkRzc3N1xfX/M//Lf/Vy53l0jwmWntYBI0INF3utacdHTewQNGozYotZEQRJvfdGkiDpNHfTevDGp1y2Y7tbxw3nPsbVz5ELicNOCni3cCBEQJaXSJGc1vgF4FnOVXVmi1nXe9/fIg4nI8D/F65vTXWoPoSgLqSRamaBr77XdqrXZb6VJRaahGT1zVQNTBd9A8a/225u27IB7OJarENPqN2eVlIQw0OpD2ys96m0+s8P4CT7MzzK0/zEpewLx9aieSlhmtrOT1CHlGl8dovgNroIkYBtpyoM03SLmlLpmYfHYMXmlVM2onwlmDYZoQTR6fXgvaCtYKadoSYiJOO+LlqzBeeKWqkRgUyTOWZ8p8S1tnvF0PbZmJcSRMW9K09SrTKmH3Gnr5Opoe0PTZtfQ2Z2/J8+yhQzNayQgfMilyVD637kMczqSz0z2C0Elrp2t6Wvh2vveeVfpdMsjpoWudwCluIGaCtbW7g56ujVfmDjIuVQ0x0iQQBP8svao7vbd/n8rNzQ3f9Ct+FU+fPuXq6ur5QaMf99hxjx332PH1hx23T5/yi77tVz4XbvxHc1BqrfyVv/JX2O/3fOELX+Cf/tN/Ss6Z3/AbfsP5Z37JL/klfPrTnz6DzI/8yI/wy37ZL/uqtu13fdd38b3f+738+I//ON/2bd/273yvZVlYPhQwdHNz478JA2F0eReiSMlO0hHtZC83kjE9pYMCdMc8A5FCMA8HC332rFapLWMlQ80EDcS0AU29lekkonOqJ64bP7V3FetVk/+sip0DvnzGiy8i8F22+A3aakZE/L1E3bSpFUQqQfwGEfHKr9ZKs0ogdga4UPLq8+3QDZTooKMfss02oXUilRrUlh3k7JTd4VLIVhuF3FvdsK7HM5i49t2rN3dQfAYm0jrj/tSNxuPDjUgafMGEEKnVfSJKXqlloeYVDML2VcweOoiZ+vvWhGkhpkYKBdVGGiZEQ39IeDiVWKPkDDio1XV1V0t1U6kQI8P2kqaJhhLUyYQpBfJ8JFqhWUOHLTpskJp7pyBS64qWlXUJTLsrTAfq8QktH4mXK0yXNIm9ypazu6b1e6GXNWcDqxMANIQQB4Ke5HruXnpi+LtXhuepNPHKR/qs2qtRd748mT+d1BXeCu5+Bf1BUdbSKzGvwOz0kPDbAroDZTzfL50ZUDO1cb7Xn3k2eDbJf8xxjx332HGPHV+/2GEf2tT+h46PvEH5sR/7Mb7whS8wzzMXFxf8tb/21/jc5z7Hj/7ojzIMAw8ePPiqn3/jjTd46623AHjrrbe+CmBOf3/6u3/f8Sf/5J/kj/7RP/pv/bmOWyRtCGn0XaEqg3iVYaJOFquNNA6kdJL3daKa4LK+ajTL5+yC02HgrcA4EIYJQkLsZCvNV81apfks+8MzVhFxElzNRDMIyV/HoHYHSI9B6OZJp3yOcYOIUPNKLRl6C1ab0VDqqeFlhpWMhpNzZL8J6VbTvfpz1r5r1p2ANUKItJqxVgimNKt9Jt2VB631m8qorVBr7iSyAZVA1RUJI2EYQTi3Y6mVUgtrHYB4viFFA9W8vV37bLqVQikLJS9ONpSIjBs4BVCZh3GpFiQsRIykW1StGwkZ45BozRnlHoLVOmlsdWmhGWHaOB8gDjRN7vio3pamVcpxIVil1MXPkzg50USI4wVxuCBJo5aVEJT1uEfHC6/Qyh3t5ucgX2PDAySM/tqnmX8HnBOh7dz+t2fcA2vVPTlqhVO2a2+VmwSvPq3fU1YRGrW4jM96a/z0Ws6LEES6TLEWoIezdRmjqFDLKTAOWvOQNIleMVutrjagUkTOjpEnIPP8D0Fa8QfxRzjuseMeO+6x4x47PsqI5yNvUL7lW76FH/3RH+Xp06f81b/6V/me7/ke/v7f//sf9WU+0vGH/tAf4vu///vP///m5oZPfepTpHFLSNN5139OgQzJd8klsy4Lpx2liuu4m1Var3AkBmjxvMtHlCEO1Lr21qeDjOA3hAaXyLXTLBlDSk8YjZEQk89Y8eqCfpEkROKwQUSpIfTq58Sur95ulYCkbvajStCIiC/6mt0ASU8t395K9b/PTpBrOOPbmnMEDFroXgzmoV+KgdFnzV4pBivU4qmUqurW2+Y7aW8tg8mpqlTERfkEEToTjNa85dcM/z3tLFcEn9WeDmuN2laaVVLqFawGX/TL3CvZQIojoW1gTpT1iKYVOJG9HDxD/6wnFnurBU5VQvCqME4bn7NqAg2UdWWMA/vbJ2zHgbzMBHECWAzBHzgaaJo6ITASphGrBUmZfHyCjTuCBtpypOUFmWZ0ekhNF+eFr90N68M8BBHx+b2B4HkVp3PS6uIPCo2c9Y0hIWH0n60ZadWrPKtYtxun52Wczy+99dtqb6fjxlwq1GKcEkbp0toqXsFXKVCzu51qwjT0+XO3rz61lGuhLgvz3f4jreN77LjHjnvsuMeO493dc6/fj7xBGYaBz3zmMwB8+7d/O//4H/9j/uyf/bP8jt/xO1jXlSdPnnxVJfT222/z5ptvAvDmm2/yj/7RP/qq13v77bfPf/fvO8ZxZBzHf/vDpxFJkdbtqV1W5a0p68mNm+2m7z2f8ZuDRiqCxEQYxp4C2T0OUJRArSutLJgE4jj5DrPMlFoJKCG5u5/RNd/2jA3vO2Cf+3K6WCH2dpghKWDVW3bSd8qqShg3hDBQ8+wLZ0ioJvcNyAuyHH2OHCJxGFzmRsMo1Lwi5m6BnWANKsRxogFBs8sKNXrburflRMBkdDVA8yAr7QmUrTU04Autz+hrc5JTW/aUVghpAOmySHE/Bqz1iHp/fQleYdZa/fWpiEbi6OmmRqXmI3ldyHn1xM04sB5n6nIDeUFzIS+L3wtD8rRY9cXs6gEPz5rXmVYLadygaUSHCUkTrQpVks/HW2M9HkhiLPunpBB9zi1Q7EMmVii06qZD1tUSITFsB0qenYQYR7StsDzpYPMQ3TygErqM9KSq6PNfXL4axNvhnKolc6DwULjuDmmGExhnb3d3A6eSC6IV7WOHmlcMXPoYE83oLfvRH3TWU0YNrGUHJA3OwCc6cTB0El71Nre1FeLgVe4phZde8dZCqyvVPlpY4D123GPHPXbcY0erz48b/8k+KK01lmXh27/920kp8Xf+zt/hu7/7uwH4F//iX/DFL36RL3zhCwB84Qtf4I//8T/OO++8w+uvvw7A3/7bf5urqys+97nPffQ3l1M8uc/lpGcg1FpQgVarn/jutBiazxqbGRonN+fR1A2EjNYyrWRMFBl8Rtuq502GYUSCUEs9J3v2D4FhLnUU6TvG2i9o36WL7ybppjd01n0pxXFIpF9kB6lWq+dZhEBFSCEhGNq8BVvFPRsMIERooNFbzLEnQKzr6hLFced6/HVBrDj2hAC9pVlLcWKbukStFm93ap9FtuySPxGhlWftxYCx7p8Qhg1xuvDvW1e0VrRGBGdzA9TiTo5ilXK84XD7+Nl4wIxSCnldiTFRTYjjSKiGlgOhHGil0EohBmVZVlqtpGGkNYgpOEEsdIKj3FBLoY5CiBNVB9ABQ2kSKaUSVCjrgpaZVhbWsvZ4cENiPJ+TU5HXLIFkJ6N1O24VyOvMuHtIs4iVI4FMPb7vD6zhmmLm3IITcdHAVQCFatmvm0r/fplaM2gkRcVMyLUwH24wM4ZhOMtpS86IGNpbqjkvSIy0WsglE6K7TyJKjAPWhNZWpFQsH70CjoPf3+pcAfLqFY+JB9ZZRSXiEfJdCio+f7cQIW6I4T+Og3I67rHjHjvusePrDztCfP4Mr4+0QflDf+gP8Zt+02/i05/+NLe3t/zQD/0Qf+/v/T1++Id/mOvra37P7/k9fP/3fz+PHj3i6uqK3/f7fh9f+MIX+I7v+A4AvvM7v5PPfe5z/M7f+Tv503/6T/PWW2/xh//wH+b7vu/7/p1Vzn/osNagnqoNgRh9F9fnZdKazxCDuxGKCiQ7tzvneY9KYBi3nmVRV1ptng4qCQJYzeR57eQqRYcRyys1r7TcKyczap+Xejz6aYbsIV0NEFGib9sJqh5C1o2YvN0riBmlzJ66aYJodOJdmbv1MIRhIKhXAa1mN9A5VV74Tezk6sK8LmiKpGFyb4FS+g4291m36/o1RGiNnLO7SeJyNBUn7klrZxMhBAdtVUJ1HoLl1ee75UioxiQTx1JcDceJRAd5f8P+6Tvk5YC3W0FUWJaF1gq5TYy7a9J0gaoSybAYSRr7eU8Gpu2pDdplic16EFoghh6lLkI1dzsNEmgG1YzQW891zQQa62FPUOuBbM5DyK25s2irLMvs0svawLqXQBtptaHRgXe+eZ/h4trb770tX/fvonnB0tUzxQJ4Wzy402VtTkwspZKX2asVHGBBWBevuE/cBBU+NH923oCVFcpK1IiEsW8cBkgbnAVg/hAy3PdhXXxODH1eXBEpNKtYXfus2R8WoslfoRsz+UMvQfBqqraV2p6/ErrHjnvsuMeOe+w4YcfzHh9pg/LOO+/wu37X7+IrX/kK19fX/PJf/sv54R/+YX7jb/yNAPzAD/wAqsp3f/d3f5XZ0ukIIfDX//pf53u/93v5whe+wG6343u+53v4Y3/sj32Uj/HsMKMsR0peSMPoMjlRrDZanyM2FMurz1db7fPESK6VfLyjrQtlnIjD6OAkigXfxfvoLbheIS9ETVjwnW1rjZL95ohxIA0TrXaWdAg0EZoIzjDz1mY2+pxW3WCnM+ctQLaKmtsyl8XbctM4OkjlxQFBxWWMtvY46z6j7NIvQQkhEVMiBiUBsh4xqz3IqXYtvIDouYXpEduVshy9Ujtp3Zc7z1agAhXEUIlOkOyujJgg9ehp4iVjNEZVhibsa4NxROrCfHvH8fYDWimErmwA8/ajJsrq5D9rBVHD1iPzB1+kPfkyra5ojKy5EOKCRSMO3jouZSaOA7TYHw6eDCvqngXOZPfF02qBPLvTpEClkeeZGKRXWQOqypJn0pCwNJHXTIqAFZeMDsGNjIoTxgLG8vQ94jQhBFC82lyfuFV03HZAD2BGkwrNkLMKwtA4oOrz+hDU2/RlJYWE9jAymoO1j/D97mxaIQ00CV2d0X08NKLxJDekt3Yr6/6AlcUfHurrIwbrVXjCzDNPqK3PuI/Q/SpAsXKk4Q9HqRVbj/fYcY8d99hxjx0fDTvK8ztQ/yf7oLyI4+Rl8E/+2n/LFGGejzRgO21Iw0Q5GdyEiBBpeUWku/TFCCF5S2t/QysziBvLhDQSpp3PRj80/9OuG4d4rrBaq0grlLwi6g55mFFLJg0jaXOBidKsYOXUNj45PCr5uKeWlXHaQhycZCaClZl12RNEGMetzzAxb6sFnxGWniJKHNAQHTyzZ1GklJzJTZ9LipP3rHkrMAwTw7BFNFHBw8FComHUdfZF2JUKdd6zHu8oefaArsEDslBB1AHOW42dTFcLtSwUg0MNvF9HchzItx9w3N/QFn/tNO2Q5ItK8e9jBiXPrKtXnBsr1Mc/Rbt5i5ozOm5I44Zh2EAIxDQwDiNKYdpdonGDxoHj7XvcPH6XtHvA5uKBz1JPbfxWWfdPiFKoOTvrv7dereTzDNydFwPb60duMmSucBAMTSMhDn5+aiVqcLCJLkuNaexkNaXKgG5exdLOge7Efu+W3qLRqzW6muJEUrTmJEKBJkqulRgG4slNshmtuZTUnHNIaYX9PLNfVjaXD3nw8GPo4GZati7Mt084PH2fVo1hHEgxkYZEiN56xfxzWc2U9YgJxHHrRlSrVzyOFA52INze7fkV/8f/5j/aB+VFHPfYcY8d99jxYrHjdn/kV/yW//N/WR+Ur4Xj6ZP3OKiSxg1mlbun7zOOUye8+Xy3lkzJjZSiL0CBnDMnXXgp3n7TfqI1JpoopWTy6v4JwzgyDCP1LM1zkx7X9HfWf63uJ4ARireIwzhRKuS2sKyZzW5HiiONhlX/DBq6S1/LtLKeiVsYlJKp60wQIcQ+3xbtu+YAMSFxgNbnzm0FK0jDjX00kIaNV0sNzCrL4ZZ1OfSAtN6qDQkZLoFGzTPSOnlMhJC8WhEAA5HqRDp8Nt4w17pjIAMkJVrjIgmHQ2GeC8vxjrLOtNpnxXlGib212bleIbK9fMTUKlGVdHzCQqFoYBiVMHmgV15nwjhScgNrDDFw3N8xjI00QUgjaXNJSF6VukN5o5WFuhyx9UA5OTDiPgdOwHO2eVlnQgis68pUKz2pAjpLvyyzM/3V02ZLLah4JYrCcb8yDgOqCdEjkp+iYaBa8CwLM6QWzxOhUdalh3P1AK7Odi/WTixBgggimVyWrjRotLxQ6uKERfWoc62ZmA/Up5lDXYjTznNk8grLgdAWpBk2ZxgHVEcwpUkgl56x0efZGhN1Wf0zmZFi8vAyCdSaPauk/qdxUF7kcY8d99hxjx0vBjusPb/676XeoIhENI2M16+CBurxBlMhDBtC8Iu4HG4IY/GdHuIMYnNZmGpAhwmNiRCEMSVOzn0hRCx54uc4bhBVZye32tuvwUlA1sB6tLniBj/iDoWEALUSVbDovzdZqet8BoFWK3U+QHEGuoSBMG49k0ID1Exdjy6Qq57QaXn1amw+uMtf6xr3WslLY4guIax5df+CNKEhERkwVmortGJ+jsytnaNEYj8XdV373N2dIFUVq5llncHcFEhjIsQR1UhppZP6cDOgmv0c5kJduj+AGGteWOcDGoL7IoToEfQaaFZow4aYNkyyQH5KKntCcpvkNE3U6jN4K6tfg1oRJjfZWlfyuqLRI+1VvHqseSaI0eYD9XgHGLkWwLqs09D+nWr171xLpdbKcX/HdHlNbxmwzDMhCLmDTBzGrlgoiAlxSMQYyLkQozrY371LbBF0Q6FiJdOyV6MaApaz8wwMnwNzyoBxEByGCRrkZfFqvBkhRoolUpdS5FyQOLDZXLCbtlhZWcsdMlckjER15r/LHTOlZCwXqlSfDfcHJlYpZXF1SHMzrtwMDdGVCRK8bVwCUYSQPpoPytfScY8d99hxjx0vBjvC8PycsZd6g7K7eMDFxY5qBQuB6dEbWLcaFg2k1KV0rSASfWHHgIyT7+vD0Bez0sqRUlasgdbqoWFWaGtjnveEmM5kMmcqu17f8xLiM1UA3b55ORKOdyAe9KUaOBz3pJRo65Fi7lgpAtoyUcCimwe1dXXHwggS/Ub2CsjZ6Ianifp9URBxd8UYlJyb51WECNSebnrXP7e3iTX4jFyEHqKVWPZPWPdOBHNlQZ/DmpsS+ezRnQmtnwPVximDwzrQWKd0aVBSaFDmnmzplkPVDGol2+wApJ4nImFAdGK6GsiHA3XeQ14xDYw9qE1jwqi0PDu50Dwoy4AY3BOBvEBylUXLC1jxxbMcaOuCjKNXPnn1Vn4novm3wn0ExMmDy/HOK+eYUBplPUAafP5L7t4PDnaESCAhmjqrP/d7Rsi3PwfjI2R8naqCBiepWVNXVvS5/tmh0ejptG6yVPJMrYUgXcbIhK2rV9Pi2SLaW/NCgyAMunECX/GqMY4Dms0rmW7z3UxIuPyzlQJtxWompBHUg9gArBnLMfv9kN2DwmW26Rdimf8XOe6x4x477rHjxWDH8/vIvuQblDUvHlZVVrRViANCJR9uKaVSU6Lz8oHCuhyJwU+ajjtER68Ujndo80huRIlhwD0RCgLU4r+602OfKavTy5uJS+KsEYcRNFIAWqbo0duc3g/seR0jRZXYII4jec1IDWhwxn6xSsBJZOu6x1ZlSE7Ca6USpCLBjYclRG83mksNpRs0gc8fNUagkZfFnSO7gyUCVoxSC1UWn5CWlVpWl0oatGFyeeSJAQ6cTI3MzD0NSkFbA+q5Ld6a+xioKkmMjQJEZJjAOtu/VqpVUOvBYXSy2crTp4+JQRi2r8LhfWS9c1tn9eh4bY05z+TDHbk20u4SSSPj5YZSG0EaYRh6ImfFamZdZzfK6i382FvjZl2VcTJwMqPiD4SYJmpZWeeFNBq5LIg1N2aKnkaKGaUDTcAlfIpXxLVVbCmkYQtaKcf3iTqg2zcxiWg5UJuHfDWUOGwptVDywdvB9aTycPfQUlb3imwNWxdy8Qdrutyw2T3wa7YeqMcDrbjsr9SC0AhtouKW5mJAit0IyhDFFRrm3QEN/t2Nbsuu6lVzlz/k9UjJMxIj+fjMQv5lO+6x4x477rHjxWDHuvwC+qC8yGO+e0qoM6pKyU/RD94lDgOtVpfPlYSmkRgm5nVl3h8Qq8QYGKshayavB29TJm/ZBvHWZG2VVrLvlGOirsszB8gQPMMhJlDfFacYCJ1Ep6LkdcG6qVDNq5OERInDgMbA4W7vUjdR90uovbWrEQ0jRp/Z5Zm1FkSsp2u6KVJpRkju4yCqnVTX9/L9/4uAldp16KdqxWVrpdtGl5a7AZH5zDIquWRy9pm2dK+Ik/xOmlsi19YgAQq1rogkLCRKrr0ycR+HB5sEc6Za7UqBgRDFHTTNfSFqcwvm4/6WcXfBuHuEbHYkO5Lf/yJG8dwIt5qEZsz7G3KplNbYXj/y+XoaacuBaA3LM1RfHBKiz32TG2tpMJp4oipUcllJQcHUgVNdfRFigmGgmHRjJP/uDZ/l1l5FxTQgWqH0PBExUKXkBctHYhixujB/8C/RNdPSZQdvdVZ+VCwMniOyzOz3ezREpmnTz091dURIWMmsrUCMnXDXX0M4k/TspMoQzyxZzTy+veeqCG5LXdeVkty+nVaIgj+sqNSysuyfcLJsl5D6awghTYgV8vr8bPyvteMeO+6x4x47Xgx2hK+bDcpxT4yBIUbWdUEp1DoS0kQctyx5QfMNa8wc1uY7zXlPQFjWGV3u0BCYthekcYdqJASlBSUixDTSrJHnPXU9+gru82ONbottRHIprGshlkIM2RM8w+BtcPWLMi9HmlXubm6IMbHkhcc3T7m8fpUw7UhB0XmPpMQ07MAay3JHHCZf4LV4WzekrktfMRViCC4DDO5WWHP2CGz1mz/XGYoDZOhzR8RboyFEBg207DPY1ly2Nuyuve1aT94G7kSY19V3+cNIrYV5fkIMXmPGUQmafHftzDUQYTRjQyELtDS4fM58Tl9KwYIRemvSaOR18Ups3CKb1xg2N+T5Camb/2TLhBgYt5fU40wM0QGrFMaLkRqjk/06WaxhZ5CVMCD43BcLiDYs9nPYvPqkdtfEfs5jnxmbRscYq9RlpovwiCGRhkithiZ3vJTiMsCURiBAigTZIhjzky/CxRuUcEXSgFEZgkFvxZbqM+w4DFivYmWFGAckDD7XFUGDJ7Lm456yLKTozpWezusPBk/J3aBpcoVKnjGMNG5oDTRkam3Mhxvubm/YjAPTZouEwvF4x93TtximCzZXrzoprlSGacMwjLS8Eg4vLwflHjvuseMeO14Mdszl+Yc8L/UGZbp8xPjodebHbztxLQ3E6ZI47ri8fkgphcfvvY2oMl5fMU4b6rxDBW9Lznu2w8S42VJFub15QqmZqwePQCJVPTY9l0wjEGMipMS8rqylkkTcrKcZoNRSyMcjIRwYt1dICE6OEoNWWIrP/YrBxeWVE83KSq0LJfl8MwShlAWsuaFOnymKqt8YaXSpnDj5rJVKNYjJDYOcpFYcAA1v95m6tE/ETXTE5X6lFELwWS6itLbQqhBs6uxun99aq24wlDy0TOOAqbDf3zIEZbvZOmlNpTsPeks7TBfYupJKYyJANZZaWUt2K2xVTu6S65pRFfIyc7i74TJtqNMDwoNPkW6FZsa8vyW0ypoLm4tLpt0lEjsLfRxJ0xVaC23dU5qbLfUkK4DexlZq9qyPWme0SwFbzd7WFPe6UBrU6i1PDed2vYq3Lk9hZqVUcs6E6FVPTKOrKHqiaRwiSCCOmz7DnZnLAbFIlUSuXuGFWBCENG0YNzuXJsbooW9yyjFR0rTFrNHWo1uUi6AmXtGUjMaBFEY/vzF2hUAkl0apJxJmdSJnDOSyMC8H5uXopEMTIqAauH74KWRI5BZIOhACrqhYZ3fe/EjT5K+t4x477rHjHjteDHZY/jrpoNAq0zRSx8RxNtbS2I0TFoynT98nxoSMA9PFIy4fvsbxsCcfj7TlQFmPPH38Pj9/mLm42JBr4d333iNq4JOf/AY2u8uzZXEad0zXryB4vsRme4GZB2XdPXnPteyq59CrvBxYSmUojRi7eY45IU5VyfOBQZVNSrTic87WszCO+wMiAZE+Wy4urwNDw0DLHufupjlgIWC1UKpHnPurGHmZWQ9HD7wKSmiNNESGcUeVSGmNkMCsktcF7eqFtTRaSKQY3da5eMiURm/ziSRqhRQGHjx4DVSIEhwQa6VJgGHiuK7M77/LIG5TvYmeG7KWigApRta8du2+Z6asqyeUrutCo5F1wq4+gU2PUDWG23eoj/81owZvw1plHBPh8hFcvMndIhyfvMvFbvIqpVeNpXoY2ylvRUPsM+audugKjVZdSUFvZ5ZSaHklxu5t0S3ET8AI3W66ZbRF8lIJQSnZ29S5t1ejKKa9Rb2d2LXCscLxuGBWqSJuttTlqrlm6urZIYgQhw0WB0o1xs2GVgpWFqiVEEfcadQlrhI9bt5UQAN0oyQPGhvcrKt60qyKoZrYbq+4vHwVggev1VoJA4Q4sDa/twb1B2nrmSulVNb88nJQ7rHjHjvuseNFYcfXyQZF1xsO734REyONA0hgXg+M44b57pbbuxt0GqkhsNZCtEpbj0gtLPOBu/nIO09ueCNGhhh49dWPMR8PvPf++1wbXFw9cDmfTmwfvEE+3rHMTwiqxGGirkefb4LPFWvlOC/U9UjMK7s4sbt8nUPOzPORkGCMMB/uuFu64yDG7e0ekz0xBIIUNtsLNCWsGuu6snSzG5WFcdq6ssAKLbp1tcQB6tJnrY3jcabmhTCMVITWhFGhirAyU/EqBVFnfMcEJKbLLZO7EtBqcYmZOrkN89m1ywO8kojj5Mz1ZqzzAWsrRQauXnuVNlZWc4+DljO0whgj6wBa1fM++gw8hOiVXIgs3WwoxsAwRFo1dAyIAsMF4eI1NB+oyx1BGiUOtPESq4WnT38Gu33MdvMapbjttGdC4PNzdVWE4FVVCLFr+I1q6h3avKK5oHHAK9uMSsaaZ6Z4CJq3fZ3816jNQ+Baa8zHA0NvrWsQb4E3z/8QVQ+iq0eCXhB1IvaU1JQSuRTm/R1WM+OQHOzEmfPWKm05sJaZXArkY3d0TL2VDzFtyLWx7m+666fzAERPHgnmM+CUUBXqeqS11WWN6oZUQYNXsuKz5zEmUpw8fbWHn8VxS8sLw/jRwgK/lo577LjHjnvseDHYEdbn9096qTcobB8g4wVjGjnU9xiSkoYNm6tXOBx9ZtZKId88JXPDdrtjmWeiNHa7K6bdQ775869iLXN78z6bYcOyFN7+ys8R4sSD174BnTbcvf8lbt79WSwXbm8/YJy2TNuF1irj7oKYfNd9uHnMe++9w/Fw4JUHV8x3T3giAnFgunqFVivExILywVe+RIyR7e6CWir7m3fBMg8fPPSbV/bM85Ht7prd1SV5ySzzkXK8o5VMCoFBA8M4MW22HG+ecPvkPVozxjggquymLcO04+Zw04OgBhruLBkM5tsP3JlwnHzGGlKXja2YVJoGAuK7/37KY/KddOmx2nVdOM4z2YTNtMPWmcdv/Sy7i0tiy9RSUfNZMVYYxHyBhoDVTBXzQLUQaMFnnyEoMW0Iw0QoHhWOFSREdPsIuCaY9Vaqez0c7p5gCKVVcukzbxWq+SIt64oCtTVSiGiI1CoQGi3PbiCEeUS5eHUoKSFWPW8kBQ/BQs9GS1hzg6acKeJ/Vktm6eCQ0oBKIY0TZV0gDYQ4MD/5CiZ3jNef4C6PaM2YroQYSUExutlWXju5zmfbPj9WxpRYW6Bh5zGCiANgWWcwWHIjWMPqQgjenqZlcnHfgyG5r4E04/bpU2peiTGiceMEP1uxNpDShFQ3hSp58fsgO0mxfYRZ8tfccY8d99hxjx0vBDvW5fkjMl7qDUrbPiReXHl8+eGOfHwMROK2IK3y6MErMF5w9+7Pc3F5zZgGxCrL4Y4wTYzDlsvrK1JSri+33Dy5peTGqw8fUI63LPtbdkNi3T+lHQrbrc+iYxopy8HntZpIw4bWjOOyMAwDQxSGqAxDYj3eEccLsgE0tMxcP3qNtL0k58w0jiRVttsPeO/tLzKvR4oIadyQhg3bzQY1Tyodxg1pHFmWmbwcOTx5zHEpTJtdZ+cLx8MBSysxKHd3N0zmbPEgICrU2lBp5CYUhFxmohqDGGVutAqtNdb16KoDdQ8HFSGXStgfaeYzX6kjeV1Z5ztya1C3DCrk2xvW+RYRIdfGtLnAIztmhlohbmi1srZGCKlLGo25ZIIIQz/H2lNejYoVzyTxdFRF1LDSWGdXUqzznrLO1FZYlwUz8aqmFSK4pLJkVCMqfXbsNkdOEmwOpM1Kr1YarTYHCo3ewiQivZrwaHPz8xncYCnh57pZw8xVDyFEZ8mLdj+IyLDZsdx8QHn8RWR8k30OpPng5lzQPRIU6Bbqw0hKkWU9UpaZEDxlNajQivsmwEpdF4YUKGkiqj/A5tsPaHkhpIHaCtS1z5N9Nr4uFXQgTont5pLalrNrJxKodXWX0+AcCjNYS6GWTK4vL0n2HjvuseMeO14MdizL8+PGS71Beffnf4b66CFXVw/AKm+98y7C+6Sv/By05q6AaaIeHiP5yPGw98TLcWI+HtiGwN27X6Y2X+yDGEvdo3Ulhcbtky9RpYC5vKrWxoNXPs48zyw377HbXVAatHlm7otyd3GFWGEcEmnYUkvuN+/K4bjneLjj4uoRF9evolfXQKWuC7vra5p9DKtGGLakcUQFDoc77p6+S6uVi4trqFfEOKFp6C6NhSfvv00Ige3FlWvY1z15PjJNG25un1DN2O52vmvOK3neozGRNjuG7SWqnjyal8Xn12lC8RZqrc2dBGluPoVxnPekYWSzuUDViXmH/R15fYJcXjE9eIVaVqRlYkisZfXXCQFM0XxkmwaswVINqvtMeLAZDHHw4LIQaLWxVp+dnxw8Q4zUulDnmZwX6nzH/vHbaDkiBstxz7jZsWZ/f+kSypIzMcIKZ9voVt10yclk3lp1cysjBm/rEqKbSgWFXDA8nbXWepZfavBq8RR7H2Jy34dWqMXn1SUvxHGDXr2BLTPr3WNogeniU4wp8fj99znWwuX2gryuTNttVyxUSoVqynZ3SW3Gl770RaZx5NHVNUGEeT4gNFZLpNF9Cg4ffAltBSuVSnMPiVLRFB2wm5FSZBi88hMRgg3EYQNhwF0hC1SvHjUFRCMRIyYjt5eXg3KPHffYcY8dLwY7xik89zp9qTcobd5z++7C4b132FxcsHntGyn7x5TDY9LmijRukbZy9eACQagY4zgw58LNceb2cGCMERXjvXVFNdFa47C/ZRwGLneJdZ55eveUh5cu37t9/yvUdWVd9uQ8M44TpVbmeWUuhe1mQiURpwvCZssogTo7AE3bS+7SU/J6dCZ19aAoNbeW3k0b5mVmbYXbJweuL0bGkBjThqUcePuLX2QulfHiimEcCDGS1/VMbhtjZJouYAjk+cihB6GNmx3T5eto8Ique1a6k2BUirlErhIQzKWFJTOmREoDAPP+BqsL08UjLlLoLcpKtcK4mfj49QNoQjZje/0GN7dP+Kl//g8Jrf7/2Pu3UN3XPb8L/DzH//E9jOM8rbnWPtTZiialF1VeNI0J5qK8sqRpaFRCGpp0JagBKQLSoFRUvFEaCdq3DSJ4q4gEbwSNneqKSSpJVe3zOs015zi9p//xOfbFf+yt1RFdOztx7oXjgQVrvu8cY8zxjvF83+fw+30+XJ4/W46yzQYpQc3HhYWQEjkmgg8L0CcFfliaRU4kv2CbpRBEsdyFxrwwH1KYF0gQgqE/cbh7S6mhXa2IMS0rdyPJSiz3rmIJlhQ8xIXkqFgMqBmWu+XHr5Xywn9IZISSaG0edyWJqBQ5JhCSLBYuhHpUrIe4HEnHlJEqLXwIubQPZgQyLNI1WTQU56+Wo+Hg8CHQZwNlw/XFM6xRPHzx2UKazBFrK1xIRL8U2qUwYbWiaVb03QkhlnbI1XqLFIoc3XIs7f1iPk3LLlJJhVGS+Mh4QIiF7SAF8bF7wPvManOBLmpIAbFo8ui7/eIBkZrSLNX/xaP2/qs4nrLjKTuesuP9ZIdZ+s+/1PhKL1BUWVOuNmhtmLoj24srYmVRc0O92nLqT+zv95hSc3b1Cnv+Ad1pRzceSLYkzjM6Q9ms0a2krFtsc4YbB+bDLaumXaBE7YbNZoO1hq7rmNOyG3LztGjGlaLZXmDDUoCW/YRUBltuMUozx4Bzi/a83Zzh5mKxjpqSpn2GUIYu3hDnHklk1Wy4vLikP9wS/EhaegAYXGB0EVVLooukYWC1WlNVJcE7dscjF0WDkposDat1vVRVS0UYjhz7E2Vd0ay2pAy6rEiiIKdMnHumaVow19GTYyBJyHGELFmtzxBCkVLA+6U9bhomcvJLH7335JzxRAICYRSX1y/x/XEpqNMaJZfistktBWQ6ZzQJRF7uucmEGPHzSIgOlRdVOFJiq5po7AKuio9FW1Lhpp7+4RYjl12b9w60wflEoR/5CCGhlXq8j37UtMdAXPaFj6hrgdaWlCMITRagtF7aAXNCxeVzLW4QBSSkWERxIi/33xJACIz5oRfjhzpytbhbhCS6iSz2FFVDfXZNOO1RVhBsQwwzIgWk0JSb84VRYSymqAnHHWE8cbhbCgEvr56hdcEQHCl7yrIi+MA8dxQKXIgE5yAlrFGkeVxcMHnRyEu9HI9rISBmtC4e32RnTsd70mG3mHWtxSqDbbf4EDBFtWC8VUYU5fsNgJ9gPGXHU3Y8Zcf7yY4gxP/y5PyfjK/0AmV7/RKdPSkGUpzov/g2WSyypcJKCjLPrl8h6w13fc+ZC/iQOTt/hht7YpqXX1ZbLsdnRYkqaiyK6eFz5u6OhKEqG0SC3d0tMUW2589RuiSFCRdnjqeeq/U51WrL3B/ox56b28+4EIaiKJmnnmkY0bakrBuktrhuj4ppOapTmnK1YYwed7iju7tBmeVuc+g7TFHz9u6eaXa0mwvazQVSSNzc8fLVR0xuph9HqvaMer1FRM/9zRcUhcWYAj97wmnm7rhjdXZNYRpyjtw/vKOstphyxTjsmIeOZ1evCF4yzMMjMVLiXcTFA0JCGEekhKquFwBPVmiRUUISYiJnwwKpVDy/+gC/GpinxYCq1RJShTEoBTYLTEhMajGEDtNIijPjaceFeHRtPIqmtBQYWxCcxo+JkEfGYeD20+/z9vvf5vyspaw3C2ES8UgIXcBM5ETwj8bYxV2PYLmXTimSbaRqVgil0NIu9+aPR7E5s9g4cyZniTaa5aD3h96S5ShZCPEIi8zLcaZYIFhk8SMOwrJT8lRS4slIXS5cinjCsyWmTHfagX6GrbcLQ+ORximtoWpXRGGJPuCmCU+P90t3Rsrg3VJ8FqRBVyXt+dXyZhI97niLn3pSlmiliYCSGmEswQeErlBkiqyY55mqsIuNNPqFGNqcoQqBVhKjBEoVTKfje5v7P+l4yo6n7HjKjveTHejqS8/Tr/QC5fhwQ8VMu9pwfn5JnDqUqjl/9TP46cjx4S1FbVidX1OsHKRAnDqE3dLPM7NzrOuKBBgE8+HAeNgRosN1BwKeen1B1x/YHe4WhoAbKbQGaYhuwmpL01qCG6mqgmHsKMqKSpZk34ESGGMI2mONfuwhl5ydvaAbjpzGEyJD3L1DxkA/9HTTiDEFVVkxz45u8tiyWN4YZOK4e7do35Xm+x9/jI9hUb0nQXZ78tRhCRA1LgdSDkhhePn6l8i2IZVrYvCkNPDw9mNCTBS2xGrNcXcLOZJjwMWAqdcUq4KhO3D37g3ROySJ9WpNWVQgMsF7jC7YnR44HgbWreby7BnV6pKiKEkpYqt2KRAcDvh56dVX2uLHkZATkrywGaJn6g/InKmqGh7NnMtE9cBje1zZgDCsr17SH25w7kijFSSFMma5042RuqqWdkCW6no3+cWnkTN+HtFaLVXmpsDYBcgEPAKWFlHZDzHg6rH4S6pEjAGjLSkHYg4Ya9FSLTZWJEKpRwvuQpYUj4Y1pfUCS/KLhM40W3zfoUJPW9WLqE5ZVFGQUsZYQ8oeER2VVUw+IgiM+x1uPFG2W2RVY6TAhRktBa474b1nLGukqWisAaEXlbxQZCRIicuCcfbIlDBiXoiYAkqjl91aUQIKYRX12TWCxOHuLXnuCbFncl/dItmn7HjKjqfseD/ZcTgNX3qefqUXKL4/cXFxjrYrpnkCqSEJ4tRxf3/LNPQoKdl/9i1ks1k8E8wM+7c8f/ENdt2aMPV0/Z68v8WagmZ9Tt2scDlxuvmMlBLrs2dkqRl37zj0e+Z5xBFIytLUW/LwgHcT43FPqR89CMnTdSeEMNiiJOQThdC4x4KkVV1zeXZB1iXjNNI9vMGPRxprUULikCQ38ezZc9rVmtJonA90pyP74wFETbNaI6TAhYQpGowxGCawlqgsWSh8ThihULZECIUPAZ8iPklkuWZNZDjc4oYDpmnBSwQBoyVSV5Ad0+mI1YaXV89wfiYLgTX2Uc3+Q+iQpWxWPH+hSfPpsdVtJPjMcXeLGY5U7ZYcAqdhJMXIpl5TpEwlBTEKqrrhJgbmpbZw6cFH/qhNLrIo40OMxMktNERgdf6M40N8BBbJhd8QPVJIpvkRf53CI3NC4IKnNAthE6mQWuOGDqixxiz8ihRIcbFwSmNQYgmd6B3iUVSmtEEkkCkjUaAFwixArYV2ubw+UiuUhOwfq/1jeAynRFlWi+6cjlOsePP2Hentp1w//5CirEhKUtiCrjvQPbzjze7Ifo68vrrg+vx64SawEEnLerOwC5xD6WkRiM0zu+GBQlesNheo7JmmCY0iS4m1anmTc4GQEmVdE6Nfjp+lwkePDAJ3vCGnjAienCXzuAP15bXpP23jKTuesuMpO95Pdrjxy5+8fqUXKOfbNUoplBQ0ZcHYjeQw8/D596iKAlNUEDPe96zPn6GrNfXqnN3n3+bdp9/i6vUvwuaC+7efoABtJLUVDN0tSgma7TlxHjjcfIpZX4OtaS5fYVdbrLT000AQDl1tiMFx3O+ojeDY7dBFyTQMWFuRqZG6ZpwGtAwQJqbTgC5aivYCYUqqi9fUZUMaDqw3BQ6D63eknHg4ThQy0a43lOstz5sGZSzKVhhjcdPM2HckIqJYgV3TNBtcf0B0DxS2IAtBiiPHu3vefTxz8ezVEoYktus1buqRMqLwFGULxIWGqQxVISmK5fiSeWkXyzE8thRO3N7esF5vaIsCU68RVYsLEecm5mlECoUSkrHvl1a5eo1SYMqabVGSoicKTaULunFEywJtLaSAhIUwiVgw2xHcOOHGjvvPvk+/v8FYhUDiXMIUmhzzI25b4EJAIYgxg1gcIzlKspKUZUkMEWPtYv2MGZuWFkJdVJB+KHqTJCTmkdaYkiflpThMKbsEalhonEs4LkwCIUArtVTxA0mmHwnXhBDI5IASaxdlfbVa8ezqFUpkYhxJc8LPIzOQssKsXvILr/4xTFUwTSPlowDMh0CKGUR8ZE94vJ8Zhp7uuOfd7Vsur1/ykZKEkFFWk3LGNC3WFAzTCd1UrNoz5uMdcTqBXORt3o00dc20u2GaJ4TSWFtRGM30FZYFPmXHU3Y8Zcf7yQ5rvvyy4yu9QAk5MvqJ+7cH2qYlPfoO3HjAxwJTtghdMs+Jjz/7lBevv0HTrAGFITPs31HXG148+xCUYpo6nB+YppkQA0VZUTYtUVia85eU7Yo0D4z7W6Zuz9yfmCfNav0MXdbEsWPoDmiRefvuDUIuhsnNVWK9vmJ23YI6ti1x6vFjQFpPlo+iKx8JLlJosKsV0zwSXGYOEze7LzC3lrYsuLi8oN6co8+ukUIyvP0MKRLF5hUnN9EfblgB67KklGeQIyEEZM68OF/hnaY0DqRkFoGQYZgDaLMUxQXAe6I7YYoCoxX9/NhmWbdLEZnUzKcdvj/R379jfLhBPXuGnx0hRoq6QeVMKRcHiLIlWahHvLMgpkD0jjBPTKc9QSqKiw9Yn39AcX6GMQYRHahiOVaVGpRaPBX9iWHouL/9lH63tGyasmIOCWkEgqVoLUsIKZIEIA05B6wyhORI4oeq8gVjrauFb+DdopYPCFRZoK1adlApkuRjcZdQFIV59GAsd8wheJTSGG2WYjqlkHnZySEW+mbyy711SkvbpFTlAoPSFlnDqT+gqpdkP0OOSBJhGhZ+g5IU2hDmnr7fP3ZXbDnt71ESRPLLvbldMxPRZUVrC3RVc/7BNylKQ2nsglSXAu8WsqOPmbpec3d3xzRNCD9BChhlMM2aar15bIfMmDYRskAbi40bwn73/ib/TziesuMpO56y4/1kR4350vP0K71AmXzkYltRNz88glzANkW53H+Vq3OmlFF6Rk4DX3z+fS7WK6yWRFHjTveE/p6iXIGQaGNQq0uYI4d3n2JODzx7+TVcgPnuLeZwR1mvkKZEFCXKe5SpmN1EYyx122IrRQ4T3jlyVqwqi9+/4dDt0WVJub7EpwAi4YY9Eo+0BfiZ4AOyLJnTorGutmeMpx2mD1xePKe5fEnwA32/x/lPaOaBcv0MXbX084TUivPNc6xI0N3ifIe1FcnPzMOB5ANKLD33x7FjDIG6rbDKIEUkZE1R1uhmQ+j2uOFAihO6qhBK4f2ACCOlLhd1e1GhOefr3yiYxg6zuebq9Tdw44CbRpTIJD8RYiIm8PNEyhGjLEN/IsSENRLKBvvsI1S15gz5SIUMi9hUasSjDn5BbCv8ODLc3LO/ued4POB8om4TUchl9xIi2tb4FNBKL2RHJdHSgpLIKBA5QhIoswjBIgIjl8IuUkbkgBI1cwyIlEg5MAePNZaUQKnlvjnGTGapts9CAmrBPwuBfhRu/TDstLWLtvwR2CRY2l1V3WCaDTZ1zCIh6y2ayHS8Zxg6iqLE1mvC0EHOVEVB8CPdcYexBmEsw+mIrkpEtWb2gskH+mGgMC3nZYvO3YLo1haUoLAaacwiqZscdaGWn5fdUDUtYeqRUjCOI9pYjK2QQqIe7/Vjcpz6r+4C5Sk7nrLjKTveT3aE9L8TF0/VrpDGIGRGS4UPAwKBygGhBFpnbITp1FFLwel04DQd2K5WbDdn3O0dQShqUz7KzyLu7i3DOLBdt0DD6dhjbYH3A0kahO8wUjKflgmchcANJ+YcCTHg3AB4otKMw8ThbqRtSnzq0OOIOZ4o2xXWLA6DcThiYoMuSmQ2nLqOslBMxwfC1DPcv6PenKOshHiiWZ8xa8N82iH2dwynE7pomI47wv0b5maN9xPudMDa5UhVao3ksapaCsqqZur6R27BSIgeN88kZqYwwrhGP3oXog+MIS/Fb75jPu7RxqBNgbElxkhsodGqQueJ4+ffJku5TFRbkv2Enz2mbMEYkAXGlOR6TXQOEUZE2Sz2T+9Aa6xsyUaRXEBLuRTxxaWdMIbAfveWzz/5Nvt9R8iA1Jz6CdOeMbtE9BFVZ3J4bNUTYik6E0sgaFOQ4mJtdQHqtlp8GRLCD9sBtSWmhEiB4GaE1o87K4WUkBHL82L5f6VLtNHL7gaxVOjnvNw1S4mbB5KUKKMQ0T96MjTBTajgEcpjjGLq9zgTF8NpiKzOXhBTYjjsuL2/oVidoZKmaFoKW2PbNaY9Q5R73r37BNF/RlM1xNDzcrOi7w/orBZqpExI5TCyJD3eHQ/dga7vsGVDXRV0hwd2tzeItIC1VqtF7T7s7yitJQOIhKKgNM17nf8/yXjKjqfseMqO95Qd4cvP06/0AsUWNVlLssgEabGNod/dMaZEu265vb+nKEqUzBRlgcxrnFuKl9zcL/eYL75Ojp7j7RdMh3s++cHHvH79AWRNYQxoGKcjUhqESMzdnpAzn3z/Ey6fv0QVBUgYux0hJ3SMKAVDd0SZClu2YAum3VtUjkRbYTT0fcKHSNWumJ3jeHjAy4oAhPmEYo+MjiwTw+kBNznWz14hVElhLXK15vjFD9A5YooaNzmGeSafRsZpojYgtEFKxXA6IaVhmgaUFHA4LG19MXLKAW0UShdIpRj7nvl0jzIl/anDZ1ifnXM8jUw+UYiJwiyCtLIscWEhIBqlUEB/3NNsNtj2jIfbLyiKGlk0iGGgVCBjopOK6uwZwQ0c7t5wfvUccsYUFTIG5t07RFVjyjUJgUAiJaQ4MQ8dp4c7Dqcjsw8EEnMQ9OPI5jKz23fYoqKRlmHsiTkvkyNnsvBoYx5V8QVam+XYNHiULsg5AAJdlPjgHzFk6VFj71FFidAGJcWjrTYTY8aYYin6I0P2i7lVgdaP98pkjLHE4FA/rOpX+vFYXDyGcsBqicKRs2ccRkgT0S/9AJBJquCz2we254J1DDw72+LmiZTuON6+491nP+B8VdOkmbXIWJ+Yjm84HTLr8+fL3fw8s7+/w80j1mh8TFTVeiGPHvfE+UihFHW14th3HO9viHVN8J4+B4JzTGOPLCxDsO9p5v/k4yk7nrLjKTveT3ZM6cs7vL7SCxQ/9QyHbumdNzXGlo/FUI7j/Q1CCKbptBzHSYnUmkYpMpGYM8JNnB5uCDnjxp7htOdsuyKFiawr9qeZ6vw5rqkJISKHO2ScECny/IPXKGOY53kh63mPsJpD19GuKsqqRemaanuJqlqkrpnGPfdfvGG/P3FxeYmQiq474h57/WUcKW1F3y0QpVN/AGOJ08Q8jKT7O4ahZzodf6heYH/oWK/XiJwwRqJEpnpkAySREVrg3Ii0kWw097s91trFnnrYo1ReVutSk1Jms27xLnDovljcIGVFdgue+/Zhx3bVEKuCLgYuLy4WI4XWzG5mHGdSEpTKcux2ODdy//CA85HNZosWGaMkZV3jugeGeX78/kdWF89Ylw1xHNjd3HD5jZ9bkMlK4lMmxkhyE+PhARE9zi19/lJpujngQmYYJu7u7nn14TfwPjL5gI8JIS2FLRBqmVSPB6yEBDlG5jlSKoNIacFKI1DKLDRHlnteJZZ2Q5kTZAVSUBQlfpzQevmzlILk81LBrwSZjJSSGNzSTSAVSQiUKUFZTHNGJpPTQm2UcaZqSiZRUzZn+NOewlrcPGCvrtg8e0kYR5IbON2/4f7jiWpzTRYRTkcum4Y4eQ5DBzmhpWXoexCK0d9QVB3Huzu0URRVwXEakAIuXxTcvbkhKUvbriFHYgoYKfEpMvRH6qphnhNd1wOatrxA2/a9zPt/GOMpO56y4yk73k92mOrL58ZXeoEi6w2Ftfhuj3Ij2hpkUSBtgXczQgrmcakG/yFTQFmNT4L+0OO84/jxZ2xevECnAM5zcX5GUIKQFcfDDtlsSNExzj3pdEBljy0KQoL5uCcKSVnWHE5HUk6URjO5xRDZrmr6mGkJrJ5/yEZ+BM4RoiMrTZIaUxQoWy1He2SGrqfbH7HWYMqSuesJMVDWBfPcEfxI9IlsLcponE/MzrPZNEhlSH5CWwmJhcRYr5Zq9ATCFKyeXeG7ASkyqV6Mm4fTADES5gFjNEpK1uuWsig4zYnvff97rArFxapCVRZjDElrTuNEWdUotfAAZp948+aGwzhTm+X7KYRCFwsFURaGwXvefOsNVW3AlpxOI+PHb7Hq2zx//RmrVUVz/oLkAzk4xrfvyLZlTAI3dfT7O5KbePHymruHA29vd2QVsEXJ7CL3+57Ll4E5BOYpoLTAx0jyAW0qBBNhXo7xF+Pn0ubn3YyxBSknRAggBYJEDAFtCmxVkGIkjB3KlhSqRmaw1iKVRKRAjjyqy/9HXkFiCbUQ03KPjCJLDcoslfzagpSIpIhhIrueECUnt0eGgbGXxOQR0wCJx5bGwGEYaYVBynt2p479/kC73aCkICaBiywWV7tdWkhLSwozLnmyKKntiloX+Oj5+JPvU5UVZ9dneNPSzYGHt2/pjntCTEDgbLOiNJZ3X3xB256hbcnu+Mn7DYCfYDxlx1N2PGXH+8kOE7/8PP1KL1DuP/ucD7/5DY7ewTRwfNhR1hVVVTGGSFGW1Os1xXpN9gu9L2HQSkIpKLZXWHNHU5f4qUO1BZ2byaZECo/UMNy+ZYoRrRRFWRI8zCEihcJsLjg/u8QNA+Xk0FoRiQSWfvLhdMTUie999/eorEVLjVECEUbm5EEZplOirCxaLsVRfjxS1xY3jdTrDWKt8PsDLmaG2TGeRi4uL2k2K0KIXF5doORCP4wxEN2Eypl97yhKi8wDImWqtuTYD0jv0Upwe3dHU9e4OXLZrig2a2S9pa7PcPOBOA+InKg1TN2JaexpmxVNVaOK5T5R25I5RPr7PUYKqkrzwasrRFFgpSa5kS6OxBCQCkq7RgvJ6spw+/mnvPr6NevtJUPXcXsI/I0/+Jw87/in/+l/iilGLl7/DL7vCIcHvvO9j5m7HVJkpm6PDI7zpuK479nvJ4pGcRrv8WGxhAYXmWaPZZnYUmqyVKQgFuW5XrDcUrHYSHxYjmqFWPwTadnN/JDmmBNIIQkxk9yEsZacQUvx6OlwS7V9CEBeOAw5k0NEqUUWhw8I+XhwmxILarIgpciCklz2Z27oOBw6oh+xukBrgdGemDO6sMRsqFWDzIl+v+PYnbBl+dhJEBC2oT1f0U0zVVnznb/3exRW82y7oioKNmcbovf0xx223dK0l8TjA+PDHWaTmY8dD28/J8eZMWSmGJkznG82nD9/TvCRYTiC/+q2GT9lx1N2PGXHe8qO8fSl5+lXeoFibYHvT1iZmUSibCu0lAzHPVVRoOVy/1cWhuMQ8N7z9rNP2KxWiDRh6g3KaKb+yDyPDMPM/f2ett3w4mKN8iNTN3HxzZ/FlmtUVTJ0J9599gnnm3POnn0AGXQZSLWCmJj6mZQySWTe3X3MFBzPXr0EBPv9gcEnmkKwPavJORHcSD919HHxRmQpMNbgQqIfZrTRXD274mHfsTEWIy26sGilkTmhCvNo20gIkUFrjC159uySGGeESBzdwDj2pBDojgfspuXi+QtSyug6I7RCBY+RBUklfFrwzjEGMCXN+Rmu05Rtu9x/shxXliJTGoUpDI7I6EdsYZBti6nOic6zkomCQBxOSCHJCcq65urVa6rNFSkF6hz5uVdX/PL/8WeZjzvE3DFNA8fbN+zevcHWDd39G073b1mfX/Dig9fsjwPf+f0/5Hia+NrPvODsbMsPvv+OZCZyThz7AV1VFE3FlDNFYUApUjYURiGIS+FajuQYyTkxzdNiHY0ekdPS71+UoBZjZ3AOYwtEXsRdgsVbIQULJVNpUk5YYwjeIwTElMlEUlz+k0oRgkdrQwz+8e46P6aYxtgCP98jckCacoF+HTpkzpiyRNpEkpKL6w843d/QTyPCKMZ5WiBYWpNiR+g96dRzypJaRVTOxLFHKsn9uy8IMdI0DaKqUKrg5s2n2Gmk7DuqqqauCqJouD6/QGioq4rkwU2Ll0OXJSls3uf0/4nGU3Y8ZcdTdryf7Ji/vCvwq71AidM902EixwzJIZVgGibiNCFjxcl5irJmDjV5tUWOA9uLjKlr+ts3PHz796kvX2BWFaooWV1sKNozhEi8vb3jB9/5ArNpaUaPKiPSlGTpSUlyuLujtAV1WTIfb5BhQumCm+5E14/UhWUYBkxpaYoShCIUI3UFTVPT1hX96QRS4GLGh4QsNPM0cdGuMI3h4dCzNZpht2PVbJBxRiIYhp533ZFCQr1qkY9Kc49gnDwDBZVPjO9umIeOclUQ5sD9Fzecv35FXdaL5tsqInkxV9qCJA3juzcM04QRicTi2zLa0l49pyrM0lY39HSHB5SU2PU5xliUWzwR0ziRbUMot5iLc8rkycM9ZdngfUDZkipFVq6nuXhBiOAODxgxk97+PjYLCA4ZHWmUNKuasqn5E3/8l3i4Ocesr1hdv+Z1tebFz/8K+9lzdnaOmo/8/Dc+5/DmU2Sz4uQi3bFjpVqMLanLGms0pirQOSCzJ3u3bD6UYOiOyxuG1SQ3URpLICMLSUoZPw0o4sJUSHHZDeWMyBEpBYKMEooQFghVDn558cgLjjunRS/vFraDD3HBmD8q1o0CJQsQkrK09E5xdvUBPs34vsfPASkyVV0jtOTYnQjeY5WmalvKqibFRJhnhtOBcXKU9YrgAx+9/pB9N5CATV3g5oFTP2BsBTGw3x9YXVzTdx2f355YrUF4RyV7Yq/YXL2gO/Wk6JlGT9d3aJlp6q8uSfYpO56y4yk73k92+B8D8PiVXqAMpw6jDU21tNkF70kx0s+OwzBQ2YqkAnN/x0oqZE6cvXrNaZyorl5SX7+m3JzjD3uy69k+fw7CMM4j9uwZm6/9Imevvok/7VDuROp7utu3FHgurl/T7W+JRjEOHZvnz9FFSzsl6otIGAdsIbAK3OGedzf3KK2oVw3ZKu7e7nnz6WdsNxtWmzXFqsRFELXi7t0NWkrqtsYPA8Pkse4WITVz8IgYmKcJXRQIBPM8o0rLPM4QwQCGSCg0VbVGFSXd7sDm1dfI08Bh9wXVaoMqDFJkbN0gpCWngWcffY1xHMFPDKcDZdUgqxZlFi+GUJrjcSCXa/TlB9iyRDlPN9+itcIWlts3n5HtHrV5QVGVrJqCWJaPx9QSReL0biDGTHX1gnJ1wfTwBsYdKiZGQNmG0O+xmyukkAhrufzgZxAXH5HtGkzFpnrOVjz2rI1HvBRUxnJ7e8Pryy1zbUkhIFKPdQoxswi5VML7aSFoziNJaXzKCFVQGkPRlPgYIWeci8TpgJ9GXByJw0yOEYylLAvc2KNEpjCG2XvIEa0W1bkGbFEipERpteyoUMgwIYxmnieKugEyPidKm/DJYe2GdrO0WR53e4Q2VKWhcj1No9GmYt0+o7/7nNPbT8hI9vd3xHlmvVpTNxUuBFAlzfOXZDdQ5RtEiqQcGbsHvFOsz6+p6wqtdry7u2P7/DVXLyJCwO7+FiEiSiuSG6iainjyKCNYP79iGEZOh8P7mvo/8XjKjqfseMqO95Mddw/7Lz1Pv9ILFKUlRVkQcwIBWWh2px2n3YHVdk3WFtM0+Bg4W9UMhzu++NbfZI6Ccn2ObTb42xtiCoTkOfUDWRY8fPpdzrdb5nHCv/keyQ3MMpMBowVeaUT2KJmZ3cjsA7s3t8zuDf39jssXl49ExBUqTdRNyYffeM04Tthy2XHlLPng534RmRNTf0QDq3bLaXLoun3UcWtMVSHLgvH4gJWCVVFy9CeIEWk1MSaKdkM0GiXM8gPNGSU824sz3DhhyoqcJW17we6zjymVptmsQGayC9jSMPWOKBZa4O6zHyBzQjTn1B/8MkY4Dm++Q3CO6uJD7Pkr5tMN/cMtQ0rU7RpTFTzcLb/oKgdIHYWckMEzj8vxpjWBseuprGS9ajE2UAxH5Nklmi15P2GMRNRbckjs797ho2cKDlGtkfYMYbeLllxJtBSAxKaJaCT57AWmrPha06D8iVFlxlOPMRkpHEkXzGNPnGd8v4dpYrPe4GWBo+D5z/0cVdlQGsXdZ9/j008+YYyKeex5+Pxzkptw3lFZjdSanAOGjFFisXTKtPAHViuMUcgUqTeJoixxQ4/SE1XdgoBClgsiO3hMUSKB7B0yB4TRxL4j5simaRApsolH+psf0N/9AF21zMU5IRtMVZNzxvU9OQfe3vacXV5ilOWzLz6lSYnaGrq728d2RM3qg1/istry9vNvkYLHasuxO1I2DbvPvkNhDQnFs4++wWq94ZNPvkdhK/T2En9/Q3IzdVnQder9Tf6fcDxlx1N2PGXH+8mOqvjyeIKv9ALFFgUxOIbuhPMB26yoqoYcBapYqIoiJNrLF4h6gwiZ7nbP2faMhECnGZcGKBqKpDh+/gXZWqwy3Hz2Ke16y8E7UvCkw4Fp9igJ26tLjNGUxSUPt58T+tNyjAuYQjLOIyhJWxu0siQfEVIuASM1AoG1huPxSPSeOA48u7pcCqLSjFUZQSBETxojCEVT1aQUl53c82ccDz3EiN2sQWq00JTrEudniBHhI6pYvAlSS6qyILsdTSUR5TlZabSyCOtIo2d1cUH74uuM7z7l4tVrjscj5fYZhJGhe2AaZyYf8X5E1iV+F5i706MgSxBSompXDNLS1DXKGNqLa8bDPUpGghsRHpSf6I8j64stbWXY3XyMHG6oV2dMyuJUQ9tcEvZvwRRUSuB9QkrJNB5RZUMUGhEdyhToGBDuBG5cXmu17HaVranqNbYZwY1IKVFFRWiqhV2gSk7xAbM6p2xWZCTu9ga2K5QuEC7QNhUXlx+SMhjTsNs9cL2q+P2//bc5dg+0Vcm6NBSFQpOwVqGLkqAW6JJQmjlk/DCjhGAaB2LIKGsYhxEtE0pbXD8s7a4KmrJEZI9dXXPsB/pDz2F3jzUalSxGt2RdYxQ8vPmYU9fTrlqqoqQ5e4YPmZvdHuFnxizR/YnpocOQKFdbBJlp6LA5Eo+3IBV5dc5HP/PzXL/8iJuyQMeJNOzZffKH3Jma+4cd6zly/fwFr16+IsfMNM9s5Vc3Pp6y4yk7nrLj/WTH/W7/pefpVzdh+KGxMoG1tNtLspBoIqWQ2LpBC4ULnn53zzTN+AR1e059foULgdNhjzvMXH/tJd3uLbVyBB9wWbE5u6J6+QFZb5j7A9JYuH1LVa4oyi1lW1GvWmxdcSzfIgpLJhNC4vObW8w04Q97Jm1QzWbpY08JqRVCCEpb0jxbQXAL5CZlTt0JYy3zqePi/IykFaZucGlBTJsQkWjkxXNWa4+Yj/jhhFKZXBuCe1SAp0Qgkt1EmmdOw4DSElMtk9+aglSdIZuWYj4SxgM6R/zuFsKEloqmbtndfc7+7nMur68pt+fkocff3RJiwIhIff2cEBJ+Gjnud1y8eElbl8xjT61WMHW4wz1CaurNFqRiig5pDYfdA4eHA6ZtEPsT+3dvmfoT05QpVxuqQlOfXYEtEWomxBk7D/DuiKlWhBiw7Zr5tGOMnqLdMs8DRmmslDg3ElWBlwVVsXgoLAsm2tYtqt5ir15g0wJVynHGy0gaHO58zerZa5rtiigN0zDxx37x6xTrf4LkJiSZd7sdx/2BlAKohdw4dwE1zjSrNYd9zzhObC82kBYNefYzYzcsd8cxISUYrSiMIs4D1WZDOruim+6wbaQwBqsEcrtF6YJxHzgd96zOntFszhBh5nKzIudEtiUuZD765s/S3XyByjPN5Uuq5pzj7gu6u89oVxd0pxND98Cb+4QtNyhbYDfnjMc9rm0pm5YUDCjD2CeeXV1hlGJ/95ZP3n3G6vo5SVp8CEj11Y2Pp+x4yo6n7Hg/2XHo+i89T7+6CQN8+sUtzy4uOb9+RrlaLVXCOeJ8wtZrbFnS37wlzjM6RZJQtBcviPWKsD/SD45q09K5CdO06KCZTh3RGPJ2SwqC6e33EHjKuoaqoFxXSOlI3cDpeIe0GmUqjqcjfhwo6oaqqNFKccoBKyxiiiSZiHFRdxMiKkeUtoSpJyuDKGvCMFIUK+Y547JGZcl4e8fkHHVZkoSmunhGSom6LJnsijHtCW6m9Ao17JEi4J1bKrKVYnIjWkoqXaCVoH71s5jzlyQhkGEmPEzELAkZZJpBW9I4IhEYpRHaYGxNSJFGaagDaZ6ZnMPaEsXEpBTbl1+j2p4h3EChFdpaohvJ5YZQXTHJiO/u0UpjtEaZkpgzKguyNuRc0e1OtJsNttAURYnwE+PYkRCYFnRdkAO0q5b+tCfNHUZlSltgjQA/IuaJxsLV9pKoC253BwIgjCLkTHI9xkaKosIWFpmXqnshG2yWzBG8T8xzT6ktVkq0FgzDAZc9pMA/8c2PmNyrBVwgwQtJv9szDRPDMNBeb0nvHhjznl5uSFoy6YTRFfNx4Pf/7nfwWiP8yLY0nJ81rEpL6o+o1UekqmF7+YKiaTnc3ZEYaM7PkVqyak784e//He7bmpcffgBFQzjuGe4faC+v6D/5NmkciQh2N3+H68tLRgTImiwVtmiQUiKFYre/Q1YV+3HGDUcOf/D7CFOw3TS8+tovcP7hz3P7vT8gCmjOtpy6CeRCrHRTz2H88k6Nn7bxlB1P2fGUHe8nO/Lcfel5+pVeoFhlyCSmqcf7ebFSukBKihgVmQJtWzbrLWEaGA9H7j/5DuaNYlIKv2ngbofveiqrEUZhdEHyHuNm5r6jdxM5sSi+2zNGoTC65nT/BaeHd4v/oWqJs6fQis2mpRGS8fYdfRdIYiYIR9XUxNkzHk6k5Li4vmT78kOG4YSfOrIuUUVJxBKywjlPoy2qtigFx/09xpZkP2KqFb1eE5pLWG0RQuNcTzneURiNUXqhDGYoy5LCaqL3KCEpywKtFEpJvO9JMlHWNSF6xmnE2ApVVfixp91sUGULWqNDBqFIAsI8UdYV5WqNHxWFgmJ7TpaGrASxKDDaMpwOeDcyzjcMJMpS09QVVmScD2htgAARVpcfsb3+GoWB9dWWcOq5/+IdxhiyC+gYFhdFYej2t8jkSDGgTLWEx9AjH1HT/TQQco+UE27syUVD1Z6RQliIoUqTQiBPHVlFyrJFSs2cMkVdI7TFE8jKcjjcoID6/DkxBoKbSHKmaWqkNjwcT4gkqJuGs4sLTFGiq4Kf+YUad9zjhpHbw8CzZ68YD/dkZl589E1CkPTzyHjas1mVFDmh5bI7/vjzzyjMYguVuiKMPQ83t7R1xdHDN3/2Z5gTDPueVdtQtGuatqXf70lFgTSSh9Oy2/zO52949TP/GG3dMkvF7G6xNpAnz3AaOB4dkDlfFfy973yb9vwKzTXv3n3BxfUL6tWa7u4NZWFIQoLS1HWNqVbYcXzPCfAPPp6y4yk7nrLj/WSH/9+LzTjGgCkKqmaNc45hnKmaFeXKUDbrhdcgM52LzJOjXW2pNg3TcWmz8tkyJ0FVVZySII+JdaVoCsiuR6DYPrtGZ8HcH3n7xRdIoVDNBqEMs1nR3b7BiAeUMTzbNMTDDUXTINctzdkV3dARxp4wObQxrL5xhe87pIik+cT1WUt3yhxHT5aSlALXV+eEeYCcGPuZeZwQpoayQRhFig6VDuRuT9AFVbOCFKgLhR86pNEIoRZKoXgsZCwaUk7s33wXrT9l1bTMMSLUsmOZ+4m+H7C1JM4jkri4InRFzhmVE94HpFUIbUAq5mli2O/Z1CXK9fSnnpShbFd4H5iHjkolnl9uFhBQzKAk89hx7I5obdi0K6TS1O2KNI/43RvUpqJat8yzY+hHRGkxZYVzPW1pAIkWiuACumoQQuBm96g6lxTtGcFPHO/fIaTCFJFxd4spLCiN0AVYTRYZ5gGPhiioqpqQEjnOiKJkdh6z2kJMpKUZkIheigfPapS2vHi+wrmR+STI0pCAIk9UxjI3hu/93b9NzJrx6oL1Rx8QHh6obc0YRqr1SwiRmCONsczzTI6eX/75rxOFxCjP529vSGXFs+0ZJnSIdjHMXjx7Qe8CxIQuLLuHHXbznOMw0FjN5rxFNxteFCXNtuGwu8UmzTB2fO/T79AWDWVTk7Ngt9/hRM2H3/xZ6tWWw2GPnyfuf/Bd4nAgC8nm/AoOJx52R1ZlzWp7iWl+DOvXT9l4yo6n7HjKjveTHe3m4kvPU/mTTPJ/99/9dxFC8K/+q//qjx6bponf/M3f5OLigrZt+Y3f+A3evXv3Rz7uk08+4dd//dep65rr62v+9X/9XyeEHz/sNh98nWZ1hnAzRd1imoZx6Oh3t3z/b/x1fvD3fo/+7oExetL2gvajb+AmhzveU44H1t4zHjoOX7zjfNOyKS25H7BliyordNMipo64ewtDR2EkSgmMCFgjONusuH7xgvOPvs6YC46HI/3xhC1XbJ+9pry4oKgqVGGo1w1tUzIdjghtMfWa0+0XfPp3/hbv3tywHzzjHGAeGG8/I3Y7kJLth1+nffkaUTfozRZsgURg0wz+SDh+jp3vUeMDp90tMUeMKdGmQBmLKVuyLhZ/iC1/pD7vhj39cKLve4axR9uCpm6Q0S9FWghiCJxuPmc87BY5lS1QqqAoKgSZ+XSgG06IZsXDYeDY9RyOO3b37xi7DqUMhTIk78khkKaePI84H3DzUvxlyhrdbInDgWn/ju544tNvf5svPv4YN4+kHIjTslP008hpHIjZ451DKL0QF1MipYiSAmMEiYxu1lQXH+BRzMPiLHHeE7zn4YvPmU97oqkJ1RlBKMbhxOnhnul4xE8zpqpYr1oqY6m0wsYZGSdM3SDKliQMcxQEFAnJHAQuRFCKMUiOp4FhdFx+9A3Wl+ekeWLsRsYQ2Q/L0fM8Dnjv6I8njv2AufqA8tXPotoziqomS83Z2TnXTUVZaKIPlHXD5uoFyjtetpY6TIj9A0VakOal1ozdEdus2Tx7xerqJVoInp017O4+wz+849Wrl6y3W4S21Jev+drP/jLucKIqKg77Pbf3e8bjifv7ew6Dx/nAOAzo5NhsG0IOHHbvcNOXv0v+acqNp+x4yo6n7Hh/2XHc33zpefoPfILyO7/zO/zH//F/zD/+j//jf+Txf+1f+9f4L/6L/4L/7D/7z9hsNvz5P//n+ef/+X+e//a//W8BiDHy67/+6zx//pz/7r/77/jiiy/4l/6lfwljDP/2v/1v/1j/hstnLyg1jPtb3MMNOQTCPKJXK1CK1dk528sLcsrMIfDuD/42aVqqnpGKzdkGK2HuZ2qjQbdEGZYq72rLOEyoCMIWmEZiOshJslq3jGOHlVBvWigqjCpwu3fYzQpvW1S1JQ49x25g2h8I5YQtS0KCcXfPYAuUUgxBsD/usKtMoyV2s+H8ww8JbuKwOxLzLcPpSPIjw3GHAJrtGlOU1KZA+4CyBV540iAY+wld1AsECBDGEFNGSUVOArLEuQgigVSkFNFZEr2nqkpSjCiliTFQNlCsL7CbS5JQ+O5AGI7k5CAGEHD18jVV3RLmjCgLhJvwPhCNQoslBKZxwCqFUJKUEiJnVttzdI60lcbWFfMcmaRAVC1JG/o5YwuFrWrm04GcARTTOJG8pxSSstX4eYScSTFxHDqUBDJUdYspay6eXROdIyZNzDPBeaKf8ePAZnWGkxCjR8aIDyOiqKhUTew6EpmcI0S3FA8mUFLRbrbMzpGTI/YDMgaa62eQl3CGiEyeU3eHVJKrq3OkKVFG4KyG7ZboI0VTE4NHpUQQkv5wh8iJQgBCI42lm0+cnZ+hmwqCJ4qEC365djgd8UNCrlZsL7eL6l0IDt+/xd+/5WHu0Lom3H1OdAfq89eklcagMOcrNs4z+x6hJFcvzhGqoO9OrLQinI6I1Yq2adl9dmQYHUoXeFNRNTX+5gvm8GMgIX+KcuMpO56y4yk73l92sPvy/CSRc/6xU6brOn7lV36Fv/JX/gq//du/zR//43+c/+A/+A84HA5cXV3xn/wn/wn/wr/wLwDwB3/wB/ziL/4if+2v/TV+9Vd/lf/yv/wv+ef+uX+ON2/e8OzZMwD+o//oP+K3fuu3uL29xdr/9R7p4/HIZrPh//3//G2m/QHvBq6fXy4vcg5YW1BuLimqFcIo/DRz3N8TxhNjf6QqKuZ+ImUIbiRLRdmswFpi33G2WSPbDdFNVCpRG83tuxsCYE2N1JoYPPPUY43CVi1JawiJlBK6WHYh09ghw4xRGp8gC0nXn5DRYcuSsq5Aaaa4mDF1cFQmU7Ytdw9HUlgojzkv1ekxOWYfMLpA+oAUEeeWQrWyWXPY75nHic35ZkF4nzqyLjBnV0hdME8TFk8YT1i7KNJTzpTWILSlrJulHVEsYYC2mNUFq+vXuHlmvntLCgNjt8PNI2XdYIuC480N7eYCU1n2dzfLrkRKyJL15Quy62EeSKZk9ongZqSKnG7fYWUiC0m7OcdWFVFahFxW/OM44Z1ndXFFWZXgZsY043yiiAGtFT7mZQekFFIIYnDL8XQW2MeWvaQNolkR+2WHQ0oE51ACumlmfXaGTBFlDKZu0LZCKYsQgpjiIutSgjB29Mc9wtRkACFo2xb8jPaeKCKyWpEQ+KlnGnriNFEaSVG16LKiH2fUaotFQAx41xODw7mAsBXGGkQWzJNDl5a6rFHaEN3I/t075rnHKk1KHlVaVlcvMdpyfPM53X6HqmuKZoUQGRUiAUHOgewm5OqcMPfEvqfetgy7HeNhsd5WmzVSlYyjww0jOWW+88kdL1+9YF0muu6ELGs2F+cYrRi7Aa8t/6f/+29xOBxYr9c/9bnxlB1P2fGUHe8/O+5vH/i//NZf/lK58Q90gvKbv/mb/Pqv/zp/6k/9KX77t3/7R4//7u/+Lt57/tSf+lM/euwXfuEX+PDDD38UNH/tr/01/tgf+2M/ChmAP/2n/zR/7s/9Of7u3/27/Ik/8Sf+vq83zzPzPP+RkAH49t/6XU6HPd/86EPcXqFWDQbB6fae9uo5wc/k3jEPJ9zpgA+O6CKff/EZZdNQNxXFxTWnw4mP//YfcHV9yeX1BlkYJA6cY/AjvtAkISAnhMxEIrZtyNaQwsjsJxSGMMcFEzyNPzJT+pTpuz3GFiilqaRAFjVlYYhuJquItSVuXuBNThaMu5E8eYKb6Lo9zeacavURQkMxjHS7B1AZpKHrBubJc/Ot77Nab3j+0WsIE252nCaHWW0o1s9RuqBuJ9zDZxSFYZ5GgvdUdUUQGS2We+EQI9YWaK1IMTLdvyXPI6dTz7vv/gHFuqWslh1cFAaf9eLcNCXeZ5qz62XyjxPJzRTNCt2U7N90pOioV1u8s8zDgfXmnNP+wNubd1x50FXD1QffwBQF8ziQ55nysSAzRo+WS5FhSANbqyB4TqZEqAo/LcV3OWn6YUTmRJCZqe/JOaPHE7Yo0HKxg47T8rmr1bKjrJoVyiyOCyMFkUhOIDKIFIjCoNoLKl1CjksIK4Msa5CCXFfEaUZITc5QrS+JqgS1Y/YTq3VLlhqVwA2O2c9URkGc0XGCGPFBYTdbDLDenqELw7TboYMHEVlfrPGpxQ8TY3+i1AVaSpSUyNKi6uVeOHpHsb0kDycMAVFuUGHD0O2o2opoJO5hh5SCYr1Ba4OUBVEb6k2DLgp0CnwzjawuCqy22LJe6JB+x3DyTGPgeBh+qnPjKTuesuMpO376smM8/SNsM/5P/9P/lL/xN/4Gv/M7v/P3Pff27VustWy32z/y+LNnz3j79u2P/s7/NGR++PwPn/ufG//Ov/Pv8G/+m//m3/f465/7JmEa2TYFGpBkuv2R23f3DC6wPxzo9weasmB9tmVzfcHmmz+H42PKssCWJfXZJVE33P7Nb/Hd//5v06wsv/wLH/JLv/gL9N7jYySMESUNUzeQ44AsDC54pLKc7vY0bU1OEKeBefCYukAq8GHGx0wECm0gR2QONCIjlUbLmigF2AKtCny3xx2PmKKibFuqXKKsYuhHdm8/oVytKMuG5CdiMtj1GmkHShNppeT0sOdZgrEfODtbI6sGLytSCITZIwj4OSBIqKoCbZBlSQJijogUIaVFgKUEpIibJoapJynF5uWz5bhXCmKM7B8e0FVLffYMXRSMh3u8W46Js65pP/wZ7MUzZJioKOnefh/Xd1hjMIVFlSXNakN7/fxHTgqTPdotQKLZKNxwYv95jxtH6qsXNK9+nsLu6D7/Q+qyoGwtSRekYWQ47XBzRwqJaXK0bbN0FpiCOI1M84iyFT5lbLuibDe0TU2YR6S2S6dBmnExkoVE5kzQJbE8Yw6JSiiyGMnRE8YZaSGVoHVJSpmibRajafTIHKgNlNfXKD8jtWFMGmNAp4lkM0KCNA06SoRJBA85ZXTdEOeB0O+Y+hPykXwpU0ImKApDihUhS+bZo2yFabaUMROmjjCciPOEUJJumEliWmBbw4lxGlitV4i6Qca8+FqFwufE2E+cv3pOvd7Sf/odNpszmutXaC3RfYceBuLkUUrQrBty+vKHr+8jN56y4yk7nrLjpy87WH35ZcePtUD59NNP+Vf+lX+Fv/pX/yplWf44H/oTjb/0l/4Sf/Ev/sUf/fl4PPL69WvOzq4ojUbFwNAdCEJRnF9RdxOFMZydbdlu1rSrNcexJ6qlJvj8xXMgo7VChJFn1xf8U7/6J/j0u9/Ch4QQiyNDNCtynyik4H7/wDw6dIAyFEyzJ0mNUqCEQqGozi6xZ5LsRqyG6CMiSqTWKK0Ic0IrTT9NFFJh6hqcQxuJWdUc+46MouuOpOhZbbaYomIlNUpphtMRPzvKuiVmiMOJ2fVoa7m8uOTVy5f44ElKk1F478gmkv3A5Gb81DM/3NBajalrsoDj/Q6jwDxO+CwFIUxklyk3Dat6y3gaSSkyaUlWFq0VY3dEybyQEquW7nhH9ImEQkuBUYm0v0PbghQ9RheYek12I9mPVCKRRQAklZJkBDlHxuMOrRRJGc625+TNOVEZTqcjPibmwx0iDEyTZ3YBO3u8WlTmYZqYhwOBzJRabPGM+nKLmEfmaVrQ3jlTtQ1FWRFzxtsSff6SeX+Pf3iDdyOq3ixkyTTiJ0+yGj+e8OM9yQ1LEWFVU9YNZdGQcwIUSEnyExCYZ4dKmTDPZNeTDOizK4Qb8XHGmhakQSuIQyArcKc94rRHpoDMgRw8pqxAGpLrkVIho0MJRS4LIgLCzNifEEKhi4JMor18hg0jg3PIZoWtW8Q8cpgTJMVnn77BxcjV81ekEFBS4VxgciOHd58jYkSQCSmTuwMmeGTIjNPA7Bzbq0v8NCO/pPTrfeUGPGXHU3Y8ZcdPW3bMP0YNyo+1QPnd3/1dbm5u+JVf+ZUfPRZj5L/5b/4b/sP/8D/kv/qv/iucc+z3+z+yG3r37h3Pnz8H4Pnz5/z1v/7X/8jn/WG1/g//zv//KIqCovj7zakqBpRSzCEhyhWF0lRVic0ZoyVSKSY3o4qGPsH+9oHhGAhSEqaJZtVi6xL/9g0mer7x4Quai5dIDfPckX1E1y0JaC4UZuqpWCYSRqHKGikj0QeS0hxu76m253gfGI/LXa2yFSknSHJpy1OKQKQ7dRS6gGFkHEdMX2KsJESFnwPGlKBa+v0bRI5smxWysOhyjdWWoTsgtOX4/e/h55mz62tYrTh/9TWGwwPd8Z5x9kibiDmRBUTXISTYqkYKQQqReRihKZEJ+n6gKCtcdMiUmB5m6nZNdBMhekiZ2R2ZtWG9vUApTRYaJTJ2taU/HVFSIJQkuJm+7+j7HcmNlFVN6E4QAn3/QBc9BI9cnVFevkAIEKJApYiOM6OfCNP0CGeSnLcV3/1bfxNtSs6un1FfPCenzNg9EI8HvJLLa6YN4zAxBM86K3JZM596RNHCeCAjIENwEZkF3f33KcovmL0jhcWXIaYJQV52jz4T8/dxwx7cQHATKIvUlvb8ms3lC3S1otxcQMgw94TgQRgoahARHxMiBaa7T0mAEiBzRviesZ+RyWGaNednWzICkfzCcJByscJKgzEWkTNwoj/t0XVLZSviPOC7HbZukDlQKonwI8HPWF2Ac8T7z6nXG9T1S6ahQ2gBuqDcbqnqFVIatMzs3n3GNA5UbUUOhtXjfblpSpIxJFcjR49Pmeb8mrXzP9W58ZQdT9nxlB0/fdlRn119qdyAH3OB8if/5J/k937v9/7IY3/mz/wZfuEXfoHf+q3f4vXr1xhj+K//6/+a3/iN3wDgD//wD/nkk0/4tV/7NQB+7dd+jb/8l/8yNzc3XF9fA/BX/+pfZb1e80u/9Es/zj8H025RWlGSiX5iPOw4dTuCm4hKY4oShcDPjqosGXaJb33/9/nDH7xle3bOP/3P/B+4fPbhUnAlHK7rsO2W/rjj+O4BISXt2RllVaKERE4z+/2e9us/x/rV1zBK8e7b3yL6iEkJUmLqOnRVYctEUWhciByOB84uLtBaEJMn5kS7OkNqS69myqKisAXeT8gQl/a6okLUK1TfYJVElg0ojSpaUBp32KELjdIF+5t3+DDzjT/2T5KEIrrFAaKqFpUjbjxRVTXlao3entEYQwiO+dQTc8K2a+q2oTt2SB9Y6FILCTIGmOeJ03636M3dzPr8gqYqUMowj57YnchaogQLOjmB1AqTIsmP+Kln2O3Y3d5zcXlF0A3RgL+9YdMo5gTJO9q6JIeZuy/eLoWHpcdFRzre42Pk/PICpCEkELYGW6GsZepHKEqktqzXZ3B4wL+9QZ8+xd06TH2Od2DCuJAfhcCUJf1ph7Awp5ngJ4gO/EiaMyFGfABIzFNPjGERa40jMfYYYwkxIpVmpRR+tFhtcGMPUiJlZj4eECIhlaSoW/CJMByJOSCMw8eEVgZTbUhJoIualAU5RlLi8Wg2EsJE8A43T0gSIUVqaxapmCgxaWkFNUZj25aoNcyONI+UVUWuCiIJZS0rvWHVLEApVdf4FMk5E6eZOA0UpUYVBb0QsNoSDjtkAu0Cw/7E6vyK6B9BV+rLyQJ/2nLjKTuesuMpO95fdszzlwc8/lgLlNVqxS//8i//kceapuHi4uJHj//ZP/tn+Yt/8S9yfn7Oer3mL/yFv8Cv/dqv8au/+qsA/LP/7D/LL/3SL/Ev/ov/Iv/ev/fv8fbtW/6Nf+Pf4Dd/8zf/Z3c6/0vDKI3OkTCP+L7DO0d9ccV0f0v2juaiZY5yWfEJQdP3fOMbgouXz+mOPWuRqIqW6oOS/d1nzLf3vPnW7xFToiktSWvCOPPm889p1g3aFDipkaYgHXtSCjy7POdwfEBKSdnWjH2HNhKla4QSxDDTlgXeJYr1Ght6ynXFMCcOd7coIRi8ZxCSHD0xL/eR8+CYP/8BgoxdtWTvFg7ANDL5zGl3R76/5exiQ1V+DYlYdoTdnuAGstaYnMkyoX1EkSm0JsdAd5oQZYXXBbJul17121vImbqqOO6PaFtiKs2nn31C9hNWCap2jV2tqI0iDz2ivURUJTEGnE+kGNFaYI2C6JkPd0hdAoaPv/stlBGwE6yePUPZkvLqGqsVue8IbmJwA4XVrF8+JyeBHwd0TgzzwOz8EixFgWzWJBeW1b8RmPU5zasPiP3E/PCGtiqpXlzRFAbleoTWzN0JkQJlVaCVpChKxpMikQnOEVOiP+6QeMpqRd2ek3Nmd/cO72ZILJX+0iAky8f5mf74QF1XzNOJ2S+StnK1ZXCRaTwxdQdUTly8+gjdnjMlDT4goyNEj24MLuVFwy4E2mii8Mz9CNOAVhlhLMEHUojownJ1eYEsSpS2GNOilWbqT/TdAdcfkBIICXIixISSihQ8QkdkuyFNjkp6fH+PyIqcQSTPqhCgE37cU6gSnKCsDONxTw7j8jM47ZinntWPcVXz05YbT9nxlB1P2fH+skP+GH3D/9BJsv/+v//vI6XkN37jN5jnmT/9p/80f+Wv/JUfPa+U4j//z/9z/tyf+3P82q/9Gk3T8C//y/8y/9a/9W/92F/rzff+HswjkKnqhiw1VdFiXjbIELCrNTIqkILd3QOmKMnziWeXGz762itkTIgwMNx8hhoPbM5rALphxBgNGfrbW+b9HW0l2Lz4GpnE/ec/4OL5KzYXF8z9ibOqopQZnzLrVx8xRzjevaE/ddiipr16SXX1GpUj/Xf/Jt1DB5sNstQMh46qabFFRcqGIrFAcFrFPPQY2+CUoFCC/e7AOM7sDgNZQFtXXF1f01jL8eYd/uENyRjCOFC0W8pNS3/sKZsCKRXzMNIf7hdqYt8xDRNKS5RuyXGZgFlmEIL7+zvG2XF/e8f1sys2Zxf0Xc9wOHKfPMcPAs9/4RlpmgjHA/riGeb8OVIqlIik6QTTRB57xuMD2gguX3+ANiVKadLYMU49+3Fiu95SFBorQMRICGk56o2eKSVcBoqK2Qfuv/dtNm3LseuJKXFx1rBpV4jhSJ5mpv5A3dRUVYHIiTwPxBBRCSY/o5TAhRPjOBHcIoEr65ppcJiiwIilXTUFR87LTiTFRHB+IWxWa4RzxHmAOJHLie7wgOuW4rIsoGzaRZg1Thwf7kjjRHc8cf61n6E6fwFaMxzfgRsxxmCkAClwQ8ewc4vTJCfSPDMf79ieb5HGUCiJG44okSi1RhcV5IQbjxiZKQQMh3ukUlhdUNUlXoulE0Fp+mkk7ieM1iRTk3JGZY+RkpwjZdvgYmTuDkidsMmRpKbdGLKQbGVFngJGBbJQxFT9w4gM4H/b3ICn7HjKjqfseF/Z0U9fHq74D8RBed/jhyyD/9f/4//Gqq6pqhKhYOhH7g4jly+e8/KDr9PPE9P+geOh4/7+xLOLFVWjUNVy9Nms1ghp6U4Hwu6BpBK2qDh1HfPQUxQF3anDzRNtXVOfnaHKFo1hdXlBWVW4aSA4B+MRVdaU6zPGaWZ8+JzQ90yzRytJe/4c5zzf/f/+fygLzeb5NbpuFpmVsRhbIQQYYxm7HVJkEIbZe+xqS2ELDg93eOcZx5msJC++/rOsjOLt977F3Ref4pxj1bZMo8NUNevnr0nrZ5TpwLoucNPMNPTM08j+2OFmR7tqufro6wjv8NNASAmZMzFEpmliHEaaVYsQGe8d0Sd2uxMf/vwvcfnsGfPhnhgDQkiai+cIVSCihzCCEBy/+ISH+1vquqG5vERKQ10UVAb2d3fc3twhjaWyhvVmDVIwjDOITFHWlO0aH2YiigTs7++oynoRsCmNERmRIynKhe0wDZAzVmv6fmAOie3z16QwMnQHcgKjDcYYgptBFySpcMMBKzNu7JnHaWEKKEMIntP+yPF4oFy1mHqzQI/mgRQXboY1mmnsiDkTfCL6gNYa5zxudmgpyHH53WrPLrB1y9nFGUYsPy+JIviZ0+0t0a45++V/EoGif/NdPvmb/z0X24Z6tbwBhhgp25ar5x+hmxWQEeOR4D0IycPtDcPpQFuUlKVZTLpC4H0EZZljRhlDJmOLcvke53H5mcfAFBM5G0RZkmNAzkf6/sjFq9fIoiVMHn86YMqSQVr+mf/zb35pDspPw3jKjqfseMqO95sdLsGf/L/+pX90HJSflqGzxABWK3yMP3JjFDnh5x4tDShLs9lg25Z1KWjXK/b3D+xvdkwPt9hmTRpnTocHbLvsCItCk7xme3EJxuBcYHN+zfmz58QMSmmUEKSUUKbElCvU2SUCyNFjCDghUe2K+qrl9rNPuPlb/wPy7Az96gPazYp5HBlOPa+/9jW8i6ScSTlDTkgh8NOE1JkUEmEewY0QHVJkrAxUVUObZuIU8SmgqxovFKJaQx44DRM/+J2/ycXXv87r55fkCoSUCGO4PD/n+pXg/uYdu/2eMHvmeTkSXSCRktlNzMeO5uUrqrrm7nt/gC0ryvWaM1NRrNccH3b48UCxXjPc3bJ/9ylT3+OHEV0Y2vML6mZFe/2cdrPG1hVWKuQ4MDzsIUvqzRmTn0lSkqRCCoWWGUqD1hpioCyqxc8RPZdnWwolmMYT0jboqmWp4htJ3kF8PNa0FVnU2LLFXF7h9jeoacSYmpwj8yOm3RSWZn2B21zgdu/I40R1dokuK8ig3IwPkTl42tWW8rGKX8TAZ9/7DsHNGC1w88IxCLNboFZViTIajcIqi1KCuiyxMlGqhMmOqq7x3hP9gq2OyeHDjJhnpoe3xMMtm03D7Gamh5myKhC6oFQVIQWYexIKgSEkT/ITzgeksiSRGccRfzxijUFIiVQBZQpycCitlp1eSAQf8N5jynppKcyCvu85nY7UTcnkEjeffsr2+culgE8tjA4h03ub+z/peMqOp+x4yo73lB3hyxXXw1d9gSIl96eJvD5HqsCqVVgEYRp5+73vUm235JSQKdJWDZaEDo5xd0f2gW7sYbcnhkS1XlO2G4yxODdjihofE816A+OMaWqSEOSUlla24Jfdj5RUhUUKSUxL69w4R8gCIRO2rqgvLvj4e5/CYcCcn5OUoWwUpcjo6BAyE9BLK9004CKIskVqQxEcMnuStJSbC/rDnpwS42nP5/2B86trqqpgHAyXVxuqVUP5kNj5kfPLmoKRNB8Zu0wWkGLgeHdDW1h0ChRCE1JG+sxut6PenhH7mdtPPiWLxPmHH2IEpBSXe88wIoOnv79BaoOICREiKUY+/cGn7O53FNaijKY6duQEzeUlH9U1tdBMw0AeOpQuUEaxataUKRB9QLbnEBxFTqBA5UDwHqkV426HWwDSzGRyDFSy5Hh/yzg5tqsWRUIpS0gwRRDtilIXjHd3TN2eODmkLSmLEpEh+Jlxd0sYJ9rnr5lNSdQlxlZIpQjTDH5iVVma8nqRvLmJpiqxZYF6/SHOzZy6I0VRYm3BaC3r62vatmXbNqSYKNoG5okwjiA0pijIIi/UTSmXo3ENarXB9zPdd/8Orrunrgu25xvGEPGzQ0iBsjUIyXQ8oKUk5gy2IeZEmP1SULe0GrC/u2PsOpqmYntxibQW158o6hYpFNM0EFWBMBW6bJe7aaVJ08SqbUlSIhScv/gAnRLBJWxtEU2BkIqp+/FAbT9N4yk7nrLjKTveT3YY5b78PP1HFwH/6MfbdzdcPX9OnGbWrWFVl+zvdyS1VBYTHVIK8uMR4+n4QKkNfpoJMTANIzlnmvUGnyFnwc2bN5AD169eU7YrXMzUmwtsURD8jMiQXGKeJ5RSZB+J03IEllmKhpSSzDFAktz9vT9EycjZ+YrDaUR0B8SqAC2XVr5xptms8aNnPPYEpUntNUYK3LAnTgNaa2ISCAQ+BdaXlxRFCbqiP+wQKaLxiPnIeLwjCsnl6w9prp4jQmLoDpwOe3RZogQQA/0UGb1jDp589ylSKHJRwuqCoh55oV7TH3fcfO/brC/OOd+uFrR1u2Y3P/CHv/O7JKWQKrFuazSR4/HE4DxjTFRUyCA4/+AlTdXw8R98h4dNi7WWdVNSti1CSBAKiSb7A/PDF2gtsUrjHxXttq0JMdANHcJobFGijYEQkNrS7XuC0Jw8WDKFmjne39NhMaZmP3ULAMoalLIYrQlZouo1uizQbmLY3XH44mMEgrIoMGq5ey2tBQKIEvFY9R/mCS2X1sv1+RlSCjb9mv50ZBgnNlcvaM6fs16t8YcbJNBst4xDzzGL5TheKmJ0uHnAao3IkUJrgiwo60uGfo9arQkiE6YZIw3FusGYAlu3ODcR/ETdrhncSPIzUUiE0lglESLjvEKuLihXzyi356i6RuOomjPQxXIkOz8Qg6Bst2RgOB3QeUbliLElbVMRg1s6M4RBCIXQBUmAKGu0+odXg/K/9XjKjqfseMqO95MdufpyWgr4ii9Q6lWDm0/oGTqX6e4CCMUcIloWzIcTIQVSlKA0x/tbirIgk7m4PKdZn1GvNwgpcCFxPHZIadhuzilXW7IUWK3xIeLCgDKK5Cb607T0pGuNItN1D6QskUVDUTcIKZmiJNqSvNrgDndUmzO6yVOkxN3Hn2BKQ1FUaHFGIyzdac/dw57UrFltzsn9Ce32rNcrVFHT7+4ZTm9Zr1es19uFQOgmTtNIac3Sxy81cvuCSllsVWCsQcrIwILarsqaaeyWdkZjiSmxbVtWZyu2z67o58SnH79l7o8UJbTbNd3DEWUsQmhUSNTVGv3hioTi88/foJXk1I/kFMha8cE3PmK1PSOnhJISFRPzYYfRkvXmjPOra2KK+OFIdB6lF9x19InQ7dBGQ7vFNiskAqlLRJ5YXz/D+xmRBLpaEcaB4+HE29s9QRlCeED4gVdXW5IHWUq8EDz76OtIo5lHh5QZGT1BSi4vztFFQ8hQ25KHzz7BhbBIxsrljVemSJY1Ummk0ZDSgulOC2AreU+KkbIosPqMVeOZ/UweHuhchzvc4ceBw7sK267QRiNYQEZGZHyMOKEwVYtWCuECIWZKY4khE+YeoTO20GgVia5nnAeyVpiyQpYVwXuG7kTMLLtPLRFCooqaiw+ekcottt0i3EjsbhbtPZDl8jl0sSKXNUJKajy+W+7/xxBo6u0jJVMsb55CIOUCdZIIfIzvdf7/JOMpO56y4yk73k92hP4fUZvxT9vwIjP2E/bikkEZhts7sjuihMS2GiUhhsjxtEPWK9rnrxgOO+Iwosuaiw+/iRCS/dsfMA9HhJ8x1tAPR+JOU20vqaqCadozTyeadsN42jO7GWUKlFsEU/PY4+cF6nR+/Zx5nnHHB9w0YsqK5uyK5tlHrF6MCCC5ES0DRmvmfuT7f/gHjP2JOWZif4L+yLosqJoaaQuUFFRGk5QiTSPd7p6mWn5pt9fPSUpR1S1D11EUDaeHB0o3IvYPmMJS2QJLS5on/DyTUmaaZ9r1BiXA7W/Q24Jn7YbPuluGvkPqDZUpWV9dIXJc/BYJ5GlHWdY8uzxn2xT4FAk+MgxLVfnFxRXt+RlTP6ELg7Wa7rBn3O85vv2EMBwxVYsyFVoaUnQMfY/VGrs9QxcVQhomv5hhV0WNKTYQB8I4U65W2M0lQh04HD6n292RpSQai1Oa/hi4bmu2ORL9xDz0tJst1hpk2aIFKDJuGji9/YR5f1ygQ5sNZjhRMqP8TPBx8Y2sF89GdhPRTeSUUQhyCOA9hBk3CUxRIlJEIVBSYK1Bn10SGodWGmFKRE7EaZnEbl4KKFfn1xhjAMHD4bAcYx/vCGMHInH98jnSpEUApiRzCBjdIk3BHBJKl9hKLJhsMtIasjJIoVA5EqNDnO443b9ZqJRhBiTJVOQwM97tKbYD2RimfsfUHxefRtkQYiSgMNbiHy2uaZooi5IoJOM/gM34p2U8ZcdTdjxlx/vJjn6c/1dm5/84vtILlKIs2Z6f0zYbMJbxcGTY3S1V4ygQgtXlcy6vPyAJCOOIKjRn60tcf6C//YLgI7ubNxy6gbZukCmx3m7w84gbemIIjP2BMA+kEEjBoWyFNCVu6AjOYcqGot1gbIFQGqUSkoQMM1o0rK9f0zz/iHNjFzlWCMg0kaaeudszDXsEjvOyQpuKom3x00xdl2AsMguKsiUrw3C4ZXf/DtYecsZLjVOWmCX5dGJTK1gbVJbkOUAMBAdSgB9OaKVYXT0nkYnRMQwD/nig+/2O9WZDZaC6ukK3a4QssIXl/rOPUUBSguRnkoD0KAYrHyfIap1IKSDJJO/RigUSJQRaG9bnV1hbEAG9OsduniMOnzPc3JJjZIqgsuCsVCitKYA4B6K06LMX5PgZlhXV+TWi3hBT4ux8yzd/5jXkwITmzWlpbSubmtomVPQQZ+ZpQrcX5OAYuwNaZCaZCWPP0B/p+h5dtKwqQ5iXNsxjPxPnmaqyMHeLI0NospBM07QYTeGRYgkielRVA5IUEzEEtNKoukBpi7IVYTriJ49MkrK0WC0pjMZImIYepo6yqhlHgTVrZLNhXp8TjcWYCWUN50VBQqKMQdVr1kVNio4cI/IRnBZTIBz2hOFAjh6fBIpEFAKty+XuWWqcH0BkxuGEixN5OJDEQjmNKRH8yDB7ylAxTdNSO+EXUqf3iaT+t8XW/8McT9nxlB1P2fF+suNw/PK1a1/pBUpjNOumJJ4OiKqksAZXNwgpMFVFiIGkl1/abn8gDCOlVZSVwgTP/Sffpk+Gi1cf0ryqGB9uKOqayc/cv/uC0/ADVFViSsNmu6FuDD5GMhqlS6SZEWGmqltySoiUiG6mKCvCZkNyE0W7RdXtgmGeB5JQZCAHT3ITcR6Y+hFrC9rNlmZ7Trtu6e7uCG5GSckcJGxfYdUC+Snycpy3u7llDIFUNJRnF5gcGIeRFAR+ciQ/UW1arDXkEChX56jSwDwzjz3j0JGlRlQbej8R+8h6fQlKIbRd1OZGYqylsHppO2PpGIgIlLY4H2jqBikTWbdL//00UhrL2PUIJRdJGBm7OiMiaNoGWShy2SJ0gRIeUTVICcNpj0Ug6nNs+xzaC0QUiGwRcVwgP1KjbUmMnuvNQoHMuuR844g5UTTrpejwcIOODkvEGAlK0PmelANF1WDbhs35JfOUGaYZ21qCnPFCUbYG2ayYvUPmRFFvSaYCqVDljIwBKSUpQ5o6YpxQUqGEIT5q0CMCbSSZhgQIN1GWFtOsSJMjDieIETcORDdxvqmIQlBeXSOLGm8qnDAQZmQIiLqiaNYkJCEmktAoW6GiIroRYyxZGbSQEDPzbkRlSDmjtSWTySkznI7sdzfkuae5eo6tKpQLkGuSrnFJIKeB/197fxJraZqe5cLX237dancXEdlGVuvj4+aYMjYJAwYuYcCiEyPLAwsQyFCWbAkhGRAwtCUkJEDIE4SZURIIGwQ2OpZtCnzktqiiWleblV00u13t17ztP/ii8jjBcNK/XY6IqnVJIWXGWpl69tr7u/fbPM9928kCU5b0XUu/3+FCQBiNKRuaacPwFJ+gHLTjoB0H7Xg82lEvm3f8nD7VC5T28hq6Fu8GqsUCHyFKg5QS5wO6tGijWb3xGrlumN++jYo9OfXI6YzZXEPXUzQzdpfXaGupz54fbaWLkvNPfALfttSzmpPTE4QQ2Kp5ZNTToaTAx0S721CoMX8hJhi0YbPZEENEuh55fY7brTEkiGF839Ch7RjkNW0akshUzYz2+or1/TfwQ48tKvpNjzl+nqqeE2/uk7oBtGQym9Gc3mK327DbdRhbEF3JfttSTpdkI+kEuJsd85gorGa/aYnrcUckSBTG4mPCmPEz08YSc6Y0Bm0sQ/SEoWO+mINQ4x1qzgTv0IhxFl6PD5v3gXa7H2f7lUYIKKwBpXCDJwnByZ07dO0e7VompSWVmjhfEONAWRiS0HS9xiWwkyPy4hayqslSkOuGdn1Oe+91pqeRbnXN5v59cr9lenaGnSRqJbHakrTES4139ehkuLlCi8js+A7z27dwviOKYtzVDh5VTannM4SITCYF+92efrdDKIUfHElpRHOEKBs0Erff4rdbhs0G5zp09kxKA12HUgHf75BEbFHje49zicIO4FuElGQ1MLQdWShSTBATUipCcIicqaXAuZb28hzX95RGgMjk6jkSC0w9QwmJkoJMIqeEMZYYA0pKRAZVlhSzI/r1Gsjk5ElxgCzQMjOfNng7Boi5mwsqLZktjkimpPcJyChd4IND2ILaVsTdDqkUs+mSsq5/V0e1TxoH7Thox0E7Ho92/G46157qBcp6t2UIjqooIcDm/IbN1hGAutIcnc5RYWC6aHjwymukhw+58667ZCGJQkJI9A/eZHjwBtPFCdMX34VtGpzrWZzd4vm7L9Ltt2zaLavVJSEM1HVDaRR9N+CjIPnA5mZNVVWousTFQOg6fAA3BNr799i2O7SUTOqa9fU1vtuzW684OT3luZdeomoqfB6j1Tddz9C1BB9hyASlqPQl7X4H7TWWgNKWlByVrlBNQer2aOGJtaW9WLHpeoIpKY+W7LqBdTtQuIF+cEilqKzFKIX7atibFCipISZi35IU465FCYSQSGUIKeO8J8bR+2J0y8wYqcg5Y0vLbrdj226ZNBWBiLV2FKvdjt456sIQg8cogRoESmlsYdCmJgePjxlri3G2PjokGenGI9+0u8KtrxkimMkpOwrE8hSx10RTo5SFboVB4PweIwribse0rGgqSYp74v4CVUxQKILPYGr89pLd5ZvYxYRyUhAJTAqDcJYkJDFqTD1B2wpTTNDWoOspG3Gf1770GW7uvc7dl15AVMdjzocCZYvRhhsFOqOkJsYIMUMI+LBHaoMpSiQZ3+1JweO6PRIgJ7IUCLdHBY9AM2RPCB4hNUKNHg+uG/A5InMgaUPOY4gbOZEZv0/t9hqZI9ooBj8gpcEaw2Q2e2sc0nuHyKClAjLaKmJMpG6DRUBRInSB84HwyHNj6DriOwwLfBI5aMdBOw7a8Zi0o/8G6UG52gyczc44u3VG9Htsozk6PcEPgd36CqkysV9TFlPmWtL1O+698kWq0mKKErfdUjU1TVNRzmdkJdndPCS5gdC37K/XvPHqGySRMUkggif1Hc1kRo4JNzgSiTypsCd32G1uCENk6ANRSdrW4fc7KteilGSDZLfbo7Wl84lq8Gx327HDuawQWiK0wpQ1QiU2+5711YrjJDm6fZui1pRxNNF5cO8hvfMszo4oyxLXbiFkkjIUs5JF3eCHHlVrnHN0mwFhNcEHLldbtus1xmiev/siQkusKcbu8jAQ0/gwaCnHMKhHx8KI//ezDyEQwxgYp7RGIimsQeaE6zsm0ym6niKVZlYvMbs1q9UNpMR2aOknNdV0jtAFWSqSSEjhx51lGIj7NbqZIYVEJsf+3ldY3X+To5e+icWzL3CsDO58Sn/zBn5oGXYrht0G38ygqUEbvHXYwnB5cZ/t+UNUMeXOe97P9PSY3Cdi8FCOI6bBDexXN0TXobVGSoOpKpSxYzPjfgNotJpghaAQkdpK7OkRRWHJOZMzDH1PWZXjzi2PNtRCCIxSpEe21PpRWFgIHkMef6m4nqQE9tE4pzYGW1fstzt2qxtsVY4202Egu5Y4dHjvyNET3IAtKozSqByBSAiRwQekyCQEKIMRijQ41udvUjUT5rduE2PAGjuOo8ZIHPbk4B5ZXhuUtuQciElgRSSGgHcduiiJ/uk9QTlox0E7DtrxeLRD/C7M65/qBcqL3/oB7n77d1BoxZf/+6/TDxc0FRw/e4fwzC1C3/Lw/D5VFchGklOBaWbEqsBOpzTzKYXMDL1jc/9NNl/5CtYYUtfSbtZUkxm3njnj4cVDkoi06x2b6xuObwuEMsS2hZgQ2nLT3+fm/j1cDJy86yW0NnR6jLr2vaMPERcSPkaEchTWYoxmt7pBCImdJJyPxK6ltAY9b8a72sKQ9mvCjaZpDMiMrhvoEtfnb6LnSyjGnUgUCVFMyEIR0IRhj+97Qoj0wVE0oz2xQmFjoi4NQkFMDqNKtC3wTuJCwHXjkaGRgrKwxAy2KDBWEUMcRwG1xlpLTml0Q0QilRzHyUwxrqqlophOqecLXN/T7le49TVBlwRpEUIRQhz1K0VkzhSlhr4nXL5GwNNUJZPFBNRz6EIRL99k8B3DfkcKo0W1Txl1/DxhfoIpJuOdrdDEnOhNTW8bVHCcv/YFfL/G1AtaN2CLkrqsSdaiaBj6FikkWUgQoIJj2K7JKWOLgvZyy/7qPldvvkpTlpQvvEh0jt1+P06AWMO4h5RorTG2IAVPzhFlDFJ81TEUpIK+64E87orseNcbU8ZoS1HWZCkROaCNpd9vUabA7zekHBHKkGEURQGuXZP7Dd4NOBfxjN87rQ34sRnOpQQi0+7WYC1l3aCUJOVETGn0RRj2NE2DVRolI0JlsoZUSOqyGh0oc4bfhSPkk8ZBOw7acdCOx6MdbfsN4iR7+4XnKWzB9cXDMZ46CR68+gYhBpbP34VqQSEr3O6a5miJNZbN+ga6lrosUNYwtDs22y2Dz7z6ypvkFFkUlmldsDhekKWinpSEnCkKQ2UlQRq2lxuy60FE9pstqqgRAqaTmmlTkrqWTejoo2foemLOTBdLCgXROQprSDmxub5herQkkt9KVQ19y5HW3Dk9ppeGmze/Qj2ZUZ2dIdo1dVGjluNxoI8BYyuG0OG6DYvZBKMtppyi8zg+JlIiW4MyFlvUrC5X1IVhPp/gw3jH67uWoevRtsR5x5sPbyis4Xg+wVhIIZJ0QhclwTmEENSTZhQJP+A2A7vNGmMURVHQ7zb0+z3VbMLQbinKCcpaCqWwszniUbOWSB6pFNIWY2bGsEcrQ7QJHwM5BaTJKFOgdaBvt3jfgxTElMe7/dNnqeopShmEqZGmRNgCFwdkv6EqK/zyhDgMZCu42W4pXUQXJYXRZO+RSiCUwpgxRyLkhERCSvihY7fdsL25JsVAv2+xZYVZHFNP53TtDlyHLcZsDZHCI/OtiCaBBN8PQERKiQCKuiG6QBAZ09QMMeKGgfVmdPs8EhljC3zfj30Lmw3VZIpIkd16Q8qZajIddyu2wiWPiKPtdEwJXZaUQgAZmSPdfk+UCjObE1OAEMhSUpQFWmTIEBUMJHJOhOCx1o7ppiIhBTRaEVLGkwnZI8XTa3V/0I6Ddhy04/FoR8c3yALly5/+NM+6iK0KdFkyP5qir6959ZU30Lfu8sI3v5/ZdsfqK5/B7TYsn3kGVU24eeOLPHz1DYpSc/u55zAerq4fMF3MWdw+o9CC5WQMUxp2e9ph4OTomGbe4K8vuL53j7Z1TM9OqeYz7GJJMTsmtTs2b7zKvU9+jJOjJUdHU3Zdze56gxcRVWlElGg7Pvyb3R6VMremM9CK7XbH6uICFzPIgmrwdG1LcgNtt6Po51il6Z1jd31NDo6mrNAp0w4dWowW3rvrc8LlNXXdYMuKYbehMCWLZ5+FJHj1868Q4oBWUM1qynqOGwakFBhb0PYDyffEHDB2gRscrh8YhoEyRrSEGBLdfjc6G+aM1IpmMgWVKZuaoqwJg0cZ/SjRc48fdmhtEFIhpEbJcfUtpCSnjDAVwQeUUVRWEbxns16TpKWoK0xSKJkhjfbYm80GXTYU1RShLa7vyc6T0jXZ9eRuhxJjE1hhPLIqmc2nICQ5CZIbH2CGgRwGktKUkylKjZHvSUhSShTVBGHiGGdfTqjnx6QQcH1HKDTZd4SuQ4QMZcEweMpSU+U9cdiMR6+FIiQ12mS3mzEPRBp0DGg53q9vt1vamzW99/iUmMwW7Hct/a6nsArXtvSrNdfXa3oXKOoCW2jm0ylaKbwPuMFRNTUzo5FaITPEocf3O1ISeDcgJdiyxBaG2O9QSmJtgVICrEFR4YcBX5SoDNkPyBCRSJSUSGuRcYyQf1o5aMdBOw7a8Xi0wwzqHT+nT/UC5XOf+hSmLKnmSybPvkgInth6crFCp0g+v0+4vKB7+CbT2QTRr6mzY+Mjr3zxNe6+70WENph6yt3/4xSfE0YKYr8HW7G7OKe9umb5/LOUVUN7eU3cbmlmDScv3GV25xmyVKwuHpBzpFnMMeIuu3bP+fU52jjmZ7coTUmyNaqZkPZrXL8jDI758QJSYgiJ7PZE11PVJUfHdygnU2K7QhtFDONR3XD9gFxV6OmCGALCe9arNV9ZrTk+u8XRs8+T6waNHQ18ipqb9RppGqbLOdPZjG634/jWMfvNmi5EZuUUIdXoOSCBHBFaUtd2tOaOEeccQkDZVAgybnAMw8B+t6OZLagWy0cr9m78oV8sAEnLbvxvYcyC0BahNJHRWdBnSUAjI2TX4aQfMyeixmrJ0A84P6DaLYKAyJkUxiyNrm3xfYe0hqHbk9sdu5sLvOvx3uH9gMrQTGZUkxllXSMRJJ8xVpOFYHCOICxuP0C3wtYTdFkhhECK8Shel+Ns/0QpvHPsu57kHSp4dHKkNiFDJCaPsKMT6NWD+xwvl1QTSZYC79wotikTh4H1xUOic1hboqVASEFEsr9Z0a63FIsF1fyEoxfeTdV5Nhf38dtrtjcX5OBBa6bTKc20QZDIQpBsSb1oMM7B0IOQ43Gz8wxdi5YClMJUFqkUOSd08kgJ3o27yRQCMY47Na0U0Tu81gghiMkBAmUtMilSlijeeWz6k8ZBOw7acdCOx6MdY6/LO+OpXqDsXMcnPv4JvvX/+naO7pxyeX7B9eWKejbj9a98kfMvfAYjIm6/J8VTdD0F31LhmNUG3+/othtktWB2dErXrvEX99FDD0g653nzjXtMb91BnE7pt9dkpZnUDWHo6a/Oafe7MXypLOnLGpRhslywu7lme71iMp1hixKnK8rJHFVI9jeRrfek4CirGu8dRgliDBRWg+9JvaRv90ih2ay3vPnaPW4fH1MdLZABbq63pODYbHrWm55iNuDPH/LsS++h95l2t8VMNaKcsNrvuPny61zde4AkoLXg9M4dTNUwmy9pVw/JKSK15ubiAqEttS0IznPz8JwUI4vjJVpADu6tpi5SJn01Yj1GgmsZ9o5+s4Y8WoSHELFGI6UE4RBSoo2mj4neJ4p6jjKWLKDb3BCHDqEluzyKW9nMMEVJRhGiG3+ROEe72eBzwEpNTIkcHN4N+KEnC4E2JVoJTGnHXU1mNCgSGSkyUUqSNDhZkZoCRcRUNRlBejR9oI1Bpox3A7HfE/qONPixByMlrNSkCLmcYmcNMiZMGnj/u59jPp8DkZAhxoAfOmJwuKEnIMBWbPd7khuQSmLqBlOXVDFRz6acnp2ymE2YTCTTuuDmvib7ASUyRd1gtEQoxTD0SKFAaYSdYIsMajfWbzV911PYkqwE1pYo/dU00YwIDmwx3nebgj5EchbEGBm6DhM9yljKqiLFSPCO3Eey80RpyOnpXaActOOgHQfteDzaEfpvkAUKQrLabLn32qvI6Nj1Ld2w45kXX6Lveq6//BAjI9OmpJ5YRPzqDLlgevuE/W7HG1/4Mse3bmODp5lPCfMJ63tr+usrhv2AVJrXf+uzaJmxRtNpydX6mjR42vUVw+BIUjA7WpKyJO521FaxnDYUZGyO5HbHq29+heff9x6aQo2eCdMpJjsmhcELSd00KCEIbiBGB35cxW43W4r5MS4I9ikjTMHFl15ne++CurZsd1uK0rLa3lC5lvmkgpgplCL4nuxAO8dmu+PijWtu3z6mmtRUzYJb7/tWjNKEoSVmQYgRYxRlXVFWdnQ8VIbBO7QShG6HFIKMJmcIMbHbbMbYcGNIMdL3nlW3wxYF1WKGqUuQEmMKgnOPmrB6hBx9AsIwGk3ZQmNFpI0DexdBG5S0JATBOVIYdzYoSUxA0dBMaqSAbnNFdI7o/ZhNlTO2NEzqGoUgdh2qLDFS4LuOICAByY3OlkVdU5hjysIgchpHHMuS5AdC9CAEMQuklFgt8SkTU6T1AVKiLmdMJzP663ukboVZzshGI0SJlJrgHVIahBlGz4KUCSESYiJmQdYKPZlQNg123iPJhN0NbRiwVYNwgcoK5DPPYbUiDHvcfgdkyqIY73aFol6cIB75FmiRGboBoSTFdIrUijR4chrzPMZjXcew22JtiWAMObPW4IaI1AIpNTkF+q4lpbHfJIaBmARUU1DvPPTrieOgHQftOGjHY9EOh3nHj+lTvUCpreLs7AifA5/7/OeZTBp89Fzcu0dVVjQnS3K3fzRxMLA7f8DVxYrJcsLxM3doti3tzTU6tqwffoX55L3st3u+8sZDkpZkobh991lityVtLxGVpS4s67bFDY7ze9fc7HbMlwtiPuf4aM5yVlMUGkrL4D1KgJpULJczmspQVhW6OcKWJe7mASJHqlIzOIeSipASk0JTlZob15Ni4uT2HU7OboF39M0U01uef/bdpBAoNg9hd0MmsVguqZsGYRTKTjDzY9rtllfvvYH3PcvTE07u3MEoKJREk4j9lnJyRHN0m+wHht01ymjatiUimC6XTFKiXV3j+xYpBCE5kiyQpkSKjFYC73oGHxkGT9f3KGtRWaKAHAM+j+NvkBn6HqU0VSUhetywRVIgBRRFSUgZbQsI40x/H9w45iYywTnKsqaaLVFWsX74Jt1+N6aEx0g1adBSInSBni1hGGC/xW93uK4dMy+URGpLUZYoWVBUJfgeIQQCyN6Rc8L3HTlBVuMuI0dNiv3oFZAC3g8MXUe33TJbnaOsxBQlFzcbyqDRpoQ47vzIEZEDAkkzndJ2/Ri6VU/wMTFEidsHCJF+vyf0LZ1V4+cRM+X8hNPn3oOxBTf3vsJ+s6GQgnbbYrShmM2ISSFcy7DfE43C2JKUIq4b03OHdo+y4xF+ypm+7+i7Dm8dQ98jtUYqzTD0BNcjpCHuO4Qa+wukUiSpwBQIW5Hc09ske9COg3YctOMxaUf6BlmgfNP7XwJjidLy+pdeJQ6edr+l2+yo6gnv+45vw293DDcPCT7RnJ5y6+QWUytwMaFSQqeGlAXdds8rn/0sN5s91+st7/8/30e7XVMZMKbmaFljyoar9Y5227K93tAPCWMqcorsLq6xYeBk8ixu3+GjZL5YcHvRIKqKlCXTSnN0ekTA0rtAGgZy9hw9ewex3rI6XxGcRziBBYzWnD53hzhs6fvAsG/R+8j7v+0DHL/wLqRQ7K4ecvPFT7C9fp2+79lcX5AAqQ31fo+SiqY2VHbJ8viYqqnxIbDbrth9/P9BJpidLNFyTinB6IwuDN1qANeT+g6BQCvFkAVt2yFNQbYKpUq6/Q05AFIQ4njvHINnv92gcsJWxTgNQKaeTDBFzXQ2e+Qq6RF6FFaZM7Zs0CYTYkQISHF8XdoKt10Rhp5ysqCcHoHU+HaDkAqlDd1uy3Q6oSwsw35gdnLG9OwZxNCxPn8DkRLtds8wtEyahkKCsSBjh2gdSoIxJUopcggk35F8T0IikgQyxhhEYfHekXwiDx637UZjpOmM2dFd8uwItb0iuY6rh2/S7bbUkykpeqwxKCEojMBaSwgRTyIj8T5BioSQyVFgVEXWmt5H3NCTC8ckOGwzYXF2m9jv8O2WUhqSMOjlLfTiFDYX9DHjJeiiGr0HuhYhMoMf0DFSpoSWY8qo0Aa0JglBDGMOTBICaQt8jETnQIApPUoZsBU+C1I/0A1P7xXPQTsO2nHQjsejHfvhG8So7YX3vpvttuXm8gaRElJDVVesrlcUUiMGR4qZsppy8twddF3SDY71/Ut0UaKLivbmhnsPznlwb8OLd4+RUnK8rNGuY241tREczRcYLbj/+mts2oFmsaRPivb+BdYY9n1AaIkymn5wbNc72mHghRfuMJ9UWFswf/6MPgvW6x035w+p50uc95RGkGJEKYW2NaoyVGWBbSpEyqzPL8bAqnqCKUsWz9/l7Nm7GFUgpaQ8u42Me1y/4uLBAzqtMUbRtR3T1ZrF8Yzj0wXb9WYcgXMOF8Kje8EeGSJNb6HU3L//gN35OdV8RhaSqqwQMeK9H8fVEGSlcX1P7AOmGgjBQ5YUVUFOA0oIJsvlI+Om8d5XAPPF/NFR7UAWHmEsUij6fk9R1CgpEDEiAZUCaWix2pCwuP2W/c0KU1VMjk+pF7cBRb+xSCUR+YLsO6LrWO83CFMjokesHmKVZHZ0hCimVPuW/cWb5OgodORoUtF3HUNwSF2RY4AUyDGSpUCagtjv8a4fHR2FJOeAkqAlDCIhK4OqF9jjF5HzW8hyAj7Qra/QUmDKCkhjk5+2xOhpu4HGWLQt6Lct7bYl2QI1WZJMTcoF0RSI+SmCgD9/AzEMuP2W0mhw/WhVXRZYUxGKBXJ+hlSK3nsCoFImurFPIerR5RFfIHMixYTPIK2lqacopUdh13ZMH3WOkDIxBqIs0NGRQyArQ1IFQRhChiE+vQuUg3YctOOgHY9HO/w3SpOsd1BVUz57/wu43vPs7RNMXeLbHhMT6/v3kMYwmZbMq4LCaj766c/yiY9/nuOzM5aLmqEdePXBhrIoMEZQlZYcNb0fKKoGXdR07cDVZsOujww+sqxKjm9ZRI7ousS5wHB1jU+RLkIuKuL6miI6ZPAobRAyI6Xm3uUF99+8x6x3CCI0JdubFa4fkEoQYqBuluOdpvdM7txBXF7S7Vo2+47jl8a7vZwCMWW0gOlsgbv9Aru+R4uMJFFMM3VZo61nuahpKku77wjBsd/uGPoOSUakyJE+QsFokd1MiabGGE3R1FRlyWa9IUuDKRsEO4SU2MKSlIJmBoAtLHIyIceMritEhu3NJVIpjJQYZSisBaEIjzIkyAKhNNPFCTl4ht0KwoBIASkE443naC3eS4udnIAqiUMHGRj2DKtLhO8x2qCLChU8wXvO3/wKWwWz42eY3H0fRT1Hmy1pc8n28pIUIZUKawqk1WShSTGipQQtCAmUHA2SEKOzYkjjWCMIqmaKUBaTNfroeXQ9Jbc72s01oV0Rhz1GQFFYEFDWYxNd9BJRNZjFKSl6yqDw+XoUMKkZXMKHTGVqdL0guZYsDTE6gtuyu9yzvbkGIEtNEpn1+pztvWtczNQyYIXAaEUeWvouYIoCZUpKW5GiIww96VHHPUIQgyd6T/HIBjvn0Sxq0DW5mRC3N9i8B13hkyDqcdQzpe5xPfq/Zw7acdCOg3Y8Ju1QxTt+Tp/qBcoXv/AaIXq8CwjvqGNPrRS3by2QvaPf3XDnhedYzCum0wqtNc8/8wz3Xn+TqUl015fsQuToqOFo3lBUBTe7AakK5vWULCXbzrNfr3B9x9GzzzJcXnL9+kP2+xZdFMxPTvDe06eE0oYkFM1yAfs1Omeyc4iiRMRMu7lBu45bZ6eP8g8SqAI/eIb9jqEb8Dnz5he2SFOwuPscs/mSEsm94TW0hmF7xfbNz2O0hZyZVCUpB/r1CuUjKWcSmenpCU0zxa8fEPctSQrc0LPb9witiSGQc0IquH79PjfpPuV8ztndu5AF280G+chWOWUw1o5ZI1IQ9luqssADQmtsUVLVk9Hp0Ac8ib7tKMqKLARkaDtHyBKpNcoW9O2enKCaTgiyRBQlKkFubwhpbHQLWdK3HVFoimZGDoHVgzfQWlIUBbgeISLSaMpmQjM/wXctVw/fZLvbcd3uWXWJl07uIFHkfo8gIbVhyIEuQGFHR8qcIjxKoRAZlMjE6AmuHwXZluNYqVJoNe4qMgqhauRsQQiBqy9/lmF1DkBTV3Q5ochM5wvIEgnYukEVJboo8UNGFoZq2pASrPY7YlY00wpJIPd70rBHJEdVT9DSsluvuDq/ZLaco6yh2+9587UHnG89Qhleev6MxWKCMRoVHK4bKJWmqGqGkEg5IzQgE8oohJS4/Q4tFTIFyKCtIUeJzBLneogRpTUxRHz0SDNBa4vI8vE8+L8PHLTjoB0H7Xg82iHUO9eNp3qBcnF5TV0pjhdTdkNL3G+RBdQyUk0s2lpu7r3Jwj6Lzon1esvNessLd59hLiIXF9Cvd5wdTckic33TUs7nlMsjZIqkfk+bEnp+gmwc26sV+23P+cNLyknDu77p/6Q8OuX6s58hDo7FYkq767ha7XDXO46bgrKZUChF6wNdipRNSWkbLjc7hJkirWW3vkR6TxgcxfwW8dYCvTyBqmY/7HCbDfQDMgWu77/CxWufp91s6fct8/mU49u3CENgd7Om9R5VWnrXI09vU5QF9WJK9J5+76hrhS4KlFTMJjWT+ZTddkuO426t22xBKeqqRCjFfrfFuR7nA1lkpNJkqfAho6xGSIGQipwFXbfHdS3KWlIMGFsgpCKEgNLm0W0yyJzHY2IXMFU9/rLThhQTInqssfSDI7qA7ztUqdApcvXqq5RWYWcTspAE56msJqfMpGgoi7GxzBQlcsgkD+12w8UXP4lbHo9R4sGTi5qkJcE0CDRpaEnBIYQEW6BMNdZrLVoVZNuScyILUAhC8EjEaNPttsTrN8kpk9sbcrtC1w1JNRT1dBzBbKa4lMF1CBLB94S+BcAaiWkqcsokkcimgAwie7TvCSmgigKhJN47Qhx3v95HnPR0Q8AYwbQUhDiQhjUxnlA2FbooMNZg1KNfKikh8jgKaaqCmDNaKlRRkUOAokY2DcrW+M6jnR99GMp6PMpWBeXRLcrlHZTSJFM/XgH4PXDQjoN2HLTj8WhH/40SFrhYNswmJd16w/xsydVmw/VFS1UYju6MxkL9fg8h86XPf5kHmzVqukTdeoarN96g856XnruNLQ2bbYsVEeMdarfh6vICHSLPv//d6HrBzf3XuLm4YvAZVc6YPX+XW+/7duzymOgS7uJVlMkM24EYxnvSVx9ccL7ZcbYbCAiCUKOlcnY082OqsxdIfqBvOzarDQhL8+xzFMtTEBqRA7t1R3d9TlkqyqNbvHnvIZfnlwD4weMFqGZK2zq++KXXyEJy+/YRfugpqpJn5rfJQjO4HYlxnn8ymTCZTCisZXFyRjPdcXH/dfwwMPg9Uhn6R6t9QSY/SrTUWmNtgRaS4DqklEhTEYWiH3pcP5ByRgmBridIU0HWiG6L991o0ywqOhcYBofvemJMqOkZMoMbeuJ+hxJQVjVSBiQQfM++a8kyse0dhRCjQ2XwVNOGSWHpdmusVcQYRjOmzZphGBA5cvngPnFomU/nCCUxSo/x8DGSkxuDuRJkPx5d5yyJKSJyJEeHNYoUMlIKlJKEkBGA845+dUFaXTBZHHHnmSPc3JKEwSUxChejiVFod/j1DbHSWFtSzJeEFAk5YrUmiURlFAnP0HVkJCIlQvSkBIMLaFqMkhRVSTdEZqenTO7MmczP2W83pMz4dSHo+0BdjKFdVpvx++MG3NCjtSUZDVLiYyQB2hYkaSBLUojEFCincxbTI7rVNTdXD7DTJfPbd2nmR0jxuzuqfdI4aMdBOw7a8Xi0o92/86vhp3qB4lY3fO7VlklTc/tszvz0hBAyhQFVlKiUaI4jl67ntd/6Cs28YVmWhJuB4BInt05ZHi9RWnH27LMM3cDD117Dra6pZaSui/EB6bek3R4hoCoEt176Fk6+5Q9jZlOsUZSzOdevDtSlxE4bllZxdvoubi7P+fKXXuf1i89T1DV33/sSVVWPR6BZM/lq05S1CGXY7TombYsqW3wIRN/TPniD2O0xosTkjJaCqi4QSJIx1IslQhvabgVCorUkx0BVNmM6aNuyCY6u6zHNDPqOFBwg8Xk8Ol5dXdK3w5gsKoGUcN6jqgpblgjsmAOhC4QU1JMa7xQxelL09G1LCn6c0w+BvneYusFOK6QtELEHvyN4D0ISsiAqg5eBbr9BXN1H2xq3vaZfX2G1YrZIiJzJ/Z487CkLg7l9zHq9Zb9r8f1XH+CIqhWQ6HZbIpm+2xOGjtA7rLU0zYSqqom+g6jQZUUKAZQZw9PMaCaU91e4boMICV0vMFIgo0ekNGaSZIGLcXSVzBkhxp8RKSXTAop6Bosjun3La6+8Qrff0UwarAhk78mPdjnKmtGYaujIKdHlMeQrukDbdrRt92jHqBDzBeV0AV1PhaKZ1kzShCFqimqGIqFgvNO2BUVhxwY1IUhivKcPKRIe3RvLnEjBM7iArGdkbaCEFPyYMrvZgx2bLrUpSGkMolO2xlZTtNJk56CoMbZ8nI//74mDdhy046Adj0k7fnu09f8HT/UCJUZPDImbmzWFhLJQJO/JVnMjoZ42aGPGVXdKtOcbNusWLRLLWY2qpgxtR1mVzI8nXPUD15dXKAnve++L+K7n4uKG+fEMpQRCJKSULOcN09qgZWJ/s2L95ut02z06lySlyfs1RxNLUTUkJINzTBaGMPRcbvcQHHG35vziCq00k0lJOZuT6intbsvQ7klhIAw9fnXNxEi0tY/iqgcEIJXg5JkXkUXD5vqK2O84XdbMlzOKwpBiYHd1Qe53nNw6xeoSVdXkDH5oyQn2g+Pq/D5CKybzJcPgiM6RQkArhTEGJSVKMLowxjGFtBvG2ff9fke76/DOMZtNsFXB4B0xRgoBPkbgkmlVYG1BWRT0LuC6Hi8kSWpSzuzXF2gpxshz53BSoZRAC3Dtjt1mgygLdF1DiLh+oJ40uM6xv9lQG8vi1tkoYK6nnM54rp6xubpGpMikqSgKi4iCbnB459AqoeSYI6HnJ8isSBKGnFC2pCxrSB6+epfsHClFjNGU5ThSaLWmUBKhJFlXBFFiBYgYsTJTLUebbGs12Rqas2coFmek/Yru5gGDT+N453qFeLSz8y5iyzGq3WiLrRq01nTnK8TWE9qGqAtCdNz74qdJZEJIbIbIdH7EUum3fj52mzVtHFBSjE2GchSdvu0QWSNNQjQF2hSE3QrhdsTUk7wkDwHtoc8ChUDacdfo2h1xv8PWFXF4etOMD9px0I6Ddjwe7XD5nfsnPdULFFFqTm7X7FdbBj+Qk8SQMbWh3e0piwotFKF1aGDftSgsd24fcbRoyCGyX92w3yi6fcvV5Q1d2/PcM8e4ds+XX72k7Rzz4znT4xk3qx3XFyvK177I/M4d5MkdCmOQ0bG6vmFa3eH283fZ3nuNB/fPkVVNOamopzWL2Yz19Yr7b5zz7O1jglacbzY09egtcPbssxwfnSKV5MGrr3DvS19GeMdsXiFs/ShQa8KyW9CkRD0/4vS5d1HNT7i59zrnWtJ1O4QCW5SklAkxoOsJGY3Iieh6EBlTFaA0qVXkAZTWKG2xQlDWE1zXkX0/HmMKQfYDfTdgpKRoGnof2W63bLc72q5DSYUeHEFmRE6klOnbFh0CpVXE3iGLEiUNru/YrddkRnMkmRNCJlRh0VqRRUEOjm6/YVI3RKGJ1dh0uF9tCb2jLCrKZkJZjtbJ8+kECWSlsdUEHXqUUAQ34N0Y3NX3e9IwEKXGSA1K4VMGZcbQu26La7ej8CiNzP4ta25dFAiliENPVSqapgapiUNHaQwZSStLWl3j2h1+6JlU1RhmlhI5BmKWkDU5K5zPZGGZnhwzERLbXOCHdoyq91BP51hrkCmQokf5NbkxxJBxSLS25ByIsSdJhalKFtWYYyKDQyqIrme3XmOVZDKfoWxBCAGhLLqQrLc79pfXlIsF0+Upqt+Shw5lNOJR3ogb2jFYzI5NmTlEkuvx3Q4tFC6+83HBJ42Ddhy046Adj0c7hv4bJM346OSMsih50wUqa9lt9/RDz9nZAtc7rlYblospVklemD1DCgERAioG2vUGawtSlkih2LYDQ0rYaryHc/WU4lRQDAMpCYwwlHXB9RsefbNm9sqXOR56mknFYlFjzXj8V9UT9NERw16TjWZ5fEToBtY3K0KIPPviGftdy27lmT/zLqbLObNJzdHJ6ZgbIQQnZ7eR0TNs1xRGYkQmx4Rr90yrgi7BZDZFyoRbX9DfnCMJzKc1UgtMVRMiuOCYzqZoPYZN+a4jp8jgOspJgzEGbaYgQJCoJyX17DbdvmN3/RAB2KrAD4KimkIG5wZ0UbAwlqKs2Wy3CCEwZryr5NGdMykRnWeIHgaopMQoRY4Bay0+eFSGFBJd2yGUwRQVOUYieTx6lAqsoqjnWGtJQ4vb75FFgW4mzCYTwn6LYYz4VsZiywnZSaJ3VM2C+qhCSMGwuqIbWpSREAMxJVxyqKtzqv2G/c0lcdhRWItEIPx43K20QihBjonBOzZthwsJZStyt8UoiS0bYt/hcxjvl+OAVBlVTKimS5StCUOHb1cM/RoAlQPCDwilKZuaalITkiAnQEoUCS0MJDO6aUpQ1FTNBKk1JoEyHTFnfAgUSiL7DW63pZ5WSKlIMZBUgSlKjDXE4EhIXBLsu552v0OQMHEUbls2SK1AKaSUxKyYVAo/9Pihww+efc7jAsVYsn565eOgHQftOGjH49EOF/M7fk6fXoUBvvBbr3E6q/Bd4PpyRUpgVSJLRdnMaJ2jmi+x0aOVpNtsUdqSfc9+3ZGz5OJyhbKGqq7Hph9peeN8y92zF5kdW2hbVpst4WZLM5vj1UO++Mo5XR94/uacxfECLQ0nL9zBnkzZ7taEbkfqO2IXyGkcR4tlzdGk4u7dO5yf33DvvOWZb/42ju/cQsaeQo15DSEmUl1xdHZKPpqjU0TER81PQo7jiZMFy9vPIbRic/6QECPN8pgsBH3fM52fUDQN2812/JqSQKVIcANKSXIW7HcthS0wZU27b+mcRylwu4gpDETHbrtBbS05Z5rjW5STCf3mBu97BILJtAEBXdcjhKSqJwgpgERKHqssxgiEd6SU6LoW54bRd6AsQEAOiZQi2hiULZnXzfj+/WbcCeiC1nmG3QaZE2XTUM6WFM2E0pqxAS4mpBIQI9H1JB/wIaCMJmuJjwHUKL4xhvF+WMAwDFy99gU0GaUERVUSQsYNO7IxSD3aa2fnCW4gxET2AeccZM2w2RJdx3xxTFQFod+AMkxnMwq5JOsSITU+gbAVsh2b77TSBN/T7dYoWzA9OaNaHDEMkd3mhuTH+/kkJbYo8EikEUxnc5CKDHT7Pd57tFEoAZUWxG5LoxONEWz3ewprkGbswu9zxLlAzOPur5zMaGYzCi3QSpClJAiF1ZaUxrtyaywpJ2IuKKopRbMkp0jYrzBljf9d3CU/aRy046AdB+14PNrh0jdID8pr964oOGIIiW3naMqK6bRgtd5g6hn1dA5SIhPkEAi9Az2u8KaLBbvtHmEUCMbOay0olzPU4GgvHxKHHmJCaouwBboyvPD8ktU2cfeb3s3y5BSrK7SWiKokDFu2D+6z2W0hRIyWtLuWoi5YLJeE5Lm5umZSWO5+03uZPfccVWEhCHA93vUE73F9R0qZoqzBD0CmqmoG76mWxyyfezeqbgjdwEYKRKHHpjApKGSDyHn8IVGKm5sVdVPig2MYOurJjHp5ijaKfr0m9gPT2ZSyOcbtVww3F4TBYIylKCv8ox/43erqUTYEhH5PzpmimmC0oCPSdz3VZMrx6RkxBrp2i8gAkZgS3geUkghGl8HZ0RL9aDwwPXKcNHVD2czx62varoWoGHRAmpKinoJUzBcLrC3HRjUfxuj4HJFS4LqWsL4iI0laE5G41RYlBFJGsu9J3tOlACKTUiD6QERSlCVIRQiJrm1xWqGMQQuFHzxZZpCjqVQOns61OBfw3UDvLxBFQz2bcnR2RL2YIVC47ZrNxTnOJfLsWcTyBUzwuN01oRubGbUQhKEn9AMyZ2QeQ9cCiZwSwQeGYaAqCxACISXJe0Lfo5WiKCsqIVE50g0JoeQ4IdD3SF2QU2Kz3lA3DcY2kDJWaOr5grocnR97P4zCFgaMNShTkBEMw0AARDnBThYU9RwROrJwj4yjnt4rnoN2HLTjoB2PRzt2bf+On9OneoFSFRopJH3okI8asLQcTXte/9IrLI6PmReanBNvfuV13DCgrcIgOL21pGhKTuuS/b6lmk7JSkLIuMHTXV8/itlWTI8WyLIECSdHS2Lcsb7eomzBbAJgWB6fsr1RfPG13yLGSFOVDJ1ns9lROYcUkhQDm8sVz7z4Iscv3BmTJ11HevRN7vZ7+t2K6B0AXQgIEiJECgnkiLUGUqS7eIP95SU3lw9JMY47C23QVYkk095cMfRuvMN0Y5hVVc8gZXy3R1EjRMa3e2K/J2632MKMD1MMBDEme2opSaYgkUkxIGJCIIg547zH+fGu1ff9uKouDCGAlDVaWcgCGo8moqUgCUU/OJqqxE6m9LsW7x3Oe4bdHuUD64dv4Hcb5LShSw7TSKrZEokkhkjMe1T0IOV47Ny3Y0hYOcPOn0FrhRKRoff0OSCMRqcBTUZITUwRozXTskbmQAjgXCTEPBpgBQ9RkmNkSNAPA0aBSh6XHr3PVoiiwuoCYQ1KCmbzKcvTE3RRQQiooWUX3Jg3Mmzw0iCNBVuTdYEREgF02x0hjiZHWoDUCiEFIQSykJiqREiF8wGdEsN+h5QJrQvC4DBGE4RAlQ05jk2dxWw+pswicD6DLkEKfBgoy4q6acZfaiGQkxiPyY1BKgkiAYKQPH3K5G2LdZFCRKzJBJlxXUu32T22Z//3ykE7Dtpx0I7Hox3Rte/4OX2qFyiNlez7jn6ILGYTdusdOWVOppZa3eKV1y/5rfbzvOd9zyOVIBuLntXkwTH4AN7Tdx1FWWCURNdT+t2ebrcfH+zCsDhe0hwdE4wlh54iRITecXH+kHZzxdlywezZd1GdPUd11rDYD7h+R2MNIgUWyyn4gaooMGWFiwLmp3Tesz2/R1OWKK3HO1NjSdIScoDkCa5FKkWpFZHM4ALdZgU5s705Z3V9yepqhSlKyrNTmkmN0AYfEhlQGfJuy+WDK2xhWSxmyBzZ7/f0yqDLmv2uJQyOGBOzxYyu64g5MpvP0EoSQ0JrBSngujHgrCgn6JTYdx3DENBmvGOvHpkG5RDGPIfk0dpgrcEISRh6JAmRA1pBXRX4zhGEppAVcb/Cry9odMDeOaUvp0gPrh9w7Wa8Y24ldV1jTYk2FlNOiFKhyoZycYtycYb1LXJ3SeeuWEdHFxPWZIpmgpAVpMRRJdGpJyXHer0n5HHkzhQlirFhT2lL9B4hBVJrrCwRAmLOJDk6W0qpUMaSXEfsdgyrC3JRjLuXrkUaiwoDYX/F4AaEGqcbqrJA6wYpRkdNqS3ROTIJaw3ej2ZUY2S7hNJgjGLYbWhXV1RVzdgAoMhqzObAjxH0UkhCCONUBAKlDFpmQnCkmLBaM+z3bDcrsAUheBJQ1NNxkoQ8RqvnTHQDzgfa7TU6bbFHM0SOo+21f+eGS08aB+04aMdBOx6fdrxTnuoFysnZMcuTY4b9HlIku5aqkAihOD1VZDHltfsdu82Koi7o+8Dk6Ih+s8XWFbsH91ERZlUFpub07v9BIuIHz/biggh4H7i8uCSakmZak4HltEJJgQueKBV6Pgc9Hh3OZ0tcDhRNjVbgt5o8KOrSIosJ0+M7FPNjQhjwrmMfHLPFHGtKglfIssJohZFjZDcxUpcFxECFpG07+qEnxkA5mbE0NV3bEpxjv1mjyxpMhbTjvH5Onhj9GP3tB3wMSFtQLY+w1RRRVOAjzjt8jAxZj21m1Yx6Oht3TNExDO145C3FaMsdRqfAEDymaJgtjmjqGsnYse0eWUNHqWiHHqsV1hYMfY81BluOeSRu6PBolNT4GGh3GyaVQZgC7z2b82vWl1eY0mJtydnzLzJ/7j1U8xMEmvbmAawvx2Pq/Z7cfglVCAo6yiLSsufBusdNamx9gtIl+DA+lFHi946UM5W1lNaiiwIpQcU0ulhKQWUNQmlUNUEqxbDd4va70VSqGH0elB2bFNfrHSqticGjbIEwJRaFjJEsYeg6Bu8JSKqqQGhJWRVIa3A8yhhJieD68c7WOxCSsi7H3WbwKDPeDwutKespIUt8CHjnEGm8l9+3e1JKxBCZT6cwjI6eqigIKeL96OUQYyILhamnqNltlC1x3Y7+5gHdaoUnY4sZpigIfcfqYYcpLFoXFOadx6Y/aRy046AdB+14TNohv0Gs7quqYj6fYCaGZtIwbyxlHChmE1zMFKVgeZwxWlJUFXpasphOccbQFBbtHDIGVNWwfPHdLJ95YRwHmy65enjBvu8p53NiCOz7HSFLipSwUnJyfMymGxDVDKns2ChEot2t6S4f0q8k9aRhe3NFZS3SFDSTJcX8hGYype92kDMxRNwwoAQURpOTpQ2e4DNKWsgO7/woODmjC4s2mizGhrXJ2YTd6obN+X3arqOaRuqjGpEF/X6PyJm6qSmMoqkMPo4W5XZyhBCanNd07QalxLiDIRJDYmh7YjVamecMMSTc0JEqkHl8ALWW1JMlzWSKUWZ0joxj5LrI5ejHI8RolT14YhZjEqwtcC7R9hu22x1Ca0SIxBDR8zNaY9n5xOX5A67vvYkSIKxiMp2MgWGb9Wh2lBX9+QM2V/dBRlIMxN0eokclj5SCwSd00WCa24hyAs6hc0DpGZ7MkMZdjCZRaM1kMkEqyF3HxvUk79DGQjkl1AuQAmKGoWW/2aJjTSonVM0SYUqCa2HYIGPERwdhQAqJVBqlDFkFRByvSBSJpATeDczPblE3Nb7vGboe3w9IqcZRVK3QSkNKaFNSlhVCyTHlVGn69Ya+64kpjGZZxqCripwSZc6k6HBRoZqagEJVC1RtCFcPCP0WUxQUzRxjC1IM+HZHGjpyBmsshZZoJRBZEKNHhTzaYP8uhOZJ46AdB+04aMdj0g75zjc2T/UCZVJq9g8fcnx8xMnpCfPZlOv796CoEQhslDw/bQjeMzl7BqEUdRoQ05KhGxhiQBnD4plnWRydkPdb9ptrrjc9X77o6ds9R7dvM19UpKB58OACGyPL+YQ773qJiVQEJKpucP1A227G4K6+Q5YGLRoWp3coZ0smx3cw0wU5J4J3FEWJtWMjUg6jkY9SAmUKpAmEmBmGDhk9SUAMjr7tsGWJVFOQMHQ9/eCATDmp8UPPZn3Drh9Nk9x+C9GD0RgpsFqjrSLYirKq6bqeoe/H4zgf0bagqiRt29Pv96xzpKlLhr5nt2sRSiCMwQhBTGMIlBR5PGoU4PoeqSyyKsjKw7AfJwmGjhAzPiZMgkUzxTtPTAKjC6IPiGpBdXyHqCqC7xjO30TFh+MOsa5plsdMlgtEErzx6udwuz2hH+jajsH1TOcNy+WcmCLr1YoUI5O6Rhs9ehWQUEqijIJhh2u3uOBwLkCIKDm+XhR2vFtWMJ1N6PqWoR+IQ08Se3JwqDAgHzksFrZCiPHuPCtBkiXRZKQs8K5DBEchMyl5ggtobTG2witF9D3Dfk8MAakUzWw5Cr1UqMmU0I9Hv6WaIlPGdXt83yKqGqsLpFAMfU9KGVU2pJSQViEQyJzJIY55HkZh6xpRlOiYSEhICaIfE07iGC425DhOfOQAUqKLiqKZUtUNgozImeQdIgwMbUsYnt4rnoN2HLTjoB2PRzu69hukB8VqwdAnzi+uMZOS2WLO6Yt3aXd7Lh9ecu+LrzCfT5ksZ5xMjyB0hKtLum3L0Dv6zrF87oz67DluVhsuvvQKq67n/HJFj+ZyEPzWF1/nO77lXSSt6Vxgtd6SlWYaoaobaq3RWuOGlna7JnmPEFCW4wz54rn3U996AWEKcnSkfkN61PwklRpd+rQeY6uDJ8eIYOy41sYgRUaITFHM0MaSHhnw5ChQIoFMaGPRsmHje/ara1bXX6GsC2bzGUlJUsq0bce9h5HZ8TFaM6629xuMAlEVODcOjU7KAqsUgtF0KMWEi5HWO2pdvfXZG63HaO0U8H1LVBo/DJS1oShrht2W1cN70HYMPlAuj0FD6lrajUIiiEIT7JTq1ovUt17AHt3GSEl3fcl6v4flAl0co45uI01Bcj2JhKgm+NYh7OhyqHxJWZdIU1Api9SWECKkMeRqcAN5fTWu5L0jbW/w+1GA0tDi+w5Ppu16lEgsF1NMU1Iog40Nu+2ezXrHcLFCJIdUGqMUy9MzqsUpEUU39GihRxMr3ZBEScwGlTLZd2iVxztnZRBaofWMbuNxPpJCYHV5yWazpZrOx6NSF9lfn0MOY39D3+Jdj9KKlDN93wHjzhVlSGL0JUkyk12PFiCQmMKMR8lFgTKWEHramwskoBUoYSEnYrcl9DuElAgh8SEiTImZLFHNDGstSmkYWvrzN8mhQxn7uB793zMH7Thox0E7Ho92kL9Bsniy1ThrGHzmcrUj5sytkxNiCHRDh60rXEq4IbC/uaAoDOv1FrdpEdowfe5FFt/0f1E/9x66N1/j/s1/Y7/fk3OgNmPY1dVm4MuvXuKk4N7lhsZoKEp8ShTB03f78UPMGZUTpIgxhmoyozp9gerkGUzZIJUmDokQM8E7tNGIDDkGnPd0XcfQ7ghDi1Cjt4Kt56Ru92jsy6BtRGSJd5Gu6xBKUxSK0HcIIZg1E9StM5qqYvCOFDMxZqSS+Ay9T9jeY/Oe0I0/JEYJhNTYsiI4Rxx66rLC9+OKWBkNUmHKamzIenSsr5SmKCtIkIZ+7MzOieQG2tU17eqGiIFGM22m4w+uHyAn2t0GIzXZ1pS371A+836q+QxtFEYqYlWjtMRajVQaqxQpJsL1FcmO8/b18oTT4xMKK9lt1sToUEohsqJKkTAM7Lcbhr7Dak0YBq7ffHX0QxAZENhywvJ4xvZGstlu6bcr6tJSWo0qx055bUumJzVGF+zzBaAgJwYfESLT7jZ03tNnhc1qzKAIkdDvid2KUnimhR69Gswo+CE4EGDKmubYQErjL6CYCDFCCPiuJUnDZHFCzmO4mC5G2+/9vsX3PYUpIWWyCHR+YOcySsFkYimMhfTIyCpEJrbBh8x6u0dEjy5LTDUf76qHftyJZ9BK41wgZzD1FExJlnrMRykq0lDi2w2py6T0zscFnzQO2nHQjoN2PB7tkO4bZIESI0xPjnBXe3yCy8srXNvh2o5htWNxckoxnaHLki4pVBbIuuF4cQs9O2V5992UsxOyGCO/hdTInGhKRag1+WxGdD2b7Y6H24H1puO5b7nL7OSULAQpeXIOCMI4wtbvyMFTLU9YPPcejp99N3ayQBuDyIl235PCADmS43hSFlIc/Qv8QDv0uN2ebrdhNp1ytDxC5EhhC4qmHo9Th5YUx/fzaMeUQ0Dk0VZ4uVhgq5pdN+AGx3azBSEoioKqLMg54vsOKRVlYQjeo6wZ3f+UHh0pMygrUEWNtJacgUfHvKRE6DusEsgckWSGMNaiqwYzmSCkpFksaRYLXAxIYxg2O/bbK0AyrwuEFmOo2dCRcyLGgBSCITrCdoVWgmI6JUeHCQMhRPywRzhBkgqJQqaAzpbaFAw5EVNCiDS6IOZENZ3SLBaIFFndXOJDoJnMKJTCFpa6rqgKhes62I7fi/3Qs2132NBjraCcLFHFlLIpcfuS6AOkgESRlaZ3PdvtBo8Zx/FSIHQdbnONCHtEXaKbEi3HX1opRWJMBCHJQqMKQ2EUOY6fAWI0WCqrmiAMInr2NxfookLrAt973DAgbQHNjOwdbr3l5nJDFxMnt46YNEtkGpNGY85jT0LUSCTKVoTWj7HschQ/U5TkR1ktSInEYAykJMYd7ERBGndV0ljs4gSvFKTtY1aA//85aMdBOw7a8Xi0Qw1f51M8OY9WueeXK+ZJ8vpr93jp+TNO5xUCiesGpLI0t16geeZFdDlhGHr67QVO7kAWHN+6wz5mhqsL5DAQbx5y1Aj21y3nD3dUTcP7XjglDVusyMTXB6bllKM7tzGTKdE59n7ASkG727PabNhe3GdiS47uvMjs1gvIoiHGSAiebnPD+vIhIrpxZwHjHZ8U+JAIKYGQqKKgu3K0u/uQx6AtHwLT5ZJCK/zgyD7S73tiBibgvSPlRFqtxlCwDMIUaG1AqfGHLWX2XYfVhrquyEJws1oT/ICSmqw0KEMzmQKClAJSGmR+NHY2+NF1sqiQwrFb77i6uKbvW5SUFGWN6gO6i0xOzjBHt4hdS3f1kNz3o6mSKZHK0KWI2/Xs+w0PVp/kjCnNyRmp3bF+9UvkuGO+qBEpMux79vvRhGp1fYWRAmOKcZLi+oJOCtq+J8SEMhprDBKBCwFrNNZoos9jVkcq3mpK7DuHbztIjq53rFZbELBZbUgx0MxqJk3JtosI2zH0PZvVDmVLVBYIoRExkzOEmOn7PT7cw5YV+ADBYZUkZVjv2lGgxWjBHYUiKUP0wzhBksaGx5gCSmlCGsaLeasY1htCiCQL3WbL9mZFVppqMSPqKd7v2cYtdj5nWhgmlSG0O0KICG0R2tB3Dnd5jjEaKSC4CDhc2oKIKCFJaWxeRCq6PuBDBLGjjBGtJGkYGPoOJQXZB5JQ+JTe9jw+DRy046AdB+14/Nrx25/F/x0iP03q8ogvf/nLvPvd737cZRw4cAB4/fXXee655x53Ge+Ig3YcOPBk8E5046k8QTk6OgLgtddeYz6fP+Zq3hmbzYbnn3+e119/ndls9rjLecc8jXUfav6DIefMdrvlmWeeedylvGMO2vEHw6HmPxiexpp/N7rxVC5QvtpsNZ/Pn5pvyleZzWZPXc3wdNZ9qPlrz9PyS/6rHLTjD5ZDzX8wPG01v1PdeHqdlg4cOHDgwIEDX7ccFigHDhw4cODAgSeOp3KBUhQF/+Af/AOKonjcpbxjnsaa4ems+1Dzgf8VT+PnfKj5D4ZDzU8eT+UUz4EDBw4cOHDg65un8gTlwIEDBw4cOPD1zWGBcuDAgQMHDhx44jgsUA4cOHDgwIEDTxyHBcqBAwcOHDhw4InjqVyg/LN/9s+4e/cuZVny3d/93fz6r//6Y6vlv/yX/8Kf+TN/hmeeeQYhBD/zMz/zttdzzvz9v//3uXPnDlVV8cEPfpAvfOELb3vP9fU1P/ADP8BsNmOxWPBX/spfYbfbfU3q/fEf/3H+8B/+w0ynU87Ozvjzf/7P87nPfe5t7+n7ng996EMcHx8zmUz4i3/xL/Lw4cO3vee1117j+77v+6jrmrOzM/7W3/pbYx7D14if/Mmf5Nu+7dveMiR6+eWX+bmf+7knuubfzk/8xE8ghOBHf/RHn5qav9446MbvjadRO5523YBvcO3ITxkf/vCHs7U2/4t/8S/ypz/96fxX/+pfzYvFIj98+PCx1POzP/uz+e/+3b+b/+2//bcZyD/90z/9ttd/4id+Is/n8/wzP/Mz+b//9/+e/+yf/bP5pZdeyl3XvfWeP/kn/2T+9m//9vyrv/qr+b/+1/+a3/Oe9+Tv//7v/5rU+73f+735p37qp/KnPvWp/PGPfzz/6T/9p/MLL7yQd7vdW+/5oR/6ofz888/nX/iFX8i/+Zu/mf/IH/kj+Y/+0T/61ushhPwt3/It+YMf/GD+2Mc+ln/2Z382n5yc5L/9t//216TmnHP+9//+3+f/+B//Y/785z+fP/e5z+W/83f+TjbG5E996lNPbM1f5dd//dfz3bt387d927flH/mRH3nr75/kmr/eOOjG752nUTueZt3I+aAdT90C5bu+67vyhz70obf+PcaYn3nmmfzjP/7jj7Gqkf9RaFJK+fbt2/kf/sN/+NbfrVarXBRF/lf/6l/lnHP+zGc+k4H8G7/xG2+95+d+7ueyECK/+eabX/Oaz8/PM5A/8pGPvFWfMSb/63/9r996z2c/+9kM5F/5lV/JOY/iKqXMDx48eOs9P/mTP5lns1kehuFrXvNXWS6X+Z//83/+RNe83W7ze9/73vzzP//z+Y//8T/+lsg8yTV/PXLQjd9/nlbteBp0I+eDduSc81N1xeOc46Mf/Sgf/OAH3/o7KSUf/OAH+ZVf+ZXHWNnvzCuvvMKDBw/eVu98Pue7v/u736r3V37lV1gsFnznd37nW+/54Ac/iJSSX/u1X/ua17her4H/N0Ttox/9KN77t9X8Td/0Tbzwwgtvq/lbv/VbuXXr1lvv+d7v/V42mw2f/vSnv+Y1xxj58Ic/zH6/5+WXX36ia/7Qhz7E933f972tNng6PuevFw668bXhadOOp0k34KAd8JSFBV5eXhJjfNuHDnDr1i1+67d+6zFV9b/mwYMHAL9jvV997cGDB5ydnb3tda01R0dHb73na0VKiR/90R/lj/2xP8a3fMu3vFWPtZbFYvG/rfl3+pq++trXik9+8pO8/PLL9H3PZDLhp3/6p/nmb/5mPv7xjz+RNX/4wx/mv/23/8Zv/MZv/E+vPcmf89cbB934/edp0o6nTTfgoB1f5alaoBz4/eVDH/oQn/rUp/jlX/7lx13KO+L9738/H//4x1mv1/ybf/Nv+MEf/EE+8pGPPO6yfkdef/11fuRHfoSf//mfpyzLx13OgQO/rzxN2vE06QYctOO381Rd8ZycnKCU+p+6lR8+fMjt27cfU1X/a75a0/+u3tu3b3N+fv6210MIXF9ff02/ph/+4R/mP/yH/8Av/dIv8dxzz72tZuccq9Xqf1vz7/Q1ffW1rxXWWt7znvfwgQ98gB//8R/n27/92/nH//gfP5E1f/SjH+X8/Jw/9If+EFprtNZ85CMf4Z/8k3+C1ppbt249cTV/vXLQjd9fnjbteJp0Aw7a8dt5qhYo1lo+8IEP8Au/8Atv/V1KiV/4hV/g5ZdffoyV/c689NJL3L59+231bjYbfu3Xfu2tel9++WVWqxUf/ehH33rPL/7iL5JS4ru/+7t/32vKOfPDP/zD/PRP/zS/+Iu/yEsvvfS21z/wgQ9gjHlbzZ/73Od47bXX3lbzJz/5ybcJ5M///M8zm8345m/+5t/3mv9XpJQYhuGJrPl7vud7+OQnP8nHP/7xt/5853d+Jz/wAz/w1j8/aTV/vXLQjd8fvl6040nWDThox9t43F26v1s+/OEP56Io8r/8l/8yf+Yzn8l/7a/9tbxYLN7WrfwHyXa7zR/72Mfyxz72sQzkf/SP/lH+2Mc+ll999dWc8zguuFgs8r/7d/8uf+ITn8h/7s/9ud9xXPA7vuM78q/92q/lX/7lX87vfe97v2bjgn/9r//1PJ/P83/+z/85379//60/bdu+9Z4f+qEfyi+88EL+xV/8xfybv/mb+eWXX84vv/zyW69/dYTtT/yJP5E//vGP5//0n/5TPj09/ZqOsP3Yj/1Y/shHPpJfeeWV/IlPfCL/2I/9WBZC5P/7//6/n9ia/0d+eyf+01Lz1wsH3fi98zRqx9eDbuT8jasdT90CJeec/+k//af5hRdeyNba/F3f9V35V3/1Vx9bLb/0S7+Ugf/pzw/+4A/mnMeRwb/39/5evnXrVi6KIn/P93xP/tznPve2/8fV1VX+/u///jyZTPJsNst/6S/9pbzdbr8m9f5OtQL5p37qp956T9d1+W/8jb+Rl8tlrus6/4W/8Bfy/fv33/b/+cpXvpL/1J/6U7mqqnxycpL/5t/8m9l7/zWpOeec//Jf/sv5xRdfzNbafHp6mr/ne77nLZF5Umv+H/kfReZpqPnriYNu/N54GrXj60E3cv7G1Q6Rc85/cOc1Bw4cOHDgwIED/988VT0oBw4cOHDgwIFvDA4LlAMHDhw4cODAE8dhgXLgwIEDBw4ceOI4LFAOHDhw4MCBA08chwXKgQMHDhw4cOCJ47BAOXDgwIEDBw48cRwWKAcOHDhw4MCBJ47DAuXAgQMHDhw48MRxWKAcOHDgwIEDB544DguUAwcOHDhw4MATx2GBcuDAgQMHDhx44jgsUA4cOHDgwIEDTxz/P1J8kN/NrWHCAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import json\n", "from torchvision.io import read_image\n", "\n", "\n", "preprocess = MaxVit_T_Weights.IMAGENET1K_V1.transforms()\n", "\n", "with open(\"imagenet_class_index.json\") as labels_file:\n", " labels = json.load(labels_file)\n", "\n", "\n", "dog1 = read_image(\"dog1.jpg\")\n", "tensor = preprocess(dog1).unsqueeze(dim=0)\n", "\n", "torch_model.eval()\n", "with torch.inference_mode():\n", " torch_output = torch_model(tensor)\n", "\n", "torch_class_id = torch_output.argmax(dim=1).item()\n", "\n", "jax_array = jnp.asarray(tensor.permute(0, 2, 3, 1), device=jax.devices(\"cpu\")[0])\n", "flax_model.eval()\n", "flax_output = flax_model(jax_array)\n", "\n", "flax_class_id = torch_output.argmax(axis=1).item()\n", "\n", "print(\"Prediction for the Dog:\")\n", "print(f\"- PyTorch model result: {labels[str(torch_class_id)]}, score: {torch_output.softmax(axis=1)[0, torch_class_id]}\")\n", "print(f\"- Flax model result: {labels[str(flax_class_id)]}, score: {jax.nn.softmax(flax_output, axis=1)[0, flax_class_id]}\")\n", "\n", "\n", "plt.subplot(121)\n", "plt.title(f\"{labels[str(torch_class_id)]}\\nScore: {torch_output.softmax(dim=-1)[0, class_id]:.4f}\")\n", "plt.imshow(dog1.permute(1, 2, 0))\n", "\n", "plt.subplot(122)\n", "plt.title(f\"{labels[str(flax_class_id)]}\\nScore: {jax.nn.softmax(flax_output, axis=1)[0, flax_class_id]:.4f}\")\n", "plt.imshow(dog1.permute(1, 2, 0))" ] }, { "cell_type": "markdown", "id": "c77f3244", "metadata": {}, "source": [ "Let's compute cosine distance between the logits:" ] }, { "cell_type": "code", "execution_count": 45, "id": "36801241-11cc-4850-8ea2-f0a306eaba2a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array(0.99999857, dtype=float32)" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "expected = jnp.asarray(torch_output)\n", "\n", "cosine_dist = (expected * flax_output).sum() / (jnp.linalg.norm(flax_output) * jnp.linalg.norm(expected))\n", "cosine_dist" ] }, { "cell_type": "markdown", "id": "65e57aa6-1572-4805-9207-bc8a5f9f3ab1", "metadata": {}, "source": [ "## Further reading\n", "\n", "- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)\n", "- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/getting_started.html)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }