{ "cells": [ { "cell_type": "markdown", "id": "877f5751-82f9-4f09-b2eb-8c6042a3f140", "metadata": {}, "source": [ "# Neural CDE" ] }, { "cell_type": "markdown", "id": "162e9c19-28b5-44c9-9a01-0ec508f2ac6c", "metadata": {}, "source": [ "This example trains a [Neural CDE](https://arxiv.org/abs/1810.01367) (a \"continuous time RNN\") to distinguish clockwise from counter-clockwise spirals.\n", "\n", "A neural CDE looks like\n", "\n", "$y(t) = y(0) + \\int_0^t f_\\theta(y(s)) \\mathrm{d}x(s)$\n", "\n", "Where $f_\\theta$ is a neural network, and $x$ is your data. The right hand side is a matrix-vector product between them. The integral is a Riemann--Stieltjes integral.\n", "\n", "!!! info\n", "\n", " Provided the path $x$ is differentiable then the Riemann--Stieltjes integral can be converted into a normal integral:\n", " \n", " $y(t) = y(0) + \\int_0^t f_\\theta(y(s)) \\frac{\\mathrm{d}x}{\\mathrm{d}s}(s) \\mathrm{d}s$\n", " \n", " and in this case you can actually solve the CDE as an ODE. Indeed this is what we do below.\n", " \n", " Typically the path $x$ is constructed as a continuous interpolation of your input data. This is an approach that often makes a lot of sense when dealing with irregular data, densely sampled data etc. (i.e. the things that an RNN or Transformer might not work so well on.)\n", "\n", "**Reference:**\n", "\n", "```bibtex\n", "@incollection{kidger2020neuralcde,\n", " title={Neural Controlled Differential Equations for Irregular Time Series},\n", " author={Kidger, Patrick and Morrill, James and Foster, James and Lyons, Terry},\n", " booktitle={Advances in Neural Information Processing Systems},\n", " publisher={Curran Associates, Inc.},\n", " year={2020},\n", "}\n", "```\n", "\n", "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/neural_cde.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "id": "16ab19e7-1e40-4b9e-99ab-3aa74f6e5ed8", "metadata": {}, "outputs": [], "source": [ "import math\n", "import time\n", "\n", "import diffrax\n", "import equinox as eqx # https://github.com/patrick-kidger/equinox\n", "import jax\n", "import jax.nn as jnn\n", "import jax.numpy as jnp\n", "import jax.random as jr\n", "import jax.scipy as jsp\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import optax # https://github.com/deepmind/optax\n", "\n", "\n", "matplotlib.rcParams.update({\"font.size\": 30})" ] }, { "cell_type": "markdown", "id": "23dedc5d-dc07-4dd1-845f-1251bab4b32a", "metadata": {}, "source": [ "First let's define the vector field for the CDE." ] }, { "cell_type": "code", "execution_count": 2, "id": "9e916e0a-df54-4045-9b74-16e536e58000", "metadata": {}, "outputs": [], "source": [ "class Func(eqx.Module):\n", " mlp: eqx.nn.MLP\n", " data_size: int\n", " hidden_size: int\n", "\n", " def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):\n", " super().__init__(**kwargs)\n", " self.data_size = data_size\n", " self.hidden_size = hidden_size\n", " self.mlp = eqx.nn.MLP(\n", " in_size=hidden_size,\n", " out_size=hidden_size * data_size,\n", " width_size=width_size,\n", " depth=depth,\n", " activation=jnn.softplus,\n", " # Note the use of a tanh final activation function. This is important to\n", " # stop the model blowing up. (Just like how GRUs and LSTMs constrain the\n", " # rate of change of their hidden states.)\n", " final_activation=jnn.tanh,\n", " key=key,\n", " )\n", "\n", " def __call__(self, t, y, args):\n", " return self.mlp(y).reshape(self.hidden_size, self.data_size)" ] }, { "cell_type": "markdown", "id": "09b8851b-e418-47e8-856e-622266c78612", "metadata": {}, "source": [ "Now wrap up the whole CDE solve into a model.\n", "\n", "In this case we cap the neural CDE with a linear layer and sigmoid, to perform binary classification." ] }, { "cell_type": "code", "execution_count": 3, "id": "d1e9dc97-de5c-4905-9e59-3ca474aa9a52", "metadata": {}, "outputs": [], "source": [ "class NeuralCDE(eqx.Module):\n", " initial: eqx.nn.MLP\n", " func: Func\n", " linear: eqx.nn.Linear\n", "\n", " def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):\n", " super().__init__(**kwargs)\n", " ikey, fkey, lkey = jr.split(key, 3)\n", " self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)\n", " self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)\n", " self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)\n", "\n", " def __call__(self, ts, coeffs, evolving_out=False):\n", " # Each sample of data consists of some timestamps `ts`, and some `coeffs`\n", " # parameterising a control path. These are used to produce a continuous-time\n", " # input path `control`.\n", " control = diffrax.CubicInterpolation(ts, coeffs)\n", " term = diffrax.ControlTerm(self.func, control).to_ode()\n", " solver = diffrax.Tsit5()\n", " dt0 = None\n", " y0 = self.initial(control.evaluate(ts[0]))\n", " if evolving_out:\n", " saveat = diffrax.SaveAt(ts=ts)\n", " else:\n", " saveat = diffrax.SaveAt(t1=True)\n", " solution = diffrax.diffeqsolve(\n", " term,\n", " solver,\n", " ts[0],\n", " ts[-1],\n", " dt0,\n", " y0,\n", " stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),\n", " saveat=saveat,\n", " )\n", " if evolving_out:\n", " prediction = jax.vmap(lambda y: jnn.sigmoid(self.linear(y))[0])(solution.ys)\n", " else:\n", " (prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))\n", " return prediction" ] }, { "cell_type": "markdown", "id": "4d96c981-86ed-4594-afae-6c087523a47c", "metadata": {}, "source": [ "Toy dataset of spirals.\n", "\n", "We interpolate the samples with Hermite cubic splines with backward differences, which were introduced in [https://arxiv.org/abs/2106.11028](https://arxiv.org/abs/2106.11028). (And produces better results than the natural cubic splines used in the original neural CDE paper.)\n", "\n", "!!! danger \"Time is a channel\"\n", "\n", " Note the inclusion of time as a channel of the data! This is a subtle point that is often accidentally missed. If you include it then the model has enough information so that in theory it's actually a universal approximator. If you forget it then the model probably won't work very well...\n", " \n", " If a CDE ever isn't training very well, make sure to ask yourself \"did I include time as a channel?\"" ] }, { "cell_type": "code", "execution_count": 4, "id": "d647fc5f-4eaf-4a06-83e6-034508fbbfca", "metadata": {}, "outputs": [], "source": [ "def get_data(dataset_size, add_noise, *, key):\n", " theta_key, noise_key = jr.split(key, 2)\n", " length = 100\n", " theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)\n", " y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)\n", " ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))\n", " matrix = jnp.array([[-0.3, 2], [-2, -0.3]])\n", " ys = jax.vmap(\n", " lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)\n", " )(y0, ts)\n", " ys = jnp.concatenate([ts[:, :, None], ys], axis=-1) # time is a channel\n", " ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)\n", " if add_noise:\n", " ys = ys + jr.normal(noise_key, ys.shape) * 0.1\n", " coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)\n", " labels = jnp.zeros((dataset_size,))\n", " labels = labels.at[: dataset_size // 2].set(1.0)\n", " _, _, data_size = ys.shape\n", " return ts, coeffs, labels, data_size" ] }, { "cell_type": "code", "execution_count": 5, "id": "0657dffa-da73-4c3c-a722-14eff8ca7811", "metadata": {}, "outputs": [], "source": [ "def dataloader(arrays, batch_size, *, key):\n", " dataset_size = arrays[0].shape[0]\n", " assert all(array.shape[0] == dataset_size for array in arrays)\n", " indices = jnp.arange(dataset_size)\n", " while True:\n", " perm = jr.permutation(key, indices)\n", " (key,) = jr.split(key, 1)\n", " start = 0\n", " end = batch_size\n", " while end < dataset_size:\n", " batch_perm = perm[start:end]\n", " yield tuple(array[batch_perm] for array in arrays)\n", " start = end\n", " end = start + batch_size" ] }, { "cell_type": "markdown", "id": "407f9c67-8797-480b-a56c-c30af3c7bd26", "metadata": {}, "source": [ "The main entry point. Try running `main()` to train the neural CDE." ] }, { "cell_type": "code", "execution_count": 6, "id": "248b14eb-2486-4c1a-ac2a-ade54d6d8880", "metadata": {}, "outputs": [], "source": [ "def main(\n", " dataset_size=256,\n", " add_noise=False,\n", " batch_size=32,\n", " lr=1e-2,\n", " steps=20,\n", " hidden_size=8,\n", " width_size=128,\n", " depth=1,\n", " seed=5678,\n", "):\n", " key = jr.PRNGKey(seed)\n", " train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)\n", "\n", " ts, coeffs, labels, data_size = get_data(\n", " dataset_size, add_noise, key=train_data_key\n", " )\n", "\n", " model = NeuralCDE(data_size, hidden_size, width_size, depth, key=model_key)\n", "\n", " # Training loop like normal.\n", "\n", " @eqx.filter_jit\n", " def loss(model, ti, label_i, coeff_i):\n", " pred = jax.vmap(model)(ti, coeff_i)\n", " # Binary cross-entropy\n", " bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)\n", " bxe = -jnp.mean(bxe)\n", " acc = jnp.mean((pred > 0.5) == (label_i == 1))\n", " return bxe, acc\n", "\n", " grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)\n", "\n", " @eqx.filter_jit\n", " def make_step(model, data_i, opt_state):\n", " ti, label_i, *coeff_i = data_i\n", " (bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)\n", " updates, opt_state = optim.update(grads, opt_state)\n", " model = eqx.apply_updates(model, updates)\n", " return bxe, acc, model, opt_state\n", "\n", " optim = optax.adam(lr)\n", " opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n", " for step, data_i in zip(\n", " range(steps), dataloader((ts, labels) + coeffs, batch_size, key=loader_key)\n", " ):\n", " start = time.time()\n", " bxe, acc, model, opt_state = make_step(model, data_i, opt_state)\n", " end = time.time()\n", " print(\n", " f\"Step: {step}, Loss: {bxe}, Accuracy: {acc}, Computation time: \"\n", " f\"{end - start}\"\n", " )\n", "\n", " ts, coeffs, labels, _ = get_data(dataset_size, add_noise, key=test_data_key)\n", " bxe, acc = loss(model, ts, labels, coeffs)\n", " print(f\"Test loss: {bxe}, Test Accuracy: {acc}\")\n", "\n", " # Plot results\n", " sample_ts = ts[-1]\n", " sample_coeffs = tuple(c[-1] for c in coeffs)\n", " pred = model(sample_ts, sample_coeffs, evolving_out=True)\n", " interp = diffrax.CubicInterpolation(sample_ts, sample_coeffs)\n", " values = jax.vmap(interp.evaluate)(sample_ts)\n", " fig = plt.figure(figsize=(16, 8))\n", " ax1 = fig.add_subplot(1, 2, 1)\n", " ax2 = fig.add_subplot(1, 2, 2, projection=\"3d\")\n", " ax1.plot(sample_ts, values[:, 1], c=\"dodgerblue\")\n", " ax1.plot(sample_ts, values[:, 2], c=\"dodgerblue\", label=\"Data\")\n", " ax1.plot(sample_ts, pred, c=\"crimson\", label=\"Classification\")\n", " ax1.set_xticks([])\n", " ax1.set_yticks([])\n", " ax1.set_xlabel(\"t\")\n", " ax1.legend()\n", " ax2.plot(values[:, 1], values[:, 2], c=\"dodgerblue\", label=\"Data\")\n", " ax2.plot(values[:, 1], values[:, 2], pred, c=\"crimson\", label=\"Classification\")\n", " ax2.set_xticks([])\n", " ax2.set_yticks([])\n", " ax2.set_zticks([])\n", " ax2.set_xlabel(\"x\")\n", " ax2.set_ylabel(\"y\")\n", " ax2.set_zlabel(\"Classification\")\n", " plt.tight_layout()\n", " plt.savefig(\"neural_cde.png\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "id": "00a26c12-7038-4294-9e47-98965d119742", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step: 0, Loss: 2.5234897136688232, Accuracy: 0.5, Computation time: 27.177752256393433\n", "Step: 1, Loss: 4.682699203491211, Accuracy: 0.5, Computation time: 0.5112535953521729\n", "Step: 2, Loss: 1.9817578792572021, Accuracy: 0.46875, Computation time: 0.4303276538848877\n", "Step: 3, Loss: 0.909335732460022, Accuracy: 0.375, Computation time: 0.42275118827819824\n", "Step: 4, Loss: 0.5238552093505859, Accuracy: 0.96875, Computation time: 0.3412055969238281\n", "Step: 5, Loss: 0.5987676382064819, Accuracy: 0.5625, Computation time: 0.4041574001312256\n", "Step: 6, Loss: 0.5615957975387573, Accuracy: 0.5625, Computation time: 0.3387322425842285\n", "Step: 7, Loss: 0.5031553506851196, Accuracy: 0.625, Computation time: 0.4076976776123047\n", "Step: 8, Loss: 0.3657313883304596, Accuracy: 0.84375, Computation time: 0.35105156898498535\n", "Step: 9, Loss: 0.34929466247558594, Accuracy: 0.9375, Computation time: 0.42032384872436523\n", "Step: 10, Loss: 0.2539682686328888, Accuracy: 1.0, Computation time: 0.3486146926879883\n", "Step: 11, Loss: 0.2294737994670868, Accuracy: 1.0, Computation time: 0.3518819808959961\n", "Step: 12, Loss: 0.2001168429851532, Accuracy: 1.0, Computation time: 0.4245719909667969\n", "Step: 13, Loss: 0.18462520837783813, Accuracy: 1.0, Computation time: 0.4051353931427002\n", "Step: 14, Loss: 0.19849714636802673, Accuracy: 1.0, Computation time: 0.4198932647705078\n", "Step: 15, Loss: 0.21601906418800354, Accuracy: 1.0, Computation time: 0.344160795211792\n", "Step: 16, Loss: 0.1362144500017166, Accuracy: 1.0, Computation time: 0.42815589904785156\n", "Step: 17, Loss: 0.12172335386276245, Accuracy: 1.0, Computation time: 0.3531978130340576\n", "Step: 18, Loss: 0.13752871751785278, Accuracy: 1.0, Computation time: 0.4119846820831299\n", "Step: 19, Loss: 0.10557006299495697, Accuracy: 1.0, Computation time: 0.33621764183044434\n", "Test loss: 0.1057077944278717, Test Accuracy: 1.0\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABFcAAAINCAYAAAD7m5BWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd3iUVfrG8e+Zlt6AhN6RjgVBFFFw7WJde+8V67ruz7WsXdfey4oFe++9i2CnWQCVoohICyQkmZRp5/fHOxMDBAhJpiS5P9c11wwz78x7EkLI3HnO8xhrLSIiIiIiIiIi0jiuZC9ARERERERERKQlU7giIiIiIiIiItIECldERERERERERJpA4YqIiIiIiIiISBMoXBERERERERERaQKFKyIiIiIiIiIiTeDZnIM7dOhge/XqFaeliIiISGsxffr0YmttYbLXEUc22QsQERFpY0yyF7AxmxWu9OrVi2nTpsVrLSIiItJKGGMWJXsNIiIiIomibUEiIiIiIiIiIk2gcEVEREREREREpAkUroiIiIiIiIiINIHCFRERERERERGRJlC4IiIiIiIiIiLSBApXRERERERERESaQOGKiIiIiIiIiEgTKFwREREREREREWkChSsiIiIiIiIiIk2gcEVEREREREREpAkUroiIiIiIiIiINIHCFRERERERERGRJlC4IiIiIiIiIiLSBJ5kL0CkNQuHw5SVlVFeXk5VVRWRSCTZSxIRqZfL5SIjI4OcnBxyc3Nxu93JXpKIiIhIi6FwRSROAoEAixYtIjMzk/z8fLp27YrL5cIYk+yliYisxVpLJBLB7/dTXl5OcXExPXv2xOfzJXtpIiIiIi2CwhWROAiHwyxatIgOHTpQUFCQ7OWIiGyUMQa3201ubi65ubmUlJSwaNEi+vTpowoWERERkQZQzxWROCgrKyMzM1PBioi0SAUFBWRmZlJWVpbspYiIiIi0CApXROKgvLycnJycZC9DRKTRcnJyKC8vT/YyRERERFqElAhXQhG4fxpUBJK9EpHmUVVVRVZWVrKXISLSaFlZWVRVVSV7GSIiIiItQkqEK3NWws1fwP99CNYmezUiTReJRHC5UuKfl4hIo7hcLk04ExEREWmglHj3t2VH+NdoeHMeTPou2asRaR6aCiQiLZm+h4mIiIg0XEqEKwCnbQu794Frp8D0pclejYiIiIiIiIhIw6RMuOIycOvu0DkbJrwNq7XNW0RERERERERagJQJVwDy0uH+8U6wct67ENZWbxERERERERFJcSkVrgAMK4KrxsFnv8Pd3yR7NSIiIiIiIiIiG5dy4QrAEUPg4EFwx9fwW2myVyMiIiIiIiIismGeZC+gPsbAhJHw0lz48g/olZ/sFYlIS7WhiSc+n4/c3Fzy8vLo2bMnw4cPZ9SoUYwfP56MjIyErnHSpEn89ttvAFx55ZUJPbeIiIiIiDRdSoYrAH3yoV0GTPsTjhya7NWISGsTCAQoLi6muLiYBQsW8PHHHwOQn5/P8ccfz1VXXUVeXl5C1jJp0iQmT54MKFwREREREWmJUjZcMQZGdHbCFRGR5vDKK6/U3rbWsmbNGkpKSpg1axafffYZv/32G6Wlpdx555289NJLPPPMM4wZMyaJKxYRERERkZYgZcMVgBFd4P2FsNIPhVnJXo2ItHQHHnjgBh+z1vLOO+9w/vnnM2/ePP744w/23XdfPv/8c4YMGZK4RYqIiIiISIuTkg1tY0Z0ca6nLU3uOkSk9TPGsM8++zBt2rTaapU1a9Zw6KGHEoloLryIiIiIiGxYSocrQwshzQ3famuQiCRIbm4uzz//PPn5+QDMnTuX5557rt5jq6qqeOWVV5gwYQKjRo2iffv2eL1e8vLyGDJkCGeeeSbffffdBs81btw4jDG1/VbACXnWvazbhyUUCvHee+9x4YUXMmbMGIqKivD5fOTk5NC/f39OOOEEPvvssyZ/LkREREREpGFSeltQmge26gjTFa6ISAJ17tyZ0047jZtuugmARx55hCOPPHK94wYPHlw75aeusrIy5syZw5w5c3jggQf497//zfXXX99s69t999359NNP17s/GAwyb9485s2bx2OPPcbxxx/Pgw8+iM/na7Zzi4iIiIjI+lI6XAFna9CDM6AqCBneZK9GRNqKo446qjZc+eKLLwgGg3i9a38Tqqqqol27duy+++5ss802dO3aFa/Xy5IlS5gxYwbPP/88wWCQG264gaKiIs4///y1nn/ttddSXFzMZZddxuzZs4G1m+7GDBw4cL3zZmdns+uuu7LtttvSq1cv0tPTWbp0KbNnz+app57C7/fz2GOPkZ+fzx133NF8nxgREREREVmPsdY2+OARI0bYadOmxXE56/voVzjpdXj2YNihW0JPLdJoc+fOZdCgQclehuBss4nZnO934XCYvLw8/H4/ALNmzWKrrbZa65h3332X3XbbDY+n/px60aJF7LXXXvz000/k5OSwZMkScnJy1jtu3LhxtVuDGrLGjz76iNGjR5ORkVHv46tWreLAAw9k6tSpuFwu5s+fT+/evTf5uiLrasr3MmPMdGvtiGZeUipp+DcUkVbCWks4HKaqqgqPx4PH48HtduNypXSnARFpPcymD0melP9OuG1n51p9V0QkkdxuN926/ZXorly5cr1j9tprrw0GKwA9e/bkvvvuA6C8vJzXXnutWda26667bjBYAWjfvj2PPfYYAJFIhKeeeqpZzisiIm2XtZZgMEggECASiVBTU4Pf76esrIyysjKqq6sJBoOb9YsMEZHWJOW3BeWnQ//2ME3hirRSV02GOeu/b2/VBhfCFWOTvYpNKygoqL29atWqRr3G6NGja29//fXXHHPMMU1eV0P06dOHTp06sWzZMr7++uuEnFNERFqnSCRCIBDAWovL5cIYU1utYq3FWkt1dXXt8W63G6/XW1vZUreKVESktUr5cAVgRGd48xcIR8Cd8rU2ItJa1B3BvKEfDFesWMHjjz/O+++/z5w5cygpKaGysrLeY//4449mW1tZWRlPPfUUb7/9Nj/88APFxcW1W5jieV4REWk7rLWEQiFCoVBtoLJuZUpssl3d50QiEYUtItLmtIhwZWQXePpH+GUVDCpM9mpEmldLqOBoq0pLS2tvt2vXbr3Hn3vuOU4//XTWrFnToNcrKytrlnV98sknHHXUUSxbtiyh5xURkbbDWlu7BWjdAGVjFLaISFvVIsKVEV2c62lLFa6ISGKEw+G1Kj4KC9f+5vPZZ59x1FFH1Va3DB8+nN12242+ffuSl5dHWlpa7bEHHXRQ7Ws21bx58xg/fjxVVVUADBgwgL333pstttiCdu3akZ6eXnvsaaedxsqVK5vlvCIi0naEw+Ha/imbE6zUZ0NhS1VVVe39CltEpDVoEeFK91woynKa2h67ZbJXIyJtwQ8//FC7vScrK4shQ4as9fiVV15ZG6w8+OCDnHrqqfW+zoa26jTWDTfcUBusXHrppVxzzTUb/CF0Q2sSERGpT33bgJpbLGyp27NFYYuItAYtIlwxxum7Ml1NbUUkQZ5++una26NHj15rKlAgEGDKlCkAjBgxYqMhxqJFi5p1XR9++CEARUVFXH311Rv8gbO8vJzVq1c367lFRKT1ikQiBIPBzd4G1FQNCVtiY58VtohIKmsR4Qo4fVfeng9Ly6FzTrJXIyKt2dKlS5k4cWLtn08++eS1Hl+1ahWhUAiAvn37bvS13nvvvU2er+5vBmMl2BuyfPlyAHr37r3R3yh++OGHazXkFRERqY+1tnYbEKy/jSfR6gtbwuFw7f+74IQtscqW2PQiEZFkazHhSt2+K/spXBGROCkvL+ewww6rbWY7aNAgDj300LWOyczMrL29YMGCjb7W7bffvslzZmdn1972+/1r/XldmZmZBAIBFi5cuMEgJhwOc/3112/yvCIi0ratuw0oFUOK+nq21A1bjDFrbSNS2CIiydJiBhsPLoRMr9N3RUSkuVlreeeddxgxYgRTp04FIDc3lxdeeGG9CpG8vDy22GILAKZNm8Yrr7yy3utVVFRw6KGHsnjx4k2eu3fv3rW3Z8yYsdFjR44cCcDKlSu544471ns8GAxy6qmnMm3atE2eV0RE2q5IJEIgEGhUsJLsyhaXy4Xb7a7dIhQOh6mqqqK8vJyysjL8fj81NTWEw+H1RkeLiMRLi6lc8bhg644wTeGKiDTSq6++WnvbWlvbl2TWrFl89tln/Prrr7WPd+vWjWeeeWa9RrYx55xzDueeey4AhxxyCEcffTRjxowhJyeHH3/8kUmTJvHnn39y3HHH8fjjj290Xbvuuit33XUX4GxBuuCCC+jZsydutxuAfv360a9fv9rzfvDBBwD84x//4NNPP2XPPfekffv2zJs3j8cff5x58+axyy67MG/evLUmHomIiNTdBhSvprWJVF9lSygUWmubU92eLapsEZF4MZuT5o4YMcIm87ehN38B902DuWdBeouJhaQtmjt3LoMGDUr2MoTN/+1afn4+xx13HFdddRX5+fkbPM5ay7HHHstTTz21wWMOOOAAnnnmmdptRGPHjuXTTz9d77hwOMy4ceNqK2bWdcUVV3DllVfW/vmSSy7hhhtu2OB5d9xxR1555RVGjhzJokWL6NmzJ7/99tsGjxfZkKZ8LzPGTLfWjmjmJaUS/TpcWhxrLYFAoElNa621tcFFSxBrkBsTC1u8Xi9ut1thi0jLktL/WFtURNGvHUQs/L4G+rdP9mpEpCXzer3k5uaSm5tLr169GD58OKNGjWLfffclIyNjk883xvDkk08yfvx4Jk6cyMyZM6msrKSoqIitt96aY489lsMOO6xBa3G73XzwwQfceeedvPbaa/z000+UlZURDofrPf76669n55135p577uHrr79mzZo1dOjQgUGDBnHkkUdywgknrDXdSEREJLYNKNavq60ECrGeLDGxcKhuZUvdsc8KW0SksVpU5cp3y2D/5+DBfWHPjQ/oEEkqVa6ISGugypWNUuWKtAjN3bQ2Vv3SWgKIWGVL7D2Ry+XC6/XWVra0pSBKpAVI6X+MLepXm70KnOtfS5O6DBERERGRlNcc24Bau/oqWwKBADU1NbU9aepuI9LnUUQ2pEWFK3lp0D4Dfi1J9kpERERERFLXuk1dFQg0TN2wJVbNEggECAQCAOuFLS29IbCINJ8WFa4A9M6H30qTvQoRERERkdTT3NuANnae1h7YxD4+hS0i0hAtMlyZ/HuyVyEiIiIikloikQjBYDCu24Bir7s5fRtbi/rCltg2orphy7oNckWkbWh54UoBvDAX/AHI8iV7NSIiIiIiyWWtJRwO124D0hv6xFg3wIqFLTU1NdTU1AAKW0TakhYXrvTKd65/LYWhRclciYiIiIhIcsVGC4fDYfVWSbJNhS3W2rW2EHk8Hv19ibQiLS5c6ZPvXP9WqnBFRERERNquSCRCIBCo7X+iN+qppb6wJRKJUF1dXXuf2+1eq7JFf4ciLVeLC1dilSsLS5O5ChERERGR5Fi3aa22mrQMCltEWrcWF65keKFztsYxi4iIiEjbE2ugGs+mtZIYCltEWpcWF66AMzHo19Jkr0JEREREJHFi1SraBtQ6bShsqaqqWmtSkcIWkdTUYsOVt+YnexUiIiIiIvGnbUBtUyxsif191xe2eDye2ovCFpHkapnhSgGUVkNJFRRkJHs1IiIiIiLxEYlECAaD2gYk9YYtsRHcdcOWWGWLy+XS14tIArXIcCU2MejXUoUrIiIiItL61H3jDKhaRdazobCl7tYxhS0iidMiw5VeBc71b6UwvHNSlyIiIiIi0qystQSDQcLhsKpVpMHq69kSC1tij9fdRqSwRaR5tchwpXsuuI3GMYuIiIhI6xKJRAgEAindtNYYg7U22cuQTagvbAmFQrXVUApbRJpXiwxXfG7olqtxzCIiIiLSOqhprcRbfWFLMBhcL2zxer243W6FLSKbqUWGK6BxzCIiIiLSOlhrCQQCalorCWWMwe121/65vrCl7thnhS0iG9diw5U+BfDNn2At6N+4iIiIiLREsaa1qbwNSNqG+sKWQCBATU0N4DRVjoUtHo9HX68i62ix4UqvfKgMwopK6JiV7NWIiIiIiDSctgFJqttY2BILVrxeb+02IoUt0ta12HCldhxzicIVEREREWk5IpEIwWBQ24CkRakbtsQaGgcCAQKBAOBUtqzbs0WkLWmxX/GxcczquyIiifTpp5/W/iB85ZVXJns5cfPbb7/VfpwnnHDCRo+dM2cOJ598Mv369SMzM7P2eQceeGDtMbH7xo0bF9d1J0OvXr0wxtCrV69kL0VEUlysWqWmpkbBirRosa9dt9u9VpASCATw+/2UlZVRVlZGZWVlbZAo0tq12MqVLtmQ5la4IiKbb8mSJbz00kt89NFHzJkzh+LiYvx+P3l5eXTr1o2RI0ey9957M378eHw+X7KXm9ImT57MXnvtRXV1dbKX0mwmTZrEb7/9BtCqAzQRSax1twEpVJHWJPb1XLeyJbaNqG5ly7oNckVakxYbrrhd0DNf45hFpOHWrFnDZZddxsSJE2ubs9VVXFxMcXExs2bNYuLEiRQWFnLZZZdx5pln4vV6k7Di1HfOOefUBivHHXcc48aNo6DAKS3s3LlzMpfWaJMmTWLy5MmAwhURaR7aBiRtTX1jn6211NTU1NsgV2GLtAYtNlwBZxzzQoUrItIA8+fPZ7/99uOnn36qvW+77bZj9913p1evXuTl5bFq1SoWLFjAu+++y48//sjKlSs577zz2HLLLVvldpaN6dWrV+1+6g35448/+OGHHwDYc889eeyxxzZ47KZeqyWLVbmIiKzLWls7DUhNa6Uta0jY4na7a/u1xKYRibQkLT5c+eQ3CEecShYRkfqsWrWKXXfdld9//x2ALbfckgceeIAddtih3uNvvvlmvvnmGy699FI+/PDDRC61RVm8eHHt7W222SaJKxERST2xLRGqVhFZX31hSyQSobq6unYseSxsiVW26N+QpLoWHa70yodAGP6sgO65yV6NiKSq448/vjZY2WGHHXj33XfJzd34N43tttuODz74gNtvv11bgjag7taqtLS0JK5ERCS1RCIRAoFA7ZvE1vamsLV9PJJ8GwtbYhS2SKpr0fUefWITg7Q1SEQ24Msvv+Stt94CICcnh2eeeWaTwUpdF1xwATvuuONmndNay5QpU7j00kv529/+RpcuXUhLSyMrK4vevXtzxBFH8MYbbzTotUpLS7nxxhsZO3YsRUVF+Hw+cnNz6dOnDzvssAMTJkzgnXfe2eC2mxkzZnDGGWcwbNgwcnNz8Xq9FBUVMXjwYPbaay+uueYa5s2bt97zNjYtaNy4cRhj2GWXXWrvu+qqq2qPr++NxOZMC/rll1/417/+xciRIyksLMTr9ZKXl8fw4cOZMGECH330Ub0fb1VVFa+88goTJkxg1KhRtG/fvva5Q4YM4cwzz+S7777b4HljH1es30rddde9rNuHZXOmBb3//vsce+yx9OnTh8zMTHJychg4cCBnnHEG06dP3+hz6/s7KS4u5sorr2TYsGHk5OSQk5PD8OHDueGGG6isrNzkekSkeVlrCQaDa/WU0BtAkc0X20ZXdxpRLGypqKhgzZo1lJeXU11dTSgUatXbj6XlaNGVK73znetfS2HnnslciYikqjvuuKP29oknnkjPnvH/ZnHSSScxadKk9e4PBAL89ttv/Pbbbzz33HPstddePPfccxsMe7799lv23XdfVqxYsdb9wWCQ8vJyfv31V7766ivuu+8+SkpKyM/PX+u4K6+8kquvvnq9HzhWrlzJypUrmTt3Lu+99x7Tp0/n1VdfbcqH3CxCoRAXXXQRd999N+FweK3HysrKmDlzJjNnzuS+++7j008/ZezYsWsdM3jw4Hr7n5SVlTFnzhzmzJnDAw88wL///W+uv/76eH4o66moqOCoo46qN1T7+eef+fnnn3nwwQc555xzuP322xvUl2HatGkceOCBLFmyZK37Y5+n559/no8++oh27do128chIhumbUAi8bOhypaqqqq1JhWpskWSqUWHK4WZkOXVOGYRqZ+1lo8++qj2z8cee2xCzltVVUVaWhpjx45lu+22o2/fvmRlZbFy5Up++eUXnnjiCVavXs27777LcccdV2+wUVlZyUEHHVQbrOy8887su+++9OjRA5fLRXFxMT/++CMfffQRP//883rPf+2117jqqqsAyMjI4Mgjj2T77benXbt2VFdX88cffzBt2jQ++OCDzf74rr322trzX3755QAcfvjhHHHEEZv9WjHWWg4++GBef/11wPkB6cADD2SXXXahqKiIysrK2jBo1qxZG6xcadeuHbvvvjvbbLMNXbt2xev1smTJEmbMmMHzzz9PMBjkhhtuoKioiPPPP7/ej+uyyy5j9uzZALzyyivrnWfgwIGb9bGFw2H23ntvpk6dCkB+fj4nnXQSw4cPJxQKMXXqVB5//HECgQB33XUXVVVVPPjggxt9zcWLFzN+/HhWr17N0UcfzS677EJ2djZz5szh3nvvZdWqVcyaNYvzzz+fxx9/fLPWKyKbLxQKEQwGgfXfBIpI84v9O4v9MqK+sMXj8dReFLZIIrTocMUY6F2giUEiUr+ffvqJVatWAU7AsPXWWyfkvBMmTOCBBx5Yr5Ik5rrrruPEE0/khRde4LXXXmPy5MnrVWG8/fbbtRUJZ555Jvfdd98Gz/fVV1+RkZGx1n0TJ04EnJDiww8/ZPTo0fU+t7q6mu+//76hHxoAY8aMAVjr4xs4cCAHHnjgZr1OXTfffHNtsNKjRw/efPNNhg0btt5x//3vf5k+fTodOnRY77FJkyax22674fHU/1/bddddx1577cVPP/3Ef/7zH04++WRycnLW+7jqVjs15WOKueWWW2qDlQEDBvDxxx/TpUuX2sePP/54zjrrLHbbbTdWr17NxIkTOeCAAxg/fvwGX/Pjjz8mPz+fqVOnMmrUqLUeO+GEExg+fDilpaU8/fTT/Pe//13rfCLSfKy1hEIhQqFQmwtVYv1kRFJBfWFL3Uld4IQtsWlEClskHlp0zxWAXnnw+5pkr0JEUlHd7RI9e/bc4Jvu5rbTTjttMFgByMrK4uGHHyYrKwuAJ554Yr1j5s+fX3v71FNP3ej5tt9++/UaysaeP2TIkA0GKwDp6elst912G339eKuoqODGG28EwOfzbTBYidl2223r3d611157bfTvuGfPnrUhVXl5Oa+99loTV75pgUCA22+/HXB+qHvhhRfqDTq22WYb/ve//9X++YYbbtjka991113rBSsAvXv3ZsKECYBTNVO3ektEmk+saW1bDFZEUl0saIlVrbhcLsLhMFVVVbU9WyoqKqipqSEcDqtnizSLFl25AtAlBz5YCNY6lSwiLU3xpXdR8+P6DUVbs7ShW9DhunPjfp5Y1Qqw0bAjGXJychg2bBhfffUVX3/99XqPZ2Zm1t6ePXv2Zo86jj3/jz/+YM2aNeTl5TVtwXH0zjvvsHr1agCOOuqojQYrTVU3aPr666855phj4nYugC+++ILly5cDsPfee2/0YzvkkEPo168f8+fP5/PPP2fFihUUFRXVe2xhYSFHHXXUBl/rb3/7G9dddx0Ac+bMacJHICLrqvsbcaBBPZJEJLnq69kSDocJhUK1j9fdRqRm1NIYLf5/gy45UBOG1VXJXomIyF9qamp44oknOOSQQ9hiiy3Izc2t/Y86dvnqq68AJwBZ12677Vb7n/oZZ5zBVVddVe9Unw3ZfffdAVi9ejVjx47lmWeeoaysrBk+suYX2zIDsP/++zfptVasWMEtt9zCHnvsQbdu3cjKylrrc56enl57bH2f9+b2zTff1N7eY489Nnl87O8NqDd0ixkxYgRut3uDj3ft2rX2dkmJ9s6KNJfYNKDYVgMFKyIt07rTiIwxhEIhqqqqKC8vp6ysDL/fr8oW2SytonIF4M8KaJ+58WNFUlEiKjjaqvbt29feLi0tTdh5f/jhBw4++OAGhyH1hR6DBw/m4osv5oYbbsDv93PllVdy5ZVX0r17d0aPHs3OO+/M+PHjNzj96OKLL+bNN99kzpw5fPfddxx11FG43W623nprdtxxR3bZZRf23HPP9Xq1JEPdkGPQoEGNfp3nnnuO008/nTVrGrZXNBFh09KlS2tv9+/ff5PH1z2m7nPXVV/PmbrqbhOrrq7e5HlFZNNi24BivUb0W22R1qO+ypZYkBp7XJUtsiktP1zJdq6XlsOw+qunRaSNqtvbYtGiRYRCobj3XVm9ejW77bZb7ZSf7t27s++++zJw4EAKCwtJT0+v/c84NpUmEonU+1rXX389I0eO5MYbb6ytYli8eDHPPfcczz33HGeffTZ77bUXd9xxx3pv3AsKCvjqq6+48cYbeeihh1i+fDnhcJjp06czffp07rrrLnJycjj//PO57LLL8Pl8cfysbFzdkCM7O7tRr/HZZ59x1FFH1X4uhw8fzm677Ubfvn3Jy8tbK2w46KCDANYb9xwP5eXltbdjPXY2pu7HX/e569Jvy0USZ92mtfr3J9L6GWPWqhCtL2ypO/ZZYYtAKwhXOkcrV5Zs+GdQEWmjBg0aRLt27Vi9ejVVVVXMmjWLESNGxPWc99xzT22wcvzxx/PQQw9tdHrNphx00EEcdNBB/Pnnn0yZMoUvvviCTz/9lO+//x5rLe+88w5ffPEFX3755XpVHzk5OVx77bVcffXVfPfdd3z++edMnTqVjz76iOLiYsrLy7nmmmv45ptveOedd5L2Q0Fubm7t7YqKika9xpVXXlkbrDz44IMbbALs9/sb9fqNVXcaUUPOXffjr/tcEUkOay2BQIBIJKJqFZE2rL6wJRAIUFNTU/u9IRa2eDwefb9oo1p89N4+A9LcsLRxP4+LSCtmjGG33Xar/XN9U3ma24cffgg4k2HuuOOOjVbKLFq0qMGv26VLFw4//HDuvPNOvvvuO3755Zfaj23NmjVcfvnlG3yuy+Vim2224eyzz+bZZ59l+fLlvPLKK7Rr1w6A9957j7feeqvBa2lu3bp1q709d+7czX5+IBBgypQpgNOLZGPTlTbnc94cOnfuXHu7IdvE6h6j8ckiyRUOh6mpqVGwUg99LqSti4UtscoVYwyBQAC/38+aNWsoKyujsrKSYDBIJBJRz5Y2osWHK8ZA52z4U5UrIlKP8847r/b2o48+Gvc317HJMO3bt9/ohKKZM2eycuXKRp9niy224MUXX6z9LUrdprCb4nK5OPDAA7n66qtr79uc5ze3nXbaqfb266+/vtnPX7VqVW23/759+2702Pfee2+Tr1e35L+pPwzVHXP9wQcfbPL4uscke0S2SFsVK/8PBAIAKvcXkU3aWNhSVlZGeXn5WmGLtE4tPlwBp6mtwhURqc/o0aPZZ599AKeHxZFHHrnRXhbruuOOO/jiiy8afHxsBPKKFSs2ep66wUZj5eXlUVBQAFAbLmyOXr161d5uzPOby957711bRfP000/zww8/bNbz646tXrBgwQaPKy8v5/bbb9/k69Xte9LUbUSjR4+mU6dOALz11lsbHYv88ssv11aujBkzZoNjmEUkfmJNa2P9VRSqiMjmin3viE0iiv3Spm7Ysm5li7QOCldEpNV77LHHareefPnll4wZM6Z2DPKGfPPNN+yxxx5ccMEFtb+9bIiRI0cCzm8+L7vssvUet9Zy+eWX8+qrr270de666y5eeuml2sZp9XnhhRcoLi4GYKuttlrrsdNOO40ff/xxg88NhUJMnDix9s/rPj+RsrKyuPjiiwHnB4/99ttvowHLrFmz1qpAysvLY4sttgBg2rRpvPLKK+s9p6KigkMPPZTFixdvcj29e/euvT1jxowGfxz18fl8XHDBBYDzOT/00EPrnQL0/fffc/rpp9f+Ofb5EJHEiDWtjW0DUrWKiDQXhS1tR4tvaAtOU9vlfghFwNMq4iIRaU4dOnTgo48+Yr/99uOXX37h+++/Z4cddmDUqFHsvvvu9OrVi9zcXFavXs2CBQt49913N7t6Iuass87ikUceIRwOc9dddzFr1iz+/ve/06lTJxYvXszTTz/NzJkzGTx4MBkZGUyfPr3e15kxYwbnnXceBQUF7LHHHmy77bZ07doVl8vFsmXLeP/992u3uBhj+Pe//73W8ydOnMjEiRMZMmQIu+yyC0OHDqVdu3b4/X4WLlzIs88+W1sl0b9/fw455JBGfbzN5Z///CdTp07l9ddfZ9GiRWyzzTYcdNBBjBs3jqKiIqqqqvj55595//33mTZtGp988slaY6jPOecczj3XGWt+yCGHcPTRRzNmzBhycnL48ccfmTRpEn/++SfHHXccjz/++EbXsuuuu3LXXXcBcPLJJ3PBBRfQs2fP2i1Y/fr1o1+/fg3+2C688ELeeOMNpk6dypw5cxgyZAgnnXQSw4cPJxQK8fnnn/PYY49RU1MDwKmnnsr48eM36/MnIo0X2wYUDodVrSIicRf7HhP7ucJaW9sgt+52xHWnEUnqaxXhStcciFhY4XeqWERE1tW/f3++/vprLrnkEh5++GECgQBff/117Yjj+nTq1InLL7+cMWPGNPg8W2+9NXfffTdnn302kUiEzz77jM8++2ytYwYNGsRrr73GKaecssHXif3HW1JSUjt6uT5ZWVncf//9azXurWv27NnMnj17g+fZcsstee2118jIyNjUhxZXxhhefPFFzj//fB544AHC4TAvvvgiL774Yr3Hr/tDxtlnn83XX3/NU089RSQS4YknnlivgfEBBxzAAw88sMlwZfz48YwZM4apU6cyf/58JkyYsNbjV1xxBVdeeWWDPza3280777zDkUceyZtvvklJSQm33nrrescZYzj77LO54447GvzaItI0sW1A1loFKyKSFOt+74mFLTU1NbW/eHG73Xi93tq+LvpelZpaRbjSObo9fkm5whUR2bD8/Hzuu+8+LrnkEl588UU++ugj5syZQ3FxMZWVleTl5dGjRw9GjhzJ+PHj2WeffTY67WdDzjzzTLbZZhtuu+02pkyZwqpVqygoKKBfv34ccsghnH766Wv1CanP/fffz+GHH84nn3zCt99+yy+//EJxcTHhcJj8/HwGDhzI7rvvzimnnFLvVJklS5bw7rvvMmXKFL7//nt+/fVXysrK8Pl8dOzYkW222YZDDjmEww8/fK3Rgsnk9Xq59957OfPMM3nooYf4+OOPWbx4MeXl5eTk5NC3b19Gjx7NoYceulYTXHB+MHnyyScZP348EydOZObMmVRWVlJUVMTWW2/Nsccey2GHHdagdbjdbj744APuvPNOXnvtNX766SfKysoIh8ON/tiys7N54403eO+993j88cf5/PPPWb58OW63m65duzJu3DhOO+00tt1220afQ0QaLrYNKNZbRb8VFpFUUV/YEolEqK6urr0vFrbUbaAryWc2ZxLCiBEj7LRp0+K4nMb5ZRXs/iTcvRfsPyDZqxFxxskOGjQo2csQEWmSpnwvM8ZMt9aOaOYlpRLN1WyhYuX3GrHceOFwmGAwqFBKJAlilS1138e3obAlpT+wVlG5EqtWUVNbEREREdmQWCigbUAi0lKpsiV1tYpwJdsHuT74syLZKxERERGRVKNtQCLSWm0obKmqqlqrea7ClvhrFeEKONUrS1W5IiIiIiJ1RCKR2vGmqlZpHvociqSu2Pe5WIissCVxWk240jnHaWgrIiIiImKtrd0GBOv/dldEpC1oSNji8XhqLwpbGq/VhCtdc+C75clehYiIiIgk27rbgPRGQUTEUV/YEg6HCYVCtcd4PJ7ayhaXy6XvoQ3UasKVztmwugqqgpDhTfZqRERERCQZtA1IRKTh6uvZUjdsMcasVdmisGXDWk24EpsYtLQC+hQkdy0iIiIiklh1twGpaa2ISOPUF7aEQqG1tlgqbKlfqwtX/ixXuCIiIiLSllhrCQQCqlYREWlm9YUtwWBQYUs9WmW4IiIiIiJtQyQSIRAIYK1VsCIiEmfGGNxud+2f1w1bXn/9dbp3784uu+ySrCUmTasJVzpmOddLK5K7DhERERGJv3Wb1mobkIhI4q0btsydO5fMzMwkrugvxhiXtTbSiOcZa63d3Oe1mnAlzQOFmRrHLCIiItLaaRuQiEhqqqioICcnJ9nLACAWrBhjOgLZgAGCQA0QiF6CQLBuCNOYYAVaUbgCztagpQpXJEXEypNFRFqiRv5cIRJ36zZW1P+1yaHPu4jUp7Kykuzs7GQvAwBjzAHAdkBvoBBIiz4U5K9wJQDUGGMCQHX0sWJr7TWbe75WF678sirZqxABl8tFJBJZq0RORKQliUQi2mYhKWXdbUB6cy8iknr8fj9ZWVlJXYMxJh24GjgRaN/Ap1kgBHiBeUAbD1eyYfIisBb0/60kU0ZGBn6/n9zc3GQvRUSkUfx+PxkZGclehgjghH3BYFDbgEREUpzf70/qtiBjjAcnVPln9C4LzALWAC4gs84lI3pJB3w4wQo4W4g2W6sKVzrnQGUQymogLz3Zq5G2LCcnh/LycoUrItJilZeXp8yeaWm7rLWEw+HabUCqphIRSW1+vz8p24LqNKHtAZwRvXspcA/wHrAAKMcJWzx1Lr7oJS16ycbZGrTZWlW4EhvHvKRc4YokV25uLsXFxZSUlFBQUJDs5YiIbJaSkhIqKyvp1KlTspcibVhsvGc4HFa1SgpSbzkRqU8SK1cMTnDSBRgGVAAPWmtvqOfYII0MUDamVYYrS8thcGFy1yJtm9vtpmfPnixatIjKykpycnLIysrC5XLpBxERSTnWWiKRCH6/n/LyciorK+nZs6f6RknSRCIRAoFA7Rt4/d+ZevR3IiL1qa6uJi0tbdMHxk9e9Ho+8DqAMcZtrQ3H+8StMlzROGZJBT6fjz59+lBWVkZpaSlLly4lEtnsMesiIgnhcrnIyMggJyeHTp06KViRpFi3aa22AaUmBSsisjFJ/t4d+wHGD5QBJCJYgVYWrhRmgtcFSyuSvRIRh9vtpqCgQFuDRERENsFaSyAQUNNaEZEWKsnbBW30ehWwDOiAMylogTHGZa2N+2+5W9WvA1wGOmarckVERESkJQmHw9TU1ChYERFp4ZIVsESb2QJ8C3wADABGRh+LGGNcJs4La1XhCkDXHKfnioiIiIiktljT2kAgAKDeZCIiLVgsIE+WaIVKALgKmA2cY4zZDZyApU4AExetalsQQOdsmLY02asQERERkY2JRCIEg0FVq7RQmhYkIuuqrKwkKysrKeeOjmKOGGNG4Yxjfh+4ALjfGHMP8A1QgtOLJQDU8NfUoGBzBC+tLlzpmgNvzoNwBNytri5HREREpGWz1hIOhwkGnSmYClZaJv2dici6/H5/0sIVnEa2IeBo4ERgIbAC6AvcDqwGfgUqccKVAE6wUgMEjDE1OMFLJnC/tXbG5i6g1YUrnXMgFIHiSqf/ioiIiIikhtg2oHA4rFBFRKSVSXK4ErMNkAUMW+f+dtHLhkSAapxw5U1A4UqXaKCypFzhioiIiEiqiEQiBAKB2u0kClZERFqXiooKsrOT8ybcWhuK3rwDmIxTyZKHE7RkR68zotdZOCFKevS+dJxsJDP6GrHX2iytL1zJca41jllEREQk+epuAzLG4HJp37aISGuUzJ4rMdbal4CXNvS4McaFk4N4AG/0klbnkgf82Jhzt9pw5U9NDBIRERFJKmstgUBATWtFRNqAZFaubEqdMcw2OlEo0NznaHXhSm4aZHkVroiIiIgkU6xaRduARETahhTpubKWaKUK1tpIvM/V6sIVY5ymtgpXRERERBLPWksoFCIUCmkbkIhIG+L3+1OmcsUY47LWRmKhijEmHWfbjyE6itlaG27Oc7a6cAWgYxasqEz2KkRERETalkgkQjAY1DYgEZE2KFXClViwYozJAXYEegM9gRycDGQ1sMgY8zkwz1pb3RznbZXhSlEWfLsk2asQERERaRvqNq0FFKyIiLRBfr+foqKiZC+DaLCyG3ACsCfQfiOHzzTGXGOtfbWp522V4UqscsVaZ5uQiIiIiMTHutuAFKqIiLRNKVS5si9wD9AjelcEWArU4Ixojo1kzgS2AV42xtxjrT23KedtteFKIAyl1VCQkezViIiIiLRO2gYkIiIxqRCuGGN6A/8DOkfveg/4FpgNrMAJWtoBQ4F9gS1xerGcboxZbq29rrHnbp3hSvTvc7lf4YqIiIhIc6u7DUhNa0VEBJIfrhhjfMBJOMFKNXAHcIu1dnU9h78CXGOMOQJ4AMgFzjHGPGKtXdqY87fK/wmLotOflvuTuw4RERGR1sZaSyAQqA1WVK0iIiKQEqOY2wFHABZ4x1p7ibV2tTHGY6JiB8ZuW2ufBQ6L3p0FHNXYk7fKcKVjLFypSO46RERERFqTSCRCTU2NtgGJiMh6/H4/OTk5CT9vndAkD+iLs/3nyehjLmttyEbFnlP3NjAdeB8nXNm+setoleGKKldEREREmo+1lmAwSE1NDQAul0vBioiIrKWysjLZPVdyo9drgN/AmRzUgOdVATOit/Mae/JW2XMl3QN5abBC4YqIiIhIk8S2AalaRURENiZZlSuAwdkKlBb9swvIAaeqZZ0qlfq4+CsbqW7sIlpl5Qo4TW1VuSIiIiLSeKFQiOrqagUrsh59LYjIumpqakhLS9v0gc0vFp4EgeVAH2B09D632cA3rDr3FwJbRW//2dhFtN5wJUuVKyIiIiKNEdsGVHcakN5Mi4hIKqpTmfI78CVOznGkMWb7WL+VTTxv9+ilmL+2B222Vh2uqKGtiIiIyOaJRCIEAgFCoZCqVUREpEGstUn9/yK6/Wcp8EH0rmHA/caY44wxA40xucYYd3RokNcYk2mMKTDGXALcFn3Od8A7jV1Dq+y5Ak5T2xWVELHg0s8EIiIiIhtlrSUcDhMMBgGnaa2IiEhDJTNgqVOF8hSwA3AMzlafR4ApwExgMVCJMxVoCLAX0Dn6vFLgSWvt4gb2aVlPqw5XQhEoqYL2mclejYiIiEjqim0DCofDqlaRBkn2b6lFJLWEw+Gkh/LRUKQsWo1SARyJM/1nbPSyIWuAi6y1j8F6Y5obrNWGKx3rjGNWuCIiIiJSv9g2oNibZb1hFhGRzZUCY5ix1lpjjMta+4cx5iLgG2APoD/QHsgE3HWeEgY+BK621v7c1PO33nAl+ve63A+DC5O7FhEREZFUY60lFArV9lZJ9m8cRUSk5aqoqCArKyvZy8BaG4kGLH5gEjDJGLM1zgShIsCLszVoIfC1tbayuc7desOVWOWKmtqKiIiIrMVaSyAQ0IhlaRJtDRKRGL/fnxLhCjgBC4Axxm2tDVtrZwGz4n3eVhuuFEa3Ai3XOGYRERGRWrGmtdoGJI0V+7ppZFsCEWmFKioqErotyDj/eZlYkBK9z4UzEdnWuc8N1P2PzsYuje2tsiGtNlxJ80BBOqxQuCIiIiKibUAiIhI3ia5ciQYjdp37IkCk/mdsWmOnBMW02nAFnK1BCldERESkrYtEIgSDQW0DEhFJktDCP/Df8hi+3bYn4++7Jns5zc7v9yescsUY8yp/jVA+yFr7Z/T+f+NMB1oN+KOXiugl9ufK6KUKqAaqrbVBaPyUoJjWHa5ka1tQXaur4IvFMHWxM6L61j0g25fsVYmIiEi8WGtrtwEBClZERBIsvHwV/ruepvr59zDpaXh32CrZS4qLRIYrwAigS/R23f/Uzuav0KU+4eglFL0EgYAxpgYneMkEdrTWLmvMolp1uFKUBT+vSvYqkqsqCPd8C58ugtkrnLqpHB/4g/Dvj+CuvUA/Y4mIiLQ+1lqCwSDhcFihiohIgkXK/FQ++CKVj7wK4TAZx+5L1llH4OqQn+ylxUVlZWUitwVV4lSjZOOEIjHpm3ieO3pZt8TA4mwncuOEL43S6sOVlX4IR8DdRrcVPzDdCVe27wr/2AHGdIctO8L/psNNX8DILnBc6wxPRURE2qxIJEIgEFDTWhGRBLM1AaqeeBP/fc9hS8tJO2AXsi84BnePjRVUtHwVFRXk5uYm6nTbABlAprW2tM79h+NsC8oCcqLX2dFLVp3rrNjzo9dpOIFLPk5w0yitOlzpmAVh62yHKUyNqVAJVVoND8+EPfvCg/uu/diZI+DbP+GaKbB1JydwERERkZat7jYgNa0VEUkcGw5T/eon+G9/gsifK/HtvC1ZF52Ad0jfZC8tIfx+P126dNn0gc3AWhvrn7Lu/R829jWNMT4gPfrajdLqwxVw+q60xXDlwelQEYB/bL/+Yy4Dt+8B+zwNZ74Nbx8JeZsqohIREZGUZa0lEAioaa2ISAJZawl88i0VNz1K+JdFeIZtQe5NF+AbvXWyl5ZQCe65sknREcwbUjtpKNbE1lobAAJNOWer/nVGx+jfbVtsaruqEh79DvbtDwM71H9MQQbcuw8sr4ALP4DmnfItIiIiiRKJRKipqVGwIiKSQMEZcyk94l+sOeVKqAmSe8+/KXj1jjYXrIDTcyWZ4Yoxxm2M2dsYM8YYk2OtDW/kErFRzbmGNlG50hbHMT8wHapDcP6ojR83vDNcMgau+gwenAGnb5uY9YmIiEjTWWsJhUKEQiFtAxIRSZDQgsX4b36Mmve/wNWhgOyrJ5Bx+J4Yb6t+e71RKVC50g54EvgZmADMbOgTjTFDgYOApdbahxq7gFb9t1+Y6Vwvr0juOhJthR8e/x4OGgj92m36+BO3hi/+gNu/gmOGQZbGM4uIiKS8SCRCMBhUtYqISIKElxXjv/Mpql/4AJOZRtY/jiXjxANxZWUke2lJV1FRkchpQfXJBQqALXAa1DaIMcYFbAdcBcwFFK7Ux+uG9hltr3LlvmkQDMN52zXseGPg9OHwwUJ4Zz4cMji+6xMREZHGq9u0FlCwIklhjKGZK+pFUlakrILKB16g8tHXIBIh47j9yJpwBK72ecleWsrw+/3k5OQk/LzGGBPd3hNLuEqAqs15Cf4a4RxqylpadbgCztagttRzZWk5PPUDHDoYeuY3/HkjukDPPHhxrsIVERGRVLXuNiCFKiIi8WNrAlQ9/oYzVrnMT9r+48j+x7G4u3dK9tJSTqJ7rhjnP0AX4MZpRFsQfShY5xgvECHavJY6TWzrhDIunBHN4AQzjdbqw5WiNhau3POt05j2nAZWrcQYAwcPgtu+gj/KoFvCRpSLiIhIQ2gbkIhIYthwmOpXPsZ/+5NElq7ENzY6Vnlw2xir3BiJ7rkSDUbC0Qv8NZq5DFgUPSZYz1PrPh+c7USxv9gVTVlTqw9XOmbDnOJkryIxiivhudlwxNDGhSN/j4Yrr/y0+eGMiIiIxEfdbUBqWisiEj/WWgIff0PFzZP+Gqt8c9sbq9wYoVAIny8xzTuNMT2Aa4GdgOVAKU5DW4CewA3GmEU4gUsZUF7nUhG9vwbwAocBx+BUuPzSlHWlbLgSLi2n8uOvSRvSF2//Xo3+7UzHLCd0CEXA08p/FpnyOwQjcHgjt/V0z4Xtu8JLc+HskU41i4iIiCSPtZZAIKBqFRGROAtOn0PFjY8SnDYbd68u5N7zb9L2HqPvu6kpA6fapGf0EhMCOgGnN+I1fwE+acqiUi5csdbif+NTii++g/DK1QC4O7YnY6fhZIzZlsxdR+Hp1KHBr1eUBRELqyqdKpbWbMrvUJAOQ4oa/xoHD4KLPoQZy2Dbzs23NhEREdl81loFKyIicRSa9zsVt0wi8MFXuDoUkHPNBNIPa9tjlTdXbIdNAv+fCgELgC44/VbcOD1XCoBqoDJ6nBunYa1rA9cWp3olADxkrf2oTi+WzZZSXzGhZcWs/NetVL4zFd+W/Sm6/3JCi5dRNWU6VZOnUfHiB5h0H0UPXEH2+J0b9Jodo9OgVvhbd7hiLXy2CHbqAa4mfE3vswVc/im8NEfhioiISLJpG5CISHyElxbjv/NJql/80BmrfOFxZJ54ICYzfdNPlvUkeHrYQuAcnEa0WTjbgq4EzgCmA4/gTAzKBTJxKl0y61wycIIXP8745bettfNgrV4smy0lwhUbiVD+5JusuvI+bDBIu/+cQf6Zh2M8zvJyj9nXKYudu5CVF97M8hMvI3TtueSfdsgmXzsWriz3w7B4fhBJ9lMxrKx0wpWmyPbBXn3hjXnwn7GQnhJfISIiIiIiIk0XWVNO5f0vUPnY685Y5eP3J+uswzVWuQlCoRAeT+LeOEYDkDXRCwDGmF+jN6cDk5oSkjRWSrx1rpk5l5UX3kz6mOEU3XoR3j7d1jvGGEPa4L50eekOVpx5NasuvZPQ4qW0v2oCZiO/0YlVq7T2iUGf/e5c79xz48c1xCGD4NWf4cOFsG//pr+eiIiIiIhIMtnqGiofe53K+1/AlvtJO2Ac2RdorHJz8Pv9ZGVlJeXcdbbxPA98DywBfMaYAM72nxi77u3mDmBSIlxJ33YIXV69i/TRW29yn5YrM52Oj1zDqsvuZs0DzxNasoKi+y7DlZ5W7/EdMp3NVMsr4rDwFDLld+jfHjo1w9an0d2d13lprsIVERGRZFKfFRGRprGhMNUvfYj/zieJLFuFb+wIsv51At5BfZK9tFajoqIioWOY64oFJNbaRURHMMcYYyIbOj4eUiJcAcjYcZsGH2vcbtpffx6eHp1Y9Z97WbamnM7P3VK7jaguj8sJWFpz5Up1CL5ZAsds2Tyv53bB3wfC/6bDSj8UJieEFBERERERaRRrLYEPvqTilscIz1+MZ6v+5N52Eb7tm+lNk9RKZuVKXcaYWKWKsdaGE701KGXClc1ljCH/zCNw5eaw8vz/svq/D9P+svonLhVlOQ1tW6uvl0BNGHZuYr+Vug4eBPdNc7YHnTq8+V5XRERENo8xJtGNAkU2SVVVksoC3/xIxU2PEpoxF3fvruTedylpe47W122c+P3+pFWurMtaGwEwxuQAvYFYh+JqnCa31dFLTfQSjD2nqVpsuBKTe/R4qqfPpvTOJ0kftSVZu++w3jEds1p35cpniyDNDaO6Nt9r9msHWxbBW/MUroiIiIiISOoL/fQrFTdPIvDJt7iK2pFz/bmkH7I7xuNO9tJatYqKiqRXrhhjXNbaiDGmO3AoMBzoCRThTAiKAEGcMc6h6O0AUGOM8eNMHjrBWrugsWto8eEKQIfrzqNmxlxWTLiWbh8/grdbx7Ue75gNP6xI0uISYMrvMLILZHib93V37gn3T4PyGsipv6WNiIiIiIhIUoWXLMd/x1NUv/wRJjuTrItOIPOE/TEZGqucCKmwLSgarGwDXAPss5lPDwC+6KXRNjxmpwVxZaTR8eGrscEQy0+9AhsIrvV4xyworoRgOEkLjKPlFfDzKtipGaYErWvH7hC2zrYjERERSQ6VsYuI1C9SUkb5dRNZteupVL8xmcxT/k77Tx8m68zDFKwkUCpsCzLG+IB7+CtYeRO4HSiP/nkO8A0wD1hd56kR/gpVqpqyhlYRrgD4+nan6I6LqZk2m1XXPLDWYx2znFlLxZXJWVs8TYmNYG7Gfisxwzs7240+X9z8ry0iIiIiLZt6AUmy2Mpq/Pc+x6qxJ1H16Guk778L7T+aSPa/T8ZVkJvs5bU5lZWVSQ9XgOOA7aK3rwFOtdZeCMRSgEustdsD2wAnAp9H738DKMTZFrTWtKHN1Sq2BcVkH7AL1V8dzJoHnid9+63IHr8z4DS0BaepbeecJC4wDiYvgsJMGNSh+V873QPbdVW4IiIiIiIiyWeDIapfeB//nU8RWVmCb7ftyf7n8Xj6x6GMXxqsoqKC9u3bJ3sZRwIG+BG42lob27cSSwCyjTFua20lTqDyhjHmCeBowG2t3b+pC2g1lSsx7a88i7StBrDywpsIF5cATuUKtL6mthELUxfDTj0gXhXDO3Z3th01dtqStU4AdNTLMOphZ2S0iIiINJy2BYlIW2etpfrtqaze8wzKL7sHd4/O5D93M/kP/kfBSgpIhZ4rwFCccOWVOsEKQEb0unKd+wFOAH4H9jXGHNjUBbS6cMWk+Si651IiZX6KL70LgKJohVJrC1dmr4TVVfHptxKzY3fn+ovNrF4JhuHlubD303DcqzB/tbPF6MiX4akfmn2ZIiIiIiLSCgW+/I6Sgy6g7Ozrwesh78H/kP/8zfhGDkn20iQq2T1Xov1WCnEmAM1Z534Tvb++kVER4Jno7cObuo5WtS0oxjewNwX/OI6SGx8h+++70WH3HXGZ1heuTInuCNupe/zOMaQQctOcrUEHDmzYc+athuNfhSXl0L893LI7HDAAqkJw7jtwyccwZyVcMRZ8moomIiIiIiLrCM5ZgP+mSQQ+m46rcwdybjyf9L/vinHrDUSqSYGeK7HuxQGgtM79aTgBSgjYUGlNbPTy4KYuolWGKwAF5x6D/41PWfnPW+g+dUvaZ+SwspWFK5N/h8GFUBjHCiy3C3bo5oQr1m56+5G1cNVkqAjAI/vDLr3AFX2Oz+3cd/MXcP90J4S5bx/okBm/9YuIiLR02hYkIm1J+PelVNz2BDWvf4rJyyb73yeTcey+mPS0ZC9NNiDZlSs4035W4+QbWevcvwrIBbYAMMa4rLWR6ONunIoX+Ks3S6O1um1BMcbnpfDOfxNesZpVV9xLYVbj+4akomAYZi6F0d3if64x3Z0qlEVrNn3s5EXOBKPzRsGuvf8KVmLcLrh4DNy5J8xaBke/DOFI/a8lIiIiIiJtQ6S4lPKrHmDV7qdT8/6XZJ5xKO0nP0LmqQcrWElxFRUVye65YoGfccKUruvc/z1OBcsYgDrBCkA+sEv0dklTF9FqwxWA9K0Hkj/hSMqfeotRv01jZSsaxTx/NdSEYVhR/M+1Y3TM86amBoUicN1U6JkHx2658WMPHAi37gE/rYLXfmmedYqIiIhIYqiiSppLpKIS/51PsWqXk6l68k3SD96N9h9PJPtfJ+LKTfp4X2mAyspKcnKSOpa3GvgWZ3vQ1gDG+Sa1Bng/eswwY8x5xphCY0yeMaYIOBPYDWc70YymLqJVhysABRediLdvdw548ibWrG496crslc71kASEK33yoVP2psOVF+bAL6vg4h0b1ktl/BYwuAPc8ZVTiSMiIiLr05tYEWmNbCBI5eNvsGqXU/Df+RS+MdvQ7t37yb3+XNydOiR7ebIZ/H5/UsMVa60f+Cz6x5HGmBzrCAPvAMuBdsDtwCfAK8As4Kroc5YCk5q6jlYfrrgy0ii842Jyipex3zsPtZotKLNXQobHCT7izRhnatAXi53xz/XxB+DWL2FEZ9i7X8Ne12Xgwh2c7UYvzW2+9YqIiIiISGqykQjVb0xm1R6nU3Hl/Xj6dafgpdvIu/8yPH3jOKlD4iZFRjF/AhwP/AdnGxAA1tofgQtxApYITuPacUCn6CGlwJ3W2s9NE3+b0Wob2taVsf2WLN33AA588yWKv92TjqMGJHtJTTZ7JQzs4PQwSYQduzsByJyVMLSeapkHpsPKSpi436ab3ta1a2/YuiPc9Q0cNBDS2sRXpIiIiIhI2xOYMoOKmx4lNHsB7gG9yHvkKnxjR6hCr4ULh8N4PMl9I2etLQGe2MBjTxtj/gDOwwlXsnEKTb4BbrXWTo0et4FSgoZp9ZUrMRVnn0ZpVj5l/3czNtyy96BErBNyDCnc9LHNZcdoiFzf1qBlFfDgDNivP2zTaf3HN8ZEq1eWlMNzs5u+ThERERERSS3B736h5JhLKD3+MiKl5eTceiHt3rybtHEjFay0cNbalPk7NMZ4jDH1Nqiw1n5mrT3YWjsI6G2t7WqtPSgWrDSHNhOudOiUw727n4tr9s+sefiVZC+nSRavgfJA/RUk8dIpG/oW1B+u3PKFE/j8a3TjXnunHrBdF7j7W6gONW2dIiIirU2q/NAqIrK5gjN/ovSkKyg56HxCcxeSfdlptP9wIhkH7YpxN6BJo7QYqfB/lbU2FO2zsh7jcBtjjLU2Lu8620y40jELPh38N9ZsN4rV1z9I6M8VyV5So/0YbWY7NIGVK+BUr3yzBAJ1vlx/XwMvzoXjt4IeeY173Vj1ygo/PPl986xVRERERESSI/DtbEqPv4ySg/9B8LufybrweNp/+giZJx2ISfMme3nSChlj0owx7Y0xucaY9b7IYg1um7r1Z2PaTLhSmAkYw6wz/gGRCMWX3JnsJTXa7JXgcUH/9ok9747doSoEs5b9dd/b853h4cdv1bTX3r4bjOkO901zmuOKiIiIiEjLYa0l8NX3lBx9MaWHX0RwzgKyLj6J9p89StaEw3HlZCZ7iRIHgUAArze5gVk0TPk/YCpOU9tNNoAxxuQYY0YZY/YyxnRsjnW0mXAlwws5Pvg9twsF/zwR/1uf4X9nSrKX1SizV8AW7RLf/HWHbs6En7pbg96d72xP6p7b9Ne/cAdYVQWPfdf01xIREWktUqHUWkRkQ6y1BKbOpPSIf1F61MWE5y8m+9JT6fDZo2SddgiurIxkL1HiKJmTgupM9xkFHAMMALDWVjXg6ZXARcDbwAXGmCa/u24z4QpAUZaz9ST/zMPxDerDyovvIFJRmexlbbbZCW5mG5OXDgPaw4ylzp+XVcDMZbBX3+Z5/eGdneqYZ2ZD/Iq1REREWh4FLCKSaqy11Hz6LSWHXEjpcZcSXryM7CvPpP3kR8g8+SBMRnqylygJUFFRQXZ2drJOH8szRgO9gB+AjwCMMRvMOowx7mhvlm+jd+0E9GnqYtrU4NvCTFjpB+P1UHjLP1my7wRWXXUfhTf/M9lLa7DlfmfkcTLCFXCmAb05z2lg+94C5769+jXf6x80EP75AcxavvmTh0RERETassiaCkK/LyW0aCmR0jKMxw1eL8brxng84PNG7/NgvB7wONfG69nwfbHjXWu/T1Hg13ZZawl8/A3+u54m9MM8XF0KyblmAumH7KF+Km1QMitX6hiMk21MA1YAWGsjGzk+9thsYAmwFZDf1EW0qXClKAu+X+7cTt9uGHlnHc6ae58lc88xZO22fXIX10Czo314hyRwUlBdW3eCp3+EhSXOlqC+Bc4WpeayZ1+49GN4/WeFKyIiIiJ12UCQ8JIVToDy+1JCvy8jXHt7KXZNRfxO7nY5wYvPg/F4cLXPx3TviKdnZ9w9YpdOuLt3wqT54rcOSRobiVDz/pdU3vMMoTkLcXXvRM4N55F+0N8wPoUqbVVlZWUqhCtdo9dLAP+mDq7T1HZZ9PiuQJMT4zYVrhRmwoo6u4DaXXwKlR99zcrz/0v6lMdxFzRD45A4mx2dFDS4Q3LOHws8pv4OXy+BM0Y07+vnpsEuveCNX+CynZz/x0VERETaAmstkVWlhBYtdUKTxcv+uv37UsJLiyFS55exPi+ebh1x9+hM5vBBeHp0xtO9E+6enXG3y8eGwxAMYUMhbDAUvR2GQDB6XxiCQed6rWNC2EDIuS8UgkBoneNDhJevIvjbEqq//gFbWf3XmozB1an92oFL7HbPzrjycxL/iZUmseEwNe9+jv/uZwj/sgh3ry7k3PwP0vcf51Q4SZuW5G1BMbF0b3ObS1ggNhO8pqmLaFP/GoqyoDLoTKPJ8oErPY2O917GH3ueRvH/3UbHB69M9hI3afZK6JkHOWnJOX+/dk5j4HfmQ9jC3s3Ub6Wu/QfAuwvgyz9gTI/mf30REZGWxhhDHKdHSpIE5iyg8uWPCC1c7IQoi5etHVQAro7t8XTvRNqoLXH36IynZzRA6dEZd6f2623XSZRIJEIgEHC+NletIfy7EwKFo2FQ+PdlBD79lsjKkrWeZ3Kzo4FLp9rAxR0NhFydOmDc7g2cURLNhsLUvPkZ/vueJTx/Me6+3cm97SLS9t3Z2S4mghOuJLFyJfYfY3n0egv+Cks2yBhjotUrXYFY4lu+kac0SJsLV8Bpats7Wq2YtmV/2l10EqtvmEjm3mPIOWi35C2wAWavhKFJ6rcCzrSgrTo6PVG65TiTgprbrr0h2wev/6JwRURERFoXWx2g8s1P8T/+BoEZc8HnxdunG56enUnfaTieaODg6d4Zd/eOuFK8KagxBtMhH1eHfLzDB633uK2s/it4iW5jCv++jNDshdS8/yWEwn8d7PPg7toRd49OePr3wjtiMN4RQ3C1gOry1sSGwlS/9gmV9z5L+Lc/cffvSe7dF5O2144Kv2Q9lZWVSatcqdNX5WdgV2BfIBeoMMa4NtF3BWAPoACnT4vClc1RGB2tvrISehf8dX/+uUfhf/9ziv91Gxnbb4WncxLTi40oq4Hf18Dhg5O7jiFFMHUx/H0gxKOXWbrH6b3yzjy4ZlziR06LiIiINLfgwj/wP/UWlc+9S6S0HE+fbuRdcQZZh+zRqsMDk5mOZ2BvPAN7r/eYDYWJLCv+q9qltuplKZVffgcTXwLA3b8nvpFD8Y4cgnfkUNydk7Q/vpWzgSDVr3yM//7niPy+DM/gPuTedylpe+yQtAopSX1+vz8VtgU9B5wDpAPXGGMmWGurN3SwtdYaY4YAR+JsKXofKGvqItrU29a6lSt1GY+Honsv5Y9dTmLFef+l83O3pGQH9DnRfivJamYbE/vM9CvY6GFNsn9/eGkuTF4Ee8Rh65GIiEhLkoo/l8im2VCY6ve/oOKJN6iZMgM8bjL23JGs4/YjbfTWbf7v1XjcuLt1xN2tI+y49VqP2ZoAwe9+IfjtjwS/nU31qx9T9dRbALi6d8IXDVq8I4fg7t21zX8um8LWBKl+8X389z9P5M+VeIZtQc7E0/H9bTt9XmWTKioq6NixY1LXYK391hjzGU71yolA2BjzFPCdtbY0dpwxxgtkA0OAp/hrQtB9GwtjGqpNhSuFGwhXAHx9e9D+ygkU/99tlNwyiXYXnZjYxTVArJltssYwxyxY7Vz7g/E7x47doV0GvPazwhURERFpWcJLi6l4+i38T79NZPkq3F0Kyb3oBLKO2Bt3x/bJXl6LYNJ8+LYbim+7oYATVIXmLiT4jRO21Hz6LdUvfwSAq0NBtKplCN7thuIZ0EvbVxrA1gSoeu49Kh94nsiyVXi2GUjOtWfjGztCoYo0WIqMYgY4H3gH6AacAuwIfG6M+RlYiTN+uSswFtgr+hwL3G+t/ao5FtCmwpWCdPC6nG1B9ck98UBqZsyh5KZH8PXvRfYBuyR2gZswe6VTfVOUxK/d6hB8/ofTE2XW8vidx+uG8VvAC3P+akAsIiIikqpsJELNlBlUPP4G1R9+CRFL2tgRZN9wHul/G6UGoE1kPG68w7bAO2wLOPkgrLWEFywm+O1sAt/8SPDbH6l5Z6pzbHYm3hFD/gpchvXHpGlUcIytqqbqmXepfPBFIitW4x0xhNyb/oF3R1VTyeZLZs+Vuqy1s40xJwP/wQlWBkUvG1IFvGytndBca2hT4Yox0XHMG5h8bYyh8NaLCP66hBXnXIe3VxfSthqQ2EVuxOwVya9amfq7M3FpdDeYuQysjU/fFYAD+sMT38P7C+GggfE5h4iISEugNzypK1yyhsrn3sP/5FuEfluCq10eOacfStbR4/H07JLs5cVVMr8ujTF4+vXA068HGUfuDUB4yXInbPl2NsFvfiTw6bfOwWk+vFsPwDtyCL6RQ/EMH4QrKyNpa0+WiL+KqqfeonLiy9hVpXh32JLcO/6Fd9QwfY+RRkuRnisAWGs/MMYsBw4BRgO9cRrcpgHh6CUALANusdY+3Zznb1PhCjhbgzZUuQJOCWLHSdexZM/TWHrMxXT7YCKeTslvmlUdgnmrYbc+yV3Hu/MhN81ZxxefwdIK6JKz6ec1xrZdoGuOszVI4YqIiIikksCsn6h45FUq35oMNUF82w2l3YXHkbHPTpg0ldwmg7trR9xdO5J+4N8AiBSXEpg2m+C3swl++yOV9z1PZeRZ8HrwjR1B+n5jSdt1FCYztScyNVWkvJKqJ96g8uFXsCVl+HYaTubZR+IbOSTZS5NWIJXCleiI5e+B740x/XB6q3QD2gEhoBj4xlr7XTzO3+bClaIs+GMTfYA9hQV0euIGluxzFsuOu4Qur92NKyMtMQvcgF9WQdgmt3IlYuHDX51RySM6O/fNXBa/cMVlYL/+8NBMKKmCgrb3CwYRERFJQRWPvELpFfdhsjLIOmIfso/dF28903AkuVwd8knfa0fS99oRcEKG4My5BD6bTs3bUyj78CvISCNtt+1J338cvp2GY3ytZ/tQpKyCqsdep/KRV7FrKvCNG0nWOUfi3Ua/tZTmk0I9V2JTgEz09nxgfn3HRUMY29znb3PhSmEmzFi66ePShvSj4//+w7LjLmHFhGvp+MB/kvrNNtbMdmgSJwXNWwUl1bBTDxhUCGluJ1wZv0X8zrn/AHhgOrw9H44eFr/ziIiIpDKV7KcGG4mw5rqJVPzvBdL3HE27Oy/GlZ2Z7GVJA7lyMknbeVvSdt6W7EtOcaYQvf4pNe9MpeaNyZi8bNL22pH0/cY6W2VaaFPcSGk5lY++StWjr2ErKvHtvj1ZE47Au2X/ZC9NWiG/309OTpx+216POuGJrXOfG6c5rYldR4+r+5+njV4i8QhWoA2GK0VZsLoKgmGnaerGZO01hvZXncWq/9zLn6vX0OnRa3EX5CZmoev4cQXk+qB7ck4PwLRoKDWiM/jcTtAzswFBVVMM7gA98uDjXxWuiIiISPLY6gCr/3ETVa9/StYJB5B/1Vkt9s13c7HWttjgz7hc+EYNwzdqGPbKMwlMnUn1G59S88Zkqp97D1dRO9L22Yn0/cfi2WpAi/g4I6vXUPnwK1Q9/gbWX0XaXjuSefYReAdr9KbET6LDlfqCEWttOGEL2Ig2Ga5YYFUVdGrA1rD8M4/A3aGAFeffyJI9T6fT0zfi69cj7utc19xip1okmd/Xp/3pVP70yHP+vHUnePJ7CISdsCUejIFxPeHFuVATgrQ29xUrIiIiyRYpKaP4lCsJfP09eZedRvbph7aIN9vx1lo+B8brIW2XkaTtMhJbVU3NJ99S8/pkqp5+i6pJr+Hq0Yn0fceSvt9YPAN6JXu564msLKHyoZepeuotbFUNaeN3ImvCESm5Vml9KisrycyMfwXfF198wY477vglUAJ8aK29DcAYMxBn9PKfQDXOFKBKwF/ndlX0Ul3nEgACzRnMtLm3qoXRv/eV/oaFKwA5h+6Jt2cXlh5/CUv2Op2OD19D5tgR8VtkPRaUwPh+CT3leqYthW07/xXwDO8ED8+En4phy47xO++4nvD49/DtnzAm8bmWiIhI0rWWN7EtUeiP5RQf+29Ci5bS7r5Lydx/l2QvKSW01q9Jk5FO+j47kb7PTkTK/NS8/wU1b0ym8oEXqLzvOdz9e5K+3zjS99sZd4/OSV1raMFiqia9TtVLH0IgSNr+Y8k663A8SfhFsLRd1lrcCaji+/XXXwFGRf+4vM5Dw4B/AKuB2PagCM5koNh1GKehbQgI1rlUGWMycUYy39DUNba5cKUo2mtnQ+OYNyR9u2F0e+9Blh1zMUsP/yftLjqR3BMOwN0+v9nXuK7VVVBaDX0K4n6qDVrhh9/XwHFb/nXfNp2c65nL4huu7NDdqYz5dJHCFREREUmcwA/zKD7+UmxNgMKnbyRt+y03/SRpNVy5WWQcsjsZh+xOpLiU6renUPPGp/hvfQz/rY/h2XqAM3Fo/M64i9olZE3WWgJTZlD16KsEJk8Hn5f0A3Yh84xD8fTumpA1iMTEqXVJvcrLy+v+sW662z163Zh/hGGcQGZmI5e1ljYXrsQqV1ZsZBzzhnh7dKbrW/ex4uzrWP3fh1h922Nk77MTOcfuR8aY4RiXq3kXG7WgxLlOZrgy7U/nekSXv+7rkuN8Pmcug+O3it+5M72wXReYvAgu2yl+5xERERGJqf7kG1adfjWuglwKn70Zb/+eyV6SJJGrQz6Zx+1H5nH7EV6ynOo3PqPmjclUXPMgFdc9hHf7Lck85SB8Y0fEparHVlVT/eonVD76KuH5i3EVFpB1wTFkHLkPrg75zX4+kc2RiEq2M844gzPPPDMbyMTZ0hPzPPAbTraRB+QAudFLTp3rHCAbyIpeMqPPycLZJtRkbS9caWTlSowrJ4tOj11PzdyFlD/5JuXPv0vFqx/j6dWFnCP2JufgPfD26rLpF9oMC6PhSt9khitLnelAdUdBGwPDO8e/qS3A2F5w3RT4szx+o59FREREAPzPvEPJxbfjHdibDo9dh7tTh2QvSVKIu2tHss44lKwzDiU0/3eq3/iM6pc/ZM1JV+DddjBZFx2Pb7vmmcQQXlpM1ZNvUvXMO9jScjxD+pJz64Wk77MzJq31jI2WlimRlSvR81Xi9FCpe98fwB+Nfc3oVKFm2dcUn1KLFOZzQ0F648OVmLRBfehw3bn0/OEViv53BZ5unSj578P8PvJwluxzJmsefYXw6jXNsuaFJc66uyVxUtD0P50Gtus2rt2mE/y2Bkqq4nv+cdFfFk1eFN/ziIiIpKLW2t8i1VhrWXPLY5RcdCtpY4ZT+NLtClY2ItFvrFKRp18Psi84hvYfTST76gmEFy+j9Ij/o/T4ywh+/0ujXzf43c+sOe9GVo09kcr/vYhv1DDyn72JgtfvIuOgXRWsSEoIBAKkpaUl5FyTJk3CGPOJMeakaJ8UjDFr5RnG4THGuKPXdS/u6MUVvdSOdLbWhppjjW0uXAGnemVlI7YF1ceVnkbO33ej6yt30mPmi7S7/AzC5X6K/3Ubvw09kJX/upXwqtImnWNhCfTMA3eS/raqgvDjSmcE87rq9l2Jpy3aQZds+PS3+J5HREQkVSlgiS8bDFHyz1sov+MJMg/fkw6TrsWVHf8JGC2Zvib/YnxeMo8ZT/tPHyb73ycT/HE+JQeez5ozriX0S8N+O2hDYarf+ozVh1xIyUEXEPjkWzKO35/2nzxE3v2X4dtuqD7nklIqKirIyspKyLk+//xzgLHA1kR34FhrI8aYC40x3xhjjgfSrLUha204el33Eo5eItFLs6fDbW5bEEBRZtMrV+rj7daRgnOPJv+cowjMXkDZY69S9vgbVLz8IQX/PJG8k/+O8W7+p3xBCfRLTI+sen23HEIR2Lae3U7DipxuQj+sgL/1jt8ajHG2Br35CwTD4I1/Q2oRERFpIyIVlaw6/SpqJk8n9x/HkXPBsXoTK41i0tPIPPVg0o/Ym6pHX6XyoZeo+eBL0g4YR9Z5x+Dpuf5vKyOl5VQ99y5Vj79JZOlK3D07k33FGaQfvJsCPklpFRUVZGc3cARvE5WVlcVu1gBhY4yJBiQ7AyOAHYCXaab+KY3RditX4hCuxBhjSBvaj8Kb/0n3Tx8lbZtBrLr8bhbvfDz+97/YrNcKRZwpPUnttxJtZrttPZUrWT6n0e4PK+K/jnE9oTwAM+JcJSMiIiJtS9lNj1IzdSYFt1xI7j+OU7AiTebKySTr3KNoP/lRMk89mJp3v2D17qdRdundhJcVA84o5fLL76V4x+Pw3/go7l5dyJt4Be0+mkjm8fsrWJGUV1lZmbDKlUCgtoetC6dIJJZlxL5hr8QZvUxs249J8Dfztlm5Et0WZK1TERFPvoG96fz8rVR+8AWrLr+HZUf/HznH7keH68/Dlb7p/WmL10AwkuRJQUudbTn56fU/PqwIvloS/3WM7g4eF0z+DUZp0pyIiLQxxhj1uIiDSFU1/hffJ2O/cWQdsXeylyOtjKsgl+yLTyLjpAOpvPdZqp55h+pn3vnrgNgo5RMOwDMojmXgInHg9/sTFq5kZtaGjVsCWGvD0T/Hxp2UW2v90cciCVnUOtpk5UpRJtSEoawmMeczxpC1x450n/I4+ecdQ/kTb7Bk/FkEF/25yecuTPIY5oiF6UvXHsG8rqFFsKwivtVAALlpTvXMJ2pqKyIiIs2k6vVPsWV+so/ZN9lLkVbMlZOJZ2BvWKdFQPrBu5F96akKVqRFSuS2oKFDhwL4gV2BR40xZxtj9gZ6RA/pY4zZxRgzwhiztTFmkDGmnzGmhzGmszGmvTEmzxiTZYxJM8Y0e6OJNlm5UjuOuRLyNlCNEQ/G56X9ZaeTPnIoK866lj92O4Wi+y4na/cdNvicBaXOdbK2Bc1f7YRQ9TWzjRlW5Fz/uBJ2iXNwOa4n3PgFLPdDx8SEpCIiItKK+Z98E88WPfCNap7RuSJ1hZcVU/XE2qOUM046EM/A3lTe+yzVz7xDzdtTyDz9UDKP2w+TmcA3JyJNlMiGtqeccgqXXHJJKZAF7IPT3LYKKIwecgBO35UaILDOpabO/bHbNcaYKsAL3GytLW3qGttkuFIUC1f8znaXRMvac0e6ffQwy068jGVH/YuCfxxPwf+dhHGtX0i0sATaZWx4S068xfqtbKxyZUj0y/mHFbBLr/iuZ2wvJ1z5bBEcOji+5xIREZHWLTBnAYGZP5F35ZnqsyLNKvjdz1Q+8io170yFiCVt9+3JOPFAvCOH1H6t5d1zCcHZC/Df+jj+mx6l6pFXyTz7CDKO3LtRQzBEEq2ysjJhlSuFhYUAhwH/A4YCPiBWghABOkYvmyOCs5vnXqC0qWtsk/9q64YryeLt1YWub99P8cW3U3LbY4RXl9Lhxn+sF7AsLIE++clZI8C3f0KHDGcU9IbkpDlr/DEBTW0Hd4DCTJiscEVERNoYvflvfv4n34Q0L1kH757spUgrYENhat77nMpHXyM0Yy4mO5OM4/cn87j9cHfvVO9zvEP6kv/IVQSmzcZ/y2NUXHk/Ne99Tt49l+AqyE3wRyCyeRK5LQjAWvulMWYMMBAYAGQCt0av3wIWAek41S2ZQEad64zoY7GLD6dqxYVTAdNkbTJcKYz2wllZmdx1uDLSKLzj/3B3yKf0rqcgYulw84VrBSwLS+JfDbIx05Y6I5g39fPc0KK/qlziyRhna9D7CyEcAXeb7BokIiIiTRXxV1H58kdk7jtWb2KlSSJryql6tmmjlH0jhuB95kaqX/mY8kvuYvWB55M/8Qo8/XvGefUijef3++naNbGTRqy1ZcA30QvGmLujDz0DvIIzPci9zsWDE6R4cEKVWLCShhO+lDbH2tpkuJLjg3RPcitXYowxtLvsdDCG0jufxFpL4S3/xLhclNU4AVCy+q2s8DtjoI/dctPHDi2C13+BVZXQPs5T48b2ghfmwqzl9Y+HFhEREdmUqtc+wVZUkqVGttJIoQWLqZr0OlUvfwhVNXh32Iqcq8/CN24Exr35vTKNMWT8fVc8fbqy5vRrKDn4H+Te8S/Sdh0Vh9WLNF0iRzFvxErAAiustc1SgdJYbTJcMcapXkmFcAWiAculp4HLRentj0MkQuFt/2JhiVOWkaxJQbX9VhoQYNRtajs2zgH7mO5OHPn5YoUrIiLSdmhbUPOqeOpNPAN64RsxJNlLkRbEWktg6kyqHnmVwORpcRml7N16IAWv3sma069mzWlXk3XRCWSefoi+B0jK8fv9Cd0WtC7j/KMYj1OFMj9pC4lqk+EKOH1X4j06eHMYY2j371MwLkPJrY+BhYWn/QtwJS9cWQppbqcqZVNix/ywPP7hSkEGDOoAXy6Gc7eL77lERESk9Qn8MI/gd7+Qf/UEvWGVBrFV1VS/+gmVk14jPO93XB0KyLrgGDKO3AdXh/xmP5+7cwcKnruJsv+7E/9NjxL6+TdybzgXk57W7OcSaaxkhyvWWgvMiv3ZGFO3aYRd57i4a9PhyrzVyV7F2owxFPzfyQCU3PoYOas9uEdcSI+85PynP2MpbNkRfA2oasxNg155zsSgRNihGzz5A9SEIK3NfhWLiIhIY/iffBOTnkamGtnKJtQ3Sjnn1gtJ32dnTJo3ruc2Genk3vkvKgf0xH/r45T8uoS8/12Ou2P7uJ5XpKGSHa6sy1obSeb52+zb0sJM+GJxslexvljAYgMh+tz9FBf5fXjPPQdnI0zihCMwtxiOHtrw5wwtglnL4remunboDg/PgpnLYPtuiTlnKqoMwn3T4Ks/IMvrTG7K8UG2z/m87No81akiIpICVGHRPCIVlVS++jEZ+4/DlZc6bwokdVhrCX33M5WTXqfm7SkbHKWcCMYYsiYcgad/T8ouuJmSA88n74HL8W7VP2FrENmQRI1iXrJkCRdddBHPPPPMfcACa+2tAMaYdsCJQDFQA1TXc4ndH4jeDgBBIGitDTXnOttsuFKUBWtqoDrkNLdNJcYY2l1+Oi9/F2CPz15g9bU+2l12ekK/kf9a6nxuBhc2/DnDiuDNeVBS5WzdiaftuoLLwJd/tM1wxVp4ez5c8xksrYBtOsHqKqcBcXkAymrgwRkwfgu4ehx0iHOTYRERkZai8tWPsf4qso4en+ylSIoJL1lO9aufUv3ax4TnL27QKOVESdt9BwpevJXS066m5Ih/kfvf80g/YJekrkkkUaOYlyxZwrPPPgtwBjAVZ/wyQE/gZuCP6J/DQKjOJbjOJcBf4Up1dBvRj9baa5tjnSkWKyRO3XHM3VNw+p7FcPO4c+ieHmCLu57CZKTR7p8nJuz8c1Y615sTrsT6rvy4AnaKc9+VvDQYUuiEKxfE91QpZ95quOJTp6Hv4EK4e28Y2WXtY4Jh+N90uPMbp0LryrFwwIBNj9QWERFpzay1+J98E++gPviGD0r2ciQFRNaUU/PO51S/+jHBb34EwDtyCJnXnk3a/uMaPEo5ETwDe9PulTtYM+E6yi64mdAvi8i68DiMy7XpJ4vEgd/vJycnJ+7nqaioqPvHuk0rOkavm/Lr9o6AwpWmKIpOjFrpT81wZUk51EQMJf/3D3LaBSi58RGMz0fBuUcn5Pxzi8Hrgn7tGv6c2MSgHxIQroDTd2XSd6lZfRQvr/wE//wAMr1wzTg4ehi46/n/1OuGs7eDPfrCvz6E896DN36BG3b962tfRERaFm0LarpIcSnBH+eTe/FJ+nw2g5b6ObQ1QQKTv6X61U+o+fhrCIRw9+lG1oXHkb7/uKRXqWyMq30e+Y9fR/lVD1B5//OE5v1O7m3/TKkQSNqOqqoqMjLivGUBGD16ND/99BMDBgzYGqg7lmYWcCZO89pcIAvIjl7WvZ1Z55KBM2EoB2fLULNIibekC0rghNfgqrHwtwT1iIi9wUyVcczrWljiXPdu56Lwjv/DBoKsvuYBCIcpuOC4uJ9/9krYol3DmtnG5KVDjwQ3tX1whjMyekyPxJwzmX5YDv/3oTN++v59oH0D/g/t3x5eOhQemQW3fAnHvOL8OUeN5kVEpA1ytcuFNB+R1WXJXookmLWW4PQ51Lz6CdVvfYZdU4Fpn0/GUeNJP3AXPMO2aDFhkfF5ybn2bDwDelFxzf8oOeRC8idekdKhkLRergRUTqWnp9O/f3+std/Xvd9auwz4X0Newxjjxsk/vNFrH07QUtVc60yJcCXD4/SKWJ7AoKOlhCt9CsC43RTdeym4DKuvn4gNh+O+RWjOysaNVB5amLhwZbuu4I72XWnt4cqqSjj9Lad3SkODlRi3C04dDgPbw/GvwYR34JH9waMKUhERaWOM2423f0+CP/+a7KVIgoQW/kH1a59Q/eonRBYvg/Q00vbYgfSD/oZvx20wns34TWIKMcaQedx+ePp2Z83Z17P6wPPJf/w6vEP6Jntp0kYkaLpxc4oAAWttTbxOkBLhSkG6c13SbJnRprXPcBqirqhM3Dk3x4ISZ+pLrDeM8XgouudSjMdDyY2PQDBMwcUnxyVhX+F3etFsTr+VmGFFTqPV0mrIT2/2pa0l2+eMiv7yj00f25KFInD2O1Bc6VSdbE6wUtdOPeG6v8HFH8GVk51tRS3kFzQiIkLL3YKRarwDelE9ZUaylyFxFCkupfrNyVS/+gmh738Blwvv6K3IOv9o0nbfoVVtofHtuDUFr9xB6dEXU3bODRS8cTeurPhv0xCJSfb/TcZZgAuw9Y1iNsYYGxXvtaREuJLhdapXVjfbbqdNc7ucKoBUrlzpU7D2m1/jdlN458XgcVNy22PYYIh2lzf/FKG5xc51Y8MVcJraJqKaJLY1yB+ALF/8z5cM//0cvvgDbt0dhnXc9PEbc+RQZxLU/6ZDn3w4aZvmWKGIiEjL4R3Qi8oXPyBSUoarIAUb77Ug1tqkv7GKsVXV1HzwFdWvfkJgynQIR/AM7kP2JaeQtt9Y3B3bJ3uJcePp1YXc2y6i9KiLqbjmf+T+9/xkL0nagFSpXImGJuH6HosFK9HbuTjNcKuttXEp60iJcAWgXUZiK1fA2RqUyuFKfSOGjctF4a0XYbweSu9+ioi/ig7Xn4txN19J49zYpKAOm//coUkIV+6bBt/+CeN6xf98ifb6zzBxBhy/FRwyuHle8+IdYVEpXP2Z0yNntz7N87oiIhJ/xpiU+YG2pfIMcBr8BX9ZRNqoYUlejTSFDYcJfvk91a99Qs27n2P9Vbg6dyDzlIOdPioDeiV7iQnjGzWMzDMPo/K+5/CNHUH63mOSvSRp5aqrqxPSzHZTjDHjga2AL4EvrbXV0fuNtdYaY/KBvYBxQHtgsTHmRWvtF829lpQJVwoyYHWiw5XMxPZ5aSh/AJZWOJUr9TEuFx1u/AeurAxK73mG0JLldHzgP81W4jh7JXTNcRrUbq6CDOiWk7i+KyO6OFONvvyj9YUrC0vgog9huy5w+U7N97ouA3fsCYe9COe8C68cBgMbEaStKxB2/i5S5BdYIiIi9fIO7AVA8OdfFa60UKG5v1L96sdUv/4pkeWrMNmZpI3fifQD/4Z3u6FtdjRx1nlHE5gyg/JL7sK79UDcnZvhBzyRDaioqCArKzljSOsEJ1sBFwE7A2dbaz+JHRN93AdchzNRqK7zjTHnW2vvas51pUy40i4dShK4LQicypVEhQCbY2Gpc913A+EKOL+5an/FWXh6dKb44jtYcsA5dH7qRjydmv5NdE5x47YExQwtStznNdMLW7XSvis3f+E07L1nH2e0cnPK8MLD+8PeTzmjnV87vP6RzvUpqYK7v4Gpi50gsCLoXAcj0DPPqbI5bLAmEomISGpydy7E5GQS/Pm3ZC+lxQuHw7Vbg1wuV1ynhoSXFlP9utOYNvzzb+Bx4xs3kvTLTyPtb9th0vWDh/F6yL3jX6ze92zK/nkr+U9c12aDJom/iooKsrOzk3V6gzN+eUdgW5zxzDMBjDEuwFhrw8BZwLHR56zCmQyUjzMl6GZjzBxr7Yd1tw81Rcr8a0tK5UqW0yQ0tF7bm+SqnRSUv+lj8048iE5P/pfg/MUs2et0auYsaNK5q4LO+RuzJShmWBEsWgNr4taHeW07dHfCnLIEnS8Rvl/uNAY+dTh0jFMgXJQFV45zPnePztr08cGwc9zYx+DR76BztjOxab/+cMo28I/tnT5GV38Gox6GKz6FX0vis3YRkbYqVfpbtGTGGLwDehP86bdkL6VFstYSDocJhUJ4PB7c0a3p4XCYYDBIMBgkHA4TiTT9B+xIeSVVL7xPydH/ZtWY4/Hf+CgmK4Psq86iw1dPkv/gf0jfZycFK3V4encl5z9nEPzyOyofejnZy5FWrLKykszMpDWHjv1n2B/IAj4Bfo/eZ6PBCsBBQDawGLgW2AO4AliOM5J5vDEmu7ma3aZO5UqSeq5YnDG3HZMWuq1vYYnz1dJ7I5UrdWXtvgNd37iHpUf/H0vGn0XR3ZeQve/YRp3751UQsTCoCZUrsaarP66AHbs3/nUaaoduTiXFt0tg11bSP+SmL5x/E6fEueHsvlvAy3Phli9hz37QvZ6+ftbCJ7/BtVOcKVZjusPlO9e/lei8UU4w9OgseOoHeOw7p2nupWMaXhkjIiISb94Bvah667OUasjaElhriUQiRCKR2mqVWLgSiURqg5dYsBIOhze7qsUGQwQ+m+70UfngK6gJ4O7ZmaxzjyLtgF3w9OoSt4+vtUg/bA8Cn36L/9bH8Y3eGu/QfslekrRCfr8/mZUrMbEupb8Aa6K3DWCNMbvXefwKa+2k6O2fjDEDgVOBXYAioKI5FpMyb3cK0qEs4Px2PFGKohUBqdbU9rdSpyogfTOir7Qt+9Ptvf/h69eD5SdexopzridSvvkfWGxS0JAmhCux585e2fjX2Bzbdgaf25mo0xp8vhim/A4TRsR/a40xcO0uzvWlHztBSl3WwnVT4cTXnduP7AdPHrTxHi1bdoTb94QvToKjh8HDM+G0N52tQyIiIqnAO6AXkdJyIitWJ3spLca6wcq6oVQsaPH5fKSnp+Pz+XC73RhjCIfDBAIBAoFAvVUt1lqCs36i/Mr7Kd7hGNacehWBz2eRcdgeFLx0G+0+fois845WsNJAxhhybjgPV/s8ys6/CVuV4N4L0iYkeVtQTCw8WQ6s+25jDNAJqAQmAxhjYu+wPwaKgaE0Y8FJyoQr7aKNhhPZdyVVw5Xfy6BH/uY/z9OliK5v30/BhcdT/vx7LB57AlVfzNqs15i9EnJ80K0Jkwk7ZDrh0OwE9V1J98DwTq2j74q1cNPn0CUbjtkyMefsmgsX7QCTF8Hrv6z92H3TnGlFx24J7x3jVAY19Bd8RVlw3d/gmnHw8W9w6IuwrFkyYRGRtktVFs2jdmKQ+q40SN2KlPqClfq4XC68Xi8+n6/24na7iUQitWFLzcLFVNz5FKt3PZWSv/+DqmffxbfDVuRNvIIOXz5BzlVn4d1moL7uG8GVn0PuLRcS/nUJFdc9lOzlSCtUWVmZtIa2OBtQwNkSBFA3sY19wxgLZADvALFmBbFSjuLobRcQbK5FpV64ksCtQUXRLWIrKhN3zoZYvAZ6NDLcMF4P7S4+ha5v3gtuN38eeC6rrr6fSFXDGpLMWQmDOjgTZZpiaBH8mKDKFXC2Bs1ZCaUtPJh/bwHMWg4XbL95lUtNdfxWsHVHuGryX/8Gn/nR2Z504AC4epxTHdQYx20Fj+zvVGQd8FziKppEREQ2pO7EINk4ay2hUAhrLS6Xq1FBR6yqxev14l60jOATb1Fx5MWs2e10Ku96GtOxPZnXnUPBV0+Qc+f/kbbrKIzPG4ePpm3xjd6azFP/TtXTb1PzwZfJXo60MkmuXImFK+XR615AxES/QRljCoHB0ccmx46r01slgtNzBZzKlmaRMuFKQXTs7+oEvjkuTMHKlaogrKyEHnlNe530kUPp/skj5B67H6V3P83inY7F/+5UNtarJ2KdbUFN6bcSM7QQFqyGymbLATdu+27Ov7BvliTmfPEQjsDNXzpTov4+KLHndrvgv7s6TYivmwrvzIdLPoZxPeGW3Zsetu3SC1461PmGc8gLLfvvSUREWj53+3xcHfIJzlmY7KWkrLqNa2N9UxojvHwVlS99QOkFN7Fiu6Mo3uN0Kq59ECqrybn4JDpMfYx2z9xIxmF7QFZGbWPc5mqK29ZlXXAcniF9Kbv4TsLaBifNKJk9V+qEJPNwKk8OBDpYRxg4F6eRLcAndRrcxnTBqWoJ0YzhSko1tIXETgzyuZ3zplK4srjMuW5quALgys6k8NaLyD5oV1ZefDvLjv03GX8bRYfrz8XXt8d6xy8qdcKQpkwKihla5IQdc1bCiARsj926E6S54aslsEff+J8vHl7+CeavhgfGgycJseegQjhtuLMV6JWfnM/p/eObbwz0oEJ49Qg48iWnB8trh0PP/OZ5bRERkc2Vtv2WVL74AZ6+3ciZcKS2ntSxqf4qGxPxVxH4+gcCU2dQM2UGoV8WAeBql4dvx61JGzOctJ2G4+5atNbz6jbGjQUrsdt1H4/nuOfWyKR5yb39Ilbvfx7lF91G3qNXazyzNAu/30/79u2TvYyncUYtdwUmGmM+wNkqdBqQDnxvrZ1dz/NGAmnAHJzxzM0iZcKVgiSEK+BsDUqlcOX3aI/j5ghXYjLGDKf7J4+y5uGXKbnpERbvdDz5Zx5OwQXH4cr+a3zWnGZoZhszNPr/5Y8JClfSPbBNJ/i6hVZEBMJw+1ewVUfYK4nh0D5bwP3TnHDnkf0gs5krcjtmOVuEDngOTnoDXjkMcjU9UUSkwRQANJ+C2y4Ct5uy/z5CYNYvtLv9Ilw5SesfkDI2N1ixoTDBH+ZRM2UGgakzCMyYC8EQpHnxjRxKzsG7kTZmOJ7BfRr0pr7uZKHYOuqGLOFwuHZ7koKWhvH060HOpadQfvm9VD36GpknH5TsJUkrkORRzDGTgbeBfYDxwN78tTsngjN2uT574fRm+dZaG2quxaROuBLbFpSEccwpGa40oaFsfYzXQ/4Zh5F90K6suvoBSu96ivJn3qbdxaeQc9Q+GI+HOSvBbWCLZgggO2ZB+wxnHHOijOoKd38LZTUt7w37ewtgSflfk3uSIWLhik8hwwOVIfjmT9gzDkFPr3ynOueYV2DC2/DoAcmp1BERkbbNlZlBu3suoWLrAay59kFW7HcO7R+6Em+/9at724rYVqDYiOr6ghVrLeFFf1IzZQY1U2YS+HIWtsz5YdozpB9ZJx9E2k7D8Y0Ygklv2g9k6wYtseCnbtCyuaOe26r0o/ah5sOv8d/9DBnH7queNtJkFRUV5OTkJHUN1tpqY8w5wBKcgKUrTiXKfOB5a+1r6z7HGLMH0B9nO9GHzbmelAlXfG5nSk0ipwWBE67MT6Hth7+XOdUCsW1Szc3TsT0d772UvJMPYtV/7mXlhTdTOvFFOlw1gTnlo+jXrnkaqRrjVK8kamIQwKhucOc38O2fsGvvxJ23OTz5vVOtNK5X8tbw1A8wfSncvBs8MB1u/Nz5PMYj+NihG1y3C/zfR3D1Z07DXBERkUQzxpBz6iF4h/Rj9ZnXsmL8BNrd8X9k7D0m2UtLuFjjWlh/602kpIyaz2dGq1NmEv5jOQDurkWk770TaTttg2/01rjb58dtfbE1xRrjxkKWWCAUC1xU1VI/YwwZx4wnMHkaga9/IG2n4clekrRwyey5Upe19ldjzMXA80BnwA3MtdZ+Y4wxdfqzYIxJBw7DCWBKgdebcy0pE66AszUoKZUrlc5v7ZvatLM5xCYFxbt6IX34YLq8cQ/+Nyez6uoHWHr4P9mv/3bMPu5MoF+znGNoIfxvMdSEIC0BX2nDO4HX5WwNaknhyi+rnF4x/94xeV+DyyqcMGXH7nDoYMhLd/qivDAHjhwan3MeMRTmrYaHZkK/AmeqkIiIbJy2BcVH+uitKXrnfladdhWrTr2SnLOPJPeiEzDuZmo8luLWHbNsqwMEps2m5vOZBKbMIPjjfLAWk5OJb/TWZJ12iNM3pXfXpH1N1lfVUrcJrqpa1ucbsw0mK4Oadz9XuCJNlirhCoC1tgT4qO59dYOV2BQha201cEr00uxSKlxpl57YUczghCuhiHPe9knfMuZUrvRqxn4rG2OMIXu/cWTtMZqlD75C75seY8vLT2L5d7vT7uJT8Pbo3KTXH1LkfG5/WQXDOjbTojciw+v0LPn6j/ifqzk9+YNTuXXo4E0fGy9XfOr0fbn+b06wt0cf2LYz3PYVHDCg+XuvxFwyBn4thSsnw+DCxPTnERERqY+nSyFFL91G6X/upfyeZwh8/wvt7r0Ed0GCfjBLgtptNqEQ4Z9+I/D5TGqmziTwzY9QXQMeN77hg8i+4BjSxgzHu9UAjCf1Aqe6VS1AvVUt1lrcbnebrmoxaT58u4yk5v0vsVef1WbCQ4mPysrKlAhXjDEunP4pts61rVuxUvd29Dlm3fuaQ0qFKwUZzhjiRCqqM4452eGKtU7PlZ0SvNXXpPlYeMDhnFa5D8+sfArXcy9Q8don5J10EAXnH9voEs9hdZraJiJcAWdr0APToCIA2b7EnLMp/AF4eS7s0y95X3/vLYB3F8DFOzr9UMAJWP49xhmb/MhMOHu7+Jzb7YI794S9noYL3od3jmoZf28iItI6mTQfBTdegG+rAZRcdjcr9j6L9hOvxDdsi2QvrdkFlyyn+rPpVH82neAXs4ischr/efr1IPPIvZ2tPqO2XGv4QUuxqaqWUChUe0xbC1rS9hxNzZufEZw+F992cSpPljYhVSpXrLUbndlujPHy19aMCqDEWlsRj7Wk1HeTdhlJqFyJ/n+xIsGhTn1WVkJ1qHknBTXUnGLwp+fQ7eoz6PH1M+QcugdrHnyRRdsexqrrJxIuKdvs1+yeC7m+xDa13b4rhK3TO6QleP0XKA/AMVsm5/zlNfCfT2FQBzhlm7UfG9nFqWB5YHp8t+vlpMEde8AfZXDV5PidR0SkNdC2oMTIOmofil66HcIRVhx4Hv4X30/2kposUu6n6v0vKLnsbpaNPZHlo45mzUW3Efzye3w7bUverf+k6OunKPxoInlXnkn6rtu3yGBlXS6XC7fbjc/nIz09HZ/Ph8fj/H45HA4TDAYJBoNrhS+tmW/sCPB5qXn382QvRVq4ZDe0Nca4jTFHGWNmGGNuMca0q+eYjsBVwAfAbGAe8KQxJi774lKrciU9OT1XIDUmBsVrUlBDzF3pfC7aZwKZRRTdcTH5Zx7O6psnUXr745Q99BJ5px9K3hmH4c5r2D8iY5ytQYkMV7bt7Ew8+voPGNszcedtDGvhie9hYHsY0bQdWI120xewvAL+Nx689VSGXjQa9nwK7vkG/jM2fusY2RXOHAH3fgu79YnPlCIREZHN4dtmIEXv3Mfqs66j5PybqHz1E3xD++Hp3RVPn254+nTD1S4vJQIvGw4TWVFC6M8VhJcWE166gvCfK6OXFYT+XElkxWqnb0p6Gt5Rw8g+Yi/SdxqOZ2DvlPgYEmVDo55jFS6xbUOtdfuQKzsT387bUvPeF2Rfflqb+ruX5lVTU0NaWuJHtNbZ0tMNOB7YGvjdWru67uPGmBzgMmBCnaf7gP2BocaYw6y1M5pzbSkVrrTLgKoQVAWd/hmJkFLhSrQ4JBmVK/NWw4B1RjD7BvSm00NXUXPBsZTcPImSWyax5sEXyT357+Sd/Hc8HTc9s3loETz+HQTD9b95b25ZPtiyo9MgNtXNWg6zVyZv/PLclU64c/xWsHWn+o/p3x4OGwyPfw8nbwNd4xj8nT8KJi+Ciz+CbTr99W9TREQkWdwdCujw9I2U3fYYVW9NoXzqDAiFax83uVl4enfDGw1baoOX3l1x5TTPf2Q2EiGyag3hP53AJPTnSsJLo5fofeHlq9ZaF4DJSMfdpRB3l0LSx47A060jnpFD8GwzEFd6mt5Us37QAn8196078tntdreqoCVtz9EEPvyK0Pfz8G7VP9nLkRYsSf8uYn1VegHDgdXAy+BUswCxErTdgTOBMLAc+BXIBfoCfYD/Aw5vzv4rKReugDOOOVHhSobXGQGdCuHK4mjlSrcEV65ErDOOekNTYdKG9KPTpGup+WEeJbdOovSOJyi99xly/r4beWceTtrgDZcZDC2EmjAsKIGBHeL0AaxjVFd4eGZiQ7rNZSMRXviigr7l5exbU07lJ2VE1lRgawLYUBgbCjk/JEUiuPKycRXk4W6Xi7tdPu7CgmYp073zG6e/yT+23/hx542Cl39yqkqu37XJp90gnxvu2BPGPw0XfQCTDkhO6CQiksr0hjjxjMdN3r9OIu9fJ2FDYcJ/LCe48A9CC/8g9OsfhH5dQs23P1L56sdOWWqUq7AgGrQ4YUttANOzCybdaTBmrSVSWrZWlUn4z5WElxZHq1Cc2wSCay8qzYu7UyGeLoWkbb8l7i5FTpDS2QlTPJ0LMfk5tV8vsca1dScCydpibxLXrWqJNcRtTaOe03YdRbnHTc17nytckUaJQy/YxigC2gMzgTnR+2y0aiUNGIvTBmUecJ619l0AY8ydwDnAIGPMjtbaZtsjl5Lhyuoq6JLA7VtFWakTrnTMgvQE/60sKXcqhrZYb5fa2tKGbUGnSdcRWLCYNQ++SPmzb1P+7DtkjBtJ7gkHkrX7Dhjf2mnG0FhT2xWJDVcemO70XRmT4ObAddlQiOBvfxKct4jgwj8ILvqT4G9/Elr0J8HFyzg5GOJkYPWdm//a7sJ2ePt1x9u3O75+PfD26+H8uUcXjHfTX0BzV8I78+G87ZyxyxvTJQcOHwLP/ghnjYxv+LdFO7hkJ2d60ZM/wLFJ6kWTCiLW6UMzbxXMK3Ei+vYZzvfJ9hnQIdP5u9HPxyIiiWM8bjy9uuDp1QX+tna3d1tVQ2jRn4R+XbJW+FL90VdEVpbUeRGDu2sRxucl/OdKbHXN2ifxuHF36oC7SyG+rQfi3qcIT7QCxd2lCHfnDrja5zc4IFGw0jixqhaPx1O7ZahuNUtLHvXsys/Bu/2W1Lz7OVkXnaCvCWm0JH3txE4afafJUpzKlLqPDQVGRm+/aK191xiTYa2tAl4DDgUGAP2BVhquRN/kJWMccyqEK7+XJWlL0CrnelPhSoyvb3cKb7yAdhefTNljr7Hm4ZdZfsKluNrnkXPIHuQcNb62mqV3vjPG98eVcEh8lr+ekV3AZeDrJYkLV0LLV1Hz3c/UfP8zgR8XEJjvBCoEQ7XHuPJz8Pbsgm/oFszfdmfeLG3P6X/LpVuPHFx5ObjzczBpaeBxY6IXgPCaCiIlZYRXr3Euy1c5Yc383/G/M4XyaHd/ADxuvL264t2iB74teuIb2BvfgN54t+iJK+OvPZF3fuNUbJ28ThPbDZkwAp6bDfd8C/+NY/UKwPFbwse/wrVTnL45yfg3kSy/r3G2an31B8wvgcrgxo/vluP0qNm9jxMqJmLrnYgknzEmVX5rKHWYjDS8A3vjHdibjHUei5RVEPp1STRwWUJwwWIIh0nfbfvaihNPtPrEVViAaaY367EpObFeInoT3Th1Rz17vd5WMeo5fc/RlF9+L+FfFuEZ0CvZy5EWJhbWJknsxIXR63Kgep3HBuH0YlkDTIveF/vJegnwEzAOaNZxRykVrhTEKleqN35ccyvKgpnLEnvO+vy+BkZ3T/x55612rvs1MFyJcRfkUnD+seSffSSVn3xL+dNvseaRV1jzvxdI22oA2YfuSfZBuzKoQztmJ7CpbU4aDCl0wpV4sMEQNd//TNUXs6j+5kdqZv1EeFmx86AxeHt3xTugF1l77oh3i574+vfE26c77nynHMtaOPJxpwJhwGGbPp+7Q8FGHw+XlBFcsJjA/N8JRi+B+b9T+eFXf4U7LhfeXl3wDexNWffe1Czvzfm79CLX9MDp67RxnXPgiCHw9I9w1oj4Bh7GwI27wm5Pwr8/gicPat3VGRELU3+HSd85oZLbBdt1caqF+reDLdo718Y4VX2rqpzrpeVOj5pnfnSem5sGu/Z2pj7FKsZERCQ1uHKz8W01AN9WAxJ2TgUr8bOpUc8toarFt8cO8J/7qHnvC4UrstmqqqrIzEzaNLHYbxdivzn21Lkd+0a3NZAOfAH8vM7zqoFA9PYmfpW5eVIqXKm7LSiRYpUr1ibvTVx1CJZVJGdS0LzVUJj5V7i1uYzHQ9buO5C1+w6EV5VS/tKHlD/7Nqsuu4tVV9zLecNG8GzfPQjtvROe7EaeZDON6upUAFSHmr7NKlJZTc3MuVR/9T1VX31H9Tc/YiudL1Jv3+5kjNmGtK0GkrbVANKGbbHJfijfL4eFpXD6tk1bV4y7IBf3iCGkjxiy1v02GCK4cDGBn34j8NNCAj/9SuDn3zDvfsFlkTC8Cgv/4cbbqwve/j3x9evpXPfvibdXV1wFuWv9IDZh5F/VKzft1jxr35DOOfDvHeHST+DFuXDo4PieLxmshbfmwW1fOT2JOmTAOdvBMcOg4wYy9Nw06JX/15+P28qpbpnyO7y/wLm88pMzQvu8UQpZRETaqtgb/ZZSRdGS1a1qAeqtaokdl0p/H+7Cdni3HexsDTr3qGQvR1qYiooKsrKSPn2iIno9DMgCsNYGo+OXt44+NgNYGL1dN5SJLb6sOReUUuFKXpoTNSVjW1B1CMoCzhqSYUm587edjC0Q81dvftXKhrjb55N/2iHkn3YIgZ9/pfzFD6h55gP+OesaFr3hI2P7YWTsPIKMsSNIG7YFxh2ffQw7dIOHZsKsZbB9t4Y/z1pL+M8VVM/6mepvfqD6mx+o+e7n2goQ35C+5By5NxmjtyZ9+63wFG3+J+61X5zmrXtvsdlP3SzG68E3wNkWxAG7ADBnJez/eID/Z++8w5yo9jf+mUndXtll6b13EFCQDtKboIJdvHq59n7tvf3svXutWABFEFAREBALvUnvsPTtLXXm98dJYEXKlmSS7J7P8+SZLcmcs9lkMued9/t+76m1j0vsu3Bv3o1r625c2/b83emCuMpmrp+BpX4tLA1qEVU7nbscNflqQTp7mqVTr25sUK+CTWwL322BxxdDn/pQw+Djt+504c3JR8srQCsoRisoQssvEttiB3pxCXqJE63EgV7iQHe60d1udLcX3G50lxtd00HThJLiv6kqbl1hW45KXonCzXYTdVPMZCRbMK8zg8XMUasVJcoqujn4bmq0HSUmSmyjfdvYaCyx0QxIjmZQvWjye5v5aI147f/0RflEFl2HY8WwJUsIrtuywOUFm1m8Xm0mERDdJFl0c8qIrdqOIokknJFlQZLTIfNVQs/ZXC0ejydsyodsg3tQ+MS7eHYfEDlCEkkZKS4uDqW44u8GtAVR4tMMuEZRlGcQJUK3AP7wg2W6rjsUcTD0f3DW4URey7FATiysxBWTCon2EJQF+YwGR4pCJ674OwXVNVhc0XWxkBrbIvD7tjZvSMr913H42mu55f/W85BjERkbVpL9xDvwxDuoiXHYz22PvUNLbB1bYOvQAlNSYKw759QWQt2fmacXV3SvVwTObtmFc91WkZmydgtef+ic1YK9Y0sSJ1+MvVs77Oe0qfT8vBp8v1WIBaF4rb3yJ0RFW7loTGPi7H/v8qS7PSJ0d9ueE+G7uw/g2ryLop9+A5ebfkA/wPMe7IqJEm0da6Vhrp124mv/97XSUKPPkpZ7BlQFnhkAgz+HRxbBG0Mr/nfrXi/e7Hy07Fy8x3LxZpW6HctFy87zfZ8nBJXcfPTiMhyIFAUlyo4SZUWx2VCsZhSLRWzNZnFQU1VxP98Jbk6Jzr5ccfLbMlonFS8c9uDZ58bt9oiOUW4PusOJXuI8ywROmo7Nysi4aEbGxJBjjmavFssqSwwH0qNp0ygGW0I0anwsalwMalw0WnQ0qwtj+PlIDCvzozisR1NijcJhsRMfbSLaLAQWp+/mKtXls0a0EFk6ZcCwpqHNxtE1TYhbDqd4/kp8W6cLzeEUQpfTDS6XuJ+r1M3tRnd50N1u8G11twfd7QGvF93jBY8X3esVNVxe799Fs1Ohiv+7oiq+14AJxWISrwl/ppLVevz1gs2CYrGgRtlQbFZxs1tPCGtRNpQon9AWZT/xfZDEaYlEEplUd2FFK3ag5+aj2G1gt4njaIjFi9O5Wk7V6rn0/Y3CNrA7hU+8i2vJKimuSMpFYWEhsbEBjSspM6XaJn8L3AnUBq4GWiPElb5AEvAnopOQH7/A0gkhsBwAAhoOElbiCojSoFA4V0CIK2UNdQ00e33iitFlQYcKodAFzVKCN0bTVJUtDdvzc4f23NdThL+W/LqKkl+W41i+geK5vx6/r7lBbazNG4gFep304wt2U1IcSkw0arQdNSYKrJYznjQk2KBlDVix04kr9QieA0fwZB7BkynaJ7o278K9dTe6w1dup6pYmzcgul83bB1EiY+1bVNUe2AVkD8z4XARjDKu5Po4G4/CDzvg1m6n7hCkWMxYm9TD2uSfKcC6puE9losn8zBf/HiIjesOM7nWEazHxHNavGHbCVGqFGpiHKaaqZjTUzClp4htWjJqYjymhFix0E+IQ42LBpNJOCF8YgRAPbeH/9Z28/FvHhapXrqliYWzVlTKNVJYLBwlhSecJUIgKTgulGh5haddCKvxsZhSElFTEzHXSRelXUnxmBLjUZPjRSvsuFjUuGifKBGDGhstFrc2a5lPXvOd8OBCmLFFZAK9NAian6WDlq7rQiRwuNBLHGjFJeiFJWjFDvEcFJWgFRajFRaL56HI57ApKMZeUERqXhGHDx3Csb6YIyuLiHUWoni9fxujLuLT6OqTxlaiSi/k7ShRNrDbKVYt5OtWcjQrxzwWcjQLM1QzKXEmGqSaaZhqwmYz+f6HilDJFN/tuIsHQEf3aieEC7+Q4RM2hPDhE5tcbp9Y4j4umugOp3hefN8HBEURHc/MJhSLWXTeMvnCpU0qimoSf4+qiu8V5URl7/F/mni/nBBiNPBqPpHGc0K48XgrPe/jIkyUHdUv8kXZhSDjX1jYfa9Tuw3FZhF/l816fBs7dgDmmga1cpNIJEGjuuWraNl5uP/ajvuvHcdv3p37//lZb7X4hGrr8a1aMxVbj47YenfG3KKhoc/Vya6W0iKL/39nZPmQkuBbHJ/c6lsiOQtFRUUhLwvSdd2rKMp9wGdATaD0pdhC4Gld13f47quXeq8PBqKAuZzoMhQQwk5cSbKHJnMFQtsxaG++sN0bXfrgD7MNpqhkMUGLFNGOGcCcnkLchQOJu3AgAN68AuEaWbMF5+pNuHdl4vhzHVpuwel3ajahxkSJRa5/a7OKBadvsflyXjFmp4N9Jz3UlFEDa/MGRF09RpTNtBSlM2pM8PNgZm6FGIsIHjWaV/6EeCtcU8YOQaVRVBVzWjLmtGSGNmvJ0/8DrRm8MOjEfTSHE+/BY0LI8otZB47gPZyF53AW7h378BzJLvcHeB/fjXeE7++0c7RZUeKiMcXFoibFoSbFY2lURwg8SfFCQElJwJSSeOKWnPCP9uHB4HAhXDFDdAC6rZvIrylLZx9FUY5fgSOxYv3p6wErDsAd82F7lk5NiwtnbhEJ3mL61yhiYFoxre2FKA6HEGuKitF8oo1e7BSiTonY6iUOYp1uYpxF1PQJHh6HmxKnF5dTCAZZmher7sGkgIKvFErzl0SdEFoURbg6FJPpuJsDVUWxCgEA39YvAqhxMSg1bOL3dl/JlF80iDrxtWq3/l1MsFpOuEGsFt/PhGsEv3vEvzXYCaLruhBZ/OKRyy0EI58DRytx+kQkf/mZcDMd/77YKQQ3R6mfF4v7ebNyT+zDL0b53Dqly//s53WQ4koEUtUXzpLyUdWFFd3pwvnrKtxrtuDeuBP3X9vRDp5w8isZNdBaNMHZvzdFyTVwFrtwFrlwFznxlLjwljjRSoSTEaeLGlv2UWfx+xQ8/T558cnsbtWFQx26UNCpE1E1k0i26yTZIcmuk2TXSY7SibUEvhT2VOVDJwsuQQ/FdfsuuFjCbkkoCXNC6Vwpja7rCxRFGYdoTNsckadyDJgOzDrpvrqiKH0A/0psIZAVyPmE3TspOQr2BTRW5uyEhbiSJ0qCVIM/D7eWsw1zRWmTJsI7TxUabEqII7pXF6J7dfnbz7XCYrFI339YOBKKSoR4Uuxb+Pmv2vsXgk4X5rRk1MZ1UWKjyfLGMP1AHBf1T6N5mzTMtdMxZaQG3I1SVlxemLNNtM2NCv56/m/syhGulZu6Vr4cKT0GLmsHH60R+/MHrKp2G2rD2lga1j7tY3VdR8srRMstQMsrwJtfeNxtgv9qv66dqIj0OQj2FFl4ZpmZ8xpZuPpc+wlXhd2GEhuNKT4GxXb2rkehYHu2EFZyHfDxKOPag5embjy0rQHbsxUOeWzEJNm4d1Ayg5sEbgxdh7WH4atNIlS30AWDG4vXiAzWPTWKooDfIWOAuOtH13XwOYMUe3i+byQSSdnQdR2PxxMW+R2BQtNhV67C7uW7if3+B9IX/IytIB9NUTlWsx77a3dge+embEprxvrkJmTbE/++g3jfDbCoOgk2iLfpJNh8X1t1UguOkLpmFRlrV9Bo/Z+0/+MnALalN+X3pufx7TkXkhdzolujRdVJ9IkuyXbx9ckiTI1onTY1NKIqsLo6U/lQMENxdY8Q2xWzLDOVlI8QZ678DV3Xf1MUZQ2QAViAHbqun+5qbiLwNZAOfKvruuc096sQYSeuJEWJE3QjibOKjjKhFFf25YWuU1ByFKQEuZNW2zTRxndfftmzGdTYaKzNGmBt1qBCY9ocMOUdSOsAHbpXaBcBZfEeyHOGpiToo7VgUeGKdoHZ3+TO8Pl6eHUZvDjo7Pf3oygKpsS4422py0oboFkjeHY1nNNAZOpEAisPwjUzwazAV+PE+8BIHB54dxW8uRy8uvi/DWgE/50P/54t2mrffi6YA3COpijQoaa43XkufLga/rdGiHoDGopytLbplR9HUnkURQGfm0cikUQmVSlf5XCRwrojKuuPqmzdW0yNhQvpt3I2rQ5uxq2aWdr8fH7pPJR9zdoRFW8n3qqTaNNJt+k0s0GCzeUTTyDRph//Ot6mE20+neMkEUaKJDld0/D8tQPHopW0WLSSJks/4fIVX5EzdiS7xlzEMXsS2Q6FHIdCjgNyHAqbs1RyHAp5zr/v3KLqtE/T6F7bS/daGm1raFgroFucztXiF1kC5mrx+Jwr5rBbEkrCnHBxrgAoiqLoul4M7DjpZ//IBNB1fQYwI1hzCbt3UrIdchzGtkVWlBPtmEOBrouyoK4hWDBuyzYmZ8a/qFp32LjgywQ7tE6D3/fDrcYMeUa+2yLK3s432LlQ4ISpG2FEsxMurcpSIwYubwcfrIYbz4FGSWd/TGW5ozv8uB3+uwDmTBBdbMKZ+TvhP3OhZgx8MhrqJxo7/vZsuHEubDoGQ5rAfT1PvPdmXQIPL4I3VsAfmfDaYKgdQHE30S5Em0md4OO18P4qGP4lXNgS7jnv9K2mJRJJ2YjkhbSk8kSysFLggg1HhZDiF1SOFCm037OGYWu/58LNv2B1uyhu2BDHXZNJGNuX8TUTmHBcPwh81wtFVbG0bYqlbVPibrwEz/a9FLw6BeWLqaR8+x3RV44k5rpxmFIS//FYjyYunOU4FPYXKCw/aOLPAypvrLTw+kqFKLNO55oa3WoJsaVlioapnFpIaVeLxWL5R6tnj899UpEORLpfXClLrbJEUoqioqKwEVd85T4KYOJEaK0/Y8UvsPi3iu8hQWm5F3bLk6QoUT5R5IZYAy+qpcfAkWLjxitNrkNY6I3utuHvFDSyWfDHap4i2rmuOwLDDRjPz3l1hGvD4RHupFBR7IZ5O0VXJqM/v6ZuFO+nqzsEdr/Xd4JP1wn3yssXBHbfpyLGCo/3hatnwtsrRYvhcGXudrhhDrSqAR+NgtQgO8NOZtpGeGChKD/730jod1LGT5QF/m8A9KgL9y2AIVNEfs7ARoGdR4INbu4KV7eHN5bDB2vEc/OfLvCvTqF9T4aKAqdw8OU7ocQDJW6xdXkhyiyclLE2sU20i8+m8p6ISySSqkskCSsuL2zJPiGirD+qsjP3xAGtQbyXK/b8RJ+Z/yP64AGU2GiiLh5I1EUXULNds5D9beYm9Uh69b+4b5pI4WtTKHpnGsWfzCL6ypHEXjcONfnECbtZhZQoSInSaZKk06ee6BCb5+S40PLHARMvLBOLmjirTtcML91qaXSv5aVJkl7ui8lna/VcLlfL8bKgaviBLKkURUVFpKeHhyVZURRV13UNKEuJT1BEFT9h905K9pWdZ5cYK66kxcDmgHa5LjvHOwUZLK4cKRYn+E0McK5YTdAyVThXjOTcOqIsYsWB0GRd+Jm3UyygRgWh5fWZ8GpCXOqcAe0CfPyrESPKjN5fLXI1GhvgXunXUIiBry8X7X+NeO2Wl1/3ws0/QPua8OloY49jRS548BeYvgm614ZXBkPNM1xUGNUc2qfDDXPh2llwQxe449zAL+bjbPDfnjChDTz1Kzz/O3y5Ae47H4Y2Mc6laCROjxCTVxwQIvbuXNiTC8fKGdhuNYnMnAaJ0CBBvObb14RmyfJCo0RS3Qjn4FpNhz15CuuOqqz3iSmbslTcmphjapROuzQvI5p4aFtDo5XjAPpjr+JctBJL26ZE330XUUN6okSdop1hiLA0PUlkeXsqJTMWkPzRE1hanLkzQYINBjTwMqCBF3BztBiWHTDxxwEhuMzfI5ZgKVE63Wp5jztb6saVT2ypdFaLP+BcZq5IykmYZa5oiqKkAj0RXYOsgAvIR7RmLvTdCoBioAgoARw+USZghJ+44jum5pQYKzakxYhMjFCw1xfgW9fgzJVtBoXZ+mmXDjM2iw9go4J7u9YGkyJKg0IprszcAhmxcE4tY8f9ZTfsyYO7zwvO/q/vLNwrr/wJrw4Ozhgn81AvWLQH7p0vckyMDoE+E2sOwb++h0aJ8NFIY4WV7dlw3fewM0fkm9zctWwiSYNEmD4eHvGVCa05LMqEgpHDVD8R3hkOv+2DxxbDf+ZA7/rweB/jy6YCjVcTGTtL9sKyTFh9CJw+t3VGrPj7BjYWAkndBOFKibYIt0qUWQglJR4odArbfIFLXGTYm+cTZvLE81biOw+2m6FNDSGOda4lXHpJvosTxW7xuH35sD9fbDPzxT6L3cLJVuyCYg98PkY4rCSRRTgtqiXG4A+uBcIiuNblhd8zVVYfNrHuqMqGoyoFLvG6jLbotEnVuKKNh3ZpInekZowQDXSPl6IPv6XwxU9AVYl/9D9EXz7c8G5t5eG4yPKvC8m+5iGyxt1O0jsPY+vRocz7qBENw5p4GdZEfDBkFijHXS1/HlCZs0MsyTJiNXrV8XJ1Ow/1E8p/gb0srpbSQou/LEgG2krKSziVBSmKcg+ivXIjoG45HuoAAnrGG3biiv/kMMvodszRJ048ow3u5OJ3rhgurvjbMKcYM167NLEQ35VrjMsBxOK2XboQV0JFrkOIAVd1MF4I+HCNcC5c0Dg4+0+NhivbwzsrxWLeCCdJjRi4/3y4+2f46i/hhggHtmbBld+J5+STMSLzxyiWZ8KkWcKe/PlYUe5THuxmeKY/dKopyomGfQFvDYOONYMz3/PqwvcT4JO18MIfMPAz0Z76+s6RVSrk1WDZAdEJ7YftcLRYiLmta4iOWl1rCUE1UEKVrguhZO1hIeStPQyfbxDlVgCJNkARx5zSRJmhVpwQdOKsotQoxio+64wUACUSSfkJpzIgXYd1R1W+22Zi7g4zuU4Fs6LTLEVnaGMP7WoIIaVRon5Kcd+9fhu5/30Zz4bt2AZ0J+HxGzDVipx2cpa2TUmd8QrZVz1A9pX3k/B/txE9dkCF9lU7Tmdscy9jm3vRddidpxwXWr7bZmbaFjOjmnqZ3NFNnfiKVTGcydXi/9rjEB8Yuqw/lZSTcAi0VRTFDLwM/AvRJQiEK8UOCNuY+Ln/bMdb6haFcLIElLA7jfWXBeUEPq/qjJRux+xvLWsUe/MgNUqc7BrJ9myItwlhyQjalQq1NUpcAXFF951VomTC6OcYRMaEWzO+S9DWLPh1n3CtBLN84LpO8InPvfLakOCNU5qLWsE3m0WJSf+GgQvqrSj78+HyGaIj0+djxOLVKL7fCrf/BLXj4OPRlXP8XdQaWtYQnYTGTxUuocvbBadsx6zCNR1haFN4Ygm8+Ido4fx4Hzi/fuDHCyTbs2HKehFSfaxECEL9Goi/pU99UQYVDBRFiPDFbjhQIMbVSplZi9wiXBHEa7FdusjRGd5UuGUkEklkES7CSmaBwsxtJmZuN7M7T8Vm0ulX38vIph6619LOKoprxQ4KX/qUog++QU1OJPHNB7AP7RmRDixT7TRSpr1IzvWPkXfbc3gzjxB744RK/S2KAg0TdRomepjQSgj176+18OUmMzO3mRjb3MP1HT3Uiq1cVMTJrhYhsIgPDU1VcbvdAW/1LKm6hElZ0HDgP4AG/A68CSQBryJElk8R5UAtgNYIZ4vJd18XkBXoCYWduJJUKnPFSEItroTixNffKcioz7YmyeIK6rrDMMbA7JFz64hyh+UHoE8D48b1M3OLKBNpY7D1/qM1YDMF39mREg1XthMhszd1hWYGOKEUBZ7uB4M/h4d+gbeHBX/M05FTApd9Kxa8U8cZV86o6yJP6KlfhTviveEnjp+VoW0azJ4At/4o8ltWHBTPdbCEyZqx8PoQIZg9uBAumyE6Wz14fnh1FXJ64McdogX5H5lCvBjQSMy1b4PgOh7dXvhtP8zZBvN3iRNvEMfvy9tBtzqiPCg9RoR3LzsgSgLn74JnlopbsxQY2BD6NxKOpHAqp5NIJP8k1MJKvhN+3GXiu21mVh4SV2jOyfAyqZ2TCxp5iSvjZ4Lz11Xk/fcVvPsOETVhCPH3TkJNiAvizIOPmhBL8idPknf3SxQ+/zF6XiHxD1wXsP3XiIZ7z3VzTTsP764xM3WzmW+3mhnXwsP1HTykx1Q+j9MvtHgRryuz3YbJZAp8q2dJlaWoqIi4uNC9lxVFqQFc7ft2G3CFrus7FEWpjRBXjgEv67q+w3f/JsAjwMUIR8tFuq7nBXpeYSeuxFuFrTonhOKK0ezNF4GjRrMtGwYFuDvImTCrwi5vdKhtl1piIfT7fuPFlewS+DMTJncxNrQz1wHTN8PoFifcYMHkus7CvfLyn/Dm0OCPB6L9863d4NnfhHvDyC5Uflxe+Pcc4SL4fCy0SDVmXK8Gjy4WbY6HNxWdfgJZTpNohw9Hig4/L/4Bfx0VAlYw85l61YcfLxMlZm8sh4W7RfvtK9qLY0eoyCoWz/Nn60W5at14+G8PGN8quF2g3F6R3zJnO/y0Q3SeiLUKZ0zv+qKle8YpzmmiLOL3vesL59HOXFiwSwgtb68UQnNqlAiHHthI7CfK4FJYSeWJxCv+krITquBatwa/7lOZuc3Mgr0mXF6Fhgkat3RxMaKJl9px5VvUF0+bR95dL2JqWIvkr5/H1q1tkGZuPIrVQsJLd6FE2yh6bzr2Qedi7RrYvy89RufBHm4mtffw7mozUzeZmb7FzMUtPPyrg5sagfgM8mWumKxWLBbL31o9lzkUV1ItCVXmiqIoiq+NcjrQHxFQO8UvonAic8UNxPkeY9F1fTtwmaIoOcANwAfAuFKdhgJC2IkriiIWg9khLAsyErcXDhZAPYO7yGQVi4W/UWG2ftqlwxcbhHXdqAVTlEVcqf1tnzHjlebnneDVYXATY8f9coO4gn1NB2PGS44SJR6vLYP1h6GtQZ3ZrusMP+wQDovudYxteazr8PAv8Md+0YraqLBijwZ3/AQztoh2xvf1DI4LQVWEE6lTBtw8F0Z+Cc/2h5FBLG+zm0WL7dHNxf/00cWilfijfUQ4tZHszoX3VonxnV4Y0FAIPefXC57rQ9dhwxEhjH63RRyj46xCBBnaVIxdHhFNUUQJZuMk8VrJc8Ave8Rx6Yft8PVGmDbe+KBtiURyeowOrtV1+OuYyFGZs8NMtkMhya5zUQsPI5t6aZOqVejiUNEns8h/8HWsPTuS9N4jqNHh0wUoUCiKQtwD1+NctJK8+14ldc6bKNbAq9W1YnUeOd/Nvzp4eGu1hSkbhZtlQisPk9q7SanERTR/oG3p+vGAtnqWVFlC6FxREO2UkxFhtNuB30r93v+OcCJKf9B13a0oiknXdS/C1TIR6KYoykBd1+cFcnJhJ66AsLYbXRaUZBfuBqPFlYOFYvFdr4qH2fpply5CVrdnG3eVH0Rp0GvLxdXfhCDlIZyKH3dAnThjS4I0HT5dL8QGI5/j6zvBZ+tECcLnY40Z06zC8wNFAOsDC+GtocY5hD5aC1M2iPbFRpW5ubyizfPc7SJL54Zzgj9mj7oweyLcOBdu+gGW7hOOiGDmF9VPhI9Hib/zscUwfppovf3fHsEvu1p/BN5cLsa2mMT/9l+dgitEHy4UGULTN4ljs80kBJUxLYSgYgvQJ3WCXWQ/jWouhP1lB0SIsUQiCQ/8C1gj3CoHChVm+XJUduaqWE06/ep5GdnUS8+6XiyVWDMXvv01BU9/gG1gd5Jevx/FXnWTs9VoO/GP/Yecax6m6N1pxN44IWhj1Y7TeaKXi3+1V3hrtYWPN5j5cpOZS1t7uKadm6SK6Fe+VsyK+dQfNKcLxfULLv6vTSaTdLVUM1wuFzabgYuqf+I/I8wBDpb6uT8IxsuJkFt8wgpAHvADMAE4D6j64kqy3fiyIEURNY5Hio0d93inIIMzV46LKyFwroAoDTJy4X9eXXhlmWiROtCgUqhCl7D1X9bW2JKgxXtEwOq9PYwbE0SI501dxWJ4yR7jQkmbpcBt/vKgbSIDI9gs2iP+zgsaw51BanN9Mg4PTJ4NC3bDw72EU8goasbCF2PhpT/gzRXCrfPKYOgQxIW5ogi3Rt8GIlvmrRXCcTGpoxCVAt3lZuVB4bxauFu4RSZ3ER2+ghVO7PaKUp2vN4oxNR26ZIh8m2FNg99tymIqf1cpSfggy4KqFkblqxS64Cdfjsqyg2Kx3Lmml0fPdzK4oZf4Sq6TdF2n8IVPKHxtCvaRfUh88S4US1guNQKKvX937EN6UvDqFOwjemOuH1w7YP0EnWf6uLi+g8Ibqyx8sNbMlI1mrm3n5vqOnnK5K13L1oPFjKlu2SzHZ3O1eDye4/eRQoskyPiPWC7fzY9fULEC8fC3UiIQ5UKFvq8DvgIPyyNeUtSJxb+RpMUY71wJZRvmGAtkGFwq1zBRLFzWHRGdSYyiY01xNfj3fcaJK7/sFuUERpcETdkAKVEwKEjtl8/EZW3hw9XwzG/QI4jlEydzXWfhEnpwIXSvLdo1B4vt2XDjHGiRAi8NMuZvLHbDtbNEadvT/WBiCMrWLSa4uwf0bgC3/Qhjv4ZbuwvnTjA7OEZZRKnQRa3g/34T4s7UjfDvznBp28rlhei6CKd99U8RGJtkh7vOFeU/lV1knI6tWUJQ+XaT6DSUFiP+lotaQUMDu6hJJJLwINjCikeD3zJFe9/5u004vQr14zVu6uxiZBNvhdv8noyu6xQ8/g5FH3xL1CWDSXjqZhRTEFsVhhnxj0zGuWQV+Q+9SfLHTxgyZsNEnef7CZHl1ZUWXl1pZX+BymPnu8r0uax7vZTMWoSt7zkVChk+U6tnmdVStTmhU4QUv6ASgy9bxYcHcAD1AH9Rudmnr3gQ3YSCdnkpLMWVUDhXQHSm2J1r7Jj780Vpg9Eix7YsYzsF+VEVaJNmfKitzSxCg3/fb9yYP+wQwZFGhhUfLhRX9//VCawhOKexmeH2c0Vr4NkGuUhAvIeeK1Ue9Paw4Ly28xwwaaZ4bt8fYUxr70IXXPWdcFa8OAjGtgz+mGeiW22Ye6kQsl74HRbthmcGBN8FlxEHL10AV7YXpWePLxEBrdd3FqJeeUQWXRdukdeXi+e1RjQ8cD5MbBOc/2muQ2SoTNsohGWzKlqIX9xKiFWhDOyVRC6KooTLCa6kggQzuNatwbTNZt5abeZosUqCTWdscw+jmnppV6NiOSqnQ/d6ybvvVUq+/IHoa0YT/9C/q527ylQzlZirR1P42hS0vELUBONO7Jsm67w6wMXrq3TeXGXB4YVn+rjOWtrl+mM92uEsokb1Dcg8ypLVIoWWqkMI28P7P/iOATuADKB0AMNBYBPQEegLfKnrurvU7/sAPX1fB3xFGpbiSlIU5DiETdrIdpE1okXZiJHsLxDCSjCv/J6Kbdmi60QoaJcO/1sj8iOMFADOrSsWg9klwe+g4/CIDh0jmxv7v/16o8jwCXb75TMxujm8uxKe+02UzRj1P26WArd3FwvvGVsCn4Pi1UTeSWYBfHEh1DbAbZbvhCu/g7WH4LXBoemIdCoSbPDqYFGy89AvMORzIejd1DW4LYlBlCJ9eaE4Vr/8JzyxBN5eAdd2gotbn/m9rekiS+WN5aIDUu04eLyPcNEFstsSiGPA/F0wa6vYurzQMlXk1YxuLlqYSySS6kuwgmt1HX7Za+L5ZRZ25qqck+Hl4R5Ozq/rDcrnsa7r5N35AiXfzCf2ponE3nFFtRNW/Fi7t4PXpuBesxlb7y6Gjq0ocFNnN9FmneeXWXF44MV+rjNmdpXMWIASG419QPeAz6e0q6W0uOJ3afldLSaTSQotEYg/1DgUlCrv2QZsBYYAjQF8obUrFEVZhxBXxiqKkgtMQ3QV6g3cCcQCu4G1gZ5fWIoryVHiJDjfKVqCGkVajBB1nJ7ABQiejf35xizSSpPrgKPF0MTgMFs/7dLEQmNLFrRNM27c8+rAC4i8iKFNgzvW0n1Q5IbBBpbmeDXRJahHXWiQaNy4J2NSRfnINTPFfK5ob9zY/+oE83bC/QuEiNc4gGUWL/whOq081c+Y7ir5TrhihghYfWMoDDG4vKws+ANXn/lVlOt8twUe7i1avAf7M7drbZgyFpZnijylZ5aKttGDG8MlbUSItV+cL3GLsNgP18COHGiUKIKQRzf/W4OESuP0wOK9QlD5eac4BqRGifKl8a1EK3qJRCIJVnDtX0cV/u9PK8sOmmiYoPHmICd96nmDejwu+XaBEFZuu5y4Wy8L3kARgKVDc1BVXCs3Gi6u+JnU3kOUBR5fauU/Pym8NtB5yoseusOFY+4S7IN7oNiDG0rqF09Ku1pKlw7J8qHIo7i4mJiYINbgnwVfhspRRVEWAt2BJF+7Zb9D5VNgIFALuMt3O5lpwM+BnltYiitJviuP2SXGiysghIc6BgkemQXQ0+BQwVCF2fopHWprpLjSPl1cVf/dAHHlh+0iW+Y8A/+3S/YKJ9S9Pc9+32DTrwF0rQWvLoMLWxpTPgOivOL1ITB0CtwwB2ZcHBhHwpxtwu0woY1YKAebPAdcPgM2HhUdkEKRn1NWUqPh+UHCNXL/Qrjue+hdX7iIghl46+ec2vDZGJFl8sUGIaLM3AoNEmBgY1Fi+vMuISq3S4M3hgihKlCOsj25QnRbtEdk4pR4xOfWyOYwvKno2iXLfiTBQJYFRR7Bylc5UKjw8nILs7abSbLrPNjDxfgWnkp1/CkLWnYeBY+/g6VTS2JvnhjcwSqA7vWCb8FuBGpsNOYWDXCt2mTIeKdjYisP0Wad+xdbuW6ujbcHO/8RAu9YuAy9oJio0YEpCSoP/vIhs9l8vHyotJtFtnoOfwoLC0MqrpRyr3wAzEW4UhQARVFUXdcXKIpyF3A70IgTwbUOhP7xEfC8L4MloISluJJSSlxpZGDAX01feeThImPEFZdXZGQYJeT42R5icaVuvFh8rDtszELVj8UkHAdL9wV3HI8m3BP9Gxpb9hTKINuTURT4b08RevrOSpHDYhS1fNkcV30HjyyCZ/pXbn9bs+COeSIU+dHegZnjmch1wKXfinHfGQb9DQpgrizn1IbZE0SL6teWwaivhIvqhnOEayzY57bNUoRr5o5z4ZU/YdomeG+V+J1FFSL22BYi86miy9FiN2w6JtxEG44I18zuUqHk41qJ933PuoF1xEgkksin9AJSDdCCv9AF76218PF6cTp/XQc317Z3E2fQBY38J99DKygi+ZlbUUK8CPbsPYjz5z9wr9uGZ/8hvJlH0A5lAaDERqPGRaOmJWPt1hbbeR2wnNMGNTrwV3CtnVpRMmMButcb0kDf0c282Ewu7l5o5erZNt4b4vzbBWvHjAWoNZKwntshZHOEv5cPWSwW2eo5Qgi1c8WPrus5iFbMpX+m+bZfKIqyAhgFNEeUAh0FvtZ1/ddgzSksxZVk35s/x2HsuDV9r5FDhWe+X6A4WCBO8muXP6C7UmzLFp1zjBZ1/CiKuIK87ojxY/euL1ro7ssPXoemZZnitWtkl6BQB9meis4ZItD27ZUihNXIUqW+DeA/XUSpSrfaFc9fyXPAv2aJzlrvDAt+ueCxYrj8W1G68s4w6NcwuOMFGotJvAYntIHP1wtxY+I3Qpi6tqMQHirT3edMbM0Sgso3m4T7MNEuAmobJ4l8lcV74PZ5vnmq4vXYKEn8Pt4m3jc2k9iaVSHuHysW+zpaLFyGO3NEySoIIbN9umjX3Ke+2F81jRmQSCRnoXRwbSCEFbcGUzebeWOlhWyHwsgmHm45x02tWOOcTM5fV1EybR6xN03E0ryBYeOWxnvgCEWfzcb58x94tuwGQM1IxVwvA1v3dphqpYGioBUUoRcW49lzgKIPvqXo7ako0XZiJo0l5l8XBjR81tyqEfpn3+M9eAxznbK1Nw4WQxp7sZud3DrfxpWz7XwwxEFqNGh5hTgWLCP6smEo5jA5afRRllBc6WoJPYWFhcTGGtyN5TSc1Ga59M9VXde3Ac8bOZ+wFFdKlwUZid+5YpS4sr9AbI0WOXbmiEWFkWHBJ9MuHd5aIUIfAx0keSZ6+0J8F+2Gy9oFZ4wfd4hFmn8sI/gqDIJsT8WD54uuLA8uhE9GG7v4vONcWH4A7lsgys+alNOp5dXgJl+A7ZcXim5iweRggXCsZBaITkS9QhQ4HQhiraKLz5XtRYect1bCDXPFe71vA1GW068BxFWizLvAKdooL9kLS/bAzlwhivRtIErR+jX4uxim6aLMauNRcQzckSNcfPN3CbfZqbCaRGZKarTIaRneVDhf2tQQnxdSTJGEkuoaGhppBLIjkK7Dwr0mnv/Twq48la4ZXu7u5qJ1DWPLw/QSB3n3voqpYW1ib5xg6NgA3qM5FL75JcWfzQavF+s5bYh78HrsA7phblD7jI/Vih24V/xF8Vc/UPjaFIo+mUnC4zcGrGOOliXsjKbUxIDsr7L0ra/x9gVObvjJxuXf2/lwqJOEH5aCy03U6H6hnt4ZOV2rZ+lqCT2hLgtSTjqQnvy9D11RFBO+ciEfmt/ZEizCUlzxd3swuh1zol0sig0TV/LF1mhxZUeOcI6EknbpQgzYeBQ6GdiquHGSeL4X7QmOuKLpogVznwbB75rix6vBV2EQZHsq0mPhznNFec6c7TAsyFk3pfHnrwyZAv+ZA99dXD7XxDNLxevk6X7QJcgBtnvzYMI3oiTo09EirLUqYDeL99klbYSja852kUc0d7s41rZLh2bJ0DRFlCk2SRZZRaoihAsF8Z7aXwC7c0/cNh+DNYeFKBJlFu6kK9oLp1Tqabrw+NvAtznp2OfVwOkVZZr+rVeDJLtwtMj1q0QiqSilg2sru/ALRVjt6Sh45XO8ew+S/OX/odgNqkECdE2j6K2vKXxtCrrLTdS4QcTePLFcDhE12o6tV2dsvTrjnnwxeQ+/Se4tz6IVFBFz2fBKz9G7/zBqjeSgh8SWh3Nra7w/xMn1P9i4bJaNz75dgKlBLSztwqQFYRk5m6vF4/Ecv48UWoJLcXGxYc4VTdP+8f88lVPlNHgDP6MzE5biSpRZnHhnG1wWpCjiaqRR4kpmvjjhr2mgq8rhEaLOmObGjXkq/OLOuiPGiiuKIiz8324OTivodYfF6+ee8wK73zOxOIyCbE/F5e1g6kZ4dBH0qlc5t0J5qRkLL18AV84QAsu7w8uWhTFtI7y7SjgvJgY5F2hrlnCsuLzwxdgTgc9VCbMqwp3PqwuP9YGVB2DuDvF++X4b5G0o+74S7cJ5d30nOL8+dKpZuXItkwrRqnFiqEQiqfoEMri2dFhtsl3noR4uxhkQVns63Bt3UPTuNKIuHoztXOPaAWqFxeTe/jzOH5diH9yDuHuuwdyoTqX2aWnThJTPnyZn8hPk3/8apvQU7AMrFxLn3XcIU4jLgU5Fp5oaHw1zcNPnBejL1hJ1y6UR7X47nauldBci//2kqyXwFBUVGSauPPfcc8fLkKxWK7fffvtkwHmam6vU1n9zn3TzAh7AUw6RpsyEpbiiKMK9YnRZEBgrruwvgPQYYzMy9uSKK8FGBgWfipqxUCNaLK6Mpnd9+Gw9rDgQ+G4+P+4QC8n+BmZlfBFGQbanwqzCk/1gzFfw0h/wkAGhsKXpXR8e7wsPLIS7f4YXBp25JG7VQbh3gXACPdQruHNbfxiu+A5MCnx9ITRPDe544YCqiPDbc3zuHF0X2SbbskWZTrEH0MVxyv+JVysOGiYKZ5aRHeQkknAmkhdGVZlACSsFLnhvjYWPN5hRgOt9YbUnd30xEl3Xybv3FdSkeOLvm2TYuJ69h8i59mE82/YS/9D1RF8zJmCvf8VuI+mdhzg6eDIFz36IrV/XSgXRevcfFi2Zw5DWNXQmH/4ZRddxDu6HwZGPQeVkV0vpm78kTwotgaOoqMiwsqCnn36a/Px8LBaLX0x7BiGOlBZMXKW2pW/Ok7723xyAqijKg7quFwdyvmEproDIXTG6LAiE2LHGoAV/Zn5oSoIAGoeoU5AfRRFX6NeGQFw5r64ItFy0JzjiSvfakGDQAvBwkQiyvTaMgmxPRceaIlz0f2vhwlbQuoax41/eToi1L/4hji0Pnn/qco+DBaKVcEasaNkbzBa6P+6AW34QQvLnY6BhiAXPUKEoUCNG3IxsXS6RSCSBJhD5Km4Nvt5k5o1VFnJCFFZ7Ojzb9uBes4X4x29ETTTmBFYrLCb7yvvRsnJJ/vRJbD07nfUxuqbh2boHz6ZdaMUlmNJTsLRvjqnGqT9oFauFuDuuJPc/T+BcsKzC7hWt2IH3wBHsI4J8ZaYSnLt6HptrtWRjSX1uIOBdaMOCU5UPlRZbPB4PiqJgMpmk0FJBCgsLiY835hjgP4663W7cbjcQUF3w4QDuCwhjcSXZbnxZEEBGnFj06Hrwa+335wc/y+Fk/OJKo0Rjxz0VnWqKMMlch7FXo2Ot4nn/ZXdgS2m2Z4vn90rjXLJM/Utk11zS2rgxK8o9PUQezf0L4JuLjA9UvrmrEFg+WC2cPjec8/ffl7jhX99DiQemjD0RrB1odF20p35mqeg2894ISAt9NzuJRCKRVIJACCsL9oQ+rPZMOH9ZAYB9QDdDxvM7Zby7D5D8xbPYup85LE93uij69HuK3p+OdvDY339pNhE1qi/xD08+ZXcgW58uAHi27oEKiislX8wBjxd7P2Oen/LiWr4B06bt7B5/A19stHBtO0/QuyCGmlOVD5UuHZLlQxWjuLiY2rWNCQj84YcfcDqdOJ1OHA4HI0eOvBiIAqJ9W//Xp7pFnXSz+242IAYIuJUjbN9SyVGQGYJWvTVjRKhhjuNEsG4w8GhwsDA0zpWMWIgJoa3UT2efsLTyoLFlNCByV55eKkrAApV5M2+n2A5sFJj9nQ1Nhy//gu51Ql/mVRYS7HD/+XD7T/DhauG2MRJFgYd7i/f2//0m3t/+7kpeDW79ETYcgQ9GQLOU4MzB5RXdi6ZuFJ1nXhhkbLcsiURSdZBlQeFD6VDNiizOPBo887uFzzdaQh5Weyaci1diblpPtDg2gOIpc3DM/IW4u646q7DizTxCzr8fx71uK9bu7Yi640osHVqgxkbhPXAUx+wlFH38He6NO0j99mWUqL9f1VNjolAS4/BmVsxSrTtdFL47HWv3dli7hN8VL39mjaluTVpdN4isRQozt5sY38LwvM+Q4ne1mM3mv7la/CKLbPVcNowsC+rW7R9i5TROdABSfDe11NZUamv2bS2+mxmw+m4WXdcD/gYI29P6pBBlrvjbrR4qDK64cqhQOA5C0Ya5cZgsxDuki7KLlQdCIK40EOLKoj1wcYA+A3/cIYJ6axlUxPrbPtiXL7rxRApjW4jn6eml0DEDOhsYZgzCLfP8QOGWune+KD38d2eRx/LDDni4F/QPkjiWVSxCdf/IhFu6wq3dQ9sOXSKRSCSVIxD5KvlOuH2+jaWZJq5q6+b2ru6QhdWeCb3EgWvZemIuH2HMeG4PhS98jLV7O2L+c/EZ76vlFZJ1+X1oR7JIeuch7IN7/O33powaWDu3wtqzIzlXP0jRx7OI/ff4f+xDzy1ArVmxuuWSb+ajHTpG7HO3V+jxwSb/sXfw7j9MytfP07WpnZYbND5eb+HC5t5qey5S2tVisVhkq+dyYGSg7cn4QmjDx9J3EmH7Kkm2Q55TqPlGklFKXAkmmQViW8fANCldF86VcHE5RFlE9saKg8aP3TxFOFZ+2R2Y/R0ugtWHjA2VnbJBlFMNbmLcmJVFUeC5geJ9dsOc0AioVhO8MwyGN4Nnf4OBn4vn8oYucE3H4Iw5ZxsM/Ey8Rl65AG4/VworEolEEskEQljZk6cwYaadPw+oPH6+k3u6h6ewAuD8Yz043cfLZ4I+3qIVaFl5xFw7FuUsi9rC17/AuyuTpPce+YewUhp7v66YWzXCuWTlP37nWrYeAGunFuWeq+71Uvj211jaNsV6vsG23DLg+Ok3Sr76gZjJ47Ge0xpFgavautmRq7JkX5i+4EKAqqpYLBasVitWq/V4gKpfdHG73Xg8nuMutepMKMQVp9PJihUrUBSltaIo5UqNVgQ23zaoZ+Bh+47y5x3kGpy7UtMocSVfbGsb6Fw5UgyFrtCH2ZamSwasOQRug12JiiK6yPy6NzAC3s++kqBBBpUEZRXDTzvgwhaRV1aSYIO3hkJWiSjF0UKgPUdZ4LXB4v+1PVvMaXyrwI+TUwI3zYXJc4SjadYlMLr8520SiUTyD2RZUOgIhLCy7KDKxd/ZySpR+GCYk3FhXp7hXLQC7Das57Q1ZLySafNQUxKw9TnnjPfTdZ3iL+diH3Z+GVtDK/8IVdS9Xgpe/BRT7TSsXdqUf65f/Yh39wFibrgk7N6X3qM55P33ZcytmxB36+XHfz6ksZf0GI2P1ltCOLvwRVVVTCYTVqsVu92O1Wo9ntviF1rcbvffSgKrE0aWBQGUlJTw5JNP0rVrV4BZQHlVTBswCrgGGBZMgSVsl2X+kpzsEkiNNm7cGtGiYCvY4so+n7hiVAkJwI5ssW2caNyYZ6NzLfhgDfx1FDrUNHbs3vXhq7+Em+CcSgYL/7QD6icEL6vjZKZtEh0FLin/OUBY0DZdlODcvxDeWA43dTV+DrO3iZyczhlCYBn5Fbw8KDBlQboOP+0U4b25DrjjXJjcGSxh3NFJIpFIJGcnEMG10zabePRXK/USdN66wEm9+LB1uB/HuWgFtu7tUOzGhPY5f1lO1Nj+KJYzL1W8uQXo+UXsqteSKavMHChQcXjB6QGHR8HpFVmKDo9Cwy2r+O/GHXwx5N+snmWjVqxO7Tidc77/igYbd+B59l7cFivl+QtdyzeQ9/CbWM/rgP2C8yr3RwcYXdfJu+cltMJikl++G8V6QkixqHBZaw8vLLOyKUuhZUr4vwZDyak6EJUWVrxeb7UKxS0qKiIuzrhF7NatW3niiSdQFAVd13/Xdf2Lcu5CB1oDDwJFwFXA9ABPEwhjcSXFJ64cKzZuwQpi8VMjBg4VBXeczAIh5BjpOtgZJm2YS9PFl7mx4oDx4krPemBSRGlQZcSVAif8tl90CTLigoWuw5cbxHNn5Hsj0FzaFpYdEO2RO2cY24b3553CNXNOLfh0DBwpguu/h2tmiZ/deI4Q38r7/3R7hWjzwWpYdwRa1YBPRoutRCKRSCIbXdfxeET72oosoLwaPL/MwkfrLfSo7eWlAU7iwqDBwNnw7D2Ed+d+Yq4wKG9F19FLnKjJiX/7eZ4TtmarbM1W2ZKtsjVbYefRFL6wRrH894O8kWQhLUYnygw2E9jMOjYTJNrAFqPR0Z1JVu0GHB0zGt0Fyw+qmL/4hgY/fsiCVv15In8IyodQN16nbz0vAxt66ZCmYTrNv9qz7xA51z2GqXYaSW/df9byJaMp+WIuzvl/Ev/wv7E0q/+P31/UwsNbqy38b52F/+vrCsEMI5NTdSDyCy2lw3H9v6+KYouRZUElJSXMmDEDgDZt2rBu3br7QZT6+PJXzojvfk5FUV4B+gI9gUnAdEVRVF3XA2o9Cltxxe9WyQpBJkPNmOA7V/bnh6ZTULQlcN1xAkF6rHgeVhyEaw0eO8EGnTJEqO1dlbjYsGiP6AJzgUElQX9mws7cf7YSjjQUBZ7uJzr03DQXpo2HhgbkAX2xQXTsaVMD3h8pBM56CfDtxeJ3766EK78TeUA3nAODG3PaEys/uQ6Ysh4+XieOHY0S4Ym+IizZKt0qEolEEtEEogyo0AV3LLCxeJ+Jy1q7uae7G3OErLlcS1cBYOvV2ZgB/WUWZhPrj6r8sNPEz7tN7M0/8YQl2HRaJGuMbgVFPc5lxNI5THxsMHFtz3AyNnAg+gP9eNqsoGUfIe/hN3H8+At6/x40feROnnK4OFCgsO6oiSkbzXy8wUJKlE7/+h4GNPDSrZZ2/DNdKygi55qH0D1eUj58DDXR4JP6s+DZnUn+Y29j7dmR6KtGnfI+8TYY2sjLnB3yRKUynOxqOVlkqYquluLiYsPElf379/Pjjz8CcP7557N27drdiqKYytrpxy/A6LqerSjK+whxpbaiKC11Xd8U6PmGrbhS2rliNDVjYU9ecMfIzBelEUayIwcaJoZfkGaXDFi6TzgyjC5V7V0fnv9dvM4qWn724w7xeu1kUOebKRsg3grDmhozXjCJscLbw+Di6TB+WnBdHroOry4TTpne9UXuS+mW5HYzXN1BOGq+3QxvrRDdfeJtosNWoyQhmjRMghI3bMsW5UTbskWZn6ZDj7rwVD/o2yD83mcSiaRqEW7ZDlWVQAgr+/MVJv9kY1euwkM9XExo5QnCTIOHViQCENXU4F8B0XVYfwhqAJ+sN/P6DDtmRefc2hrjW7honqzRPFmnRrR+/JzR2/Q6jg1ZTdHEOzE9+h+iRvQ5bTmRdjiLos9nUzxlDnphMbG3X0Hsfy6mlkUF/Gs1D4UuWLzPxLzdJr7fbubrzRbirDpjm3m4ob0D903P4Nmxj+RPnsLcqE7Qn5fyoHu85N76HFgsJD5/xxkdNbFWWQ4USE5VPnSy4FIVWj273W6s1uDa7vyll5mZmaxYsYLatWvTr1+/478uz75KOVQ2I8qCagHnANVHXEm0i8VJSJwrscIdECw0HQ4UGt/lZUcOdDK49KYsdM6AGVvEArVegrFj9/GJK4v3wNiW5X+8ywsLd8OQJmd3NwSCXAf8sF04IqKqSAZZsxT4ehxc/q0QWf43ErpUMgPnZLwaPLgQPt8AF7aEZ/ufPv/EahLP77iWoj3zb/vEe+fXvTB909/v1zAR2qTBmBbC4dJSlv9IJBID8dWfh3oaVZZACCsrD6ncNM+GV4N3hzg5r3bkhV+qcSK4UisoQk0I/NVqXYe1R1R+3GXix10mDhaqfJDWkHN2LuOpey6lX30vCbbTP95UI4mUb18m94YnybvtOQqe+gBbz46Ym9RFibajO1x49x3CtXoTns27QVGw9e9K3B1XYml5aqdLrBWGNvYytLEXpwd+y1SZu9PMjGVFdHjgUdpvX0H8Ezdh6xmkNoOVoPCNL3Gv3kTi6/diyjjziYlHI2IcVJHGmcqH/K2e/ferSq6WQHP48GE8Hg+1atWicePjbVnL9cFXqvQnG1gO9AGCooqGrbhiUkU75qwQOVfynOLqdDAWsEeLxKLcyLIgh0e4ZcZXQEAINv6F9IqDxosrrdNE9s38XRUTV/7cDwUuuMCgFszTNopgtgkRGmR7Opomi7Kgy74Vt3eGC3dJIMhzwp0/iYDZG7qIErCynB+bVOEOKu0QKnTB7lyIsUDdBHlCIpFIJFWVQATXzthq4qElVmrFiuDahomRKYQpccLaqxcENpDQ5YXpW8y8v9bMgUIVi6rTo47GzZ2dNHP2w/3cB3Sy7sVsq33WfZnr1iTl25dx/rKC4q9+xPnnOkq+nX/ib4iPwdqxBfYh5xN14QDMdcpuH7eZoW99jZ5Fmzj81WNoh7J4ftjd5KQM5cFcV1j9X11rt1D4ymfYR/UlakSfs97frcmwfaMoSyhuJLha/IK+Ue7JY8eOAZCamkpKSop/DhV90xUDWb6voyo9uVMQtuIKQEp06JwrILITgpEB4e8UVNvATkG7coTEF05htn6ap0CcFVYegLEGt6lVFdGO95vNFRPTftwJUWYRjhtsNB0+XSfKj6piQGqdeJg6Dq6YAZNmwksXwIhmFd+frsN3W+CJJeI48khvUfZTGWKtwqkikUgkkqpLZYNrNR1eWm7h/bUWutfy8lJ/J4n2QM/SONQAiyseDWZuM/HmKguZhSqd0r3c3MVJv/re4wG/3tg+HHn+Q4q/mEv8vWVL5VNMJuz9u2Hv3w0AragEPB4UmxVs1kotBku+W0ju3S9hSYwjYerzdLK34ZUVKqOm27mnu5tLW4e+1EsvcZB72/+hpiWT8PiNZXqM26tgVsNHHKounM7V4hdc/F+bTKawdLUYIaz4hW2n03n8e/9xuRJ4Ab+LJShv2vAWV6JCk7mS7mvbfagoOOJKZoHYGulc2eHvFGRAYGh5ManQsaboGBQKhjcT5SILdpcvx0TTRSvf3vWN6fq0ZC/szoPbugd/rFBRIwa+HAfXfAc3zoWpG4XTpG05BY1dOfDAQvh1H7RLg/+NKv8+JBKJJNyRZUGBJRBlQEVuuGehlfl7zFzcws39PdxYwmtdVG6U42VBlTsp13SYu8PEayst7MlXaZPq5ZGeDnrU0f7hKDXVSsM+ojdF70zD2rUN9v7lP/lRYyp/YVp3uih47iOK3puO5ZzWJL35AKa0ZC7Fw6CGHh5aYuOJ36zUiNYZ1LBM+ZpBI//J9/Hu2E/ylGfLXL7l0cAs45tCztlcLR6P5/h9Qi20+OdiFP7jcF5eHkVFRf6flalTUKl9+O+fAPh7rWad4SEVJqwP96nRoSkLyvA5SoLVMWi/z7kSCnGlYaJxY5aHzhmwJUuUcBhNt9qQGgXfby3f49YfFq8Ro0qCPlkr5jnE4Kweo0mwwedj4b6esPYwDP9CBMv6X8OnQ9Nh/RF4dilc8Ll47ON9YMbFUliRSCQSyZkJhLByoFDhspl2Fu41cf+5Lh7uGfnCCpzIXKmoc0XXYd4uE6On27lzoQ2bGV4f6OTr0U561v2nsOIn4dnbMLduTO6NT+PeuKOi06/YnDWNkm8XcLTftRS9N53oK0aQMuVZTGknLOA1ouHl/k46pHm5Z6GVDUdD88/WNY38J9+j+NNZxEwai61HhzI/VpYFhR+qqmIymbBardjtdqxWK2azuIrr9Xpxu9243e6/iS9GUlRURExMTNDH8R+Dk5KSiI+PZ926dRw6dOj4r8u7O9+2HlDX9/XhSk/yFIS3cyVUZUF+50qQxJXMfEiOEm2RjWJHDtSJC98Q1C61RNnSmkOBy9ooKyYVhjQVLoki1987yJyJH3aASYF+DYM7P4C9eSIX5oZzRO1vVcduhus7i2yZ91bB+6tFkO/59UQuT604ccvwdfZavEe4VLJ9x4uRzeCBXidcaBKJRCKRnI5ACCtrDovgWocH3r7Ayfl1Iy+49nQopQJty8uqQypP/W7hr2MmGiZovNDPyeBG3jJ11FOj7SR/+BjHRt5M9hUPkPDsrcdLfoKJc/FK8p/+AM/GHZhbNSb501tO24baZobXBjq5+Ds7N/xk5evRTtJjjHOT6Q4XuXc8h+P7xURfMYK4+8tWQuVHBtqGP2XJajEyFLeoqMiwNswArVq1on79+qxfv57Vq1fTv3//0gG1ZUVFlAP1BJoCu4F9AZ4qEObiSmqUCAt1eIwpu/ATYxUZIEFzrhQYm7cCsDNHtJINVzrUFPknKw4YL66AyPb4dJ0oDSpLzoeuw5xtcF5dDKmj/ny9eH4ubRv8scKJeBvccS5c2R7eXAFL98KqQ5B/ksOpRrTo/HR+fehZF9KkqCKRSCSSMhCI4NrF+4Swkh6t879hTpokVa1SLTXe71wpu51c1+Gj9WZeWGYhPUbn6d5Ohjfxlnshb0pPIfmTJ8n9z5PkXPMQ9iE9iX9kMqaaqeXb0Vnnq+Nes4WC5z/C9etqTHXSSXz5Huyj+pyxlTEIp/1bFziZONPOPQutfDTcGBu2lpNP9r8ewb38L+Luu5aY68aV+/Xr0agS7qrqQumsltLiSulWz/7fV/R4djaMdK7ouk7Xrl3JyMhg/fr1vPTSS9x1110ddV1frSiKWdf1s+amKIpi0nXdoyhKU2C478eLgJ3BmHdYiyspIj+L7BJxldpIasYGtyyoWcrZ7xcodF2IKxe1Nm7M8hJrhZapoctd6ZIhFuSztpZNXNl8TOSfXHfqCxkBxeGBr/6CgY2Mfx+EC6nR8FCvE98XuuBAARwsFMJKy9SydQCSSCSSqoRR3RqqKpUNrgU4VKhw90IbjRJ0PhzmICmCg2tPi80KNgveA0fLdPciF9y/2MqPu8wMbODhqd4uYsvoCj4VluYNSJ37JkXvTqPg1Sk4l6wi+vLh2Af3wNKu2VnFj9OhaxrutVtxzP0Vx9xf8e49iJIUT/xD1xN92XARhFtGmiXrTGrn5tWVVo4Wi3OTYOLZe5Dsqx7Au+8wiW/cR9Tw3hXaj1tTsMhA24jEf8wq7Wo5uc2z//eBdLUUFhYa6lwBGDt2LMuXL+fgwYMA7yuKcpmu65tAiCel7up/MSu6rnsBdF33KoqSArwE+Hunz9R1/Ugw5hrW4kqq78CUVVx1xBVdF4G2fRsEft+n43CRCFgLxzDb0nTOgGmbQmNR9LfdnbJeLNzPdhIwd/uJTkPBZtZWyHEI94ZEEGsVAqWRIqVEIpFIqg7+q72Vubrr1eCeX6y4vfBif2fVFFYQIp590HmUzJhP3D1XH89gORU7chRu/tnG7jyFO7u6uKadJyAXPxSrhdgbJ2Af0Zv8x9+h6N1pFL31NWp6CvaB52Lr3w1zg1p4kpM4oMewJ9/EnnyFEre4+KJoGnEl+TRyH6N+yVGi/1yO48elaAePgcWM7bwOxP7nIuzDeh936pSXPvW9vLoSluwzMbZ58MJtXWu3kHPNQ+huDymfP421a8VtzbIsqOrgLx8ym83Hy4dKu1kC1eq5sLDQEOcKnHCvXHfddcycOZM5c+YAtAdmK4ryEfCJruu7z/D42kAz4DHgXET2ylRgabDmHNbiSoov5PtoCEJta8bAtuzA7zerRDgRjAyz3e77O8K5LAhE7son62DTUWibbvz4w5rC/9bAzzth9FlaQs/ZDl1rie42weaTtdA0Gc6tE/yxJBKJRCKpygQiX8XP+2vNLDto4sleThomVu2r/zHXjcMxaxHFX8wl9rpxp7zPDztN3L/YSpQZPhzqpFutwOfOmOvXIvn9R9Fy83EsWIbjx98omvYzxZ99f+JOFjvxMcnUiU3G5PWQXJhNclEWZu2E4JFjtnKowzmkTr6GRqO7lbm7zplokayTHqOxKIjiimPe7+Tc+DSmGkmkfP0E5sZ1z/6g0+3LA1uyVbplhLbLkSTwlC4fslgsAW31XFxcbKhzxX+M/vDDDxkwYADr169XgQbAHcBYRVF2AFuBg0ABIlvFAqQC3YERnAi03Q48ECzXCoS5uHLcuRKKUNs4OFokrkqYAqjoZoawU1CTcBdXMsR2xcHQiCudM4Rj6fttZxZXtmcL4e2yPsGf05pDsO4IPNZHlr1IJBKJ5O/IsqDyUfpKrt8qX1FWH1Z5baWFoY09jGlW9Ren1nbNsHZvR9GH3xJz9WgUy4klhFuDF5dZ+Gi9hQ5pXl4e4Ap6qGtBVDxTGgzmi17DyWvvplXmX3RWDtOaLGo6skkqyKZObhYmqw01rS5qjWTcyckcjk5hrzWVX2Ka8cOhWNwuhf5/eni8l6vSziNFgd51vczeYQ74+gGg6JNZ5D/8JpY2TUj68DFMNSp3Yj9zm4kch8Ilrc4aWyGJcMoSiltWV4tRmSsnk5aWxvz580lLS5sLDAHigHa+G0AJ4AWsvltpPMBG4BJd17cFc55hLa74nSuhaMdcMwa8OhwrhvQAinP7C8TWyEDbnTkQYwn/kM/a8aL7y/IDcHUH48dXFeFe+XSdCEyNt536fnO3i+1gA1owf7xWlMBc2DL4Y0kkEolEUlUpHVxbWWGlwAV3LbBSM0bnkZ6uanPxI+b68eRc/SAls34heuwAADQd7phvZd5uM5e1dnNXNzfWILb2zSqBD9Za+GqzmWK3wvl1vAzrCufXaUFy1Flsx0ANoA0wFLjPUcJXm8y8scrC6Ol2PhzqpHElw4hjLKLUJpDomkbBMx9S9M5UbAO6k/javajRlVOCNB0+3mChZYpG14yq09lKcnZKu1qAcrtajO4WVJoaNWoATAQGATcApRIZiTrFQ/IQXYF+Ae7Wdd0R5CmGt7gSbRFdgo6Fwrnie80cLAywuOJzrtQ22LnSOCkynA896sK8nYF3DJWV4U3hg9ViDqcTNOZsP+FyCSZZxcJFc0nrs2fASCQSiUQiOTUnCyuV2xc88quVQ0UKn45wEleNPp9tfbpgblqPonenETWmP4qi8O4aM/N2m7mrm8hXCSYL9ph4YLGVPCcMbeRlUns3LVIqLoYk2uH6jh7Or+vl2rl27vnFyhejnJXqnrMjV6Vhgh6wc1jd4SL3zudxzFpE9OUjiH90Moqp8urV4n0qO3NV/q+vMyLWB5LgcTZXi8fjOX4fVVUpLCwkJSV0oYe6rucBUxVF+QOoB3RClAmlAdGIUFsPkAVsAabpum5Yy5SwFlcURbRjDoVzxS+oBDrUdn8+xFsh4TSuiGCwMwe61jZuvMrQq74ItV13BDrWNH78jjWFq+j7racWV/bkwsaj8MD5wZ/L/9aC2yuDbCUSiURyamRZ0NkpHVwbiG4ZM7aZmLPDzC1dXHRMr15X/BVVJea6ceTd9SKuJav4s+E5vLrCwvDGHq5uGzxhpcQDz/5h4atNwmnxyfAzt7vOd8KygyY2HFU5UqygAOkxOt1qeemaof1DTGiVqvNQDxe3zbfx0y4TwxpXvMxrV65CmxqBeV1ouflkX+trtXzvtcRcX/5Wy6fjo3UW0mM0Bjeq+iVtkrJzOldL6S5EBw4cID09BPkNJ6Hr+j6EK2Wpr2OQFVARpUFOXdePHyQURVFKfx9MwlpcAdGOORSZKxl+caUosPvNLDA2b6XYLcYM9zBbP+fXE4lDi/eERlxRlBPBtnkOSDjJdekvCRrSJLjzyHPCR2tgcBNokhzcsSQSiUQiqWoEMrjWz65chSeWWuma4eVf7atnTkXUqL4UPPcRx96Yyl1De9I8WeexXsErjXJ54YafbPyeaeKadm5u6XL6sqPDRQqvrbTw/XYTTq+CWdFJjRbrqSPFCm+tttAp3cvrg/7Z2al/Ay9mRWdLllphccXhgf0FCiObVn4N5/x9LXn3vYp3/2ESX7+XqBF9Kr1PPxuPKfx50MQdXV2VculIqj4nu1r27t3LzJkz6dmzZ4hn9nd8bZdPqxgYJaxAJIgrUaKVsNEkR4FFhUMFgd3v/nyolxDYfZ6JXbliG+5tmP0kR0G7dPhlD9zSLTRzGN4U3l0lhJRL2vz9d3O3Q7u04Atkn6wVNd03nhPccSQSiUQiqWoEQ1hxeeHOhTYsJni2jyskpcvhgGKzYrliNNrzH3JOk9+4555ORAVpNaHr8MBiK79nio5MZ+rAs3CPyt0Lbbg1GN3Uw/AmXtqlaceFmGI3fL/dxBO/WXlhmZUnern+MZZJFf/nijJ1sxkdhY7pFd+JZ+9BCp58D8cPSzHVTqt0q+VT8dF6C9EWnYtaVE+BUFIxDh06xKWXXsqnn34aduJKOBH2Hw2hcq6oCqTHBNa5ouvGO1d2+NowR4q4AqI0aM0h4RwJBe3SoXmKCJMtrXNm5sOawzC0aXDHL3aL3Jd+DaBNWnDHkkgkEknkIsuC/knpvIBACSsALy+3sPGYypO9XNSMrdptl8+ErsNzjS5ka81m3PfVg9RYtSxoY72/1sys7aIE60zCyurDKrf8bKN+gsbMcQ4eOd9Nlwztbw6XaAtc1NJLixSN7Tn/fE0s3ifcLt1rV0wYyS6B11daOK+2l/Nql78sSCssJv/ZDzna/184F60g9o4rqbHg/YALK4cKFebuMHFhc89pGzdIJCdz6NAhLr74Yl588UUprJyFsBdXakSLzBXjzDwnqBkb2MyVfCcUuoztFLQjR5TZNEg0bszK0rueSDH/dV9oxlcUmNQRNh6D3/af+PkPO8Q22CVBn6+HHAfc2DW440gkEolEUpUoHVwbSGHl130q/1tv4ZKWbvo3qN4ZFV9sMjPjQDw7nnsWS7P65Fz3KM5FKwI+TpEb3l9roW89D9d3OLPD4sVlFpKjdD4c6qRe/OkXDN9sMbH+qIm+9f7+P8x1wBO/WWiYoHFurYrlpby6wkKRG+49t3wlUrqmUfzVjxztcw1Fb35F1Ije1Fj0P+JunohiD7z68dlfZjTgitbStSIpG0ePHmX8+PE8/fTT9O3bN9TTCXvCXlxJiQK3Bvmus9830ARaXNnn7xRksLhSN0F0XYoUOtSEOKvIXQkVo5qLMOX3V5342Zxt0Co1uEKVwwPvrITz6oiORBKJRCKRSM6O1+vF4xELxsq2Wi7NsWK4d5GNJkka93R3B2SfkYqmw4drzXSp6eXq86NJ+fxpzI3rkv2vR3AuXhnQsb7dYibfpfCvDp4zihW6DuuPqgxu6D2rE6NVqsbYZh4mlcrLySxQuPx7O9klCs/1dWKrwPnyX8cUvt5sZmJrzxmDdk/GtWw9x0bcRN7dL2KqW5OU714h8cW7MKUHpxNLkRu+3mxmUAMvdc4gQkkkfrKyshg/fjyPP/44gwYNCvV0IoLwF1eixTZUHYMOFQbONZPpy2+pG4I2zJGExSRaMi/aExrHEggx6or2sGA3bM8WuT8rD4qA2WDy9V9wtBhukq4ViUQikZwFWRYUvDIgEGLCfYtsFLjghX7OiLpQFQxWHVLJLFQZ30IIHmpSPCmfP4O5UR2yr30E56+rzr6TMjJru4k2qd6zdmRSFEiL1ll/VMV7FtNJixSdJ3u7MKvg0eCLjWbGfWvnSJHCO4OdtK5R/pPOrdkK1/9gJyUKbuhUNvHNs+8QOf95gqzxd6Jl5ZL4yj2kfPMS1g4tyj1+efhovZkCl8KVQezsJKk65ObmMn78eO6//36GDh0a6ulEDGEvrqRGiW0oxJWMWNH+LVCumf1+54pB4oqmizbMkSauAPSuDwcLYVt26OZwWVuwmUT+ycwtomn68GbBG8/thbdXCsfKuXWCN45EIpFIJFWBYATXluaTDWaW7Ddxd3c3zZLllf5Z201Em3UGlCqNUpMTSJnyLOaGtcme9AjOpWsCMlaBSymzu+La9m5WHTZx0zwre/PP/BrILFB4d42Z4VPtPLbUSpMkjS9HOTi3Ajkpfx1TuOJ7OyZF5+PhDhLO4pzRikooeO4jjva/Fsf8ZcTedjlpCz8ganS/oAuls3eYeH2llcGNPHSoZi3EJeUnPz+f8ePHc+eddzJq1KhQTyeiCHsN3u9cORaCUNuavnbMhws56wGzLGTmi0Ctk9u/BYsDBaLMJBJb+faqL7aL90Cz4Lgjz0pKNFzYEqZvEh2eOtYMrlD17WbhbnqiL0FraSiRSCSSqoWiKBjYZTJsCFa+ip+/jim8uMxC//oeJrSUV/qdHvhhp5n+DbxEW/7+OzU5geQpz5A94R6yr3mI5I8ex3Zu+0qN59bAWsZLwBe19FLscfHqCgsXfGWmZYpGq1SN9BgdFZ0it8KBQoUNR4XzBqBTupe7ujnpV99boXOu1YdVrv/BRpxV53/Dzpz1omsaJd/Mp+DZD9GOZGMf1Zf4/16DqZYxXQv+PKBy7y9WOtf08kzvEOQsSCKKwsJCLr74Ym688UbGjRsX6ulEHJEjroTAuVIzRmwPFgZmgb+/QOStGLVw3pEjto0i0LlSJ14IGYv3wrWdQjePazrClA3CQfN4n+CN49HgzRXQugb0bRC8cSQSiUQiiXSCLawUueHOBTaSo3Qe71W+gNKqyqJ9JvJdCqOanlpoMqUkkjzlWbIvuYvsqx4k9t/jibl2LGpcTIXGS7TpbMlW0fWynTdf1dbDkEZeZm03sXifiUV7TWSVgI5ClFmnRrRO2xoaE1t7KpU5ouvw3TYTjy+1khqt89EwJxmn6R6lFTtw/vwHRe9Px712K5YOzUl6+0GsnVtVaOyKsCVL4cafbNSL13ljUMUyZSTVh+LiYi655BKuueYaJkyYEOrpRCRh/xZL9rk8QlEWVNMXPBuoUNvMAoPDbCOwDXNpetcXnXMcntAF8jZNFkLP/nwY1Dh443ywGnblwvsjpGtFIpFIJJLToes6Ho8HRVFQ1eBUt3+4zsKePIX/DXMa5jYOd2ZuM5EapdP9DN10TKmJJH/xf+Q/8BqFL39G0UffETv5YmKuHIESVb4n8pJWHh5YbOP3A2qZWxunx+hc297Dtb7AWo/vYeYAvUx25So8+quVPw+a6JDm5ZWBTtKi/34f3eHEsXA5ju8X4Zz/J3qJE7VWDRJeukuU/wTpNXsqDhQqXPeDjRirzntDnAFx4UuqLiUlJUyYMIEJEyZw5ZVXhno6EUvYiysWEyTaQ1MWlO4T2wMlruzPhw7pgdlXWdiRI8qZUqKMGzOQ9KoPH66BPzOF0BIKPJpooQ2iROmi1oEfY18+vPQHDGgobhKJRCKRlJXqUhYU7HyV0vy0y0TXDI1uFWzLW9UodMHifSYube3BdBZtwFQjiaR3HsK9fhsFz31EwdPvU/T+N8TeNIHoCUNQrJYz78DHiCZeXl6u89IyC+2GOYm1ln/egRJVnB54d62F99aYiTLDIz1djG/hQfW9BHWXG+eSlThmLcYx73f0wmLUlASixg3CPqI31nNaGyqqgGgvfd1cGyUehc9GOE7rrpFIAJxOJ5dffjljxozh2muvDfV0IpqwF1dAiAOhcK5YTWLsQIgrRS5xoKtjdKeg5Mh1QnSvLQJlF+0JnbiyaI8QV+rEwfurYXyrwD6fug73LxD7fKxP5P6vJBKJRCIJFkYKK/vzFbbnqFzYXWZT+Cn2gFtTzpgrcjKWtk1J/uRJXMvWU/DcR+Q/9AZF704j9pZLiRo7AMVsOuPjrSZ4sIeLO+ZbuXq2jXeHGO8iKnKLINgP11rYk68yvLGHe7q7SI0G3ePF+dsaSmYtwvHDUvT8QpSEWOzDexE1ojfW7u3P+jcGC4cHbvjJxt58hfeHOGUYs+SMuFwurrrqKgYNGsTkyZNlF7pKEhHiSmo0ZIXAuQIn2jFXFn8bZkPLgnJCJ0oEgigLdK0tHCOh4ptNkBwlWiPfMx/mbIdhTQO3/1lbhYDzcC/jukhJJBKJRBIpGCmsAPyyVyyI+9TznuWe1YcEn2vE7+QtD9aubUn++nlci1dS8PzH5N31IoVvf03cbZdjH9brjI6OQQ29vDbIyS0/27h8lp0He7jomqEF/ULUliyFrzabmbnNTJFboXmyxvtDHJyX4ca1bAN53y/CMfdXtKw8lNho7IPOxT6iD7aeHcvszAkWXg3uXmhl9WGVF/q56CrdV5Iz4Ha7mTRpEueddx633HKLFFYCQESIKylRsCUrNGNnxIpA28pidBvmfCccKYrcvBU/verDk0tE56NaBgpTAHkOmLcTJrQRXYM+Wy9cJufUgrSK5bP9Y/+PLoJ2aXBl5UL1JRKJRFJNqconw8EOrj0VC/eaaJig0SBBXu33YzNDlFknz1mx519RFGy9u2Dt1Rnnj79R8MLH5N74NOZXp2Dr2xVrl1ZYOrfClJL4j8f2qafx3mAndy60ctVsO51rermxk5tutQInsug67M1X+POgyrdbzKw5YsJq0hnSwM2E1EM0z9mD8/1lHJm9GO1INkqUDduA7kSN6IOtdxcUewVqloKArsNTv1uYt9vMvee6GNJYCoSS0+PxeLj++uvp0KEDd999d5X+LDGSiBBXUqPht/2hGbtmDKw+VPn9+MWVOgYJBDt9nYIiXVzpXQ+eRLg7JrQxduzZ28DphXEtRfbPSxfAsClw1zz4aFTlS3ieXgo5Dvh4NGetYZZIJBKJpDrhD64FghZcezKFLlh+UOXyNrL18skk2CourvhRFAX74B7YBnbHMXMRRZ/MpOjDbyl6ZyoApkZ1sHZuibVzayxdWmNuXAdFVelaS2PexQ6mbjbz3lozV8+x0zhR45wML+dkaHTJ8P4jWPZMaDpsy1ZYccjEikMqf+1xYN+fSd2svfQu3MNdzj3UOrYPdu9HL3GSA2CzYOvTlagRvbD1744aHX5Jx++vNTNlo4Vr2rm5Qr6GJWfA6/Vy44030qRJEx544AEprASQiBFXch3g9opFrpHUioPsksp3rMksEPWjNQLgeCgL/jbMjZONGS9YNEuBegnw3RbjxZVpm0S3oDZp4vumyXBvT3hkkehidFm7iu97WSZ8sQH+1enE/iUSiUQiqe4YXQZUmt8yTbg1RZYEnYIEG+RVoCzoVCgmE1Fj+hE1ph+6w4V7/VZcK/7CtWIjjp//pGTqPHG/xDisnVpi7SLElkvbNWV8CzvfbjWzYI+JmdvNfLlJvD7qxWukRevE23TirBBv1YmzicYEx4oVsos09MzDRO3fR/zBfdQ6uo+6Wfu4NmcvyfnHTkxOVTHVScfcqA7m89phalQHc6M6WNo2rXBb6WDj8sLrKy28t9bC8MYe7ujqDvWUJGGMpmncdttt1KxZk8cee0wKKwEmIsQVf7ebbMeJDj5G4Q+g3Z8PTSohVGQWQK1YjieLB5sdOSIlvV6E53goinCOvPiH6KpT16C/Z3curDwI/+3xd4fKle1h/i54YgmcVxcaVcAZdLhIuF/qxMHt3QM2ZYlEIpFIIppQCisAC/eYSLDpdEyXORUnEx8A58qpUOxWrOe0wXqOuIKm6zrenftxrdgoBJeVG3EuWCbubDahpibRV9fpq+ugg8er4faC16uLrlm6Lqwpvq8VdFRdx+J1Y9JOiGae2FhMjeoQ1aUjlsZ1hIjSuA7merXCpsynLOzIUbhroY1NWSrjmnt4sIfLsLWGJPLQNI277rqLmJgYnnnmGcNcgdWJyBBXfFa/rOLIFVf25xsbWLojG+onGO/0CQYX+sSVbzbBLd2MGfObTaAAY1r8/eeqAs8PhEGfwW0/wrTx5XuOM/Nh4jdwpBg+GQ3Roc09k0gkEkmEU1WuOoZaWPFqsGififPreAPWwrcqkWjT2ZSlounBvVCoKArmxnUxN65L9MUXAKBl5+FatQnXyo1oR3PEVS9VAUXB7v+aEz/z/15HQVF9ryWrGXPdDMyN62BqVBc1JSGi3zu6DlM2mnnuTwvRFnh9oJP+DaTjSnJ6NE3j/vvvB+Cll16SwkqQiAhxJdXnXAlFO+a6pcSVyrA/H/o2qPR0ysyOnMjPW/FTJx7OqwNTN4quPcFW5DUdpm+GnvWgZuw/f18zFp7qBzfMhTeWw61ldJ/syYUJ30CBEz4fA50yAjptiUQikUgiklAE157MuqMqOQ6FPvXlAvVU9KvvZd5uM3N3mBjWxNjnSE1OwD6gO/YB0u4LcLQY7l9sY8k+E+fX9fJkLyc1ypE5I6l+aJrGo48+SmFhIe+9954UVoJIRDyzfufKsRC0Y64RI7JSKiOuODziQFjHIOeKRxNlLVVFXAEY31qUBS3LDP5Y83aK//dFrU5/n+HNYHRzePlPeHCh6PxzJrZlw7hpUOyGKRdKYUUikUgkEjgRXKvrOqqqhsxN8MteE2ZF5/w6Ulw5FSObemmZovHCMgsOmZUaMn7ebWLU9CiWHVB5sIeLdy6QworkzOi6ztNPP82hQ4d49913pbASZCLi2S1dFmQ0qgK142B/QcX3cdD3WKPElX154NYiP8y2NEMaQ6xVuFeCia7DWyuEY2lo0zPf96l+cHk70aK57yfw1V/C9VIapwd+3QsXTxP7/upCaCsDbCUSiUQSICK5tMHr9eLxeFAUJeQn/Av3mOhcUyPeFtJphC2qAv/t7uJgkcrH6yPC+F6lKHLDg4ut3DTPRkaszvQxDia28gSsHbWkaqLrOs8//zw7d+7kf//7HyZTFciLCHMi4ugYbwWLCsdCIK6AEEUq41zxP7a2bMNcYaIsMLwpzNwKj/WBmCBljf2ZKVpvP9GXs9Zcx1jh8b5wSRvhXrn7Z9EBaEIb2HIMVh2Cv46KFPeMWJgytmIBuBKJRCKRVCVCna9yMpkFCttyVO7p5grpPMKdrrU0+tf38O5aC2Obe6RjwiDWHFa55xcr+/IV/tXezY2d3VjlGllyFnRd57XXXmPdunV8+eWXmM0RseyPeCLCuaIowr0SirIg8DlXKiOu+JwrRgXabq+C4grAuFairGbO9uCN8eYKkfEz/gwlQSfTugZMHw8vDhKvk7t/hs83CEHw6g7wzjD48TIprEgkEolEEm7CCoiSIEDmrZSBO7u5cXngtZWR01EnUvFo8MZKM5fNsuHW4OPhTm7vKoUVydnRdZ133nmH3377jS+++AKLRXbQMIqIkbBSokJTFgTCuXK0WGSn2CvwjGUWgEkR7gUj2JEjBIIEuzHjGUWXDGiYKEqDyiN+lJW/jsKiPXDXueX/PyuK6Gp0QWMhsDROqhqdmiQSiUQS3oSDOFFWwiG49lQs3GOiYYJGgwT97Heu5jRI0JnY2sNnf5mZ0EqhZYp8zoLBqkMq//enhbVHTIxoIlosx0k9S1IGdF3nww8/5Oeff+abb77BapUvHCOJCOcKQGo0ZIXIuVKnkh2DMvNFhxmjWvvtyK5aeSt+FAXGtRSlO3tyA7//t1eIXJfL21d8H7FWaJEqhRWJRCKRSEoTLsG1J1PkgmUHVfpK10qZmdzRTaINrplt588DEbOUCHs0XQTWTpxp49JZdvbkqTzfz8n/9ZXCiqTsfPrpp8ycOZPp06djt1exK+0RQMQcEUPtXIGKiyv7843LW4Gq1Yb5ZMa2BAWYtimw+92bB99vg0vbQoIMs5NIJBKJJGCEU3DtySzNNOHWFHrXk+JKWUm0w+cjHaRE6UyaY+Pzv8zo0sBSYZwemLrZxPCpdm6aZ+NIkcID57mYP6GEYY3l61JSdr744gu+/vprZsyYQVRUVKinUy2JmLKgVF/miq5jeDJ23co6Vwqga+3AzedMZJdAjqPqiiu14uD8ejB9E9zWXaTXB4J3Vgpn0aQOgdmfRCKRSCRGEC4OkFMRjvkqJ/PLXhMJNp1O6VqopxJRNEjQ+XKUg7sW2njiNytbslUeOM8l80DKQZ4Tvtpk5tMNFo6VKLRK1Xihn5NBDb2Gud0lVYfp06fzySef8P333xMTExPq6VRbIkZcSYkWmSfF7uB1ijkdaTEinLQi4opHg0OFxrVh3uEPs62CZUF+xrWCm3+Ahbugf6PK7+9okchxGdsC0g3KxZFIJBKJpCrjF1a8Xm9YlQGdzMpDKudkyMVsRYi1wusDnby60sK7ayzszFF4ZaCTFHnB/IwcKFT4ZL2ZqVvMFLsVetbxck07N91rabK1sqRCzJw5k3feeYfvv/+euDgDyyUk/yBixJVU34E6q8R4cUVVfB2DCsr/2EOF4NWNKwvakS22VdW5AjC4MTRIgMcWQ496FQsZLs3/1op2ydd3Dsz8JBKJRCIxEkVR0MOoLqN0cG04CysAiXadQlf4zi/cMalw2zlumiVp3L/Yyvhv7bw+yEmr1PB5PYYLW7IUPlhnYe4OEzowtLEQVVrIUOAqh67D4n0qveoGXzCbO3cur7zyCrNnzyYxMTG4g0nOSsTo9CnRYnsshLkrFXGu+B9TxyhxJQdsJmMzXozGZoYn+sLuPHhrReX2tSMHPlwNQ5vKVskSiUQikVSWcO0IdDrqxensyw/vOUYCw5p4+XyEAx24dKadV1dYOBKic/ZwQtfhj0yV6+baGP1NFPN3m5jY2sOPFzv4v74uKaxUQVxeeHCJlX//aOenXcGtk/v555959tlnmTVrFsnJVbhsIYKISOdKKKgTD/N3lf9xmT63S20Dy4IaJoorCVWZ8+vDyGbw5goY1bxiwojTAzfOFc6Xh3sFfo4SiUQikVQnvF7v8XyVcAuuPR114nXm7FRweZF5IZWkdQ2dqaMdPPKrlbdXm3lvjZnBjbxc1sZD+7Tqk2lT6ILfM038ul9lyX4TBwtVUqN0bu3i4pJWHtk4oQqT44Cb59lYccjEvzu6GdgweIHEixYt4rHHHmP27NmkpqYGbRxJ+YgYcSUcnCtHi0XuS3nKUPzOlVoGlgW1TjNmrFDzYC/4ZTfcvwCmjC1/0PGzS2HjUfhghMxakUgkEknkEuqyoEgIrj0d9eI1NF3hQKFCgwTpIqgsqdHw+iAXe/MVpvxlZvoWM9/vMNO2hpfLWnsY3Mhb5UQsXYfN2Qq/7jOxeJ+JNYdVPLpCjEXn3NpebursZmgjL7aIWXVJKsK2bIX//GTjSLHC832dDGsSPGHl119/5f7772f27Nmkp6cHbRxJ+YmYt3my37kS4nbMmQXlyzPJLIAa0ZXPBSkLTg/sy4eRzYM/VjiQFgN394AHFsKMLTCmRdkfu2AXfLAGrmoPAwIQiiuRSCQSSXUkkoUVgLpxQlDZny/FlUBSL17nv+e6uamLm++2mflsg5l7frHx3J86F7f0cFFLN2nRoZ5lxcl1wG+ZJpbsM/HrfhPHSsTrvmWKxtXtPPSs66VDmlblhCTJqfllr8qdC2xEmeHT4U7aBdGp9eeff3L33Xcza9YsMjIygjaOpGJEjLhiN0OcNbRlQSCcKOURV/bnG9cpaHeuCM+tymG2JzOxDUzbCE8shn4NIMF+9sccLoI750HLVLi3Z9CnKJFIJBJJlSTS8lVORb14IajsLVCB6lO6YhQxFpjYysMlLT38nqny6QYLb6yy8O4aM4MaeulRx0vrVI1GiXpYd2zyavDXMZUl+0Spz/qjKpqukGDTOa+2l/Prir8lkgUjSfnRdfjfejPP/2mhZarOGwOd1IwNnki7cuVKbrvtNmbMmEHdunWDNo6k4kSMuAKQEhW6sqC6pcSV8pCZD20MKtPZ6usU1CzFmPHCAZMKT/aDEV/C//0mvj4Tmg53/ARFbnh1sDGOIolEIpFIgkkoRI2qIKwA1IjWsZlkqG2wURXoUUejRx0nu/MUpmw0M2Ormdk7xImYzaTTPEWjVYpG61SNVqkaTZJ0w50fJR7hYtpfoJBZoLK/QHy98pCJXKeCgk7bGhr/7uihZx0v7WpoVT7nUHJqXF549Fcr32w1c0FDD0/1dhFtCd54a9eu5cYbb2T69Ok0aNAgeANJKkVELS1TokMnrqTFgEUtn7ii6XCgEAY3Cd68SrM1S3x4VbeuN23S4Or2osyn2A13nXfqjJvsEnjpD1iyF57pX71EKIlEIpFIAoU/uBaImODa06EoUEd2DDKUBgk6953r5p5ubvbkK/x1TGXjMZW/jql8v93Ml5vE/8Ki6jRL9ostOq1SNerGiVIbiwnMSvnz9twaHCxUjosmpQWUzAKVrJK/79Bu0qkdp3N+XS+9fO6UpDK4pCVVm+wSEVy78rCJ/3Ryc0MnN2oQDyF//fUX119/PVOnTqVJE4MWlpIKEVHiSmq0KH0JBaoiFuzlEVeOFglV06i2yFuzoEFC9XRj3N1DpPx/uAZmb4N/dYLJXSDWCnvz4P3V8NVfIpD4srZwSetQz1gikUgkksgi0vNVTke9eI29+ZEtEkUiJhUaJeo0SvQywhf+qemwN19hYynB5cddZr7efOrXmkXVsZjEBVCLChaTfvxrq8n3e9+/9kCRwuEiBU0/sS+TopMRq1MnTqdPPS914jTqxAlBpU6cRmpU+QUcSdVma7bC5B9tZJUovNDPydDGwQuuBdi8eTOTJk3iiy++oHnzahKsGcFE1DI8NQpWHgjd+HXiRWBsWdlvcBvmbdnQtJq6Mexm+G9PuLQdPPcbvL4cvtwAnWvBzzuFODa6BVzfGZrKNvASiUQiqUIYIXJUVWEFoG68zh8HFHRdLqRDjaoIZ0uDBO/xRauuQ2aBcLgcKlLwaMKB4tYU3F7f194TP3P97WfiPhrQpaZ2XDSpEycElfSY8M56kYQXC/eYuHOhlViLzmcjnLSpEdycpm3btnHVVVfx6aef0rq1vDIcCUSUuJISDdkOESoVivrGOvGiy0xZyfQJMXUMcK44PcLVM7Rp8McKZ+rGiyyVazrAU7/C7/thUkfxfYZBDiKJRCKRSKoSVSVf5XTUi9cp8SgcKxEdHiXhhaJAnXidOvHBdQhIJKdD1+GDdWZeXGahdarG64NcpMcEt7vY7t27ueKKK/joo49o3759UMeSBI6IElfSYoRd8FgJpMcYP36deDhaLEpLylJ6k2mgc2VnjugUJF0Zgg414etxoZ6FRCKRSCSRTVUXVgDqxourz/vyVWpEy45BEonkBC4vPLzEyoxtZgY3EsG1UUFeQe/bt4+JEyfy3nvv0alTp+AOJgkoEWWEqxkrtocLQzO+v6WyXzQ5G/vzIdEucj+CzfFOQVJckUgkEomkWhEswUPXdTweD7quo6pqlRRWAOrG+doxy1BbiURSiqwSuGq2jRnbzNzY2cWL/YIvrBw4cIBLLrmE119/na5duwZ3MEnAiSjnit+tcrgoNOPX9ZWV7M+HxmXoyLM/39gwW1M17BQkkUgkEokksFTlfJVTUTtOR0FnX4EKyNITiUQCW7IUJv9kI6dE4aX+TgY3Cv6x4dChQ1x00UW89NJL9OzZM+jjSQJPZIorIXaulLVjUGaBcWLHtmxokAi2iPqPSiQSiUQiCSeqm7ACoqtMzVjZjlkikQjm7zZx90IrcTadz0Y4aF0juPkqAEeOHGH8+PE8++yz9OnTJ+jjSYJDRJUFpUSLFPFQOVfSYkRrt7KIKyLZ3JgwWxDOFZm3IpFIJBKJpKJUR2HFT704Ka5IJNUdXYf31pi5aZ6VxkkaX482RljJyspi/PjxPP744wwcODDo40mCR0SJK2ZVpLgfCpFzxaRCrbiyiSs5Dih2GxNm6/DAnjxoVk3bMEskEolEUp0JhAjiD66tjsIKiHbMoixIIpFUR5we+O8vVl5cbmVIYy+fDHeSZkD3sNzcXMaPH88DDzzA0KFDgz+gJKhEXBFJekzonCsgSoP2lUFc2ZcntnUNEFd25IguSlJckUgkEolEUl78wbUAqlo9BYb6CRpZJWZ25io0Sgz+lWqJRBI+HC2Gm+bZWHvExM2dXfy7owcj9OX8/HzGjx/PnXfeyahRo4I/oCToRNwnaHosHAmxuFKWbkF+AcYIcWVrltjKTkESiUQikUjKit+t4vF4UBSl2gorAKOaekiw6Ty42IomtRWJpNqwKUvh4hl2tmarvDLAyeROxggrhYWFXHTRRdx4442MGzcu+ANKDCHiPkXDwblypEiU4pyJvX7nSkLw57QtS5RMNZSdgiQSiUQiqXZUpISnOuernIoa0fDf7i5WHTbxxcaIM3ZLJJIKMG+XiUtn2tGBz0Y4GNTQmG5hRUVFXHLJJVx77bVMmDDBkDElxhB54kosZJeIurhQ4O8YdOAs7pV9+ZAcBbHW4M9pq69TkNUU/LEkEolEIpFENlJYOTWjmnrpWcfLi8ssZBbI50QiqaroOryz2szNP9tomiyCa1ulGmNZKykpYeLEiUycOJErrrjCkDElxhF54oqvHfOR4tCMX9fX/edsobb78o0pCQJRFiRLgiQSiUQiqb6UVSCp7sG1Z0JR4JGeLgAeXmJFl+VBEkmVw+GBuxdaeXmFlRFNPHwyzEkNA4JrAZxOJ5dddhljxoxh0qRJxgwqMZSIFVcOh6hjkL/7z1nFlTyoZ0BJkMMjSpBkmK1EIpFIJJIz4Q+u1XUdVVWlsHIKasfp3N7VzdJME99tk5ZgiaQqcbQYrpxt4/sdZm47x8WzfVzYDKoCdLlcXHnllVxwwQVMnjxZHn+rKJErroQodyU9RuSbnElc8Woi9NYI58r2bNCR4opEIpFIJJLTI4Nry86EVh46pXt55g8rR0PklJZIJIFl4zGFi2bY2Z6t8tpAJ9d1MCa4FsDtdjNp0iR69uzJLbfcIoWVKkzEfbqmx4ptqJwrJhVqxZ25HfOhQnBrBnUKyhbbprIsSCKRSCSSasvpTtZlGVD5URV4opeLEg888ZsB4XkSiSRoFLtFvspls+wowOcjHQxoYExwLYDH4+G6666jQ4cO3HXXXfIYXMWJOHElyS6CW0PaMegs4oqRbZi3ZYFFhYaJwR9LIpFIJBJJ5OAPrvV6vVJYKScNE3Vu6OTmp11mftoly4MkkkjDrcGXG81c8FUUL6+wcl5tL1+PdtAixbgwJa/Xyw033ECzZs144IEH5DG4GhBxveYUBdJC3I65URLM3CqSpk/1HvGLK0ZkrmzNEi2YLfJzXyKRSCQSiQ+/Y0Xmq1Scq9t5+GGnmceXWulWq4QEW6hnJJFIzoauw4+7TLy83MKefJXO6V5eGeCmU03N0Hlomsatt95KRkYGjz76qDwGVxMizrkCIvckVGVBAE2SId/Jaetw9+YJS2mtuODPZWu27BQkkUgkEkl1p/SJuxRWAoNFhSd6OclxwLN/yPIgiSTc+SNT5aIZNm6bb8NigjcHOfl0hDMkwsqdd95JXFwczzzzjMy5qkZE5H86PcTOFX++yfbsU/9+Xz5kxAbfTVLiFl2JZJitRCKRSCQSOBFcC8gT+gDQKlVnUnsP3241s3S/fD4lknBk4zGFa+fYuHqOnSyHwtO9ncwY66Bvfa9hobV+NE3jvvvuQ1VVXnzxRXkcrmZE5H+7ZmxoxZUmZxNX8oztFCTDbCUSiUQiqd7I4Nrg8Z+ObhomaDy0xEqRO9SzkUgkfvbmK9yxwMqF30ax4ZjKPd1czB3vYHQzL6YQrHI1TePRRx+luLiY119/XQor1ZCI/I+nx0ChS9xCNX6sFbbnnPr3+/KhrhF5Kz5xRzpXJBKJRCKp3rz99tssWbLkuLgiCRw2s+gedLBQ4eXlllBPRyKp9hwrhseXWhj2tZ0Fu01c38HNvEtKuKqdB1uIEkV1Xefpp5/m8OHDvPPOO1JYqaZE5H89zdeO+UiI3CuKAk2STu1ccXiEq8aQNsxZonNSg8TgjyWRSCQSiSR8adKkCVOmTOG8887j1ltv5ZdffsHtljaLQNGppsbEVh4+/8vMqkMRefoskUQ8hS54baWFC76K4qtNZi5s4eHHS0q49Rw3cSGMRdJ1neeff56dO3fy4YcfYjLJTiPVlYjrFgTCOQJwqFB07gkFjZPh173//Pl+fxtmgzoFNUoEs/yMl0gkEomkWjNkyBCGDBmCy+Vi4cKFTJs2jbvuuouuXbsyevRoevfujdUqQ1krw21d3SzYa+LBJVa+GeMI2RVyiaS64fLCl5vMvL3aQo5DYXAjD7d0cdMgwbi2yqdD13VeffVV1q9fzxdffIHZLA8M1ZmIXJb7xZWQ5q4kifHznX//ub8NsyHOlWxZEiSRSCQSieQEVquVCy64gPfee4+1a9dy+eWX8+OPP9KzZ0+uu+465syZg8PhCPU0I5IYCzzW08XOXJW3VsvyIIkk2Hg1+G6biaFf23n6dyvNkzWmjnbwUn9X2Agrb7/9Nr///jtTpkzBYpHHhepOREprNX1lQeEQarsjBzrWPPHzfXliWy/IzpU8h3DJTGwT3HEkEolEIpFEJmazmX79+tGvXz+8Xi9Lly5l+vTpPProo7Rq1YrRo0czcOBAoqOjQz3ViKFnXY3RTT28u8aMSYXJHd3SQSyRBBhdh8X7VF5abmVLtkrLFI1Hz3dwXm3N8O4/p0PXdT744APmz5/PN998I52BEiBCxZVYq7h6cKQwdHMo3TGotLiyNx9sJkgL8nnKhqNi2zYtuONIJBKJRCKJfEwmE7169aJXr15omsby5cuZOnUqTz/9NE2bNmXUqFFccMEFxMXFhXqqYc+DPVyAlTdXWfjzgMpzfV1kxIb+KrpEUhVYc1jlxeUWlh80US9e4/l+ToY08qKGiaji59NPP+X777/nu+++w263h3o6kjAhIsUVEKVBh0LoXKmXIMJkTw613ZcHdeIJf9hHMQAAMKlJREFUuqq6/ojYSnFFIpFIJBJJeVBVlW7dutGtWzc0TWPNmjVMnTqVF198kXr16jFq1CiGDh1KQoIBAXIRSLQFnu7j4tzaXh5damXMN3ae7OWifwNvqKcmkUQsO3MVXlpu4efdZlKidB44z8X4Fh6sYZgNO2XKFL7++mtmzZpFVFRUqKcjCSMiVlxJi4XDIXSumFXRpecf4kq+MXkrG45AnThIku9niUQikUgkFURVVTp16kSnTp146qmn2LBhA1OnTmXEiBHUqFGDUaNGMWzYMFJSZMjbyYxs6qVdmoM7F1i5cZ6Nia3c3N3NLYNuJZJycLhI4fWVFr7ZaiLKDDd1dnFlWw8xYRpfMm3aND799FNmz55NTExMqKcjCTMi9vBfMwZWHgztHJokwcajf//ZvnzonBH8sdcfhtbStSKRSCQSiSRAKIpC27Ztadu2LY8++iibN29m2rRpjBs3jvj4eEaOHHlcdFHCJfggxDRI0Jky0slLyy18tN7CykMmXujnpHGSLBOSSM5EnhPeW2Phs7/MaDpc2trDvzu4SQ7jC8ffffcd7777LrNnzyY2NjbU05GEIREbwZUeC0eKROBRqGiSLDJWHB7xfZ5DdA8KdhvmfCfszpMlQRKJRCKRSIKDoii0bNmSBx98kD/++IO33nqL4uJiJk6cyLBhw3j77bc5ePAgeihPxMIEqwnu6e7m7QscHClWGD/DzvQtppCeo0ok4YrDA++vNTPoyyg+XGfmgoZe5l7k4L5zw1tYmTNnDq+++iozZ86UJZOS0xK54koMOL2QG8Jugk2SQdNhd674fq9BbZj/kmG2EolEIpFIDEJRFJo0acI999zD0qVL+fjjjwG46qqrGDx4MK+99hr79u2r9kJL73oaM8Y6aJ+m8cBiG3cutFLgCvWsJJLwwKPB1M0mBn9t54VlVjqka3w71sGzfV3UjgvvY8e8efP4v//7P2bNmkVycnKopyMJYyJaXIHwaMfsz13xt2EOtrgiw2wlEolEIpGEAkVRqF+/PrfffjuLFy/mq6++Ijo6msmTJzNgwABeeukldu7cWW2FlrQYnfeHOLm1i4sfd5oY+42ddUci9nRbIqk0ug7zdpkYNd3OQ0tspMfofDzcwTuDnTRPCf/jxC+//MLjjz/O999/T2pqaqinIwlzIvZon+4rcwuluNI4CRRKiSs+50q9IDvF1h+GWrGQEuR2zxKJRCKRSCSnQ1EUatWqxU033cT8+fOZMWMGKSkp3H777fTp04dnn32WLVu2VDuhxaTC9R09fDrCiVeHS2fa+GCtyJWQSKoTyw6qTJhp4+afbQC8OsDJlyOddM3QQjyzsvHrr7/ywAMPMGvWLNLS5FVtydmJ2EDb486VEHYMsptF2+XtOeL7vXmQYIN4W3DH3XAE2sj3t0QikUgkkjBBURTS09P597//zb///W+ysrKYMWMGDzzwAIcOHeKCCy5gzJgxtGzZElWN2Gt75aKjr+zhoSVWnl9m5Y8DJp7u7SRVXhyTVGEOFir8uMvE3J0m1h0xkR6j8dj5TsY082KOoLf+H3/8wT333MOsWbPIyDCgW4mkShCx4kqaT1w5FELnCojSoNLOlWCH2RY4YWcujG4R3HEkEolEIpFIKkpKSgqTJk1i0qRJ5ObmMnPmTJ588kl2797NwIEDGTNmDO3atavyQkuCDV7u7+LrzRpP/25hzDdRPNPHSY86kXHlXiIpC4eLhKDyw04Tqw+bAGiZonFPdxeXtPRgj7AV58qVK7ntttuYOXMmderUCfV0JBFEhL3UT2A3Q6I9tM4VEOLKb/vAq8H+fGiWEtzxNh4TW5m3IpFIJBKJJBJITEzkiiuu4IorrqCgoIDZs2fz0ksvsWXLFvr378/o0aPp3LlzlRVaFAUubumhY7qX2+fbuHaunWvbu7m5ixtL1fyTJdWAI0UKP/kElZU+QaVFssatXVxc0MhLg4TIrINbu3YtN954I9OnT6d+/fqhno4kwohYcQVEadCRUDtXkkTXor15QlwZ0Ci4460/LLZSXJFIJBKJRBJpxMXFcckll3DJJZdQXFzM3Llzeeedd1i/fj29e/dm9OjRdOvWDZPJFOqpBpxmyTpTxzh45ncL76+1sOygygt9XdSJj8xFqKT6cbQYftplFoLKIRUdhWbJGjd3djG4kZeGiZH9Wt6wYQPXX389U6dOpUmTJqGejiQCiWhxpWZsaANt4UTHoJUHhMhiRKegmrFQIya440gkEolEIpEEk+joaC688EIuvPBCHA4H8+bN49NPP+XWW2+lR48ejB49mh49emA2R/Tp6t+IMsOj57s5t7bGQ0usjPnGzqPnuxja2BvqqUkkp+RYMczbLQSV5QeFoNIkSeOGzm4GN/TSOCmyBRU/mzZt4tprr+WLL76gefPmoZ6OJEKJ6E+rtBjYfCy0c2jqE1dW+xwlRogrbWoEdwyJRCKRSCQSI7Hb7YwYMYIRI0bgcrlYsGAB06dP56677qJr166MHj2aXr16YbVaQz3VgDC4kZe2NRzcucDKHQtszN3p4eIWHs6trWGSpUKSEJNdckJQWXZQRdMVGiVq/KeThwsaemiaXDUEFT/btm3j6quv5rPPPqN169ahno4kgolocSU9RtjTvBoh+yBKsEONaNiaJb4PZqBtkQt25sDIZsEbQyKRSCQSiSSUWK1WBg8ezODBg/F4PCxevJipU6dy33330bFjR0aPHk3fvn2x2+2hnmqlqB2n88kIJ++stvDZX2Z+3m2mZozGqKZeRjX1RHyJhSSyyHHAz7tNzN1pZtkBFa+u0CBB4/oOHgY38tA0SUdRQj3LwLNr1y4uv/xyPvroI9q1axfq6UginIgXVzQdjpWcaM0cChonw+5cUIDaccEbZ+NR0JF5KxKJRCKRSKoHZrOZfv360a9fP7xeL0uXLmXatGk8/PDDtGnThtGjRzNgwACioyOzv7FFhRs7u7mug5sFe0x8u9XMe2vNvLPGQqd0L2OaeRjcyEts1TDsSMKMXAfM32Pih51mfs8Ugkq9eI1r23sY0shDs+SqKaj42bt3LxMnTuT999+nU6dOoZ6OpAoQ0eJKzVixPVIYWnGlSZLIXEmPIaitxtYfEds26cEbQyKRSCQSiSQcMZlM9OrVi169eqFpGsuWLWPq1Kk8/fTTNGnShDFjxjBo0CBiY2NDPdVyYzWJUqHBjbwcKVL4bpuJGdvMPLjExlO/6wxsIISWrrU01Cq82JUEnzwnLNgjHCq/71fx6Ap14zQmtRcOlRZVXFDxc+DAASZMmMCbb75J165dQz0dSRUhosUVv6ByqAjahnAeTZLBrZ0Qe4LF+iMiZyaUQpJEIpFIJBJJqFFVle7du9O9e3c0TWP16tVMnTqVF154gXr16jFy5EiGDh1KQkIQ67WDRFqMzr86eLi2vYe1R1RmbDUxe4eZmdvN1I7VGN3My+imHtllSFJm8n2Cyg87zfyWqeLWFOrEaVzVTggqrVKqh6Di59ChQ1x00UW89NJL9OjRI9TTkVQhIltc8YkZhwtDOw9/qG2cLbjjyDBbiUQikUgkkr+jqiqdO3emc+fOPPXUU2zYsIFp06YxfPhw0tPTGTlyJMOHDyc5OTnUUy0XigId0jU6pGv891w3P+8WZUNvrjLzxioLXTOEm2VQQy/RllDPVhJuFLr8goqJX/ebcGsKtWI1Lm/jYUgjL61TtWolqPg5cuQI48eP59lnn6VPnz6hno6kihHR4kpqtMg5CXU7Zn/OijWIobrFbtiRA0Nly3WJRCKRSCSSU6KqKu3ataNdu3Y8+uijbN68mWnTpnHhhReSkJDAyJEjGTFiBKmpqSgRtLK0m2F4Ey/Dm3g5UKgwc5sQWu5dZOPx33QGN/QytrmHTunVc8EsERS5YOFeIags2W/C5VXIiNG4rLXH16Gqer8+srKyGD9+PI8//jgDBw4M9XQkVZCIFlfMqhBYQi2uFLvF1q0Fb4yNR0V4b1uZtyKRSCQSiURyVhRFoWXLljz44IM88MAD7Nixg2nTpjFhwgRsNhsjR45k1KhRpKenR5TQUitW598dPVzfwcPKQyrfbjXzwy4T32w1Uy9eY0wzD6OaesmIlWVD1YEiN/ziE1QW7xOCSnqMxoSWQlBplyZzegBycnIYP348Dz74IEOHDg31dCRVlIgWV0DknBwKcVnQrlyxzSwI3hj+MFvZKUgikUgkEomkfCiKQpMmTfjvf//LPffcw549e5g+fTpXXnklACNGjGD06NHUrl07YoQWRYEuGRpdMlzcdx7M2yXcLK+ssPLqCp1zawuhZUADb1AbLkiMQ9fhYJHCtmyFbTkqa4+oLNlnwulVSIvWuLilKPlpLwWVv5GXl8dFF13EXXfdxciRI0M9HUkVJuIPtekxsC8/tHPYmevb5kCBMzjZKxuOQI1oGWYrkUgkEolEUhkURaFBgwbccccd3H777Rw4cIDp06dz/fXX43Q6GT58OKNGjaJBgwYRI7TEWBBBt8287MtX+G6bmW+3mrhroY04q87QRh5GNxOL7gj5k6o92SWwLUdlW7bKthzl+NeF7hP/wIwYjXHNPQxp7KVjuhRUTkVBQQEXX3wxN954IxdeeGGopyOp4ii6XnbLYJcuXfQVK1YEcTrl55FF8PVf8NdkQvZhceuPsHgPZJXAJ6Ohd/3Aj9HnY2iUBB9KsVUikUgkEYCiKCt1Xe8S6nkEEVlzUcXQdZ0jR47wzTff8M0335CXl8fQoUMZPXo0TZs2jRihxY+mw7IDomzop10mHF6Fxokao5t5OLe2lwYJOjEyCDfkFLp8IkqOwvZs1fe1SlbJiddbgk2nWbJG0ySNpkk6TZM1miRpJAS5mUakU1RUxMUXX8xVV13FFVdcEerpSAJDWB+II965Ui9e1Bpml0BKdGjmsCsHmqXAskxYfiDw4sr+fFF6dEW7wO5XIpFIJBKJRCJQFIX09HQmT57M5MmTOXbsGN999x333XcfR44c4YILLmDMmDG0bNkyIoQWVYHutTW613bxYA/4YacoG3phmfX4fTJiNBom6jRMENtGiRqNEnXSoqtXa14jcHpgZ57CtmyVrcdFFIWDhSc6YkSbdZokafSu66WpT0xplqyRGhW6i8iRSklJCRMmTGDixIlSWJEYRuSLKwliuycvNOKKrotyoNEtREL38szAj7Fkr9j2rBf4fUskEolEIpFI/klqaiqTJk1i0qRJ5ObmMnPmTB5//HH27t3LwIEDGTNmDG3btkVVg9guMkDEWmFcCy/jWnjZn6+wMUtlV67CzlyVXXkKM7aZKSpVbhJtEYJLo8RS20SN+vE6tohfPQQXjwZ78xVfOY+vpCdbZU++gqaL59ii6jRK1OmcrtG0pUe4UZI0asXpsrQnADgcDi677DLGjh3LpEmTQj0dSTUi4g+PfnFlbx50yjB+/GPFkO8SJTsWE3y2TijTgfzgWbJXZK00TQ7cPiUSiUQikUgkZSMxMZErrriCK664gvz8fGbPns0LL7zAtm3b6N+/P6NHj6ZTp04RIbTUidepE+/92890HY4WK+zK8wkuPuFlxSGVWdtPnNSqik7tWP24y6VhokbDBPF1sr16uSt0HQ4UKmzNVtjuK+XZlq2yI1fBrYknQkGnXrwo4xnc6ERpT70EHUv4v1QiEpfLxZVXXsngwYOZPHlyRLjMJFWHiBdX6saL7d4Qhdr6OwU1ShKdiz5YDRuOQucACT1eDZbugwENq9cHlkQikUgkEkk4Eh8fz4QJE5gwYQLFxcXMmTOHt956iw0bNtCnTx9GjRpFt27dMJlMoZ5qmVEUSIvRSYvR6VZL+9vvit2wJ98vuqjszFXYlauy7IAZh/fvuSClXS5+0aVOfGQLCbouchW3lRJQ/AGzxSeFyzZJ1jmvjldkoyRrNE7UZacmA3G73VxzzTX06tWLm2++OSyFFa/XS/369cnMzCQ1NZXMzEysVusZH7Nq1So6d+4MwEUXXcRXX31lxFQlFSDi3+5RFkiLgb25oRl/R47YNkyEaF8o2PLMwIkrfx2FXIcsCZJIJBKJRCIJN6Kjoxk3bhzjxo3D4XDw008/8cknn3DrrbfSs2dPRo8ezXnnnYfZHLmn3NEWaJmi0zLFC5xwvGg6HCxUhMslT4guu3NVluxX+Wbrib/XrOjUS/h7rkvDBCHAnCqQVdNFaY1XFxcZPZr/e+UfP/fq4NFO+vnx35/m577HeLUz/FwXotLOXCGo5DhOLNITfeGyY5p6hBMlWeSkxJ15fSwJMh6Ph+uuu45OnTpx5513hqWwAmAymZg0aRKPPfYYx44dY8aMGVx00UVnfMx77713/Ovrrrsu2FOUVIKI7xYEMG6qCO36epzxYz/1K3y0Bjb9B0wq9PsEGiQGrqvPG8vh/36DFddCDdmGWSKRSCQRguwWJKnOuFwu5s+fz/Tp0/njjz/o1q0bo0ePplevXlgsVb9FT4GLEy6XvBNlRnvzT5TMAMRaxNvIq58QUfQQNwOxqDomRZT4N0g40Z1HdOrRSJHhsmGH1+tl8uTJNGzYkMceeyxshRU/+/fvp0GDBni9XgYMGMC8efNOe9/i4mIyMjLIz8+nUaNGbN++Pez/viAT1n985MropaiXIEpnQsHOHKifKIQVgHNqwdztQnUPRCDVkr3QMlUKKxKJRCKRSCSRgtVqZciQIQwZMgS3283ixYuZOnUq9957L506dWL06NH07dsXm61q9tKNs0K7NI12/9/evUfXdOf/H3/tnESESFCKujVapSimRevaIHTilviaqMZdQ9FWZw0zE12jinYYWj+qqjquk1IkNJKgtCOdstwaQd3aBlEtJah7XZJz9u+PI0qLRk5yds7J87GWdZKdnX1elpzglc9+f+6Xbl7tkuOQfrjgvK0o85yhHy8aMgzJ15B8fZz/nvY1TNl8JNvNx253/Pqj8xzz3o5fv9avjzNM1vM4HA698sorqlq1qsaNG+cRxUO1atXUqVMnJScn67///a8yMzMVEhJy23OXLl2q8+ed8y9iYmI84vdXnHlFuVIzWFq+X7qSI7ff13jozK2DZps+IC3ZK2WclupUcO3al7Ol7T9K/Ru5dh0AAABYw8/PT+3bt1f79u1lt9u1ceNGJSQkaOzYsWrQoIEiIyMVFhamgIAAq6MWOl8f6cFgUw8G29XW6jDweA6HQyNHjlRQUJAmTpzoEQOlcw0dOlTJyckyTVNz587VG2+8cdvz5syZI0ny9fXVwIED3RkR+eA5X4F3Uf36jkE/uHmobY7DuUtRrXK/HGtW1fm47Zjr1996VLpml1ozbwUAAMDj2Ww2Pf3005oxY4Z27dqll19+WVu3blXbtm3Vr18/ffzxx7p48aLVMYEiz+FwaPTo0bLZbJo6dapHFSuS9Mc//lE1a9aUJM2fP192u/035+zbt0+bNm2SJHXt2lWVK1d2a0bcO8/6KryDGtd3DPrunHuf9/tzUrbj1nKlepBz2+RtR12//oYjUgmb1OwB168FAACAosPHx0ctWrTQ1KlTtXPnTsXGxmr37t3q2LGjoqOjtWTJkhu3AwD4hcPh0Ouvv67Lly/r3Xff9bhiRXK+/gcPHixJOnbsmFatWvWbc24eZJt7Loo2z/tKvI2a11euHHFzuXLorPPx5nLFMJyrV9IKYOXKxiNSkyrOHZEAAADgnXx8fNSkSRNNmjRJ6enpGj9+vDIzM9W5c2dFRUUpLi5OZ86csTomYDnTNPXPf/5TWVlZmj17tkcWK7mef/75GzuJ5d7+k+vq1auKi4uTJNWoUUPPPPOM2/Ph3nnuV+NNKpSSAnydK0nc6dD1v+Nqlb31eJMHpGMXXbtNKeuS9PVpbgkCAAAoTnx8fNSwYUNNmDBBaWlpmjJlik6cOKHu3bsrMjJS8+fP18mTJ62OCbidaZqaMmWKDh8+rHnz5slms1kdySWVK1dWRESEJGn16tU6evSXWx8+/vhjnT59WpI0aNAgjy6RihOv+FMyDOeOQe6+LSjzjFSupFTuV/PHcm/jcWXuysbrux9RrgAAABRPhmGoXr16eu2117R161bNnDlT58+fV69evdSlSxd98MEHOn78uEyTncHh3UzT1PTp07V3714tXLjwxooPTzd06FBJzu2k58+ff+N47i1BNptNgwYNsiQb7p1XlCuS89agI26+LfXgWSmk3G+P17lPCiohfenC3JWNR5zFTf37838NAAAAeAfDMFS7dm2NHj1amzZt0ty5c5WTk6N+/fopPDxcM2fO1NGjRyla4HVM09SsWbO0detWLVq0SH5+3jMzoX379nr44YclSfPmzZNpmjp06JBSU1MlOQffVq9e3cqIuAdeU65UD3bOXHHn3yeZZ6SHyv72uM1HeuIB6ct8rlwxTecw21Y1JB+2MgcAAMBNDMNQSEiIRo0apQ0bNuijjz6Sv7+/Bg8erA4dOmj69Ok6fPgwRQs8Xu5WxevXr9eyZctUokQJqyMVKMMwNGTIEElSZmamPvvsM82ZM+fGa5dBtp7Fa8qVmsHSlRwp62f3PN/Fa9KJS7dfuSJJTR+QMn6Szly+92tn/OScudKKkhIAAAB3YRiGqlatqhEjRig1NVUrVqxQ2bJl9corr6hdu3aaMmWKMjIyKFrgkf7zn/8oJSVFy5cvl7+/v9VxCsXAgQNv/N5mzZqlBQsWSJKqVKmizp07W5gM98prypUabt4xKDN3mO0dypXcuSv5Wb2y4YjzkXkrAAAAyCvDMFS5cmUNGzZMn376qVatWqWqVasqNjZWoaGhmjhxovbt20fRAo+wePFixcfHKzExUQEBAb//CR6qQoUK6tGjhyTnINsff/xRkrN08ZbZMsUF5Uo+5W7D/NAdypWGlaSyJaWle+/tunaHtGi3VK+iVDXIpYgAAAAoxipUqKCYmBitWbNG69at08MPP6wJEyaodevWGjdunL766is5HA6rYwK/kZCQoLi4OCUlJalUqVJWxyl0L7zwwi3vG4ahmJgYi9Igv7ymXKlWRjLkxnLljPP5ckudX/P3lQY1lj7LlPbdw255KRnSwTPSS00LIiUAAAAglStXTv3799fKlSv1+eefq1GjRnrrrbfUqlUrjRkzRtu3b6doQZGwcuVKffDBB0pKSlJgYKDVcdyiTZs2evTRR2+8HxYWppCQEAsTIT+8plzx95WqBLq3XKkWJJW8y0qtAY2lMiWkmV/m7Zp2h/TONuduQ+EPF0hMAAAA4BZBQUGKjo5WQkKCNm7cqKeeekozZ85U8+bNFRsbq82bN8tut1sdE8XQ6tWr9c477ygpKUnBwXf4KbaXCgsLu/E2g2w9k9eUK5JzFcl3brwt6E7zVnIF+0v9GkmrMqQDP/3+NVcfcJ43ohm7BAEAAKDwBQYGKioqSkuWLNGXX36psLAwLViwQM2bN9fIkSP1xRdfKCcnx+qYKAY+/fRTTZ48WcnJySpfvrzVcdzK4XAoMTFRklSxYkVFRERYGwj54nXlijtWrpimc6Dt75UrkvR8Y+fqlvfS7n6ew3SuWqldXupUu0BiAgAAAHlWsmRJdevWTXFxcUpPT1e3bt20bNkytWjRQiNGjND69euVnZ1tdUx4odTUVE2YMEEpKSmqUKGC1XHcbtWqVfr+++8lOQfZetuW08WF15UrJ3+WLhfy9/ysS9KlbCmk7O+fe18pKfoxKfHruxc/aw5I355m1QoAAACsV6JECYWHh2vevHnasWOHevXqpZSUFLVs2VLDhg3T2rVrdfXqVatjwgts2LBBY8aMUXJysu6//36r47id3W7X+PHjJUm+vr4aPny4xYmQX15VrtR0045Bv7dT0K+98Lhk85Hev8PqFYcpTd/qvF5nVq0AAACgCPHz81NYWJjef/997dq1SwMHDtT69evVpk0bxcTEKCUlRZcvX7Y6JjzQ5s2bFRsbq6SkJFWpUsXqOG6ze/duffLJJ1q8eLHCw8OVlub8j+KAAQNUs2ZNi9Mhv7xq4+wb2zGfl+oU4mqyQ2ecj3m5LUiSKgVKPetJy/ZJLzeTqpS59eOfHJC+OS1Ne8ZZwgAAAABFkc1mU2hoqEJDQ2W327VlyxYtX75cb7zxhurUqaPIyEh17NhRpUuXtjoqiri0tDT95S9/UVJSkqpVq2Z1HLd6++23tXDhwluOPfjgg/rXv/5lUSIUBK/6r3xuufLd2cJ9noNnnHNUKt/DzmBDmzh3A/og/dbjubNWapWVuj1SoDEBAACAQmOz2dSyZUtNnTpVO3fu1N///nd99dVX6tChg3r37q2lS5fq/PnzVsdEEbRz50699NJLWrFiRbFeqWGz2RQSEqJhw4Zpy5YtxW6Qr7cxTNPM88lNmjQxc5csFUWmKT32vvR/j0rjQwvvefonOueurOl9b583ap2UnCE9XVO6eM3569wV6fA56f91dOYGAMAbGIax3TTNJlbnKER5/wcUUMw4HA7t3r1b8fHxWrNmjapUqaJu3bqpc+fOKlcuj0u/4bX27NmjmJgYxcfHq06dOlbHgWcp0tNJveq2IMOQqrthx6C9J6U2+ShYX24m7c5yrqwJLCGVKylVD3KWKt34vgIAAAAv4OPjo0aNGqlRo0aaMGGC9u3bp4SEBHXv3l3lypVTRESEunTpUix3hSnu9u/fr5iYGC1ZsoRiBV7Hq8oVSaoRJGX8VHjXP3HRuSPRY/kYZF2zrLS2T4FHAgAAAIokwzBUv3591a9fX6+99poyMjKUkJCgZ599VgEBAYqIiFDXrl1VqVIlGUaR/qE0XPTtt99q4MCB+vDDD1WvXj2r4wAFzqtmrkjOAuP7885ZJoVhT5bzsUHFwrk+AAAA4I0Mw9AjjzyiV199VZs2bdLcuXOVnZ2tfv36qVOnTnrvvfd07Ngx3cvYAniGzMxM9evXTwsWLFDDhg2tjgMUCq8rV2oESdfszhUmhWF3lvNGr3qUKwAAAEC+GIahkJAQjRo1Shs2bNCiRYvk5+enmJgYdezYUe+8844OHz5M0eIFjhw5oujoaM2ZM0ePP/641XGAQuN15UrN3B2DCmnuyu4s6aHyUukShXN9AAAAoDgxDEPVqlXTK6+8otTUVC1fvlxBQUEaMWKE2rVrp7feeksZGRkULR7o6NGj6tWrl9577z01a9bM6jhAofK6ciV3O+bCGmq7Oyt/81YAAAAA3J1hGKpcubKGDx+uzz77TCkpKapSpYpiY2MVGhqqSZMmaf/+/RQtHuD48eN69tlnNW3aNLVs2dLqOECh87py5YEyks2Qjpwv+GtnXZJOXGLeCgAAAOAOFStW1ODBg7VmzRqtXbtWtWrV0rhx49S6dWuNHz9eu3fvlsPhsDomfiUrK0tRUVGaPHmyQkNDrY4DuIXX7RbkZ5OqBkmZZwr+2rnDbB+rVPDXBgAAAHBn5cuX14ABAzRgwACdO3dOKSkpmjx5sg4ePKiwsDBFRkaqcePG8vHxup8fe5RTp04pKipKb7zxhsLCwqyOA7iN15UrklSvgvP2nYKWe836rFwBAAAALBMcHKzevXurd+/eunjxotasWaMZM2Zo//79atu2rSIiItS0aVPZbDaroxYrZ86cUVRUlMaMGaPw8HCr4wBu5ZW1bqNKzoG2Z68U7HX3ZEm1ykqBDLMFAAAAioTAwEBFRUVp6dKl2rp1q9q1a6f58+erRYsWN3YjysnJsTqm1zt37pyioqL0t7/9Td26dbM6DuB2XlmuNLx+285XJwr2unuypAYMswUAAACKpICAAEVERCguLk5paWnq0qWLli5dqhYtWtzYjSg7O9vqmF7nwoUL6tmzp0aMGKEePXpYHQewhFeWK7kzUXYVYLly+mfp2EXmrQAAAACewN/fX506ddK8efO0Y8cO9ezZU8nJyWrZsqWGDx+udevW6dq1a1bH9HiXLl1Sr169NHjwYPXq1cvqOIBlvHLmSrC/8/adgly5kjtvhW2YAQAAAM/i5+enDh06qEOHDsrJydHGjRuVkJCgMWPGqGHDhoqMjFT79u1VsmRJq6N6lMuXL+u5555Tnz591K9fP6vjAJbyynJFct4atPmHgrsew2wBAAAAz+fr66vQ0FCFhobKbrdr8+bNWr58uSZMmKC6desqMjJSHTp0UOnSpa2OWqRduXJFvXv3Vo8ePTRo0CCr4wCW89pypVElKfEb6cRFqVKg69fbnSU9GCwF+bt+LQAAAADWs9lsatWqlVq1aiWHw6Ht27crPj5ekydPVq1atdStWzeFh4erTJkyVkctUq5du6b+/fsrPDxcQ4cOlWEYVkcCLOeVM1ekX4baFtTclb1ZzFsBAACA6/73v//JZrPJMAzVqFFDZ8+eveO5mZmZCg4OlmEYCgwM1DfffOO+oMWMj4+PmjZtqsmTJys9PV2vvfaaMjIy1KlTJz377LNatGjRXf+siovs7GwNGjRIbdq00YgRIyhWgOu8tlypX1GyGQVTrpy5LP1wQWrALUEAAABw0dNPP63Y2FhJ0vfff68hQ4bc9rycnBxFR0fr/PnzkqRp06apTp06bstZnPn4+Khx48Z68803lZaWpkmTJunYsWOKiIhQ9+7dtXDhQp0+fdrqmG6Xk5OjwYMH64knntCoUaMoVoCbeG25EuAn1bmvYIba5s5bYRtmAAAAFIRx48bpySeflCTFx8dr3rx5tz1ny5YtkqQePXooJibGrRnhZBiG6tevr7Fjx2rbtm2aMWOGzpw5o549e6pr166aM2eOTpw4IdM0rY5aqOx2u4YPH666devq1VdfpVgBfsW4l28CTZo0MdPS0goxTsH622fS2oPSziGSK6/9mV9KkzdJX70gBTNAHACA32UYxnbTNJtYnaMQeff/ouAWhw4dUuPGjXXhwgWVLl1aO3bsUO3atSVJGzZsUNu2bWW321W9enXt2rVL5cqVszgxbmaapjIzM7V8+XIlJibKz89PXbt2VUREhKpUqeJV5YPdbteIESN0//33a+LEifLx8dqf0aNoK9IvKq9+VTSqJJ29Ih0559p19mRJNYIpVgAAAFBwatWqpZkzZ0qSLl26pOjoaGVnZ+vs2bPq06eP7Ha7fHx89OGHH1KsFEGGYahWrVr661//qo0bNyouLk42m03PP/+8nnnmGc2YMUNHjhzx+BUtDodDI0eOVNmyZSlWgLvw6ldGQQ213Z3FvBUAAAAUvL59+yo6OlqSlJaWpjFjxuiFF17QkSNHJEmjR49WmzZtrIyIPDAMQ9WrV9ef//xnff7554qPj1dgYKBeeukltW/fXm+//bYOHDjgcUWLw+FQbGys/Pz89Pbbb1OsAHfh1bcFZdul+rOkvg2lMfn8O+nsFanRbOnvLaThTQs2HwAA3orbgoC8O3/+vBo3bqzMzMxbjj/11FPasGGDfH19LUqGgnDy5El9/PHHWrFihU6fPq1OnTopIiJCderUKdK3DjkcDo0dO1bnz5/X7NmzKVZQFBTdF4y8fOWKn02qV9G1obZ7rg+zfYxhtgAAACgEQUFBWrRo0S0lyu2OwTNVrFhRQ4YM0SeffKK1a9cqJCREY8eOVevWrTVhwgTt2bNHDofD6pi3ME1Tb775pk6dOqX333+fYgXIA69/lTSsJO05Kdnz+f2KnYIAAABQ2KpVq6bSpUvfeP+JJ55QrVq1LEyEwlC+fHkNGDBAycnJSk1NVYMGDTRp0iS1bt1aY8eO1Y4dOywvWkzT1OTJk/Xdd99p7ty5stlsluYBPIXXlyuNKkk/Z0sHfsrf56celh65TyoXUKCxAAAAAEnO2y/69Omjc+d+2YUhNTVVs2bNsjAVCltwcLB69+6tFStWaMOGDWratKmmT5+uli1b6tVXX9XWrVvdXrSYpqnp06dr7969WrhwISungHvg9eWKK0NtT1yUth2VutQu2EwAAABArokTJ+qLL76QJLVv317BwcGSpJEjR2r//v1WRoObBAYGqmfPnlq2bJm2bNmi0NBQzZs3T82bN7+xG5Hdbi/UDKZpatasWdq2bZsWL14sPz+/Qn0+wNt4fbnyUDkpsET+ypXVB5zT6jpRrgAAAKAQbN26Va+//rok6YEHHtDSpUtvrFi5fPmynnvuOV29etXChHC3gIAARUZGKi4uTmlpaerUqZM++ugjNW/e/MZuRNnZ2QX6nKZpau7cuUpNTdXSpUtVokSJAr0+UBx4fbniYzjnpeRnqO2qDKnOfVLt8gWfCwAAAMXbhQsX1Lt3b+Xk5MgwDC1cuFD33XefnnvuOfXt21eStGvXLsXGxlqcFFbx9/dX586dNX/+fO3YsUN/+tOftHLlSrVs2VIvvviiPv30U127ds3l51m4cKFSUlKUkJAgf3//AkgOFD9eX65Izrkr+09JV3Py/jnHL0pfHpM6s2oFAAAAheDFF1/UwYMHJTlvAQoLC7vxsZkzZ94YaDt9+nStXbvWkowoOvz8/NSxY0fNnj1bO3fuVP/+/bVu3Tq1bt1aQ4YM0erVq3XlypV7vu6iRYuUkJCgxMREBQQwaBLIr2JTrmQ7pK9P5f1z1hxwPlKuAAAAoKB99NFHiouLkyT94Q9/0JtvvnnLx8uUKaPFixfL19dXpmlqwIABOnnypBVRUQT5+voqNDRUM2fO1M6dOzV06FBt3LhRoaGhGjhwoFauXKmff/75d68THx+vDz/8UElJSSpVqpQbkgPeq1iUK7lDbdOP5/1zUr6VHq0gPcwtQQAAAChAhw8f1rBhwyRJpUqV0uLFi2874+LJJ5+8MY/l+PHjGjRokDtjwkPYbDa1atVK06ZN086dOzVy5Eilp6crLCxMffv2VXx8vC5cuPCbz0tMTNS///1vJSUlKTAw0ILkgHcpFuVKtTJSvQpS3FeSw/z983+8IKX9yCBbAAAAFCy73X7LtstTp05V3bp173j+6NGj1aZNG0lSSkqK3n33XbfkhGfy8fFRs2bNNGXKFKWnp+sf//iHvv32W4WHh6tXr15avHixzp49q1WrVmnGjBlKTk6+sTsVANcYppmHtuG6Jk2amGlpaYUYp/AkfSO9/In0fmcp/OG7nzt3hzT+C2l9P+duQwAA4N4YhrHdNM0mVucoRHn/BxQAWMw0Te3du1cJCQlavny5zp07p/T0dFWoUMHqaMC9MKwOcDfFplyxO6R2/5HK+EvJvSTjLn8s3ZdJV7KlNb3dlw8AAG9CuQIARZNpmjp79qzKleOnyPA4RbpcKRa3BUmSzUca2kTanSVtPHLn845dkNJ/ZJAtAAAAAO9jGAbFClAIik25Ikn/V1eqVFp69y6Lb1ZnOB8pVwAAAAAAQF4Uq3LF31ca/Li05Qdp+4+3PyclQ6pXUQqhzAUAAAAAAHlQrMoVSYpuIJUtKb335W8/duSctOO41IVVKwAAAAAAII+KXblSuoQ0sJH0Wab09SnnsWt2af5OqdsSqYRN6vqIpREBAAAAAIAHKXbliiQNaCyV8pPeS5PWHJA6xEmv/096tIK0oqdUg63eAQAAAABAHvlaHcAKZUtKfR6TPkiXVn4j1S4vze8mtX3w7ls0AwAAAAAA/FqxLFckacjjUsZP0jMPSVH1JN9iuYYHAAAAAAC4qtiWKxVLSwsirE4BAAAAAAA8Hes1AAAAAAAAXEC5AgAAAAAA4ALKFQAAAAAAABdQrgAAAAAAALiAcgUAAAAAAMAFlCsAAAAAAAAuoFwBAAAAAABwAeUKAAAAAACACyhXAAAAAAAAXEC5AgAAAAAA4ALKFQAAAAAAABdQrgAAAAAAALiAcgUAAAAAAMAFlCsAAAAAAAAuoFwBAAAAAABwAeUKAAAAAACACyhXAAAAAAAAXEC5AgAAAAAA4ALKFQAAAAAAABdQrgAAAAAAALjAME0z7ycbxklJ3xVeHAAA4CVqmqZZ0eoQAAAA7nBP5QoAAAAAAABuxW1BAAAAAAAALqBcAQAAAAAAcAHlCgAAAAAAgAt8rQ4AwPMYhvGgpAHX3/3cNM3PLQsDAAAAABajXAGQHw9KGnvT+59bEwMAAAAArMdtQQAAAAAAAC6gXAEAAAAAAHAB5QoAAAAAAIALDNM0rc4AwEMYhhEqKTUv55qmaRRqGAAAAAAoIli5AgAAAAAA4AJ2CwJwL/ZI6i6pgaQJ148tlbTEskQAAAAAYDHKFQB5ZprmKUmJhmGcvenw16ZpJlqTCAAAAACsx21BAAAAAAAALqBcAQAAAAAAcAHlCgAAAAAAgAsoVwAAAAAAAFxAuQIAAAAAAOACyhUAAAAAAAAXUK4AAAAAAAC4gHIFQH44bnrbsCwFAAAAABQBlCsA8uPiTW+XtiwFAAAAABQBlCsA8iPzprcftywFAAAAABQBhmmaVmcA4IEMw0iX9Ifr786W9F9JF3I/bprmJ1bkAgAAAAB3o1wBkC+GYYRLSpZku93HTdNkFgsAAACAYoFyBUC+GYbxpKQRkppLqiwpIPdjlCsAAAAAigvKFQAAAAAAABcw0BYAAAAAAMAFlCsAAAAAAAAuoFwBAAAAAABwAeUKAAAAAACACyhXAAAAAAAAXEC5AgAAAAAA4ALKFQAAAAAAABdQrgAAAAAAALiAcgUAAAAAAMAFlCsAAAAAAAAuoFwBAAAAAABwAeUKAAAAAACAC/4/lEgijx8Wq1YAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "main()" ] } ], "metadata": { "kernelspec": { "display_name": "jax0227", "language": "python", "name": "jax0227" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }