{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "cib2jXJDQ4Nz" }, "source": [ "# Adversarial training\n", "\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/optax/blob/main/examples/adversarial_training.ipynb)\n", "\n", "\n", "The following code trains a convolutional neural network (CNN) to be robust\n", "with respect to the projected gradient descent (PGD) method.\n", "\n", "The Projected Gradient Descent Method (PGD) is a simple yet effective method to\n", "generate adversarial images. At each iteration, it adds a small perturbation\n", "in the direction of the sign of the gradient with respect to the input followed\n", "by a projection onto the infinity ball. The gradient sign ensures this\n", "perturbation locally maximizes the objective, while the projection ensures this\n", "perturbation stays on the boundary of the infinity ball.\n", "\n", "## References\n", "\n", " Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. \"Explaining\n", " and harnessing adversarial examples.\", https://arxiv.org/abs/1412.6572\n", "\n", " Madry, Aleksander, et al. \"Towards deep learning models resistant to\n", " adversarial attacks.\", https://arxiv.org/abs/1706.06083" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "executionInfo": { "elapsed": 3773, "status": "ok", "timestamp": 1707150623004, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "diO66z2ZQ4N3", "outputId": "bb764d5c-9684-4a74-9906-5e37e8fa23a4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX running on GPU\n" ] } ], "source": [ "import datetime\n", "\n", "import jax\n", "from jax import numpy as jnp\n", "from flax import linen as nn\n", "\n", "import optax\n", "from optax.losses import softmax_cross_entropy_with_integer_labels\n", "from optax.tree_utils import tree_l2_norm\n", "\n", "from matplotlib import pyplot as plt\n", "plt.rcParams.update({\"font.size\": 22})\n", "\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "\n", "# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make\n", "# it unavailable to JAX.\n", "tf.config.experimental.set_visible_devices([], \"GPU\")\n", "\n", "# Show on which platform JAX is running.\n", "print(\"JAX running on\", jax.devices()[0].platform.upper())" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "executionInfo": { "elapsed": 52, "status": "ok", "timestamp": 1707150623148, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "sWqYXjZFpUXe" }, "outputs": [], "source": [ "# @markdown Total number of epochs to train for:\n", "EPOCHS = 10 # @param{type:\"integer\"}\n", "# @markdown Number of samples for each batch in the training set:\n", "TRAIN_BATCH_SIZE = 128 # @param{type:\"integer\"}\n", "# @markdown Number of samples for each batch in the test set:\n", "TEST_BATCH_SIZE = 128 # @param{type:\"integer\"}\n", "# @markdown Learning rate for the optimizer:\n", "LEARNING_RATE = 0.001 # @param{type:\"number\"}\n", "# @markdown The dataset to use.\n", "DATASET = \"mnist\" # @param{type:\"string\"}\n", "# @markdown The amount of L2 regularization to use:\n", "L2_REG = 0.0001 # @param{type:\"number\"}\n", "# @markdown Adversarial perturbations lie within the infinity-ball of radius epsilon.\n", "EPSILON = 0.01 # @param{type:\"number\"}" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "executionInfo": { "elapsed": 54, "status": "ok", "timestamp": 1707150623297, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "kwr2gSOZQ4N7" }, "outputs": [], "source": [ "class CNN(nn.Module):\n", " \"\"\"A simple CNN model.\"\"\"\n", " num_classes: int\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = x.reshape((x.shape[0], -1)) # flatten\n", " x = nn.Dense(features=256)(x)\n", " x = nn.relu(x)\n", " x = nn.Dense(features=self.num_classes)(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "executionInfo": { "elapsed": 6789, "status": "ok", "timestamp": 1707150630219, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "NUBxQ4c5LPwp" }, "outputs": [], "source": [ "(train_loader, test_loader), mnist_info = tfds.load(\n", " \"mnist\", split=[\"train\", \"test\"], as_supervised=True, with_info=True\n", ")\n", "\n", "train_loader_batched = train_loader.shuffle(\n", " 10 * TRAIN_BATCH_SIZE, seed=0\n", ").batch(TRAIN_BATCH_SIZE, drop_remainder=True)\n", "test_loader_batched = test_loader.batch(TEST_BATCH_SIZE, drop_remainder=True)\n", "\n", "input_shape = (1,) + mnist_info.features[\"image\"].shape\n", "num_classes = mnist_info.features[\"label\"].num_classes\n", "iter_per_epoch_train = (\n", " mnist_info.splits[\"train\"].num_examples // TRAIN_BATCH_SIZE\n", ")\n", "iter_per_epoch_test = mnist_info.splits[\"test\"].num_examples // TEST_BATCH_SIZE" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "executionInfo": { "elapsed": 53, "status": "ok", "timestamp": 1707150630371, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "trBo7PpcQ4N8" }, "outputs": [], "source": [ "net = CNN(num_classes)\n", "\n", "@jax.jit\n", "def accuracy(params, data):\n", " inputs, labels = data\n", " logits = net.apply({\"params\": params}, inputs)\n", " return jnp.mean(jnp.argmax(logits, axis=-1) == labels)\n", "\n", "\n", "@jax.jit\n", "def loss_fun(params, l2reg, data):\n", " \"\"\"Compute the loss of the network.\"\"\"\n", " inputs, labels = data\n", " x = inputs.astype(jnp.float32)\n", " logits = net.apply({\"params\": params}, x)\n", " sqnorm = tree_l2_norm(params, squared=True)\n", " loss_value = jnp.mean(softmax_cross_entropy_with_integer_labels(logits, labels))\n", " return loss_value + 0.5 * l2reg * sqnorm\n", "\n", "@jax.jit\n", "def pgd_attack(image, label, params, epsilon=0.1, maxiter=10):\n", " \"\"\"PGD attack on the L-infinity ball with radius epsilon.\n", "\n", " Args:\n", " image: array-like, input data for the CNN\n", " label: integer, class label corresponding to image\n", " params: tree, parameters of the model to attack\n", " epsilon: float, radius of the L-infinity ball.\n", " maxiter: int, number of iterations of this algorithm.\n", "\n", " Returns:\n", " perturbed_image: Adversarial image on the boundary of the L-infinity ball\n", " of radius epsilon and centered at image.\n", "\n", " Notes:\n", " PGD attack is described in (Madry et al. 2017),\n", " https://arxiv.org/pdf/1706.06083.pdf\n", " \"\"\"\n", " image_perturbation = jnp.zeros_like(image)\n", " def adversarial_loss(perturbation):\n", " return loss_fun(params, 0, (image + perturbation, label))\n", "\n", " grad_adversarial = jax.grad(adversarial_loss)\n", " for _ in range(maxiter):\n", " # compute gradient of the loss wrt to the image\n", " sign_grad = jnp.sign(grad_adversarial(image_perturbation))\n", "\n", " # heuristic step-size 2 eps / maxiter\n", " image_perturbation += (2 * epsilon / maxiter) * sign_grad\n", " # projection step onto the L-infinity ball centered at image\n", " image_perturbation = jnp.clip(image_perturbation, - epsilon, epsilon)\n", "\n", " # clip the image to ensure pixels are between 0 and 1\n", " return jnp.clip(image + image_perturbation, 0, 1)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "executionInfo": { "elapsed": 53, "status": "ok", "timestamp": 1707150630531, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "QgQcIgg2Z2XB" }, "outputs": [], "source": [ "def dataset_stats(params, data_loader, iter_per_epoch):\n", " \"\"\"Computes accuracy on clean and adversarial images.\"\"\"\n", " adversarial_accuracy = 0.\n", " clean_accuracy = 0.\n", " for batch in data_loader.as_numpy_iterator():\n", " images, labels = batch\n", " images = images.astype(jnp.float32) / 255\n", " clean_accuracy += jnp.mean(accuracy(params, (images, labels))) / iter_per_epoch\n", " adversarial_images = pgd_attack(images, labels, params, epsilon=EPSILON)\n", " adversarial_accuracy += jnp.mean(accuracy(params, (adversarial_images, labels))) / iter_per_epoch\n", " return {\"adversarial accuracy\": adversarial_accuracy, \"accuracy\": clean_accuracy}\n", "\n", "@jax.jit\n", "def train_step(params, opt_state, batch):\n", " images, labels = batch\n", " # convert images to float as attack requires to take gradients wrt to them\n", " images = images.astype(jnp.float32) / 255\n", " adversarial_images_train = pgd_attack(images, labels, params, epsilon=EPSILON)\n", " # train on adversarial images\n", " loss_grad_fun = jax.grad(loss_fun)\n", " grads = loss_grad_fun(params, L2_REG, (adversarial_images_train, labels))\n", " updates, opt_state = optimizer.update(grads, opt_state)\n", " params = optax.apply_updates(params, updates)\n", " return params, opt_state" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "executionInfo": { "elapsed": 62308, "status": "ok", "timestamp": 1707150692933, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "nwt6HuxhCd1a", "outputId": "71338c4a-356d-4749-e5ad-d186ccf5d82e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 out of 10\n", "Accuracy on train set: 0.982\n", "Accuracy on test set: 0.982\n", "Adversarial accuracy on train set: 0.979\n", "Adversarial accuracy on test set: 0.977\n", "Time elapsed: 0:00:10\n", "\n", "Epoch 1 out of 10\n", "Accuracy on train set: 0.989\n", "Accuracy on test set: 0.987\n", "Adversarial accuracy on train set: 0.986\n", "Adversarial accuracy on test set: 0.984\n", "Time elapsed: 0:00:15\n", "\n", "Epoch 2 out of 10\n", "Accuracy on train set: 0.991\n", "Accuracy on test set: 0.988\n", "Adversarial accuracy on train set: 0.988\n", "Adversarial accuracy on test set: 0.986\n", "Time elapsed: 0:00:21\n", "\n", "Epoch 3 out of 10\n", "Accuracy on train set: 0.992\n", "Accuracy on test set: 0.989\n", "Adversarial accuracy on train set: 0.990\n", "Adversarial accuracy on test set: 0.986\n", "Time elapsed: 0:00:26\n", "\n", "Epoch 4 out of 10\n", "Accuracy on train set: 0.992\n", "Accuracy on test set: 0.988\n", "Adversarial accuracy on train set: 0.990\n", "Adversarial accuracy on test set: 0.985\n", "Time elapsed: 0:00:32\n", "\n", "Epoch 5 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.991\n", "Adversarial accuracy on train set: 0.994\n", "Adversarial accuracy on test set: 0.989\n", "Time elapsed: 0:00:37\n", "\n", "Epoch 6 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.993\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:00:43\n", "\n", "Epoch 7 out of 10\n", "Accuracy on train set: 0.996\n", "Accuracy on test set: 0.992\n", "Adversarial accuracy on train set: 0.995\n", "Adversarial accuracy on test set: 0.990\n", "Time elapsed: 0:00:48\n", "\n", "Epoch 8 out of 10\n", "Accuracy on train set: 0.994\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.992\n", "Adversarial accuracy on test set: 0.987\n", "Time elapsed: 0:00:54\n", "\n", "Epoch 9 out of 10\n", "Accuracy on train set: 0.997\n", "Accuracy on test set: 0.992\n", "Adversarial accuracy on train set: 0.995\n", "Adversarial accuracy on test set: 0.991\n", "Time elapsed: 0:00:59\n", "\n" ] } ], "source": [ "# Initialize parameters.\n", "key = jax.random.PRNGKey(0)\n", "var_params = net.init(key, jnp.zeros(input_shape))[\"params\"]\n", "\n", "# Initialize the optimizer.\n", "optimizer = optax.adam(LEARNING_RATE)\n", "opt_state = optimizer.init(var_params)\n", "\n", "start = datetime.datetime.now().replace(microsecond=0)\n", "\n", "accuracy_train = []\n", "accuracy_test = []\n", "adversarial_accuracy_train = []\n", "adversarial_accuracy_test = []\n", "for epoch in range(EPOCHS):\n", " for train_batch in train_loader_batched.as_numpy_iterator():\n", " var_params, opt_state = train_step(var_params, opt_state, train_batch)\n", "\n", " # compute train set accuracy, both on clean and adversarial images\n", " train_stats = dataset_stats(var_params, train_loader_batched, iter_per_epoch_train)\n", " accuracy_train.append(train_stats[\"accuracy\"])\n", " adversarial_accuracy_train.append(train_stats[\"adversarial accuracy\"])\n", "\n", " # compute test set accuracy, both on clean and adversarial images\n", " test_stats = dataset_stats(var_params, test_loader_batched, iter_per_epoch_test)\n", " accuracy_test.append(test_stats[\"accuracy\"])\n", " adversarial_accuracy_test.append(test_stats[\"adversarial accuracy\"])\n", "\n", " time_elapsed = (datetime.datetime.now().replace(microsecond=0) - start)\n", " print(f\"Epoch {epoch} out of {EPOCHS}\")\n", " print(f\"Accuracy on train set: {accuracy_train[-1]:.3f}\")\n", " print(f\"Accuracy on test set: {accuracy_test[-1]:.3f}\")\n", " print(f\"Adversarial accuracy on train set: {adversarial_accuracy_train[-1]:.3f}\")\n", " print(f\"Adversarial accuracy on test set: {adversarial_accuracy_test[-1]:.3f}\")\n", " print(f\"Time elapsed: {time_elapsed}\\n\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "height": 483 }, "executionInfo": { "elapsed": 419, "status": "ok", "timestamp": 1707150693452, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "vQnv1S84Q4N-", "outputId": "bd622b38-d865-4fcd-e746-2621eff0b57c" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAB/EAAAOkCAYAAABeW8dTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAABYl\nAAAWJQFJUiTwAAEAAElEQVR4nOzdd5xU5fXH8c/ZXTpLbyIoiFQLCIpdsfcSiTWJoumxJDHGX0xi\nNNYYo0YxiSnWxB4LltgRe0WxAoJ0EZDey+6e3x93Roa7d3anz+zu9/16zWt3nnvvc8/OzM7u3PM8\n5zF3R0RERERERERERERERERERIqvrNgBiIiIiIiIiIiIiIiIiIiISEBJfBERERERERERERERERER\nkRKhJL6IiIiIiIiIiIiIiIiIiEiJUBJfRERERERERERERERERESkRCiJLyIiIiIiIiIiIiIiIiIi\nUiKUxBcRERERERERERERERERESkRSuKLiIiIiIiIiIiIiIiIiIiUCCXxRURERERERERERERERERE\nSoSS+CIiIiIiIiIiIiIiIiIiIiVCSXwREREREREREREREREREZESoSS+iIiIiIiIiIiIiIiIiIhI\niVASX0REREREREREREREREREpEQoiS8iIiIiIiIiIiIiIiIiIlIilMQXEREREWnkzGyWmXnCbVax\nYxLJhpldGnpNu5mNKnZcqWjIsTclZjYm4nkaU+y4UmVmoyLiv7TYcYmIiIiIiEhqKoodgIiIiIiI\niIiIiIg0DmbWCRgB9AA6AO2ATcAaYAUwG5gFfOHuNcWJUkRERKS0KYkvIiIiIpIGM+sBzAPKIzbv\n4u6TChuRiIiISOMWqyRxSR27/NDd/5Gjc40CXqxjlzvdfUwK/fQBZibZvBEY6O6z0ouu1jk83Obu\nluKxE4D9Q82/d/dLM4xle+CHwHFA/xQPW21mE4F3gJeA8e6+Nkn/s4BtM4ktR1J63kVERERyReX0\nRURERETS8x2iE/gAYwoYh4iIiIgEzijRvpJpDvy+AOfJOzPrbmb/BT4DLiD1BD5AW4KBBBcAjwNL\nzezB3EcpIiIi0vAoiS8iIiIikp4xdWw7zcyaFSoQEZF0mFlvM1scuo0tdlwiIjmwl5mlkzyOZGZt\ngG/mIJ5UfNvMhhToXHlhZkcAnwKjgZQqANSjBbBvDvoRERERafBUTl9EREREJEVmNhKo62JrV+Ao\n4NGCBCQikp5yoHOorbIYgYiI5MHpwMVZ9jGaYHZ4IZQBVwAnFOh8OWVmhwKPECTeo2wEpgDTgJVA\nFdCJ4O/QjkCXAoQpIiIi0mApiS8iIiIikrozU9zn0TzHIdKkxdbrvbTIYUgj5u53AHcUOYyMufsE\ncjMrVqSU1bBlldHvmNnv3L3WOvFpCJfSD58j175hZiPd/e08niPnzKw38DDRCfy3gOuBx919XR19\n9AX2Iah8cFiSvhLtQvIlrZLZm+j/y7um2Q/A+gyOEREREcmYkvgiIiIiIikws5bAyaHmjcAaoGNC\n25Fm1s3dFxUsOBEREZGm53ng0IT72xKsrz4hk87MbBvggFDzcwQJ5ny6Cjg4z+fItT8BbUJtDvwC\n+HMqAyncfSYwE/i3mbUDvgX8FOiQZP9l6QZpZiuS9LU43b5ERERECi2fI0lFRERERBqT49kyWQ/w\nBHBvqK2C4CKkiIiIiOTPG8BnobbwTPp0nM6WFSxWkPvqSsuBtaG2g8zswByfJ29igx1Oitj0W3e/\nIZNKCO6+0t3/BgxO0reIiIhIk6MkvoiIiIhIasZEtP07dgtLpey+iIiIiGTnrtD9b5pZeIZ4qk4P\n3X+A3JdQXwHcHNF+VY7Pk0/HRbTNAf6YbcceeDnbfkREREQaAyXxRURERETqYWZbA4eEmpcC/3P3\nN4FpoW07mdnwggQnIiIi0nT9m6CMe1xb4IR0OzGzvYH+oeY7s4irLtcQJPMT7W5mUcnxUrRfRNsj\n7l5V8EhEREREGrGKYgcgIiIiItIAnE7tAbD3ufvG2Pf/AX4f2j4GeC/XgZhZM2B3YAegM1BNMKBg\nMvCOu2/I9TlLkZkZMJDgcehFcNF+HbAQ+BJ4093D5WrzFUsFMBwYBHQDWgFrgBnu/mgKx5cB2wM7\nA12A9gSf1dYBq4EvgNnAtFw9v2bWg6Bkbd/Y+doAKwleSwsIXkuR68jmi5n1BHaJxdSOICmzFHja\n3Wfn+dxlsfMOBnrGzt+coOzxUoLHf2LC77xkwMz6ErzOtwUqCd6/FgOPprI+cez3vjfB79o2BM9T\nK4Jk2FKC35V3CvW7n63YzzOM4DHpBpQDXxH8HK+6++riRZcaMxtC8Hu7FcHvzGKC9+DX3H1pns7Z\nDdiN4He1K8FM6S+ASe4+NR/nzBcz6wqMYPN7cTmwCpgLfOjunxcwlgpgV4K/q10I3oMXEcywfq1U\n/79w9zlm9iKQWI7+dKIrJdUlXIZ/uru/ZmbhxH7W3H2pmf0JuDy06Qoze9zda3J9zhzbKqKtYK9V\nERERkaZCSXwRERERkfqNiWhLvDj8H+BStlxH9TQzuyBXSb9YNYBfA98mSFxFWW1mDwBXu/v0DM9z\nG7WXA/ihu/8jk/4S+i0H5gE9Epod6OPuc9LoZ2fgJwSz7LrWset6M3sJ+Ie7P5xByJhZeE3Xl9x9\nVML2HYBfAKOJfk5mU8dauma2B/D92PHtUwhpo5m9D7wI3O/uk1I4Jn6ubsDxBEmOUUD3eg6pMbMP\nCF7bf3f3NameK3TeS4FLQs0HuPuE2PZmwFnAjwiSmVHOBO5Ip98UYxsMHAscAOxDMJChLuvM7A3g\nb8DDDSDJAoCZ3UHda0SfYWaprCF9prvfkWL/fd19Vmx7W+DHwPeAAUn6ng5MiNpgZtsSvHYPAPYH\nOtQT5yYzexe4FbjL3TfVs38kMxsD3B5qjnwMIo6dRTBQIW62u/dJ2N4ROJ/gMelBtI1mNh642N3f\nTTnwzecYRfBekej37n5pCsdOIHisv+bulrC9FXA2wXtx3yTd1JjZ67FzPp9q3PXEdULsvPsTJLuj\n9vkc+AvwN3dfH2sbRYaPRT6YWWuC5/4MggFgde07G7gHGOvuX2Z4vvr+lm0NXAh8B+iYpJu1ZvYE\nwZrn4epDpeBOtkziH2hmvd19bioHx17T4XXYw2X6c+0G4FyCATxxOwLfIv0BCIUW9f9XRu+1IiIi\nIpKcyumLiIiIiNTBzPaiduJpeqyMPgDuPgN4LbRPZ+CYHMUwBphKkDBJlsCHYDb6WcBHZnZuhqeL\numj9nQz7SnQItZNVL6WawDez7mZ2L/AB8EPqTuADtAQOAx4ys1fMbMd0A64jljIzuxKYRJBgrus5\niTq+o5n9G3iD4PlKJYEPwSzX3YFfAe+b2fEpnKuPmT0LzAf+DpxM/Ql8CD4r7gJcB8w2s/A6wVkz\ns50IHsNbSJ7AzzkzO8zMJgGfAn8geJ2ksn5yK4Ik0YPAx2a2a96CbCTM7ACCx/mPJE/gJzt2hJm9\nBswE/kywBnOHFA5tBuwJ/Av43MwOT+e8+WZmhxE8Jr8leQIfgt/3w4G3zSzrdaZzJfa6/xC4luQJ\nfAjeQ/YBnjOz/8QG7GR6zq3N7HHgIYLfwcgEfkw/4HqC98idMj1nvsQGIkwHbqSeBH7MtsBFwAwz\nuyg2IC6X8XyH4PV4HskT+ACtCZLcn5jZebmMIUceIqhcE1dGMOgxVcez5d9iJ89J/NjguCsjNv0+\nm9+XAlkf0VbX+4GIiIiIZEBJfBERERGRuoVnpUP0DKmotjHZnjw24/h2UksyxrUEbjKzyzI45UsE\nM8gT7W1m22XQV6KogQApXSCPrVP7AXBKhufeB3g1NhszK7Gy6/cTVEVIu7JZrHTyBNJLLiSTyvn7\nEAygyCbx0xm4M1b6Nydiz+nrwJBc9ZmGPYGhWfYxGHjZzMIzNyXGzEYDzxCUv8/ETsBebFnhJF29\ngSfM7GdZ9JEzsYoHT1J38r7WYcAvzewv+YkqjUDMDiX4G7F9mod+C3g49v6Z7jl7Ay8DR6d56CDg\nFTMbke4588XM/kCQbI4qRV6flsBVBK/ndP4fqCue3xL8HU5nIFoz4EYz+2UuYsiVWEL8v6HmVCqM\nJNt3Qr6XcYm5hWC5gkR9Car0lLKo5U+Oy+R3XERERESSUzl9EREREZEkkpRXhaDEeNgDwE1Ai4S2\nw82sh7svyPD8P6J2yfC4jQQzuecSrCvdiyA52Tphn4vN7JN0zunubmb/AX6TGApB0jmTQQHxctrH\nh5rXUvuCe9Sx+wFPseXPlWgG8BGwhCBRvRXB41AZ2q898LSZHeLur6QcfG1/AL4ZalsKvE2wdrAD\nW5N8ZvkdBOtfR5kHTCZYT3otwczvdgQzMYcQJHFyZTnB4/ZV7PsaglnOAwjWQ45K+v/CzOa4+01Z\nnrs3wczqtgltDrxPMOt6KcHAgT4E1QAKZT7wCbCM4DFpDnQiSCZHzTBsBfzHzGZkUu68kRsO3E2Q\n8IurBt4hSFitIKgI0Y/g9ZaOr4CPCX7nlxNMTugQ62cAtZP+5cANsefpsTTPlTOxigC3suXv1krg\nLWABwXt6T2BvopOqPzGzp9z9iXzHGsXMhgEPs+V78TqC+OcTvGd1Ixh40SWii6MJSuGPTeOcnQnK\n4CcbRDabYIDXIoLHrA/Bmu7xRGJ7YBzw01TPmS+x6i3/l2TzJoLHcQ6wgeBvyO5EV2k5HHjMzA51\n9+os4vk+tddjX0zwOxr/W9ab4PlsFdHFVWb2vLu/n2kMeXAnWw6eHGhmu7v7W3UdZGY9gYMj+so7\nd98YG6x5W2jTxWZ2h7uvLUQcGXiPLZcvABhIUCnoqsKHIyIiItI4KYkvIiIiIpJc1Frnr8XK52/B\n3ZfH1osdndBcQZD8TnsGs5nFSwKHbQAuBW5x9+WhY9oApxMkmuNx30z6FbjuYsskPgQz6TNK4hM8\nJuEk/KPuvqqug8ysB0Hp8vCxawnWJR8bNVMuVob2mwQXkvskbGpBkHTd2d1XpPUTBAYD+ybcf5/g\ngvUL4WSKmTUnWHs+se1Q4MiIfu8GrnH3j5KdOFZCeRjBGu4nxmJJRzXBDNqHgKfcfWYd52pP8Hz/\nltql9/9oZuPd/eM0z79FHwTJcQhezzcAN0YNdjGz7cmuikBdVhPMin4YGO/uUTML43H0JVjX/acE\nyf24ZsB9ZrZjfP3tEnQOcEHs+94EyZdE9xGsy1yfOn9fQ/7C5gFNK4ErgH+5+7LwjrHk8Fd19LUB\neI7gtfusu89PtqOZdQe+S/B7GR7Ic7uZDXH3han+EDnUjqBaS/y1/CnBe+yT7r7FOtKx968fEbyP\nh9/7xprZ/9y9Js/xRrmfzRVh5gC/Ax5w93WJO8Vm4p5KMFAnnMy/yszudPeVKZ7zzwQDPcJeBy5w\n9zfCG8xsK+AXwM8IHu+tgatTPF9emNlBBCXxwzYQ/G78zd2XhI5pQfC383pqvw8fSPD6yfRvcl+C\nQYdxb8T6eyn82or9X/FLgr8Hie/FFQRLAuyXYQz5EK8ktG1C2+kEAyTq8h22/NnWELzfFMpdwIUE\n1SPiehC8L19TwDjS8Syb/64kujL2/+sl7j6vwDGJiIiINDoqcyQiIiIiktyYiLaosvlxUeXho/pI\nxd+pPfttGbC7u/8hnMCHoJysu/+NoFT43FhzFzYnTFPi7p9R+6L39ma2Zzr9JIhaT73OUvpmZgSP\ndbfQpsnAru5+QbJSt+6+yd3vJZgN/HJo8zZkMKgiphubP0PdGovj2ajZkO6+0d2fDTWfFtHnr939\n23Ul8GP9Vbv7RHe/xN2HEMzGnJZCzBuAfwKD3P0gd/9rXQn82LlWuPvNBDObw49fC4KETjbipcSX\nAHu5+0XJqlW4+3R3n5rl+cK+IkhI9Xb3U9z9gboS+LE4Zrr7hQSzfL8Mbe5HkLQsSe6+2t0Xx37G\nWkl0YEN8ez23DWmcNv4czwCGuvu1UQn8WHyT3P2LiE2riK277u7HuPsddSXwY30tdPergB0JEuWJ\nOgE/SeNnyKWObE5o3wvs4u6PhhP48PX711jgKKAqtLkPwfIYxTAg9vVFYGd3vzOcwAdw9xp3v5tg\nEFM4Wd+WFH9XYoOeopYd+SuwT1QCP3b+L939AuBQgkoBAP1TOWc+mFk7ggos4QoRS4CR7n5FOIEP\n4O4b3P0egvfhqNnuF2exVMA2bK7sch2wt7u/GDU4JPZ/xaVE/y+zr5kNzDCGnHP3qHXsT4kNqqtL\nuJT+Q+6+OneR1S32P8TFEZv+LzaorhQ9T1C5JspZwEwze9zMvm9mRfv9ExEREWnolMQXEREREYlg\nZttQu1ToRoKy+ck8Re11Qncws13TPPd+wEGh5hrgWHf/oL7j3X0WcBjBjPVMRQ1WiFrXvk5m1ovQ\njHSCJOjz9Rw6mtrlbWcDB7n75FTOHUsaHkuQSEx0emzGbqaeAL6fwWzYvUP355DhLDt3fybF18Ib\n7v4Dd5+ewTmWEJTA/jy06WQz65pufyHVwFHuHp4Vnnfu/hd3vzJqIEwKx34EHEHt5GoqM9mbmtXA\nwbH3o7S5+0PufqG7hwdNpHLsHIJk9/LQph+mkNDLp+eBb7n7xvp2dPcJBBUNwoo5YOQTgt/beiuZ\nuPsnRC8Hk2r8v4poeww4J5asre/848ngb1Ye/IBguZtEmwgexw/rOzj2Pnw4tddNryA68ZuOW2MD\n4lJ5PP9DsDRBWKkNYAon8TsBxyTb2cxGUruyTUFK6Yc8BEwMtXUk+0FzeRF7zfyC4H/TKBUE/z/8\nA/jMzBab2f/M7PdmdpSZpTW4VERERKSpUhJfRERERCTaGdSeOfdEstmkEMygJCg5HHZmmuf+fkTb\nv9z91VQ7iCW6/5DmeRPdSzBoIdHJGSTAvkXtzx13R81eD/lFRNuZ6Sb0Ysmms0PNzQlKo2diLfCj\nVJIeEXqE7r+dwUCAgvJgyYPw0gotCGa5ZuMvXs86xaUqNnjiX6HmXcxs62LEU8J+V1/Vh3yKzdoP\nvwd2B3YrQjgA64Ez0nzviFo/vljxA5weNfu+Dv+i9mCyEbGS+0mZ2QBqD/5aS4oJ/Dh3fwh4PNX9\ncy32c0b9rbkxnfc/d18E/Dxi0zFmtm1Eeyq+JFgeJB03RbQV8/VYS2zA2muh5vBM+7q2zSGoNlFQ\nsdd1+G8twM+yHHSYN+7+DKkPMuhMMADudwQDIReb2SQzu87MdslXjCIiIiINnZL4IiIiIiLRoi76\n1lVKv659To2tb1svM2tLsJ57ohoyW/v2T6S3jvXX3H0p8L9QcyeCEs/piJoJWV8p/d2BPULNz7l7\nRhfW3f1papd9TTozrx4PJin9nYqK0P1mGfZTaI8TJCAT7Z5Ff06w1nVDFlWRI5vHpLFZTbCMQ7GV\n0vP0n/qWAwhz98+BKaHmgWYWXmqlEJ5Pt3JGrCR5eEmOtsD29Rx6CrUH0d3j7nOjdq7HHzM4Jlf2\nB7YLta0Drki3I3d/GHg31FxG5kv2/MXd16R5zASC9eITlWICNjyT/oio6jGx/8tOCTX/O8NBelmL\nJcVfCjW3ITq5XxLc/XrgJNL/X9MIln46H3gvltAfnev4RERERBo6JfFFREREREJi5ez7hZqjktq1\nxGbXfRZq7khQ1j0Vu7F5rdq4lzJJHMdmTD6c7nEJopLtKZcnjs2u2iHUPMnrWf+doHRw2H9SPW8S\n4fXph5pZmwz6uSeLGBaF7h9gZp2z6K8g3H0ttZckSGuJiJA3izlDO0c+jmjL5jFpbB4v5JrSycRe\nZ+GZ4MV6njJ9Lw6XXC8nqChQaLmKH6BnPceEB3EB3J3h+V8DZmV4bLb2jWh7LJXlCJKIKvMedY5U\npP18xirHhAfE9aivskIRPMCWA88qgNMi9juGYHBiojoHGRbARRFtP8yi4kLeufuDBANzbgY2ZNjN\nUOC/ZvZybDkrEREREUFJfBERERGRKFHl7+9PZR3jmKjZ+KmW1B8Z0fZUisdGyebYJ4ElobZ01jI9\nPaItlQvkUUmJ51I8ZzLvh+6Xk1ky750sYngzdL8d8LiZ9cmiz5wws2Zm1tHMukTdgJWhQ7plcbq3\nszi2YMyspZl1TvJ4hGcJQ3aPSWNTsOfYzJqbWac6XrvhJVCK9Ty9keFxCyLa2mcTSIYKGX/472BV\npuePzaoOl1cvlKjBCE9k0d9jEW0jM0iiL6d2hYdUhZ9PAyoz7CsvYoMkHg01R1VXCre94e7hQZgF\n5e5vUPs10hy4tPDRpM7dF7n7uQQDdM4GXgE2ZdDVvsDbZhb1v7CIiIhIk6MkvoiIiIhIgtjs7HA5\ne0itlH7cfwhKhic61My2SuHYYRFtH6Rx7pwdGxu0cH+ouTlwcn3Hmlk5cGqouYp6ZrLHkhHhxMcK\nd/+yvnPWIzwYASCV5yPRPHcPJwTTETWAYU9gqpnda2bHmVnrLPqvVyxZf3RsHdpnzewLM1sHbCSo\nNvFVklv4OemQRRj1VWIoKDPrYGZjzOwWM3vdzBab2UaCsteLiX48FkZ01aFQMTcAOX+Ozay1mZ1k\nZjeZ2QQzW2hmGwhmfi4h+Wt361BXHXIdWwrWuPvyDI+NKlPdLotYMjUvw+PSij9WnSRcoWSKu2c6\nwxdgUhbHZiOq1HxaSxIkcvc51P5bVkntykH1+SKLkvGl8nqsT7hqwS5mtlP8Tmyd+XDVn6hKB8Xw\nG2r/D/kdMxtcjGDS4e5L3f2v7r4fQRWqw4HLCSpZRQ3oidIdeMLMeucpTBEREZEGQ0l8EREREZEt\nnUiwZm+iz2Ozo1Li7rOAV0PN5aRWij5qlmg2M8OmUfticDqiBi+k8nMcSu2Sz8+6e1TyM1FHgjVg\nE7U3M8/mRlBVICzVigJxS9Pcfwvu/iS1y/pDMDDiFIKZg8vM7FUz+0Msqd8xm3PGmVmlmV1DcBH9\ncYJ1aA8hmDUXXr4hFdnMBM7qccwVM9vOzO4neExuB35IMKiiM9Asgy6LMTu6VOXsOTazbmb2d4KB\nE/cD5xKsNd6N4HcnXcV4npZncWxVRFt5Fv1lanmGx6Ubf9R7XtrLyeT4+EyFByM4MDXLPqNm0HdJ\ns4/lWZy/VF6P9XkOCA/+S5x5/y2CMvtxG6g9aLEo3P1D4N5QczlBMrzBcPc17v6Mu//O3Y9y962A\nbQiqNN1L7aVOEnUlvcGzIiIiIo2SkvgiIiIiIlsaE9GWyYXEqGOi+g7rENGW6fq5uHs1kPHa1O7+\nJrUHEexpZtvXc2hUoj+VUvrpJtazkW6CPFxSPhOnAq/Xsb05sDfwfwRJ/cVm9p6ZXWxm22VyQjPb\nnyBxdCG5e3wr6t8lqVw8jlkxs3MJ1nY+CWiRo24zSfw3Vjl5js3sRIL3nx9Qe3BVporxPEUlPhsU\ndy/Uz9Ahoi3b11PB33NiVX3Cr7XVsb/J2YiqBpPu37IG/3qsT+xx/k+o+VuxKkFQu5T+uCyqZeTD\n76j9PI02s0yWASoZ7j7X3f/t7qcBPYCLgTVJdt/fzA4oXHQiIiIipUdJfBERERGRGDPrC+wXsenp\nZOst17EO83hqrwc62Mx2ryeMqLK0UeVr05FtAiOt2fhmVgkcH2peAYxL4VyFTOKnm8zLNvmCuy8F\nRgG/JbXnpYygJPNlwOdm9rCZ7ZDq+WIXwP9H+ksH5FPWj2M2zOz/gJvIrAKBpCbr59jMTiOYrakK\nB01L1KCajVn2mU0p/kxFvW5zMZggqo8OOei3MQqXx+9BsLTRLsDO9exbVO7+OXBrxKYrCx1Lvrj7\nKne/AhhJ8moZPyhgSCIiIiIlR0l8EREREZHNxgAW0f4myddbTnabTnSSeEw9MUQlKzIpGZ0o25nG\n/6Z2Sf5v17H/N4FWobYH3X19lnE0Cu6+yd2vJCgr+2PgZVJPUn0DmGhmP6pvx9hgiruB1hGb5wC3\nEMxG3BfoSzCAojVQ5u6WeANeSjG+kmZmI4Grk2x+D7iG4PU7kmAt9Q5Ay/DjEXtMJE/MbFvgn0SX\n6Z4C3EhQ1WJPgt+jjkCrJM/T7ELFLTkRlaSuzLLPYqzZHvWenosqEFF9FGOQQslz90+AiaHmM6g9\nC38B8ExBgkrPZUD4/6ZDzWxU4UPJH3f/FDiB6KWfNBNfREREmrRsSiCKiIiIiDQaZmbUvrCbD6eY\n2c/rSGgvj2irpO61Q+uTVQLD3Web2csE61DHbWdm+7j7qxGHZFpKH2BJRNtbwNEpHp+ObB7TrLn7\nCoJE+i2x0st7xW77EiQno5LvEAzK+JuZrXP3umYP/oraM/DXAOcA/06zrHNjmbV+A7UH6nwOnO7u\ndS1zsAUzayyPR6m6itqv/6+As9z9iTT70nPVsESVi++QZZ/FqOawPKItF3FE9RH1mEngTmBEwv3j\nqL3E0N05WOYg59x9vpndDFwQ2nQVwf8KjYa7v21mTwDHhDZ1N7Ot3P3LYsQlIiIiUmyaiS8iIiIi\nEjgA2LYA5+lA7VLziZZHtPXI9GRm1pHsZ/JDdBK+VrLezHqzZbIfYCYQleyPsjiibWt3X5yHW1GT\n+IncfY27P+fuv3f3gwleJwcBY0meoLkx9vwmc3JE24nufkcGCYvOae5fcmKvzXDiYymwfzoJ/JgG\n/3iUqtgAieNCzRuAQzJI4ENhl+iQ7C0GakJtg7Psc0iWx6fN3auonSxuYWbZJvKj/h9QEj+5e9hy\naaOWQJfQPiVVSj/kD9SuTrGnmYWT3Y3B00nauxY0ChEREZESoiS+iIiIiEjgzAKea0wd26ZHtIXX\nbk1HNscm+i+wLtR2kpmFS/V/i9qfM/7t7lFlUqOsoPYF663NrEOKxzcKsZL74939PKAXwYX8sPbA\n6VHHm9n2QL9Q88vu/lS6sZhZM4LS8g3dYRFtN7t7srV469Iny1gkuX2ANqG2+939g3Q7MrOtyU0J\ncymQWJWaT0LNXc2sVxbdDs/i2GzMiWgbmmlnZtac6AENczPts7Fz9yXAk3Xs8r67f1SoeNIVi/+6\niE1Xmllju6abbOmTtgWNQkRERKSENLZ/+ERERERE0hZbO/yEiE17R62xnM6NYNZceNbzIbHkUpR3\nItp2z+LHy+bYr7n7SmBcqLkDtUufRpXS/3ca53Fqr79uRCdgmwR3X+vuFwE3RWw+MMlh20S0PZ9h\nCLsArTI8tpTk8jHZO5tApE56nuStiLbjM+nIzDoTLFFSDG9GtI3Mor9dqD0oZbq7R1Wwkc3qmmlf\nyrPw464nWE4k0U7AKUWIJZ/CS93E6fUtIiIiTZaS+CIiIiIiQdnx8PrLs4E3su3Y3RcC40PNZSSZ\nQU108mJ0FjOuokqqZ6rOkvpmNoLaZYtfd/eo6gJ1CT9eAKem2Udj9JeItqiEJ0C3iLaFGZ73+AyP\nKzV6TKAqoq3UrgvoeZKoQRtnmVmyJF9dvk3xqjFEJfGzSbxG/R2MOods6UmiE8GbCMrtlzR3Xw1c\nFbHpslilnMaid5L2TN//RURERBq8UvuwLiIiIiJSDGMi2u5NowR8faIuEp8RtaO7T6N2KeEewLHp\nntTMdiW3ZYSfBRaE2o4ws/j6slGz8KMS//V5lNrJxmPNrFglkUvFrIi2ZDPkN0W0VaZ7QjNrA/wg\n3eNKVK4ekz2APbMPpyhWRbSVWpWFXD1PvYETsw9HiuBRYEmobReSD36LZGadgItzFFMmnqF2JZ4R\nsQFvaYlVDPpWxKb/ZRJYU+Lum4BzgCtDt3PdPTzDvVT9jdrLJvQDvluEWPLlqIi2Oe6+ouCRiIiI\niJQIJfFFREREpEkzs/5El1y+N4eneRhYH2obaGbJEoH/imj7Y2w93HTckOb+dXL3amoPSGgGnGJm\nFdSeJbgBuD+D88wC7g41G/AvMwtXTGhKeka0hQdVxC2KaNsng3P+CeicwXGlKOvHJPY7eEtuwimK\n1UB4cNJWxQikDrl67f4dqMgyFikCd98A3B6x6QYz2ymVPmK/q/+hiO9f7j4HeCJi040ZdHcJ0CXU\ntgB4KIO+mhx3v9/dfxu6/b3YcaUq9jvx+4hNRRmkYmY/MbOBOexvN+CIiE1Rvz8iIiIiTYaS+CIi\nIiLS1I2JaPvE3T/M1Qli68lHzZY7M8khdwErQ239gX+kWk7YzK4ms8RXfZKV1D+U2mWwH3f35Rme\n5ypqz8bfBbg/Njs8Y2a2f2w2dcGYWd/YRe9sZj1HzYj/KMm+k6g9o/kYM9sx1ZOZ2feAH6W6fwPw\nTkTbT1N9TsysHPgHMDSnURVQbCDOzFDzjrFBOKUi6nk608y6p9qBmV1OdEJIGo6rqD1IqSPwnJkd\nVteBZtYTGMfm10BN7sNL2diItr3NLKo8eiQz+wbws4hNt7j7xkwDkwbnDuCzUFvU4L5COBb4xMzu\niA2GzZiZ9QHuI/oadS4H1IqIiIg0OErii4iIiEiTFVtnPqo8bz7WSI26EHlSVALR3ZcCv47Y/wzg\nLjNrn+wkZtbSzG4AfpXQnLMEhrt/AIQHOIwkmCUYlkkp/fh5PgMuiNh0NPCOme2bTn9m1tHMxpjZ\nW8AEIOVkdo60J1jTfraZXW5mO6RzsJn9GLgwYtN9Ufu7+yrglVBzOfCYmQ2o51wtzewK4J+JXaYR\nbql6iWAmeqK+wENm1q6uA81sK+ARNi+D0ZAfjw9C99sDJxcjkCjuPgWYEWpuDzwZex6SMrMOZvYv\n4LeJXeY4RCkAd18GnB2xqTvwtJk9HXtPH2ZmPc1soJkdbmZ/BSYDhyccE5VILwh3f4HoijQXmdl1\nZtayruPN7LsE/z+UhzZ9BlybmyilIYgNwirm8hBh5QR/E6ea2Ytmdno6gyzNrMzMxgBvAdtF7PKg\nu7+am1BFREREGqZSGm0vIiIiIlJoBwO9ItojE6NZeoJgdn1isrA98A2iBw38DTgN2CvU/m3gYDO7\nk2B2/zyCNXd7Efw8Y4A+Cfs/AgwHts32B0jwb2onD0aG7n8FPJXNSdz9xtjawd8JbRoMvBxLyD9C\nkKyeDSwlGLDQgWDG5iCC2ft7AqMISv8XW1eCBONvzexT4HlgIkFidRGwjKACQTuC9W73JhhosktE\nX0+4+2t1nOsa4MBQW1/g/Vii8yHgY4LXZcfYtiOAs4BtEo55l2BphKhlJxoMd19nZmOBi0KbjgA+\nNbObgaeBzwl+3q7AEIIZh2OAtgnH3AT8NN8x58kTBO87iW4zs12AJwlm6q+hdvJ7VaykcyFcQ1AO\nP9EIgpmffwUeB6YAawnKpfcnGOBzFluWHB8HDCO3739SIO7+sJldQnQZ8cNit/r8BXiU2r+v4Uov\n+fQTYF9qz5o+HzjOzP5J8Pd8LsF7T09gf4LXc9T7bjVwuruvzVvEUqoeJBikGfU/QbEYwf9Yo4C/\nm9lE4DXgbYJqGouB5UBrgv81BgO7A6NJXklgPsHvh4iIiEiTpiS+iIiIiDRlUeXs33T38CzQrLn7\nejNLnMmbGEOtJL6715jZicCbQO/Q5h7A/8VudZlBUIL93YyCTu5u4A/UnhmY6F53z0WS5HtAc6Jn\nCu8euzVUQ2K3TMwkurz+19z9WTO7Fzg1tKk1cF7sVp+FBI/9bZkEWYL+AJwAhNfy3Rq4OnarzysE\nv3sNNYn/APAngmRKXHPgF7FbMmcSlHMuhNsIBiyFK250BH4Tu9VnCkES9L3chiaF5O6XmVkNcBlB\nsjAdtxH8noYHM0HtJWvyxt2XmtmxwAsEg/cS9SN4X/pDqt0B33f3t3IYojQQ7u5m9huil2gqBS0J\nBp5kM+hvPnCAu8/LTUgiIiIiDZfK6YuIiIhIk2RmHYDjIzblc/3NqL4PNLNwkh4Ad59PkMSanMG5\nPgcOdffFGRxbJ3f/kmAGeV0yLqUfOtdGgiT0BcD6XPQZ05BnML4H7B17HupzFjA+w/PMAQ7Ox6CW\nYnH3lQQztudm2MV44JgCzkjPOXdfDXy/2HHUJTYA6ATgowy7+BA4JLY0iTRw7n4FQVWaVF8Pi4Cz\n3P27sRLkHSP2WZGr+FLh7hOBfQj+NmdqJTDa3W/PTVTSELn7U9ReLqfQPiSoGpFr/wV2iy2pJCIi\nItLkKYkvIiIiIk3VKQQzhhJVE712ba48T5BcSFRGUC49krvPBnYDbgQ2pXCOaoLZh7u6ezbJgvrU\nlaT/NJawyAkPXEdQHv82Mk/mzwL+CAxx96glDPJpGsHM9+dJ7XmMsiDWx8gUE/i4+3qCktNXk/oF\n9xqCGdfD3f3jDOIsae4+nWCJiUfTOGw18GvgMHcvaPIvH9z9IeAYIKXXUTHEBiDtAfyL4H0tFRuB\n64C9NIuzcXH3NwmWRhgF/JVgGZIvCcrirwamEvz9Ph3YNpTo7hrR5fL8RRst9n66M8H78eo0Dq0m\nqNizg7s/ko/YpMEJLwtTUO5+IcHv1SkEr80vsuhuA8HyPge7+4mxAay59AXBshrhm4iIiEjJM/fw\nMnciIiIiIlKKzGxrggTFIQRl2DsRJFyXEszWfxG4pxAzp82sFXAh0eWN33D3Z/J47o4ECcgDCZKx\nfdlyvfJqgsfkM4LHZSLwgrtPy1dM6TCztgTLAOxFkJTanuBnqEzYrYYgyfQJ8D5B6dznY7NKMz1v\nD4LlCQ4mWE+3XcLmxbFzPQ/c7e4zQ8fuSbD2+Nfc/YlMYykVZjaMYImLAwhK7McH9jjBbP1JwFME\ny0OsCB17dKi7Je7+Rj7jzTUza0ZQmeBQNq8dXwm0ofbv9pnufkch44szs37Adwl+53ckiC/uS4JZ\noc8SvHYXho49kGAJibi17p5pdQppgMzsdmBMqHmwu08pQjgAmFl7gjXBDwdGEPzuxZeocYJBWx8Q\nVP+4393nFCNOkVSZ2XYEZfR3AAYQ/G/TheBvSiuCCkirCKpgzCB4fb9P8L/N8iKELCIiIlLylMQX\nEREREZEGz8xaElwk3uDuDbJUvpmVEyQnHVjtef6wFkvgtgHWuHum1QEaFTNrDVQQPCYZD5iQ/Ir9\nvrcgeJ6qih2PlC4zM2A6sF1C8yqgfb7fY9NhZmUEg9HKCN7/9boWEREREWniKoodgIiIiIiISLZi\nZeMzLbNfEmJJ45UFPN8milBSupQ11AEgTU1j+H2XgjmMLRP4ABNLKYEP4O41FPD9X0RERERESl9Z\nsQMQERERERERERHJJTNrDlwZsemBQsciIiIiIiKSLiXxRURERERERESk5JhZxwyPKwf+CgwPbVoN\n/CfbuERERERERPJNSXwRERERERERESlFj5vZvWa2Z2x9+3qZWX/gKeC7EZv/6u6rchqhiIiIiIhI\nHliJLQMmIiIiIiIiIiKCmb0J7B67Owd4AngX+BhYCqwBKoEuwG7AwcBRRE9a+RjY1d035DlsERER\nERGRrCmJLyIiIiIiIiIiJSeUxM/GAuBgd/8kB32JiIiIiIjkncrpi4iIiIiIiIhIY/URMFIJfBER\nERERaUiUxBcRERERERERkVJ0DzAtw2NnAz8hKKE/N3chiYiIiIiI5J/K6YuIiIiIiIiISMkys8HA\nvsBIYHtgG6Aj0BowYBmwFPgCeBWYALzm7puKEa+IiIiIiEi2lMQXEREREREREREREREREREpESqn\nLyIiIiIiIiIiIiIiIiIiUiKUxBcRERERERERERERERERESkRSuKLiIiIiIiIiIiIiIiIiIiUCCXx\nRURERERERERERERERERESoSS+CIiIiIiIiIiIiIiIiIiIiVCSXwREREREREREREREREREZESoSS+\niIiIiIiIiIiIiIiIiIhIiVASX0REREREREREREREREREpEQoiS8iIiIiIiIiIiIiIiIiIlIilMQX\nEREREREREREREREREREpEUrii4iIiIiIiIiIiIiIiIiIlAgl8UVEREREREREREREREREREpERbED\nEBERERERERGR+pnZAOAUYCegDbAAmAA85O5rihiaiIiIiIiI5JC5e7FjEBERERERERFpMsxsG+Db\nCU0vuvsb9RxzBXAhUB6x+QtgjLuPz12UIiIiIiIiUixK4ouIiIiIiIiIFJCZXQL8LqFpB3efUs/+\nlyQ0JV7MsdjX9cBh7v5KzgIVERERERGRolASX0RERERERESkgMzsXWB47O7r7r5PHfv2Bz4FyhKb\nE773hLa5wCB3X5fDcEVERERERKTAyurfRUREREREREREcsHMWgE7EyTfHXiinkPCJfRnAT8GDgfO\nBxYmbOsFnJ2rWEVERERERKQ4NBNfRERERERERKRAzGwE8E7srgP7u/urSfZtQZCkrySYaf8VsKO7\nf5WwT19gItA+ts/H7r5z/n4CERERERERyTfNxBcRERERERERKZy+ofsf17HvKKBd7HsHbkhM4AO4\n+0zgRjaX2N/BzHrkIE4REREREREpEiXxRUREREREREQKJzHBvtHdl9ex76jY13iC/u4k+z0cuj8s\n7ahERERERESkZCiJLyIiIiIiIiJSOG0Svl9Vz757J3w/2d3nJtlvMlBNMFsfYJsMYxMREREREZES\noCS+iIiIiIiIiEjhlCd8X5FsJzNrBuxKkJh34JVk+7p7FbCSzTP222cfpoiIiIiIiBSLkvgiIiIi\nIiIiIoWzMuH7SjNLlsgfCbRkc2L+1RT6js/Eb5ZhbCIiIiIiIlIClMQXERERERERESmcRQnflwFD\nkux3ROh+0pn4MYmz79ekG5SIiIiIiIiUDiXxRUREREREREQK56PY1/is+eOS7Hdqwj5z3H1Osg7N\nbCu2vMazOKsIRUREREREpKiUxBcRERERERERKZypwMLY9wb8zMz6JO5gZj8E+sbuOvBUPX0OS+gP\nYEbWUYqIiIiIiEjRJFt3TUREREREREREcszda8zsXuBnBAn6jsA7ZvYPYCYwAvhubJvFvt5VT7d7\nhO5PzmXMIiIiIiIiUljm7vXvJSIiIiIiIiIiOWFmXYHPgHbxJjaXzg/ff9HdD66nvw+BHWN3p7r7\n4ByGKyIiIiIiIgWmcvoiIiIiIiIiIgXk7l8BpwPV8abYV0u4b8Ay4Pt19WVmgwkS+B67vZzreEVE\nRERERKSwlMQXERERERERESkwd38cOBSYRpCwjyfw49+/C+zr7jPr6eq8hOMAnsxxqCIiIiIiIlJg\nKqcvIiIiIiIiIlJEZrY7sAvQEVgOvO3uE1M4rhy4C2gda3LgW+6+Lk+hioiIiIiISAEoiS8iIiIi\nIiIiIiIiIiIiIlIiKoodgIgkZ2YzgXbArCKHIiIiIiJSavoAK929b7EDERFp6nT9QkREREQkUh8y\nvHahJL5IaWvXqlWrToMHD+5UrABWrVoFQGVlZbFCkAZMrx/Jhl4/ki29hiQbev2UvsmTJ7NunSqG\ni4iUiKJev9DfbcmWXkOSDb1+JBt6/Ug29Popfdlcu1ASX6S0zRo8eHCniRPrXQoxbyZMmADAqFGj\nihaDNFx6/Ug29PqRbOk1JNnQ66f0jRgxgvfee29WseMQERGgyNcv9HdbsqXXkGRDrx/Jhl4/kg29\nfkpfNtculMQXERERERERESkRZtYJGAlsBXQCWgPm7pcVNTAREREREREpGCXxRURERERERESKyMxa\nA98Fvg/skGS3yCS+mV0FtI3dneLuf819hCIiIiIiIlJISuKLiIiIiIiIiBSJmR0D3Ap0BizJbl5H\nF5XAT2Lfrzez/7j7yhyGKCIiIiIiIgVWVuwARERERERERESaIjP7HfAo0IUggR9O1teVvI+7KXas\nAS2BE3MYooiIiIiIiBSBkvgiIiIiIiIiIgVmZj8CLmXL5H0N8BbwN+Bpks/M/5q7TwMmJTQdmcs4\nRUREREREpPCUxBcRERERERERKSAz6wVcT5C8d4Jk/aPAQHff093PBsal0eWj8a6BA3IXqYiIiIiI\niBSDkvgiIiIiIiIiIoV1KUHp+7gb3f0Ed5+RYX9vJXzf3sz6ZRyZiIiIiIiIFJ2S+CIiIiIiIiIi\nBWJm5cBoNs/An+juP8+y2w9jX+Nl+Qdl2Z+IiIiIiIgUkZL4IiIiIiIiIiKFszvQniCB78A12Xbo\n7l+yOYEPsHW2fYqIiIiIiEjxKIkvIiIiIiIiIlI44VL3z+Wo35UJ31fmqE8REREREREpAiXxRURE\nREREREQKp1vC96vcfWXSPdOTOBO/WY76FBERERERkSJQEl9EREREREREpHAs5x2aGUGJ/rhluT6H\niIiIiIiIFI6S+CIiIiIiIiIihbMo4ftKM2uegz4HsuXggCU56FNERERERESKREl8EREREREREZHC\n+TJ0f7cc9HlA7Gs8kT8lB32KiIiIiIhIkSiJLyIiIiIiIiJSOK8DG9m8hv0pOejzRwn9LXb3j3PQ\np4iIiIiIiBSJkvgiIiIiIiKSF1XVNTw0cV6xwxApKe6+BniVYNa8AWeaWb9M+zOznwA7xbsHns46\nSBERERERESkqJfFFREREREQkp+LJ+4Ovf4lfPPhBscMRKUU3xr460Br4r5l1TrcTMzseuC7Wj8W+\n/ilHMYqIiIiIiDRYVdU13PrqTG59dSZV1TXFDidtFcUOQERERERERBqHquoaxk2az9jx05i1ZG2x\nwxEpWe7+uJm9DuxFkHgfCrxjZj9z98fqO97MugMXE5TRj0/QcOBhd/8oT2GLiIiIiIg0GOMmzefy\nJz4FoEOrZowe0avIEaVHSXwRERERERHJipL3Ihn5NvAG0I0gAd8HeMTM5gHPA50Sdzaz04HtgP2A\nvQmu6cRn3xswE/hBgWIXEREREREpWVXVNYwdP+3r+2PHT+O4YT2pKG84ReqVxBcREREREZGMKHkv\nkjl3n2VmxwJPAR3ZnIzvDYwJ7W7A7aH7JByzADje3ZfnMWQREREREZEGYdyk+Vtcp5i1ZC3jJs1v\nULPxlcQXERERERGRtCh5L5Ib7v6OmQ0H7gP2IEjKb7FLwveJiXtPaHsX+Ia7f5HPWEVERERERBqC\n8Cz8uIY2G79hRCkiIiIiIiJFV1Vdw0MT53Hw9S/xiwc/UAJfJAfcfQ6wD3Aa8B5BYj7qFhe/P51g\nxv4eSuCLiIiIiIgEwrPw4+Kz8RsKzcQXERERERGROmnmvUh+uXsNwWz8+8xse4J173cHugOdgObA\nUuAr4APgeXf/sEjhioiIiIiIlKRks/DjGtJsfCXxRUREREREJJKS9yKF5+7TCWbZ31bsWNJhZuXA\nbsCOQBeCagFfAR8B77p7dRHDA8DMtiWIsTvQAVgLzATedvecT8kp9PlERERERJq6ZLPw4+Kz8UeP\n6FXAqDKjJL6IiIiIiIhsQcl7EUmVmbUD/g/4AUHyPspiM/sHcI27ryxYcICZNSeI7RxgYB37vQ78\n0d3Hlfr5zGwU8GKGIca95O6jsuxDRERERKRk1DcLP66hzMZXEl9EREREREQAJe9FJD1mtgfwENCz\nnl27AL8GzjCzb7r7m3kPDjCznYBHgH4p7L4X8KiZPQSc6e6rSv18IiIiIiKyWX2z8OMaymz80h5i\nICIiIiIiIgVz5xuz+cWDHyiBLyL1MrO9gReITuCvBzZEtG8NvGBme+UzNvh6tvprRCfUq4FlQE3E\nttHAc2bWtpTPJyIiIiIim6U6Cz9u7PhpVFVH/XteOjQTX0REREREREREUmZm3Qhm4LdOaN4E3Ajc\nAsyItfUDfgScBzSLtbUGHjKzoe6+KE/x9QYeBCpDm+4FbgbecvdqM6sA9gZ+DhyXsN/uwK3AyaV4\nvgjTgOvTPGZ+hucSERERESk5qc7Cj2sIs/GVxBcREREREREAzthzWzq0aqZy+iJSn8uB7gn31wHH\nu/uzof2mAxeY2fPAw0CrWHsP4DKCBH8+/JOghH9cNUHJ+n8n7uTuVcBLwEtmdh7BIIS4k8zs1oif\nqRTOFzbf3W/J4DgRERERkQYv3Vn4cWPHT+O4YT2pKC/NwvVK4ouIiIiIiAgAFeVljB7Ri+OG9WTc\npPlK5ovkiZmNz3GXTlDCfnnsNhl4B3jf3Tfm8kRmth1wZqj5N3Uln939aTO7GPhTQvNZZnaNu8/M\ncXy7AoeFmq8OJ9QjYrzJzAaz5cCCq83sOXf3UjmfiIiIiIhsKd1Z+HGlPhtfSXwRERERERHZQjyZ\nf/TOW3HEja8wY/GaYock0tiMIki859tyM7sDGOvus3LU5/lsLo0PwWz7m1I47s8ECevtY/ebEZSV\nPy9HccX9IHR/MXBFisdeBJzO5mUChhMk6J8uofOJiIiIiEhMprPw40p5Nn7pRSQiIiIiIiIl4a8T\nPlcCX6QwLHTLdv94e0fgZ8BHZva9nES65VruALe6e3V9B8X2uS3UfHyOYkp0YOj+ve6+IZUD3X05\n8Eio+YQSO5+IiIiIiMRkOgs/Lj4bvxQpiS8iIiIiIiK1vPH5kqxGs4tIvRKT8B66hZP04VvivuG+\n4kn9xL7aAH83s19lFbDZCCBca/KBNLoI79vbzIZnE1MiM+sE9As1v5ZmN+H9jzOzyOtnhT6fiIiI\niIhslu0s/Lix46dRVV2Tg4hySx8KREREREQkqarqGh6aOK/YYUiBLVm9gZ/d/z41SYp9n39wf647\ncSh9OreO3kFE6nNA7HYk8DabE+8GzAH+CZwDfBM4FDga+BZwOUGp9aqEY9YBvyGYEX488H3g78A8\ntkz4G3CFmR2SZdyJFrj7jFQPdvfPgYWh5vBM9mx0i2ibnmYf4auA3YCdSuR8IiIiIiISk+0s/LhS\nnY1fUewARERERESk9FRV1zBu0nzGjp/GrCVrGT0iPPFSGquaGueCBz9g4croatB79evM2Qf2p7zM\nOG5Yzy1eJyKSGnd/ycy6Av8jWAcdYDLwC3d/pr7jzawL8H8E68m3BH4PnOPu/4jtcmtsNvfpwI1A\nW4JEfhlwNfBchqEPCd1/O4M+3gKOTbg/OMNYonSKaFuRZh9R+w8BPiiB84mIiIiICLmbhR83dvw0\njhvWk4ry0pn/riS+iIiIiIh8LZy8l6bnttdm8uLUryK3dW7TnD+fPIzysmACcEV5GaNH9FIyXyRN\nZlYBPAmMIEiuPwac5O4bUzne3RcDvzSzp4BxBOXy/2pmi9394dg+NcAdZvYe8GpsH4BdzGyUu0/I\nIPRBofspz8JPMLOePrMRNfqoRZp9tIxoSzbQoNDnS8rMmgHDgN5AB2AlsBT4zN1VUkdEREREGpVc\nzcKPi8/GL6VJLErii4iIiIiIkvcCwAdzl3PN01OSbr/upKF0a1c736Rkvkjazgd2JUjgTyONBH4i\ndx9vZucCtxHMsv+7mT3r7qsT9vnQzH4F3Bw7H8DBwIQM4h4Quj83gz7Cx4T7zMayiLauafYRtf/A\nEjlfMrsRzOhvFbXRzGYATwB/dvfwIAoRERERkQYl17Pw40ptNn5pRCEiIiIiIkURX/P+4Otf4hcP\nfqDEaxO2cv0mzr33fTZVe+T2H+6/HaMGRi3/vFk8mf/8+ftz3YlD6dO5dT5CFWnQzMyAnyU0XZFJ\nAj/O3e8APo/d7QScEbHbvwhmZcftm+HpOobuL8igjy9D9ztkFkqk+cCmUNsuafYRtX9U2fxinC+Z\n1iRJ4MdsR7D0wjQzG2tm6VYLEBEREREpGbmehR8Xn41fKjQTX0RERESkCdLMe0nk7vz64Y+YszT6\ntTCsdwcuODT1iaHhmfkisoXdgB4J95/IQZ9PEiRpAY4G/pK40d03mtlLwDdiTb3TPYGZtaL2ZJBM\n/oCsC92vMLOW7r4+g7624O7rzWwisEdC8zHADWl0c0xEW9tSOF8OlAPnAHuZ2RHuvijdDmI/b5RB\nq1atYsKECdnEl7FVq1YBFO380vDpNSTZ0OtHsqHXj2Sjqb5+OgMH9K7gxblVkdu/M6Q5B23TLLPO\nV01nwoTpmQcX7i72HGVCM/FFRERERJoQzbyXKPe/M5cnPgxPjg1Utqxg7Km70CyDcnLxZL6IbGHn\nhO9Xu/vyHPQ5O/bVgJ2S7PNRwvfpzvSG6MRyJon3qGPaZNBPMs+E7o8ys+GpHGhmhwBDIzbVlVQv\n9PkSLQbuIqi+MBzoAjQnqG4wCDgLeDbiuOHAo2ZWe30UERERkRJRXeM8M2sTz8zaRHVNdMU4aZpm\nrqhmQpIE/rbtyjigd+OYw944fgoREREREamTZt5LMp8tXMWlj3+SdPs1o3emdyeVxRfJoS4J35fn\nqM/Efrok2SexnH4mSfOohG8mywBsiGirqxR8um4BLiJIZkMwsOF2M9vX3VcmO8jMugJ/T7K5rvgK\nfT4IyvifCjzk7uFy/gArYrepsVj2A+4Btk7YZ0/gMuDCes61BXcfEdVuZhMrKyuHjxo1Kp3uciY+\n+6xY55eGT68hyYZeP5INvX6Se2jiPO6d8gEAu+40WAPEIzTF109NjXPDX1/Dk4wnvuHbezB8m/Aq\nYMVTWVmZ8bElOxPfzLYxs6PM7CQz2y9Wtk1ERERERNKgmfdSl3Ubqzn77vdYv6kmcvu3dt+GI3fa\nqsBRiTR6iVebWplZtxz0uV2S/hNZCvvUJeqY5hFt9Ylajz3rUvpx7r4AuDnUvDPwgpkNijrGzHYB\nJgB9k3S7ulTOFzvnZ+5+X5IEftT+LwN7AQtCm84xs56p9CEiIiJSSFXVNYwdP+3r+2PHT6OqOvpz\nqzQt9787lw/mrYjcdvKuvUsqgZ+tvM7EN7NyoHtC0yp3r7P4v5kNJfjws1do01oz+xfwa3cPr58m\nIiIiIiIJNPNeUnHZE58wbVF0rmhQj0ouPnpIgSMSaRLmh+4fD/wj087MrIxgXXUnSNRHr40BiVez\nlmdwqqg3i0zKsUcdU2fSOgMXAfsDibPGdwU+MrMXgDcIKhN0BvYFRrF5oks1QYL9oIRjl5fY+dLm\n7nPM7CzgfwnNrQhK8V+d6/OJiIiIZGPcpPlbXMuYtWQt4ybN12z8Jm7Zmo1c8/SUyG3tWzXjwsMH\nFjii/Mp3Of1vA7cl3B8NPJpsZzPbi2AtsdZsOUIcglJv5xGsLbZ/XSXJRERERESaKiXvJVWPfzCf\ne9+eG7mtVbNybj5tF1o2y1WlbxFJ8E7sazzpfrGZ3VvfpIc6/IygTLrHbm8n2S9+RcuBOemexN3X\nmVkNW1Z1zGStjXClxWp3z9lMfAB332hmRwKPAbsnbKoADovdIg8FziF4rFJOqhf6fJly96fMbCJb\nDjY4FCXxRUREpISEZ+HHjR0/jeOG9aSivGSLjEue/fGZqSxfG12M6peHDaRz26iiXw1Xvl/powk+\nkBrwBTAu2Y6xcvn3ESTrjeCDTFz8g6gRlCT7d57iFRERERFp0O58Y7bK5ku95ixZy0UPf5R0+++P\n24Htu2W+bpuIJOfuM4H34neBnsBTZtY+3b7M7GTgD2y+ZgLwQJLdRyZ8Hz19pX7LQ/d7ZNBHeI2O\nZZmFUjd3XwQcAFwDpPJHcR5wiLvfQu0Y55Xa+bLweOj+bnk8l4iIiEjawrPw4+Kz8aVpmjR3Ofe9\nEz0Weaet23PqyG0KHFH+5TuJvw+bE/CPubvXse+PgF5smbB/iyCx/yGbE/sGHG1mh+YxbhERERER\nkUZpY1UN59z7Hqs3VEVuP25YT05UiUKRfLuEzUl3I1hScLKZnRmb5FAnMxtiZncD9xDM9o5fM3nH\n3Z+M2H9HgvXX49dlXs8w7s9C93tn0Ef4mNrTrHLE3de5+6+AfgTVHZ8AZgCrgHXAdIKk9reAAe7+\nQuzQwaGuJpbi+TIUHsDRJpXXnIiIiEghJJuFHzd2/DSqqmsKGJGUguoa53fjPiYqy2wGlx+/I+Vl\n4QLvDV/eyumbWX+gQ+yuA8/Vc8gP2JykrwFOdvf/JvT3Y+AvbP7A+WPg2RyGLCIiIiLS4J2x57Z0\naNWMa56ewqJVG4odjpSga5+ZwofzVkRu69O5NVd+YyfMGt+HX5FS4u5Pmtm/ge+w+VpID+BfwI1m\n9howCVhAkABuDrQnKLs+ks1J38QJD2sJrq1EGRPaP9PrKVOAPRLub5dBH31D9ydnGEvK3H0BMDZ2\nq5OZtQF2CDW/W8rnS9PSiLaOBIMMRERERIoq2Sz8uPhs/NEaeN6k3PfOnKTXMU7ZbRuG9e5Q2IAK\nJG9JfIJRx4k+TLajme1E8EE0Pgv/3sQEPoC7/83MDiQo0Q9wuJm1cnd9yBARERERiVm0agPPfbpQ\nCXyJNH7KQv75yszIbc3KjbGnDqdti3x+TBSRBGcRLCl4ApsnLBjQlmCd8mQVCBNH2SQm8I9291rX\nXsysGUHJ9A9iTVPcfW6GMX8auj8ycq+67R66n/ckfpr2BsoT7i+n9s/dkM8XtWxD9BVRERERkQKq\nbxZ+3Njx0zhuWE8qyvNdbFxKwdI1G/nj01Mjt3Vo3YwLDxtY4IgKJ5+v8MTFB2pia74lc2Dsa/yD\n6N+T7PfPhO+bAztlGJuIiIiISKOysaqGv034nIOue4mnP1lQ7HCkBC1YsZ4LHkw6tpqLjhjMTr3S\nXpJbRDLk7tXAScBFwAY2z5JPTOiHb0Ts8w6wu7u/lOQ8m9x9f3ffJXY7NYuwXwzd72FmKc/Gj+3b\nvZ4+i+300P27Y89VYzlf/9D9de6+Jo/nExEREUlJfbPw4+Kz8aVpuOapKaxYtyly2/8dPoiObZoX\nOKLCyWcSvzLh+9X17LtfwvdLgdeS7Pdm7Gv8w2p4zTARERERkSbnjc+XcORNr3DN01NYt6nu6/4q\nkt40Vdc4P7v/fZau2Ri5/eDB3Thz7z6FDUpEcPcad78G2JGg9PpytkzYhyUm9N8AzgD2dPdP8h8t\nEKzV/kWo7aQ0jj85dH+eu+dz/fe0mFk34Buh5lsby/lijgjdTz66S0RERKRANlVVc83TU1Lef+z4\naVRV1+QxIikF781Zxv3vRhcRG9q7Ayfv2rvAERVWPusktkz4PnqIxGZ7sjkx/4q7e9RO7r7SzFYT\nlJsD6JBVhCIiIiIiDdiileu58n+TUx6BfuCgbvz2qMG8P2c5Y8dPS2mEuzQON4+fzpszopZBhh7t\nWnLtN4dipiEeIsXi7jOAn5rZhcCuBCXnBxGsVd6eYKb+cmABwXrpb7j7nCLE6WY2DvhJQvN3zeza\n+maPm1k5wRICicblOsYsXQu0Trj/iru/31jOZ2Z7A/uGmp/J1/lEREREUvHhvOX8/P5JaS0LGJ+N\nP3pErzxGJsVUXeNc/OjHkdvM4PLjdqCsrHFfx8hnEj9xrfrKZDuZWT+gB5uT+K/W0+96gvXhnM3J\nfBERERGRJqOquoa73pjNDc99xqoNVfXuv3WHVlxyzBAOGdIdM2O7rm05blhPxk2azxVPfsqytfWN\nuZWG7M0ZS7jxhc8it5UZ3HTqLo26/JxIQ+LuGwiqEyarUFgKbgB+wOZrStsD58Xa6/Kz2L5xVfUd\nY2ajqF1u/wB3n5BSpGkwszPYsrT9JrYcrFAS5zMzSzb5pZ7jugB3hpqrgXvS7UtEREQkF+YuXcu1\nz0zlsQ8yK40/dvw0jhvWk4ryfBYdl2K5563ZfDJ/ZeS200Zuw869OhQ2oCLI5yt7ecL3zc2sZ5L9\nDo59jQ+XeKWefivZnPBPfViOiIiIiEgjMHH2Mo65+TUue+LTehP4zcqNn4zqx3Pn78ehO/TYYqZ1\nRXkZo0f04p3fHMz+A7rkO2wpkqVrNvLT+96nJkm652cHD2Bk306FDUpEGjR3nw7cEWq+wswOSXaM\nmR0GXB5qvt3dP89xePHz9YgNAEhl33IzOx+4LbTpWnePnvpT3PPdaGZXm1n3VM4XO+cQ4CWgX2jT\nbe4+NdV+RERERHJh+dqNXPHEpxx03UsZJ/Bh82x8aXwWr97Atc9E/5vasXUzfnnYwAJHVBz5TOKH\nF6/YP8l+oxO+XwO8l6xDM2sDtEhoih6CISIiIiLSyCxds5EL//sBo//2OpO/rP/f4L36deapn+7H\nhYcPonXz5AW4KsrLuH3MSI4ZulWd/d3x2sy0Y5bicncuePADFq6MHvu853adOfuA7SO3iYjU42Jg\nUcL91sATZvZHM+trm/Uzs2uBx4FWCfsvAn6Xx/h6AC+a2WQzu9LMDjCzjvGNsUR6HzM7G5gIXMeW\n18ieBy4t0fO1A34FzDOzZ83sbDPb08y2qIJpZu3M7HAzuwOYBAwJ9TMZ+L/Uf0QRERGR7KzfVM0/\nXv6c/f74Iv96dSYbc7Cm/djx06jKQT9SWq55agor10dPXPnVEYPo0LppVBPMZzn9DwlKgVUQzLL/\nuZnd7+5f/zaZ2Y7AQWyeWf9KPWuo7RA/NHbM7JxHLSIiIiJSQmpqnPvemcsfn5nC8hTK3nerbMFv\njx7CMTtvlfIa52VlxnUnDmPZmk28On1x5D6/f+JTula25Kid6072S+m49dWZjJ+yKHJb5zbN+fMp\nwyhv5OvHiUh+uPsCM/smwXrq8eR8c+CXsdv6WFvLiMPXAaPdfUHeA4VBwK9jN8xsXez8HUg+seU5\n4BvunslaM4U8XwVwSOxG7HwbgVUEgypaJTkO4HPgcHdfluY5RURERNJWU+M89sF8rn1mKl8sX1f/\nAWmIz8YfPaJXTvuV4pk4eykPTpwXuW1Y7w6cOKJ3gSMqnrzNxHf31QQf5uJXhUYA95lZfzOrMLPd\ngQdj2+P73FdPtyNC96flKl4RERERkVLz0bwVfONvr/PrRz6qN4FfXmZ8d5++vPCL/Tl2aM+UE/hx\nzSvKuOU7I9hx63aR293h5/dP4o3Pl6TVrxTHh/OWc83T4eJom/3ppKF0bxeVWxMRSY27v0KQQP4y\nYnNLohP4XwIHu/ur+YytDq2ATkRfD9sIXAYc4e5rGuj5mgOdSZ7Ad+B2YJi7z8nROUVERESSen36\nYo79y6v87P5JOU/gx93w3Geajd9IVFXXcPGjn0RuM4Mrjt+RsiY0GSGf5fQBro99dYJE/WiCMvsb\ngNeBgWyehb8AeKCe/o5M+H6pu6ump4iIiIg0OivWbuLiRz/m2L+8ygdzl9e7/67bduSJc/fh4qOH\nUNmyWcbnbduigtvHjGSbTq0jt2+sruEHd72bUjl/KZ5V6zdx7r3vs6naI7f/YL/tOGBgtwJHJSKN\nkbu/BgwG/gAsrWPXpbF9Brv76wUIbSZBgvxdoK6KjxDE9heC2C6pp0JkKZzvBuBy4GVgdYrHzI+d\nc4i7nxWbeCMiIiKSN1MXrGLM7W9z2r/e4uMv8nsNYd7ydTz03hd5PYcUxt1vzeHTJNecvr37tuy4\ndfsCR1Rc+Synj7tPMLN/Ad9jc7I+cYiEJ3y9wN2jF2sEYmt7HZJwzBs5DldEREREpKjcnYff+4Kr\n/jeZJWs21rt/5zbN+dURgxg9vFfORiJ3rWzBnWeN5Jt/ez0yhlUbqjjjtrd5+Cd70atjdLJfisfd\n+fUjHzN7ydrI7UN7d+CCQwcWOCoRSYWZ7QAcAIwEehOUYK8k/QkY7u79chtdnSdbAVxkZhcDuwE7\nAV1im78CPgbecffoRS3r7nsCW15HSiemS4BLzKwNMBToB3QjKDW/gWAyySfA+4lLP2aikOdz9w+A\nDwAsKLuzPbAd0AvoSFABYQOwDFgMvKdZ9yIiIlIoC1as5/rnpvLfifOoiR5XnhdXPvkpo4dvTUV5\nvucuS758tWoDf3p2auS2Tm2aN8lrGXlN4sf8iOAD13cT2hIT+g78xt3vraefMQRlweLHPpPDGEVE\nREREimrqglVc/OjHvD2rromMATP41u7b8MtDB9G+deYz75Pp26UNt43ZjVP/+SZrN9aeILho1QZO\nv+1tHvrRXnRs0zzn55fMPfDuXB7/YH7ktsoWFYw9ZReaV+iihkgpMbNRwO+BfcKbMuyygJdLE04a\nJOnfoMQmXcRK1b8euzWq87m7Eyw1qeUmRUREpKhWrd/E31+awb9encH6TYUvbb9yfRVXPzWZi4/e\noeDnltz4w1NTWLU+etzvr47Iz/WvUpf3JH5sdPH3zew24HRgF4KRwcuBt4F/uPtHdfURG1n80/jd\n2NfH8hKwiIiIiEgBrd5QxY3Pf8Ztr82iOoVh6jv3as/lx+3I0N4d8hrX0N4d+Ou3hvO9O9+lKiKu\nGV+t4aw73+Ge7+1Bq+bleY1FUjNt4SoueSx67TiAP4zemW06q3qCSCkxs8uAXxNc64hf70h80003\nId90FogUERERkaLbVF3DvW/P4cbnp6VUURBg+65tWLm+ikWrkhbnzsjtr81izN596N2xTU77lfx7\nZ9ZSHnpvXuS2Edt25JvDexU4otJQiJn4ALh7xqOx3d3NbESoSQtxioiIiEiD5e48+dGXXP7Epyxc\nWf8H1/atmvHLwwZy6shtKM9R6fz6jBrYjT9+c2fOf+CDyO3vz1nOufe+xy3fHqGSdUW2flM159zz\nftIZD6eO3Iajdt6qwFGJSF3M7BfAb2N3PXZLlsyPeuOvb7uIiIiISF64O898soBrnp7KzMVrUjqm\nW2ULzj9kAOVlxi//+2HOY6pxOOPWd3j25/vpGkUDUlVdw8WPfhy5rczgsuN2yNkSkg1NwZL42Yqt\nLyYiIiIi0uB9/tVqLn3sE16Ztjil/U8c0YtfHTGIzm1b5Dmy2k4Y3otFqzbwh6emRG5/fvIifvvo\nx1x9wk4EBbSkGH7/+KdMXbgqctvA7pVccsyQAkckInUxswHAH9hyucEpwI3AW8CRwBWxbQ70BdoS\nrDW/G3AUsH/C8ROA84DoNwIRERERkRyZOHspV/1vChNnL0tp/zbNy/nh/v343r59ad08SEueuGvv\njM+/bM1GjrzpFb5csb7WthmL13DTC9M4vwmun95Q3fXGbKYsiP4Yc/qefdihZ/sCR1Q6GkwSX0RE\nRESkoVu3sZq/vDidv7/8OZuq66+QPKhHJVccvyO79ulUgOiS++F+27Fw5Xpuf21W5Pb73plLt3Yt\nOf+QAYUNTAB44sP53Pv2nMhtLZuVcfNpu9CymZY8ECkx/weUszkJ/yhwsrtvAjCz3RN3dvfZCXdf\nAv5kZrsBtwNDCBL6/wEOdvfURoiJiIiIiKRh5uI1/PHpKTz18YKU9i8vM04d2ZufHjSArpW5m5TQ\nsU1zbjp1F075x5uRyxKOfXE6u2/Xmb2375Kzc0p+LFq5nhue+yxyW5e2zfl5E7/OpCS+iIiIiEgB\nPPfpQi597BO+WL6u3n3btqjg54cM4Iw9ty2JEnBmxsVHDWHRqg08+eGXkfvc9MI0ulW24Nt7bFvg\n6Jq2OUvWctFDHyXd/vtjd6B/98oCRiQi9TGzMuCbbC6fPw/4TjyBnyp3f8fMdgWeBfYBdgLGmdm+\n7h69toaIiIiISJqWrN7ATS9M4+635lAVkTSPcuiQ7lx4+CC279Y2LzHt1qcT5x8ygGufmVprmzv8\n9L5JPPXTfXM6eEBy7+qnprBqQ1XktouOGEz7Vs0KHFFpKXoSP/bhtSPQGjB3j55CIiIiIiLSAM1d\nupZLH/uEF6YsSmn/Y4f25DdHDaZ7u5Z5jiw9ZWXG9ScNZenqjbwxY0nkPr8b9zFd2rbg8B17FDi6\npmljVQ3n3vd+0g+8xw7tyUlZlCgUkbwZClQSJPEduNndU1tINMTd15vZ8QSl+LsAewA/Bv6Sm1BF\nREREpKlat7Ga216byd8mfM7qJJ87w4b17sCvjxzMyL75ryj44/378eaMJZFLFS5evYHzH5jEnWeO\nbLLrqZe6t2Ys4ZH3v4jctlufjpwwfOsCR1R6Cj6tx8y6mtn5ZjbOzBYAm4BFwCxgRh3H7WJmw2O3\nPoWJVkREREQkMxuqqhn7wjQOvv6llBL4/bq24Z7v7c5Np+5Scgn8uBYV5fz99BEM3qpd5PYah/Pu\ne593Zi0tcGRN05+encoHc5dHbtu2c2uu/MaOmOlihUgJGhL7Gv8FfaK+A6yOX2Z3Xwr8KaHPn2cV\nnYiIiIg0adU1zgPvzuWAP03g2memppTA37Zza/5y2nAe+cleBUngQ3yywTC6tI2ebf/KtMX87aXP\nCxKLpGdTdQ2/G/dJ5LbyMuOy43Q9AwqYxDezSjP7KzAbuBY4GuhG8AEz8ZbMdcA7sdv4/EYrIiIi\nIpK5lz/7isP//ArXPfcZG6rqrmjcqlk5Fx4+kKd+uh97NYD12tq1bMadZ+7G1h1aRW7fWFXDd+94\nh88WripwZE3Li1MX8Y+Xo8dANys3xp66C5Utm3bZOZES1jHh+xqCWfRh4Tql0W+6mz2U8H1fMxuc\nSWAiIiIiuVRVXcOtr87k1ldnUlWt1X5KnbszYeoijrrpFS7874csWLm+3mM6tm7GJccM4bmf789R\nO29V8MRr18oW/PnkYSQ77fXPfca7mmhQcu58fRZTk1w3On3PbZNOHmlqCpLEN7NdgEnAD4GWbE7W\ne8KtPjewOdG/rZkdkPtIRUREREQy9+WKdZx993ucftvbzFxcf2XkQ4d057nz9+Mno7aneUXBi2Rl\nrFu7ltz13ZF0bB2dJF65voozbnub+cvXFTiypmHhyvX84oEPkm7/v8MHsXOvDoULSETSlXhFamWS\n9evDf0TqXEzU3T8nqHQYv74yLOPoRERERHJk3KT5XP7Ep1z+xKeMmzS/2OFIHT7+YgXfvvUtxtz+\nDlMW1D8ov0VFGT8e1Y+XLjyAM/fuW9RrGvv078I5B2wfua26xjnv3vdZtmZjgaOSZBauXM+fn58W\nua1L2xb8/JABBY6odOX9t8rMBgDPAn0JEvDxD5QGrAdWUvcM/Lgnga8Sjj8ht5GKiIiIiGRmU3UN\n/3j5cw667iWe/OjLevffplNrbh+zG/84fVd6dWxdgAhzr1/Xttw2ZjdaNSuP3P7livWccdvbrFi7\nqcCRNW7VNc7P7pvE0iQXIA4a1I3v7tO3wFGJSJoSRzglux4SvnLaK4V+E6+v9Ew3KBEREZFcqqqu\nYez4zYm6seOnaTZ+CZq3bC0/v38SR499ldemL6l3fzMYPbwXL14wiv87fBDtSqQC3E8P6s/IPtFl\n/OevWM8v//sh7qnMJ5Z8u+p/k5Mu0fCbo0rnNVUK8prEN7MK4DGgM5uT7yuAS4HB7t4G+FUqfcVG\npj/G5tn4h+Q6XhERERGRdL01YwlH3fQKV/1vCms3Vte5b/OKMn56UH+e/fl+HDCoW4EizJ9dtunI\nX761C+Vl0TmoaYtW87273mH9profF0ndX16czhszoi+s9GjXkmtPHKp140RKX2I9z8ok+ywI3R9U\nV4ex6y/t2XztJXphUBEREZECGTdpPrOWrP36/qwlazUbv4SsWLeJq/83mQOve4lH3v8ipWP27d+F\nJ8/dl+tOGkrPJEvsFUtFeRk3njosacXA5ycv5PbXZhU2KKnljc+XJH0fGNmnE8cP27rAEZW2fM/E\n/yEwgM0fIj8BdnH3y9x9agb9jU/4vr+Zdc42QBEREZHGrqq6hocmzit2GI3OV6s2cP79kzj5H2/y\n2cLV9e6//4CuPPuz/fj5IQNomWT2ekN04KDuXH3CTkm3vzNrGefd+z7VNRrxnq23Zy7lz89/Frmt\nzODPpwyjU5vmBY5KRDIwPeH7MjOLulL1CVsuP7h/PX2OBCoS7i/PODoRERGRLIVn4cdpNn7xbaiq\n5l+vzGD/a1/k7y/PYGNV/c/H4K3acddZI/n3d3dnSM/SXat8q/at+NOJQ5Nuv/qpyXw4b3nhApIt\nbKqu4XfjPo7cVl5mXHb8DpqUEJLvJP65BB84jeAD5BHuPjuL/sILPw7Ooi8RERGRRi2evD/4+pf4\nxYPJ18+W9FTXOHe+PosDr5vAwymMVu/ZviW3fHs4d5y5G326tClAhIV30q69+eVhA5Nuf/bThVw8\n7mOVrsvCsjUb+el975NsLMR5B/Vnj+00xlmkgfiULRP0tUZCufsaYHLsrgEnm1mHOvr8RcK+ADOy\nD1NEREQkM+FZ+HGajV887s5jH8zn4Otf4oonJ7M8haXvtmrfkutOHMoT5+7DfgO6FiDK7B00uDvf\nS7LE3KZq55x73mflei37Vwx3vDaLaYuiJ8GM2asPg3qU7gCRYqmof5fMmFlfNs/Cd+A6d892Clh8\n6Fb8g+52wKtZ9ikiIiLSqFRV1zBu0nzGjp8W+aFZMvfenGVc/OjHfDJ/Zb37VpQZ39t3O847aHta\nN8/bv90l4yej+rFw5XrueiN6zO49b82hR7uWnHdQ/wJH1vC5Oxc8+AFfrlgfuX2P7Tpx7oF6XEUa\nCndfZmYfEyTvnWCW/dMRuz4IXBLbpxK4z8xOcvev/whZMFXlUuAbbJ5EUQW8ls+fQURERCSZZLPw\n48aOn8Zxw3pSUZ7vOaYS9+aMJVz9v8l8MG9FSvtXtqjgJwdsz5l792mQlQQvPHwQ78xaGvnzzlm6\nlose/oibT91Fs74LaMGK9UkrC3arbMHPDtY1jSj5vJq4a+yrEXyQfCDbDt19k5mtA1rGmjpk26eI\niIhIY6Hkff4sW7ORPz4zhXvfnpvS/nts14nLj9uR/t2TLXXc+JgZlxyzA1+t2sBTH4eXcg5c/9xn\ndKtswSkjtylwdA3b7a/N4oUpiyK3dWrTnBtP2YXyMl18EGlgnmfzDPyjgYsi9rkNuJDN10AOAWaY\n2ZPAXIJrIocRTHCAzddf7nX3VfkJW0RERKRuyWbhx8Vn448e0auAUTVN0xau4pqnp/D85OjPk2HN\nyo1v77Et5x7Yv0Ev1da8ooyxpw7nqJteYdWGqlrbn/zwS/bu14XTdte1iUK54slPWbOxOnLbb44a\nTGXLZgWOqGHIZxK/W8L3G919etI907MWaEXwwbRtjvoUERERabCUvM+fmhrngXfncs3TU1iWQqm5\nLm1bcPHRgzl2aM8mOaK7vMy44eRhLFnzNm/PXBq5z68f+YjObVtwyJDuBY6uYfpo3gqufmpy0u3X\nnTiU7u1aJt0uIiXrQeDnBIn3IWa2m7u/k7iDu881s6uBy9g8y74T8O2E3eJ/bOLblwEX5zl2ERER\nkUj1zcKPu0mz8fNq0cr13PD8NO5/Z07SJdnCjtppKy48fCDbdm4cywBu07k1fxi9M2ff817k9t8/\n/gnDt+2gEu4F8Nr0xTzx4ZeR2/bYrhPHDu1Z4Igajnwm8ROnHUUvcpB5v/G3nTU57FdERESkQVHy\nPr8+mb+C3z76Me/PWV7vvmUGp+/Zh/MPHUC7Jj56uGWzcv55+q6cdMsbTF1YeyJojcM597zHPd/f\nnRHbdipChA3HqvWbOOfe99hUHX3V5fv79uWAQd0it4lIaXP3N83sAaB9rGkk8E7EfleY2XbAGDZf\nC9lil9hXA1YA33T31MrGiIiIiORYfbPw42YvWcs+14xnRJ9ODOxeyYDulQzo3pZtO7dRlbEsrNlQ\nxT9ensE/X5nB2iSznsN269ORXx85mF226Zjn6ArvqJ234o0Z2/CfN+fU2rahqoaz736Px8/dp0ks\ngVgsG6tq+N24jyO3VZQZlx23Y5OcBJOqfL4yE6fe5GQoi5l1AZqz+UPq4lz0KyIiItKQKHmfXyvX\nb+L6Zz/jrjdmpTRiffg2Hbj8+B3ZoWf7+nduItq3asadZ43khL++xvyIddw3VNXw3Tvf5b8/2pPt\nuzWdJQfS4e789tGPmZ3kd3xor/b88rBBBY5KRHLJ3U9Jcb+zzOwlghn5vaN2AcYBF7p7/VPfRERE\nRPIg1Vn4cQtWbuDJD7/kSTbP0G1RUcb23drGkvqVDOzRlv7dKtm6QyvKlNxPqqq6hvvfncsNz01j\n8eoNKR2zXdc2/OrwQRwypHujTqL+9qghvDtrGVMW1J5k8PlXa/jduE/404lDixBZ03DbazP5/Kvo\n+dhn7dOXAU1oGcpM5DOJ/1Xiecysfw4+TO4V+xpf5+2LLPsTERERaTCUvM+P+ON6wvCteXTSF1z5\n5JSUPvR2bN2Mi44YzDdH9NLFhAg92rfkru+OZPTf3mDFutpLESxfu4kzbnuHh368Fz3aqxx82IMT\n5zFu0vzIbZUtKhh76nCaV6j0pEhT4e53Anea2QhgZ4IlDDcB84CX3H1hMeMTERERSXUWfl02VNXw\nyfyVfDJ/5RbtbZqX0z82Wz9I7gdJ/m6VLRp1Aro+7s7zkxfxh6cmJ02UhnVp25yfHTyAU3br3SSW\nM2jZrJybTxvOsTe/Glmd4L8T57FXv86cMLxXEaJr3OYvX8dNL0Snhbu3a8F5B/UvcEQNTz6T+O/H\nvsbnLx0KZJvEPzXh+03Am1n2JyIiIlLylLzPj/Dj+sC7c3kryTruiczglN224cLDBtKxTfMCRNpw\nbd+tktvG7Mpp/3yLDVU1tbZ/sXwdY25/m/t/uCftWzXtZQgSTV+0ikvGfZJ0+1Un7MQ2nVsXMCIR\nKRXuPhGYWOw4RERERBKlOws/XWs2VjNp7nImzV2+RXv7Vs0Y2L2S/t3bfp3YH9C9kk5N4LP6pLnL\nuep/k3k7hesYAK2alfP9/bbjB/ttR9sWTat8/Pbd2nLF8Tty/gMfRG7/7aMfM7R3B/p1bVvgyBq3\nK5+cnHRZh98eNaTJvQ4zkbdHyN1nmdlnQH+CmfPnmdnf3b0qk/7MbEfgm2weFPCWu9euzZlnZlYO\n7AbsCHQh+Nm+Aj4C3nX31BYaySMz25Ygxu5AB2AtMBN4292jp/M0oPMVWkN4zkVEpHFS8j4/kj2u\nqSTwd9y6HVccvxPDenfIY4SNy4htO3HzacP54b/fjVyeYMqCVfzgrne586yRtGxWXvgAS8z6TdWc\nc8/7rNsU/S/mqSN7c8zQngWOSkREREREJLlczMLPxIp1m3h71lLenrXl5/kubVt8XYo/ntzv370t\n7VqW1uDxquoanpkVVK7bp7ompZnxs5es4Y/PTOXJD7+sd1+AMoOTd+vNzw8eQLd2TbcK3gnDe/Ha\n9CU89N68WtvWbqzm7Lvf49Gz99Z1iRx5ZdpXPPlR9Gt0r36dOXrnrQocUcOU72EO9wKXECTetwf+\nCJyfbidm1hF4AIj/9jjwjxzFmGoM7YD/A35AkMiNstjM/gFc4+4rk+yTF2bWnCC2c4CBdez3OvBH\ndx9X6uczs1HAixmGGPeSu4/K5MBSf85FRKTxylfy/pJxH9OlbQu6VLYIvrZtTtfY903hQ0o2j2tl\nywouPGwgp+2+LeUqnZ+2Q4Z058pv7MRFD38Uuf2tmUs5/4FJjD11eJN/fC9/4tPItfoABnRvy++O\n3qHAEYmIiIiIiCSX71n4mVi8egOLp2/gtelLtmjv2b4lAxJm7A/sXsn23drSqnlxromMmzSfe6ds\nBGDXSfMZPSJ5SfdlazZy0/hp/OfN2WyqjhghH+GgQd341RGD6K91xwG47LgdmDR3WeTSA1MWrOLK\nJydz+fE7FiGyxmVDVXXS6oIVZcZlx+3QpJfBSEe+k/jXAWcDnQlmL/80lvz9pbuvS6UDM9sNuJtg\nEED8nelz4J7ch5s0hj2Ah4D6prx0AX4NnGFm33T3gpT7N7OdgEeAfinsvhfwqJk9BJzp7tFXCEvo\nfMVQ6s+5iIg0bne+MZvLn/g0L/0mU9miIpbc35zY33xrTpfKFnRt24KulQ0v4Z/toIgThm/NRUcM\npmtlizxE13ScOnIbFq3cwA3Pfxa5/X8fLaBL20/4/bFN98Pckx9+yd1vzYnc1rJZGTefNrxoF5dE\nRERERESi5GIWfp/OrVmyZiOr1mdUyDll81esZ/6K9UyY+tXXbWawTafWsVn7bb9O8G/XtQ0tKvL3\n+Ss8+GHs+GkcN6xnrdn46zdVc/trs/jrhOkpPz47bd2eXx85mD37dc5pzA1dmxYV3HzacI77y2ts\njFjy799vzmbPfp05cifNEs/Gra/OZMbi2gMlAL67b1+276ZBJanKaxLf3Veb2c+A/xAk4A34MXC8\nmd0GPEdQfv1rFlyx2xbYDzgRODJ2XPz4KuB77p7aUKMsmdnewLNA1KKT62Mxha/obg28YGaHuPvr\neY5vFPAYEPWqrwZWAu2BcB2W0UAvMzvY3VeX6vmKodSfcxERkXxYtaGKVRuqmJnkn+xEbVtUBIn9\nWJL/66R/ZUJb7H7r5sVb36q6xnnzyyouvf6ljC4oDOxeyeXH78jIvp3yEF3TdN5B27Nw1XruSZKo\nvuuN2XRv15KzD9i+wJEV39yla/nVwx8m3X7pMTswQLMnRERERESkhORyFv57vz2YxWs2MnXBKj5b\nuIrPFq6OfV3F+k21E6654g6zl6xl9pK1PD954dft5WVG3y5tGNC97dez9vt3r6RP59Yplb2vT3jw\nw6wlaxmXMBu/psZ55P0vuO7ZqcxfkdrK0r06tuLCwwdx9E5bUdbEq9wlM3irdlxyzBB+88jHkdv/\n778fstPW7endKSo9JPX5Yvk6xr4wPXJbj3YtOe/A/gWOqGHL+1VVd7/HzLYHLmVzIr4n8JvYjYR2\nCJKkiXHFE/hxv3T3l/MZ89cnNutGMBs78bd1E3AjcAswI9bWD/gRcB4QX1SlNfCQmQ1190V5iq83\n8CC1E+r3AjcDb7l7tZlVAHsDPweOS9hvd+BW4ORSPF+EacD1aR4zP52dS/05FxGRpuGMPbelQ6tm\nOS+nnyurN1SxekNVSrG1aV6+Rfn+LZL+bVvQNSHx36ZFbv41jc+8/+Or61i41oGNaR3fpnk5Pzt4\nAGP27kOzHHwwl83MjMuP25HFqzbw7KcLI/e59pmpdKtswYm79i5wdMWzqbqGc+99P+msiqN33oqT\nd2s6j4eIfL28WyWblxVMi7tHj5YSERERyaFczMKHIIH92AdfMnpEL7Zq34pRA7t9va2mxpm3bB1T\nF676Oqk/dcEqZny1ho3V+UvuV9c40xetZvqi1fzvowVftzcvL6Nft7ZbJPcHdK+kV8dWKSfOkw1+\niM/Gf2PGEq7+3xQ+/TK1FXzbt2rGuQduz3f23Dav1QMai9NGbsPrny/hyQ9rr9m+akMV59z7Pg/+\ncE+aV+iaULqueOJT1m2qjtx28dFDcnbtr6koyKPl7peZ2VLgWoIZzIlJ+0TG5oQosf3i+24Eznb3\nW/McbqLLge4J99cBx7v7s6H9pgMXmNnzwMNAq1h7D+AygmRvPvyTLddqryYoWf/vxJ3cvQp4CXjJ\nzM4jSEjHnWRmt0b8TKVwvrD57n5LBselo9SfcxERaQIqyssYPaIXxw3rmVUZ+FKwZmM1a2Ij2uvT\nunn5Fsn+eAn/4OvmZH+Xyha0aV5eq+R6tmXzAY7aeSsuPmoIPdq3zOh4qV95mXHTqbvw7X+9xbuz\nl0Xu86uHP6JLZQsOSLhw05j96dmpTJq7PHLbNp1ac9UJOzXZJQZEmoLYQPijgW8Cw4H+1K5ulw6n\nQNd7REREpOnK5Sx8SF5OvqzM2KZza7bp3JpDhmy+dF9VXcOsJWu/TupPWxR8nbVkLdU1+SvkvLG6\nhslfrmRyKMHeqlk5A7q3pX88sd+jkgHd29KjXctan+eSDX6YtWQtR419lakLUluVuHlFGWfu1Yef\njNqe9q2b1X+AAMEEg6tP2ImP5q1gztLaz8MHc5dz7TNT+M1RQ4oQXcP10mdf8dTHCyK37bN9F47c\nqUeBI2r4Cvahzt1vNrNXgT8AhyZuquOw+DvbBOBCd383T+HVPrHZdsCZoebf1JV8dvenzexi4E8J\nzWeZ2TXuPjPH8e0KHBZqvjqcUI+I8SYzG8yWSearzey5upYoKPT5iqHUn3MREWl64sn8w3fszqhr\nJ/DV6vRmlDc0azdWM2fp2sgPUGGtmpV/Xb6/U5vmrNlQxeQvV7Fi3aaMz//v745k3/5dMz5eUtey\nWTn/OmNXTrzlDaYtqr3SUnWN85P/vMe9P9iDYb07FD7AApowdRF/f2lG5LZm5cbYU3ehXUtdjBFp\nrMzseIKB773iTcWLRkRERCR1uZqFHxcuJ1+fivIytu/Wlu27td1iDfMNVdXM+GrN18n9eFn+VK41\nZGPdpmo+mLeCD+at2KK9smXF16X4B3ZvS79ubfnz858l7SfVBP7xw3pywWED6dVRZd8z0a5lM8ae\nugvfvOV1NlXXTlX985WZ7NmvMwcO6h5xtIRtqKrm0sc+idzWrNy49NgdNDkhAwUdme3uk4DDzWxH\n4BSCde93o/b64g58AjwP/LdIa4yfz5ZVAaYDN6Vw3J8JEtbxhTybEZSVPy+XwQE/CN1fDFyR4rEX\nAaezuWT8cIIE/dMldL5iKPXnXEREmqhbX52VkwT+r48cxOLVG1m8agNfrd7AV6s2sHj1Rpau2UAe\nB6nnxbpN1cxduo65S9flrE8l8AurQ+vm3HnWSE746+ssWFl7fb91m6o56453+O+P9mS7rm2LEGH+\nLVq5nl888EHS7f93+CCGNvJBDCJNmZldRPC5On41K16NMKtuszxeREREpF65noUfl2w2fjpaVJQz\neKt2DN6q3RbtazdWMX3R6lhif3Ny/8sU15vP1Kr1Vbw7e1nSSnTp2qtfZ3595GB23Lp9Tvpryob2\n7sCvjhjM5U98Grn9Fw98wP9+ui9btW8VuV02++fLM5i5eE3ktu/tux3bd2uc13XyrSjl1dz9Y+C3\n8ftm1gHoBDQHlgJL3D160YTCOS50/9ZUYoqtCX8bcFVC8/HkPqF7YOj+ve6+IZUD3X25mT0CfCuh\n+QTqTqoX+nzFUOrPuYiINEGzFq/h5henR24rL7O0SsT9YL9+ke3VNc6ytRtZ/HVifwOLV8Xurw4S\n/fH2pWs25rUsnTQtPTu04s6zRnLiLa+zMmI9+KVrNnL6bW/z8E/2oltl41rioLrG+dn9k1iyJnqA\nzgEDu3LW3n0LHJWIFIqZHQtcGbsb/8OamIBfBqwhWMZOREREpKTkehZ+XLqz8dPRunkFO/fqwM69\nOmzRvmLdJqYvWsXUBatjyf3gtrjEqiEO7F7Jr44cxKgBXTWjOYfO2rsPb3y+mOcnL6q1bdnaTfz0\n3knc8/3dsxpY0tjNXbo26bXLnu1bcu6B20duk/qVxBpp7r4cWF7kML5mZiPYXMou7oE0uniALRO6\nvc1suLu/l3VwgJl1AsJX4V9Ls5vX2DKpfpyZ/cjda4p9vmIo9edcRESaJnfn4nEfs7Eq+s/ln08e\nxsaqmqzWgIdgMEB8vflB9SxPVfN1wj9I8scT/18lJP4335Twl/oN7FHJP0/fle/c9nbka33esnWM\nue0d7v/hHlQ2orLyf5swndc/XxK5rXu7FvzpxKGUlenCjEgjdn3sqxMk76uA24G7gYnuHj2NRURE\nRKQEjB7Ri517teeIG1+hKuJz/5Ct2vH4uftQ3gA+07Rv1YwR23ZixLadtmhfsnrD17P147epC1ZF\nDkDPp+7tWvCLQwYyekSvBvF4NjRmxrXfHMqRN70SWZXh7VlLuemFaZx/6MAiRNcwXP7Ep6zfFH3t\n8nfHDKF185JIRTdIeuSiHRC6v8DdoxeqjODun5vZQiBxsYwDgVwldLtFtEUPc0kuXOumG7ATEFXP\ns9DnK4ZSf85FRKQJeuLDL3ll2uLIbaMGduXonbfCzDhuWE/GTZqfdTI/FWVlRue2LejctgUDqaxz\n35oaZ/m6TbGZ/bVn9X99iyX/oz74S9Ow+3aduemUYfz47vfwiJfBp1+u5Ef/mchtY3ajRUV54QPM\nsXdmLeX656LXQCwz+PPJu9C5bXjFMRFpLMxsOLAdmxP4y4Ej3P2tYsYlIiIikip359LHP0n6Of73\nx+3Q4BPOndu2YM+2LdizX+ev29ydRas2JJTkX8XUhauZtnAVazfmtoCSGfz8oP58b7/tlATNs45t\nmnPTqbtwyj/ejJyMMvbF6ey+XWf23r5LEaIrbS9OWcSzny6M3LbfgK4ctkM9s4WkTvrNjzYkdP/t\nDPp4Czg24f7gzMOppVNE24o0+4jafwjRSfVCn68YSv05FxGRJmbl+k1clmRNrhYVZVx27I5fl0+r\nKC9j9IheBU3mp6KszOjUpjmd2jRnQPe6E/7uzop1mzbP6l+9kcUJyf4g8R8v+b+eJMUJpAE7fMet\nuOy4Hbn40Y8jt782fQkXPPghN548rEHPUF++diPn3fs+ycasnHtg/y0uEolIozQs9tUIEvm/VQJf\nREREGpJnPlnAa9OjK4sdN6wnu/WJSik0fGZG93Yt6d6uJfsN6Pp1e02N88XydbGk/iqmLVzN1AWr\nmP7V6qTVFevjDlt3bK0EfoHs1qcT5x8ygGufmVprmzv89L5JPPXTfelaqQH3ces3VXPp459Ebmte\nXsbvj91BSz9kSb/90QaF7qc8IzvBzHr6zEbUWvTpvnNELSqaLOlc6PMlZWbNCC549AY6ACuBpcBn\n7j4v3f4SlPpzLiIiTcx1z0zlq1VRf4LhvIP6s03n1rXaSzWZnwozo0Pr5nRo3Zz+KST8V66rYsHK\n9Tz6/hf8d+I8vlod/VhJw/KdPbZl0cr1jB0fXfTp8Q/m07VtCy4+enCD/CDo7lzw4IeRJfoARvbt\npLXiRJqGcLW7u4sShYiIiEgG1m2s5vInJkdua928nIuOaHpz28rKjN6dWtO7U2sOGry5WG9VdQ1z\nlq5l8pcr+c0jH7N83aa0+h07fhrHDeup9dgL5Mf79+PNGUsiq2IuXr2B8x+YxJ1njmzQEwty6R8v\nz2B2kuuOP9hvO/p2aVPgiBof/eZHGxC6PzeDPsLHhPvMxrKItq4RbXWJ2j/Zoh6FPl8yuxHM6H8b\neAi4FXgQeAGYa2afm9mNZtY3zX6h9J9zERFpQj6ct5y73pwduW37bm35/r7b1Xl8PJn//Pn7c92J\nQ+kTkfBvyMyM9q2bMbBHJf93xCDeuOjARvlzNlXnHzKAk3btlXT7ba/N5J+vZDLesvjueH0Wz0+O\nLjPXsXUzbjplF12cEWkaEkeerXD3dCvdiYiIiBTNLS99zhfL10VuO/fA/vRoHzWfr2mqKC9ju65t\nWb+pJu0EPsCsJWsZN2l+HiKTKGVlxvUnDaNLkuXtXpm2mL+99HmBoypNc5eu5S8vRk/A2LpDK84+\nQBMUciGvM/HN7PQcd+nAeoL14pYDU919ZY7PAdAxdH9BBn18GbrfIbNQIs0HNgHNEtp2ASak0ccu\nEW3JatwU+nzJ1HdlfjvgPOBsM/sbcIG7pzotr9SfcxERaSKqa5xfP/JR5LrgAFccvyPNK1JL8oVn\n5jdWDbkCgdRmZlz1jZ1YsnojL0xZFLnPVf+bQtfKFnxjl+TJ/lLz8RcruPp/U5Juv+6kobrYJdJ0\nJI7UUz1OERERaTDmLl3LLUmSmH27tOGsffoUNqAGoKq6hrHjp2V8vGbjF1bXyhb8+eRhfOe2tyKv\nzV3/3Gfs3rcTuzbSJSNS9fvHP2VDkqUifnfMEFo1Ly9wRI1Tvsvp30GQeM8XN7NpwHjg7+7+YbYd\nmlkralcoyOQqcHgoWoWZtXT36NqZaXD39WY2EdgjofkY4IY0ujkmoq1tKZwvB8qBc4C9zOwId4++\n+htTCs957PGNMmjVqlVMmDAhg3ByY9WqVQBFjUEaLr1+JBtN9fXz3OxNfPzFxshte/esYP2cj5gw\nJ/1+OwMTJkSPkG1MOgMXj4A3v2zOo9M28NX61EucNbXXWqk7qbcz68syPl8R/aHwggc+4IvPp7Bj\nl/x8pMnle9C6KufS19exsTr6o9Fh21ZQtmAyExZEl6SUaPHnSKQBeoPgWokBLc2sV5bLw4mIiIgU\nxJVPTk6euDt6CC0qlLgLGzdpflYTDeKz8UePaDiD2Bu6ffp34ZwDto9c6q+6xjnv3vd58rx96dim\neRGiK74XJi9MWmVw1MCuHDqke+Q2SV+hhu5Ynm5lBCXZfwS8b2ZPmFmPLGONSixnkniPOiaXC0A8\nE7o/ysyGp3KgmR0CDI3YVFdSvdDnS7QYuAs4AxgOdAGaE8x0HwScBTwbcdxw4FEzq29KU0N5zkVE\npJFbtr6Ghz6LTuC3aQYnD2qaHw7SVV5m7L11M34z3Pl2/xq6t9ZaZQ1Ri3LjZyNa0qNN9PNX7TD2\n/Q3MXFFd4MjS4+7c9ckGFq6NTuD3aVfGiQP1uy3SlLj7lwTLwsV9o1ixiIiIiKTq1WmLefqT6AK2\nBw3qxgGDuhU4otKX7Sz8uLHjp1FVHT14QvLjpwf1Z2SS2fbzV6znl//9EE9WRrMRW7+pmksf/yRy\nW/PyMi49ZgfMdB0uV/I9Ex+CZHtc1Cs66tlMd7/4tiOBD83sUHeflHKEW4pK+EZfTa9bVBn3Vhn0\nk8wtwEUEyWwIHoPbzWzfupYYMLOuwN+TbK4rvkKfD4Iy/qcCD7l71IIxK2K3qbFY9gPuAbZO2GdP\n4DLgwjrOU/Tn3N1HRLWb2cTKysrho0aNyiCc3IjPPitmDNJw6fUj2WiKr5+z73mP9dXR68pdfMxO\nHDtymwJH1LBNmDCBg9vBpWP2S6nMflN6rTUkw3Zdy+i/vc6iVbX/1dpQDTd/WMNDP96DPl1yO3Yy\nV+9BD747lze+jC4Y1rZFBXf+cB+27axxn5morKwsdggi2fgtcADB5IRfm9m/3X15cUMSERERibap\nuqbOxN3FRw8pcEQNQ7az8OM0G7/wKsrLuPHUYRx54yssW1s7PfX85IXc/tosztqnbxGiK55bXvqc\nuUujr13+aP/tcn5tpqnL90z8u4A7Y7f32Jxsj8+kXw+8CzxBkHz9L/AcMC9hn7h1sX3uAh4FXiMo\neR7fJ57Q7wI8ZmZbZRhz1GzqTKbGRK1rl3Up/Th3XwDcHGreGXjBzAZFHWNm8XXsk72rrC6V88XO\n+Zm735ckgR+1/8vAXtRez/4cM+tZx6EN4jkXEZHGbcLURTz54ZeR20Zs25GTdu1d4Igaj4ryMkaP\n6MXz5+/PdScOpU/n1sUOSdLQu1Nr7jhzJJUtoscfL1mzkTNuf5uvIpL8xTZ90Wp+Ny76QhfAVSfs\npAS+SBPl7m8Dvya4ptENeMrMOhc3KhEREZFod74+i+mLoi/nf2/fvkrcRcjVLPw4zcYvvK3at+K6\nk6KKTAeufmoyH85bXriAimzOkrX8dcLnkdt6dWzFj0dtX+CIGr+8JvHdfYy7nwlMJkj4xhPt9wMH\nA5XuPtLdj3X3b7v7Se5+mLtvC/QAzgfmEFsnDtgRuNTdT3D3fYH2wOEEiWJL6H9r4KoMw476S1Rf\nOfYoUcfUmbTOwEVAeC31XYGPzOxpM7vEzM41s0vN7AWCARPxIXHVbFm+D2B5iZ0vbe4+h6C8fqJW\nBKX4k2lIz7mIiDRC6zdVJ030lZcZV35jR8rKVIoqW0rmN1xDerbj76ePoHl59MeX2UvWctYd77B6\nQ1WBI0tu/aZqzrnnPdZtii73f/KuvTl2aF3jTEWksXP3a4FfEVzL2B342Mx+bGapLj0nIiIikndf\nrdrAjc9HJ6N7tGvJ2QcocRclV7Pw4+Kz8aWwDhzUne8lmW2/qdo55573Wbk+pXmoDZq7c+njn7Cx\nKnogySXH7ECr5uUFjqrxy3s5fTP7PUGZOIClwAnu/kp9x7n7IuDPZvZ3YCxBYnYn4GUz28vd57l7\nDcFa6M+a2aXA7wg+/BpwmpldEkvqpszd15lZDVsOcMjkCm+4jHq1u+d0Vra7bzSzI4HHCD7wx1UA\nh8VukYcC5wADgYMS2peX0vky5e5PmdlEILFE/aHA1Un2bzDPuYiINE5/eXE6c5ZGf7D73j59GdSj\nXYEjatziyfzjhvVMqcy+lIa9+nXh+pOHcu697xO17NxHX6zgx/+ZyK1n7EbzinwXHKvflU9OZsqC\nVZHbtu/WlkuP3aHAEYlIKXL3P5rZmwTL0A0kqID3ZzP7CJgNrCR6ycF6uvXv5jZSERERaaqueXoK\nq5IMmP71UYNpk6RqWlOW61n4cWPHT+O4YT2pSDLAXfLjwsMH8c6spXwwb0WtbXOWruWihz/i5lN3\nadRrwT8/eRHjpyyK3HbgoG4cPLhbgSNqGvL6m25mowgS+EawXvgBqSTwE7n7Onf/HnB3rJ9ewH8i\n9rsUuJfN5fUrgOMyDH156H6PDPoIl/NfllkodYsNdjgAuIZgeYH6zAMOcfdbqB3jvFI7XxYeD93f\nrZ79l4ful+xzLiIijcv0Rau45aXoUlRbd2jFTw/uX+CImo7wzHwpfUfv3JNL6lhr8ZVpi7nwvx9Q\nU5Nuviu3nvroS/795uzIbS0qyvjLacM1Ql1EADCzMoIKdvHrMwY0A4YDxwOnE1SWS/U2JnYTERER\nydr7c5bx34nRl/FH9u3EMTtnuqpx45brWfhxmo1fHM0ryhh76vCky/w9+eGX3Pv23AJHVTjrNlZz\n6WPRFUSbV5Rx6TE7NOoBDMWU7yFSl7O5zP0N7v5xFn2dR5CUbwvsa2aHuvuzoX1+DZzE5g+/+xPM\n4k/XZ8AeCfczWYQ2fEzuh13FuPs64Fdm9mfgRIJZ50OArgTP8RcESxrcBzwS2x9gcKircKn8kjhf\nhqaE7rcxs1YJsYQ1qOdcREQaB3fnN498zKbq6ITjpcfuQOvmGtGeb/FkvjQMY/buy8JVG/hbknXY\nHp00n+7tWnLRkeF/PQtj7tK1XPjQh0m3X3LMDgzsUVnAiESkVJnZ1sCjBAl7SH/GvYiIiEje1NR4\n0sRdmaHEXR1Gj+iV8nWGCRMmADBq1Kj8BSRZ26Zza/4wemfOvue9yO2/f/wThm/boVFW0/zbhOl8\nsTw6tfaTUf3YRstV5k3eZuKb2TbA3glN/86mP3dfBjyZ0HRqxD6zgbcIBg4Ym9djT1c4AbxdBn2E\nF8mYnGEsKXP3Be4+1t2Pcfd+7t7O3Vu7e393P9bd74knsc2sDRCu4fluKZ8vTUsj2jrWsX+DfM5F\nRKRhe/i9L3hrZtSfLDhkSHcOGdK9wBGJNAwXHjaQE4ZvnXT731+ewb9emVHAiAKbqms47773WbU+\nutTkUTtvxakjMxkrKiKNjZl1AMYTJPATr35bljcRERGRnHhw4tzI8uEA395jW4b0bHzJSpG6HLXz\nVnx7j20it22oquHsu99j7cbo6wEN1azFa7jlpejrK9t0as2P9u9X4IialnxO7UpcM73K3cNJ0kx8\nDJwc0X+iN4C9Yt93zvA8n4buj8ygj3B8pZbQ3RtIrOG5nNo/d0M+X/uItuj/OAJN4TkXEZESsnzt\nRq78X/SfilbNyrVetkgdzIxrRu/MktUbeemzryL3ueLJyXRr15Jjh/YsWFzXPfsZ789ZHrmtd6dW\nXH3CTpqpIiJxfwD6E8y+d4IEfDXwMvAeMAdYE2sTERERKagV6zbxx6enRm7r2LoZ5x8yoMARiZSG\n3x41hHdnLWPKglW1tn3+1Rp+N+4T/tRIlmx0dy59/BM2VtdEbr/02CG0bKalAvMpn0n8xOEoq3PU\n58rYVyN5ufOFCd93yPA8L4bu9zCz7dw9pek8ZrYdEJ46F+6z2E4P3b/b3fN5caDQ5wsvILzO3dfU\nsX9TeM5FRKSEXPP0FJau2Ri57eeH9GfrDq0KHJFIw9KsvIy/fms4p/3zzaSzQ37xwCQ6t2nO3tt3\nyXs8L3/2Fbe8FF3iv6LMGHvqcNq1bJb3OESk9JlZJ4J16+PJe4BHgHPdXYucioiISNH9+fnPWJLk\nmsUvDxtEh9bNCxyRSGlo2aycm08bzrE3v8rajbVTXP+dOI+9+nXmhOENf9nGZz9dyISp0RMnDh7c\nnQMHqYJovuWtnD6QeIWqg5nl4opV4isi2QCExIUZMk0STyRY1z3RSWkcf3Lo/jx3z+f672kxs27A\nN0LNtzaW88UcEbqffGHSQKN+zkVEpLRMnL2Ue9+eG7ltUI9Kztw7vEKLiERp06KC28bsRp8k669t\nqnZ++O+JfPxFXQWZsrdo1XrOf2BS0u0XHj6QYb075DUGEWlQ9gfiV74deOr/2bvv+Ciq9Y/jn5NG\nDb0XRXqvgiiioNgL9t7Adu3da8eu13Lt115R8IcNFAuoiA2VjlTpSO+Emn5+f8yuWZaZZLPZnWyS\n7/v12ld255yZecKGZGeec55jrT1dCXwRERFJBH+t28F7v61wbevctAZn99YSYVKxtW5QnYdP6ezZ\nfs/oOSzZGKu5zaVjT3YeD37hXky7UkoSw06KdjVzKY54JvHXhb0eEINjDgx5vt6jT3rI88JmXnuy\n1lpgTNjmS40xRdaFCPQZGrY5/Fil7Ukg9E7nz9baGeXlfMaYfkD/sM3jCtunArznIiKSIHLy8rn7\nszme7Y+c2pnU5Hh+RBMpX+pWr8R7Qw+iXvVKru07s3K55O0prNyyOy7nz8+33PR/M9m0032WyoB2\n9bns0JZxObeIlFnB0XrBWfjDSisQERERkVDWWu7/fC55+da1/YGTO5GcpCXCRE7r2YzTPWbb787O\n45oPppOZU3ZXxnrph8Ws3rbHte2aga1pXsd9MoXEVjzvEC8OfA3+tr+xJAczxhwE9KVgvbjFHl2D\nw8AsUJJR7M8AuSGvWwPXR7DfjYG+QbmBY3kyxgwwxtiwx4DihRsZY8zF7F3aPge4Oh7nKsn5TJSL\nhRpj6gHvhm3OA0ZEsLtv77mIiFRcb/+6zHXdLIBz+zSn1/51fI5IpOzbr25V3hnSm+qV3It1bdqZ\nxUVvTWbzzqyYn/vlH5fw6+LNrm0N0ivx9JndSNJNLhHZW+iooxxr7dRSi0REREQkxFez1/HbUvfr\nm9N6NNU9C5EQDw7uRKv61VzbFqzbwSNfzvc5othYunEnr/3kvtL0/nWrcsVhmqjgl3gm8ScBGwLP\nDXCsMebWaA5kjGmIk5g1FIxUH+3RvUfI84jWM3djrV0MvBO2+WFjzFFe+xhjjgEeCtv8trXWfXHM\nEjLGNIo02W+MSTbG3Ay8Fdb0pLXWezpg6Z3vOWPMY4H3PiLGmI7Aj0CrsKa3rLV/FbV/WXjPRUSk\nbFu9bQ/PfLvIta1OtTT+fWx7nyMSKT86N63JKxf0IjXZPWG+bNMuhr4zhd3Zua7t0Zi6fAv//Xah\na5sx8Ow53anrUSFARCq0LSHPy3adTRERESk39mTn8ciX7uWzq6Ulc8dxumchEqpapRRePK8nlVLc\nU63Df1/BV7PX+hxVyVhrGfb5XLLz8l3b7z+5E5VTiyxgLTEStyS+tTYfeAMn6W4DX/9jjPmfMaZW\npMcJJFAnAW0pmNWfAYx06VsbODCk35Ro4w+4l4KBCOCUhB9rjHnCGHOAKdDKGPMk8AVQJaT/BuC+\nEsZQmEbAD8aY+caYR4wxAwP/BsA/ifQWxphrcNZ8f5q93/PvgPsT9Hw1gDuAVcaY8caYa4wxBxtj\nQpdLwBhTwxhzrDHmHWAmEL4Qx3zg35F/iwn/nouISBl2/+dz2eNRSuvu4ztQq2qaa5uIRObQNvV4\n6sxunu2zVmVwzQfTyfG4GC2ObbuzuX7kDM8yk9cNbM0hreqV+DwiUi7NDXleK5Jl3ERERETi7eWJ\ni1mTkenadv2RbWhQo7LPEYkkvg6Na3BfIevD//vjP+O2vF88jJu7jp8XbXJtO7pjQwa2a+BzRBVb\nvBdcfRhYFngeTORfCSw3xrxljDnbGNPOGFPTGJNkjKlsjGlojDnMGHOrMWYK8A3OenHB/S1wp7XW\nrabL6UBoDc2JJQneWrsOOAMIXfghDbgNZ5b/7sBjMXArkBrSbw9weuAY8dYeuAuYAGwxxuw2xmwG\nsnH+/V8Ewu9mfgucYq3NSfDzpQBHBY45CdhujMkyxmwyxuzGGdDxNXAxe//7AywBjrXWbo30ZGXo\nPRcRkTJm/Nx1fDtvvWvbQQfU4bSeTX2OSKR8Gty9Kfec0MGz/Ye/NnLnp7Ox1j35HglrLbd//Kfn\nDa4+Lepw/ZFtoj6+iJR7vwPB61QD9C/FWERERET4e/NuXvEon92yfjWG9DvA54hEyo7z+uzHCV0b\nu7btyMrl2pEzyM4t+WSCeNudncuDX7hX46icmsS9J3oPVpD4iGsS31qbCZwIBO9YBxPxNXCSriOA\neTil5HKAXTjr2P8A/AfoRUHiPugla+2rHqe8MfDVAGustZNi8D38jJNEdqt5UTnwCLcWGGSt/aWk\n549SFaAO7u9vNvAgcJy1dlcZPV8aUJe9Z8CHssDbQHdr7d/FPXgZfc9FRCSB7crK5f7P57q2pSYb\nHjm1M8ZozWyRWLmsf8tC12j7eNoqnhpf5GpLnt77bQXjPQbl1KqaynPndiclOd7jpUWkrLLW5gL/\nC9l0bWnFIiIiIgLw0JfzPJOM953YkTSPcuEiAsYYHjutC/vVqeraPmvlNp4ct8DnqIrvhQne1Tiu\nHdia5h7fn8RP3H/zWmvnAwOBORQk5INJeVPEI9jXAHnAg9ba693OY4xJAc4HegQeMRvJbq39FegA\nPM7ea9eF2xLo0yEWAwgisAwnQT4V59+nMFuAl3BiG2atLap/aZ/vGZy15n8i8jUC1wTO2dFaO9Ra\nG/Xaggn8nouISBn0/PeLPD8EX3lYK1o3SHdtE5Ho3XFse07p3sSz/aUflvDupOXFPu6c1Rk88uV8\nz/anzuhG45peY01FRP7xCLAA537HqcaYS0o3HBEREamofly40bNy4KAODRmg8tkiRapROZUXzu1B\narL7JJ3Xf17GhAXu/88SweINO3njZ/dqHC3qVuXyQiZKSPykFN2l5Ky1fxljegG3AzcA9YNNhewW\nTOSDMzP/FmvtzELOkQvMKnm0nsfPAO40xtwL9Aa6AMFFLjfiDFKYEoijuMeeSMH3WtyYhgHDjDHV\ncErYtwIa4KzlngWsw1lvb4a1tkT1Ovw8n7V2FoH30zhTE1sDLYFmQG2c2fBZOCUINwHTo5l1X0QM\ncXvPRUSk4liwbjtv/LLMtW2/OlW59ojWPkckUjEkJRmeOKMbm3dle67ndv8Xc6mfXonju7iXvQu3\nMyuX60bOIDvP/WPu0H4HMKhjw6hjFpGKw1qbaYw5Dvge51r3dWPM/sAjusYUERERv2Tn5vPAF+6V\nA9NSkrhP5bNFItateS3uOK4DD411L0l/y6hZfHVD/4Qb+G+t5f7P55KT556yfWBwZyqlJPsclYBP\nSXz4J8n+qDHmSeBUnHLlBwHt2Hct8804s71/Az601i70K86iBL6P3wKPhBEoVT8p8ChX57POgqWL\nAg/fJep7LiIiiS8/33L3Z3PIy3f/EPzg4E5UTtWHYJF4SUtJ4uULenHOa78xZ/X2fdqthRs/nEnt\nqmkc3Kpukce7b/Qclm1yXyGqc9Ma/Pu4diWOWUQqBmPMfjgTG84CXsYZOH4fcKUxZjjwC7ACyKDw\nCRCuYj3IXURERMqndyYtY+lG92ucKw9ryX51VT5bpDiG9mvBb0s28d38Dfu0bd2dww0jZzLi8oMS\nagm+r2av45fF7pMfjuvciMPb1ndtk/jzLYkfZK3NAUYFHgAEZnXXxJlZvS3KUu8iIiIiCWXU1JVM\nW7HVte2Ero1Vkk7EB9UrpfD2JX04/eVJ/L1l9z7t2Xn5XPHeVEb962A6NK7heZxPpq3i0xmrPdtf\nPLenRqaLSHEsZ+/kfHApwUbArYFHtCylcL9HREREypYN2zN57jv3eXNNalbm6gGqHChSXMYYnjyj\nG8c//zNrXZbWnLx8C89/v4ibj06MSQC7snI9KwdUSU3mHlXjKFUJMdTDWrvLWrvGWrtZCXwREREp\nDzbvzOKxrxe4tlWvlKKSdCI+qp9eifeG9qFutTTX9h1ZuVzy9mRWbd03yQ+wZONO7hk9u9BztKhX\nrcRxikiFY9h7KUFLQTK/pA8RERGRQj3+zQJ2ZbunY+46oQNV0jRIWSQataul8fy5PUhOcv9Y/sIP\ni/nVY+a7356fsIh12/cdbABw3ZGtaVorsUr/VzQJkcQXERERKW8e/WoBGXtyXNtuPbotDWtU9jki\nkYqtRb1qvD2kN1U9bkSt357FxW9NZuuu7L2278zM4YI3/mBPTr4fYYpIxWJDHl7bi/sQERERKdK0\nFVv4dLp7pbGDW9blhC6NfY5IpHzp3aIONx/V1rXNWrjhw5ls3JHlc1R7W7xhB2/+vMy1rWX9alx2\naEufI5JwSuKLiIiIxNhvSzbzyfRVrm1dmtbkwoNb+BuQiADQtVktXr6gFykeo+GXbNzFpe9OYU92\nHnn5ll9X53Dw4xNcS+CJiJTQ33F6rAh8FREREXGVl28Z9vlc17bkJMP9J3fCGBX2ESmpqw5vRf82\n9VzbNu3M4uZRM8nPL51xuNZa7hszl1yP8z9wcifSUpRCLm1aI01EREQkhrJz8z3LbhsDj5za2bOc\nlojE3+Ft6/PEGV25edQs1/bpf2/j9Jd/ZVPGHjbs1qRWEYkPa22L0o5BREREKqZRU1cyZ/V217YL\n++5Pu0bpPkckUj4lJRn+e1Z3jnvuZzbt3HfW/c+LNvHyj0u4ZmBr32Mb++daJi3Z7Np2QpfG9G9T\n3+eIxI2GUYiIiIjE0Os/L2XJxl2ubRf13Z+uzWr5G5CI7OO0ns2487j2nu3z1u5QAl9ERERERMqd\njN05PDnuL9e2OtXSuGmQe/lvEYlO/fRKPHt2d7yKW/z324VMXb7F15h2ZuXy8JfzXNuqpiVzz4kd\nfI1HvJXKTHxjTDrQG2gO1ALSiWJAgbX2wdhGJiIiIhK9vzfv5vnvF7m21U+vxC3HtPM5IhHxcsVh\nLVm/PYu3fnVf/01ERERERKS8+e+3f7FlV7Zr2+3HtKNm1VSfIxIp/w5tU49rB7bmhQmL92nLy7dc\nP3IGX17fn9rV0nyJ5/nvF7F++76VAQCuP7INjWtW8SUOKZpvSXxjTBpwPnA10AOIRR1ZJfFFREQk\nIVhruXfMHLJy813b7zuxIzUq62JYJFHk5Vs6Nk6naloyu7PzSjscERERERGRuJq/djvDf1/h2ta1\nWU3OOrC5zxGJVBw3HNmGP5ZuYbLLrPs1GZnc9vGfvH5RL4zXlP0YWbh+B2/94j6ZoVX9agztd0Bc\nzy/F40s5fWNMZ2Am8AbQM3BeE+UDYjMAQERERCRmvp6zjh8XbnRt69+mHid2bexzRCLiJjcvn0+m\nrWLQf3/k1o//VAJfRERERETKPWst938+l3yPVcPuP7kTSUlKu4jES0pyEs+d253aHtUuvpu/nrd/\nXR7XGKy13DdmDrkevwgeHNyZtBStwp5I4v5uGGO6Ab8B7XBPvtuQh9f20Hb9JREREZGEsiMzhwe+\nmOvalpaSxEODO8d9JK2IRObd31Zwy0ezWL55d2mHIiIiIiIi4ouxf67lj2Xu626f3rMZPfer7XNE\nIhVP45pVePqsbp7tj309nz9XbYvb+T+ftYbfl7r/Hjixa2P6ta4Xt3NLdOJaTt8YUxUYDVSjIBFv\ngD8Cj1bACYHuFngAqA7UA3oDHQL9gwn8n4Af4hmziIiISHH999uFnmtJXTuwNS3qVfM5IhERERER\nERER2J2dy6NfzXdtq14phX8f187niEQqriPaN+SyQw/gDZeS9jl5lmtHzGDs9YfGfEnOHZk5PPyl\n+++BqmnJ3HNCx5ieT2Ij3jPxLwf2pyB5vxE43Fp7sLX2RmBsaGdr7QPW2tustUOstZ2B1sDLFCTx\nD3W62QestQ/EOXYRERGRIs1ZncG7k5a7trWsV40rD2/pb0AiUqiLD96fp8/sRou6VUs7FBERERER\nkbj73w9LWJuR6dp246A2NEiv7HNEIhXb7ce2p1uzmq5tf2/ZzZ2fzsZaj7UvovTsd4vYuMN9AtKN\ng9rQqKZ+DySieCfxr6EggZ8LHGet/TnSna21y6y11wCDgB048Q4zxtwbj2BFREREiiMv33L3Z7M9\n15R7+JTOVEpJ9jcoESlUSnISp/dqxnc3H65kvojEnDEmL+yRG2G/WD5czykiIiIVz4rNu3jtp6Wu\nba3qV+Oig1v4G5CIkJaSxIvn9SS9snux9C//XMvIyStjdr4F67bzjscEpDYNqjOk3wExO5fEVtyS\n+MaYJjgz6cFJ5L9nrZ0RzbGstROB0ygYEDDMGNMnFnGKiIiIRGvEHyuYtSrDte3UHk05RGtJiSQs\nJfNFJE6MyyPSfrF8iIiIiPDQ2Hlk5+W7tt1/cifSUuI9z1NE3DSvU5X/nN7Vs/2BL+ayYN32Ep/H\nWst9o+eS5zED6cHBnUlN1u+BRBXPdyaYZA9ePH5QkoNZayeEHMMA95fkeCIiIiIlsWFHJk9885dr\nW43KKdx1fAefIxKRaCiZLyJxYClYFjCSfrF8iIiIiADww18b+G7+Bte2Yzo1pH+b+j5HJCKhju/S\nmAv67ufalpWbzzUfTGd3dsmKbI2euZrJy7e4tg3u3oSDW9Ut0fElvuKZxG8Q9npKUTsYYyoV0eXN\nYFfgKGNM7WgCExERESmph8bOZ0eW+wfpfx/XnvrpRX2sEZFEEp7Mb1hVE1lFJGqRzobXDHwRERGJ\ni+zcfB78Yp5rW6WUJO45oaPPEYmIm3tO6Ej7RumubUs27uK+MXOjPvb2zBwe+XKBa1v1SpqAVBbE\nM4kfmmDfZa3d6dInK+x15SKOOYmC0eVJQN/owxMRERGJzk8LN/LFrDWubT32q8W5vd1H0YpI4gsm\n8x89tAqXd0nTzHwRKRZrbVLYIznCfrF8uJ5TREREKo63fl3Gsk27XNuuPLwVzevoOkckEVROTebF\n83pSNc39I/zH01bx6fRVUR37mW8XsmlneBrWceOgNjSsUVRKVkpbPJP4oWXcMj367Ah73bjQA1qb\nA2yjYHR5q6giExEREYlSZk4e946Z49qWnGR45JQuJCVpIpxIWZecZOjXNFVl9kVEREREpExZvz2T\nF75f5NrWtFYVrjpcaRWRRNK6QXUePqWzZ/s9o+ewZKPbPGlv89Zs591Jy13b2jVM5+JDWhTreFI6\n4pnE3x7y3L0WBGwNe71/BMdNo2CAQI3iBiUiIiJSEv+buIQVm3e7tg05pAUdm+jjiUh5El5mX8l8\nERERERFJZI99NZ9d2Xmubfec0IEqHjN+RaT0nNazGaf3bObatjs7j2s+mE5mjvv/63DWWu4bM4d8\n697+4OBOpCbHMz0ssRLPd+nvkOdpxpjqLn2CizEEf5QOKuyAxpj9gGohm/ZEH56IiIhI8SzZuJNX\nJi5xbWtcszI3HtXW54hExC9K5ouIiIiISKKbunwLo2e6L//Xr3Vdju3cyOeIRCRSDw7uRKv61Vzb\nFqzbwSNfzo/oOJ9OX83UFeFzqB2n9mjKQS3rRh2j+CueSfzwn6Z9akFYa9cAWwIvDXBmEcc8P6Qv\nwPqooxMREREpBmst946eQ3Zevmv7sJM6Ub1Sis9RiYjflMwXEREREZFElJdvuW/MXNe25CTDsJM6\nYYyW/xNJVNUqpfDieT2plOKeuh3++wq+mr220GNk7Mnhsa/dk/3plVK48/j2JY5T/BO3JL61dhmw\nIWTTgR5dx1KQlO9ojLnarZMxphtwFwWz9gF+L2mcIiIiIpEYM3MNk5Zsdm07sn0DjunU0OeIRKQ0\nhSfzRUREREREStPIyX8zb+1217aLD25B24Zeqx6LSKLo0LgG953U0bP93x//ycotzjKfuXn5jFue\nw7jlOeQGJh39d/xfbNqZ7brvTUe1pUF65dgHLXET70UPJoY8P8Gjz3uBrxYnmf+CMeYNY8xAY0xr\nY8yBxphhwE84pfRNoO9Ma+3SOMUtIiIi8o+M3Tk8/OU817bKqUncf7JGs4tUVMFkvoiIiIiISGnZ\nuiubp8b/5dpWr3oaNx7VxueIRCRa5/XZjxO6NnZt25GVy7UjZ5Cdm8+YmWsYuSCbkQuyGTNzDXNW\nZzD89xWu+7VvlM5FB+8fz7AlDuJd83UMcFbg+QBjTF1r7V5T2Ky1E4wx44BjKEjkDwk8QpmQdoBh\ncYtaREREJMQT4xZ4jmK94ci2NK+jctoiIiIiIiIiUjr+++1Ctu3OcW27/Zj21Kic6nNEIhItYwyP\nndaF2asy+Dsw6z7UrJXb+M838/l+fkEx9OcnLKJO1VTy7T7dAXhwcGdSkuM9r1tiLd7v2BhgD07i\nvRJwlUe/y4EVFCTqCTwPfYT+6D1rrR0bj4BFREREQk3/eysjJv/t2ta2YXUu63+AzxGJiIiIiIiI\niDjmrdnOB3+4z77t1rwWZ6hymEiZU6NyKi+e14PUZPfKn2/+spzlmwsS/Cs272bGygzXvqf1bEqf\nA+rEJU6Jr7gm8a21u4HmQOPA40WPfquAw3HK7weT9uAk7kOT+pnAHdbaW+IXtYiIiIgjNy+fuz+b\ng/UYxfrwKV1I1ShWERERERERESkF1lru/3yu5+zbB07uRFKSlv8TKYu6NqvFHcd1KNEx0iuncGcJ\njyGlJ97l9LHWbomw39/AEcaYI4BTga5AAyAHWIWT4H/PWrsuTqGKiIiI7OWdScuZv3a7a9tZBzbT\nKFYRERERERERKTWfz1rD5OXuKZizDmxG9+a1/A1IRGJqaL8W/LZkE9+FlM4vjluPbkf99Eoxjkr8\nEvckfnFZaycAE0o7DhEREanY1mzbw3+/XejaVrtqqkaxioiIiIiIiEip2ZWVy6NfzXdtS6+Uwm3H\ntPc5IhGJNWMMT57RjeOf/5m1GZnF2rdj4xqcf9B+cYpM/KD6ryIiIiIuHvxiHruz81zb7jq+A7Wr\npfkckYiIiIiIiIiI48UfFrN+e5Zr241HtdXsW5Fyona1NJ4/twfJxVwa46FTOpGiZUDLNL17IiIi\nImG+n7+eb+a6r+DT54A6nNGrmc8RiYiIiIiIiIg4lm3axRs/L3Vta9OgOhcdvL/PEYlIPPVuUYcb\nj2wTcf/Tezal1/5aBrSsUxJfREREJMSe7DzuGzPXtS0lyfDIKZ0xpngjX0VEREREREREYuWhsfPI\nybOubfef3IlUzb4VKXca16wccd8uzWrGMRLxi36Ti4iIiIR4fsIiVm/b49p2xWEtadMw3eeIRERE\nREREREQcExasZ8KCDa5tx3VuRL/W9XyOSETiLTcvnxd/WBxx/3d+XU5uXn4cIxI/KIkvIiIiErBw\n/Q5e/8m9HF2z2lW47ojIy1aJiIiIiIiIiMRSVm4eD34xz7WtcmoSd5/QweeIRMQPY2auYfnm3RH3\nX755N2NmroljROKHFD9PZoxJAQYAPYH2QC0gHUiO4nDWWntkzIITERGRCi0/33L3Z7PJzXcvR/fQ\n4M5USYvmI4uIiIiIiIiISMm98fMyz0TeVYe3plntqj5HJCLxlpuXzwsTFhV7vxcmLGJw9yakaHmN\nMsuXJL4xphpwD3ApUDcWhwTc77CLiIiIROHj6auYsnyra9txnRsxsH0DnyMSEREREREREXGszdjD\nixPcy2k3q12FKw9v6XNEIuKH4s7CDwrOxj+9V7M4RCV+iPvwC2NMZ2A+cDtQDycBb+J9XhEREZFI\nbdmVzWNfzXdtq5aWzH0ndfQ5IhERERERERGRAo99tYA9OXmubfec0JHKqaoeKFLeRDsLP+iFCYvI\nzcuPYUTip7gm8Y0x+wE/AM0omD0fnEFvSvAQERERiZnHv57P1t05rm03H92OxjWr+ByRiIiIiIiI\niIjjj6Wb+XyW+/rW/dvU45hODX2OSET8EO0s/KDgbHwpm+JdTv9ZnPL5oYn7LcDnwDTgb2AX4D58\nTERERCTOJi/bwqipq1zbOjauwcUH7+9zRCIiIiIiIiIijty8fIZ9Pte1LSXJMOykjhijuY8i5U1J\nZ+EHvTBhEYO7NyElOe7F2SXG4pbEN8Y0AE7GSeAHZ+G/CNxprd0Vr/OKiIiIRCo7N597Rs92bTMG\nHjm1sz7gioiISMSMMRNKOwbAWmuPLO0gREREJDZGTv6bBet2uLZdckgLWjdI9zkiEfFDSWfhBwVn\n45/eq1kMohI/xXMm/kCccv3BEvrvW2uvj+P5RERERIrlzV+WsXD9Tte28w/ajx771fY5IhERESnj\nBlBQjbA0BCdRiIiISDmwZVc2T41f6NpWr3olbhjUxueIRMQPsZqFH6TZ+GVTPN+tJoGvwTouj8fx\nXCIiIiLFsnLLbp773utCOI3bjmnvc0QiIiIiIiIiIgWeGv8XGXtyXNvuOK496ZVTfY5IRPwQq1n4\nQcHZ+FK2xDOJnxbyPNNaOz+O5xIRERGJmLWWYZ/PJTMn37X93hM7UrOKLoRFREQkKqYUHyIiIlJO\nzFmdwcjJf7u29divFqf1aOpzRCLih1jPwg96YcIicvPc74WWSTNHQl5uaUcRV/Esp78x5HlWHM8j\nIiIiUizj5q5nwoINrm39Wtfl5G5NXNtEREREijCwtAMQERGRss9ay/2fz8W6LJJjDDxwcieSkjR+\nT6Q8ivUs/KDgbPzTezWL+bFLxeh/wU9PwGG3Q5czITmeKe/SEc/vaEbI85rGmMrW2sw4nk9ERESk\nSDuzcnngi7mubWnJSTw0uDPG6EJYREREis9a+2NpxyAiIiJl3+iZq5m6Yqtr29kHNqdrs1r+BiQi\nvjm9V7OIE+0TJ04EYMCAAfELKJFtWVquk/lxK6dvrZ0BrAjZpNHoIiIiUuqe/XYhazPcxxVeNaAV\nLetX9zkiERERERERERHHzqxcHvtqgWtbjcop3HZMO58jEhFJcMFk/ku9y1WZ/bgl8QOeCnl+a5zP\nJSIiIlKouWsyeHvScte2FnWrctWAVv4GJCIiIiIiIiIS4oUJi9iww32F4puPakvd6pV8jkhEpIwo\nZ8n8eCfxXwZ+AgwwwBhzV5zPJyIiIuIqL99y12dzyMt3WVAOeOiUzlROTfY5KhERERERERERx5KN\nO3nrl2Wube0apnNB3/19jkhEpAwqJ8n8uCbxrbX5wCnAdJxE/kPGmDeMMXXieV4RERGRcCMn/82s\nldtc207u1oT+ber7G5CIiIiIiIiISIC1lge/mEdOnvvkg/tP7kRKcrznZYqIlCNlPJmfEu8TWGu3\nGWMOBV4ChgQe5xpjvgB+AVYA2wH3v0yFH/unWMYqIiIi5dPGHVn85xv39eTSK6dwz4kdfI5IRERE\nRERERKTAd/M38OPCja5tJ3RtzMGt6vockYhIORFM5v/0BBx2O3Q5E5LjniIvMV8itNZmGmPuAOoD\nJwJVgDMDj6gPi0/xi4iISNn2yJfz2JHpPtLy9mPa0SC9ss8RiYiIiIiIiIg4MnPyeGjsPNe2yqlJ\n3H28Jh+IiJRYGUvm+xKZMeYG4EGgOgUz7o0f5xYREZGK7dfFmxg9c41rW7dmNTnvIK0nJyIiIonB\nGNMZ6AM0B2oB6RR/KURrrb00xqGJiIhIHL3x81L+3rLbte2aAa1pUquKzxGJiJRjZSSZH/eIjDEv\nAf+iIGlvw74qmS8iIiJxkZmTxz2j57i2JRl45NQuJCfpo4iIiIiUHmNMPeA64EqcCoYlOhzO/RYl\n8UVERMqINdv28NIPS1zb9qtTlcsPa+lzRCIiFUSCJ/PjGokx5mLgqsDL8KT9JmAlsAvIi2ccIiIi\nUjG9+uNSlm3a5dp28SEt6Ny0ps8RiYiIiBQwxhwNvIeTvC9sZKENee7Vz3psFxERkQT26Ffz2ZPj\nniK598SOVE5N9jkiEZEKJpjM37MVDr66tKP5R9yS+MaYJOChwEuLc5G5FXgCGGGtXRmvc4uIiIgs\n27SLlyYudm1rVKMytxzdzueIRERERAoYY44EvgSCd+aD904Ief1P97DdwxP2xqWPiIiIJLjflmxm\n7J9rXdsOa1ufQR0a+ByRiIgkinjOxD8YaEbBRehq4DBr7bI4nlNEREQEay33jZlDdm6+a/uwkzpS\nvVLilEYSERGRisUYUwf4P5wEfvC+STYwGvgD6A5cFOhugSFAdaAe0BvoD6RTkMyfCLzrR+wiIiIS\nG7l5+TzwxVzXttRkw7CTOmKMxuiJiMRdnZYF5fQTSDzvXncLfA2ux3aHEvgiIiLihy/+XMvPiza5\ntg1sV59jOzfyOSIRERGRvVwH1KEgCb8QONFauxjAGHMlBUl8rLV7JeiNMdVxli8cBlQBDgeWA5dZ\na91HMYqIiEhCef/3FSxYt8O1bWi/A2hVv7rPEYmIVDChyfvkxJvwFc+IaoU8t8BncTyXiIiICAAZ\ne3J4aOw817ZKKUk8OLizRrKLiIhIabucghn4u4HjijPxwVq7E3jSGPMF8B3QBLgYZzb/v2IfroiI\niMTS5p1Z/Pfbha5t9dMrcd2RbXyOSESkAknw5H1QPCPbE/I8w1q7O47nEhEREQHg6fF/sXFHlmvb\n9Ue2oXmdqj5HJCIiIlLAGHMATtLdBh6vRVu50Fq7wBhzPDAFSAUuN8Z8Zq0dF7OARUREJOaeGv8X\n2zNzXdvuPK69lgAUEYmHMpK8D0qK47GXhjzX3XIRERGJu1krtzH89xWuba0bVOfy/i19jkhERERk\nH70CX4OlgUaV5GDW2j+BV0M23VOS44mIiEh8/blqGx9OWena1mv/2pzao6nPEYmIlHN1WsIpr8A1\nU6D7uWUigQ/xTeL/CgTXYUszxrSK47lERESkgsvNy+euz2ZjrXv7I6d0Ji0lnh99RERERCJSL+z1\nTJc+e32iMcZULuKYw4NdgUOMMY2iC01ERGRvuXn5vPnLMt78ZRm5eflF7yCFys+3DPt8ruu9C2Pg\ngZM7aQlAEZFYOmJYmUveB8XtTra1dhPwVcimc+J1LhEREZHhv69g7prtrm1n9GrGQS3r+hyRiIiI\niKvaIc93WmszXfqEb6tSxDGn4UykCKYEDooyNhERkb2MmbmGh8bO46Gx8xgzc01ph1PmfTZjNTP+\n3ubadm6f/ejctKa/AYmIlGe9L4fDbi5zyfugeEd9J3A0zrpstxlj3rfWute4LSOMMclAb6Azzuh5\nA2wEZgNTrbV5pRgeAMaY/XFibAjUAnYDy4DJ1tqYf9Ly63yBmQcdAo96QDqwE9gCLABmWGvdFxIS\nEZFybV1GJk+PX+jaVqtqKnce197niEREREQ85YQ8z/bosyPsdRNgq9cBrbX5xphtQB2cRH6LEsQn\nIiICOLPwX5iw6J/XL0xYxODuTUhJVpW7aOzIzOGxrxe4ttWsksqtR7fzOSIRkXKsSm0YeFdpR1Ei\ncU3iW2vnGmMuA94BagDfGWNOsta6/6VKYMaYGsC/gSvYt/Rd0CZjzGvAf6y17lMB48QYk4YT27WA\n5197Y8wk4Alr7ZiycD5jzIHAScCRODMJCvuZ3WOM+Rx4zlr7WxTnGgD8EEWYoX601g4o4TFERKSY\nHho7j51Z7uO47jyuPXWrV/I5IhERERFPofcLanj02RL2uiUwt4jjVqVgJn61KOISERHZy5iZa1i+\nefc/r5dv3s2YmWs4vVezUoyq7Hr++0Vs2pnl2nbL0W2pUy3N54hERMqxgXdD1TqlHUWJxH3InLX2\nfeAMIANoBUwzxjxpjOkQ73PHijGmLzAfuAvvBD6BtruAeYF9fGGM6QLMA16gkIR6wCHAaGPMx8aY\n9EQ9nzHmNGPMEmAKcB/Qj6IHnVQBzgYmGWOGG2NUe0hEpAL44a8NfDl7rWvbgfvX5sxezX2OSERE\nRKRQy0Kepxhjarn0mRf4GkzKH1zYAY0xbYHKIZt2RR2diIgI+87CD3phwiJy8/JLIaKybfGGHbz9\n63LXtvaN0jmvz37+BiQiUi4Y9831O0CvIf6GEgdxnYlvjJkQ8nI9UBMn0XozcLMxZivwN06C3+57\nhEJZa+2RMQm0EMaYfsB4nBHt4TJxfkLCp/c1Bb43xhxlrZ0U5/gGAJ/jlJYPl4czwr8m+w7YOB1o\nZowZZK3dmYDn64Mz08CLBbbhzC5wG6J4AXCgMaa/tXZTBOcTEZEyKDMnj/vGzHFtS0kyPHxqZ5KS\nPD7MiYiIiJSO+WGvuwI/hW6w1m40xqwFGuHcdzjXGHO3tdbr3smlga8G53pZixaLiEiJhM/CD9Js\n/OKz1vLAF/PIzXf/M/7AyZ20RIGISHE17QWrp7m3HfsYJMd7Rfn4i/d3MIC9k/PB58G76XUoWK+t\nOIIXpXFljGkAfMLeCfwc4DngFWBpYFsr4F/A9UBqYFtV4BNjTDdr7YY4xdcc+Ih9E+ojgReBP6y1\necaYFJyZ7DcBg0P6HQS8iTN7PeHOFyYD+BT4HvgFWG2tzQ3E1Qo4DbgFaBiyT3tgrDGmn7U2L4pz\nLgL+W8x9dKNERMRHL05YzMote1zbLu1/AO0beVWoFRERESkd1tpVxpiVQLBcUB/CkvgBY3DuNQDs\nBzwK3BneyRhzFHAjzn2S4P2WX2MYsoiIVDBes/CDXpiwiMHdmyjxHKHx89bz8yL3eWYndWvCQS3r\n+hyRiEgZt2KSdwK/3fHQaqC/8cRJaQ1DiHsCPkYeYu+k8B7gFGvt+LB+i4FbjTHf4SSaqwS2NwIe\npOCiO9ZeZ+/y/nnAEGvt8NBOgWT3j8CPxpjrcQYhBJ1ljHnT5XtKhPOBU0LwaeBDa+2+Qz+d8y0B\nnjTGvIPz739oSPNBwBXAyxGeL9Qaa+0rUewnIhJTuXn5GuXuYvGGHbz60xLXtqa1qnDDkW18jkhE\nREQkYhOAiwPPTwSecunzFs79hGBy/nZjTG/gA2AlUAs4ATifgvs7FvjRWqsB5iIiEjWvWfhBmo0f\nucycPB4aO8+1rUpqMncd397niEREyrj8PPj63+5tSalw9MP+xhNHfgyVM3F4xD9oY1oC4Qsm3F1Y\n8tla+w1wb9jmocaYA+IQ34HAMWGbHwtPqIez1j6PU0Vgr/2MMYX+u/p9PpxlFoYAXa21b3kl8MPO\ntRE4CVge1uTxv1lEJLHl5uXzybRVDPrvj9zy0azSDiehWGu5+7M55OR5l6Krmlb2SyaJiIhIufVZ\nyPNDApXv9mKtnQqMoKAaoQEGAm8A44D/Ay7CSeAHPxTls+99CRERkYgVNQs/6IUJi8jNy/chorLt\ntZ+WsmqrewXBa49oTeOaVVzbRETEw8wPYN2f7m19r4K6rfyNJ47ifXe7LNcruJmC0vjgzLZ/PoL9\nnsUZKd868DoVp6z89bEMDmd2eahNQKTDS+7EudAPLhPQEydB/02inM9a+78Ijx2+3zZjzAPA2yGb\n9zfGdLHWzo7mmCIifgvOvH9hwqJCR75XZJ9OX80fy7a4th3dsSGDOjZ0bRMRERFJEF8DG4CagddX\n4J58vwboAPRg3yUKCWwLLaN/m7W2VErpG2OSgd5AZ5wqfgbYCMwGpka5zF1MGWP2x4mxIU4lg93A\nMmByPKoXlML5Ev49EJHEV9Qs/KDlm3dzxfBpnNGrGV2a1qRZ7SoUPW+rYlm1dTf/m7jYtW3/ulW5\nrH/M5/6JiJRvmRnw/YPubdXqw2G3+RtPnMU1iW+t/TGex4+zwWGv34zkYiewJvxbOGvVBZ1C7JP4\nR4S9HmmtzYpkx0Ci+zOckntBp1F4Et/v85XEaJyyg6GfGrvhXLSKiCQsJe8js213No98Nd+1rWpa\nMsNO7uRzRCIiIiLFY63NwVmCr6h+GcaYI4FncAbHh2dHgq/XAjdZa0fFNNAIGGNq4FTAu4K9l+AL\ntckY8xrwH2vtdt+CA4wxaTixXQu0K6TfJOAJa+2YsnS+wLES+j0QkbIj0ln4QRMWbGDCgg0A1Kqa\nSpemNenctCZdAo+Knth/9Kv5ZOa4Vyu478SOVEpJ9jkiEZEy7qcnYddG97Yj74PKNfyNJ85UZ9aF\nMaYXEL6gT3EuhEexdxK/uTGmp7V2eomDA4wxdYDwehDFHWn/K3sn1QcbY/5lrd3nU4Xf5yupwKCB\nzex94VrkzRERkdKi5H3x/OebBWzZle3adtOgtjStpVJ0IiIiUn5Ya7cBQ4wxj+BMOOgKNABygFXA\nROBza22m37EZY/oCnwBNiuhaD7gLuNgYc4a19ve4BwcYY7rgLF0QSU3NQ4DRxphPgCHW2h2Jfr7A\nORP6PRCRsiXSWfhutu3O4edFm/h50aZ/toUm9rsGvlaUxP6vizfx1ex1rm0D29XnyA6qICgiUiyb\nl8Dv4at3BzTuBt3Pd28rw5TEdxe+DMA6a+3SSHe21i4xxqzHKZcWdAQQkyQ+zsV6OPe6PN7Ch1Q2\nALoAbosu+32+WEgNe+2+aLKISClS8r74pq3YwsjJK13b2jdK55J+LfwNSESkIsrNhl+fc573uwFS\n0ko3HpEKwlq7GHi6tOMIMsb0A8ZTsHReqEycKgGVwrY3Bb43xhxlrZ0U5/gGAJ8D6S7NecB2nOUM\nksLaTgeaGWMGWWt3Jur5AudM6PdARMqW3Lx8nvl2YUyPWVET+zl5+TzwxVzXttRkw30nqYKgiEix\njbsb8nPc2459HJLKX3UTJfHddQx7PTmKY/wBnBzyukP04eyjjsu2jGIew61/R9yT6n6fr0SMMfUo\nWFcwyH3Yo4hIKVDyPjo5efnc/dkcz/ZHTu1CanL4PVEREYm5KW/ADw87zytVh75XlW48IuI7Y0wD\nnNnfocnjHOA54BUgOBGiFfAvnCUGg4PtqwKfGGO6WWs3xCm+5sBH7JtQHwm8CPwRWA4xBegH3MTe\nyyoeBLwJnJ2I5wucM6HfAxEpe177aSmrtu2J+3kqQmJ/+G8rWLjefVzWpYe25IB61XyOSESkjFv8\nPSz82r2t02mw/yH+xuMTJfHdtQ97HfEs/BDLijhmSbitRR8+sroolV22eQ008Pt8JXWqy7Zp0RzI\nGJMKdAeaA7VwRs5vARZaa1dFGZ+IVFBK3pfM278uY8E67yqjvfav7WM0IiIV1K7N8OPjBa8nPgZd\nzoJqdUsvJhEpDQ+xd/XBPcAp1trxYf0WA7caY74DPgWC6x41Ah7ESS7Hw+vsvcReHk7J+uGhnay1\nucCPwI/GmOtxEuBBZxlj3nT5nhLhfJD474GIlCHTVmzlqfF/ldr53RL7taum0rlpTboEHmUlsb9p\nZxbPfOde0aBhjUpcd0RrnyMSESnj8nJh3F3ubSmV4agH/Y3HR0riu2sb9tq9bm/hwvcJP2ZJbHXZ\nVr+Yx3Dr3y5BzldSV4S9XmKtXRDFcXrjVBBwXVzZGLMUGAs8a60NH7QhIvIPJe9LbvW2PTzzbfjK\nLCIi4ruJj0JmSJGtzAwnkX/CU6UXk4j4yhjTEhgStvnuwpLP1tpvjDH3AqG/LIYaY/4T6+tpY8yB\nwDFhmx8LT6i7xPi8MaYDeye1HzPGfGut9Vyiz+/zBc6Z0O+BiJQtvyzaxNB3JpOfYIuRbi0jif3c\nvHzGLXfKOx+al88T3yxgR2aua9+7ju9AtUpKyYiIFMvUt2CjR4qv3w1Qq7m/8fhIfzHchU/li6YU\n+9qw17WiC8XVGpwSaaHrvvcAJhbjGD1ctrmVzS+N80XNGHMRcGDY5lejPJzbmnKhWuKUo7vGGPMy\ncKu11q1qgYhUUErex86wMXPZk5NX2mGIlB9a01yisX6ec/Ecbupb0PtSaBCvQlsikmBuZu/7A4uB\n5yPY71mchHVwCmIqTln562MZHPsO7N8EPBzhvncCF1FwP6AnToL+mwQ6HyT+eyAiZcQ3c9Zy3YgZ\n5MQgg9+4ZmWuGdiaeWu3M2d1BgvW7iA7Lz8GURZIxMT+mJlrGLkgG4AGExYzaqp78dbeLWpzcrcm\nvsQkIlJu7N4CPzzi3pbexLmnVY5FlcQ3xoTfubHW2ksj6BdLrucsKWNMFSB8Qd1oMi/hCwilGGMq\nW2szo4usgLU20xgzDegbsvkk4JliHOYkl23VE+F80TLGNGPfmFYCL8XyPC6SgWuBQ4wxx2k9ORGJ\nV/L+vNd/p061NOpWS6NOtUrUqZYa+JpG3epp1KmWRu2qaSQnJXZpueL6evZavpu/vrTDEClftKa5\nFJe1Tvk663Ij1uY5bRd8Cgle3lREYmJw2Os3rbVFjrYMrAn/FvBoyOZTiH0C+Yiw1yMjHXBvrd1m\njPkMOD9k82kUnlT3+3yQ+O+BiJQBo6as5I5P/4zZDPy1GZlUSU3m0VO7AJCdm8/C9TuYvTrDeazK\n4K915Suxn5uXzwsTCqoGvvzjEtd+SQbuP7lTwi8FICKScCY+Bpnb3NuOehDSqvkajt+inYl/CRD8\n824Cz90S6qH9Yqmwc5aUW2I5msS72z7VojyWm3HsnVQfYIzpaa2dXtSOxpijgG4uTYUl1f0+X7EE\n1q7/P/ad3X+Vtba4GbRNwFfA98Bs4G9gO87I+EbAIcA5wNFh+/UERhtjjijuYI3AIAk37Xfs2MHE\niROLc7iY2rHDWX+6NGOQsqui/vyMW57zzyjsWJq0ZHORfQxQNRXS0ww10gzpaYb0VOdr9cDrGmlO\ne3qaoXqqIS05MS8it23fwe/rYdS4Iv/UABXv50yKVlF/BxUlNXs7fSY//M/0vZzvHmLy9qbkpNUo\n1bgSjX5+9lZ30xS6LP0Bi/O3JpQFzJIJ/Pnp02ypG14UK36C75FIIjHG3Be+zVq7zyKNbv1iye2c\nsWCM6QU0C9s8qhiHGMXeCeTmkd5biIQxpg7QKmzzr8U8zK/snVQfbIz5l7X7jmLy+3yBcyb0eyAi\nZcNrPy3h0a+iWX20cC9MWMTg7k1ISU4iLSWJzoEE+rmB9vKW2B8zc81ekzeyc92/j/MO2o9OTWpG\nfR4RkQppw3yY8qZ7W7M+0OUMf+MpBSqnv6/KLtuiycS4jbp2XVs9Sq/glF0L1j41wNvGmP7W2u1e\nOxlj6uNdXr6w+Pw+X3H9Dye5Hupla+2XxTjGGuBc4BNrbY5Le0bg8RfO934YMAJoGtLnYOBB4PZi\nnFdEJGYssCsHduVY1u2KbBxd5eSCpH5ocj802R/6qJxMXEeP5+Vbfl+by+hFho2ZiTnAQKQsa7F8\nBKm5u/55nZq7ixbLR7Ko7ZWlGJUkMpOfQ6slTpE1t9/KwW2tF7/JlNrdsUm6zJQK7X72nczgllB3\n6xdLcUniAwPDXq+z1i6NdGdr7RJjzHqgYcjmI4BYJZAbuGxbXMxjLAp73QDoAsxKgPNB4r8HIpLA\nrLU8Oe4v/jfRfcZ4SS3fvJsxM9dweq/wsUaO8pTYD5+F76VW1VRuOapdiWIXEalwrIVv7nQq/7k5\n7vEKUQmwJHdXIv3XKWv/im4zqKNZJLRShMeOirV2nTHmRZx10IK6At8bYy601u4zlNIY0wN4HzjA\n47A7E+V8xWGMuRu4LGzzrzjrukXMWrsQWFiM/j8ZYw4B/sCZoR90rTHmWWvtmmIcq5fbdmPMtPT0\n9J4DBgyI9FAxF5x9VpoxSNlVUX9+Ds3L58A4lNOPl8w8yNxj2bgnsvvIaclJ1KmW9k8Z/9pV0wrK\n/Ffft9x/rSqpJEVQ4n/vZQiyKe5HiIr2cyZFq6i/gwq1fh78OG6fzU3XjqPpKfdpTfMQ+vkJ8dtL\nsKfoj7ZV96zh8CoL4eCrfQgK0tPTfTmPSAkEqwgW1SdWgsUy4jk4oGPY68lRHOMP4OSQ17H84xNe\nnQ+cwfjF4da/I+5Jdb/PF2wLlWjvgYgkqLx8y71j5jDij7/jep7Q2fiRKCqx/+eqDOasTrzEfvgs\nfC+3HN2O2tWiSS+IiFRgf30NS39wb+t2HjR1TamVO9Em8YfEuF8icUssu83OL4rbPjFJWoe4Ezgc\nCP1pPRCYbYz5HvgN2ALUBfoDA4Dgp6c8YCJwZMi+2xLsfEUyxlwGPBy2eS5wUqRr0JWEtfZvY8xQ\nnPL7QVWAi4HH4n1+EUlMKclJnN6rGYO7N2HMzDU8891CVm3dU9phxUx2Xj7rtmeybntkY9OSDP8k\n+msHk/0hX2tWTWPB2u18MWsNazJiNt5NRMJpTXOJxq7NMPHxyPv/+Dh0PRuq1Y1fTCKJr7QmPfjx\nC7x92OuIZ4CHWFbEMUvC7T6A2ySLwrjdz/FKcvt9Pkj890BEElB2bj43jZrJl3+ujfu5ipqNH4m9\nEvt9nG1uif0F67aTkxfbsWuRJPY7NEmPaBZ+x8Y1OK/PfjGNT0Sk3MvNgvF3u7elVYdBw/yNpxRF\nlcS31r4by36JxFq7xxiTT0HyGZy10IsrvFR8XnHXSS+KtTbbGHM88DlwUEhTCnBM4OG6K3At0I5i\nJNX9Pl9RjDGn45T5D7UcONpau7Ukxy4Oa+3XgTXtQwc3HI2S+CIVXkpyEv1a1+PZ7yIu8lEu5VvY\nvCubzbuiWZ1GRGJm4TjvUcwASybAovHQ1usjnVRIEx+FLM/Vs/aVmQETH4MTnopfTCKJ7YEY90s0\nbcNer4ziGOH7hB+zJNzuBdQv5jHc+nvVQfb7fJD474GIJJjd2blcOXzaXknpcA3SKzH80oNo1yhx\nKx0Vltj/c5VTit/PxH4kHhjcieQIKhOKiEiIP16BLR7jVPvfDOmN3NvKIS1W6G4be5dEi+YnonHY\n67gkla21G4wxA4FhwHUUPeBgFXCJtfZ7Y8yHLm0JdT4vxphBOOvRJ4dsXg8cVZwy9jH0BXsn8XuX\nQgwikmA2bM/kvNd/Z2WMZuG/PaQ3W3Zms2VXNlt2Z7Nlp5MY37Iriy2BJPmOzNyYnEtEypncbO9R\nzKHG3QWtjoDk1PjHJIlv/TyY+lbx95v6FvS+VMszSIVkrY0oOR9pvwRUO+z1uiiOET4NtFZ0obha\nA+QAoX/IeuBUBoxUD5dtbmXzS+N8kPjvgYgkkG27sxn6zhSm/73Ns8/+davy/qUH0bxONPPYSldo\nYj/Ir8R+UaqlJdOjeS1fzykiUubt3AA/PuneVmt/6HuNv/GUMiXx3S0E+oa8bh7FMcL3Kbq+TpSs\ntXuAO4wxzwJn4swC74gzmjsFWA3MBz4EPgv0h33Ls01LxPOFM8YcDIwGQhcT2oqTwF8czTFjYEHY\n62rGmCoh37uIVDAbd2Rx3ht/sHTTLs8+taqmsm13TsTHHNiuQZF9snPz2bY7mNwPfN2ZxZbdOQXJ\n/sBAgK27na/5/l7HikhpmPI6bI7gY9LmxTD5dd/WNJcEZi2Mu9N9+YUi982Db+6ECz/T8gwi5Ygx\npgp7Vy0EKHox4H2FXyenGGMqx6J6obU2M1ApL/SezknAM8U4zEku26onwvkS4T0IfL9u2u/YsYOJ\nEydGEU7J7dixA6DUzi9lX3n8GdqWmc9TUzNZtdP7or95ehI3dbUs+XMyS3yMzQ9NgCZ14Jg6kNup\nCqt25LN8ez7LM5yvK3fkE8+8/q7sPP7z4ff0a6oB0lK48vj7R/xT3n5+2i14gcbZO1zb5jQ9l02/\n/u5zRCUXfI+ioSS+uwXsfQHWMopjHBD2en704UTGWrsOeCHwKJQxphrQKWzz1EQ+X+A4XXHWn68W\nsnkncJy1dnZxjxdDW1y21WbfC2MRqQC27Mrmgjf+YPGGnZ59bhzUhmsHtmbMzDW8MGERyzdHc+9t\nX2kpSTSoUZkGNdyW1txXfr4lY0/OP0l/J9HvJPwLthU8Nu/KJjs3ioSOiJQerWku0Vg4DpZOjH7/\npT9oeQaR8sctsRxN4t1tn2pRHsvNOPa+pzPAGNPTWju9qB2NMUcB3VyaXJPqpXC+svIeiEgp27A7\nnyenZLJxj3eWuk2tJG7sVZlqqeV/0GVKkqFFzWRa1Ez+Z+pdTr5ldZwT+58vyaFv4xSV1BcRiUD1\nHUtotO5717attbqwqV5f17byTEl8d/PCXveJ4hgHhb2OexK/mPqxdyn6bez7fSfU+YwxbYDx7F3m\nLQsYbK39o4TxlVRNl20ZvkchIqVu224ngf/Xeu8RdtcMbMUNR7bBGMPpvZoxuHuTmCfzI5WUZKhd\nLY3a1dKK7gxYa9mVnRco5Z/F1t0FM/v/mfkf9tiZpRL/IqVKa5pLcUW6/EJRxt0FLQdCSmR/Y0Qq\nCmNMI/a+z7CslAelR8ptlGh2FMfJctlWJYrjeHkFuJOC6n0GeNsY099a6/kH0RhTH3jVo7mw+Pw8\nX6m/B9baXm7bjTHT0tPTew4YMCCKcEouOPustM4vZV95+hlasG47t705udAE/oB29Xn5/F5USUv2\n7FMRZeXmsXDdTmavjk0p/vW7LdtqtuH0Xs1iHKmUJ+Xp94/4r9z8/FgLbz8OuPy+NUnUPucVBjTq\n7HtYsZCenh71vkriu/sh7HUjY0xLa+3SSHY2xrQEGhZxzNJ2UdjrD6y1eYl6PmNMc+A79v53zQXO\nstZOiEF8JdUm7PUea613DW0RKZcy9uRw4ZuTmbfWO1l25WEtufXodpiQ8sIpyUmlnsyPlDGG6pVS\nqF4phf3qRrZeXmZO3j+l+/9J9u/MZtOOLKav3MqfqzLYnR3PP0EiFZjWNBcRSUSnAi+GvD4HKAtJ\nfLdZ2tGM0qkU4bGjYq1dZ4x5Ebg5ZHNX4HtjzIXW2vDl8DDG9ADeZ9+qikGeJbZ8Pl+ZeA9EpPRM\nW7GFIW9PYXum92D6k7o14ekzu5GWEr46h1RKSaZLs5p0aVYwXyuY2J+5chv/+WZBsScqvDBhEYO7\nNyElWf/eInuZORK6nAnJSlMKMPcz+Ps397Zel0AZTeCXlP53uJuGs65705BtZwGR1iE9O+z1Kmtt\nVOu/x4MxpgHOTYNQbybq+QL7fwfsF7I5H7jYWvt5ySOMiePCXv9ZKlGISKnZkZnDxW9NZvZq7yIc\nQ/q14I7j2u+VwA9VlpL5xVE5NZnGNavQuKb75J7cvPxy9f2KJAytaS7RSkmDox+BkeGXNcV0zKOa\nhS/irjbObG1wppp8XYqxFIdbYjmyNZyK3sd7Haro3AkcDoTOGj8QmG2M+R74DWdZvLpAf2AABWvN\n5wETgSND9t2WIOcrS++BiPjsx4Ub+dfwaezJ8R4kf0Hf/Xjg5M4q714MwcT+wvU7oqo0uHzzbsbM\nXKPZ+CLhRv8LfnoCDrtdyfyKLmcPfHufe1ulmjAwBpUCyygN/3JhrbXAmLDNlxpjiqwvFOgzNGxz\n+LFK25NA6PTJn621MxLxfMaYmjjry7UNa7raWjsiRvGViDGmH85FeKhxpRGLiJSOXVm5DHl7CjNX\nbvPsc2Hf/bnvxI6eCfxQwWT+dzcfztNndqNFhDPey6qK9v2K+CZWa5pLxdT2GKcUfrRaHQFtjo5d\nPCLlS2j58x3W2jKRPLXW7sEZUB8qmg9u4SM786y1MZ0Fbq3NBo4HwpfeSwGOAe4HngeGAUdQcH/M\nAteyb2WEbYlwvrL0HoiIv8b+uYbL3p1SaAL/uiNa89BgJfCjkZuXzwsTFkW9/wsTFpGbF8XgapHy\nbstSJ5n/Um9nZn6eluSskCa9ABkr3dsG3AHV6vkbTwJREt/bMzjl2oNaA9dHsN+Ngb5BuYFjeTLG\nDDDG2LDHgOKFGxljzMXsXdo+B7g6Hucq6fmMMVWBsUD3sKbbrbVe68ZFxUSSVXPfrx7wbtjmPCAh\nBhiISPztzs5l6DtTmLpiq2efc/s054GTO0WUwA8Vntwu78K/34ZVdWNBJGqxXNM8N5qldqXMMwaO\neYSCycLF2TfZmYWvKg4iXtaGPC9rU462hb1uFMUxGoe99v4gXQLW2g3AQOA/QCTlnlYBR1lrX2Hf\nGFcl0Pm2hb1O2PdARPwx4o+/uW7kjELXbb/nhA7cEra0n0RuzMw1JaocGJyNLyIelMyvuDJWwy8e\nKdS6baDP5f7Gk2CUxPdgrV0MvBO2+WFjzFFe+xhjjgEeCtv8trV2SYzDC56vUaTJfmNMsjHmZiB8\nUdQnrbVzEu18xphU4GPg0LCmh621T0YSQzE9Z4x5zBjTMNIdjDEdgR+BVmFNb1lr/4ppdCKSkDJz\n8rj8van8sWyLZ58zejXjkVO6kFSCke7B5HZFEfx+Hz20Cpd3SdPMfBGR0rJwHM4k0WI6cCg06BDz\ncETKkVkhz6sEBoeXFQvDXjeP4hjh+0Q/tbEI1to91to7cK7br8eZKLAU2AHsARYDXwDnA22ttd8H\ndg3/JRbREok+na9MvQciEj/WWl76YTF3fTYb6/GRLcnAk2d05bL+Lf0Nrhwp6Sz8IM3GF4mAkvkV\nz3f3Q47HIKljH4PkVF/DSTRlbcS33+4FTgYaBF5XBcYaY54DXgaWB7a3BP4F3ACE/kRtADwWcoiJ\nRsAPxpgFwKc468bPtNZuhX9K+zcHTgAuB8KncX6HU9ItEc/3HPuuM78EWG2M+VcxYg76y1r7QyHt\nNYDrgFuNMT/gLIEwHZhjrd0R7GSMqQEcApwDnMfe7zfAfODfUcQnImVMZk4eVwyfxq+LN3v2OaV7\nE/5zetcSJfArsuQkQ7+mqfz7nMMYM3MNL0xYVKKR7yIVitY0l5L68yP4/oHi71e5Jgy4M/bxiJQj\n1to5xpgVwP6BTccC75diSMWxAOgb8jqarNABYa/nRx9OZKy164AXAo9CGWOqAZ3CNk9NoPOVyfdA\nRGLLWsujX83n9Z+XefZJS0nihXN7cEynaAp2SFBJZ+EHBWfjV6RJGiJRCybzf3oCDrsdupwJyUpn\nljsrJ8PsUe5trY+CNp5zqisM/dQXwlq7zhhzBs765sH1wtKA2wKP4HphlV123wOcHrhwi7f2wF2B\nB8aYPYHz18K72sK3wKnW2pwEPV9Hl22tcAZPRONdoLAkflAKcFTgAYAxJhtn1HxV9l03LtQS4Njg\noAYRKb+yc/O5+oPp/LRwo2efE7o25qkzu2mtuRgIzswf3L2JkvkixRFc03xpJB+BXGhN84pr2U8w\n+qro9j38DqhWN7bxiJRPzwH/DTy/0xgzwlpbFqbnzQt73SeKYxwU9jrREsj9gOSQ19vY9/suzfNV\nhPdARAqRm5fPXZ/NZtRU75U3qqUl8/rFB3JIq7JU7CXxxGoWftALExYxuHsTUpJVIFkkIkrml1/5\n+fC1x3zYpBRnUomonH5RrLU/4yR017o0V8Y9gb8WGGSt/SWesRWiClAH9/c3G3gQOM5au6uMns9v\naUBdvBP4Fngb6G6t/du3qESkVOTk5XPtiOlMWLDBs88xnRry7NnddVEWY8Fk/nc3H87TZ3ZTmX2R\nohgD3c6Jfv/+t2pN84pow3z48ALIj2Ksr9arEymOF3BKphucgfJvlG44EQsfGdbIGBPxTPBA3/Bl\n7KIcbRY3F4W9/sBam5dA56sI74GIeMjMyeOaEdMLTeDXrprKyCv6KoEfA7GahR8UnI0vIsWkMvvl\nz58fwprp7m19roD6bf2NJ0EpuxABa+2vOOuTPQ54L3zstD0OdLDWTvIhtGU4CfKpQFEXlFuAl3Bi\nGxblBajf5/PTM8BDwE/Azgj3WYPzPXa01g611ka6n4iUUbl5+dz44UzGz1vv2WdQhwa8cG5PUpXA\njxsl80UitH4ufH179Pt//W/IzIhdPJL4tq+F98+ArCjf92MeqfDr1YlEKnCNfBLOrGoDXGyM+cEY\n07l0IyvSNGB12LazirF/+Dovq6y1Ea037wdjTAPg1LDNbybY+cr1eyAi3nZm5TL0nSmMm+t9T6Jx\nzcp89K+D6dqsln+BlVOxnoUf9MKEReTmlYXiOyIJSMn88iFrB3znsXxf1bpweAnuZZUzqjsRIWtt\nBk6Ju3uB3kAXIDiccSMwB5hirS32bw1r7USci/ZoYhoGDAusodYNp+R8A5zS71nAOmAuMKOkpfn8\nPJ+1dkBJYo3ifLOAWQDGGAO0xllXrhlQG6fiQhawFdgETNese5GKJS/fcstHs/hytlthFsfhbevz\n0vk9SUtRAt8P4WX2RSTE1uUw/LSSJeHXz4aR58IFn0BqYSsKSbmQtQNGnAnbPWZ1pVaHhh1g1RT3\ndi2/IFIsxpjDAk/vBB7DWVLuMGCWMWYqzszo2TjXoMUeMG6t/SlGoYYf1xpjxgBXh2y+1BjzZFGD\n940xycDQsM1jYh1jCT2Jc38j6Gdr7YxEOl8FeA9ExMXWXdlc8vZkZq3y/nzfsl41hl92EE1r6bN7\nLMR6Fn5QcDb+6b2axfzYvps5UuXNpXSozH7Z9vN/YafHSuQD74Yqtf2NJ4Hpp7qYAkn63wKPhBEo\nVT8p8Ch35/OTtdYCiwIPERHy8y23f/xnoYniQ1vX49ULe1EpJdmzj8RHMJkvIgE7N8B7p3hfEBXH\nil/hoyFw9vu6IC7P8nJg1MWwbrZ7e1IKnDMc0hvDy4dA+Fhdk+ysV6flF0SKYyLO0mxBloLB/b2B\nA0twbEt87/c8A1wRco7WwPWB7YW5MdA3KLeofYwxA9i31PvAwGSImDLGXMzepe1z2DtRnkjn8+09\nEJHStzZjDxe+OZnFG7zHdHVqUoN3h/ahXvVKPkZWvp3eq1nE9xomTpwIwIABA+IXUCJSElVKm5L5\nZc+WZfDbS+5tDTtDr0t8DSfRaaqgiIhIIfLzLXd9NptPpnuvN9e3ZR1ev+hAKqcqgS8ipSwzA94/\nDbYui90xF34Nn18H+Sr5WC5ZC2NvgiXfe/c56Xlnpn2DDnBg+AROnG0NOsQvRpHyzVCQvLchD1PC\nR9xYaxcD74RtftgYc5TXPsaYY3CWsAv1trV2SYzDC56vUWAAQCR9k40xNwNvhTU9aa2dk4jnKwvv\ngYjExrJNuzjj5d8KTeD3OaAOI6/oqwS+lA6VN5dEEPw5nPxaaUciRfn2XsjLcm879jFI0v31UEri\ni4iIeLDWct/nc/hwykrPPr1b1ObNi3tTJU0fMESklOXsccrfe82mLolZI5wLLWuL7itly09PwYzh\n3u0D7oQe54e8vgsq1yx4Xbmm00dESsq3JHyM3AtsCHldFRhrjHnCGHOAKdDKGPMk8AUQWt95A3Bf\nHONrBPxgjJlvjHnEGDPQGPNPXc5AIr2FMeYanDXmn2bve2TfAfcn8Pkg8d8DESmhuWsyOPOVSaze\ntsezz5HtG/De0D7UqJzqY2QiLpTMF5GiLPsJ5n/h3tbhJDjgMPe2Ckx1JURERFxYa3ngi3m8//vf\nnn167FeLt4f0oVol/TkVkVKWl+uUvV/xq3efox6CJRNgaXhV4oCWAyGtGiwY697+24tQtS70v7nk\n8UpimDkSfnjYu737BXD4v/feVq0uHH4HjAsk7gfc6WwTkeL6ib3L6Zcp1tp1xpgzgHEUJIbTgNsC\nj8zAtsouu+8BTrfWxmDdlyK1B+4KPDDG7AmcvxbeE1u+BU611uYk8vnK0HsgIlGYvGwLl74zhR1Z\n3onQU3s05YkzupKarHl6kkBU3lxKQ52WBT9vkpjy8+AbjwkAyWnOPSvZh357ioiIhLHW8uhX83ln\n0nLPPl2b1eTdoX2orgS+iJS2/Hyn3P3Cr7379LsR+l0PbY7yXtP82Meg9gEw4kxndLSb7x+AqnW0\nRll5sOQH+Pxa7/ZWR8BJz7qvc9/7MsjeVfBcRIrNWjugtGMoKWvtz4Hy7R8BjcOa3RLHAGuBM6y1\nk+IanLcq7D0bPVQ28DjwoLU2ryycr4y+ByJShAkL1nPV+9PJyvVezuqSQ1pw34kdSUoqC8VbpEJS\nMl+85OfD7I9ic6zUKnDcU9DtHP18Jbrp78J6j5WjDr4W6hzgbzxlhIbpiYiIhLDW8uS4v3j9Z+/1\npDs2rsHwoQepXJ2IlD5rnTL3s0Z49+lxIQy633le1JrmqZXhnBHQpIf38cbeBHNHlyRqKW3r5sD/\nXQj5HjO7GnaBM9+FZI+/cylpcPhtzsOrj4hUCNbaX4EOOMnoLYV03RLo08Gn5PEy4EFgKlBUgnwL\n8BJObMOiTKj7fb5/JPB7ICJRGD1jNVe8N63QBP5Ng9oy7CQl8KWMUJl9CbIWFo6HV/vDZ1fE5pg5\ne5yKhFpHPbHt2QYTPKoAVm+oio+F8HVoijEmBRgA9ALa4ZQTSwei+R9mrbVHxiw4ERER4NnvFvG/\niUs829s3Suf9yw6iZlUlLUQkAfzyjFPm3kuHk+DEZ/eeTT3gLmfUe2aG8zp8TfNK6XD+x/DWsbB5\n0b7HtPnw6eVQpRa0HBCDb0J8lbEaPjgTsne4t9doCuePgso1/I1LRMosa20GcKcx5l6gN9AFqBdo\n3gjMAaZYa4t9195aOxEodpYqENMwYJgxphrQDWgFNMBZOz4LWAfMBWZYG16iJrHP53H+uLwHIuKf\ndyctZ9jncwvt88DJnbj4kBb+BCQSS5qZX7GtnAzfDoO/4zCOcNYISG8Eg4bF/tgSGz8+Abs3u7cN\nut+5DyWufPktaYypAtwLXAbEYsFEQxleO05ERBLTixMW8dz3LgmrgDYNqvP+ZQdRp1qaj1GJiHiY\n9o5T3t7LAYfBaW/se2MkkjXNq9WDCz+Dt46B7av3PXZeNnx4Plz8OTTtVaJvQ3yUmeEk8HescW+v\nVMMZwFGjib9xiUi5EEgQ/xZ4JAxr7S5gUuBR7s4Xdu6EfA9EpHDWWp7/fjHPfLfQs09ykuHpM7tx\nSo+mPkYmEgdK5lcsGxbA9w/CX1+W7Dh1WkLVerBqsnv7L/91EvkHXVmy80jsbVwIk191b2vSE7qe\n4288ZUzcfzsaY9oCXwMt2Hv0tJLwIiKSMF79cQlPjfe+YG5ZrxofXH4Q9apX8jEqEREP88Y4Ze29\nNO7ulMVP9VgON5I1zWs1DyTyj4U9LtV5s3fC+2fA0HFQv22xwpdSkJsNoy6CDR6zu5JS4ez3oWFH\nf+MSERERqcDy8y0Pjp3HO5OWe/aplJLE/87vyZEdGvoXmEi8KZlfvm1bCRMfg1kjnWp+0arTsuDn\nw+bDiLNg6Q/ufb/+N1SrD51Pi/58Envj7/Zeyu+4/0CSVn0vTFz/dYwx9YDvgAMomD0fTN6bEjxE\nRERi5s1flvHY1ws82/evW5URl/elQbpHMkxExE9LJ8Inl3lfCNdtDRd8Ung5skjXNK/fzpmZnVrN\nvX3PFhh+inOBLonLWvjiBudnx8vgl6Dl4b6FJCIiIlLR5eTlc+tHswpN4KdXSuG9oX2UwJfyK5jM\nf6k3zBwJeVr5pUzbtRnG3Q0v9IKZH0SfwK/TEk55Ba6ZAt3PdQZ4pKTB2cOhcTePnSx8diUs+ynq\n8CXGFn0Li8a7t3U5C5r38TeeMijeQ5seAJqxd+I+ExiLU9prKbADyItzHCIiIq7e+205D42d59ne\nrHYVRlzel0Y1lcAXkQSweppTxj4v2729RlO4cLRTDj9WmvWCcz5wyrDn5+zbvn01DD/VmZEfXpZf\nEsPEx5x1Ar0ccQ90O9u/eEREREQquMycPK4dMZ3v5m/w7FO3WhrvDu1D56Y1fYxMpJQEk/l7tsLB\nV5d2NFJcWTvh95dh0vOQtT3644TOvHerzFAp3Zlo8OZRsHX5vu3Bpf+GfAWNukQfh5RcXg58c6d7\nW2pVGHS/r+GUVXFL4htjKgNDcBL4wVn4Y4HLrbXr43VeERGRSI2c/Df3jfEoKww0qVmZkZf3pWmt\nKj5GJSLiYeNCp3x99k739iq1nfL3tZrH/tytBsLpr8NHwY/3YTYvgg/OgIs/L7wCgPhv+nD48T/e\n7T0vhv63+hePiBTJGFMJaATUAtKJooqitVZTkEREEtSOzBwue3cqfyxzWbIqoGmtKgy/tA8t61f3\nMTIRkWLKzYbp78KPT8Au70FJe6lS27kGHX93wbaikvehqjeACz6FN4+G3Zv2bc/a7tw7uXQ81N4/\n8u9FYmvy6869IjeH3gQ1m/obTxkVz5n4/YHKFJTQ/wU4xdqSLIAhIiISGx9NXcldn832bG9YoxIj\nr+hL8zpVfYxKRMRDxipntrvb2vTglLs//2On/H28dDrVmRUx9ib39jXTnRHv538EKZXiF4dEbvF3\nThl9L62PghP+C0YrlomUNmNMK+BfwECgCyW7X2NLuL+IiMTJ5p1ZXPz2ZOas9p6p2rpBdYZf2ofG\nNTWhQCqY5DRYPxtWTYOmPXWdksjy82HupzDhIfcZ8W5Sq8LB18Ah10Hlmk4SvzjJ+1B1Wzn3Ht45\nEXJ27du+cx28f7oqBpaWXZtg4uPubTWbOz8DEpF4XtQdEPganIU/TAl8ERFJBJ/NWMXtn/yJdZlM\nClA/vRIjL+/L/nU91oAWEfHTrs1OAn/7Kvf2pFQ4531odmD8YzlwKOze4lyou1n2I3xyGZz5DiQl\nxz8e8bb2Txh1MViPlcsad3Pep+LcKBGRmDPGVAVeBC7CuX+iu9UiIuXU6m17uPCNP1i6ySXhFNCt\nWU3eHtKHOtXSfIxMJEHkZcPMEc6jURfodYmzbnblGqUdmQRZC4u/h+/vh3Xek6P2kpTivJeH3Q7p\nDQu2n/JK8ZP3oZr2hLPfgxFnQ37uvu2bF8GIs5yKgWm6x+urHx6BrAz3tqMehFQNUotUsUuyFUPo\n8JZ8QKXcRESk1I39cw23jJrlmcCvWy2NkZcfpJJ1IpIYsnY4Zeo3LfToYJwy962O8C+m/rdA30LW\nKJz/uTNb3+sXrcTftpXwwZneSy/U3A/OGwWV9LdOpDQZY+oAU4CLKbg/E6xmGPrc7XX4Q0REEtji\nDTs54+VJhSbwD2lVlw8u76sEvgg4CeIvb4Gn28GYa2H1NF1jlrZVU+Hdk+CD0yNP4Hc+A66dAic8\nvXcCH6D7uSUfVN56EAx+ybt99VRnWcA8lyS/xMe6OTDtHfe2/Q5xqjxKxOI57SIz5HmGZuGLiEhp\n+2bOWm74cCb5Hp/5a1dN5YPLD6J1A63nLCIJIDcL/u8Cp0y9lxP/6/8FkDFw9CPOjPw/P3TvM/1d\nqFYPjrzP39gE9mxzEvg717m3V67plB1Mb+RrWCKyN2OMAUYBHQKbLM4s/GxgAVAFaBPS9iNQHagH\nhC7uGfxkOxdwWRRURERK25+rtnHJ21PYsivbs88xnRry3Dk9qJyqalZSzqVVh5w93hXDwuXshhnD\nnUejLtBriDN7W7Pz/bPxL/j+QVgwNvJ9Wg9y7gc07ha/uIK6nQM718O3HvcfFo2DsTfAyS9qiYZ4\nsxa+uQNc08EGjn1M70ExxTOJvyzkubIhIiJSqr6bt55rR8wgzyODX7NKKu9fdhDtG+kiQEQSQH4e\nfHo5LJ3o3eeIe5zy9qUhKQkGvwiZ22DhN+59fn4aqtSBQ671NbQKLTjwY+N89/bkNDhnBDRo729c\nIuLmFOAICpL3+cADwHPW2h3GmCuBl4OdrbUDg8+NMfWAE4Ebga6BYzQAbrLWfudT/CIiEoFJSzZx\n+btT2ZXtnbA8s1czHjutCynJ8SyaK1LKUqs5s7G7nOlcR84c4czW3bIk8mOsmw1f3gzj74UupzsJ\n/SY9lBSMl4xVzrrmMz/wSMq6aNoLBj0AB/SPb2zhDrkedqyH3z1m5c94H6o3giPv9TeuimbBWFj+\ns3tbjwugSXdfwykP4vnJ4Beci1CAFGNMh8I6i4iIxMsPf23g6g+mk+uRwE+vnML7lx5EpyY1fY5M\nRMSFtc6NiXljvPv0vRr63+pfTG6SU5011fc7xLvP+Lth5kjfQqrQrHXKTHpdMAOc8jK0ONS/mESk\nMLcEvhqcJPxl1tqHrbU7itrRWrvJWvsO0BO4P7C5HvClMeb4OMQqIiJRGD93HZe8PaXQBP7l/Q/g\niTO6KoEv5ZyBS8cXlE+vVg/6XQ/XTYOLv4DOp0NSauSHy9kF09+D1wfCq4fB1LecpegkNnZvgfH3\nwPM9nQoIkSTw67aBs4bDZd/7n8CHQMXAh52fJS8/PwWTX/cvpoomJxPG3e3elpauSo1RitunA2vt\nJiC0vsYZ8TqXiIiIl58XbeTK4dPIznP/wFm9UgrvDe1Dl2ZK4ItIgpjwkPf6YQBdz3bK2SfCbIPU\nKnDuSGjYxbvPmGvgr6/9i6mimvAQzB7l3T7ofuiiSzKRRGCMSQf6UrCe/ffW2neLexxrbb619kHg\n3ziDAVKBEcaYprGMV0REiu/jaau46oPpZOd6J79uO6Yddx3fAZMIn+tF4qnbudCo877bjYEDDoMz\n3oJbFsBRD0GdVsU79ro/YexN8FQ7+Px6WF3IcnRSuOxd8NNT8Fx3mPQC5GUVvU96Ezj5Bbj6d+h4\ncunep0hKcgauH3C4d5+vboO5o30LqUL5/X+wbYV72+G3QfUG/sZTTsR7iN9dOOu5AdysC0kREfHT\npCWbuOzdqZ4XzVXTknlnSG967Ffb58hERDz89pJTht5Lm2Ng8EvOxWmiqFILLvgEah/g3m7z4KNL\nYPmvfkZVsUx9u/CfmwMvhX43+haOiBSpL879mOBdztdKcjBr7VNA8JdsOjCsJMcTEZGSefOXZdz6\n0SzP5fyMgUdO7cw1A1srgS/lX1IqDPh30f3CZ+d3Oi2K2fnvanZ+NPJyYMqb8HwPZ3B4VkbR+1Su\n5Qy6uH469LzIqbCQCFIqwdnvQyOviQbWWbpw+S++hlXu7VjnDABxU6clHPQvf+MpR+J6989aOw+4\nGufCtAbwtTGmSTzPKSIiAjB52RYufWcqWR4J/MqpSbx1SW8ObFHH58hERDzMHAnj7vJu3+9gp3x9\ncjFuZPglvSFc+BlUb+jenpsJI8+BtX/6G1dFsHA8fHmLd3vbY+G4JxKjcoOIBDULe13kXURjTFG/\n/J8NdgXOM8akRRGXiIiUgLWWp8f/xUNj53n2SU02PH9OD84/aH8fIxMpRb0ugdotIu8fnJ1/5ttw\n83w46kEnCVgca2c5s/Ofbg9f3ABrZhRv/4oiPx/mfAIv9XGW9Nu5vuh9UqrAoTfDDbOcQRepVeIf\nZ3FVrgHnfwK1PH7P5mXDyPNg/Vx/4yrPvn/QGUjj5uhHnMEVEpW4T+Gx1r4FXArkAJ2AP40xNxpj\nVLdYRETiYtqKrQx5ezJ7ctzXnauUksSbF/emb8u6PkcmIuLhr6+dsvNeGnaGcz+EtKr+xVRcdQ6A\nCz6Fyh4f87O2w/unweYl/sZVnq2Z4VQ5sB7rrDbp4ZSmTJRZESISFDqKNMtau86lT27Y66LufH0b\n8rwKzmx/ERHxSX6+5b4xc3lhwmLPPlVSk3nj4t6c1E1z3KQcqdMSGnR0b0upAofdGv2xq9eHfjfA\ntdPgos+h06nFm52fvdNZqu61AfDq4U4FM83OdyyZAK8PgI+HwpalRfc3yXDgULh+Bgwa5lTkS2Tp\nDZ37E1U97v1mZcD7p8O2lf7GVR6tngYzP3BvazkQ2h3nbzzlTFzv5hhj3gp5OQfoiXOx+jTwpDFm\nNrAcyMBZB644rLX20ljEKSIi5cesldu45K3J7Mp2T2ikJSfx2kUH0q91PZ8jExHxsGJS4YnY2i2c\ni89Ev0gGZ53D80bBe6dA7p5923dthOGnwNDxUKOx39GVL1tXwAdneY92r7W/816kVfM3LhGJROgs\n+Z0efcLvMNcvpC/W2u3GmB045fQB2gM/RR2hiIhELDs3n1s/msXns9Z49qlROYW3h/Sm1/6qBijl\nyCmvQI2m8N5J7u0HXQHpjUp+nqQkaHm489i50UkYTnsHti6L/BhrZ8LYG2H8PdDlDOg1BJp0L3ls\nZc3qafDdA7Dsx8j36XQqHHEv1G0Vv7jioV5rOO8jePdEyNm9b/uOtc5Eg6HjoKp+N0fFWvj6Dvc2\nkwzHPqaqgCUU7ykZl7B3cj743ADJQHegWxTHNYFjKYkvIiL/mLM6gwvf/IMdWeETlxypyYaXL+jJ\n4W3r+xyZiIiHtX/CiLOdcvNuqjeEC0c7o8jLiv36wlnvwYfnQr7L7+NtfzsXykO+giq1/Y+vPNiz\nFT44A3ZtcG+vUhsu+ASqN/A3LhGJVGiC3qvEyvaw182Aou5U51Nw30V3IkVEfLAnO4+rPpjGxL82\nevapn16J94b2oUPjGj5GJuKDbufA2x6zbCvVgH43xv6c1evDoTfCIdfD8p+c2fULxrpfe7oJzs6f\n9o5TuazXJdD5DKhUPfaxJpJNi5z17ueNiXyflgOdWfdNesQvrnhr1su5PzHyHPefkU0LnXsyF41J\n7MqHiWr2x7Bqsntb70uhQQd/4ymH4l5O34MNeYiIiJTY/LXbueDNP9ie6f6hPSXJ8OJ5PTmyQxlK\nhIlI+bZ5iVO+LSs8TxNQqaYzA7/OAf7GFQttj4ZTXvZu3zDPuVDO9phFLt5ys+DD852bDW6SK8E5\nI6FeG3/jEpHiWBvyvIrHevfBuqbB+ya9CjugMaYWELqeSX7U0YmISEQy9uRw4Zt/FJrAb16nCh//\n62Al8KV8Wvw9/P2be9sh18V3dnNSErQcAGe9CzfPh0H3O1XsimPNDPjiBni6HXxxI6ydFfs4S9v2\nNfD59fDSQZEn8Jv0cJLaF40u2wn8oDZHwckverevmuwsK5AX4UAQcWTvgu+GubdVrgUD7vQ1nPLK\njyS+icNDRETkHwvX7+D8N/5g2+4c1/YkA8+d04NjOsWghJeISCzsWAfDT/WeSZ1SGc77P6c8fVnV\n9Sw49j/e7Sv/gFEXQ577725xkZ8Po6+CFb96dDBw2quw/8G+hiUixfZX2Ov2Ln0WA6HrkhxfxDFP\nDHwN3jPZFEVcIiISoQ07Mjnntd+ZumKrZ592DdP5+F+HsH9dLW8k5VB+Pnz/gHtb1brQ9yr/Yqne\nAA69Ca6b4VSy63gKJBWjCHX2Tpj2Nrx6GLw2AKa9C1meqxiVDXu2wrfD4PkeMP1d7+X7QtVtDWe+\nC5f/4AyQKE+6n+sM9PCy8Gv48ianPLxE5tfnYPtq97aBd2uJghiJdzn9IXE+voiIVHBLNu7kvNf/\nYMuubNf2JAPPnN2dE7pq7WUpZ3KznQ/MAP1ugJS0wvtL4tizFYafBttWuLebZKfcW3lIxPb9F+zZ\nAj96JPMXf+skpU99zZlJIYX7/n6Y84l3+9EPO+sVikiiWwDspqCUfjdgdmgHa22+MeZH4NjApiON\nMYdaa38JP5gxpibwAM6s/WASf2Yc4hYREWDllt1c8OYfrNjsssZyQI/9avH2Jb2pVVXXaVJOzf8c\n1v3p3nbozVAp3d94wLmmbDXQeezcADPedxLYW5dHfow1M5zHuLuh65nQawg07hq3kGMuezdMfhV+\neQYyMyLbp3ojGHAH9LgAkt0KRJUT/W50JlT88Yp7+/T3IL0xDLzL17DKpG0rC+5JhqvfHg4c6m88\n5Vhck/jW2nfjeXwREanYlm/axXmv/86mnVmu7cbAE2d0Y3D3pj5HJuKDKW/ADw87zytV93eUu0Qv\ne7dTRn7DXO8+p7wMbY/xL6Z4G3An7N7s/My6mf2Rs4b7cU84v7jF3eTXvS+SAfpcCQdf4188IhI1\na22uMeYX4OjApmOB9126jgy0BZPzo40xNwKjrLXZAMaYw4HngQMoKL2/FpgRt29ARKQCW7h+Bxe+\n+Qfrt7vfhwDo36Yer17Yi6pp8Z4/J1JK8nLhh0fc29KbOGthl7bqDaD/zU7idtlEmPYOLPjSfV10\nN9k7YOpbzqNJTzhwCHQ6zbn/kojycmHGcGcA/Y61RfcHqFzTqWDQ58qKsR68MXDMY7BzPcz9zL3P\nj/+B6g0T42c4kX17H+Rmurcd+xgk6+9frGi6i4iIlEkrt+zm3Nd/L/TC+bFTu3BGr2Y+RiXik12b\n4cfHC15PfMzZJoktLwdGXeSUkfdy7H+g29n+xeQHY5wEfafTvPtMfg1+fMK/mMqaBV/B17d7t7c/\n0blQ1iAIkbLk65Dnxxpj3KY9jaSg9L4F6gDvAtuNMauMMduBCUAXChL9FnjaWtUCFRGJtRl/b+Ws\nV38r9D7ECV0b8+bFvZXAl/Ltz/+DTQvd2w6/HVKr+BtPYZKSoNURTrW7m+bBkcOg1v7FO8aa6fD5\ndfB0exh7M6z1qEAQbubI+K+zbi3MHQ3/OwjG3hhZAj+lslPR8YZZThK/IiTwg5KS4NRXoUV/7z5f\n3Qrzv/AvprJmxSSY+6l7W9vjnP9vEjNK4ouISJmzetseznntd9ZmeIz4Ax46pTPn9NnPx6hEfDTx\n0b3LomVmOIl8SVz5+TD6aqd8vJfDbnPKz5dHScnOhXJhF3MTH3Vmm8veVk+Dj4eCzXdvb9YbTnvd\n+TcWkbJkFJCPk3ivA5wf3sFamwtcDgSzRcFEfRrQBKhOQeI+6Efg2XgFLSJSUf2yaBPnv/EH23bn\nePY5t89+PH9OD9JSdMtdyrHcLJj4uHtb7QOckuyJKr2hMzv/+plw4WfQ4WRIKsaAm+wdMPVNeLU/\nvH6EU349e5d3/9H/gpd6xy+Zv3QivD4QProYNi8uur9Jhp4Xw3XT4agHnYp4FVFKJTjnA2jY2b3d\n5sPHl8KK3/yNqyzIz4dv7nBvS0qFYzwqdEjU9IlCRETKlLUZezj3td9ZvW2PZ59hJ3Xkwr7FHFUr\nUlasn+eUcws39S3YMN//eKRo1joXObNHefc5cCgMvNu/mEpDShqc/b6TdPby1W0w+2P/Ykp0W5Y5\nyy/kevzNq30AnPthxZo5IVJOWGvXAicD5wYertPZrLW/AKcCWyhI2Ic/TODxNTBYs/BFRIonNy+f\ncctzGLc8h9y8fQdOfj17LUPfmcLu7DzPY1w9oBWPntqZ5CRVRpJybvp7kPG3e9vAu8rGmurB2fln\nDw/Mzr+v+LPzV08rmJ3/5S2wbrZ7vy1LY5/MXzMD3jsF3hvsPI9Eh5Ph6t/h5OehppYdpXJNOP9j\nqOkxASwvC0aerfts4WZ+AGtnubf1/RfUbeVvPBWAkvgiIlJmbNieyXmv/8HfW3Z79rn7+A4M6XeA\nj1GJ+MhaGHeX+4xcmxdo0337hPPTkzD5Ve/2TqfC8U9VjFLoadXgvFFQv71HBwufXQmLvvM1rIS0\newt8cAbs2ujeXrUuXPAJVKvnb1wiEjPW2q+stf8XeEwqpN84oC3wCDAvsDmYuM8EvgFOtdaeYK3d\nEe+4RUTKmzEz1zByQTYjF2QzZuaavdo+nPw314yYTrZLcj/oruPbc/ux7TEV4fO8VGzZu7yXQWvQ\nETqf7m88sZDeEPrf4szOv+BT6HCSM2M9UlnbYcob8Mqh8PqRMH24++z8WCTzNy+Bjy6B1wbA0h8i\n2+eAw+DyCc6Ahfpti3/O8qxGY7jwU6hSx709MwPePx0yVvkbV6LK3A7fP+DeVq2+U11SYk5JfBER\nKRM27sji3Nd/Z9km7zJVtx3TjssPa+ljVCI+Wziu8Au1JRNg0Xj/4pGiTXkDfiiknFjLgXDqaxWr\nFHrVOk7pQq8R7/m5MOpCWDnZ37gSSc4eGHmud0nElMrODHyNchepMKy1W62191pru+CU028M1LPW\nVrPWHm+tHVPKIYqIlEm5efm8MGHRP69fmLDon9n4r/y4hDs+nU2+xzjpJAP/Ob0LVxymz2RSQUx+\nDXZtcG874p6yfV2blAStj3Sqx908D464F2oVc5nO1VPh82sLZue7iSaZv2MdjL0JXuoDcz+LLJZG\nXZ1BCRd9Dk17Rf49VDT12sD5H0GqR3W77ath+GnOIPuK7uenvCcZHHGvU91AYk5JfBERSXhbdmVz\nwRt/sGSjdwL/pkFtuWZgax+jEvFZbjaMj6Dc+ri7IM97nUbx0ZxP4Mtbvdub9nJuEKSk+RdToqjR\nxEnkV/WYRZ6zGz44s2KWrsvPd6oRrPzdo4OB09+A5n18DUtEEoe1Ns9au95aq7uJIiIlNGbmGpZv\nLqj2t3zzbkbPXM3jXy/g8a8XeO6XlpzE/87vydm9i5nkEymrMjPgl2fd25r2gnbH+xpOXKU3gsNu\nhetnOdXP2p8Y3ez8woQk8xuum4DJd1muY882+O4BeK67s4RifgQJ/zot4Yy34IofnUEJqhBStGYH\nwpnveL/Hm/5yBtnneC/tWu5tXgK//c+9rVFX6HGBv/FUICmlcVJjTGegD9AcqAWkU/wBBdZae2mM\nQxMRkQSzbbeTwP9rvXdl0GsHtub6I5XAl3Juyuves3JDbV4Mk1+Hg6+Of0zibfF38OmVOEsVu6jX\nzll/rVJ1X8NKKPVaOzdE3jkRsl1+x2dug+GnwtBxULuY6xOWZd/eC/MKmVB77ONOiUcRERERKZHw\nWfhB938+j51Z3smyqmnJvHbhgRzaRssaSQUy6UXnGs3NkfeVz2RxUhK0HuQ8dqyDGcNh2nuQ8Xfs\nzrFlKR22PMf+K0ZBnWHQ5UzIz3Hu6/z8tPe/ebjqDeHw26HnxZCcGrv4Koq2x8DJz8OYa9zbV/4O\nH18KZ70HyaWSVi1d4+91fi7dHPefsl2FI8H59tNmjKkHXAdcCdQv6eFw7ogqiS8iUo5l7Mnhwjcn\nM2/tds8+Vx7WkluObqu156R827UZJj4eef8fH4euZ0O1uvGLSbytnAL/d6H3BU7N5oFZ6B7rrlUk\nTbrDuSOddebysvZt37EWhp/iJPKrN/A7Ov/9/gr89qJ3e99roO+//ItHREREpBwLn4UfVFgCv1bV\nVN6+pDc99qsdz9BEEsvOjfDbS+5tLfpDywG+hlMq0hs5a34fejMs+QGmvQ1/fQ3WZQZ9FKruWevM\nzP/2XqfEfubWyHasVAP63QB9r4K0ajGJpcLqcYEzWGPCQ+7tf30JX90CJz5bPgeteFnyg/O9u+l0\nKux/iL/xVDC+lNM3xhwNzAHuARrgJOHdHnvtFkEfEREpp3Zk5nDxW5OZvTrDs8/Qfgdwx3HtlcCX\n8m/io045tkhlZsDEx+IXj3jbMB9GnOmUg3dTtS5cOBpqNvU1rIR2QH+n3J/xuDTZstRJ8md6/z0o\nF+Z/Ad/c4d3e4WQ4+mH/4hGRhGaMqWaMaWD0QVhEJCpes/AL07BGJUZdebAS+FLx/PIM5HgscXnk\nff7GUtqSkqHNIDjnA7hpLgy8xxmoHyu7NkaWwE+uBAdfCzfMckr/K4EfG/1vgd6Xe7dPewd+fMK3\ncEpdXi58c6d7W0plOOpBf+OpgOKexDfGHAl8SUHyPrymqA15hCfrbdgDlMgXESn3dmXlMuTtKcxc\nuc2zz4V99+feEzsogS/l3/p5ztpnxTX1rYq5nnhp2rrCKf++x+OCO626Uz6+npb/2EeHE+HkF7zb\n1/0JI8+DnEz/YvLTysnwyWV4Lr/Q/CA47TWnnKOIVFjGmOrGmIeMMUuA7cBaIMsYM8EYM7iUwxMR\nKVO8ZuF7aVG3Kh//6xDaNkyPY1QiCShjtff67m2Pg+Z9/I0nkdRoDIff5iTSz/8Y2p/ova56rJgk\nZ8b49dPhmEdU4S/WjHHKw3c42bvPxEedZH5FMO1t2Ohxb/GQ66HWfv7GUwHFtZy+MaYO8H9AMgVJ\n+mxgNPAH0B24KNDdAkOA6kA9oDfQH0in4G7WRODdeMYsIiKla3d2LkPemcLUFd6jTs/t05wHTu6k\nBL6Uf9bCuDvB5kexb54zWvbCzypWma/SsnOjk8Dfsda9PTnNKRvfpIe/cZUlPS6A3Vuc8oFuVvwC\nHw8tf2vQbV4CI86GXI8BCnVawbkfQmoVf+MSkbgyxgwC/hey6W5r7UeF9G8CfAu0Z+/JDSnAAOBw\nY8yb1tor4hCuiEi5UtxZ+B0b1+DdoX2on14pjlGJJKifnnBf+gzgiHv8jSVRJSVDm6Ocx/a1MON9\nmP4uZKyM7Xkad4fBL0GjzrE9ruwtKRlOex3e3+Lch3Az9iao1gDaH+9vbH7avQV+eMS9Lb0JHHqj\nr+FUVPGeynEdUIeCJPxCoIu19lxr7bPAb6GdrbXvWmtfstY+YK09EWgK/BvYE+hyeOAx3FqrZL6I\nSDmTmZPH5e9NZfKyLZ59zujVjEdO6UJSkpKSUgEsHAdLJ0a//9IfYNH4mIUjHjK3wwenw5Yl7u0m\nySkXf8Bh/sZVFvW73lnPz8tfX8IX1zsDXMqDXZucpQL2ePzdq1oPLvhYsytEyqfzgNaBRwPgqyL6\nfwh0oKDCYfjDAJcaYzzutImISFBxZuEfUK8qI6/oqwS+VEybl8D04e5tnc9QMtlN6Oz8WFs7E5b9\nFPvjyr5SKztLJjTo5N5u8+HjIfD37/7G5aeJj3tXmjzqAS3h4JN4J/Evp+Bicg9wnLV2caQ7W2t3\nWmufBA7EKRNngIvZe7S6iIiUA5k5eVwxfBq/Lt7s2eeU7k34z+ldlcCXiiE3G8bfXfLjjLvLOZbE\nR04mfHgerC3kAv2k56DDSf7FVNYNegB6XOjdPvMDGH9P2U/kZ+92ZuBvXebenlIFzhsFdVr6G5eI\n+GUQBQn4r621HgvNgjHmNOBQ9k7Y5wLrgXwKEvsGuNUY0z6+oYuIlF3FnYWflw/V0uJcHlskUU18\nzKnyF84kw8C7/I+nLEnS740yr0otZ1B9zebu7bmZzjX9hgW+huWLDQu8l9Fo1hu6nOlvPBVY3JL4\nxpgDgCaBlxZ4zVrrcYeqcNbaBcDxQA7ORenlxphjYhKoiIiUuuzcfK7+YDo/Ldzo2efEro156sxu\nJCuBL1I8eTmlHUH5lZcLn1wKy3/27jPoAeh5kXe77MsYOPFZZz1BL7+9CL8+61dEsZefB59eDqun\nurcHqzc06+VvXCLiC2NMY6BZyKavi9glWKIk+EH4JaCutbYJ0AgYSUEiPwW4KXbRioiUL6Nnro54\nFj7A31t2M2bmmjhGJJKg1s+F2R+7t/U4H+q28jeeiq5OSzjlFeijlZN8VaMJXPAJVKnt3p65zamu\nl7Ha17Di6p+lPV0G8AAc+x8t2+mjeM7ED95xCr6bo0pyMGvtn8CrIZu04IqISDmQk5fPtSOmM2HB\nBs8+x3ZqxDNndyclOd4FZEQSSEoaHB2DirjbVsD/XQBLJpT9mcuJxFoYewMsGOvd55DrtEZYtJJT\n4PQ3oUV/7z7f3Q/TyuAKW9bCN3cW/rNz3BPle209EWkX+Bq8XzLdq6Mxpil7z8KfZK29zlq7E8Ba\nuxm4EJgWOJ4BTjfG6IOziEiYPdm5PPDFvGLv98KEReTm5cchIpEENuERClZJDpGcBof/2/dwKqxg\n8v6aKdD9XOdaWfxVv51TJS+linv79lXwwRnepefLmoXjnHuIbrqdq8kGPovnRV29sNczXfrs9VfA\nGFO5iGMGF2AxwCHGmEbRhSYiIokgNy+fGz+cyfh56z37DOrQgOfP7UGqEvhSEbU9BloOLPlxFo2D\n4afC/w6GqW87ZbylZL4bBjPe927vfgEc9ZB/8ZRHqZXhnBHQuJt3n7E3wrzPfQspJn57CSa/6t1+\nyPXQ53L/4hGR0tAi7HVhNTgHUZCcB3g2vIO1Nj9se21AJfVFREJk7Mlh8Eu/siMzt9j7Lt+s2fhS\nwayaCn996d7W+zKo2cy9TWJHyfvE0rwPnPm2UzXPzYZ5MPI8Z8nFsiw321mW001qNThymL/xSFyT\n+KH1JXZaa91+esO3eQxl+cc0nPXegsn/g6KMTURESllevuXmUbP4cvZazz4D2tXnpfN7kpaiBL5U\nUMbAMY96XyQU18b5TtLzmY7w7TDIWBWb41Y0vzwLvz7n3d7uBDjpOZUXi4XKNeD8T6Bua/d2m+8s\nabD0R3/jitbcz2D83d7tnU5zlmAQkfKuVsjz3dbawjJKh4c8zwK+8ugXnC4TvF/SJbrQRETKn783\n7+a0//3KwvU7oz6GZuNLhfL9g+7bU6vBoTf7G0tFo+R94mp3nLP0n5e/J8GnlznL55VVk1+FLUvc\n2/rfDDUa+xuPxDWJH7oAa7ZHnx1hr5sUdsDA6PJtFIxAbxFNYCIiUrry8y23fTyLz2d5j2Tv36Ye\nr1zQi0opyT5GJpKA6reDquEFjkpoz1ZnPfFnu8JHl8Dff6jUfqSmD3dm4XvZ/1BnLXNdaMdO9fpw\n4WeQ7nGpkJcNH54Hqz2rUSeGFb/Bp1d6t+/fD055GZI0cE2kAqga8ryo8jiHBL4GS+nvcetkrV3L\n3hMl6kcfnohI+TFl+RZO+d+vLNm4q0TH0Wx8qTCWToRlHoOk+17lXJ9JzO2u0ljJ+7Kg18UwsJCB\n+fO/gK9vL5v32HZuhB+fcG+rtR8cfK2/8QgQ3yT+9pDnNTz6bAl73TKC41alYGR5teIGJSIipSs/\n3/LvT/7k0+mrPfv0bVmH1y48kMqpSuCL8PvLsGtDfI5t85yZwW8dDa8PhD9HOaWzxN38sfDF9d7t\njbrCuSOdMvASW7X2cxL5VWq7t2fvdNag27TI37gitWkRfHgu5GW5t9drC2e/r58dkYojdNJDVa9O\nxph6QFsK7oH8WsRxd1Iw6SE96uhERMqJT6ev4vzX/2DLrthc42g2vpR71sL3HsvCVa4Jh1znbzwV\nQZ2WzG9/A1N6v6TkfVlx2G1w4FDv9ilvwM9P+RdPrEx4CLK2u7cd/bDuV5SSeCbxl4U8TzHG1HLp\nMy/wqurhFQABAABJREFUNXhBenBhBzTGtAVCf1JKNoRSRER8lZObx/lv/M5H07xLePduUZs3L+5N\nlTQl8EXYvMT5EF1ch93urKtduWbk+6yZAZ9eDs92cUbe7txY/POWZ8t+go+HOuXb3dRpBRd86pR/\nl/ho0B7O/9gp4ehm92Z475TEWyZi5wZ4/3SnAoabag2c76tqHX/jEpHSFHp3rKoxxusP9oDA12Bi\n/pcijluFgvsrZbiOp4hIyeTnW54a9xc3j5pFdgyT7pqNL+Xewm9g9VT3tn43QpVafkZTvoWUzV/f\n6Ahsku6DlhnGwPFPQYeTvPtMeBimv+dfTCW1dpZ3vC36Q4eT/Y1H/hHPJP78sNddwztYazcCwcWQ\nDXCuMYUuHnppSF8AfWoSESkDcvPy+XjaSno+/B2/LQ0vwlKg5361eHtIH6pV0qhTEfLz4fPrIDez\n6L6h6raBw2+Hox+Cm+fDCU87s3wjtXMd/PAIPNMJRl8Na/8s3vnLozUzYOR53rOo0xs7s8RVVjD+\nmh0I57wPSanu7dtXwfBTYddmf+Pykr0LRpwN21a4t6dWhfNHQe39/Y1LRErbsrDXfT36DQ55ngdM\n8jqgMSaVvasVhi9fKCJSIWTm5HHdyBm8+MPiuBxfs/Gl3MrP956FX60BHFTI0mASOa15Xz4kJcNp\nb8B+h3j3+eIG+Otr/2KKlrXwzZ0UjAUOYZLg2MecgQtSKuKWxLfWrgJWhmzq49F1DAVJ+f2AR906\nGWOOAm5k75+kokrJiYhIKcrLt/y6OodB//2RWz/6kx2ZuYX2f2doH6orgS/imPomrIjio84xj0By\nIMGZVg16XwZX/wEXfAKtj4r8OHlZMPMDeLU/vH2Cs65XfgWc1LdpMbx/BmR75EIq13IS+ErC+qfV\nEXDaaxRcQoTZtBBGnAlZO30Nax95ufDxpbBmunu7SYIz34EmPXwNS0QSwqzA1+D9jX3qcRpj6gKn\nBfpYYJq1trBqhG2Cuwa+atKDiFQ4G3ZkcvZrv/Pl7LVFd46SZuNLuTX3U9gw173tsFud+wsSPSXv\ny5/UynDuCKjfwb3d5sNHQ2DlZH/jKq55o73vP/a8GBp18TUc2Vs8Z+IDTAh5fqJHn7cCXy3Oxebt\nxpjvjDFDjDGDjDFnGGPeBr4EUgN9LPCjtVafmEREElBuXj6fTFvFXb/s4fXZ2SzfvDui/WpU9phZ\nKVLRbF0B3w5zbzOFfHxrdQS0OXrf7UlJ0HoQXPAxXDvVSex7lSR3s+IX+L8L4PnuMOlF2LMt8n3L\nsozVMPwU2L3JvT21Kpz/ETTwuGCT+Ol8mlNlwsvqac7PbK5H9YR4sxa+vh0WFjLq/oT/Qttj/ItJ\nRBKGtXY9ELybZ4AzjDHXB9uNMZWBN3HK4weT8p8WcdjwEUHxmYIqIpKg5q3Zzikv/sqsldvifi7N\nxpdyJy/HqcjnpmZz6HWJr+GUK0rel29VajuTZmo0dW/P3QMjzoKNC/2NK1I5e2D8fe5tlWrCEff4\nG4/sI95J/M9Cnh9ijGke3sFaOxUYQUFy3gADgTeAccD/ARcBKRSMUs8H7o1f2CIiEo1g8n7Qf3/k\nlo9msX63SxkeESmctfDF9ZDjMdnu4GvcE/kmGY55tOgSV/XaOMnPm+fB0Q9Dzf0ij23b3zD+bvhv\nR/jyVmeWenm1ewu8fxpkrHRvT0qFs4dDc69iUxJ3vS8t/IJy6Q/w6RWlU0Hi1+ecahpeDr0ZDhzi\nXzwikoheZe/7IM8YY1YaY34BVgEnUXAPJBN4t4jjHRnyPBNYENtwRUQS1/fz13PmK5NYk1HMpcii\npNn4Uu7MHAFblrq3DbgDUir5G095oeR9xVCzKVzwqVOp0c2erc79pe3xqxITtUkvQsbf7m0D/g3V\n6vkbj+wj3r85vgY2ADUDr6/APfl+DdABZ+R48CI19A50sHxccNtt1lqV0hcRSRC5efmMmbmGFyYs\ninjWvYh4mP4eLJ3o3ta0Fwx6wBkpO+WNvdsOHFq8GeFVasEh18FBV8FfX8Efr0Revj9nF0x53Xm0\nORoO+pdTBaC8rJGVtRM+OBM2euU/DJz6ilPdQEpX/1th12b442X39nmj4cvacOIz/v18zv4YvvOo\npAHQ5Uw40mOku4hUJO/g3CM5iIL7HU2BJhQk9wl8/a+1doPXgYwxycDJIftMtdZWwDVwRKSisdby\n1q/LeeTLeeQXMoegx361eO3CA6mf7p2InDhxIgADBgyIbZAiiS4nE378j3tb3TbQ9Rx/4ylPup9b\n2hGIXxq0h/P+D94bDLkuA8oyVsL7p8OQr5z7cYlg+xr45b/ubXVbQ+/L/Y1HXMV1Jr61Nsda28ha\nWyXwcJ09b63NwBk1HhxZHn6HzQQea4FzrLXPxitmERGJXPjMeyXwRUooYzWM95hZnJwGg/8HSckw\n4C6oXLOgrXJNGHBndOdMToGOJzsXElf+BN3Pd84VqUXjnRHFLx0EU96E7MKW6y0DcrNh1IWweqp3\nn+OfhC5n+BeTeDPGqUDR9WzvPtPehgkP+xPP8l9g9FXe7S36w+CXys+AFxGJmrXW4sy2n0tB0j40\ncR+8D/IVUMjIIAgcp07I6+9jGqyISALKycvn7tFzeGhs4Qn8k7o1YeTlfQtN4ItUaFPfgu2r3duO\nuFszyEUitV9fOOMt72UwN8yFD893Bs4kgu/uhxyPe/nHPAYpxbg3KHET73L6EbPWbrPWDgHaAbcB\nw3HK6Y8FXgHOAVpZa0eVXpQiIgJK3ovEhbUw9kbI2u7efvi/nZG9ANXqwuF3FLQNuNPZVlKNu8Ep\n/4Ob5sHAu6F6w8j33fQXfHmzU2r/2/tgm0cZ+kSWnwefXQlLJnj3GXAX9NFo5ISSlOQkxtsUsr78\nz0/Bb/+Lbxwb/4IPz4O8bPf2+u3h7PdVilJE/mGt3QT0Au4BFlKQuDc4yf3rgMHW2qIWXr458DU4\nQuiL2EcrIpI4MvbkMOTtKYz4w6MEcMCNg9rw/DndqZya7FNkImVM1g74+Wn3tkZdoMNgf+MRKeva\nnwAneMxuB1jxC3xWSsv+hVo5Bf78P/e21oOg7dH+xiOeEm4YlbV2MeDxl0NERBLBu7+t4KGx80o7\nDJHy5c//c2a1u2nUFfrdsPe23pcVzHrvfVlsY6leHw6/HfrdCHM/c0qVr5kR2b6Z25z1wCe9CB1O\ndMr179c38WceWwtf3QpzP/Xu0+dK599FEk9yKpz5jlMV4u/f3PuMuxOq1oFucSgHuWMdvH8GZGa4\nt1dvBOd/nDhl80QkYVhrs4FHgUeNMZVwZtRnWGsjGiUbKKX/FAX3Uay1NsI/2iIiZc+KzbsY+s4U\nlmz0rgCWlpLEk2d0ZXD3pj5GJlIG/f4K7N7k3nbEfc6AaREpngOHwM71MPEx9/Z5Y+CbO+C4J0rn\nXll+Pnzzb/e2pBSn2qEkjIRL4ouIiIhUODvWw9eFfIA+5X9OkjJUShocflt840pJg25nQ9ezYOUf\n8PvLMP8LiGSZXZvnXJjMGwONu0Pfq6DTqYk7C/mHR50ygl66nAnHPp74gxEqsrSqcO6H8M4JsH6O\ne5/RVzvLT7Q7LnbnzdoJI86CDI+ZYGnV4fxRUKt57M4pIuWStTYLZxnB4uyTB3wen4hERBLL5GVb\nuHL4VLbuzvHsU7daGq9d1Ite+9fx7CMiwO4tMOkF97bmfaHNUf7GI1KeHP5v2LEWpr3j3j75NUhv\nBP1v8TUsAGaPgtXT3Nt6Xw712/kbjxRKQ6lERKTYLj54f54+sxst6lYt7VBEyj5rnTL0mdvc2/vf\n4pSxK03GOLPpz3oXbpjlzNCvXCvy/dfOdMrUP9MZJj4OOzfEKdAo/f4K/PSEd3vro2Dw/zQLoSyo\nUgsu+ARqt3Bvt3nw0SWwYlJszpeXCx8PgbWz3NtNsvP/pnG32JxPRMoFY0y6MaZryENTRUVEivDJ\ntFWc/8bvhSbw2zaszuhr+imBLxKJSc9DlkclsSPv0wB2kZIwBo5/Gtqd4N3n+wdhxgf+xQTOJIRv\nh7m3VakDAzwmGEmp0Z1IEREptpTkJE7v1YzPr+1Hj+a1SjsckbJt7mewYKx7W4OO0P9Wf+MpSq3m\ncNQDcPN8OPFZZ53vSO3a4JQTe6YTfHaVd+IzFnKz2X/5KPZfPgpyPdYoB/hzlHcZMYDmB8FZ7zlV\nCaRsSG8EF46G6g3d23MzYcQ5sG629zEi+fkJDsDxWgYD4KRnnfXkRET2dg4wI+RxTOmGIyKSuPLz\nLU+OW8AtH80iJ8969ju8bX0+ueoQmtfRZAORIu1Y7wxmd9PqSGjRz994RMqj5BQ4403nvpKXz6+D\nhYXcU4i1X56Bnevc2464G6rU9i8WiUipJvGNMVWNMQcYY3oZY/obYw4rzXhERCRyyzft4sxXfmfG\nym2lHYpI2bVrE3zlURLfJMHglxI3eZxW1Vnn6+rf4cLPoE0x8g952TBrBLx6GLx1nFNyPy83tvFN\neYMDln/AAcs/gKlvuvdZOB5GX+V9jAYd4bz/c75XKVvqHAAXfAqVarq3Z2XA8NNgy1L39kh+fn5+\nGqa/6x3DYbdDz4uKF7eIVBT1ARN4AHxWirGIiCSsPdl5XDtyOi/9sKTQfhcfvD9vXnwg6ZVTC+0n\nIgE/PwW5e9zbjrjH31hEyrPUKs6yf/U8StTbPPjoYlg1Nf6xbF3uvYRGg07Q85L4xyDF5nsS3xjT\nxRjzvDFmFrAdWAxMBiYCEwrZb7Ax5rTAo5ChKyIiEm//z959x0lVXn8c/5zdpfcOikovKoogigIK\nWLD3Hkuixm40JhrLz5IYWzRRo0bs0dhiiTUqFkAFLIBiR7qKCNJ7293z++POZofZO7MzszOzs7vf\n9+s1r91773Ofe5YZ3b33PM953vlmEYfePYFvF62u7lBEarbXL4N1S8KP7fkb2HpAbuNJhxl0Hxms\n+X3BVNjtLKjXJPnzv58Ez5wKf98FJv4d1i+vekxrl8K7N5dvj78p2LfFdT8MrlsaZ/BAy22DJLBG\nIddcHXcMBmEUNQw/vvZneOwIWB0zCj2Zz89n/4ax18e/9s4nwogr0wpbROqE6F8+69w9A7/8RERq\nl59XbeD4+z/gtS/izBgECgz+eNgO/PHwHSkqVMFZkaQs/w6mPBJ+rO+hNeM5hEhN0rh1sOxfs63C\nj29eB08cC0tmZTeOt66Bko3hxw64KagcIHknZ3/dmNk2ZvYaMA04H+gXub7FvOL5FfBs5PWimRVm\nNWAREamgtNS5/a0ZnPHoFFZviD9rtkWjIn49rBsdGmv9LJG4vnkVvnw+/FibnjD8itzGkwlte8BB\nt8LvvoFRN0LL7ZI/d+X38NbV8Lft4dVLYPGM9OMYfyNsiFrbb8PKIBFbZuGX8ORx8WceNGkXlGNv\n3in9GCQ/bLdHsBxCvFuHFd8FM/KjB49U9vmZ8y68dH78a3YbDof+XWtIikgiP0d9H782tIhIHfXV\ngpUcfs9EPp8fZ71uoFmDIh7+5SBO27NL7gITqQ3e/QuUbg45YDBCs/BFsqLlNkEiv2GcaoHrl8Hj\nR1acZJAp8yYEVTDD9DkEuu2dnetKleUkiW9mowiS96MoT9R7zKsyt1Oe6G8PHJTxQEVEJK6V6zZz\nxqOTufOdmQnb7bBVc169cBhXHdyXG4c24tf96tOljUpRi2xh/fJgLe1QFpTRrxdn9nBN0LAF7HE+\n/OZTOOFJ6DIs+XM3rwvKl98zCB4/Gma+DaWlyZ+/6GuY8nDF/VMehp+/gWVz4fGjtkzSRmvQPJiB\n36Z78teU/NZrFBxxb/zjP38FT54Am9ZV/vlZ9DX8+5Q4D70IStAd91j+LoMhIvniy6jvm5hZs2qL\nREQkz7z19SKOHf0BP63cELdN51aNeP68PRneu30OIxOpBRbPCJa2C7PzCdC+T27jEalLOmwflNYv\nbBB+fMX38Pgx8Z9Xpau0BF6/PPxYYX3Y/8+ZvZ5kVNbrI5jZYOA/QCPKE/ZGUD5uJtAA6FZZP+7+\nrpnNB7aO7DoceCUbMScSqQAwCNgRaEvwsywGvgCmuHtJrmOKZWbbEcTYAWgJrAPmAh+7+4JacL2c\nvgc14T0XybZvflrF2f+ayvfL1iVsd/SAztxw5I40rBfMeCwsMIZsXY8/nLAXL01bwF1jZzJvaeI+\nROqEN66ENYvCjw0+F7atJSsHFRRCn4OD18Iv4KPR8Pmz8ct3xZr1dvBq2yso07/zidCgafz27jDm\nSvCQpL+XwH9/Dyvnx/+3L2oY3FB12im5+KTm2Pn4YPDMG38IP/5D2fIKmxN8fn4Hy+bBxjg31M22\ngl88G39kvYhIuakEs/HLsk/7Ai9UXzgiItXP3Xnw/bnc+Po3eILpXgO2bcn9p+5K26ZxkiAiEt+4\nG8LvdwqKYO8490oikjnb7QnHPBQ8fwj7b3HRF/Dvk+EXz0FRhn7PffJY0G+YPc6H1l0zcx3Jiqwm\n8c2sMfA85Ql8A+YB1wHPu/taMzsbSDA1ZgsvABdGvt83o8FWwsyaA38AziJI5IZZYmb3A7e4+6qc\nBQeYWX2C2C4AeidoNwn4i7vHqZ2Rn9eL9JXT9yDf33ORXHlp2o/84fnP2bA5/kzYeoXGNYfuwMm7\nb4uFlA8uKizg6IGdObz/Vkrmi8x8K/7I91ZdYeTVuY0nVzr2CyoM7PvHYP29yQ/CmiTLhC2ZAa/9\nPliDfMCpQUK/5bYV280YA3PGxe/nuwnxj1khHPMIdBmSXExS8ww+B9YtgfduDT8+663E5383Mf6x\n+s2CBH6LreO3ERGJcHc3s9HANZFdv0dJfBGpwzaXlHLNS1/x1MffJ2x3eP+tuOXonf43cUBEUvDT\nZ/D1i+HHBpymRJ5IrvQ9FA66LX6FzrnvwQtnw9EPQ0EVi6mvXwFj48y0b9oBhv2uav1L1mW7nP4l\nQCfKy+W/Cwxw98fcfW0a/b0X9f02ZtaxqgEmI1JN4BvgSuInc4kcuxL4OnJOTphZP+Br4C4SJNQj\n9gReNLPn0i3Zl+vrRa6Z0/cg399zkVzYXFLKH1/5iouenpYwgd++WQOePmswpwzeLjSBH60smf/2\nJXvz12N3Vpl9qXs2rIJXLop//LC7oH4t/++iSVvY+1K4+As46kHYakDy525YCZPugjt3DkYmfzeJ\n/03TKd4Eb16VflyH3wN9tFpTrTfiKtj19Mz2WVAExz8GHXfMbL8iUtv9BZhDMNlhsJn9qZrjERGp\nFivXbea0hz+uNIH/2317ccfx/ZXAF0lXvEReUUPY69LcxiJS1w06A/a6LP7xr16AMVeQsDRNMt67\nNZjMEGafa6GBVvXKd9lO4v+a8hn4C4Aj3X1FFfr7LGa7bxX6SoqZDQHeAbYKObwBCKsHuzXwjpnt\nmc3YAMxsODARCFu4tQRYDoRl344G3jKzBDVpq/96kWvm9D3I9/dcJBd+Xr2BXzzwEY9MnJew3W5d\nWvPqb4YycLvWKfWvZL7UWW9dDat+DD+26xnQNYW142u6ovqw07Hw67Fwxluww1HBbPhkeCl88wo8\nciDctxdMexI+uheWzkovllE3Qv8T0ztXahazYMT7Dkdmrs/D7oLuIzPXn4jUCe6+DjgE+IngmclV\nZvaomSUaRC4iUqvMW7KWI++dyKTZS+O2qV9UwN9P3IWL9u1Z6cQBEYnjuw9g5pvhx3Y7C5p3ym08\nIgIjroRdTol//KPRMPHO9PtfMjPoI8xWuwRLVkrey1oS38y2B7aJbDpBufEVVex2blR/AF2q2F9C\nZtaeYDmA6OzSZuA2oEdkfyOgJ/DXyLEyjYHnI31kK75tgGeB2OEyTwFDgAbu3hpoAAwHYkva7w48\nlK/Xi1wzp+9Bvr/nIrkw9btlHHrXBD6etyxhu18N6cITv96d9s0apn2t2GS+SK02ZzxM/Wf4sRbb\nwH5/zGU0+cMMttkNjn0kmJ0/9BJo1Cr58xd+Di+eC29dm971h14SrAEmdUdBIRx5f2YS78OvhP4n\nVb0fEalzzGxbYB1wHEGlOwNOBn4ws3+b2TlmNsTMtjezbVN9VefPJiKSjI/mLOWIf0xkzuL4xVrb\nNq3P02cN5rCdw+bZiEhS3OGdOAV/6jeDob/NbTwiEjCDQ+6AXgfEb/P2tTDtqfT6H3MVlBaHHzvg\nlqqX6pecyOa7tFPka9kQyRer2qG7lwDRf9m1qGqflbge6BC1vR44xN0vdffZXm6Wu/8eOCzSpkxH\nIJsl8R5gy1LvJcCp7n6Su0+K/Hvh7sXu/q67HwHE1vA9zsz2z9PrQe7fg3x/z0Wyxt351wfzOOH+\nD1m0KqzgRKBhvQLuPKE/1x66A/UKM/NrpCyZL1JrbVwDL18Y//ihd6qEFQTrie97Lfz26+DfpF0q\nRZfSKDE24DTY55rK20ntU1QfjvsXbL1r+n3scjLsnaD8nYhIYvMIJiq8R1BlsKyKYQPgGOCeyLEv\nIu1Sec3J3Y8hIpK6Z6f8wMkPfcSKdZvjtundoRkvnj+EAdumMMBXRCqa/Q58Pyn82J4XQOPUqmuK\nSAYVFsExj0DnQfHbvHwBzHw7tX5nvg0zx4Qf63csbLt7av1JtclmEj96NvIGd5+foX43UD4wIGv1\nl82sG/CrmN1XuXucujPg7m8AV8fsPt3MumYhvl2BUTG7b3L3fyU6z93/DsTW0LjJKqlHlevrRa6Z\n0/cg399zkWzasLmE3z37GVe/9BWbS+InwrZt3ZgXzhvC4f23zmF0IrXAO3+CFXHWeNzlZOixT27j\nyXf1G8PAX8J5H8CpL0VGJWe6dKYFZQNVkrPuatAU9o+zLmQydjtHnx8RqSqLekGQyPeY/em+RETy\nTmmpc8sb07n0uc8TPnsY3rsdz527B51baek9kSpxh3euDz/WqDUMPi+38YhIRfUbw0nPQNte4cdL\ni+GZU+HHqcn1V7IZxlwRfqyoEex7XVphSvXIZhI/+q+s9XFbpa4F5VOtVmWw31iXAPWitmcBf0/i\nvDsibcvUA7JRk+asmO0lQLJPIa8gKNtXZgAVE/TVfT3I/XuQ7++5SFb8sGwdR987if98Emed7ogR\nvdvxygVD6dupeY4iE6klvpsEH98XfqxZJ9j/htzGU5OYQbfhcNK/4cKpsPs5Qbm/jHB48/+ChxpS\nN7nDuzenf/5bV+vzIyJV5VGvePtTfYmI5KX1m0o474lPuHf87ITtfrlnFx48dVeaNayXsJ2IJOGb\nV+CnaeHHhl0CDfWMTyQvNG4NJz8fPCcMs3ktPHEcLE38OxSAyQ/Bkhnhx4b+FlqoGm5Nks0k/pKo\n7zNS9t7MtgKKonYtzUS/cRwes/1QWbn4RCJtHo7ZfUSmgooSu4jnU+4ev/51FHdfAbwQs/uoPLse\n5P49yPf3XCTj3puxmEPvnsBXCxKPibp43548dNogWjTWTbRISjatg5cuiH/8kDugUctcRVOztekO\nB94Cl3wNB9wMrTJQ9GbOOJgZt+CO1HYzxsCc8emfr8+PiFTN91l6fRf5KiKSNxat2sBx933AG18t\njNumsMC4/vAduO6wHSjK0NJ9InVaaQmMjTMHr1knGHRmbuMRkcRabhsk8hvESaeuWwL/OhJWL4rf\nx9qlMP7G8GPNO8OeCZb6lLyUzb+Ioj9JBWa2Ywb6HBr5WlYa7rsM9FmBmQ0EYoejPJNCF7FttzGz\nAVWLqpyZtQa6x+yemGI3se0PN7PQz0Ourxe5Zk7fg3x/z0UyrbTUuWfcLE575OOEa9A1a1jEQ6ft\nysX79qKgQFU5RVI2/kZYFmeUbL/joPcBuY2nNmjYHAafG5TajzdCORVjroTiTVXvR2qW4k3w5lVV\n70efHxFJk7t3cfeu2XpV988nIlLmyx9XcvjdE/nix5Vx2zRrUMTDvxzEKXt0yV1gIrXd58/Akm/D\nj+19GdRrlNt4RKRyHXaAE5+Ewvrhx1d8B08cAxviTMgbfyNsiPP7dv8/BaX7pUbJZhJ/cuRrWTm3\ngzPQ52lR368HPs5An2FGxGwvdPc5yZ7s7rPZchADVJzJXhXtQ/bNCtmXyMyQPvvlyfUg9+9Bvr/n\nIhmzesNmznl8KreO+TZhFeA+HZvxygVD2advh9wFJ1KbzJ8CH9wTfqxJu2BWuaTPCqF+k+qOQkRE\nRERE4njzq4UcO/oDFq7aELdN51aNeP68Pdm7V7scRiZSyxVvij8bt1UX2OWUnIYjIinoMhSOeiD+\n8YWfw9O/qDihYNFXMCW2YHTEtnvADskUx5Z8k7UkvrsvBD6LbBpwoZk1Tbc/M9sLOIDydd7ed/fi\nKgcabvuY7XQGC3wUs903zVjCtA7ZF384a7iw9rE/d3VdL+xYtt+DfH/PRTJi5qLVHH73RN78OkHZ\nHeCwnbfiP+ftSZe2SpCJpKV4I7x0Pnhp+PGD/xqsdyXpK6oP+99Q9X5G3Rj0JXWLPj8iIiIiWePu\n3P/ebM5+fCrrN8dfqXLgdq146fwh9OrQLIfRidQBnz4GK+KsrjP8SijUcpkieW2HIxIfn/cePHJQ\neSLfHd64PM5zSIMDbgJTld2aKNsLDP2TIIHvQCfgfrPUPylm1hV4smwz8vXuTAQYR5+Y7aRnZEeZ\nW0mfVRG2Fn2DFPtoGLIvXtI519eD3L8H+f6ei1TZfz//icPvmcicJWvjtiksMK4+ZHvuPKE/jesX\n5TA6kVrm3b/A4unhx7Y/ArY/PKfh1Fq9RkG32GI6Keg+Enrun7l4pGbR50dEREQk4zYVl3LFf77g\nxtemJ6z+d+QuW/PEmbvTpmmqjxhFJKFN6+DdW8OPtesL/Y7JbTwikh0/TobbetBh4VjaLv4A5r4X\n3m6XX8BWu+Q2NsmYbCfx7wXKhnwZcDzwopklvYCpmR1DMMN5q8guB6a5+38zGWiMXjHbP6TRR+w5\nsX1WxfKQfanWnApr3ztPrge5fw/y/T0XSVtxSSk3vvYN5z/5Ces2xR8B37ZpfZ44c3fOGNqVNMZb\niUiZnz6DCbeHH2vUGg6KczMtqTMLZkJbGn/SWmHkXP3/rs7S50dEREQko1as28RpD3/M05MTP1a7\nZL9e/O24nWlYrzBHkYnUIZMfgDULw4+N/D8o0H93IrXGhpX0nX4n239zW/jx+s1g5DW5jUkyKqtJ\nfHffBJwJlBAk3w04BJhhZo+Z2SlA9+hzzGwbMxtuZteY2efAv4G2UeevA07LZtxAq5jtOL/1Evop\nZrtleqGEWgBsjtmX6lCasPbx6vrm+nqQ+/cg399zkbQsXbORUx76mPvfS1xcYpdtW/LqhcMY3K1N\njiITqaWKN8GL54PHGTBz4F+gafvcxlTbddgedj099fN2PR3aa+WbOk+fHxEREZGMmLtkLUf9YxIf\nzFkat02DogLuOnEXfrNPT00eEMmGDSvjTyrYagD0OTi38YhIThTEew651++hWYfcBiMZlfVaye7+\ntpmdAzxAkIgHaAL8IvKKZsC8mG0oT+CXAKe7+5fZitfMGlFxcMO6NLpaH7NdZGYN3X1DepGVc/cN\nZjYVGBy1+1Agzm/oUIeG7GuaD9fL9XtQE95zkXR89sMKzn18KgtWJv4Injx4W64+ZHsaFGkkrkiV\nTbwDFn0Rfqz3QSpbly3Dr4TPn4GNq5Jr37AFDL8iuzFJzaHPj4jkETNrBwwFdidYlrA10BjA3fep\nxtBEROL6cM5Sznl8KivWxc4BKte2aQMeOHUgu2wbO49GRDLmg3tgfVhRXWCfq1VJTKQuadwWdvt1\ndUchVZSTBY/d/WEz+xn4J8ENaFkyPzpJX7Yd/Zskev8K4AR3fzOrwYYnltNJwoad0yTNvsKMYcuk\n+nAzG+Dun1R2opntB+wccig0qV4N18v1e1Dt73lkkESYPqtXr2b8+PFphJMZq1evBqjWGCR17/6w\nmX99vYniBOvPFRXAadvXZ1jLpXww4f2sxKHPj1RFTfv8NFkzj4FTbwktc1Rc2ISPWx/DpnffzXlc\ndUXnzsfQY/bDSbWdtfUxzJ8cZ7CF1En6/NRMZb8nRGoDM9sduAw4jIqDzI3y5yNh574ObB3ZnObu\np2YlSBGREM9M+YGrXviCzSXxH0D06diMB0/blc6tGucwMpE6Zu3SIIkfpssw6DYit/GISPVatwRG\nD4W9LoN+x0JhTtLBkmFZLacfzd1fJUjk3g9spGICv+x7Z8vkfQlB8r9/DhL4AA1D9m1Ko5+NIfsa\npdFPPKPZMi4DHjGz5olOiozqvy/O4UTx5fJ6uX4Pasp7LlKpzaXOI19u5JGvEifw2zQ0/m/3hgzr\nXC93wYnUYlZaQu9v76LAi0OPz+pxOpsaaLmKbPpx64NY12irStuta7Q1P259UA4ikppEnx8RqS5m\nVmRmtwETgSOAQrac3JCMd4EdI6+TzGy7jAYpIhKitNS56fVvuOy5zxMm8Ef0bsez5+yhBL5Itk34\nG2xaE35spGbhi9RJy+bAi+fAPYNg2lNQEv7cUvJXTodeuPuPwDlm9n8EN6d7EZSI6wA0I7hRXQ8s\nAT4D3gZecvfvchhm2Kzp+mn00yDJvtPi7gvN7G7gkqjdOwHvmNkp7j499hwz2wV4HOgap9s4v+Vz\nfr1cvwfV/p67+8Cw/WY2tVmzZgOGDx+eRjiZUTYDtjpjkOQsWLGecx+fymfzE68GMaxnW/5+wi60\napLOxzw1+vxIVdSoz8+EO2D1rPBj3fehzwnX00c3zNm39e3w1PEJmzQ+8nb27rVvjgKSGkWfnxqn\nWbNm1R2CSJWYWRHwCrA/5bPty5YTDJv4EM/9wJ8oHwBwAnBLpuMVESmzblMxv/33NMZ8tShhu9OH\ndOWqg/tSWKB7IZGsWvkjfPxA+LFeB8C2u+c2HhHJL2XJ/PXLYY/zqjsaSUG11E9w9yXAg5EXAGZW\nANTPg/XDwxLLYTO1KxN2TtwkeZquAPYGohPAuwJfmNk7wAfAMqANMAwYTnn1hRJgPBC9pt6KPLle\nrt+DmvSei4SaNHsJFz75KUvXJi4icd7w7vxu/966gRbJpMUzYNyN4cfqN4ND79SI91zpNSooEThn\nXPjx7iOh5/65jUlqDn1+RCT37gVGRb4vS97PBp4DPgGGAL+prBN3X2Zm44D9Iv2MQkl8EcmShSs3\ncOZjk/nyx1Vx2xQWGH88bAdOHqzCICI58d6tUBJWJBYYcVVuYxERkYzJm0UQ3L2UDM5Ur0Ic682s\nlC2XGkin3lNsGfWSTA9QcPdNZnYQ8DJBRYMyRQQ37aNCTwxu6i8AepNCEj9X18v1e1CT3nORWO7O\nA+/P4ebXp1OaYI5O0wZF3HbszhywY8fcBSdSF5SWwEvnx79Z3v9P0HKb3MZUl5nBATfBvXuCl8Yc\nK4RRN2pAhcSnz4+I5JCZ7Q6cQXnyvhi4DLgr8nwEM2udQpf/JUjiG7CHmdV393SWiRMRievLH1dy\nxqOTWbQqzv0P0KxhEf/4xQCG9WyXw8hE6rBlc+DTf4Uf2+Eo6LRTbuMRkfzTuhvsdRn0O7a6I5EU\nFVTepE5aEbOdTtarU8z28vRCSczdfwZGEIyyT1xDOzAf2M/dR1Mxxvl5dL0VMdvZfg9yfT2RKluz\nsZgLnvyUG19LnMDv3q4JL54/RAl8kWz4+H6Y/3H4sS7DYMAvcxqOAO37wq6nV9y/6+nBMZFE9PkR\nkdy5LvK1rIz+qe5+Z1kCPw2fRH1fn2AQvYhIxoz5aiHHjv4gYQJ/m9aN+M+5eyqBL5JL42+G0pB1\nrq1Qs/BF6rrW3eCI0XD+ZOh/IhTmzbxuSZKS+OFmxGynM4Uu9pyZacZSKXdf7+6XA90JSu29CswB\nVgPrgVkE6+z9Aujl7u9ETo19Ejk1j66X6/egRr3nIrMXr+HIeyby3y9+StjuwB078tIFQ+nRvmmO\nIhOpQ5bNgbf/GH6sXmM47C4o0J9a1WL4ldCwRfl2wxYw/Irqi0dqFn1+RCTLzKwJweB4j7yed/d/\nV7HbLyJfy4b39qpifyIiQFABcPS7sznn8ams31wSt92gLq148bwh9OzQLIfRidRxi76Gz58JP9b/\nJGjbI7fxiEh+UPK+1tA7F246MDhqu1safXSN2f4m/XCS4+4Lgbsir4QiDw12iNk9JY+ul+v3oEa+\n51I3jflqIb975jPWbAwZZRtRYHDZAX04e69umEr/imReaSm8/BsoXh9+fJ9roXXsrwXJmSZtYO/L\nYUwk8Tr8imCfSDL0+RGR7BtKMFsegqT73VXt0N1XmlkxUBjZ1b6qfYqIbCou5f9e/IJnpiQu3nnk\nLltz89H9aFBUmLCdiGTYuBsoH78XpbA+7P2HnIcjItUsumy+Eve1gt7FcF/HbO+WRh+7x2znW0J3\nCOU39xCUk4/9uavzerl+D+rCey41XEmpc/tbM7h73KyE7Vo3qc9dJ+7CkB5tcxSZSB009RGY9374\nsW0Gw25n5TYeqWjQmcz99ksAug46s5qDkRpHnx8Rya7OUd87MClD/a4CWkf61FRYEamSFes2cc7j\nU/lwzrKE7X6/fy/OH9FDEwhEcm3+VJj+avixXU+HlukUmhWRmmhdo040HnWtkve1UFrvppmdmulA\n0uHuj2Wp63Ex2x3NrJu7z0nmZDPrBnSopM/qFvsePuHu8Wti5f56uX4P6sJ7LjXY8rWbuOjf03hv\nxuKE7Xbq3IJ7Tx7I1i0b5SgykTpoxffw1jXhx4oawuF3q4x+Piiqz3ddjgOga2G9ag5Gahx9fkQk\nu6JH265w9/gltlJTn9DpeCIiqZmzeA1nPDqFuUvWxm3ToKiAvx3Xn4N36pTDyETkf8ZeH76/XmMY\n9rvcxiIi1aN1N75pfyg/t9+bvfvvU93RSBakOyTjn+THjWG2kvhTgR+BraP2HQfcnOT5x8dsz3f3\npNabzwUzaw8cGbP7oTy7Xq7fg1r9nkvN9uWPKznn8anMXx6nbHfE8btuwx8P34GG9VS+TiRr3OGV\ni2DTmvDjI66Etj1zG5OIiIjUNNF/2DfMRIdmVh9oSvmzmqWZ6FdE6p5Js5dw7uOfsHL95rht2jZt\nwIOn7Ur/bVrmLjARKTf3PZgTZ/7Y4HOhqVbVEanVosrmL3p/QnVHI1lU1WliVo2vrHF3B16K2X2G\nmVWaGYu0OT1md2xf1e1WoHHU9vvu/mk+XS/X70EdeM+lhnp+6nyOvndSwgR+/cICbjqqH7ccs5MS\n+CLZ9unjMHts+LGtBsDg83Mbj4iIiNRE0eW1GplZywz0uUvka9nzksQlvEREQvx78vec+tDHCRP4\nfTo246ULhiiBL1Jd3OGdOLPwG7aAPS/MbTwikjutu8ERo+H8ydD/RJXOrwOqksSv7Qsd3Q5El7Tr\nAfwmifMujrQtUxzpKy4zG25mHvManlq4yTGz09iytP1m4LxsXCsD18vZe1BN1xOJa1NxKde89CW/\ne/YzNhaXxm3XsXlD/n32YE7cbdscRidSR61aAGOuCj9WUA8Ov0d/PIuIiEgy5sZsD81AnwfEbH+S\ngT5FpBYoLinloQlzeWjCXIpLwp8vlJQ6N732DX94/guKS+MXX92nT3ueO3dPLeEnUp1mjIH5H4cf\n2/M30KhVbuMRkexT8r7OSvedfjSjUeQhd59lZv8Ezoza/Wcz+9Ld3wo7x8xGAbHD4B5x99nZiNHM\nOgJ93H18Em0LgYsIZsVHu9Xdv8zH6+X6PagJ77nUDYtWbeC8Jz5h6nfLE7Yb3K01d580gLZNG+Qo\nMpE6zB1e/S1sXBl+fO8/QIftcxuTiIiI1FSTgZVA88j26cCr6XZmZg2BcwhK6Rsw293nVzVIEakd\nXpq2gOtf/RqAlo3qcfTAzlscX7epmIuensZbXy9K2M8ZQ7ty5UF9KSyo7fO6RPJYaSmM/XP4sSbt\nYPdzchuPiGRXVNl8Je7rprTedXf/VaYDyVNXA4cBZYvINAZeNbM7gXuBeZH93QhumC8C6kWd/zNw\nTRbj6wiMM7PpwH+At4Fp7r4c/pdI3wY4GPg1sHPM+W8D1+Xx9SD370G+v+dSy308dxnnP/kJi1dv\nTNju18O68ocD+lBUWNVVUaTaFG+CiXcG3w+5CIrqV288ktgXz8KMN8KPdewHQy/OaTgiIiJSc7l7\niZm9ARwf2XW4me0Xb/B4Ev4OdCjrHni+qjGKSO1QXFLKXWNn/m/7rrEzObz/Vv97lvDTyvWc+egU\nvlqwKm4fhQXGnw7fgV/svl3W45VqMO0pJYdqkq9fgEVfhB8b9nto0DS38YhI9hwxWv9/lrRn4tcJ\n7r7QzI4BxgBldaLqA5dGXhsi+xqGnL4eONrdF2Y9UOgDXBl5YWbrI9dvSfwlE94CjnT3+Itc5cH1\ncv0e1KD3XGoZd+efk+Zxw3+/SVi6rnH9Qm45eicO3XmrHEYnWTH5QRgXGT3doCkMPrd645H41vwM\nr18WfqygCA7/BxTWCz8uIiIiEu5G4FiCmfMGPG1mB7v7h8l2YGYG/JWgmlzZLPwNwN8yH66I1EQv\nTVvAvKXr/rc9b+k6Xpq2gKMHduaL+Ss587HJLFoVfxJBs4ZF3PuLgQzt2TYX4Up1ePEceO8vmulZ\nE5QUw9gbwo817wy71pV5lyJ1RP8TqzsCyQOawlkJd38f2A/4KeRwQ8KTuT8B+7r7hGzGlkAjoDXh\n7+8m4E/Age6+tiZcL9fvQQ19z6UGW7+phIv/PY0/vvJ1wgR+17ZNePH8IUrg1wZrl8K7N5dvj78p\n2Cf56bXfw/o4y1sM/S102im38YiIiEiN5+5fECxVaAQJ+FbAu2b2NzPrnuhcM2tgZkcD0wiqwxHV\nz23uvjhrgYtIjRE7C7/MXWNn8t/PF3DsfZMSJvC3bd2YF84bogR+XbBsTpDMv2dQMDO/pLi6I5Iw\nnz0Jy+Ks4Dr8cijScpsiIrWNkvhJcPeJQF/gZmBZgqbLIm36uvukHIQ2lyBBPgUoqaTtMuAegtiu\ndffK2ufD9f4n1+9BHr/nUst8t3QtR/5jIi9NW5Cw3b59O/DSBUPo1aFZjiKTrBp/I2yIWlt9w8og\nkS/556sX4euXwo+16wt7XZrTcERERKRWOR/4mPIEfD2CpPwMM5sN/Ca6sZk9bGbjgaXAM0C/qHMh\nvSXsRKSWip2FX2be0nWc/+SnbNhcGvfc3bq05sXzh9CjvUpz1ylK5uev4o0w/pbwY216wM6asSsi\nUhupPk6S3H0lcIWZXQ0MIrhZLhuKuhj4Epjs7in/dePu4wluvNOJ6VrgWjNrQrAGfXeC9dwbAxuB\nhcBXwKfuHv+v8zy8XpzrZ+U9yIfrSd0zbvrPXPT0p6zaEP8jZAa/268X5w3vQUFByv+bkHy06GuY\n8nDF/VMehkFnQPu+uY9Jwq1dGszCD2MFcPg9GukuIiIiaXP3DWZ2CPAyMJjykvgAXSlPzhPZf1rU\n90QdN2A8cJy7xy/tJSJ1RrxZ+Mk4asDW3HRUPxoUFWY4KqkxypL5KrOfP6Y8Aqvmhx8bcaXeHxGR\nWkr/d09RJGH7QeSVNyKl6idFXrXuejHXzul7kK/vudRcpaXOXWNnccc7M0j0iK1Fo3rceUJ/hvdu\nn7vgJLvcYcyVEDbGyUuCYyf/Jxi9IdXvjcthbZxqtHtcAJ0H5jYeERERqXXcfYmZDQNuIpiFX48t\nk/dhdwzRyfti4G7g0qpWwBOR2iPeLPzKXDqqN+cN747pnlRAyfx8sXENvH9b+LEO/WD7I3Mbj4iI\n5ExWf+uaWWPgespHiU9w9/+k2dfRwJDIZilwhbtvrnqUIiK5s3L9Zi759zTemf5zwnbbd2rO6JMH\nsm2bxjmKTHJixhiYMy7+8dljYeab0GtU7mKScN++Dl88E36sTY9gpLuIiIhIBkSS75eZ2Z3AJcBx\nwNaVnLYUeAG42d3nZDlEEalB0pmF37BeAX87rj8H9euUpaikRlMyv3p9NDr+BIN9roYCrZgsIlJb\nZfu37YnAbykfJf50Ffr6Abg4qq+pwFNV6E9EJKemL1zFOf+aWulo+KN22ZobjuxHo/oqXVerFG+C\nN6+qvN2YK6H7SCisl/2YJNz6FfDKxXEOWlBGv16jHAYkIiIidYG7/wj8DvidmXUDdgc6AK2B+sAy\ngqXdPiNYwk6l80WkglRn4bdr1oAHT92Vnbdpmb2gpHZQMj/31i+HiX8PP7bN7tBz/9zGIyIiOZXt\n37InRb4aMMXdP063I3f/2MymAgMJEvmnoiS+iNQQL3z6I5c99xmbS+I/ZysqMK4+ZHtO3WM7la6r\njSY/AEtnVd5u6Sz4+AHY47zsxyThxlwFaxaGH9v9bNh2cG7jERERkTonMrteM+xFJCWpzsKvV1jA\n8+fuwbatm2QxKql1lMzPnYl/h40rw4+NvFrLMYqI1HJZq7ViZo0Iyt975PVsBrot68OAvcysQQb6\nFBHJmvWbijn1oY/47b+nJUzgt2vWgKfPGsxpe3ZRAr82WrsUxt+cfPt3bw7Okdyb9TZMezz8WMvt\nYJ9rchuPiIiIiIhIklKdhb+5pJTJc5dnMaI8MO0pKCmu7ijy0+b1VTu/LJl/zyD9O2fDmp+DUvph\nuo2ArsNyG4+IiORcNofI7URQ7g2CJP57Gejz3ajvGwL9gCkZ6FdEJKOKS0p5/MPvuOn16WwsLk3Y\ndlCXVtxz0gDaN2+Yo+gk58bfCBtXJd9+w0oYfxMcfFv2YpKKNqyCly+Kf/ywu6C+ZqiIiIiIiEj+\nSXUWfpm7xs7k8P5bUVRYS9fVrsszxt1hzSJYPq/ia9nc+BXoUqWZ+dnx/l9hc5xBOftcndtYRESk\nWmTzt2mfmO3PMtBnWR9l01l7oyS+iOSR4pJSXpq2gNve/JafVm6otP0v9+zCVQf3pV5tvVkWWPQ1\nTHk49fOmPAyDzoD2fTMfk4R7+1pYNT/82MBfQbe9cxuPiIiI1FpmdgvwiLtPr+5YRKR2SHUWfpl5\nS9fx0rQFHD2wcxaiyhO1Ocm8eT2s+D5Iyocl64urONs+FWX/zuuXa4nAqlrxffxnSX0Oga0H5jYe\nERGpFtn8a6VN1Pdr3b3ybFYl3H2Dma0ByqbBta1qnyIimVCWvL9r7Mykbpob1ivgpqP6ceQutfgm\nWYJR72OuAE9cjSH83BJ44wo45QWtcZYLc9+Lf4PcvDPs96fcxiMiIiK13aXA781sMvAI8LS7x1n0\nVkQksXRn4Zep9bPxy9TEZL57UFb9f4n5mGT96p+qNTzJkndvgZJNIQcMRv5fzsMREZHqkc2/UqLX\nq9+cwX6j+1JNWxGpVqkm7wG2ad2I+07ele23ap7l6KTazRgDc8anf/6ccTDzTeg1KmMhSYhNa+Hl\nC+MfP/ROaKj/XkVERCTjDBgUed1uZi8CjwJvursnOlFEJFq6s/DL1InZ+NHCkvnVqWw2fWy5++qY\nTV8VDVsGA+D7/6K6I6nZlsyEaU+FH9vpOFVsFBGpQ7KZxF8e9X0LMytwT2cqYjkzKwRaRu1KYYFh\nEZHMSSd5D7B3r3bceUJ/Wjaun8XoJC8Ub4I3r6p6P2OuhG4joEifmax55/rgwUiY/r+AnvvmNBwR\nERGpM8oS9QY0BI6PvH4ys0eBR919RnUFJyI1Q1Vn4ZepM7Pxo0Ul8zu0P5Sf22dpCbUKs+nnbTmr\nvrbMpt+wAib8DVp0hh77VHc0Nde4G4PqjLEKimD45bmPR0REqk02k/hLor43YEfg8yr2uUOkr7Ib\n3cVV7E9EJCXpJu/LPPzLQRQWqDS6SN74/kP4aHT4saYdYdQNuY1HRERE6oongSOAxpHt6IT+VsDl\nwOVm9hHwMPCMu2sig4hUUNVZ+GXq3Gz8aMvm0HfZnWz33TPQ+tr0yuxv3hCZTR9nbfrNVX+PaoTl\n8+Dxo2DHY2DUjdCsQ3VHVLP89Dl89Z/wYwNOhdbdchuPiIhUq2wm8b+OfC27ET2IqifxD4l8LUvk\nz6pifyIiSalq8r6MEvh1SEER9NgXllbxV1VRw+AhQLvemYlLym1eDy+dT/mfKjEOuR0atcppSCIi\nIlI3uPvJZtYUOAE4FRhadijytezGYffI604zewH4p7u/ndNgRSRvZWoWfpk6ORs/SuP1P1Uss1+W\nzHeHtYvDy90vnwerF1Rb3AnVawytusS8ugZfW24LN6SRZC+sH2e99ihfPgcz34J9r4WBv4KCuvmZ\nStnYP4fvL2oIe12a21hERKTaZS2J7+5fm9lCoAPBzefFZnaXu69Np7/Ize1FBDe0Bix396kZC1hE\nJESmkvdSxyyYBq/9HuZPrnpfi76Ee/eEwefB3pdBg2ZV71MC42+KP8hix2Ogz0G5jUdERETqFHdf\nAzwIPGhm3YBfAScD25U1iXw1oBFwInCimc0HHiMot6/JDSJ1WKZm4Zep07Pxo5WV2X/jD9ByOyjZ\nDCu+y9/Z9M06lSfmo1+tu0KTdmAZmlDSulswuGGHI2HqI0HCedOa+O03roT/XgKfPQ2H3gEddshM\nHLXV9x/BzDHhxwadCc23ym08IiJS7bI5Ex/gVeBMghvPdgQ3pyem2deDkT488vpvJgIUEQmj5L2k\nZd2y4CZ2ysPEnd2djtJimPR3+OJZ2P/PsOPRmbsJr6t+nAqT7go/1rgtHPiX3MYjIiIidZq7zwGu\nBq42sxEECf0jgSZlTSJfDdgGuBK40swmAY8QlNtPkEkRkdom07Pwy9T12fhb2LASFla1sGwGFDXa\nMjEfnahvuS3Ua5Td65cl76MrEww+F/oeBq9fBtNfTXz+/I9h9DDY8wLY+w9Qv0ni9nWRO7zzp/Bj\n9ZvC0EtyG4+IiOSFbCfxbwZ+CRQS3GgeZ2YOnJPsWm5m1gy4HziO8ln4JcBN2QhYRATg0Q++4/pX\nv668oQhAaSl8+i94+zpYvyx711n9Ezx/Bkz9Jxx0K7Tvm71r1WbFG+HF88FLw48ffBs0aZPbmERE\nREQi3H0cMM7MmhA8CzkNGEb50oJQXm5/z8jr72b2PMHs/LE5DllEqkGmZ+GX0Wz8atKsU/yy903b\nV89A/rDkfbQWW8MJT8D01+C1S2HV/Ph9eQlMvBO+fAEO/iv02j97cddEc8bBdxPCj+1xgZ5RiIjU\nUVlN4rv7HDMbDVxAeQL+eGBvM7sbeNLdvws718y2A34BnA90LNsd6edBd5+ezdhFRESS8uMnQen8\nH3O4wsu89+HeIcHI973/AA2b5+7atcF7t8Hib8KP9T0sKA0oIiIiUs0iyxE+AjxiZl0JJkmcDHQt\naxL5akDjyLGTzex7d++KiNRqRw/sXCHRvnztJgb++S1KQwrDvXLBUPp1bpGj6KSC6Nn0FV7bZX82\nfSoqS97H6nMQdN0rWLLuw3uDhH08K7+HJ4+F7Q+HA26B5p0yF3dNlWgWfqNWsMf5uY1HRETyRrZn\n4gP8DtgFGEJ5Ir8T8Gfgz2a2CJgFrIgcbwV0Z8vEPVHnTgJ+k4O4RaQOO22P7WjZqJ7K6Ut865bB\nO3+EqY+SVOn8pp1gzU/J9d1rFMydAJvXxm/jJfDB3UGJ/f2uh52OU4n9ZPz0OUz4W/ixRq3goNty\nG4+IiIhIEtx9LnAtcK2Z7U1Qbv8ooGlZE8qfn2yb+whFJB+8N3NxaAK/XbMG7LCVBn9nXdOO8cve\nN+2Q//fsqSbvozVoCqNugJ2Oh1cvrnyiw9cvwayxsM81MOgMKChMO+wab/qrsODT8GNDL9HEDRGR\nOizrSXx332xmRwAvAEOpWPqtI9Ah5rTov2ii278PHOXum7MTrYhIoKiwgKMHdubw/lvx0rQFSuZL\nudKSoJz92Oth/fLK2zdoASP/D3Y5Ge4bBktnJW7fpicc/wSs+RnevAq+eiFx+zWL4IWzykvsd9wx\n2Z+k7inZDC+dB6XF4ccPuAWaxf5JIiIiIpJf3P1d4F0zOw84hmCG/t5smcgXkTpo7PSfQ/eP6N2O\nggL976HqDJp1hI79goR3dMn7lttC/cbVHWB6qpK8j9VpJzjjLZjycDC7fGOCFXU3rYbXL4XPnoJD\n74BOO1ft2jVRaQmMvSH8WNOOsNuvcxuPiIjklVzMxMfdl5rZCIJ17H8D1CfxtMXoYwZsBP4GXOOe\nqB6PiEhmxSbz73hnBj8sW1/dYUl1+WFyUDr/p2nJtd/lZNjnOmjaLtje/wZ46vjE54y6AQrrBWvL\nHftPGPhLeO0yWPJt4vO+nwT37RXc4A2/Ahq1TC7GumTiHbDwi/BjvQ4IqhmIiIiI1ByFQIPIq2z5\nQRGpo4pLShn/7eLQYyP7aLBylTTbCob8BnY9A4rqV3c0mXXE6Mwk76MVFAbPJvocAm9cDl+/mLj9\ngk/g/uGw+7kw4spgVn9d8cVz8Zf72/vS/FpmQUREcq4gVxdy9xJ3v4xg7ba/AfMIbjITveYCfwG6\nuvtVSuCLSHUpS+ZfNqpPdYci1WHtEnjpfHho3+QS+J12hjPehsPvKU/gQ1Amv9uI+Od1Hwk9999y\nX7fhcM6EoGR+/UpuZL0EPhoNd+8K056E0tLKY60rfv4G3v1L+LEGzeGQ2/O/tKGIiIjUeRbY38ye\nABYCo4HBKIEvUud9+sMKVq6vWLy0XqExtGfbaoioFmjdLUhyX/wFDD639iXwAfqfmNkEfrTmneC4\nR+GkZ4NKBYl4KXx4D9yzO0z/b3biyTfFm2D8jeHHWm4Hu5ya23hERCTv5GQmfjR3/wn4PfB7M+sM\nDATaAa0jTZYCi4Gp7v5jruMTEUlk0uyl1R2C5FJpSVACbuz1sGFl5e0btgzWcxv4y/D13MzggJvg\n3j2DG9QtjhXCqBvDE8lF9YNR//2OgTevhi+fSxzH2sXw4rmREvu3BeXs6rKS4mAQRsmm8OOjboDm\nW+U2JhEREZEUmFlvgrL5JwNlf7iU/eEYu2xhJes3iUhtFK+U/u5d29C0Qc4fAeeHjavTOy+T5eUF\neu0PXT6Cd2+BD+6Ov8QdwKr58PRJwSz+A2+BFp1zF2euffovWD4v/NiIK2vnoBEREUlJtf4V4u7z\ngfnVGYOISComzloSur9t0/osWRMnQSg10/cfwWu/i19+fQsGA06Ffa6FJm0SN23fF3Y9HSY/uOX+\nXU8PjiXSfCs45qFIif1L45dcK/PDR3D/3kHJv5FXQaNWlf4ktdKH/4Afp4Yf6zYCdjklt/GIiIiI\nJMHMWgInECTvB5Xtjnz1yKuskuFq4Fngn+4+IaeBikheGBcniT+iT/scR5InlswKksEpWNeoE41H\nXavkfTbUbwz7/TFYxu6Vi2H+x4nbT38V5oyHEVfBbmfVvvdj83p479bwY+36BJ9BERGp83JWTl9E\npKb7fuk6vl+2LvTYq78Zyl+P3ZkubRrnOCrJuDU/wwvnwsP7J5fA32oA/PodOOzvlSfwywy/Ehq2\nKN9u2CJYxz5ZXYfBOe8HM/frN0vc1kth8gNw10D45F91r8T+klkw7obwY/WbBu+byuiLiIhInjCz\nAjM7yMz+DSwA7iFI4Jcl66OT9wBjgVOBju5+phL4InXTjyvWM31h+KzzkXUxiT9jDDwwEpZ8m1z7\n1t34ps9FTB50T3bLywt02AFOHxMsaRf9XCTMpjUw5gp4YAT8+Elu4suVyQ/C6p/Cj424Kry6o4iI\n1DlK4ouIJOn9WYtD9/ft1JyOzRtx9MDOvH3J3krm11QlxfDh6CDZ/dmTlbdv1BoOvRPOfAe2Hpja\ntZq0gb0vL98efkXyAwDKFNaDPc6HC6fCTidU3n7dUnj5AnhoP1jwaWrXqqlKS4OfuXhD+PH9/lj5\nunwiIiIiOWBmO5jZrQTVCl8BjgEaEl4yfw5wLdDV3fd198fdfX2uYxaR/BFvFn63tk3o2rZJjqOp\nRqWl8O6t8OTxsDGJJfHK1rw/fzKLOo7ElTjNjYKCoBrhBVOSm3G+8HN4cB947TLYsCr78WXbhlXw\n/t/Cj3XqD30PzWk4IiKSvzSsUEQkSRNmhpfSH9az7f++Lyos4OiBnTm8/1a8NG0Bd42dybyl4bP3\nJY/MmxiUp//5qyQaW3CzOfL/oHHr9K856EzYtLb8+3Q16wBH3Rcpsf97WPRl4vY/ToH7RwTt97mm\naj9Dvpv8AHz/Qfix7YbCwNNzG4+IiIhIFDNrDZxEUC5/l7Ldka+x5fLXUF4u//3cRioi+W6sSunD\nxtXw4rnwzSuVt225XTCYXmXzq1fT9nD0g7DzifDf38HyufHbeil8fB988zIceAv0PazmVtX78B+w\nfln4sX2uqbk/l4iIZJz+ShERSUJJqTNp9tLQY0N7tK2wT8n8GmL1QnjzavjimeTadx4EB90GW/Wv\n+rWL6sPel1a9nzLb7QFnvQtTHoKxN1Qy68Bh6iPw9UvBDeKAU2tfqbZlc+Ht68KPFTUKyugXqCCR\niIiIVA8zex44GKhHeeIetpxx78B44J/Ac+6uGwoRqWDD5hImzQ6fdFBnSukvnQ1PnwSLpydu16BZ\nsCzdzicpeZ9PeuwD530A790GE++E0s3x267+CZ45FXqOgoNuhVbb5S7OTFi3DCbdHX5su6HQfWRu\n4xERkbymp9ciIkn44seVrFxf8SaiflEBu3WNP5O5LJlfVmZf8kTJ5uCm6a5dk0vgN24Lh98Dp7+Z\nmQR+thQWwe5nw4VToP8vKm+/fhm8enFQlm7+1KyHlzPu8PKFsDnOc+59roE23XMbk4iIiMiWjgTq\ns+U690S25wLXAd3cfaS7P6YEvojE88HspWzYXFphf9MGRQzqUosrr5WZ+VZQba6yBP7OJ8Jlc4NB\n7Erg5596jWCfq+GcCbDtnpW3nzkG/jE4SPqXJEj655sJt8Om1eHH9rlas/BFRGQLSuKLiCRhwszF\nofsHdWlFw3qVz2AuS+ZLHpj7PoweBm9eFf/GqYwVwG5nBUnxXU6uOTO3m7aHI/4RDDrouFPl7Rd8\nGiTyX74Q1oZXnKhRpv4T5sWpMtt5t2Cgg4iIiEj1i07crwUeBYa7ew93/5O7f1d9oYlITRGvlP7Q\nHm2pX1RD7mHT4R7M3H7i2MSV6Oo3g1++BkeOhsJ6uYtP0tO+D/zyv3DY3dCoVeK2m9fBW9fA/cPh\nh8k5Ca9KVv0EH98ffqzn/rDt4NzGIyIiea8W/yUnIpI5788ML003tEe7HEciaVu1AJ47HR49BBZ/\nU3n7bQYH5ekPurXyG8d8te3ucNb4YAmAhi0qaezwyWNw1wCY/CCUluQiwsxb8UOwREKYwgZBRYXa\ntnSAiIiI1GTvAr8COrr7r9z9veoOSERqDnePm8Qf2bcWl9LfuCYoqT72esrHQ4XYele44GPoMiRn\noUkGFBTAgFPggilBBYXKLPoSHtoPXr0E1q/Ienhpe+9WKN4Qfmzk/+U2FhERqRGUxBcRqcTajcV8\n8v3y0GPDerbNcTSSsuJNQXm1uwfBl89X3r5JezhiNJz+BnRKYhZ7visohN1+DRd+ArucUnn7DSvg\nv7+DB0bADx9nPbyMcg+WB4hXYWHEFdCuV05DEhEREYnjj0B3dx/h7o+qXL6IpGPGojX8uGJ96LHh\nvWvppIOls+HBfeGblxO32+UU+NVr0Hyr3MQlmdekbVBB4dSXoU2PSho7THkI7tktePbjCQZ3VIdl\nc+GTR8OP7XAkdNISnCIiUpGS+CIilfh47jI2l1T84791k/ps36l5NUQkSZszHkYPCcqrbVqTuK0V\nwuDzIuvJn1j71iFr0hYOvxvOfAc69a+8/U+fBSPZXzwf1oQvJ5F3PnsKZr0dfqxTf9jjwpyGIyIi\nIhKPu//R3edVdxwiUrPFm4W/U+cWtG/WMMfR5MDMt4MB54mq6xUUwcF/hcPugqIGuYtNsqfb3nDO\nRBh+BRTWT9x2zaKgCuPjRweJ83wx/mYoLa643wpgxFW5j0dERGoEJfFFRCoRr5T+nt3bUFBQyxK9\ntcXK+fDMafDY4bBkRuXttxsC57wPB9yURNn5Gq7zrvDrsXDI7cktEzDtcbhrIHx0H5SE3HDmi1U/\nwRuXhx8rqAdH/AMKi3Ibk4iIiIiISBaNi5PEH9G7lpXSd4f3/wpPHAMbVsZv16Q9nPYqDDqz9g3M\nr+vqNYThl8O5k6DLsMrbz34H/jEY3rstqNBYnX7+Bj7/d/ixnU+Ctj1zG4+IiNQYSuKLiFRiwqzw\nWcgqpZ+HijcGN/Z3D4KvX6y8fdOOcNSD8Mv/Qocdsh5e3igohF1PD0rsD/wVUMnDjY0r4fXL4P7h\n8N0HuYgwNe7w30viP8zZ69K69f6KiIiIiEitt3LdZqbGWfpvZJ9alMTfuAaePQ3e+ROQoET61gPh\nrPGw3R65ikyqQ9uecNorwTKIjdskblu8AcZeD/cNq95nGeNuIPSzW1APhv8h5+GIiEjNoSS+iEgC\ni1ZtYMai8DLsQ3ooiZ9XZr0N/9gjuLHfXMmSogVFsMcFcMFk2OnYujtCv3FrOPSOYGb+1gMrb7/o\nC3jkAPjP2bB6UdbDS9qXz8O3r4Uf67AjDP1tbuMRERERERHJsndnLqaktGJisG3TBvTbupZUmFs2\nJ1jm7euXErfb5WT45WvQYuvcxCXVyyxYBvGCKcF7X5nF04NnGS9fCOuWZT++aD9OhW9eCT+26+nQ\nctvcxiMiIjWK6sqKiCQwIU4p/a5tm9C5VeMcRyOhVnwPb1wB019Nrn2XYXDQbdC+T3bjqkm2HgBn\nvB2Uzn/7Oli3NHH7z58OkubDr4DdzqreMvVrFsNrl4Yfs0I4/B4oqmTNPBERERGpMjMrBAYBOwJt\nCco9LQa+AKa4e0k1hgeAmW1HEGMHoCWwDpgLfOzuC7JwvQ6R620duR7ACuDHyDXDa6GLJGHsN+ED\nq0f0blc7lv6b9Xawtnmi8vkFRXDAzSqfX1c1bh3c8+98Erz6W1jybeL2nzwG01+DUTfCTsfl5jMz\n9s/h++s1hmG/y/71RUSkRlMSX0QkgQmzwpP4QzULv/pt3gCT7grK5xevr7x9s61g1J9hh6N0cx+m\noAAGnAp9DglKvU15GLw0fvuNq2DMFfDpv4JBEV2G5C7WaK9fCuvjjKQfejFs1T+X0YiIiEgdZ2bV\nnqgG3N1z9rzHzJoDfwDOIkjeh1liZvcDt7j7qlzFBmBm9QliuwDonaDdJOAv7l7JlN9Kr9cIOAM4\nm2BAQ6K2XwCjgYfdfUMK1xgOjEs/SgDedffhVexDqklJqfPujPCl/2p8KX13mHhHUGUv0T1pk3Zw\n3GOw3Z45C03yVJchcM4EmHQnvHsrlGyM33bdEnjhLJj2BBz8N2jbI3txzZsAs8eGH9v9HGjWIXvX\nFhGRWkHl9EVE4nD3+En8nkriV6sZb8I/BsO4P1eewC+oB0MuCkrn73i0EviVadwaDv5rsJZg590q\nb//z1/DPg+D5M2HVT1kPbwtfvwxfvRB+rG1v2Ouy3MYjIiIiEsw+z4dXTpjZYOAb4EriJ/CJHLsS\n+DpyTk6YWT/ga+AuEiTwI/YEXjSz58ysWZrX2wP4LHK9hAn8iH7APcCnZpbE+lYigWk/LGf5us0V\n9tcrtJr9vGLjGnj2l0GFuEQJ/K0GBPesSuBLmaL6sNelcN4H0G1E5e3nvgv37gnjb4HiBEn/dLnD\nO9eHH2vQAob8JvPXFBGRWkdJfBGROL5dtJrFqyv+IV9YYOzRvU01RCQsmwtPngBPHgvL51bevttw\nOHcS7PcnaNA06+HVKp12htPHwOH/gMZJPAT64lm4e9egOkJJxYdJGbduGfw3Tuk5K4Aj/gH1GmY/\nDhEREZGKvBpfOWNmQ4B3gK1CDm8AwrIiWwPvmFnWM2+R2eoTge4hh0uA5UBYlvBo4C0zS+kGwsz2\nJvj36BmnyWpgJeHvUx9gvJntnso1pe4aOz18JYZBXVrTrGG9HEeTIcvmwkP7w9cvJm7X/xfwq9eh\nReechCU1TJvucMoLcPRDQbWGREo2wvgb4d4hMPf9zMYx8y344cPwY0MuhEatMns9ERGplVROX0Qk\njgkzw2fh79y5Bc1r6k1xLhVvYrt5z0S+37Nq65JvXg8T7oAJtycui1ameWc44Eboe5hm3ldFQQHs\n8gvoczCMvwk+vj/xbIhNa+DN/4NPH4eDboWue6V/7co+P29cAWvjLCE6+DzovGv61xYRERGpmnT/\nAC1L7iZ7fqrtM8bM2gPPA42jdm8G7iQoDz8nsq87cA7wG6DsJqox8LyZ7ZytNeHNbBvgWSB2Rv1T\nwN3AR+5eYmZFwBDgt8DhUe12Bx4Cjk/yeq0J/j0axRz6EPgLMNbdV0baNgWGAb8D9olq2xR4ycx6\npbHkwEzgbymesyDF9pJHxk6vZaX0Z70Dz50OG1bEb1NQBAfcDIPO1H2+JGYG/Y6BHvvA23+EqY8k\nbr90Jjx6COx8Euz/Z2hSxYk7paUw9k/hxxq3hd3PrVr/IiJSZyiJLyISx/txkvhDe1YyklcCkx+k\n67wngu+n9IPBadykuMO3r8Mbl8OK7ypvX1gf9rwQhv0O6jdJ/XoSrlFLOPAW2OVkeO1S+P6DxO0X\nT4dHD4UdjgpugFtsnfo1E31+ZoyBz58OP691NxhxVerXExEREckAd0+p4qGZ9QIeAwYRJOOLgTeB\nl4BPgO8JZm8DtAC2BQYQJJz3J3iu48DHwKnuPrPqP0VSrgeiF/NdDxzh7m/GtJsF/N7M3gb+Q3mS\nuyPwJ4IEfzY8wJbl/UuAX7n7v6IbuXsx8C7wrpn9hmAQQpnjzOyhkJ8pzBVAbNbnXuAC9y1Hwbr7\nGuB14HUz+yNwTdThDsAfgFT/oF3g7qNTPEdqqJ9Wruebn8LHedS4JL47TLwT3vlj4gHjjdvCcY8F\na5+LJKtRKzj0Dtj5RHj14mA5wEQ+exJmvAH7Xx9UfDCDaU9Bv2OhMIU0ytcvwsIvwo/t9XtVihQR\nkaRVSxLfzDoQrEXWkmBUdGE6/bj7YxkMS0TkfzYWl/DR3KWhx4bV5PXlcmXtUnj35vLt8TdBv+NS\nG828dHaQvJ+ZzDMzoMe+cOBfgtJpkh0d+wVlCz9/Bt66GtYsStz+q/8ECfe9LwtmxydbjSHR52f9\nCnjlovjnHn4P1G8c/7iIiIhInoiUTn+dIDlvwGvAhe4eb92oxZHXVOABM+tGsPb6gcBuwIdmdoC7\nT85y3N2AX8XsvipRstvd3zCzq4Hbonafbma3JPh5041vV2BUzO6bYhP4ITH+3cz6suXAgpvM7C13\nr2ypgtgZ+58TvJcJspLg7tea2WCCARllTiD1JL7UIfFK6Xdp05hu7WpQcnDTWnjpguC+MZGtdoHj\nH1f5fEnftrvD2e/BB3fD+FugeH38tuuXwUvnB8n7Q26HF8+B9/4Ce12WXDK/pBjG3Rh+rPnWMDD2\n16eIiEh8OUviR26EzgaOATplqFsl8UUkK6Z+t5wNmys+b2naoIj+27TMfUA1zfgbYcPK8u0NK4NE\n7MG3xT+nzKZ1MOFvwWj8kk2Vt2+xLRxwU1DyXSX1ss8Mdj4eeh8I42+Gj0aDl8Rvv3ktvH1tpMT+\nX6D7yMqvkejz89bVsPqn8PN2Owu2y/ryqiIiIiJVFilH/wrB5AYH/u7uF6fSh7vPAQ42szuBC4FW\nwKtmtqO7h9fazoxLKC+ND8Fs+78ncd4dBAnyHpHtegRl7H+TyeCAs2K2lwB/TvLcK4BTKV8mYADB\ngIA34p0QKd2/Tczu0e6J/kjewt/ZMonfzcy2cneVu5dQ4+Ik8UfUpFn4y+bCv0+GRV8mbrfzSXDI\n36Be7EoVIikqrAdDfws7HAn//T3Meitx++8mwL2R5wvL5iSfzP/86aA8f5i9/wD1Gqb/M4iISJ2T\nUpm3dJhZgZndAnxGcFO5FcEI86q+RESyZkKcUvqDu7WmXmHW/9dZsy36GqY8XHH/lIfh52/in+cO\n37wC9+wO791aeQK/sEFw83T+R9D3ECXwc61hczjgRjhnAmw3tPL2S2fCv46Ef58CK36I3y7R52fq\nP+GTOOP3Wm4L+1ybVOgiIiIieeB6gnLvDnyYagI/xsVA2XpHbQnK1GfT4THbDyWTsI60if1D74hM\nBRUldtToU+6+MZkT3X0F8ELM7qMqOa1jyL6Pk7legrZhfYqwYXMJE2eFVw2sMaX0Z4+FB0YkTuBb\nYVBp74h/KIEvmdWqC/ziWTj2n9C0Q+K2pZu33C5L5t8zKJipX1K8xWEr3RxMdgjTujv0PyntsEVE\npG7KaibKzIxgtvzvCWb9G8ENatlLRCQvTZgVnsQf2kOl9BNyhzFXhq9l5yWRYyH/+18yCx4/OhiJ\nv/L7yq/TcxSc/yGMvEql06tbh+3hl6/C0Q9B0ySeNX7zMtyzG7x3GxTHPEut7PPz+h/i93vYXVpX\nTkRERGoEM2sMnBK165aq9Bcp9V7WhwGnmllWsl5mNhCIrWn9TApdxLbdxswGVC2qcmbWGohdX2ti\nit3Etj/czBI9Pws7tjJkXzwrQvbVC9knwodzlrJ+c8UxM43rF7Jb19bVEFEK3IOKe48fDeuXx2/X\nuC2c9jLsfrYG60t2mAUz8i+YDIPOJOX5gnGS+VstGAMr40xaGHFlUA1AREQkBdkup38acBJbJu0N\n2AxMB+YAq4FkS4yJiGTd8rWb+OLH8GcuQ3u2y3E0NcyMMTBnXPzjs8cGa9z3iixRuWltMOt+0t0V\nRziHadUFDrgFeh+QkXAlQ8yg3zHB+/ruLfDhvVBaHL/95nUw9nqY9mQwu6LnvsH+yj4/xRvC9w84\nDboNTzt8ERERkRwbAkTX0427lnwKyvrwSN9DgUpqBadlRMz2wkhZ/6S4+2wzWwRET38cCXySieCA\nsKnIs1LsI7YOcnugH0GFyTA/huxLJZvaJmRfnPWjpK6LV0p/aI+2NCgqzHE0Kdi0Fl6+EL58PnG7\nTv3h+MehZewKFSJZ0LAFHPxX2PlEeOViWPRFaudHldnv2PYAtpv37/B2HXaEHSor6iIiIlJRtpP4\n0XVtjWB08XXAo+6eyqhkEZGcmTR7aehk8U4tGtK9XZPcB1RTFG+CN6+qvN2YK6HbCPj2vzDmKlgV\n9swrRlFDGHoJDLlI64flswbNYP8/Q/+T4fVLYe57idsvmw1PHA19DoF9/5jc5ydW861h/+vTi1dE\nRESkevSO+n6Fu8cZqZg8d99gZiuAlgSJ/F5kJ4m/fcx2KmXjy3wEHBa13Tf9cCoIS56n+vwprP32\nxEniu/t8M5sDdIvavSfJ/9vsGbO9wN2TKE8mdY27806cJP4+ffO4lP7yefD0yZUnSHc+EQ65XeXz\nJfc67wpnjYeP7oVxNwYTD1KxbA59lv0j/vGR/wcFWppTRERSl7Ukvpn1A7YjuHk0YAmwl7tPz9Y1\nRUQyYcKsxaH7h/Zoi6mUW3yTH4ClSUxyWToL7h0MS2cn12/vg4O111t1qVJ4kkPt+8CpL8PXL8Ib\nV8LqBYnbT38VZryRePZ+PIfcEYyeFxEREak5mkd9n8n1gJpQXgWxeaKGVdAnZjvpWfhR5lbSZ1Vs\nDNnXIMU+wkYNVzbQ4H4geiHk88zsH+6+KdFJkWUofxez+5HKQwztqx7QH9iGYDDHKmAZMMPd56fT\np+SXWT+vYf7y9aHHRvTO0yT+7HHw3K8Sl8+3Qhh1o8rnS/UqLII9L4TtD4fXLoMZr2em386DoJeq\nSYqISHqyORN/l8hXI7iJvEYJfBHJd+7O+zOXhB4b2rNtjqOpQdYuhfE3V96uTDIJ/NbdIqXW90s/\nLqk+ZWvM9dgP3r+t8iUT0kng73wi9No//RhFREREqseyqO+LzGwHd/+qKh2a2fYE66iXJfFXVKW/\nBHrFbMdZ/Deh2HNi+6yKsExhqmuihbXvHbIv2h3AKcAOke2ewKNm9kt3DxtYgJkVRs6Lnok/H/hL\nKsFGDCKoIBA6hTlSKeBV4A53jx1EITXE2Diz8Hfcujntm+dZxTp3+OBueOsa8NL47Rq3gWMfha7D\nchebSCItt4UTnwomG7x2WeWTEirTZSiUlgSDBERERFKUzTousUNAn8nitUREMuK7pevijmwf0kNJ\n/LjG3wgbV2Wmr6JGQamxcz9QAr82aNAU9r0OzvsAuo/MXL/1GgWzNURERERqnrIkdlnC/fQM9Hlm\n5GvZNNZ0kuvJaBWzvTCNPmLXe2+ZXiihFgCxI0d3CWuYQFj7hGvcRxL1BwIzonafAEwzs7PNrKeZ\nNTKzBmbW1cx+CUwGLohqvxQ42N3TubFqTJwEfkQ34DfATDO7y8xSrU4geSBeEn9kvs3C37QOnj8T\n3vy/xAn8TjsHJcyVwJd8YwZ9D4UL0lkxJsaE2+GeQTDtKShJY/KCiIjUadlM4hdGfb/K3ZfFbSki\nkifenxU+C3/7Ts1p21TPOUIt+hqmPJyZvvoeFtwk7XUp1MuzmQRSNW17wsn/geP+BS22qXp/xRth\nzaKq9yMiIiKSe+8BGyLfG3Chme2RbmdmtidBMrhsUMBG4N0qRRh+nUZUfI6U4sLBAMSOmi4ys4z8\n8e/uG4CpMbsPTbGbsPaVLnvg7j8AA4HbKf8Z+wCjCZL76wje9zkEJfOjBwu8Dgxw989TjDVVhQSf\nlUlmlmeZX0lk5frNTPkuvCT9iD559FYu/w4e3h++fC5xu52Oh9PHBLOeRfJVg2aZ6WfZHHjxHCXz\nRUQkZdms4xL9ZF31YkSkRpgwc3Ho/mEqpR/OHcZckXh0fTLa9AhK5/fYJzNxSX4yg+0Pgx77wvt/\nhYl3Ji6xn4iXwhtXwCkvaN1EERERqVHcfbWZvQwcR5B4LwLGmNkp7v5SKn2Z2RHAowTJ2bLlDF92\n99WZjRoIT2RvCNlXmbBzmqTZV5gxwOCo7eFmNsDdP6nsRDPbD9g55FClSXwAd18DXGJmbwF3Ad0r\nOeUH4Gx3T3fx5SXAa8A7wBfA98Aqgpn5HQlK9Z8AxK5BNQB40cxGRgY+JM3MYgdJlOmzevVqxo8f\nn0p3GbN6dfCRr67rZ9tHPxVTUuoV9jerD8tnT2P8nOq/J2q5/DN2+OpW6hXH/9+PU8CsHqfzY6tD\nYOJHOYyucrX9MyTpGZ7JziLJ/FlfTGb+Nodlsmep4fT/H6kKfX7yX9l7lI5szsSfEvV9YzOLLbkm\nIpJXiktKmTR7aeixoUrih5sxBuaMr1of/U8KSucrgV931G8M+1wNB91atX7mjIOZb2YmJhEREZHc\nuhRYE/neCZLE/zGz183scDOrH+9EM6tvZkeY2RvA80D0VMG1wGVZijlstvymNPoJWyM+USn4VI1m\ny7gMeMTMmic6yczaAffFOZxUfGY21MwmEyTWK0vgA2wD/NfMnjOzXslcI2IBcCKwlbuf5u6Pufun\n7r7U3Te7+0p3/9bdH3H3UcDewI8xfewB/CmFa0o1+mxxSej+ndoWUVDdg5rd6fzDS+z82XUJE/ib\n6jXns53/yI+dD9VAbBEREZEkZG2GvLt/aWYzgLKbkP2Bf2freiIiVfX5jytZvaFiSav6RQUM6pJw\nCcS6qXgTvHlV1fv5IQNrjEnNU7wJPri76v2MuRK6jYCiuM+5RURERPKOu/9gZmcB/yKYYOEEyeb9\nI69iM5sOfEcwsxqgObAdQYn2suc5FnVuKcGs7u+zFHbYjO10/ggLW6csU7PwcfeFZnY3cEnU7p2A\ndyLVDqbHnmNmuwCPA13jdLsmzv7oPi4A7mDL5SUBPgAmAD8RvE+dgCEESXQi+44GDjSz09y9kjrk\n4O4zCEr0J8Xd34ssu/ARwQz9MheY2R3uviCFvgaG7Tezqc2aNRswfPjwZLvKqLLZZ9V1/WwqKXUu\nef/t0GMnDd+J4Tt1ynFEUTatg1cugtnPJG7XcSfqn/AE/fO4fH5t/gxJFYzPYF+tu8Fel9Gj37H0\nKFThYimn//9IVejzk/+aNUt/eZZs/7a4Efhn5Ps/oCS+iOSxCTOXhO7frUtrGtaLfQ4jIiIiIiKS\nPnd/2sxKCZ6bNKB8TXsD6gH9gB1jTouevuqUJ/A3Ar9y96ezGHJYIjudtezDzqk0SZ6iKwhmn0cn\nnHcFvjCzdwgS68uANsAwgorJZdUqSwjSNtGlwlYkupiZHU9QPj/aJ8Dp7v5ZnHP6Aw8RlLaHoAT+\nU2a20t3fSnS9dLj792Z2OkGVgDKNgNOAmzJ9Pcmcz+avYNnaikUvigqMYb2qsWrgiu/h6ZNg4ReJ\n2/U7Dg69M6jIJlIXRZL39DsWlLwXEZEUZLOcPu7+GPAywQ3lzmb2t2xeT0SkKuIl8VVKP46i+rD/\nDVXvZ9SNmkVdF+nzIyIiIoK7PwP0ByYSPDspm1lf9qpwClsm7y1ybv8sJ/Bx9/UEs/2jpZOViy1N\nX5LquuyVcfdNwEEEM8+jFQGjgOuAvwPXAiMpfz7mwAUE68tHWxHvWmbWEvhHzO6PgKHxEviRGKcB\nQ4Ho0mRFwMNmlpVsp7u/DsSuab9/Nq4lmTNu+s+h+3ft0ormDevlOJqIOe/CfXsnTuBbYXC/dtT9\nSuBL3dS6GxwxGs6fDP1PVAJfRERSltUkfsRJwDsEN5YXmdnTZtYhB9cVEUnamo3FfPL98tBjQ3so\niR9XvcZQWIUEaveR0FPPjOqsXqOCUvjp0udHREREagF3n+HuexHMHH8SWEl5gj7eayXwBDDc3Ye5\n+7c5CndFzHbHsEaViK39HX4jVkXu/jMwArgFWJfEKfOB/dx9NBVjnJ/gvNOB6PXXSoBTIoMeKotx\nPXBy5JwynYFTk4g3Xa/EbA/K4rUkA8bGSeLv06caHq+6wwf/gH8dCeuXxW/XqDWc8gLscT6YxW8n\nUhspeS8iIhmS1d8gZrZX5NtbCEZn7wEcCxxhZq8C7wHzCG4+w0aYJ+Tu72UmUhGp6z6as5Ti0or/\nG2rTpD7bd2peDRHludJSmPA3GHcDeOxknCSVjcrXDX3dZRZ8BkYPSf1zpM+PiIiI1DLu/j7wPoCZ\n9SJYx70N0DLSZAWwFPg8siZ6dZgBDI7a3iaNPmLPmZl+OIlFkuSXm9kdBM+j9ge2B9oRPBP7EfgG\neBp4ISrx3jemq9jZ69EOjtl+092T/pncfaaZvQUcELX7cGB0sn2kaHrMdhMza5TMoAPJvYUrN/DV\nglWhx0b0aZ/bYDavh1cugs8rWS21Yz84/glotV1u4hLJE+sadaLxqGtVNl9ERDIm279NxrNlcr6s\n3Ft94MjIK11O9uMXkTri/Til9Pfs0ZaCAiUJt7BuGbxwNsx8s2r97Ho6tI99NiZ1Toftg8/C5AdT\nO0+fHxEREanFIkn66krUJzKdLZP43dLoo2vM9jfph5Mcd19IsGZ97Lr1FZhZE2CHmN1TEpzSL2Z7\nUmrR/e+c6CR+/zT6SFbY9OlWgJL4eWjct+Gz8Ldt3Zju7ZrkLpAV38PTv4CFnydu1+9YOPTvKp8v\ndUvrbnzT/lB+br83e/ffp7qjERGRWiQX5fShvNwbVFy7rSovEZGMmDArPIk/TKX0tzR/Cty3V9UT\n+A1bwPArMhOT1HzDr4QGKVS80OdHREREpLp8HbO9Wxp97B6znfUkfoqGAIVR2yuo+HNHaxWzHZ51\nTWxRzHabNPpIVouQfSuzeD2pgnil9Ef2aY/lqirZ3Pfg/uGJE/hWAPvfAEc9oAS+1B1RZfMXdRyJ\nFxRWfo6IiEgKcj2TXYl3Eck7P61cz6yf14QeG9pTSXwgWPfuo9Hw5tVQurnq/e19OTTJ5nMpqVGa\ntIHhl8OYK5Nrr8+PiIiISHUZF7Pd0cy6ufucZE42s25A7ELesX1Wt9j16J9w95LQloE1lC95ANAo\njWvGZj3XptFHsnrGbK9392xeT9K0sbiEiXEmHOSklH7Zc4AxV0Gi/wQatYZjH4Fuw7Mfk0g+aN0N\n9rpMZfNFRCTrsv1b5j3SWOteRCSXJsQppd+tXRO2apnO85daZsNKeOkC+OblxO0atYaihrB6QeJ2\nbXrCbr/OXHxSOwz6NUx5GJbOStxOnx8RERGR6jSVYB35raP2HQfcnOT5x8dsz3f3ROvN55SZtafi\n0o8PVXLaYrZM4scmyZPRK2Y7/CY1Mw6M2a6kPrpUl4/mLGPdporJ80b1Ctm9a+vsXnzzenjlYvj8\n6cTtOvSDE56AVttlNx6RfKDkvYiI5FhWf9u4+/Bs9i8ikgkqpZ/AT5/Ds6fBskom1mwzGI55GBZ+\nAU/FPpeLMeoGKKyXuRildiiqH5Rf1OdHREREJG+5u5vZS8B5UbvPMLNbK5mtjpkVAqfH7H4p0zFW\n0a1sOSv+fXf/tJJzPmXLxP3BZnZRZf8eZcysCDg4ZvdnyZybKjMbAgyL2T0mG9eSqotXSn9oz7Y0\nrJfFst0rfoB/nww/TUvcbsdj4LC7VD5faj8l70VEpJoUVHcAIiLVqbTU45anG9qzXY6jySPuMPWf\n8OC+lSfw97wQfvkqtNgaeo2CbiPit+0+Enrun9FQpRbR50dERETkf8yshZl1NrNtU31lObTbgeKo\n7R7Ab5I47+JI2zLFkb7iMrPhZuYxr+GphZscMzuNLUvpb2bLwQrxvBmz3QU4N4VLXwjEvmdxE+uW\n5kLoZtYWeDRmdwnwZDr9SXa5e9wk/sjoUvrTnoKS4tB2aZk3Ae4fnjiBbwWw/5/h6AeVwJfaL7Lm\nPf1PVAJfRERyTr95UhQZOT4I2BFoCxhB6bQvgCnJjrTOJjPbjiDGDgQl3dYBc4GP3b2SOtdpXa9D\n5HpbU15CbgVBib2P3T38rkMkD0xfuJolazZV2F9YYAzuluXydPlq01p49ZLKy+Y1aAFH3gt9oiaN\nmMEBN8G9e4KXbtneCmHUjUEbkTD6/IiIiEgdZWbNCBLII4DdgK0Injekw8ni8x53n2Vm/wTOjNr9\nZzP70t3fCjvHzEYB18fsfsTdZ2cjRjPrCPRx9/FJtC0ELiKYhR/tVnf/MonLPQ3cCEQvUv5XM/vJ\n3Z+v5NonAn+J2b0k0mc8d5rZWuAOd1+URHyY2fbAs0D3mEMPu/u3yfQhuTV78Vq+X7Yu9NiI3lEf\ntRfPgff+UvVZwu7w0X0w5kpI9GizUaugCl/3keldR6Sm6X9idUcgIiJ1mJL4STKz5sAfgLMIkvdh\nlpjZ/cAt7r4qZ8EBZlafILYLgN4J2k0C/uLuVSpZZ2aNgDOAswkGNCRq+wUwmuDmcEOS/f8SeKQq\nMYYYkegGPjKaf1wVr/GulpGoWSbMWhy6f5dtWtKsYR0s2b34W3jmVFg8PXG7Tv3h2H9C664Vj7Xv\nC7ueDpMf3HL/rqcHx0QS0edHRERE6pDIrOqrgd8BTct2V19ESbsaOIzyxHVj4FUzuxO4F5gX2d8N\nOIcgSR59g/UzcE0W4+sIjDOz6cB/gLeBae6+HP6XuN+GoIz9r4GdY85/G7gumQu5+1ozuxq4L2p3\nfeA5M3ue4HnIJHdfF7l2E2AIwSz/w0O6vMbdVye4ZHOC2fu/N7NxBEsSfAJ8GX1e5DnWnsAJwEls\n+e8P8A3Bcy7JQ+PizMLfvlNzOrZouOXOZXOqlszfvB5e/S189lTidh12hBOegFZdku9bRERERNKm\nJH4SzGww8DzBSPhE2gJXAqeZ2THu/mHWgwPMrB/wAhVHVIfZE3gxciP5q0puDONdbw+CEmw9K2sb\n0Q+4B7jQzE5296mpXjNDMlhfTGqL92fGK6Ufb6xOLfb5s/DKRbB5beJ2g84M1i6v1zB+m+FXwhfP\nwoaVwXbDFjD8iszFKrWbPj8iIiJSB0QGx78G7EV54t4jr5S7i5yXkwEA7r7QzI4hKPveKLK7PnBp\n5FU2gD/spmE9cLS7L8x6oNCH4DnNlQBmtj5y/ZbEX2LyLeBId9+c7EXc/X4z25EguR7t6MgLM1tJ\n8P40T9DVve5+b5KXLQL2i7yIXGMTsJpgUEWjOOcBzAYOKBvUIPknqVL6sdJJ5q+cD0//InH5fIAd\njoLD74b6TRK3ExEREZGMiXfDIhFmNgR4h/AE/gZgY8j+rYF3zGzPbMYG/5s9PpHwBH4JsBwoDTl2\nNPCWmTUNOZboensT/HvES+CvBlYS/tChDzDezHZP5ZoZshyYXA3XlTy2YXMJH89dFnpsWF1K4m/e\nAK9cDP85M3ECv14TOPohOPiviRP4AE3awN6Xl28PvyLYJ5IMfX5ERESkbngI2JvyBDyR71cT3MMa\n5Un574FlBPf30fvLLIy0+S7yNevc/X2CBPJPIYcbEp7A/wnY190nZDO2BBoBrQl/HrYJ+BNwoLtX\nMrI51EXAxZQPYIjVgvgJ/I3AJcD5aVw3Wn2gDfET+E5Q9bC/u+fkcyKpW7VhM5PnhT+rGNk3QRK/\nTFky/55BMO0pKIkzp2XeBLhv78QJfCuA/f4UlNBXAl9EREQkp6plJr6ZFQEDCWaF9yO4gWoVObyc\n4Mb0c+ADYKq7V8sMajNrTzADv3HU7s3AnQTl0OZE9nUnKA/3G8rLkzUGnjeznbO1JryZbUOwplmz\nmENPAXcDH7l7SeTfewjwW7Ys1bY7wUOD45O8XmuCf4/Ym8EPCdZwG+vuKyNtmwLDCEoC7hPVtinw\nkpn1qmTJgQ+Ac5OJK0RTKq5l95S7hw24SGQm8LcUz1mQYnupRlO/W87G4opjXJo1KGLnzi1zH1B1\nWDYXnj0Nfvoscbt2feG4x6Bdr+T7HnQmc78NlpDsOujMShqLxNDnR0RERGoxMxtJUOa8bPa8AQ8T\nrMP+rZmdTVCWHgB37xo5rxDYBTiE4DlE+0gfK4Hj3P2LXP4c7j7RzPoClxMs8dc6TtNlwP3AzWXP\nDbJsLkFC/iCCf6/CBG2XETxH+Zu7z0nQLiF3d4L16l8keJ5xKtCpktMWAo8B/3D375K81O0EAzX2\nBgZQvgxDIgsIKjje7e6VrJ0m1e39GUsoLq04N6Z1k/qpPauINzPfHT5+AMZcAaUJHrk2bAnHPgLd\nR6b8M4iIiIhI1eU0iR9Jil9AcGPXLsnTFpvZfcA92UqGJ3A90CFqez1whLu/GdNuFsFaZG8TrLVW\nluTuSHDTeE6W4nuAoIR/mRKCEvn/im4UGQTxLvCumf2GYBBCmePM7KGQnynMFQQjuqPdC1zg7ltk\nQt19DfA68LqZ/ZEt17rrQLDu2lXxLuTu3wLfJhFTBWb2y5Ddj6bR1QJ3H51ODFIzxCulP7h7G4oK\n60Chkm9egRfPh42VPEPb+cRg9n2qo+6L6vNdl+MA6FoYu/yiSCX0+REREZHa7bLI17JZ+P/n7jdV\ndpK7lwBTgClmditwF/BLoBfwnpnt5+5TshNy3JhWAldE1oUfRDBZo+xZxWLgS2ByOhM03H08aSwR\nEInpWuDayBr0OxNMwGhPMOliI0EC/Svg09hnGlURScZfDlxuZl0JEu3tCUr4lw24+Bn4xN3nptH/\nZ8BnAGZmQA+gG9CZYIJMQ4KfbzmwJHIdzbqvQeKV0h/eqx2FBWmsmBGdzB/yW/h+Enz2VOJz2u8A\nJzwBrbumfj0RERERyYicJfHN7DiChG9LtrwBi7fWW1mb9sD/AReY2Tnu/mzWgoy+uFk34Fcxu69K\nlOx29zciN623Re0+3cxuSefGrJL4dgVGxey+KTaBHxLj3yOj5KMHFtxkZm9FRo0nEjtj/3Pgwspu\ndt39WjMbDOwftfsEEiTxq+i0mO3p7v5xlq4lNdiEWYtD99f6Uvolm+Ht6+CDuxO3K2oIB90Gu5wM\nlpOlNUVEREREaj0zawSMoHwW/pRkEvixIiXfTzez1QRrsbcAXjSzHd19RQZDTjaeYoKqeh/k+tqJ\nRP6dJkVeub72XIKqANnq3wmqCM7M1jUkt0pLnXdnhCfxR/RJopR+IsvmwCsXVt5uhyPh8HtUPl9E\nRESkmuVkqqmZXUtQmqwV5aPMo9d7C3sR1c4i5z5tZtEzurPpEspL40Mw2/7vSZx3R6RtmXoEZewz\n7ayY7SXAn5M89wpgXdT2ACoOCNhCpHT/NjG7R0dmASQj9t+um5ltleS5STOzbQlKykVLZxa+1HLL\n1m7iqwXhKzoM7VGLk/gr58MjB1WewG/dHc58GwacogS+iIiIiEhm7UbwrKDs+cg9VezvtwQzyiEo\n356tAfMikmWf/7iSJWs2VdhfWGDs1SvZoqZpsgLY949wzCNK4IuIiIjkgawn8c3s1wQlzKKT92UZ\noSnAfQRlxs6NvC6P7Jtc1kXMedea2RnZjpst144HeCiZhHWkzcMxu4/IVFBRYhekSnrN98iI/Bdi\ndh9VyWkdQ/alMrs9rG1Yn1V1KltWeigFElYnkLpp4qwlhNWe2LplI7q2raU3q7PehtHDYH4l/+lu\nfzicNR469stJWCIiIiIidUyXmO3xlZ1gZnErKUaq4/2lrClwhpklWgNeRPJUvFL6A7drRYtGWVxm\nrF5jOPHfMPRiDeQXERERyRNZLadvZh2Av7LlrPvNBGuy313ZmlyR2d/nAxcTjFIvS+TfbmavuHv4\nX7ZVj3sgwVpi0Z5JoYtngBujtrcxswHu/kmVgwPMrDXBWm7RJqbYzUTgF1Hbh0eWK4hXGj9swEcl\nC2lvYUXIvmzcfZwas/22u/+YhetIDTdh5pLQ/UN7tMVq2w1raQmMvwneu434K5gABfVg1A2w21m6\naRcRERERyZ7WUd9vjvNsJHYSQQMg0Zryr0d93wLYFfgovfBEpLqMnb4odP8+VS2lX5nN6+CNP8C6\npdDvWCjM2QqsIiIiIhJHtmfi/x/QNPK9AQuAge5+WWUJfAB3/8HdLwcGAtGJ2CZktzzciJjthe4+\nJ9mT3X02EPtXd+zM+aoI+8t9Vsi+RGLXS2sPJJp2G5YIbx2yL542Ift+SuH8SpnZnkDPmN0qpS8V\nuDsTZsVJ4vesZaX0Vy+Cxw6H924lYQK/xTZw+huw+9lK4IuIiIiIZFejqO/D1/iCNTHbCe+/3X0J\nwbJ5ZX/075BeaCJSXX5etYEvfwz/X8LIbCfxAZbNgRfPgY/vz/61RERERKRS2U7iH0v57Pl1wAh3\n/zLVTtz9K2AfYH1Uf8dlMM5Y28dsp1I2vkzsiPe+acYSJuzmPZVZ8fHax/7c/+Pu84HYgQx7pnC9\n2LYLkhnIkaLYWfirqLhsgAhzl6zlxxXrK+w3gyE9alESf94EuG8YzHs/cbueo+Ds96DzrrmJS0RE\nRESkblsd9X3DOG1iM3mx1QLDbKR8ebladGMjUjeM+za84GjnVo3o0b5p6DERERERqb2ylsQ3swGU\nzxh34FZ3j539nbTIubdRfkPaPnKNbOgTs530LPwocyvpsyo2huxrkGIfYQ8KKhtoEDsU9zwzq1/Z\nhSyoTf67mN2PVHZeKsysAXB8zO5n3L1ipja1fuuZ2SAzO8rMTjezY8xspJkl8wBF8lS8Wfg7bNWc\n1k0q/Ujnv9JSeP9v8OihsCa8FB8AVgD7XAsnPg2NUymsISIiIiIiVbA46vsmZhb2bOa7yNeymfX9\nE3VoZk0IyuhHL2coIjXI2OnhSfyRfdrnZtm/1t3giNHBEnsiIiIiUu2yucBR78hXI7iJfCIDff4L\nuIbym9LeQEbWmY/RK2b7hzT6iD0nts+qWB6yr12KfYS17x2yL9odwCmUl+XrCTxqZr9097CBBZhZ\nYeS86Jn484G/pBJsEg4HWsbsq2op/UEEFQsahR00sznAq8Ad7h47aEPy2ISZcUrp90j1P6M8tG4Z\nvHAOzByTuF3TjnDMQ9BlaG7iEhERERGRMt/GbPcM2fctsJny5zb7Afcm6HM/yidqOLCsijGKSA5t\nLC6J+6xiRLZL6bfuBntdBv2OhcJsPioWERERkVRks5x+dDZsk7unumZ7BZG15jfFuUYmtYrZXphG\nH7HrvbdML5RQCwhu5qPtkmIfYe0rW2NvI3AgMCNq9wnANDM728x6mlkjM2tgZl3N7JfAZOCCqPZL\ngYPdPd66f+k6LWZ7trtPqGKfjYmTwI/oBvwGmGlmd0WqAUieKy4p5YPZS0OPDetZwytOzp8C9+1V\neQK/615wzvtK4IuIiIiIVI9vCJ5tlE1Q2Cm2gbsXAx8STIww4BAzC10Cz8zqAddG9QfwVSYDFpHs\n+njuMtZuKqmwv1G9Qvbo1iY7Fy2beX/+ZOh/ohL4IiIiInkmm0n86ORnlUqax1gX5xoZYWaNqPjv\nsi6sbSVif+YiM4u31l1K3H0DMDVm96EpdhPWvtIFttz9B2AgcDvlP2MfYDRBcn8dsIFgCYJH2HKw\nwOvAAHf/PMVYEzKzDsD+Mbsfy+Q1KlFIMFBhkplleXi0VNVn81eyemNxhf0NigoYuF3s+J0awh0+\nHA0PHwArExUOsWB0/SkvQlN9VEVEREREqkPknr4sQQ/BLPowT5edQjAj/1Uz22Ikrpl1Af4L7By1\newXwcYbCFZEciFdKf0iPNjSsV5jZi7XcVsl7ERERkRogm3+lRdeAamFmjTKwPnkjghntZaPLw+tM\nVU1YIntDGv2EndMkzb7CjAEGR20PN7MB7l7p8gJmth9b3uCXqTSJD+Dua4BLzOwt4C6geyWn/ACc\n7e6vJ9N/Gn7Blp9lp2pJ/CXAa8A7wBfA98Aqgpn5HQmWBjiBigMHBgAvmtnIyEOZpJlZ7KCMMn1W\nr17N+PHjU+kuo1avXg1QrTFk0kuzNoXu79nC+HDi+zmOpuoKi9fR+9u7ab94YsJ2m4ua8fX2l7C8\nYAC8l7ufs7Z9fiS39PmRqtJnSKpCn5/8V/YeidRQY4C9It8fbGbm7h7T5jHgSmArgvvcLsC7Zjaf\n4D67JcGg+rLBAGXLGd4dmckvIjXEuDhJ/IyX0m++NVz4qRL3IiIiIjVANmfix5aTH5WBPsuSpmU3\nqLHXyISw2fLhWb/EwtaIz2TlgNFsGZcBj5hZ80QnmVk74L44h5OKz8yGmtlkgkR3ZQl8gG2A/5rZ\nc2bWK5lrpCi2lP677j4vjX4WACcCW7n7ae7+mLt/6u5L3X2zu69092/d/RF3HwXsDfwY08cewJ/S\nuLbkyFdLK5anA9ihbYZHtudAkzVzGTj1kkoT+Cub92HKrrezvPWAHEUmIiIiIiKV+HfkqxEMFj8y\ntoG7ryWo+laW3PdI+20I7j23p/y5TlmbL4GbshOyiGTDnMVrmLc0vAjoiN4ZTuL3P0kJfBEREZEa\nIpt/tU0CSii/obzczF4KGVmeFDMz4A9Ru0oi18i0sBnU9dPoJ2x99EzNwsfdF5rZ3cAlUbt3At4x\ns1PcfXrsOWa2C/A40DVOt2squ66ZXQDcQVBCPtoHwASCgRUGdAKGEDxYILLvaOBAMzvN3Z+r7FrJ\nMLP+VFw/8NF0+nL3GQRLAiTb/j0z2xP4iOChS5kLzOwOd1+QQl8Dw/ab2dRmzZoNGD58eLJdZVzZ\n7LPqjCFT1mwsZs6bb4YeO23U7my/VcIxMPnDHT79F0y4HIor+d/KHhfQYt/r2KOwXm5ii1GbPj+S\ne/r8SFXpMyRVoc9P/mvWrFl1hyCSNnefY2bnAS0iu8ImAuDuL5nZr4F7CZ5NRD9Tif7egGnAoalW\nhhOR6hWvlH6fjs3YqmWGVxLtfWBm+xMRERGRrMlaEt/dV5jZB0DZem2DgL+yZdI5FbcQlI8vu0n9\n0N1XVCnIcGGJ7HTWsg87p9IkeYquIJgNHp0A3hX4wszeIUisLwPaAMOA4ZQPqigBxgP7RJ27ItHF\nzOx4gvL50T4BTnf3z+Kc0x94iKDUPAQl6Z8ys5Xu/lai6yXp1JjttUBGBggkw92/N7PTCaoSlGlE\nUB1Asx/yzIezl1JcWnEcUdum9enTsYY8BN60Fv77O/jsqcTtGrSAI/4BfQ/JTVwiIiIiIpISdx+d\nZLtHzOx94FrgYIIy+tGmE1TcG+3uoYMBRCR/jfs2PIm/T980Z+E3aQdrF1fc37QjdNolvT5FRERE\nJOeyXT/pNoIkflnJt4vMbFvgQndPqhS+mXUgmPl9XFQ/DtyajYDdfb2ZlbLlUgON0+gqdqhsSaZH\nw7v7JjM7CHgZ2D3qUBHB8gXxljBwgpJ8vUkyiW9mLYF/xOz+CBjh7usTxDjNzIYSDBjYLSq+h82s\nt7uH1wtLgpkVASfF7H7e3TM9WCIhd389sqZ99GCK/VESP+9MmLUkdP+QHm0pKLDQY3ll8Qx45lRY\n/E3idh13guMeg9bxim6IiIiIiEhN4u6zgFPMrBDYFmgPbAbmu3t4BlBE8t7qDZv5aM6y0GMj+6SY\nxG/dDfa6DKa/Grxi9RoFBdlcWVVEREREMimrf7m5+8sEyduyxLsRrPM208z+ZWZHmdl2seeZ2XZm\ndqSZPQrMIkjgl2XYHBjn7q9kMfQVMdsdwxpVolPM9vL0QkkscrM+gqBSQTIJ8fnAfpER/7Exzk9w\n3ulA66jtEuCURAn8qBjXAydHzinTmYqz6FN1ANAhZl9apfQzIPbzOKhaopCE3p8ZMhIdGNqjbY4j\nScMXz8H9wytP4O96OpzxlhL4IiIiIiK1kLuXuPtcd//I3T9RAl+kZpswc0loxcBWjevRf5tWyXXS\nuhscMRrOnww7HAmzx4W3631QFSIVERERkVzL9kx8CBLwk4AelCfyGxPMoD4JwMycoAy6A00pT9jD\nlsl7I1iz/PgsxzyDoHR/mW3S6CP2nJnph5NYJEl+uZndARxLMAt8e6AdwXv8I/AN8DTwQlTivW9M\nV1MTXObgmO033T3pn8ndZ5rZWwSJ9zKHA0mVD4zjtJjt74E4dypZNz1mu4mZNUpmkIPkxk8r1zN7\n8drQY8N6tstxNCnYvAHGXAlTHkrcrl4TOPRO2OnY3MQlIiIiIiIiIlUydnr4OJy9e7WjMJmKgUeM\nhn7HQmHkEe+c8bA55NlHUSPotnf6gYqIiIhIzmU9ie/uS8xsX+B5gnLjZcNLYxP18Rakjm4/BTjG\n3ZdmI9Yo09kyid8tjT5ip8FWMn226tx9IcGa9bHr1ldgZk2AHWJ2T0lwSr+Y7UmpRfe/c6KT+P3T\n6AP4X3n/Q2N2P+buFYcv50ZY7bNWgJL4eeL9meGl9Hu0b0rHFg1zHE2Sls2FZ0+Dnz5L3K5dXzju\nUWjXOzdxiYiIiIiIiEiVlJY6474Nrxg4ItlS+v1P3HL729fC23UfCfViV/4UERERkXyWk4WQ3P17\nYA/gT8BKtpxdn+hFpO1K4Dpgj0hf2fZ1zPZuoa0S2z1mO+tJ/BQNAQqjtldQ8eeOFlvDK52SfYti\nttuk0UeZE4AGMfseq0J/VdUiZN/KnEchcU2Ik8TP21L60/8L9+1deQJ/pxPg1+8ogS8iIiIiIiJS\ng3y5YCVL1myssL/Agpn4KXOHGW+EH+t9QPh+EREREclbuSinD4C7FwPXmdlfgF8ABxIk9mPXNC+z\nCPgAeB14wt2TWe89U2JLsnc0s27uPieZk82sGxV/ruoq8x5P7Hr0T7h7SWjLwBqgZdR2OsN3G8ds\nh9c2T05sKf1JqZT3z4KeMdvr3b0qP59kUGmpM3FWDUnil2yGt6+DD+5O3K6wARx0Kww4FSyJEnsi\nIiIiIiIikjfe+SZ8fsyu27WmZeP6qXe48HNY9WPIAYNeSuKLiIiI1DQ5S+KXiSTjH4i8MLO2QGvK\nZ3ovB5a5e3jGLTemEqwjv3XUvuOAm5M8//iY7fnunmi9+Zwys/bAkTG7K1lwm8VsmcSPTVono1fM\ndlrvsZn1ZMvlDgAeTaevDDowZvvzaolCQn2zcBVL126qsL+owBjcvSoFITJs5Y/w3K/gh48St2vd\nDY59FDrtlJu4RERERERERCSjxn0bnsRPupR+rG9fD9/feVdommafIiIiIlJtclJOPxF3X+LuM9z9\no8hrRjUn8Imsq/5SzO4zzKwwrH20SJvTY3bH9lXdbmXLWfHvu/unlZwTe/zgZP49yphZEXBwzO5K\n6oTHFTsLfwPw7zT7qjIzGwIMi9k9pjpikXDxSunvsm1LmjbI+VimcLPegfuGVZ7A73sYnDVeCXwR\nERERERGRGurn1Rv4fH74KowjM53E7x0770REREREaoJqT+LnsduB4qjtHsBvkjjv4kjbMsWRvuIy\ns+Fm5jGv4amFmxwzO40tS+lvBs5L4tQ3Y7a7AOemcOkLgW1j9qWc6DYzA06J2f2iu1dp/flIv+mc\n15aKVQBKgCerEo9k1oS4pfTTWGMu00pLYNyN8PjRsG5p/HYFRXDALXDcY9CwRe7iExEREREREZGM\nGv/t4tD9W7dsRK8OTVPvcOWP8NO08GO9lMQXERERqYmUxI/D3WcB/4zZ/Wcz2y/eOWY2Crg+Zvcj\n7j47w+GVXa9jssl+Mys0s0uAh2MO3eruXybRxdNAbJ2vv5rZ0Ulc+0TgLzG7l0T6TNUIKg4GyEQp\n/TvN7CYz65DsCWa2PfAu0D3m0MPu/m0GYpIM2LC5hI/nLgs9NrRn2xxHE2PNz/CvI+HdWwCP3655\nZ/jVGzD4HEhvvImIiIiIiIiI5Ilx0+OV0m9HWvNMZrwRvr/ldtC+b+r9iYiIiEi1UxI/savZMnHd\nGHjVzP5iZl2tXHczuxV4BWgU1f5n4JosxtcRGGdm35jZDWY2wsxalR2MJO67mNn5wFTgr2z5nr8N\nXJfMhdx9LcG/R7T6wHNm9pyZ7Wtm/yvRb2ZNzGx/M3uRYFZ6bM3ya9x9dXI/5hZiS+kvAN5Ko59Y\nzYHLgflm9qaZnW9me5hZs+hGZtbczA4ws38C04DtY/r5BvhDBuKRDJkybzkbi0sr7G/WsIidO1dx\nRnvxJnj31uBVvCm1c+dNhNHDYO67idv13B/OeR+2GZR+nCIiIiIiIiKSFzYVl/J+nGX/9umT9NyS\nLcUtpX+QJgOIiIiI1FB5shh0fnL3hWZ2DEHZ97LkfH3g0shrQ2Rfw5DT1wNHu/vCrAcKfYArIy/M\nbH3k+i2JP1DjLeBId9+c7EXc/X4z25GgNH60oyMvzGwlYARJ8Xjudfd7k71uGTNrAhwVs/txdy9J\nta8EioD9Iq+y624CVhMM4mgU5zyA2cAB7r48g/FIFb0/K7xE3R7d2lBUWMVxTJMfhHF/Dr5v0BQG\nJ7HCRGkpTLoT3rkeEn10rQBG/h8M+S0UaLyViIiIiIiISG0wed4y1mwsrrC/Yb0C9ujeJvUON66B\nue+FH+utUvoiIiIiNZUyQ5Vw9/cJEro/hRxuSHgC/ydgX3efkM3YEmgEtCb8/d0E/Ak4MDK7PlUX\nARdTPoAhVgviJ/A3ApcA56dxXQgGCsQuDJaJUvqVqQ+0IX4C34FHgP7u/n0O4pEUTIgzun1YVUvp\nr10K795cvj3+pmBfIuuWwdMnwtvXJU7gN+0Ap70Cw36nBL6IiIiIiIhILTI2Tin9Pbu3pWG9wtQ7\nnDMOSjZW3N+gBWy3Z+r9iYiIiEheUHYoCe4+EegL3AyEL64dWBZp09fdJ+UgtLkECfkpQGWz0ZcB\n9xDEdm26s9c9cCfB7P9bCB/cEGsh8Begt7vf7u4JFv9OKLaU/hR3/zrNvmLdDlwPvAesSfKcBQT/\nptu7++nunux5kiNL12zkqwWrQo8N7dmuap2PvxE2rCzf3rAySOTHM38q3Ld3/HXqynQZBme/D12G\nVi0+EREREREREck74+Ik8Uf0aZ9eh/FK6ffcFwrrpdeniIiIiFQ7ldNPkruvBK4ws6uBQUA/oGwq\n72LgS2Cyu1esh1V53+MJStCnE9O1wLWRUvM7A92B9gSl3zcSJNC/Aj5194oLg6fJ3b8jWEP+cjPr\nCgyIXLclwcz0lcDPwCfuPjdD19wnE/3E6fsz4DMAMzOgB9AN6Ay0Iqi4sBFYDiwh+Lk06z7PTZwd\nPjN+65aN6NKmcfodL/oapjxccf+Uh2HQGdC+b/k+d/j4fhhzFZRWsnrFXpfC8CugII2R9yIiIiIi\nIiKS1+YtWcucJeGFMUemk8QvLYk/WaD3Qan3JyIiIiJ5Q0n8FEWS9B9EXnkjUhp/UuSV62vPJagK\nUCtEKgXMjLykBpswc3Ho/mE92xKM1UiDO4y5EsLGxHhJcOzk/4AZbFgFL18IX7+YuM9GreGoB4JR\n8iIiIiIiIiJSK8Urpd+7QzO2bhlvFccE5k+BdSETGAqKoEfW5sKIiIiISA4oiS8itZK7M2HmktBj\nQ3u2Dd2flBljgvXm4pk9Fma+Cc23hmdOhWWzE/fXeTc49hFo0Tn9mEREREREREQk78VL4o/sm24p\n/dfC92+7BzRqlV6fIiIiIpIXlMQXkVppzpK1LFi5ocJ+MxjSPc0kfvEmePOqytu9dAFsWAklGxO3\nG3w+7HsdFNVPLx4RERERERERqRHWbCzmo7nhy/6lVUof4NvXw/erlL6IiIhIjackvojUSvFm4e+4\nVQtaNUkzaT75AVg6q/J2a8NH1v9Pg+ZwxD+g76HpxSEiIiIiIiIiNcqEmUvYXOIV9rdoVI9dtmmZ\neodLZ8OSb8OP9T4g9f5EREREJK8oiS8itdL7mS6lv3YpjL+5ChFFdNwJjnsUWnerel8iIiIiIiIi\nUiOMi1NKf+9e7SgqLEi9wxlvhO9v10fPHERERERqASXxRaTW2VxSyodzwkvUDeuRZhJ//I2wcVUV\nogIG/goOuBnqNaxaPyIiIiIiUiuZWQegN9ASaAYUptOPuz+WwbBEpIpKS51x34Yn8TNfSv/A9PoT\nERERkbyiJL6I1Dqf/bCCNRuLK+xvWK+AgV1apd7hoq9hysPpB1SvMRxyB+x8fPp9iIiIiIhIrWRm\nfYGzgWOAThnqVkl8kTzy1YJV/Lx6Y4X9BRbMxE/Z+uXw3aTwY70PSr0/EREREck7SuKLSK0Tr5T+\nbl3b0KAoxYks7jDmCvDS9IKp1xjOHAsd+qZ3voiIiIiI1EpmVgDcBPyWYMa9Zajriotui0i1Ghun\nlP6AbVvRqkn91Duc+TZ4ScX9TdrB1gNT709ERERE8k4aCy4lL3JDKiKSUxNmhSfx0yqlP2MMzBmf\nfjCb18HK79M/X0REREREah0zM4LZ8r8nmGBhBMn3speI1CJj45TSH5F2Kf3Xwvf3HAUFaa3CISIi\nIiJ5Jtsz8X8wsweBB939hyxfS0SEVRs2M+2HFaHHhvZMMYlfvAnevKrqQY25ErqNgKI0RteLiIiI\niEhtdBpwElsm7Q3YDEwH5gCrgZCptiJSkyxZs5HP568IPTYynSR+8SaY9Xb4sd4Hpt6fiIiIiOSl\nbCfxOwH/B1xpZq8Do4HX3V2jykUkKz6cvZSS0or/i2nbtAF9OjarhohEREREREQquDbqewNWANcB\nj7r7yuoISESyY/y3iwl7EtqpRcP0nlN8NxE2rqq4v7ABdB+Ren8iIiIikpdyVe6+EDgYeAWYa2ZX\nmVnHHF1bROqQeKX0h/ZoQ1CxMgVF9WH/G6oe1KgbNQtfREREREQAMLN+wHYEM/ANWALs6e5/VwJf\npPYZNz28lP7IPu1Tf04BMOON8P3dhkP9Jqn3JyIiIiJ5KVdJ/OjScNsCfwK+M7NnzWy/HMUgInXA\nhJlxkvg926XXYa9RQSn8dHUfCT33T/98ERERERGpbXaJfDWC5yXXuPv0aoxHRLJkc0kp781YHHos\nrVL67vDta+HHeh+Qen8iIiIikreyncTfD/gPUEz5zWnZSPN6wFHAG2Y208x+b2YpLlgtIlLuxxXr\nmbNkbeixoT3S/N+LWTCT3tL436UVRs5NY2S9iIiIiIjUVrGZu2eqJQoRybrJ85axemNxhf0NigrY\ns3sazyl+/hpWfB9+rJeS+CIiIiK1SVaT+O7+jrsfC2wD/B/wHUECH7acnd8duAX4wcyeMLO9shmX\niNROE2aGj27v2b4pHVs0TL/jDtvDrqenft6up0P7vulfV0REREREaqPCqO9XufuyaotERLIqXin9\nPbq3oVH9wtBjCcWbhb/VLtB8q9T7ExEREZG8lZNy+u7+s7vfSJCsPwh4GSil4uz8BsAJwDgz+9rM\nfmNmLXMRo4jUfO/HLaWfgSIfw69MbW25hi1g+BVVv66IiIiIiNQ2i6K+L6q2KEQk68bGSeKnVUof\n4Ns3wvf3Pii9/kREREQkb+UkiV/GA2+4+xHAtsB1wHzCZ+f3AW4HfjSzR8xscC5jFZGapbTUmTR7\naeixYZlI4jdpA512qbxdmb0vD84RERERERHZ0pSo7xubWatqi0REsua7pWuZvTh8yb8RvdNI4q9e\nBD9OCT+mUvoiIiIitU5Ok/jR3P0nd/8T0BU4HHiN8hn50bPzGwGnAhPNbJqZnWNmTaspbBHJU1//\ntIplazdV2F+v0Ni9awaS6aUlsHRWcm3b9ITdfl31a4qIiIiISK3j7l8CM6J27V9dsYhI9sSbhd+r\nQ1O2ad049Q5nxJmF37wzdOyXen8iIiIikteqLYlfxt1L3f0Vdz+EIKF/A/ATW87Ot8hrJ+AeYIGZ\njTazAdURs4jkn3il9HfZthVNGmSgQuV3E2HNwuTajroBCutV/ZoiIiIiIlJb3Rj1/R+qLQoRyZp4\nSfwRaZfSfz18f+8DwSz8mIiIiIjUWNWexI/m7j+4+9XAdsDRwFvRhyNfDWgK/BqYbGYfm9lpZlY/\nt9GKSD6ZMGtx6P5hPTJQSh/g82eSa9d9JPTURBoREREREYnP3R8DXiZ4xrGzmf2tmkMSkQxau7GY\nj+YsCz02Mp1S+pvWwZzx4cd6H5h6fyIiIiKS9/IqiV/G3Uvc/QV3HwXsCExnyzL70bPzdwUeBuab\n2ZVm1qSawhaRarJhcwmT5y0PPTa0ZwaS+Js3wNcvV97OCmHUjRoBLyIiIiIiyTgJeIfg2cZFZva0\nmXWo5phEJAMmzlrCppLSCvubNyxi4HatUu9w7rtQvL7i/vpNocvQNCIUERERkXyXgRrT2WFmvYCz\ngdOAVpTPxI8WPTu/LXA9cKGZnePuL+UkUBGpdh/PXcam4vCb4506t6z6BWa9BRtXVt5u19Ohfd+q\nX09ERERERGo1M9sr8u0tQGNgD+BY4AgzexV4D5gHrCT8eUhC7v5eZiIVkXSM+za8lP5evdpRVJjG\nnKpvXwvf32MfKGqQen8iIiIikvfyKolv9v/s3Xm8XHV5+PHPk5UkBEJIWGRNSAiLEQUXZJGAC4sr\nomirtdUu+muttmpbl7a21taqtbVVK9qqtRutiopaAZVNAorKomzZ2HeSQEJC9nuf3x9nxjuZe+be\nmbkzd/28X6953Tnf7znP+d47I+ac53yfb0yhKKP/VuC0ajO7J+sB1gJXAWcCe9X17w98PSL+KDMt\nRydNAMvXrCttP+mIeUye1IFZ8Y1K6U+aAr27ivd77A3L3jf0c0mSJEmaCK5i9+R8teLgNODcyqtd\nySi73yNNJJnJFSvKk/gvPLqNUvq9vbDy0vK+Jee0Hk+SJEljwqgopx8RCyLib4EHgP+mSOBXM2+1\npfN/DLwJOCQzXwc8Dfht4GZ2L7cfwEcj4oRh/DUkjZBrVpcn8TtSSn/bRlh1WXnf0tf0vV/2Ppi1\n79DPJ0mSJGkiqd7vgP7LBw7lJWmE3PbQkzz65PZ+7RFw2pFtJPEfugmeKnkoICbB4pe0MUJJkiSN\nBSP2ZHZETAJeCbwNeCGNL1y3ABcCn8nMm2tjZOYW4AvAFyLiFcCngIMr3ZOAtwNv7uovImlErd20\nnTsefrK079ROJPHv+Db09L/4ZtqecNZHYe6iYvs5vzX0c0mSJEmaqEy8S+PElQ1m4T/rkDnMnTWt\n9YCNSukfciLMnNt6PEmSJI0Jw57Ej4hDKGbP/yZwQLW58rP2ifNVwAXAlzJz0MWoM/NbEXEjcAfF\nenIBvGDgoySNddfdWT4L/5C5Mzhs31lDP0GjUvpHvQxmzIHT/mjo55AkSZI00fyQNta6lzT6XbGy\nPIl/xlFtzMIHWHlJefuSs9uLJ0mSpDFhWJL4ERHASynWuj+LYpZ82az7XuA7wD9n5vdbPU9mPhAR\n/03xkADAgUMcuqRRrmEp/UXzhx78yYfh7h+W9z3jtUOPL0mSJGlCysxlIz0GSZ23fvN2br5/Q2nf\n6e0k8Z+4Fx67rbxvyTmtx5MkSdKY0dUkfkQcCPxW5VUtc1826/5RirL4F2TmA0M87Yqa99OHGEvS\nKJaZLG+QxO9IKf3bvk7p5JhZ82HBsqHHlyRJkiRJ48bVq9aSJbcRDthrD445cK/WA666tLx938Uw\nb1Hr8SRJkjRmdHsm/n00nnUfwLXAPwNfy8ydHTrnlppzSRrH7ly7mUee3NavPQJOOmLfoZ+gUSn9\nY18Nk4d9NRJJkiRJkjSKXb6ivJT+6UftR1GotEUrv1vevuSs1mNJkiRpTOl2FmoyfYl7KBL3TwH/\nBXwmM2/p4rkDE/nSuNaolP4zDtqbOTOnDS34utXw8M3lfc84f2ixJUmSJEnSuLKzp5cfrlpb2ndG\nO6X0t22Ee5aX91lKX5IkadwbrqmkQVHm/rPAlzPzyS6e67+ABrWmJI0njUrpn9KJUvq3fLW8fZ8F\ncNAJQ48vSZIkSZLGjRvufYJN23b1a582ZRInL2qjWuCay6G3fzxmzIWDn9vGCCVJkjSWdDuJvwv4\nFsWs+yu7fC4AMnMzsHk4ziVp5Ozs6eXHd60v7Ttl0fyhBc9sXEp/6WuLev2SJEmSJEkVVzYopX/i\nwn2ZOa2NW7ArLylvX/wSl/iTJEmaALr9L77DM/OhLp9D0gR0030beGpHT7/2GVMnc/xhc4YW/MEb\n4Ym7y/uWvnZosSVJkiSpgYiYApwAnAQsBeYC+1S6nwAeB34B/Ai4ITNLpulKGglXNEjin7GkjYkG\nPTth9WXlfUvObj2eJEmSxpyuJvFN4EvqluVrykvpP2/hXKZPmTy04Lc0mIV/4HEw/8ihxZYkSZKk\nOhGxH/B24HeAZjN+ayPicxTVD8uzh5KGxf2Pb2H1Y+WFQc84av/WA973Y9i2sX/75Gmw6IWtx5Mk\nSdKYM2mkByBJ7Vi+em1p+ymL5g0tcM8uuPWi8r6l5w8ttiRJkiTViYjzgTuADwD7AVF5NTyk8toP\n+FPgjoiwZJg0ghrNwl+0354cuu/M1gOuurS8/fBTYfrs1uNJkiRpzDGJL2nMeXLbTn7+QMkT6cCp\ni9soU1fr7qvhqbIHBAKeft7QYkuSJElSjYj4IHAhRcn8ALLygr5kff2Lmv2icuz/RMSfD9/IJdVq\nWEr/qP1aD5YJK/6vvM9S+pIkSRNGV8vpR8Q+wL/Sd5F5cWZ+uc1YvwG8orLZA/xGZj415EFKGnN+\ndOd6enqzX/t+s6dz5P57Di34LV8tb19wKux14NBiS5IkSVJFRPw28MHKZm3iPoGfATcAdwPVJ5j3\nBhYAxwPPrtm3etwHI+LBzPxC90cvqWrLjl386K71pX2nL2kjib9uFTxxd3nfkWe1Hk+SJEljUleT\n+MCvA+fS94T4+4cQ68dA7YXo94HPDyGepDFq+ep1pe2nLJpHxEBVJwexYwvc8e3yPkvpS5IkSeqQ\niNgf+AS7J+F3Av8IfDoz7xvk+EOA3wP+AJhK36z8f4iIb2dm+bRgSR133Zr17NjV26999h5TePbh\n+7QecOV3y9sPWApzDmk9niRJksakbpfTr816Lc/MFe0Gqhx7DX2z+t8wlIFJGruWr2mQxF88b2iB\nV10KOzb3b588HY55Rf92SZIkSWrPnwLVMmIBPASckJl/PFgCHyAz78/M9wInAA/WdM0CPtDpwUpq\n7IqV5c/MvODI+Uyd3Mat15WXlrcvOaf1WJIkSRqzupbEj4i9gOfQ91R5gxrVLflaNTxwYkTM6kBM\nSWPIA09s4e515StpnLJoiEn8RqX0j3wJ7LH30GJLkiRJUp/X0jd7fgtwembe2mqQzLwNeCGwtSae\nZcSkYZKZXLmiPIl/Rjul9J9aB/dfX95nKX1JkqQJpZsz8Z8BTKZv5vy1HYi5vOb9lMo5JE0gjUrp\nL9l/NvvttUf7gbc8Dqu/X95nKX1JkiRJHRIRxwPV7F4CH8/M1e3Gqxz7d/Tdf9mvcg5JXXbHw5t4\neOO2fu0RsGzJ/NYDrrqMvvlQNWYfCAc+s/V4kiRJGrO6mcRfUrfd8hPlJW6r/Kz+a/bIDsSUNIZc\n061S+rdfDL07+7dP3xsWv2RosSVJkiSpT/V+STXp/l8diPkflZ/V+yX192QkdcGVDUrpH3fwHPbd\nc3rrAVddUt5+5FkwqdurokqSJGk06ea//vapeb8pM3cNNWBm7gSerGmaO9SYksaO3t7kum4l8RuV\n0j/m5TB1CDP8JUmSJGl3tdNzd2TmmqEGzMw7gR0NziGpS65oVEr/qDZK6e/cBmuuKO9bck7r8SRJ\nkjSmdTOJP6XmfUkdqLbVxprRwbiSRrnbHnqSJ7b0ny0/bfIknrdgCM/0bLgf7m2w4oel9CVJkiR1\nVu29jK0djLulwTkkdcHjT+3gpvueKO1rK4l/zzWw86n+7VNnwoIXtB5PkiRJY1o3k/jra97vFRFT\nGu7ZpIiYCuxd07RhqDEljR3XrFlb2n78YXOYOW0I/4m59aLy9tkHwuGntB9XkiRJkvqrLS+2d0QM\nOeFeiTGnwTkkdcHVqx6jt2Ta0v57TefYp+3VesCV3y1vP+IMKwRKkiRNQN1M4tdfMJ7QgZjHV35W\n140rz+hJGpeWr25QSn9Rl0rpP/08mDR5aLElSZIkaXcP122f2YGYL6n8rN4vqT+HpA67YkX5bcnT\nl+xHRJT2NZQJKy8t71tydosjkyRJ0njQzST+zys/q8+kvqIDMV9Zt317B2JKGgO27ujhZ/eUl6k7\nZfEQlnt89HZ49NbyvqWvbT+uJEmSJJW7Duih737Je6PljF+fyrF/UtPUUzmHpC7Z1dPL1SsfK+07\nvZ1S+g//HDY9VNIRsLgTz/lIkiRprOlaEj8z7wLuqWwG8PaI2LfdeBExH/g9+i5yH8nM24Y0SElj\nxk/ueZwdPb392veeMZWlB+1dckSTGs3Cn3ckHHhc+3ElSZIkqURmbgB+RHGvJIDnAJ8YQsiPAidS\n3C9J4MeVc0jqkhvv28CT23b1a582eVJ71QJXXlLefvBzYM8hTFyQJEnSmNXNmfgAX6e4IE1gNvDV\nyrr2Lakc85VKjGq8b3RwnJJGueWry8vUnXTEvkye1Oakld5euOVr5X1LXwvtT4aRJEmSpIH8XeVn\nUtzneGdEfC0iDmw2QETsHxEXAu+uiQPw8Y6OVFI/V6won4X/vIVzmTV9SusBV363vN1S+pIkSRNW\nt5P4HwO21myfBvwgIg5pNkBEHApcXjm2Ogt/O/CRTg1S0uh3zep1pe2nLG7jCfeq+6+HjfeV9y19\nTftxJUmSJGkAmfkt4Cr6JioEcC6wOiL+IyJeHRGH1R8XEYdFxLkR8WVgDXA+fcn7BK7MzG8Px+8g\nTWRXrHi0tP2Mdkrpb3wQHvlFed+Sc1qPJ0mSpHGhjUdDm5eZj0XER4AP0XdReipwR0T8F/DfwE8y\nszbRT0TMAJ4LvAH4VWBGtasS5+OZ+WA3xy5p9Fi7aTsrHtlU2nfqoiGUlWtUSv/g58Dche3HlSRJ\nkqTBnU+xdv0i+u6ZzKS4D/KrABGRwFOV/j3pS9jD7sn7AFYBrxuOgUsT2QNPbGHVo5tL+9pK4q9q\nUEp/nwUwf0nr8SRJkjQudDWJD5CZH46IZ1E8UV57UfpblVdPRDwCbKj07wMcAEyuhKgm7qvHXpyZ\nH+z2uCWNHteuKZ+Ff+jcmRy678z2gvbshNsarMqx9LXtxZQkSZKkJmXmuoh4EXARcAJ91QfrE/Wz\nG4Wo2ednwGsyc303xiqpz5UNSukvnD+Lw/ad1XrAlQ2S+EvOdpk/SZKkCazb5fSr3gj8D/0T8kHx\nIMHBwNOBpZX3U2r6ay9KL6SYnS9pAulKKf01l8PWx/u3x2Q49tz240qSJElSkzLzPuD5FBUMN7L7\n7PqBXlT23Qj8BfD8SixJXXZFgyT+GUvamIW/fRPc/cPyviVntx5PkiRJ48awJPEzc2tm/irwdmAd\nrV+UrgXemplvqC+9L2l8y0yWr1lb2nfqoiEk8RuV0l+4DPZs48JbkiRJktqQmbsy8y8oJjW8Ffgm\n8Bh9kxvqX49V9nkrcHBmfigze4Z94NIEtHVHD9fdWV7woq1S+ndeCT07+rfvsTcc+vzW40mSJGnc\n6Ho5/VqZ+c8R8UWKMvrnAifSt959va3Aj4CvAV/KzO3DM0pJo8maxzbz6JP9/+c/KeCkI9pM4m/f\nDCu/W973jPPbiylJkiRJQ5CZW4B/qbyIiHnAXIplBwGeAB7PzPJSZZK67kd3rWP7rt5+7bOnT+HZ\nh89tPWCjUvqLXwKTp7YeT5IkSePGsCbxATJzG/Bp4NMRMRU4BphPcWEKsJ5i5v0dmblzuMcnaXRp\nVEp/6cFz2Htmmxe0K/4Pdm7p3z5lBhz10vZiSpIkSVIHVZL1JuylUeTyO8pL6Z965DymTWmx4Glv\nD6y6tLzvyLNaHJkkSZLGm2FP4teqJOl/PpJjkDS6LV9Tfs+qK6X0l5wN02e3H1eSJEmSJI1LmcmV\nK8qT+KcvaaOU/v0/ga2P92+fNAUWvaj1eJIkSRpXWnxEVJKGz45dvfz4rvK15k5Z3GYSf/NauPOK\n8j5L6UuSJEmSpBIrH93EQxu3lfYtayeJv6pBKf3DToYZc1qPJ0mSpHHFJL6kUeum+55gy46efu0z\np03m+EP3KTmiCbd9A7J/TGbsA0e8sL2YkiRJkiRpXLuiwSz84w7em/mzp7cecGWDJP6Ss1uPJUmS\npHHHJL6kUatRKf3nLZjb+lpzVY1K6R/zKpgyrb2YkiRJkiRpXGtYSv+oNmbhr78T1q0q7zvyrNbj\nSZIkadwxiS9p1LpmdXkS/5TF89sL+Pjd8MBPyvsspS9JkiRJkko88dQObrj3idK+Fx61f+sBG83C\n3+8YmLug9XiSJEkad6aM1IkjYjpwADAHmE0bDxRk5g87PCxJo8TGLTv5xQMbSvtOXTyvvaC3fK28\nfe9D4JAT24spSZIkSZLGtR+uXktv9m+fP3s6xz5tr9YDWkpfkiRJgxjWJH5EHAG8DTgdWDrE8+cQ\nj5c0iv3ornWlF8j77zWdxfvt2XrATLjlK+V9Tz8PJlmYRJIkSZIk9XdFo1L6S+YzaVK0FmzL43Df\nj8r7jjSJL0mSpMKwJMEjYibwaeBNQFReY1JETAaeAzwdmEfxu6wFbgF+lpk9Izg8ACLiMIox7k9R\n6WALcDfwk8x8qAvn279yvoMq5wPYADxYOWf5lc4YMRY+8/GoUSn9kxfNI6KN/4Q88ovG681ZSl+S\nJEmSJJXo6U2uXrW2tO+Mo/ZrPeDq70PZraRZ8+GgE1qPJ0mSpHGp60n8iJgLXAMcRZH8zMqLmm0a\nbPcL140xNiMi9gL+BPgdikRumXUR8Xngo5n55LANDoiIaRRjezuwZID9rgM+lpkXD/F8M4DfBN5K\nkdweaN9bgAuAL2bmthbOsQy4sv1RAnB1Zi5r58DR/pmPd8vXlCfx2y6l/4sGs/D3Oxb2P7a9mJIk\nSZIkaVy76b4n2LBlZ7/2qZODUxbPbz3gqgal9I88yyqBkiRJ+qWu/sswiumyXwGOpi9BH8BOilnM\nq9l9Zv7VwI3AfTXttYn72yr7XA38sJtjrxURJwJ3AO+ncTKXSt/7gdsrxwyLiFgK3A58igES+BUn\nAd+MiK9FxOw2z/d84OeV8w2YwK9YCnwGuCkixsQjxaP9Mx/v7n98C/eu31Lad/KiNpL4vT1w60Xl\nfUtf03o8SZIkSZI0ITQqpf+8Bfuy5/QW50ft2gGrf1Det+ScFkcmSZKk8azbj3e+CjiDvtn1vcCf\nA/Mz85nA39funJmnZ+ZzMnMBsB/wFuAX9CXy9wP+urLf6V0eOwARcTJwOfC0ku5twPaS9oOAyyPi\npG6ODX45W/1a4IiS7h7gCYq/e73zgO9HREuLi0fEaRR/j8UNdtkEbKS8osJRwFUR8bxWzjncRvtn\nPhE0KqV/1AGz2W/2Hq0HvPda2PRweZ9JfEmSJEmS1ECjJP7p7ZTSv3c57NjUv33KHrBwWevxJEmS\nNG51u5z+uys/q7Pwfyszv9zMgZm5Dvi3iPh34E+BD1LMev6/iDg3M7/bjQHXioj9gIuAmTXNO4F/\npCgPf1el7QjgbcA7gKmVtpnARRFxXLfWhI+IQ4CvAvUz6i8EPg1cn5k9ETEFOBn4Q+CVNfs9D/gC\n8LomzzeX4u8xo67rx8DHgCsyc2Nl3z2BUym+Ay+s2XdP4OKIOLKN8vOrqXvwowkPtbLzaP/MJ4rl\na8rXmjulnVn40LiU/qEnwZxD24spSZIkSZLGtQc3bGXFIyVJd+CMdpL4KxuU0l+4DKbNLO+TJEnS\nhNS1JH6lVPuJ9M3IvrzZBH6tzOwFPhQRWygSxVOB/46IYzPzwY4NuNxfAfvXbG8FXpWZ36vbbw3w\nnoj4AfB1+pLcBwAfokj2dsO/sHup9x7gzZn5H7U7ZeYuKssQRMQ7KBLSVedHxBdKfqcy7wP2rWv7\nLPD2yudUe87NwCXAJRHxlxQVGKr2p1hr/gNNnLPWQ5l5QYvHtGq0f+bjXk9vcu2a9aV9pyxuI4m/\nazvc/q3yPmfhS5IkSZKkBq5sMAt/wbxZLJg3q7VgmbDy0vK+JWe3ODJJkiSNd90sp39iJX61FP7n\nhxIsM/+Oomw8FDPPPziUeIOJiIXAm+uaPzBQsjszLwX+rK75LRGxoAvjezZwZl3zR+oT+PUy858o\nZpTvdlxERNn+depn7P8C+P36BH7JOT8I1P/dXt/E+YbVaP/MJ4pbH9zIxq07+7VPmzyJ5y2of4ak\nCau/B9s39m+fNAWOPbeNEUqSJEmSpImgURL/9CVtzMJ/9DbYeF9535FntR5PkiRJ41o3k/gH120v\nH+yAiJg6yC6frO4K/GpETGtjXM16F31l0qGYef1PTRz3ycq+VVMpyth32u/Uba8DPtzkse8DttRs\nH0//BwJ2Uyndf0hd8wWZ2dPkOev/dgsjomzN+ZE02j/zCWH5mnWl7Scctg8zpk1uPWCjUvqLXgwz\n57YeT5IkSZIkjXvbdvZw7Z3l9yg6Wkr/acfD7ANajydJkqRxrZtJ/Nrs2PbMfKRkn11129MHifn9\nmvczKGb7d8sr67a/0EzCurLPF+uaX9WpQdU4o277wszc3syBmbkB+EZd86sHOazsauInzZxvgH1H\n2xXKaP/MJ4RrVq8tbW+rlP62jbDqsvI+S+lLkiRJGiER0c37MZI64Ed3rmfbzv7FJ2dNm8xzF7Qx\nKWDld8vbl5zTeixJkiSNe928aKydJb+5wT6b6rbnDxQwM5+sO+aoNsY1qIg4gf6VBBpM5y1Vv+8h\nEXH80EbVJyLmAkfUNV9btu8A6vd/5SA3Ecr6SmqUN7ShpG2wygvDZrR/5hPFlh27uOHeJ0r7Tm0n\niX/Ht6Gn5NmWaXt6kSxJkiRpJN0fEX9ZqXonaRS6okEp/VMXz2falBZvqW56BB66sbxvydktjkyS\nJEkTQTeT+LXJ9pkN9nmybrs+iVqmF8jK+27Vwj69bvuRzLyr2YMz807g0brm+pnzQ1FWs2tNSdtA\nVpfEXDrA/g+WtLXy9y9bzPzhFo7vttH+mU8I19/9ODt7sl/7nJlTOfZpe7cesFEp/aNeBtMa/WdJ\nkiRJkrruQOBPgbsi4lsRcU5ExEgPSlIhMxsm8dsqpb/q0vL2vQ+F/Y9tPZ4kSZLGvW4m8WsTtDMa\nrHdfTZJWs3YnDBQwIuYAtZm8/jWtOuOYuu1WysZXXV+3fXSbYylTljxvZVZ8o/3rf+9fyswH6Pu8\nqk5q4Xz1+z6Umfe1cHy3jfbPfEJYvrp8rbmTj5jH5Ekt3s/a9Ajc/cPyvqWvbXFkkiRJktQVk4GX\nAt8G7o6ID0TEaFt6TppwVj+2mQc3bC3tW3bUgIVEy628pLx9yVng8zuSJEkq0c0k/sq67bLS92uA\n2n8RD1bf+mWVn9V/3ZZn/IaufqxNz8iucfcgMYeipD4401uMsUdJ22BJ58/Xbf9uREwr3bNGZTbB\nu+uavzTYcQ1iTY2I50TEqyPiLRHxmog4IyKaqeIwkNH+mU8IjZL4p7RTSv/Wi+h7PqjGrPmwcFnr\n8SRJkiSp86oXLQEcCnwIuDcivhoRLx65YUkTW6NZ+EsP2pv9ZpfdUhvAji1w11XlfZbSlyRJUgPd\nTOKvALbUbB9Xv0Nm9gJXU1ysBvDCiDilLFhE7A38Jbtn5W7u1GDrHFm3fX8bMeqPqY85FGWLhrf6\nGHDZ/ksGOeaTwG0124uBL0dEwwcIImIy8E/sPhP/AeBjzQ1zN8+hqCDwE+Ai4AvAV4HLKdYTvDMi\n/jEiFrQRe7R/5uPeY09uY+Wjm0r7TlnURhK/USn9Y18Nk6e0Hk+SJEmSOufFwNeBXRT3Q7LyCmAq\n8Grg0ohYHRHviYg2LookteuKOzpYSv+uq2DXtv7t02bDYaW3QSVJkqTuJfEzcxewvKbprAa7Xlg9\nhOJi9ZsR8cbaGd4RcRrwQ6A2OfswcFPnRrybfeq2H2kjRv1673PaG0qph4CddW3PajFG2f4DrnGf\nmduBs4FVNc2vB26OiLdGxOKImBER0yNiQUT8BvBT4O01+68HXpqZT7Y4XoCZwIwB+hcC7wBWR8Sn\nBnq4oMRo/8zHveVrymfhH77vTA6Z2+L69etWw8M3l/dZSl+SJEnSCMvMyzPztcAhwJ8C99JXdbB2\ndv4RwEcpHlz/r4h4wbAPVppgNm7ZyQ33lc2faTOJv/K75e2LXwRTBi1wKUmSpAmq29NRLwFeUnl/\nVkRMzcz65POFwPspZi0nRSL5y8C/RsQ6YC9gFn1Ppld/fiIzS2plD01EzKD/ww1byvYdRP3CWVMi\nYo/MLHn0tjWZuS0ibgBOrGl+OfAPLYR5eUnbnk2c+/6IOIGixN/bKJLqRwEXNHHOS4C3ZeZ9LYyz\nHZMpHhw4KSLOzszyx6crRsNnXvk8yxy1adMmrrrqqjaG0xmbNhWz47s9hq/9omyVCFg4c0fL5z78\n7v/m8JL2rXscwPVrNsOdrcVT+4br+6Pxye+PhsrvkIbC78/oV/2MpLGscr36NxHxEeBMiuvsl1Jc\n19Ym86dTPET/+ohYSXEN/u+ZuWHYBy2Nc1evXktPb/9bjvP2nM7Sg/ZuLVhvL6y6rLzvSEvpS5Ik\nqbFultMH+ArQS3HBORd4Q/0OlRn7v03fOu/VRP004GkUieVq4r7qaorS7t1QlshuJ/FedsysNuI0\nUn8FsCwijm/mwMq6ev2WN6CJJD5AZm7OzHcB5wF3NnHI/cA5mXlOmwn8dcC/A78OHA/Mo/h+zKF4\ngOAtwPdKjjueorLDYIuVjZXPfNzKTG5f31Pad+y8ya0GY/9Hry7tenT/F0BEaZ8kSZIkjZQsXJqZ\nrwIOBf6CYim6stn5R1E8xP9gRHwpIk5EUsdcuaJ8LsiyJfOZNKnFewoP3QhPlcSLybD4xW2MTpIk\nSRNFV2fiZ+bDEfEKYHalqXSd8cxcHhHnAv8J7MvuCfuq6r+SLwFe341Z+BVlCd8dbcQpm1Y8UCn4\nVl0AvI8imQ3F3+dLEXHqQKXqI2I+8LkG3U2NLyJOobhh8Owmx3oI8H8R8XXg/Zm5arADKh4CfgW4\nqKSCA8DGymslxe/+AuC/gYNq9nk+RdWAPx7gPCP+mWfmCWXtEXHD7Nmzj1+2bFkbw+mM6uyzbo5h\n5SOb2HDZD/u1Twr4rVecxt4zpjYf7IEb4Ory1RAOf9kfcfj8I9sdptowHN8fjV9+fzRUfoc0FH5/\nRr/Zs2cPvpM0BmXmw8CHIuLDFLPy30qxROEkdk/mzwDeBLwpIm6huE/wn5m5efhHLY0PPb3JVSvL\nk/gdLaV/6PNh5oCrWkqSJGmC6/ZMfDLzu5n5v5XXdQPsdxlFSf2/Bm6vNEfltQ24FDg3M1+amd2s\nm1g2m7qdBarK1mMfcin9qsx8BPh0XfMzgMsj4qiyYyLiWcBVwIIGYQe90I+It1di1CfwfwR8HHgX\n8G7g7yptvzyUYub+TRHxmsHOA5CZqzLzfxok8Mv2/yFwEv3Xs397RDxtgEPHxGc+nl2zem1p+3GH\nzGktgQ9wy1fK2w88DkzgS5IkSRojMrM3M7+dmS+juI7/a+Bhdp+dX71v8gzgM8BDEXFBs5X6JO3u\n5vs38MSW/rehpkwKTl08r/WAKy8pb19yVuuxJEmSNKF0dSZ+qzLzCeDPgD+LiMkUZdN3ZubjwziM\nskT2YOXYy5Qd0+mn4d8HnAbUzuJ+NnBLRFxOkUR/nKK6wanAMvoe3OihSMa/sObYDQOdLCJeB3yq\nrvlG4C2Z+fMGxzwT+AJFaXuAmcCFEbExM78/0PnakZn3RcRbgNpHnWdQlOL/SIPDxtJnPi4tX7Ou\ntP3URS1eIPfsglsvKu9ben6Lo5IkSZKk0SEz76e4V/IXwCuAtwEvqnbTl8zfk2LJwt+OiBsoEvsX\nZmY71eakCeeKFY+Wtj93wVxm79HiJIMn7oHHbi/vW3JOa7EkSZI04XR9Jn67MrMnMx8d5gQ+mbkV\n6K1rntlGqPoy6j2Z2dFZ2ZWL8HOA6+u6pgBnUqyh90/AB4Ez6Pu8E3g7cEvdcRsanSsi5gD/XNd8\nPXBKowR+ZYw3A6cAP6kb3xcjop2/66Ay8xLghrrmlwyw/5j5zMej7bt6uP6u8v+Zn7J4fmvB7r4a\nniqb1R/w9Fe3PjhJkiRJGkUq90q+kZlnAk8HVlAk77PmVU3oPxv4IvBARLw/ImaN0LClMeOKFeWV\nAtsrpX9pefu8I2HfI1qPJ0mSpAmla0n8iJgdEc+oeR00+FGjxoa67QPaiHFg3fYT7Q1lYJn5GHA6\n8FFgSxOHPAC8ODMvoP8YHxjguLcAtYt19QC/VkmADzbGrcAbK8dUHUyxdl+3fLtu+zmD7L+hbnvU\nfubjzY33bmDrzp5+7bOmTeZZh85pLdgtXy1vX3Aq7DXQigqSJEmSNDZExJER8QngGmAJReK+Xm1C\nfx7wV8CaiHjlsA1UGmMe3riVOx5+srTv9LaS+N8tb19yduuxJEmSNOF0cyb+64Gbal5ndvFcnbaq\nbvuQNmLUH7O6zbEMKjO3ZuZ7gSOAdwDfAe4CNgFbgTUUSe03AEdm5uWVQ4+uC1U/e73WS+u2v5eZ\nTf9OlX3ry+d38+bBirrtWRFRP1O+1pj6zMeT5WvKn3I/ceG+TJ3cwn+idmyBO+qf3aiwlL4kSZKk\nMSwipkTE6yLiCuAO4A/Y/UH76uz7dcBXKe4H1M7QD2B/4OsR8a5hHLo0ZlzZYBb+YfvOZOG8FgtZ\nbN0A915b3nekSXxJkiQNbkoXY8+nuEiE4oLxG108V6etAE6s2V7YRowFddt3tD+c5mTmIxRr1tev\nW99PpYzesXXNPxvgkKV129e1NrpfHnNWzfYz24jRrLL67PtQPNRQZkx+5uPB8tXrSttPWTyvtUCr\nLoUdm/u3T54GR7+8jZFJkiRJ0siKiAXAW4HfoLjPAv2T8wA/plgC7yuZuaOyfN2vAL8LPIu+2foB\nfDQirs7MgR7klyacK1Y8Vtp++pL9iIjSvobW/AB6d/VvnzEXDnluG6OTJEnSRNPNmfi1/1Ldkplj\nqbT47XXb7fzr+nl126MtoXsyMLlmewP9f+9a+9Rtl1/ZDOzRuu1924jRrL1L2jYOsP9E+MxHnQ1b\ndvCLB8s/llNbTeI3KqV/5JkwY05rsSRJkiRphETEpIg4NyIuo6jw9kfAfuw+USIoHlL/AnB8Zp6U\nmf+ZmTsAMnNLZn4hM08AXsXuy+dNAt4+PL+NNDZs29nDtWvKJxm88Og2SumvurS8/cizYNLk8j5J\nkiSpRjdn4tcmecvWZxvNrqzbPiAiFmbmXc0cHBELKcrUDRRzpNWvR/9fmdl/YfI+m4E5NdsDlaZv\nZGbd9lNtxGjW4rrtrZk50Pkmwmc+6vzozvVkyX8dDthrD46Yv2fzgbY8DqvrV2uosJS+JEmSpDEg\nIg4Bfhv4TeCAanPlZzVxHxTLwV0AfCkzB3pYvTgw81sRcSPFg+YzKzFe0NnRS2Pbj+9az9ad/W+L\nzZw2mecumFtyxAB6dsLq75X3LbGUviRJkprTzZn4t9a8nxURs7t4rk67AXiwrq2VTODr6rYfGE1l\n6iJiP+DcuuYvDHJY/cJg9UnyZhxZt13+iHNn1F8V/WKQ/cf1Zz5aXdPgKfdTFs9rrVTd7RdD787+\n7dP3hsUvaXN0kiRJktRdUXhZRHwbuAv4AHAguyfvAXqBi4EzM/OozPxkMwn8qsx8APjvmrgHduQX\nkMaJKxuU0j9l0TymT2lx5vx9P4JtJf/znDwNjjijjdFJkiRpIupmEv8Gdp+N/6IunqujMjMpLo5r\n/WZEDPqv9so+b6lrro810j7O7rPir8nMmwY5pr7/pc38PaoiYgrw0rrmnzd7fCsi4mTg1LrmywY6\nZgJ85qPS8tXlSfyOldI/5uUwdY8WRyVJkiRJ3RURB0bEnwH3UFw/nkOx5F11vXsq7x8D/gZYkJnn\nZmaDEmRNWVHzfvoQ4kjjSmZyxcryJP4ZR7VRSn/lJeXtC14A01uoOihJkqQJrWtJ/EpS9IKapvd0\n61xd8g/ArprtRcA7mjjuDyr7Vu2qxGooIpZFRNa9lrU23OZExK+zeyn9ncDvNnFofR2ww4H/18Kp\nfx84tK6tYWI9WpqGvdtx84Av1zX3UMw4GMywfeaC+9Zv4b7Ht5T2nbyohST+hvvh3mvL+yylL0mS\nJGl0ug/4C+AQ+s+6D+Ba4A3AoZn5p5WZ9ENVvQAba0seSl1159rN3P/41tK+01tN4mc2TuJbSl+S\nJEkt6OZMfICPUZSDC+DEiPhQl8/XMZm5Bvi3uuYPR8SLGx0TEWcCf1XX/KXMvLPDw6ue74Bmk/0R\nMTki3gV8sa7r45l5a9kxdf6H3SsrAHwiIs5r4ty/QvFdqLWuErORf4yIj0RE/TrzA53nGOBq4Ii6\nri9m5srBjh8Ln/l4cs2a+hUaCkcfuBfz9mxhUsitF5W373kAHH5KGyOTJEmSpK6rVn1L+ta73wJ8\nHjguM0/NzAszs2TdsCFr66F5aby6/I7yWfhPP2gv9t+rxep+a1fCE3eX9x1pEl+SJEnN62oSPzO3\nAC8DHqa4SPxARHy5Mlt6LPgzdk9czwS+ExEfi4gF0eeIiPg48G1gRs3+jwF/3sXxHQBcGRF3RMRf\nR8TpEbFPtbOSuD88In6PYnmDT7D7Z/4Diif/B5WZT1H8PWpNA74WEV+LiBdFxC9L9EfErIh4SUR8\nk2IW/JS6Y/88MzcNcMq9gPcCD0TE9yLi9yLi+RExu3aniNgrIs6KiH8DbgaOqYtzB/AnzfyOFaP9\nMx83ul5Kf+lrYFKL69ZJkiRJ0vAKYCXwTuCgzHxbZt7SpXP9F7Cg8lrYpXNIY84VKxqU0l/STin9\n75a3H/AM2Pug1uNJkiRpwqpPrHZURBxK8ST5+cDnKBKsbwTOj4hvAVcCtwBPAJtbjZ+Z93VutKXx\nH4mI11CUfa8maqcBf1R5bau0lT2WuxU4LzMf6eYYK44C3l95ERFbK+efQ+MHNb4PnNvKU/2Z+fmI\neDpFafxa51VeRMRGipsQew0Q6rOZ+dkmTzsFeHHlReUcO4BNFAn2GQ2OA7gTOCszn2jyXGPpMx/T\nenqT6+5cX9p3Siul9B+9HR5tUEhi6WvbGJkkSZIkDYtdwLeAz2TmlcNxwszcTBv3XqTxbOPWnfzs\n3vLbRi2X0gdYdWl5+5JzWo8lSZKkCa2rSXzgHnZfa61aIm468JrKq11J98dPZl5TKaf+VeDAuu5G\nNbUeBl6Tmdd1dXCNzaBxcnsH8LfAhzKzp43Y76RIjv8t5b//3gMcux14H/DJNs5baxqw7wD9SVEW\n/x2VmxQtGaOf+Zhyy4Mb2bi1//Mj06ZM4rkL5rYQqMEs/H0Xw4HHtTk6SZIkSeq6wzPzoZEehDTR\nXbN6LT292a9931nTOO7gOa0F27wW7v9Jed8SS+lLkiSpNV1PgrP7WmvVtd7q20e1zLw2Io6mKO/+\nO0CjLOPjFOvX/W1mbhyGod0NfAg4B3gWfWvqNRrbhcDfZ+Zd7Z4wM5NivfpvAv8PeBP9E931HgH+\nHfjnzLy3yVP9A3AfcBpwPLBnE8c8BHwD+HRmrmjyPKVG8Wc+Lixfvba0/TmH78MeU5ssgd/bC7d8\nrbzvGedDjJn/xEiSJEmaYMZjAj8iJgPPAZ4OzKO477OWogLjz9qcSNBREXEYxRj3p6heuIXi3spP\nuvGZRMT+lfMdVDkfwAbgwco5y+u4t3++Uf8ZjDaNSumftmQ+kya1eF9h9WXsPpepYvbTnGggSZKk\nlg1HEr/kX68Dtjdj2LNzlQTt+yLizyguiJZSXBBBcUF0K/DTzNzVRuyraON3qozpg8AHI2IWcBxw\nBLAfRan57RQJ9NuAmzKzt9VzDHDueykS3O+NiAUUifb9KC5KE9hIsT78jZl5dxvxfw78HCAiAlhE\nsWbfwcA+FDPit1MsxbCucp6OLq/Qzc98ortm9brS9lMWzW8+yP3Xw8YGH/nSoRT5kCRJkiQ1KyL2\nAv6E4gH4RuujrYuIzwMfzcwnh21wQERMoxjb24ElA+x3HfCxzLx4iOebAfwm8FaKZPpA+94CXAB8\nMTO3DbTvIHFG9WcwWvX0JlevLJ9kcEY7pfRXXlLevuQsJxpIkiSpZd1O4t/H0JL1o04lYfujymvU\nyMyngOsqr+E+990UT653K34CqyuvYTdaP/Ox6qntu7jxvvL15k5d3OheQ4lGpfQPejbMXdjGyCRJ\nkiRJrYiIE4GLgKcNsus84P3Ar0fEazLzx10fHBARSykq9h3RxO4nAd+MiIuAN2fmpjbO93zgy8Di\nJg9ZCnwG+P2IeGNm3tDGOUf1ZzCa/fyBDax/ake/9imTglMXtzDJAGDnNrjzivK+Jee0MTpJkiRN\ndF1N4mfm4d2ML2ns+cndj7Ozp/+zPfvMnMoxB+7VXJCenXDbN8r7nnH+EEYnSZIkSd0XEfsA/0pf\nVb6LM/PLbcb6DeAVlc0e4DcqD/p3VUScDHyPohJgvW0Uv9v0uvaDgMsj4sWZ2dVJCBGxDPgWMLuk\nuwd4EtgbmFTXdx5wcES8KDM3t3C+04BLgBkNdtkE9AJ70b8a41HAVZVzXt/COUf1ZzDaXdmglP6z\nD9+HvWdMbS3Y3T+EnVv6t0+dBYef2sboJEmSNNHVX6hIUlc1KqV/0qJ5za83t+Zy2Pp4//aYDMee\nO4TRSZIkSdKw+HXgXOCVFAn4phO3JX5cifNK4NXAG4Y8ukFExH4Us79rk8c7gb+jWA5vJkUyezHw\niUpf1UzgokqMbo3vEOCr9E/gXwicDEzPzLkUCe5lQH0J/ecBX2jhfHMp/h71CfwfU3wmczJzr8yc\nQ5HEPwe4vG7fPYGLK6XxmznnqP4MxoIrGiTx2yul/93y9kVnwNQ9Wo8nSZKkCc8kvqRhtXxN+Xpz\npy7qQCn9hctgzwl9D0KSJEnS2FBbQmx5Zq5oN1Dl2Gvom93d9SQ+8FfA/jXbW4GXZeYfZead2WdN\nZr6H4kGFrTX7HwB8qIvj+xd2Xxu+B3hTZv5qZl6XmT1QLJ+XmVdn5quAd9bFOD8iXtLk+d4H7FvX\n9lng5Mz8RmZurDZm5ubMvCQzX0T/v8H+FGvbN2O0fwaj2qNPbuO2h54s7Ws5iZ8Jqy4t7zvy7BZH\nJkmSJBVM4ksaNo8+uY1Vj5ZXIzxlcZNJ/O2bGz/hbil9SZIkSaNcZab1c4DqOmMNnlJuydeq4YET\nI2JWB2KWioiFwJvrmj+Qmd9rdExmXgr8WV3zWyJiQRfG92zgzLrmj2Tmfwx0XGb+E3BB/XER0UzJ\nuNfVbf8C+P3M7B3knB+kKIdf6/WDnWy0fwZjQaNS+ofMncER8/dsLdjDN8Omh0s6Ao6s/ypKkiRJ\nzTGJL2nYLG9QSn/BvFkcvE/ZEn4lVvxf+TpzU2bAUS8dwugkSZIkaVg8A5hM38z5azsQc3nN+ymV\nc3TLu4DaBcPXAP/UxHGfrOxbNRX4w84N65d+p257HfDhJo99H1B7wXk8/R8I2E2ldP8hdc0XVGf7\nN6H+b7cwIp42yDGj/TMY9S5vkMR/4VH709xzGzVWXlLefsjzYFYLVQclSZKkGibxJQ2b5WvKk/in\ndKKU/pKzYXr9coeSJEmSNOosqdu+tQMxb6v8rM7uP7IDMRt5Zd32F5pJWFf2+WJd86s6NagaZ9Rt\nX5iZ25s5MDM3AN+oa371IIcdUNL2k2bON8C+ZTFrjfbPYFTbvquHaxvcnzi91VL60Lha4BJL6UuS\nJKl9JvElDYvMbJzEb7aU/ua1cOcV5X2W0pckSZI0NuxT835TZu4aasDM3AnULvA9d6gxy0TECcDB\ndc1faSFE/b6HRMTxQxtVn4iYCxxR19xqpYP6/V8ZEQPdPyvr29jC+TaUtE0taQNG/2cwFlx/1+Ns\n2dH/mYcZUyfzvAUt/k9nw/3wyC3lfSbxJUmSNARTuhk8Ig7tZvzMvK+b8SV1zspHN7F2U//JD5Mn\nBc8/Yt/mgtz2DSibXDBjHzjihUMcoSRJkiQNi9p7Mdlwr9bVxprRwbi1Tq/bfiQz72r24My8MyIe\nBfavaT4DuLETgwPKplGvKWkbyOqSmEuBnzfY/8GStlYywWUXxGULrFeN9s9g1LuiQSn9kxfNY4+p\nk1sLturS8va5C2FeNwtiSJIkabzrahIfuIfOXpDWSro/fkkdsnx1+Sz84w7em732aDjJYHeNSukf\n8yqYMq29gUmSJEnS8Fpf836viJgy1Nn4ETEV2Ju+ezAbhhJvAMfUbbdSNr7qeuAVNdtHtz+cfsqS\n563Mim+0/zE0SOJn5gMRcRewsKb5JJr/25xUt/3QIJNWRvtnMKplJleuLE/in9FWKf1LytuXnAMR\nrceTJEmSKoajnH508SVpjLimQRL/lMXzmwvw+N3wQIN7E5bSlyRJkjR21F8cndCBmNVy6NV7JWs7\nELPMUXXbTc8Ar3H3IDGHon/5N5jeYow9StoGS3J/vm77dyNi0CfNIyKAd9c1f2mQw0b7ZzCq3bXu\nKe5dv6W0r+Uk/vZNcM815X1HntXiyCRJkqTdDUcSPzv4gu7N7JfUJdt39XD93etL+05dPK+5ILd8\nrbx970PgkBPbHJkkSZIkDbvqjO7q/Y1XNNqxBa+s2769AzHL1NcHv7+NGPXHdLLm+BMlbU0+OT7g\n/ksGOeaTwG0124uBL0dEwwcIImIy8E/sPhP/AeBjg5xrtH8Go9oVd5TPwj/mwL04YO+y5zcGcOcV\n0LOjf/sec+BQ71NIkiRpaLpdjv4+Wk+6TwLmALNr2qoxHgO2Dn1YkobTDfc+wbadvf3a95w+hWce\nMmfwAJlwy1fK+55+HkwajueRJEmSJGnoMvOuiLgHOIxi5vzbI+LvM7P8yedBRMR84Pco7p0ExRrp\ntw18VNv2qdt+pI0Y9eu9z2lvKKUeAnYCtWu2PQu4qoUYzyppG3CN+8zcHhFnAz+gLyH+euCZEfFJ\n4AqKBH0v8DTgNOAddedaD7w0M58cZHyj/TMY1a5YMQyl9Be/BCY3uWygJEmS1EBXk/iZeXi7x0bE\nHOBU4K3AORQXo08C52XmLZ0Yn6ThsbxBKf0TF85l6uQmEvCP/ALWrSrvs5S+JEmSpLHn68C7KO51\nzAa+GhFnZubOVoJExFTgK/RNhEjgG50caM25ZtC/omN5XfKB1U/OmBIRe2TmtvZG1iczt0XEDUDt\nNOiXA//QQpiXl7Tt2cS574+IE4APAW8DZlCUqb+giXNeArwtM+8baKex8BmMZk9u28lP73m8tO/0\nVpP4Pbtg1WXlfUvObnFkkiRJUn/dnonftszcAHwb+HZEvBz4d+AI4OqIeHFm3jCS45PUvOVrypP4\npyxqspT+LxrMwt/vWNj/2DZHJUmSJEkj5mPA/6Nv/fXTgB9ExBszs6ny6BFxKPCfwCn0zcLfDnyk\n88MFyhPZ7SR9y46Z1WasMpexexJ/WUQcn5k3DnZgRLwYOK6ka9AkPkBmbgbeFRHfBz5FcR9rIPcD\nb83MBlO6mxrHsH4GlYckyhy1adMmrrrqqjaGM3SbNm0CGPD8P31kF7t6+xcMnT0VNt51M1fdHU2f\nb+8Nt/Gsrf0fCOiNyVz7yHR61jUeh0anZr5DUiN+fzQUfn80FH5/Rr/qZ9SOMVGDOjO/DbyaouzY\nHOBbETFgKTNJo8MTT+3glgc3lvadsriJpQl7e+DWi8r7lr5mCCOTJEmSpJGRmY9RJNuDvgT8qcAd\nEfG5iDitMut6NxExo9L3eYp170+udlXifDwzH+zSsMsWDC9ZEHxQ20va+v2uQ3ABu48rgC9FxF4D\nHVRZluBzDbqbGl9EnBIRPwW+y+AJfIBDgP+LiK9FRDPr0o+Vz2BU+vnantL2p8+fzKRoPoEPsO/6\nn5a2b5jzdHqmzGp5bJIkSVK9UTsTv15mXhkR/wH8BnAA8JfA74/ooCQN6ro715P9H3TnwL334Ij5\nTVzY3nstbKpfrq/CJL4kSZKkMSozPxwRzwLOpS+RPxP4rcqrJyIeATZU+vehuB8yuRKimrivHntx\nZn6wi0Mum6U9rY0405uM3ZbMfCQiPk2xXEHVM4DLI+LXMnNF/TGVz+E/gQUNwm4e7LwR8Xbgk/R9\nPlU/ApZTrEMfwIEUD188v3oocB5wdkT8emZ+bYDTjPhnkJknlLVHxA2zZ88+ftmyZW0MZ+iqs88a\nnb+3N3nP8h+U9v3qac9g2XFPa+2Et7yntHnuiW9g2fPKx6DRbbDvkDQQvz8aCr8/Ggq/P6Pf7Nmz\nB9+pgTGTxK/4V4okfgBvjog/zsz6dbwkjSLL16wtbT9l0TyimSfdG5XSP/QkmHPoEEYmSZIkSSPu\njcAXgNdTJOOhuOcBxT2bgyuvMrX7X0iR+O+mskR22czwwZQdM2iSvEXvo1iioDbh/Gzgloi4nCKx\n/jiwL0UFhGX0VavsAa4CXlhz7IaBThYRr6Mon1/rRuAtmfnzBsc8k+KzP77SNBO4MCI2Zub3G5xq\nLH0Go8ovHtzIus39ixZMnhS84MgmqgTWWrca1q8u7zvyrDZGJ0mSJPU3Jsrp17ievqfMZ1Cs+yZp\nlMpMrlm9rrTvlMXzBg+wazvc/q3yPmfhS5IkSRrjMnNrZv4q8HZgHX0J/BzkRWXftRRrqr+h25Mc\nKvF765pnthGqvmx7T2Z2bCY+QGbuAM6huI9UawpwJvAXwD8BHwTOoO/+WFJ8FrfUHbeh0bkiYg7w\nz3XN1wOnNErgV8Z4M8V9rZ/Uje+LEVH6dx1Ln8Foc8WKx0rbTzhsH/aeMbW1YCsvKW/f71jY57AW\nRyZJkiSVG1NJ/Mzsobhwql7UHjVyo5E0mHvXb+GBJ8rvI528qIkk/urvwfaN/dsnTYFjzx3i6CRJ\nkiRpdMjMfwYOA94BXElR2jwavLYBVwC/CxyWmf8yjEPdULd9QBsxDqzbfqK9oQwsMx8DTgc+Cmxp\n4pAHgBdn5gX0H+MDAxz3FmBuzXYP8GvNPFRR2eeNlWOqDgbeNMBhG+q2R+1nMJpc2SCJf8ZR+7Ue\nbNWl5e1Lzm49liRJktTAWCunD8UTxtWnzttfSEBS112zpnwW/jEH7sW8PcuW4KvTqJT+ohfBzLnl\nfZIkSZI0BlVmQn8a+HRETAWOAebTlyBeTzHz/o7M3Dkyo2QVcGLN9iFtxKg/pkFd8qGrJMnfGxGf\nBF4LvIS+v+sU4EHgDuB/gG/UJN6Prgt1wwCneWnd9vcys+nfKTNXR8T3gdo67K8ELmhwyJj6DEaD\nx57cxi0PlkwQoI0k/pbH4b4flfctOafFkUmSJEmNjakkfkQcA0ynL4n/5AgOR9Iglq9eW9p+ajOl\n9LdthFWXlfctfe0QRiVJkiRJo1slSd+wFPsIWsHuCeSFbcRYULd9R/vDaU5mPkKxZn39uvX9RMQs\n4Ni65p8NcMjSuu3rWhvdL4+pTeI/c4B9x+RnMJKuXFk+C//gfWaweL89Wwu2+nuQ9SsaAHvuD097\nVhujkyRJksqNqXL6FGXloK+c/oMjNRBJA9vV08t1d64v7TulmST+Hd+Gnu3926fO8ul2SZIkSRoZ\nt9dtP7eNGM+r2x5tCeSTgck12xvo/3vX2qduuzxjPLBH67b3HWDfifAZdNQVA5TSj4jSvoZWfre8\n/cgzYdJYu80qSZKk0WzM/OsyIt4J/DZ9s/B7gatHbkSSBvKLBzeyaduufu3TpkziOYc3UQq/USn9\no18G02YOcXSSJEmSpDZcWbd9QEQ0PRO8su/+g8QcafXr0f9XZvaU7lnYXLc9o41z1l/kPjXAvhPh\nM+iY7bt6WL66fKm/01stpb9rO6y5orzPyQaSJEnqsFGZxI/C7Ih4RkS8NSJ+DPw9xQz8oEjkX5KZ\nj4/oQCU11Ogi+bmHz2WPqZNL+35p0yNw9w/L+5aeP8SRSZIkSZLadAP9qyK2cpH2urrtBzJzoPXm\nh1VE7AecW9f8hUEOq19HbnEbpz6ybrv8growrj+DTvvp3U/w1I7+z2DsMXUSz184UMGDEvcshx2b\n+rdPmQELTmtzhJIkSVK5ribxI6KnnRewi6Jc2U3AP1OUBqsm7wG2A+/t5tglDU2jJH5TpfRvvYi+\n/7nXmDkPFi4b0rgkSZIkSe3JzAQurmv+zYgY5EltqOzzlrrm+lgj7ePsPiv+msy8aZBj6vtf2szf\noyoipgAvrWv+eaP9J8Bn0FGNSumffMS8wScY1Ft5SXn7EadbMVCSJEkdN6XL8VtcWKqhajYvgJ3A\n6zNzoPXIJI2gzdt3ceN9T5T2nbKoiSR+o1L6T381TO72f7YkSZIkaeRExHTgAGAOMJs2JmBkZoPS\nZh3xD8Dv0HdPaRHwjkr7QP6gsm/VrsGOiYhl9C/1fnpmXtXUSFsQEb/O7qX0dwK/28Sh32P3mfCH\nA/8P+HSTp/594NC6tssGOWbYPoOx7sqV5Un8lkvpZzZO4h95VoujkiRJkgY3HNmwkum0LamW0Af4\nEfC2zLxliDElddH1d61nV2///+nvO2saxxy418AHr1sND99c3mcpfUmSJEnjUEQcAbwNOB1YytDu\n1+QQjx84eOaaiPg34Ldqmj8cEbdm5vfLjomIM4G/qmv+Umbe2Y0xRsQBwFHNJPsrs9PfSTELv9bH\nM/PWJk73P8DfALVZ4U9ExMOZedEg5/4V4GN1zesqMRsaC5/BaHDX2s3cve6p0r4zWk3iP3orPPlA\neZ9JfEmSJHVBV8vpV0QbrwQ2AfdTPNH8N8DxmXmyCXxp9LumQSn9kxbNY9KkQQp03PLV8vZ9DoeD\nnz20gUmSJEnSKBIRMyPii8BK4F3A8cBU2ruXUvvqtj8Daqc4zwS+ExEfi4gF0eeIiPg48G1gRs3+\njwF/3sXxHQBcGRF3RMRfR8TpEbFPtTMiJkfE4RHxexRrzH+C3e+R/QD4i2ZOlJlPUfw9ak0DvhYR\nX4uIF0XEL2utR8SsiHhJRHwT+G/6P3Dx55lZsvB6P6P9MxhxjUrpH3XAbJ42Z0ZpX0ONZuEf9GyY\nvX+LI5MkSZIG19WZ+Jk5HA8JSBpllq8pT+KfOlgp/czGpfSXvhZiOO5FSZIkSVL3RcRc4BrgKPom\nNNQuJ1hb3qx+u1+4boyxkcx8JCJeQ1H2vZoNnQb8UeW1rdK2R8nhW4HzMvORrg+0+Nu+v/IiIrZW\nzj+HxhNbvg+cm5k7mz1JZn4+Ip5OURq/1nmVFxGxkeJzGqg83Wcz87NNnnOsfAYjplEp/ZZn4QOs\n/G55+5KzW48lSZIkNcEku6SOenjjVtY8trm075TFgyTxH7wRnri7vM9S+pIkSZLGiYgI4CvA0fQl\n6INiHfZbgNXsPqv+auBG4D7KZ9zfVtnnauCH3f8NIDOvAV4MPFzSvQflyeOHgRdl5vJujm0AM4C5\nlN8P2wF8CDi7Mru+Ve+kWHN+W4P+vWmcwN9OUYnh91o54Rj9DIbF5u27+Mndj5f2tZzEf/JheOim\n8j6T+JIkSeoSk/iSOmp5g1L6C+fPGrxc3S0NZuEfeBzMP3KII5MkSZKkUeNVwBn0za7vpShtPj8z\nnwn8fe3OmXl6Zj4nMxdQrL3+FuAX9CXy9wP+urLf6d0f/i/HdS3Fgwh/C5RnTAuPV/Y5OjOvG4ah\n3U2RkP8Z0DPIvo8Dn6EY2wczc7D9S2XhHylm/3+U8sR6vUeAjwFLMvMfMnOgaguNzjtaP4MRtXz1\nWnb29P9zzpk5lWcduk/JEQNYdWl5+5xDYb9j2hidJEmSNLiultOXNPG0XUq/ZxfcelF539LXDnFU\nkiRJkjSqvLvyszoL/7cy88vNHJiZ64B/i4h/B/4U+CAwD/i/iDg3MxvU/e6OzNwIvC8i/gx4DrC0\nMh6AtcCtwE8zc1cbsa+ijaUCKmP6IPDBiJgFHAccQfGww0yKme+PUFQwuCkze1s9xwDnvhd4L/De\niFgAHF857xyKz3ojxXr0N2Zmg1J0LZ+za5/BWHX5HeWl9JcdOZ/Jk1r8Sq28pLx9yTku+ydJkqSu\nMYkvqWN6e5NrGyTxT1k8f+CD774anlpb0hHw9POGPjhJkiRJGgUiYjZwIn2z8C9vNoFfq5J4/lBE\nbKGYzT0V+O+IODYzH+zYgJsfzy7gR5XXqFEpjX9d5TXc576boirAcJ1vVH4Gw623N7lyZdn9BTi9\n1VL6O56Cu64q77OUviRJkrrIcvqSOmbFI5tYt3lHv/bJk4ITF84d+OBbvlrevuBU2OtpHRidJEmS\nJI0KJ1Lcj6lO4f38UIJl5t8B11Y2Z1PMQJcmrFsf2si6zdv7tU8KOO3IQSYY1LvrKujpH4vpe8Gh\nJ7U3QEmSJKkJXZ2JHxEzgb+i78J0eWZ+vc1Y5wEnVzZ7gfdl5s6hj1JSpyxfU/6k+7MOmcPsPaY2\nPnDHFrjj2+V9ltKXJEmSNL4cXLe9fLADImLqIPdAPklxzySAX42It2dm/yespQngihXlpfRPOGwf\n5syc1lqwlQ1Wp1j0IpjSYixJkiSpBd0up/8rwB/SVyLuf4YQ637gD2pi3QBcOIR4kjrsmtWNSunP\nK23/pVWXwo7N/dsnT4OjX9GBkUmSJEnSqFFbpmx7Zj5Ssk/92uXTgYGS+N+veT+DYrb/D9sbnjS2\nXdkgid9yKf3eHlh5aXnfknNaHJUkSZLUmm6X0//Vys8AbsjMn7QbqHLsDfTN6n/TEMcmqYO27ezh\nJ3c/Xtp36mBJ/Eal9I88E2bMGdrAJEmSJGl0qZ2+W/I0MwCb6rYHrAGemU/WHXNUG+OSxry1m7bz\n8wc2lva98Kj9Wwv24A2wpWSyQkyGxS9qY3SSJElS87qWxI+IGRSl3LLyapCla0k1RgAviIjpHYgp\nqQNuuPcJtu/q7dc+e/oUjjt4TuMDtzwOq79f3mcpfUmSJEnjT22yfWaDfZ6s264vwV+ml77qhXMH\n2lEar65cWT4L/6A5Mzhy/z1bC7bykvL2w06CGfu0ODJJkiSpNd2cif8MiqfLqzPnO1HG7eqa93sA\nSzsQU1IHNCqlf+IR+zJl8gD/qbn9YugtqQo5fW9YfGaHRidJkiRJo8bDNe9nRMTUkn3uqvysJuVP\nGChgRMwB9q5p6v+EtTQBNC6lP5+IKO1rqFESf8nZLY5KkiRJal03k/j1pdt+3oGY1RjVi9glHYgp\nqQOWr1lb2t52Kf1jXg5T9xjiqCRJkiRp1FlZt11W+n4NsLVme7AFuF9W+VnNUpY/ZS2NY7t6s+EE\ngzOO2q+1YI/fBWvvKO8ziS9JkqRh0M0k/r4175/KzG1DDViJUbte3CDZQUnD4fGndnDbQ/XVHgun\nLBrgf6Yb7od7ry3vs5S+JEmSpPFpBbClZvu4+h0ys5eiGmFUXi+MiFPKgkXE3sBf0jfhAeDmTg1W\nGitWPdHL5u27+rVPnzKJ5y9s8RbiykvL2+ctgbkL2xidJEmS1JpuJvFr16svqZXdttpYszoYV1Kb\nrl2zjsz+7QfNmcGCeQP8z/TWi8rb9zwADj+1M4OTJEmSpFEkM3cBy2uazmqw64XVQygS+d+MiDdG\nxLTqDhFxGsXyhQtqjnsYuKlzI5bGhp+v7Z/ABzjpiH2ZMW1ya8FWWUpfkiRJI6ubSfwnat7vHRFD\nPldETAbm1DSVT/2VNKyWNyhXd8qieQOvOdeolP7S18CkFi+wJUmSJGnsqM0QnhURU0v2uZC+0vsJ\nzAW+DDwZEQ9ExJPAFcBS+hL9CXwis+wxa2l82tXTy2X37OS6B8uT+GccvX9rAbc+Afc0qBq4ZLCV\nLSRJkqTO6GYSvzarF8DTOxDzWPrWdwMoX4Rb0rDJTJavaZDEXzxAubpHb4dHby3vW/qaDoxMkiRJ\nkkatrwC9FPc45gJvqN+hMmP/t4Ht1abK/tOApwF70pe4r7oa+GS3Bi2NRhff/BAXrtjBpgZ1QM84\nar/WAq65HLKnf/vMeXDws1sfoCRJktSGKV2MfXvlZ/Vi8hzgF0OM+bLKz+pF6pohxpM0RPes38KD\nG7b2a4+AkxcNkMRvNAt/38Vw4DM7MzhJkiRJGoUy8+GIeAUwu9J0f4P9lkfEucB/Avuye8K+qjrZ\n4RLg9c7C10Syq6eXT12xumH/kv1nc9CcGa0FXdmglP6RZ1o1UJIkScOma0n8zLw9Ih4B9qe4oPyD\niPhUZj7VTryI2BN4J31Pnj+RmTd0bMCS2rJ8dXlBjGOfthdzZ00r7aO3F275WnnfM84vngCQJEmS\npHEsM7/b5H6XRcSRwLuAVwHH0Je430ox+/5zmXlxN8YpjWYX3/wQ96zf0rD/9FZn4ffshNXfL+9b\ncnZrsSRJkqQh6GY5fYDv0Ddrfj7wr0OI9a+VGFTi/d/QhiapE65Z3aCU/qL5pe0A3H89bLyvvO/p\n53VgVJIkSZI0fmTmE5n5Z5m5lKKc/oHAvMyclZnnmMDXRDTYLHxoo5T+vdfB9o392ydPh4WntxZL\nkiRJGoJuJ/H/FthVeR/A+RHx3xGxV7MBImJ2RFwInE/fLPxe4COdHqyk1uzq6eVHd64v7Tt1cRul\n9A96Nux7RAdGJkmSJEnjU2b2ZOajmfn4SI9FGkmDzcLfe8ZUjj90TmtBG5XSX3gaTN+ztViSJEnS\nEHQ1iZ+ZdwEX0DcbP4DXAXdExPsi4rBGx0bEYRHxfmAFRQKfmjj/mpkrujl2SYP7+QMb2bR9V7/2\n6VMmccJh+5Qf1LMTbvtGed8zzi9vlyRJkqRxojJZ4Rk1r4NGekzSWNPMLPxTF89jyuQWbn1mwsoG\nq1wceVYLo5MkSZKGbsownOPdwLOAk+lL5B8IfBj4cEQ8CqwBNlT69wGOAA6oHF9d56167HXAO4Zh\n3JIGsbxBKf3nLpjLHlMnlx+05nLYWjJhJCbDsed2cHSSJEmSNCq9nmLCQ9VvA18cobFIY9Jgs/AB\n9tpjamtB166ADfeW95nElyRJ0jDrehI/M3dGxKuAbwCnUCTjoS85fwCwf91hUfO+dv9rgFdn5s7u\njFZSK5avWVva3lYp/YXLYM8W16qTJEmSpLFnPrtPWGhQqkxSmWZm4QNcs2Ytu3p6m5+N32gW/oHP\nhL0tmCFJkqTh1dVy+lWZuR44Hfg7YCd9ZfGrr36H1LwC2AF8BDijEkvSCNu8fRc33behtO+URfPL\nD9q+ufFFsaX0JUmSJE0MtWuSbcnMJ0ZsJNIY1MwsfID7H9/KxTc/1HzglZeUty85p/kYkiRJUocM\nSxIfIDN7MvOPgQXA3wP3UCToB3rdDXwMWJCZH8jMnuEar6SB/fjO9ezq7f8Mzrw9p3HUAbPLD1rx\nf7Cz5EJ7ygw46qUdHqEkSZIkjUqP1bwvm9ggqYFmZ+FXfeqK1ezq6R18x82PwQM/K+9bYil9SZIk\nDb+ul9Ovl5kPA+8B3hMRBwMnUJSSm1vZZT2wFrghMx8c7vFJas7yNetK209eNI9Jk6K0r2Ep/SVn\nw/QGiX9JkiRJGl9urXk/KyJmZ+amERuNNIY0Owu/6p71W7j45oc474SDB95x1WWUPlOz10FwwDNa\nG6QkSZLUAcOexK+VmQ8AD4zkGCS155rVa0vbT1k0r/yAzWvhzivK+yylL0mSJGniuIFiNv5+le0X\nAd8YueFIY0Ors/CrPnXFal75zKcxZfIABUkbltI/G6LBRAVJkiSpi4atnL6k8ePhjVu5c+1TpX2n\nLp5fftDt34SyFTFm7ANHvLBzg5MkSZKkUSwzE7igpuk9IzUWaSxpdRZ+VXU2fkM7tzaedLDk7JbP\nJ0mSJHWCSXxJLbtmdXkp/UX77ckBe+9RftAvvlLefsyrYMq0zgxMkiRJksaGjwF3AQGcGBEfGuHx\nSKNau7Pwqz51xWp29fSWd979Q9i1tX/7tD3h8FPbPqckSZI0FCbxJbVseYMkfsNS+o/fDQ/8pLzP\nUvqSJEmSJpjM3AK8DHiYIpH/gYj4ckQ0uKiSJrZ2Z+FXDTgbf+V3y9uPOAOmTG/7nJIkSdJQTOlm\n8IiYCryN4oIU4LbMvLzNWC8Ejq1s9gKfqZSgkzSMenuTa9e0mMS/9Wvl7XsdDIec2KGRSZIkSdLY\nEBGHAluA84HPAccAbwTOj4hvAVcCtwBPAJtbjZ+Z93VutNLIGuos/KpPXbGaVz7zaUyZXDOnqbcX\nVl5afsCSc4Z8TkmSJKldXU3iA+cC/whUk+1nDSHWJOCTNbHuAb4zhHiS2nDHI0+y/qkd/dqnTApO\nPGLf/gdkwi++Wh5s6WtgkgVBJEmSJE0499B3f4PK+wCmA6+pvNqVdP9+jzRshjoLv6o6G/+8Ew7u\na3z4Jtj8SP+dYxIsfsmQzylJkiS1q9vZszdVfgawIjO/326gyrEr6JvV/+Yhjk1SGxqV0n/WoXPY\nc3rJfaJHfgHrVpYHs5S+JEmSpIkral5QJN+zrr3dlzQudGoWftWnrljNrp7evoZGs/APeR7MKpmo\nIEmSJA2TriXxK6X0l9F3EfqVDoStxgjghRExuQMxJbVgecNS+vPLD7ilwSz8/Y6B/Y8t75MkSZKk\n8S9rXo3aW31J40qnZuFXVWfj/9LKS8p3XHJ2x84pSZIktaOb5dWeDsysvE/gig7EvBz488r72ZVz\n/LwDcZtWeXDgOZVzz6N4oGAtxVp1P8vMnuEcT5mIOIxijPsDcyjW2bsb+ElmPjTAoe2eb//K+Q6q\nnA9gA/Bg5ZyPdeg8ewBHV17zKL4Dm4HHKao03JSZuzpxLpXbtrOHn9z9eGnfKYvn9W/s7YFbLioP\ntvS1HRyZJEmSJI0p92HSXRpQp2fhV33qitW88plPY8qmB+DRW8p3WnJOx88rSZIktaKbSfxj6rZv\n6kDMmys/qxe6RzFMSfyI2Av4E+B3KBLIZdZFxOeBj2bmk8MxrqqImEYxtrcDSwbY7zrgY5l58RDP\nNwP4TeCtFA80DLTvLcAFwBczc1uL53k28HLghcDzGPg7uzUivgX8Y2b+qJXzVM61DLiy1ePqXJ2Z\ny4YYY9T62T1PsH1Xb7/22XtM4biD9+5/wL3XwqYGz40sHcoSj5IkSZI0dmXm4SM9Bmm06/Qs/Krq\nbPzzei4r32HuETBvccfPK0mSJLWim0n82traWzNz81ADZuamiNgCzKg07T/UmM2IiBOBi4CnDbLr\nPOD9wK9HxGsy88ddHxwQEUuBbwBHNLH7ScA3I+Ii4M2ZuamN8z0f+DLQ7BXNUuAzwO9HxBsz84Ym\nzvFq4OPAwhaGNgN4HfC6iPhP4O2ZubGF4zWIa9asLW1//sJ9mTK5ZHWORqX0D30+zDm0gyOTJEmS\nJEnjyXknHMx5Jxzc1L5XXXUVAMuWLWv+BP/x3fJ2S+lLkiRpFCjJunXMzJr3WzsYt3Ym954djFsq\nIk6mKONflsDfBmwvaT8IuDwiTurm2OCXs8evpTyB3wM8AfSfOg3nAd+PiJb+hhFxGsXfo1ECfxOw\nkfKygEcBV0XE85o41XMZOIGfFL/bjgb9bwR+HBGNqiaoDctXryttP7WslP6u7XB7g4IPltKXJEmS\nJEkjZduTcPc15X2W0pckSdIo0M2Z+Btq3s+JiMjMIa33FhFB35rrAE8NJV4T59uPYgZ+7QMJO4F/\npCgPf1el7QjgbcA7gKmVtpnARRFxXKfWhC8Z3yHAVynWhq91IfBp4PrM7ImIKcDJwB8Cr6zZ73nA\nFyhmrzdzvrkUf48ZdV0/Bj4GXFGd+V55OOBU4N0UpfCr9gQujogjW1xyYCPwdYoHCJYDD2bmrsq5\njgBeXTlXbXWGo4DvRMTJmdnTwrmqVgN/3+IxDWrHj33rN2/ntofKP7JTFs/v37j6e7CtpBDCpClw\n7LkdHp0kSZIkSVKT7rwcenf2b5+xDxzSzNwTSZIkqbu6mcSvnbI7CVhEkRQdikWVWNWHAcpre3fO\nX7F7Ungr8KrM/F7dfmuA90TEDygSzdUk9wHAhygS/N3wLxQl/Kt6KErk/0ftTpVk99XA1RHxDoqH\nEKrOj4gvlPxOZd4H7FvX9lmKsvW7zfavLJ9wCXBJRPwl8Oc13fsDfwJ8oIlz3g58AvifzCxdCC0z\n7wQ+HhH/RvH3P6Wm+3nA71TG2aqHMvOCNo4bl669c31p+0FzZnD4vjP7dzQqpb/oRTBzbgdHJkmS\nJEmS1IKVl5a3L34JTO7m7VJJkiSpOd0sp7+m8rOacD+rAzGri1JF5ec9HYhZKiIWAm+ua/7AQMnu\nzLwU+LO65rdExIIujO/ZwJl1zR+pT+DXy8x/oqgisNtxlSoHg6mfsf8L4PfrE/gl5/wgUP93e/0g\n57qP4u//jMz8YqMEft151gIvp//34k8GO1aDW766/JmZUxfPo9/XZ9vGxhfEltKXJEmSJEkjpWcX\nrL6svG/J2eXtkiRJ0jDrZhL/JuDxyvsA/qBS1r0tETEV+AP6HgrYDFw/lAEO4l30lcaH4qGEf2ri\nuE/S9wADlRh/2Llh/dLv1G2vAz7c5LHvA2qT4sfT/4GA3VRK9x9S13xBC2Xq6/92CyPiaY12zsx/\nzsx/a7UMfmZuAP6yrvmwiFjaShztLjNZvnpdad8pi+f1b7zj29CzvX/71FmuLSdJkiRJkkbO/dfD\n1if6t0+aCke8sH+7JEmSNAK6lsTPzKSYfV2dons48NEhhPxYJQYUifwftLnOebNeWbf9hWbOV9nn\ni3XNr+rUoGqcUbd9YWaWZE37qyS6v1HX/OpBDjugpO0nzZxvgH3LYnbCN+l72KPquC6da0J45Knk\noY3b+rVHwMlHlCTxG5XSP/plMK2k9L4kSZIkTSARcWg3XyP9+0mj2srvlrcvOBX22Gt4xyJJkiQ1\n0O1Fnv6WvhLs1dn4OynK0jeVgI+IycDfAO+kSMxG5edQHggY7JwnAAfXNX+lhRBfoRhz1SERcXxm\n3jjkwQERMRc4oq752hbDXAu8oWb7lRHxtgFK45c98LGxhfNtKGmbWtI2ZJm5ISLWA7XZ5W49MDAh\n3La+/H+uT3/a3uwza9rujZsegbt/WB5o6fkdHpkkSZIkjUn30P/h805Jun+/Rxq7VjVY/u9IS+lL\nkiRp9OhmOX0y8xfA/9KXeA/gj4CfRcSvRMS0RsdGxLSI+FXgZ8B7qs2VOF/PzFZmgbfq9LrtRzLz\nrmYPzsw7gUfrmutnzg/FfiVta0raBrK6JOZAJecfLGmb28L59i1pe7iF41tV/4BAt26OTAiNkvil\npfRvvQjKngWZOQ8WLuvswCRJkiRp7IouviSVWbca1je4hbbkrOEdiyRJkjSA4Xgy+60UpcyPoi+R\nfxzwn8AXI+JG4E6KmdoJ7EMxy/x4YBp9F5/VY1cAv9nlMR9Tt93OAwPXA6+o2T66/eH0U5Y8b2VW\nfKP9jwF+XrZzZj4QEXcBC2uaT6L5v81JddsPZeZ9TR7bkoiYB+xd1/xIN841EezqTe5okMQ/dVEL\npfSf/mqY7GQQSZIkSaro5MPmtZMnJDXSqJT+/kthjitRSJIkafToekYtMzdFxMuBS4DF9F2kBjAd\nOLHyqlebvK9urwJelplPdm/EQPHAQa2mZ+HXuHuQmEOxvaRteosx9ihpG+xBg89TLJFQ9bsR8c+Z\nuWOggyIigHfXNX9p8CG27dySthvaCRQRU4FnAocAc4AngceBVZn5QJvjG1Pu3tjLtpIc/h5TJ3HC\n4fv8cntXTy+XL7+OMx+6qTyQpfQlSZIkqeo+Wk/iT6K4Lp1d01aN8RiwdejDksa5lZeUty+xlL4k\nSZJGl2GZFpuZd1XWmf8i8Jpqc91u9Un7rGu/EHhrZm7u2kD7HFm3fX8bMeqPqY85FE+UtM1vMUbZ\n/ksGOeaTwK8Bx1a2FwNfjojfyMyyBwuIiMmV42pn4j8AfKyVwbbod+q278zMFW3EeQ5FxYIZZZ2V\nygTfAT6ZmfUPbYwbt64rn4X/3AX7Mn3KZHb19HLxzQ/xqStWc+7Gf+fMsv+q7HM4HPzsro5TkiRJ\nksaKzDy83WMjYg5wKkXlw3Mo7p88CZyXmbd0YnzSuPTUerj/+vI+S+lLkiRplJk0XCfKzM2ZeT7w\nXODrwDbK12urbdsGfAU4ITPfMEwJfChK+tdqpxR7/Xrvc9obSqmHgJ11bc9qMUbZ/gOucV9J1J9N\nURGh6vXAzRHx1ohYHBEzImJ6RCyIiN8Afgq8vWb/9cBLu1VNISLeBNRniz/XZriZNEjgVywE3gGs\njohPRUSr1RDGhNsalNI/eeG+XHTDA7zo76/m3V/9Ofesf4pXTLq2PMjS10JY1VGSJEmShiozN2Tm\ntzPzZcArKRL4RwBXVyZQSCqz+nuQvf3b9zwADmz1tpokSZLUXcO+QHVm/gx4TaVM+YnA8RSzwqsJ\n5PXAWory5z/JzPpkdVdFxAz6P9ywpY1Q9WXspkTEHpm5rb2R9cnMbRFxA7svQ/By4B9aCPPykrY9\nmzj3/ZWbAh8C3kaR5D4KuKCJc14CvC0z72thnE2LiIPp/ze4H/hMN85XYzLFgwonRcTZmflYqwEq\nn2eZozZt2sRVV101lPG1bcvO5K4NPZQtq/i5q1by+La+ghrHxZ0smPRoaZyfbD2cLSP0O2hkbdq0\nCWDEvsMa2/z+aKj8Dmko/P6MftXPSJrIMvPbEfFq4HsUkwe+FRFLM/PxkR2ZNAqt/G55+5KzYNKw\nzXOSJEmSmjLsSfyqSnL+msprNClLZLeTeC87Zlabscpcxu5J/GURcXxm3jjYgRHxYuC4kq5Bk/hQ\nVFUA3hUR3wc+RfHE/0Dup1gKocHCY0NXeSjkf+lfTeD/ZWarD2GsA74LXA7cQrFW4ZMUM/MPoFga\n4PXAS+qOOx74ZkSc0YmHNUaDFY/30FuSwAd2S+ADvGpy+Sz8TXsewZZZB3d8bJIkSZIkyMwrI+I/\ngN+guGb9S+D3R3RQ0mizazvceUV535JzhncskiRJUhNGLIk/iu1R0rajjThla8QPVJq9VRcA7wOm\nVbYD+FJEnDpQqfqImE/j8vJNjS8iTqGY8d7sIueHAP8XEV8H3p+ZqwY7oA3/TJFcr/XZzPy/FmI8\nBPwKcFGDChAbK6+VFH/rFwD/DRxUs8/zKaoU/HEL5yUzS0seRsQNs2fPPn7ZsmWthOuYKy6+Fbh3\n0P0m08PLJv+4tG/2SW9m2UnLOjswjRnV2Ysj9R3W2Ob3R0Pld0hD4fdn9Js9e/ZID0EaTf6VIokf\nwJsj4o8zs75CoDRx3XMN7ChZpXPKDFjwguEfjyRJkjQIa0X1VzaDelpJ22DK1kfv2OzszHwE+HRd\n8zOAyyPiqLJjIuJZwFXAggZhS65m+sV4eyVGfQL/R8DHgXcB7wb+rtL2y0OB84CbIuI1g52nFRHx\nAeC36pqvBf6wlTiZuSoz/6fZJRwy84cUDw48Utf19oh4WivnHo129fRyyS0PN7XvSZNuY35sLOkJ\nePp5nR2YJEmSJKne9UBWXjOAU0Z2ONIos7JBccgjzoCpnZxzI0mSJHWGM/H7K0tkl83OH0zZMYMm\nyVv0PuA0oHYW97OBWyLicook+uPAvsCpwDL6HtzooUjGv7Dm2A0DnSwiXkdRPr/WjcBbMvPnDY55\nJvAFilLzUJSkvzAiNmbm9wc6XzMi4reAD9c13wa8PDPLqiF0VGbeFxFvoSi/XzUD+HXgI90+fzfs\n6unl4psf4h9+sIq1m5srQtGolD4LToW9xvzzDJIkSZI0qmVmT0RsAPahSOQfBQz5mlsaFzJh5aXl\nfUvOHt6xSJIkSU0yiV8nM7dGRC+7VymY2Uao+sd4ezq9Tnpm7oiIc4BvAc+r6ZoCnFl5lR4KvB1Y\nQpNJ/IiYQ1Gyvtb1wOkDlejLzJsr5fevAp5bM74vRsSSNtarrx3TeRTLCtS6B3hJZj7RbtxWZeYl\nEXEDuz9M8RLGWBK/mrz/1BWruWd98x/LdHZw5qSflncufW2HRidJkiRJGsRMiut9ANebkKoeuQWe\nfKCkI+DIRrfOJEmSpJFlOf1yG+q2D2gjxoF1211JKmfmY8DpwEeBZjKvDwAvzswL6D/GsiuaqrcA\nc2u2e4Bfa2aNvco+b6wcU3Uw8KYmxlsqIl5EsR795JrmRyl+t4fajTsE367bfs4IjKEtu3p6ueiG\nB3jR31/Nu7/685YS+AAvmnQje0bJ8ymTp8HRr+jQKCVJkiRJjUTEMey+rN+TIzUWadRpVEr/4GfD\nnvsN71gkSZKkJg3rTPyICOAciqTzc4FDgDkUT4hHi+EyM7s1/lXAiTXbh7QRo/6Y1e0PZ2CVJPl7\nI+KTwGspZoEfA8yn+IwfBO4A/gf4Rk3i/ei6UDcMcJqX1m1/LzOb/p0yc3VEfB84q6b5lfSfST+o\niHg+8E1gWk3zExQJ/DWtxuuQFXXbsyJiRjMPOYyUdmfe12tYSn/xS2DGnLbjSpIkSZKa9o7Kz6CY\njf/gCI5FGl1Wfre83VL6kiRJGsWGLYkfEW8C/gI4rLZ5uM7fohXsnsRf2EaMBXXbd7Q/nOZk5iMU\na9bXr1vfT0TMAo6ta/7ZAIcsrdu+rrXR/fKY2iT+M1sNEBHPoFh/flZN82bg7My8pY0xdcrjJW37\nAKM2if/lH93LX33n9iHF2JvNnDbp5vLOZ5w/pNiSJEmSpMFFxDuB36ZI3gfQC1w9ooOSRosnH4KH\nby7vW3LOsA5FkiRJakXXk/gRMQn4N+AN9CXts+5ny2GHOKzB1Gc2n1u618CeV7fd9SR+i05m91L0\nG+j/e9fap277sTbO+Wjd9r6tHBwRi4HvUVRvqNoOvDIzr29jPJ20d0nbxmEfxTA7Z/L1TIue/h3T\n94LFrisnSZIkSZ1UqXC4J8XEgecDb6ZYzq32fsslmVn2oLk08ay6tLx9n8Nh/lHDOhRJkiSpFcMx\nE//jFOuhQ99T4bVJ+Goiv1FifrD+briybvuAiFiYmXc1c3BELAT2HyTmSKtfj/6/MrMkG/tLm9k9\neT6jjXPOrNt+qtkDI+IQ4Afs/nfdBZyfmVe0MZZOW1y3vTUzm/79RsKvP/8w5syYOqRy+g1L6R/9\nCpi6xxBGJ0mSJEnjV0TZ09Dthar8rN5v2Q68t0OxpbFv5SXl7UeeDTFaC4RKkiRJMKmbwSPiBOAP\nKS4mqxeUlwNnA3OB/0dfUj8zcxIwm+KJ8vOBLwM76FvT7WrgkMyclJmT6Z4b6L9+XCu1wV9Xt/1A\nZg603vywioj9gHPrmr8wyGFr67brk9bNOLJue10zB1XG+wPg0JrmXuDXM/NbbYyjG+oXUvvFiIyi\nBVMmT+K8Ew7mB+86jU+89jgO37f+GYuBHch6njdpRXnnM17bgRFKkiRJ0rgVHXrV3m/ZCbw+M4e2\nbpo0XmzfDHc1WFliSf1tHEmSJGl06WoSn/5Pf382M1+cmZdl5oayAzLzqcy8NzO/lplvpkgWX0Fx\nQfoC4LrKTPeuycwELq5r/s2IGPTBgco+b6lrro810j7O7rPir8nMmwY5pr7/pc38PaoiYgrw0rrm\nnzdx3N7AZfR/AOB3M/O/mz1/N0XEycCpdc2XjcRY2tFuMv8Vk68r79jzADi8/s8hSZIkSaqTQ3xB\nXzL/R8CzR9GD7tLIu+tK6Nnev3363nDYScM/HkmSJKkFXUviR8RU4Bz6nghfCbyz1TiZ+QDwYuCi\nSpxDgO9ERLdrdf8DRbn2qkXAO5o47g8q+1btqsRqKCKWRUTWvZa1NtzmRMSvs3sp/Z3A7zZx6Pfq\ntg+nqKTQrN9n95n0MEiiOyJmAt8BnlnX9ceZ+bkWzj2oyrqC7Rw3j6JiRK0eYFQ8YNCKVpP5DUvp\nP/08mNTNQhmSJEmSNC60O/N+E3A/xXX63wDHZ+bJmXnLcP8C0qi28tLy9sUvgslTh3cskiRJUou6\nORP/BPrWTU/gU5m5a4D9G6rMjP814L5K0xLg3UMe4cDnXAP8W13zhyPixY2OiYgzgb+qa/5SZt7Z\n4eFVz3dAs8n+iJgcEe8CvljX9fHMvLWJEP8DPFbX9omIOK+Jc/8K8LG65nWVmI2OmQp8DTilruvD\nmfnxwYfbsn+MiI9ExP7NHhARx1As8XBEXdcXM3NlR0c3jOqT+fvP7P98w5FxP0dPuq/kaCylL0mS\nJEmDqCwT2M5rSmbOyczDM/OszPzTzLx5pH8fadTp7YFVDZL4S84Z3rFIkiRJbZjSxdjVNdOrT4p/\nf7ADImJyZvaU9WXmtoj4KPCZSszfA/66Q2Nt5M+AVwD7VbZnUlQB+Efgs8A9lfaFwNsoKg3UPsr7\nGPDnXRzfAcCVEbEC+DrFuvE3Z+YT8MvS/odQlLH/beC4uuN/APxFMyfKzKci4s+A2hnw04CvRcRF\nwAXAdZm5pXLuWcDJFLP8X1kS8s8zc9MAp/xH+q8zfyfwYES8rZkx11mZmVcO0L8XRbWA90TElRRL\nINwI3Fo7zojYCzgJeD3wq+z+eQPcAfxJG+MbdarJ/DkbV/Pjh3fx/YemcM/6LQC8ssEs/Dt7D+SI\nA585jKOUJEmSJEmq88DPYMu6/u2TpsCiFw7/eCRJkqQWdTOJP7fm/a7KzPZ6vXXbewBPDRDzWxRJ\nfID9I+KZ3XziPDMfiYjXUJR9r1YVmAb8UeW1rdJWVtp/K3BeZj7SrfHVOAp4f+VFRGytnH8Ojast\nfB84NzN3NnuSzPx8RDydItld67zKi4jYSPGQxV4DhPpsZn52kNMdU9J2BMXDE+34MjBQEr9qCsXy\nDb+suBAROyjKFc6k73tQ5k7grOpDFOPF5EnByQdN5U9e/wIuvvkhPn35Sl751HWl+17cczLvam9l\nAkmSJEmSpM5Y+d3y9sNOghn7DO9YJEmSpDZ0s5z+rJr3TzbYZ3Pd9kCJXzLzQWA7xcx+gGe0N7Tm\nZeY1FAndh0u696A8gf8w8KLMXN7NsQ1gBsVDFGWf7w7gQ8DZmTnQAxONvBP4A/oeYKi3N40/x+3A\nuyiqKIwl04B9aZzAT+BLwDMzs0GN+bGvOjP/++fP4OAoeZoduLj3pGEelSRJkiRJUp1GpfSPrC/6\nKEmSJI1O3Uzi1yaI60uOV9WXUz+4ybjVqb5Nr18+FJl5LXA08LfA4wPs+nhln6Mzs3yqcmfdTZGQ\n/xlQugxBjccpqhgcnZkfbLRswWCy8I8Us/8/SvnDDfUeAT4GLMnMf8jMHOyAEfAPwF8BP6T/wyWN\nPETxNz0mM9+Smc0eN6ZNufVrpe039S7i3jxgmEcjSZIkSZJUY/2dsHZFed+Ss4Z3LJIkSVKbullO\nf33N+z0b7PNg5Wc1qbsU+GmjgBExnWKmd3X/bo5/N5m5EXhfZV3451CMdV6ley1wK/DTzNzVRuyr\n6NLoMxEAAEeYSURBVHswodUxfRD4YGUN+uMoSs7vR1H6fTtFAv024KbMrF++oG2ZeS/wXuC9EbEA\nOL5y3jkUn89G4DHgxsy8u434yzo11ibP93Pg5wAREcAiYCHFgyX7UFRc2A48Aayj+L3G7az7hnp2\nwm3fKO26Zo9lRZ0HSZIkSZKkkdJoFv78o2HuwuEdiyRJktSmbibBV9W8j4hYUJLMvZ1iBnm1IsDp\nwBcHiHkqMJm+JP6wrz1eSdL/qPIaNSql8a+rvIb73HdTVAUYFyqVAlZXXqp15xWwtaQYRUzmd3/v\njzhoTcvPsEiSJEnShBMRMymqwVUf6F+emV9vM9Z5wMmVzV7gfZm5c+ijlMaolZeUty+xlL4kSZLG\njm4m8esT9EupS/Rm5vaIuJliFncA50XEezPzQcp9oPIzKBL5Kzs9aEkD+MVXytsXLmPK3gdw3gnD\nOxxJkiRJGqN+BfhD+iYp/M8QYt0P/EFNrBuAC4cQTxq7tj4B9zaY32ISX5IkSWPIpMF3aU9lffAb\n6Xuq/IwGu/4vfUn5PYCLI+Kw2h0iYs+I+BJwGn0XpVsZZbPhpXFt+2ZY+d3yvqWvHd6xSJIkSdLY\n9quVnwHckJk/aTdQ5dgb6Lv/8qYhjk0au1b/ALKnf/us+XCQMw8kSZI0dnR7TfnvUawfD/ByiifD\n630JeB99a6kfD6yOiB9RPE0+BzgFmF3Zv5rw/0JmbuvSuCUB7NrBYfdUZt/PeQh2bum/z5QZcPTL\nhndckiRJkjRGRcQMivL31UkKX+1A2K8CJ1DcM3lBREzPzO0diCuNLY0mHxx5JkyaPLxjkSRJkoag\n20n8/6UogR/A4RHxosz8Qe0Ombk+It4LfI7iAjYr4zqlZrfq0+RZeX8/8BfdHbokfvqvLLjnv4r3\nm39Wvs+Ss2H67PI+SZIkSVK9ZwDTKu8T+GEHYl5d834PiiUNG1zESeNT9O6ENZeXdy45Z3gHI0mS\nJA1RV5P4mXlrRHyUYjY9wAEN9vuXiHga8EH6Evm77VL5WU3gvywzn+j8iCX90lPr4eq/7dtet7J8\nP0vpS5IkSVIrjqrb/nkHYlZjVO+fLMEkviaYvTfeDts39u+YPB0WLhv28UiSJElD0e2Z+GTm+5rc\n7y8j4irgw8DzgUl1uzxJUXr/w5m5vqODlNTfVX8D20oufmvN2AcWvWh4xiNJkiRJ48O+Ne+f6sRS\ngZm5LSI2A7MqTfOGGlMaFW6+sJg8MHnwW5jz1v2kvGPhMpg2q7xPkiRJGqW6nsRvRWZeDZwaEfOB\npwP7ATuBB4AbM3PXSI5PmjAevR1+9sXB9zvmVTBl2qC7SZIkSZJ+aXrN+50djFsby4ylxodvvg1+\n+DF4wR8PnMzPZN/1DZL4S87u3vgkSZKkLhlVSfyqzFwLXDnS45AmpEy47P2QvYPvayl9SZIkSWpV\n7fKAe0fEpMxmLsAai4jJ9C1lCEU1Q2l8ePyuQZP5s566jxnbHis//sizhmGQkiRJUmfVl6yXNNGt\nugzuauIZmr0OhkOf3/3xSJIkSdL4sq7mfVBUIhyqYyuxqtZ2IKY0ulST+Z95TlFmv6evYGfDWfhP\nexbsdeAwDVCSJEnqHJP4kvrs2gHf+0Bz+x57LkzyPyGSJEmS1KLbKz+z8vOcDsR8WeVnNZG/pgMx\npdGpJJk/b12jUvqd+J+XJEmSNPzMwEnq89N/gfVN3uuJGHwfSZIkSdJuMvN24JHKZgB/EBFtr2Ef\nEXsC76TvoYAnMvOGoY1SGgOqyfxPHc9em1aV77Pk7OEdkyRJktQhJvElFZ5aD1f9bfP73/jl4hhJ\nkiRJUqu+Q5HAT2A+8K9DiPWvlRhU4v3f0IYmjTEb7i1v3/sQ2L8Tq1VIkiRJw88kvqTCVX8D259s\nfv9tG+Gqj3RvPJIkSZI0fv0tUF3QO4DzI+K/I2KvZgNExOyIuBA4nyJ5H0Av4IWaBLD4JVYRlCRJ\n0phlEl8SPHo7/OyLrR/3sy/CY3d0fjySJEmSNI5l5l3ABfTNxg/gdcAdEfG+iDis0bERcVhEvB9Y\nQZHApybOv2bmiq4OXhorVn4Xbr4QenYNvq8kSZI0ypjElya6TLjsfZC9bRzbA5e+r4ghSZIkSWrF\nu4Fr2T2RfyDwYeCuiHgoIn4YEd+KiIsr7x8E7gL+qrJv9ViA64B3DPcvIY1amx6Gb74NPvMck/mS\nJEkac0ziSxPdqsvgrqvaP/6uK2H19zo2HEmSJEmaCDJzJ/AqYDl9yfhqMj+AA4CTgZcCL6u8P7Cm\nv3b/a4BXVWJKqvX4XSbzJUmSNOaYxJcmsl074HsfGHqcy95fxJIkSZIkNS0z1wOnA38H7GT35HxZ\nybPavgB2AB8BzqjEktRINZn/k8+P9EgkSZKkQZnElyRJkiRJGiGZ2ZOZfwwsAP4euIe+2faNXncD\nHwMWZOYHMrNnBIYuSZIkSeqSKSM9AEkjaMo0eMlfw4WvG1qcM/+miCVJkiRJaktmPgy8B3hPRBwM\nnADMB+ZWdlkPrAVuyMwHR2aU0hg2dyG84I9h6WtHeiSSJEnSoEziSxPdkWfCwtOLte3bccQZsPgl\nnR2TJEmSJE1gmfkA8MBIj0MaF2qT95O9FSpJkqSxwX+5ShNdRDGT/oKTIXtbPHZycWxEd8YmSZIk\nSZLUDpP3kiRJGsP8F6wk2P8YePZb4Kf/2tpxz34L7Hd0d8YkSZIkSZLUKpP3kiRJGgf8l6ykwrL3\nwy++AtufbG7/PfaGZe/r7pgkSZIkSZKasGXGgcw884Mm7yVJkjQu+C9aSYVZ+8Ky98Jl729u/9Pe\nWxwjSZIkSWpJREwF3gZU1ya7LTMvbzPWC4FjK5u9wGcyM4c+SmmMmLuQO/Z7OY/tdxqnPfOFIz0a\nSZIkqSNM4kvq85zfhp99EdavGXi/fRfDc397eMYkSZIkSePPucA/AtVk+1lDiDUJ+GRNrHuA7wwh\nnjQ21JTNf/Sa5SM9GkmSJKmjJo30ACSNIlOmwUv+evD9zvxrmDy1++ORJEmSpPHpTZWfAazIzO+3\nG6hy7Ar6ZvW/eYhjk0a3uQvhVRfA7/0Unvkrls6XJEnSuGQSX9LujjwTFp7euP+IM2DxS4ZvPJIk\nSZI0jlRK6S+jmDmfwFc6ELYaI4AXRsTkDsSURheT95IkSZpATOJL2l0EnPURiJL/PMRkOPNvin0k\nSZIkSe14OjCTvpnzV3Qg5uU172dXziGNDybvJUmSNAGZxJfU335Hw7Pf0r/92W8p+iRJkiRJ7Tqm\nbvumDsS8ufIzKz+P6kBMaeSZvJckSdIEZRJfUrll74c99u7b3mNvWPa+kRuPJEmSJI0P82veb83M\nzUMNmJmbgC01TfsPNaY0Kpi8lyRJ0gRlEl9SuVn7wmnv7dte9r6iTZIkSZI0FDNr3m/tYNxtNe/3\n7GBcSZIkSdIw81FWSY0957e4e+WtACx4zm+N8GAkSZIkaVzYUPN+TkREZmajnZsREQHMqWl6aijx\nJEmSJEkjyyS+pMamTOPew88HYMHkqSM8GEmSJEkaF9bVvJ8ELAJWDzHmokqs6sMAa4cYT5IkSZI0\ngiynL0mSJEmSNHzWVH5WE+5ndSDm2ZWfUfl5TwdiSpIkSZJGiEl8SZIkSZKk4XMT8HjlfQB/EBFt\nV0qMiKnAH9D3UMBm4PqhDFCSJEmSNLJM4kuSJEmSJA2TzEzge/TNmj8c+OgQQn6sEgOKRP4PMrNn\nCPEkSZIkSSPMJL4kSZIkSdLw+luKhHvSNxv/byNicrMBImJyRHwUeGdNHBjaAwGSJEmSpFHAJL4k\nSZIkSdIwysxfAP9LkXivJuD/CPhZRPxKRExrdGxETIuIXwV+Bryn2lyJ8/XM/ElXBy9JkiRJ6rq2\n11yTJEmSJElS294KHAccRV8i/zjgP4EvRsSNwJ3Ahkr/PsARwPHANPpm3lePXQH85vANX5IkSZLU\nLSbxJUmSJEmShllmboqIlwOXAIspkvFQJOSnAydWXvVqk/fV7VXAyzLzye6NWJIkSZI0XCynL0mS\nJEmSNAIy8y7gBOBrFMn4aln8rNmt2v7Lw+ibfR/AhcCzM/Pu4RizJEmSJKn7TOJLkiRJkiSNkMzc\nnJnnA88Fvg5soy9BX5u8r23bBnwFOCEz35CZm4d31JIkSZKkbrKcviRJkiRJ0gjLzJ8Br4mIqRRl\n9I8H5gNzK7usB9YCNwA/ycydIzJQSZIkSVLXmcSXJEmSJEkaJSrJ+WsqL0mSJEnSBGQ5fUmSJEmS\nJEmSJEmSRonIzJEeg6QGImL9jBkz5h599NEjNoZNmzYBMHv27BEbg8Yuvz8aCr8/Giq/QxoKvz+j\n3x133MHWrVsfz8x9R3oskjTRjfT9C/9/W0Pld0hD4fdHQ+H3R0Ph92f0G8q9C5P40igWEXcDewH3\njOAwjqr8XDGCY9DY5fdHQ+H3R0Pld0hD4fdn9DsceDIzF4z0QCRpohsF9y/8/20Nld8hDYXfHw2F\n3x8Nhd+f0e9w2rx3YRJf0oAi4gaAzDxhpMeiscfvj4bC74+Gyu+QhsLvjyRJY4f/v62h8jukofD7\no6Hw+6Oh8Pszvk0a6QFIkiRJkiRJkiRJkqTClJEegCRJkiRJ0kQWEQGcA5wOPBc4BJgDzAaixXCZ\nmd7vkSRJkqQxzIs6SZIkSZKkERIRbwL+AjistnlkRiNJkiRJGg1M4kuSJEmSJA2ziJgE/BvwBvqS\n9ln3s+WwQxyWJEmSJGkUMIkvSZIkSZI0/D4OvLHyPikS8LVJ+Goiv1FifrB+SZIkSdIYFZntPtwt\nSZIkSZKkVkXECcBP2T0Rfznwd8D1wOuAz1b6MjMnR8QsYB7wHOClwOuB6ZUYPwTemJkPDtsvIUmS\nJEnqmkkjPQBJkiRJkqQJ5r1125/NzBdn5mWZuaHsgMx8KjPvzcyvZeabgcXAFRQPALwAuC4iFnZ1\n1JIkSZKkYWESX5IkSZIkaZhExFTgHPpK6K8E3tlqnMx8AHgxcFElziHAdyJij86NVpIkSZI0Ekzi\nS5IkSZIkDZ8TgBmV9wl8KjN3tRMoizUSfw24r9K0BHj3kEcoSZIkSRpRJvElSZIkSZKGz+LKz6j8\n/P5gB0TE5EZ9mbkN+GhNzN8b0ugkSZIkSSPOJL4kSZIkSdLwmVvzfldmrinZp7due7AS+d+qeb9/\nRDyznYFJkiRJkkYHk/iSJEmSJEnDZ1bN+ycb7LO5bnuvgQJm5oPAdory/ADPaG9okiRJkqTRwCS+\nJEmSJEnS8Hmq5v3UBvtsqts+uMm41RL9+7c6KEmSJEnS6DFlpAcgaXSqrLn4HODpwDyKm0FrgVuA\nn2VmzwgOT6NQROwBHF15zQNmU8wgehxYAdyUmbtGboSSJpKImAQ8EzgGOICiDPEW4DHgTuDnmbll\nxAaoUSki9qf4989BwJxK8wbgQeAnmfnYyIxM0jizvub9ng32ebDyszqzfinw00YBI2I6sHfN/t7v\n0YTgvQu1w/sXkkYL712oHd67mDi8qJO0m4jYC/gT4HcoLmTKrIuIzwMfzcxG5R81AUTEs4GXAy8E\nnsfA/7+yNSK+BfxjZv5oOMan8SEipgA/A46r78vM6H+EJrKIOAL4Q+BX2H3N4Xq7IuJG4P+Av8/M\n+rLFmiAiYgbwm8BbKRIAA+17C3AB8MXM3DYMw5M0Pq2qeR8RsSAz767b53agh74KiqcDXxwg5qnA\nZPqS+E90YqDSaOW9C7XK+xfqNu9dqBXeu1CrvHcxMUVmDr6XpAkhIk4ELgKe1uQhDwKvycwfd29U\nGo0i4tXAx4GFbYb4T+Dtmbmxc6PSeBUR76H4vvXjhbCqKjdMPgj8MTCtxcOPzswVnR+VRruIeD7w\nZWBxi4euAN6YmTd0flSSxruI2JMiyV5N0J+bmd8q2e+nwPEUM4u3AYsz88H6/Sr7XgmcVtlM4EWZ\neWWnxy6NBt67UCu8f6Hh4r0LNcN7F2qH9y4mrkmD7yJpIoiIk4HLKb8I3gZsL2k/CLg8Ik7q5tg0\nKj2XgS+Ak+LG5I4G/W8EfhwRjWZMSABExGHAX4z0ODS6RcRs4DLgTym/CO6lKI1Z9v9lmqAi4jSK\nf/s0ugjeBGykb1ZrraOAqyLieV0anqRxrDKD6kb61q8/o8Gu/1vZJylKq15c+bfRL0XEnhHxJYoE\nfvW/V1sBZ45qXPLehdrg/Qt1nfcu1AzvXagd3ruY2EziSyIi9qN4in1mTfNO4O+ARZX2GRT/R/GJ\nSl/VTOCiSgxNXBuBL1Fc3B4OTMvMuZk5neI79MfAo3XHHAV8p7KGodTIp4FZIz0IjV6VNYC/Q/8E\nyO3Auyj+GzQ9M/fNzD2A/YBzKL5bjwznWDV6RMRcin/7zKjr+jHwamBOZu6VmXOAvSi+M5fX7bsn\nRUJtry4PV9L49L2a9y9vsM+X6CuLnxSz8ldHxNUR8Z8R8R3gAeBNlX2qCf8vWDZT45H3LtQh3r9Q\nN3jvQgPy3oXa4b0LWU5fEhHxOYp15Kq2Aq/KzO812P8s4Ovs/n8en8vMt3VvlBpNIuJvKdYfvJ3i\n5sj/ZOaWQY6ZT/G9OaWu63cz87NdGajGtIh4DfDVmqZvAq+q3ceSdIqITwLvrGv+C+BvMnNnvwN2\nP3Yqxfpz38/Mh7syQI1KEfFx4D11zZ+lKJXaO8Bxfwn8eV3z32TmBzo8REnjXEQ8HfhFZTOBMzPz\nByX7/TbwOfpm1gS7z7Kp/lsoK+/vA56VmU8gjTPeu1A7vH+hbvPehZrhvQu1w3sXMokvTXARsZBi\nbZSpNc3vysx/GOS4d1M87V61E1iSmXd3fpQabSLid4EtwH9kZk8Lx80BbqJ42r3q3sw8vGx/TVyV\np0PvoK9M5nconjz9Uu1+XghPbBHxHIqnj2urS70nMz8xQkPSGPH/27vvMMmqMvHj35cJZGaIgoIM\nkgYFyaKICyYMqCAqYsawisIqQZHVVZpdA2ai4K4uoOsKioqsrAll+C0CohJElCDYICqZGTIDw/v7\n41Yz1bdvVVdVV3dVd30/z1MPfU6de85bVUOF8957TkTcDGxUV/U7YIdWPtMi4ifAnnVVN2bmpl0O\nUdIAiIhPA/NrxV9m5n81aHcUxd6pzSZwAvgLsFdm/r6bcUr9wLkLdcr5C00m5y7UCucu1CnnLmQS\nXxpwEXEicFBd1Z+AheN9ENSWELuGYqmfESdk5vu7H6Vmkog4gNKPGeCZmXlVD8JRnyq9Nz0EPINi\nr1d/COsJEXE5sF1d1Q8yc5/eRKPpIiI2orhStV7LV1VFxF4Uk3P1npKZf+tGfJJUpbYX5ieA5zB2\na8R7Kb4jfSIz75rq2KSp4NyFesH5C43HuQu1wrkLdcK5C8HYH36SBs/epfLXWjmTq9bmP0vV+3Qr\nKM1oZzP2KqJtexCH+lTtDOX31lV90itlVBYRezD6R/CjwOG9iEXTzvoVdZe2cXxV26o+JalrMvOC\nzHwesAHwQoolVV8LPBtYJzMPNYGvGc65C/XC2Th/oQacu1ArnLvQBDh3IZP40iCLiB2BDUvV326j\ni3LbjSJih4lFpZkuMxcD5QlGv0AIeOJKmX9n+XeU64DP9S4i9bF3lso/zswbehKJppuq30BL2jh+\ncUXdnIo6Seq6zLwjM8/PzDMz83uZeWlmPtbruKTJ5NyFesX5CzXi3IXa4NyFOuXchUziSwPu+aXy\nrZl5Y6sH175w3FaqfsGEo9IgKH9hcG8XjTiE0WcoH5SZS3sTivpVRKwCvKZU/c1exKJp6a8VdWu1\ncfzaFXV/7zAWSZI0Pucu1EvOX6jKITh3oXE4d6EJcu5CJvGlAff0Urmd5VhG/KpU3qrDWDQgImId\nYF6p+tZexKL+EhFPBY6uqzozM8/rVTzqazsAK5fq/LeilmTmLUB54n/XNroot/1bZpb3qZMkSd3j\n3IV6wvkLVXHuQm1w7kIdc+5CYBJfGnQLS+WWz2SvU97rqdynVPbqirrfTnkU6kcnAqvW/r4POKyH\nsai/7Vwq/7m8D3BEPCUiXhARb4mI/SPihRHxlCmMUf3t30vl90XE3PEOiohg7P6Fp3YtKkmSVMW5\nC/WK8xeq4tyFWuXchSbKuYsBZxJfGmxblMp/6aCP8jHlPqWyd5fKN2TmNT2JRH0jIl4DvLKu6uOZ\n+bdexaO+t1OpfPnIHxHx2ohYRPH59HPg68C3KM52vyUi/hARR0bEqmiQHQtcXVfeHDg9IlZsdEBt\n38vjGX02+y3AZycjQEmS9ATnLtQrzl9oFOcu1CbnLjRRx+LcxUAziS8NtjVL5U6WBCvvozK/s1A0\nCCLirYz9AvuVXsSi/hERqwPH1VVdCZzQo3A0PWxUKv81ItaMiLOA7wC7A9Hg2K2ATwM3RMQLJzFG\n9bHMfAR4GXBdXfX+wBUR8Z6I2DwiVo6IFSNik4g4APg1cHBd+7uAvTLz3ikLXJKkweTchaac8xcq\nc+5CHXDuQhPi3IVM4ksDKiJWZux7wIMddPVQqTw7IlbqLCrNZBGxIfClUvVfgJN6EI76y6eAkaXC\nEnhfZi7rYTzqf+V9Ke8BzgFe00YfTwJ+HBFv6lpUmlYy8y/AjhSfTSPfZxYCp1D8QH4QeJhiyd5T\nge3rDv8RsENm/m7KApYkaQA5d6FecP5CDTh3oXY5d6EJc+5isJnElwbXahV1D3fQT9UxLvOjUSJi\nDnAmsFbprvdmZicTMJohImJn4H11Vadm5kW9ikfTxvxS+SBgt7ryIuB1wJOBFYH1gX2An5SOmw18\nLSJ2mIwg1f8y8/7MPIxiEuWGFg75C/DyzHx5Zt48udFJkiScu9AUc/5CVZy7UIfml8rOXagjzl0M\nLpP40uCqOuN8aQf9PFJRt3IH/Whm+zKj9+EBODkzz+1FMOoPtT2avsLy7yN3Ax/uXUSaRsoTrmvX\n/X1EZj4/M8/KzL9n5tLMvC0zf5CZL6X40Zx17Vek2E+s0RJ2msEiYreI+DXwv8CmLRyyEXBuRJwV\nEe6lK0nS5HPuQlPN+QuN4tyFJsC5C3WFcxeDa3avA5DUM1Vnoc/toJ8VW+xbAyoiPgq8q1T9S+DQ\nHoSj/vIBRi/xdGRm3tmrYDStVE3CApyQmZ9rdmBmfrm2POY/11VvDbwC+J8uxadpICIOBo4FZpXu\nuhi4kGLv3AA2AJ4LPGfkUIqz318WEW/LzLOmJGBJkgaTcxeaMs5fqAHnLtQp5y40Yc5dDDaT+NLg\nur+irpP94KqOqepbAygi3gV8olR9NfDKzGz0RVYDICI2Ao6uq7oE+GqPwtH0U/U5cy/wkRaPPxp4\nJ7BeXd1b8IfwwIiI1wMnlKovA96RmVc2OGY74GvAyBKGqwDfioglmfmzyYpVkqQB59yFpoTzF6ri\n3IUmyLkLTYhzF3I5fWlAZeZDwOOl6lU66Kq8/NyyzPRsdhERrwFOKVUPA3tm5j1TH5H6zIks399y\nGcX+gtmkvVTvvoq672RmSxOxtUm4b5aqd59wVJoWImI+xTKp9X4F7NboRzBAZl5BsX/hpXXVs4H/\njIhOvkNJkqRxOHehqeD8hZpw7kIT4dyFOubchcAkvjToFpfK63fQxwalsj9uRES8CPhvRi/zcxvw\n4sz8W2+iUr+IiH2BV9VVnVT7gim1quqz5sI2+yi3Xy8iOvkc1PTzDmCtuvIy4C21JEFTtTZvrh0z\nYkPgrV2NUJIk1VtcKjt3oa5x/kKNOHehLnDuQhPh3IVM4ksD7rpSeaMO+igfc32HsWiGiIjnAGcz\nep/Ceyh+AP+pJ0Gp39QvRXcr8LFeBaJp69qKuuE2+6hqv3bbkWg62qtU/mlmtvz9pda2vATd3hOO\nSpIkNeLchSaF8xcah3MXmijnLjQRzl3IJL404K4plZ/WQR+blMp/7DAWzQAR8Uzgf4FV66rvB16W\nmVf1Jir1odXr/l4fWBIROd4NOLXcUUW7oal6EOqpqs+aqmXqmrm3om5eB7Fo+tmmVL6ogz7Kx2zX\nWSiSJKkFzl2o65y/UAucu9BEOXehiXDuQibxpQH3h1L5WR30sUup7A/hARURmwM/BebXVT8C7J2Z\nv+pJUJJmqqsr6larqGumqn3Vj2PNPGuWyrd30MdtpbJXQkiSNHmcu1BXOX8haYo4d6GJcO5CJvGl\nAXd+qbx+RLR8Rnut7ZPG6VMDICI2As5j9L+Hx4D9MvMXvYlK0gx2CfBAqa7dZVWr2t/ZWTiaZu4v\nlVfuoI9VSuXyv0dJktQ9zl2oa5y/kDSFnLvQRDh3IWb3OgBJPfVb4K/AU+rq9gOOafH415fKt2Tm\nb7sRmKaPiFiP4gfwU+uqHwfelpnn9CYq9bmPMnpZulY9B3hrqe69pfJvOopI00pmPhwRPwH2rave\nFfivNrrZtVS+LTNvnXBwmg7uYPRVV5t30McWpbKTKJIkTR7nLtQVzl+oTc5daEKcu9AEOXchIjN7\nHYOkHoqIk4D31VX9CViYmcvGOW4Wxb50m9VVn5SZB3c/SvWriJgHLGLsfjoHZuZXpjwgzWgRcQCl\nveUyM3oTjXotIt4IfLOu6i5gw8x8uIVjZwHDwIZ11d/KzDd2NUj1pYg4k2Lif8QwsNl4333qjp8N\n3MDoyd/vZuZruxakJEkaxbkLTZTzF5oqzl2onnMX6pRzFwKX05cEX6JYNmzEZsD7WzjuEEb/CH6s\n1pcGRESsAvyQsT+Aj/AHsKQp8B2KHyMj1gY+0uKxH2T0j2CA07sRlKaFn5bKCxh7ZUwz/8ToH8EA\nP5lIQJIkaVzOXahjzl9I6iHnLtQp5y5kEl8adJn5J+C0UvUnIuLFjY6JiJcA/1aqPjUzb6hqr5kn\nIuYAZwG7le76RGZ+rgchSRowmfko8PFS9Uci4i3NjouIfRn7GXZpZvpDZnCcAdxeqvtCRLxmvAMj\n4g3AZ0vVd9b6lCRJk8S5C3XK+QtJveTchSbAuQu5nL4kiIj1gSuB9eqqlwLHASdTLNUC8DTgQOAD\nwJy6trcD27ofz+CIiC8z9sy/G4DPd9jltZl5/sSi0kznknQqi4igOIu4PHl7BnAC8KvMXBYRKwA7\nUbxvvQ2o/3dzP7BLZv5hCkJWn4iIdwNVV119FzgFuCgzH6y1XRV4LsUSvntXHPO+zDx5smKVJEkF\n5y7UCecvNNWcu1CZcxfqlHMXMokvCYCIeB7Fl4mVK+4e2aNnpYr7HgL2zMwLJys29Z+IWATs3sUu\nT8/MA7rYn2YgfwirSkSsTbG35dYVdy8DlgDzgFkV9z8C7JeZ50xagOpbEXE8xfJyjSyhmDRZo0mb\nkzPzfU3ulyRJXeTchdrl/IWmmnMXquLchTrl3MVgczl9SQBk5v9RnA3494q7V6L6R/DfgRf5I1iS\n1CuZeRfFpFzVknKzgLWo/hH8V2B3fwQPtA9Q7JP7cIP759H4R/AjwGHAQd0PS5IkNeLchSRpOnLu\nQhPg3MUAM4kv6QmZ+UtgK+AY4O4mTe+utdkqMy+aitgkSWokM+/OzJcCrwd+M07zG4APAZtn5q8m\nPTj1rSwcBywEPkN1MqDsVop95bbMzC+ly5pJkjTlnLuQJE1Hzl2oE85dDDaX05dUKSJmAzsD2wDr\n1KrvAH4P/DozH+tVbJIkNRMRGwHPAjYGVqFYWux24DeZeUMvY1N/i4hNgB0o9tqdDyTL//1clpl/\n7l10kiSpzLkLSdJ05dyFOuXcxeAwiS9JkiRJkiRJkiRJUp9wOX1JkiRJkiRJkiRJkvqESXxJkiRJ\nkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqESXxJ\nkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqE\nSXxJkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJ\nkvqESXxJkiRJkiRJkiRJkvqESXxJkiRJkiRJkiRJkvqESXxJkqRpICKGIyLrbgt6HZMkSZIkSZIk\nqftM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmS\nJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmS\nJEmSJEmS1Cdm9zoASZLUvyJiNrAzsAmwLrAycCdwC3BhZt7f5fHmAs8Bng6sCdwH/LU21u1dHGcB\nsBPFY1oTWAzcDlyWmTd2a5zaWCsA2wObAesA84EHgTuAPwC/z8ylXRpnJ2Bbisf1MMVj+lVmXj/B\nvgPYtNb3+sA8ICkex+3AMPCHzFwykXEkSZIkSZIkSSbxJUlShYh4JvDPwMsoErZVlkbE+cBQZl7S\nYr/DwMZ1VZtk5nBErA78C3AgsEbFocsi4ufAEZl5ZYsPozz2XOAg4B+BrZq0uw74KnB8Zj7SyVi1\nfrYFPgy8BFirSdMHI2IR8HXg++0m9GsnWhwMfBB4SoM2VwMfzcwftNn36sBhwAHAgnGaZ0T8Afhf\n4LTM/EM7Y0mSJEmSJEmSCpGZvY5BkiT1iYhYGTiRImnbzrY7XwPem5mPjtP/MKUkPjAH+DHwtBbG\neYwiGf3ZNmIjInYGzqyN16qbgDdk5sVtjrUGcAqwPxDtHAsckpnHNeh3mLHP3QPADyhWL2jF5zLz\niFYaRsQuwPeBDVrsu97XMvNdHRwnSZIkSZIkSQOvncl5SZI0g0XEusD5wDuo/o7wEHAPxTLqZe8E\nfli72r0d6wHnMTaBfy9QdUX6bOAzEfGRVgeIiBcDv6A6gZ8Uj+nxivs2Bn4eEXu1MdbTgIuBN1Cd\nwH8cuJviuazsotWxKFZI+DljE/j3UST3q3woIg4ar+OIWEjxulQl8LM2xp1AxysVSJIkSZIkSZKq\nmcSXJEkjS7L/ANildNc5wN7A2pm5SmauBawEPB/4bqntnsAxbQ59AvDU2t9XA28E5mXmvNo4mwNH\nMzbp/cmI2HO8ziNiY+A7wGqlu84A/gGYW3tMKwK7At8otVsZOCMiNmthrNWBHwFPL901DBwKLKyN\nt3ZmrkKRhN8D+FfgmvH6r3AysE3t7/8H7AOskZlrZOZqwIbARyj2ra93TO2EjWaOZ/Rz9jDwReDZ\nwGq1MdbNzJUotj/YtfYYzwOWdfBYJEmSJEmSJEk1LqcvSZKIiM8CH6qrWkyxlPyPxznuDcDpFEvi\nQ3GV9nMbLUFfsST8iG8Db2m0H3xEbEWxSsCT6qpvAp6RmY2uOicizqdIlI9YBrw5M89ocsy+FEvv\nz66rviQzmy5ZHxHfBl5Xqj4OOKKVfe4j4iXA45n5swb3DzP2ucta/59v0u/zKK7Yn1NXfXhmfrFB\n+w2Av7J8VYClwO6Zecl4j6F2/FOBhZn501baS5IkSZIkSZJG80p8SZIGXERsCBxSV/UY8IrxEvgA\nmfktoH6P9SiVW/FbmiTwa+P8EXglo5e93xh4c6NjIuK5jE7gA3ywWQK/Ntb3gH8qVT87Il7UZKyd\nGJvA/3xmHtJKAr827k8aJfCb+EKzBH6t3/8DTipVv7bJIdsxeln/77WawK+Nd7MJfEmSJEmSJEnq\nnEl8SZJ0KKOv0v5yZv6yjeNPAG6oK78qIp7UqHGF97eS6M7MXwOnlqrf0+SQ8t7vV1AsE9+KrwCX\nluoObtL+yFL5dxRL2U+mO4GPtdj2q6XythExq0HbNUvlP7cVlSRJkiRJkiRpQkziS5Kk15fKJ7Zz\ncGYuA86qq1oB2K3Fw6/KzIvaGO6UUnn7iHhyuVFEBPDSUvWXM/PxctsqWew3VH4eXhwRs8ttI2Iu\n8PJS9ecy89FWxpqAb2Tmw600zMyrgSV1VasAGzVovrhU3rH90CRJkiRJkiRJnTKJL0nSAIuITYGn\n1FUNZ+b1HXR1Wam8S4vHndPOIJn5G+BvLYy1kLFXlJ/dzli19vVJ/1WAbSvaPRtYua68lNEnNUyW\n/9dm++FSeX6Ddr8Gsq68Z0QMRcScBu0lSZIkSZIkSV005moySZI0UMoJ8Icj4sAO+tmqVN6gxeMu\n72Csy4H6q++3Bb5farN1qXxzZt7RziCZeV9EXA9sWVe9DfDbUtMdSuXLWr1CfoKG22x/X6m8elWj\nzLwjIr4HvKau+ijgPRFxBnAucFFmPtjm+JIkSZIkSZKkFpjElyRpsJX3rl8InNyFftdqsd1NHfR9\nc6m8dgvjD3cwDhT7wdcn8ase17ql8g0djtWuJeM3GWVZqTyrSduDgZ2Ajevq1gcOqd0ejYgrgQuB\nC4BfZOa9bcYjSZIkSZIkSargcvqSJA22VpPt7VqlxXadJH7Lyev5FW3KS+l3mmAuj1XuF8Y+h4s7\nHKtdOX6TDjvOvBV4FvDdBk3mUCT5D6FYBeG2iDgzInacrJgkSZIkSZIkaVCYxJckabDNnaR+Y5L6\n7bTvSUt493isSZOZt2fma4FnAl8Crm/SfCVgP+A3EXFyRKw4FTFKkiRJkiRJ0kxkEl+SpMF2d6l8\nemZGF257tDj+Gh3EXD5mcUWbe0rleR2MU3VcuV+Au0rlqqv1p63MvCozD8vMLYAnA68HTgAup/qE\nhQOBU6cwREmSJEmSJEmaUUziS5I02O4olZ82xeNvPH6TMZ5aKpeT6DD25IQFHYwDsMk4/cLY53DT\nDsfqe5n598z8dma+PzN3ADYA3g/cVGr6hoh48dRHKEmSJEmSJEnTn0l8SZIG25Wl8k4RsfIUjr99\nF44pPwaAq0rlp0bEOu0MEhGrAZuXqn9X0fS3pfIOEbFSO2NNV5l5W2aeADwDuLR095t6EJIkSZIk\nSZIkTXsm8SVJGmyXM/pK9pWBV0zh+K9qp3FE7ESxpHu9cvIY4FrGLn2/dztjAfsw+rvSg1Qn8S8F\nHqorzwVe2+ZY01pmPgAcVarephexSJIkSZIkSdJ0ZxJfkqQBlpmPA98vVX88ImZNUQjbRMSubbR/\nT6l8RWb+tdwoMxP4can6vRERrQxSa3dQqfpnmflYxVhLgf8pVX8oIua0MtYM8udSedWeRCFJkiRJ\nkiRJ05xJfEmSdAxQn5zeGvhCp521miivc3wrCe+I2Bl4R6n6K00OObFU3hF4b4sxvQt49jj91ftM\nqfxM4FMtjtVXImJ2h4duVSrfOtFYJEmSJEmSJGkQmcSXJGnAZeYNjE1QfyAi/iMiVm61n4hYOyKO\nBH7UZgg7At+IiLlN+l4InMPo7y43A99odExmXgQsKlUfFxH7NgsmIl4FnFSqviQzz2sy1mXAt0rV\nH4yIY5s9rtK4e0bEi1ppO8k+HRH/HRG7tHpARKwDfLJUvairUUmSJEmSJEnSgDCJL0mSAD7E2KTr\nu4A/R8RQRDw7IlasvzMi5kfE8yLiAxHxU4orrz8NbNjGuCP72b8euCwi9o+I1evG2DQihoDLgPVL\nxx5Y24u9mQOAJXXl2cB3I+K/IuK5I1edR8Ss2mM8DfgBUL8ywP3AW1p4LO8BrivVfQC4tvYcbRER\nT3z3iog1IuIfIuLoiPgj8BOKVRB6bQ7wBuCSiLghIj4XEftExMal+GdFxJYRcThwJfD0uj4eAP5z\nasOWJEmSJEmSpJkhii1jJUnSoIuINYFvA82uBr8fWAqsQZEQr3J1ZlYmoyNiGNi4rupZwHdKdVAk\n3lcCVqTaxzPz35rEWT/mi4HvU71H++O1seZRfXLjQ8DrMvPcFsd6GnAusLBBk8eBxRSPbZWK+w/N\nzGMb9D3M6Odpk8wcbiWu2vGLgN3rqp6fmYsq2h1LcfJBlQTupdh+YQ1Gn+xQ3+btmXl6q7FJkiRJ\nkiRJkpbzSnxJkgRAZt4DvBT4BPBwg2arAWvROIGfFFdlt+oO4MXAjaX6eVQn8B8DPtpqAh8gM38G\nPB/4c8XdKwBrUv2d6Gbgha0m8Gtj3Qg8B/hegyYrUDx/VQl8KJL8vdbsDM+geG3WpjqBfw+wvwl8\nSZIkSZIkSeqcSXxJkvSEzFyWmR8DFgCfBW5o4bClwAXAR4BNM/NNbY55PbAd8DmKq7yrLAN+Bjwr\nMz/VTv+1MX5NcXX8YcA14zS/HjgC2DIzL+5grMWZ+RqKZP7ZFKsXNHNfrd0+wEntjjcJPgy8BDgB\nuIrWTiz4M/BJYPPM/PYkxiZJkiRJkiRJM57L6UuSpKYi4qnADsC6FFdgQ5F4vp1iD/hrM7PRlfvl\nvoZpsiR8RMwFdqXYX33N2jh/BS7MzNsm9EBGx7EJsCOwHjCfYkn924DLM7OVExfaGWsOsAvFiRHr\nUizrf39tvGuA32fmo90cs5siYnVgK2BTiudrNYqTKu4DbgF+l5lVqxxIkiRJkiRJkjpgEl+SJE2Z\nie7rLkmSJEmSJEnSTOdy+pIkSZIkSZIkSZIk9QmT+JIkSZIkSZIkSZIk9QmT+JIkSZIkSZIkSZIk\n9QmT+JIkSZIkSZIkSZIk9QmT+JIkSQMiIk6LiKy7HdDrmCRJkiRJkiRJo83udQCSJEmSJM0UETFU\nX87MoeqWM1/tZLEFdVWnZeZwT4KZpiJiD2CPuqpFmbmoF7FIkiRJkqaOSXxJkjRwTLCoH5mokWaM\no0rloV4E0ScOAHavKy8ChnsRyDS2B2P/TS2a+jAkSZIkSVPJJL4kSZoymbmg1zHUmGBRP9oDEzWa\nakPz5gJH1EqfZWjJ0l6GI0mSJEmSJJP4kiRJkjTI3gf8W+3v+4DjehiLJEmSJEmSgBV6HYAkSZKm\nRmYekJlRdzut1zFJ6qGheeswevWHoVqdJEmSJEmSesgkviRJM8iCI8+dveDIc9/a6zgkSdPC0cD8\nuvJ83F5EkiRJkiSp50ziS5I0A9Ql7/8InN7reCRJfW5o3tbAgRX3HMjQvGdMdTiSJEmSJElabnav\nA5AkSZ1bcOS5s4E3Ah8DNutxOJKk6WBoXgBfpPqk7lnAFxma91KGluTUBqbxREQA2wObA+sBqwN3\nAX8HfpmZd3VhjPWBbYFNgDWAOcCDwGJgGLg2M/820XGmWkSsRfG4NgPmASsBDwFLgJuA6zNzuAvj\nbELxGq0LrAXcB9wO/CYzb5xo/5IkSZKkwWASX5Kkacjkff+LiBWAHYGtKRIts4E7gDMzc0mT4+YA\nC4EtgQ0oEjRLgbuBW4BLMvPeyY2+dRGxEfAcYGOKx3gncB1wUWY+2qOYTNRIze0FvLjJ/XsCLwfO\nnZpwpr+IGKZ4H6y6b7yTITYZ7z2p9n7zEeBVFJ8pVR6PiIuBT2bmj8YZs9z/CsDbKFZneFYL7W8B\nzgP+OzN/VrrvtFpfVc4vzkNo6O2ZeVoLIbcsIvYB/gnYg3FWI4yI24ELgG9n5lltjLE6cBjFd7Mt\nmrS7DjgB+Eqzz8iIOAA4tcHdR0XEUU3COT0zDxgvZkmSJElSf4tML66QJGm6aDV5P3zMXk1nyAdR\nswRLC8YkWCr62yQzhyNiVeBDwHupTrRsn5lXlPp6MrAf8DJgN2CVJrE8DlwCfAE4OzMfb/VBVCRW\nmiZLImIIqE8UHJ2ZQ7X7dgU+SZEUqXIfRaLimMy8r9UYJ6KdRA1Fsn3SEjUUJzJMNFEzHhM1at/Q\nvLnAVTT/9wvFv+GtGVrSk5Nxpptuf8bU9TuL4r32UGBuG33+EHhjK++/EbEucA7w7Db6H3FDZo76\nPjJOEn88XUviR8QqwBnAKzs4fFlmtnTRQ0S8ETgOWKeN/q8DXpWZ1zbo8wD8bJAkSZKkgeaV+JIk\nTQNeeT89RMRCisTJpm0csxdF8mS8pPOIFYBda7dfRMR+3Vg+uR0R8THgaKDZySKrU1w1+sqI2DMz\nb53EeDpJ1KwHvA7Ylxa/E7eZqNmCIon/TxHRMFEj9cBBjJ/Ap9bmIODYSY1GDdVOCjsDeEWDJo8A\nDwDzGfsZ8grggoh4YWbe02SMFYFfUKwaU+XB2m0Vmp9g1ldq2w6cTeMVJx4G7qdYqWVVmn+eNRvn\n4xSfh1UeA+4FVmPsCRhbAL+MiJdk5m87GVuSJEmSNLOZxJckqY+ZvJ9WNgTOBJ5cqn+AIjnQKPmx\nOo0T+A9RJE9WA1asuP8FwIURsXNm3t92xB2IiKOBj5eql7I8kVROhGwDnB0Ru7azakAb8ZiokVo1\nNG8dRq+uMZ6jGJr3XwwtuXOyQppBPkrxfg5wcum+945zbKMTsU5jbAL/AuAUYNHIyVG1q/V3Bt4B\nvJ3lv/O3B74KvKbJ2IczNoH/XeBrFNu3PHECQETMpdjuZQeKlWNe1qDP0ylWjIFi5ZLN6+77EsVV\n6I1c3OS+dryBsZ8LvwBOAi7MzNtHKmvP3+bAdsBLKZ7z+eMNEBEHM/Zz4U8UJ3D9pP7krYjYgmLF\nncPr+l4bOCsitqvYZudilv+7eQXFFhgjzqU4YbARTxqTJEmSpBnAJL4kSX3I5P2kmIwES72vsDyB\n/yvg88DPRibmI2Jt4NVN+roH+HHtdiVwTWY+MnJnbcn93YB3Ay+sO24hcCJwQAsxTtRLWL7c8u3A\nZyiW9L+xFuNKFHtpf5LRSaFdKOI+ZRJiMlEjte5oYF4b7ecDQ8DBkxHMTJKZ3xz5OyJOLt3X9ntf\nRLwPeG1d1VLgHzPz6xVjL6NIml8SEWcA3wfWqN29b0Tsn5lnNBjqraXyYZn5paqGmbkU+F3tdlpt\ne5GXV7Q7Hzi/9jj2Z3QS/5zMXNQglm4qP67jMvOQqoa15++a2u2M2uoE+zXrPCK2o/icr3c88KHa\n81Qe4zrgExHxdeBHwNNrdy2g+Cw9sNT+Wmrv8RGxPqM/G37Tyb8pSZIkSdL0EpnZ6xgkSVJNt5L3\nw8fs1dHVxoMiIkZ9AcrMtp+vJvsffwr4l2zxS1ZEPBfYEvhmfdJ+nGPeDvwHMKtWlcCWmXn9OMed\nxuh9ipvuPRwRQ1Rftft/wN6NlmiOiNUoEjg71VVflZnPbBZfJyLixxQnF4xomKipOHZFYL/M/EaT\nNttRJMfqV0JomKipO+6pjE7UAHwlMw9scEjV8310Zg41eQhS64bmbU1xglCrW3eMWAZsy9CSq7sf\n1Mw00c+Y2hYhN1OcADRi/8w8s8Xj96FI5I+4IjO3r2i3EsWKLyP+BmzUzVVTImIRsHtd1fOnIokf\nEbdRbJsC8Ciwdmbe18X+y589p2TmeCcEjhy7ALic5Sd6PQJsnJm3NWg/hJ8NkiRJkjRwvBJfkjRQ\nFhx57kCcvdZvj3OATir4ZmZ+tJ0DMvOXwC/bPObUiNgQ+NdaVVBcxXd4O/106EbgFZl5b5P47o+I\nf6RIUozYJiK2qF2N2E31ialHKU6AaUntpImGCfyaYxidwD8lMz/QQt83R8RejE7UHBARRzVK1GgG\nGJrXV++9XTIL+D1D7VzAP8WGlsy0z5h3MjqBf06rCXyAzDy7ljzfo1a1XUTslJm/KTVdq1S+aTK2\nPemRNev+vrPLCfxtGJ3A/zvwwVaPz8zhiDiWYpULKD5j3sLYK/slSZIkSQOs3aswJEmSVG0pxd6/\nU+W42pgjnj9F4x7RLIE/IjOvAMoJox0nIZ6+TtQAx9ZVjSRqJKmZ/UvlEzvoo5z0372iTXl7j4W1\nVQBmgsV1f69XO/GtW8qvz6mZ+UCbfbTy+kiSJEmSBphJfEmSpO44p37/9clWS6RfU1e1TUSsPMnD\n3gGc3Ub7i0vlLbsXyhMW1/1tokbStFZ7H6/fiuRRYFEHXV1WKu9SblB7P6vfJmFN4FsRsU4H4/Wb\nS+v+ngV8OyKqtsDpxD+Uyj9tt4PMvIbRWxmMeX0kSZIkSYPN5fQlSZK64/xudRQRs4DNgKcBawCr\nU/29rT5pPxtYH/hzt+KocFFmLmuj/Y2l8vwuxjLiUmCv2t8jiZo3ZOZNXei7K4maiHiI5a+ViRpJ\nzWwPzK0rLwbeGdH2jgHrlcobNGh3CnBCXflVwM0R8T3gB8CizLyj3cH7wCks/2wAeA7wp4j4IcXJ\naD/PzFva7TSKF2LnUvVzImKrDmJcyvLPhnUjYnZmPtZBP5IkSZKkGcgkviRJUndcNZGDI2IOxZXf\nb6S4WruTq+rnTySGFgy32b68tP3qXYqjnokaSTPJk0rldYGTu9DvWg3qTwZeBry8rm5l4E21GxFx\nHXARcAFwXifvqVMtM38YEf8OvLuuejawT+1GRNxEsWLMyOP6Uwtdz6fYGqXepycY7og1KVa8kSRJ\nkiTJJL4kSVKX3NPpgRGxG/AfwMIJxjAZSfJ65f2Tx1O+an9WtwIZYaJG0gzTKNk+UZV73WfmsojY\nGzgaOAxYqaLZFrXbAUBGxIXAccD3MjMnJ9yJy8z3RMQ1wFHAvIomG9du+wNExGXAl4HTm5xoNVmv\nDzR4jSRJkiRJg8kkviRpoAwfs1fb69FOpQVHnjub4krsj1Esp96Rfn+cM9T9nRwUES+huGK8KnHS\nrhW60EczfZmsMVGjvjG0pH/fe4fmbQ1cSfvvE8uAbRlacvW4LdUNc8dv0pGG/zZr74MfjYgTgLcB\n+wI7Un3iVQDPq90ujIj9M/OvkxBvV2TmlyLiVIpVBV4HPJuxJ2iN2AH4KnBoROyXmX+oaDNZrw80\neY0kSZIkSYNnsid6JUlSG4aP2eux4WP2+jqwFcVEeitXDGuaioi1gG8yNoH/c+BwYA+KkznmAStl\nZtTfKK4sF0WiBlgAHEzxvDzSpPlIouaKiHh6gzYmajSzDC35PcX2E+06xQT+lLq7VL6g/N7f4W3B\neANn5q2Z+ZnM3IVixZCXAp8Azqf6PXU34LyIWGOCj3lSZebizDwpM/egWGVlD+BfgB9TfQLeM4Bf\nRMRGFfeVX5/HgDldeo2Gu/WYJUmSJEnTn0l8SZL6kMn8gXEosHZd+R5gj8x8UWZ+MTMvyMwbMvPe\nzKxKoKw2NWFODyZqpHEdRXvbYiwGhiYlEjVS3mrjab0IIjPvy8yfZObHMvMFFJ9V+wEXlZouBD48\n5QF2KDMfrn22fjIzX0bxuF4O/G+p6ZOAYyq6uIvRq9LMBp46KcFKkiRJkgaaSXxJkvqYyfwZ79Wl\n8qGZ2c7V9et0M5iZxESNVGFoyZ0Ue5+36ujaMZo6v2P0e89GEbGgR7E8ITMfyMzvZOZzgZNKd7+x\nFzF1Q2YuzcwfZeZewBGlu/eNiJVK7R8DyitT/MNkxihJkiRJGkwm8SVJmgZM5s88ETEbqF/K/VHg\nzDaOX5din3e1wESN9ISTgOtaaHctY5O1as2y+kJEVO0tXykz7wSuLFW/rhtBddE/U3xmjVgQEas3\naPtYqdzyc9EDnwf+XldeCdi8ot15pfJkvz7T6TmUJEmSJHWJSXxJkqYRk/ld03GCpYvWYfTe6Hdm\n5sNtHL9Xl+MZJCZqNLiGliwFDm+h5eEMLXl0/GaqcF+p3O6e8WeVyh+MiL7ZPiUz76NYraTeqg2a\nT/S5mDKZmcBNpeqqx1V+fV4eEbtMTlTANHoOJUmSJEndYxJfkqRpyGT+hPXDhHg5ObZGRLT03ax2\n0sGh3Q9pMJiokTgX+FmT+3/K2K0n1LrbSuWt2jz+RGBxXXk94OsREdXNm2t0XG1FmE76WwdYt65q\nGXBHg+YTfS7a1umJeRExB9isVH1ruV1m/hJYVKr+ZkSs3cm4tbGbvbZdeQ4j4rSIyLrbcCf9SJIk\nSZKmhkl8SZKmMZP5HZvypEKFu4GH6sqrAru3eOxHgWd2PaJpxkRNwzFM1Ki5oSVJcSLQ4xX3LgMO\nq7VRZy4vlQ9sJwGfmUuAoVL1q4EfRMRarfYTEatGxHuB3zZosldE/F9EvLrVhH7tffc4Rq8UcmFm\nLmtwSPm5eGtENLpqv1u2iYgrIuJtEbFyG8f9G8UqOSOGGXvC14gjgKV15U2BCyNim1YHi4hZEfHK\niDgf2KFJ0/JzuHtEPL2ypSRJkiRpxujozHtJktRfho/Z6zHg6wuOPPe/gTf2Op5p4HJgy7rygRFx\nce0K7SmRmRkRFwAvras+PiKel5mLGx0XER9kbHJnUG0TEacBXwK+nZkPjdN+RLuJmguBubXySKJm\nv8y8qpXBakmvlwOHAR+kcUKtMlGTmX9oZRypLUNLrmZo3inA+0r3nMLQkqt7EdIM8j/A/nXltwBb\nR8RPKLbyWFpq/83aEvVPyMzjImJn4E111a8E/hwRX62N8evMfGDkztq+9E8HtgNeBuwJrAw8QLUA\ndqvd7o6Ic4CfU7wXXZ+ZT8QZEU8Gnk9x8seOpX6Ob9A/FCs6PM7yCwi2BK6JiB8ANzD6ZDaA8zPz\n2ib9tWpb4DTgpIg4l2J1icuAP9ZvXVNbVeB5wEHAC0t9nNDoe0Fm/joiDgL+o656IXB5RJwFfAu4\nKDOfWKEgIuZSbN2yLfAC4FUsX9Gg4UkemXltRFwHbFGrmgv8KiLOBn4P3AvUx3ltZp7fqD9JkiRJ\n0vRgEl+SpBlkJJnf6zimgQknWLrkK4xO4m8NXBYRRwE/zMx7ACJiHvAiiuTJc2tt7wRup0jYDDIT\nNVLnjqI48Wt+rbwYTxLqhu8CH2f0yWLb125VfszY7TQA3lX7b30ifw2KE4IOA4iIBykS4auz/GSj\nTqwFHFC7Uev7gVrfqwErNTjuPzPze406zcy/RMTX6/sFNqR4L67ydqAbSfwRqwL71W4ARMRDFCc2\nrFK7VfkZxYoDDWXmV2snTnyO5SsTzAJeX7sREUspXttVKE6o6NSnKD7rRqwGvLlB29MBPxskSZIk\naZoziS9JkgZRtxIsE5KZZ0fEjyiumByxCbUTMSJiCcXVi6uXDn2MIqnzkW7HNI2ZqJHaNbTkTobm\nHU2xmgXAEENL7uxlSDNBZj4SEXtTfNY8YwL9PAy8OSIuAT4BzKto1uz9bcQVjYYY57hVa7cqjwFf\nAP55nD4ADgbWBPZuoW03jPe4Vqb5+/RpwIFNtghYPlDmlyLicuCrFCu1lM0FxtuC5WaKLXaajXN6\nRCwAPsborQwkSZIkSTPUCuM3kSRJmlky8xGKZEI/LBm9P3BBg/vmMTaBfx+wb2b+dFKjmh5aSdSs\nQ+ME12nAK1tN1FCshnBDgyYjiZpmiaGWEjUUV0KPG5PUJV+mSAx+rPa3uqC2JPwOwOuAbwC/o/j/\nv7zSSyt9nQhsDPwLrX1uLQMupUj8b52ZuzVodw6wE3A0xbYhDzdoV+8e4GvAtpl5ZCvb0GTmA5m5\nD8VqKCcCF1OsJNPqFihtycwrKVap+TBwHq2dhPcgcAawa2a+vfY9odXxFlGcFPhWiufx0RYOu57i\n/7cXAZtk5o0tjHM0xUowR1M8rluA+xn/s1CSJEmSNA3FFG79KkmS1Fdqy56/qnbblmKJ39UYuyTx\nJpk5XDp2mCKp0rBNG3HMAg4BDgc2aNDsIYql24cy8y+14xYBu9e1eX4tmdBonNOAt9VVvT0zT2vS\nfohiue0RR2fmUKP2FccfAJxaV3V6Zh7Q6vEtjrEVxT7RLwZ2YexJD2UPUiSujs/MizsYbxbF8uPv\nro03Z5xDrqe42v97FHs9P97iOJtRXIH/XIqkzXyKK2Lrl+OvfD4rXuebMnNBK+NK6n8RsR7wLGA9\nihOVZlMkqu8CrgOuycz7O+h3DsX7zWbAkyneT2fV+r4DuKrW92NdeBhTpva+vTnF49qIYkuCORQJ\n8LsoToy4up3E/TjjrQI8G3gKxeuzam2sxcCNFNu93NGwA0mSJEmSMIkvSZLUN2qJhu2B7Sgm/leg\nuHLzGuCS+j3eNZaJGkmSJEmSJEkzgUl8SZIkSZIkSZIkSZL6xAq9DkCSJEmSJEmSJEmSJBVM4kuS\nJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM\n4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS\n1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmS\nJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmS\nJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuS\nJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM\n4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS\n1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmS\nJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmS\nJEmSJEmS1CdM4kuSJEmSJEmSJEmS1CdM4kuSJEmSJEmSJEmS1Cf+P9TjdkzHEYOEAAAAAElFTkSu\nQmCC\n", "text/plain": [ "\u003cFigure size 1600x600 with 2 Axes\u003e" ] }, "metadata": { "image/png": { "height": 466, "width": 1016 } }, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "\n", "plt.suptitle(\"Adversarial training on \" + f\"{DATASET}\".upper())\n", "axes[0].plot(\n", " accuracy_train, lw=3, label=\"train set.\", marker=\"\u003c\", markersize=10\n", ")\n", "axes[0].plot(accuracy_test, lw=3, label=\"test set.\", marker=\"d\", markersize=10)\n", "axes[0].grid()\n", "axes[0].set_ylabel(\"accuracy on clean images\")\n", "\n", "axes[1].plot(\n", " adversarial_accuracy_train,\n", " lw=3,\n", " label=\"adversarial accuracy on train set.\",\n", " marker=\"^\",\n", " markersize=10,\n", ")\n", "axes[1].plot(\n", " adversarial_accuracy_test,\n", " lw=3,\n", " label=\"adversarial accuracy on test set.\",\n", " marker=\"\u003e\",\n", " markersize=10,\n", ")\n", "axes[1].grid()\n", "axes[0].legend(\n", " frameon=False, ncol=2, loc=\"upper center\", bbox_to_anchor=(0.8, -0.1)\n", ")\n", "axes[0].set_xlabel(\"epochs\")\n", "axes[1].set_ylabel(\"accuracy on adversarial images\")\n", "plt.subplots_adjust(wspace=0.5)\n", "\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "GAPfMtFKQ4N-" }, "source": [ "Find a test set image that is correctly classified but not its adversarial perturbation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "executionInfo": { "elapsed": 1666, "status": "ok", "timestamp": 1707150695251, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "5o-Qf1qEQ4N_" }, "outputs": [], "source": [ "def find_adversarial_imgs(params, loader_batched):\n", " \"\"\"Finds a test set image that is correctly classified but not its adversarial perturbation.\"\"\"\n", " for batch in loader_batched.as_numpy_iterator():\n", " images, labels = batch\n", " images = images.astype(jnp.float32) / 255\n", " logits = net.apply({\"params\": params}, images)\n", " labels_clean = jnp.argmax(logits, axis=-1)\n", "\n", " adversarial_images = pgd_attack(images, labels, params, epsilon=EPSILON)\n", " labels_adversarial = jnp.argmax(\n", " net.apply({\"params\": params}, adversarial_images), axis=-1\n", " )\n", " idx_misclassified = jnp.where(labels_clean != labels_adversarial)[0]\n", " for j in idx_misclassified:\n", " clean_image = images[j]\n", " prediction_clean = labels_clean[j]\n", " if prediction_clean != labels[j]:\n", " # the clean image predicts the wrong label, skip\n", " continue\n", " adversarial_image = adversarial_images[j]\n", " adversarial_prediction = labels_adversarial[j]\n", " # we found our image\n", " return (\n", " clean_image,\n", " prediction_clean,\n", " adversarial_image,\n", " adversarial_prediction,\n", " )\n", "\n", " raise ValueError(\"No mismatch between clean and adversarial prediction found\")\n", "\n", "\n", "img_clean, pred_clean, img_adversarial, prediction_adversarial = (\n", " find_adversarial_imgs(var_params, test_loader_batched)\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "height": 386 }, "executionInfo": { "elapsed": 487, "status": "ok", "timestamp": 1707150695836, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "aF2-kSuQQ4N_", "outputId": "25b60801-8954-4c42-a5d6-92ab4d0d7d0d" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAACDUAAALjCAYAAAAI3awuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90\nbGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAABYl\nAAAWJQFJUiTwAACUnklEQVR4nOzdd7gsVZWw8XfBJecoiujFgGIEAyomUBEVdcwYUNAZ46hjdswY\nUHTMjjOGUZFg+gwYMKFyjahERcSAkgwoSFJyWN8fu670rVN9Tofq09XnvL/nqee5d3fX3vtUV1fv\nqlq1dmQmkiRJkiRJkiRJkiRJXbPWtDsgSZIkSZIkSZIkSZLUxKAGSZIkSZIkSZIkSZLUSQY1SJIk\nSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIkqZMMapAkSZIkSZIkSZIkSZ1k\nUIMkSZIkSZIkSZIkSeokgxokSZIkSZIkSZIkSVInGdQgSZIkSZIkSZIkSZI6yaAGSZIkSZIkSZIk\nSZLUSQY1SJIkSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIkqZMMatCSFhEr\nIyJryyHT7pcWV0Qc0rAfrJx2vyRJUjdExJm1ccKZ0+6TNI6IOLBh/LvHtPs1iGn3vaHtVYvVtiRJ\nkiRJarZi2h3Q8hYROwA7ASuBzYANgCuBC6vlHOBnmXnVtPooSZIkSZIkSZIkSZoOgxq0qCJiXeDB\nwGOBPYEbDbDaVRHxc+C7wOGZefLkeihJkqRxRMR2wB+AtRte3tWxnCRJkiRJkqRhOP2EFkVErBcR\nLwZ+DxwJPInBAhoA1gXuArwYOCkifh4Rz4oIg3IkSZK658k0BzQAHLCI/ZAkSZIkSZK0BBjUoImL\niPsCvwDeDmzfQpW3B/4XODUiHtVCfZIkSWrPAfO89sSIWGexOiJJkiRJkiRp9vmkuyYqIl4CHEz/\np/UA/g6cCpxV/ftaYBPgJsDOwFZ91tsJ+FxEbJGZF7XVZ0mSJI0mInYDbjPPW7YB9qFk7pIkSZIk\nSZKkBRnUoImJiIOBl/d5+WLgI8CngBMy87o+dQRwW+ARwFOAW7bfUy11mXkApruWJGkxPHXA9xw5\n4X5Iy1pmHggcOOVuzKTMjGn3QZIkSZIkrcnpJzQREfFimgMargPeC9w0M1+cmcf1C2gAyOIXmfkm\n4NbAY4DfTqTTkiRJGllErA/sWyu+CriwVvaQiNh2cXolSZIkSZIkadYZ1KDWRcSewNsaXvoH8C+Z\n+R+ZefGw9WbmdZn5OeB2wBspARKSJEnqhkcAW9TKvgJ8sla2AnjSYnRIkiRJkiRJ0uwzqEGtioiN\ngI8yd9+6Cnh4Zn5l3DYy86rMfC2wF3Of/JMkSdJ0HNBQdli11A0yTYUkSZIkSZIkGdSg1r0YWNlQ\n/qbMPKbNhjLzO8C9gCvbrFeSJEnDiYjtKQGnvS4AvpqZP2bu9GG3j4g7LUrnJEmSJEmSJM20FdPu\ngJaOiNgUeFHDSycDB0+izcz85STqHURErAPsAuwEbAdsAFwE/BU4Ezg+M1ufIiMi1qvavDWwDbBp\n9dLfKDcPfgX8MjOz7bbn6dMK4C7AbYGtgaRsh7OBH2bmkg88iYhNgN0on83mwDWUbXB8Zp46ZF3b\nAncFbgZsAlwC/JmyLc9tsdur29sK2Bm4OaXvGwOXUfan1X/DeW23W+vD6r/5RpT9+grgj8DJmfnr\nSbbd0JetgDtRtv/mlADA84C/AD/LzLMXsz+SNCOewtyA6U9l5lXVvw8HXl97/QDgxLY7Uo3R7kYZ\nl2wFXEv5TTsNOG45jEsAIiKAW1G2w40pv++XU37P/gz8ODMvW6S+rKD8tt4a2JYybr4U+H1mHjnA\n+msBtwDuQBlrbkY5l72cMsXdH4GzgN+29flGxHaU8dGOVXsbUcZkFwDnUvaloafUG7NPNwJ2rfq0\nKWXMfQHw9cw8a8Jtr1W1uzNlvLYpsC7l/OcCyvY/oec7r3lExM0o34kdKPvW34EzgB9l5vlD1BOU\n7/gdKeek61LG72cAP2j78+jKfhARd6Scd92Qcmy7GDid8r28YJJt1/rRmeOsJEmSJC11BjWoTU+i\nXPCre2VmXr3YnZmUiNgH+FfgQZQLsv1cEBFfB96emSeN0d5awL2r9vakBA+svcBq50fEt4F3ZOZx\nY7RdD4z4bmbu0fP69sDLgCczdw7t1S6LiK8Ar87M+lOaiyIiDgH2rxXvmJlnLrDeSsoFwV4fz8wD\net5zN+AVwIMpF/Sa6vkl8ObMPGKB9vYBXgLch+ZMOtdFxA+Bl2fmsfPVtUA7mwEPB+5P2aduMsA6\npwH/D3hvZv5t1LYb6n0U8O/AfemzX0fE74D3A/+bmVdUZXsA9ewvr8/MA0fsx2bA04EnUm4WzPfe\nU4HPAu/MzEtGaU+SlqADGsp6p504HDgQiJ6yJ0bES9q6+VWNS14J7Mf1QZ91/4iIzwBvyczTR2zn\no8ydPuOZmfmhUerrqXdt4A+UG5OrJbBymIC6iLgD8BzgUZRAwX6uiIjvAh/KzM+P0OVBxoq3pWRy\nezTNn8lZwJHz1H93yu/zo2k+z6i7KiJOoowRPp2ZJw+wzuq2tgUeAdwP2AO4wQKrXBcRP6Ps2x/M\nzEsHbavW7oHA62rFe2bmqur1dYCnAc+iBFQ3eSpwyDD1Dti3nSljxj0pGfI2WmCVyyPiWOB/gc9P\nIsB7MSy0X8+z3oHM/1muoJxHPh+4TZ9qromIL1POoX81T1sbA/9B+X7ctM/bLomIIyjnYSPf6O/K\nflAFkb+Uct3hZn3edm1EfAt4a2+myFHPB+fpy6IdZyVJkiRJhdNPqE31iwQA5wDfWOyOTEJE3Dki\njgO+AjyS+QMaALak3CA9PiIOiYh+N/3na/NFlGwHq4D/pDx1uFBAA5Qn2PYFfhoRR0XEfBdaRhIR\nTwZ+SbkoN9/ftiHwOODUiHh+2/2YlohYOyLeBRwL/At9AhoqtwEOj4jPR8Sc/SYitoqIL1H2rT3o\nf2xeHeDyw4h44wh93iwivkB5cuhQynd2wYCGys7Aa4GzIuJlw7bd0Jftqwu2n6PcPJhvv7458E7g\npIi4/bht1/oREfFcyk2V/2KBgIbKbSkXrE+PiKe32R9JmkURsTvlidlep1fTTgCQmb8Hflh7z1bA\nw1rqwwHAryk3mfoFNEB5ivZpwCkR8bwRmzu0oezJI9bVay/WDGiAcjN1oICGiLhBRHwS+BnwTOa/\n0QawPrA38LmI+H5E3G7YDs/Tl7Ui4iBKxranMv9n0rT+FhFxGGWc9TQGC2iAMh67G2XcfFJEPGKA\ntlZGxDeBPwEfpIyhFwpogDIu2xV4B2V89JQB+ziwatxzMvAB+gc0tC4i9o6Ikylj/YMp+8lCN7Kh\nnB/djxII+4uIuMvEOjljqswMP6V8lv0CGqA8ePJI4Gf99qkquPfXwJvoH9AA5Xv3bOBXEbHbCH3u\nzH5QBX//EngN/QMaoJxT7A18JyI+EiXLYWu6dJyVJEmSpOXGoAa1orppfteGlz46q0/o9IqIZ1Eu\nqo5yQWYtys3jH0TEjYdc9+HA9iO02eshwHEtX6h+NeWC/jAXqNcB3hMRL22rH9NSPUn5OeAFrPnE\n6UIeCXy2yr6xuq4bAj9guJs6Abw6Il47xDpQgk8eAYxzcW8j4K0R8YmImC+Qo6+I2AH4HvDQIVe9\nNfD9iLjzKO029GNTSiDJ+xj8ZkmvbYAPRcS7ImKY/UCSlpp61gJYM0vDfGUHjNt49XT0xxjsZttq\n6wPvjYg3jNDkdynBcL3uWd20HEdTYERTAMUcEXFPyk22x4/Y9r0oY9U9Rly/ty9rAZ+mZM0YOjNg\ndV6xipJxY1yDtL+SElAySOBwP1sBH4+It49Rxxqqz/RHzH8DfFLuQZnOYBw7A9+LiMe10J+ZVmU6\n+DGDBc+uti5ln1rjOx0RjwWOpkz9MKhtgG9W2QWG0Yn9oAqm/xJleodhPA346qjnLA396MxxVpIk\nSZKWI6efUFv2pDlI5luL3ZG2RcRLKE9wN7kW+DnwO8ocohtTnpa5G3O/X7ehPGF/52HmSW1wHWW+\n0N9S5g79O+Ui/raUeVm3bljnpsCXI2KXcef+rZ5Mr2cJOB84jjJ/a1Lmht2d5mwWb46Ib40zJUcH\nvIuSnaHXryhPD/2NcoN8F+Y+tQolyOSFwDuqrA1fpdysX+064HjKtBeXUC5C7k75fOteFxFHZeYJ\nI/8lxaXAKZT5oS8CrgI2p8yXe0eas1A8oXr/i4ZpKCK2oqSE7nfj5yzKxcK/UoJmVlKCiVYfXzYD\nvkhJtzuyKmXvN4C793nLJZR9+lzgMsr36g6UrBF1L6Bso38fp0+SNIuq37Kmm1WHN5R9BngvawbX\nPSgitsvMc0ds/1nMTfe+2lWUoNRzKGO2G1Nu0m3Y857XRJlWaGCZmRFxOPCq3q5QbsKPEiSx+nfp\nEbXiyyjTHS207n2Ar7Hm39Xr95Tf+b9RbtzfkLIdNqm9bzPg6xGxV2Z+f+DOz3Uw8Jha2QWUp9RX\njxW3p3/mgUMov7lN/gCcRpmr/jLKWHNTylj3NpRglbZcRNlu51X/vo4yPtqJkrWpKQjixRFxdma+\nd8y2dwDeTTm3WC2BkyhjxAsogRQrGe5G+bj+BJwKXEjZJutSstPdnjJurNuAkq3s95l5/GJ1smO2\noOzTvU/0X0UJcjgHuJzynbwXzUG2H4mIH2Xm2RFxb+ATrHmeeRHwE8qYNSn7xD0pAeW9NgM+FhG7\nZea14/1Ji7cfRJmq7hD6TM1HGa+fSTknvgHl3KU3E939KL87Y+ngcVaSJEmSlp/MdHEZewHeSrmI\n0rtcC2w05X6tbOjXIUOsf//q76jXcQ5l2oXN+qy3GfBqysWV+rpHDtH+qmqdPwL/Tbkos+EC69yD\nkuqz3m4Chw+5/errn0W58Lb6/z+iCmhpWHcjytzZ1zTU871F3g8OaejDyhH3n9Nr//8UsFOf9e9J\nufBer+NiykXxD/SUXUq5EbJVQz2rs31cPM62rP09x1LmpL110+fXs84GlKdof9vQ9nXAXkN+Fof1\n2Td/CNyjzzo3BN5e25d+01DHgUP049A+/fhm9T1r3CaUi7Vf7LPuIxdzv3ZxcXHpwkK5kV8/Hv5g\nnvd/tuH9Lxmx7ZtTbmzX67uCMv3A5g3rbERJx977m3oe5UZUbx1nLtD2Tg3t/naM7bh/Q31HDLDe\ndpRpperrXlr9dt60z3rrUAIUz2hY9yz6jHEb6qmv+xfWHDufCDwQWLth3XWBB9bKHtjnN/Zw4PYL\n9GVt4M7A6ymBpgk8ZoC/YY/qvdcA36ZMYbLjAutsBjyX628k1/e/2w3x2R/YUMefa/W9Bdiuz/q3\nAG41YL17DNmnv1PGuo8Dtl5gnR2BtwFXNrR7OrD+mNtkoL63sTS0vWqMfveOoS+kBDdv0uc7+aI+\n2+9QSnBE735xGiXIekVDXVtRstc0fZeeNsLnMK39YFtK8Hy9nmsoDx30+07sTjnX6fc5rF5WDtiP\nqR5nXVxcXFxcXFxcXFxcXMoy9Q64LI2FksK9fqJ+agf6tbKhX4cMuO62tQtHq5cv0HCRvE8dd6A8\nUVav40kDrv9+SiriORerBlj38cDVtXavBW4+RB1NF8JWL28HYoA6mm52JA0XXye4HxwyykWsPvtP\n78W0fx1wPzqrYf1PUIICknJBfNcB6roPzUEiA21LylN/nwZ2GWEbblitW2/7u0PU0e9GxfsH3Jfu\nR/PNq9XLgQP248kN614JPGOIv+XFDXWcD2y6WPu1i4uLSxcWSlau+vHwmfO8/+EN7/9Fi21fANxx\ngHVXAmfP85ty5gB1/LhhvcYAvQHq+nZDXXsvsE5Q0tDX1/slsPOA7W5BmU6jXseHB1x/vrHi/zFP\n4GSf+g5pqOcVI27TvQfcF+4BfAi4xQhtbNVn+318iDoOnGcbng/cacS/v6nePQZc998pmUg2H6Hd\n21Oe5K+3/dTF6HsbS0Pbq1r4LH8H3GyAOh7fsO6VrBm0fiQDBAdU+3W9rh8NsR2mvR98omH9ywbZ\nFyhB4f8zz+eRDHY+OPXjrIuLi4uLi4uLi4uLi0tZmlL4SaO4SUPZ6Yvei3a9lvJURq+vAY/NzIsG\nqSAzf055guaq2ksvHXD9f8/MwzLzmkHeX1v3U8xN0b8W7aTI/0hmviQzc4B+HE55sr3uCS30Y5pe\nkZkfWehNmflX4GUNLz2BcpHsauChOcB0HJn5PeCjDS8NNK9rZp6Tmftm5smDvL+27mXAkyjZOXrd\nJyLuOGA1/9lQ9iXguQPuS9+heb7xgVVp0t9Zr5ry1NqHBq0nM9/RUM9WwL+O0z9JmiURcRNKwFmv\nqyjTTPTzNcqN2l63jYi7DNn2fSgZtXpdBzw8M3+20PqZeSblpvdlw7Rbc1hD2dC/UxFxY0q2gF5/\nZuFp3B4NPKBWdhZw/8w8bZC2M/NCSqDJ72svPSUibjBIHX18BXh6Zl435Hr3rP3/bEpGuKFl5jcG\n3BeOzcxnZObQ5y6Z+TfgoZQb1r32jYhtGlYZxrXAPpl54pj1DC0z35+ZBw16zlNb9xTgwZRA3F7P\na6NvM+rvlKwk9e/ZHNU53Ndrxety/ZQux1EykFwxQLsvYe7x9h4RsXKAdae6H1R93LfhpSdn5qoB\n2r+Oct77pUHam0eXj7OSJEmStKwY1KC21G/+Q5lfcyZFxJbAU2vFFwL7DRtgkJknAO+qFd8xIvYY\nuYOD+wDlKZJeDxuzzj8zN1hiIU3zmN51zH5M0/HAO4Z4/+cp+0+Td+Vwc8t+uKFstyHWH1m177+o\n4aWHLrRuROzE3Bs2lzFgQENPHz4HfHnQ9zfYH9i6VnZYZh4xQl3/SZmKptfzIsLfVknLxf6UAL1e\nX6lu4DTKzKspmX/q6uOuhTy9oez/MvMHg1ZQ3ZA6eMh2e32SuYGr+0bEukPW8yTmnpcdkQvPe//i\nhrKnZuafh2k8My9mbtDrupRpOkZxGfCsYX7fe9TPKX46QmDEosrMv1OeZu+1HiVD1Tjen5k/GbOO\nqaiCSf6vVrxrRGw/jf50wKsysx74Mp/6tlstgQMGPR/NzEtoDjJblPOwMfeDpzP3uPjl6lxg0PaT\nMk3MOMFrXT3OSpIkSdKy440XtWWjhrKLF70X7Xk6Jd1+r7dn5gUj1vduytODvcYNLlhQdRH4s7Xi\nW0TEVmNU+/7MvHTIdVZR5hzttesYfZi2dw5zgb26gfPdhpeuoTngY766jmNuwNAuw9Qxjuri+tm1\n4rsNsOrjmXvj6xOZWQ8KGMTbRlhntXpAznXAq0epqPpc31cr3hG47Sj1SdIM2r+hrCl7wSDveUJE\nrDdIoxGxMdc/tbzadcAbBlm/5u2Up6iHVo0Lv1or3hLYZ8iqmrI7HDrfChFxN+DuteKjM/OYIdsG\nIDO/DpxaKx51rPr/MvOPI667ovb/dUasZ7F9Gag/OT/I+KifpJw/zLKmm+njbJNZdRHNmdbm0y9L\ny9czsx6wvpCjG8p2GbKOcYy6HzRl9Rs6a0t1rvHJYdeDzh9nJUmSJGnZMahBbVm/oeySRe9Fex7U\nUDbKk9wAZOa5QD39bT297qT8oqHszmPU9/lhV6gCAOoXcLab0Sfar2W0TAFNFyCPH/Gif72uGy7y\ntqzvU4OkDK9fEITRv1M/BM4cdqWIuClw61rxd0cMrFjtGw1li/XdlqSpqaZ/uHmtuOkm/xxVgNxv\nasVbUNJzD+KuzB17fneU39TMvJwRxjY9moIPBp6CIiJ2ZW4w3MlV6vb5NI1VDx+03T6+Wfv/HSOi\nKXB5IZ8Yow9/rf1/zzGDcRdFNU1XPbX8UFOq1Pw4M88YY/0uaDoHGWebzKpvDBsQXj3V/6eGl5qm\n9FtI/RwMYDEzZgy9H0TEtpRA4V5nZuYPR+zDqMfGLh9nJUmSJGnZmcUbiuqmpvS4w6be7YSIWIe5\nT4/8JjPPGrPqk2r/37VqayxRbBQRW0XE1vWFufOYAmw7YnMXAb8acd1za/8PYJMR65qmUzLzHyOs\nd15D2bEj9qFeV6vbMiLWjojNm/anap+q//2DzBldnyLjGkb8+6tUsqNc1Lx3Q1nT02vDOIW5x7/l\n+BSipOWnabqIT2dmfTqGfpqyNQw6BUXTtEtfG3DdJuOsexTwt1rZPtVUZoN4SkPZvFkaKpP4TauP\nVddmtJvQx43Rhx/X/r8p8OWIWDlGna2IiHUiYot5xkf1gO5Rx9sAPx1j3UUTEevPcw5Sz9AF422T\nWdXWeH/Uuprq2WyEevqawH7QNJYeeGqhBj+m+Zx4IV0+zkqSJEnSslNP7ymN6nLmpodt9WLJItoV\n2KBWdloL9dYveK8LbMXcm/19VRd0HwXcCbg9sBLYmOEDlDYf8v2r/XHE+ZGhObXzpszeNCV/GHG9\npr9/1NTMrWzLKrvDfSlzPt8BuB0lQKG+/y9kRURs3C/Yo3rCsv6U5a8y88oh2+l1MmUO8mHs3lA2\n1nc7MzMiLmDNwI4bjlOnJHVd9VRpffoHGGzqidUOp0wX0XvD64ERccMB5irfpaGsnhFrGCOvm5lX\nRcSngef0FK8L7Av873zrRsTazE2xfg0LZDqofr/rGZAuHnaO9wb1sSoM/5v2h8y8cIw+HAo8rlZ2\nD+DXEfF54FOU9O+XjdHGvKqg472BPSnj7dtSphVpykw3n83H6MZCmToWVURsDjyCst/dAdiJMvYc\nNkB78zb7NSOmfe7Q77xhaIu4H+zUUHbykG38U2ZeERG/Zogp4jp+nJUkSZKkZcmgBrXlAuZeHNl8\nCv1oww4NZf8SEaPezJ/PlgwQ1BARewBvotyUbXraZVijBpxcNEabTU/HrD1GfdNy0YjrNWUzabOu\ngbdlRKwLvAh4HnCjEftQtxlzMzistkVD2agBHeOs3/Td/kJEG1+pNQz6dK4kzarHUoIqe/0uMwd+\nijgzz4yIH7Dmk7BrU6ZueNsCqzc95VufzmIYvwWS0cdYh7FmUAOUv2PeoAZKUOENamXfzMy/LLDe\nFkA9XflmExyrDuOCcRrLzKMi4puUbdNrXeDx1XJVRBxHeXL7WOB7YwZSABARmwCvBv6Ndn7Lxwnw\nHms7tiUibga8BfgXYL0WqpzVoPdxXDTieq2cO2TmtQ1j3aHOwaawH0zq3GHgoAa6fZyVJEmSpGXJ\n6SfUlqaLDIOkpO+ixbyo0HTB5p+qNLcfA44B7kk7AQ0w/NM0q42StnOpabrA2IW6BhIRt6XMbfsW\n2gtogPn3qc0byuopmoc1yvqL9d2e93stSUvAAQ1lw2RpmG+dprrrNm8oGznzU2ZeS//AvEHW/zFz\ngyruERG3WGDVJzeUDTL1RGfGqg3G/X2Hkr3iR/O8vi5lXPxy4Ejg/Ig4MSJeU918HVpE3Bf4NfAy\n2tu+4zxA0MZ2HEtEPA84lZI5o40b2TD6Ocgsa228Xx2rFtWU9oPNG8oW+9yhy8dZSZIkSVqWDGpQ\nW05vKLvjoveiHYt5AaPvBZ0q9e3nGOzivrSgiLgdJUDmlovcdNMF0EHnXO9nlKkrFuu7vRwv2Eta\nJiJiR+A+DS99vWk+9fkW4DvA1bV6do6IpvnUezWlTm9KsT6McW+YNQVoNAUtAP/MCvCIWvHFwBcH\naKsTY9U+xr7pmpkXAHtQsiYM8rmsRZk+7g3A7yLi81UQ50AiYk/gq3QrBfyi37zuFREvB97L8FNu\naAmZ4n7QhXOHLh9nJUmSJGlZMqhBbTm5oWzLUZ+WEgAvAB7WUH4t8HXgVcBDKMEj21Iu8K+bmdG7\nUOYD1jJXzdt9BM0ZVM4DPgo8nXIT4ZaUC3kbAWs37FMfH7L5phsSmwxZR91IcwFLksZ2AM2Zo35M\n+T0ZZjmd5ps5ByzQh6abW+su2PP5jfsE8mGUKSx67TfP+x8DbFAr+3+ZecWY/VgSMvPqzDwIuAnw\nbOB7DH5T85HACRHxrIXeWAWXHAFs2PDy2cAHgP0p06TsSBkfbQis1TA++u6A/eu0iNiNktGryYnA\nWyn7727A9pSn6tevb49qm2hGTXk/8NxBkiRJkjTHOCkxpV795lDeDfj9YnakBX9rKHsn/S/qjKMx\nVXJEbAG8tuGlbwDPyMyzh2jDJ6wE8DTgDrWya4BXAO/LzGGeXhp2n2qa53rzIeuoG2VO5qbv9l2B\nM8fryhxTfbpSkiYlysTs+y9CU4+PiBfOc4P/ooayTYDLxmhzrBtemXlWRHwPuG9P8c0i4l6Z+YOG\nVUadegKaf89+Ajx0wPWHMc42HVtmXkwJLPhARGwE7F4t9wbuQXMwApQglf+NiMszc75gzP9kboaG\nS4HnAocNme5/qYy538XcwKXfAU/JzPmmBVlDRCyV7bFcTXM/6MK5w7I5zkqSJEnSrDCoQW05Djgf\n2LpW/gTgU4vfnbGc31C2bWY2lU/KQ4GNa2U/BPYZYS7Vrdrpkmbc4xvKnpuZHxyhrmH3qfOB61gz\nO9DOI7Tb6zYjrNP0Hd5skb/bkjTL9gRuugjtbE6ZmqHfGPKihrLtgL+M0lgVTDpupgcoQQn3rZU9\nGVgjqCEidmh43xn1982j6Xdr+6X+e5aZlwJHV8vqqdruTdlX9qN5Xvr3RMSXMrPpJinAvg1lj83M\nr43QxZkfc1f75u614guA+2bmH4esbua3x3LVgf2g6Vg+7rnDsOsvy+OsJEmSJHWZ00+oFZl5HfCF\nhpf2iYgbLXZ/xtSUBWHgeXlbsndD2WtGCGgAWDlmXzTjImIDykX/XmcBHxqxypXDvLl60vbUWvE2\nEXHjEdsHuNMI63Thuy1Js+ypi9jWAfO8dnpDWT0b0TDGWbfXZ4HLa2WPi4j61BZPYu552GGZWZ++\nop+LmZueffuI2HzA9ZeEaoqK72Tm84EbAwc3vG0z4ClN60fELYCb14q/N0pAQxVgsf2w63VQ0znI\nf49wIxs8B5ll094PTmgoG2XsD/wzSKNpCr75eJyVJEmSpI4xqEFt+r+GsrUp6VtnycnMfQJwl4io\np6adpJvU/n81o8/Te88x+6LZd0Pmzln+7SFunvxTRGwL3GKEPvykoewRI9RDRGzF3CCNQaxqKHvw\nKH2QpOUmIjYBHtXw0j2b5lAfZqFkWagHbu4VEf1uEh/XUHa3Mf68cdb9p8y8BPhirXhz4GG1sqap\nJw4bop1k7rgwaL4RuSxk5mWZ+QrgvQ0v36/PavXxNsC3RuzCrsAGI67bJW1uE89BZte094OfMzdA\n7N4RseWIfXjEsCt4nJUkSZKk7jGoQa3JzJ/SfOP9JRHR1hNwc1TzO7emyjrRdAGjKT3tpGxb+//5\nVb+GUqVTrqc31vJT359gxBTdjBiIQPOF0KeN+P3dj7lBGoM4hjINRq89I+IGI9QlScvNvsCGtbKz\ngGPHrTgz/wJ8p1a8Fn2esKc5UO7RETHquU2bY7xDG8r+GcQQEXdm7hRKP8rMpuwT86lvLyjTvi13\n728oa7pBC90YH3WN20Qw5f0gM69h7vn4ujQHhM2rOtcYNcuQx1lJkiRJ6hCDGtS2VwD1p7/XAT4W\nEW3MVfxPEbF2RHyAkla2bZ9tKHt5lcZ/MVxd+/8mI9bzTObegNDyU9+fYIR9qrpZ9IIR+3Ak8Lda\n2a70v2HVrw9bAq8ZpQPVfNrH1IrXA145Sn2StMwc0FD2yVGy/vTxiYay/ZvemJm/Ze60RtsBDx+2\n0Yi4C2OkNW/wTeDcWtmDI2Lr6t9NN+WaAiEWciRwTa3s4RHR5t8yi85sKOs3fm9rfLQR8Ixh1+uo\ntrbJ3YF7jN8dTUkX9oOmLJCvGSFbw/6Uc45RHInHWUmSJEnqDIMa1KrMPBb4QMNLdwI+Vc03O7aI\n2A74BuWm/SR8EvhdrWw74N0Taq/ur7X/bxwRdxymgoi4OfCq9rqkGVbfnwDuNUI9LwV2HqUDmXkl\n8LGGl94VEbcfpI4qMOpwYKtR+lB5U0PZsyPCjCaS1EdE3JLmFOKfbLGZzwNX1MpuFRH9bog13fB6\n2whBtO8a8v3zysxrmRugsQ7w+IhYwdynfK8EPj1CO2cCR9SKA/i/iFjOAa03aiirB5ms1tb46O2M\nNzbpkrG3SfUdbDof1Ozown7wReDPtbKtgMMGvaZQZYt856gd8DgrSZIkSd1iUIMm4aWUeTDrHgl8\nOyJuPGrFUTwVOAW4/6j1LKS6IH1Qw0vPiIjXjlN3RKyIiP0ioimt52pNc0W/Yog2tqZkm9h42P5p\n6cnMc5h7QX+XiHjQoHVExEOAN4zZlTc39GML4OiImHd+2oi4EeXi5oOroqGnYwHIzFXA92rF6wCf\nq1KCjywibhgRQ6fFlaQZcEBD2amZ2TTeG0lmXgJ8teGlfmnDDwUuqZXdEvjQoFMbRcRbGO0m9kL6\nTUHxQOamdf9yZl40YjtvZu5TxLsCn66yB4wsIu5bPWW9aCJix4h4zpiZ0ZoyJpzS570nM/eJ9IdF\nxO0GbSwi/g141qDvnwFN5yD/MehnEhFrAx8ChgrGVudMfT+opqB4ccNLDwG+GBE3XKAPDwKOppxr\nwIjnDiyx46wkSZIkzTKDGtS6zLyUkv73Tw0v3xv4ZUS8dpjUkRGxQUTsRwmW+Ciw9QKrtOEQ4AsN\n5a+PiC9GxM2Gqay6UPtySgaIw4BN53n7UQ1l+0bEGxeaL7pKo/x9YJeqqK200JptTTeKjljoQlo1\nzcvzKOlXVz/9OtI+VU3/8O8NL90A+HpEfD0iDoiIXSLiRhFxq4h4UET8D3Aa0BuE8b5R+lA5ALig\nVrYV8MOIeNEwN1Oq7bNHRHyUkvJ6UtljJGkqqnFH01RBTdNFjKsp88Pjmo7LmXkBzdMH7Q8cGhF9\npyeLiPUj4l3Af/YUj3rDa47M/BlzA3x3A17X8PZRpp5Y3c5vgJc0vPRQ4LiIuPcw9UXEFtXv8E+A\nVcDAN/dbshnwfuCsasx722FWjohnAy9reOlTTe/PzL9Txsy91ga+FBE7LdDW+hHxJuDDvVUO0d2u\n+i7wj1rZjpTgz/nOXahuMn+B66eNWQrbY7nqxH6QmZ8EvtTw0oOB0yLivyNi7+qcYfuI2DUinhoR\n3wC+xvVBZH+gnMuM0oeldpyVJEmSpJm1Ytod0NKUmWdFxH2AbwM3rb28CfB64FURcTRljvufA2cD\nf6dcVN4E2AG4LXAfYG9GmMdzHJmZEfEU4CfAbWovPxzYJyK+RJk7+UeUJ9AvpNz43ZwSeHF7ypMc\nezDEfM2Z+f2I+BGwe+2lV1PmZX4/5WnzP1KCk24A3A14DPAoSlrM1d4L/MegbWvJ+i/KxcW1e8q2\nBL4fEYdTbiSdRNmHN6V8/x5ICQDo3f/PBk4EHjFKJzLz8xHxOsoxoG7valnI+ykXJuv7df0pqn59\nOCMiHk+52Nm7PdYD3gH8Z0QcQbmg+3Pgb8CllGPQZsBNKN/rO1ECLebLuiJJs+4BQFOWrcYbxWP6\nCiX7Qu9Ns80o2b6agij+F3gic8dL+wEPiIiPU4L6/gBcS/k7HkD5bVvZ8/4vUI7p9THrOA6j/Pb2\n2q32//Mov0Ujy8z3VJmG6pmCdga+V904+wLl5v1ZlKC+6yhj1S2AW1N+0+5BGa+2MlXcmLahjHlf\nHRG/BL4FnAD8jJIW/0LKb/6mwM0pU6M8hfJ31H0lM384T1tvBe5XK9sROCki/g/4HPALyn65RfXa\ng4GnUcYDqx1PmUqkaZqWmZGZl0fE+5ibIe7BlMD0/wa+TgnSvpLyWd2Gcm50AGtmifMcZEZ1bD/Y\nn3IMqGdU24wSLN0UMN3rKsq0P//W8Nqg5w5L8TgrSZIkSTPHoAZNTGb+LiLuSrkI/YCGt6wL7FMt\nozoWuHyM9eeVmf+IiAdSLtrUn6JYm3KR/ZETav65wA+A+nydd6ZkqxjEf1Eu5ntBcZnLzF9FxNuY\ne3FyBeXi4wEDVHMpJXBmoYuHC/XlDRFxHWU6i4FShPf4KGV/rt+AgLlpyOfrw9ER8UTKTaf6/Ovb\nAC+oFkla7pqmf/hxZv6+7YYy84qI6H3Ct7cPc4IaMvO6iHgs8GNKMF6v7YCXV8t8fk+ZsuD4kTrd\n3xHAwawZPFf3ySrF+rj+jfJbtm/Da3erlll1G+YGFw/qDJqno/inzPxmRHySctOz14bA86tlIX+h\nbPtBx+dddzAlSPpWtfLtgbdUy0K+T/nueQ4yuzqxH2TmRRGxFyUAbNhj2WXAvpn5gyqTS93A5w4s\n7eOsJEmSJM0Ep5/QRGXmeZSnvZ9OeRqtLb8FHpuZu2fmlS3WO0dm/pHyBODHW6z2GspTI/O1exLl\n6cMrRmzjHSx8IV/Ly6soN1lGcQGwT2Y2zbE7tMx8E+V71W+e67q/Ak/LzH/NzGu5fn7cXhcP2YfP\nUJ6o/PUw6y3g0hbrkqSpiojNac7M0zRNRFua6r5fRNSDFgDIzD9Rpjc7bYS2fgc8MDPPH2HdeWXm\nnylPF89n5Kknam2tfhL5JYw+bmxyWYt1LbYTgXtWn8NCngZ8Z8R2zgYeMIkgn2nJzEsoqfXPGbGK\n7wAPm/Q5miarS/tBNYXdPSnHuEHH2t8D7pKZX6n+Xz93uI6SJXLQPniclSRJkqQpM6hBE5fF/1HS\n/D6XuXMMD+pK4LOUiys7Z+Zn2+nhwjLz75l5AOWi+dGMNjdoUlLnvgjYPjPPHqDdL1Ju/A6zzc4C\nHpmZL8lM57LVP1Xfxf2AFzLck0lfBu6Umd9tuT8/BnahpGH9H8r348+UoJ9/UIINPk1JKX3TzPxY\nz+rbNFR50Qh9OJ4yTcwLKN+dUVxKSU/9cMbLPCNJXfN4YP1a2bWUY/OkfIsSyNZrLcpvQaPMPAu4\nK/Ae4OoB2riW8lT9XTLzdyP2cxDzBS38MjNPaKuh6jf+HZQ05x9l9JtuZwJvA26TmU1TfkzSbymZ\nEb7FYJ9jk3OrOnYbMKCBzLyCMv3VWyjnG4O4DjiEMj76xQj97LTMPJ0yJcuRQ6z2D+CVwN6ZOVSg\nqbqpS/tBZl5bHeO2p2SY+yIlmO1iyrnDBcBPgXdTAprum5m9wW71c4dLhj1XXiLHWUmSJEmaWeE9\nT01DROxESR9/N2AnyhzGm1MunF9BuTl5IeUm43GUCxQ/6soFsojYkfLk4r0pN0Rvwpop7K+ipKL9\nFfBLSlrkb1eZK0ZpLyg3Sx8D3IeSYnn19DFXU540/AnwJeBLvamMI2IryvydvX6Tmb8ZpS9aGqqn\nb59GyaRyV8rTS6ungriYst8eQ0mN/Yvaundkbprvb2fmxKaCqYuIjzF3yoydM/NXY9S5FuX79RDK\nselWwLZcv12SEsBwFuUi6inAKuDYzBz15oskqUURsT0lAGIvyrQFW1JuQF9AOXYfA3xiMZ6sj4gN\ngJfRPNXSsZn5jQm2vQXwMMp4+07Ajqw5z/21lG3yG8p2OYHyW/7bSfVpGBGxMeW3eHdKAOQtKH/D\nJj1vu45yznAqcBJlyrVvVRmdRm13O0qa+QcAuwKb9rx8ftXWt4AjMvOM2rr3ALbqLet5SnxmRcQu\nlClh9qSMjVYHOiXlKf6TKVMDfLJ+rhYRD61V97fMPHaS/dVkzPJ+EBHrUgK61+sp/klm3n3Memf6\nOCtJkiRJs8agBqkFVdDBhpTAhkur9JSTbnNTysXcS83IoHFFxArKRbjLu54uuPq+nQ7crKf478Bm\nbX8Xqu2yYfXff2TmdW3WL0nSYomI9YENgCszcyZTnkfE2sBGlBup/5j0GDgi1qnau9QAxiIiNqQE\nV186TgCJZtss7QcRcV9KIHKv92fmcyfQ1swfZyVJkiSpq1Ys/BZJC6kuqF7K4HN8ttHmMNMHSPOq\nsntcNO1+DGhv1gxoADhhEjc2qu3id02SNPOqaRbanAt+0VU3Txftd7kKZLhosdqbBd6oFczcfvDs\nhrLjJtHQUjjOSpIkSVJXrTXtDkiSNKgqfexBDS99ZrH7IkmSJKm7IuLuwKNrxVcAX55CdyRJkiRJ\nYzCoQZK0qKr5Z0dZb23gfyhz1vb6B3D4uP2SJEmS1B0RsX41pcMo694U+BRzM5R+JjMvGLtzkiRJ\nkqRFZVCDJGmxfTkiPhkR94iIGGSFiLgl8DXgXxte/p/M/HurPZQkSZI0bSuBsyLiVRFxo0FWiOLR\nwPHATWsvXw28o90uSpIkSZIWQ0xgCnJJkvqKiB8Dd6v+ezbwFcpFx18AFwCXApsAWwN3BR4A7ENz\nIN4vgLtk5pUT7rYkSZKkRRQRtwZOq/57HfBD4DvAicAZwEXAtcCWwA7AvYBHAjv3qfJVmfnmCXZZ\nkiRJkjQhBjVIkhZVLahhHOcCD8jMU1uoS5IkSVKH1IIaxvUF4LGZeW1L9UmSJEmSFpHTT0iSZtEp\nwG4GNEiSJElawDuBxxjQIEmSJEmzy6AGSdJi+wTw2xHXPQt4DmXKiXPa65IkSZKkjvkL5dzh7yOu\n/w1g98x8cWZe1163JEmSJEmLzeknJElTERE7A/cGdgNuAdwE2ALYEAjgQuAC4I/AD4BVwA8z8+pp\n9FeSJEnS4ouI9YDdgXsCdwRWAjcGNgY2AK6knDdcAPwS+C7w7cwcNZBakiRJktQxBjVIkiRJkiRJ\nkiRJkqROcvoJSZIkSZIkSZIkSZLUSQY1SJIkSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIk\nSZIkSZIkSZIkqZMMapAkSZIkSZIkSZIkSZ1kUIMkSZIkSZIkSZIkSeokgxokSZIkSZIkqUMiYo+I\nyNqyctr9mraIWNmwXQ6Zdr+kcTXs16um3adBzXLflwK3/9Lg7760sBXT7oAkSZIkSZIkSZIkSVpT\nRNwQuBNwM2BT4Grgb8CvgOMz88rl0A+DGiS1JiIOBF5XK94zM1cNsG7Wir6bmXu007Puq54q2L9W\nvGNmnrn4vZEkSdKoHBOPzjGxJC2uiNgO+AOwdsPLu2bmyYvbI0mSpOmLiK2BuwB37Vm2a3jrxM5X\nI2Jt4MnAc6q+RJ+3XhoRXwTekZknLtV+gEENkiQNJCJuCdwW2AbYGriGEoX4a+CEzLxiit2TJEmS\nJq66oHUn4DaUcfF6wD+As4CTMvOsKXZP0vCeTHNAA8ABwAsWrSeSJElTEBHrAXdnzQCGHafcpzsA\nhwJ3HODtGwFPBJ4QER8EXtjWvYqu9GM1gxo0c6p5hM4YYpWkXGS5GLgA+AVwHLDKiHNJ84mIHYCX\nAg8DVs7z1ssi4qvAezLzB4vRN0nS8uaYWNJiqgJ8Xww8Dthinvf9Evgg8OHMvHyRuidpdAfM89oT\nI+KlmXn1YnVGkiRpCu4BHDPtTqwWEXsDn6MECQy1KvAs4E4RsXdmXrQU+tFrrbYqkjosgE2AGwN3\noEQKvQs4KSKOi4j9ptk5Sd0TEetFxDuA3wLPY/6ABoANgccA34+I/xcRW064i5IkDcsxsaShRcRa\nEfF6SiDUM5knoKFyG+A9wKkRcd9J90/S6CJiN8p3tp9tgH0WqTuSJEnLXjU++wLNgQSXAt8DDqEE\nG5zSp5rdgKMiYt1Z70edQQ1a7u4CHBYRX4+IG027M5oNEfH4iDi/tjx+2v1SOyJiW+A7wIso6XSH\n9RjghIi4RasdkyRpchwTa2iOiZe+Kg3rl4DXAsNeiNoR+FZEPKX1jklqy1Nbeo+kFmRm1JY9pt0n\nSdI//WPSDUTEpsCngQ1qL10LvAG4SWbeNzOfmpmPycw7UKbK+E5DdbsDb5nlfjQxqEEq9ga+V93M\nlBayPrBVbVl/qj1SK6of7O9RfmybnEu5sPt/wKeAnwLXNbxvJXCMN4YkSTPGMbGG4Zh46Tuc/k9p\n/wb4IvAxSqrWixreswL4aEQ8fCK9kzSyiFgf2LdWfBVwYa3sIY4LJEnSMnMF8BPgvylTdd0O2GwR\n2j2QuRmjrwIem5mvy8wL6itk5vHAg4BPNNT3gojYZYb7MceKNiqROuJOwDkN5UE54NwIuCflILRT\nw/tuDnw5InbPzGsn1Uk1y8yYdh+mKTMPYP65LLU4DgFu1VD+C0rmhm9lZva+EBHbAa+hzBPVGyx4\nY+BTEbFHZjYFPkiSNAmOiWeYY2LHxF0REc+nZCCrOwl4SWZ+p/b+jYCnA29mzSd61gY+HhF3zMyz\nJ9VfSUN7BHOnk/kKJZD/OT1lK4AnUaaskiRJWoouAD4EHA8cB/wiM6+pvylicqfrEbE98NyGl16f\nmV+Yb93MvDoingrswppTi60FvAl46Kz1ox8zNWgpuTAzz29YzsvM0zPze5n5FmBn4PnA1Q117AaY\nHlNahqp0yY9seOko4C6ZeXQ9oAEgM8/NzH8HHsfc48q9gWe03llJkvpzTCxpLFW2sYMaXvoOcJ96\nQANAZl6ame8G9mJu1obN8Yao1DUHNJQdVi11TkEhSdIyFRF3iohDI2KdCdX/mmlPWZeZP8/MZ2bm\nhzPz5KaAhkXwfKC+jU8F/muQlTPzKkqQed1DIuLWM9iPRgY1aNnJzOsy833Ak/u85bUxyZArSZ1T\nfedf1fDSacBjMvPKherIzM8BL2146fURUZ9/SpKkqXJMLGkerwY2rpX9hZJudN65ZDPzhzQ/2fOo\niLhLS/2TNIbqCby9asUXAF/NzB8Dv629dvuIuNOidE6SJHVGRNwW+CblusGnIqLV7P8R8UrgDZQp\n6+rTYi0bEbEWsF/DS2/PzKYHURpl5o+A79erZ8CHVrrSj/k4/YSWrcz8dEQ8ivJ0da+VwO2Bny96\npyRNywMoc2PVPSszrxiinvcC+wO79pRtS3kK5n9H7p0kSRPimFhSr4jYkuYnuF/YNHdqk8w8IiL2\nZ+5N0xcDTxivh5Ja8BTmPuj2qerJOoDDgdfXXj8AOHESnYmIzYDdgVsCm1LmsT4XOAX4eVPGxKUq\nInYE7gDcBNiEMn/1Xynb46eZedEi9mUT4C6Uz2ULylObF1f9+MkC664F3ILyt2xNmQJtBXA58A/g\nj8BZwG8HeYhkwP6uR5la7dbANpR9CeBvlKCdXwG/XMz9aZxt2FL7ndsmgoi4ASU1+00p2azWpXwv\nLgR+TflMLplW/+qmcVzq6r4bEZsCd6v6thlwGXBe1Z8Tl9p3KSJuAXwL2KoqehRwREQ8sY2pKiPi\nZVyfHW5t4PCIuGqhKQ6WqN0pU4X2uhz47Ah1HUbJHt3rMcArZ6gffRnUoOXuXcy9gAvlBqcXcKXl\n44ENZT/PzO8NU0lmZkT8N/CR2ktPxaAGSVJ3OSaWtNrjgHqWsTOATw9Zz8HMDWp4RERs2qUL9dIy\ndUBDWe+0E4cDB1KeqFvtiRHxkp7Ah7FFxG6UC9v70P8a9Z8i4sPAu0e9cRYR9wG+Wys+NjN3H6W+\nWt3/BbykVvyazHzTEHXcDHgO5fi7wzxvvSYifgR8HPj4KDeUImIP4Jha8esz88Ce9zwIeAFwP+am\nn6Zqv/GGfETcnZJy+tGUG34LuSoiTqr69OnMPHmAdVa3tRblZsmDgD0pwQNrL7Da+RHxbeAdmXnc\noG3V2t2DCW3DiKjfEP1uZu4xRN+msk20sIjYAXgm5YberRZ4+3URcSLwdeDQzKxnz5m4xTwuVe1N\nbd8d8Dt9F67/vVq3T1V/jYhDgYMWMwBtwh4KbFcrexxwbUQ8eZzAhoh4EfDWWvEK4GnAcgxqeEBD\n2fcWypLXx1ENZbeMiJtm5lkz0o++DGrQcncccAnXR/utttO4FfdEMd6UEsV4LXA+cGRmnj9EPVsB\ndwJuRoneXIsSAfgX4GeZefa4fZ2n7VtSBhE3pvxg/x34PXBCZv55Uu22rfo7bkMZhG1COTG/hLId\nfwn8us0T8y6pnni4K2X/WR0Vvjoq/lTgtMWKIq1SWO9C+V5sSxmcnlf15Qcj/ji2Zc+Gsi+NWNeX\nG8ruGhE7ZuYZI9YpSdIkOSaev23HxDPOMfFQHtNQdkRmXjdkPccAf6B8b1Zbn3Jx9BMj9k3SmCJi\nd+b+vp9eTTsBQGb+PiJ+CNyr5z1bAQ8DPtdCH9ahzMv8fNYMnGhyI+B1wNMj4kmZuWqEJr9PCc7a\nsafsHhFxi8w8fYT6gH/ehHtirTgpQSGDrL8p5QnVZ7PwzTso1/HvUy0vjIjnZGY9tfPIImILyo3J\nh4247ntpTlk9n3UpTz3fDfjPiHhkZh45QHsvAl4EbD9ke1sD+wL7RsRXgQMy87wh65ivXyNvwxba\n7uQ2We6q7/kbKQECg96LW4ty7nEX4NUR8a7MfNGEuriGaRyXurzvVlMtvJUSpFTPcFS3LSXI7anV\nsay14/O0ZOa7q/PwV9deegIlsGH/Ec4RiIj/AN7R8NK3aX7YYjm4Z0PZUA9brpaZf4qI3wE3b2hj\noWCCrvSjL4MatKxl5rURcTZz085v3fT+iDiEklq+146ZeWb1+saUH/1/o/9F4NOBVfP1q7ro9nTK\nydGuC7z3VEr6l3e28cRL9WP9r8B/ADv3eVtGxLGUuXRaiZwbNxq5ob67As8A/oWSpmo+V0TE9yjR\nY0dk5t8a6juTcjG+n49FxMcG6No/95da/Ycwz741jOoz3I/yOe7O/IOuv0bEZ4D3jHoy37BtzsrM\nlT2vb0EZnP4bc6M7V7sqIr5DeaLh+FH6MaZbN5T9dJSKMvO8apusrL20N/CBUeqUJGmSHBM31ueY\n2DHxsO2dyYyPiSNifZovZDU9ZTOvKoPZ1yjf4V57YVCDNE1PbSg7rE/ZvWplBzBmUEN1nDmScn48\njBsBX4+IRwKnDbNidTw6HHhN7aUnUwImRnV/5qZo/v4gv1kRcWvKgxS3HLHt2wHfioinZOawmXSa\n+rMt5abFQk+QN627DSU9+R3G7QeD36t4OMPfAK17CHBcRDw0M38xZl1jbcOWdG6bLHcRcXvKE+f1\nG3rDulkL3VnQFI9Lndx3I2Jdyuf3kCFX3Qr4RkQ8ODPrWYJmTma+pgpGfHntpf0omTr+dZjAhoh4\nDvDuhpe+Czw8My8fubOzrel6x0lj1Hcic489u7LweVhX+tGXQQ1SeTqpbpAUaWuIiD0p0bjzpWRa\nqI4A/h140xB9uG21PCciXpWZHx6j/VsCn6E8OTTvWykXBj9fRUI+qStplaLM9fRe4MFDrLY+ZfqB\nBwJvi4j9MnOUeYKmLiLuC3yYwQeg2wLPBZ4dEe8FXtXm4CEi9gYOof+F29XWpaQY2zsi3p6ZL2ur\nDwuJMk/bRg0v/WmMav/E3KCGe2NQgySpuxwTX9++Y2LHxMtuTFy5M2U/6HU5MGqAxfeYG9TQFDQh\naRFExAY0PwHZlFngM5TfkfV6yh4UEdtl5rljdOMQ+gc0XEB5uOBPwIaUc+rduD4obT3KVDiPGqHd\nQ5kb1LAf4wU1PLlPO/OKiNtQAjv7BRv+ETiZktnqWuAGwN25fk7z1dYFPhkRjBnYsDYl0KR+M/73\nlGxGf6VcM9mekvWo7hD6BzT8gRKE8mfK3PMbUDKD3ZSSPar+mzOO6yhBs78FLqZk1tqI8ht/J5qD\ndW8KfDkidsnMi8doe9xtOCnT3CbLWkTsSnnqfIs+b7kKOIFyvPsb5XPZmvJduuFi9LFXB49LXdh3\nP8rcgIbTKd/p8yi/U7egZNSoB09vAHw8Im7XgSxsY8vM/6wCG+oZQw6gZGx4+iBZ7yLimcB/N7z0\nA2CfzLxs7M7OoCoorv5dAvj1GNX+pqGs6aHOzvVjIQY1SHPT7EL5sRxYRDwa+CTN86QNWsemVR3D\nRv+ttg3woWoQ8qJh06dGxJ2Bb9B84JrPQ4DvRZmjcKoi4nHAR4CNx6hmPdZMkTozIuK5lEjHQdKD\n1a0NvBC4dxXV+pcW+rM/5fMYpj8BvDQiNsrMfx+3DwPask/5OAPgixrK7jRGfZIkTZpjYhwT93BM\nvPzGxNA8Xj0lM68Zsb4TG8puEREbL4ULvNIMejRzf+9/mJm/r78xMy+KiK9U66y2OgPO20dpPCKe\nSEkVXvcnyrH3C5l5dW2dGwEvpUxVsRZl+qSmGyLzyszTq+xK9+gpvllE3DMzfzhsfRGxEXODKy4H\n/t8C621IySxVv3F4DfAx4N2Z+cuG9daiBL0dDNy+9yXggxHxk1GyGlXqGYQ+BxzY9LRzRGxJz437\niHggzWO2I4C3ZuYp/RqNiLUpAaQPBx5L/8xY8/kT5UnqzwM/nu+GWETcg3JDrj7N0krg/Qw/dUav\nkbfhBHRlmyxbVXauz9Mc0PAbSjDVV/qNhSLiJsAjKCn+7z6hbva215XjUpf23SdwfcbBpDxV/sbM\nnHNzNyJ2AN7Z0JebAi8DXjtmXzohM19cBTY8r/bSv1IyNjx7vnPviPhX4H+ZO+3UscBDMvPSVjs8\nW3ZsKLsOGGeKzaYpsJva6WI/5mVQg5a1agDdlD514Pl9KRd+jmDNi7fXUuYmPptyMfgGlDQrt+3T\nj40pF0/7DVQuqeo7lxLZvDpysyl91QsoUZEDX/yKiJsCX6X/xdtfAr+iRM1vW7W9suf121N+3EdK\n19+GiHga8H/0n4/xYkoE7F8p23NzygnHrpST4pkWEc8A3tfn5esoT1edSYlsvRHlKaxtG957F0qq\nsN0z8+9j9OdBzL14ewnwE8p+fFXVj3vSfBPlORHxtcz8yqh9GEK/uaPX61M+iKanDW4ZESvGuCgs\nSdJEOCb+Z/uOiWecY+KxNd1QGufJnNMp27336bWg3Mw5YYx6JY3mgIaypqknVjuUNYMaVtcxdFBD\ndSP33Q0vnQDslZkXNq2XmX+izNP+TcqT8Osy+g3hQ1kzqAHgKcDQQQ2UgIZ6xscvDjAF1vuYe6z9\nE/DYzPxRv5Wq1N5fjTJF0aGUIIDVNqNkJ9prkI43WH0z/jrg6Zn50Xn6cQHlBtRqT2x42ysz8y0L\nNZqZ11I+/xOA11VZjQbNAnIq5ff1k4NeY8nMY4HHRsTjKft9732RJ0TE6zLzdwO2XzfONmxL17bJ\ncnYIc7O3Qjl2vmKhzyczz6ZkynlvRNyLyWf3mPZxqYv77uqAhiuBJ2dm34C1zDyn6suHKQFOvZ4W\nEa+vjnczLzOfXwU2PKv20jMp5/6N594RcQBlf6ifo/4UeNA451tLRH0qK4C/jXkPoSk4f6FpXrrS\nj/llpovLTC2UQUE2LCtHqOsefep6cZ/3H9Lw3j/3/PtiSgT5Fn3W3wXYvqH80D79+CZwP2CtPvXd\nHvhin3UfOcR2+EafOo4CbtdnnXtSLir3vv9XDXXsMWAf6uutGqL/u1EGGf224X2BFX3WXQu4IyW9\n8e+rdV7Q8L4tKBfOt6akpq2389ye1+db+n2WTfvWQPs0JWXf5Q3rX0uJFG3a59YG9qGk8Wrabh8d\nYvufWVv3AkoartX/P5USYbxOw7qrIzwvbejDGf22V5tLtQ9c19D+vcao84Q+2/Vmk/57XFxcXFyW\nx4Jj4t71HBOnY2IcE7dxXPlaQ/tvGLPOcxrqfMxi/D0uLrO+AHuMekxsqOsmzD3vvZI+v9XVOuvU\njmOrl7uM0P5rG+o5B9h2iDqe0udYncAhA6y/BXBFbb0LgfVG+Hu+2dCHBy+wzl0a1rkQuM2Qba9D\nuSler2vXEfep1ctLRtgO9d/PsxbrN2ucBXhOw9//zgHXbXUb1uoeeRw4zW0y7b53aaGcszTtG6+b\ncLsjbf8uHJda+NvH3Xfn+07vN0Q9G1Gm3KnXsceYfVk5rf25Tx+DEkTftL3e0/D+/SjnY/X3Hg9s\nPu2/Z4S/v+nvHuszogTD1Ov81Zh17t5Q53XA2l3vx0JLfa4Xabl5YZ/ybw1Rx+po3N8Dd8zM/8r+\nEeYnZ+Yfe8si4snMnYfvKuCZmfnAzPxOlsjHpvpOycx/AV7S8PKHq/S984qI/Sjz5ta9NjP3yYZU\naVXbP6RcAD+kp3iS6dMaRcTmlLke1629dCXwhGobfjf7RJRl5nWZ+bPMfDVlzt39aEiLk5kXZub5\nmXk+0JQe7B+rX19gafwsRxURKyjzX9YzA1wBPDAzX1Tf56q/59rMPIpy8frrDVU/NSL+ZcRurb7Y\nDSV99K6ZeWTW0khW/bg6M99HuZhc/4xWMvqTBgOrPpMLGl663Sj1VU+79psbarxIREmSJsMxsWNi\nx8TLfExcaXo6Z9CnZvtp/+kcSaPYn7lPSH6l3281lGMT0DQn+lOHabhKUf60hpdelpl/HbSezDwU\n+M4wbdfWv5ASqNhrc+Chw9RTTYlx/1rxuZRAh/k0jVNelA1p3edTfS5Po9wY6PUfw9RTcxLwrhHW\n2672/5+2/Rs/IR+gZODq9bAx6xx1G3bFJLbJcvTyhrJjgDcsdkcG1OXj0qAmte9+OTMPH/TNWaZP\naMrSMulMG4sqy53qZ1AeSKh7fkS8Y/V/IuIJlPPk+n3okynnaBdNppczZ7OGsnGzVzStHzRnB+xa\nP+ZlUIOWrSo90WMbXjoL+PmQ1f0DeEAOOX9dRGxAeWqoVwJPy8wPDVpPZr6joZ6tKHMazdd+AAc2\nvPThzHzjAO1eQ4ngOnqwnk7Ei5ibLvkaYJ/M/NQwFVUXNY/IzC+21rvJewQlXXDdkzLz2wutnGV+\nskfTPN/t68frGt+q+tFveofefqyizHlW94Qx+zCo4xvK7jViXbsCG/Z5bdj5uSVJmijHxI6J6xwT\nz7GcxsRNY9Wm4N9hNK3vmFhafPs3lM039cR873lCRAwzXeOezP2N+nVmfnKIOlYbd27yppswTxmy\njicy95r6J3Ke9OIRsT1zp/L4FfDxIdsGIDNPo2TX6fXQakwzivfO1/951Ke2XqfxXR1TBV58tlZ8\ni4gY5/dp1G3YCRPaJstKRNyK5iDp51Y3gjtlBo5Lg7Y7qX33rSOsUw+cg+ZzlJlWbfOnUqZerHtR\nRLw1Ih5HGUOsXXv9FMq0U+OeYywl9QcToP902YO6sk/5fOO3rvRjXgY1aNmJiLUi4nk0n8gAvHGE\ngcZrM3POk0wD2J/rn95Z7bDMPGKEuv6Tkrqv1/OqiPh+7sfcOYjPo6QLHkg1YH8O4x/ghhYRG9E8\nV9NBg1y8XCKa/v7PZubnB62guoj7TErqn153jIh7j9ivK4D9h/wuNc1/vFjRrN9tKHtkRDRFKC6k\n6QmU1bYYoT5JklrnmHgNjolnn2PidizW0zmjjLEljSgi7sPc37kLgK8utG5m/gT4Ta14C+DhQ3Th\nPg1lTTdCBvEjynQ/o/oqcH6t7MERUR+HzKeeWQr6j6dWewBzAwA+OWZWg2/U/r8Vo2WLugboO1/8\nAuqZNvacoZvgTVm47jxiXeNswy5pc5ssRw9oKFs1bNaDRdTl49Kw2t53/1pl5BulH/Vxf1MmtJlX\n7SdPofnY9zLgU8wNaPglcP8q65+u1xQQ2JhhcQhzsgPO01bX+jEvgxq0lGwREVs3LNtExM0j4j4R\n8QrKwfO9NH9xTmDN1LGD+Afw4RH7XE/BdB3w6lEqqtI81S+A7Qjcdp7VDmgoe2dmXjxk26ez8Anc\nJOwHbFkr+yNw0BT6sugiYkfKXFt1rxq2rsw8Hmi66DvfDfr5HJ6ZfxqyD7+jRAT3ulX19Oakfamh\nbEPKIGxg1Wcy31MeI0chSpI0IMfEjonBMTE4Jh7FYj2d45hYWlxN00V8epAMMpWmbA3DTEGxW0NZ\n/WnegVRBYk3TBQ26/tWUGy291gH2HWT9iLgDcIda8c8z82cLrNoUHDduhqeTGsruNkI9p1Zp00fx\n49r/NwW+HBErR6yvNVFsFBFbNY2Nab5Rs+2IzY2zDRfNIm+T5agpgKvLwS5dPi6tYQr77rGjrFQd\nB+oBvUs2mLcK7H8i8IWGl+sZOn4N3C8zz5t4x2ZPUyBRPeBoWE3ndf3a6lo/5jVuh6QuaUoVOowz\nKOlZh00V9uXMbJpPdl4RcVPg1rXi72Zm/cmyYXwDeFut7J6UtD5N6oOtpMz3OoojKGl3F1NTBOyH\nm+apXaKaBp8/ycz6UxSD+jjwmAHaGMTAT8XV/Jw1vxdrAzdgvCcxFpSZp0bE14AH1156eUR8OzMX\nnLczIjakXOzZaJ63zUQqRknSTHNM7JgYHBM7Jh7NYj2d45hYWiRVNp/6MQ0Gm3pitcMp88H33qB4\nYETcMDP/PMD6u9T+fy3NT9UOaqEAgoUcCjy3VvYUmqf/qWt6iGGQgMbdG8pOG2C9+fytoeyGI9TT\nb3w0iEOBx9XK7gH8OiI+TwkgObrKhjQxVRDFo4A7AbcHVgIbM/wDnZuP2IVxtuFEdGCbTFQ1Bc4m\nI6x62QT3x6YArp9MqK02dPK41JF99w9jrPt3SoDXapv2e+NSkJnXRMS+wOeAh/V5228pAQ1/Wbye\nzZSm86VxA9rX71M+X0BrV/oxLzM1SMW3gPuMeGD96YhtTiIa8hTKyWGvxmjIiNgWuEmt+LTMPGvE\ntr8HLHZUctM27HIEbNvu3lD2lTHq+wZzn6S6+ZBpGFcbKaIVOLehbLEiWt/I3CjBtYGvVPON9xUR\nO1CeFrnnAm3M7PyKkqRlwTFx4Zh4tjgmbs9iPZ0zTlpjScN5LOVmUK/fZebAx6fMPBP4Qa14bZqn\nYWhSP37+ITMvH7T9Br8eY10y8zjmZsTZLSJ2mm+9iFib8kRqr2spAY0L2aGh7MKIyFEX4NSGOuuZ\nmwYx8rzmmXkU8M2Gl9YFHg8cSfk7fxARB0fEv0REa9NyRsQeEfED4PfAO4AnUTJpbMpo9z1G/a3t\nzNzwHdomk/YEynRxwy5DZWQd0g0ayjoX8NKjU8elju27F42xbj0guD4Fw5JTBdN/dJ63fGnY7HXL\nTNP1g37BAIPqF4wwX1BXV/oxL4MatNydREmZ98DMHDUCb9TBSevRkFUavvpAul805K4NZSNHu1fz\nKI0TaT+U6qm++mDxEuaemC5lTZ/hyE9nVgOQpsFnUzvzuTQzLxqxG01z7i5KRGt1Uef1DS9tAHwy\nIr4fEc+MiF0i4kYRsVNEPCgi3k/Z73pvKFwJNN0QuqL9nkuSNDbHxGtyTDxbHBO3Z7Gezhl3SgtJ\ngzugoWyYLA3zrdNU9xoiYhPmBkcNNb1Tg3HXh+a/Z6EgjfszdzxxdGY2BaL9U0Ssw9zAkkkZJWDg\nkjHbfALwo3leX5fyAMjLKUEO50fEiRHxmoi42SgNRsQ6EfEx4Jiq7nqa81GNmklo3G04tg5uk2Ul\nItZn7pjp0iGm+VlUXToudXTfHTdT2bISEQ9k7tROvV4cEUNPDbiMXNhQNl8m6EE0rX/VAlMldaUf\n83L6CS0Xl1Ii7C6kXKA6DjgmM8dNzwujR+M2RUN+IaKt3+1/6hcN2TSv1KgpWlf7NS3MkzWgpujX\nX1QXkpeLpqfFxr2AfRolvddC7cznojHabxo0LmZE65uA29A8n+a9qmUhCTwLeE3Da02DA0mSFotj\n4rkcE88+x8TtuZS5QQiTeDpnomnIJRURsSPNc7x/fYTsM9+hBD713iTaOSLulpnzpVdvCshqCtwa\nRhs3kA+jnP/3Djj2i4jXVsGRTZqCHgaZemKU7AmjGuUm3lgZJTPzgojYg/IE/MtYOAhvLUqg4K7A\nGyLiC8BrMrMpoHCO6mbsfGnOp2GqWTk7uk2Wm6bveRsBWJPSieOS++7si4j7UwLW1lvgrW+KiGsy\n862T79XM+WtD2XYREfOMSRZyowHb6WI/5mVQg5aSHauUeItt1JOpxRo89IuG3LyhrAvR8oPaqqFs\nud0w3ryh7KIx62zahsNG+s9sNGtmXhcRT6CkOnvFCFVcBTw7Mw+JiP9peH3eJzgkSWqBY+JmjomX\nrs0byi4as87lOia+kLn71CSezulMmm5piTuA5idef9xyG/MFNTQ9pdw0Lc0wFrpxsqDMPCciVgF7\n9hSvpGRg/F79/RGxMfDIWvEllBs5y16V5eigiPhvSuaGJ1Cmhxrks34k8JCIeEFmfmCA97+A5hug\n11KmMPs+JQvZH4E/UzJmXlH18Z+qQIxjBmhvFrwAt8m01acug/GPdcvBC3DfnVkRcV/gS8wNYv4t\ncAbwwFr5wVVgwzsWo38zpGnqy3WB7Sj7/SiaHh5ZaIrNrvRjXgY1SOMbNRp3sS7g9ouG7Gq0/KCa\nLip2OQJ2EprmBhv3M2haf/Mx65wpVeThKyPiKOAg4L4DrvpD4HmZeVJEbE7zU2lnttJJSZK6xzHx\n9RwTLy7HxO35K3CLWlnT0zXDaP3pHEkLi5LyaP9FaOrxEfHCzOw31WJTkNgmY7bZ1nQ8h7JmUAPA\nU2gIagAexdwgrc9m5uUDtPO3hrI/AXccYN1hTXXKy8y8GPgA8IGI2IgyzdjulGCRewAb9ll1PeB/\nI+LyzPx4v/ojYgvgtQ0vfQN4RmaePUR3x81E1AnLdZtk5iHAIVPuRq8LKZlbewPJmsaoXTH149Jy\n3XeXioi4N3AUc4/rv6P8tp5PCXioBza8PSKuzcx3T7yTs+MsyvWUema+m9BuMMHvZqQf8zKoQVq+\nOhktPwQjYMtnWL9Avw7jPRXWdMG/aVsveZn5Q2CPiNgFeBBl/sybUlIPb0q54fE74AfAZzKzdw7J\n2zdUeWFmjjoAkCRJk+GYePY5Jm7PWZSbT71uMmpl1U3V7fu0I2my9qScv07a5sAj6DOXdmZeExGX\nseZNj+3GbHPc9Vf7LPB+1uzbYyLiuQ1BGqNOPbF6G1zMmjc4twMuycymcciSUM2XfXS1rE4zf2/K\n/rIfzYGZ74mIL2Vmv6xTDwU2rpX9ENgnM4cNsG3KdjWL3CYdUGV+vZg1g2DXiYitM/P8KXWrr44c\nl9x3Z1RE3AP4KnOD/c4A9szMP1bv+xfgK5Rr6r3eVWVs+O+Jd3YGZOaVEfEr4La1l3Zh/mxY89m1\noexns9CPhaw1zsqSxtIUEXlXYJuWl6YDBzSnZO1KtPwgmlKWbr6I7XfBRQ1l40YBN62/3FIYryEz\nT87MgzNzr8zcKTO3zMwVmblFZt4lM19QC2gA2K2hqjbmK5ckaalxTDwex8SOidvU9NTMyEENwLY0\nB/mM9XSOpIE8dRHbOmCB10+v/X/ziGh6cm9Qdxhj3X/KzH8wd/qIzYCH9xZExI2A+9XedxbNGR36\nqT9xvBZw6yHWn3mZeXVmficznw/cGDi44W2bUbJl9LN3Q9lrRrgBCmW6kaXAbdIdv20oa7o+2BXT\nPi65786giLgb8HXmBqScRQloOGd1QRUg+DCapwZ5X0Q8e2IdnT3HNZTVg80HEhHr03z944QZ6kdf\nBjVI09MUpblZZp7f8tLv4ttFDWVdiZYfRNMF3BsuYvtd0PTZ3mDMOps+w+VwAbdtTSct3130XkiS\n1H2OicfjmNgxcZuanprZZYz6mi5iXYJTskkTFRGbUKZLqLtnZsY4C+X4WL/htFdENGVlWa3pAvnd\nRvzzxl23rinbQv2m+pOYew39sGrqykGtaih78BDrLymZeVlmvgJ4b8PL9QCSXvVAu6sZ/VrLPUdc\nr2vcJt3x/Yayey96Lwa3qqFsMY9L7rszJiLuQpkepB7Efw4loGFONrZqmqaH0RwI+P6IeHrrHZ1N\nTYEf9xqxrrsxN3vjZcCPZ6gffRnUIE1P07xQ9dQuk1SPlIfxo91biZYf0O+ZeyK9c0TUowSXsqZ9\naNy5z3YZsB31EREb0HwS8NXF7oskSTPAMfF4HBM7Jm5T043H7SJixxHra3qq58QhbwRKGt6+zJ3j\n+izg2HErzsy/AN+pFa/F/E/XN6Usftwo7UfElsADRlm3j29R5pHvtXdEbNvz/5GnnuhR32YAjx+y\njqXo/Q1l82UI2rb2//Mz87phG42ILYD7DrteR7lNuqPppu1TIqKrU8BP+7jkvjtDImJX4JvMzWj3\nR0pAwxn91q2mI9qHMr3IGtUCH4yIp7XZ1xn1Deae19+smupjWE9qKPt2Zg4ynWJX+tGXQQ3S9Kxq\nKFvMaMhTgMtrZXcadaBVnfCtHLdTg8rMvzP3SaK1WF6RmU1RbSOnNYuI7Zh78ng58PNR61ymHsXc\ntNWnZ+ZYqZUkSVqiVjWUOSYekGNiwDFxa6qnq5ouSI76lGHTet8esS5JgzugoeyTLQYUfaKhbP95\n3v8V4Jpa2UMjYpTMQk8D1hlhvUZVmvP637OC6sZeROwC3L72+o8zsynN/Hy+xdzsULtExCOHrGep\nObOhbIN53n917f+jThn2TOYG/swqt0l3fIu5WdRuBDxxCn0ZxLSPS+67MyIi7gAcDWxRe+nPwP0y\nc8Gp5aopnx7M3HO3AD4cEfONI5a8PkGjAP82TD0RsRHNwUlNY7fO9mM+BjVI03MMUI8+3DMixk2V\nOpDMvAY4qVa8OfDAEavcd6wOjaYpAvZfF6Hd+sk4TOd42nQB9zERsfaI9TX90JyQmfVBpvqIiAD+\nveGlDyx2XyRJmhGOicfnmHgux8Sja8ou1vSUzbwiYgfgPgPWL6klEXFLmgPbPtliM58HrqiV3arf\nU3yZ+WfgqFrxBsDBwzQaEdsArxpmnQHNNwVFG1kaVt/MeU/DS+9ZrDFPR92ooezced7/19r/N46I\nobIzRcTNmcx+NC1uk46onkb/34aX3hURTfv6VHXguOS+OwMi4naUoOStai/9hRLQ8JtB66oC8vcG\nflp7aS3goxGx3zh9XQL+p6HsyRFRD66cz2uZGyB0LvCFGexHI4MapCmp5vWtz1GzHvDKRezGkQ1l\nzxi2kohYiyGjtVrymYayR0TELSbc7t8byuaLJJ+UHwIX18puADxi2Iqqi75Nc1h50XE4TwbqF3Iu\nBD48hb5IktR5jolb4ZjYMXGbmp6eeUBE3HTIeg5g7jWnX2fmiSP1StKgDmgoOzUzW8s2k5mX0Hxc\nfOo8q/13Q9lTImKgILyIWI/ye7f5IO8fRmaeApxcK75zdSPnCbXyq4BPj9jUe5j7e7UD8KUqYGNk\nEbFLRIwakDlqmztGxHOqKThH1TTeOmWe9zdNk/SKQRuLiK2BzwJLaZout0m3vA+4tFa2JXBUbVqb\ngUXEpmP3qr9pHpfcdzsuInamBDRsXXvpr5SAhl8NW2c1htgbqGc0Xgs4JCLqv7vLyZeAU2tl6wAf\nGeS3NiLuDryo4aV3DDnlQ1f60cigBmm63tRQ9uyIWKx5oD7O3FRP/xIRew5Zz7+yuHMHA5CZxzJ3\nLqZ1gI9XF5Un5ZKGslHSJo4lMy8DDml46W0Rsf6Q1T0buE2t7Erg/0bo2rIUETcD/qvhpTdUAzZJ\nktTMMfEYHBM7Jm5TZv6IuTf31gLeOmgdVUr5pgtZTU/9SGpJdcx/SsNLY6f6bdCU+eFx/S52Z+a3\nKDef6j4YES+Z7/cqIm5MyfSwR1U09JzrAzisoewjzP1d+0pm1tPLD6QK5DwAqE8DshtwUkQ8vMr+\nOJCI2CgiHhsR36RknWrKjjNJmwHvB86KiDdGxG2HWTking28rOGlT82zWj3jB8C+Vfvzjnki4i7A\n94FdqqK2pmOZNrdJh1Sp25/T8NIuwLERsfegdUXEXSLiU8DHWureHFM+Lrnvdt+eQD0Y53zg/pn5\ny1ErzcyLgL2Ye86xNvCYUesdV0RsGBFbz7f0WXWLBdYbKPAmM68D/qPhpbsC34yILefp+/2Bb1Km\n0Op1OvDeQdrvWj/6MahBmqLMXMXcdLHrAJ+LiDuPU3dE3DAimtLk9bb/V5qf7Dq0Shc6SDt3Bt45\nQhfb8paGst2Bj8XocyEvFAF7ekPZLqO01YL/Ye4J/c2A/xt00FmliGy6SPnJzDxvzP7NlFEv/Fff\nl28zd6B3POUkX5Ik9eGYuBWOiR0Tt6lpf9o3Iha8yFiNpz/A3Kep/0q5QShpch4A3LihfL6bxKP6\nCnOD2zYD5puL/XmUTIa91qY8HHByRLwwInaLiBtHxE4R8cCI+B/K04L371mnKV36uI4Arq2V7dbw\nvqGnnuiVmUfSHMy5PfBF4JSIeF1E7BkRO1Q3CNeJiG0i4pYR8ZCIeFVEfBE4jzJ+2WucPrVgG+DV\nwC8i4tSIeE9EPCUi7liNw9aPiBURsWVE3DUiXhARJ1J+u+tTRX0lM+uBmv+Umd8HftTw0quBn0bE\nUyPi5lWbG0bJJvH4iPgsJd35rXvWaeXmyrS5TbonMw+lOeD2ZsDXI+J71fFul4i4QfX92Kz6nB4V\nEW+NiN9QMhnsy9zvSdv9PZIpHJfcd7svM/8HOLCn6ALgAZn5ixbqvpAybunNJPU15mZIWkwvo+zD\n8y1NTlxgnaZsVY0y89uUjC919wJ+FxEHRcQeEXHTiLhtRDwyIr4MHM3c6R6uBp6UmVcN2n7X+tGv\ncy4uM7UAKynRd/Vl5SK0fUjb7QI7An9rqPcKytMtGwxR19qUyPWPUp4o+sEA69wIuKih/dOB3RZY\n96GUA/Pqda5tqGePAfteX2/VEH/3R/vsE98CbjVgHSuAhwPHAi8Y4P31z+wK4ObT2Lco81A2/f1H\nAJsvsO7DKQOS+rrnAdsN2P6ZtXXPHGM7HDjqPtTGAnyQchH2DgO+PyjpEpu24UXATovVdxcXFxeX\n5bXgmHi+uhwTOyZ2TDzmAny3oQ9XUS5I9VtnfeBzfT6Hpy5m/11clsJS/ZYNfEykZE+ov//YCfav\n6Zh99ALr3J/y29x0nBhkOZJyY7BefkgLf89XF2j7PGCdFtoJ4B1jbIN+y5tG3KcOHPHv2KXl/v8e\nuOEA7e5KSe8/TltvG3VbtLkNG+qu17tqwPWmuk3G6ftSXYB1KQEBbXw3jpz09mdKx6WO7LutfqcZ\n45ygT19WdmB/fhvlXGnXCdS9DfALyjnr+lP+Ow+cwHcgGXKMQnnA46gx27wO2G/M7dGJftQXMzVI\nU5aZZwCPZ25E+HqUwcRZEfGuiHhERNysitxcERFbRMTKiLhPRPxHRHwc+BNlTuKnUgZPg7T/J+Dl\nDS/dHPhRRHwmIh4dEbeLiBtVUaT7R8Q3gC9z/ZxKf2eC6bAW8FygKeXR/SmR4p+posRvVW23FVFS\n/9yu+ls+BPyZMti8+4Bt1lNkrQf8IEraxNVPFjSlG5rEcfe1wM8ayp8I/DIi3hARd46IrSJivWq/\n2TcijqL8zVs0rPuszDx3An3tuo2AZwI/i4jfVN+9/artt7LnO/CQiHgn5aT7g8zdhlcCj8rM3yxy\n/yVJmkmOiVvhmNgxcZsOYO4cy+sAh0fEqupJuTtGeWLvXhHxCuAM4FENdX0xM6f1vZCWhYjYHHhE\nw0tN00S0panu+8U8WY6yPPn3UOYeXwbxVcoxfRLTT8DCWRg+lZn16aqGlsWLgScz2nbo59IW61ps\nJwL3zMw/L/TGzDyJsh9cMWJb76B5zDez3Cbdk+WJ5EdRsl9N6pjVmmkdl9x3Z0NmvowS0HDSBOo+\njzLNxcMzc9T9YEmpxhqPBg4fsYrLgMdl5qjrd6ofdSOloZTUrsw8OiKeSJnDr37hdRvgBdUyqfY/\nGBF3pcwD3Gtt4LHVMm8VwNOA202gewvKzMsi4kGU9P+3rL28gsH+hmF9kDLQ67UdJW3ifHakRGy2\nJjOviohHUtI211NN3hB4TbUM6sDM/Fxb/Ztht2S0793fgUdm5nfa7Y4kSUubY+LxOCZ2TNymzDwj\nIh5NCVxZr/byfatlEMcDT2mzb5IaPZ6SLaXXtcCnJ9jmtyhTy/ROw7gW5Tt/UL+Vqt/7XSnpmB8y\nQDuXUFKjvzMzr43Bp3cf1pFVW/2mXxpr6om6zDw8Ir4OvJIy9lho2qcm5wKfpzwFelyb/RvAb4Hn\nU7Id3ZcS+Dasc4E3A/+TmfXA1r4y84sRsTslW8gdBlztLEoWqiMBJrgfTYXbpHuqffqVUaZLOAh4\n0JBVXEuZ6uftbfetn2kcl9x3Z0NmnjXBupfbVH8LqgI8nhxlSoe3ULJULbga5YGLF2Xm75ZSP3oZ\n1CB1RGZ+JiJ+T4l8ulVL1Q4Tpf0MSlTkvw/ZxrXAMzPzsxExlQu4AJl5TkTcHfg4Jep/0u39MCLe\nR5kTcuqqi467UwaNdxmxmqsoA8L/ba9ny85xlJS8v512RyRJmkWOicfjmNgxcZsy89sRsTfwWa7P\nRjKMbwL7ZuYl7fZMUl1mfoAyleJitnktcIMR1z0D2KcKbngScB9KQN4mlMyH5wKnAF8HPpmZF/es\neyYlVXqrMvOKiHgWzeOPyycRNJCZ5wMviojXAA+mzEN/Z0qmqM173notJeDidEpWppOBb2fmKUO2\nt4qWtl1mXkqZb/t9EbExcDdgd8q0FLegBDD2zqt9HWWqr1OBkyiZN741TDBDrf2TImIXYB/gMZR9\naAeuv99xNfA74CfAl4AvZeY1PVWcAjysVu2C2Tbb3IYNdY9V77S2SdW2d5X7yMwTgQdHxC0on8t9\ngdtTpr/r3W4XAb+mpOL/DvDN6hgxSButbf/FPi5VbU5z311Fi9/pzFzZVl1aXJl5IGUKis6oro98\nlvJdfBhwV67/fb2GMh3kryhTB34mM3+9lPsBBjVoNl0CvL9P+UzLzOMj4vbAc4AXAjcdoZpLKSd9\nHwe+NkTb1wHPjYhVwDspA4eF/Jxy8fbHI/SzdZl5AfCwiHgsJf3ssBeUL6HMuXvkgO9/AeVE+9XA\nBkO21brqIvY9KP16OcNdePwq8JLMPG0SfZshx1JSXt1oyPV+TZnH+dDquyRJ0qQ5Jp6fY2LHxC/A\nMfHYMvO7VaDOWyk3Hge5jvQX4PXABx0bS5pPlcq69XTWo8rMSU7XMV+7l1ICyD67uiwi1gU2BK7J\nzH9Mo1+Dqvr37Wr5p4hYmzLNZwL/yCyTdLfYblKeZP9KT5ubUgIoLp2vvcz8W+96S4XbpLsy83TK\ntcOD4Z/fj425/vvRqTHTYh+X3HelZtWx4SjmTv24LPsRLY8lJLWkmmf2PpRUfHejRIpvy/WRg0m5\nWHsWcBolInEVcGyOOcdfNUB5FPAvlCectqekAP47ZZ7Un1Kefjq67ROSNlVPaT2CEil+a2Crnpev\no6RJPI0SSfpN4LuZefkI7WwO7AvcC7gj5UmFTWi+qLtj9VTBREXEBpTPbx+uj5zrTeN8PiXydxXw\n6cz81aT7NCui5Cy7I3BPyrbbiXIzZXNKOs/LKdvvN5QI4a9l5o+m0llJkpY4x8Tjc0zsmLgtEXET\n4HHA/YDbUqaFWRf4B3A2ZU70rwJHZplLWtIYImIP4Jha8aIcPyVJ0uLyd19amEEN0gyJiBWUaEjo\nYARn1/VEk15dRZsuG9WN+o0ocxxe6kVGSZI0qxwTj8cxsWNiSbPBmxuSJC0f/u5LC3P6CWmGVHNF\nzXxK4WmpLlouywuX1dODnU5XKEmSNAjHxONxTOyYWJIkSZKkWbPWtDsgSZIkSZIkSZIkSZLUxKAG\nSZIkSZIkSZIkSZLUSQY1SJIkSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIk\nqZMMapAkSZIkSZIkSZIkSZ1kUIMkSZIkSZIkSZIkSeqkFW1XGBFnAJsCZ7ZdtyRJ0hK0ErgkM3ec\ndkfUHsfEkiRJQ1mJY+K6M4HX18ouWvxujM4xsSRJA1sX+HOt7MiIuHYandHUrMQxcV+Rme1WGPG3\nDTbYYMudd9651XolSZKWotNOO43LL7/8gszcatp9UXscE0uSJA3OMfHS5JhYkiRpcI6J59d6pgbg\nzJ133nnLE044YQJVS5IkLS13vvOdOfHEE8+cdj/UOsfEkiRJA3JMvGQ5JpYkSRqQY+L5rTXtDkiS\nJEmSJEmSJEmSJDUxqEGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmSJEmSJKmTDGqQJEmSJEmSJEmSJEmd\nZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmSJEmSOsmgBkmSJEmSJEmS\nJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmSJEmSJKmTDGqQJEmS\nJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmSJEmSOsmg\nBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmSJEmS\nJKmTVky7A5IkSZIkSZIkSZK02jnnnDPtLgxkhx12mHYXpGXBTA2SJEmSJEmSJEmSJKmTDGqQJEmS\nJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmSJEmSOsmg\nBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmSJEmS\nJKmTDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmS\nJEmSJEmSOsmgBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2S\nJEmSJEmSJEmSJKmTDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElS\nJ62YdgckSZIkSZIkgMycdheWnIiYdhckSZI6YYcddph2FwYyifHbHnvs0Xqdk7L77ru3XudBBx3U\nep1aXGZqkCRJkiRJkiRJkiRJnWRQgyRJkiRJkiRJkiRJ6iSDGiRJkiRJkiRJkiRJUicZ1CBJkiRJ\nkiRJkiRJkjrJoAZJkiRJkiRJkiRJktRJBjVIkiRJkiRJkiRJkqROMqhBkiRJkiRJkiRJkiR1kkEN\nkiRJkiRJkiRJkiSpkwxqkCRJkiRJkiRJkiRJnWRQgyRJkiRJkiRJkiRJ6iSDGiRJkiRJkiRJkiRJ\nUicZ1CBJkiRJkiRJkiRJkjrJoAZJkiRJkiRJkiRJktRJBjVIkiRJkiRJkiRJkqROMqhBkiRJkiRJ\nkiRJkiR1kkENkiRJkiRJkiRJkiSpkwxqkCRJkiRJkiRJkiRJnWRQgyRJkiRJkiRJkiRJ6iSDGiRJ\nkiRJkiRJkiRJUicZ1CBJkiRJkiRJkiRJkjppxbQ7IEmSJEmSpMnKzGl3YWoiYiL1TmKbTqqvkiRJ\ns2aHHXZovc7lPNY65phjJlLvnnvu2Xqdz3rWs1qv85xzzmm9zknso+rPTA2SJEmSJEmSJEmSJKmT\nDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmS\nJEmSOsmgBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmS\nJEmSJEmSJKmTDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnU\nIEmSJEmSJEmSJEmSOsmgBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmS\nJHWSQQ2SJEmSJEmSJEmSJKmTDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmS\nJEmSJElSJ62YdgckSZIkSdLiyMxpd0FLSERMuwtTtdz/fkmSZtWee+45kXpXrVo1kXrVbcv9HOuY\nY46Zdhe0TJipQZIkSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIkqZMMapAkSZIkSZIkSZIkSZ1kUIMk\nSZIkSZIkSZIkSeokgxokSZIkSZIkSZIkSVInGdQgSZIkSZIkSZIkSZI6yaAGSZIkSZIkSZIkSZLU\nSQY1SJIkSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIkqZMMapAkSZIkSZIk\nSZIkSZ1kUIMkSZIkSZIkSZIkSeokgxokSZIkSZIkSZIkSVInGdQgSZIkSZIkSZIkSZI6yaAGSZIk\nSZIkSZIkSZLUSQY1SJIkSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIkqZMM\napAkSZIkSZIkSZIkSZ1kUIMkSZIkSZIkSZIkSeqkFdPugCRJkiRJsy4zp90FadHNyn4fEdPugiRJ\nYznnnHNar3OHHXZovc5J/Oa+8pWvbL3OVatWtV6nlq9ZGWvOyth9UmblOKr+zNQgSZIkSZIkSZIk\nSZI6yaAGSZIkSZIkSZIkSZLUSQY1SJIkSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIkSZIk\nSZIkSZIkqZMMapAkSZIkSZIkSZIkSZ1kUIMkSZIkSZIkSZIkSeokgxokSZIkSZIkSZIkSVInGdQg\nSZIkSZIkSZIkSZI6yaAGSZIkSZIkSZIkSZLUSQY1SJIkSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIk\ndZJBDZIkSZIkSZIkSZIkqZMMapAkSZIkSZIkSZIkSZ1kUIMkSZIkSZIkSZIkSeokgxokSZIkSZIk\nSZIkSVInGdQgSZIkSZIkSZIkSZI6yaAGSZIkSZIkSZIkSZLUSQY1SJIkSZIkSZIkSZKkTjKoQZIk\nSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIkqZNWTLsDkjTrMrP1Oq+66qrW67ziiitar/NjH/tY63V+\n6EMfar3OX/3qV63Xedhhh7Ve57777tt6nStW+FMvSVo8kxgXRUTrdU6in8vdrIyJr7766tbr/PjH\nP956nR/84Adbr/O0005rvc6PfOQjrdf5xCc+sfU611lnndbrnCWzcmyWpKXiVa96Vet1HnTQQa3X\nOSvH8je/+c3T7oKkJeCUU05pvc4ddtih9TrVn5kaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmSJEmS\nOsmgBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmS\nJEmSJKmTDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmS\nJEmSJEmSJEmSOsmgBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWS\nQQ2SJEmSJEmSJEmSJKmTDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmS\nJElSJxnUIEmSJEmSJEmSJEmSOsmgBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE5aMe0OSNJi\n+cc//jGRet/61re2Xueb3/zm1uvcYostWq9zo402ar3OZz3rWa3XudNOO7Ve50EHHdR6ne9+97tb\nr/MHP/hB63UCrLfeehOpV5IWW2ZOuwtLSkS0Xudy/owmsT0vu+yy1uuE2RkTb7nllq3XueGGG7Ze\n57Of/ezW67zVrW7Vep2vec1rWq9zEmPiY489tvU6AdZff/2J1Nu2SRxLJC0tr3rVq1qv80c/+lHr\nda5atar1Oo866qjW65zENaM999yz9TpnxSTOB/xt1HI0qf1+Vs7ZH/KQh0y7CxqTmRokSZIkSZIk\nSZIkSVInGdQgSZIkSZIkSZIkSZI6yaAGSZIkSZIkSZIkSZLUSQY1SJIkSZIkSZIkSZKkTjKoQZIk\nSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIkqZMMapAkSZIkSZIkSZIkSZ1kUIMkSZIkSZIkSZIkSeok\ngxokSZIkSZIkSZIkSVInGdQgSZIkSZIkSZIkSZI6yaAGSZIkSZIkSZIkSZLUSQY1SJIkSZIkSZIk\nSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZIkqZMMapAkSZIkSZIkSZIkSZ1kUIMkSZIk\nSZIkSZIkSeokgxokSZIkSZIkSZIkSVInGdQgSZIkSZIkSZIkSZI6yaAGSZIkSZIkSZIkSZLUSQY1\nSJIkSZIkSZIkSZKkTlox7Q5Imn1XXnll63W+9a1vbb3Ogw46qPU6Aa655prW67zXve7Vep2f//zn\nW69zq622ar3OWXHDG96w9Trvfe97t17n1Vdf3XqdAOutt95E6pWkxRYR0+7CkpKZ0+7C1EziN/dt\nb3tb63W++c1vbr1OmMzfP4mx0ec+97nW69xiiy1ar3NWvkuTOB+YpTHx+uuvP5F6JWmxTeqa2XLl\nOUa73J6zYRLj1+X82e+1116t13n00Ue3Xqe0mMzUIEmSJEmSJEmSJEmSOsmgBkmSJEmSJEmSJEmS\n1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmSJEmSJKmTDGqQJEmSJEmS\nJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmSJEmSOsmgBkmS\nJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmSJEmSJKmT\nDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmS\nJEmSOsmgBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHXSiml3QNLi\n+v3vf996nU9/+tNbr3PVqlWt17n99tu3XifA8573vNbrfOlLX9p6neq+bbbZpvU611577dbrlKSl\nJDOn3QUtICJar/Oss85qvc5JjImPOeaY1uucxHgD4PnPf37rdb785S9vvc5JmMRxZBL7/awc77ba\naqvW63RMLGmpOOeccyZS7wc+8IHW63zzm9/cep3L2ayMN7R8zcr4dVb2+6OPPnraXZA6x0wNkiRJ\nkiRJkiRJkiSpkwxqkCRJkiRJkiRJkiRJnWRQgyRJkiRJkiRJkiRJ6iSDGiRJkiRJkiRJkiRJUicZ\n1CBJkiRJkiRJkiRJkjrJoAZJkiRJkiRJkiRJktRJBjVIkiRJkiRJkiRJkqROMqhBkiRJkiRJkiRJ\nkiR1kkENkiRJkiRJkiRJkiSpkwxqkCRJkiRJkiRJkiRJnWRQgyRJkiRJkiRJkiRJ6iSDGiRJkiRJ\nkiRJkiRJUicZ1CBJkiRJkiRJkiRJkjrJoAZJkiRJkiRJkiRJktRJBjVIkiRJkiRJkiRJkqROMqhB\nkiRJkiRJkiRJkiR1kkENkiRJkiRJkiRJkiSpkwxqkCRJkiRJkiRJkiRJnWRQgyRJkiRJkiRJkiRJ\n6iSDGiRJkiRJkiRJkiRJUietmHYHJPX3pz/9qfU6X/Oa17Re56pVq1qvcxJe8pKXTKTe5z//+ROp\nV912yCGHtF7n/vvv33qdG2ywQet1SpLUT0S0Xue5557bep2vfvWrW6/zmGOOab3Oa665pvU6JzUm\nfuELXziRetVtH/vYx1qv84ADDmi9TsfEkpaKHXbYYSL1vvnNb55IvctVZrZe5yTG2bNir732ar3O\no48+uvU6l7ujjjpq2l2QtMSZqUGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmSJEmSJKmTDGqQJEmSJEmS\nJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmSJEmSOsmgBkmS\nJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmSJEmSJEmSJKmT\nDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqJIMaJEmSJEmSJEmSJElSJxnUIEmSJEmSJEmS\nJEmSOsmgBkmSJEmSJEmSJEmS1EkGNUiSJEmSJEmSJEmSpE4yqEGSJEmSJEmSJEmSJHWSQQ2SJEmS\nJEmSJEmSJKmTDGqQJEmSJEmSJEmSJEmdZFCDJEmSJEmSJEmSJEnqpBXT7oA0DVdeeWXrdX77299u\nvc4nP/nJrde58cYbt17n29/+9tbrPPjgg1uvU8vXGWec0Xqdn/nMZ1qv8yc/+UnrdUrSQjKz9Toj\novU6J9HPWTKJbXrVVVe1XucxxxzTep377rtv63VutNFGrdf5lre8ZSbqnJRJ7KOTMCvHvEk4++yz\nW6/z05/+dOt1Hn/88a3XOSufkaTp+epXv9p6nQ95yENar3O5H89m5Xd8uX9ObTv66KOn3YWBzcp5\n65577tl6nZM45qlds7J/Tso555wz7S4saBLXaZYSMzVIkiRJkiRJkiRJkqROMqhBkiRJkiRJkiRJ\nkiR1kkENkiRJkiRJkiRJkiSpkwxqkCRJkiRJkiRJkiRJnWRQgyRJkiRJkiRJkiRJ6iSDGiRJkiRJ\nkiRJkiRJUicZ1CBJkiRJkiRJkiRJkjrJoAZJkiRJkiRJkiRJktRJBjVIkiRJkiRJkiRJkqROMqhB\nkiRJkiRJkiRJkiR1kkENkiRJkiRJkiRJkiSpkwxqkCRJkiRJkiRJkiRJnWRQgyRJkiRJkiRJkiRJ\n6iSDGiRJkiRJkiRJkiRJUicZ1CBJkiRJkiRJkiRJkjrJoAZJkiRJkiRJkiRJktRJBjVIkiRJkiRJ\nkiRJkqROMqhBkiRJkiRJkiRJkiR1kkENkiRJkiRJkiRJkiSpkwxqkCRJkiRJkiRJkiRJnbRi2h2Q\npuG8885rvc6HPexhrdf56Ec/uvU63/a2t7Ve5+abb956nQcddFDrdWr5Ovjgg1uvc6uttmq9zu23\n3771OiVpqYiIaXdhyTn//PNbr/NBD3pQ63U+6lGPar3O//qv/2q9zlkZE6+33nqt16nZ8MY3vrH1\nOrfeeuvW61zuY+LMbL1Of0OlpWMSxwi1b6+99mq9zqOPPrr1OiWpqyYxfj377LNbr3OHHXZotb51\n11231fqWGjM1SJIkSZIkSZIkSZKkTjKoQZIkSZIkSZIkSZIkdZJBDZIkSZIkSZIkSZL0/9u5vxer\n6n+P42vVIDWOYudEhTilWGF4kWBdWBFjMBeOVCBF0IViFHQdGTJJEbHF7DqQAkkkrSgIQqN2w0gU\ndVH+oCAQ4ntQiG5OaZm/ytb34ntxoKM5X3p/vuu9Zj8ef8CLz8xes/fHerKBlEQNAAAAAAAAAEBK\nogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABSEjUA\nAAAAAAAAACmJGgAAAAAAAACAlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAA\nAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAAAACAlEQNAAAAAAAA\nAEBKQ20fAC7n3Llz4ZubNm0K3yxh69at4ZsnTpwI37z77rvDN5cuXRq++dhjj4VvEu/DDz8M33zj\njTfCNzds2BC+OTIyEr4JcDl1XYdvNk0TvllKiZ+/hPPnz4dvduVOvG3btvDNkydPhm/edddd4ZvL\nli0L39y4cWP4Zpd05W9+amoqfLPEnXj9+vXhm/PmzQvfHHQlPpe78rcEMzUxMRG+2aW/k67c37v0\nO4XMpqen2z7CjAzy33xX3pe7ZMeOHeGbvV4vfJNL800NAAAAAAAAAEBKogYAAAAAAAAAICVRAwAA\nAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAA\nAACAlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAA\npCRqAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAAAACAlEQNAAAAAAAAAEBKogYAAAAAAAAAICVR\nAwAAAAAAAACQkqgBAAAAAAAAAEhpqO0DwOW89NJL4Ztvv/12+Obzzz8fvnnNNdeEbz788MPhm4sW\nLQrfnJqaCt8cGRkJ3xx0p0+fDt/csmVL+OaZM2fCN5966qnwTQAura7rto/Qqu3bt4dv7tmzJ3zz\nueeeC99csGBB+Oa6devCNxcuXBi++fHHH4dvDg8Ph28OuhJ3zWeffTZ88+zZs+GbmzdvDt/skqZp\n2j7CjAz6Zyi0Zd++feGbExMT4ZuDbnx8PHyz3++HbxKrK5/hXdKV+4a/+Xhdee17vV7bR+Bv8k0N\nAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRqAAAA\nAAAAAABSEjUAAAAAAAAAACmJGgAAAAAAAACAlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAA\nAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAAAACA\nlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApDTU\n9gGYXU6dOhW+2ev1wjdXrVoVvjk2Nha+uXr16vDNOXPmhG9OTU2Fb46MjIRvEu/HH38M3zx48GD4\n5iuvvBK+uWTJkvBNAGaH06dPh29u3bo1fPPee+8N3xwfHw/fvO+++8I3h4eHwzenp6fDN0ucs5Sm\nacI367oO3yzhp59+Ct/86quvwjdfe+218M2u3IlLPJ+ldOW5By7v5ZdfDt+cmJgI3yylK+9n/X6/\n7SPMyL59+8I3165dG75ZwqB/jnfp5++CrvzNl9CV9+Wq8txzcb6pAQAAAAAAAABISdQAAAAAAAAA\nAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACAl\nUQMAAAAAAAAAkJKoAQAAAAAAAABISdQAAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApiRoA\nAAAAAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACAlUQMAAAAAAAAAkJKoAQAAAAAAAABISdQAAAAA\nAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApDbV9ALic3377LXzzs88+C98cGxsL3xweHg7f3Lt3\nb/jmyMhI+CbxTp06Fb75wAMPhG8uX748fPPRRx8N37ziCl0gABfXNE345vnz58M3P/300/DNrtyJ\n33rrrfDNEufskrqu2z7CjPz666/hmw8++GD45i233BK++cgjj4Rvlni/A5gtDhw40PYRmEXWrl3b\n9hFmZNDvBl35+btydx9kXiP4//wfGQAAAAAAAAAgJVEDAAAAAAAAAJCSqAEAAAAAAAAASEnUAAAA\nAAAAAACkJGoAAAAAAAAAAFISNQAAAAAAAAAAKYkaAAAAAAAAAICURA0AAAAAAAAAQEqiBgAAAAAA\nAAAgJVEDAAAAAAAAAJCSqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAAAAAAAFISNQAAAAAAAAAA\nKYkaAAAAAAAAAICURA0AAAAAAAAAQEqiBgAAAAAAAAAgJVEDAAAAAAAAAJCSqAEAAAAAAAAASEnU\nAAAAAAAAAACkNNT2AZhd6roO35w7d2745vnz58M3b7jhhvDNTz75JHzzpptuCt8k3rlz58I3N2/e\nHL555MiR8M1Dhw6Fb86fPz98E2C2aJomfLPEnbCUEj9/ic05c+aEb5Zw/fXXh2+WuBOPjo6GbxLv\n7Nmz4ZtPP/10+OaXX34ZvvnNN9+Eb86bNy98s8T7XZd06fMO+Gv79+8P3+zSe6T3s8HUpWd0kJV4\nfyphfHw8fLPf74dvDvJ/A5mcnCyy2+v1iuzCn/mmBgAAAAAAAAAgJVEDAAAAAAAAAJCSqAEAAAAA\nAAAASEnUAAAAAAAAAACkJGoAAAAAAAAAAFISNQAAAAAAAAAAKYkaAAAAAAAAAICURA0AAAAAAAAA\nQEqiBgAAAAAAAAAgJVEDAAAAAAAAAJCSqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAAAAAAAFIS\nNQAAAAAAAAAAKYkaAAAAAAAAAICURA0AAAAAAAAAQEqiBgAAAAAAAAAgJVEDAAAAAAAAAJCSqAEA\nAAAAAAAASEnUAAAAAAAAAACkNNT2AZhd5s6dG7556NCh8M2TJ0+Gb65cuTJ8k274/fffwzfvueee\n8M2DBw+Gby5dujR8c/HixeGbAFxaXddtH2FGmqZp+wgzVuJOfOTIkfDNEydOhG/eeeed4Ztdeu0H\n2YULF8I3V61aFb55+PDh8M0lS5aEby5atCh8c5D/lrryWQe0Z2Jiou0jzIj3s24o8Znrte+GrrxO\n4+Pj4Zv9fj98s4SuvEYl9Hq9to8Af4tvagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJRE\nDQAAAAAAAABASqIGAAAAAAAAACAlUQMAAAAAAAAAkJKoAQAAAAAAAABISdQAAAAAAAAAAKQkagAA\nAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACAlUQMAAAAA\nAAAAkJKoAQAAAAAAAABISdQAAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAA\ngJREDQAAAAAAAABASkNtHwAu5+abb277CPCX9u7dG7558ODB8M2rrroqfHPbtm3hm/Pnzw/fBICu\n68qduGmato8wIyXOWdd1+GZXfp9VVVW7d+8O3zx8+HD45pw5c8I3t2/fHr7ZlTtxieceALquK5+P\nXbm/duX3Wcr4+Hj4Zr/fD98cZF36dxt0mW9qAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAAAACA\nlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRq\nAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAAAACAlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAA\nAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAA\nAACAlEQNAAAAAAAAAEBKQ20fAOA/Zffu3UV2N2zYEL65fPny8M0XXnghfHPdunXhmwBwMXVdF9lt\nmqbILoOnK8/Srl27iuxu3LgxfHPZsmXhmy+++GL45kMPPRS+WUJXnlGA2eL48ePhm8eOHQvfrKqq\nuvHGG4vsMnhK/butKz744IPwzTVr1oRvltCVu+agP6PQZb6pAQAAAAAAAABISdQAAAAAAAAAAKQk\nagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACAlUQMA\nAAAAAAAAkJKoAQAAAAAAAABISdQAAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAA\nAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACAlUQMAAAAAAAAAkJKoAQAAAAAAAABISdQAAAAAAAAA\nAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApDbV9AICL2bVrV/jm448/Hr5ZVVV15ZVXhm++/vrr4Zsr\nV64M3wSArqvrOnyzaZrwzRIG+Wfvip07d4ZvPvHEE+GbVVXmedq7d2/45ooVK8I3u6LEawTApY2O\njoZvHj9+PHyzqqrq2LFj4Zs7duwI39y6dWv4Zon7q8/cwbVmzZq2jzAjg/zvtkH+2Ut9hpT4vIOL\n8U0NAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRq\nAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAAAACAlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAA\nAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAA\nAACAlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAAAAAAAEhpqO0DAN23cePG\n8M09e/aEb/7xxx/hm1VVVe+991745sqVK8M3AYDuqus6fLNpmvDNrijx+1y/fn345jvvvBO+Wep1\nL3EnXrFiRfhmV557f/OxSvw+AdowOjpaZPf48eNFdqOV+Cwb5M+IQf59jo+PF9nt9/vhm4P8OnXl\nZ5+cnAzfLKHX67V9BPhbfFMDAAAAAAAAAJCSqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAAAAAA\nAFISNQAAAAAAAAAAKYkaAAAAAAAAAICURA0AAAAAAAAAQEqiBgAAAAAAAAAgJVEDAAAAAAAAAJCS\nqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAAAAAAAFISNQAAAAAAAAAAKYkaAAAAAAAAAICURA0A\nAAAAAAAAQEqiBgAAAAAAAAAgJVEDAAAAAAAAAJCSqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAA\nAAAAAFIaavsAwKX9/PPP4Zuvvvpq+Oabb74ZvrlixYrwzffffz98s6qq6rrrriuyCwB0U13X4ZtN\n04RvdkWJO/HOnTvDN999993wzdtvvz18s9Sd+Nprrw3fHOTnvis/u/c7gNlhdHQ0fPPJJ58M3yzx\nuTPIBvn32e/3i+yWuMesXr06fLMruvKMlni/27FjR/gmdJ1vagAAAAAAAAAAUhI1AAAAAAAAAAAp\niRoAAAAAAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACAlUQMAAAAAAAAAkJKoAQAAAAAAAABISdQA\nAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJREDQAAAAAAAABASqIGAAAA\nAAAAACAlUQMAAAAAAAAAkJKoAQAAAAAAAABISdQAAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAA\nAAApiRoAAAAAAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACClobYPALPFL7/8Er65Zs2a8M0vvvgi\nfPOOO+4I35yamgrfHBkZCd8EAPizpmnaPkJrStyJ77///vDNzz//PHyzxJ34o48+Ct+cN29e+Oag\nq+u67SPMKn6fALPD119/3fYRmCXGxsbCN6enp8M3Szlw4EDbR5iRQf53cAm9Xq/tI0A6vqkBAAAA\nAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABSEjUAAAAAAAAAACmJGgAAAAAAAACAlEQNAAAAAAAA\nAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgBAAAAAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABS\nEjUAAAAAAAAAACmJGgAAAAAAAACAlEQNAAAAAAAAAEBKogYAAAAAAAAAICVRAwAAAAAAAACQkqgB\nAAAAAAAAAEhJ1AAAAAAAAAAApCRqAAAAAAAAAABSEjUAAAAAAAAAACkNtX0AmC0uXLgQvvnDDz+E\nb27atCl8c8uWLeGbIyMj4ZsAAP8JdV2HbzZNE75Z4pwllLgTP/PMM+Gbk5OT4ZvDw8PhmyWepS7p\nynMPAF03MTERvjk2Nha+OT09Hb45yPeNffv2hW+uXbs2fHOQX6Oq8m8CoJt8UwMAAAAAAAAAkJKo\nAQAAAAAAAABISdQAAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJREDQAA\nAAAAAABASqIGAAAAAAAAACAlUQMAAAAAAAAAkJKoAQAAAAAAAABISdQAAAAAAAAAAKQkagAAAAAA\nAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACAlUQMAAAAAAAAA\nkJKoAQAAAAAAAABISdQAAAAAAAAAAKQkagAAAAAAAAAAUhpq+wAwWyxYsCB887vvvgvfBADg/zRN\nE75Z13UnNkuYP39++ObRo0fDN7uixPPZJV157gGg6/bv3x++OTExEb45PT0dvlnCoN/hiOV5AvgX\n39QAAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJREDQAAAAAAAABASqIG\nAAAAAAAAACAlUQMAAAAAAAAAkJKoAQAAAAAAAABISdQAAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAA\nAAAAAAApiRoAAAAAAAAAgJREDQAAAAAAAABASqIGAAAAAAAAACAlUQMAAAAAAAAAkJKoAQAAAAAA\nAABISdQAAAAAAAAAAKQkagAAAAAAAAAAUhI1AAAAAAAAAAApiRoAAAAAAAAAgJSG2j4AAABAW+q6\nDt9smiZ8s8Q5GUyeJQDgzyYmJsI3V69eHb45PT0dvkl+k5OT4Zu9Xi98E4CyfFMDAAAAAAAAAJCS\nqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAAAAAAAFISNQAAAAAAAAAAKYkaAAAAAAAAAICURA0A\nAAAAAAAAQEqiBgAAAAAAAAAgJVEDAAAAAAAAAJCSqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAA\nAAAAAFISNQAAAAAAAAAAKYkaAAAAAAAAAICURA0AAAAAAAAAQEqiBgAAAAAAAAAgJVEDAAAAAAAA\nAJCSqAEAAAAAAAAASEnUAAAAAAAAAACkVDdNEztY1/979dVX/9dtt90WugsAMBt9++231ZkzZ35s\nmua/2z4LcdyJAQBmzp14dnInHmxHjx4N37z11lvDN8nv+++/D99cuHBh+CbA3+VO/NdKRA3/qKpq\nflVV/xM6DAAwOy2uqurnpmmWtH0Q4rgTAwD8WxZX7sSzjjsxAMC/ZXHlTnxJ4VEDAAAAAAAAAECE\nK9o+AAAAAAAAAADAxYgaAAAAAAAAAICURA0AAAAAAAAAQEqiBgAAAAAAAAAgJVEDAAAAAAAAAJCS\nqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAAAAAAAFISNQAAAAAAAAAAKYkaAAAAAAAAAICURA0A\nAAAAAAAAQEqiBgAAAAAAAAAgJVEDAAAAAAAAAJCSqAEAAAAAAAAASEnUAAAAAAAAAACkJGoAAAAA\nAAAAAFISNQAAAAAAAAAAKYkaAAAAAAAAAICURA0AAAAAAAAAQEr/BPnWXU3HzsaeAAAAAElFTkSu\nQmCC\n", "text/plain": [ "\u003cFigure size 1800x600 with 3 Axes\u003e" ] }, "metadata": { "image/png": { "height": 369, "width": 1050 } }, "output_type": "display_data" } ], "source": [ "_, axes = plt.subplots(nrows=1, ncols=3, figsize=(6 * 3, 6))\n", "\n", "axes[0].set_title(\"Clean image \\n Prediction %s\" % int(pred_clean))\n", "axes[0].imshow(img_clean, cmap=plt.cm.get_cmap(\"Greys\"), vmax=1, vmin=0)\n", "axes[1].set_title(\"Adversarial image \\n Prediction %s\" % prediction_adversarial)\n", "axes[1].imshow(img_adversarial, cmap=plt.cm.get_cmap(\"Greys\"), vmax=1, vmin=0)\n", "axes[2].set_title(r\"|Adversarial - clean| $\\times$ %.0f\" % (1 / EPSILON))\n", "axes[2].imshow(\n", " jnp.abs(img_clean - img_adversarial) / EPSILON,\n", " cmap=plt.cm.get_cmap(\"Greys\"),\n", " vmax=1,\n", " vmin=0,\n", ")\n", "for i in range(3):\n", " axes[i].set_xticks(())\n", " axes[i].set_yticks(())\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "executionInfo": { "elapsed": 2, "status": "ok", "timestamp": 1707150695968, "user": { "displayName": "", "userId": "" }, "user_tz": 0 }, "id": "pG7U7jzxQ4OA" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "last_runtime": { "build_target": "//learning/grp/tools/ml_python:ml_notebook", "kind": "private" }, "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } } }, "nbformat": 4, "nbformat_minor": 0 }