{ "cells": [ { "cell_type": "markdown", "id": "d36b21f0-6445-4b28-bc68-b04895b68f61", "metadata": {}, "source": [ "# Latent ODE" ] }, { "cell_type": "markdown", "id": "829f5f6b-5b03-414a-b37a-ccaca860224b", "metadata": {}, "source": [ "This example trains a [Latent ODE](https://arxiv.org/abs/1810.01367).\n", "\n", "In this case, it's on a simple dataset of decaying oscillators. That is, 2-dimensional time series that look like:\n", "\n", "```\n", "xx ***\n", " ** *\n", " x* **\n", " *x\n", " x *\n", " * * xxxxx\n", "* x * xx xx *******\n", " x x **\n", " x * x * x * xxxxxxxx ******\n", " x * x * x * xxx *xx *\n", " x * xx ** x ** xx\n", " x * x * x * xx ** xx\n", " * x * x ** x * xxx\n", " x * * x * xx **\n", " x * x * xx xx* ***\n", " x *x * xxx xxx *****\n", " x x* * xx\n", " x xx ******\n", " xxxxx\n", "```\n", " \n", "The model is trained to generate samples that look like this.\n", "\n", "What's really nice about this example is that we will take the underlying data to be irregularly sampled. We will have different observation times for different batch elements.\n", "\n", "Most differential equation libraries will struggle with this, as they usually mandate that the differential equation be solved over the same timespan for all batch elements. Working around this can involve programming complexity like outputting at lots and lots of times (the union of all the observations times in the batch), or mathematical complexities like reparameterising the differentiating equation.\n", "\n", "However Diffrax is capable of handling this without such issues! You can `vmap` over\n", "different integration times for different batch elements.\n", "\n", "**Reference:**\n", "\n", "```bibtex\n", "@incollection{rubanova2019latent,\n", " title={{L}atent {O}rdinary {D}ifferential {E}quations for {I}rregularly-{S}ampled\n", " {T}ime {S}eries},\n", " author={Rubanova, Yulia and Chen, Ricky T. Q. and Duvenaud, David K.},\n", " booktitle={Advances in Neural Information Processing Systems},\n", " publisher={Curran Associates, Inc.},\n", " year={2019},\n", "}\n", "```\n", "\n", "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/latent_ode.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "id": "80a7f1a8-c288-4ea8-b751-622a6d76f4ce", "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "import diffrax\n", "import equinox as eqx\n", "import jax\n", "import jax.nn as jnn\n", "import jax.numpy as jnp\n", "import jax.random as jr\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import optax\n", "\n", "\n", "matplotlib.rcParams.update({\"font.size\": 30})" ] }, { "cell_type": "markdown", "id": "c1a86a99-786d-4f4a-8376-71955dcff0f6", "metadata": {}, "source": [ "The vector field. Note its overall structure of `scalar * tanh(mlp(y))` which is a good structure for Latent ODEs. (Here the tanh is part of `self.mlp`.)" ] }, { "cell_type": "code", "execution_count": 2, "id": "9dc61000-f4b6-4693-92eb-1d93dfbddf02", "metadata": {}, "outputs": [], "source": [ "class Func(eqx.Module):\n", " scale: jnp.ndarray\n", " mlp: eqx.nn.MLP\n", "\n", " def __call__(self, t, y, args):\n", " return self.scale * self.mlp(y)" ] }, { "cell_type": "markdown", "id": "b7004ef1-e03b-4c20-9c16-53266d4b3c15", "metadata": {}, "source": [ "Wrap up the differential equation solve into a model." ] }, { "cell_type": "code", "execution_count": 3, "id": "5466e77f-c7f2-4b32-8145-a92d2ca7911c", "metadata": {}, "outputs": [], "source": [ "class LatentODE(eqx.Module):\n", " func: Func\n", " rnn_cell: eqx.nn.GRUCell\n", "\n", " hidden_to_latent: eqx.nn.Linear\n", " latent_to_hidden: eqx.nn.MLP\n", " hidden_to_data: eqx.nn.Linear\n", "\n", " hidden_size: int\n", " latent_size: int\n", "\n", " def __init__(\n", " self, *, data_size, hidden_size, latent_size, width_size, depth, key, **kwargs\n", " ):\n", " super().__init__(**kwargs)\n", "\n", " mkey, gkey, hlkey, lhkey, hdkey = jr.split(key, 5)\n", "\n", " scale = jnp.ones(())\n", " mlp = eqx.nn.MLP(\n", " in_size=hidden_size,\n", " out_size=hidden_size,\n", " width_size=width_size,\n", " depth=depth,\n", " activation=jnn.softplus,\n", " final_activation=jnn.tanh,\n", " key=mkey,\n", " )\n", " self.func = Func(scale, mlp)\n", " self.rnn_cell = eqx.nn.GRUCell(data_size + 1, hidden_size, key=gkey)\n", "\n", " self.hidden_to_latent = eqx.nn.Linear(hidden_size, 2 * latent_size, key=hlkey)\n", " self.latent_to_hidden = eqx.nn.MLP(\n", " latent_size, hidden_size, width_size=width_size, depth=depth, key=lhkey\n", " )\n", " self.hidden_to_data = eqx.nn.Linear(hidden_size, data_size, key=hdkey)\n", "\n", " self.hidden_size = hidden_size\n", " self.latent_size = latent_size\n", "\n", " # Encoder of the VAE\n", " def _latent(self, ts, ys, key):\n", " data = jnp.concatenate([ts[:, None], ys], axis=1)\n", " hidden = jnp.zeros((self.hidden_size,))\n", " for data_i in reversed(data):\n", " hidden = self.rnn_cell(data_i, hidden)\n", " context = self.hidden_to_latent(hidden)\n", " mean, logstd = context[: self.latent_size], context[self.latent_size :]\n", " std = jnp.exp(logstd)\n", " latent = mean + jr.normal(key, (self.latent_size,)) * std\n", " return latent, mean, std\n", "\n", " # Decoder of the VAE\n", " def _sample(self, ts, latent):\n", " dt0 = 0.4 # selected as a reasonable choice for this problem\n", " y0 = self.latent_to_hidden(latent)\n", " sol = diffrax.diffeqsolve(\n", " diffrax.ODETerm(self.func),\n", " diffrax.Tsit5(),\n", " ts[0],\n", " ts[-1],\n", " dt0,\n", " y0,\n", " saveat=diffrax.SaveAt(ts=ts),\n", " )\n", " return jax.vmap(self.hidden_to_data)(sol.ys)\n", "\n", " @staticmethod\n", " def _loss(ys, pred_ys, mean, std):\n", " # -log p_θ with Gaussian p_θ\n", " reconstruction_loss = 0.5 * jnp.sum((ys - pred_ys) ** 2)\n", " # KL(N(mean, std^2) || N(0, 1))\n", " variational_loss = 0.5 * jnp.sum(mean**2 + std**2 - 2 * jnp.log(std) - 1)\n", " return reconstruction_loss + variational_loss\n", "\n", " # Run both encoder and decoder during training.\n", " def train(self, ts, ys, *, key):\n", " latent, mean, std = self._latent(ts, ys, key)\n", " pred_ys = self._sample(ts, latent)\n", " return self._loss(ys, pred_ys, mean, std)\n", "\n", " # Run just the decoder during inference.\n", " def sample(self, ts, *, key):\n", " latent = jr.normal(key, (self.latent_size,))\n", " return self._sample(ts, latent)" ] }, { "cell_type": "markdown", "id": "d45e7035-eaaa-4a1f-b2fc-c69b3f5ebb73", "metadata": {}, "source": [ "Toy dataset of decaying oscillators.\n", "\n", "By way of illustration we set this up as a differential equation and solve this using Diffrax as well. (Despite this being an autonomous linear ODE, for which a closed-form solution is actually available.)" ] }, { "cell_type": "code", "execution_count": 4, "id": "7d21ca14-d588-45d0-b512-653114a00b46", "metadata": {}, "outputs": [], "source": [ "def get_data(dataset_size, *, key):\n", " ykey, tkey1, tkey2 = jr.split(key, 3)\n", "\n", " y0 = jr.normal(ykey, (dataset_size, 2))\n", "\n", " t0 = 0\n", " t1 = 2 + jr.uniform(tkey1, (dataset_size,))\n", " ts = jr.uniform(tkey2, (dataset_size, 20)) * (t1[:, None] - t0) + t0\n", " ts = jnp.sort(ts)\n", " dt0 = 0.1\n", "\n", " def func(t, y, args):\n", " return jnp.array([[-0.1, 1.3], [-1, -0.1]]) @ y\n", "\n", " def solve(ts, y0):\n", " sol = diffrax.diffeqsolve(\n", " diffrax.ODETerm(func),\n", " diffrax.Tsit5(),\n", " ts[0],\n", " ts[-1],\n", " dt0,\n", " y0,\n", " saveat=diffrax.SaveAt(ts=ts),\n", " )\n", " return sol.ys\n", "\n", " ys = jax.vmap(solve)(ts, y0)\n", "\n", " return ts, ys" ] }, { "cell_type": "code", "execution_count": 5, "id": "90d6a004-0930-4a04-9139-677be3f6787b", "metadata": {}, "outputs": [], "source": [ "def dataloader(arrays, batch_size, *, key):\n", " dataset_size = arrays[0].shape[0]\n", " assert all(array.shape[0] == dataset_size for array in arrays)\n", " indices = jnp.arange(dataset_size)\n", " while True:\n", " perm = jr.permutation(key, indices)\n", " (key,) = jr.split(key, 1)\n", " start = 0\n", " end = batch_size\n", " while start < dataset_size:\n", " batch_perm = perm[start:end]\n", " yield tuple(array[batch_perm] for array in arrays)\n", " start = end\n", " end = start + batch_size" ] }, { "cell_type": "markdown", "id": "cec3fca6-c5cc-4c8a-950b-9ccabaa34efb", "metadata": {}, "source": [ "The main entry point. Try running `main()` to train a model." ] }, { "cell_type": "code", "execution_count": 6, "id": "bb6c4793-6699-48e6-8a88-4874c4b77592", "metadata": {}, "outputs": [], "source": [ "def main(\n", " dataset_size=10000,\n", " batch_size=256,\n", " lr=1e-2,\n", " steps=250,\n", " save_every=50,\n", " hidden_size=16,\n", " latent_size=16,\n", " width_size=16,\n", " depth=2,\n", " seed=5678,\n", "):\n", " key = jr.PRNGKey(seed)\n", " data_key, model_key, loader_key, train_key, sample_key = jr.split(key, 5)\n", "\n", " ts, ys = get_data(dataset_size, key=data_key)\n", "\n", " model = LatentODE(\n", " data_size=ys.shape[-1],\n", " hidden_size=hidden_size,\n", " latent_size=latent_size,\n", " width_size=width_size,\n", " depth=depth,\n", " key=model_key,\n", " )\n", "\n", " @eqx.filter_value_and_grad\n", " def loss(model, ts_i, ys_i, key_i):\n", " batch_size, _ = ts_i.shape\n", " key_i = jr.split(key_i, batch_size)\n", " loss = jax.vmap(model.train)(ts_i, ys_i, key=key_i)\n", " return jnp.mean(loss)\n", "\n", " @eqx.filter_jit\n", " def make_step(model, opt_state, ts_i, ys_i, key_i):\n", " value, grads = loss(model, ts_i, ys_i, key_i)\n", " key_i = jr.split(key_i, 1)[0]\n", " updates, opt_state = optim.update(grads, opt_state)\n", " model = eqx.apply_updates(model, updates)\n", " return value, model, opt_state, key_i\n", "\n", " optim = optax.adam(lr)\n", " opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n", "\n", " # Plot results\n", " num_plots = 1 + (steps - 1) // save_every\n", " if ((steps - 1) % save_every) != 0:\n", " num_plots += 1\n", " fig, axs = plt.subplots(1, num_plots, figsize=(num_plots * 8, 8))\n", " axs[0].set_ylabel(\"x\")\n", " axs = iter(axs)\n", " for step, (ts_i, ys_i) in zip(\n", " range(steps), dataloader((ts, ys), batch_size, key=loader_key)\n", " ):\n", " start = time.time()\n", " value, model, opt_state, train_key = make_step(\n", " model, opt_state, ts_i, ys_i, train_key\n", " )\n", " end = time.time()\n", " print(f\"Step: {step}, Loss: {value}, Computation time: {end - start}\")\n", "\n", " if (step % save_every) == 0 or step == steps - 1:\n", " ax = next(axs)\n", " # Sample over a longer time interval than we trained on. The model will be\n", " # sufficiently good that it will correctly extrapolate!\n", " sample_t = jnp.linspace(0, 12, 300)\n", " sample_y = model.sample(sample_t, key=sample_key)\n", " sample_t = np.asarray(sample_t)\n", " sample_y = np.asarray(sample_y)\n", " ax.plot(sample_t, sample_y[:, 0])\n", " ax.plot(sample_t, sample_y[:, 1])\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " ax.set_xlabel(\"t\")\n", "\n", " plt.savefig(\"latent_ode.png\")\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 7, "id": "4788d4cf-4150-409b-8aa4-273038408591", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step: 0, Loss: 19.934764862060547, Computation time: 27.07537531852722\n", "Step: 1, Loss: 17.945302963256836, Computation time: 0.1743943691253662\n", "Step: 2, Loss: 16.862319946289062, Computation time: 0.16676902770996094\n", "Step: 3, Loss: 17.838266372680664, Computation time: 0.1676805019378662\n", "Step: 4, Loss: 15.913865089416504, Computation time: 0.16959643363952637\n", "Step: 5, Loss: 15.387907028198242, Computation time: 0.16565966606140137\n", "Step: 6, Loss: 16.50263214111328, Computation time: 0.16969871520996094\n", "Step: 7, Loss: 17.307086944580078, Computation time: 0.17042207717895508\n", "Step: 8, Loss: 15.414609909057617, Computation time: 0.16952204704284668\n", "Step: 9, Loss: 16.912670135498047, Computation time: 0.16579079627990723\n", "Step: 10, Loss: 17.230003356933594, Computation time: 0.16723251342773438\n", "Step: 11, Loss: 18.290681838989258, Computation time: 0.16434955596923828\n", "Step: 12, Loss: 15.541263580322266, Computation time: 0.16330623626708984\n", "Step: 13, Loss: 15.520601272583008, Computation time: 0.16518783569335938\n", "Step: 14, Loss: 14.719974517822266, Computation time: 0.16350150108337402\n", "Step: 15, Loss: 15.513769149780273, Computation time: 0.16359448432922363\n", "Step: 16, Loss: 16.30827522277832, Computation time: 0.1634058952331543\n", "Step: 17, Loss: 14.704435348510742, Computation time: 0.16392016410827637\n", "Step: 18, Loss: 14.534599304199219, Computation time: 0.16302919387817383\n", "Step: 19, Loss: 14.99282455444336, Computation time: 0.1640028953552246\n", "Step: 20, Loss: 15.04023551940918, Computation time: 0.16433429718017578\n", "Step: 21, Loss: 15.750327110290527, Computation time: 0.16364169120788574\n", "Step: 22, Loss: 14.745054244995117, Computation time: 0.163421630859375\n", "Step: 23, Loss: 15.654170989990234, Computation time: 0.16426348686218262\n", "Step: 24, Loss: 14.102017402648926, Computation time: 0.16342639923095703\n", "Step: 25, Loss: 13.730924606323242, Computation time: 0.16349434852600098\n", "Step: 26, Loss: 14.454326629638672, Computation time: 0.162459135055542\n", "Step: 27, Loss: 16.074562072753906, Computation time: 0.16372108459472656\n", "Step: 28, Loss: 14.457178115844727, Computation time: 0.16365718841552734\n", "Step: 29, Loss: 14.899832725524902, Computation time: 0.16407418251037598\n", "Step: 30, Loss: 14.21741771697998, Computation time: 0.16400694847106934\n", "Step: 31, Loss: 12.896212577819824, Computation time: 0.16325831413269043\n", "Step: 32, Loss: 13.572277069091797, Computation time: 0.16397356986999512\n", "Step: 33, Loss: 14.58654499053955, Computation time: 0.1686105728149414\n", "Step: 34, Loss: 14.236112594604492, Computation time: 0.1673274040222168\n", "Step: 35, Loss: 13.96904182434082, Computation time: 0.16666364669799805\n", "Step: 36, Loss: 13.717779159545898, Computation time: 0.16426467895507812\n", "Step: 37, Loss: 13.212942123413086, Computation time: 0.16362261772155762\n", "Step: 38, Loss: 13.356792449951172, Computation time: 0.16526198387145996\n", "Step: 39, Loss: 13.750845909118652, Computation time: 26.91799235343933\n", "Step: 40, Loss: 15.398611068725586, Computation time: 0.1675868034362793\n", "Step: 41, Loss: 11.830371856689453, Computation time: 0.16466093063354492\n", "Step: 42, Loss: 12.59495735168457, Computation time: 0.16176891326904297\n", "Step: 43, Loss: 13.213092803955078, Computation time: 0.16349530220031738\n", "Step: 44, Loss: 12.40422534942627, Computation time: 0.16125273704528809\n", "Step: 45, Loss: 13.30964469909668, Computation time: 0.16145730018615723\n", "Step: 46, Loss: 12.55689811706543, Computation time: 0.16156625747680664\n", "Step: 47, Loss: 11.785927772521973, Computation time: 0.1622486114501953\n", "Step: 48, Loss: 11.325067520141602, Computation time: 0.16244864463806152\n", "Step: 49, Loss: 11.61506462097168, Computation time: 0.1624457836151123\n", "Step: 50, Loss: 10.890422821044922, Computation time: 0.16366934776306152\n", "Step: 51, Loss: 13.305912017822266, Computation time: 0.16304707527160645\n", "Step: 52, Loss: 11.54366397857666, Computation time: 0.16243696212768555\n", "Step: 53, Loss: 11.796025276184082, Computation time: 0.16330742835998535\n", "Step: 54, Loss: 12.504520416259766, Computation time: 0.16342830657958984\n", "Step: 55, Loss: 11.736138343811035, Computation time: 0.16159415245056152\n", "Step: 56, Loss: 11.351236343383789, Computation time: 0.16047382354736328\n", "Step: 57, Loss: 11.916851997375488, Computation time: 0.16179728507995605\n", "Step: 58, Loss: 11.83980655670166, Computation time: 0.16157770156860352\n", "Step: 59, Loss: 11.1612548828125, Computation time: 0.16280055046081543\n", "Step: 60, Loss: 11.311992645263672, Computation time: 0.1631929874420166\n", "Step: 61, Loss: 11.657142639160156, Computation time: 0.16200017929077148\n", "Step: 62, Loss: 10.814916610717773, Computation time: 0.16182494163513184\n", "Step: 63, Loss: 10.638484001159668, Computation time: 0.16114020347595215\n", "Step: 64, Loss: 9.871231079101562, Computation time: 0.16211938858032227\n", "Step: 65, Loss: 10.842245101928711, Computation time: 0.16185402870178223\n", "Step: 66, Loss: 11.241954803466797, Computation time: 0.16134214401245117\n", "Step: 67, Loss: 10.528236389160156, Computation time: 0.16387319564819336\n", "Step: 68, Loss: 10.252235412597656, Computation time: 0.16159725189208984\n", "Step: 69, Loss: 10.343666076660156, Computation time: 0.16295313835144043\n", "Step: 70, Loss: 9.838155746459961, Computation time: 0.16141152381896973\n", "Step: 71, Loss: 10.129756927490234, Computation time: 0.16135191917419434\n", "Step: 72, Loss: 10.172172546386719, Computation time: 0.16157221794128418\n", "Step: 73, Loss: 9.98276424407959, Computation time: 0.16115164756774902\n", "Step: 74, Loss: 9.925966262817383, Computation time: 0.16163945198059082\n", "Step: 75, Loss: 9.98451042175293, Computation time: 0.16181254386901855\n", "Step: 76, Loss: 10.033723831176758, Computation time: 0.1613597869873047\n", "Step: 77, Loss: 9.620193481445312, Computation time: 0.1607823371887207\n", "Step: 78, Loss: 9.448945045471191, Computation time: 0.1607818603515625\n", "Step: 79, Loss: 7.9748687744140625, Computation time: 0.1488492488861084\n", "Step: 80, Loss: 9.215356826782227, Computation time: 0.16275405883789062\n", "Step: 81, Loss: 9.691690444946289, Computation time: 0.1624891757965088\n", "Step: 82, Loss: 8.748353958129883, Computation time: 0.16045212745666504\n", "Step: 83, Loss: 8.528343200683594, Computation time: 0.16178536415100098\n", "Step: 84, Loss: 8.34644889831543, Computation time: 0.16109156608581543\n", "Step: 85, Loss: 9.200542449951172, Computation time: 0.16094589233398438\n", "Step: 86, Loss: 8.57141399383545, Computation time: 0.1619279384613037\n", "Step: 87, Loss: 7.508444786071777, Computation time: 0.1600663661956787\n", "Step: 88, Loss: 7.279205322265625, Computation time: 0.16137266159057617\n", "Step: 89, Loss: 7.090503215789795, Computation time: 0.16118311882019043\n", "Step: 90, Loss: 7.453930377960205, Computation time: 0.16112112998962402\n", "Step: 91, Loss: 7.0916032791137695, Computation time: 0.16120529174804688\n", "Step: 92, Loss: 7.136333465576172, Computation time: 0.16111302375793457\n", "Step: 93, Loss: 7.14594841003418, Computation time: 0.16206598281860352\n", "Step: 94, Loss: 6.871617317199707, Computation time: 0.19673919677734375\n", "Step: 95, Loss: 7.352797031402588, Computation time: 0.16296100616455078\n", "Step: 96, Loss: 6.726633548736572, Computation time: 0.16156458854675293\n", "Step: 97, Loss: 6.9557905197143555, Computation time: 0.16250896453857422\n", "Step: 98, Loss: 7.102599143981934, Computation time: 0.1620466709136963\n", "Step: 99, Loss: 7.049860954284668, Computation time: 0.16131353378295898\n", "Step: 100, Loss: 6.750383377075195, Computation time: 0.16186952590942383\n", "Step: 101, Loss: 7.038060188293457, Computation time: 0.16181278228759766\n", "Step: 102, Loss: 7.034355640411377, Computation time: 0.16237926483154297\n", "Step: 103, Loss: 6.82716178894043, Computation time: 0.16185402870178223\n", "Step: 104, Loss: 6.787952423095703, Computation time: 0.16224908828735352\n", "Step: 105, Loss: 6.880023002624512, Computation time: 0.16243886947631836\n", "Step: 106, Loss: 6.616780757904053, Computation time: 0.1620333194732666\n", "Step: 107, Loss: 6.402748107910156, Computation time: 0.16213607788085938\n", "Step: 108, Loss: 6.7207746505737305, Computation time: 0.16174864768981934\n", "Step: 109, Loss: 5.961440563201904, Computation time: 0.16174983978271484\n", "Step: 110, Loss: 6.086441993713379, Computation time: 0.16232728958129883\n", "Step: 111, Loss: 5.67965030670166, Computation time: 0.1625194549560547\n", "Step: 112, Loss: 5.820930480957031, Computation time: 0.1604611873626709\n", "Step: 113, Loss: 6.119414329528809, Computation time: 0.16963505744934082\n", "Step: 114, Loss: 6.096449851989746, Computation time: 0.16268205642700195\n", "Step: 115, Loss: 5.988513469696045, Computation time: 0.1606006622314453\n", "Step: 116, Loss: 6.118512153625488, Computation time: 0.16241216659545898\n", "Step: 117, Loss: 5.241769790649414, Computation time: 0.16131067276000977\n", "Step: 118, Loss: 6.166355609893799, Computation time: 0.16092491149902344\n", "Step: 119, Loss: 6.842771530151367, Computation time: 0.1441802978515625\n", "Step: 120, Loss: 6.375185489654541, Computation time: 0.16277027130126953\n", "Step: 121, Loss: 5.80587100982666, Computation time: 0.1614992618560791\n", "Step: 122, Loss: 5.733676433563232, Computation time: 0.16245174407958984\n", "Step: 123, Loss: 5.918340682983398, Computation time: 0.16118144989013672\n", "Step: 124, Loss: 5.5885467529296875, Computation time: 0.16121363639831543\n", "Step: 125, Loss: 5.8133063316345215, Computation time: 0.16047954559326172\n", "Step: 126, Loss: 5.448032379150391, Computation time: 0.1612851619720459\n", "Step: 127, Loss: 5.919766902923584, Computation time: 0.16178321838378906\n", "Step: 128, Loss: 5.811756610870361, Computation time: 0.16073966026306152\n", "Step: 129, Loss: 5.2886857986450195, Computation time: 0.16239547729492188\n", "Step: 130, Loss: 5.062446594238281, Computation time: 0.1623084545135498\n", "Step: 131, Loss: 5.370600700378418, Computation time: 0.16302895545959473\n", "Step: 132, Loss: 5.032846450805664, Computation time: 0.16185617446899414\n", "Step: 133, Loss: 5.3186492919921875, Computation time: 0.16357207298278809\n", "Step: 134, Loss: 4.988264083862305, Computation time: 0.16092920303344727\n", "Step: 135, Loss: 5.364264488220215, Computation time: 0.16193294525146484\n", "Step: 136, Loss: 5.038562774658203, Computation time: 0.16143488883972168\n", "Step: 137, Loss: 5.195552825927734, Computation time: 0.16141676902770996\n", "Step: 138, Loss: 4.877957344055176, Computation time: 0.16106271743774414\n", "Step: 139, Loss: 4.971206188201904, Computation time: 0.15976953506469727\n", "Step: 140, Loss: 4.850249767303467, Computation time: 0.16672515869140625\n", "Step: 141, Loss: 5.053151607513428, Computation time: 0.16182613372802734\n", "Step: 142, Loss: 4.553808212280273, Computation time: 0.16060352325439453\n", "Step: 143, Loss: 4.6004109382629395, Computation time: 0.16107678413391113\n", "Step: 144, Loss: 4.889383316040039, Computation time: 0.1608583927154541\n", "Step: 145, Loss: 4.736492156982422, Computation time: 0.16157317161560059\n", "Step: 146, Loss: 4.708489894866943, Computation time: 0.16304683685302734\n", "Step: 147, Loss: 4.679104804992676, Computation time: 0.1609785556793213\n", "Step: 148, Loss: 4.689470291137695, Computation time: 0.16070127487182617\n", "Step: 149, Loss: 4.528751850128174, Computation time: 0.16136622428894043\n", "Step: 150, Loss: 4.48677396774292, Computation time: 0.1604769229888916\n", "Step: 151, Loss: 4.637646675109863, Computation time: 0.16101288795471191\n", "Step: 152, Loss: 4.762913703918457, Computation time: 0.16133403778076172\n", "Step: 153, Loss: 4.44551944732666, Computation time: 0.1619107723236084\n", "Step: 154, Loss: 4.5776472091674805, Computation time: 0.1616075038909912\n", "Step: 155, Loss: 4.562440395355225, Computation time: 0.16150236129760742\n", "Step: 156, Loss: 4.409887313842773, Computation time: 0.16173315048217773\n", "Step: 157, Loss: 4.46767520904541, Computation time: 0.16112399101257324\n", "Step: 158, Loss: 4.25125789642334, Computation time: 0.16138744354248047\n", "Step: 159, Loss: 4.785336971282959, Computation time: 0.1468524932861328\n", "Step: 160, Loss: 5.054254055023193, Computation time: 0.16128849983215332\n", "Step: 161, Loss: 4.8799567222595215, Computation time: 0.1611628532409668\n", "Step: 162, Loss: 4.688265800476074, Computation time: 0.16042160987854004\n", "Step: 163, Loss: 4.51352596282959, Computation time: 0.1602628231048584\n", "Step: 164, Loss: 4.331615447998047, Computation time: 0.1609640121459961\n", "Step: 165, Loss: 4.137004852294922, Computation time: 0.16290068626403809\n", "Step: 166, Loss: 4.654952049255371, Computation time: 0.16114187240600586\n", "Step: 167, Loss: 4.4677629470825195, Computation time: 0.16231393814086914\n", "Step: 168, Loss: 4.510952949523926, Computation time: 0.16344356536865234\n", "Step: 169, Loss: 4.258943557739258, Computation time: 0.16016602516174316\n", "Step: 170, Loss: 4.283701419830322, Computation time: 0.1614704132080078\n", "Step: 171, Loss: 4.368310451507568, Computation time: 0.1617722511291504\n", "Step: 172, Loss: 4.095067024230957, Computation time: 0.16355204582214355\n", "Step: 173, Loss: 4.290921211242676, Computation time: 0.16144156455993652\n", "Step: 174, Loss: 4.135052680969238, Computation time: 0.16065239906311035\n", "Step: 175, Loss: 4.188730239868164, Computation time: 0.16092491149902344\n", "Step: 176, Loss: 3.9966931343078613, Computation time: 0.16103458404541016\n", "Step: 177, Loss: 4.127541542053223, Computation time: 0.16103053092956543\n", "Step: 178, Loss: 4.2538557052612305, Computation time: 0.1615607738494873\n", "Step: 179, Loss: 4.453568458557129, Computation time: 0.1603102684020996\n", "Step: 180, Loss: 4.0408525466918945, Computation time: 0.16083049774169922\n", "Step: 181, Loss: 4.516185760498047, Computation time: 0.1609797477722168\n", "Step: 182, Loss: 4.250395774841309, Computation time: 0.1612706184387207\n", "Step: 183, Loss: 4.046529769897461, Computation time: 0.16176581382751465\n", "Step: 184, Loss: 4.198785781860352, Computation time: 0.16283583641052246\n", "Step: 185, Loss: 3.9407706260681152, Computation time: 0.16234254837036133\n", "Step: 186, Loss: 4.026411056518555, Computation time: 0.1624460220336914\n", "Step: 187, Loss: 4.224530220031738, Computation time: 0.16072320938110352\n", "Step: 188, Loss: 4.028736591339111, Computation time: 0.16074919700622559\n", "Step: 189, Loss: 3.837322950363159, Computation time: 0.16036534309387207\n", "Step: 190, Loss: 4.123674392700195, Computation time: 0.16191387176513672\n", "Step: 191, Loss: 3.9622178077697754, Computation time: 0.16129708290100098\n", "Step: 192, Loss: 3.969315528869629, Computation time: 0.16092944145202637\n", "Step: 193, Loss: 3.7825825214385986, Computation time: 0.16073131561279297\n", "Step: 194, Loss: 3.9199018478393555, Computation time: 0.16074514389038086\n", "Step: 195, Loss: 4.052471160888672, Computation time: 0.16427040100097656\n", "Step: 196, Loss: 3.7691221237182617, Computation time: 0.16066265106201172\n", "Step: 197, Loss: 3.937032699584961, Computation time: 0.16099143028259277\n", "Step: 198, Loss: 4.042672634124756, Computation time: 0.16167831420898438\n", "Step: 199, Loss: 3.7281570434570312, Computation time: 0.14007043838500977\n", "Step: 200, Loss: 4.159261226654053, Computation time: 0.16143798828125\n", "Step: 201, Loss: 4.408998489379883, Computation time: 0.16060853004455566\n", "Step: 202, Loss: 4.1045427322387695, Computation time: 0.16067767143249512\n", "Step: 203, Loss: 4.352884292602539, Computation time: 0.1615588665008545\n", "Step: 204, Loss: 4.170437335968018, Computation time: 0.16057705879211426\n", "Step: 205, Loss: 3.970756769180298, Computation time: 0.1603851318359375\n", "Step: 206, Loss: 4.299739837646484, Computation time: 0.16051793098449707\n", "Step: 207, Loss: 4.127477645874023, Computation time: 0.16169023513793945\n", "Step: 208, Loss: 4.360357761383057, Computation time: 0.1614537239074707\n", "Step: 209, Loss: 3.9281232357025146, Computation time: 0.16314291954040527\n", "Step: 210, Loss: 3.9255576133728027, Computation time: 0.16143369674682617\n", "Step: 211, Loss: 4.089841842651367, Computation time: 0.162628173828125\n", "Step: 212, Loss: 4.131923675537109, Computation time: 0.1637284755706787\n", "Step: 213, Loss: 4.047548294067383, Computation time: 0.16175484657287598\n", "Step: 214, Loss: 4.078159809112549, Computation time: 0.1614534854888916\n", "Step: 215, Loss: 4.092671871185303, Computation time: 0.16064238548278809\n", "Step: 216, Loss: 4.069928169250488, Computation time: 0.16089081764221191\n", "Step: 217, Loss: 3.7901744842529297, Computation time: 0.16229534149169922\n", "Step: 218, Loss: 4.05171012878418, Computation time: 0.16241717338562012\n", "Step: 219, Loss: 4.072657585144043, Computation time: 0.16231489181518555\n", "Step: 220, Loss: 4.119385719299316, Computation time: 0.16376709938049316\n", "Step: 221, Loss: 3.946767568588257, Computation time: 0.16153383255004883\n", "Step: 222, Loss: 3.8579845428466797, Computation time: 0.16051745414733887\n", "Step: 223, Loss: 3.955892324447632, Computation time: 0.16411495208740234\n", "Step: 224, Loss: 4.090612411499023, Computation time: 0.16119980812072754\n", "Step: 225, Loss: 3.871494770050049, Computation time: 0.1633768081665039\n", "Step: 226, Loss: 4.001490116119385, Computation time: 0.1612398624420166\n", "Step: 227, Loss: 3.856689453125, Computation time: 0.16136479377746582\n", "Step: 228, Loss: 3.854506254196167, Computation time: 0.16175079345703125\n", "Step: 229, Loss: 3.920146942138672, Computation time: 0.16027593612670898\n", "Step: 230, Loss: 3.8486571311950684, Computation time: 0.16107869148254395\n", "Step: 231, Loss: 4.150424003601074, Computation time: 0.161329984664917\n", "Step: 232, Loss: 4.034335613250732, Computation time: 0.16145658493041992\n", "Step: 233, Loss: 3.862642288208008, Computation time: 0.16074752807617188\n", "Step: 234, Loss: 3.879786491394043, Computation time: 0.16097068786621094\n", "Step: 235, Loss: 3.9150876998901367, Computation time: 0.1610715389251709\n", "Step: 236, Loss: 3.6582045555114746, Computation time: 0.16137981414794922\n", "Step: 237, Loss: 4.022642612457275, Computation time: 0.16101980209350586\n", "Step: 238, Loss: 3.920273780822754, Computation time: 0.16168999671936035\n", "Step: 239, Loss: 4.942720890045166, Computation time: 0.139939546585083\n", "Step: 240, Loss: 3.820035457611084, Computation time: 0.16096997261047363\n", "Step: 241, Loss: 4.027595520019531, Computation time: 0.1608715057373047\n", "Step: 242, Loss: 3.9767158031463623, Computation time: 0.16132664680480957\n", "Step: 243, Loss: 3.927661895751953, Computation time: 0.16009283065795898\n", "Step: 244, Loss: 4.054908275604248, Computation time: 0.16004633903503418\n", "Step: 245, Loss: 4.072584629058838, Computation time: 0.1604931354522705\n", "Step: 246, Loss: 4.165594100952148, Computation time: 0.16080093383789062\n", "Step: 247, Loss: 3.9277215003967285, Computation time: 0.16055607795715332\n", "Step: 248, Loss: 4.001946449279785, Computation time: 0.1610417366027832\n", "Step: 249, Loss: 3.9720990657806396, Computation time: 0.16184639930725098\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAACqUAAAHiCAYAAABrrlQ+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzddXRc573u8e+MRszMstCyzI45zEmbNNCgk7Rp2qRN6RRO6RZP6RROGVNMw9hw0jDHkJjZsizJYmaGmfvHtgKOQTAz75bm+ax11l6RZvZ+7lq33prZz/t7HR6PBxERERERERERERERERERERERERERkalwmg4gIiIiIiIiIiIiIiIiIiIiIiIiIiLTn0qpIiIiIiIiIiIiIiIiIiIiIiIiIiIyZSqlioiIiIiIiIiIiIiIiIiIiIiIiIjIlKmUKiIiIiIiIiIiIiIiIiIiIiIiIiIiU6ZSqoiIiIiIiIiIiIiIiIiIiIiIiIiITJlKqSIiIiIiIiIiIiIiIiIiIiIiIiIiMmUu0wGOJykpyZObm2s6hoiI+NGmTZtaPB5Psukc04XulSIigUf3yonT/VJEJLDoXjlxuleKiAQW3SsnTvdKEZHAonvlxOleKSISeI52v7R9KTU3N5eNGzeajiEiIn7kcDgOms4wneheKSISeHSvnDjdL0VEAovulROne6WISGDRvXLidK8UEQksuldOnO6VIiKB52j3S6e/g4iIiIiIiIiIiIiIiIiIiIiIiIiIyMyjUqqIiIiIiIiIiIiIiIiIiIiIiIiIiEyZSqkiIiIiIiIiIiIiIiIiIiIiIiIiIjJlKqWKiIiIiIiIiIiIiIiIiIiIiIiIiMiUqZQqIiIiIiIiIiIiIiIiIiIiIiIiIiJTplKqiIiIiIiIiIiIiIiIiIiIiIiIiIhMmUqpIiIiIiIiIiIiIiIiIiIiIiIiIiIyZSqlioiIiIiIiIiIiIiIiIiIiIiIiIjIlKmUKiIiIiIiIiIiIiIiIiIiIiIiIiIiU6ZSqoiIiIiIiIiIiIiIiIiIiIiIiIiITJlKqSIiIiIiIiIiIiIiIiIiIiIiIiIiMmUqpYqIiIiIiIiIiIiIiIiIiIiIiIiIyJSplCoiIiIiIiIiIiIiIiIiIiIiIiIiIlOmUqqIiIiIiIiIiIiIiIiIiIiIiIiIiEyZSqkiIiIiIiIiIiIiIiIiIiIiIiIiIjJlKqWKiIiIiIiIiIiIiIiIiIiIiIiIiMiUqZQqIiLe4fFAW7npFCIiIvbWesB0AgkUbjcMdsNQn+kkIiIiEzM8YDqBiIiITNDwqJtRt8d0DBERCWQejz5PiojYiMt0ABERmSF2PQQPfRKufwJmrTadRkRExH4qX4fbL4aLfg+LrzGdRmaigU7YejfseRzqtsDwoUJqZArMOhEWXAGzz4cgfRUgIiI209sCr/4Ctt8L/e0QnwsrPmn9X1Cw6XQiIiJyFKWN3fzoyT2sL28FYGVeAl87bw4LsmINJxMRkYDR3w4v/9T6XnSwCzJOgDO/BYVnm04mIhLQ9CRKRESmrrcVnvoapC2ErOWm04iIiNhP+0G4/6OQkA9zLjCdRmYa9yi8+Vd46Scw2AmpC+CEj0JMJowOQct+KH8Jdj8CiUVw7o+g+HzTqUVERCzN++DOy6GrFuZdCklF1mKeZ74J+/4Da+6B0GjTKUVEROQwz+9u5NN3bSI6LJiPrpoFwENbarn0T2/wyysXcfHiTMMJRURkxutpgts+BC2lMP9yiMuB3Y/CnZdZ34Ge+HnTCUVEApZKqSIiMnVPfx0GOuCiRzV5S0RE5HBDvXDvNeAegTX3QpimhYgX9TTBgx+Hyteg4Cw46zuQseT9rxsdgX1Pwos/hnuugkXXwAW/hJAI/2cWEREZ091gTZJ3j8KNz0HmUuvnHg9suxce/SzcfRV89FFNTBUREbGRTQfb+Mxdm5mbHsM/P7acxKhQAD5/ZhE33bGRL963ldjwYE4vTjGcVEREZqyRIbjrcuiosj4z5p1q/fzUr8DDN8Oz37YW7c//sNmcIiIBymk6gIiITHN7n4IdD8ApX4G0+abTiIiI2IvHA498Gpp2w+X/hMQC04lkJmk9AH8/G2o2wsV/guv+feRCKlgLh+ZeDDe/Dqd+DbbdA/88F3qa/ZtZRERkjHsU7r8eBjrhIw+/U0gFcDhg8Rq45M9w8A144fvmcgKjbg97G7rYU9/FwPCo0SwiIhKgPB7rM2DDDmvRoUG9gyN86b5tpMaGctvHV7xdSAWIjQjmthtWUJwazRfv20pdR7/BpCIiMqO9/L9Qvw0+/Nd3CqkAweHWz7JXwuNfgK56cxlFRAKYxtmJiMjk9bfDE1+C1Plwyn+bTiMiImI/r/7C2i7o3B9B4dmm08hM0l4J/7oQRgfhY09C1tLjvgUAVwic+S3IWg73fxT+dQFc/xhEp/ksakPnAC/ubWJzVTsHW3vpHhghNDiI3MQI5qbHsCIvgcXZcTgcDp9lEBERG3rrH1C9Hi655eiLXBddZb1m7R9g7qXjv9950aNba/nxk3to6h4EID4imM+eUcjHT8rD6dS9S0RE/ODgWvjP16xCKkBkMpz5bTjhemshh5/9+rlSqtv7uO+Tq4mLCHnf78NDgvjzdUv5wG9f5fuP7+IvH1nm94wiIjLDtZTBG7+DxddByYfe/3tXqLXI8c8nWvfQq+7wf0YRkQCnSakiIjJ5z3wLepvh4j9aBQcRERF5x94n4aUfwcKrYPXnTKeRmWSgE+68DIb74PrHJ1fQmX0uXPcgdNbArR+0tk/2su01Hdx420ZO/OkLfPPhHby8rwmnw0FWfAQxYS7eqmjjJ//Zy6V/WstZv3yFP798gN5BsxN/RETET3pb4IUfQMGZsOjqY7/27O9DVCo8+SVruqof/fq5Ur5w71bSY8P45RWL+N2aJczPjOVHT+7hv+7dwvCo2695REQkAO1+FG67CAa74YO/gEv/CknF1uS3p79hTVD1o5r2Pm5fd5DLT8hiRV7CUV+XlxTJ588s4pldjbxSqh06RETEy174vjUR9ez/OfprEgusoUp7HoPaTX6LJiIiFk1KFRGRydn/HGy9y/pjPmOx6TQiIiL20rQHHvokZJwAH/qtkcklMkO53fDIZ6CtAj72BKTOm/y5ck+2tku+41K4Zw3c8JT1Ze4UdQ8M88MndnP/xhriIoL59OkFXLI4k8KUqPdNQ23vHeK53Y08uLmGnz29l3+8XsEXzy7i6uXZuIK0jlZEZMZ6/dcw3Avn/+z4fyeFxcC5P4SHboLdj8D8y/wS8a4NB/ntC/u5YmkWP71sIUGHpqJ+aGE6t7xSzs+e3ktseDA/umS+pn2LiIhv1G+Hf98EGUvg2vshPN76+YIr4Jn/BxtugbhZsPozfov0+xfKwAFfOmf2cV970yn53PdWNb94Zh+nFiXpfikiIt7RuMsqmp72DYhKPvZrV30a1v8ZXvqJtUBfRET8Rk94RERk4gY6rZXYyXPgtK+bTiMiImIvfW1WwS8kEq6+yyslP5G3vfEb2PsEnPsjmHXi1M+XsxI+/Feo2wyPfX7KU3bKmrq5+I9v8O/NtXzq1Hxe+9oZfPW8ORSlRh/xAWR8ZAhXLs/m/k+t5qHPnEh+UiTffmQnV/5lHVWtfVPKIiIiNtXdCG/93Zomn3z8QgtgFVGT58DLP/PLtNT9jd384PHdnFKUxM/eVUgFcDgcfPr0Am4+rYC7NlTx8JZan+cREZEANDIID94AEQlw9d3vFFIBnE447ydQfAE8/z1o2uuXSC09gzy8pZYrl2WREXf87zpCXE4+d2YhO2o7eWFPkx8SiohIQFj/Z3CFw8pPHf+1odFw4ueg7Dlo2On7bCIi8jaVUkVEZOKe/Q5018PFfwRXqOk0IiIi9jE6Ag9+HLpq4ao7ISbDdCKZSWo2wYs/hHkftlb5e0vJhXDmd2DHA/DGbyd9mpf2NnHxH96gq3+Yu25cyf/7YAnRYcHjfv8JOfHc96lV/Pbqxexv6uEDv32VR7eq6CMiMuNs/IdVtDn1q+N/jzMITvsatOyDvU/6Lhvg8Xj47qO7iAgJ4pdXLsLpPPJUt6+eV8zSWfF8//HdNHUP+DSTiIgEoHV/hNYy6zv4I02BczqtnVlCouDx/5ryAsPxuGt9FUOjbm44KW/c7/nwkkyyE8K55ZUDPkwmIiIBo7cFtt8Pi9dYCzfGY+kN4AqDt/7m22wiIvIeKqWKiMjEHHgJNt8Gqz8LWctMpxEREbGX574L5S/Bhb+G7BWm08hMMjpsPWiMSrUePHp728NT/hvmXmyVXms3T/jtr5Q286k7NpGXHMnjnz+ZVfmJk4rhcDi4eHEmT3/xVOZlxPKFe7fyq+dK8fjhAauIiPjByCBs/CfMPg8SCyb23pKLITYH3vyrb7Id8sKeJtaVt/Klc2aTEh121NcFOR387LKF9A+N8vOn9/k0k4iIBJjeFnj1F9Yk1MKzjv66qGQ4+3+gegPse8qnkUbdHu5+8yCnFydTkBw17ve5gpx87MQ8Nh5sZ0dNpw8TiohIQNh2L4wOwopPjv89EQmw4AqrzNrf4bNoR7KjppPfvbCfP71cxq463QdFJLColCoiIuM32AOP/RckFsIZ3zKdRkRExF623g3r/wgrb4Yl15lOIzPNuj9A40744C8gLMb753c4rLJrVCo89EkY6hv3W9ceaOGTt2+kMCWKuz6xivTY42/jeDyZceHceeNKrlyWxe9e2M8X79vK8Kh7yucVERHDdj0Cvc3j22bxcEEuWHEjVL7ms20XPR4Pv3qulPykSNasyDnu6wtTovjo6lk8tLmGfQ3dPskkIiIBaMMtMNwHZ3/v+K9dfC0kFMCLPwK37z4zvV7WQmPXIFcty57we69YlkVESBD/Wlvp/WAiIhJYtt0DGSdASsnE3rfs49a9ddfDvsl1mMGRUb758A4+9IfX+fXzpfz86X1c8LvX+X8P7WBE33GKSIBQKVVERMbv+f+Bzmpry6DgqZcNREREZoyajfD4FyHvVDj3x6bTyEzTVg4v/xTmXAglF/ruOuHxcMmfoHW/NfV3HA409/DJ2zcxKzGCO29cSWxEsNfihLic/OyyhXz1vGIe3VrH5+7ezNCIvrQVEZnWtt4J8bmQf8bk3r/kIxAUCptv92qsMa+XtbC7voubTysgOGh8X51/9oxCIkNd/Ob5Up9kEhGRADPYbU0FL7kQkouP//ogF5z+DWjaDWXP+yzWQ5triA0P5sySlAm/NyYsmEuXZPLkjjq6BoZ9kE5ERAJC/XZr0f7iayb+3owlkFQM2+/zfq7DeDwevvLAdu7eUMUnT81ny3fOYet3z+HGk/O4580qPn/PFtxu7QolIjOfSqkiIjI+la/DW3+zpr/lrDKdRkRExD666uHeayEmHa64zXogJOJNz/8POILgg//n+2vlnw6rPmv93Xdw3TFf2js4ws13bCLE5eTWG1aQEBni9TgOh4PPnlHI9z40l2d2NaqYKiIynbUfhIpXrYluDsfkzhGRAHMugB33w8igd/MBf3utgpToUC5ekjHu98RHhnD96lye3tVARUuv1zOJiEiA2X4fDHTCSV8c/3vmXQrR6bDhzz6J1Ds4wjO7GvjQonRCXUGTOscVy7IZGHbz5PZ6L6cTEZGAsfPf4HTBvA9P/L0OByy6CqrWQXul16O9299eK+fxbXV8/fw5fPODJcRFhBAXEcK3L5zLtz5Ywn92NnDLqwd8mkFExA5UShURkeMb6oNHP2dNMznrO6bTiIiI2MfwANx3rTXJ5Op7rKKEiDfVbYHdj8KJn4OY8RdkpuTMb0FsDjz5ZRg98hQbj8fD1/+9nQPNPfx+zRIy43w7Rf+Gk/L4/kXzeHZ3I199cJumCYiITEfb7wMcsOjqqZ1n8bXQ3w6lz3gl1pjqtj5eLW3m2pWzJly4uf7EXIKDnPz11XKvZhIRkQDj8cCmf0HaAshcOv73BQXDsk/AgReh2fuTu18pbWZg2M2FCyf/mXRRViyFKVE8sLHai8lERCRgeDyw5zHIPQUiEyd3jgVXWMddD3sv12Gq2/r41XOlnDs3lZtPy3/f7288JY8LFqTzy2dL2dvQ5bMcIiJ2oFKqiIgc34s/hPYKuOgPEBJpOo2IiIg9eDzwxJegdhN8+C+QOtd0IpmJXvgBhCfA6s/575ohkfCBn1nbP67/0xFf8sCmGp7YXs9XzivmpMIkv8S6/sRcvnpeMY9ureOnT+/1yzVFRMSLdj4EOashLmdq5yk4A6LSvL7t4gMbq3E44IplWRN+b3J0KB9ekskjW2rp7Ne2xCIiMkl1W6BhByz92MSnii+93tphY+tdXo/19M4GEiNDWJ47+YW4DoeDS5dksrmqg9qOfi+mExGRgNC0B9rKoeRDkz9HXA5kLIE9T3gv12H+96k9BDkcfP/ieTiOcC93OBz8+NL5RIW6+OETu/F4tPBeRGYulVJFROTYqjbA+j/D8hsh7xTTaUREROxj/Z9g291w+jen9mWYyNFUvGpNujnlvyEsxr/XnvNBKP4gvPxT6HjvJJumrgF+9MRuVuQlcPOpBX6N9ZnTC/jo6ln89dVybn2jwq/XFhGRKWjeB817YN4lUz+XMwjmXgxlz1vT6r3A7fbw4KYaTi1KJmOS07+vXTmL/uFRHt5c45VMIiISgLbfB0Gh70xym4ioFCg6xzqHe9RrkQZHRnlxbxPnzE0lyDnBouxhPjA/DYBndjZ4I5qIiASSPY8DDphzwdTOM+dCqN0IXXVeifVu+xq6+c/OBj5xSj7psUf/XBkXEcKXz5nNG2WtvFLa7PUcIiJ2oVKqiIgc3VAfPPoZiM2Gs79vOo2IiIh9HHgRnv02lFwEp37VdBqZqV78McRkWouDTPjAz8DjtqbmH+LxePj2IzsZHHHzs8sW4pziQ8mJcjgcfO9D8zhvXio/fGI3r+3XF7ciItPCrkcAh/W3kzfMuwRGBqD0Ga+cbuPBduo6B/jwCZmTPseCrFgWZcVy14YqTbsREZGJc49a98uicyAsdnLnWHwNdNdD+Utei/VmRRs9gyOcOy91yufKT46iODWap3eplCoiIhO0/1nIXArRaVM7z9hwib1PTj3TYf74UhmRIUF8/KTc4772mpU5ZMSG8aeXDng9h4iIXaiUKiIiR/fiD6G1DC7+A4RGmU4jIiJiD60H4IEbILkELvkzOPWxSnygagNUr4eTvgjBYWYyxOXAqk9bk3bqtgDwn50NPLu7kS+fM5u8pEgjsYKcDn515WJmp0bzubu3UNnSaySHiIhMwN4nIHsFxKR753zZqyA6HXY97JXTPbm9jlCXk7NLpla4uXJ5NvubethV1+WVXCIiEkAOroWeBph/2eTPMft8CI3x2v0R4OV9zYS4nKzOT/LK+c6fn8ZblW00dw965XwiIhIAeluhdpO1cGOqkoshocBrCxzHNHUN8NSOetasyCEuIuS4rw8OcnLTqfm8WdnGxso2r2YREbELPT0VEZEjq3jN2pZ4xSch/zTTaUREROxhoAvuWQMOB6y5W4s2xHfW/g7C42HJtWZznPwliEiEZ7/DwNAIP35yD3PTY/jEyXlGY0WGuvjrR5bhcMBNt2+kb2jEaB4RETmG7kZo2A5F53rvnE4nFH8Ayl+GkamVWkbdHv6zs4EzilOIDHVN6VwfmJ+Oy+ng8e3e3wpSRERmuN2PQHAEzD5v8udwhVr3x71PwuiwV2K9UtrMyrwEwkOCvHK+8+en4fHAc7sbvXI+EREJAOUvAR4oPNs75ys6Bypfg+F+75wPeGBTDSNuD9eszBn3e65ank1MmIvb1h30Wg4RETtRKVVERN5vsBse/Qwk5MPZ/2M6jYiIiD243fDQJ60p4lfcBvG5phPJTNVSZj1EXH4jhJiZRvq2sFg47RtQ+RovPHY7tR39fPvCElxB5r9OyEmM4A9rTqCsuYfvPbrLdBwRETmaAy9YR29MtXm3onNhqAeq1k3pNFur22nqHuQDC6a4DSSQEBnCyUVJPLGtHrfbM+XziYhIgPB4YN/TUHDm1D8Dzr0Y+tutss0U1bT3UdbUw+nFKVM+15g5adHMSozg6V0NXjuniIjMcGXPQ3gCZCzxzvkKz4GRAah8wyunc7s93PtWFavzE8lPHv8Qi4gQFx8+IYund9bT2qMJ4iIy85h/iiQiIvbz7LehswYuucV8EUJERMQuXv5fKP0PnP9TTREX31r3BwgKsSbW28GyGxiNz6dwx284Z04SJxZ4Z9tGbzi5KInPnVHIA5tqeGRLrek4IiJyJGXPQ2QKpC7w7nnzTrXul/ufm9JpXtjTRJDT4bXCzUWLMqjt6GdLdbtXziciIgGgYQd01VhTTqeq4EwIjoTdj035VK/tbwHgtNne+wzocDg4f34aa8ta6Oz3zjRXERGZwTweqHjV+vzn9M7UbnJPAlcYlE3ts+SYzVXtVLf1c8WyrAm/95qVOQyPenhwU41XsoiI2IlKqSIi8l5lz8Omf8Hqz0HOStNpRERE7GHXw/Dq/8GSj8CKm0ynkZmstxW23QOLroYo702jmZKgYB6Kvo5ix0F+UFxhOs37fOGsIpbnxvOth3dQ0dJrOo6IiLybexQOvGhts+j08lfRIZGQezLsf3ZKp3lxbxPLc+OJDQ/2Sqxz5qYS6nLy2NY6r5xPREQCwL7/AA4oOm/q5woOh4IzrPujZ2pTu9cdaCU5OpSCCUx9G4/z5qUx4vbw8r4mr55XRERmoI4q6Kq1Pvt5S3A4zDoJyl/xyume2F5PiMvJOXNTJ/ze2anRLMmJ42EttheRGUilVBEReUd/Ozz6eUieA2d8y3QaERERe6jfDo98BrJXwgW/BIfDdCKZybbdbW0ftfJm00neVtPex7fKimkKnUX65t+A22060nu4gpz89uoluIKcfP6ezQyOjJqOJCIiY2o3W981FJ3tm/MXnQstpdA2uUUTtR397G3o5qw5E394eDTRYcGcOSeFJ3fUMzJqr3umiIjY1P5nIXMpRCV753yzz7MKPI27Jn0Kj8fD+vJWVucn4vDy9yCLsuKIDQ/m9UOTWEVERI7q4FrrOOtE75439yRo3gO9U7sXjbo9PLmjnjOKk4kOm9xCx4sXZbC3oZt9Dd1TyiIiYjcqpYqIyDv+8w3oaYRLb4HgMNNpREREzOttgXuvhfB4uPIOcIWaTiQzmcdjTazPXgWpc02nedstrxwARxDBZ/0/68va3Q+bjvQ+GXHh/OKKReys7eIXz+wzHUdERMaUPQcOJ+Sf4ZvzF5176DrPT+rtr5Y2A3DGHC+VgA65aFEGLT1DbKho8+p5RURkBupvh7rNUHiW9845dn8sfXrSpyhv6aWpe5BV+YleCvWOIKeDkwoTeb2sBc8Up7mKiMgMd/ANCIuD5BLvnjf3lHfOPwWbq9pp7h7kgoUZkz7HBQszcDrgsW2alioiM4tKqSIiYtnzBGy/F079KmQsMZ1GRETEvNFhuP+j0NsEV90J0d6boCVyRJWvQWsZLP2Y6SRva+gc4P63arh8WRbxy660vgB++ae2m5YK1nbJ167M4e+vV7CxUiUgERFbKHseMpdBRIJvzp9YAAn51oS5SVhf7pttiU8rTibE5eSFPdqWWEREjqPiVfC4vbuAIzoN0hdNetEGwLoDrQCsLvB+KRXglKJk6jsHONDc45Pzi4jIDHFwrTUl1enlalPGEgiOgMqplVJf2NOEy+ng9OLJL3RMjg7lpMIkntxer8UaIjKjqJQqIiLWFLgnvghpC+HUr5hOIyIiYg//+bq1UvqiP0DmCabTSCDY9C9r5f+8SwwHecctrxzA7fHw6dMKwBlk/a3YUgr7njId7Yi++cESMuPC+coD2+gbGjEdR0QksPW1Qe1mKDzbt9cpOs8q9Az3T+htY9sSr/LBtsQRIS5W5Sfy0j6VUkVE5DgOvAQh0ZC1zLvnzT8dajbCUO+k3r75YDtJUaHkJkZ4N9chJxcmAfBq6dS2TRYRkRmsuwHaDlilVG8LCobsFVOelPri3kZW5CUQExY8pfOcNy+NytY+9jdpsYaIzBwqpYqIBDqPB574Egx0wqW3WH+Ei4iIBLqN/4SN/4CTvgALrzCdRgJBbwvsfgwWrYHgcNNpAGjuHuSeN6v48AmZZCccehA59xKImwVv/Mb6O9JmIkNd/N/li6hs7ePnT+8zHUdEJLBVbwA8kHuyb6+TfzqMDFjFmwmoaOmlsWuQ1T7YlhjgzOJkKlp6qWiZXBlIREQCRMUr1r3S29/L554K7mGoWj+pt2+p7uCEnDivL9wYk50QQX5SJK/tb/bJ+UVEZAY4uNY6+qKUCjDrZGjcZS2onITqtj5KG3s4c07KlKOcM9fape3ZXQ1TPpeIiF2olCoiEuh2/hv2PAZnfBNS55lOIyIiYt7BtfDUV62pXmd9z3QaCRTb7rUeGC693nSSt925/iCDI25uPq3gnR8GueDEz0PNW+98MWwzqwsS+diJufxrbeXbW06KiIgBB9dCUAhkLvXtdXJWAY4JT7hZX249eFyVn+CDUHDmHOuh4kt7NS1VRESOorsB2sp9s4AjZxU4XVD52oTf2tY7REVLL0ty4r2f611OLkpifXkbgyOjPr2OiIhMUwfXQnAkpC3yzflzTwY8k/6O85VSa2HFGV4opabGhLE4O45ndjVO+VwiInahUqqISCDrqocn/xuylsOJ/2U6jYiIiHkd1XDfRyA+Fy77h7VduYg/bL8PMk6AlBLTSQAYHBnlrg0HOXNOCvnJUe/95ZLrICIJ3vitmXDj8LXzi8lNjOCrD26jb2jEdBwRkcBUtc66twWH+fY64XGQtgAqX5/Q29aXt5ISHUpeUqRPYuUkRlCQHMlL+1RKFRGRoxibYpqz2vvnDo2CzGVQ8eqE37q1uh2AJTlxXg71XqcUJdM/PMqmg+0+vY6IiExTB9dCzkprkbwvZJ4ArrAJL3Ac80ZZCxmxYeR76TPl2SUp7KjtpLl70CvnExExTaVUEZFA5fHA41+AkUG45BaVbkRERIb64N5rYHQIrr7HKjiI+EPTXmjYDguvNJ3kbU9sq6elZ4gbTsp9/y+Dw2Hlp2D/M9C42+/ZxiMixMX/XbGImvZ+fv1cqek4IiKBZ6gP6rbALB+UbI4k9xRrivfI+B7eeTwe1pW3sio/0WfbEgOcOSeFDeVt9A5qgYSIiBxB1XpwhUP6Qt+cP+8U63480Dmht22p6iDI6WBhVqxvch2yKj8Bl9PB6/tbfHodERGZhvo7oGkX5Jzou2u4Qq3BTRNc4AjgdlufKU8sTPLaZ8pTipIBWHtA90URmRlUShURCVSb/mUVCc7+HiQVmk4jIiJilscDj34WGnZYE1KTZ5tOJIFkx/3gCIL5l5lOAlhFnX++UUFRShQnFyYd+UXLb7S2z7LxtNTluQmsWZHDP16vYGftxB7CiojIFNVuBPeIbya/HUnuSTAyALWbxvXy8pZemrsHWV2Q6NNYZ8xJYWjUzetleqgoIiJHULUOspZBULBvzp93KnjccHDdhN62paqDOWnRRIT4aDLdIdFhwSzMimVDRZtPryMiItNQ/VbrmLXUt9fJPdl6JjDBBRy767vo6BvmpELvfaacnxlLXEQwr2mxhojMECqliogEotYD8Mw3If90WPEp02lERETMe/1XsOsha7HG7HNNp5FA4nbD9gesv8uiUkynAeCtynZ21XVxw0l5R1/pH5EAS6+HnQ9CR7V/A07AN86fQ0JkKP/voR2MjLpNxxERCRwH1wEOyF7pn+vlrLauN84JN+sOtAKwKt+3pdTluQlEh7p4eV+TT68jIiLT0GC3tWOGLxdwZK2AoFCoeHXcbxl1e9ha3cGSnDjf5XqX5XkJbK/pYGB41C/XExGRaaJui3VMX+zb62QtBzzvXG+c3ji08PDEgqMs6J+EIKeDkwqSeG1/Mx6Px2vnFRExRaVUEZFAMzoCD30SgkLgkj+DU7cCEREJcPuehhd+CPMvh5O+aDqNBJrq9dBZBQuvMp3kbXesP0hseDCXLsk89gtXfcaaMrzhFv8Em4TYiGD+56K57Kjt5F9rK03HEREJHFVrIXUehMf553oRCZA6f9yl1Dcr2kiJDiU3McKnsYKDnKwqSGTtoRKsiIjI22o2WlNMc1b57hrBYZC9AirHX0o90NxDz+AIS7LjfZfrXVbkJjA86mFLVYdfriciItNE3RaIz7U+6/lS5gnWcZy7box5q7KN/KRIUmPCvBrnlKIkGrsGKWvq8ep5RURMUBNJRCTQvPYLaxu9C38NMRmm04iIiJjVvA/+fSOkL4SLfg9Hmwop4ivb74fgCJhzgekkALT1DvHMzgYuXZJJeEjQsV8clw1zL4Itd8BQr38CTsIFC9I5oziZXz1XSk17n+k4IiIz3+gIVL/l28lvR5J7ElS/CSNDx33p1uoOls6KP/pEcC86sSCRg619ugeJiMh7Va0Hh/PQhDYfyjvV2pa4v31cL99SZb3OX5NSl81KwOGwyj0iIiJvq9sCGUt8f53weEgshJrxl1Ldbg9vVbazLNf7CzhOLrImr766v8Xr5xYR8TeVUkVEAknNRnjl59YkrvkfNp1GRETErP52uGeNNTnk6rshxLeTskTeZ3QE9jwGxR+A0CjTaQB4eEstQ6NurlqePb43rLwZBjph+32+DTYFDoeDH14yH48Hvv/4btNxRERmvoZtMNwLs/xcSp11Eoz0H3fbxbbeIara+liUHeeXWKsLEgFYp2mpIiLyblXrrKniYTG+vU72Sus4zrLN5oMdxEUEk5cU6cNQ74iNCKY4NVqlVBEReUdvK3RU+aeUCpC51Bro5PGM6+VlzT109g+zPNf7U1yz4iPIT4rk9f3NXj+3iIi/qZQqIhIohnrhoU9a01E/+H+m04iIiJjlHoUHP2F9uXXlHRCbZTqRBKKqtdDXCiUXmU4CgMfj4d43q1iUHUdJ+jgfjGavhPRFsOEv4/7i1oSs+Ag+f1Yhz+1u5KV9TabjiIjMbNVvWcdsH25HfCSzTrSOVWuP+bLtNR0ALMyK9XEgy+yUaBIjQ1RKFRGRd4yOWAMk/DFVPPMEayJrzZvjevnW6g4WZ8f5ZZr4mBV5CWw62M7IqNtv1xQRERurP7TQ0G+l1GXQ0whdteN6+dhCCl+UUsGalrq+vI3BkVGfnF9ExF9UShURCRTPfAvayuHSWyDMPw9eREREbOv578GBF+CCX/h/ipfImN2PgSscis4xnQSAzVUd7G/qYc14p6QCOByw4lPQvBcqXvFdOC+48eR88pMi+f5ju/SlroiIL9Vthqg0iM3073UjkyAh3yr5HMO26k4cDliQ6Z/vRpxOB6sKEll7oBWPjRdwiIiIHzXutKaKj00x9aXQaEiZCzVvHfelA8OjlDX3+O0eOWZ5bgJ9Q6Psquvy63VFRMSmxna/SF/kn+tlLrWOteObKr6xsp2kqFBmJfpm57VTipLpHx5lS1WHT84vIuIvKqWKiASCfU/DplvhxM9D7smm04iIiJi17V5Y+3tYfhMs/ZjpNBKo3G7Y8zgUngUh/tkW8XjufbOKiJAgLlyUMbE3zr8MIhKtaak2FuJy8j8XzaOytY+/vVpuOo6IyMxVu9maymZC1nKrdHOM8ue2mg4Kk6OIDgv2W6wTCxJp6BqgoqXXb9cUEREbq9tsHcdKML6WtRxqNlmfQ49hf2MPo27P+HfO8JIVedakubHJcyIiEuDqtkJiof+GLKXNh6CQ4y5wHLOtuoMlOb6bKr7i0ATWjbovisg0p1KqiMhM19MMj30OUhfAmd82nUZERMSs2k3w2H/BrJPh/J+YTiOBrOYt6GmAuRebTgJA39AIT+2o58KF6USFuib25uAwWHoD7PsPtFf6JJ+3nDo7mfPnpfGHl8qoae8zHUdEZOYZ6ILW/ZBhqJQ6tu1iZ80Rf+3xeNhW3cGi7Di/xjqxIAmAtQda/XpdERGxqbqtEBYH8bn+uV7WchjshJbSY75sd30nAHP9XEpNjQkjJyGCNytUvhEREaxJqRlL/Hc9VyikLbAWWB5H18Aw5S29LMryXWE2NiKY4tRo3qps99k1RET8QaVUEZGZzOOBx//Leij04b9af1SLiIgEqu4GuPdaiEqFK2+DIP9NxxJ5nz2PgTMYZp9nOgkAz+1upHdolA+fkDW5Eyz7ODic8ObfvBvMB77zobkA/OiJPYaTiIjMQPVbraM/HyC+W9Yy61h75Ak3Ne39tPYO+b2UmpsYQXpsGOtUShUREbDul+mLwEcT1t4ne4V1rHnzmC/bU99NZEgQOQm+2Y74WJbnJvBWZRueY0w7FxGRANDdCF21/v9MmbnUKsO6R4/5sp011gKOBVlxPo2zNDeezQfbGXXrvigi05dKqSIiM9nm22DfU3D29yB1ruk0IiIi5owMwn3XwUAnrLkbIpNMJ5JA5vHA7seg4Az/bUN1HI9sqSUjNuzt7aEmLDYT5l4EW+6AIXtPIM2MC+fzZxbx9K4GXi1tNh1HRGRmGZssY6qUmjofXGFH3XZxW00HAIt9/ADxcA6Hg9UFiawrb1XZRkQk0I0MQuNuyFjsv2smFFiTWWveOubLdtd1UZIeg9Ppp7Lsu5wwK472vmGq2uz9eVJERHzs7YWOft59I3MZDPdC895jvmx7rVVKXZjp2+90l+fG0z04Qmljt0+vIyLiSyqliojMVK0H4OlvQt5psPLTptOIiIiY4/HAE1+yHr5c8mdrKx4Rkxp3QmcVzLnQdBIAWnsGeXV/Cxctzpzaw8dln7CK37sf8Vo2X7nxlDzykiL5n8d2MTTiNh1HRGTmqNsCcbMgMtHM9V0hkL74qKWb7TWdhLicFKdF+zcXsDIvgbbeIcpbev1+bRERsZGm3eAetu5X/uJ0QtZyqD56KdXt9rC7vou5GTH+y/UuS7LjAdhS1WHk+iIiYhOjw5Ay1//f4WcutY5jCy2PYntNB9kJ4cRHhvg0zrJZ1uCAjZVtPr2OiIgvqZQqIjITjY7AQ5+EIJdVvnHqn3sREQlg6/8MW++C074O8y4xnUYESp+2jrPPN5vjkCe21zPq9nDpksypnSj3ZEgsgo23eieYD4W6gvjuh+ZS3tLLbWsrTccREZk56jabm5I6JmsZ1G2FkaH3/WprdQfzMmIIcfn/e5KleqgoIiJg3aPAv5NSAbJXWNPfBjqP+Oua9n56BkcoSTdTSp2dGkV4cBBbqzuMXF9ERGyi5EL4zDoIjfLvdRPyISQKGrYf82XbqjtZ6IedN7Liw0mLCeOtynafX0tExFfUUhIRmYle/T+o3QgX/traSlVERCRQHXgRnv2WNZHytG+YTiNiKX3WKuxEp5pOAsDDW2opSY+Z+tQ4hwOWfgxq3oTGXV7J5ktnFKdwRnEyv3thP609g6bjiIhMf72t0FEFmX7eZvFwWcthdBAad7znx6NuDztrO1nkhweIR1KQHEl8RLAeKoqIBLr6rRAWC/F5/r1u1jLAA7Wbjvjr3fVWWXWuoVKqK8jJgqxYtqiUKiIiJjidkDofGnYc9SWtPYPUdvSzKCvW53EcDgdLc+O1qFFEpjWVUkVEZpqD6+DVn8PCq2H+ZabTiIiImNN6AB64AZLnwKV/0eRwsYfeFmtLYZtMST3Y2svW6g4uWZzhnRMuWgNBIbDpX945n49964K59A+P8svnSk1HERGZ/uq2WMcMG5RSAWo2vufHB1t76RsaNbYtscPhYOmsBDYdVClVRCSg1W2F9EXWoj5/Grs/j92vD7O7rgung6kvVpyCJTlx7KnrYnBk1FgGEREJYGkLoGEnuN1H/PX2WmsBx4LMOL/EWT4rnrrOAWo7+v1yPRERb9NTWRGRmaS/HR66CeJmwQW/MJ1GRETEnIEuuGeN9ZDn6rv9v92PyNHsfw7wwOzzTCcB4Mkd9QBcuMhLpdTIRJh7MWy7D4b6vHNOHypMieIjq2dx75tV7KnvMh1HRGR6q9sMOKyijUmxmRCdbi0CeZe9Dd0AlKSZKaUCLM+Np6Kll+ZuTegWEQlII0PQtBvSF/v/2uFx1nTW+m1H/PXu+m4KkqMICw7yb653WZIdx9Com111+mwmIiIGpC2AoW5orzjir3fUdOJwwPxM/3ymXJabAKBpqSIybamUKiIyU3g88PgXobseLv8HhJpb0SwiImKUe9RapNFaBlfcBgl+3hJP5FhKn4aoNEgzXNg55Kkd9SzOjiMzLtx7J116Awx2wq6HvXdOH/rCWUXEhAfzoyd34/F4TMcREZm+6rZAUhGEmSt9vi1zqTWJ7l321HcR5HRQlGpusdLYQ8VNB/VQUUQkIDXthtEhyFhs5vrpi953fxyzp77L2DTxMYuz4wHYWtVhNIeIiASo9IXWsWHHEX+9vaaD/KRIosOC/RKnJD2GyJAgNlZqtw0RmZ5UShURmSm23AG7H4Ezv209fBEREQlUL/7IKv6d/1PIP810GpF3jAzBgRdh9rngNP9xvKq1j521XVywIN27J551IiTNhk23eve8PhIXEcKXz5nNG2WtPLe70XQcEZHp66Qvwjk/NJ3Ckr7IWqA02P32j/bUd5OfFGl0Atz8zBhCXE49VBQRCVRjU0ozlpi5fvoi6Dho7bj2Lh19Q9R29FOSbraUmhYbRnpsGFurO4zmEBGRAJVcAo6gY5RSO1mUFee3OEFOBwuyYtle0+G3a4qIeJP5p2AiIjJ1zaXwn69D3mlw4hdMpxERETFnx4Pw+q/ghOthxU2m04i8V9U6GOyC2eebTgLAkzvqATh/fpp3T+xwwNKPWdsmN+z07rl95JoVORSlRPHjp/YwODJqOo6IyPSUsxKK7XGPI30R4HnPw8Q99V3MMVy2CXUFsTgrjrcOqpQqIhKQGndCWCzEG9rRZWxCa/329/y4tLEHgOI087uvLc6OY0u17pMiImJAcBgkF0PD9vf9qrl7kKbuQeZlxvo10qKsOPbUdzM04vbrdUVEvEGlVBGR6W5kEP79CXCFwaV/scXULRERESPqtsKjn4Oc1fDBX1jFOBE7KXsenMHWQiIbeGpHPYuyYslOiPD+yRetgaBQ2PQv75/bB1xBTr594VwOtvZx29pK03FERGSq0hdbx0NbFHcNDB+aAGe+bLMsN55dtZ30D2kRhIhIwDn/Z/Dptea+r0hbZB3HJrYesr/Jmiw+O9X8fXJxdhzVbf209gyajiIiIoEobeERJ6Xua7DulXP8vIBjYVYcQ6Nu9jZ0+fW6IiLeoOaSiMh098IPrBVbl/wJYry89aqIiMh00dME914DEYlw5R3gCjGdSOT9DrwEOasgNMp0Eqpa+9hR28kHF/jo78eIBJh3CWy/D4Z6fXMNLzttdjJnzknh9y+U0aIHoCIi01t0KkSlvV262VtvPUAsSTM7KRWsUuqI26OtiUVEApHTCbFZ5q4fmQix2e8vpTb2EBkSREZsmKFg71icHQfANm1VLCIiJqQtgO566Gl+z4/HSqH+niq+MMuazLqtptOv1xUR8QaVUkVEprP9z8O6P8Dym6D4A6bTiIiImDEyCPd9BPraYM3dEJVsOpHI+3U3QuMOKDjTdBIA/rOzHoAPzPfhoqalN8BgF+x8yHfX8LJvXVBC//Aov3y21HQUERGZqozFUL8VeOcBYkm6+VLq4ux4AJVSRUTEjPRF7yulljX1UJgShcMGO87My4zF4YAdNZoIJyIiBqQtsI4N29/z430N3SRFhZAUFerXOFnx4SREhrBNnx9FZBpSKVVEZLrqaYJHboaUuXDuD02nERERMcPjgae+AtXr4ZI/Wg9XROyo/CXraJNS6vN7GpmbHkNOYoTvLpKzCpKKYcsdvruGlxUkR/HR1bnc91YVu+v0EFREZFpLXwQtpTDUy576LuIigkmN8e8DxCNJiAxhVmIEW6vbTUcREZFAlL4IWstgsPvtH+1v6qYwxb+T344mKtRFXlIkO2o1EU5ERAx4u5S64z0/3tfYzRwDO284HA4WZsWyXRPERWQaUilVRGQ6crvhkU9bXxxd/k8IDjedSERExIw3/wabb4dT/hvmX2Y6jcjRHXgRIhIhbaHpJLT2DLLpYDtnz0317YUcDlhyHVRvgObpM3n0C2cVERsezA+f2I3H4zEdR0REJit9MXjc0LiLPfXdlKTF2GICHFhbE2tSqoiIGJG+CPBAw04AOvuHaewapCg1ymyud1mQGctOlVJFRMSEiASIzX7PpNRRt4fSxm6K08ws4FiUFUdZUw+9gyNGri8iMlkqpYqITEcb/gxlz8N5P4aUEtNpREREzCh/GZ7+Bsz+AJzxbdNpRI7O7YYDL0H+GeA0/zH8pX3NuD1wTomPS6kAi64GRxBsvcv31/KS2IhgvnzObNaVt/Ls7kbTcUREZLIOTdB3125hX0M3c9LtMQEOrFJqY9cg9Z39pqOIiEigGdthpn4rAGVNPQAUpdirlNrQNUBz96DpKCIiEohS50HTnrf/82BrLwPDbnOl1OxY3B60YENEph3zT8NERGRi6rfBc9+D4gtg2SdMpxERETGjrQIe+BgkFcGH/2qLop/IUTXtgt4mKDzLdBIAnt/dSGpMKPMz/bDlVFQKzD4Ptt0Do9NnNf+aFTkUpUTxk6f2MDTiNh1HREQmIyYDIpPpPbiJ/uFRStL9v9Xi0SzOjgNga1WH0RwiIhKAotMgKs16zgCUNXUDUJRin8Ub8zNjAZVvRETEkJQSaCmFkSEA9jVY98o5hkqpC7PiANheo/uiiEwvenIrIjKdDHbDgx+HyCS46PfWlqgiIiKBZrAb7lkDHg+suQfC7FMwEDmiAy9ax/wzzOYABoZHeXV/M2eXpPpvC+Ml10FPozXpf5pwBTn59oVzqWzt4/Z1labjiIjIZDgckL4IT91WwNwDxCOZmxFDSJCTrdUdpqOIiEggSlsADTsB2N/YQ1iwk8z4cMOh3jEvw/qeZ4dKqSIiYkLKXHCPQNsBAPY2dONwmFvAkRQVSmZcONtqOoxcX0RkslRKFRGZLjweeOLL0FYOl/0dIhNNJxIREfE/txse+pS1UvmKf0FCvulEIsd34CVILoGYdNNJWF/eSt/QKGfPTfXfRYvOhchk2Hqn/67pBafNTuaM4mR++8J+Wnu0baSIyLSUvpiozjJCGaLQRtsSh7qCKMmIYYtKqSIiYkLqXGjZB6PD7G/qoSA5iiCnfQZgRIcFk58UqVKqiIiYkVJiHZt2A9ak1NzESMJDgoxFWpgVq0mpIjLtqJQqIjJdbLkTdtwPp/8/yD3ZdBoREREzXv4J7HsSzvsxFJifOilyXCNDULUe8k8znQSA5/c0EhESxOp8Py5wCgqGhVfBvv9Ab4v/rusF37qghL6hUX79fKnpKCIiMhlp83EyyokxLUSEuEyneY8l2XHsqOlkZNRtOoqIiASa1PkwOgStZZQ19VBko4UbY+ZnxrJLpVQRETEhsQgcQdC0B4B9jd3Gd96YnxlLVVsfnf3DRnOIiEyESqkiItNB0x546quQdxqc8t+m04iIiJix62F49eew+DpYebPpNCLjU7cZRvptsajI4/Hw/O4mTilKIizYzyv7l1xnbXu1/T7/XneKClOi+ciqWdy9oYp9Dd2m44iIyESlLgDgxKgGw0Heb3F2HP3Do5Q29piOIiIigSZlLgADtTuo7einKNVs0eZIFmTGUtc5oF0rRETE/4LDILEAmvbQPzRKZWsvxYZLqXMzYgDYU99lNIeIyESolCoiYndDvXD/9RAaDZf9HZzmtgYQERExpn47PPIZyFoBF/4KHPbZVk7kmCpeAxww6yTTSdhV10VD1wBnl6T6/+IpJZC5FLbcBR6P/68/BV84q4josGB+9ORuPNMsu4hIoHPH5dLvCWGhq9p0lPdZnB0HwNbqDqM5REQkACXNBqeLzsqtABTadFIqwA5NSxURERNSSqBpN6WN3Xg8GJ+UOu9QKXV3nUqpIjJ9qJQqImJ3T30NWkrhsr9BVIrpNCIiIv7X0wz3XgPh8XDVneAKNZ1IZPwqX7O2RoxIMJ2E5/c04nDAmXMM/U255Dpo2gV1W8xcf5LiI0P4wllFvLa/hZf3NZuOIyIiE1DbNcQ+Tza5o5Wmo7zPrMQI4iKC2aZSqoiI+JsrBJJmM1q/A4AiG5ZS52Va5ZudKqWKiIgJKXOhrYLyOuu7wNmGp4qnRIeRFBXKbk1KFZFpRKVUERE723YvbL0TTv0q5J9uOo2IiIj/jQzB/R+F3ma4+i6INjDhUWSyRgahegPknWI6CWCVUpfmxJMYZajYPf8ycIXBljvNXH8KPrJ6FvlJkfzwyd0Mj7pNxxERkXEqa+phjzuHhJ5S203qdjgcLMiM1QQ4ERExI3UekR2lBAc5yEmIMJ3mfWLCgslNjGCXJsKJiIgJKSWAh87qnba5V87NiNF9UUSmFZVSRUTsqrkUnvgyzDoZTv+G6TQiIiJm/OdrULUWLv4jZCwxnUZkYmo3wcgA5J5sOglNXQPsrO3izBKDk/fDYqHkItj5IAz3m8sxCcFBTr51QQnlzb3ctf6g6TgiIjJOZU097PHkEDzYDt0NpuO8z/zMWEobuxkcGTUdRUREAk3qPGKHGiiJd+MKsufj4pL0GPZoIpyIiJiQMhcAT+NuZiVG2uJeOS8jhrKmboZGtGBeRKYH8/9yiojI+w33wwMfg+AwuOzv4AwynUhERMT/3vo7bLoVTvoiLLjcdBqRiat8HXDArBNNJ+GVUmurqdNnGyylAiy5DgY6Ye+TZnNMwplzUjilKIlfP7+fjr4h03FERGQc9jd1Ux9WYP1H406zYY5gQWYsI24P+xq6TUcREZFAkzIPgBOjmwwHObqS9BgOtvXROzhiOoqIiASa+DwICiWio5SC5EjTaQCYmx7D8KiH/U36/Cgi04NKqSIidvT0N6BpF1z6V4hJN51GRETE/ypfh/98HYrOg7O+azqNyORUvAppCyA83nQSXt3fQnJ0KCXp0WaD5J4CcTmw5Q6zOSbB4XDw7Qvm0j0wzG9f2G86joiIjENZUw+jydaEG7uWUgF21HYaTiIiIoHGfWgC3OKQWsNJjq4kPQaPB/Zq8YaIiPhbkAtP0mxSByspSI4ynQaAuRkxAOyq0xRxEZkeVEoVEbGbHQ/Cpn/ByV+CorNNpxEREfG/9kq47yOQkA+X/U0Tw2V6Gh6AmresEqZho24Pr+1v5tSiZBwOh9kwTicsvhbKX4GOKrNZJqE4LZo1K3K4Y91Bypp6TMcREZFj8Hg8lDX1kJGWBrHZ0GC/UmpWfDix4cHsVClVRET8rN6TQKcngnz3QdNRjmpsUeXuepVvRETE/3piiyhyVNumlJqbGElESBC7VUoVkWlCpVQRETtpLoXHvwDZq+CMb5tOIyIi4n+DPXDPNeAZhavvgbBY04lEJqd2I4wMQJ75Uur2mg46+oY5rTjZdBTL4msAD2y9x3SSSfnyObMJDw7if5/aYzqKiIgcQ3PPIF0DIxQmR0HqfGjcZTrS+zgcDhZkxmpSqoiI+F1FSx97PTmkDRwwHeWoMuPCiQlzsUelVBERMaAhNI9MRyuFsaOmowAQ5HQwJy1aizVEZNpQKVVExC6GeuH+j4IrFC7/JwS5TCcSERHxL7cbHrkZmvfA5bdCUqHpRCKTV/k64ICc1aaT8GppCw4HnFKYZDqKJS4H8k6FbXdb/7ufZhKjQvn8WYW8uLeJV0ubTccREZGjGJtoXZgSDanzoKXUmmRuM/MzY9nX0M3giD0edIqISGCoaOlhjzuHqM5S8HhMxzkih8NBSXqMSqkiImLEAU8mAAWOesNJ3jE3I4Y9dV243fa8d4uIvJtKqSIiduDxwONfhOa9cNk/IDbTdCIRERH/e/XnsOdxOOeHUHiW6TQiU1P5OqQvhPA400l4pbSJhVlxxEeGmI7yjsXXQXslVK0znWRSrj8xl1mJEfzoyd2MjE6/Yq2ISCB4p5QaBWnzrUn8LfsMp3q/+ZkxDI96KG3oMR1FREQCSEVLHxXOHJxDPdBZYzrOUZWkx7CvoVvlGxER8bttA6kARHWXG07yjnkZsXQPjlDT3m86iojIcamUKiJiBxv/CTvuhzO+CQVnmE4jIiLif7sfg5d/AovWwOrPmk4jMjXDA1D9JuSeYjoJnX3DbK3u4LTZyaajvFfJhRASDVvvMp1kUkJdQfy/D5RQ2tjDvW9Vm44jIiJHUNbUQ3Soi9SYUEidb/2wYafZUEewIDMWgB21nYaTiIhIIKlo6aE35tAONc32W7QxZm56DH1Doxxs6zMdRUREAsxbnTGM4LLVfbIkPQaA3ZoiLiLTgEqpIiKm1W6Gp78BhefAKV8xnUZERMT/GnfBwzdD5jK48DfgcJhOJDI1tRthdBByTzadhNfLWnB74LTZSaajvFdIJMy7BHY9AoPTczLcefNSWZWfwK+eK6Wzf9h0HBEROUxFSy/5yZE4HA6Iz4OgUGjeYzrW++QkRBAT5mJnnUqpIiLiP5WtfThSiq3/aN5rNswxjJVv9qh8IyIifuTxeNjfMkBraBa0lJqO87bZqVE4HLCvodt0FBGR41IpVUTEpL42uP96iEqFD/8VnPpnWUREAkxvK9xzNYRGw1V3QnCY6UQiUze2JX32SrM5gFdKm4gJc7EoK850lPdbch0M98LuR00nmRSHw8F3LpxLe98Qf3ypzHQcERE5TEVLL7lJkdZ/BLkgaTY02a9043A4mJ8Zy05NShURET8ZHnVT1dZHSmoGRCbbupRalBpFkNOhUqqIiPhVW+8Qnf3D9MUW2mpSakSIi5yECPY16r4oIvan9pOIiCluNzz8Keiuhytug4gE04lERET8a3QYHrgeuhvh6rshJt10IhHvqNoAyXOM/33n8Xh4pbSZk4uScAXZ8ON/9kpIKICtd5tOMmnzMmK5cmk2t75RQWVLr+k4IiJyyMDwKLUd/eQmRr7zw5Q5ti3dLMiMZW99N0MjbtNRREQkAFS39THq9pCXFGV9drXp/REgLDiI/KRIdtepfCMiIv5zoNn6ns+RPBvaK2Fk0GygdylOjWavJqWKyDRgw6dSIiIB4vVfwv5n4fyfQNZS02lERET87+lvQOVrcNHvdS+UmcPthpo3bTEltbSxh8auQU6bnWw6ypE5HLD4Gjj4OrRVmE4zaf993mxCgpz8+Cn7bQktIhKoqtv68HggP/ldpdTkOdBZDYP2e3g3PzOWoVE3pY32yyYiIjNPZatVtMlLijhUSt0HHo/hVEdXkh6jSakiIuJXB5p7AIjOmgeeUWgrN5zoHcVp0VS29DIwPGo6iojIMamUKiJiQvnL8NL/wvzLYfmNptOIiIj438Zb4a2/w4mfh0VXmU4j4j3Ne2GgE3JWmU7Cq6XNAJxq11IqwKKrAQdsu8d0kklLiQ7jc2cW8dzuRl7e12Q6joiIABWHple/d1JqiXW00daLYxZkxgKws7bTcBIREQkE5c1jpdQoSC6GwS5rRzebmpsRQ13nAB19Q6ajiIhIgDjQ1EOoy0l8zjzrBzb6HFmcFo3bA2VNPaajiIgck0qpIiL+1lUHD34CEovgQ7+1JkSJiIgEkoNr4amvQOHZcPb3TacR8a7q9dbRBpNSXy9roSA5kvTYcNNRji42CwrOgK33WFNmp6mPn5xLflIk3398N4MjmlIgImLa2AS43KTDJqUCNNlvsvWsxAiiw1zsUClVRET8oLK1l5gwF/ERwe/cH5v3mg11DCXpMQDsqddEcRER8Y+Kll7ykiJxJs+2ftBSajbQu8xJiwZgX4PuiyJibyqlioj408gQPHADDPfDVXdAaJTpRCIiIv7VUQX3fQTic+Gyf4AzyHQiEe+q2gCRyZCQbzTG0IibNyvaOKkwyWiOcVl8LXRWQeVrppNMWqgriO9dNI+Kll7+8XqF6TgiIgGvoqWXhMgQYsOD3/lhfC64wmxZunE4HMzLiNGkVBER8YuKll7ykqNwOBzvKqXaZwLc4UrSrfLNnvouw0lERCRQVLb2WjtvhERCbI6tSqm5iZGEuJzsa1QpVUTsTaVUERF/euab1vSsi35nbYsjIiISSIZ64d5rYHQI1twL4XGmE4l4X/V6a0qq4Wn422o66B8e5cSCRKM5xmXOBRAaC1vvNp1kSk6bncx581L5/Qtl1Hf2m44jIhLQxqbavIczCJJm23JSKsCCzFj2NHQzPDp9J4eLiMj0UNnSR25ihPUfUckQkWjb+yNASnQYSVEhKqWKiIhfjLo9VLf1Myvp0L0yebatFm+4gpwUJkexV5NSRcTmVEoVEfGXrXfDW3+D1Z+DBZebTiMiIuJfHg888hlo2AmX/xOSikwnEvG+7kZor4ScVaaT8EZZCw4HrMqfBqXU4HCY/2HY/SgMTO+HjN++YC5uj4cfPWnfB7oiIoHAKttEvv8XKSW2nJQKMD8zlqERNweae0xHERGRGWxoxE19Zz+zEiLe+WHyHFuVbY6kJD2GPQ3T+/OiiIhMDw1dAwyNupmVcOgzZVIxtOwHt30WEM5Ji2af7osiYnMqpYqI+EPdFnj8i5B7Cpz9fdNpRERE/O+1X8DuR+Cc70PROabTiPhG9XrrmG2+lLq2rJX5GbHERYSYjjI+i6+FkX7r34lpLDshgs+cXsiT2+tZW9ZiOo6ISEDqGxqhoWuAvKSI9/8yeQ501cJAp/+DHcfc9BgAdtfpwaKIiPhOXUc/bo/12eVtycXWog2Px1yw4yhJj6G0oUcTxUVExOcOtvQCvDNVPKnI+t6ys9pgqvcqToumsWuQjr4h01FERI5KpVQREV/rbYH7PgKRyXDFvyDIZTqRiIiIf+19El78ESy8Ck78L9NpRHynagO4wiB9kdEYfUMjbKlu58TCaTAldUzWMmtL5a13m04yZZ86LZ+chAi+99guPTAVETGgsqUPgNyko0xKBVtOg8tLiiTU5VQpVUREfOpgm3WfzDl8UupAB/Q0mQk1DiXp0QyNujnY2ms6ioiIzHCVrda9ctbYZ8rkYuvYUmoo0fsVp0UDsLeh23ASEZGjUylVRMSXRkfgwRusL3OuugMik0wnEhER8a/G3fDQJyFjCXzot+BwmE4k4jvV6yHjBHCZnU76VmU7w6MeTiyYRn97Ohyw+BqoWgetB0ynmZKw4CC+e+Fc9jf1cNvaStNxREQCTuWhskrekUqpyXOsY9MePyYaH1eQk+K0aG1NLCIiPlV1qJQ6K/Fd98mxsk2z/e6PY2anqnwjIiL+cbC1lxCXk/SYMOsHSbOtY2uZuVCHGSul7tN9UURsTKVUERFfeuF/oOJVuPBXkHmC6TQiIiL+1dcG966BkEi4+m4IDjedSMR3hvqgfhvkrDKdhLVlLQQHOVieG286ysQsvBoczhkxLfWskhTOKE7mN8/vp6lrwHQcEZGAUvH2VotHKKXGzYLgCGuLYhuamx7D7rouPDbePllERKa36rY+QlxOUqJD3/lh0tgEuP1mQo1DQXIUTgeUqnwjIiI+VtnaS05CBE7noQEbEYkQFmurUmpaTBgxYS72Neq+KCL2pVKqiIiv7Pw3rP09LL8RllxnOo2IiIh/jY7AAx+Drjq46i6IyTCdSMS3ajeBe8QepdQDrSzJiScixGU6ysTEpEPBWbDtHnCPmk4zJQ6Hg+99aB5DI25+9KR9pw2JiMxElS29pESHEhl6hPug02lNubHhpFSAkvQY2vuGaewaNB1FRERmqIOHF20AotMgJNpW2xIfLiw4iNykSJVvRETE5w629pGbGPHODxwOSCy0VSnV4XAwOzWasqYe01FERI5KpVQREV9o3AWPfg6yV8F5PzGdRkRExP+e/RZUvAIf+i1kLzedRsT3qtdbxyyz//+9o2+InXWdnFiQaDTHpC2+BrpqrX8/prncpEg+fXoBj22r49XSZtNxREQCRkVLL7lJR5iSOialxL6TUjNiANhd32k4iYiIzFRVbf3kJES894cOByQV2npSKkBxajSljSrfiIiI73g8Hg629pGTcNhnysRCaD1gJtRRFKVGqZQqIramUqqIiLf1t8O910JoDFx5G7hCTCcSERHxr813wIZbYNVnrYKZSCCoftPa8jAiwWiM9eWteDxwUmGS0RyTVvxBCIuDrXebTuIVnz69gPykSL79yE4Ghqf39FcRkemisrWXvMRjlFKT50B3PfR3+C3TeM1JiwZgd12X4SQiIjITeTweqtv63l9KBUgsstUEuCOZnRpNZWuvPluJiIjPNHcP0j88Sm7SYffKxELorIHhfjPBjqAgOYq23iFae7TThojYk0qpIiLe5B6Fhz5p/VF65e3WtjciIiKBpGo9PPElyD8DzvmB6TQi/uHxQM1GW0wFXnuglYiQIBZlxZmOMjnBYbDgctjzuC3LQhMVFhzEjy6dT1VbH79/0d5Th0REZoKewRFaeoaYdfgDxHdLKbGONpyWGh0WTE5CBHvqtTWxiIh4X1vvED2DI0cupSYVQWc1DPX5P9g4FadF4/GgqXAiIuIzla3WfXDW4QsdEwsAD7RV+D/UURSlWosa9+u+KCI2pVKqiIg3vfgj2P8sfOCnkLPSdBoRERH/6qyB+66DuGy44lYIcplOJOIfbeXQ3waZy0wn4Y2yFpbnJhDimsYf9xdfAyMDsOth00m84sSCJC47IYu/vFJOaaNKRiIivlTddugB4uFbLb5b8hzr2LTHD4kmbm56DLvrNSlVRES8r+rQffLIk1ILraONp6XOPlS+2degz1UiIuIbla29AOQmHmFSKtjqPlmUEgWolCoi9jWNn1KJiNjMjgfh9V/B0htg2SdMpxEREfGvoT649xoYHoA190J4vOlEIv5Tu8k6ZpmdlNrYNcCB5l5OKkw0mmPKMk6A5BLYepfpJF7zrQtKiA5z8c2HduB2e0zHERGZsY5ZthkTmw3BkbaclApQkh5DZWsvvYMjpqOIiMgM8/Z98vCiDUDSbOvYat8dHnITIwhxObXYT0REfOZgay8up4PMuPD3/iIh3zraqJSaHhtGZEgQZboviohNqZQqIuINtZvh0c/CrJPgAz8Hh8N0IhEREf/xeOCxz0H9drjs75BcbDqRiH/VvGWVW8a2AzZk7YEWwJrMOa05HNa01Jq3oLnUdBqvSIgM4ZsfLGHjwXbu21htOo6IyIxV1TqOUqrTaf29atdJqRkxeDywV1PgRETEy8buk9nxR5qUWgA4oMU+ZZvDuYKcFCZHsU/lGxER8ZHK1j4y48NxBR1WpQqNhqg0aD1gJtgROBwOClOjKWvWpFQRsSeVUkVEpqq7Ae69FiJT4MrbwRViOpGIiIh/vf5r2PlvOOu7UHy+6TQi/lezETKWgDPIaIw3ylqJiwhmbnqM0RxesfAqcATBtrtNJ/Gay5dmsTIvgZ88tYem7gHTcUREZqSqtj5iwlzERgQf+4UpJbadlDo3w7qP767vMpxERERmmqq2PlKiQwkPOcJn1+Bwa5q4jSelAhSnRbNPCzdERMRHDrb2Misx8si/TCy01aRUgKKUKPY3qpQqIvakUqqIyFQMD1iF1IEOWHM3RE7zqVQiIiITte9peOEHMP9yOPlLptOI+N/wADTsgKylRmN4PB7WHWhldX4iTucMmNofnQpF58C2e8E9ajqNVzgcDv73wwsYGHHz7Yd34vF4TEcSEZlxqtr6jv4A8d2SiqCnEQY6fR9qgjJiw4gJc7FHpVQREfGyqra+Y08TTyqEFnvvVjE7NZr6zgE6+4dNRxERkRnG4/FwsLWPWUe7VyYW2K6UWpgSRVP3IJ19ui+KiP2olCoiMlkeDzzxJajdCJfeAmkLTCcSERHxr+Z98O8bIX0hXPR7a8ttkUDTsAPcw5C13GiMqrY+ajv6ObEg0WgOr1p8DXTXw4GXTCfxmoLkKL58zmye3d3I49vrTccREZlxqo9XthmTVGwdW+w3Dc7hcDA3I4bddSqlioiId1W19ZGTeKxS6mxrW2IbL6ArTosCYH+jpqWKiIh3dfYP0z0wwqyj3SsTC6GvBfrb/RvsGIpSrPtiWbPuiyJiPyqliohM1ro/WtuJnv7/YO7FptOIiIj4V18b3HO1tb3b1XdDyDge/ovMRDVvWcfMZUZjrC9vBWD1TCqlzv4AhCfA1jtNJ/GqG0/OY1F2HN97dCctPYOm44iIzBijbg/V7X1kj6uUOts62nQa3Nz0WPY2dDHqtm8pSEREppeB4VEaugaOvXgjsRCGeqzFgTY1OzUagH0qpYqIiJdVt/UDkBV/jFIqQGu5nxIdX1GKdV/c39hjOImIyPuplCoiMhn7n4fnvgMlF8GpXzOdRkRExL9GR+DBG6CjGq66A2KzTCcSMad2I8RkQky60RgbyttIigqhIDnKaA6vcoXAgitg75O2mkAwVa4gJ7+4fCG9g6N899GdpuOIiMwYDV0DDI96xjcpNX4WOINtW0otSY9mYNhNZWuv6SgiIjJD1Hb04/Fw7PtkUpF1tOEk8TGZceFEhgRR2qBSqoiIeFd1ex8A2QnhR37B26XUMj8lOr7M+HDCgp3sb1IpVUTsR6VUEZGJatkPD34cUubBpbeAU/+UiohIgHn2W1D+MnzoN5CzynQaEbNqNkLmUtMp2FDRxoq8BBwOh+ko3rXkWhgdgp3/Np3Eq4pSo/niOUU8taOBJ7fbdwqRiMh0UtVqPUAcVyk1KBgS8qHZnqXUuRkxAOyu6zKcREREZoqa9uNMfwNIPFRKbbVvKdXhcDA7LVqTUkVExOuq28ZKqUe5V8bngsNpq1JqkNNBQXKUSqkiYktqUomITER/h7VVcVAwrLkbQiJNJxIREfGvTbfBhltg1WdhyXWm04iY1dMMHQcha7nRGNVtfdR29LMyL9FoDp9IWwip82HLXaaTeN0nT8lnYVYs33l0J609g6bjiIhMe2MPEGcljqOUCpA827aTUotSonE5HeyuVylVRES8o/btUupRpr8BxGRAcKStJ6UCFKdGs6+hG4/HYzqKiIjMINXtfcSGBxMTFnzkF7hCIG6WrUqpAEUpUZRpsYaI2JBKqSIi4zU6DPd/FNoPwpW3Q1yO6UQiIiL+dXAtPPnfUHAWnPMD02lEzKvdaB2zlhmNsaGiDYCV+QlGc/iEwwGLr4G6zdC0x3Qar3IFOfm/yxfRPTDMdx/dpQeqIiJTVNXWR5DTQXps2PjekDQb2ius73tsJsTlpDAlin3amlhERLykpr0Pl9NBaswx7pMOByQV2r6UOjs1mva+YVp6hkxHERGRGaS6rf/YizcAEgvtV0pNjaauc4CewRHTUURE3kOlVBGR8fB44KmvQMUr8KHfQu5JphOJiIj4V0cV3PcRiJ8Fl/8TglymE4mYV7MRHEGQvthojA3lrcRFBDM7JdpoDp9ZcCU4XbB15k1LLU6L5otnz+bJHfU8srXWdBwRkWntYFsfmXHhuILG+ZV30mxwj0BbuW+DTVJxWjR7NSlVRES8pKa9n4y4cIKcjmO/MLEIWu1dSi1Osz77lmoqnIiIeFF1ex/Z8cfZeSOxwPoMaaPF5QXJUQAcaOoxnERE5L1UShURGY91f4RN/4KTvwxLrjWdRkRExL8Ge+CeNdYUqTX3Qnic6UQi9lC7EVLnQcg4twn2kQ0VbSzPTcB5vIeL01VUMhSdB9vug9GZt+L/5tMKWJ4bz3ce2fX21tMiIjJxVW195CRM4J6cNNs6tpT6JtAUFadZ0246++03yVVERKafmnZr8cZxJRVBRzUM9/s+1CSNlVI1UVxERLzF4/FQ295PdsI4JqUO9UBPo3+CjUNRqlVK3a9SqojYjEqpIiLHs/dJePbbMPdiOPM7ptOIiIj4l9sND38KmnbDFf+0Hk6IiPW/jdrNkLXMaIz6zn6q2vpYmZdgNIfPLbkWepvgwAumk3hdkNPBr69ajAP40n1bGRl1m44kIjItVbf1kT2hUuqhv2ttWkotSYsBNAVORES8o6Z9HFsSg1W2wQOtB3yeabKSokJJiAxRKVVERLymuXuQwRH38T9TJhZYx9Yy34cap1kJEQQHOdjfpPuiiNiLSqkiIsdStxX+fSNkLIFLbgGn/tkUEZEA88pPYe8TcO6PoPBs02lE7KOlFAa7INNsKXVDeRsAq/ITjebwuaJzISIJttxpOolPZMVH8KNL57PxYDt/etm+D39FROyqe2CYtt4hZiVOoJQaGg3RGdBsz1Lq2BS4vfVdhpOIiMh0NzgySlP3IFnH25IY3pkk3rrft6GmqDAlirJmTYQTERHvqG63di/KPt69MrHQOtqolOoKcpKfFEVZo+6LImIvaleJiBxNVx3cczWEJ1hbFRvellVERMTvdj0Mr/wMFl8Hqz5jOo2IvdRtsY6ZS43G2FDRRnSYi5L0GKM5fC4oGBZeBfv+A31tptP4xMWLM7lkcQa/fWE/W6raTccREZlWqtqsB4g5E5mUCpA827aTUtNjw4gOc7FXU+BERGSK6joGAMY5KfXQBLgW+5RtjqQoJYr9jd14PB7TUUREZAaobusHIDvhOPfKmCwICrVVKRWgMDWK/U0qpYqIvaiUKiJyJEO9cPdVMNgN19wH0ammE4mIiPhX/TZ4+NOQvRIu/BU4HKYTidhL3WYIjnxn619DNlS0sjw3gSBnAPxvdPE14B6GHQ+YTuIzP7hkPmkxYXzxvq30DI6YjiMiMm28/QBxPBPg3i1pNrTsBxsWWhwOByVpMdqaWEREpqzm0PS3zPGUUkMircKNzSelFqVE0TUwQnP3oOkoIiIyA1QfWuh43KniTqe1gKPVXjsdFaVEUd3eR//QqOkoIiJvUylVRORw7lH4903QuBMuvxXS5ptOJCIi4l89TXDPNRCRCFfdCa5Q04lE7KduC2QsBmeQsQhN3QOUN/eyMi/BWAa/SpsPaQth612mk/hMTFgwv75qMdVtfXz3kZ2a+iMiMk5jZZtxTYB7t6TZMNQN3Q0+SDV1xWnR7GvQFDgREZmamnZr8ca475NJhbadJD6mKDUaQFPhRETEK6rb+0iODiUseBzf9SYW2G5SakFyFB4PVLT0mo4iIvI2lVJFRA73/Pdg35Nw/k9h9rmm04iIiPjXyCDcey30t8GauyEqxXQiEfsZHYaGHZCxxGiMNyusbexX5icazeFXS66zJjk37DSdxGdW5CXwhbNm89CWWu57q9p0HBGRaaG2o5/IkCDiIoIn9saxiect+7wfyguK06LpHhyhtqPfdBQREZnGatr7CHI6SIsJG98bEguhtdyWk8THFKVEAbC/URPFRURk6qrb+ske7+KN+Dxor7QGXdlEfnIkAOUtWqwhIvahUqqIyLttvBXW/h6W3wQrP2U6jYiIiH95PPDEl6DmTbjkT5C+yHQiEXtq2gMjA8ZLqRvK24gMCWJ+RozRHH41/3JwBsPWu00n8anPnVnIKUVJfPexXeyq6zQdR0TE9mra+8mMD8fhcEzsjUnF1rHFnlsUl6RbU+D2NahwIyIik1fb3k96bBiuoHE+Fk4ogMFO6GvzbbApSI4OJSbMpUmpIiLiFdXtfWTFR4zvxQn5MDoEXXW+DTUBeUmHSqnNmpQqIvahUqqIyJjSZ+DJL0PRudaUVBERkUCz/k/WttinfR3mXWo6jYh91W2xjqZLqRWtLM1NGP+DxZkgMhGKPwDb77Mm1s5QQU4Hv7lqMQkRIXz2rs10Dczc/7eKiHhDbXv/+B8gvlt0GoRE23aL4tmHtibeq1KqiIhMQU17P5lx45z+BlbZBqDtgG8CeYHD4aAoNZoylVJFRGSKRkbd1HcOkJ0wzntlQp51bK/wXagJighxkREbRnmz7osiYh8B9ORKROQYajfDAx+DtIVw+a0Q5DKdSERExL/Knodnvw0lF8Fp3zCdRsTe6jZDWOw7D+oMaOsdorSxh5V5CcYyGLP4Wuhrgf3Pmk7iU4lRofz+miVUt/fz9Qe347Hx1pkiIqbVtPdNrGwzxuGApCJo3uf9UF4QHRZMZly4SqkiIjIlNRNdvJFYYB1b7VtKBShKiVIpVUREpqy+c4BRt4fsiUxKBWgr912oSShIieKAJqWKiI2olCoi0lYBd18JkUlwzf0QGmU6kYiIiH+17IcHPg4p8+DSW8Cpjwkix1S3xZqSOtEtgr3ozYpWAFblB2AptfBsiEyBLXeaTuJzy3MT+Pr5xfxnZwO3vGKvL7pFROyie2CYroERsuInUUoFSC62/h62qZL0aPY1dJmOISIi09TgyCiN3QMTu0/GzQKH03Zlm8MVpkTR2jtEa8+g6SgiIjKNVbf3AZCdMM5SakwmOIOtjoGN5CdFUt7co4XtImIbetosIoGtrw3uuhzcI3DdQxCdajqRiIiIf/W3wz1XQ1AwrLkbQiJNJxKxt+EBaNxllVINWl/eRliwkwWZcUZzGBHkgsXXQOkz0FVnOo3P3XRKPhcuTOfnz+zlpX1NpuOIiNhObUc/AJmTLaUmFUF3HQzacxppcVo0B5p7GRwZNR1FRESmofqOATweJlZKdYVAbDa02XxSamo0gKaliojIlNS2W58px32vdAZBfK7tFm/kJ0fROzRKU7cWa4iIPaiUKiKBa7jfKuF0VMOae62HECIiIoFkdAQe/Di0H4Sr7oS4HNOJROyvcZe1oCnjBKMxNlS0sXRWPCGuAP1Yf8JHwTMKW+4yncTnHA4HP798IXPSYvive7ZQ3qwHriIi71bTdqiUGjfZUups69hS6qVE3lWcFsOo28OBJm3DKCIiEze2eCNrvFsSj0ksgFabl1JTrF3v9quUKiIiU1Db0Y/DAWmxYeN/U0IetNtsUmqyNXDkgL47FBGbCNCnVyIS8Nyj8NBNUP0mfPivkLPKdCIRERH/e+67cOBFuPBXMGu16TQi00PdZutocFJqZ98wexu6WJmXaCyDcYkFkHcabL7d+tt+hosIcfHXjywlOMjJTbdvpHtg2HQkERHbmHTZZszbpdT9XkrkXSVp1hS4fY1dhpOIiMh0VHNoS+IJTUoFSCiwtiW28RbA6bFhRIYEaVKqiIhMSW17P8lRoYS6gsb/poR8290nC5KtxRoHmrWgUUTsQaVUEQlMz3wL9jwO5/0vzLvEdBoRERH/23wHrP8jrLzZmjgoIuNTtwUikiA2y1iENyvb8HhgZV6CsQy2sOwG6KyCAy+ZTuIX2QkR/PGaE6hs7eNzd29hZNRtOpKIiC3UtPcR6nKSFBUyuRPE54HDCa1l3g3mJblJkYQEOdlb3206ioiITEO17ZOY/gZW2WawE/pafRPMCxwOB4Wp0exv0j1SREQmr66zn8yJLt6Iz4OhHuht8U2oSUiLCSM8OEi7LImIbaiUKiKBZ90fYcOfYdVnYPVnTKcRERHxv6r18MSXIP8MOPfHptOITC91WyDzBHA4jEXYUN5KiMvJouw4YxlsofgCqyC86VbTSfxmdUEiP7x4Pq+UNvO9x3bhsdE0BhERU2o7rAeIjsnem10hEDfLtpNSg4OcFKREsbdBhRsREZm42o4BUqPDCA6a4CPhxALr2Fbu/VBeVJQSxf5GlW9ERGTyatv7yYib6ETxfOvYXuH9QJPkdDrIS4qkXJNSRcQmVEoVkcCy8yF45psw92KVcEREJDB1VMN910FcNlxxKwS5TCcSmT6GeqF5L2QsMRpjQ0Ubi7PjCAuewJZSM5ErBJZcC/v+A131ptP4zTUrc7j5tALu2lDF316z9wNiERF/qGnvJ3OiDxAPl1QErQe8E8gHStKi2adSqoiITEJ9Zz/pcROckgqQcKiUauP7I1il1KbuQTr7hk1HERGRacjt9lDXOUDWhEupedbRZos38pMjKW/RYg0RsQeVUkUkcBx4CR76JOSshkv/Ck79EygiIgFmqBfuXQMjg7DmPgiPN51IZHqp3w4eN2ScYCxC18Awu+o6WZWXYCyDrZxwPXhGYeudppP41dfOK+aChen871N7eWpH4BRyRUSOpLa9n6z4iKmdJLEIWsvA7fZOKC8rToumoWuAjr4h01FERGSaqe8cmPj0N4C4HHA4oc3mpdTUKADKmrV4Q0REJq6ld5ChEffE75VxOYAD2uwzKRWgIDmKmvZ+BoZHTUcREVEpVUQCRO1maypc0mxYcw8ET2JlsIiIyHTmdsMjn4aGnXD5PyF5tulEItNP3RbrmLHYWIRNle24PbAyP9FYBltJLIC8U2HT7bYtEvmC0+ngl1csYumseL5431bWHWg1HUlExIi+oRFae4fIip/ipNTEAhjph65a7wTzsuK0aAD2alqqiIhMgMfjoa6jn4zYSTwPcYVYhRubTYA7XFGKdY/c36ipcCIiMnF1HQMAE999wxUKsdm2u0/mJ0fi8UBla6/pKCIiKqWKSABoKYO7LoeIBLju35oKJyIigenVn8PuR+GcH0DROabTiExPdZshOgOi04xFWF/RSnCQgxNy9Dft25beAJ1VUP6i6SR+FRYcxN8/uoxZCRHcdPtGdtR0mo4kIuJ3dR39AFMvpSYVWcfWsikm8o2S9BgA9qmUKiIiE9DWO8TgZKa/jUnIh1Z7T0rNjAsnLNjJ/iaVUkVEZOJq263PlJmT+UyZkAvt9puUClDerFKqiJinUqqIzGxddXDHJYADPvIIxKQbDiQiImLA7kfh5Z/AojVw4udNpxGZvuq2QOYJRiNsKG9jYVYc4SFBRnPYypwLISIJNv3LdBK/i48M4Y5PrCQ2PJjrb32TMj2IFZEAUz32AHGyZZsxiYXW0aal1JToUOIigtnb0GU6ioiITCNj09/SYydbSi2wJsB5PF5M5V1Op4PClCiVUkVEZFJqO/oAJreAIyHfdpNS85IiAShv1n1RRMxTKVVEZq6+Nrjjw9DfYU1ITSwwnUhERMT/6rfDwzdD1nK48DfgcJhOJDI9DXRaRZWMxcYi9A6OsLO2k5V5CcYy2JIrBBZfA3ufgu4G02n8Li02jLtuXInT4eAj/9hATXuf6UgiIn4zNtUmKz5iaieKTofgSNuWUh0OB3PSotmrSakiIjIBdZ3WfTIjLmxyJ0gsgMEu6Gv1YirvK0qJpqxR90gREZm4uo4BokNdxIYHT/zN8XnWPXLAPrsXRYa6SI8N06RUEbEFlVJFZGYa6oO7r4K2A7DmbqPlAREREWN6muHeayAsDq66E4In+RBCRKBuq3XMMDcpddPBdkbcHlbmJxrLYFsnXA+eUdhyp+kkRuQmRXL7x1fQMzjCmr+tVzFVRAJGbUc/wUEOUqJDp3Yih8Mq3rTs904wH5iTFkNpQzdut32n1YmIiL3Ud4yVUic7KTXfOrYe8FIi3yhMiaKuc4DugWHTUUREZJqpae+f+n2yrcJ7gbwgPzmSA5qUKiI2oFKqiMw8o8PwwPVQuxEu+wfknWo6kYiIiP+NDMF910Fvi7VAIzrNdCKR6a1ui3XMWGIswpsVbQQ5HSydFW8sg20lFULuKbD5NnC7TacxYm5GDHd+YiUdfcNc/VcVU0UkMIw9QHQ6vbAbQFKRbSelAhSnRdM7NErNoemwIiIix1PXOUCIy0liZMjkTpBwaPc5m21NfLiilCgADmgqnIiITFBdRz+Z8ZMtpeZZx3ablVKToihv7sXj0YJGETFLpVQRmVncbnj0s7D/WbjgVzD3ItOJRERE/M/jgSe/DNXr4ZI/Gi3RicwYdZshbhZEJBiLsKGilfmZsUSFuoxlsLVlN0BHFZS/aDqJMYuy47jrxpV09VvF1Oo2FVNFZGarbe8jc7JTbQ6XWGjdR4YHvHM+L5uTFg3A3oYuw0lERGS6qOvoJyM2DIdjkos34nLAEWTtSGdjRanWPXJ/Y7fhJCIiMt3UdvSTETfJHebiD5VSbbZ4Iz85ku7BEZp7Bk1HEZEAp1KqiMwcHg88+y3Yfh+c+W3robSIiEgg2nALbLkDTvkKzL/MdBqRmaFuC2SeYOzyA8OjbKvuZFWeuVKs7c25ECKSYOOtppMYtTArjrtuXPV2MbWiRdOCRGTmqmnvJ2uyU20Ol1gEeGw35WbM7EOFm30NKtyIiMj41HX0kx47hfukKwTisqHV3qXU7PhwQoKclDVpq2IRERm/nsEROvuHyYyLmNwJQqMgMgXa7PUZsiDZmiBergniImKYSqkiMnO89L+w/k+w8tNWCUdERCQQlb0Az3zTKmed8S3TaURmjmsegNO+buzym6vaGRp1szJfpdSjcoXCCR+FfU9Zk+4C2IKsWO6+aRX9w6NccctadtZ2mo4kIuJ1gyOjNHUPTv4B4uGSCq1jy37vnM/LIkNd5CREsFelVBERGaf6zgEypjpRPKHAdhPgDucKcpKfHMl+lVJFRGQC6jr6ASY/KRUgId92pdT85EgADjTrvigiZqmUKiIzw+u/gVd/Dks+Auf9L0x2OxoREZHprKUMHrwBkkvg0r+AU3/ui3hNyhxIKTF2+Q3lbTgcsHSWSqnHtOzj1nHjP83msIH5mbE8cPNqQl1BrPnretaXt5qOJCLiVXUdAwDem5SaUGAdW+1ZSgWYkxbN3oYu0zFERGQaGBl109g1MLWiDRwq25RbO9XZWGFKFPubtHBDRETGr7bdKqVO6TNlQp7tdtvIiA0nLNipSakiYpyeUovI9Pfm3+D571nbE3/otyrgiIhIYOrvgHuuBqcL1txtbR0jIjPGhopW5qbHEBsebDqKvcVlQ/EHYfPtMDxgOo1xBclRPHDzalJjw/joP9/k6Z31piOJiHjN2APETG+VUsNiICrN1lsUz0mLpqKll4HhUdNRRETE5hq7B3F7mPqk1MQCGOyCPnsvcitKiaamvZ++oRHTUUREZJqofXtS6lRKqfnQVQvD/V5KNXVOp4PcxEjKNSlVRAxTc0tEpretd8NTX4HZHzg0ES7IdCIRERH/c4/Cvz9hrci98naIzzWdSES8aHBklC1VHazMSzQdZXpYcZP1wHTXw6aT2EJGXDgPfGo18zJi+PRdm7nllQN4bD7lSERkPGra+wDInGrZ5t0SC6HFvpNSi9NicHugTNsTi4jIcdQfKtqkx051UurYJHH7LtoAKEqNwuNBU+FERGTcajv6cTkdpERP4V4Zn2cd2w96J5SXFKREUd6ie6KImKVSqohMX7segUc/C/mnwxX/giBNjRIRkQD13Heh7Hn44C8g92TTaUTEy7bXdDI44mZlfoLpKNND3mmQNBve+pvpJLYRHxnCPTet4oML0vnpf/byjX/vYGjEbTqWiMiU1Hb0E+R0TL1s825JhdBa5r3zedmc9GgA9jZoe2IRETk2r0x/A2sCHECbzUupKdaOQfubdI8UEZHxqevoJy02jCCnY/InGbtPtld4J5SXFCRFUt3Wx+CIdtkQEXNUShWR6an0WWsiXNYKuPpuCPbiAwgREZHpZOvdsO4PsPwmWHaD6TQi4gMbyq1tElfkqpQ6Lg6H9W9i7Sao2WQ6jW2EBQfx+6uX8PkzC7lvYzUf/ecGWnsGTccSEZm0mvZ+0mLCcAV58SvuxELob4O+Nu+d04tyEyMJdTnZ19BlOoqIiNhcfecA4IVJqfGzwBEEbeVeSOU7sxIjcTkdmiYuIiLjVt85QEbsVBdvHJqUarP7ZH5yFG4PHGztMx1FRAKYSqkiMv1UvAr3fwRS58G190NIpOlEIiIiZlS/CY9/AfJOhfN/YjqNiPjIhoo25qRFEx8ZYjrK9LHoagiJ0rTUwzidDv773GJ+deUiNld1cOHvX2drdYfpWCIik1Lb3k/mVKe/HS6xyDradFpqkNNBUWqUJqWKiMhx1Xf0Ex3mIjpsijvMBQVDXA602ntSaojLSW5SJPsbVUoVEZHxqe/sJz1uios3wuMhNBba7DUpNT/Z6k+UN+u+KCLmqJQqItNL9Vtw99UQnwvXPQxhsaYTiYiImNFRDfdeAzGZcMVt1kMCEZlxhkfdbDrYzso8TUmdkLAYq5i68yHobTGdxnY+fEIWD336RJwOB1feso67N1Th8XhMxxIRmZDajn4ypvoA8XBJh0qpLfu9e14vKk6NUSlVRESOq7ZjwHuLNxLyoc3epVSAopQoTUoVEZFxcbs9NHYOkjbVieIOhzUt1YaTUgEONPcaTiIigUylVBGZPmo3w52XQVQKfPRRiEw0nUhERMSMwR64Zw2MDMI190GEymoiM9WO2k76hkZZkae/fSds+U0wOgibbzedxJbmZ8byxOdPZlVBIt98eAdfvG8rXQPDpmOJiIzLqNtDY9cAGd6elBqXA06XbSelAsxJi6a5e5C23iHTUURExMbqO/tJn2rRZkxigTUBzuYL2YpSoqhs7WVwZNR0FBERsbnW3iGGRt2kx3jhXpmQB+32mpQaFeoiNSaUA5qUKiIGqZQqItND3Ra44xIIj4PrH4foNNOJREREzHC74eFPQdMuuPxWSC42nUhEfOjNijYAVmhS6sSlzIHcU2DjP8Gth5JHEh8Zwq0fW86Xz5nNE9vr+cBvXuOtyjbTsUREjqulZ5ARt4d0b5dSg4IhPg9abTwpNS0agL0NXYaTiIiIndV19Htv8UZCPgx22X4XisLUaNweqGjRVDgRETm2hs4BAO98pkzIh44qGB2Z+rm8KD8pSvdEETFKpVQRsb+6rXD7JRAWCx97AuKyTScSEREx56Ufw94n4NwfQ9HZptOIiI9tKG+lIDmS5OhQ01GmpxWfhM5q699NOaIgp4P/OquIB25eTZDTwVV/Wccvn93H8KjbdDQRkaOq6+gHIMNbE+DeLbEQWuw9KRVgX0O34SQiImJX/UOjtPcNe7GUWmAdbbY18eGKUqytiksbNRVORESOrb7T+kzplani8XngHrG+g7SRvORIlVJFxCiVUkXE3uq3w+0XQ2g0XP+EtY2aiIhIoNr+ALz2Czjho7Dq06bTiIiPjbo9bKxsZ2V+ouko09ecCyBuFqz7k+kktndCTjxPfeEUPnxCFr9/sYzLb1lHaaMKTyJiT/VjU21ivTwpFSCp0Crd2HTKdnJ0KAmRISqliojIUY0VbTLivLR4I3GslHrAO+fzkbykSJwOKGtSKVVERI7Nq58pE/Kto80Wb+QnRdLRN0xb75DpKCISoFRKFRH7atgBt18EIVHWhNT4WaYTiYiImFOzER79LMw6CT74S3A4TCcSER/bXddF9+AIK/MSTEeZvpxBVom/ej3UbDKdxvaiQl384opF/OGaJVS19nLB717jV8+VMjhiz2KWiASutyeleqts826JhTA6aLspN2McDgfFqdHsVSlVRESOoq7Dy4s34nLAEQSt9i6lhgUHkZMQQVmT7pEiInJs9Z0DBAc5SIwMmfrJEvKsY3vF1M/lRfnJkQBUtGixhoiYoVKqiNhTw0647SIIjoCPPQ7xuaYTiYiImNNZC/deA9FpcOUd4PLCFyUiYnsbKloBWJmnSalTsuQ6CI2B9X80nWTauHBhBs9/+TQuWJDO717Yzwd/+xpvVbaZjiUi8rb6zgHCg4OIDQ/2/skTi6xja5n3z+0lxWnRlDZ243Z7TEcREREbqjs0KTUzzkul1KBgiMu2XdnmSApTojUpVUREjqu+s5+02DCcTi8M/4hKA1cYtNnrPpmfFAXAgeZew0lEJFCplCoi9tO425qQ6gqD6x9/Z+S9iIhIIBrqhXvXWMdr7oNIldNEAsWGijZmJUaQFuuDKXCBJDQaTvgo7HoEOuw59c6OEqNC+c3VS7jt4ysYGHZzxS3r+H8PbaelZ9B0NBER6jv7SY8Lw+GL3QMSC61ji31LqXPSoukbGqW6vc90FBERsaG6jn4cDkiN8eJnyYR8221LfCRFqVFUtPQyPOo2HUVERGysvnOA9BgvLd5wOq0BW+2V3jmfl2TFh+NyOqhoUSlVRMxQKVVE7KVpD9z2IQgKgY89AYkFphOJiIiY43bDI5+G+u1w2T8gpcR0IhHxE7fbw1uVbazMSzAdZWZY+Snr+OZfzOaYhk6bncxzXz6Vm07J44GNNZzxfy/zt1fLGRrRQ14RMaeuY4AMb21JfLioFGvCts0npQLsbdD2xCIi8n71HQMkRYUS4vLiY+CEfNtNgDuSwuQohkc9HGzVwg0RETm6hs4B0uO8uHgjPs92pVRXkJOcxAgqNClVRAxRKVVE7KNxl1VIdbqsCakqpIqISKB75Wew+1E45wdQfL7pNCLiR/sau+noG2ZlnqYje0VcDsy9GDbdBoMq8ExURIiLb10wl6e/eCrLcuP58VN7OO83r/L87kY8Hm0dLSL+V9/ZT7qvJok7HNa01Nb9vjm/F8xOtUqp+1RKFRGRI6jr7CcjzsuLN+LzYKAD+tq8e14vK0q1tioua+oxnEREROzK7fbQ0Dng3d2p4nOtxRs2+54sPymK8hbdE0XEDJVSRcQe6rbCvy6wCqkfewKSikwnEhERMWvnQ/DKT2HxdXDi502nERE/21DeCsDKfE1K9ZrVn4PBLthyl+kk01ZhShS33rCCW29YjtMBN96+kav/up6NlfZ+MC0iM8vwqJum7kHSvV22ebfEQmix76TUyFAXOQkRKqWKiMgR1XX0k+HtxRsJ+dbR5tNSC5LHSqm6R4qIyJG19Q0xNOomPcaL98qEPBjuhd5m753TC/KTI6ls7WPUba+yrIgEBpVSRcS8mo1w20UQEgU3PKVCqoiISO1meOTTkL0KLvyVNa1JRALKhoo2MuPCyYqPMB1l5shaav27uv5P4B41nWZaO6M4hae/eCrfv2geB5p7ufyWdXz8X2+xq67TdDQRCQCNXQN4PHi/bPNuSUXQVQND9t36tzgtmr0NXaZjiIiIzXg8Huo7B0iP9fLijbdLqeXePa+XRYa6yIwLZ78mpYqIyFHUdwwAeHehY3yedWyv9N45vSAvKZKhETd1Hf2mo4hIAFIpVUTMOrgWbr8YIhLghv+888WGiIhIoOqqh3uvgcgUuOpOcIWaTiQifubxeHizok1TUn1h9Weh4yDsfdJ0kmkvOMjJ9Sfm8urXTufr589h08F2Lvjd63z2rs3sqVdJSkR8p84XDxAPl1hgHdsO+O4aU1SSFk1lax8Dw1poISIi7+gaGKFvaJR0by/eiJ9lHdvtPSkVrB0e9jeqlCoiIkdW32kVNL16r4zPtY42myienxQJQHlLr+EkIhKIVEoVEXPKX4Y7L4OYDKuQGpdtOpGIiIhZw/1WIXWgC9bcA1HJphOJiAGljT209g6xOj/RdJSZZ84FEDcL1v4ePNq2yhsiQlx8+vQCXv3aGXzujEJeKW3mA799jY//6y02HWwzHU9EZqCxB4g+nZSaeGgXn9Yy311jiorTYhh1eyjTJDgREXmXxi5r8Uaat++TweEQk2n7SakARSlRHGju0VbFIiJyRA2H7pVenSoelwM4bLd4Iy/ZKqVWNOtzo4j4n0qpImJG6bNw15XWKPuPPQUx6aYTiYiImOXxwKOfhbotcNnfIW2+6UQiYsjaAy0ArC5QKdXrnEGw+nNQ86a1a4N4TWx4MF85r5g3vn4m/33ObLZUtXPZn9dx1V/W8UppMx6VgEXES/w6KbXFzqXUaAD2NnQbTiIiInbS0OmjUipYO91Ng1JqYUoUgyNuatu1VbGIiLxfXccAwUEOEiNDvHfS4DBrEFd7pffO6QXJUaFEhbqo0KRUETFApVQR8b89j1tT4FLmwMee0BQ4ERERgFf/D3b+G876Lsz5oOk0ImLQ2gOt5CREkBUfYTrKzLTkOohIgtd/ZTrJjBQbEcznzyrijW+cyXcvnMvB1j6u/+ebXPj71/n3phoGR7TNtIhMTX1nPzFhLqJCXb67SEikNQ3OxpNScxMjCHE52dfQZTqKiIjYyNul1BgflFLjc223LfGRFKVGAbC/SQs3RETk/Ro6+0mLDcPpdHj3xPF5trtPOhwO8pMjKVcpVUQMUClVRPxrx4Nw//WQsRg++hhEJJhOJCIiYt7uR+GlH8PCq+HkL5lOIyIGjbo9bChv5URNSfWdkAhY/Rkoex7qt5lOM2NFhLj4+Ml5vPq1M/jZZQsYGnHz3w9s46SfvsSvnyulqXvAdEQRmabqOgbI8OWU1DGJhdC63/fXmSRXkJOilChNShURkfcY25I41Rel1IR86G2CQXvfewqTrWniZU3aqlhERN6vrnOA9BgffKZMyLXdpFSAvKRIyptVShUR/1MpVUT8Z9Nt8NBNkLMKPvIwhMeZTiQiImJe/TZ4+GbIWg4f+i04vLw6V0Smld11XXQNjLBapVTfWn4jhMbAa5qW6mshLidXLc/h2S+dyh2fWMHCrFh++8J+Tv7pS3z5vq3sqOk0HVFEppn6zn7SfbEl8eESC61JqR6P7681ScVp0exTKVVERN6lvnOAxMgQQlw+eASckG8dbTYF7nCxEcEkR4eyX6VUERE5gobOAdLjfDRRvKcBhvq8f+4pyEuKpK6zn4Fh7V4kIv6lUqqI+Mfrv4bH/wsKzoRrH4DQaNOJREREzOtugHvWQHgCXHUXBPvh4bqI2NraAy0AKqX6WlisVUzd/Si02Hdr5pnE4XBwSlEy//z/7N13eFRl+sbx75lJ770AgVRC6B3pVQVFsIBgxYK99/5b3dW1d9eCvYGKDQREpffeWwiEnjJppNeZ+f1xiOsqSBJm5j2TPJ/r8ppdZc65VcyZM+d57/eaPiy+fxiX92vL/J05XPDWCia8s4ofNh+VL8eFEA2SXVxFrCuaUiNSoKoYyvOdf64m6hATiKW0mqLyGtVRhBBCGERuSRUxzlq8UT+UWmTsoVSAlKgAGUoVQgjxF3a7nZxiJ10rQxP0V4O1pSZGBmC3w8ECaUsVQriWDKUKIZzLbodfn4AFT0LnCTB5Bnj5q04lhBBCqFdbBV9dAZVFcNkMCIxWnUgIYQCr9heQHBVAVKAMqTvdWbeChzesfFV1khYnIcKfJ8d1Ys2jI3n8/DTyy6q55+utnPXsQp6Zu4vMPHl4LIQ4uapaK4XlNbRyVVMqQEGG88/VRKkxQQDskbZUIYQQJ2QXVxET5Kyh1BPDNoWZzjm+A6VEBbDfUobdwI3nQgghXK+gvIYaq41WwU5Y6Bhm0KHUCH0240CeDKUKIVxLhlKFEM5jrYPZd8CqN/QWoovfBw8v1amEEEII9ex2/Rp5bANc9B7EdlWdSAhhALVWG+sPFjJAWlJdIyASelwFW7+G4mOq07RIQT6eTB2cyKL7hvHl1H4MSArn45UHGfHyUq74YA1zt2VTU2dTHVMIYSDZxVUAxDrjAeKf/T6UatxG7bQYfSei9JwSxUmEEEIYhVObUr0DwT/SLYZSk6MCKKuuI6ekSnUUIYQQBpJz4p7SuU2pxmoUjz8xlJqZL0OpQgjX8lAdQAjRTNVWwfdTYfdPMPQhGPYIaJrqVEIIIYQxrHgFtn8DIx6HjuNUpxFCGMS2o8epqLHKUKorDbwTNn4Mq9+C0c+qTtNimUwaA5MjGJgcgaW0ipkbjjJ97WFum76JiABvJvVpw+Q+bYkL81MdVQihWPbxSgBiQ1zQlBrSFsxekG/cptTIQG9C/TylKVUIIQTw30ZxpzWlAoQlQqGxhm1OJjlKX7iRkVvmmsUsQggh3EJW/T2lM4ZSfUPBO9hw18kAbw+ig7zJlKZUIYSLSVOqEMLxqkth+kR9IHX0czD8URlIFUIIIertngML/wmdJ8Dg+1WnEUIYyKp9BWga9EuQoVSXCWkLXSbCxk+gvEB1GgFEBfpw2/Bklj04nI+v6UP3uGDeWbKfIS8u5pqP1/HbrlzqrNKeKkRLlXWi1cYpWy3+mcmsD94YuClV0zRSYwJlKFUIIQQAlpJqwEntb/VCEww3bHMyKdEBAOyzlClOIoQQwkjqG7SdsmBB0yAsHooOOv7YZyghwp8D+XJNFEK4lgylCiEcq7wAPh0HB1fq2xGfdYvqREIIIYRx5GyH72+EVj1h/FuyaEMI8T9W7S8gLSaIUH8v1VFalkH3QG0lrPmP6iTiD8wmjeEdovhgSh9WPDSCO0aksCurhBs+28DgFxbz+oKM37dcE0K0HPWtNk4dtvmj8GRDD6UCdIgJYm9uKTabXXUUIYQQitUP2jj1OhmWCCVH9XsoAwv39yLEz5MMGUoVQgjxB9nFVXiaNcKd9f1raDwUGW/xRmJkAAfypSlVCOFaMpQqhHCc4qPw8Riw7ILJX0K3yaoTCSGEEMZRZoEZl4FPMFw2Azxl6zAhxH9V1VrZeLiIAUnSkupykanQcTysfQ8qClWnESfRKsSXe89uz8qHR/Dulb1Ijgrg1QV7Gfj8Im76fANL9+bJMJYQLUR2cSXh/l74eJpdc8KIFL0NzlrnmvM1QWpMIBU1Vo4WGXs4SAghhPNlF59YvBHk5KFUgKJDzjuHA2iaRkpUAPss0iYuhBDiv7KPVxIT7IPJ5KTCkNAEOH4YbFbnHL+JEiP8Kaqopai8RnUUIUQLIkOpQgjHyNkOH4yC0my48ntIHaM6kRBCCGEcddXw9ZVQng+XTYfAGNWJhBAGs+lwETV1NgYky1CqEsMehppyWPWm6iTib3iaTYzuHMPn1/dj6QPDuGFwIhsOFjHlo3UMe2kJ7yzZT35ZteqYQggnyjpeRWyIi1pSQW9KtdXCceMO3qTGBAKwJ6dEcRIhhBCq5bqqKRUM2QL3Z8lRgWRYyrDbZQGbEEIIXXZxFbFBTiwMCY0Haw2UZDnvHE2QEOEPQKa0pQohXEiGUoUQZy5zCXw0BjQTXDcf4geqTiSEEEIYh90OP90FR9bCRe9Aqx6qEwkhDGjlvnzMJo0+8WGqo7RMUWnQ6SJYNw3KC1SnEQ3QLtyfh8d0YNUjI3jjsh7EBvvw/Pw99H92IbdP38Tq/QXy8FmIZii7uJLYYBfuOBCeor8W7HPdORupfbQ+lJqeI01wQgjR0mUXV+HvZSbQx9N5JwlL0F8LM513DgdJjgrgeEUtBdIKJ4QQ4oTsYicvdKy/Thps8UZiZAAAB2QoVQjhQjKUKoQ4M1u/gi8ugZC2cP1vEN1JdSIhhBDCWFa9AVtnwLBH9IEnIYQ4iWV78+nZNsS5Dw/F3xv60Im21DdUJxGN4O1hZly3Vnx9U38W3DuEq86KZ9nePC57fw0jX1nKhysOcLxCHkIL0VxkH6+ilTPb3/4sPFl/zc9w3TkbKcDbg7ZhfuzJlaFUIYRo6XJLqpzbkgrgGwo+wW4xlJoSpQ/gZOSWKU4ihBDCCOx2OznFTr5WhtYPpR503jmaoE2oLx4mjcw8uSYKIVxHhlKFEE1jt8Pyl+GHm6Btf7juZwhurTqVEEIIYSzp8+G3f+jDqEMfUp1GCGFQBWXV7MgqZkhKpOooLVtUB+h8Cax7H8rzVacRTZAcFcj/XdCRdY+N4uWJ3Qjx9eRfc3bR798LufebLWw9clx1RCHEGSitqqW0uo7YEBc2pfqH68M3BcYdSgVIjQmUplQhhBBkO3vQBkDTICwRCo3VAHcyKdH6UOo+GcARQggBFJbXUGO1ERvkxGtlUGsweRjuOulpNtE2zE+aUoUQLiVDqUKIxrPWwdx7YeE/oculcOX3+spYIYQQQvxXzg747nqI7Qrj39a/tBdCiJNYsS8fux0Gt5ehVOWGPgR1lbDyddVJxBnw8TRzSa82fH/rQObdOZiJvdvw685cxv9nJRe9vZLZW7OotdpUxxRCNFJ2cRUArVw5lAoQngIF+117zkbqEBPIgfxyqmqtqqMIIYRQKLe4ipggF1wnQxPcoik1JsiHAG8P9kmbuBBCCCCnRL+njAl24rXS7KHvMFtkrKFUgMRIfxlKFUK4lAylCiEap6Ycvr4SNnwEg+6Bi94DDy/VqYQQQghjKbPAjMngHQiXfQVefqoTCSEMbNnefEL8POnSWhZ6KRfZHjpP0NtSS3NUpxEO0LFVEE9f2IU1j47kyQs6UlRew50zNjPo+UW8tSiDgrJq1RGFEA2UdbwSgFhnN8D9WXgy5Bu/KdVqs7PPIk1wQgjRUlltdiyl1cQEezv/ZGGJcPwwWGudf64zoGkaSVEBZMj1UQghBJB7Yig1OsjJ18rQBCg66NxzNEFChD6UarPZVUcRQrQQMpQqhGi4sjz49ALI+AXOewlGPQkm+TEihBBC/I/aKvjqcn3r58tmQFAr1YmEEAZmt9tZnpHHwOQIzCZpVDaEYQ+DrRaWvqA6iXCgAG8PrhmYwKL7hvHRNb1pHx3IS7/upf9zi3jw263szi5RHVEIcRqWEn2IPMaZWy2eTEQylOVAlXF/TnSICQQgPUea4IQQoqUqKKumzmZ3zXUyLBHsVn0w1eBSogJk0YYQQggAcuvvKZ290DE0HgqN15SaEBFAdZ2NrOJK1VGEEC2ETJMJIRomdxd8MEJ/nfQF9L1BdSIhhBDCeOx2mHUbHF0PF0+DVj1UJxJCGFx6bimW0mqGpkSqjiLqhSdBr2tg06eG365ZNJ7JpDGiQzSfX9+P3+4ZwoRebZi9NYsxry/n8vfXsHRvHna7NEYIYUT1Wy1GObvV5s/CU/TXQuNeE+LD/fHyMJEu2xMLIUSL5ZItieuFJeivBtya+M+SowKwlFZTXGHsVlchhBDOl1NchaZBRICT7ynDEqDqOFQWOfc8jZQY6Q/AgfxyxUmEEC2FDKUKIU4v4zf48Byoq4Fr50GH81UnEkIIIYxp2Yuw41sY+X/QcZzqNEIIN7B8bz4Ag9tHKE4i/seQB8HsBYueVp1EOFFKdCD/vqgLax4ZycNjOrA/r4wpH63jvDdWMGvLMeqsNtURhRB/kFtSRaifJ94eZteeODxZf83f59rzNoKH2URyZAB7pClVCCFarOziE0OprmpKBUO2wP1ZSlQAAPvy5BophBAtnaW0iogAbzzNTh6TCo3XX4sOOvc8jZQYoQ+lZubJUKoQwjVkKFUIcWp2O6x5F6ZfCmHxcMMiaN1TdSohhBDCmHZ8D4ufgW6XwaB7VacRQriJZRl5pEQFEOuKNhvRcIHRcNatsPN7yNqiOo1wshA/L24emsSyB4fzwoSu1NRZueurLQx9cQmfrDxARU2d6ohCCPSh1GhXDNr8WVgioEFBhuvP3QgdYgJJzylRHUMIIYQiub83pbrgWhkQDZ5+UJjp/HOdoZSoQAD2WcoUJxFCCKFaTnEV0a7YeSP0RKO4wRZvRAZ64+9llqZUIYTLyFCqEOLkrLUw916Y/xCkngfXzofg1qpTCSGEEMZ0dCP8eAvEnQUXvA6apjqREMINVNVaWXugkMEpkaqjiJMZeCf4hsHCp1QnES7i7WHm0t5x/HbPUD64ujexwT48+dMuBjy3iFd+20theY3qiEK0aLkl1WqGUj19IKQtFBi3KRWgQ2wguSXVFMnPKiGEaJGyi6vwNGuE+3s5/2Sapi/aMNiwzcm0DvXF28NERq4MpQohREuXW1JNdKAL7ikN2pSqaRqJkQFkylCqEMJFZChVCPFXlcfhywmw4SMYeDdc+jl4B6hOJYQQQhhT8VH46jIIiILJX4KHC1baCiGahbUHCqmpszGkfYTqKOJkfIJh8H2wfxFkLlGdRriQyaQxqmM0394ygG9v7k+f+DDeWJjBoOcX8ey83eSXVauOKESLlFNS5ZotiU8mPBnyjd2UmhoTBMCeHNmeWAghWqLc4iqiAn0wmVy0UDo03i2aUs0mjaTIADKkKVUIIVo8S2kV0a5oFPcOAP9IKDLe4o2ECH8y8+SaKIRwDRlKFUL8r4L98OHZcHAljH8bzn4KTPKjQgghhDip6jKYMRlqKuDyb8BfBsuEEA23fG8eXh4m+iWEq44iTqXPVAhuC78+Djar6jRCgd7xYbx/dW8W3DuEczpG8/7yTAY9v4in5+zCUlqlOp4QLUad1UZ+WbVrtlo8mYgU/Tszu13N+RugQ4y+PXF6ToniJEIIIVTILq4ixhWDNvXCEvUGOJvNdedsopToAPbJUKoQQrRoNXU28stqXNOUChCaYMhG8YQIf44dr6SqVr7nFEI4n0yaCSH+68By+GAklOfB1bOgxxWqEwkhhBDGZbPB9zdC7k6Y+AlEpalOJIRwM8sy8ugbH4avl1l1FHEqnj5w9pOQsx22TFedRiiUHBXIa5N7sODeoZzXJZaPVh5g8POLeeqnneSWyHCqEM6WV1aN3Y5rWm1OJjwZasuhNFvN+RsgKtCbED9P0nOlKdXIKmusbDxUxJdrD/H+skzeXbqfr9cfZuOhIkqralXHE0K4sVxXN4qHJYK1GkqzXHfOJkqJCuDY8UrKq+tURxFCCKFI3oldb2KCXbTQMSwBig655lyNkBjpj90OhwoqVEcRQrQAHqoDCCEMwG6HddNg/iMQngSXfaW/CiGEEOLUFj4J6XNh9POQMkp1GiGEmzlSWMHe3DIu7R2nOoo4nU4Xw9r3YOE/odOF4B2oOpFQKDEygFcu7c6dI1L4z+J9fLb6EF+uPcxlfeK4ZViya9uphGhBckv0B4gua7X5s/Bk/TU/A4JaqclwGpqmkRodyJ4cGUo1mpo6Gwt35/L1hiMsz8jHajt5467ZpDEgKZyxXWO5oFsr/Lzk8Y0QomHsdjs5JVUM7xDlupOGJeivhZkQ3MZ1522C5KgAAPbnldG1TYjaMEIIIZTIKdYXFEe5agFHaDxsnwl1NeDh5ZpzNkBihH5NPJBfRmqMfMfZYJXHIW8PVJeCTwhEtgefYNWphDA8+VZDiJautgrm3ANbp0P7MXDxe3IBFUIIIU5n8xew8nXofR30u0l1GiGEG1q4OxeAkWnRipOI09I0OPdZ+GAErHgVRv6f6kTCAOIj/HlxYjfuODGc+uXaw8xYf4Srz2rHrcOTCfM3zgMHIZqD+geIyga/I1L014IMSByqJkMDdIgJ5NuNR7HZ7JhMmuo4LZ7dbueXnTk8PXc3R4sqiQnyYergBHq1DaVT62CCfDwwmzQKymrYm1vK+oNFzNuezUPfbeeF+encOCSRq/q3k+FUIcRplVTVUVFjdX1TKuhDqQlDXHfeJkiO0odu9llkKFUIIVoqy4ldbly20DE0Aew2OH4YIpJdc84GSIj0ByAzv1xxEjdgs+nFNGvfg4MrgD8sLtTM0G4A9JkKaePAJJuUC3Ey8m2GEC1Z8VH4+krI2gxDH4ahD8kFUwghhDidgyvgp7shcRiMeUEfVhJCiEZauMdCYqQ/CRH+qqOIhmjTC7pcCqvegl7XQEhb1YmEQbQN9+P5CV25fUQyry/M4KOVB/hq/RGmDk5g6uBEArzlqzchHMFSWt9q46KtFv8ssBV4+kHBfjXnb6DUmCDKa6wcO15JXJif6jgtWm5JFfd9s5UV+/JJjQ7kwym9GZYahfkkw8J+YR7EhfkxMi2ah0ansv5gEW8uyuDZn/fw6aqDPHNxF4anurD9UAjhdnJLFCzeCGoNZi99KNXg2oX74WnWyLCUqY4ihBBCkRxXXyvrG8WLDhpqKDXA24OoQG8y82Qo9W8V7IdZt8PhVfqA8dAHoXVv8A2FigI4ul5vwp05BVr1hPFvQXQn1amFMByZPhOipTq4EqYNg/x9MOlLGP6IDKQKIYQQp1OYqS/oCI2HiZ+C2VN1IiGEGyqtqmVNZgGjpCXVvYz6B2gm+O0fqpMIA4oL8+Olid345e4hDEqO4LUFGQx5YTEfLM+kqtaqOp4Qbi+nuAqzSSPCX9FQqskEYUmQn6Hm/A1Uv/3inpxSxUlattX7Czj/jeVsOlzEU+M6MffOQYxMiz7pQOqfaZpG34QwPr++H9/c1B8/bw+u/Xg9936zhfLqOhekF0K4o2wVjeImM4S0g8IDrjtnE3maTcSH+5ORK0OpQgjRUuWWVONlNhHq56JnOqHx+muR8a6TCRH+HJCm1FPLWADThoNlJ1zwBty+AYY/Cu3Pgbg+kDoaRj4Bd26GC9+F4iP6r9/0uerkQhiOTKAJ0dLY7bB2Gnw2DnyC4YaFkDZWdSohhBDC+CqPw/RJ+v++/GvwDVGZRgjhxpZn5FNrtTOygzReuZXgNjDwTtj5PRxYrjqNMKiU6EDevaoXs24bSMfYIJ6eu5vhLy3hq3WHqbPaVMcTwm3lllQTFeitdkv6iGQo2Kfu/A3w+1BqdoniJC3XV+sOc+WHawny9WTWbQOZMiAeD3PTHsP0TQhj7p2DuGNEMj9uPsaF/1lJZp4MVAkh/iq3fig1yIVDqQBhiW4xlAqQEh3AfvkZKoQQLZalpIqoIG80V+18FxANHr6GvE4mRgbIUOqp7JoFMybpu2TdvAJ6TQHzKXZBMpmh+2Vwy2poNwBm3w6L/+3avEIYnAylCtGS1FbpNeM/PwBJI+GGRRCZqjqVEEIIYXzWOph5jf4FwqQvIDxJdSIhhBtbsDuXYF9PerULVR1FNNbAu/UvJefdD9Za1WmEgXWLC+GLqf2YPrUfUUE+PPz9ds55bRm/7szBbrerjieE28ktqSLa1YM2fxaeAscPQV212hx/I8Dbg7ZhftKUqshnqw/y8PfbGZQcwazbBpISHXjGx/T2MHPfOal8fn0/8suqGf/WSpZn5DkgrRCiOanfkjgqyMWN4mGJ+q5CbvD5NjkqkEMF5bKLgRBCtFA5rr6n1DS9LbXooOvO2UCJEf4UltdwvKJGdRRj2bcQvr0OWveCa+fp3wE3REAkXPEt9LgSlj4Py150bk4h3IgMpQrRUhQdgo9Hw5YvYMiDcNlXelOqEEIIIU5v/kOQuRjGvgrxg1SnEUK4MavNzpL0PIanRja5NUso5OUHY16AvD2w5h3VaYQbGJAcwY+3DuC9q3qhATd+vpHJ09aw/Wix6mhCuBV9KNXFgzZ/Fp4Mdpshm27+qGNsELukKdXlPll5gP+btZNRadFMu7oXgT6O3RZ0YHIEP90xiNahvlz3yXrmbMty6PGFEO7NUlpFqJ8n3h5m1544LBFqy6Hc+MPyyVEB2OxIM5wQQrRQSu4pwxKgyHj3jwkR/gBkyjXxv/LS4ZspENlBHzD1CWrc+80ecMEb0HUSLHoaVv/HOTmFcDOn6BkWQjQr6T/DDzeBHZj0JaSNVZ1ICCGEcB9rp8H6D2DAHdDzKtVphBBubtPhIgrLaxiZFq06imiq1DHQfgwseQ46XwLBrVUnEganaRrndophRIcovlp3mFcXZHDBWyu4qEdrHjg3lVYhvqojCmF4OSVVDEgKVxsiIll/LdgHUR3UZvkbabFB/LIrh/LqOvy95et/V5i/I4en5uzinI7RvHV5T7w8nLPwqE2oH1/f1J+pn67njhmbKams4/J+DWzvEUI0a5aSaqICFTSKhyXor4WZEBDl+vM3QkpUAAD7LGWkxTZy0ES4jM1mZ/3BQlbuy2d3Tin5ZdXY7RDq50lyVABnJYYzOCXSaddaIUTzlVtSzZD2ka49aWgCZC7RG8U1zbXn/huJkfpQ6oG8cnq2lZ28qC6Dr68CTx+4/JvGD6TWM5lh/NtQWwm/PKovbG1/rmOzCuFm5BObEM2ZtQ5++z+YMRlC2sFNS2QgVQghhGiMjAV6S2rqeTDqKdVphBDNwLzt2Xh5mBjewdgP7MRpjHkO7Fb9C0YhGsjTbOKq/vEseWAYtwxLYu72bIa/tIQX5u+htKpWdTwhDKuipo7SqjqiXLnV4smE1w+lZqjNcRppsYHY7bAnp1R1lBZh65Hj3P31ZrrHhfDGZT2cPiQT7OvJZ9f1Y2j7SB79YTvfbDji1PMJIdxDbmk1USoaxcMS9dfCTNefu5ESIvwxaZBhKVMdRZxEVa2Vj1YcYPALi5k0bQ3/WbKfA/nl+Ht5EOjjQU5JNZ+uPsT1n26g/7MLefbn3eQUV6mOLYRwE2XVdZRV1xHt6nvK0HiorYAyi2vPexpxYX6YTRqZ+XJNBGDhU5C/Fy754MzLB8wecNF7EN0Fvr/B8DutCOFsslRaiOaqJBu+vQ4Or4Je18Lo5/TVHUIIIYRoGMse+PZaiOoEF7+vr3IUQogzYLPZ+Xl7DkPbRxIgzWXuLTQeBt8Hi5+BfQsheaTqRMKNBPl48tDoDlzRry0v/ZLO20v2882GI9w9qj2T+8ThYZY15EL8UW5JNQAxqodSfYLBPwry96nNcRodW+mtLruzS+jVTlpvnMlSUsXUzzYQEeDN+1f3xsfTNfeMvl5m3r2yFzd8toGHvtuGt4eJ8d2luV2IliyvpIrkyAjXnzg4DjSzWwyl+niaaRvmxz6LLNowmo2HCrl/5jYO5JfTNyGMB0enMrxDFEE+nv/z66pqrazan89X647wwfIDfLbqELePSOb6QQkuuwYLIdxTbok+xO7ye8r6RvGiAxBonF2zPM0m2ob5cSC/XHUU9Q6vgXXvQ7+bIHGYY47p5QeTPodpQ+Gbq+D638BTdkkSLZN8yy1Ec5S5BN4bDNlb4KJpcMFrMpAqhBBCNEZ5Pky/FDx84LIZ4B2gOpEQohnYfKSInJIqzusSozqKcIQBd+qteXPuhhr5Elc0XptQP16b3INZtw0kMSKAx3/cwdg3V7Ams0B1NCEM5fcHiMEG+G4rIgUKjD2U2jrElyAfD3Znl6iO0qxZbXbu/noLpVW1fDilDxEBrm0o9PE0M+2q3vSJD+Peb7aydG+eS88vhDAOm82ORVVTqocXBLdxmxaw5KhA9klTqmFYbXaen7+HCe+uptZq4/Pr+/LNTf0Z3731XwZSQb/2jegQzbSre7Pk/mEMaR/Bi7+kc94by0mXhnghxN+ov6d0+bUytH4o9aBrz9sAiRH+ZOa18O8za6tg9h36IpsRTzj22GEJcPEHkLMdFj3t2GML4UZkKFWI5sRmg6UvwGcXgm8Y3LAYuk1SnUoIIYRwL3XV8PWVUJqjD6SGxKlOJIRoJuZuy8HLbGJkmnFWxosz4OkDF7wBxw/DomdUpxFurFtcCF/fdBZvX9GT0qo6Jk9bw23TN3HseKXqaEIYQv0DxGgVwzZ/Fp4MBRmqU/wtTdNIiw1ilwylOtU7S/axan8BT43rRGpMoJIMvl5mPpzSm/bRgdz6xUZ2ZhUrySGEUKuoooY6m53oQEXXybBEt2hKBUiJDuBAfjm1VpvqKC1eVa2VO2Zs4p0l+5ncJ475dw9hcEpkg98fF+bHe1f15tPr+lJSWcf4/6zgu41HnZhYCOHOLCd234h2dVNqSBygGXLxRkKEPwcLyrHZ7KqjqLPiFcjfCxe86pximvbnQO/rYfV/9EZWIVogGUoVorkos8CXl+jbR3aZCDcsgqgOqlMJIYQQ7sVuh5/ugsOr4cK3oU1v1YmEEM2EzWbn5x3ZDGkfcdLGD+Gm4gdC7+tg7TtwdKPqNMKNaZrGeV1iWXDvUO4elcKCXbmMfHkJbyzMoKrWqjqeEEr9dyjVAE2p4clQUQAVhaqT/K202CDSc0qxtuQHjE608VARry7IYFy3VlzaW+0ixkAfTz6+pg9Bvp5c+/F6smRBgxAtTu6JQZsoVddJNxpKTY4MoNZq51BBheooLVpZdR1TPlrHvO05PH5+Gs9e3JUAb48mHWto+0jm3TmIbm1CuG/mVl75NR27XT7/CCH+V46qe0oPb71RvMiAQ6mR/lTV2sg+8c+mxSnJhpVvQOdLIHmU885z9lN6E+uPt0Kt3KuJlkeGUoVoDjIWwDsD4NAqGPsaXDxNthkWQgghmmLFq7B1Bgx7BLpMUJ1GCNGMbDl6nOziKs7rEqs6inC0UU9BQAzMvh3qalSnEW7O18vM3aPas/C+oYzoEMUrv+1l1CtLmb8jRx6uihYrp7gaPy9zk4cVHCoiRX8t2K82x2l0bBVERY2VQwUtfDtGJ6iqtfLAt1uJCfLhmYs6o2ma6kjEBPvwybV9qaixcvMXG2UxgxAtjKVUcaN4WAJUHTf8gg3Qm1IB9lnKFCdpuWrqbNzyxUY2HCri9cndmTo48YyPGRXkw5dT+zGxVxveWLSPJ2fvbNnNf0KIv8gtqSLA20PNPWVoPBQddP15TyMxQr8mHshrofeMS58DWx2MeMK55/EOhPFvQeF+WPxv555LCAOSoVQh3FldNcx/RG9I9Y+CG5dA72vBAF+GCiGEEG5n5w+w8Cl9ZeTQh1SnEUI0M3O2ZuNlNjGqY7TqKMLRfIJg7Ctg2QUrX1OdRjQTbUL9ePuKXky/oR/+Xh7c/MVGrvxwLXtzS1VHE8LlckuriAnyMcTwH+H1Q6kZanOcRsfYIAB2Z8vPDEd7c1EGmXnlPHtxFwIN1H6fGhPIq5O6s+1oMY//uEMWMgjRgtRvSRwVqLApFQzZAvdnSZH6AE6GfKZWwm638/B321iekc+zF3dhfPfWDju2h9nECxO6csPgBD5dfYhHf9gu10IhxO9yS6rULd4IjYdC410jEyP9AcjMb4ELNfIzYNPn+u5XYQnOP1/iUOhxJax5G/L2Ov98QhiIDKUK4a7y9sIHI/WLV98b4YZFEJWmOpUQQgjhno5ugB9uhrh+MP5tWeAhhHCoWquNWVuOMTItiiADDS8IB0odoy9qWPoC5GxXnUY0IwOSIph75yCeGteJ7UeLGfP6cv41Zxdl1XWqownhMrnFVUSpeoD4Z6HtwOShP8QysOSoAMwmjV3ZxaqjNCs7s4p5d2kmE3q1YUj7SNVx/uLsjtHcNTKFbzce5bPVh1THEUK4SH1TamSgqqbUE0OpBhy4+TN/bw/ahPqyV5pSlXh9YQbfbz7GvWe359LecQ4/vqZpPHpeGrcNT+Kr9Ud4bv4eh59DCOGeckuqiQ5StXgjAcotUG2sa09UoDf+XmYyW2JT6qJ/gacvDHnAdecc+SR4+sP8h0AWTYgWRIZShXA3djts/BSmDYXiY3DZV3Dei+Cp6IOUEEII4e6KDsGMyRAQDZOnyzVVCOFwS9LzKCivYUKvNqqjCGc67yXwC4Pvb9J3tRDCQTzMJqYMiGfJA8O5tHccH608wMiXlzBnW5a0/4gWob4p1RDMnnrTTcE+1Un+lo+nmeTIAGlKdSCbzc6j328n1M+Lx883bjHAXSNTGJUWxb/m7GJtZoHqOEIIF8gtqSbY1xMfT7OaAKHx+qsbDKUCpEYHsjdHro+utmpfPq8vzODiHq25Y0Sy086jaRr3n5PKFf3a8t7STN5but9p5xJCuA+9KVXRPWX9dfK4sRaNaZpGQqQ/B/Jb2FBq7i7YNQv63wYBLlxoGBAJwx+B/Ytgz1zXnVcIxWQoVQh3UlkEM6fAT3dCmz5wyyq9kUcIIYQQTVNVDNMnQV0NXDET/CNUJxJCNEMzNxwhIsCboQZs1BIO5BcG494Cy05Y9LTqNKIZCvP34tmLu/D9LQOICPDm9umbufqjdWTmGattQwhHstvtalttTiY8xfBDqQBpsYHszi5RHaPZ+HbTUbYeLeax8zsQ4uelOs4pmUwar0zqTttwP26bvvn3BkUhRPNlKVW4JTHoTWOBraAwU12GRmgfE8j+vDJq6myqo7QYeaXV3PX1FhIj/Hn6os5oTt6hStM0/jm+M2O7xvLsz3v4dWeOU88nhDA2u92OReU9ZeiJ7eENuHgjISKAzPwW9p3SqjfB0w/63ez6c/eZCpFp8Mujf1toUF1n5dedOTz1006u/XgdV3ywhtumb+KtRRnszJLdUIR7kaFUIdxF5lJ4Z5C+cmLUk3DVjxAUqzqVEEII4b6sdTDzGijIgEmfQWSq6kRCiGaooKyaRXssXNyzNR5muQVv9tqfA72u0b/gPLRKdRrRTPVoG8rs2wfx1LhObDl8nNGvLeeVX9OpqrWqjiaEwx2vqKWmzmawodQkKNgPNmP/N9exVRDZxVUUldeojuL2SqpqeWH+Hnq2DeHC7q1VxzmtIB9P3rmiF6VVtdz3zVZsNmnVFqI5s5RWExWo+DoZlug2Q6mp0YHU2ewcLGhhzXCK2O127pu5lZLKWv5zRU/8vDxccl6zSeOlid3o1iaYe77ewp4cWagjREtVVFFLjdWmbgFH2Imh1KKDas7/NxIj/DlaVEl1nbHvbR2m+Bhsnwk9rtLLBVzN7AnnPqO35m74+C9/uarWyrRl+xnw7CJu/HwjX607Qm5JNVW1NnYcK+alX/dy/hsrGPfWChbtyZXdk4RbkCdiQhhdTQXMexA+G6dvJ3z9rzDoHjDJf75CCCFEk9nt8PMD+lYZY1+FxGGqEwkhmqlZW7Kos9m5pGcb1VGEq5zzDIS2gx9uhmrZFlI4h9mkMWVAPAvvH8p5XWJ4Y9E+zn51KYv3WFRHE8Khckr0lkdDDaVGpIC1GoqPqk7yt9JigwCkLdUB3lyYQUF5DU+O6+T0djdHSY0J5B8XdGJ5Rj7vLpOti4Voziwl1UQFKmxKBX3gxk2GUttHBwKQniP3aq7w7cajLNubx2Pnp9EhJsil5/bxNDPt6t74e3sw9dMNHK+QhTpCtEQ5xfo9ZYyqe0rfUPAJhiLjNaUmRvpjt8PhggrVUVxj7Ttgt0H/29RlSBoBCUNg2Yv/873xrqwSxr65gn/P20On1sF8cm0ftv7jHObdNZjvbhnA0geGs+HxUfxzfCeKK2u57pMN3PDZRvJKT924KoQRNGmqTdO0lDM5qaZpt57J+4VoMY6sh/cGw7r3oN8tcNNyaN1LdSohhBDC/a15GzZ8BAPvhp5Xq04jhGim7HY732w4Qtc2waTGBKqOI1zFOwAueg+Kj8Dc+1WnEc1cVKAPr03uwfQb+uFlNnHtJ+u58bMNHDteqTqaEA6Re2IoNSZY8bDNH4Wf+Gq8IENtjtOoH0rdJUOpZ+RQQTmfrDrIxF5t6NomRHWcRrmsbxznd43l5V/3svFQkeo4QggnsNvtWEqriFK9eCMsAcotUG38LYATI/0xmzT25spQqrPllVbz9Nzd9IkP5cp+7ZRkiA7yYdrVvcktqeL+mVulVU6IFii3VL+nVHqtDE2AQuMNpSZE+AOwP68FtIdXFcOGT6DTRXqZgCqapu+KXJEPq94CYP6ObC5+ZyUllbV8fG0fPruuL8NSo/Dy+N9xvogAb67uH8+Ce4fy6HkdWJ6Rx7i3VrAzq1jB34gQDdPUqsWNmqZd29g3aZoWrmnabODNJp5XiJahrgYW/hM+OgfqqmHKTzDmOfDyU51MCCGEcH975sIvj0HaOBj5D9VphBDN2JrMQvbklHJ537aqowhXa3sWDH0Itn0FW6arTiNagAFJEfx81xAeHJ3Ksow8Rr28lHeW7KfWalMdTYgzUj+Uqnxb4j8KT9Zf8/epzXEaEQHeRAV6y1DqGXrlt72YTRr3n5OqOkqjaZrGsxd3oVWID3fO2ExxRa3qSEIIByuqqKXWajdAU2riiUDGG7j5Mx9PM/HhftKU6gL/nLOLyhorz17cFZNJXdN497gQHjsvjQW7LXy4wvi/R4UQjpVbXL/7hsJrZVgCFB1Ud/5TqB9KPZDfAoZSt0yHmlIYcIfqJHoJXcfxsPot5q7eyq1fbiItNoi5dw5meGrUad/uaTZx45AkvrtlAAAT313Ngl25zk4tRJM0dSg1APhA07SvNU0LbsgbNE07G9gGjG3iOYVoGXJ2wPsjYPnL0P1yuGWVXuEthBBCiDOXtQW+mwqteugtdqamfhwWQojT+2B5JuH+XlzYo7XqKEKFIQ9A/GCYex/k7VWdRrQAXh4mbh2WzIJ7hzIoJYLn5+/hgjdXsOmwtOMJ95Vbom9FF6XyAeKfBUSBdxAUGHsoFaBjqyB2Z8vQTVPtzCpm1pYsrhuYoL6FsImCfDx587Ke5JZU8cSsHarjCCEczFJaP2ijuin1xFBqYabaHA2UGhMoTalOtjwjj5+2ZnHb8GSSowJUx2HKgHhGd4rhuZ/3sFnuj4RoUX6/p1S50DE0Ho4fBptVXYaTCPTxJDLQmwP5xm86PyN2u75zYps+0Kq76jS6EU9gr63g6NwXOSsxnOlTzyKykYuMOrcOZtZtA0mOCuDmLzayaI8MpgrjaepT+FJAAyYAWzRNG3iqX6hpmqemaS8DPwMxJ/70siaeV4jmy1oHy1+BacOgLBcu+wrG/wd8glQnE0IIIZqH4mMwYzL4hevXWWkgF0I40f68MhbusXDlWe3w8TSrjiNUMJnh4vfB0xdmXgO1sp26cI02oX68f3Vvpl3Vi+MVtVzyziqe+HEHJVXSkCfcT05JFWH+Xnh7GOhaqmkQngQFGaqTnFZabBAZuaVU1xnr4ae7ePGXdIJ9PblpaJLqKGeke1wId41MYfbWLH7amqU6jhDCgQyzeCM0QX814NbEJ9M+OpBDhRVU1sj10RmsNjtPz9lN2zA/bh6WqDoOoLeHPz+hK9FBPtz7zVYqaupURxJCuEhuaRXh/l5/2QrdpUITwFYLxUfVZTiFhAh/MvOaeVPqoVWQvxd6NXozcKfJsMbws60/V3v8xgcTE/D1atp3HlFBPnw5tR9psUHc8sUmVu8vcHBSIc5MU3/ydgfWog+mtgOWaJr2lKZp/3M8TdNST/y6u0+cywo8Doxo4nmFaJ6yt8EHI2HhU9DhfLh1DaSOUZ1KCCGEaD6qy2D6JP318q8hMFp1IiFEM/fRigN4eZi48qx2qqMIlYJi9WZuy06Y/4jqNKKFOadTDAvuG8qU/vF8sfYQo15eys/bs7Hb7aqjCdFglpIq9e1vJxOeAvnGb0rt3CqYOpudjNxm3nzjBOsOFLIkPY9bhiUR7OupOs4Zu2VYEj3ahvD4jzvIObGFqRDC/VlKTjSlqmx/A71cxS/CfZpSowOx22GfRa6PzvDNhiOk55byyJgOhlpYFOzryUsTu3Egv5wX5qerjiOEcJHcYgPcU4bG669FB1WmOKmkSH8O5DfzodQNH4F3MHS6SHUSAI5X1DD1sw184jEBX6rw2zjtjI4X6OPJp9f1JS7Mj6mfrmdPTomDkgpx5po0lGq32w8Ag4Bn0AdNzejDpss0TWsHoGnaTcBGoBv68GomMMhut//bLt9+C6GrrYIFT+ntqCVZcOlncOmn4B+uOpkQQgjRfNis8N31+kDQxE8gupPqREKIZq6wvIbvNh3lou6tG73tjmiGUs6GgXfBxo9hy3TVaUQLE+DtwZPjOvHjrQOJCPDmli83ccNnGzh2XJp7hXvIKakiWnX728lEpEDJUaipUJ3kb3Vure/AtONYseIk7ue1BXuJDPRmSv941VEcwsNs4pVLu1NTZ+OBb7fKAgUhmglLqUGaUgHCEt1mKLV9TCAA6bmlipM0P2XVdbz8azp94kMZ3Tnm9G9wsf5J4Vw7MJ5PVh1k5b581XGEEC6QW2qAe8qwE43iRcZrFE+I8KegvIbiima6u055PuyeDd0mG2b3xMd+2EHW8Uoeuvoi6Dge1k2DyuNndMwwfy++uL4fft4e3Pz5Roorm+m/T+F2mtxRbbfbrXa7/Qn01tPD6IOn/YEtmqYtAN4G/E78+c+B7na7fd2ZRxaimTi0Ct4dCCtegW6XwW1r9YuOEEIIIRzrl8dg73wY8wKkjFKdRgjRAryzZB/VdTZuGJKgOoowihH/BwlDYM49kLVFdRrRAnWLC2H27QN57Lw0Vu4r4OxXlvLB8kzqrDbV0YT4W7kl1cSobrU5mfBk/bXA2G2pbcP8CPTxYLsMpTbKhoOFrNpfwE1DEpu8jaIRJUT48+j5aSzPyOfzNYdUxxFCOIClpIogHw98PA3wsyos0ZANcCfTLswPLw8Te2Uo1eHeW7qf/LIaHj+/I5qmqY5zUg+N7kBipD8PfruNkioZ2hGiucsprlbflBrUGkyehrxOJkQEAJCZ30zbw7d8CdYa6H2t6iQAzNmWxdzt2dw9qj292oXCkAegugTWvnfGx44J9uHtK3pytKiS+77Zgs0mCxGFek0eSq1nt9uXo7ehfoM+gBoMDD/xv48Dl9vt9il2u72Z/hQTopGqSmDOvfDxGP0CeNUPcOF/wC9MdTIhhBCi+Vn3Pqx9B866FfreoDqNEKIFyC6u5NPVh7ioR2uSowJVxxFGYfaACR/rW1p+fSWUF6hOJFogD7OJG4Yk8us9Q+iXEMbTc3dz4dsr2X5UhtWEMdVabeSXVROl+gHiyUSm6q/5e9XmOA1N0+jcKpgdWbJ9X2O8sWgfEQFeXNGvneooDndlv7YMbR/Jv+ftZn+ePLIRwt1ZSg10nQxLgOKj+g6BBudhNpEcGUB6jgylOlJheQ0frTjA+V1j6RYXojrOKfl4mnl5Yjeyiyv510+7VMcRQjhRrdVGQbkBhlJNZghpC4XGa0pNjPQH4EB+ueIkTmC367tWxfWDqDTVacgrreaJH3fQrU0wNw1J1P9kTBdIPQ/WvK3PEZ2hPvFhPH5+Ggt2W5i23D0a7EXzdsZDqSeUAfU/Qe1/eN0A/OKgcwjh/tLnw9tn6ds2nnUb3LoGkkaoTiWEEEI0T3t/hZ8fhPZj4JynVacRQrQQz87bA8A9o9orTiIMxz8CJn0OZRb49lqw1qlOJFqouDA/PrqmD29d3oPckmrG/2cF//xpF+XV8ntSGEt+WTV2O8ZsSg1LAs0Eeemqk5xW59ZB7M4uoVaakRtk8+Eilu3N44bBzasltZ6mabwwoSs+nmbu/XqLNGYL4eZyS6qIClS8JXG9sETADsfdo4k5NSZQmlIdbNqyTCpqrdwzKkV1lNPq0TaUm4cmMXPjUVZk5KuOI4RwkrxS/Z5S+VAq6Is3iow3lBoX6ofZpDXPodTsrZC3B7pOUp0EgCd/2kl5tZWXJnbDw/yHUb0hD0DVcVj/vkPOM2VAPOd1ieHlX9PZnS0LVIVaZzyUqmlaErAKeBB9ENUKlKA3pY4EtmqaNvxMzyOEWys+Bt9cDTMmgU8wXP8bjP43ePmrTiaEEEI0Tzk79IGf6E5wyQf6SlQhhHCyVfvymb01i5uHJhEX5qc6jjCi1j1h7CtwYCn89oTqNKIF0zSNsV1bseDeoVzWty0frTzA2a8sZcGuXNXRhPhdTrHetBYdZJBhmz/y9IGQdpDvDkOpwdTU2dhnkVbMhvjP4n2E+nly5VnNryW1XnSQD09f2JmtR4t5d+l+1XGEEGfAUmqA9rd6YScavwrdo5WrfXQg2cVVFFfK9u2OkF9WzaerDjKuWyu32TXmzpEpJET489iP26mqtaqOI4RwgtwS/Z4yJtgA95ShCVB0UHWKv/DyMBEX6ktmXjMcSt32NZi9oNNFqpOw7kAhc7dlc+vwJFKi/3SdbN0TkkfBmncc0jivaRpPX9iFYF8v7vl6C9V1co0T6pzRUKqmadcAm4He6EOoB4EhQBdg+Yk/1wb4TdO05zRN8ziT8wnhdqy1sPINeKsP7P0FRjwONy6FNr1VJxNCCCGar9IcmD4JvAPhsq/BO0B1IiFEC1BcUcv9M7cSH+7HrcOSVMcRRtbjSuh3i74t0/oPVacRLVywryfPXNSF727pT4CPB1M/28AtX2z8/cGNECrlllQDBmm1OZnIVMjPUJ3itDq3DgZgx7FixUmMLyO3lAW7LUwZEI+/d/N+lDG2aysu6NaK1xdmsDNLfm8I4Y7sdjuWkmrjNKWGJuivBtya+GRSY/TvCzOkLdUh3lu6n+o6K3eONH5Laj0fTzPPXNiZQwUVvLnI+J/phBCNV39PGRVogHvK0HioKoaKQtVJ/iIxMoDM5taUaq2D7d9CyjngF6Y0is1m5+m5u4gJ8uGmIad4bjDwLijPg21fOeScYf5ePH9JF/bklPLaArnGCXWaNJSqaVqQpmlfAR8CAejDp9OB7na7fY3dbj8KDAOeAOpOnOcBYLWmae7zaVSIM3FoFbw7WG+/SRgMt63Vq7c9vFQnE0IIIZqvmnKYMRkqC+HyryG4tepEQogWwGqzc9/MrVhKq3l9cg98PKWdWZzGuc9Ayrkw7wHYt1B1GiHo1S6MOXcM5oFzU1m0x8Kol5fy+eqD2Gx21dFEC1Y/HG3YodSI9lCwT3/YZWAJ4f74e5nZmSXb9p3OB8sP4O1h4ur+8aqjuMQ/x3UixM+L+77ZKu05Qrih4spaaqw2ooxynfQLA+9gt2pKBUiXodQzVlRewxdrDjO+e2uSIt2rHGBAcgSX9GzDe0szSc+R3wtCNDeGuqcMO7F4o8h4izcSIvw5kF/WvL4DylwC5RboNll1En7ccoxtR4t5cHQqvl6neG4QPxhiu8Gqt8Bmc8h5R6ZFM7FXG6Yty2R3tnwfINRoalPqdmAi+jBqKXC13W6/0m63//5pza57BhgMZJ74tT2BTZqmTT2z2EIYWFke/HALfDxGH4yZPEMfigmNV51MCCGEaN5sNvj+RsjaApd8qN/ACSGEg2w+XERRec1f/nyd1cYj329jwe5cnhjbkW5xIa4PJ9yPyQwTPoSoNJh5DVj2qE4kBF4eJm4bnswvdw+hW1wIT8zaySXvrmJPjnxxLdTILanCw6QR7m/QBd6RqWCtgeOHVCf5WyaTRsdWQWyXptS/ZSmp4ofNx5jYuw1hRv0952Ch/l48e5HenvPGQmnPEcLd/Lf9zSBNqZqmD9y4yVBq6xBf/L3M7JVBxDP26eqDVNZaucVNd4157Pw0An08ePSH7c1rIEsIYax7yvpG8aKDSmOcTGKkP1W1NnKa0645274CnxC9KVWhqlorL8xPp2ubYC7s/jclOpoGA+6EggzI+MVh53/s/DSCfT15TK5xQpGmDqXGoQ+ZrkVvR/3iVL/QbrevBboDX5x4jz/wXhPPK4Rx2az61otv9YLtM2HQvXDbGuhwnupkQgghRMuw4B+wZw6c+2+5/gohHKqmzsatX25izOvLWbzHgt2uf4Gzz1LK1R+t45sNR7lzRDJTBsSrDSrci3egvoDR0xe+nAgl2aoTCQFAfIQ/n1/fl1cu7cahggrGvrGCF+bvoapWWvSEa+WUVBEV6I3JpKmOcnIRqfprXrraHA3QuXUwu7JKsMpDqFP6dPVBam02pg5KVB3FpUZ11Ntz3lmyn82Hi1THEUI0gqXUQO1v9cISDNkAdzKaptE+JlCaUs9QZY2VT1cdZGSHqN/bZ91NmL8Xj53fkY2Hipix/rDqOEIIBzLUPWVoO/210HjXyYQIfwAO5JcrTuIgNeWwZy50ugg81C7e+WrdYXJKqnh4TIfT/z7seCEEx8HKNxx2/hA/Lx47L41Nh4/zzYYjDjuuEA3V1KFUG/AMMMhut5/2p6bdbi+z2+1XA1cAUq8gmp+DK2DaMJh7r97KdssqGPUP8PJXnUwIIYRoGTZ8BKvegN7Xw1m3qE4jhGhmvDxMvH91b/y8zFz7yXr6PLOQQc8vYtQry9h2tJgXLunKveekqo4p3FFwG30wtbIQvpwAlcdVJxIC0B/SX9yzDQvuHcqFPVrz9pL9nPvaMlbuy1cdTbQglpJqooMNNGjzZ5Ht9dd8NxhKbRVMZa2VA/llqqOoV5INGz+BH2+Dj8bAO4OwvjecHqvv5KXYJcTXHQB7yxrefeKCjsQE+XDfzK2yAEEIN2K4plSAsEQ4fhistaqTNEhqdCDpOaW/LzwVjffNhiMUVdRys5u2pNa7pGdr+iWE8eIv6SfdJUcI4Z4MdU/p5Q8B0YZcvJEYEQBAZl4zuV/M+A1qK6DzxUpjVNdZeXdpJn3jw+ifGH76N5g94Kxb4fAqOLrBYTkuPnGNe/bnPRTKNU64WFOHUkfY7fYn7HZ7o76hsNvtM9BbU1c18bxCGEvRQfj6KvjkfKgo1LcKvnr2f7+UFkIIIYTz7f0V5t6nb8Mx5gV9mwshhHCwzq2DmXfXYF6d1I1hqZF0jwvhkTEdWPrAMC7tE6c6nnBnrXrApM/1pr2vLofaStWJhPhdmL8XL03sxvSp/dCAKz5Yy71fb6GgrFp1NNEC5JRUER1okAeIJ+MTDAExkLdXdZLT6tw6GIDtx4oVJ1Ho6EaYcRm82hF+ugv2ztf/fEgcOTXeJNoOcUnhNHh3ILzZEzZ8DLXNaPvKvxHk48kLE7qRmVfOi78Yf8haCKGrb0qNCjLYUKqtDordo4mrfXQgRRW15JfJgEZT1FltvL88k55tQ+jdLlR1nDOiaRpPjutEaVUdL/8m10IhmovcE02phhEaD0WHVKf4i+ggb/y8zGQ2l6bUXT+CXwS0G6g0xswNR8kpqeKOkcloDX1u2fMq8A6GVW86LIemaTx9YWfKqut4bYHxv78QzUuThlLtdvuypp7QbrcfBIY29f1CGEJ1KSx4Et7qA/sWwPDH4I4N0GWCDMIIIYQQrpS9FWZeA9GdYcLH+kpCIYRwEh9PMxf1aMNLE7vx1uU9uWloEuEBBvpiU7ivpBFw0btwaBV8NxVs0lImjGVAcgTz7x7C7cOTmb01i1GvLOXbjUelVUo4VW5xFTFGabU5lYgUt2hKTYr0x8fTxI5jLXATs6oSfQj1gxFweA0MugduXQMP7IPrfqbu0i+ZVPEgD8Z+DPfsgnFv6gPHc+6GN7rDzh9bRHPqoJQIru7fjo9WHmBNZoHqOEKIBrCUVBPo7YGfl4G+CwtL1F8NuDXxyaTG6NvN780tVZzEPf2yM5ejRZXcOCSp4cM2BpYWG8RVZ7Vj+trD7MxqwQt5hGhGLKXVRBlpoWNogiGvkZqmkRDhz4HmMJRaU6EX2aRdACazuhh1Nt5Zsp8ebUMYlBzR8Dd6B0Kf62D3bIf+XkmJDuSKfm35cu1hMuRzj3ChpjalnhG73W5TcV4hzpjNBps+hzd6wopXodPFcMdGGPogePqqTieEEEK0LMVHYfok8A2Fy78B7wDViYQQQoim6zIBxjwPe+bArNv0+08hDMTH08z956Yy767BJEYGcP/MrVzxwdrm8dBCGE55dR2l1XXGan87mchUyM8w/NCih9lEWmwQO1paU+qh1fB2f9j0GQy4A+7eBiP/D6LSfi8W+G2XPlAzdXAiBLeGnlfDDYv13bAComDmFPj6SiizKP6bcb6Hx3SgbZgfD3y7lbLqOtVxhBCnYSmtItJo18nQBP21MFNtjgZqH60PpabnyHBGU3y66iBxYb6c3TFadRSHuWdUe0L8vHhy9k5ZgCeEm6uqtVJcWWusptSwBCg5BnXG230mIcKfzLxm8P3OvgVQWw6dLlQaY9aWYxw7XsmdI1Iav3Cj702gmWDd+w7NdPeo9vh5mXlm3m6HHleIv6NkKFUIt3RwBbw/DGbfrlerT10EF78HQa1UJxNCCCFanqpi+PJSqCmHK76BoFjViYQQQogz1+8mfSeOrTPgpztlMFUYUvvoQGbe1J+nL+zM9qPFnPvaMt5alEFNnfx+FY6TW6JvSRwTZKBWm5OJSIXqEijNUZ3ktDq3CmZnVgk2WwsZsNjxPXw2Djy84bpf4Zyn9daZP/l09UFah/xpoEbTIHGo/v3v2f/UH2xOGwZZm12XXwE/Lw9entiNo0WV/FseVApheJaSaqKN1P4GEBgDHr6GbIE7mYgAL8L8vaQptQl2ZhWz7mAhU/rHYza5f0tqvWA/Tx48N5X1B4uYtSVLdRwhxBnIL9MHPw210DE0HrDD8cOqk/xFYmQAR4sqqK5z892bds0C3zBoN0hZBLvdzocrDpAaHciw1MjGHyAoFjpeCJs/13dwdpAwfy/uHJHCkvQ8lu7Nc9hxhfg7MpQqxOnk7IAvJ8In50N5Plz8AVz/K7TppTqZEEII0TJZa+GbKfo2mZd+BtGdVCcSQgghHGfogzDkAf2Lx3n3Gb59T7RMJpPGlWe1Y8F9QxmVFsVLv+5l7JvL2XioUHU00UzklugPEKONPpQa2V5/zU9Xm6MBurQOpqy6jkOFFaqjON+ad+Dba6FVT5i6AOL6nPSX7ckpYU1mIVf1b3fygRqzBwy8C67/TW+q+Wg07PjOyeHV6h0fxg2DE5m+9rA8qBTC4HJLq4w1aAP6UH9Yots0pWqaRvvoANJlKPWvSnNg+7ew+N8w/xFY+C/Y/CUU7Af0llRfTzMTe8cpDup4l/aOo2ubYP49b7c0hwvhxiylJ4ZSjbSA4/dGceMt3kiM8MdmhyPufL9YWwV750PaWP1eTpE1mYXsySnlukHxjW9JrdfvZn0B7NavHJrt6gHtiAvz5YX5e1rOglWhlAylCnEqRYfg+5vg3UFwZK2+Kv6OjdB14u/bOwkhhBDCxex2mHMPZC6GC16HpOGqEwkhhBCON/wxGHg3bPgIfn5QBlOFYUUH+fD2Fb34cEpvyqrquOSd1Tz2w3aKK2tVRxNurr4p1fBDqRGp+mveXrU5GqBT6yAAth8rVpzEyTZ9BvMfhg5j4epZ4Bd2yl/62epDeHuYmHS6gZrYrnDjEn3I9dvr9XM0Y/ee3Z7kqAAe+nab/DwXwqDsdrvelGrE62RYAhQZb9jmVFKjA9mbUypbtYN+35mxAD69AF5Ohe+uh6Uv6Ne9Fa/ArFvhzZ7U/WcAXls/Z2L3KIJ9PVWndjiTSeOpcZ2wlFbz5sIM1XGEEE1kObHQMTLQQAs4wk4MpRYdVBrjZBIi/AHYn1euOMkZ2L8Iasqg43ilMT5aeYAwfy/Gd2/d9IPE9YHWvWDtuw7dScvbw8zdI9uzM6uE+TuNv+OLcH8ylCrEn5UX6Kv+3uoNu36EgXfCXVv1VfGevqrTCSGEEC3b8pf05rghD0KPK1WnEUIIIZxD02DUk3DWbbBuGvzymAymCkMbmRbNb/cO5fpBCcxYd5hRryxl7rZsebgvmuy/Q6kGeoB4MoEx4B3kFk2pKVGBeHmY2H70uOoozrNnHvx0FySNhAkfg+eph7WKK2v5YdMxxndvRai/1+mP7R8BV30PySNh9h2w7n0HBjcWH08zr1zajbyyap76aafqOEKIkyiprKO6zkaUkQZt6oUl6A1wDhygcKb2MYGU11g5WlSpOopax4/AF5fAl5fobajDHoWblsHjFnj0GDyRD7etg9HPU1RZx9Pm93n80DVwcIXq5E7Ro20oE3u14cMVB9ifV6Y6jhCiCfJK9XtKQ10r/SPB09+QizcSIvWh1AP5bjyUmj5Pvz+PH6IswuGCChbszuXyvm3x8TSf2cH63QIF+/RhWwe6sEdrkqMCePnXdOqs7vF5TbgvGUoVol5NOSx9EV7vpq846DoJ7tikN6T6hqpOJ4QQQoht38Cip/Vr9PBHVacRQgghnEvT4NxnoO9NsOY/MPc+t3mwK1omf28PnhjbkVm3DSIq0Jvbpm9i6qcbOHa8hT/gF02SU1KFv5eZQB+Dt29pGkS0hzzjD6V6eZjo1CqIrUeaaVNq1hb49lpo1QMu/Qw8/n7QdOaGI1TWWrm6f3zDz+HpC5OnQ+p5MO9+fRvjZqprmxBuG57M95uO8Ys06AhhOJb6QRtDNqUmgrUaSrNUJ2mQDjGBAOzJKVWcRKH9i+GdgfqukaOf04t6hj0Esd3+ez01mSEyFVvfm7jY+izPhf0TL7MJPhmrL6Ksq1b79+AED47ugI+nmafn7FIdRQjRBJbSakwahAcYaChV0yA0Xl+8YTBBPp5EBHhzwF2bUm022Dsfkked9l7QmT5ZdRCzpnFV/3ZnfrCO4yEgBta+c+bH+gOzSeO+s9uzP6+cHzYfc+ixhfgzGUoVoqYCVr2pD6MufhoSh8Kta2D8WxB8BpXaQgghhHCcgytg1m0QPxjGvanfvAshhBDNnabBmOf1nTs2fAg/3ARW2UZXGFuXNsHMum0gj52Xxqr9BZz9ylI+XHEAq01aU0XDGXZL4pOJTIX8vapTNEi3NiFsP1bc/NpQKo/DN1eDXzhc/g14B/ztL7fZ7Hy+5hC924XSuXVw487l4Q0TP4XEYXpjasZvTY5tdLcPT6ZTqyAe+2E7BWXNb9hICHeWe2JLYkO1v9ULPbE1sQEHbk4mNSYITYPd2SWqo6ix7Rv4coL+PPTm5XDWLWA+9aKgVfsLOFJURdqQCXDzCuh9Lax+Cz6/SL8eNyORgd7cOTKZxel5LN5jUR1HCNFIlpJqwgO8MZsM9iwpNB6KDqpOcVKJEf5k5rtpO/SxjVCepy8gVKSq1sq3G48wpkusY77P8PCCPtfDvgWQn3Hmx/uD0Z1j6NI6mNcWZFBdZ3XosYX4IxlKFS1XTQWsegte7wq/Pg7RneD632Dyl/qXyUIIIYQwhry98NUV+s36pM/1h4BCCCFES6FpMOopGPEEbP8GvpkCtVWqUwnxtzzMJm4Yksiv9wyhb0IY/5qziwv/s5Idx5ppQ6NwuJySKvcZSo1oD2W5bjGI0aNtCJW1VvbmuumDxpOx2/UFjCXHYOIn4B9x2rcszcjjUEEFVw+Ib9o5Pbzg0s/175O/uRqObWracQzOy8PEy5d2o6Syjidm7cBul8UFQhhFfVOqIa+VYYn6a2Gm2hwNFODtQbswv5Y5lLr+A/j+BmjbH679+b//7v7GjHWHCfHz5NxOMfoikLGvwiUfwpF18PEYKG5ejWvXDEggIcKff83dRU1dM1vUI0QzZymtMubijbAEfSjVgJ+tEyP9OZDvpk2p6fNAM0PKKGURft6RTUlVHZf1jXPcQXtdA2YvWPue444JaJrGfee059jxSr5ef8Shxxbij2QoVbQ8tZWw+m14ozv8+hhEdYRr58PVsyCur+p0QgghhPijMou+Wt/sCVfMBN9Q1YmEEEII19M0GHI/nPcSpM+F6ZdCdTMaKBLNVlyYHx9f04c3L+tBdnEV495awdNzdlFeXac6mjC43JIqooMM+ADxZCLa668Obi5xhu5xIQBsOXJcaQ6HWvse7JkDZ/+rwd/tfr76EJGB3ozuFNP08/oEwRXfgl+EvoiyrHk2qHWICeKes9szb3sOP23LVh1HCHGCpdTATanBbcDk6TZDqQBpsUEtbyh1zzyY9wC0Hw1Xfge+Iad9S35ZNb/uyuGSnm3w8TT/9y90maAf4/gRfTC1NMd5uV3My8PEE2PTyMwr57PVB1XHEUI0gqW02pjXydB4qKs05M/KhAh/8stqKK50w12a0n+GdgOUPkOcse4I8eF+9E8Md9xBA6Kg8wTYMt3hC2GHto+kb3wYby7aR1WttKUK55ChVNFy1FbCmnfg9W7wyyN6G+o182DKbGjXX3U6IYQQQvxZTQXMmKw/3Lv8a/1mXQghhGjJ+t4AF70HB1fApxdAWZ7qREKclqZpXNCtFQvvHcqkPm35YMUBznl1mWyBKU7JbrdjKakmOtiA7W8nU7/jUn662hwN0DbMj1A/T7Y2l6HUwkxY8CSknKNvN9wAWccrWZJuYVLvOLw8zvDxSGC0vutWZZHemFpXc2bHM6gbhyTSo20IT/y4A0uJtLULYQS5JVX4e5nx9/ZQHeWvTGb9Ozw3G0o9VFhBWUtZOHVsI3x7HcR2hwkfNXhXqu82HqXWaj95A1ziUL38pzwfvrjELRrkG2p4ahRD20fy+sIM8suqVccRQjRQXmk1UYEGvKcMTdBfiw4qjXEyCRH+AO7XllqYCXm7IfU8ZRH255Wx7kAhk/q0RdM0xx68301QWw5bvnToYTVN456z25NXWi1tqcJpZChVNH9VJbDydX0Ydf7DenvBNXNhyk8QP1B1OiGEEEKcjM2qbx91bBNM+BBa91KdSAghhDCGbpNh0hdg2Q0fjHSLZj4hAIL9PHn24i7MvLk/vl5mrv1kPbdN3/T79rNC1CuqqKXGaiPaiA8QTyaknb6dXp7xh1I1TaNbXEjzaEq12WD2nfquGmNf01vFG+CbDUew2WFSHwdtqRjbFca/BYdXwy+POuaYBmM2abw8sRvVdVYe/n47dgNuNSpES2MprSY6yMDXybAEKDqgOkWDpcUGYbdDek4LaEstL9AbvgOi9BIAL/8Gvc1utzNj3WH6xIeSHBV48l/UphdM/kL/TDRjMtQ2j8/5mqbxxNg0KmusvPzrXtVxhBANYLXZyS+rJsqIu2+E1Q+lGu86mRgZAMCBfDfbnSl9vv6aOlpZhK/WHcbDpDGhVxvHH7xVd2jbX98lxObYRtOzEsPo3S6Ud5fup6bO5tBjCwEylCqas7I8WPhPeLUz/PZ/EJUGU+bANXMgfpDqdEIIIYT4O78+oW/BOPo56HC+6jRCCCGEsXQ4T7+3rSmHD8+Gw2tUJxKiwfrEhzH3zkHce3Z7ftuZy6iXlzJ97WFsNhlyErqcYn2AIcZdmlLNHhCeDPnuMaTQPS6EvZZS92+D2/QJHFwO5zwNwa0b9Barzc43648wOCWCuDA/x2XpMgH63w7r34fdPznuuAaSGBnAQ6M7sGiPhZkbjqqOI0SLZympItKIWxLXC0uEwgPgJkPsHVsFAbAru1RxEiez22H2HXqb6aQv9MHUBlqdWcDBggou69v2739h0gi4+D19scbPD55hYONIjgrk6v7xfLX+MDuOFauOI4Q4jYLyamx2iDLitTI4DjSTIZtS24b5YTZpZOa5WVNqxi8Qkap//lCgps7Gd5uOMSot2nmfz/rdDMcPwd5fHHpYTdO4fUQy2cVVfL9J7vOE48lQqmh+ig7B3Pvhtc6w/BVIGgY3LNa3jUgYrDqdEEIIIU5n7Xuw5j9w1q1w1s2q0wghhBDG1KY3TP0NfMPg03Gw8wfViYRoMG8PM3eOTOHnuweTFhvEoz9sZ9K01eyzNPNBANEguSfac6ON2GpzKhHt3aIpFaBbXAh2O2w7elx1lKYrz4ffnoSEodDz6ga/bdnePLKKq04/UNMUI/+hb4M8+w4oPub44xvAlP7xnJUYxj/n7OJoUYXqOEK0aMZvSk2EmjIoz1OdpEFaBfsQ5OPB7uxm3pS66VNInwuj/qE3fTfCV+uOEOTjwXldYk//iztfAoPu1c+38ZOmZTWgu0alEOrnxT9/2iWt4UIYnKWkGsCYCzg8vCCojb54w2C8PEy0C/Njn8WNmlJryuHQKkg5W1mEJekWCstruLSPE1pS63UYq/++WfuOww89tH0kXVoH8/aS/dRZpS3VqWw2qCiEgv2QtQWyNuuvhZn6DuDN8POFh+oAQjhM7i5Y+Rps/1ZfXdJtMgy8CyJSVCcTQgghREPtmQs/P6TfYJ3ztOo0QgghhLGFJcLUBfrWiDOv0b/AGnRvg7cwFkK1pMgAvrrxLGZuOMoz83Yz5vXl3Dw0iduGJ+PjaVYdTyiSW1w/lGrgYZs/i0yF3bP1bWo9jZ27e5sQALYeKWZAUoTaME216Gl92Om8Fxt1zZux7jDh/l6MSot2fCYPL5jwEbw7GL6/EabMBlPz+jlmMmm8OKEbo19bxoPfbuOL6/thMslnDiFczW63k1tSZcz2t3qhJ7YmLjzQqDZOVTRNIy02qHkPpRYegPmPQOIwOOu2Rr31eEUN83fkcHm/tg3/jD7iccjeAvMegJiu0LpnoyMbTbCvJ/efk8qjP2xn7vZsxnZtpTqSEOIU8krrh1INem8WFg9FxhtKBUiKCnCvodSDK8BaA8mjlEX4ftMxIgK8GJIS6byTmD2g71RY8CTk7oToTg47dH1b6k2fb+SnbVlc1MOJw7UtibUOjm2Eo+vg6AZ9d53CA1BXeer3mL0gMEZv/o1MhcgO+meoyDQwuWfnqAylCvdms8G+BbDmbchcDJ7+cNYt0P82CJKbASGEEMKtHN0I303VP2Bf/H6ze4AnhBBCOIVfmL4zyKzbYeE/9S8mx70FXg7cllgIJ9I0jUv7xDEiLYpn5u7mzUX7qKix8sTYjqqjCUVyT7TaRBn1AeLJRLQHuw0K9kFMZ9Vp/laovxftwv3YcqRIdZSmydmuN6/1vUl/SNNAlpIqFu6xMHVQAl4eTnqYE56kD8rOuhVWvAJDHnDOeRSKC/PjibEdefj77Xyx9hBX949XHUmIFqe0uo6qWpuxF2/Ub59bmAlt+6nN0kBpsUF8vf4IVpsdc3MbuLfb9RIAzQTj3270UMNPW7OosdqY0KsRQyomM1zyob5Y47upcPNy8PJvZHDjmdQnjs/XHOLZeXsY2SEaXy/5/loII7Kc2H3DsAs4QhMgfZ7qFCeVHBXA4j0Waq02PM1uMAS3bwF4+kHb/kpOf7yihkV7LFx5Vjs8nP3Pq+cUWPI8rH0Xxr3p0EOfnRZNanQg/1m8n/HdWsviw6ay1kLGb7Dze/216rj+50PaQnRnSBoBwW3ANxS8AwFN/y6pulTfYaAiX995JT8dDiwDq/79GD7BENcP2g2A5LP1oWQ3KaWQoVThnmrKYct0/QduwT4IjIWR/we9rtUfyAkhhBDCvRTsh+kT9faEy76SQRohhBCiMTx94ZIP9C+kFv5Tv0+ePF3/kksINxER4M2rk7ozsVcbUqIDVccRCuWUVBHu7+W8wUFnqB+OzN9r+KFUgO5xIazJLFAdo/Hsdr3lzScEhj3UqLfO3HgUq83OpD5xzslWr/vlsH8hLH5Wf+DUupdzz6fApD5xzN+Zw7/n7WZgcgRJkQGqIwnRolhKTgzaBBl00Ab0B++aSR9KdRMdWwVRWWvlUEE5ic3t51r6z5Dxi74rVXDrRr/9203H6BATSKdWQY17o18YXPQufHoB/PIYXPBao89tNGaTxj8u6MjkaWuYtiyTu0bJbp1CGJGlpL4p1aDXytB4fQCtuvTEYJpxpEQFUGezc6igguQoN7ge7lsA8YOV7VgyZ1s2NVYbF/ds/PW10fzCoNsk2PoVjHwS/MMddmiTSeO2EcncOWMzv+zMYUyXWIcdu0Uos+iza5u/gLJc8A2D1POg/bn6IGlTdg6wWfVm1aPr4fBqOLwGMn7V23KD2kDK2frxE4Ya+pm6G32zJwRw/Aj89n/wShrMux+8g/SVdndvh8H3yUCqEEII4Y7K8uCLS/QHjFd85xbbegkhhBCGo2kw+F59cUdBJkwbpn9ZJYSbGZAcYdwHR8IlLCVVRBm5/e1kwpMBTR9KdQPd40LILakmp7hKdZTG2fsLHFwOwx/Vm0UayGaz89X6w5yVGOb8QSNNg7GvQkA0/Hgb1FU793wKaJrG85d0xdfTzO3TN1NVa1UdSYgWxeIOjeIeXvoCOYNuTXwyHWP1gcvd2aWKkzhYTYXekhqZBv1ubvTb91nK2HrkOJf0bIPWlEauhMEw4A7Y+LE+HNsMnJUYzvldYnln6T6OHf+bLXiFEMpYSqsJ9vXEx9OgbcZhCfpr0UGlMU6mfhB1n6VMcZIGKNivL4BJHqUswg+bj9E+OqDxCzeaqt/NUFcFmz5x+KHP7xJLQoQ/by7ah91ud/jxm6WSbJh7H7zWBZa/Aq166kUR9++Fi96BThc2/Zm3yQwRydD9Mhj3Bty+Du7do7fktuoO22fCjMnwYhLMvBZ2zYZa430ukaFUYXx2OxxcAd9Mgde7wao3IXE4XPcr3LAIukwAs6fqlEIIIYRoippymH4plObA5d/oH7CFEEII0XSpo+GGhXrTwidjYdNnqhMJIUSj5JRUEWPk9reT8fTVW+Hy0lUnaZBucSEAbDlSpDZIY9hssPhpfavLXtc06q2r9hdwpLCSy/q2dU62P/MJ1tvg8nbDspdcc04Xiw7y4ZVJ3dmdXcI/5+xSHUeIFsVSemIo1ejXyrBEt2pKTY4KwGzS2J1dojqKY616E4oPw/kvNelZ6vebjmI2aYzv0arpGUY8DtFd4Ke7oNKNPnv8jYfHdMBuh+d+3qM6ihDiJPJKq4ky8mLX0Hj91YBDqfW7IOyzuMEijf2L9NfkkUpOfzC/nI2HirioRxMXbjRFVBokDoN1H+hbxTuQ2aRx67AkdmWXsDjd4tBjNzs15bDkOXizJ2z8FLpMhNs3wOVfQYfznTe/FhQLPa+GyV/Cg5lw1Y/QbTIcWAbfXAUvJsO318PuOVBrjEXIMpQqjKuyCNa8A//pC5+cD5mLof+tcNdWuPRTaNtPX3kuhBBCCPdkrYNvr4PsLTDhI4jrozqREEII0TxEpuqLOOMHwew7YPadhvkiSgghTie3pJpod2tKBYjs4DZDqR1jg/A0a2w+clx1lIbbPQtytsOwRxr9gGfG+sOE+HlybqcYJ4U7ifbnQtfJsOIVyN7muvO60PDUKG4emsT0tYf5aWuW6jhCtBi5JfrnesNfK91sKNXH00xSpH/zGkotL9CHUjuM1e8NG8lqs/PD5mMMSYk4s2ZeD28Y/6a+VfVv/2j6cQwkLsyPm4Yk8tPWLNYdKFQdRwjxJ5bSKmMv3gg90ZRaaLxGcX9vD1qH+LpHU+q+Bfo/y/AkJaeffeIeaHz3M1i40RT9boHSLNg92+GHvrBHa1qH+Epb6t/Zvxje7g9LnoWUs/UG0/Fvub50ycMbkobrO7Xclw5Xz9ILHfcvgq+v0AdUv5uqfEBVhlKFsdjtcHQj/HgrvNwB5j8M3kEw/m29ivicp/XWASGEEEK4N7sd5t0He+fDeS9Bh/NUJxJCCCGaF99QuOJbGHQPbPoUPjrHkA0MQgjxR7VWGwXlbjqUGtUBCjIc3lbiDD6eZjrGBrHVXYZSbVZY/G998LfLhEa9taCsml935nBxjzau375z9LPgGwazbnOL3xdNcd857enVLpRHvt/Owfxy1XGEaBEspdX4eZkJ8PZQHeXvhSXq5TNu1IyZFhvEruY0lLriFagthxFPNOntq/cXkF1cxcU925x5llY9oP9t+r3pwRVnfjwDuHlYErHBPjz1006sNhncEcJILKXVZzZM72y+Ifr3dgb9ni4pKoB9eQYfSq2r1tshk0cpizB3WzZ94kNpFeLr2hOnnKN/zlrzrsMP7Wk2cdPQRDYfPs5aWXTxv6rL9Hv7zy/UF8peMxcu/Uz/d6Ga2UNv0L3gdbg/A676ATpfBPsWnhhQTfpDg2qlS6PJUKowhqoS2PAxvDcEPhgBO3/Ua4ZvWqZvO9jjCvDyU51SCCGEEI6y7CXY+AkMvg/6XK86jRBCCNE8mT1g1JNw2Vf6F93vDYH0n1WnEkKIU8orrcZud4P2t5OJ6gjWGrdphesWF8K2o8XUWW2qo5ze9pmQvxeGPwqmxg2WfrfpKLVWO5f1jXNSuL/hFwZjX4GcbbDyNdef3wU8zSbevKwHHmaNm7/YSEVNnepIQjR7uSVVxt6SuJ6BW+BOpWNsENnFVRyvqFEd5cwVH4N17+ut3VEdmnSI7zYdJdDHg7M7Rjsm07BHIaQd/HSXPkzk5vy8PHh4TAd2ZpXw7cYjquMIIU6w2+1YSquJNPq1MjQeiox5jUyODGCfpQybkQfuD6+G2gplQ6kZuaWk55ZyfpdY15/cZIK+N8HRdXBso8MPf2nvOCICvHh7yX6HH9ttZW3Rv1ffMl0vgLh5ZZNa6F3C7AFJI2Dcm3D/3hMDqhf/b4Pqt9fDzh/0OT0nk6FUoY7NBplL4fub4KX2MOdufdX7+S/DfXv0Ke7YbqpTCiGEEMLRNn8Ji5/WvxRt4kp9IYQQQjRC6hi4can+hfeMybDgSbDK0IgQwnhyTmxJHBNs8AeIJxOVpr9adqnN0UC92oVSUWNlT06p6ih/z2aD5a9AdGdIG9eot9rtdr5ad4Te7UJJiQ50UsDTSLsAOl4IS190m4HlxmoV4svrk3uQnlvKQ99tl20ehXAyS2k1Ue6weKO+NcqNfvalxQYBNI+21GUvgN0Gwx5u0tvLquuYvyOHsV1bOa5p3MtPX6xRsA/WvO2YYyo2rlsrercL5cVf0impap6t6EK4m5LKOmrqbMZfwBGaYNiFG8lRAVTV2jh23LWNio2ybwGYvZQNBs7dno2mwRgVQ6kA3S8Hr0CntKX6eJq5dmACy/bmseNYscOP73Y2fQ4fnq03jE75SS+A8HSDz+KgN7r+ZUD1EshcDDOvgRcS4bPxTm1PlaFU4XpFB2Hxs/BGN/hsHKTP01tRr18At6yEPlPBJ0h1SiGEEEI4w74F8NOd+jYC494ETVOdSAghhGgZwhLgul+h5xRY8aq+1VBprupUQgjxPywnhlLdsik1oj1oJrDsUZ2kQXrHhwGw4aDBt+RLnwv56XobSSPvH9ceKCQzv5zJfds6KVwDjX5Of2A670FopgObQ9tH8sC5qfy0NYv3l7vPAJoQ7sjiNk2p8fqrQQduTqZ+KHV3tsEXbJxO8TG9FKDXFAht16RD/Lw9m8paKxN6tXZstuRRkHqevotWaY5jj62Apmn844JOFJTX8ObCDNVxhBCApVS/pzR8U2pYAhQfMeSi8ZToAAD25ZUpTvI39i2Etv3BO0DJ6eduy6ZvfJi67y58gqDHlXrbZUm2ww9/Vf92BHp78E5Lbku1WeGXx2D27dBuoD7LZtR21Ib4fUD1DbhvL1w7H/rfCl4B4OnrtNPKUKpwjeoy2DIDPhkLr3eDpc9DWBJc8qE+kX3BaxDXRwZThBBCiOYsawt8fbXeIHTp5+DhpTqREEII0bJ4+uhfPF34DhzdAO8N1ncwEUIIg8gpduOhVE9fve3GTZpSW4f4Ehvsw4ZDRaqjnJrdrrekhsbrbaON9NW6wwT6eKjZUvGPgmJh+KOw7zfY/ZPaLE50y9Akzu8Sy3M/72F5Rp7qOEI0S/VbErvFddLLDwJjDbs18clEBnoTEeDNriw3b0pd87bekjrgziYfYtaWLNqG+dGzbagDg51wztNgrYEFTzn+2Ap0aRPMxF5t+HjlQfYbeYBLiBbCUloNQFSgwa+VofFgq4OSo6qT/EVy5Imh1FyD/kwrydLvu5NHKjn93txSMixljO2q+D6z343676ENHzr80EE+nlzZvx3zdmST2RKvbVXFMP1SWP0W9LsZrvgW/MJUp3Icswe06w9n/xMmf+nUU8lQqnCeumrYPUev/X0xGX68GUqOwYjH4e7tcPWP0GWCU6euhRBCCGEQRQfhy4n6h/bLZ0oruhBCCKFS98th6gLwDtK36Fn0jCGbGYQQLU9uaTWeZo0wPzddwBaVBpbdqlM0WO/4MNYfLDTudusHlkLWJhh4l/7QpBGOV9Qwb0cOF3Zvja+Xg7YdPhN9b4ToLjD/Yb3AoRnSNI0XJnSlfXQgt0/fzIH8ctWRhGh2yqrrqKixukdTKkBYIhS6V3ty59ZB7Mxy461qKwph4yf61qxNbEm1lFaxan8+47u3QnNGmVB4Epx1K2ydri+WbAYeOLcDPp5mnpnrPp8DhWiu6ptSo4IMfq0MTdBfDdgoHurvRbi/F/ssBr1vObBMf00cruT0c7ZmYdLg3M4xSs7/u7BESB0DGz6G2iqHH/7agfF4mk28t9S9PsudsdJc+GgMZC6Bsa/BmOcb/X2E+C8ZShWOZbPC/sUw6zZ4MQW+vkK/KPS4Qq//vWMTDHkAQuJUJxVCCCGEq1QUwhcT9BXwV36nt8QIIYQQQq2YznDjEn1AddkL8OkF+jaPQgihUG5xFVGBPphMbrqbUlRHKNzvlAdCztAnPpTckmqOFlWqjnJyK16DgGjodnmj3/r9pmPU1Nm4rG9bx+dqCrMHjH1FL21Y+pzqNE7j7+3Be1f1wqTBtR+vo6CsWnUkIZqV+vY3t2hKBX1rYjcbSu3SOpi9uaVU1lhVR2ma9R9CTRkMurvJh5i7LRubHcZ1a+W4XH825H79Gv/zQ2CzOe88LhIZ6M2dI5NZtMfC4nSL6jhCtGh5vzelGnwoNSxRfy005vboyVEB7DNqQ2bmUvANg+jOLj+13W5nzvZs+iWEG6ONt9/NUJEPO751+KGjAn24tHcbvt98lOxig35n4GjHD8PHo/WipSu+hd7Xqk7k9mQoVZw5mw0Or9FvHF7uAJ9fCDtnQYfz4Irv4L50OP9lvf7XGSvqhBBCCGFctZUwfZL+Qf6yryAyVXUiIYQQQtTzDoAL34aLpkH2Vnh3IKT/rDqVEKIFyy2tItrojTZ/JypN3y43f6/qJA3Sq52+Je+GQ4WKk5yEZQ9kLtYbRj0b97DPbrfz1frDdGsTTMdWBtqlI64v9LwaVr8NuTtVp3GaduH+fDClD9nFVUz9bANVtW462CWEAeWWnGh/M/qgTb3QBCjLhRr3aU7u0joYmx12ZZeojtJ4tZWw9l1IOReiOzX5MLO2ZJEWG0RKdKADw/2JdyCMehKObYBtXzvvPC50zYAEEiL8+decXdRa3X/QVgh3ZSmpxtfTTIC3wZsNA2PB0w8KjLl4IzkqgIzcUuPtqmG36ztqJAwGk+vH3fbklJKZV875XQ1SvpMwRF8cu+Zd/Z+Ng900JAmbHT5YbrxGX4fLz4CPRkNFgb7rd5KaJt7mRoZSRdNYa2H/IphzD7zSAT46V6+FbtsPLv0MHsiAi96FlFFg9lSdVgghhBAq2Kzw3VQ4uh4unqYvUBFCCCGE8XSbBDctg+A4mDEZfn4Y6qTZTAjhejnFVe7T/nYyUR3117w9anM0UIeYIAK8PdhwsEh1lL9a9x6YvaFX45tJNh0+zt7cMiYbpSX1j0Y9BT7BMOfeZtEMdyq92oXy+uTubDlynLu+2ozVZrCH2UK4qd/b39zlWvl7C5z7DDJ0aRMMwI5jxYqTNMH2mXpb2sA7m3yIwwUVbDly3LktqfW6ToZWPWDR027TMv93vDxMPH5+Gpl55Xy2+pDqOEK0WJbSaqKCvNGMXpZmMunXyYJ9qpOcVHJUACVVdeQZbeeDwkx994mEoUpOP3dbNiYNRneOUXL+v9A0vS01dzscWunww8eF+XFB11hmrDtMUXmNw49vGJbd+kCqtQaumasvKBUOIUOpouFqK2HPXPjhZngxGT6/CLZ+DW37wyUfwgP7YNIX0HE8ePqqTiuEEEIIlex2vUV9zxwY/Sx0ulB1IiGEEEL8nYhkmLoA+t0Ca9+BD8+GAmNuISaEaL4sJdXuPZQangQmT7DsUp2kQcwmjZ7tQo03lFp5HLZ+BV0mgn94o98+Y91h/L3MXOCKgZrG8guDs/8JR9bAtq9Up3Gq0Z1jeeL8jvyyM5d//rTTeC1LQrghS0n9UKqbNKX+PpRqzBa4k4kJ8iEiwIttR91sKNVuh3XT9K2M2w1s8mFmbz0GwAXdXNAAZzLpbaklR2H9+84/nwuM6BDFkPaRvLZgLwVGG+QSooWwlFYRGeAm18nwJCg05ndvyVEBAOyzlClO8ieZS/RXRUOp87Zn0z8pnAgj/R7rein4hsGad5xy+FuGJVNRY+XT1Qedcnzl8vfBp+PA5AHXzoeYLqoTNSsylCr+XmkubPocvr4SXkiCry7Xt/JLPQ8mz4AH98Oln0KXCeBjoK2QhBBCCKHWytf1LxP73w5n3aI6jRBCCCEawsMbxjyn3+8fPwzvDYFt36hOJYRoIcqr6yitrnPvoVSzJ0Sk6C0bbqJ3u1DSc0sprqhVHeW/Nn8BtRXQ78ZGv7WkqpY527IY172Vcbfs7H4FtO4NC56E6lLVaZzqukEJ3DA4gU9XH+LZn/fIYKoQZyi3pAofTxOBRv359mdhCfqrQQduTkbTNLq0Dna/ptTDayBnO/S9QW9NawK73c6sLVn0bhdKm1A/Bwc8hcRhkDQClr8MVW72z/wkNE3j/8amUVlj5aVf96qOI0SLVN+U6hbCkqDoIFjrVCf5i5SoQAD2G20o9cBSCGqtD/S62P68MjLzyzm3k0FaUut5+kKvayB9nlMWAqXGBDIqLYpPVh2kvNp4v1fPSNFB+Gwc2G0wZbZe2iAcSoZSxf+y2eDYRlj8LEwbBi+3h9m3w9GN+nZ+V/2gN6Je9A50OE8aUYUQQgjxV1tmwIJ/QOdL4Ox/qU4jhBBCiMbqcB7cvBJiusL3N8CPt0JNuepUQohmLrdE37Y1JthNHiCeSlSa2zSlAvSODwVgw6FCxUlOsFn1BY5tB0Bst0a/fdaWLKpqbUzu09YJ4RzEZIIxz0NZLix7SXUap3v0vDSuOqsd05Zl8tKv6TKY2gT5ZdUs3ZvHRysO8PiP27n8/TW8/Gu66lhCAUup3ihu+C2J6/kEg3+UYbcmPpUurYPJsJRSUeNGgxfrpun/vLtMbPIh9uSUkmEpY3x3FzeNj3oSKov0koNmIDkqkKv7x/PV+sNsd7fGXSGagbySaqIC3WShY3gS2Org+CHVSf4iOsibAG8PMow0lGqzwYHlekuqgs9CC3fnAnortuH0u0lv+lz9H6cc/pZhyRyvqGXGusNOOb4Sxcf0htSacrj6R4hMVZ2oWXKTpXTCqSoK9RUFGb/pf5RbAA3a9IERT0D7c/XtHtzlJlcIIYQQ6uz9FWbdBglD4MJ39IdtQgghhHA/wa1hyk+w7AVY+gIcXQ8TPoaYzqqTCSGaqZwTQ6nR7vIA8VSi0mDHd1BdBt4BqtOcVs+2oXiZTazJLGBkWrTqOJDxq95WMurJRr/VbrczY+1h0mKD6Nom2OHRHKpNb+h2Oax5G3peraTpx1U0TeOpcZ2os9n4z+L9eJpN3D2qvepYhlVdZ2XL4eNsPnKcbUePs/VIMceOV/7+14N8PEiMDMDfXZoyhUPlllQRFehmizfCk6HAfZpSAbq0CcFmh93ZJfRqF6Y6zumVZMPu2dDvZvDyb/JhZm/NwmzSOK9LrAPDNUBsN73cYPXb0PdGCDRYA10T3DUqhdlbs3jkh238eOtAPMzyHbkQrlBZY6W0uo5Id7lWhp9oZSzMNNz9gKZpJEUFsM9IQ6m5O6CyUH/+qMCC3RY6xAS6rk28MQJjoOskfdeRYY+Af4RDD9+rXSj9EsL4YPkBrurfDm8Ps0OP73IVhfD5RfrrlFkQ00V1omZL7lpbotoqOLIWMhdD5hLI2gLY9RV0yaMg5Vz91T9ccVAhhBBCuJUj62HmFIjuBJO+1LcAFkIIIYT7MnvA8EchfhB8dwO8PwLGPAe9rpWFq0IIh7OUVAMQHezmQ6mRafprXjq06aU2SwP4eJrp3jaE1ZkFqqPo1r6rb8fYYWyj37r9WDG7skv41/hO7tEiOOof+hDRr0/AZdNVp3Eqk0njmQu7UGu189qCDCprrDw0ugMmkxv8e3Iym83OruwSVu7LZ8W+fNYfLKSq1gZAXJgv3duGcM2AeDq3DiYlOoBwfy/3+P0tnCKvtJq0VkGqYzRORDLsmac6RaN0aa0vbNh+tNg9hlI3faY3jfe5vsmHsNvtzN6SxaDkCMIDFHynO/wx2DULlj4PY191/fkdLNjXkyfHdeT26Zv5ZNVBpg5OVB1JiBbBUqovdHSbBRxhJwZRC/ZDytlqs5xESlQAy/bmqY7xXweW6q+JQ11+6uMVNWw8VMQtQ401PPw/BtwJmz+Hde/D8Eccfvhbhycz5aN1/Lj5GJOMvDPJ6dRWwozLoOgAXPk9tDb+9zbuTIZSWwKbVV81kLlE/+PQaqir1Oub2/SBYQ9D4nD9Pzaz/JYQQgghRBPkpcP0iRAQDVd+Bz5u9gW5EEIIIU4tYQjcshJ+uAnm3AMHlsEFr+uLW4UQwkF+b0oNcvOh1KgTQ6mWXW4xlArQPzGcNxZlUFxRS7Cfp7ogeen699cj/w/Mjc8xY90RfDxNjO/R2vHZnCEwBgbfBwufgv2LIGmE6kROZTJpPH9JV3w8Tby3LJO80mqen9AVzxbYHldVa2XV/nx+3ZnLgt0W8sv0ofyUqAAm92nLgKRweseHEebvpTipMJrckiqGpkaqjtE44SlQka9vz+4bqjpNg0QHeRMR4M22Y26w9brNpreiJQ6DsKYPPm46XMSx45Xce7aiJuvwJOh1DWz4GAbccUZ/L0ZxfpdYvks9yiu/7WV05xhjNusJ0czkleqfqaLc5Z7SPwK8g6Bgn+okJ5USFcC3G49yvKKGED8DfC49sEz/XBHUyuWnXpKeh9VmZ2RalMvP3WCR7SH1PFg3DQbeeUbt6SczJCWCTq2CeHdpJhN6xWF2xwWGNit8N1UvcZzwESQMVp2o2ZMJxOaorgayt8ChlfoA6uE1UH3ixikiFXpN0YdQ4weCd6DSqEIIIYRoBoqPwecX6wtervoeAgx8UyaEEEKIpvGPgMtnwqo3YOE/IWszTPgYWvdUnUwI0UzkllQR4O1BgLtvSR0aDx6+YNmtOkmD9U8K5/WFGaw9UMA5nRRumbvxUzB5Qs8pjX5reXUds7cc4/wurQjyUThY21j9b9Mb7uY/AjevaNIwrjsxmzT+Nb4z0YE+vPzbXvLLa3jnip4tYiv6qlorC3dbmLMti6V786iosRLg7cHQ1EhGdohiYHKE+w/lC6cqq66jvMbqfr9P6rcmLtgPbXqrzdJAmqbRtU0wO9xhKPXAEig+DGc/eUaHmb0lC28PE+d0inZIrCYZfL8+YLvsJbjwbXU5HETTNP51YWfOfmUZ/zdrJx9O6S1N10I4maV+KNVdmlI1TR/KL9yvOslJtY/RZ4nSc0rpl6h4l2VrLRxapW9Rr8CC3blEBHjTrU2IkvM32MC7IH0ebP4S+t3o0ENrmsYtw5K4ffpm5u/I4fyusQ49vtPZ7fDzQ7BnDox+DjpfrDpRi9D87/RbgpoKOLZB/yF8aKW+dW5dpf7XItpDpwuh3QC92UTBqgEhhBBCNGOVRfDFJVBVDNfObRar2IUQQghxCiYTDLob2vaHb6+DD8+Bc56GfjfpX6QLIcQZyC2pIirITR4e/h2TGSJTIc99hlJ7tA3B28PE6kyFQ6l11bB1BnQ4X18I0Uizt2ZRXmPl8n5xTgjnRB7ecO4z8NXlsOEj/ZrazGmaxh0jU4gM9ObRH7Yz8d3VvHdVL+LCml+DnNVmZ/X+An7ccoxfduRQWl1HZKA3F/Vozdkdo+mfFI63h1l1TOEmLCVutiVxvYgU/TU/w22GUgE6tw5mSbqFipo6/LwM/Dh902d6A22HsU0+RJ3Vxtzt2YxMiyJQ5cKOoFjofR2sfU9vEg838BbJDdQm1I/7zmnP03N3M2+7Gw7wCOFm3PJaGZYER9erTnFSHU4Mpe7NNcBQ6rGNUFMGiUNdfupaq42le/MY0zkGk9HbQdueBXH9YPWb+jXVwTtlj+kcS0LEXt5eso/zusS412KLte/C+veh/+1w1i2q07QYBv4ULU7KbtdXEx7bAEc36K85O8BWC2gQ00VvQm03QH9IJE1lQgghhHCW2kqYPllfxXnFtxDbTXUiIYQQQrhC235w83KYdRvMf0jfPmv8W+AXpjqZEMKN5ZZUE+Nu7W+nEtURMherTtFg3h5meseHsnp/gboQe+ZCZSH0vKpJb/9y7SFSowPp2dY9tob+H6nn6TubLX4GOk8Af8UPfF1kct+2RAf7cNeMzVzw1gremNyDIe3dbFvyk7Db7ezMKuHHzceYvTULS2k1Ad4ejO4cw4XdW9M/Kdw9t7oUytW3v7ldU2poPGhmw25NfCpdWwdjs8OurBJ6xxv0PqeiUL9+9r5OX+TQRKv2F5BfVsO4bq0dGK6JBt4NGz6GZS/CRe+qTuMQ1wyI58ctx/jH7B30TwonzN8AW2AL0UxZSqvxMGmEGmGr+YYKT4Kd3+uL9M7gZ7kzxAT5EOjjQXpuqeoo+nePaBDv+u3W1x8opLSqjpFpCtvEG2PAnfD1FbDrR+gywaGHNps0bhqSyMPfb2d5Rr773L/tWwC/PKov4jn7X6rTtCgm1QHEaZQXwN5fYNEz+ra4z8fDW73gh5v0leNeATDgdn0LvYcO6g+FxjwPHcfLQKoQQgghnMdapzekHVkLF09TsjpRCCGEEAr5hcHk6XDus5DxK7w3BI6sU51KCOHGcoqr3G/Q5lSiOkBptj4s4ib6J4azJ6eUwvIaNQE2fw7BcfpwZiNtO3qcHcdKuOKstu7V1FJP02D0s1Bdpg+mtiDDU6OYffsgYoJ8mPLxOp6dt5vqOqvqWE1ypLCCtxZlcParyxj75go+XX2Qrm1CeOvyHmx4fBQvTezGoJQIGUgVTeZ2WxLXM3vqg6kFGaqTNEqXNsEAbD9WrDjJ39j2NVhroEfTFnTUm7Uli0BvD4alGmCwJDAa+lyv/73lu9fvmVPxMJt4aWI3iitreWLWDtVxhGjWLKXVRAR4G7/N8o/Ck8Fug6KDqpP8haZpdIgJJD3HAEOpB5dDdGclC+IX7Lbg5WFicErjd/RQIvU8iEiF5S+Dzebww1/UszXRQd68vcRNFhzlZ8DM6/TFwxe9p+8EJlxGmlKNwm6H4iOQs/2/f2Rvg+LD+l/XTPp/JB3H69tbtO6tbwNlkq1dhBBCCOFidjvMuRvS58F5L0Gni1QnEkIIIYQKmgb9b9WbU2deCx+NhpH/p6/Ily/4hBCNYLfbsZQ2p6HUjvqrZTfED1SbpYH6J+ntnGszCxjTxcVbyxYdgv2LYdjDTfq+e/raw/h6mrmwhwEa3poqKg36TNW3E+x9HcR0Vp3IZeIj/Pn+1gE8PXc37y3LZFlGPi9O6Ern1sGqo51WUXkNc7ZnM2vzMTYcKgKgT3woT1/YmfO7xBIqbXjCgf67JbEbXisjUiDfTQYXTogO8iEy0JvtRw08lLr5S2jV44yuGVW1Vn7ZmcOYzjH4eBrkmfPAu2D9h7D0BbjkfdVpHKJDTBB3j2rPi7+kM6ZzFmO7tlIdSYhmyVJaTVSQmy3eCEvSXwv26/M/BtM+OpCftmZht9vVLQCsq4Ej6/Udo13MbrezcE8uA5LC8fNyk/E6kwmGPADfT4U9P+kzZg7k7WHmhsGJPD13N5sOFxl7t5LKIpgxGcweermCd4DqRC2Om/xX08zUVukrAnN2nBhA3aa/Vh0/8Qs0/QYtro++GqxNb4jtLv+BCCGEEMIYFj2tt9gMeRD63qA6jRBCCCFUa91L37ll9h2w4B9wcIW+1aK/mzQICCGUKyyvodZqJ9rdHiCeSnQn/TV3p9sMpXZtE4Kfl5lV+xUMpW75Un/tfkWj31pSVcusLVmM69aKIB9PBwdzseGPwPaZMP9hmPKTvvijhfDz8uDfF3VhRGoUD3+/jXFvreCaAQncc3YKgQb791peXcdvu3KZteUYyzPyqbPZSYkK4IFzUxnXrRVxYX6qI4pmylJajbeHiSBfN3y0G54MmUv0ti43WrzWtXWwcZtSc3dC7nYY8+IZHWbxHgtl1XWM726ghR0BUfp3zqvf0odqIturTuQQNw1J5NdduTzx4w76JYQT6W6tx0K4AUtJFW1CfVXHaJzwRP21wJiLN1JjAvlybR05JVXEBiv6Z5u9FeoqoW1/l596f14ZhwoqmDo40eXnPiOdL4alz+kLPDpc4PDPX5f1bctbi/fx9uL9fDClt0OP7TDWOr1EoegQTJkNoe1UJ2qR3PDOxY3UlEP+XshLh7w9/30tOqhXcAN4+OpfUna6CGK6QExXiO4IXv5KowshhBBCnNTa92D5S9BzCgx/VHUaIYQQQhiFTzBM/BQ2fAjzH4V3B8GEj6Gd678wFkK4n9wSfUvimObSlBoYC75h+rCIm/A0mzgrMZzlGXmuPbHNCpu/gOSREBLX6LfP2nyMylorl/dr64RwLuYbCiMeg7n3wa5Z0OlC1YlcblTHaBbGD+OFX/bw8aoDzNpyjNuGJ3PFWW3x9lDX4FddZ2VJeh6zt2axcHcuVbU2YoN9uG5QAuO7t6JjbJC65ijRYuSWVBEV5O2ev9fCk6GuCkqOQoj7/Lzu3DqYRekWyqvr8Pc22CP1bV+DyUMfOjkDs7dmERHg/XtjumH83pb6PEz4UHUah/Awm3h5YlfOe2MFj/6wnWlX9XLP/56FMLC80mp6GLm18WR8Q8EvHAr3q05yUqnRgQCk55SqG0o9tFJ/bTfA5adesNsCwMgOUS4/9xkxmfWFHT/cpO96mTbWoYf39/ZgSv94Xl+YQXpOKakxgQ49vkP8+jhkLoZxbyr5vSN0BvsE7YbqauD4YSjM1P8oOqBXa+elQ/Hh//46k6d+0xXTFbpcqq/qiu6s13Gb5V+DEEIIIdzAju/g54egw1g4/5UW1doihBBCiAbQNH3r4TZ9YeY18OlYOPtfcNYt8rlBCPG3cuu3JG4uQ6mapm+lm7NDdZJGGdo+kkV7LBzMLyc+wkWlCfsXQckxOPffjX6r3W7ny7WH6dQqiK5tjL/Ve4P0vAY2fAy/PgHtzwVPN2t6coBgP0+euagLl/aO4/n5e/jnnF28vzyTKQPiuaxPW4L9XNOceryihsXpFhbssrB0bx5l1XWE+3sxsVcc47q3olfbUEwm+XwjXMdSUk10oJteJyNS9Nf8DLcaSu0WF4zdDtuPFXNWooGGNm022DYTkked0e4UJVW1LNxj4fK+bTEb7eeZf4TelrrydX2oJqqD6kQOkRwVyP3ntOff8/bwzYYjTOrjPv89CGF0tVYbhRU1RLljC3FYkj5jZED1w4bpOaUMS1U0mHl4NYSn6E3aLrZwdy4dY4NoFeKG92WdJ+iLO5Y+Dx3Od/h3s9cMiOf95Zm8u3Q/r07q7tBjn7GNn8Lad6DfLdDzatVpWjSZhjwdmw3K86D4qL6Cr+iQPnhaP4RafPS/racAXoEQlgBt+0Hk1RDZQf8jNB7MxtpmRgghhBCiwfYvhu9v0rfHuOQDWVQjhBBCiFOL7Qo3LoYfboFfHoGj6/VV6d4BqpMJIQyqfig1JthNh21OJqYrrP9A3zLOTe6fhraPBGBZRp7rhlI3fao3A6We1+i3bj5ynD05pfz7oi7Np2nM7AGjn9MXdqx8A4Y9pDqRMt3iQph+w1msyMjnP4v38dzPe3h9QQZjOscwvkdrBiaF42F23DaUtVYb244WsyazgGV789hwqAirzU5koDcXdItldOdYh59TiMbILa2igxFbqBoi/MRQasF+vRnbTXSP09v2Nh8+bqyh1EMroDQLzn36jA7z685caupsjOveykHBHGzAnfpnqaXPw8SPVadxmOsHJbJ0bx7/mL2Tnm1DSYl20/+uhTCYgrIa7HaICnLDodTwJMhcqjrFSYX4eREd5E16bqmaADYrHFoNnca7/NRF5TVsPFTE7cOTXX5uhzB7wOD7YdatsPcXSB3t0MOH+ntxWd+2fLLqIPee3Z64MD+HHr/JDq3Sdx9JGgHnnNlnJXHm3OPbMGexWaE8H8pyoMyir8guPvq/f5QcA2vN/77PNxTCEiGuH3S7TP/foQn6q3+EtH8IIYQQonnJ2gJfXwkR7eGyGS2yqUUIIYQQjeQTDJO+gJWvwaJ/gWWX/v/rW5KEEOIPck4MpUYGuOEDxFOJ7qxvVVy4HyJTVadpkPgIf9qF+7E0PY+r+8c7/4QVhZA+X29C8/Bq9Nu/XHMYfy+zcYdpmiphMHS8EFa8Ct0vh5A41YmUGpQSwaCUCHZmFfP56kPM3Z7N95uPEejjwcCkCAYmh9OlTQgdYgLx8TQ36Jg1dTYOF5azM6uEXVkl7MgqZtOh41TWWgHoEBPIrcOSGJUWTZfWwdKIKgwhr6SaISmRqmM0TUCUXupTkKE6SaOE+XsRH+7H5sNFqqP8r21f6/882485o8PM3ppFXJgvPeJCHJPL0fzDoe+N+vVw6EPNpi3VbNJ49dLujHl9ObdP38ys2wc2+PolhDg1S+mJ3TfcsVU8PAm2zoCaCvAyyGDfH7SPDiQ9R9FQqmUXVBdDu4EuP/XidAs2O4xMi3b5uR2m66Un2lKf03ficPAs29TBCXy2+iDTlmXyrws7O/TYTVJ0SH+eHdoOJnzsNguEm7Pm/2+gPB+2f/vfwdOyXP2P0lyoyP/fllMAzQSBrSC4DbTuCR3HQXCc/v/r//ANVfP3IoQQQgjhavn74ItLwDcMrvwOfENUJxJCCCGEuzCZYPC9+vcr314H04bDhW/r37UIIcQf5JZUExHghZdHM2ogjDnxQCZnu9sMpYLeljpzw1Gq66x4ezh5QGLnD2Crha6TGv3W4opa5mzL4pJebQjwboaPOc75F+ydD7/9X7NqhzsTnf6fvfsOk6q82zj+ndnee18WWFjK0nvv3QrYsGEv0VhiNImJiW+iJqapzrFwywAAn8dJREFU0VhiV7ChgKIiVXqV3mFh6bA7W4BtbJ15/zhgSShbZufszNyf69rr5JWZ57n3hd0z55zf83uSI3j2qs783xUdWLw7j0W7bCzLymPO9hwArBZIiggiJTKI6BB/QgJ8CfCzUlltp6LaTnlVDbbiCo6fPE1eSQUOhzGun4+FNglhXNszlb7pMfRuGU2MJxXIi0coq6ymuKLaPbu/gVEAEdsa8t2rKBWgW1oUy/fm43A4mkZX7qrTsGOWcU3VgMKl/JIKVuzN557B6U3j+zqffj+HNf+BpX+Hq98yO43TxIcH8s9ru3DrO9/x1Fc7eGZCJ7Mjibg9W1EFAHFhbniujG5lHAuzf7iObELaJYbx3qqD1Ngd+Lh6sdbBlcaxeX/Xzgss3GkjLiyATikRLp/baXz8YMiv4Iv7YecsyHRux9mkiCAmdEth2rrDPDgiw9yfv4oS+Oh6sFfD9Z/oeXYT4YF3a/7L6ZMw59dg9YXQBGM1XngqJHf/4f8OS4SQeAhPhrAkVUuLiIiIgNE1fsp443/fPBPCk0yNIyIiIm4qfSjcsxSmTYZpNxvbMI54UvdfROR7uUXl7tnR5kJi24LVzyhK7XS12WlqbUibON5fdZB1B04woHVs4062+WOIawdJXer81hkbj1BRbeeG3mmNEKwJiEyDAQ8bHW163QktXN8ZqKkK9PNhbMdExnZMxOFwcOTEaaPj6fEijhSWceTkafbllVBaUU1ljZ0AXx/8fa34+1iJCwugbds4kiKCSIsOJjM5nFZxoZ5VEC8e6WyhTYI7nytjWsOhNWanqLPuaZHM3HiUoydPkxrVBLrX7V0AFUUN/mwxe+txauwOruya4qRgjSQkBnrfCSteNLqlxrUxO5HTDG0bzz2D0/nP0mz6t4rl0s669+5MDoeDssoaQjxx8ZKck63YOFfGu2NRasyZ7eEL9jbJotQ2CWFUVts5UFBKq7hQ105+cKXRRDDStdd9ldV2luzJ47LOSe6/a0LnSbDyJVjwR2h7iVGo6kT3DGnFp+uP8M6K/fxqrEldze12mHE35O2EGz8zFkNJk+D5Z+HolvBYttHd1KobCyIiIiK1UpoPUyZA+Sm49St9gBcREZGGiUiF276BOY/Dyhfh2EZjG6VQN92CVEScKreonMQINy60ORdff6PgMneb2UnqpF+rGPx9rCzebWvcotSCfXBkrbFIoY4d2hwOB1NXH6RLs0g6unPXmosZ8BBsnArf/BruWQJWbe373ywWC82ig2kWHczYjolmxxFpNN8X2rhrp1SAmAxjZ8uq0+AXZHaaWuuWZuyeueHQyaZRlLr9cwiOgRaDGzTMrE3HaJsQRtvEMOfkakz9H4S1bxjdUq96w+w0TvXL0W1Zs7+QX322mdbxoe7x99EEVVbb2X7sFBsPnSTLVkxWbglZthJaxYUw4z4t7PEWtuJyAGLdseN9dLpxLNxnbo7zaJcYDsCenGLXFqU6HEZRavpQ1815xtr9hZRUVDOifYLL53Y6H18Y+X/w0STY8J6x6NGJWsWFMq5jIlNWHeTeoa0ID3Ru0WutLHoGdn8NY5+F1iNcP7+cl+dXaVp9jFVUKkgVERERqZ3yIph6FZw8BDd8Uq+uNSIiIiL/wzcALnsOxr8GR76D14fC8c1mpxKRJiC3qJwEdy60OZ/EjpDjXkWpwf6+9G4ZzeLdeY070ZZpgAU6X1vnty7LymdfXim39Gvu/FxNiX8wjH4KcrcaDw9FxGvlFhmFNm7dVTy2NeAwtiZ2I20Twwj0s7Lx0AmzoxgFvbu/gfaXN2jXiaMnT7Pu4Amu6JrsxHCNKCQWet0B2z6D/L1mp3Eqf18rr93Ug+AAX+58/ztOlFaaHckt1NgdrD94gufm7eba11bR6f/mMuGVlfzpqx3M2ZaD1Wrhss5JXNermdlRxYVsxRVEh/i7Zwf8gFAITYSCpnmObB0fisUCu3KKXTtxwT4otUHz/q6dF1iwM5cAXysDG3v3EFdpMxaaD4DFfzW2uXey+4a2priimqmrDzp97Iva+hks+wd0nwx97nX9/HJBbvgbWUREREQaTdVp+Oh6o5vPte+bcrEnIiIiHq7r9XD7XMABb42B7TPNTiQiJqqqsZNfUklCuBsX2pxPYicoyYGSRi7wdLKhbePIspVwqKCscSZwOGDLJ9BioNFJu47eWbGf2NAA79jmtsMEaD4QFj4Fp5tAQZSImOJsp1S3XsBxdmvi/Cxzc9SRn4+VzimRbDx00uwosHcBVJVC5vgGDfPl5mMAXN7ZTYpSweiW6hNgdEv1MIkRgfzn5h7knqrgvg82UFVjNztSk1RRXcP8Hbk8/PFGejw9n6teXcm/F+2losbOzX2b8+qN3Vnz2xFs/MNopt3Tj2cmdOK6Xq7d7lvMZSuqID7Mzc+TBU2z8D7I34cWMSHsyXVxUeqhlcaxuWs7HjscDhbuymVA61iC/D1ktwqLBUb9ySjyXfVvpw/fMSWCQRmxvL18P+VVNU4f/7yOrocv7oe0/nDJP+u8C4s0PhWlioiIiIihpgo+vQ0OrjA6mLUZY3YiERER8VTJXeHuxZDUGT69Fb59Bux6+CbijX4otPHAotSEjsYxd6u5OepoTAdjG/S523MaZ4LDa+HEfugyqc5v3Z9fyqLdedzYJ40AXw95QHghFguMexbKT8LiZ81OIyImsRWV4+9rJSLIhO1QneVsUWqBexWlAnRLi2THsSIqql1YZHEu22dCcAy0GNSgYWZtOkbXZpGkxQQ7KZgLhMYb3VK3TjM653mY7mlR/HliJ1ZlF/D0VzvMjtNk1NgdrNybz2+mb6HX0wu46/11LNmTx/B28bx0fTc2/H4UX9w/gCcuy2RcpyTPvJ6QWssrqSDOnYtSY1tD/h6zU5xXm4RQdru6U+rBlRAcC7EZLp02y1bC4cLTjGgf79J5G11qT8i8Ela8CCU2pw9/39DW5JdU8um6w04f+5yKjsHHN0JIPFw3BXz9XTOv1ImKUkVERETEKAL54uew5xu45O/Q+RqzE4mIiIinC42HW76ErjfB0r/BtJsbZQspEWnazm5JnOiJD5ETOxnHHPcqSm0WHUxmUnjjFaVu+Rh8A6H9FXV+63srD+DnY+HGvl7U+SqxE/S4Dda+AbadZqcRERPYio3ubxZ37v7kHwLhKW5ZUNgtLYrKGjvbjhaZF6LqNOyeA+0vBx/feg+z11bMjuNFXNHFjbqkntX/QfDxh2X/NDtJo7i6Ryp3DWrJe6sO8tby/WbHMZWtqJwXF2Yx6K/fcsOba5i1+Rgj2ifwzm29WPu7kTx3bVcu75JMZLAKkOQHeUXlxIe58TVlbBs4XQilBWYnOae2ieEcKCh1bRfMgyugeT+Xd79csDMXgBHtElw6r0uMeBJqKuDbp5w+dN/0aLqnRfLq4n2Nv5CnshQ+mgQVxXDDxxAS27jzSb2pKFVERETE2zkcMOc3xoPBYU9A77vMTiQiIiLewjcArvw3jH0Wds+Gt0bDiQNmpxIRF8o9ZRSlxrvzlsTnExxtFODkbDM7SZ2N6ZDI+kMnsBWXO3fg6grYNgPaXQaB4XV6a1F5FZ+uO8zlnZPd+4FzfQx/AgLCjGt3h8PsNCLiYrlF5e69JfFZMa0g3/06pXZPiwRg46ET5oXYuwCqSqHDhAYNM2vTMawWuKxzkpOCuVBYAvS8HTZ/DIXZZqdpFL8Z156xHRJ56qsdfLHpqNlxXMrhcLBibz4/m7qe/s9+y3Pz99AqPpSXru/G+idG8fx1XRnWNh4/H5W3yP9yOBzklVS49zVlbBvj2ES7pbZNCMPugL02Fy0mP3UETh6C5gNcM9+PLNxpo2NKOIkRHnjNGdMK+v4MNkyBI+ucOrTFYuGRUW05dqqcD1YfcurYP2G3w8x7jMW/V78NCR0aby5pMJ21RURERLzdkr/C2v9A3/th8KNmpxERERFvY7EYN0Rvmg5FR+D1YXBghdmpRMRFPLpTKkBCR8h1v6LUsR0TcThg/o5c5w68d4GxFX3n6+r81k/XHaG0soZbB7RwbiZ3EBwNw34H2Yth55dmpxERF7MVV3jGttQxGVCQ5XbF9fHhgTSLDuK7A4Xmhdg+E4JjoPnAeg/hcDiYtfkYfdNjiHfXf08DHgKrr8d2S/WxWnhhUlf6pkfzy2mbWbTb+dsrNzWV1XY+W3+EsS8s48Y317Aqu4DbBrRg0aNDmXJHHy7vkkyQv4/ZMaWJO1FWRVWNg7hQdy5KPbNFfVMtSk0MBWB3TrFrJjy4yjg27++a+c4oKKlgw6ETntkl9awhv4awRPj6EbA7t6PpwIxY+reK4eVFeymtqHbq2N9b9LRxTTz6aWgzpnHmEKdRUaqIiIiIN1v9Giz+C3S90fgA787bgImIiIh7azUc7lpkbLn0/pWw+ROzE4mIC+QUVeDnYyHKU7ffTOwIebuhyskdRxtZm4RQWsQEM2dbjnMH3jYDgqKh1bA6va3G7uC9lQfo0TyKzqmRzs3kLnreDgmdjG6pFS7qUCQiTYLHdEqNawvlp6DE/QrterWI5rsDJ3CYUVBbdRp2z4H2l4OPb72H2Xr0FAcKyriiS7ITw7lYWCL0uNXoluqhO2wE+vnw+uSetE0M454p61nsoYWpReVV/GfJPgb/bRGPfroZgL9f3ZnVj4/gd5dm0jI2xOSE4k7O7u7g1p1SI9LAN7DJFqW2iAnB39fKrpwi10x4cAUEhBuLPF1o0e48HA4Y2d6Di1IDwmDMM3B8M6x72+nDPzqmLQWllby9fL/Tx2bTR8bClB63Qt/7nD++OJ2KUkVERES81eaPYc6vjW0TL38RrPpoKCIiIiaLaQV3zIO0vjDzblj8rNt1UhKRujEKbQKxWj10gVxiJ3DUQN4us5PUicViYUzHRFbtK+DU6SrnDFpZBru/OVNU41enty7cmcuhwjJu7d/COVnckY8vXPYcFB01FpeKiFcor6qhuLzafTtb/tj3WxPvNjdHPfRpGU1haSX78kxYFJA1H6pKocOEBg0za9Mx/HwsjOuY5KRgJhn4MFisHtstFSA80I8P7uxDRnwod7+/nkW7PKcwNedUOX+ZvZMBf/mWv3yzi5axIbxzWy/mPDyIa3o2I9BPXVGl7mxFFQDEh7nxudJqNTqK52eZneScfH2stEsMY8dxFxWlHloFzfqA1bW/ExbuzCUhPICOKeEundflOkyElkNg4Z/g1FGnDt09LYpRmQm8vjSbk2WVzhv44Cr48kFoORgu+YeaLLkJj688qK6x88l3hzhUUGbO6jURERGRpmjXbPj8PuPD+1VvNWiVvYiIiIhTBUXBTTOMTu6L/wIz74XqCrNTiUgjyS0qJ8GdO9pcTEIn45i7zdwc9TC2QyLVdgfztjupW2rWPKOopuPEOr3N4XDw6pJ9pEYFMa5jonOyuKtmvaH7LbD6Vchxv39TIlJ3PxTaeMC5Mq6tccxzv6LUXi2iAVi7/4TrJ985C4JjoPnAeg9RY3fw5ZZjDGkTT0Rw3RaGNDnhyca5cNOHcOKg2WkaTWSwPx/c2Yc2iaHc9f46Zm48YnakBtlrK+GxTzcz6G/f8saybIa0jWPWzwfw0d19GdY2HouKi6QBbMUecq6MzWiynVIBMpPC2XGsqPHrrkrzjUWdzfs37jz/paK6hqV78hjeLsHzfydZLHD5C2CvNgo9nfx3+ujotpRUVvPqkn3OGdC2Ez66DiKbwzXv1XmRq5jH44tSdxwv4tfTtzL474sY+Fej/fuMDUc4fuq02dFEREREzLF/KXx6KyR1gUkfgp8brx4VERERz+TrD1e+DMOegC0fw5SJUFZodioRaQQ5ReUkRnjwNUl0S/ALhpytZieps67NImkeE8zMjU7qnLJ9BoTE1bmoZu3+QjYeOsndg9Px9fH4RxoXN/L/ICgSvn4E7Haz04hII8v9fktiDzhXhiUZW/G6YVFqy9gQYkMD+O6Ai69JqithzzxoO65BTQXW7i8kt6iCK7omOzGciQb+wuiWuvx5s5M0qshgfz68qy89W0Txi0828+rifW7XhGvz4ZPcO2U9o55fwqzNx7i+dxqLHx3Gv2/oTufUSLPjiYewfX+udPei1DZw8iBUlZud5Jwyk8M5UVZFTlEj5zu0yji6uCh1TXYhpZU1jGwf79J5TROdblxb7l0Amz5w6tBtE8MY3zWF91YeILeh/15OHYGpV4FvENw0HYKjnRNSXMLj7+B0SolgwSODeerKDnROjWDBzlwembaZfn/5lmH/WMxvZ27ly83HyC9Rxw0RERHxAofXwoeTjAejN02HgDCzE4mIiIicm8UCQx6DiW/CkbXw1igozDY7lYg4ma2owr23WbwYqw8kdHDLolSLxcLEbqmsyi7g6MkGNnmoKDGKajKvrHNRzatL9hET4s81PZo1LIOnCI6GUU/B4TWwaarZaUSkkZ3tlOoRXcUtFqPgJt/9ilItFgu9W0axdr+Li1IPLoeKU9DusgYNM2vzMYL8fDyn0CYiBbrdDBunwsnDZqdpVOGBfrx3e28u65zEX+fs4pfTNnO6ssbsWBfkcDhYnpXPDW+s5sqXV7ByXz73D23Nit8M509XdiQtJtjsiOJh8oorCA3wJdjfzXcEjM0Ah73J3vvKTDK2tN9xrKhxJzq4EnwDIblb487zXxbuzCXQz8qA1rEunddUve6C5gNgzuNw8pBTh/7FyDbY7fD3uQ343FdWaDQqqCg2nmlHNXdeQHEJN/+tfHEWi4XW8WG0jg/j5n4tsNsd7MopZuW+fFbtK+DLTcf4cI3xw9U2IYx+rWLo1yqGvi1j3H/7AhEREZEfO74Zpl4NofEw+QutJhMRERH30Pka46HjxzfAmyPh+o+N7YtFxO2VVFRTUlHt2Z1SwdilYvMnRldLq3v1iZjQLYXnF+zh841HuX9Y6/oPtGcOVJ+GDhPq9Ladx4tYvDuPR0e3Icjfp/7ze5quNxiFOPP/AG0vhZAYsxOJSCM5213KYxZwxLU1OnK5od4topm9NYejJ0+TEhnkmkl3zTY6rqcPrfcQldV2vtl2nFGZCe5fsPVjA38BG943uqVe9pzZaRpVgK8PL07qRkZ8GC8s3MPOnGJevbE7LWJDzI72EzV2B3O35/Dq4n1sPXqK+LAAfntJO67vnUZYoOoupPHYiiuID/OAxRuxbYxj/m5IyDQ3yzm0SwrHYoHtx4oY0T6h8SY6uBJSe4Gv6/5OHQ4HC3fZGNg6lkA/L7rutFqNnar+Mxg+vQ1u+8bYvcoJ0mKCuW1gC/6zJJvJ/ZrXvTt2RQl8NAlO7IebZkBiR6fkclcOh4MTZVUcLiyjsLSSE2WVVNXYsTvAz8dKeKAvUSH+pEQGkRAeiI/VYnZkwAuKUv+b1WohMzmczORw7hyUTnWNnW3Hir4vUv34u0O8u/IAFgt0SA6nf6tY+rWKoVeLaEIDvO7/XSIiIuIpbLtgygSjM+otsyAs0exEIiIiIrXXvD/cuRA+uBreuwKueRfajjU7lYg00NlCG4/o/nYhSV3huzeNjjexDSjsNEFaTDC9W0QzY8MR7hvaCoulng82ts+E0ERI61ent728aC8h/j7c3LdF/eb1VBaLUYDz2kCjMHX8y2YnEpFGYiuuwM/HQpSnNNKJa2tsEXv6JARFmp2mTnq1NBb4f7e/kJRuKY0/ocMBu76GVsPBr/5FsMuy8jhZVsUVXZKdGK4JiGwG3W6CjVNg0C+NhYwezGq18NDIDDo3i+Dhjzcx7l/L+O0l7bixT3OsJheenDpdxafrDvP+qoMcKiyjZWwIz07sxITuKQT4elFxl5gmr6iCOE8oSo05c62Yn2VujvMIDfClRUxI43ZKLS+CnC0w6NHGm+Mc9uSWcOTE6YYtxHRX0S3hihfh01th4R9hzDNOG/rnw1ozff0R/vTlDj69t1/t7ydUlMAH18CRdXDNO9BykNMyuYvSimpWZxew7uAJNh46wc7jxZw6XVWr9/r7WGmTGErn1Ei6pkbSvXkUreJC6n8/pwG8vsrS18dK12aRdG0WyX1DW1NRXcPmw6e+L1J9d8UBXl+ajY/VQpfUiO+LVHs0j/KuCnkRERFxX4XZ8P6VYPU1ClIj08xOJCIiIlJ3Ma3g9nlGYerHNxg3TLvdZHYqEWmA3FNni1I9pPvb+SR3NY7HNrpdUSrAxO4p/GbGVrYcOUWXZpF1H6C8CLLmQ8/bwFr7e+q7c4r5eutxfjaklXY1O5f49tDvfljxL+hyHbQcbHYiEWkEtqJy4sMCTXmI3Chi2xrH/D1ut/tBu8RwwgJ8WbO/kPGuKEo9thGKj0G73zdomBkbjxIV7MfgNnFOCtaEDHrE6By+/Hm49B9mp3GJYW3jmfPwIH49fSu//2I7X289zpOXd6D9mW21XWmvrYT3Vh5g+oYjlFXW0KtFFL8Z144xHRKbTIc28Q624nI6pkSYHaPh/IMhIs04RzZRmUnhbD16qvEmOLwWHHZjcboLLdiZC8CIdvEunbfJ6DABDqyAVf+GtL7Q/nKnDBsW6Mejo9vymxlb+WrLcS6vzQKZylL48Fo4vBquehMyr3RKFndQXF7FN1tz+HLLMdZkF1JZY8fPx0JmUjiXdk4iPTaE5jEhxIT6ExXsT4CvFYvF6IpfXF5NQWklR06UcaigjO3Hin6yc3xadDDD28Uzon08fdNj8PNxzS4+Xl+U+t8CfH3o3TKa3i2jeXgknK6sYf3BE6zKzmflvgJeXbKPfy/ai7+Ple7NI+mXHkv/1jF0SY3E39e9tl4SERERL3DyMLx3JdRUwq1fG8UcIiIiIu4qNA5u/Qo+uRm+uB9KcmHgI0bHOBFxO7nFXlKUGtcOfAPh+CbofI3Zaersks5J/GHWdj5df7h+Ram7Z0NNBXSYWKe3PT9/DyH+vtw9OL3uc3qLIb+BnV/CrAfgZyvBv2lt4ysiDWcrriDekzqKx50pSs3b5XZFqT5WCz1bRLFmf4FrJtw9GyxWaFP/HSJOna5i/o5cru/VzDOfY0emQdcbYMN7RoFquId1gz2PpIgg3rutFx9/d5i/ztnFpS8u47peafx8eGtSIuvfVbc2Siqq+XrLMT5bf4TvDpzA38fKFV2TubV/C88oChS3ZCuuID7MQ64pYzOadlFqcjhfbz1OUXkV4YGNsHDw0EqjwY6LPyMs3JlL59QI4j393sSFjHkGjm2AGXfDbbMhuZtThr2mZzPeW3WQZ7/ZxajMhAs3f6wshQ+vg0OrYOIb0PEqp2RoyhwOB2v2F/Lx2kPM2Z5DeZWdlrEh3DqgBUPbxNG9AQ0z7XYH+wtKWbWvgG932fhorbFzfGyoP+O7pnB1z1TaJTbuohYVpV5EkL8PAzNiGZgRCxiVyesOnGDlPqNI9YWFe3h+AQT5+dCzRRT9W8XSv1UMHZLD8XVRZbGIiIjIORXnGh1Sy08aHVITMs1OJCIiItJwAWFwwzT44j5Y+CcoscGYv4BV92FE3E3OqQrAC4pSffwgoSMc22R2knoJD/Tj8s7JzNhwlMfGtCMiqI4PH7fNgPBUSO1V+7ccPcWc7Tk8OCKDyGD/Oib2Iv7BcMW/4d1L4NunYexfzE4kIk6WW1ROepwHFZxHphkLNfJ2m52kXga0jmXR1zs5fuo0SRGNW/zHrq8hrT8ER9d7iG+2Hqey2s7E7qlODNbEDHoENn0Ay1+AS/5mdhqXsVgsXN87jXEdE/nXwiymrDrIp+sOc0XXZCb3a0GX1AindVguKq9i8e485u/IZcGOXE5X1ZAeF8Kvx7bjmp6pxIZ6UOG8uJ2SimrKKms8ZwFHbBvYsArs9iZ5nyvzTFfmXceL6d2y/uen8zq4EpK6unSxXX5JBRsPn+ThEW1cNmeT5BsA138Mb4wwCkPvXAiRzRo8rI/VwpOXZzLp9dW8vjSbB0dknPuFZYVGh9Sj62HC69Dp6gbP3ZTZ7Q7m7cjh1SXZbD58krBAXyZ2T+XqHql0axbplHO41WqhVVworeJCualvc05X1rA0K4+ZG47y3qoDvLl8P51SIvj47r6EBDRO+aiKUusoLNCPYe3iGXambfPJskpWZxeyal8+q7IL+OucXcbrAnzpkx5Nv1ax9EuPoV1iGFa1qRcRERFXKSuEKeOhOAdunum0FW0iIiIiTYKvv3GDMiQeVr9sFKZOeM24gSoibiO3qJzQAF9CG+nmd5OS3BU2f9JkHy5ezG0DWjB9wxGmfXeYu+rSufT0Cdj3LfS5p07f9wsL9hAe6MsdA1vWI62XaTEAet4Bq181tl10s86DInJhtuIK+rWKMTuG81h9IKZpd4G7kAGtjSZGy7PyuaZnwws1zqswG2w7YMyfGzTMjA1HSY8LoXOqB3ewjGoBXSbB+ndh4C8gPMnsRC4VGezPk5d34M5B6by1bD8frT30/d/7ZZ2TGZwRS5dmkXXapressprNh0+x/mAha/YXsjq7gKoaBzEh/kzonuLUghmRhrIVGbtvxId5yP2g2AyoKoPiYxDR9BYUZCYbRak7jp1yflFqVblRkNjnHueOexGLdtlwOGBE+3iXztskhcbDjZ/CW6Nh6kS45UsIS2zwsH3TY7ikUyIvL9rLlV2TaR7zX0XHJw/BlInG8Zp3IfPKBs/ZVDkcDr7eepzn5u8hO6+UtOhgnh7fkat7pNa7I2ptBfn7MKZDImM6JFJYWskXm46y41hRoxWkgopSGywy2J+xHRMZ29H4QbQVl/9QpLqvgAU7bQBEBfvRr1XM90WqreJC9EFNREREGkf5KZgyAQr2GRcPaX3MTiQiIiLifFarsbVUWALM/wOUFcB1UyGwcbcdEhHnyS0qJ8FTOtpcTFJX+O5NKNxnPGh0Mx1TIujTMpp3Vx7gtgEtar9L2M6vwF4FHSfWeq6z99UfG9O27l1ZvdWoP0LWPPjifrhnGfh5ePdhES9RXlXDqdNVnlNoc1ZcWziy1uwU9dI2IYzYUH9W7G3kotRds89MeEm9hzhcWMbaA4U8Nqat5z+THvQobPoIVr7otV3DUyKD+MPlmTw8KoNvth5n+oajvPRtFi8uzCLIz4c2iWG0TQglKSKI2FB/gv19sVigusbBydOVFJRUcqiwjOy8UvbllVBtdwCQER/K7QNaMiozgW5pUfioCZc0MbZiY/eN+DAP+fwbe6ZbZ/6eJlmUGh8WQEyIP9uPFTl/8KProaYSmg9w/tgXsHCnjaSIQDok634iAPHt4IaPYerV8O6lRmFqeHKDh/3DZR1YuiefJz7fxvu39/7hs8mBFfDpLcbf/eTPoXn/Bs/VVK3OLuAvs3ey+cgp2iWG8dL13RjXMdGUXdijQ/y5bUDjLwJWUaqTxYcFckWXZK7oYvxQHj15mlX7Cs585TN7a86Z1wXQv1UM/VrF0L9VLM2ig82MLSIiIp6ishQ+uBZyt8OkDyB9iNmJRERERBqPxQIDHjI6pn5xP7x/Bdw0o0FbXIqI6xhFqR7y8PBizu5ecWyTWxalAtw+sCX3TFnP/B25jOtUyy5k22caHcySu9fq5TV2B3/6agcpkUHqkloXAWFw+Qsw9SpY/GcY9SezE4mIE+SdLbTxtHNlXFvYNh0qy8DfvZ6PWq0W+reKZfneAhwOR+MVe+6eDfEdILr+58KZG48CML5birNSNV3RLY1uqevehgEPGwsXvVR4oB/X9Urjul5pnCyrZNW+AtbsL2R3TjHf7rKRX1J5zvf5+1hJjQ4iPTaEEe3j6dkiiu5pUUQG+7v4OxCpmx/OlR6ygOP7otQsaDXc3CznYLFYyEwOb5yi1IMrAQuk9XX+2OdRUV3Dsqw8xndL8fwFHHXRvD/cPMMoTH3nErhhGsS1adCQiRGBPDamLU/O2s6szce4sksyrH0d5v7WuGcw6UPjM6IHOnryNH+ctZ15O3JJigjkH9d0YUK3FK9Y6KGi1EaWEhnE1T1SubpHKg6Hg4MFZazKLmDlvgKW783n803HAEiNCqJfegz9W8fQLz2WxAgPu8AUERGRxldVDh/fYHQauPodaDPG7EQiIiIirtH1egiKgmmTjZulkz93yvZSItK4cosq6OPsLf+aqrh24BsIxzdB52vMTlMvI9sn0Cw6iLdX7K9dUWppAWQvhgEPGosIauHTdYfZebyIl67v1uhb13mc1iOh+y2w4kXIGA0tBpqdSEQayFbsYVsSnxXbBnBAQRYkdTE7TZ0NbB3LrM3H2JNbQtvEMOdPUFoAh1bBoF/WewiHw8GMDUfolx5DSmSQE8M1YYN+CZs/NrqljnnG7DRNQmSwP+M6Jf3kc1tVjZ0TZZWcrqwBwGqxEBXiT4i/jwqyxC390CnVQ86VofEQEGF0Sm2iOqVE8PrSbMqrapx7zXZwBcRnGvf3XGR1diGllTWMbO+9ixnOK62vcX/1w+vgjeEw8XVoV/8O7gA39W3OjI1Hee3LZYzbPA3//QuNrvATXoPACOfkbkKqauy8tXw//1qQBcBjY9pyx8CWXnWvQ0WpLmSxWGgRG0KL2BCu752Gw+Fgr62ElfsKWLkvn3k7cvl0/REA0mNDvu+i2jc9mphQDzmJioiISOOoroRPbzUe+I1/DTqMNzmQiIiIiIu1HQs3fQYfToK3x8LkLyCqudmpROQ87HYHtuJyz+v+dj4+vpDQ0eiU6qZ8rBZuH9CSP365g1X7CujXKubCb9g5Cxw10GFircYvLq/iH/N207N5FJd1rmUnVvmpMX+GA8tg5r1w73IIijQ7kYg0QG6Rh21JfFZcO+OYt9sti1IHZMQCsHxvfuMUpe5bCA47tB1X7yE2HDrJgYIy7h/W2onBmriYVtD5WvjuLWM3jdB4sxM1SX4+Vs/7nSJezVZcjr+PlYggP7OjOIfFYuyskbfb7CTn1Tk1gmq7g105xXRtFumcQWuq4fBa6HqDc8arpYU7cwny87n4ta23Su0Jdy+GT26Cj6+H3vfA8CcgMLxew/k4qvlPm/UELX8Wx4EaHOP+hqXXXWB1/fb1jW3z4ZM89tlm9uSWMDozgSev6OA9C4V+xPP+Zt2IxWIhIyGMW/q34D8392TD70fx1QMD+d0l7WkeE8znG49y/4cb6PH0Asa+sJSnvtrBol02SiuqzY4uIiIiTUlNFXx2G+z5Bi59zugUJiIiIuKNWg42ilFPF8I744ztzkSkSSosq6SqxkGip2yzWBvJXeH4ZrDbzU5Sb9f3TiMhPIDn5+/B4XBc+MXbZ0B0K0jsVKux//LNLgpLK/nD5Znq1FVfAaEw8Q0oOgazHzM7jYg0kK3I6JSa4Gnnyuh0sPg06YKbC0mJDKJlbAgr9uY3zgRZ8yA4FpK61XuIGRuOEOhnrV1nc08y+DGoqTC6pYqIV8grqiAuLMCzrh9i2zTp+1mdUiMB2HrkpPMGzdkMVaXGtvEu4nA4WLjTxsCMWK/qXFlnkc3g9rnQ+25Y+zr8uxds/ACqK2o/RnUFbJkGL/chccXvKY5sx5jyZ5gVcJnHFaRW1dh5fv4eJr66kuLyat6c3JPXJ/f0yoJUUFFqk+JjtdAxJYK7Bqfzzm292fTkaGbc15/HxrQlJtSfKasPctu739Hlj/O49rVVvLBgD+sOFFJV4743MUVERKSBaqph+p2w6ysY9zfodYfZiURERETM1awX3DobaiqNjqnHt5idSETOIff7Qhsv6tSU3A0qi43tit1UoJ8PPx/WmrUHClm403b+F5bY4MBy6DjR6PZzEcuz8vlwzSHuHJRO5zMPOaWeUnvCkF/D1mmw9TOz04hIA+QWV+BrtRAV7G92FOfy9TcKU/PdsygVYGDrWFZnF1BZ7eRntPYa2LsAWo+sd5HG6coaZm0+xpgOiYQGeNmmqTGtoNM1RrfUkjyz04iIC9iKjaJUjxKbASU5UH7K7CTnlBwRSEyIP1uOODHfwZXG0YVFqbtyijl68jQj26uz9kX5BcIlf4e7FkJYInxxHzzfERY+BYdWG7t5/rfKMtj3Lcz9HTzXHmbcBT5+cMM0Eh+YT1Sz9vzhi+3knCp3/ffTSPbaipn4ykr+tTCLK7skM+fhwYzMTDA7lqm87JOoe/HzsdI9LYruaVHcP6w15VU1rD94guV781mxN59/LczihQVZhPj70Dc9hgGtYxmYEUtGfKhnrQQRERGRc7PXwMx7YMfnxhZ9fe4xO5GIiIhI05DYEW6bA+9fCe9eBjd9Bs16m51KRH7k+6LUCC8qSk3paRyPrIO4tuZmaYBJvdN4d+UBnpm9k8Ft4vD3PUfRzI4vjK2HO0y86HglFdX8evoW0mNDeGRUm0ZI7IUG/dIoavrqEUjpAdEtzU4kIvVgK6ogPiwAq9UDn/nFtXXbTqkAAzNimbL6IOsOFtK/VazzBj66AU6fgIxR9R5i9tbjFJdXM6lXmvNyuZNBjxrd2Fa+CKOfMjuNiDQyW3E5LWJCzI7hXLFnronys4wFZ02MxWKhc2qE84tSo9ONgkcXWbgzF4Bh7VSUWmspPeCuRZC9CNb8B5b9E5b9A3yDICIFgqIBB5TkGrt32KvB6gttxkLP2yF9GFit+ALPXduVS/61jIc+3sgHd/bB18d9e2ra7Q7eXXmAv87ZRbC/D6/e2N37utWfh4pS3Uignw8DWscyoLVxcXOyrJJV+wq+L1JduMtYmR4XFsDAM68b0DqGpAjvbAMsIiLi0ew18Pl9sO0zGPlH6He/2YlEREREmpbY1nD7HHj/Cnh/PFz/EaQPMTuViJyRW2Rs9eZVnVJj20BAOBxdB91uNDtNvfn5WHniskxue+c7XluyjwdHZPzvi7bNgLh2kJB5wbEcDgdPfrGdY6dO89m9/bRtorP4+MJVb8B/hsC0yXDHfKO7jYi4FVtxOXGeep6Mawt75kBNldE1y80MbB2Lv4+VhTttzi1KzZoHFiu0Gl7vIT7+7hAtY0Pomx7tvFzuJK4NdL7W2GK4788gPNnsRCLSiGzFFfRu6WG/784uYMzb1SSLUgE6pUayZE8WZZXVBPs3sOzMbodDq6Ddpc4JV0sLdtro0iyS+DAP/azVWKxWaD3C+CotgIMrjL+/4uNQVmjslJLW3yhSTesPaX0hIPR/hmkZG8IzEzryyLTNPL9gD4+NaWfCN9NwR0+e5tFpm1mVXcCIdvH85apO+jf1IypKdWORwf6M65T0fYX1kRNlrNxrFKku3ZPHzI1HAWgVF/J9kWrfVjGEB7rfxZ2IiIj8iN0Osx6ALR/D8Cdg4MNmJxIRERFpmiKbGR1Tp4yHD681ClMb8IBXRJwn51Q5FgvEe9pWixditUJyN6NTqpsb1jaey7sk89K3WYzukEC7xPAf/rDomPFQaujjFx3nnRUHmL7hCA+NyKBHcw97mGy2qBYw8XXj/PfNY3DFS2YnEpE6shVV0Dwm2OwYjSO2rdE9qzDbLbuHhwT40q9VDAt35vLEpe2dt4Nl1jxI7Q3B9Tsn7rUV892BE/xmXDvv3lVz2G+NBTKLn4UrXjQ7jYg0korqGk6WVXleAVhUS/AJANtOs5OcV+eUCOwO2HGsiJ4tGngdl7fL6BLefIBzwtVmyuIKNh85ySMjtVNHg4TEQOYVxlc9TOyeyprsQl5etI+eLaIZ1tZ9utY6HA6mbzjKH2dtx+5w8NerOnFtz2be/fnrHNy3/638j9SoYK7t1YwXr+/Gd78byTcPDeKJS9vTLDqYaeuOcPeU9XT94zzGv7yCf8zdzap9BVRU15gdW0REROrCboevHoJNHxgP+AY/ZnYiERERkaYtLAFu+RJiWsOHk4ztjEXEdLbicmJCAvBz4y3a6iW1J+Ruh8oys5M02P9dnklEkB8//3AjpRXVP/zBji8AB3SceMH3L8/K55nZOxmdmcBD5+q2Kg3XZgwM+iVseB82fmB2GhGpo9zicuLDPXTxxo+7wLmpke3jOVBQxr68UucMWJwLxzdBxqh6D/Hx2sP4Wi1c1T3VOZncVVQL6HUnbJwCeXvMTiMijSS/pBLwwIWOPr7GLhtN+BzZKTUCgC1HTjV8sIMrjGNav4aPVUuLdtlwOGBE+wSXzSnn9scrO9AuMYwHP9rIXluJ2XFqJb+kgnumrOfRTzfTPimcOQ8P5rpeaSpIPQcvu+PnPaxWC+2TwrlzUDrv3tabzU+O5pO7+/LzYa3xsVp4dck+rn9jNV3+OI/Jb6/l9aX72H7sFHa7w+zoIiIicj4OB8w+8zBp0KMw5NdmJxIRERFxDyGxMHmWsZXjRzfAnnlmJxLxejmnyknw1EKbC0ntBY4aOL7Z7CQNFhMawIvXdyM7r4RffLKJ6hq78QfbZkBCJ4g9f6Hp1iOnuO+D9bSKC+G567piterhTaMZ9jtoORi+fgSObzE7jXdxOKC6EqorjEXGInVwtvtbgqd1fzsrtg1gAVvTLbi5mOFnClm+3ZXrnAHPLp7LGF2vt1dU1zB9wxFGZSYQ52kFWvUx+FHwC4Fv/2R2EhFpJLaicgDPXMAR365JnyMTwgNJCA9g61EnFKUeWgVhycaCAhdZsDOX5IhA2ieFuWxOObdAPx/emNyTAF8rt7/7HYWllWZHuqB523MY+8JSFu/O43eXtOeju/vSLNpDdzZwAl+zA4hr+Pta6ZMeQ5/0GB4BisurWJNdyPK9+azYm8+fZxsntJgQfwZmxDI4I45BGbHEh3voxa6IiIi7cTjgm1/BurdhwMMw/AnQiisRERGR2guJMQpTp4yHT26Ea6dA27FmpxLxWrlFFSRFeOG9x5SexvHId9DcdZ1gGkv/VrH84bJM/u/LHfzqsy38dVQUfkfWwog/nPc9q/YVcM+UdYQF+vH2rb0IDdBjikZl9YGr3obXh8BHk+DOhRCeZHYq91dWCPl7oGAv5GdB0VEoyYUSG5TmQ9VpqD4Njh8Vo1p8wDcAgqIhOAqCYyAsCSKbQ1Rz40F8fCYERZr1XUkTkldcAXhooQ2AfzBEtwTbdrOT1FtKZBDtk8JZsNPG3YNbNXzArHkQmgiJner19nnbczlRVsWk3mkNz+IJQmKh/wOw+M9w+Dto1svsRCLiZLaz50pPXMAR1w62fgrlRRAYbnaac+qcGsmWIycbNojDAQdXQvMBLnvmWV5Vw7KsfK7ukarOlk1Es+hgXp/ck0mvr+bu99cx5Y4+BPn7mB3rJ06druLpr3bw6fojZCaF88GdXWmbqKLmi9HdHi8VFujHyMwERmYaq/hyi8pZnpXP8r35LMvK44tNxwBolxjG4DZxDM6Io2eLKAL9mtYPvoiIiFdwOGDub2Ht69Dv5zDy/1SQKiIiIlIfwdEw+QuYMgE+uQmufQ/aXWp2KhGvlFtUTpdmkWbHcL3QOIhMg6PrzE7iNLcOaElxeTX/nL+HnsemcgNAh4n/87qqGjuvL83m+fl7aBEbwju39iI1Sh1FXCI0Dm74BN4eCx9eC7d9AwGhZqdyH9UVcGQdHF4DxzbA0Y1QdOSHP7f6QXgyhCZATGtj61H/EPANBL9AwAL2aqg50zX19AkoKzC+9i+FomPAj3axi2gGCR0hsSMkd4NmfY3FNeJVcos8uNDmrPhMyN1hdooGGdk+nlcW7+NkWSWRwf71H6imCvYtgswr6n3f9+PvDpESGcSg1rH1z+Fp+t0P370BC56EW7/WPXURD3O2KNUju0PHZxrHvN1Ntqi+c0oEC3bmUlxeRVigX/0GObEfio+7dMHmir35nK6qYdSZWilpGrqnRfH8tV35+UcbuHvKOt6Y3LPJ1Kct2mXj8RlbySup4P5hrXhoRBv8fbUxfW2oKFUAo732VT1SuapHKna7g505RSzLymfpnjzeXXGA15dmE+hnpU/LGAZlxDKkTRyt40O1ckBERKSxORww/w+w+hXocy+Mflo3z0REREQaIigKbv4cpl4F0ybD1e8YD39FxGUqq+0UlFaS6K27NKX0hMNrzU7hVA+MyCAhIpD2Xz7BNloxa3U5ozILSY4M4kRpJav2FTB1zUEOFpRxaeck/jKxE+H1fXAp9ZPYyTjnfXQdTL8TJn1gdFGV/+VwGF1Qd38D+5fAwVVG11Mwupk26w3JdxsdrGJaG51OfRrwuK26Ak4dgYJ9kLvN+MrZBllzf+i0GtcOmveHtP7QcjCE6SG6p8sr9uAtic9K6AC7Zxudhf2CzE5TL8PbxfPSt3tZtNvGhG6p9R/o8FqoOAUZo+v19oMFpazYW8AvR7XBatW94+8FhMKQX8PsR43f6e0uMTuR/FhFMZw4CCcPGt3GK4qNr8pSsFjB6gs+fsZCj5A4CIk3FtpEtVRXcQEgr6gci8XYDdjjxLczjrYdTbcotVkkDgdsOXKKAfVdEHFwlXFsPsB5wS5iwc5cQgN86ZMe7bI5pXYu7ZxEWWVnHvtsCz+bup7Xbu5BgK9516ynTlfx1Fc7+Gz9EdokhPL65B50To00LY87UlGq/A+r1UKH5Ag6JEdw75BWlFVWsya7kKVZeSzdk8fTX+/k6a93khQRyKCMWAZlxDGwdSxRnniyFxERMZPDAQv+D1a+CD3vgLHPqiBVRERExBmCIuHmGTD1avj0Vrj6LegwwexUIl7DdqbQJsGTC20uJLUnbJ8BRcc9ahv1a9OrwJLNzNh7eHNZNq8vzf7Jn3dpFsmTl2cyvJ2K6UzTZjSM+5tRnDPncRj3V91nOMtuh6PrYddXxlfBXuO/x2dCj1uMQtC0fkbXdWfzDYCYVsZXmx8VpFWdhmOb4NBKY1vTLZ/CureNP+t6E4x/2flZpMnwmk6pDrvRBS65q9lp6qVLaiSJ4YF8vSWnYUWpWfOMArz0ofV6+/urDuJrtXBNz2b1z+CpetwKa9+Aeb+D1iOM37nieqdPGsXXR9cbXcePbYJS2zleaDGKUB0OsFcZXYR/3E38rJA4Y2FIXDtI6W50Fo9r37BFIuJ2bMUVxIQE4OvjgR0LI1uAbxDk7TI7yXl1bRaJxQIbDp5oQFHqSgiKhti2zg13Hna7gwU7bQxpG2dqsaOc3zU9m1FV4+C3M7dy69vf8drNPYgIcu2CVofDwbwdufzhi23kl1Ty82GteWBEa/2bqQedleWigv19GdYunmHt4gE4evI0y/bksTQrjznbcpi27ggWi9Gee3CbOAZlxNEtLRI/Tzz5i4iIuIrDAfN/Dytfgh63wSX/0IMiEREREWcKjPihMPWzO4wH4h2vMjuViFc4W2iTEOHBhTYXktLTOB5dB+GXm5vFmbbPBGDCTT9nqF8i6w6eoKCkgoggPzokR5AWE2xyQAGg911QuB9Wv2x0Dx/2uNmJzJW3GzZ/BFumQdFRoyisxSBjt5p2l0J4snnZ/IKMrUyb94NBvwR7DeRsgezFEN6A4jdxC7bicnysFs/s/nZWQgfjaNvhtkWpVquFyzon8d6qA5wqqyIiuJ5FE3sXGIXvgeF1fmtJRTXTvjvMuE5JJHrrZ6sL8fGDMX+GD66CNf+BAQ+ancg7OBzGz/aeuZA1Hw6vAUcNYDEKSTNGGUWlUc2NLuRhSRAQBn4hYP2vGoPKUiixQWk+lOQYncUL9hpf26bD+neM1/kGGd3MWw6CFoONYlUfdeb3ZLbiCuLDPLTQ3GqFuDZg22l2kvOKCPIjIz6U9YdO1H+QgyuM3QD+++e+kWw6cpK84gpGZ2qhZFN2Q580gvyt/OqzLVzz2krevrUXqVGuuZ+wP7+U/5u1nSV78miXGMYbk3uqO2oDqChV6iwlMohJvdOY1DuNGruDzUdOsnRPHsuy8nl50V5e+nYvoQG+9GsVw+A2cQzJiNMNRxERkbpwOGDub2H1K9DrLrjk7ypIFREREWkMAWFw03T44BqYfhdY/SDzCrNTiXi83KIznVI9ufvbhSR1AR9/o1tUew8qSt02E1J7QWQaUcAoPehrukY/DeUnYcmz4B8MAx4yO5FrlRXC1s+MYtRjG8DiA61HwognjU6lQVFmJzw3q4/RCS65m9lJxAVyiyqICw3w7K3Yo9PBNxByt5udpEGu6JrMm8v3M2f7ca7rlVb3AU4dhdxtMOqpes0/ff0RiiuquW1Ai3q93ytkjISM0bD079DlemMLeGkcRcdh6zTY/AnYzvxsJ3aGgQ9D+jCjAD0grG5j+odAdEvj67/Z7VCYDcc2Ggu+DiyHb582/iyuHdy/piHfjTRxecUVxHvy7hvxmcZipCase1oU32zLwW531P0zS9FxOLEfet3ZOOHOYcGOXHysFoa2iXfZnFI/E7qlkhAeyD1T1nPZS8v561WdGdMhsdHmO1VWxWtL9/HWsv0E+Fr5/WWZTO7XXM0YG0hFqdIgPlYL3dOi6J4WxcMj23DqdBWr9uWzZE8+S/fkMX9HLgDpcSEMaxvP0LZx9G4ZrbbGIiIi5+NwwDe/hrX/MbpyjH1WBakiIiIijSkgFG6cBlMmwme3w3VToO04s1OJeLSzRale283LL9AoKju02uwkzpOfBblbYcxfzE4itWG1whUvGdvDz/8DYPGOznHHt8Da12Hrp1BdDgmdjO55Ha+GMBVRS9NiK64gwZMLbcAotI5ra3RTdGOdUiJoHhPMrM3H6leUune+ccwYXee32u0O3l15gK7NIume1kQL6puKMX+GV/rCoqfh8n+ZncazOBxwYBmsegWy5hq7kKT0NHZ/a3cZhCc13txWK8S2Nr46X2P8t9ICOLgcqsobb15pEmzF5bRPqmORszuJa2csojp9oskumurePIqPvztMdn4preND6/bmQyuNY/P+zg92HvN35NKnZXT9O5uLS/VvFcusnw/kwY82cs+U9Vzfuxm/HtuOyGDn7SRQVF7F28v389by/RSXVzOxewq/GdeOeG9dRO1kKkoVp4oI8mNsxyTGdkzC4XCwP7+UJXvyWLQ7jymrD/LW8v0E+/vQv1UMQ88UqbqqzbKIiEiTZ7fD7Edh3VvQ934Y84wKUkVERERcISAMbvoM3h8P0ybDpI+Mbjoi0ihyisrx97ES5c0PgtL6waqXjaJAvyCz0zTc9pmABTqMNzuJ1JbVBya+Djhg/u+hohiG/dbz7kPUVMHOWbD2DTi0ytjat/N10OsOo2uxSBNlKyr3judn8R1g37dmp2gQi8XCFV2SeXnRXmzF5XUvYsiaDxHNjALdOlqyJ4/9+aX8a1LXOr/X68RmQO+7Yc1rRle+xE5mJ3J/NdXGQo9VLxuLk4JjYMDD0PVGo0jULCExkHmlefOLS9TYHeSXVHp24Vh8e+No2wXN+5mb5TzOLojYcPBE3YtSD64E/1Cjm7ILHMgvJctWwg196rGAREzTMjaE6T/rzz/n7eaNZdnM3prDwyMzuKFPWoOaIR4uLOOjtYeYuvogReXVjO2QyEMjM2ifFO7E9KKiVGk0FouF9LhQ0uNCuW1AS8oqq1mdXcCiXXks2m1jwU4bABnxoQxrF8/QNnH0bBGNv6/aH4uIiBey2+HrX8D6d6H/gzDqT573IEhERESkKQuMgJtnwHuXw8c3GN1T04eanUrEI9mKjG0WLd58zZPWD1a8AEc3QIsBZqdpGIcDtk03vqfwZLPTSF34+MFVbxnb4i79m9GFaeyz4OMBj46Kc417LOvehpIciGoBo5+Bbjc22U5TIj9mK66gR3Mv+LeakAmbP4SyQgiONjtNvV3RJZmXvt3L11uOc9uAc2wxfj7VFcbWzJ2vrde94NeXZpMQHsAlnRqxE6UnGfIr2PIJzH4Mbp1tdNmUurPXwNbPYMmzUJgNce3h8heNf8eesNhK3EJhaSU1dgfxntxV/GxRat7OJluUmh4bQkSQHxsOneDaXs3q9uaDK6FZb5ddeyzYaezyPLK9dkhwN/6+Vh6/pD0Tuqfw1Fc7+OOXO/j3t3uZ1LsZV3VPpWVsSK3uLxWUVDB/Ry6zt+WwLCsPCzAqM4EHhmfQMSWi8b8RL+QBdxbEXQT7+zK8XQLD2yXgcDjYl1fK4t02Fu/O450V+3l9aTYh/j4MzIj9votqUoQ+uIqIiBew2+HLB2HjFBj4CIz4gwpSRURERMwQFAU3f2EUpn44yeie2mKg2alEPE7OqXISwj24o01tNOttHA+tcv+iVNsOyNtlbNEq7sfqA5e/ZJwDV74EJw7A1W9DoBt2iHE44PBaWPs67PgC7FXQeiT0ftE4WuvfSUfElSqr7RSWenj3t7PiM41j7nZoOcjcLA2QkRBGu8QwPt94tG5FqYdWQWUJZIyu85zrD55gVXYBT1zaHj8fFVfWSlCU0Qzii/uNe/E9bjE7kXtxOGDX17DwT5C/GxI6GbuMtB2n5xnicrbicgDiQj24KDWimdFJ1LbT7CTnZbVa6JYWyfqDJ+r2xrJC4zqy48TGCXYO83bk0i4xjGbRXtCJ3kO1Swxn6h19WLmvgHdXHuDVxft4edE+UqOC6JceQ3pcKGnRwQT6WbFaLRSdriLnVDnZeaVsOHSCLFsJAM2ig/j5sNZM6p1GSqRq0hqTilLFFBaLhdbxobSOD+XOQemUVlSzcl8Bi3bbWLzLxtztxiqFdolhDG0bz7C2cXRvHqWLKhER8Tz2Gpj1AGz6AAb/yjO3yhMRERFxJyExMPkLePdS+OBao3tqWl+zU4l4lNzictonumHBmzMFRxtdpQ6tMjtJw22bDhYfyBxvdhKpL6sVRj8N0a3g61/CW6Pg2vfrtZW0KapOG/8O174OxzdDQLixNXPvuyCmldnpROosr6QCwLO7v52V0ME42na4dVEqwLU9m/Gnr3aw9cgpOqXWsttW1nzw8YeWg+s83yuL9hIV7Mf1vbUNcZ10vRE2fQTzf28UU4bGm53IPdh2wZxfG519Y9vCNe9B+yvUbVZMYyv2gnOlxWJ0S83dYXaSC+qRFsXi3XmcOl1FRJBf7d509jo4rX/jBfuRwtJK1h0o5P5hrV0ynzQei8XCgNaxDGgdy7GTp1m4y8aS3Xl8u8vGp+uPnPM9UcF+dG0WyRVdkhnWLp4OyeHevXOPC6koVZqEkABfRmUmMCrT6KKaZSth0S6ji+qby7J5bck+wgJ9GdwmjpHt4xnWNp7IYH+zY4uIiDSMvQY+vw+2fAxDH4ehvzE7kYiIiIgAhMbBLbPgnUtg6tVGkWpqD7NTiXiM3FPlDGkTZ3YM86X1NQrp7DXu28HR4TC+h/Qhxu9OcW89b4PolvDZHfD6ULj0n9D1BrNTnV9hNqx7GzZOhdMnjELvS5+DztdBQKjZ6UTqzVZkdH9L8ORCm7NCEyAoGnK3mZ2kwa7qkcrf5u5i6uqD/PXqzrV7U9Y8Y2cG/5A6zbXjWBELd9l4ZFQbQgL0uL9OLBa47Hl4bQDM/R1c9YbZiZq20ydg8bOw9g3j3Dru79Dzdpdtty1yPnlFZ4pSPb2reEJH2D7TuO5qokV03ZtHAbDp8MnaX+cfXAk+AZDimnttC3fmYncYW7WL50iODOLmvs25uW9zAIrLqzhy4jQV1XZq7A7CA31JjAgkLLCWxdLidPq0IE2OxWKhTUIYbRLCuGdIK4rLq1ixt4BFu2ws3GXj6y3HsVqgZ/NoRrSPZ2RmAq3idINJRETcTE01fH4vbP0Uhj0BQx4zO5GIiIiI/FhYItzyJbx7CUydAJNnQXJXs1OJuL2SimpKK2tIDPfwh4e10bw/rH/H6A6X2MnsNPVzdIOx3fvgX5mdRJwlfSjcuxxm3AWf/8zYpveSf0B4ktnJDPYa2DMXvnsT9i00uvS2uwR632MUdjXRh/UidZHrLYU2YPzMJnRo8l3gaiMiyI/xXVP4fNNRfntp+4t3iztxAPL3GAV+dfTy4r2EBvhyS78W9crq9eLawMBHYMmz0PV6aDXc7ERNj8MBW6bB3MeNwtQetxrPMUJizE4mAoCt2FjAERfm4Qs4EjoY14xFRyEi1ew059SlWSRWC6w7UFj7otQDyyG1F/i55rPON9tySIkMolNKLTuZi1sKC/SjfZIKUJsS9VOXJi8s0I+xHRP569WdWfvbEXx+/wDuH9aa4opq/vLNLkb8cwnD/rGYp7/awap9BVTV2M2OLCIicmE1VcbDna2fwognVZAqIiIi0lRFpBiFqQERMGWCsWWhiDRIzqmz3d+8oNDmYtL6GsdDq83N0RDbPjO2HW53qdlJxJnCk4wu4SP/CHsXwMt9YM3rUF1pXqaTh2HJ3+GFzvDx9UYx99DH4Rfb4LqpxrbfKkgVD5F3ptDGo7ck/rGEDmDbCXb3f753U9/mlFfZmX6e7WN/Imu+ccwYXac5th87xeytx5ncrzkRwSq8qLeBv4CY1vDVL6CyzOw0TcvJw/DBNTDzbohOh7uXGN1lVZAqTUhecQXhgb4E+rnpjhO1dXbxYu52c3NcQGiAL51SIliTXVi7N5Sfgpwt0GJA4wY7o6i8imVZeVzSKVFbtou4mIpSxa1YrRa6Novkl6Pb8s1Dg1jxm+E8dWUH0qKDeX/VQa5/YzU9nprPgx9t5ItNRzlVVmV2ZBERkZ+qKodpk2H7DBj1Jxj0iNmJRERERORCItPgli+Moqv3r4TC/WYnEnFrP2xJrKJUIppBeIqxdaE7stfAthnQehQERZqdRpzN6gMDH4afrYTkLvDNY/BKH9g23fi7d4XTJ2HD+/DOpfBCR1j0NMS0gmunwMNbYehvIDzZNVlEXCi3qAKrBWJCvKQoNT4Tqkrh5EGzkzRYx5QIujaLZOqagzgcjgu/OGu+UfAX06pOczz7zS4igvy4Z0jd3if/xS8QLnvB6Fg7/w9mp2ka7HajE/krfeHgChj7V7h9LiR1NjuZyP+wFVcQ7w3XlPGZxjFnq7k5LqJvegybDp/kdGUtrhMOrQGHHZq7pih14c5cqmocjOvURHZ+EPEivmYHEGmIlMggbu7Xgpv7taCkoprlWXks2Glj0S4bszYfw8dqoVeLKEa2T2BE+wRaxoaYHVlERLxZZSl8fCNkLzK2vut9l9mJRERERKQ2otNh8ufwzjh4/wrjwZyKYETqJef7olQvKbS5EIsFmveH/UuNLVLdrWvLoVVQkgOdrjI7iTSmmFYweRZkzYP5T8Jnt0PkH6HPvdD5Oud3TSvJg6y5sPsbo2CrpsLoZDfsd9DpGohu6dz5RJogW3E5cWEB+Fjd7LxQXwkdjKNth0f8jE/u15xHpm1m0W4bw9slnPtFVaeN83+PW+o09tI9eSzLyueJS9sTEaQuqQ3WchD0vR9WvwxtxkLGSLMTmefEQfj8Pji4HNKHweUvQFQLs1OJnJetuIL4MC+4pgwMh8jmTbpTKhhFqf9Zms2GQycY0Dr2wi8+uAKsfpDayyXZZm/NISkikK6pkS6ZT0R+oKJU8RihAb6M7ZjE2I5J1NgdbDp8koU7c1m408bTX+/k6a930iouhJGZCYzOTKRbs0is3nJBLyIi5isvgg+vhcNr4MpXoNuNZicSERERkbqIbw83zYD3rjA6pt72DYRc5Ea7iPyP3KIKQJ1Sv9dyMGz9FPJ2Q3w7s9PUzdbPwC/YKOIQz2axQJsx0Hok7PoaVr0Mcx+H+b83/lvbSyB9SP2KV0rzjXslh9fAgRVwdD3gMLoI97jVKHxN6e5+RdsiDZBbVEF8mBedJ+POnP9yd0C7S83N4gSXd0nmn/P28NK3exnWNv7cWwUfWAHVpyFjVK3Htdsd/OWbXaRGBXFzv+ZOTOzlRvwB9n0LX9wP962C4GizE7mWwwFbpsHsR43/fcW/odtNOu9Kk2crLqdHWpTZMVwjoWOTL0rt2SIKH6uF1dkFtStKTekO/sGNnqukopole/K4sU+aaoNETKCiVPFIPlYLPZpH0aN5FL8a247DhWUs3JnLgp023lq2n/8sySY+LIBRmQmM6ZBI3/QY/H2tZscWERFPVVYIUyca22tc9RZ0nGh2IhERERGpj5TucOM0mDIRpkyAW77UltUidZRbVE5YgC8hAbo1DRhFqWB0S3OnotSaKtjxBbQdB/7ancprWH0g8wrjK2erUcSybQbsmWP8eViy0fEwri2EJUFoPPgFgcUH7NVQfgpOF0LhfijcBwXZUHTkzNh+kNwNhv3WKHRO7KSCGPFatuIKUiK9qCg1INQoarc17YKb2vLzsfKzoa144vNtLMvKZ3CbuP99UdY88A2C5gNrPe7H3x1m5/Ei/jWpKwG+Pk5M7OX8AmHi6/DGcPjqYbjmPe85/5w+AV89AttnQFo/mPAfiFLBszR9DocDW1EF8d6y0DGhA+z5xuiy7RdkdppzCgv0o2NKBKuzCy78wspSOLYR+j/oklwLd+ZSWW3nkk5JLplPRH5Kd/7EKzSLDubWAS25dUBLTp2uYtEuG3O35zBjw1E+WHOIsEBfRrSLZ3SHRIa0idNNcRERcZ4SG7w/Hgr2wnUfQFt1kBERERFxa837w3VT4aNJRif8m2eqIEukDnKLykmI8JKHh7UR1cLYjnH/Euhzt9lpai97sVFc2PFqs5OIWRI7GV+j/gT5eyB7CRxdZ2y/vX8p1FSc/71BURDdCloMgPhMSOsLSV2NwiARIa+4nK7NIs2O4VrxHYxOqR7imp6pvLZkH3+ds4uBrWN/2p3N4YCsuUaH6Vr+3jt+6jR/mb2TfukxXNEluZFSe7GkzjD8d7Dg/2DzR9D1BrMTNb7sxTDzZ1BqM7rFDnjYWHwi4gaKyqupqLYTFxpgdhTXSOwIDjvk7TIWcTVRfdOjeXv5fk5X1hDkf57fJ4fXGIvVWgxwSaZvtuYQHxbgPV11RZoYVd6J14kI8mN8txTGd0uhvKqG5Vn5zN2ew4KduXy+6RgBvlYGZcQyukMiI9snEB3ib3ZkERFxV6eOGFu7Fh0zOmqlDzU7kYiIiIg4Q8ZIuPot+PRW+PgGuP4TFdKI1FJOUTkJ4V7y8LC2Wg6GnbPAXuM+xQDbpkNgBLQeYXYSMZvFYnRGjWsLnCmsdjig/CSU5EF1OThqjG6pgRFGh/HACBMDizRtVTV28ksqiQ/zsnPl2S5wlWUu2c63sQX4+vDYmLY89PEmZmw8ytU9Un/4w4J9cOIA9H+gVmM5HA6emLmNKrudZ6/qhMVbuni6Wv8HIWsBfP1LY6FEQqbZiRpHVTl8+xSs+jfEZMD1C5p0kZvIueQVlwMQ7y3XlQkdjWPOtib989o3PYb/LMlmw6ETDGgde+4XHVhhXBc069PoeUorqlm028akXs1+ujhERFxGRani1QL9fBiZmcDIzASqa+x8d+AE83bkMG97Lgt22rBaoFeLaMZ0SGR0hwRSo9z/QlhERFykMBveu9J4CHPzTKPrh4iIiIh4jswr4cqX4fOfwWe3w7XvgY+f2alEmjxbUQV90qPNjtG0tBwCG6dAzpYm/ZDxe1XlsPMr6HAl+HrJg2CpG4vF6IYapI5EInWVX2J0GU7wli2Jz0rqbHSBs+2A1J5mp3GKyzsn8+7KA/x59k5GtIsn6mwTnKx5xrH1qFqNM2vzMRbusvHEpe1pHqMdGhqN1cdYePifwTDtZrhrEQSGm53KuXK3w/S7wLYdet0Jo57yiCJw8T62IuNcGectCziiWoJfsPEz3IT1bB6Fj9XC6uyC8xelHlwJSV0gIKzR8yzenUdFtZ1xnZIafS4ROTer2QFEmgpfHyv9WsXw5OUdWP7rYXz1wEDuH9aak2VV/OmrHQz86yIuf2k5ryzey8GCUrPjiohIU5a3G965BCqL4ZZZKkgVERER8VRdb4Bxf4fdX8Pn94HdbnYikSbNbneQW1TufYU2F9NykHHcv9TcHLWVNc+43u14ldlJREQ8Tu6ZQhuv65Sa1MU4Ht9sbg4nslot/HlCJ06dNp4zfi9rHsS2hajmFx0jO6+E383cRve0SG4b0LIR0woAYYlw9TtQuN9YfOgp13d2O6z8N7w+FEptcMOncOk/VZAqbstW7GULOKxWiM+E3G1mJ7mgsEA/OqZEsHJfwblfUHUajq6DFgNckmf21uPEhvrTq4UWxYqYRUWpIudgsVjomBLBL0e3Ze4vBrPo0aE8Pq4dPlYLf5uzmyF/X8wl/1rGy4v2sj9fBaoiIvIjxzcbBan2Grh1tnt0uRERERGR+utzN4z4A2ydBrN/aWxZLCLnVFhWSbXdQaK3PDysrbBEozgle4nZSWpn88cQmgAtBpudRETE49iKjC2JvabQ5qyIZhAYaXQN9yDtk8J5YHhrZm48yvT1R6CiBA6ugIyLd0k9VVbF3VPW4+dj4aUbuuOjrYddo8UAGP0U7PrK2Obe3Z06ClOuhHm/g9Yj4WeroM1os1OJNIit2DhXetUCjsSOkLO1yd9zGpwRy8ZDJzhVVvW/f3hkHdRUQvOBjZ6jpKKahbtyGdsxUedPEROpKFWkFlrGhnDPkFZ8fv8AVvxmOE9c2p5APyt/n7ubYf9YzLh/LePf32axL6/E7KgiImKmA8vhnUvBLwhu+wYSMs1OJCIiIiKuMOiXMOBhWPc2fPu02WlEmqycU2cLbbzo4WFttRwMh1ZBdaXZSS6sNB+y5kKna8DH1+w0IiIeJ/dM97d4bztXWiyQ1NmjOqWe9cDwDPq0jObxmVvZu3a2UZCTceGiwNKKau56fx0HC0p5+cbupEQGuSitAND3PuhxKyx/DjZONTtN/W2bAa/2MwrBLn8RJn0IoXFmpxJpMFtRBUF+PoQGeNH1SGJnKD8JJw+ZneSChraNw+6ApVl5//uHB1cAFpfsLjl3Ww7lVXYmdEtp9LlE5PxUlCpSRymRQdw5KJ0Z9w1g5W+G8/vLMgn29+Ef8/Yw4p9LGPvCUl5cmMVemwpURUS8yq7ZMGUihCfD7XMhtrXZiURERETElUb+H3S/BZb9A1a9YnYakSbpbEcbr+v+VhuthkFVmVGY2pRtmw72auh6g9lJREQ8Ul5ROVYLxIT4mx3F9RI7Q+4OqDlHdzU35mO18OpNPUiJDGLDwmnU+IZAWr/zvj6vuIKb31rDuoOFPHdtV/q3inVhWgGMIulL/gGthsOXD0HWfLMT1U35KZhxD3x2G8RkwL3Locctxvcl4gFsxRXEhwdg8aZ/08ldjePxTWamuKiuzaKIDPZj8e7zFKUmdoSgyEbP8fmmozSLDqJ7WlSjzyUi56eiVJEGSI4M4o6BLZn+s/6sfnwET16eSVigL88v2MPI55Yw+vklvLBgD1m5xWZHFRGRxrTpQ/jkJuNi6vY5EKGVdyIiIiJex2KBy56H9lfA3MeN7a1F5CdyThnd31SUeg4th4CPP2TNMzvJhW36EBI7QUIHs5OIiHik3KIKYkID8PXxwke4SV2gpgLy95idxOmiQ/yZcnsvhlg2sbAykzdXHaGqxv6T1zgcDuZsy+Gyl5ax/VgRL9/Qncu7JJuUWPDxg2vehfhM497//qVmJ6qdgyvh1YGw9VMY8hujgUZMK7NTiTiVrbic+DAv6yge3wGsvk2+o7iP1cKgjDiW7MnDbnf88AfVlXD4O2g+sNEz5BaVs2JvPhO6pnhX4bJIE+RF/axFGldiRCC3DWjJbQNakltUzjdbjzN7Ww7/WpjFCwuyaJMQyuWdk7m8SzItYkPMjisiIs6y8t8w73eQPhSu+wACQs1OJCIiIiJmsfrAVW/CB9fA5/dBYCS0HWt2KpEmI7eoHIsF4rztAWJtBIRC8wFGJ64xz5id5txsO43OPGP+YnYSERGPZSsuJyHcS8+TSV2M4/HNHrn4IbX6EDjyOB5/M09/vZO3l+9ndIdEUiKDyC+p4NtdNrJsJbRNCOOdW3uTmRxudmQJjICbP4d3L4EPJ8HNM1yy7XS9VJTAwj/B2tchqoVRjNqsl9mpRBqFraiC9kle9jvSLxDi2sOxTWYnuahhbeP4cvMxth8rolNqhPEfj22A6tPQYkCjz//l5mPYHXBlNzUQEjGbFy6zE2l8CeGB3DqgJdPu6ceax0fwpys7EBnszz/n72HoPxZz5b+X8+aybHJOlZsdVURE6svhMG7yzPsdZF4JN0xTQaqIiIiIgG8ATPoAkjrDp7cYnWpEBDCKUmNCAvDzxu5vtZExGvJ3w4kDZic5t80fgcUHOl1tdhIREY9lK64gPsxLO4rHtAa/YDi+xewkjeNMN/TJN9/J27f2pHVCGB9/d4hnZu/knRUHiAjy429XdebrBweqILUpCYmByV9AWCJMmWAsIGpqshbAK32NgtTed8O9y1WQKh7NVlzhnQsdk7oYCzccjou/1kSD28QBsHi37Yf/eGC5cUzr3+jzf77pKF1SI2gVp2e2ImZTp1SRRhYfHsjkfi2Y3K8Fx0+d5qvNx5m1+RhPf72TZ2bvpHeLaK7omswlHZOICvE3O66IiNSGvQa+fgTWvws9boVLnzO6YomIiIiIAASEwY3T4e0xRked2742trsW8XK5ReUkRnjhw8PayhgNcx83ih1632V2mp+y18CWaZAxCkLjzU4jIuKxcosq6Hy2q5i3sfoYHVJzPLUodT4kdMISkcLwCBjeLoEau4OyymoC/Xy0aKcpC0uE2+fA1Kvgo0lw5SvQ5TqzU0FZIcz9rbFwKLaN0R01rY/ZqUQaVVllNSUV1cR7Y1fx5K6waSoUHYWIVLPTnFdsaABdUiNYtNvGAyMyjP94cCXEZxqF/o1oT24x244W8YfLMht1HhGpHX26FXGhpIgg7hqczpcPDOTbXw7h4RFtyC+p4Hczt9HrmQXc9s5aZm48QklFtdlRRUTkfKor4LPbjYLUQb+Ey15QQaqIiIiI/K+QGLh5ptFNf+pVUJhtdiIR0+UUVZDgrd3faiOmFUS1/L6TWpOSvRiKj0OXSWYnERHxWNU1dgpKK4jz5nNlYmfI2Qp2u9lJnKv8FBxaZSzu+BEfq4WwQD8VpLqD0Hi49WtI6wcz74Z5T0CNSc9za6pgzX/gxW6w9VMY/CujO6oKUsUL2IoqALyzq3hSF+N4fLO5OWphSNt4Nh0+SWFpJVRXwqHV0Lzxu6R+tPYQ/j5WxndLafS5ROTi9AlXxCTpcaE8NDKDBY8M4esHB3LHoJbsyS3hF59spsdT87nvg/XM2Xac8qoas6OKiMhZFSXw4XWw43MY/QyM+ANYLGanEhEREZGmKrKZUZhaU2Vs9VicY3YiEVPZispJiPDCh4e1ZbEY3VL3L4Wq02an+anNH0NgBLQZZ3YSERGPlV9SicMB8d64JfFZSZ2hoghO7Dc7iXNlLwZ7tXGeF/cVGA43zYBed8HKl2DKeDh11HXzOxywdwG8OgC++ZVRoHbPMhj+O/D14t8b4lVsxWeLUr3w33xCR7BY4dgms5Nc1OjMBOwOmLc9B46uh6pSaDmkUecsr6ph5sajjO6QQLR2KBZpElSUKmIyi8VCh+QIHh/XnmW/GsZn9/bjul7NWJNdyL1TN9Dr6QX8ZvoW1mQXYLc7zI4rIuK9SvLgvcuMh4NXvgL9f252IhERERFxB3Ft4cbPjM+TU6+C0yfNTiRiiorqGgpKK9Up9WLajIbqcti/zOwkPzh9AnbOgg4TwU9/fyIijcVWXA5AQrgX/679vgvcJlNjOF3WPGNxR2ovs5NIQ/n6w6X/gPGvwZF18EpfWP+eUTDaWBwO2PctvDPOuKasqYRJH8LkLyBBW1SLdzl7rowP98KiVP9giG3rFp1SOySH0yImmK+2HIf9SwALtBjYqHPO3Z7DybIqJvVKa9R5RKT2VJQq0oRYrRZ6tojmT1d2ZM1vR/D+7b0Z1SGBWZuPcd3rqxn0t0X8Y+5u9uWVmB1VRMS7FOyDt0aBbRdM+gC63Wh2IhERERFxJ6k9jM+Rebvho0lQWWZ2IhGXO7vNYmKEFz48rIvmA8E/FHZ9ZXaSH2z51CiU7XGr2UlERDxabpEXd387K74D+ATA0Q1mJ3Eeux2y5kOr4eDja3YacZau18N9K41C6i8fhLdGw4Hlzp2jphp2fmU8m5gyAU4chHF/g/vXQLtLtYubeKWz15Veu9gxuatbLNywWCxc2jmJlfvyqdq7yPhdGRzdqHN+vPYwzaKD6N8qplHnEZHaU1GqSBPl62NlcJs4nru2K+ueGMkL13UlPS6EVxbvZcQ/l3Dlyyt4b+UBCksrzY4qIuLZjqw3biiVn4Jbv4K22qpQREREROqh1TC46g04tBo+uw1qqsxOJOJSP3S08dKHh7XlFwhtxhpFqTXVZqcxOnOtfxeSuxkPQEVEpNF4dfe3s3z9IbGTZxWl5myGklzj/C6eJTodJs+CK16CU4fh3UvhvStg55cNu97L2w2L/gL/6gyf3AjFOXDpc/DQJuhzD/h68e8I8Xq24gr8faxEBvuZHcUcSV2Mc0pxjtlJLuqyzskEOMqxHl0H6UMbda4D+aWsyi5gUq80rFYV7Is0FVqOJeIGgv19Gd8thfHdUsgtKmfWpmNM33CEJ2dt56mvdjC0bTwTu6cwvF08gX4+ZscVEfEce+bBp7dASBzcNANiW5udSERERETcWYcJxjbYX/0Cvvg5jH8VrFozLt4h59SZTqkqSr24zCth22dwcHmjP7y7qCPrwLYdLv+XuTlERLxAblEFFgvEhXp5wVlKD9g4Few1YPWAZ1575gIWaD3S7CTSGKxW6D4ZOl0Da9+ANa/BJzdBaAK0GWP8vaf2grCkc3c2dTig6Bgc2wgHV8C+byFvF2AxPgeO+yu0GacuuyJn2IrLiQsLwOKtnYKTuxnHoxug3SXmZrmIdolhXBF1EJ/T1ZA+pFHn+mDNQXysFq7ukdqo84hI3ejTi4ibSQgP5K7B6dw1OJ2dx4uYufEon288yoKduYQH+nJp52Qmdk+hZ/Mo7/0wJiLiDBumwJcPQWJHuPEzCI03O5GIiIiIeIKet0NZAXz7NATHwJhntO2ieIXcIqP7W4KKUi8uYxT4hcD2z80vSl3/LviHQserzM0hIuIFbEXlxIYG4Ovj5YuWUrrD2v8Y3SITMs1O03B75hpFiSGxZieRxuQXBAMehL73wd75sOlD47PchveNPw8Ih8jmEBBqvLayDCqK4OQhqCwxXuMbCM16Q887oP3lEJ5k2rcj0lTlFVcQF+bFizcSO4PVF46ua/JFqRaLhWui9lFR5supqK401lPWkopqPl57mEs6Jel+g0gTo6JUETfWPimc9knh/HpsO1buy2fGBqNA9aO1h0iLDubqHqlc1SOVlMggs6OKiLgPhwOW/A0W/xlajYBr34OAMLNTiYiIiIgnGfQolBbA6peNxU8DHzY7kUijyy0qx9/HSpS3brNYF35B0Ga0sfXrpf80r0tc+SnYNh26XKfrYhERF8gtKich3IsLbc5K7m4cj21w/6LUEpvxfQx/wuwk4io+vtB2nPFVUwVH10POVqPI+tQRowC1vAj8gyEkHVoOhtgMiO9gFGT76neAyIXYiipoHhNsdgzz+AdDQkc48p3ZSWqlY+VGNtjbsHvnKW4dEN0oc0z77jDFFdXcMbBlo4wvIvWnolQRD+BjtTAoI45BGXE8Pb6aOdtymL7hCM/N38PzC/YwsHUs1/ZsxqjMBAL9PGCrExGRxlJTDV8/Ahvegy43wBUvgo8emIqIiIiIk1ksMObPUGqDBU8aWzt2vd7sVCKNKqeonIQIL95msa4yr4TtM+HgSmg5yJwMG6dC9WnocZs584uIeJncogqSItThi5jWRlfJo+uh201mp2mYrHnGMWOMuTnEHD5+kNbX+BIRp8gtLqdXyyizY5grtSds/hjsNeYtYKyN0gIC87eTFXoT09Yd4dYBzi8arbE7eHvFfnq1iKJrs0injy8iDePl+z+IeJ6QAF+u6pHKh3f1ZdmvhvHg8Ayy80p54KON9PnzQp78Yhvbjp4yO6aISNNTWQaf3GQUpA76JYx/RQWpIiIiItJ4rFYY/yq0HAJf3A975pmdSKRR5ZwqJ1Fb6dVexmjwDYIdn5szv70G1rwGaf0guas5GUREvIytuJx4nSuNz8nJXeHoBrOTNNyeuRCWDImdzE4iIuL2KqprOFlWRXyYl58rU3sZXZfzdpmd5MIOLAUgrvNodhwvapQalXnbczhy4jR3DEx3+tgi0nAqShXxYM2ig/nFqDYs+9Uwpt7RhyFt4vjou8Nc9tJyxv1rGe+s2M+J0kqzY4qImK8kD967HLLmGlsjjviD0b1KRERERKQx+QbAdVMhsSN8egscWWd2IpFGY2xJ7OUPD+vCPwTajoVt06G6wvXz7/4GTh6CPve6fm4RES9UVWMnv6SShHBt3Q1AcnfI3Q5V5WYnqb/qSti3CNqM1r1mEREnyCs2roviw7z8XJnayzg29XtI2UvAP4z+g0cT4GvlgzWHnDq8w+HgtaXZpEUHMyozwalji4hzqChVxAtYrRYGZsTy4vXd+O63I3lqfEf8fCz88csd9PnzQu7/YAOLd9uosTvMjioi4nq2XfDmcOMm57VToNedZicSEREREW8SGA43fgahCfDBNZCfZXYiEadzOBzkFKlTap11uwlOn4Dds10/95rXIKIZtLvM9XOLiHihs4U2WsBxRkoPsFdB7jazk9TfoZVQWQxtxpqdRETEI9jOFqV6+wKO6HQIioIj35md5ML2L4EWA4gICWJ81xRmbjzCyTLnNUxbuNPG5sMn+dnQVvhYtfhDpClSUaqIl4kI9uPmvs2Z9fOBfPPQIG7q25xV2QXc+s53DHj2W/4xdzeHC8vMjiki4hrZi+Gt0UbXmdtmQ3s9bBMRERERE4TGw80zwOoDUyZA0TGzE4k4VdHpasqr7CRGqNCmTtKHQXgqbJzq2nlztsKBZdD7LvDxde3cIiJeKrfI6AiqTqlnpHQ3jkc3mJujIfbMA58AaDnY7CQiIh7BVnS2U6qXX1daLJDSs2l3Sj15GAqzoeUQAG4b2ILyKjsfrT3slOHtdgf/mLeblrEhXN0j1SljiojzqShVxIu1TwrnD5dnsvrxEbx2U3faJ4XxyuK9DP77Iia/vZY523KoqrGbHVNEpHFsmAJTr4KIFLhzwQ83OkVEREREzBCdbnRMPX3C+Jx6+qTZiUScJuf7Qhsvf3hYV1Yf6HoD7F0Ip464bt6V/wa/YOg+2XVzioh4uVwV2vxUeIqxk0BT7wJ3IXvmQMtB4B9idhIREY+QV2xcV8aHaQEHqb0gbxeUF5md5NyyFxvHdKMotV1iOANax/DOiv2UV9U0ePgvtxxjV04xD4/MwM9HZW8iTZV+OkUEf18rYzsm8c5tvVn+6+E8NCKDrNxi7p26nv7Pfsvf5+5S91QR8Rx2Oyz4I8z6ubFK/fY5EJlmdioREREREUjuCtdNhfws+PgGqCo3O5GIU5wtSlWn1HroegPggE0fuWa+wmzY+in0vN3YElJERFzi+0IbdUo1WCzQrDccXmN2kvop2AeF+6DNWLOTiIh4DFtxBVYLxITqXElqT8ABx5poR/F9CyE0EeIzv/9P9w9rja24gk++a1i31KoaOy8syKJdYhiXd05uaFIRaUQqShWRn0iODOLhkW1Y9qthvHVLTzqnRPDq4n0/6p56XN1TRcR9VZ2Gz26D5c9Bj9vghmkQGGF2KhERERGRH7QaBhNeg4MrYPodYG94BwkRs+WeOlOUqk6pdRfd0lhQuXGKsciysS1/Hqy+0P+Bxp9LRES+l1tUgY/VQkyICm2+16wvnDwIxTlmJ6m7PXONY8Zoc3OIiHgQW1EFsaEB+FgtZkcxX0oP43i4CXYUr6mGfYug9UhjkckZ/dJj6NUiilcW7+V0Zf3vdb2zYj/780v51di2WPVvQaRJU1GqiJyTr4+VEe0TeOvWXiz/9XAeHJ7Bnpxi7p26Qd1TRcQ9leTBe5fDji9g1FNw2fPg42d2KhERERGR/9Xpahj7V9j1FXz9S3A4zE4k0iBnO6Wq+1s9db/FKMrZM6dx5zl52OjI2n0yhCU27lwiIvITuUXlxKnQ5qfS+hrHQ6vNzVEfe+ZAXDuIam52EhERj2ErLtc15VlBkcZ55tAqs5P8r6ProfwktB7xk/9ssVh4bEw7cosqeH1pdv2GPnmaFxZkMbJ9PMPbJTghrIg0JhWlishFJUcG8YtRbVj+62G8Oflc3VNzqFb3VBFpyvJ2w5sjIGcbXPs+DHjwJ6vzRERERESanL73wsBfwPp3YMlfzU4j0iA5ReVEh/gT4OtjdhT3lDkeItJgxb8ad56z4w94qHHnERGR/5FbXEGCCm1+KrEz+AbC4bVmJ6mb8iI4uBLajDE7iYiIR7EVVxAfpt03vtd8ABxeY3QmbUr2LgCL1dgJ6L/0bhnNpZ2SeHXJXg4V1K0Bmt3u4NFpm7EAT17ewUlhRaQxqShVRGrN18fKyMxzdU9dz+C/LeLf32aRV1xhdkwRkZ/a9y28NQqqTsOtX0PmFWYnEhERERGpnRFPQtcbYfFf4Lu3zE4jUm+5p8pJCNfDw3rz8YX+P4fDqxuvW1xhNqx/F7reAJHNGmcOERE5L1tROfE6V/6Urz8kdzfOf+4kexHYqyBDRakiIs5kFKVqAcf3mveHyhLI2Wx2kp/auwBSe0FQ1Dn/+HeXtsfPauWxzzZjt9d+Z6DXlu5jVXYBv78sk2bRwc5KKyKNSEWpIlIvP+6e+tpNPUiPC+Uf8/bQ/9mFPPTxRtYdKMSh7QVFxEwOB6x+DaZeDeGpcOcCSO1hdioRERERkdqzWODyfxkPtGc/Cju/NDuRSL3kFJWTqO5vDdPtJuOh3ooXG2f8+U+Cjz8MfbxxxhcRkQvKLSpXp9RzSesDxzdDZd26qZlq19fGObtZH7OTiIh4jOoaO/klKkr9ieYDjOPBlebm+LHSfDi2EVqPPO9LkiOD+P3lmazZX8gLC7NqNey3u3L5+9zdXNo5iet6aRGliLtQUaqINIivj5WxHROZemcfFjwyhBv7NOfbnTaufm0Vl764nI/WHqKssom1jBcRz1ddCV8+BHN+DW3Gwh1zIaq52alEREREROrOxw+ueRdSesBnd8CBFWYnEqmz3KJyEiOCzI7h3vxDoPfdsPtryNvt3LEProKds2DAQxCe5NyxRUTkoiqqazhRVkWCtiT+X836gr3aKHBxBzVVsGcOtBlndDoXERGnKCitxOGAOHUV/0F4EkSnN637RPu+BRwXLEoFuKZHKtf0SOXFhVlMW3f4gq9dsTefn03dQGZSOH+/ujMWi8WJgUWkMakoVUScpnV8KP93RQdW/3YEf57QCbvDweMzttLnzwv505c72J9fanZEEfEGpfkwZTxseA8G/RKumwoBYWanEhERERGpP/9guGGasdDqo+shd7vZiURqraK6hvySShL18LDhet8NfiGw6M/OG9Nuh7m/hbBk6P9z540rIiK1ZiuqACBB58r/1ay3cTy82twctXVgOZSfgvaXmZ1ERMSjnD1XqlPqf2k+AA6tNK7rmoK9CyA4BpK6XvBlFouFpyd0ZFBGLL/6bAsvLsyiuuan30ON3cE7K/Zzy9traRETwpQ7+hDsrwUfIu5EP7Ei4nQhAb7c0CeN63s3Y93BE7y/6iDvrzrA2yv2Mygjlsn9WjC8XTw+Vq1iEREny90OH02CEhtMfBM6X2N2IhERERER5wiOhptmwFujYepVcMc8iEwzO5XIRZ19eJgYoYeHDRYSaxSOLvkrHFkPqT0aPub6t+HYBhj/mtGNVUREXM5WXA5AfLjOlf8jOBpi28ChNWYnqZ1dX4NvEKQPMzuJiIhH+f5cqaLUn2oxEDZOAdt2SOxkbpaaasiaBxljwHrx/ogBvj68eUtPfv3ZFp6bv4cvNh1lYvdU0qKDOXLiNJ9vPMru3GKGt4vn+eu6EhHk54JvQkScSZ1SRaTRWCwWerWI5qXru7Hy8eE8MqoNWbkl3PX+Oob8fRFvLsumqLzK7Jgi4il2zTYe0FdXwm2zVZAqIiIiIp4nshncNB2qymDKRCgtMDuRyEXlFhkPD9X9zUn6PwAh8fDNY2CvadhYhdkw7w/Qajh0meScfCIiUme56pR6YWl9jU6pDT3vNTa73ShKbT3C2OlAREScxlZ8plOqzpU/1by/cTy40twcYJyrT5+AdpfU+i0Bvj68MKkbr93UnZAAX/4+dzcPfLSRv87ZhZ+vhRev78Zbt/RUQaqIm1KnVBFxifiwQB4ckcF9Q1sxf0cu76w8wNNf7+S5+Xu4pkcqt/RvQXpcqNkxRcQdORyw/DlY+BQkd4NJH0J4ktmpREREREQaR0ImXP8xvD8ePrwGbvlS3Q2lScs5U5SaGKGHh04REAZjnoEZd8G6t6H3XfUbx26Hz+8Hqw9c8RJYtKORiIhZtIDjIloMgg3vQ85WSO5qdprzO74Rio9B+yfNTiIi4nHO7sARF6pOqT8RmQYRaXBgOfS5x9wsu78BH39oNaLObx3bMYmxHZM4UVpJbnE5caEBxOjvWsTtqVOqiLiUr4+VcZ2SmHZPP756YCDjOibx0drDDP/nEm57Zy1L9+ThcDjMjiki7qLqtPEgbuGfoONVRodUFaSKiIiIiKdr3h+ufhuObYRpt0CNdiGRpivn1JmiVBXaOE+nayB9KMx/Egr21W+MFS/AoZUw9lmISHVmOhERqaPcogr8fCxEBasL2Dm1GGQcDywzN8fF7PwKLD6QMdrsJCIiHsdWXE5UsB/+vipx+h8tBsDBFcbCQ7M4HEa38JZDIKD+jciiQvxplxiuglQRD6Hf2CJimo4pEfzz2i6s+M1wHh6ZwdajRUx+ey2jn1/KB2sOcrqyiW/FIiLmOnkI3h4LWz+FEX+Aq94EvyCzU4mIiIiIuEb7y+DS52DvfJj1gPEAQKQJyi0qJ8DXqu32nMligStfAR8/+PRWqCyr2/v3zDMWd3aYAF1vaJSIIiJSe7aicuLDArGoa/W5hSdBTAbsb+JFqbu+ghYDITja7CQiIh7HVlxBfJgWOp5T+jAoK4Djm8zLkLcbTuyHtuPMyyAiTY6KUkXEdHFhATw8sg0rfjOMf17TBX9fK7+buY1+zy7k2W92cezkabMjikhTk70EXh8KhdnG1qWDfqmtBkVERETE+/S8DYb+FjZ/BAu0Tag0TTlFFSRGqNDG6SJSYMJ/jK2MZ94NNdW1e9+h1fDpLZDYEa58WdfSIiJNQG5xOQnh6gh2QS0HwcGVtT/fuVreHsjfA+0uMzuJiIhHshVXEK9z5bm1Gm4c9y40L8Pur42jilJF5EdUlCoiTUaArw9X9UjlqwcGMu2efvRLj+H1pfsY9LdF3P/hBjYeOmF2RBExm8MBK1+CKeMhOBbuWqQLHBERERHxbkN+BT1vhxX/glUvm51G5H/knionIVwdbRpF27Ew5s+w80uYfgdUlV/49bvnwJQJEJ4MN80A/xDX5BQRkQvKLarQufJiWgyCymJzu8BdyK6vjGO7S8zNISLiofLOdBWXcwiNg6SusM/MotRvILmbca0pInKGr9kBRET+m8VioXfLaHq3jObIiTLeX3WQj9Ye4ustx+nVIoq7BqUzsn0CVqs6OYh4lcpSY1vSbdOh/RUw/hUICDM7lYiIiIiIuSwWuOQfUJoHc38LIfHQ+RqzU4l8L6eonK7NIs2O4bn63QeOGpj3hLGbyBUvQXLXn76mrBCW/A3WvGo8rLxhGoTGm5FWRETOIbeonIGtY82O0bS1GGQc9y+B1J7mZjmXXV8ZxTgRqWYnERHxOHa7g7wSdUq9oNYjYPkLcPokBEW6du6i43BkHQz7rWvnFZEmT0WpItKkpUYF89tL2vPgiAymfXeYt5bv5+4p60mPDeHOQelM7J5CoJ+P2TFFpLEVZsPHN4FtB4x4Egb+QlsMioiIiIicZfWBiW/C1Kvg859BcLTxQELEZA6Hg5yichIj1NGmUfV/AGJawxf3w+tDoPlASOsLvoGQtxP2zIXKEuh1F4z6E/gHm51YRETOKKuspri8WoU2FxMaB/GZsH8ZDPql2Wl+6uQhOLreuG8tIiJOd6KskqoaB/FhOleeV+uRsOyfxuKNzCtdO/eOzwEHZI537bwi0uRZzQ4gIlIboQG+3D6wJUseG8pL13cjJMCX387cyoBnv+VfC7IoLK00O6KINJasBfD6UCg6CjdNh0GPqCBVREREROS/+QXCpA8gri18crPxYFzEZCfLqqistmtLYldoOw4e3AjDnoCyAlj+HCx6Gg4sNx4O/mwlXPoPFaSKiDQxtqIKABK0JfHFtRwMh1ZDdRN7HrR9pnHsMMHcHCIiHspWbJwr43WuPL/UXhAQDnsXun7ubTMgoSPEtXH93CLSpKlTqoi4FV8fK5d3Seayzkms2V/IG0uzeX7BHl5dspere6Ryx8B0WsaGmB1TRJzBbofl/4RvnzEuZq6bAtEtzU4lIiIiItJ0BUUaC7neGgUfXAO3z4XYDLNTiRfLKSoHIEmdUl0jMAKGPGZ81VSDvdooWBcRkSYr98y5Ugs4aqHFIFjzGhxZCy0Gmp3mB9tmQHJ33bsWEWkk3xelqqv4+fn4QfoQoyjV4XBdc5+Th43z8vDfu2Y+EXEr6pQqIm7JYrHQNz2Gt27txYJHBjO+awrTvjvC8H8u5p4p61h/sNDsiCLSEKdPwrSb4dunodPVcMc83dQTEREREamNsES4+XPAAlMmQtFxsxOJF8tRoY15fHxVkCoi4gZyzxTaJKjQ5uJaDgarH2TNNzvJDwr2wfFN0HGi2UlERDyW7cx1ZXyYzpUX1GoEFB0B207Xzbnjc+Oo86CInIOKUkXE7bWOD+PZqzqz/DfDuH9oa1ZnF3LVq6uY+MoK5mw7To3dYXZEEamL45vh9SGwZw6M+QtMfEPbC4qIiIiI1EVMK7jpMzhdCFOvMhZ9iZgg95Tx8DBRnVJFRETO6ftCGy3guLjAcGjer2kVpW6faRw7TDA3h4iIB/u+U2qYzpUX1Gascdz1levm3DYDkrpCdLrr5hQRt6GiVBHxGPFhgTw6pi2rHh/OH6/oQF5JBfdO3cCo55cw7bvDVFbbzY4oIhficMC6d+DNUVBTBbfOhn73uW6LCRERERERT5LcDa6bCvl74KProeq02YnEC+UUlWOxqKONiIjI+eQWlRPgayU80NfsKO4hYzTYthvbBTcF22dCsz4QkWp2EhERj5VXXEFYgC9B/j5mR2nawpOgWV/Y8YVr5ivcD8c2qEuqiJyXilJFxOME+/tyS/8WLH50GP++oRtBfj78avoWhvx9EW8t309ZZbXZEUXkv1WWwsx74auHocUAuGcppPUxO5WIiIiIiHtrNQwm/gcOrYLpd0KNrofFtXKLyokJCcDPR7ehRUREziW3qIKE8EAsWphfOxmjjePeJtAtNW835G6DDirGERFpTLbicuLCtdCxVjKvNM5N+Xsbf66tnxpHdQsXkfPQ3UAR8Vg+VguXdU7mqwcG8t7tvUmLDuapr3bQ/9lveWHBHk6WVZodUUQA8vbAGyNgyycw9Ldw42cQEmt2KhERERERz9DxKhj3V2P7tq9/YexQIOIix0+Vkxihh4ciIiLnYysuJ0GFNrUX2wYi0yCrCRSlbv4YLD4qxhERaWS2ogrtvlFbmVcYx52N3C3VboeNU6DlEOO8LCJyDipKFRGPZ7FYGNImjk/u6cf0n/WnZ/MoXliQRf9nv+Xpr3aQc6rc7Igi3mvrZ/DGMCi1wc0zYOivwartN0REREREnKrPPTDoUdjwPix6xuw04kVyTpWTGB5odgwREZEmy1ZUQbzOlbVnsUDGGMheDNUV5uWw240mC61HQFiCeTlERLyArdjoKi61EJEKqb1g++eNO8+BpXDyEHSf3LjziIhbU1GqiHiVHs2jePOWXsx9eDBjOiTyzsoDDPrbt/xm+hb255eaHU/Ee1RXwOzHYPodkNAB7lkGrYabnUpERERExHMNf8J4WLD077DmdbPTiJfILSrXw0MREZELyC0qJyFM58o6yRgNVWVwYLl5GQ4shaKj0GWSeRlERLyAw+HAVlyuTql1kXkl5GyBwuzGm2PDFAiMhHaXNd4cIuL2VJQqIl6pbWIYz1/XlcWPDmVSrzRmbDzK8H8u5v4PNrDt6Cmz44l4tsL98PZYWPs69Ps53Po1RKSYnUpERERExLNZLHDp89D2UvjmV7BthtmJxMOVV9VwoqxKnVJFRETOo6SimtLKGhLCVWhTJy0Ggm8g7JljXobNH0NABLS9xLwMIiJeoLiimvIqO/FawFF7mVcax8bqlnr6BOz8EjpfC376exGR81NRqoh4tWbRwTw1viMrfj2ce4e0YumePC57aTmT317LdwcKzY4n4nm2TYf/DIaCfXDdVBjzDPj4mZ1KRERERMQ7+PjC1W9BWj+YcTfsW2R2IvFgtiJjS92ECD2kEhEROZfconIAdRWvK/9gyBgFO2aBvcb181eUGHN3GA9+Qa6fX0TEi5y9rozXAo7ai0yDZn2MBRQOh/PH3/Ip1FRAt5udP7aIeBQVpYqIAHFhAfx6bDtWPD6cX41ty45jp7jmtVVMen0VK/fl42iMD2wi3qSyDGY9AJ/dDnHt4N5l0P5ys1OJiIiIiHgfvyC4/iOIbQOf3ATHNpqdSDxUzplCG3VKFRERObezRakqtKmHDhOgJAcOrXb93Du/hKpS6HK96+cWEfEytjPnyrgwnSvrpOuNkL8bjqxz7rh2O3z3JiR1gaTOzh1bRDyOilJFRH4kPNCP+4a2ZtmvhvP7yzLJzivlhjfWcM1rq1i6J0/FqSL1kbsd3hgGG6bAwEfgttkQ1dzsVCIiIiIi3isoEm6aDkHRMPVqYycDESf7vihVnVJFRETO6fuu4lrAUXcZY8A3CLbPdP3c69+F6HRI6+v6uUVEvIyt+Eyn1DCdK+ukwwTwC4aNU5w7btZco9i13wPOHVdEPJKKUkVEziHI34c7BrZk6a+G8acrO3D05Gkmv72W8a+sZOHOXBWnitSGwwHr3oY3hkNZIdw8E0Y+CT5+ZicTEREREZHwJLh5BuCAKROgONfsROJhck9pS2IREZELOdspVefKeggIhTajYccXYK9x3by52+Hwauh5O1gsrptXRMRL2YrVVbxeAsOh40TY+imcPuG8cVe8CBHNoMN4540pIh5LRakiIhcQ6OfD5H4tWPzYUP48oRMFJRXc8d46LntpOXO25WC3qzhV5JxOn4Bpk+GrX0DzAfCzFdBqmNmpRERERETkx2Iz4IZPoTQfpl4F5afMTiQeJKeonGB/H8IDfc2OIiIi0iTlFlUQ4u9DaIDOlfXSYQKU2uDgStfN+d1b4BNgbIssIiKNzlZUQaCflTCdK+uu991QVQYbP3DOeIe/g0Mroe99akAkIrWiolQRkVoI8PXhhj5pLHp0KH+/ujOlFdXcO3U94/61jC83H6NGxakiPzi8Fl4bDLtnw6g/wY2fQWi82alERERERORcUnvAdVMgbyd8fCNUlZudSDxETlE5ieGBWNRFTERE5Jxyi8vVJbUhMkYbWxNvn+ma+SqKYcsnRue54GjXzCki4uVsxRXEh+m6sl6SukBaP1j7OtRUN3y8lf+CwAjoPrnhY4mIV1BRqohIHfj5WLmmZzMWPDKEF67rSrXdzgMfbWT080uYufEI1TV2syOKmKemGpb8Hd4eCxbg9rkw4CGw6uOGiIiIiEiT1noEjH8NDiyDGXe6dgtU8Vi5p1RoIyIiciG2onJtR9wQ/iHQZoxRlFpd0fjzbZkGlSXQ847Gn0tERACwFZcTH6ZzZb31fwBOHoRt0xs2Tu522PkV9LoTAkKdk01EPJ6qRERE6sHXx8r4bin8f3t3Hid3Xd8P/DWzu7P3bpLNyZXEcF8CAoqgWLEEFREEqSIIeNSjFbWl2l/766/tr/Wota235VcFFa1KUakniBTlKFqL2iqCHAJVyMEmZDe7yR7Jzu+P2U2C7C4Js2F2N8/n47GPme/szHffgTzy3pnv6/P+fPvtJ+cj5x2dhrpi3v7F/8oL/v57+dLtwqnsgR59IPnUi5Mb/zo57Mzkjbck+xxb66oAAICddeTLk5XvSe78WvKNP0zKdgShOqt7B7K4UygVACaypnfQAo5qHX1+snl95XfY3alcTn74yWTRET73BngKrd04aAFHNQ58YbLw0OTm91e3APk7f5k0dSQn/P7U1QbMekKpAFWoKxZy+pF75ZuXPCf/eP4z0lKqzx/+y3/l1H+4Kf/6k4eydcRFPGa5cjn5yT8nHz8pWfvz5GWfSM65vLJ9AwAAMLOc8ObkpLcnt1+RfPe9ta6GGaxcLmetoA0ATKhcLmdNr6niVXva85M5+yW3f2r3/px7v5OsvSN51hsTW0gDPGXW9g5mYbte+aQVi8nJ70i6705+8rknd45ffje557rkxLclLfOmsjpglhNKBZgCxWIhpx2+OF9/y0n5x/OfkVJ9MW/9wk+y8gM35Wv/9XBGhFOZjTatT/7lwuSaNyVLjkzedGtluhIAADBznfLnyVHnJ997b/LDT9S6Gmao9f1DGdo6ksUm2gDAuHo2D2dwy4gtiatVLCbHXJg8cHPSfe/u+zm3fCDp2Ds54tzd9zMAeIy+wS3pG9xiB45qHXpmss9xyb/9dTK4cddeu2WwspvO3OXJs968W8oDZi+hVIApNBZO/eYlz8lHzzsmhSRv+fyP88IP3pxv/XSVcCqzx303Jh9/dnLXNyoXrS/8WmVFOgAAMLMVCslLPpgceFryjUuTO66pdUXMQKt7B5LExUMAmIBeOYWOviAp1lem/e8Ov/qP5MFbKlsW15d2z88A4HHWjvbKRRY7VqdQSFa+J+lbm3znL3fttTe+O1l3b/Ki9ycNfmcBdo1QKsBuUCwW8uIjl+Tatz03H3rl0RkeGcmbPvejvPjDt+S6O1anXBZOZXpZu3EgA8Nbn/iJwwPJdX+aXHlmUmpLXndD8pw/SIp1u71GAADgKVJXn5xzRbLvM5Mvvz65/6ZaV8QMs2rDWNCmucaVAMD0tLpntFd2CHhUrX1RctCLkp/8c+Xz66l2yweSpjnJMa+e+nMDMKGxBRyL2vXKqu17XPLMNyQ//KfK4KGd8cvvJrd+sDKR/IAX7NbygNmpvtYFAMxmdcVCznj6XnnxEUvytf96OB+84Z684crbc/jeHXnbKQfmlEMWplAo1LpM9nAjI+U88903pFxOSvXFdDY3pKOpPp3NDZX7o7crRh7M6ff+ebr678mDK87LquP/JO3pSOejm9LR3JC2Un2KRX+fAQBgVii1JOd9Ibn8hcnnz0su/Gqy9zG1rooZYtXoxcO9TH8DgHFtC6XqlVPjuNcmd341+elVUxseXXNH8otvJCe/M2lsm7rzAvCE1vYOJkkW6ZVT45T/UwmkXn1x8vobk3nLJ37uI3cnV706WXBQctp7nroagVlFKBXgKVBXLOTMo/fO6UcuyTU/eTgfuuGevO4z/5kj9+nM23/7wDzvwAXCqdTMSLmcv3rp4enZPJzegeH0bh6u3N+8Jd19Q3lgbW/OHPhSXjFyVXrTkouH/yg33nF0csd/P+Y8xULS0dyQjqaGHQKt9aMh1+3h1rHbHcOvHc0NaagzwB0AAKaV5rnJBV9JLj81+dw5ycXXJgsOrHVVzACrNmxOfbGQrjbbLALAeMamvy00/W1qLD85WXxkZaLbUa+aup29vv1nlSmpz3zj1JwPgJ22ZmxSqqniU6PUmrzy88k/PT/59EuSV/9r0rXi8c9b9d+Vz4DqSsl5V1VeB/AkCKUCPIXq64o55xn75KVH7ZWv/OihfOjf7snFV/wwR+83J5eeelBO3H9+rUtkD1RfV8z5z1o6/je7702ueWPy6x+mfOgZafrt9+WvCp3p2SG4ui3EOjD22PC276/uHdh2f2jLyKR1tJbqMqellDktDaNfpcxpbsjcbY+NHrc2pLO5lLktlWBrvTArAADsPh1LkguuSS4/LbnyzOQ11yVz9q11VUxzq3sGsqijKXV20wCAca3pHcj8tlJK9T7bnBKFQnLS2yvT335+TXL42dWf897vJPfdkKx8d9Iyr/rzAbBLVvcOpLVUl7ZGsaYp07WiEka98qzk/z2vMj31qPMqwdNN65P//GRy0/uTlq7k/C8ncye4fgywE/zrDVADDXXFnHvcvjnz6L1z9e2/zof/7Z686hM/yIn7d+XSUw/K0fvNrXWJ7OlGRpL/uCz5zl8k9U3J2Z9M4fCz014opD3JPk/ir+jA8NbHBVjHgq09m4ezYdNwNmweSs+m4Ty6aSirenorj20aykh54vO2N9VnTkslvNo5QYh1TvP2x+a2VCa3Fl0cBQCAndO1Irngy8kVL65cuHjNtUmrRZVM7OGezVlii0UAmNCqnoEs1iun1qEvTRYcktz47uSQlyZ1VVwGH9lamZI6d3ly3OunrkYAdtra3sEs0iun3l5HJb97Y3LN7yXfvDS59o8rIdT+R5LySHLw6cmL/z5pX1TrSoEZTigVoIZK9cWc98z98rJj9s4//+B/8tEb781ZH/v3vOCQRbl05YE5eHFHrUtkT/ToA5U3Ig/ekhywMnnJByvTkarU1FCXpoa6LNzFbTZGRsrZOLhlW1h1w+ZKUHXD2PFocHXD5uE8umk4v1q/KY9uqgRfyxOEWQuFpLO5IXOaK0HVea3bv+a2lNLVWsrcHR6b11JKR3N9CgVBVgAA9lCLj0jO+2IllPrZlyUXfj1p8p6V8a3uGcgR+8ypdRkAMG2t7hnIPnOba13G7FKsS075s+QL5yU//kxy7Gue/Ll+9Olk7c+Tl386qS9NXY0A7LQ1vQNZ1C6UulvMXZZc9PXkf76f3Ht90rcm6dgnOeT0yuc/AFNAKBVgGmhqqMtrTlqe3zlu31xx6/257KZf5oUfvDlnPH2vvP0FB2bZ/NZal8ieoFxObv9U8u3/naSQnPGR5OjzKwnOGioWC+lsbkhnc0P262rZ6ddtHSmnd/PwaFh1+wTWsRDro5uGtwVc1/QO5K5VvVnXP5TBLSPjnq++WMicbYHVhscEVuf9Zoh1NNza1FA3Vf8ZAACg9paekJz7meQLr6xc7H/V1UmDC0Q8VrlczqqegZx6mL8bADCR1b0DOXaZHdOm3EEvSpaemNzwfyvTUlu7dv0cPQ8l1/95suw5lemrANTE6t6BHLtUr9xtCoXK5zxLT6h1JcAsJZQKMI20Ntbn959/QM5/1tJcdtMvc8Wt9+fr/70q5x67by45Zf8s6bRymt2k56Hkq29J7rshWX5y8tKPJHP2q3VVVakrFjJ3NCy6PDsX7C6Xy9k8vDXr+oby6KahrO9/7NeOj/1i9cas769MaJ1oImtrqS5zW3eYvLpDgLVrNLza1daYBW2N6WorpaVUZxorAADT24GnJmddlnzpdcnVFyfnXlnd1qjMOo9uGs7glpEssc0iAIxrYHhrNmwazuJd3FGKnVAoJC96f3LZc5Lr/iR52WW79vqRkeRffy8Z2ZKc8aGaD2wA2FOVy+Ws7R3MIu8rAWYsnxgDTENzWkp552kH5+ITl+VjN96Xz/3gwXzpR7/Oq5+1NG963op0tTXWukRmk42rk4+dkIwMVz6wO/a1SbFY66pqolAopKVUn5Z59dl33s5NZd06Uk7P5uGs7x/M+v7tt49uGtoWbl3XX7l/z5q+rO8fyubhreOeq6mhmK7WxsxvK2X+aFC1q60xXa2V4+2PVQKu9XV75v8nAABq7Ihzks2PJt+8NPnq7ycv/dge+x6Cx3t4w+YkEUoFgAms6R1IkiwSSt09Fh2anPQHyU3vqyyoOvzsnX/tze9PfnljcvoHknlP220lAjC5DZuGM7R1JIva9UqAmUooFWAaW9jelL8447C89qTl+dAN9+TyW+/P5//jf/Lak5bndc99WjqaGmpdIrNB++LkuX+YHHx60rWi1tXMOHXFQuaNTj3dWZuHtmb9pqGs7xtKd/9gujcOjgZXB7Oubyjd/UNZ1TOQnz3ck3V9Q9ky8vhRrIVCMqe54THh1fmjk1fHHpvfVkpXa+V+W2O9KawAAEyd419fCabe+K6keW6y8t0mSZEkWd1TCdrY7QUAxrdKr9z9Tn5HJVz61UuSrgOSJUc+8Wt+enXld9sjzk2ecdFuLxGAia22gANgxhNKBZgB9p3Xkr99+dPzhpNX5B+uvzsf+rd78+nbHsybnrciF56wLM2lulqXyEx34ltrXcEepblUl71Lzdl7zhN/8Fwul9O7ecvjwqvdfUNZ1z+Y7o2V2zsf7k1332B6B7aMe57G+uL2AOsO4dWxqazz2xqzoL1yPLellGJRoAAAgCfw3D9KNq1Pvv+xpHlecvIf1boipoFVPSalAsBkxialLu60I9puU9eQnHtl8olTks+9PHn1NcnCQyZ+/k+vTr7yhmTpickZH7bYCqDG9EqAmU8oFWAG2X9hWz76qmPypod68nff/kXe+6278slb7s9bTzkgv3PcvmmwlTfMOoVCIZ0tDelsaciKBW1P+PzBLVuzvn+oMnF1bPJqXyXMOnb8SN9g7ly1Mev6BzO89fFTWOuKhXS1lkZDqo2/cVt5fMHocWdzgwmsAAB7qkKhMiF1YENy418nLXOT415X66qosVU9A6kvFjK/zcVDABjP2KTUxSal7l4dS5Lzv5R85szk8pXJ6f+QHPayxwZOh/qTG9+d3PaRSiD1lZ9PGiysAai1tb2DSSq7igIwMwmlAsxAh+/dmSsuPj4/fGB93nftXfnf1/wsl99yfy5deVBeePhiATHYgzXW12VJZ/NObf81NoX1kb7BdPcN5pGNj73t7hvKIxsHc/eajenuGz/A2lBXeNyk1fHDrI3paKr37xMAwGxTLCZnfCQZ6Em+cWnSNCc54pxaV0UNreoZyKKOJrsvAMAEVvcMpK2xPm2NLtPudgsPSV777eRfLkqufk1y6weTg15UmfLffXdyx5eTTeuSY1+brHxX0iAoDDAdrB6dlLqww2JHgJnKux2AGey4ZfNy1RtOyA13rs3fXHtX3vy5H+Xp+87JH592cE5Y0VXr8oBpbscprPsvnHwKa7lcTs/m4TyycTCPbAuuDj0myLqmdyA/e6gn6/qHsnXk8QHWUn0xC9oaM7+9MQsmCa8uaG9Ma6lOgBUAYKaoq0/OuSL57NmVbU8bO5IDT611VdTIwxs2Z685ptkAwETW9A5kkZDNU2fu0uS11yc/vjL5z08m331vknLS0JI87beSE9+a7PfMWlcJwA7W9A5kXmspjfV1tS4FgCdJKBVghisUCnnBoYvyWwcvzJd+9Ov8w/V355X/9P0876AFeedpB+eQJR21LhGYBQqFQua0lDKnpZQDFrVP+tyRkXIe3TT0uNDq2O0jfYN5aMNAfvKrnqzvH8w4+dU0NRS3h1S3BVm33y5oL2VBW1Pmt5fSUvIrLQBAzTU0VbY7/fTpyVWvTi74SrL0hFpXRQ2s7h3IkfvMqXUZADBtreoZ2KldjphCdfXJsRdXvob6k+HNSVNnUtdQ68oAGMea3sEsbLeAA2AmcwUfYJaoKxZy7rH75oyn75XP3PZAPnrjfXnRh27OWUftnbf/9oHZd15LrUsE9hDFYiFdbY3pamvMQYsnD7BuHSlnff/Q44Kr3TtMY31w3abc/uCjWb9pKOVxAqytpbrtodW2xszfIbA6Nnl17HvNJatqAQB2m6aO5FVfSq44Lfnn30ku/kay+IhaV8VTqFwuZ1XPQE47zKRUAJjImt6BrFgxv9Zl7LlKrZUvAKatylRx7ysBZjKhVIBZpqmhLr/73BX5nWP3y8e/d1+uuPX+fP2/V+WCE5bm935r/8xrLdW6RIBt6oqFSmi0vTGHLJn8uVu2jmR9/1DWjoZWx5vEet8jffnB/YN5dNPwuOdoLdVtm8A6FlgdC7IKsAIATIG2BckF1ySXr0yufFnymmuTrhW1roqnyPr+oQxtGcniThcPAWA8W0fKWbtxMEv0SgCY0JregRxqN1CAGU0oFWCW6mxpyB+/8OBc+Oyl+cD19+SKW+/PVT/8Vd74vBW5+MRltrsGZpz6umIWdjRl4U6sjh3eOpJ1fdsnsD7SNxpk3ThUub9xLMC6bsIAa1tjfea3lQRYAQB21Zx9K8HUK05LrjwzufjapHPvWlfFU2BVz0CS2JIYACbQ3TeYrSPlLBJKBYBxbdk6ku6+wSzqaKx1KQBUQSIJYJZb0tmcvznnyLz2Ocvzvmt/kb+97hf59L8/kLe94MCce+w+qa8r1rpEgCnXUFfM4s6mnZrQNLSlMoH1iQKs379/XTY8iQDrgrbGzBdgBQD2NAsOTM7/cvLpl1SCqRd9szJFlVlteyhV0AYAxrN6tFcutiUxAIyru28oI+Xs1IASAKYvoVSAPcSBi9rziQuPzQ8fWJ/3fuuu/MlXfppP3PLLvGPlwVl52KIUCoValwhQE6X66gOsj2wcTHff0C4HWLeFVwVYAYDZaK+jkvOuSq48K/nsWcmFX0+a59S6KnajVT2bkyRL5rh4CADjsYADACa3ptcCDoDZQCgVYA9z3LJ5ufqNJ+T6n6/J+677Rd742dtzzH5z8qcvPiTPWDqv1uUBTGtPJsBaCayOH2C9d21fbvvlrgdYK/dLAqwAwPS39ITkFZ9LPv+K5HMvTy74StLYVuuq2E1W9Qykoa6Q+a22WQSA8YwFbRYJ2gDAuFbrlQCzglAqwB6oUCjk1MMW5/kHL8zVt/86f3/93Tn747flRUcszjtPOzhLu1prXSLAjFerAOtjpq+OTmDdflxKV1tjWkt1JmQDAE+d/U9Jzrk8uerC5AvnVaanNri4NBut2rA5izqaUiz6XRMAxrO6t7KAo6u1VOtSAGBaWrstlGqxI8BMJpQKsAerryvmFcfvlzOO2iv/dNP9ueym+3L9z9fkgmcty1uev3/m+mAM4Ckx1QHWe54gwNrUUExX6/aQ6thtV+v2AGtXWyldbaXMaymlvq441X9kAGBPc8hLkjM/lnzlDcnVFyfnfiapa6h1VUyxVT0DtiMGgEms7hnIwnYLOABgImt6B1NXLKSrTSgVYCYTSgUgLaX6vPUFB+SVx++bv7/+7nzq3+/P1bf/Km95/gF59bOXprHeltAA08WuBljX9Q9mXd9Quvsqt+v6R8Oro8drNw7k5w/3Zl3/YIa3lh93jkIhmdtSSldrJaS6LbTaWsr89sbRxyvB1vltjWkxhRUAmMjTX5EMbky+eWlyzZuSsy5Lit5vziaregZy1L5zal0GAExbq3sGduozHQDYU63uHciCtsbUWcABMKMJpQKwzcKOprz37CNz8YnL855v3Zl3ffPOfPq2B/KO0w7OS45cImQEMMOU6otZ0tmcJZ3NT/jccrmc3oEt28OroxNYu0eDrGPB1p8/3JvuvsH0DmwZ9zzbprC2N2b+DkHWbRNZWxszv71yO7elwRRWANjTHP/6ZKgv+c5fJKXW5PQPVFbBMOOVy+Ws7hnIksMFbQBgIqt7B3LoXh21LgMApq01vQNZZAEHwIwnlArA4xy0uD2fuvj43HzPI3nXN+7MJZ//cT55y/350xcdkuOXz6t1eQDsBoVCIZ3NDelsbsiKBU/8/MEtW7O+f2hbWLV7NMi6rn8o3RsH090/lNW9A7njCaawzmupBFe7Wht3mMQ6Nn119LHR77U2evsCALPCSW+vTEy9+e+SUlty6l8Lps4C6/qHMrR1JEtcPASAcY0t4Hj+wQtrXQoATFtreweztKul1mUAUCVXdQGY0HMOWJBvXDI/X/7Rr/P+b/8i5152W1YetijvPO3gPG1BW63LA6CGGuvrdm0K6+Yt6e4fTPfGSnB1Xd9gHhkLso5OY71jdArrxgmmsDY31FUCrG2NWfAbQdbfvJ3bUrK9DwBMZ8//s0ow9baPJE2dycnvqHVFVGnVhoEkyeKd+P0QAPZEvQNbsnl4axZ3WMABABNZ3TtgSBLALCCUCsCk6oqFvPzYfXP6kXvlEzf/Mv/4vftyw5035fxnLc0lpxyQea2lWpcIwDRXKBTS2dKQzpaGrNiJRQ2DW7ZWgqp9Q+nuH9w2jXUswPpI32Ae3jCQnz7Uk3V9Q9kyMvEU1rGgaldbY7paS1nQXrnt+o0prC2luhRMaAOAp06hkJz2N8lgX3LjuyoTU094c62rogqrejYnSfaaI2gDAONZ3TO2gEOvBIDxDAxvTc/mYb0SYBYQSgVgpzSX6vKWUw7IK47fLx/4zt35zG0P5Eu3/zpv/q39c/GJy9LUUFfrEgGYJRrr67LXnObsNWfnprD2bB5O9+jU1e7RqavbjytB1p891JPujYPZODj+FNamhmK6Whszf4cAa1fb2PFjp7LObSmlVF+c6j82AOx5isXkjA8nQ33Jdf8raWxLjnl1raviSVolaAMAk1rdq1cCwGTW9g4mSRa2N9a4EgCqJZQKwC5Z0N6Yd511RC569rK851t35W+uvSuf/f6D+aOVB+WMp++Voq2SAXgKFQqFzGkpZU5LKfsvfOIprAPDW7O+/7FTWNf1DWZd/9C2AOvajQO5c1Vv1vUNZWjryLjn6WxueMyk1bHg6nih1o6mBv0RACZSV5+c/YnkC+clX70kKbUmh59d66p4Elb1DKShrpD5rS4eAsB4Vo9OFV/cIZQKAOMZW8CxSK8EmPGEUgF4Ug5Y1J7LLzou/35vd971zTvzti/+JFfcen/+7tynZ/+F7bUuDwDG1dSwa1NYNw5u2RZcHZvCuu24v3J779q+/OD+oTy6aSjl8uPPU18sZN4OIdU3P2//nLCiazf86QBghqpvTM69Mvns2cmXfzcptSUHrqx1VeyiVT2bs6ijyWIcAJjA6p7K9DdBGwAY3xpTxQFmDaFUAKry7P3n52u/f1K+8uOH8vHv3ZfO5lKtSwKAKVEoFNLR1JCOpoYsn9/6hM/fsnUkj24a3hZcHZu8ui3IOhpiLY+XXAWAPV2pJTnvi8lnX5YM9NS6Gp6Ehe2NOX75vFqXAQDTVmdzfZ65fF5K9cValwIA01KpvpjD9+7IonahVICZTigVgKoVi4Wc/Yx98rJj9k6hYCIKAHum+rpiFrQ3ZkG7LWsB4Elp6khec11SrKt1JTwJf/riQ2tdAgBMaxeduDwXnbi81mUAwLS18rDFWXnY4lqXAcAUsBQPgCkjkAoAAEBVBFIBAAAAAGY0oVQAAAAAAAAAAAAAqiaUCgAAAAAAAAAAAEDVhFIBAAAAAAAAAAAAqJpQKgAAAAAAAAAAAABVE0oFAAAAAAAAAAAAoGpCqQAAAAAAAAAAAABUTSgVAAAAAAAAAAAAgKoJpQIAAAAAAAAAAABQNaFUAAAAAAAAAAAAAKomlAoAAAAAAAAAAABA1YRSAQAAAAAAAAAAAKiaUCoAAAAAAAAAAAAAVRNKBQAAAAAAAAAAAKBqQqkAAAAAAAAAAAAAVE0oFQAAAAAAAAAAAICqCaUCAAAAAAAAAAAAUDWhVAAAAAAAAAAAAACqViiXy7WuYVKFQuGRJA/Wug4AnlJLy+XygloXMVPolQB7JL1yF+mXAHscvXIX6ZUAexy9chfplQB7HL1yF+mVAHukcfvltA+lAgAAAAAAAAAAADD9FWtdAAAAAAAAAAAAAAAzn1AqAAAAAAAAAAAAAFUTSgUAAAAAAAAAAACgavW1LgDYrlAoLEty0ejhd8vl8ndrVgwATFP6JQBMTq8EgMnplQAwOb0SACanV8LkhFJhelmW5M93OP5ubcoAgGltWfRLAJjMsuiVADCZZdErAWAyy6JXAsBklkWvhAkVa10AAAAAAAAAAAAAADOfUCoAAAAAAAAAAAAAVRNKBQAAAAAAAAAAAKBqhXK5XOsaYI9XKBSel+TGnXluuVwu7NZiAGCa0i8BYHJ6JQBMTq8EgMnplQAwOb0Sdo5JqQAAAAAAAAAAAABUrb7WBQBJkp8lOSvJ4Un+avSxLyb5Qs0qAoDpR78EgMnplQAwOb0SACanVwLA5PRK2AlCqTANlMvl7iTXFAqFDTs8fFe5XL6mNhUBwPSjXwLA5PRKAJicXgkAk9MrAWByeiXsnGKtCwAAAAAAAAAAAABg5hNKBQAAAAAAAAAAAKBqQqkAAAAAAAAAAAAAVE0oFQAAAAAAAAAAAICqCaUCAAAAAAAAAAAAUDWhVAAAAAAAAAAAAACqJpQKAAAAAAAAAAAAQNWEUmF6GdnhfqFmVQDA9KZfAsDk9EoAmJxeCQCT0ysBYHJ6JUxCKBWml74d7rfWrAoAmN70SwCYnF4JAJPTKwFgcnolAExOr4RJCKXC9HL/DvePqVkVADC96ZcAMDm9EgAmp1cCwOT0SgCYnF4JkyiUy+Va1wDsoFAo/CjJ0aOHlyW5IcnGse+Xy+Vra1EXAEwn+iUATE6vBIDJ6ZUAMDm9EgAmp1fCxIRSYZopFAovTPK1JHXjfb9cLhee2ooAYPrRLwFgcnolAExOrwSAyemVADA5vRImJpQK01ChUHhmkkuSnJBkcZLmse9pWgBQoV8CwOT0SgCYnF4JAJPTKwFgcnoljE8oFQAAAAAAAAAAAICqFWtdAAAAAAAAAAAAAAAzn1AqAAAAAAAAAAAAAFUTSgUAAAAAAAAAAACgakKpAAAAAAAAAAAAAFRNKBUAAAAAAAAAAACAqgmlAgAAAAAAAAAAAFA1oVQAAAAAAAAAAAAAqiaUCgAAAAAAAAAAAEDVhFIBAAAAAAAAAAAAqJpQKgAAAAAAAAAAAABVE0oFAAAAAAAAAAAAoGpCqQAAAAAAAAAAAABU7f8DL5isCGjhdh0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "main()" ] } ], "metadata": { "kernelspec": { "display_name": "jax0227", "language": "python", "name": "jax0227" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }