{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Text classification with transformer model\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_transformer_text_classification.ipynb)\n", "\n", "In this short tutorial, we will perform sentiment analysis on movie reviews. We would like\n", "to know if a review is overall positive or negative, so we're facing a classification task.\n", "\n", "For that purpose, we're going to build a model comprised of a transformer and dense layers.\n", "What is special about transformers is that they utilize an attention mechanism that allows\n", "the model to focus on the most relevant parts of the input sequence when processing information.\n", "\n", "We will follow the standard ML model-building process - split data into training and test sets,\n", "construct the model, run training, and report metrics." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Required packages\n", "# !pip install -U jax flax optax\n", "# !pip install -U grain tqdm requests matplotlib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tools overview\n", "\n", "Here's a list of key packages we will use in this example that belong to JAX AI stack:\n", "\n", "- [JAX](https://github.com/jax-ml/jax) will be used for array computations.\n", "- [Flax](https://github.com/google/flax) for constructing neural networks.\n", "- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.\n", "- [grain](https://github.com/google/grain/) will be be used to define data sources.\n", "- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import io\n", "import json\n", "import textwrap\n", "import typing\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from flax import nnx\n", "import optax\n", "import grain.python as grain\n", "import tqdm\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import requests" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We're going to use [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).\n", "Each review is encoded as a list of word indexes, where words are indexed by\n", "the frequency. The label is a positive or negative sentiment (0 or 1)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def prepare_imdb_dataset(num_words: int, index_from: int, oov_char: int = 2) -> tuple:\n", " response = requests.get(\"https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz\")\n", " response.raise_for_status()\n", " with np.load(io.BytesIO(response.content), allow_pickle=True) as f:\n", " x_train, y_train = f[\"x_train\"], f[\"y_train\"]\n", " x_test, y_test = f[\"x_test\"], f[\"y_test\"]\n", "\n", " rng = np.random.RandomState(113)\n", " indices = np.arange(len(x_train))\n", " rng.shuffle(indices)\n", " x_train = x_train[indices]\n", " y_train = y_train[indices]\n", "\n", " indices = np.arange(len(x_test))\n", " rng.shuffle(indices)\n", " x_test = x_test[indices]\n", " y_test = y_test[indices]\n", "\n", " x_train = [[w + index_from for w in x] for x in x_train]\n", " x_test = [[w + index_from for w in x] for x in x_test]\n", "\n", " xs = x_train + x_test\n", " labels = np.concatenate([y_train, y_test])\n", " xs = [\n", " [w if w < num_words else oov_char for w in x] for x in xs\n", " ]\n", "\n", " idx = len(x_train)\n", " x_train, y_train = np.array(xs[:idx], dtype=\"object\"), labels[:idx]\n", " x_test, y_test = np.array(xs[idx:], dtype=\"object\"), labels[idx:]\n", "\n", " return (x_train, y_train), (x_test, y_test)\n", "\n", "\n", "def pad_sequences(arrs: typing.Iterable, max_len: int) -> np.ndarray:\n", " # Ensure that each sample is the same length\n", " result = []\n", " for arr in arrs:\n", " arr_len = len(arr)\n", " if arr_len < max_len:\n", " padded_arr = np.pad(arr, (max_len - arr_len, 0), 'constant', constant_values=0)\n", " else:\n", " padded_arr = np.array(arr[arr_len - max_len:])\n", " result.append(padded_arr)\n", "\n", " return np.asarray(result)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "25000 Training sequences\n", "25000 Validation sequences\n" ] } ], "source": [ "index_from = 3 # make sure that 0 encodes pad token\n", "vocab_size = 20000 # Only consider the top 20k words\n", "maxlen = 200 # Only consider the first 200 words of each movie review\n", "(x_train, y_train), (x_test, y_test) = prepare_imdb_dataset(num_words=vocab_size, index_from=index_from)\n", "print(len(x_train), \"Training sequences\")\n", "print(len(x_test), \"Validation sequences\")\n", "x_train = pad_sequences(x_train, max_len=maxlen)\n", "x_test = pad_sequences(x_test, max_len=maxlen)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For handling input data we're going to use Grain, a pure Python package developed for JAX and\n", "Flax models. Grain supports custom setups where data sources might come in different forms, but\n", "they all need to implement the `grain.RandomAccessDataSource` interface. See\n", "[PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/source/data_sources.md)\n", "for more details.\n", "\n", "Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class DataSource(grain.RandomAccessDataSource):\n", " def __init__(self, x, y):\n", " self._x = x\n", " self._y = y\n", "\n", " def __getitem__(self, idx):\n", " return {\"encoded_indices\": self._x[idx], \"label\": self._y[idx]}\n", "\n", " def __len__(self):\n", " return len(self._x)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "train_source = DataSource(x_train, y_train)\n", "test_source = DataSource(x_test, y_test)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "seed = 12\n", "train_batch_size = 128\n", "test_batch_size = 2 * train_batch_size\n", "\n", "train_sampler = grain.IndexSampler(\n", " len(train_source),\n", " shuffle=True,\n", " seed=seed,\n", " shard_options=grain.NoSharding(), # No sharding since this is a single-device setup\n", " num_epochs=1, # Iterate over the dataset for one epoch\n", ")\n", "\n", "test_sampler = grain.IndexSampler(\n", " len(test_source),\n", " shuffle=False,\n", " seed=seed,\n", " shard_options=grain.NoSharding(), # No sharding since this is a single-device setup\n", " num_epochs=1, # Iterate over the dataset for one epoch\n", ")\n", "\n", "\n", "train_loader = grain.DataLoader(\n", " data_source=train_source,\n", " sampler=train_sampler, # Sampler to determine how to access the data\n", " worker_count=4, # Number of child processes launched to parallelize the transformations among\n", " worker_buffer_size=2, # Count of output batches to produce in advance per worker\n", " operations=[\n", " grain.Batch(train_batch_size, drop_remainder=True),\n", " ]\n", ")\n", "\n", "test_loader = grain.DataLoader(\n", " data_source=test_source,\n", " sampler=test_sampler, # Sampler to determine how to access the data\n", " worker_count=4, # Number of child processes launched to parallelize the transformations among\n", " worker_buffer_size=2, # Count of output batches to produce in advance per worker\n", " operations=[\n", " grain.Batch(test_batch_size),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model\n", "\n", "Here we construct the model with the transformer and dense layers:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class TransformerBlock(nnx.Module):\n", " def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, rngs: nnx.Rngs, rate: float = 0.1):\n", " self.attention = nnx.MultiHeadAttention(\n", " num_heads=num_heads, in_features=embed_dim, qkv_features=embed_dim, decode=False, rngs=rngs\n", " )\n", "\n", " self.dense_1 = nnx.Linear(in_features=embed_dim, out_features=ff_dim, rngs=rngs)\n", " self.dense_2 = nnx.Linear(in_features=ff_dim, out_features=ff_dim, rngs=rngs)\n", "\n", " self.layer_norm_1 = nnx.LayerNorm(num_features=embed_dim, epsilon=1e-6, rngs=rngs)\n", " self.layer_norm_2 = nnx.LayerNorm(num_features=ff_dim, epsilon=1e-6, rngs=rngs)\n", "\n", " self.dropout_1 = nnx.Dropout(rate, rngs=rngs)\n", " self.dropout_2 = nnx.Dropout(rate, rngs=rngs)\n", "\n", " def __call__(self, inputs: jax.Array):\n", " x = self.attention(inputs, inputs)\n", " x = self.dropout_1(x)\n", " x_norm_1 = self.layer_norm_1(inputs + x)\n", " x = self.dense_1(x_norm_1)\n", " x = jax.nn.relu(x)\n", " x = self.dense_2(x)\n", " x = self.dropout_2(x)\n", " x = self.layer_norm_2(x_norm_1 + x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class TokenAndPositionEmbedding(nnx.Module):\n", " def __init__(self, max_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs):\n", " self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)\n", " self.pos_emb = nnx.Embed(num_embeddings=max_length, features=embed_dim, rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array):\n", " maxlen = jnp.shape(x)[-1]\n", " positions = jnp.arange(start=0, stop=maxlen, step=1)\n", " positions = self.pos_emb(positions)\n", " x = self.token_emb(x)\n", " return x + positions" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "embed_dim = 32 # Embedding size for each token\n", "num_heads = 2 # Number of attention heads\n", "ff_dim = 32 # Hidden layer size in the feed forward network inside transformer\n", "\n", "class MyModel(nnx.Module):\n", " def __init__(self, rngs: nnx.Rngs):\n", " self.embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim, rngs=rngs)\n", " self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim, rngs=rngs)\n", " self.dropout1 = nnx.Dropout(0.1, rngs=rngs)\n", " self.dense1 = nnx.Linear(in_features=embed_dim, out_features=20, rngs=rngs)\n", " self.dropout2 = nnx.Dropout(0.1, rngs=rngs)\n", " self.dense2 = nnx.Linear(in_features=20, out_features=2, rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array):\n", " x = self.embedding_layer(x)\n", " x = self.transformer_block(x)\n", " x = jnp.mean(x, axis=(1,)) # global average pooling\n", " x = self.dropout1(x)\n", " x = self.dense1(x)\n", " x = jax.nn.relu(x)\n", " x = self.dropout2(x)\n", " x = self.dense2(x)\n", " x = jax.nn.softmax(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MyModel(\n", " embedding_layer=TokenAndPositionEmbedding(\n", " token_emb=Embed(\n", " embedding=Param(\n", " value=Array(shape=(20000, 32), dtype=float32)\n", " ),\n", " num_embeddings=20000,\n", " features=32,\n", " dtype=dtype('float32'),\n", " param_dtype=,\n", " embedding_init=.init at 0x7f0cab141e10>\n", " ),\n", " pos_emb=Embed(\n", " embedding=Param(\n", " value=Array(shape=(200, 32), dtype=float32)\n", " ),\n", " num_embeddings=200,\n", " features=32,\n", " dtype=dtype('float32'),\n", " param_dtype=,\n", " embedding_init=.init at 0x7f0cab141e10>\n", " )\n", " ),\n", " transformer_block=TransformerBlock(\n", " attention=MultiHeadAttention(\n", " num_heads=2,\n", " in_features=32,\n", " qkv_features=32,\n", " out_features=32,\n", " dtype=None,\n", " param_dtype=,\n", " broadcast_dropout=True,\n", " dropout_rate=0.0,\n", " deterministic=None,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " out_kernel_init=None,\n", " bias_init=,\n", " out_bias_init=None,\n", " use_bias=True,\n", " attention_fn=,\n", " decode=False,\n", " normalize_qk=False,\n", " qkv_dot_general=None,\n", " out_dot_general=None,\n", " qkv_dot_general_cls=None,\n", " out_dot_general_cls=None,\n", " head_dim=16,\n", " query=LinearGeneral(\n", " in_features=(32,),\n", " out_features=(2, 16),\n", " axis=(-1,),\n", " batch_axis=FrozenDict({}),\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " precision=None,\n", " dot_general=None,\n", " dot_general_cls=None,\n", " kernel=Param(\n", " value=Array(shape=(32, 2, 16), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(2, 16), dtype=float32)\n", " )\n", " ),\n", " key=LinearGeneral(\n", " in_features=(32,),\n", " out_features=(2, 16),\n", " axis=(-1,),\n", " batch_axis=FrozenDict({}),\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " precision=None,\n", " dot_general=None,\n", " dot_general_cls=None,\n", " kernel=Param(\n", " value=Array(shape=(32, 2, 16), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(2, 16), dtype=float32)\n", " )\n", " ),\n", " value=LinearGeneral(\n", " in_features=(32,),\n", " out_features=(2, 16),\n", " axis=(-1,),\n", " batch_axis=FrozenDict({}),\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " precision=None,\n", " dot_general=None,\n", " dot_general_cls=None,\n", " kernel=Param(\n", " value=Array(shape=(32, 2, 16), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(2, 16), dtype=float32)\n", " )\n", " ),\n", " query_ln=None,\n", " key_ln=None,\n", " out=LinearGeneral(\n", " in_features=(2, 16),\n", " out_features=(32,),\n", " axis=(-2, -1),\n", " batch_axis=FrozenDict({}),\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " precision=None,\n", " dot_general=None,\n", " dot_general_cls=None,\n", " kernel=Param(\n", " value=Array(shape=(2, 16, 32), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " )\n", " ),\n", " rngs=None,\n", " cached_key=None,\n", " cached_value=None,\n", " cache_index=None\n", " ),\n", " dense_1=Linear(\n", " kernel=Param(\n", " value=Array(shape=(32, 32), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " in_features=32,\n", " out_features=32,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " dot_general=\n", " ),\n", " dense_2=Linear(\n", " kernel=Param(\n", " value=Array(shape=(32, 32), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " in_features=32,\n", " out_features=32,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " dot_general=\n", " ),\n", " layer_norm_1=LayerNorm(\n", " scale=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " num_features=32,\n", " epsilon=1e-06,\n", " dtype=None,\n", " param_dtype=,\n", " use_bias=True,\n", " use_scale=True,\n", " bias_init=,\n", " scale_init=,\n", " reduction_axes=-1,\n", " feature_axes=-1,\n", " axis_name=None,\n", " axis_index_groups=None,\n", " use_fast_variance=True\n", " ),\n", " layer_norm_2=LayerNorm(\n", " scale=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(32,), dtype=float32)\n", " ),\n", " num_features=32,\n", " epsilon=1e-06,\n", " dtype=None,\n", " param_dtype=,\n", " use_bias=True,\n", " use_scale=True,\n", " bias_init=,\n", " scale_init=,\n", " reduction_axes=-1,\n", " feature_axes=-1,\n", " axis_name=None,\n", " axis_index_groups=None,\n", " use_fast_variance=True\n", " ),\n", " dropout_1=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(\n", " default=RngStream(\n", " key=RngKey(\n", " value=Array((), dtype=key) overlaying:\n", " [0 0],\n", " tag='default'\n", " ),\n", " count=RngCount(\n", " value=Array(22, dtype=uint32),\n", " tag='default'\n", " )\n", " )\n", " )),\n", " dropout_2=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(...))\n", " ),\n", " dropout1=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(...)),\n", " dense1=Linear(\n", " kernel=Param(\n", " value=Array(shape=(32, 20), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(20,), dtype=float32)\n", " ),\n", " in_features=32,\n", " out_features=20,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " dot_general=\n", " ),\n", " dropout2=Dropout(rate=0.1, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(...)),\n", " dense2=Linear(\n", " kernel=Param(\n", " value=Array(shape=(20, 2), dtype=float32)\n", " ),\n", " bias=Param(\n", " value=Array(shape=(2,), dtype=float32)\n", " ),\n", " in_features=20,\n", " out_features=2,\n", " use_bias=True,\n", " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", " kernel_init=.init at 0x7f0cab141b40>,\n", " bias_init=,\n", " dot_general=\n", " )\n", ")\n" ] } ], "source": [ "model = MyModel(rngs=nnx.Rngs(0))\n", "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n", "\n", "To train our model we construct an `nnx.Optimizer` object with our model and a selected\n", "optimization algorithm. We're going to use Adam optimizer, which is a popular choice for\n", "Deep Learning models. Adam automatically adjusts the learning rate for each parameter and\n", "uses momentum hyperparameter to accelerate convergence." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "num_epochs = 10\n", "learning_rate = 0.0001\n", "momentum = 0.9\n", "\n", "optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def compute_losses_and_logits(model: nnx.Module, batch_tokens: jax.Array, labels: jax.Array):\n", " logits = model(batch_tokens)\n", "\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits=logits, labels=labels\n", " ).mean()\n", " return loss, logits" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "@nnx.jit\n", "def train_step(\n", " model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, jax.Array]\n", "):\n", " batch_tokens = jnp.array(batch[\"encoded_indices\"])\n", " labels = jnp.array(batch[\"label\"], dtype=jnp.int32)\n", "\n", " grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)\n", " (loss, logits), grads = grad_fn(model, batch_tokens, labels)\n", "\n", " optimizer.update(grads) # In-place updates.\n", "\n", " return loss\n", "\n", "@nnx.jit\n", "def eval_step(\n", " model: nnx.Module, batch: dict[str, jax.Array], eval_metrics: nnx.MultiMetric\n", "):\n", " batch_tokens = jnp.array(batch[\"encoded_indices\"])\n", " labels = jnp.array(batch[\"label\"], dtype=jnp.int32)\n", " loss, logits = compute_losses_and_logits(model, batch_tokens, labels)\n", "\n", " eval_metrics.update(\n", " loss=loss,\n", " logits=logits,\n", " labels=labels,\n", " )" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "eval_metrics = nnx.MultiMetric(\n", " loss=nnx.metrics.Average('loss'),\n", " accuracy=nnx.metrics.Accuracy(),\n", ")\n", "\n", "train_metrics_history = {\n", " \"train_loss\": [],\n", "}\n", "\n", "eval_metrics_history = {\n", " \"test_loss\": [],\n", " \"test_accuracy\": [],\n", "}" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "bar_format = \"{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]\"\n", "train_total_steps = len(x_train) // train_batch_size\n", "\n", "def train_one_epoch(epoch):\n", " model.train()\n", " with tqdm.tqdm(\n", " desc=f\"[train] epoch: {epoch}/{num_epochs}, \",\n", " total=train_total_steps,\n", " bar_format=bar_format,\n", " leave=True,\n", " ) as pbar:\n", " for batch in train_loader:\n", " loss = train_step(model, optimizer, batch)\n", " train_metrics_history[\"train_loss\"].append(loss.item())\n", " pbar.set_postfix({\"loss\": loss.item()})\n", " pbar.update(1)\n", "\n", "\n", "def evaluate_model(epoch):\n", " # Compute the metrics on the train and val sets after each training epoch.\n", " model.eval()\n", "\n", " eval_metrics.reset() # Reset the eval metrics\n", " for test_batch in test_loader:\n", " eval_step(model, test_batch, eval_metrics)\n", "\n", " for metric, value in eval_metrics.compute().items():\n", " eval_metrics_history[f'test_{metric}'].append(value)\n", "\n", " print(f\"[test] epoch: {epoch + 1}/{num_epochs}\")\n", " print(f\"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}\")\n", " print(f\"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 0/10, [192/195], loss=0.693 [00:36<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 1/10\n", "- total loss: 0.6879\n", "- Accuracy: 0.5661\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 1/10, [192/195], loss=0.678 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 2/10\n", "- total loss: 0.6734\n", "- Accuracy: 0.6507\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 2/10, [192/195], loss=0.594 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 3/10\n", "- total loss: 0.6177\n", "- Accuracy: 0.7316\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 3/10, [192/195], loss=0.511 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 4/10\n", "- total loss: 0.5404\n", "- Accuracy: 0.7890\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 4/10, [192/195], loss=0.466 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 5/10\n", "- total loss: 0.4995\n", "- Accuracy: 0.8159\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 5/10, [192/195], loss=0.463 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 6/10\n", "- total loss: 0.4806\n", "- Accuracy: 0.8280\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 6/10, [192/195], loss=0.435 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 7/10\n", "- total loss: 0.4714\n", "- Accuracy: 0.8349\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 7/10, [192/195], loss=0.421 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 8/10\n", "- total loss: 0.4665\n", "- Accuracy: 0.8394\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 8/10, [192/195], loss=0.409 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 9/10\n", "- total loss: 0.4602\n", "- Accuracy: 0.8453\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 9/10, [192/195], loss=0.389 [00:35<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 10/10\n", "- total loss: 0.4560\n", "- Accuracy: 0.8486\n", "CPU times: user 27min 44s, sys: 12min 14s, total: 39min 59s\n", "Wall time: 8min 19s\n" ] } ], "source": [ "%%time\n", "for epoch in range(num_epochs):\n", " train_one_epoch(epoch)\n", " evaluate_model(epoch)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjoAAAGdCAYAAAAbudkLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB4kElEQVR4nO3deVxUVf8H8M8wMGyyqOwKghu4oCgo4ZIbCbZp9pSaTyqPWZm2UWn+cskltc3MMjVza9Uy0xbTlCQ33PcNFRc0AVFkVbaZ+/sDGeYy+zADzPB5v15TcJdzz2Vk7pdzvucciSAIAoiIiIhskF1dV4CIiIjIUhjoEBERkc1ioENEREQ2i4EOERER2SwGOkRERGSzGOgQERGRzWKgQ0RERDaLgQ4RERHZLPu6roA5KBQK3LhxA25ubpBIJHVdHSIiIjKAIAgoKChAQEAA7Ows0/ZiE4HOjRs3EBgYWNfVICIiIhNcu3YNzZs3t0jZNhHouLm5Aaj4Qbm7u9dxbYiIiMgQ+fn5CAwMVD7HLcEmAp3K7ip3d3cGOkRERFbGkmknTEYmIiIim8VAh4iIiGwWAx0iIiKyWTaRo0NE9Z8gCCgvL4dcLq/rqhBRLZJKpbC3t6+z6V8Y6BCRxZWWliIjIwN3796t66oQUR1wcXGBv78/ZDJZrV+bgQ4RWZRCocDly5chlUoREBAAmUzGiT2JGghBEFBaWors7GxcvnwZbdq0sdjEgNow0CEiiyotLYVCoUBgYCBcXFzqujpEVMucnZ3h4OCAq1evorS0FE5OTrV6fSYjE1GtqO2/4oio/qjL339+8hAREZHNYqBDRNQAJCcnQyKRIDc3t9av3bdvX7z22ms1Lmf16tXw9PSscTmmkkgk2LhxY51d3xDvvvsuIiIijDonODgYCxcutEh96gMGOkREWowZMwZDhgyp62rQfcOGDcP58+ctfh1TggVTmTt4e/PNN5GUlGTUOQcPHsTzzz9vtjrUN0xGJiKieq+srAzOzs5wdnau66rUidLSUoOGZjdq1AiNGjUyqmxvb29Tq2UV2KJjoILiMiz7Jw3XcjgPCBFV+Oeff9C9e3c4OjrC398fb7/9NsrLy5X7169fj/DwcDg7O6Np06aIjY1FUVERgIqupO7du8PV1RWenp7o2bMnrl69qvE6PXr0wOTJk0XbsrOz4eDggJ07dwIAvvnmG0RFRcHNzQ1+fn545plncPPmTa1119RqsXDhQgQHB4u2ffXVV2jXrh2cnJwQFhaGL774QufPpKioCKNGjUKjRo3g7++Pjz/+WO0YTV1Anp6eWL16NQDgypUrkEgkWLduHfr06QMnJyd89913aq0flffwzTffIDg4GB4eHhg+fDgKCgqUxxQUFGDkyJFwdXWFv78/PvnkE51daatXr8bMmTNx/PhxSCQSSCQSZb0A4NatW3jiiSfg4uKCNm3a4NdffxWdf+rUKQwaNAiNGjWCr68vnn32Wdy6dUvjtZKTk5GQkIC8vDzltd59910AFd1Js2fPxqhRo+Du7q5scZk8eTLatm0LFxcXtGzZEtOmTUNZWZnaz6RSZavkRx99BH9/fzRt2hQTJkwQnVO960oikeCrr77SeZ+//vor2rRpAycnJ/Tr1w9r1qyps65RfRjoGOjdX89g3p/n8Pjnu+u6KkRWTxAE3C0tr/WXIAhmu4d///0XDz/8MLp164bjx49jyZIlWLFiBebMmQMAyMjIwIgRI/C///0PZ8+eRXJyMoYOHaqcIXrIkCHo06cPTpw4gZSUFDz//PNa5xcaOXIk1q5dK6r/unXrEBAQgN69ewOoaPGYPXs2jh8/jo0bN+LKlSsYM2ZMje7xu+++w/Tp0/Hee+/h7NmzmDt3LqZNm4Y1a9ZoPeett97CP//8g02bNuGvv/5CcnIyjhw5YtL13377bbz66qs4e/Ys4uLiNB6TlpaGjRs34vfff8fvv/+Of/75B/Pnz1fuT0xMxJ49e/Drr79i27Zt2LVrl876DBs2DG+88QY6dOiAjIwMZGRkYNiwYcr9M2fOxNNPP40TJ07g4YcfxsiRI5GTkwMAyM3NRf/+/dGlSxccOnQIW7ZsQVZWFp5++mmN1+rRowcWLlwId3d35bXefPNN5f6PPvoInTt3xtGjRzFt2jQAgJubG1avXo0zZ87g008/xfLly/HJJ5/o/Dnu2LEDaWlp2LFjB9asWYPVq1eLgjdNdN3n5cuX8Z///AdDhgzB8ePH8cILL+Cdd97RWV5dYteVgXZfzAYA3LlbpudIItLnXpkc7advrfXrnpkVBxeZeT72vvjiCwQGBuLzzz+HRCJBWFgYbty4gcmTJ2P69OnIyMhAeXk5hg4dihYtWgAAwsPDAQA5OTnIy8vDo48+ilatWgEA2rVrp/VaTz/9NF577TXs3r1bGdh8//33GDFihDI4+t///qc8vmXLlli0aBG6deuGwsJCo7syKs2YMQMff/wxhg4dCgAICQnBmTNnsGzZMowePVrt+MLCQqxYsQLffvstBgwYAABYs2YNmjdvbtL1X3vtNeW1tVEoFFi9ejXc3NwAAM8++yySkpLw3nvvoaCgAGvWrMH333+vrM+qVasQEBCgtTxnZ2c0atQI9vb28PPzU9s/ZswYjBgxAgAwd+5cLFq0CAcOHEB8fDw+//xzdOnSBXPnzlUev3LlSgQGBuL8+fNo27atqCyZTAYPDw9IJBKN1+rfvz/eeOMN0bapU6cqvw4ODsabb76JtWvXYtKkSVrvqXHjxvj8888hlUoRFhaGRx55BElJSRg3bpzWc3Td57JlyxAaGooPP/wQABAaGopTp07hvffe01peXWKLjoHkirquARHVJ2fPnkVMTIyoFaZnz54oLCzE9evX0blzZwwYMADh4eF46qmnsHz5cty5cwcA0KRJE4wZMwZxcXF47LHH8OmnnyIjI0Prtby9vTFw4EB89913ACr+ok5JScHIkSOVxxw+fBiPPfYYgoKC4Obmhj59+gAA0tPTTbq/oqIipKWlYezYscq8j0aNGmHOnDlIS0vTeE5aWhpKS0sRHR2t3NakSROEhoaaVIeoqCi9xwQHByuDHADw9/dXdtldunQJZWVl6N69u3K/h4eHyfUBgE6dOim/dnV1hbu7u/J6x48fx44dO0Q/r7CwMADQ+jPTRdP9r1u3Dj179oSfnx8aNWqEqVOn6n2PO3ToAKlUqvxe9Wekja77TE1NRbdu3UTHq/6M6xu26BjInE3eRA2ds4MUZ2Zp7oqw9HVri1QqxbZt27B371789ddf+Oyzz/DOO+9g//79CAkJwapVq/DKK69gy5YtWLduHaZOnYpt27bhgQce0FjeyJEj8corr+Czzz7D999/j/DwcGULUVFREeLi4hAXF4fvvvsO3t7eSE9PR1xcHEpLSzWWZ2dnp/a5ppq3UVhYCABYvny5KHCpvLeakEgkOq9dydXVVW9ZDg4OamUrFJb7y1TX9QoLC/HYY4/h/fffVzvP39/f6GtVv//K4HbmzJmIi4uDh4cH1q5dqzEPytA6m/Oc+ootOgaSM9AhMhuJRAIXmX2tv8y5xla7du2QkpIiemDv2bMHbm5uyq4aiUSCnj17YubMmTh69ChkMhl++eUX5fFdunTBlClTsHfvXnTs2BHff/+91usNHjwYxcXF2LJlC77//ntRa865c+dw+/ZtzJ8/H71790ZYWJjev9i9vb2RmZkpqv+xY8eUX/v6+iIgIACXLl1C69atRa+QkBCNZbZq1QoODg7Yv3+/ctudO3fUhoR7e3uLWrAuXLhgkQVfW7ZsCQcHBxw8eFC5LS8vT+8QdZlMBrlcbvT1unbtitOnTyM4OFjtZ6YtaDPmWnv37kWLFi3wzjvvICoqCm3atNGawG5JoaGhOHTokGib6s+4vmGgYyC5goEOUUOUl5eHY8eOiV7Xrl3DSy+9hGvXruHll1/GuXPnsGnTJsyYMQOJiYmws7PD/v37MXfuXBw6dAjp6enYsGEDsrOz0a5dO1y+fBlTpkxBSkoKrl69ir/++gsXLlzQmafj6uqKIUOGYNq0aTh79qwyfwIAgoKCIJPJ8Nlnn+HSpUv49ddfMXv2bJ331bdvX2RnZ+ODDz5AWloaFi9ejD///FN0zMyZMzFv3jwsWrQI58+fx8mTJ7Fq1SosWLBAY5mNGjXC2LFj8dZbb+Hvv//GqVOnMGbMGLXp//v374/PP/8cR48exaFDh/Diiy+qtSCYg5ubG0aPHo233noLO3bswOnTpzF27FjY2dnpDHqDg4Nx+fJlHDt2DLdu3UJJSYlB15swYQJycnIwYsQIHDx4EGlpadi6dSsSEhK0BjPBwcEoLCxEUlISbt26pTPga9OmDdLT07F27VqkpaVh0aJFosC5trzwwgs4d+4cJk+ejPPnz+PHH39UJjfXxwV7GegYSKES6By7lov8YiYlEzUEycnJ6NKli+g1c+ZMNGvWDJs3b8aBAwfQuXNnvPjiixg7dqwyWdTd3R07d+7Eww8/jLZt22Lq1Kn4+OOPMWjQILi4uODcuXN48skn0bZtWzz//POYMGECXnjhBZ11GTlyJI4fP47evXsjKChIud3b2xurV6/GTz/9hPbt22P+/Pn46KOPdJbVrl07fPHFF1i8eDE6d+6MAwcOiEb8AMBzzz2Hr776CqtWrUJ4eDj69OmD1atXa23RAYAPP/wQvXv3xmOPPYbY2Fj06tULkZGRomM+/vhjBAYGonfv3njmmWfw5ptvWmzB1wULFiAmJgaPPvooYmNj0bNnT+VweW2efPJJxMfHo1+/fvD29sYPP/xg0LUCAgKwZ88eyOVyDBw4EOHh4Xjttdfg6empda2nHj164MUXX8SwYcPg7e2NDz74QGv5jz/+OF5//XVMnDgRERER2Lt3r3I0Vm0KCQnB+vXrsWHDBnTq1AlLlixRjrpydHSs9froIxFsIPkkPz8fHh4eyMvLg7u7u0WuETbtTxSXVfVPBjZxxq5J/S1yLSJbUlxcjMuXLyMkJKTWVy0mqq6oqAjNmjXDxx9/jLFjx9Z1dWzGe++9h6VLl+LatWsa92v7HKiN5zeTkQ1UPQfrWs69uqkIEREZ7OjRozh37hy6d++OvLw8zJo1C0BFzhOZ7osvvkC3bt3QtGlT7NmzBx9++CEmTpxY19XSiIGOgZiMTERknT766COkpqZCJpMhMjISu3btgpeXV11Xy6pduHABc+bMQU5ODoKCgvDGG29gypQpdV0tjRjoGIjJyERE1qdLly44fPhwXVfD5nzyySd6Z2SuL5iMTERERDaLgQ4RERHZLAY6RFQrbGCAJxGZqC5//xnoEJFFVU4EZ4mZb4nIOlT+/ltiYkh9mIxMRBYllUrh6empXJLAxcWlXs6eSkTmJwgC7t69i5s3b8LT07PG66SZgoFODckVAlIzCxDm5wY7O354E2ni5+cHAHrXXyIi2+Tp6an8HKhtDHQMcKtQ+zon7/56Gt/su4oJ/VrhrbiwWqwVkfWQSCTw9/eHj4+PxlWqich2OTg41ElLTiUGOnoUFJchas52rfu/2VexcuziHWkMdIj0kEqldfqBR0QND5OR9Th2Lbeuq6DT9Tt38ehnu/DL0et1XRUiIqJ6h4GOHhl5xXVdBZ1mbDqNU//m4/V1x5XbOIyXiIiogkmBzuLFixEcHAwnJydER0fjwIEDWo/t27cvJBKJ2uuRRx5RHiMIAqZPnw5/f384OzsjNjYWFy5cMKVqZldYXG7wsdkF2nN5LKWgWv2mbzqFmHl/I/duaa3XhYiIqL4xOtBZt24dEhMTMWPGDBw5cgSdO3dGXFyc1tEUGzZsQEZGhvJ16tQpSKVSPPXUU8pjPvjgAyxatAhLly7F/v374erqiri4OBQX121rSk5RKWb9fsbg48d/W7vrqZSWKyBA3HrzdcpVZOYXY+3Ba7VaFyIiovrI6EBnwYIFGDduHBISEtC+fXssXboULi4uWLlypcbjmzRpAj8/P+Vr27ZtcHFxUQY6giBg4cKFmDp1KgYPHoxOnTrh66+/xo0bN7Bx48Ya3VxN2Ut1DxdXVFvo89DVO5i+6RSeXLIX6w9fx6l/8/CfJXtx6EqO2eu27UwW2k79Ewev3NG4nwPdiYiIjAx0SktLcfjwYcTGxlYVYGeH2NhYpKSkGFTGihUrMHz4cLi6ugIALl++jMzMTFGZHh4eiI6O1lpmSUkJ8vPzRS9LcHdyQGufRlr3t/y/zWrbvk65isNX7+DNn47j0c9249DVO/jP0hST8mYOXM7BvD/PoqRcrrZPX+tRTeZjS799F8+tOYjDV80foBEREdUmowKdW7duQS6Xw9fXV7Td19cXmZmZes8/cOAATp06heeee065rfI8Y8qcN28ePDw8lK/AwEBjbsMo34ztbpZyur23HfELd+LRz3bhZr5hXXJPL0vBsn8uYcSX+4y+ngQS/HEiA5dvFRl97sQfjmD72Zt4colhwSsREVF9VaujrlasWIHw8HB0716z4GHKlCnIy8tTvq5ds1w+ir+Hs1nKuVVYinOZBTj1bz6W7bxk1LlH0nMx4st9WJRUlaCtqcWmTK5Qfv3+lnOY8P0RDPp0p9F1vXqbaxIREZFtMCrQ8fLyglQqRVZWlmh7VlaW3qmdi4qKsHbtWowdO1a0vfI8Y8p0dHSEu7u76GVN5ApBY1eWIAhIv31X476US7exYNt5tbwgVTN+Pa38uvz+ccVlVcGPXCHg2LVc5N0rQ0FxGdYdTMedIvXRWXn3OHMtERHZBqNmRpbJZIiMjERSUhKGDBkCAFAoFEhKSsLEiRN1nvvTTz+hpKQE//3vf0XbQ0JC4Ofnh6SkJERERAAA8vPzsX//fowfP96Y6lmN1Xuv4Jej/+Lr/3XH3VI5rt+5C38PZxy/nosPt6aivb87nusdovHcMasP4uv/aW4R+35/utZrnriei8c/36O2/Zej/2Lt8zHK70/fyNNaRkm5HKf+zUPn5p6wl3IKJiIiqv+MXgIiMTERo0ePRlRUFLp3746FCxeiqKgICQkJAIBRo0ahWbNmmDdvnui8FStWYMiQIWjatKlou0QiwWuvvYY5c+agTZs2CAkJwbRp0xAQEKAMpmxR3r0yDF6sHngAwJmMfCT+eFzjvp3nswFU5OAAhiU4Z+Td0xjkAMC+S+KE46PpuVrLefvnk/jl6L+Y2K813owLNejaREREdcnoQGfYsGHIzs7G9OnTkZmZiYiICGzZskWZTJyeng47O/Ff+6mpqdi9ezf++usvjWVOmjQJRUVFeP7555Gbm4tevXphy5YtcHJyMuGWbF9RieGTGAJAzLy/de5fsO084jv44cLNAkzdeErrcb8c/RcAsOSfNJ2BTmpmAbafzcLYXiFwcuC6RkREVHckgg2sF5Cfnw8PDw/k5eVZJF8n+O0/zF6mtbgyv2oG68qfg0xqh3Oz43HwSg5u5N3DkIhmkKhkR1ce90r/1kgcyJYfIiLSzNLPb4Crl5OBREnQEuDFbw/jrzMVCeSNHB3wUHtfCIIgCnhO/qs934eIiKg2MNAhg2QXVq3jVVquUAY5QEVAk5qZj+/3p+Pnl3ootzswYZmIiOoYAx3SqbRcAZm9HYrL1GdnrmQnAT766zwAYFHSReV2BjpERFTX+CQyg0+HR6Cdv/6+RU8Xh1qojXm1nfonrt+5iz4fJms9RqKyslZpedW8PfrWCjNEcZkcJ6/nmbSEBhEREQMdMwhq4oJxWua9URXVooky2GnR1MXS1TKbXu/v0Ln/k+3nlV+rrqZuf3/0XWFJOX44kI7bKt1fF28WoNCA0WNv/Hgcj32+m6uxExGRSRjoGGBE90A0cZVp3R/i5Ypyuf4Wh/efDMf6F2PwTHQQlj0bac4q1hv7VeblkdlLsOVUJjrO2IopG05i0Ke7IAgCjl/LReyCnej/UbLe8v44mQEA+GhrqqWqTERENoyBjgHmDe2EA/83QOO+XZP6wdNFhjKFQuP+AI+quYCaNnJEax83zH0iHMFNXS1S17r2b+495df2dnZ4UWWV9ZsFJfgiOQ1/nclUfi8IAj7Zdh7bz2SplaXqtoalKoiIiPRhoGMgbUseBDap6ILS1qLj5+GELa/1xs63+om2qybqTujXSvm1p4sDRse0wPk5gzAqpkVNq12nvtl3VW3bh1tToZpus/V0Fj5NuoDnvj5UizUjIqKGgqOuTBDTsinG9gpBkEqezYNtvQEATV1leKxzAP48lYEyuYC5Q8MR5qeeqCy1kyC2nQ9yikqR+FAoerbywk+Hr2P6o+3R+H432azBHTG8WxCCmrqg44ytJtX14XA/DIlohuJyBV754ahJZZibakiYll2o/Hr5zkvo3dZL48+LiIjIFAx0TODubI/Y9r6ibSFertg1qR+auMrg6miPdx/voDaBXnVfje6mPKZHay/0aO2ldkz7AN0P/e7BTXDnbiku3KwIGN4c2BYPtvVWrm3VxFWGgR38cPjqHaPu8f8eDsPczeeMOsdQ36ZUtfSoJiS/t/kssFk8G7OqXRey0d7fHU0bOart0/ezJiKiholdV2YU2MQFro5VsaMhD96aPpyHdQvE8O5BAIAerZpiYv826NTcE0v/G4n+YT5IfKhiCQYXmXFrTj3U3k/j9ra+jWpUXwAoUAluSsrUc5vWHqhahV2m0sX37IoDiFu4S+34m/nF6Dn/b3y41TKBGRERWS+26Jigrqd08WrkiFuFJVj8TFc8HO4HhQC083NDp0BP5THxHf0Q37EqWHFWWVxz//8NQPTcJJ3X8HFzxIR+rbB4RxqmPtIOoX5uKCqRI76jH9pO/VM0X05NCBpWYH97w0m8veEkViV0g7NMitJ7Vde6pTJEvdIvR//FjbxiLN6RhrfiwsxSLyIisg0MdKzEnCEdMXXjKbw5sC1e6NMKd0vk8Lg/J49UAo3dXqpk9lUtI072UrzwYEss23lJue3NgW2VsxsDgKujPd4cGIrh3YLQvLGzqOXJnB1EX6eoJyxXSlh1EP4eTsi7V6a2T6EQYGdXURPViRiLy+RcMZ2IiJTYdWUl/vtAC1yZ/wgm9m8DB6mdMsgxVPWuq/8+0AJ2KhGLq6M9XottA6AiPweo6FYLbOKi1r1mr3LigDAfo+pRnVyhu3nMWUOX245zN9Fp5l/440SG2j5FXTe3ERFRvcJAxwTW+Cj1dJHhpb6t8MKDLeHh4oDAJi44NTNOud/fwwmvxbbFwXdi8fyDrXSUBCx7NgpODnZ4/8lwrBjTDS882NJi9b6UXaS2LWH1QRSWlGPC90cAAOUqwVKxhpyfe6Xa1+kiIiLbxq6rBmRSvDh/xUVmj69GReHotTsYeD/52NtNfURTdb3aeOHUu3HKuYVUu8XqgkIl0Bnx5T5sff1B5fezfz+DFbsv4+fxPRDZonFdVI+IiOoQW3QauNj2vngrLkyZ72Io1QkUVZOe64Jq91dqVoFo34rdlwFwCQkiooaKgQ7VWIcAD+x4sy+W/lf7+l0rRkfhgyc7oUuQp9mvb8AyY+AUO0REDRO7rsgsQrxc8e+dexr3ffifThjQrmKCxU3H/zXrdf88mYEjRk6GSEREDQcDHTIb1VaTj5/qjKFdm+FWYako78fX3UnDmaYb/90Ro+tGREQNBwMdE3AEs2YhXlUrsj8Z2RyAenKzt4blG2qDHSMdIqIGiYEOmU2ApzM2vNQDHs7a5/hxc6r6Jxcd0gT7L+fURtWIiKiBYjKyEbreT6Qd3i2wbitSj3UNaoxW3trXwwpsUrXiu2rQY2mqkx4KbJIjImow2KJjhB+efwDX79zT+SAn3R7tFIB9l3LQJcgTD4Q0xfU79xDbzhfnswqwI/UmygwZQmUg1fW4KsOc1Xsu4+O/zmP+k53wSCd/s12LiIjqJwY6RnC0lzLIqSGpnQTzhoYrv9/yWtXkfltOZeDFbw1LLjbE3+eylF9XNui8+9sZAMCs308z0CEiagDYdUX1RnxHf+x8qx/G99W9BIU+5XIF5ApBtDRE9VTk24WlNboGERFZB7boUL0S1NQFvdt4YUlymslltH7nT7jKpIht76vcVn1hUmbpEBE1DGzRoXqnRyuvGpdRVCrHpmM3lN9rW9VcoRBEuTxERGRbGOhQvbTn7f748tmqJSVWJXSrUXnVg5nKkVdDl+xFt/e2c4VzIiIbxUCH6qVmns4Y2EFlsdAa9jWpBTr3/3/sWi7y7pXhaDqXkSAiskUMdMgqqK5QbopSefUWHaCopNyoMu6WlqNczm4uIiJrwkCHrIJqjs3rsW2NPt/HzQknrueKto1eeUD5tb4wKr+4DO2nb8XAhTuNvjYREdUdBjpkFVQbdNyc7OHrbtyaWcFNXbDlVKZo2yEjVj0/eH+pikvZRUZdl4iI6hYDHbIKTRvJRN//8lJPo86XCwK+0DFknatCEBHZJs6jQ/Xa4me64lJ2IaJaNBZt9/dwQnBTF1y5fdegchQ1zPFhIEREZJ1MatFZvHgxgoOD4eTkhOjoaBw4cEDn8bm5uZgwYQL8/f3h6OiItm3bYvPmzcr97777LiQSiegVFhZmStXIxjzSyR8vD2gDiUQCr0YV3VV9Q70hkUiw+dXeWDkmyqBybumZCVnQk6XDOIeIyDoZ3aKzbt06JCYmYunSpYiOjsbChQsRFxeH1NRU+Pj4qB1fWlqKhx56CD4+Pli/fj2aNWuGq1evwtPTU3Rchw4dsH379qqK2bOxicR2TeqHO3dLEeDpDABwkdmjpZdha4/9cTJD53622BAR2Sajo4kFCxZg3LhxSEhIAAAsXboUf/zxB1auXIm3335b7fiVK1ciJycHe/fuhYODAwAgODhYvSL29vDz81PbTlTJWSaFs8xZtK2xS1XuztNRzfHjoesmla0a58gVAnLvlqJpo6qEZ4GREBGRVTKq66q0tBSHDx9GbGxsVQF2doiNjUVKSorGc3799VfExMRgwoQJ8PX1RceOHTF37lzI5eKZaC9cuICAgAC0bNkSI0eORHp6utZ6lJSUID8/X/SihsnNqSpWfzMu1ORyVAOZyT+fQPTcJOy6kF213+SSiYioLhkV6Ny6dQtyuRy+vr6i7b6+vsjMzNR4zqVLl7B+/XrI5XJs3rwZ06ZNw8cff4w5c+Yoj4mOjsbq1auxZcsWLFmyBJcvX0bv3r1RUFCgscx58+bBw8ND+QoMDDTmNsiG2NlJ8M3Y7lgysit83JxMLmfMqoNIOpsFAFh/+DrKFQJm/34G//fLSfx06Bq7toiIrJTFh5crFAr4+Pjgyy+/RGRkJIYNG4Z33nkHS5cuVR4zaNAgPPXUU+jUqRPi4uKwefNm5Obm4scff9RY5pQpU5CXl6d8Xbt2zdK3QfVY7zbeGBTuX+Nyxq45JPr+fFYhvt+fjrfWn6hx2UREVDeMytHx8vKCVCpFVlaWaHtWVpbW/Bp/f384ODhAKpUqt7Vr1w6ZmZkoLS2FTCZTO8fT0xNt27bFxYsXNZbp6OgIR0fjJowjMsTtwhIte9ikQ0RkjYxq0ZHJZIiMjERSUpJym0KhQFJSEmJiYjSe07NnT1y8eBEKRdUaQefPn4e/v7/GIAcACgsLkZaWBn//mv+VTqSqf5j6yEBV2iYVLJUz0CEiskZGd10lJiZi+fLlWLNmDc6ePYvx48ejqKhIOQpr1KhRmDJlivL48ePHIycnB6+++irOnz+PP/74A3PnzsWECROUx7z55pv4559/cOXKFezduxdPPPEEpFIpRowYYYZbpIasc3MPSO0kyu8f1tPFdVXLBITTNp5Sfs0RWERE1sPo4eXDhg1DdnY2pk+fjszMTERERGDLli3KBOX09HTY2VXFT4GBgdi6dStef/11dOrUCc2aNcOrr76KyZMnK4+5fv06RowYgdu3b8Pb2xu9evXCvn374O3tbYZbpIYoqkVjvP5QW7T3d0e396rmZwpu6qLzvNtFmruu8u6VKb8WBEAi0XgYERHVMybNyjdx4kRMnDhR477k5GS1bTExMdi3b5/W8tauXWtKNYi0spNI0LO1FwCgXGX5hza+bjrPO5qeq7dshSDADox0iIisARf1pAbFQVrzAIUdV0RE1oOBDtmUpq4VCe7926knHYd4ucLerub/5JmiQ0RkPbigFNmUP1/tjf2XcxDfUX26g6AmLmZp0VEw0iEishps0SGb4uPuhMc6B8BBqv5Pu3NzD0gslEVcLldgzKoD+HDrOYuUT0REpmGgQzbv++eiMTqmBV7q1xoAsGhElxqVp6lFZ+eFbCSnZmPxDs3z8BARUd1goEM2r0drL8wc3BFODhWzcz/eOQDdghubXJ4gAEuS0zDiy33IKSoFANwtles5i4iI6gIDHWqQJDUYHq4QBLy/5RxSLt3GzN9OAwDkCnErz4K/UrFm75WaVJGIiMyAycjUMNUgVUc1pDl05Q4AoFxliYiLNwuw6O+KddpG9wg2/UJERFRjbNEhMpJQtWwbissquqzkKnk7+cXltV0lIiLSgoEONUg1GXslqLTpVAY4l7KLqvarBD1cF4uIqG4x0CEykmo6juL+N0v/qRptJVdoPpaIiGofAx1qkGoync4TX+xRfq0pkFEdfs7JBYmI6hYDHWqQajLq6urtu8qvq4+2AhjoEBHVJwx0qEEy1wTJCkHArN/OiLadyyio2q+ofgYREdUmBjpENVBSrsDKPZdF22b9XhX46GvRUTCJh4jIohjoUINkoSWv1Mh1BDoTvjuC2AX/oKScsyoTEVkKAx1qkFxltTNXpqCj6+qPkxm4dKsIu87fqpW6EBE1RAx0qEGa/lh7tPN3x6zBHSx6HV0tOpWe+/qQRetARNSQMdChBql5Yxf8+WpvjOgeZNHrcNQVEVHdYqBDDZqD1A6rE7ph2bORFilfoRA0DkEnIqLawUCHGry+oT6I6+CHVQndzF729rM30XHGVvx+4obZyyYiIv0Y6BDd1y/UB6/0b23WMv/vl5O4VybHxO+PmrVcIiIyDAMdIiIislkMdIjqAFc1JyKqHQx0iFTUVvjB/GQiotrBQIfIALsm9atxGcVlVTMgc9g5EVHtYKBDpIevuyMCm7jUuJywaVuwcPt5AAx0iIhqCwMdIj183JzMVtbC7RcAAIxziIhqBwMdIhUyqfqvxCfDIsx+HbboEBHVDgY6RCpG9wwWBTsH34lFa59GZr8Ok5GJiGoHAx0iFe5ODtg1uSrxuKmrzCLX4bIQRES1w76uK0BU3/i6O2H+0HC4ONrDzk5ikWtwHh0iotrBFh0iDYZ3D8LjnQNE2+YNDYevu6NZymeDDhFR7WCgQ2SgEd2DsG/KALOUxWRkIqLawUCHyAgSiQQ/jHsA7f3da1QOAx0iotphUqCzePFiBAcHw8nJCdHR0Thw4IDO43NzczFhwgT4+/vD0dERbdu2xebNm2tUJlFdiWnVFH+80ku07fvnog0+P+lsFpLPZZu7WkREpIHRycjr1q1DYmIili5diujoaCxcuBBxcXFITU2Fj4+P2vGlpaV46KGH4OPjg/Xr16NZs2a4evUqPD09TS6TqK5JJOIkZUcHqcHnjl1zyNzVISIiLYxu0VmwYAHGjRuHhIQEtG/fHkuXLoWLiwtWrlyp8fiVK1ciJycHGzduRM+ePREcHIw+ffqgc+fOJpdJZGu+SbmCzLxiXMgqqOuqEBHZFKMCndLSUhw+fBixsbFVBdjZITY2FikpKRrP+fXXXxETE4MJEybA19cXHTt2xNy5cyGXy00us6SkBPn5+aIXUV2S1HAU+rRNp/HAvCQ89MlOZOYVm6dSRERkXKBz69YtyOVy+Pr6irb7+voiMzNT4zmXLl3C+vXrIZfLsXnzZkybNg0ff/wx5syZY3KZ8+bNg4eHh/IVGBhozG0Q1WtnMxm4ExGZi8VHXSkUCvj4+ODLL79EZGQkhg0bhnfeeQdLly41ucwpU6YgLy9P+bp27ZoZa0xUxzggi4jIbIxKRvby8oJUKkVWVpZoe1ZWFvz8/DSe4+/vDwcHB0ilVcma7dq1Q2ZmJkpLS00q09HREY6O5pm4jchcPJwdkHevrMblcOg5EZH5GNWiI5PJEBkZiaSkJOU2hUKBpKQkxMTEaDynZ8+euHjxIhQKhXLb+fPn4e/vD5lMZlKZRPXRvikD8Fi12ZRNMeH7IzivkpScXVCCb/ZdRUFxzYMoIqKGxuiuq8TERCxfvhxr1qzB2bNnMX78eBQVFSEhIQEAMGrUKEyZMkV5/Pjx45GTk4NXX30V58+fxx9//IG5c+diwoQJBpdJZA2cZVJ4Nar5IqDFZQo8vawqEf/ZFfsxbeMpTNlwssZlExE1NEbPozNs2DBkZ2dj+vTpyMzMREREBLZs2aJMJk5PT4edXVX8FBgYiK1bt+L1119Hp06d0KxZM7z66quYPHmywWUSWYsnuzbHqj1XalxO7t2q1ptzmRWtO1tPa07OJyIi7SSCDSyjnJ+fDw8PD+Tl5cHdvWZT8xMZatOxf/Hq2mMAgA0v9UDXoMYAgOC3/wAAdAhwx+kbpo+gujL/EVF5UjsJ0uY+XIMaExHVL7Xx/OZaV0QmUl3dXNM0Og+0bGrW6zFJmYjIeAx0iExUfRmI6qR2NZxFsBrGOURExmOgQ2QGmmIQu5pOl6xB7t1SHL6aAxvocSYiqhVGJyMTUd3p91Ey7twtQyNHexyeFgtHe8MXEyUiaojYokNUAyFernBysEN7f81JdF//r7tZr3fn/miswpJy/HjoulnLJiKyRWzRIaqBba8/CLkgaGxZESDgwbbeeDqquUWCkry7pWYvk4jI1rBFh6gG7KV2eruPyhWm5dMUlZTr3M80HSIi/RjoEFmYwsRAZ9rGU2auCRFRw8NAh8hS7sc3ch1xTqfmHlr3bTj6r5krRETU8DDQIbIwucqCttU52pv+K8ieKyIi/RjoEFmYXEfXlZOD6cPDF2w7j9uFJTqP+XDrOcz786zJ1yAisnYMdIgspDK8Gd0jWOsxMmnNfgWnquTxVJ9EsKC4DIt3pGHZP5eQU8QRWkTUMDHQIbKwHq28sG/KAI37arp+1fmsipXNj1/LReSc7fjx0DXlvnKV5KByHd1nRES2jIEOkYWotrD4eThpPKZMV6YygAV/peq+xv3/v/zDUeQUlWLS+hMaj5NoXHaUiMj2MdAhqkNlct0tLYv+vqi7gMqRXRrygJisTETEQIeo1jRyrJiIvG+ot3Lb7RrmzlR2fWla5JMLfxIRMdAhspjqccbqhG54Oqo5Ph3WBV2CPAEAj4T71+wa1f6vjQUWUicisgpc64qolkQFN0FUcBMAwJr/dcfhq3fQMcADnyZdMLnMymBKNajKyi+GQ7XRXIxziKihYosOkYWE+rlp3efu5IB+oT41Hl6u7LpSadOJnpuErrO3MUeHiAhs0SEyu98m9sLBKzl4smtzvcc6y0yfMBCoasnRNCeh6tB1Bj1E1FAx0CEys/DmHgjXsYaVKlkNloBQpSnvWHUeHeYlE1FDxa4rIitWNbJKPZJRHXIuQMCByzno/cHf2HHuZi3Vjoio7jHQIbJiCg3JyJXKVfuzBGDYlym4lnMPCasP1k7liIjqAQY6RFasMglZ01IS5SqTEQrQ3n1VUi7H+sPXkZlXbIkqEhHVKQY6RFZM2aKjYV+ZjhydopJy3MyvCGw+//si3vzpOB5etMtCtSQiqjtMRiayYprm0alUPUdHVdfZ21BSrsCB/xuAv+/n7HCFcyKyRQx0iKya9iUg9qbdqjpKZbdMaoeS8opure5zk+DVSGbZKhIR1SF2XRFZsVuFFa0wmrqu5v15Tvm16v7q+TyVZRAR2SIGOkRWbtuZLL3z5Ki2+GhKXCYislUMdIis3Jq9V/SuVK66W9MsykREtoqBDlE9ZMyMyWVyhd4lHtiIQ0QNFQMdonqotFyh/6D7yhWC/q4rrnZFRA0UAx2iem5c7xCd+8vlCr2BDFt0iKihYqBDVM9Njg/Tuf/49Ty9eTflTMwhogaKgQ5RPSeRSPQfpCeOiV3wj3kqQ0RkZUwKdBYvXozg4GA4OTkhOjoaBw4c0Hrs6tWrIZFIRC8nJyfRMWPGjFE7Jj4+3pSqEdkcO4PiHLbYEBFpYvTMyOvWrUNiYiKWLl2K6OhoLFy4EHFxcUhNTYWPj4/Gc9zd3ZGamqr8XtNfqPHx8Vi1apXye0dHR2OrRmSTDGnRUV3XioiIqhjdorNgwQKMGzcOCQkJaN++PZYuXQoXFxesXLlS6zkSiQR+fn7Kl6+vr9oxjo6OomMaN25sbNWIyIxu5N7DnN/P4FrO3bquChGRyYwKdEpLS3H48GHExsZWFWBnh9jYWKSkpGg9r7CwEC1atEBgYCAGDx6M06dPqx2TnJwMHx8fhIaGYvz48bh9+7bW8kpKSpCfny96EZF5PbfmEL7afRnPfLWvrqtCRGQyowKdW7duQS6Xq7XI+Pr6IjMzU+M5oaGhWLlyJTZt2oRvv/0WCoUCPXr0wPXr15XHxMfH4+uvv0ZSUhLef/99/PPPPxg0aBDkcrnGMufNmwcPDw/lKzAw0JjbICIDnMmo+APiWs495baC4jJcyCqoqyoRERnN4quXx8TEICYmRvl9jx490K5dOyxbtgyzZ88GAAwfPly5Pzw8HJ06dUKrVq2QnJyMAQMGqJU5ZcoUJCYmKr/Pz89nsEM2KaiJS11XQaTfR8m4VViKX17qgS5Bhncvl8sVyMwvRvPG9et+iMj2GdWi4+XlBalUiqysLNH2rKws+Pn5GVSGg4MDunTpgosXL2o9pmXLlvDy8tJ6jKOjI9zd3UUvIlvUqblHXVdBpHKl821nsvQcKZaw+iB6vb8Df58z7jwiopoyKtCRyWSIjIxEUlKScptCoUBSUpKo1UYXuVyOkydPwt/fX+sx169fx+3bt3UeQ2TLtrzWG8/1CsHswR1r9br6FgetZOz8g7su3AIArNl71dgqERHViNFdV4mJiRg9ejSioqLQvXt3LFy4EEVFRUhISAAAjBo1Cs2aNcO8efMAALNmzcIDDzyA1q1bIzc3Fx9++CGuXr2K5557DkBFovLMmTPx5JNPws/PD2lpaZg0aRJat26NuLg4M94qUf3W0tsVE/u1Rjt/d4T5uWPqo+1rvQ6vrj2GV2Pb4PmvD4m23yuV47xKbk5lQCRXCCgoLoOni6xW60lEZCijA51hw4YhOzsb06dPR2ZmJiIiIrBlyxZlgnJ6ejrs7Koaiu7cuYNx48YhMzMTjRs3RmRkJPbu3Yv27Ss+xKVSKU6cOIE1a9YgNzcXAQEBGDhwIGbPns25dKhB+G1iL3y+4wImx4ehpXejOq3Lr8dv4GZBMdKyi0Tbhy/fh+PXcpXfVzboPL0sBYev3sHfb/Sp87oTEWkiEQxtq67H8vPz4eHhgby8PObrkE0KfvuPWrtW5+YeOH49T+cxz/UKwdRH2yvr9Ur/1kgcGKr1+Mrj+rT1xpr/dTdfZYnIqtXG85trXRGRiL4gB1BfWsvQv5YMWbaLiMicGOgQkdEU1RqCrb9dmIhsFQMdIjJa9cCGi4oSUX3FQIeIjLZ67xXR92zRIaL6ioEOkZUZFlX/ZgE3OEfHorUgIlJn8SUgiKjmJseH4fDVO1j63664fuce1h26VtdVEmGLDhHVVwx0iKzA+L6tlF/bS+t/u4ggCLh8qwjBTV1hZ1f/60tEtotdV0RWxtFeWtdVUFM9GXn5rkvo//E/mLbpVB3ViIioAgMdIivj5FD/fm2rd119tPU8AOC7/emi7RJOpENEtaz+fWISkU5ODvWvRae6UrmirqtARASAgQ6R1XGQ1o9fW9XVY3StJDPzt9O1UR0iIo3qxycmEVkdhaD56+pW7blSo+sIgoCScnmNyiCihouBDhGZpPoyEIYwJUNn2LJ96DJrGwpLyk04m4gaOgY6RFZoyqAwxHfww/fPRddZHVTjHFPm0VEoBPx46BrSsgu1HpN3twwHruTgbqkcey/eMqGWRNTQcR4dIiv0Qp9W+g+yMNUWHdXh5Qpd/VgqPt6WisU70tDapxG2J/bReMyMX6uGp6vOHyRXCPgm5Qq6hTRBhwAPY6tORA0IW3SIyCTaWnTkBjbv/Hr8BgDg4k1xi07e3TJcy7kLANh47IZyu71d1cfVz0eu493fzuCRRbuNrTYRNTAMdIjIJO2mb1F+vfZgOjLzigFUtLZoozqNjkQlY+f4tVzEL9yJ3RduofOsv9D7gx24kXtPdK69ygzLZ27k17T6RNRAMNAhsiGfDOuMF/q0rPXrFpcp8MQXewAYnqSsujLE6FUHcC6zAP9dsV+57fi1XNHxX+2+jEGf7kLu3VIcSb+jscy8u2WY8/sZnL6RBwBIyy5UBmBE1DAxR4fIyr3SvzUW/X0RADAkohkkEgna+bnDz8MJw7/cV2v1yLgfUJQbmKNjp9K8k3u3TG3/n6cyRd//fe4mAODLnZdw4nqexjJn/nYaG47+i692X8ahqbEY8PE/AIAr8x8xqE5EZHvYokNk5bqFNFF+XbnEwpAuzfBAy6Z1Uh9Dk5H1LfZZmcNTXWm59lmXz2RUdWldyi5Sft3nwx24cqtI0ylEZOMY6BBZOVOGdluSrhydcoWAk9fzoFAIMHVRc0OXy1Kdrfnq7btcYJSogWLXFZGVM2XiPkvSNeoqOTUbyanZeCsuVNR1VRvKuP4WUYPEFh0iK6crzPF0cai1elTS1aJTaUlymskrmXMFdCIyBgMdIivX3NNZ676/3+iLn16MqcXaGBboCIJg0nIQ+jAIIqLq2HVFZOXa+Lrh0+ER8HV3UtvXxFWGJq5NNJxlOQoDeogEAHYW+DNLNcypHm5JLBJaEVF9xxYdIhswOML4UVaT4kMtUpf1R64bdFxt5+gQUcPEQIeogXK0l2JAmI/Zy12UdEHvMYJgejeTrrMYOxFRdQx0iBowcwcGeRom/tNEIQiQmnrtWgpmbuYX45uUKygsKa+dCxKRRTDQIWqgBAsMS1ddbVwfS3RdmbPIEcv3Ydqm05j6y0nzFWqgPRdvYfDnu5VLWRCR6RjoEDVo5g02VFcb10WA6YGOoUnF1eM4Yy+Xdn9m5aSzN4070QxGfrUfx6/nIWHVwVq/NpGtYaBD1IAEeFSNzHo43L/ucloEWKQLSlcQVH2RUEPV5XSMd+6W1uHViWwDh5cTNSDe7k7YNbk/ikrL4e7kUGcDrgWYvgREdYeu5KCdvztcHe1FgdvoVQdExxWVypF3twwedTCJIhHVHbboEDUgCoUAqZ0E7k51+7AXBPPl6PxnaQr+u2K/2nZNi39mF5aY5ZpEZD0Y6BA1AK19GgEAHuvsL9peV11XNcrR0XDa0fTcin2mVwk3cu8hJe222nZLJG0TUe1h1xVRA7D+xRgcSb+DB9t413VVANxfAqJOgiztQUuP+X8DAH4eH4PIFppnky6XK7D/cg4iAj3h6mjej8/cu6X4NOkC/hPZvKq2jLGIasykFp3FixcjODgYTk5OiI6OxoEDB7Qeu3r1akgkEtHLyUk8Vb0gCJg+fTr8/f3h7OyM2NhYXLigf9IxIjKMp4sM/cN8YS8V/8o/E92iTuqjEICT/5o2dFpnfGSG6OnYNXG9VGONZTsvYeRX+5Gw2vyjoaZvOo1Ve67gkUW7zV42UUNmdKCzbt06JCYmYsaMGThy5Ag6d+6MuLg43LypfQimu7s7MjIylK+rV6+K9n/wwQdYtGgRli5div3798PV1RVxcXEoLi42/o6IyGB92npj51v9lN/3C/XGvKHh+OWlHlgxOgoye8v1bucaOLmgMczRSOTmpL2l5sdD1wAABy7nmOFKYmcy8s1eJhGZEOgsWLAA48aNQ0JCAtq3b4+lS5fCxcUFK1eu1HqORCKBn5+f8uXr66vcJwgCFi5ciKlTp2Lw4MHo1KkTvv76a9y4cQMbN2406aaIyHBBTV2UX0vtJBjRPQhdghpjQDvferkM5hfJaVr36WvQ0dYVVFwmV37triPQcXaQ6r5ADTAXiMgyjAp0SktLcfjwYcTGxlYVYGeH2NhYpKSkaD2vsLAQLVq0QGBgIAYPHozTp08r912+fBmZmZmiMj08PBAdHa21zJKSEuTn54teRFRzCqH69w3j4XuvtCrQcajWvaf6I3CWWS7QISLLMCrQuXXrFuRyuahFBgB8fX2RmZmp8ZzQ0FCsXLkSmzZtwrfffguFQoEePXrg+vWKFY4rzzOmzHnz5sHDw0P5CgwMNOY2iEiL6q0KqoFPfwssAFrbtIVtqgHd2DWHqp1Ttc+iLToGbiMi41h8eHlMTAxGjRqFiIgI9OnTBxs2bIC3tzeWLVtmcplTpkxBXl6e8nXt2jUz1pio4areoiNX2fDJsAjMfSIcK8dE1XKtDBO74B/lMHNttDVQVb9vbecYGujsu3QbKWm3kXu3FMmpN0U/RyKqXUaNj/Ty8oJUKkVWVpZoe1ZWFvz8/Awqw8HBAV26dMHFixcBQHleVlYW/P2r5vjIyspCRESExjIcHR3h6OhoTNWJyAC6uqo8nB3wTHRQLdbGOBdvFpp8rlCt7URbvoy25Ozfjt/AX2ey8MGTnSCRAMO/3CfaP+3R9hjbK0RfJYjIAoxq0ZHJZIiMjERSUpJym0KhQFJSEmJiYgwqQy6X4+TJk8qgJiQkBH5+fqIy8/PzsX//foPLJCLzaCApOeqq3fenSZqnt9A2yeHLPxzFb8dvYE3KFY3rU/1+omKx0692XcKOc5pHqDaUfCii2mb0jFeJiYkYPXo0oqKi0L17dyxcuBBFRUVISEgAAIwaNQrNmjXDvHnzAACzZs3CAw88gNatWyM3Nxcffvghrl69iueeew5AxYis1157DXPmzEGbNm0QEhKCadOmISAgAEOGDDHfnRKRXtVbNhqK6j1LC7dXBTqqu/SN6iooLkPePfVh8xJUrMk154+zAIAr8x8xqF4ciUVUc0YHOsOGDUN2djamT5+OzMxMREREYMuWLcpk4vT0dNjZVTUU3blzB+PGjUNmZiYaN26MyMhI7N27F+3bt1ceM2nSJBQVFeH5559Hbm4uevXqhS1btqhNLEhElqVQXx7KpmgL5AwN8PQtW+FoL0WehvmBJBIJMvN1zwtmqWRkhULA1tOZiAjyhL+HsxlKJLIuJs1hPnHiREycOFHjvuTkZNH3n3zyCT755BOd5UkkEsyaNQuzZs0ypTpEZCa23n2i7fYMvW19K67L7O2QX1yucZ+kjmYlWn/4Oib9fAJSOwnS5j5cJ3Ugqktc1JOIlGw7zKkIaK7cKsLXKVdQUl41d47OAE9ll2qLTnGZHDertdI42ttp7LoCTJ/MsKZ2XbwFABz5RQ0WF/Ukoio2/ixUCAL6fpQMALhTVIZXY9sAMCLIUAlWOs38C6XlCvz+ci/lNicHqdYcHX2tQZZSV9clqi/YokNESrbedbXv0m3l1ymXbhl0TqlcgS2nMlFQXCZq0Sktr0hoevSzqkU4ZVI75Gtp0dG1EteZG/lIz7mrtl3f25FTVDFPj0JHa420bpaJJ6o3GOgQkZK5Ap2pj7QzSznmVjnqCRB35ei77xe/PYyXvjuit3XEXipBuYaMbolEd9fVI5/t0l2wFoM+3Ykxqw5i3SHtk6ZKGOhQA8dAh4jw3wcqJgJMfCjULOUN61b/l2UpVwl0DInvdl24hfx7mhONVWlqXJHoSUU2Nb7Myi8BAPx1WvNyOYDluq4EQcD5rALm/lC9x0CHiDB7cEccnzEQvdp4GXS8voenm5ODGWplWaoPaEMf1Vt0BBTKsrQUZsmWFamGN0QQBOTdK9M5JL5Mbvp8Ast2XsLAT3Zi0voTJpdBVBsY6BARJBIJPJwND05soTukTC7gVmFFi4g5WyU0TvInEQeH5p4IsPL9uH7nLiatP47zWQWY+dsZdJ75F3Zf1JyL9FnSBbR5508cTb9j0jUX3Z89+ucjFQs0yxUCDl/NEY1mM5UgCFi15zL2q+RUEZmKo66IqEE6m5GPqDnbMW9oON7fcs4sZb669hja+DTSuE8iCnT0Dzc3RmXC8fNfH8aZjHz8fiIDd0srAo5/c+9pPOfjbecBAO/+dgabJvSscR0Wbj+Pz/6+iIfD/fDFyMgalZV8PhszfzsDwPBZpIm0YYsOERnNloYsT9lwErkaZjM21QUti4uqZunIVVp0LmQV1PialV1XZzLyAUAZ5FS36di/atuOX8vFxZvG16F6o9SXOy8BADaf1N+9p0/6bfURaESmYqBDREarq1l+rZVE+Z8KqqO8vtl3tcbln76RV3EdPW/Lq2uPadweu2BnjetgTjbQM0r1CAMdIjIeH0RGkUjEP7JyeVWgY29X84/hK/dbQKzhbblTVIpNx/7FPS2tToBl70MQBLy/5Rx+O37Dgleh+oSBDhFp9WBbbwDAsCjxcHFreKDWN6qjn8Lf3YpZ93NQHOxr96epb4JBQ5m60v2YVQfw6tpjmPPHGe0HWbBJJ/l8NpYkp+HlH45a7BpUvzDQISKtFj/TBUv/2xUzB3cQbWfXgnEkkOC8Si6OQgBW7rkMAHAwQ4uO8joGvDFjVh3E+vsjpWrC1IFjx69XdLPpalGxZA5YVp7uVeRrYuf5bBwxcRQbWQ4DHSLSys3JAfEd/eHkIBVtt0SOzpNdm5u9THNzkJp+36qzMovLNGOgY+BxW0/VPGG4pux0RDOWzAGTW2iZk5v5xRi18gCGfrHXIuWT6RjoEJHRnGVS/QepmDc0XO8x9lYwlMvR3rj7NkRtd10BFet3mdPGo/8a3cqnayJDS/5TsNRMzjcLSixSLtUcAx0iMtqK0VFo5umMIREBeo8d0T0QT0fpXxJCWoPWktpiaouOriDAvF1Xhh1XUmZ4oHMt567GuXhUw4XX1h0zuiurMpjJzCvGf7/aj21nspT7ato1uunYv/js/oSG1XHJioaHgQ4RGa1LUGPsebs/4jv66z1WECrmednzdn+dx1nDKttSE4MSXUst6Aue1uy9glErD+DUv3l6r2Nol0+JhvqczchH5v38lcqV2YvL5Oj9wQ70nP+33uUiSsqNayWqzCd699fT2H3xFsZ9fahqXw27rl5dewwfbzuPY9dy1fbVRqBj7pmvAUChECxSbkPAQIeITGZMbNLM01nnfk3rNVXXqbmH4Re0AFO7VA5e0Z6gaq8nR2fGr6ex83w2hi1L0XlccZnc4C6pUg1ByaBPd+GBeUlYdzAdbaf+iS2nMpGeUzVxn9okhDV85lb+LG8XqXf5iGeRrrrQ1dtFmPXbGWTkaZ7tubo7RaVq2ywV6Kh2xZk7HpErBMQt3Iknl+xlsGMCBjpEZDJjn/vVh6mrMiRHx5yJu6awRB6GrlwVVUWlcp0Pufl/Gr6Mxdn7MyhrMvnnkwCAF789jDd+PF61o9qldQVVhiwWWnnfmu5fdfTY898cVn791NIUrNxzGS+obDPG3dJyzDfTch/Vqd6GwszByJXbRbhwsxBH0nNRzq43ozHQISKT6RrOHNWiMQDg6W5Vwc37/+mk9XgXR/1L7xnS6mNtjGkVW7nnitZ9qjku5nJSpbtMdd6cS9mal7mo9HXKVWQXlOgMzHQGOipfq95XZaB54rr+bjxNPt1+weytLZVErVBmLputODXDQIeITKYr7lj7/APY/38D0DWosUFlPdTO10y1sh6Hr+Zg62nDh3rP/l3HJHu1SF+g8cm28+j23nZM3XhK6zGVgYGmtCdz5WdXn9TwjI6WLF3ulcrxx4kM5BdrXxNNNa/I3C06qsWZu+yGgIEOEZlMV2uEvdQOvu5OattberuqbZsyKAwyewM+jmzsM/7JJSlITs2u62poVP29VX2+6muFKiwpBwB8tz9d6zG6W3QMa+a6kXsPuy5kW7zFY/qmU5jw/RGM/1Z7l1n11enNSXXuH8Y5xmOgQ0QmM2Qm3up+fCFGbZsAwJD0G1OXHaCaU21JMDSvSBc7ScVIol0XbplcRo/5f+PZFQdwJD3XoONNDRJ+Olwxk/Sei7e1HqP6EzF3MKJQSXkyd9mT1h/H6JUHzLIsSH3FQIeITGbK486rkaPatooPb/2l8a9Z7czdqlH93VAt3TyBjgQ/a1mKwpDii8uqRoGlZhZoPKb6j6R6oJyZV4ynl6aYZYFPSyYjq5Zn7rJ/PHQd/5zPNrlbzxow0CEik6m26Ay+P3ngYAMmEdRM/wc445wq1QMbc4/Gqd5aJ27RMUf5wO6Lmltzql/79XXHcENl0kKJRDxvj6ujaTNWz/r9NA5cyTHTAp+Wy9FRLc9SvwO2PJGi/mEORERaqD6O5g/thMc7B6BHKy+TytL3bJg/NBw/HrpmUtm2qPpzydwPV7UWHZXida1TZSg7iUTtPS+XK3AmI18tiPvl6L+4fqdqTh9BAH4+XNUaVCY37N6rXy/3rvbkYmNZctSV6nttqWRk2w1z2KJDRDWg2oXhLJNiQDtfo9fBAoAQLxfRh/kHKsPQJ8eH4cr8RzC8e5BNfxgby/ItOuLvv9+fjn2XKnJUzNV1Vd20Tafx+Od78OHWVLV9Z26Iu1ZmqYxAKymXVz8cgHpgszdNnGNjzsm4RTk65l1KTNTaYu6yleXacL8wAx0iMllNHxQOUgmmPtIOcR38RH+p/qdrcwzvFohBHf0wvm8r5XYb/iw2WvW4xtJdD58mXcDwL/cBMF/XVfUa/3CgYpTW9TvqMx/fK9MczADAjwc1t/Spln8zv1i9DtXarWrysFftbjP/8HLVriv+EhiLgQ4RmczUCfwm9KsIXn55qSee690SEokE/h5VQ9Ht7CSY/2QnLPlvpOg8fsRXOX49V/S9uUfN6BribY4WnXNaEoi10XV7xw2YQFDTWlyqt/Fv7j10e287Ptl23qh6KctS+VpTVU9cz8VJEyc6VC3PUvGsLf9uMUeHiEwW1aIxIls0RoiX+tw4urwVF4aX+7eBk0NVN5eniwx/vNILzg46ur40/KXsaG+HxIfa4uNt5zWu4WRr5AoBUjsJnloqXvvK7EsD6IhlrGD9VQDA/ku30djFAVHBTfTW+eO/UnGrsBSfJl3A6w+1rdF1VVt0copKsSjpAlbvvQIASJ0TD0d747p3VVvrzNlaJGopsuHmUrboEJHJ7KV2+Hl8D3z0VGejz3XSENB0CPBAS+9GWs/R9FEsAHihTyt8/1y00XWwRtrWeTJ29XB9dAWN1rIUx1e7L+M/S1NQWq7QOOfTkatVi60a+pw3JCBQDUYmrT+hDHIAoLjU+PdJlKNjRDxSrKO7rzobjnMY6BCR9ajJh7GmGZmt0faz5l/TyhiCIJil66qyLHPL0zCSqlSu0JhXVKSyIruhdcnUkOsDVAvCVb45dDXHoHJ1KVcY3/Jy+OodhE3bgvd1LGKqWpQNxzkMdIjIevRo1VR9o4Gf0OZ6ODd0coVg0kSRmlji4Trzt9Pq1xEEvctKGFoXQ+IM0XDw6l2KJvzw5CpTIxvaQzl381kAwJLkNK3HqBbFFh0ionrg9YfaIsBDvH5W5SgUfZ/TVtLbUu+ZMxfIEssOVE/SBir+begbrWTog756jkxRSTm+2XcVmXlVLT0bjlbN8WOOOyxXmSfI0FFXulp+BEHA2Yx80Ug2W87RYTIyEVkNJwcpRvUIxvw/tTfHa2PoQpGkW5lcYbaRP+YMmgRBgEQi0VimoNDfEmJqi87s389gbbXh7R9sScVLfVsbVXC5XIG8e2VoqmmJFJWvDf2R6Tps+9mbGPf1IbRS6c7VdPydolIUlpQjsImLYRetp0xq0Vm8eDGCg4Ph5OSE6OhoHDhwwKDz1q5dC4lEgiFDhoi2jxkzBhKJRPSKj483pWpE1MAY+oeoLfVcJaferLNrl8sFo+ZyKZdrT74159w/lf8OyjXMkiwXBL2tR4Ym7lb/96YtZ2rVnsu4VVhi8E/qiS/2InLOdlzIUh92r3pNc7SC/XK/xSktu0jncV1mb0PvD3YgS0tekrUwOtBZt24dEhMTMWPGDBw5cgSdO3dGXFwcbt7U/Yt35coVvPnmm+jdu7fG/fHx8cjIyFC+fvjhB2OrRkQk0ta3agSXLeXojFl1sM6uPf/Pc3h2hWF/3ALA2xtOat1n1hYdZZnqgZUhAdW2M4YleQsQUC5XYMG280hJu6010J752xkkrDpocJfQyX8r5tjZdEzTAqPG/5x0XVbTPl3Hn75h2vw/9YXRgc6CBQswbtw4JCQkoH379li6dClcXFywcuVKrefI5XKMHDkSM2fORMuWLTUe4+joCD8/P+WrcePGxlaNiBogXY+AjgEeyq/tmJFoFuuMXG9s/eHrePvnEzhwOQeHVYZzA7pbe4xVGVBoCmoEQTDb/DMKAfj5yHUsSrqAEcv36Sz35L95al1NxWVy5N4t1XqOptYyXWtdJafexMWbhRrKMY4t5+gY9atfWlqKw4cPIzY2tqoAOzvExsYiJSVF63mzZs2Cj48Pxo4dq/WY5ORk+Pj4IDQ0FOPHj8ft27e1HltSUoL8/HzRi4gaNk3tNdMfa49xvUOw+ZXeNtWiY23WHryGp5el4Mkle0XbNXUzmaqqRUdL15WZLiUIAtJzVBYY1Vsv8RFPLU1BxKxtyC4o0VK++PvCknK89N0R5feq93H8Wi7GrDqI2AX/GFR3bdeoqKeY6jxK1p7fZlSgc+vWLcjlcvj6+oq2+/r6IjMzU+M5u3fvxooVK7B8+XKt5cbHx+Prr79GUlIS3n//ffzzzz8YNGgQ5HLNfabz5s2Dh4eH8hUYGGjMbRBRA+HpIsM7j7RH+wB3K/+otk1lGrqZTPX53xdxLecu5JpydBTma9ERAEhVmgf1FVs9wKoMkrq9t11jK0r1Ld+kXBV9X1hcjqFf7MGyf9JwJsO0P/I1tRqpVuVcZj7aTv3TpLLrI4uOuiooKMCzzz6L5cuXw8vLS+txw4cPV34dHh6OTp06oVWrVkhOTsaAAQPUjp8yZQoSExOV3+fn5zPYISKdNM2MS3XLnMnInyZdwPcH0jUGTwqF+eaJEQQB9ipzFegNoHTsTssuQmsf7TOBA+qLma7eewVH0nNxJD0Xc58I11VR3fXS4aPqq8db+a+OUYGOl5cXpFIpsrLESVtZWVnw8/NTOz4tLQ1XrlzBY489ptymuP+P0N7eHqmpqWjVqpXaeS1btoSXlxcuXryoMdBxdHSEo6P6EDwiangMzS2wlmULGhJzdl0BQHZBCRyk6u+zXBDMloMiV1T7t6Sn2FIdeUiaEqfVVKu3Mcs6GFhkxTaVG7HUwqF1xaiuK5lMhsjISCQlJSm3KRQKJCUlISYmRu34sLAwnDx5EseOHVO+Hn/8cfTr1w/Hjh3T2gpz/fp13L59G/7+/kbeDhHZuuqPscrPZH1zfVhjmNMlyLOuq2BR5mzRqaQpR0dhxhwduUIQBVPmXGQTUA9CFv19Ueuxuhopja2VroRnAMjIu2eR96s2GD0OITExEcuXL8eaNWtw9uxZjB8/HkVFRUhISAAAjBo1ClOmTAEAODk5oWPHjqKXp6cn3Nzc0LFjR8hkMhQWFuKtt97Cvn37cOXKFSQlJWHw4MFo3bo14uLizHu3RGT1tH24+7o7Yf2LMaJuBVXWmIxsfTU2jkEtGkbSFHcoFMbN/aOLQhDEOTo1KEtTkm9lPUvLFTiSfkdtv2oQYmyMVTkHj77Tqsczuy/cQsy8v/Hcmrqb1qAmjM7RGTZsGLKzszF9+nRkZmYiIiICW7ZsUSYop6enw86IcZxSqRQnTpzAmjVrkJubi4CAAAwcOBCzZ89m9xQR6aX6YR8V3AStvBshVcOka7Ud5/i4OeKmlpE1hrL1vCJzzqOjS35xOc5lGp64W1wmxys/HEVsO1883U3c8yBXiHN07pbWvCtJ5P6P5I2fjuO34+pz6vx5SvPAH7ViNAxrH/jJTrT3d9dyvPaFQ1fsvgwA2JGabdC16xuTkpEnTpyIiRMnatyXnJys89zVq1eLvnd2dsbWrVtNqQYRkRq5lj9zLdGi85/I5lh/+LrGfY2c7Gse6NTo7PrP3Dk62gxblmJUUPXOL6fw15ks/HUmSz3QEQSL5ntV1lJTkFOdtiHq1Z28noe8e2VIz7mL9Jy7iG3nq3aMasBm7u64usa1rojIpmjLI7DEhIG6HnfautCMYY3dbcYoM+OEgboY23L08xHNwSugPurK3IxJmv5k+3nt5ah0UF26VYiWXlWjuwpLytSOf+m7I+gQ4I6X+ra2uZXMOVcoEdkUbXkfmoIGr0aOePaBFiZfS1cgoi9IeaGP5lniVTnJpEbXyZrUtMWrLsgVgMzePI9OTf9EzDcMXvy1ap1vF2qemfn0jXxM+P6IzbXoMNAhIpuiacI4AHgttg0AYGD7qmb7317uiUZOpjdsdwr00Lrvya7NdZ47ZVA7veW72nigYw0KisWtH3KFYNGWNrkgoKTcvHk/AsTJ2MV6yrfSwVVaseuKiGyKtm6KyBZNcGpmHErLFfjrzDYAFX/p1uSRNbRLc7zzyym17b+/3Avt/N3R1s8Nvx+/gZ+05PHo4+rIj+i6tuHIv6LvFWZcN6uS6mr0q/Zcwff7081aviBUTJpYSdsfA5VOXM816/XrGlt0iMiqtWgqnj9H11wfjRztIVX5a1yA7tFYfUO9te4b2N5Xa1JqUFMXSO0k6NPWG25ODlrL0Jfq0YiBTp2b8etp0fcVy0mYp+x79xOAt54Wj6QqKTctd6m0XFE1hFw0L444Z0dbwn6l4rLayZ2qLQx0iMhqxXXwxcox3UTb9CWeqiYlC4Kgc8FCnVPsQ3ug0khmWIDSs7X2pXEAwNWRXVf1jdyMLTqDF++pWJ/LDJHT3dJyRM7ehv8srVg4VbVEhSCIAh9rnfjPVAx0iMhqLXs2Cq28xWsFPXQ/B6eVt6vGc1RbYXQ9r2Lb+cLfw0nrfgHaE47tDByV89FTnXXul1piqBjViEJhvuUkAOC7/ekwx+CzA5dzUFBSjiPpueo7BfG/9dqav6i+YLsoEdmUmY93QNegxoht76Nxv2pwohAEjV1Xob5u+Gp0lN5rqZ778/ge+OzvC2oTsumakbepq0zvNah+MWfXFQD8eOga+rTV3kVqqOpBt2owVj2vSF+Ojq1hoENENsXV0R7PRAdp3a/6PKiejDz3iXB89FcqPn5ad0tLVVlVZ9vbSbA6obvaMbr++Lf1eXJskUIw74R6OUWlZulK0vVvSYC4K4stOkRENsxepTvI08UBo3sE49v96Xgk3B/PRAdhRPdAk5Ze0PboaN7YWes5jHOsjzkXCK2kLznYELp6S6u36Nwzwwro1oQdwERkVXQlDxtCaifBjy/E4Jux3eHpIkPTRo449E4sZg/pWFG+luhjbK8Q0fcB9/N3+oV6I8zPDR0DNK8h9GxMC3TQsk8ikeDQ1FiM6RGstu+Blk0MvSWqRXIz5+gAwB8nMmpcRqlKoo+iWiQmCOabiNAaMdAhIqsS39EPABDm52ZyGd1DmqB3m6q8CEOTh1UlDgwFAKwc0w2bX+kNe6nmj1NHeynefbyD1nK8GjnCx128gHF7f3d899wDRtfJXIKauOg/qIFSCIJaIFEfjFlVtbK4vNooK0Ewf3BmTdh1RURWJbCJC45Nf6jW55hRzaOIbecDD+eK+XEkEkmNu6CqPzibNpJZdOFIfR5s64Vv95l30jpbIVcIqO+5vBV1VE1G1t612hCwRYeIrI6ni0xrC4qlqP5FbOyEbvpClvqWHCo1MXJzkUnRUsuwflthia4rc6sYGVZVR6GetkLVFgY6REQGUP0LucTMM8dqyqmoS6Z05Xk1csSZWfGYFBdqgRrVH4ev3kFqZkFdV0On6l1XbNEhIiK9VCd1M/eii2XVA537jyVDwo2l/+1qljqEN6taoFRmQmtZZSuHKSPWrMnag9dMXrustsjlAi7fKlJ+L8C8Q+KtDQMdIiIDCIKAqBaNAQBPRQUada6+Z395talxjXkmxXf0x/bEPkbVR5OX+7dWfv1Qe1+4GZkDVVll2w5zrEP1tbMEwTxNOnWYNlYjDHSIiAygEAR8PbY7fh7fA8901z4hoSnKapjd6uRQ849y1ZaYoCYu2PL6g0adX9miw0kQ697bG06Kvr+Wc9csc/XYW+mSJBx1RURkALkCcJHZI/J+q445ldWgRQeoGMJeUxIAP4+PQUFxOXzcnXAj955R5ytbdBjn1DtrUq7i4JU7NS6nLkcC1gQDHSIiA5hzpE31OYCqBzqdAz2NKs/RDC06Mns7RLaomqTQ2JaZyh+PJQIdHzdH3CwoMX/BDciZjPwal2FvpYGOdbZDERHVspokc/q4Va2C/tFTnfHdc9Gi/eUqXVeT4kPxWmwbo8p3tK/6KNd3bpRKi1TP1k2VX7s5if/uNTZgsWQycqk5lvemGpNKrTPQYYsOEZEBapJGE9jEBYuf6QpPFwf0bO2ltl/1Qf5S39Zq+/VRHSXVMcBDx5HinNQerbyw5+JtABoCHSPrYMlk5IY8B0x9Yur8SnWNLTpERAao6cP2kU7+GoMcQL3rqpKhzxXVVhR9kw8KgoDxfVshzM8NT0U2V253vz/Tc1Whhl27quCK/1kiGbkBj4yuV5ijQ0Rkwyw5D4m2UVemXLJcob+bZ3J8GCbHhwEAhnerWK1dtXsN0L54apcgTxxNz1Wva+V5FngWMs6pH6w10GGLDhGRAeQW7D4Z17slACC+g59Bx8e289W6r62v7sVOq9/F/Cc7Yd7QcLXjNAUscR18sWF8D83lVubocCYdm+VQy8uumAtbdIiIDGDJNJGYVk1xaGosmrjI9B47onsghncLwvazWaLt2xP7IDOvWH+gY+B9aApXFj/TVWuysSV/PvV9bamGwkpTdBjoEBEZoluw+efPUeXVyFFtm6YHS5+2Pugc6Im34kLRvLGzcntrn0Zo7dPIbPXRFNBULqT6UHtfbDsjDrQEC3YwMcypH6w13mSgQ0SkQ/KbfbE37Taeimqu/+BaUNm6MaGf8aOzjKHrj/cvn43EP+ezMWbVQZV6Wa4u2sqW2duh1MiV5Ml0lgxmLck6O9yIiGpJsJcrnokOqjf5CTXtIjL0dF3dFBKJRO3nYclHoLYH7Jgewcqvn45qjk+HR1iwFmStLTr14zeXiIjUjI4Jhq+7o+iBXlurUOtLKq6+18/dSeNxlqRahzA/dzzaKUDjcRMt3PrVUDDQISIis2rsKsO+KQPw7uMdlNtqHOjUJBtZy/4wPzd8NTrK5Crpo1pl0fIYKnWwk2ivssyejzpzsNakcL77RET1WPWkYGMCncAmzmrbzNF1BQD3SuXKr395qafe0V41oVpnX7eqpG1Hle6zqOAmGuv8THRQnQ94rz7rtLWyzjCHgQ4RkVUxZuZhCST48tlIk66j7yp598qUXzuZYVFRnbQ8YZu4yrD37f74eXwMOjbz0DhSbO4T6nME1TZzrC5fH1hpgw4DHSIia/BCn5boEuSJ+I6GTSoIVLT+DOzgh5VjqrqVDO650hNQqQY6lljIU5W2ZGQ/D2cEeDqLVl3XxFrnf6lvrHXUlW20pxER2bgpg9oZfU7l+lz9w6pmUjb0YaUvNuigZ/FQYzwd1Rw/Hrqudb9qcCaRAC/3b42zGQWIbedjUPmtfSzXrWYIWwm0GlSLzuLFixEcHAwnJydER0fjwIEDBp23du1aSCQSDBkyRLRdEARMnz4d/v7+cHZ2RmxsLC5cuGBK1YiI6L6aDEXX93DuHtIEq8Z0w863+uktS3XUmCbGBk1vDAzFV6OjlBMY6hPXwRezBnfAh//pZNR1SMxaF5E3OtBZt24dEhMTMWPGDBw5cgSdO3dGXFwcbt68qfO8K1eu4M0330Tv3r3V9n3wwQdYtGgRli5div3798PV1RVxcXEoLi42tnpERHSfpsTl6ot3amPImlX9wnwQ1NTFoPKGRGge+g1oX729kupdmLKWlkQiwaiYYHQJ8tSwz+jiGjDrjHSMDnQWLFiAcePGISEhAe3bt8fSpUvh4uKClStXaj1HLpdj5MiRmDlzJlq2bCnaJwgCFi5ciKlTp2Lw4MHo1KkTvv76a9y4cQMbN240+oaIiKiC6l/gqxO6oW+oN957oqNB56oGAE90aYatrz1ocj0EQcDUR9tr3a9t9XbV8y3l2PSBWJ3QzWLl25IG0XVVWlqKw4cPIzY2tqoAOzvExsYiJSVF63mzZs2Cj48Pxo4dq7bv8uXLyMzMFJXp4eGB6OhorWWWlJQgPz9f9CIiIjHVFp2+oT5YndAd/h7qQ871ebFPK4T6GZbnoqmFRCHozvkpN6ZFp0YtMOonezg7ILipa00KNfqqzz7QAg+HG55UXl9YaZxjXKBz69YtyOVy+Pr6irb7+voiMzNT4zm7d+/GihUrsHz5co37K88zpsx58+bBw8ND+QoMDDTmNoiIGoRGjqaPNzE1oND0V79CEHQOi6/edeUgFR8rCBWtSgDwQp9WplUM2u/JmCH75tDWtxHG9gqp1WuaAycM1KCgoADPPvssli9fDi8vL7OVO2XKFOTl5Slf165dM1vZRETWbtWYbmjv746l/zVtDh3AtFwYbQToDpzKqmW5ahquvuDpzjj57kBEqM6MbCRtVbB0nFO9/IjAxjq7gXzdHWFXD3OHrDPMMXJ4uZeXF6RSKbKyskTbs7Ky4Oen3gyXlpaGK1eu4LHHHlNuUygqInd7e3ukpqYqz8vKyoK/v7+ozIiICI31cHR0hKOjo8Z9REQNXb8wH/QLM2zotTamPvw1nRfTsqkocPrgP50waf0J5ffVu640XVoikcDNycG0SqmUUR+EN/fAoSs5Wvf/9XofnLmRjxHL99VirfRTWOmwK6NadGQyGSIjI5GUlKTcplAokJSUhJiYGLXjw8LCcPLkSRw7dkz5evzxx9GvXz8cO3YMgYGBCAkJgZ+fn6jM/Px87N+/X2OZRERkeeYICfa+3R9fPhuJRzv5iwrs2doL5+cMUn5fPRnZUvFIbbfofPCfTvBwdsAXI7sqt1Uuy6HrmvWxNQdoIC06AJCYmIjRo0cjKioK3bt3x8KFC1FUVISEhAQAwKhRo9CsWTPMmzcPTk5O6NhRnOHv6ekJAKLtr732GubMmYM2bdogJCQE06ZNQ0BAgNp8O0REVDvM0foR4Fkxc3FFeVXbBUGAzN4O0SFNsP9yDp6Kao7Ve68o99dWzkx4Mw+zXS/Awwk38sRTojwdFYinIpuLfpbl94M670bah/nbSST1cxbielglQxgd6AwbNgzZ2dmYPn06MjMzERERgS1btiiTidPT02FnZ1zqz6RJk1BUVITnn38eubm56NWrF7Zs2QInJ8PmeyAiIvMy9dGvLfdEUzDx/bgHcOduKbwaiVMRLBXmlKt0vUx9pB2Gdm1ecT0zXNDd2UEt0KkoW1x4ZetVUFMXLBwWgSauMoxaKZ50t7aTow1lpXGOaUtATJw4ERMnTtS4Lzk5Wee5q1evVtsmkUgwa9YszJo1y5TqEBGRmYlaYMzwiNP06JbaSdSCHMA8D/ohEQGYPChMtE11xfVRMcGQ2dvdr1vNr2doC5hcUZWPNOT+SDL1smpcHYvgqCsiIrIZpnZdaTvNqOLM8KCPbtlUbc4g1YCtMsgBTMuJCfV1EwVphpZRrmdyxIqyJHDQsbxFU1eZYRczM+sMcxjoEBGRHkFNDFvmwVzM0aChqfEhvJkHHuscgJf7tzbpggffqZrYViIRBzeGtkKVKXRPjlhRFhAZ1Bh9Q70xvJv6PHF/v9EX655/wKDrmZOmJUWsAVcvJyIijU7NjENZuQIuMsMfFZ4umoeAG/OMtLPQsCOJRILPRnRR364S6YQ388DJf/PUjpnQrxW83cTdbKrBjaEtVoa26NjZSbA6oTvulcqx9qB4rjgPFwd0bdHYsAuakZXGOWzRISIizRo52qOxkd0kHQI88FZcKBYOizD5urWdoqIaVy3SEAgBQFtf8RIYEolEdJ6hXX2qXWbaSCSav1ZVFwnLVhrnMNAhIiLzmtCvtVqireqDWaqnxUY1aHB2kGLlmCjzVlDH9aQSCVr7NNJ67LMPtAAAvBXXVnSevkaoFaOj0LyxM9b8r7tR9dF6jN4jzOudh9vh7fgw/QfWQ+y6IiIii3OWSfFMdBBKyhR6FxZVfYifmhmnNzCqqeqlfzUqCh9uTcX4vq3w6Ge7RftmDe6AxIfaorGrDDPsTiu362thGdDOFwPa+eo8RmPdzJHcfV8zT2e09HbFrgu3jDpPJrXDuAdbGn/BeoKBDhER1Yq5T4QbdJzqQ9zSQU7160kkQHBTVyxWmc1YfKxE2Z0n1dGiM6ZHsHnqVi0Me/ex9sp6GKukXH8itJZKWDV2XRERUb3SLbgJAMDJoXYeUapBg1FJ06JkZHE0MON+QFJTqsXOfSIcY3oavur5Y50DMLRrVRdiSblcx9E66mDSWfUHAx0iIqpX5g0Nx8v9W+PPVx+sleuZYwHT6i065lpA1F6l4NuFJUad29jFAQuejlB+X1ImbtHpG+qtds7Q+7lVsSrdbDIdc/pYA+uuPRER2RxPFxneGBiKEC/XWrmeakhiTHxiJ+q6Mk9gU33Ul0QiQYBHxXJI0S2b1qjs0mqrxHcIcFc75r0nwrH0v5H4dHiEcpuDASPF6jPm6BARUYNmZ2LXlWr+kGoZP4+PMakeT3Rphsc7B6ht3/xqb1y9fRedmnsYVV5ljXq19sLui7cwOCIAOUWlKvvVgzNnmRTxHf1E2xyk1t15xUCHiIgaNNO7rjRPGBjZokkNayTm6SKDp4v2+Yxej22LguIyfLX7ssb9i0d2RdLZLAzs4Ifx3x5Wbjf0vg2Z+6c+s+7aExER1ZBqy4a2h7+mtadMWQLCEpo3dsbUR7UnP3s4O2Bo1+Zo5Chu2zC0xtaeo8MWHSIisjnGxB2ildqrdV2N79sKR67ewUPt1efAsTNiwsC64CSTqm2TiqdzFu1bldBNYzm6Fhi1Bgx0iIiozj0THYTv96cjtp1PjcoZ0yNYmY9iDpN1zAZcX1p0NOkc6InxfVqpbVed+6d6jfuFav7ZW3vXFQMdIiKqczMea4+H2vkiumXN8lvefbxDjc43Jl4J9XPD8et5Rp9naWF+btg0oafGfaoLphpa5zA/N/0H1WMMdIiIqM452kvRL6xmrTmmUu2aaWLEIqbvPNwezg5SPNG1ORbvuGiJqhnEmCBL3KKj+8QNL/XAL0f+xZtxoaZWrV5goENERA2a1E6CnW/1Q7lCAVdHwx+LHi4OmDm4I4D6maOjidSIFp2uQY3RNaixhWtkeQx0iIiowQtq6lKj840JkLRx0ZA8bIouQZ5a94kCHbNcrf5joENERFRDb8eH4XxWAUZGtzD63PlDw7Hu0DUkPtS2RnXYnvggfjueged6a18Py5gWHVvBQIeIiKiGfNyd8PvLvU06d3j3IAzvHlTjOrT2ccPrD+lOHNa2EGlL79pZbqMuWPeYMSIiIjKYtilxNr9iWpBmDRjoEBERWSF3p4pOmQeMWOxTW9eVk4N58oPqI3ZdERERWaH9/xeLguIy+Lg7GXyOnRHDy20FAx0iIiIr5CyTwtnIkVriFdfNXaP6iV1XREREDYSdlhXXbRkDHSIiogaiMq+nIWl4d0xERNRAjXuwJfZdysGjnf1RWq6o6+rUCgY6REREDYSbkwN+fDEGALDrQnYd16Z2MNAhIiJqgHq19sLHT3VGqJWvTq4PAx0iIqIGSCKR4MnI5nVdDYtjMjIRERHZLAY6REREZLMY6BAREZHNYqBDRERENsukQGfx4sUIDg6Gk5MToqOjceDAAa3HbtiwAVFRUfD09ISrqysiIiLwzTffiI4ZM2YMJBKJ6BUfH29K1YiIiIiUjB51tW7dOiQmJmLp0qWIjo7GwoULERcXh9TUVPj4+Kgd36RJE7zzzjsICwuDTCbD77//joSEBPj4+CAuLk55XHx8PFatWqX83tHR0cRbIiIiIqogEQRBMOaE6OhodOvWDZ9//jkAQKFQIDAwEC+//DLefvttg8ro2rUrHnnkEcyePRtARYtObm4uNm7caFzt78vPz4eHhwfy8vLg7u5uUhlERERUu2rj+W1U11VpaSkOHz6M2NjYqgLs7BAbG4uUlBS95wuCgKSkJKSmpuLBBx8U7UtOToaPjw9CQ0Mxfvx43L59W2s5JSUlyM/PF72IiIiIqjOq6+rWrVuQy+Xw9fUVbff19cW5c+e0npeXl4dmzZqhpKQEUqkUX3zxBR566CHl/vj4eAwdOhQhISFIS0vD//3f/2HQoEFISUmBVKq+BP28efMwc+ZMY6pOREREDVCtzIzs5uaGY8eOobCwEElJSUhMTETLli3Rt29fAMDw4cOVx4aHh6NTp05o1aoVkpOTMWDAALXypkyZgsTEROX3+fn5CAwMtPh9EBERkXUxKtDx8vKCVCpFVlaWaHtWVhb8/Py0nmdnZ4fWrVsDACIiInD27FnMmzdPGehU17JlS3h5eeHixYsaAx1HR0cmKxMREZFeRuXoyGQyREZGIikpSblNoVAgKSkJMTExBpejUChQUlKidf/169dx+/Zt+Pv7G1M9IiIiIhGju64SExMxevRoREVFoXv37li4cCGKioqQkJAAABg1ahSaNWuGefPmAajIp4mKikKrVq1QUlKCzZs345tvvsGSJUsAAIWFhZg5cyaefPJJ+Pn5IS0tDZMmTULr1q1Fw8+JiIiIjGV0oDNs2DBkZ2dj+vTpyMzMREREBLZs2aJMUE5PT4edXVVDUVFREV566SVcv34dzs7OCAsLw7fffothw4YBAKRSKU6cOIE1a9YgNzcXAQEBGDhwIGbPnm1w91TlCHmOviIiIrIelc9tI2e6MYrR8+jUR9evX2cyMhERkZW6du0amjdvbpGybSLQUSgUuHHjBtzc3CCRSMxaduWIrmvXrtn8ZIS8V9vTUO4T4L3aooZyn0DDvVc3NzcUFBQgICBA1BtkTrUyvNzS7OzsLBYJVnJ3d7f5f3yVeK+2p6HcJ8B7tUUN5T6BhnmvHh4eFr0OVy8nIiIim8VAh4iIiGwWAx09HB0dMWPGjAYxQSHv1fY0lPsEeK+2qKHcJ8B7tSSbSEYmIiIi0oQtOkRERGSzGOgQERGRzWKgQ0RERDaLgQ4RERHZLAY6eixevBjBwcFwcnJCdHQ0Dhw4UNdVMsq8efPQrVs3uLm5wcfHB0OGDEFqaqromL59+0IikYheL774ouiY9PR0PPLII3BxcYGPjw/eeustlJeX1+at6PTuu++q3UNYWJhyf3FxMSZMmICmTZuiUaNGePLJJ5GVlSUqo77fY6Xg4GC1e5VIJJgwYQIA634/d+7cicceewwBAQGQSCTYuHGjaL8gCJg+fTr8/f3h7OyM2NhYXLhwQXRMTk4ORo4cCXd3d3h6emLs2LEoLCwUHXPixAn07t0bTk5OCAwMxAcffGDpW1Oj617LysowefJkhIeHw9XVFQEBARg1ahRu3LghKkPTv4X58+eLjqnre9X3no4ZM0btHuLj40XH2MJ7CkDj761EIsGHH36oPMYa3lNDnivm+sxNTk5G165d4ejoiNatW2P16tXGV1ggrdauXSvIZDJh5cqVwunTp4Vx48YJnp6eQlZWVl1XzWBxcXHCqlWrhFOnTgnHjh0THn74YSEoKEgoLCxUHtOnTx9h3LhxQkZGhvKVl5en3F9eXi507NhRiI2NFY4ePSps3rxZ8PLyEqZMmVIXt6TRjBkzhA4dOojuITs7W7n/xRdfFAIDA4WkpCTh0KFDwgMPPCD06NFDud8a7rHSzZs3Rfe5bds2AYCwY8cOQRCs+/3cvHmz8M477wgbNmwQAAi//PKLaP/8+fMFDw8PYePGjcLx48eFxx9/XAgJCRHu3bunPCY+Pl7o3LmzsG/fPmHXrl1C69athREjRij35+XlCb6+vsLIkSOFU6dOCT/88IPg7OwsLFu2rLZuUxAE3feam5srxMbGCuvWrRPOnTsnpKSkCN27dxciIyNFZbRo0UKYNWuW6L1W/d2uD/eq7z0dPXq0EB8fL7qHnJwc0TG28J4KgiC6x4yMDGHlypWCRCIR0tLSlMdYw3tqyHPFHJ+5ly5dElxcXITExEThzJkzwmeffSZIpVJhy5YtRtWXgY4O3bt3FyZMmKD8Xi6XCwEBAcK8efPqsFY1c/PmTQGA8M8//yi39enTR3j11Ve1nrN582bBzs5OyMzMVG5bsmSJ4O7uLpSUlFiyugabMWOG0LlzZ437cnNzBQcHB+Gnn35Sbjt79qwAQEhJSREEwTruUZtXX31VaNWqlaBQKARBsI33UxAEtQeFQqEQ/Pz8hA8//FC5LTc3V3B0dBR++OEHQRAE4cyZMwIA4eDBg8pj/vzzT0EikQj//vuvIAiC8MUXXwiNGzcW3evkyZOF0NBQC9+RdpoeitUdOHBAACBcvXpVua1FixbCJ598ovWc+nav2gKdwYMHaz3Hlt/TwYMHC/379xdts7b3VBDUnyvm+sydNGmS0KFDB9G1hg0bJsTFxRlVP3ZdaVFaWorDhw8jNjZWuc3Ozg6xsbFISUmpw5rVTF5eHgCgSZMmou3fffcdvLy80LFjR0yZMgV3795V7ktJSUF4eDh8fX2V2+Li4pCfn4/Tp0/XTsUNcOHCBQQEBKBly5YYOXIk0tPTAQCHDx9GWVmZ6L0MCwtDUFCQ8r20lnusrrS0FN9++y3+97//iRa0tYX3s7rLly8jMzNT9D56eHggOjpa9D56enoiKipKeUxsbCzs7Oywf/9+5TEPPvggZDKZ8pi4uDikpqbizp07tXQ3xsvLy4NEIoGnp6do+/z589G0aVN06dIFH374oajp31ruNTk5GT4+PggNDcX48eNx+/Zt5T5bfU+zsrLwxx9/YOzYsWr7rO09rf5cMddnbkpKiqiMymOMfQbbxKKelnDr1i3I5XLRmwAAvr6+OHfuXB3VqmYUCgVee+019OzZEx07dlRuf+aZZ9CiRQsEBATgxIkTmDx5MlJTU7FhwwYAQGZmpsafQ+W++iA6OhqrV69GaGgoMjIyMHPmTPTu3RunTp1CZmYmZDKZ2gPC19dXWX9ruEdNNm7ciNzcXIwZM0a5zRbeT00q66ap7qrvo4+Pj2i/vb09mjRpIjomJCRErYzKfY0bN7ZI/WuiuLgYkydPxogRI0QLPr7yyivo2rUrmjRpgr1792LKlCnIyMjAggULAFjHvcbHx2Po0KEICQlBWloa/u///g+DBg1CSkoKpFKpzb6na9asgZubG4YOHSrabm3vqabnirk+c7Udk5+fj3v37sHZ2dmgOjLQaUAmTJiAU6dOYffu3aLtzz//vPLr8PBw+Pv7Y8CAAUhLS0OrVq1qu5omGTRokPLrTp06ITo6Gi1atMCPP/5o8C+DNVqxYgUGDRqEgIAA5TZbeD+pSllZGZ5++mkIgoAlS5aI9iUmJiq/7tSpE2QyGV544QXMmzfPapYSGD58uPLr8PBwdOrUCa1atUJycjIGDBhQhzWzrJUrV2LkyJFwcnISbbe291Tbc6U+YdeVFl5eXpBKpWpZ4llZWfDz86ujWplu4sSJ+P3337Fjxw40b95c57HR0dEAgIsXLwIA/Pz8NP4cKvfVR56enmjbti0uXrwIPz8/lJaWIjc3V3SM6ntpjfd49epVbN++Hc8995zO42zh/QSq6qbrd9LPzw83b94U7S8vL0dOTo5VvteVQc7Vq1exbds2UWuOJtHR0SgvL8eVK1cAWNe9VmrZsiW8vLxE/15t6T0FgF27diE1NVXv7y5Qv99Tbc8Vc33majvG3d3dqD9gGehoIZPJEBkZiaSkJOU2hUKBpKQkxMTE1GHNjCMIAiZOnIhffvkFf//9t1qTpybHjh0DAPj7+wMAYmJicPLkSdGHTeWHbvv27S1S75oqLCxEWloa/P39ERkZCQcHB9F7mZqaivT0dOV7aY33uGrVKvj4+OCRRx7ReZwtvJ8AEBISAj8/P9H7mJ+fj/3794vex9zcXBw+fFh5zN9//w2FQqEM+GJiYrBz506UlZUpj9m2bRtCQ0PrVRdHZZBz4cIFbN++HU2bNtV7zrFjx2BnZ6fs6rGWe1V1/fp13L59W/Tv1Vbe00orVqxAZGQkOnfurPfY+vie6nuumOszNyYmRlRG5TFGP4ONz69uONauXSs4OjoKq1evFs6cOSM8//zzgqenpyhLvL4bP3684OHhISQnJ4uGK969e1cQBEG4ePGiMGvWLOHQoUPC5cuXhU2bNgktW7YUHnzwQWUZlcMABw4cKBw7dkzYsmWL4O3tXS+GI1d64403hOTkZOHy5cvCnj17hNjYWMHLy0u4efOmIAgVQx2DgoKEv//+Wzh06JAQExMjxMTEKM+3hntUJZfLhaCgIGHy5Mmi7db+fhYUFAhHjx4Vjh49KgAQFixYIBw9elQ50mj+/PmCp6ensGnTJuHEiRPC4MGDNQ4v79Kli7B//35h9+7dQps2bURDkXNzcwVfX1/h2WefFU6dOiWsXbtWcHFxqfWhyLrutbS0VHj88ceF5s2bC8eOHRP97laOSNm7d6/wySefCMeOHRPS0tKEb7/9VvD29hZGjRpVr+5V130WFBQIb775ppCSkiJcvnxZ2L59u9C1a1ehTZs2QnFxsbIMW3hPK+Xl5QkuLi7CkiVL1M63lvdU33NFEMzzmVs5vPytt94Szp49KyxevJjDyy3hs88+E4KCggSZTCZ0795d2LdvX11XySgANL5WrVolCIIgpKenCw8++KDQpEkTwdHRUWjdurXw1ltvieZdEQRBuHLlijBo0CDB2dlZ8PLyEt544w2hrKysDu5Is2HDhgn+/v6CTCYTmjVrJgwbNky4ePGicv+9e/eEl156SWjcuLHg4uIiPPHEE0JGRoaojPp+j6q2bt0qABBSU1NF2639/dyxY4fGf6+jR48WBKFiiPm0adMEX19fwdHRURgwYIDaz+D27dvCiBEjhEaNGgnu7u5CQkKCUFBQIDrm+PHjQq9evQRHR0ehWbNmwvz582vrFpV03evly5e1/u5Wzpd0+PBhITo6WvDw8BCcnJyEdu3aCXPnzhUFCPXhXnXd5927d4WBAwcK3t7egoODg9CiRQth3Lhxan9M2sJ7WmnZsmWCs7OzkJubq3a+tbyn+p4rgmC+z9wdO3YIERERgkwmE1q2bCm6hqEk9ytNREREZHOYo0NEREQ2i4EOERER2SwGOkRERGSzGOgQERGRzWKgQ0RERDaLgQ4RERHZLAY6REREZLMY6BAREZHNYqBDRERENouBDhEREdksBjpERERksxjoEBERkc36fwSVBso19+9uAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(train_metrics_history[\"train_loss\"], label=\"Loss value during the training\")\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA0MAAAHDCAYAAADm78EeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB9c0lEQVR4nO3deVxU9f4/8NfMAMM+7DsCooC4oKIi7iWKpqblWhZKaf28WhatVm653frevHbLIr1607LUzMzULKW0TBQVd3EBZFV2YVhkmzm/P4CpCVBA4Awzr+fjcR7lmc858z6Dng/v+XzO+yMRBEEAERERERGRgZGKHQAREREREZEYmAwREREREZFBYjJEREREREQGickQEREREREZJCZDRERERERkkJgMERERERGRQWIyREREREREBonJEBERERERGSQmQ0REREREZJCYDBE1YtmyZZBIJGKHQURERERthMkQ3dfnn38OiUSC06dPix0KtaKysjIsW7YMR44cadP3OXDgAJYtW9am79GYTz75BJ9//rko701E4vrkk08gkUgQEhIidijUgOPHj2PZsmUoLCxs0/dZvXo19uzZ06bv0ZBbt25h2bJlOHfuXLu/NzUPkyEiA1VWVobly5e3SzK0fPnyNn2PxjAZIjJc27Ztg7e3N+Li4pCYmCh2OPQ3x48fx/Lly/U6GVq+fDmToQ6AyRARERHplZs3b+L48eNYu3YtHB0dsW3bNrFDalRpaanYIRAZNCZD1GrOnj2LsWPHwtraGpaWlhg5ciROnDih1aaqqgrLly9H165dYWpqCnt7ewwZMgSHDh3StMnKykJkZCQ8PDwgl8vh6uqKiRMnIiUlpdH3/te//gWJRILU1NR6ry1atAgmJia4c+cOAOD333/H1KlT0alTJ8jlcnh6euLll1/G3bt373l9KSkpkEgkDY40SCSSelPBMjMz8cwzz8DZ2RlyuRzdu3fH5s2b7/kedaqrq7FixQr4+vpCLpfD29sbb731FioqKrTaeXt7Y/z48Th27BgGDBgAU1NTdO7cGVu3br3vtTg6OgIAli9fDolEUu8arl69iilTpsDOzg6mpqbo168f9u7dq3We+/08Z8+ejfXr12s+o7rtXk6fPo3w8HA4ODjAzMwMPj4+eOaZZ7TaqNVqrFu3Dt27d4epqSmcnZ3x/PPPa37GdZ/N5cuXcfToUc37jhgx4p7vTUT6Ydu2bbC1tcW4ceMwZcqURpOhwsJCvPzyy/D29oZcLoeHhwciIiKQl5enaVNeXo5ly5bBz88PpqamcHV1xeOPP46kpCQAwJEjRyCRSOqNsjfUZ8yePRuWlpZISkrCI488AisrK8ycORNA8/qmq1evYtq0aXB0dISZmRn8/f3x9ttvAwB+/fVXSCQSfPfdd/WO++qrryCRSBAbG3vPzy85ORlTp06FnZ0dzM3NMXDgQOzfv1+rTd1179y5E6tWrYKHhwdMTU0xcuTI+47ELVu2DK+99hoAwMfHR3OP/ms//+WXXyI4OBhmZmaws7PDjBkzkJ6ernWeGzduYPLkyXBxcYGpqSk8PDwwY8YMFBUVAajpd0pLS7FlyxbNe8yePfuesX300Ufo3r07zM3NYWtri379+uGrr77SanO//v3IkSPo378/ACAyMlLz3pypoJuMxA6A9MPly5cxdOhQWFtb4/XXX4exsTE+++wzjBgxAkePHtXM2V62bBnWrFmDOXPmYMCAAVAqlTh9+jTi4+MxatQoAMDkyZNx+fJlvPDCC/D29kZOTg4OHTqEtLQ0eHt7N/j+06ZNw+uvv46dO3dqbrB1du7cidGjR8PW1hYA8M0336CsrAzz5s2Dvb094uLi8NFHHyEjIwPffPNNq3we2dnZGDhwICQSCRYsWABHR0f8+OOPePbZZ6FUKvHSSy/d8/g5c+Zgy5YtmDJlCl555RWcPHkSa9asQUJCQr0OLjExEVOmTMGzzz6LWbNmYfPmzZg9ezaCg4PRvXv3Bs/v6OiITz/9FPPmzcNjjz2Gxx9/HADQq1cvADU/z8GDB8Pd3R1vvvkmLCwssHPnTkyaNAnffvstHnvsMQD3/3k+//zzuHXrFg4dOoQvvvjivp9bTk4ORo8eDUdHR7z55puwsbFBSkoKdu/erdXu+eefx+eff47IyEi8+OKLuHnzJj7++GOcPXsWf/zxB4yNjbFu3Tq88MILsLS01PyS4OzsfN8YiKjj27ZtGx5//HGYmJjgiSeewKeffopTp05pfkEFgJKSEgwdOhQJCQl45pln0LdvX+Tl5WHv3r3IyMiAg4MDVCoVxo8fj5iYGMyYMQMLFy5EcXExDh06hEuXLsHX17fZsVVXVyM8PBxDhgzBv/71L5ibmwNoet904cIFDB06FMbGxnjuuefg7e2NpKQk/PDDD1i1ahVGjBgBT09PbNu2TXOv/uvn4uvri9DQ0Ebjy87OxqBBg1BWVoYXX3wR9vb22LJlCx599FHs2rWr3jn/+c9/QiqV4tVXX0VRURHef/99zJw5EydPnmz0PR5//HFcv34dX3/9Nf7973/DwcEBADRf0q1atQqLFy/GtGnTMGfOHOTm5uKjjz7CsGHDcPbsWdjY2KCyshLh4eGoqKjACy+8ABcXF2RmZmLfvn0oLCyEQqHAF198oemfnnvuOQC4589s48aNePHFFzFlyhQsXLgQ5eXluHDhAk6ePIknn3xS8/ncr3/v1q0b3n33XSxZsgTPPfcchg4dCgAYNGhQo+9NIhKI7uN///ufAEA4depUo20mTZokmJiYCElJSZp9t27dEqysrIRhw4Zp9gUFBQnjxo1r9Dx37twRAAj/93//1+w4Q0NDheDgYK19cXFxAgBh69atmn1lZWX1jl2zZo0gkUiE1NRUzb6lS5cKf/0ncvPmTQGA8L///a/e8QCEpUuXav787LPPCq6urkJeXp5WuxkzZggKhaLBGOqcO3dOACDMmTNHa/+rr74qABB++eUXzT4vLy8BgPDbb79p9uXk5AhyuVx45ZVXGn0PQRCE3NzcenHXGTlypNCzZ0+hvLxcs0+tVguDBg0Sunbtqtl3v5+nIAjC/Pnzhabear777rv7/l37/fffBQDCtm3btPYfPHiw3v7u3bsLw4cPb9J7E5F+OH36tABAOHTokCAINfcuDw8PYeHChVrtlixZIgAQdu/eXe8carVaEARB2Lx5swBAWLt2baNtfv31VwGA8Ouvv2q93lCfMWvWLAGA8Oabb9Y7X1P7pmHDhglWVlZa+/4ajyAIwqJFiwS5XC4UFhZq9uXk5AhGRkYN3vP/6qWXXhIACL///rtmX3FxseDj4yN4e3sLKpVK67q7desmVFRUaNp++OGHAgDh4sWL93yf//u//xMACDdv3tTan5KSIshkMmHVqlVa+y9evCgYGRlp9p89e1YAIHzzzTf3fB8LCwth1qxZ92xTZ+LEiUL37t3v2aap/fupU6ca/Z2BdAunydEDU6lU+PnnnzFp0iR07txZs9/V1RVPPvkkjh07BqVSCQCwsbHB5cuXcePGjQbPZWZmBhMTExw5ckRrylNTTJ8+HWfOnNFMXQCAHTt2QC6XY+LEiVrvUae0tBR5eXkYNGgQBEHA2bNnm/WeDREEAd9++y0mTJgAQRCQl5en2cLDw1FUVIT4+PhGjz9w4AAAICoqSmv/K6+8AgD1pioEBgZqvnUCar5Z8/f3R3JycoviLygowC+//IJp06ahuLhYE3t+fj7Cw8Nx48YNZGZmArj/z7O5bGxsAAD79u1DVVVVg22++eYbKBQKjBo1SuuzDQ4OhqWlJX799ddWiYWIOqZt27bB2dkZDz30EICaqVLTp0/H9u3boVKpNO2+/fZbBAUF1RvpqDumro2DgwNeeOGFRtu0xLx58+rta0rflJubi99++w3PPPMMOnXq1Gg8ERERqKiowK5duzT7duzYgerqajz11FP3jO3AgQMYMGAAhgwZotlnaWmJ5557DikpKbhy5YpW+8jISJiYmGj+XNcftbQP2r17N9RqNaZNm6Z1j3dxcUHXrl0193iFQgEA+Omnn1BWVtai9/o7GxsbZGRk4NSpUw2+/qD9O+kmJkP0wHJzc1FWVgZ/f/96r3Xr1g1qtVozz/fdd99FYWEh/Pz80LNnT7z22mu4cOGCpr1cLsd7772HH3/8Ec7Ozhg2bBjef/99ZGVl3TeOqVOnQiqVYseOHQBqblrffPON5jmmOmlpaZg9ezbs7OxgaWkJR0dHDB8+HAA084wfRG5uLgoLC7FhwwY4OjpqbZGRkQBqpoM1JjU1FVKpFF26dNHa7+LiAhsbm3rPRf29QwQAW1vbZieTdRITEyEIAhYvXlwv/qVLl2rFf7+fZ3MNHz4ckydPxvLly+Hg4ICJEyfif//7n9azUjdu3EBRURGcnJzqxVdSUnLPz5aI9JtKpcL27dvx0EMP4ebNm0hMTERiYiJCQkKQnZ2NmJgYTdukpCT06NHjnudLSkqCv78/jIxa76kCIyMjeHh41NvflL6pLsG4X9wBAQHo37+/1rNS27Ztw8CBA+v1LX+XmpraaH9e9/pf/b0PqpuS3tI+6MaNGxAEAV27dq13j09ISNDc4318fBAVFYX//ve/cHBwQHh4ONavX/9A/fgbb7wBS0tLDBgwAF27dsX8+fPxxx9/aF5/0P6ddBOfGaJ2NWzYMCQlJeH777/Hzz//jP/+97/497//jejoaMyZMwcA8NJLL2HChAnYs2cPfvrpJyxevBhr1qzBL7/8gj59+jR6bjc3NwwdOhQ7d+7EW2+9hRMnTiAtLQ3vvfeepo1KpcKoUaNQUFCAN954AwEBAbCwsEBmZiZmz54NtVrd6Pkb+xbwr980AtCc46mnnsKsWbMaPKbu2Zx7aeq3jjKZrMH9giA06fi/q4v/1VdfRXh4eINt6jrTpvw8m0MikWDXrl04ceIEfvjhB/z000945pln8MEHH+DEiROwtLSEWq2Gk5NTow9E1805JyLD88svv+D27dvYvn07tm/fXu/1bdu2YfTo0a36nk3tG+rI5XJIpdJ6bVvaNzUmIiICCxcuREZGBioqKnDixAl8/PHHzT7P/bRFHySRSPDjjz82eG5LS0vN/3/wwQeYPXu2pg968cUXsWbNGpw4caLBhPN+unXrhmvXrmHfvn04ePAgvv32W3zyySdYsmQJli9f3mr9O+kWJkP0wBwdHWFubo5r167Ve+3q1auQSqXw9PTU7LOzs0NkZCQiIyNRUlKCYcOGYdmyZVq/PPv6+uKVV17BK6+8ghs3bqB379744IMP8OWXX94zlunTp+Mf//gHrl27hh07dsDc3BwTJkzQvH7x4kVcv34dW7ZsQUREhGb/X6vZNabu266/r4nw92/JHB0dYWVlBZVKhbCwsPue9++8vLygVqtx48YNzTdxQM1Dm4WFhfDy8mr2ORvSWAdeN9XR2Ni4SfHf7+fZkqkkAwcOxMCBA7Fq1Sp89dVXmDlzJrZv3445c+bA19cXhw8fxuDBg7WmlTTkQaaxEFHHs23bNjg5OWmqWP7V7t278d133yE6OhpmZmbw9fXFpUuX7nk+X19fnDx5ElVVVTA2Nm6wTVP7hntpat9Ud3++X9wAMGPGDERFReHrr7/G3bt3YWxsjOnTp9/3OC8vr0b787rXW0Nj92dfX18IggAfHx/4+fnd9zw9e/ZEz5498c477+D48eMYPHgwoqOjsXLlynu+T2MsLCwwffp0TJ8+HZWVlXj88cexatUqLFq0qFn9O/ufjoPT5OiByWQyjB49Gt9//71WWczs7Gx89dVXGDJkiGaaWn5+vtaxlpaW6NKli2YaVFlZGcrLy7Xa+Pr6wsrKql5Z6YZMnjwZMpkMX3/9Nb755huMHz8eFhYWWrEC2t9YCYKADz/88L7ntra2hoODA3777Tet/Z988onWn2UyGSZPnoxvv/22wQ4rNzf3nu/zyCOPAADWrVuntX/t2rUAgHHjxt031qaoq2D09w7cyckJI0aMwGeffYbbt2/XO+6v8d/v5wlA8/k3ZWG9O3fu1Ps2sXfv3gCgOee0adOgUqmwYsWKesdXV1drvY+FhUWbL+hHRLrh7t272L17N8aPH48pU6bU2xYsWIDi4mLNEgGTJ0/G+fPnGyxBXXcfmjx5MvLy8hocUalr4+XlBZlMdt++4V6a2jc5Ojpi2LBh2Lx5M9LS0hqMp46DgwPGjh2LL7/8Etu2bcOYMWM0Vdvu5ZFHHkFcXJxW+e3S0lJs2LAB3t7eCAwMbPJ13UtjfcPjjz8OmUyG5cuX17smQRA0/Y5SqUR1dbXW6z179oRUKq3XBzW1H/h7n2ZiYoLAwEAIgoCqqqpm9e/N6ftIXBwZoibbvHkzDh48WG//woULsXLlShw6dAhDhgzBP/7xDxgZGeGzzz5DRUUF3n//fU3bwMBAjBgxAsHBwbCzs8Pp06exa9cuLFiwAABw/fp1jBw5EtOmTUNgYCCMjIzw3XffITs7GzNmzLhvjE5OTnjooYewdu1aFBcX1/sWLCAgAL6+vnj11VeRmZkJa2trfPvtt02e2zxnzhz885//xJw5c9CvXz/89ttvuH79er12//znP/Hrr78iJCQEc+fORWBgIAoKChAfH4/Dhw+joKCg0fcICgrCrFmzsGHDBhQWFmL48OGIi4vDli1bMGnSJM1DwQ/KzMwMgYGB2LFjB/z8/GBnZ4cePXqgR48eWL9+PYYMGYKePXti7ty56Ny5M7KzsxEbG4uMjAycP38ewP1/ngAQHBwMAHjxxRcRHh4OmUzW6M9yy5Yt+OSTT/DYY4/B19cXxcXF2LhxI6ytrTVJ4vDhw/H8889jzZo1OHfuHEaPHg1jY2PcuHED33zzDT788ENMmTJF896ffvopVq5ciS5dusDJyQkPP/xwq3x+RKRb9u7di+LiYjz66KMNvj5w4EDNAqzTp0/Ha6+9hl27dmHq1Kl45plnEBwcjIKCAuzduxfR0dEICgpCREQEtm7diqioKMTFxWHo0KEoLS3F4cOH8Y9//AMTJ06EQqHA1KlT8dFHH0EikcDX1xf79u1r1rMjzemb/vOf/2DIkCHo27cvnnvuOfj4+CAlJQX79+/HuXPntNpGRERo7ocNfYHUkDfffBNff/01xo4dixdffBF2dnbYsmULbt68iW+//bbeFL+Wqusb3n77bcyYMQPGxsaYMGECfH19sXLlSixatAgpKSmYNGkSrKyscPPmTXz33Xd47rnn8Oqrr+KXX37BggULMHXqVPj5+aG6uhpffPGFJmH56/scPnwYa9euhZubG3x8fDTLffzd6NGj4eLigsGDB8PZ2RkJCQn4+OOPMW7cOFhZWQFoev/u6+sLGxsbREdHw8rKChYWFggJCYGPj0+rfH7Uitqxch11UHWltRvb0tPTBUEQhPj4eCE8PFywtLQUzM3NhYceekg4fvy41rlWrlwpDBgwQLCxsRHMzMyEgIAAYdWqVUJlZaUgCIKQl5cnzJ8/XwgICBAsLCwEhUIhhISECDt37mxyvBs3bhQACFZWVsLdu3frvX7lyhUhLCxMsLS0FBwcHIS5c+cK58+fr1cC8++ltQWhpvTps88+KygUCsHKykqYNm2akJOT02CJ6uzsbGH+/PmCp6enYGxsLLi4uAgjR44UNmzYcN9rqKqqEpYvXy74+PgIxsbGgqenp7Bo0SKtUteCUFNau6HS1sOHD29SSenjx48LwcHBgomJSb1rSEpKEiIiIgQXFxfB2NhYcHd3F8aPHy/s2rVL0+Z+P09BEITq6mrhhRdeEBwdHQWJRHLPMtvx8fHCE088IXTq1EmQy+WCk5OTMH78eOH06dP12m7YsEEIDg4WzMzMBCsrK6Fnz57C66+/Lty6dUvTJisrSxg3bpxgZWUlAGCZbSI9NmHCBMHU1FQoLS1ttM3s2bMFY2NjTVnk/Px8YcGCBYK7u7tgYmIieHh4CLNmzdIqm1xWVia8/fbbmvuxi4uLMGXKFK2lJHJzc4XJkycL5ubmgq2trfD8888Lly5darC0toWFRYOxNbVvEgRBuHTpkvDYY48JNjY2gqmpqeDv7y8sXry43jkrKioEW1tbQaFQNNgfNiYpKUmYMmWK5vwDBgwQ9u3bp9WmrrT230tb32sZir9bsWKF4O7uLkil0npltr/99lthyJAhgoWFhWBhYSEEBAQI8+fPF65duyYIgiAkJycLzzzzjODr6yuYmpoKdnZ2wkMPPSQcPnxY6z2uXr0qDBs2TDAzMxMA3LPM9meffSYMGzZMsLe3F+RyueDr6yu89tprQlFRkVa7pvbv33//vRAYGCgYGRmxzLYOkwhCC59wIyIiIiKdVV1dDTc3N0yYMAGbNm0SOxwincRnhoiIiIj00J49e5Cbm6tVlIGItHFkiIiIiEiPnDx5EhcuXMCKFSvg4ODAhUCJ7oEjQ0RERER65NNPP8W8efPg5OSErVu3ih0OkU7jyBARERERERkkjgwREREREZFBYjJEREREREQGSS8WXVWr1bh16xasrKwgkUjEDoeIyKAIgoDi4mK4ubm12oKM+oB9ExGROJrTL+lFMnTr1i14enqKHQYRkUFLT0+Hh4eH2GHoDPZNRETiakq/pBfJkJWVFYCaC7a2thY5GiIiw6JUKuHp6am5F1MN9k1EROJoTr+kF8lQ3fQDa2trdjhERCLhVDBt7JuIiMTVlH6Jk7uJiIiIiMggMRkiIiIiIiKDxGSIiIiIiIgMEpMhIiIiIiIySEyGiIiIiIjIIDEZIiIiIiIig8RkiIiIiIiIDBKTISIiIiIiMkhMhoiIiIiIyCAxGSIiIr2wfv16eHt7w9TUFCEhIYiLi7tn+3Xr1sHf3x9mZmbw9PTEyy+/jPLycs3ry5Ytg0Qi0doCAgLa+jKIiKgdGYkdABER0YPasWMHoqKiEB0djZCQEKxbtw7h4eG4du0anJyc6rX/6quv8Oabb2Lz5s0YNGgQrl+/jtmzZ0MikWDt2rWadt27d8fhw4c1fzYyYrdJRKRPODJEREQd3tq1azF37lxERkYiMDAQ0dHRMDc3x+bNmxtsf/z4cQwePBhPPvkkvL29MXr0aDzxxBP1RpOMjIzg4uKi2RwcHNrjcoiIqJ0wGSIiog6tsrISZ86cQVhYmGafVCpFWFgYYmNjGzxm0KBBOHPmjCb5SU5OxoEDB/DII49otbtx4wbc3NzQuXNnzJw5E2lpaY3GUVFRAaVSqbUREZFuYzIEQFlehS9iU1CtUosdChERNVNeXh5UKhWcnZ219js7OyMrK6vBY5588km8++67GDJkCIyNjeHr64sRI0bgrbfe0rQJCQnB559/joMHD+LTTz/FzZs3MXToUBQXFzd4zjVr1kChUGg2T0/P1rtIIiIDUVxehevZxThyLQc380rb/P04+RnA2p+v4/PjKfg6Lh2rHuuBPp1sxQ6JiIja0JEjR7B69Wp88sknCAkJQWJiIhYuXIgVK1Zg8eLFAICxY8dq2vfq1QshISHw8vLCzp078eyzz9Y756JFixAVFaX5s1KpZEJERPQXdytVuFV0F7cLy3G76C5uF9X891bdnwvLUVxRrWn/Wrg/5j/UpU1jYjIEoJurFRRmxrhyW4nHPz2OJwZ0whvhAVCYG4sdGhER3YeDgwNkMhmys7O19mdnZ8PFxaXBYxYvXoynn34ac+bMAQD07NkTpaWleO655/D2229DKq0/ccLGxgZ+fn5ITExs8JxyuRxyufwBr4aIqGOqqFYhq6i8wQTnVu2+wrKqJp3L2tQIbjZmsJS3farCZAjA9P6dMLKbM9YcuIpv4zPw1ck0/HQpC2890g2P93WHRCIRO0QiImqEiYkJgoODERMTg0mTJgEA1Go1YmJisGDBggaPKSsrq5fwyGQyAIAgCA0eU1JSgqSkJDz99NOtFzwRUQdQpVIjW1mOrKLaxKawZlTnVuGfozt5JZVNOpeFiQyuNmZwVZjCTWEGVxtTuCpM4aowg5tNzX8t2iEJqsNkqJaDpRwfTAvC1H4eeGfPJSTmlOCVb87jmzPpWDmpB7o4WYkdIhERNSIqKgqzZs1Cv379MGDAAKxbtw6lpaWIjIwEAERERMDd3R1r1qwBAEyYMAFr165Fnz59NNPkFi9ejAkTJmiSoldffRUTJkyAl5cXbt26haVLl0Imk+GJJ54Q7TqJiNpClUqNa1nFSCso00pw6kZ3cosroG74eyItciMp3GoTHZe/JDt/Jj1msDY10qmBBiZDfzOwsz0OvDgU/z2WjP/E3MCJ5AKM/fB3zB3aGS883BVmJjKxQyQior+ZPn06cnNzsWTJEmRlZaF37944ePCgpqhCWlqa1kjQO++8A4lEgnfeeQeZmZlwdHTEhAkTsGrVKk2bjIwMPPHEE8jPz4ejoyOGDBmCEydOwNHRsd2vj4iotajVApLzSnA+vQgXMgpxPqMIV24rUVl970JixjIJXBSmcLX+M7GpG8lxVZjCzcYMtubGOpXoNIVEaGw+QAeiVCqhUChQVFQEa2vrVjtvekEZlu29jJirOQAAD1szvDuxOx4OcL7PkUREhqOt7sEdHT8XIhKbIAjILLyLCxlFOJ9RiAvpRbiYWYSSvxQpqGNtaoQuTpZwtTGD29+mrbnamMLBQg6ptGMkOs25/3Jk6B487czx31n98POVbCzbexkZd+7imc9PY0x3FyyZEAg3GzOxQyQiIiIiAgDklVTUjPbUjvpcyChCfmn9Z3lMjaXo4aZALw8bBHnW/Nfb3rzDjeq0BiZD9yGRSBDe3QVDujjgw5gb2HTsJg5ezsJvN3LxcpgfZg/2hrGMyzURERERUfspLq/CxcwiXMgo0iRAmYV367Uzkkrg72JVk/h4KBDkaYOuTpYw4u+vAJgMNZmF3EhTXe6d7y7hdOodrDqQgG/jM7DqsR4I9rITO0QiIiIi0kPlVSok3FbWTHdLL8T5jEIk55WioYddfB0tEORhg14eCvTytEGgqzVMjfnMe2OYDDVTgIs1dj4fim/OpGPNj1dxNasYkz+NxYz+nnhjTABsLUzEDpGIiIiIOqhqlRo3cko0xQ0uZBTi6u1iVDdQzs3dxqwm6akd9enhoYC1KdfJbA4mQy0glUowvX8njAp0wT9/TMDO0xnYfiodP1/JxqKxAZgS7GGQcy6JiIiIqOkEQUBqfllNcYPaUZ/Lt5S4W6Wq19bOwgRBHtrP+ThYcqHnB8Vk6AHYWZjg/SlBmNrPE29/dxHXs0vw2q4L+OZ0BlY+1gN+zlybiIiIiMjQqdQCMu/cRVJuSe1WiqTcElzLKkbR3ap67S3lRujhbl073a1mypuHrRm/bG8DTIZaQX9vO+x/cSg2H7uJdYdvIC6lAI98+DvmDO2MF0d2gbkJP2YiIiIifVdaUY2beTWJTlLOn0lPcl5po+v4mMik6OZmjd5/GfXp7GDZYcpYd3T8Lb2VGMukeH64L8YHuWH53sv4+Uo2oo8m4Yfzt7Ds0e4YFci1iYiIiIg6OkEQkFNcUZvs/JnwJOWU4FZReaPHmRhJ0dnBAp0dLeDraAlfR0t0cbKEn7MVTIxY2U0sTIZambuNGTZE9MPhK9lYuvcyMgvvYu7W0xgV6Ixlj3aHO9cmIiIiItJ5FdUqpOaXaZKe5LqkJ7e0wUVL69hbmNQkO05/Jj2+jpZwtzWDjKM9OofJUBsJC3TGoC72+E9MIv77ezIOXcnGsRt5WBjWFc8O8eHaREREREQ64E5ppfazPLXJT1pBGRoo4AYAkEkl6GRnDt+/jPL4Olmgs4MlKwt3MEyG2pC5iRHeHBugWZsoLqUA//zxKr6Lz8TKx3qgvzfXJiIiIiJqD0VlVYhPv4PE7BKt5KegtLLRYyzlRn8mPE6Wmv/vZG8OuRHX7tEHTIbagZ+zFXY8PxC7zmRgzY9XcS27GFOjYzE12AOLHukGO36DQERERNSqsorKEZdSgFM3C3AqpQDXsosbXKQUqHnMofPfkp4ujpZwtJKzgpueYzLUTiQSCab280RYN2e8/9NVfB2Xjm/OZOBQQs3aRFODPVk1hIiIiKgFBEFAcl4pTt0sqEmAUgqQXnC3XjsfBwsEulnXTm2rSX46O1qw8q8B40++ndlamGDN470wJdgDb393CVezivHGtxex83QGVj3WAwEu1mKHSERERKTTqlVqJNwu1oz8nE4tQF6J9nQ3qQTo5mqN/t52GOBjh37etnCyMhUpYtJVTIZEEuxlh30vDMHnx1Ow9tB1nEm9g3H/OYZnh/hg4ciusJDzR0NEREQEAOVVKpxLL9SM/MSn3kFppUqrjYmRFL09bNDfxxb9ve0Q7GULK1NjkSKmjoK/cYvISCbFnKGdMa6XK9794Qp+vJSFDb8lY9/5W1j5WA88HMC1iYiIiMjwFN2twpnUAsTdvINTKQW4kFGIKpX2Az9WciMEe9tqRn56eShY1ICajcmQDnBVmOHTp4Lx69UcLNl7CekFdzF36xkcjhoOHwcLscMjIiIialPZynLE1RY6iLvZcLEDRys5Bnjbob+3Lfr72CHAxZrr9tADYzKkQx4KcMLPnYfjmc9PITY5H1/EpmLJhECxwyIiIiJqNYIg4GZeaW3iUzPyk1ZQVq+dj4NFTeJTO/LTyc6cld2o1TEZ0jFmJjI8P7wzYpPz8c3pdLwy2o/PDxEREVGHVVfs4FRtlbdTKSx2QLqDv2XroGFdHeHjYIGbeaX47mwmnhroJXZIRERERE0mCALibhZg64lUHL2Wi5KKaq3X/17soK+XLaxZ7IBEwGRIB0mlEjw90Avv7ruCrbEpmBnSicPCREREpPPKKqux5+wtbI1NwdWsYs3+vxc76OmugKkxix2Q+JgM6agp/Tzwr5+v4Xp2CWKT8zHI10HskIiIiIgalJJXii9PpGLn6XQoy2tGgUyNpXisjwdm9PdED3cFix2QTmIypKOsTY3xeF93fHkiDVuOpzAZIiIiIp2iVgs4eiMXW4+n4Mj1XE31Ny97czw90AtTgz2hMOfUN9JtTIZ0WESoN748kYZDV7KRWXgX7jZmYodEREREBq6orArfnEnHFydSkZr/ZxW4Ef6OmBXqjeF+jpByFIg6CCZDOszP2QqDfO1xPCkf206k4vUxAWKHRERERAYq4bYSW2NTsedsJu5WqQAAVqZGmNbPE08P9II310akDojJkI6LCPXG8aR8bD+VjhdHduXDhkRERNRuqlRq/Hw5G1tiUxB3s0CzP8DFChGh3pjUxw3mJvx1kjou/u3VcWHdnOBuY4bMwrvYd+E2pgR7iB0SERER6bmc4nJsj0vHtpOpyFZWAABkUgnGdHdBRKgXBvjYsdIt6QUmQzrOSCbFzIGd8P7Ba9hyPAWT+7rz5kNEREStThAExKcVYmtsCg5cvI0qVU1FBAdLEzw5oBOeDPGCi4ILoZJ+YTLUAczo3wnrDt/AxcwinE0vRN9OtmKHRERERHqivEqFH87fwpbYFFzKVGr29+lkg9mDvDGmhwvkRpymT/qJyVAHYGdhgkeD3LDrTAa2Hk9hMkREREQPLONOGb48kYYdp9Jwp6wKAGBiJMWjQW6ICPVCLw8bcQMkagdMhjqIWaHe2HUmA/sv3sZb47rByYrD1ERERNQ8giDgj8R8bIlNQUxCNtS1awO525jhqYFemN7fE3YWJuIGSdSOmAx1ED09FOjbyQbxaYXYHldTWY6IiIioKYrLq7A7PhNbY1OQlFuq2T+kiwMiQr0wspszZFwbiAwQk6EOZNYgb8SnncO2k6mYN8IXxjKp2CERERGRDkvMKcbW2FR8eyYDpZU1awNZmMgwJdgDT4d6oYuTlcgREomLyVAHMraHK1ZYJiBbWYGfLmdhfC83sUMiIiIiHaNSCzickI2tsSn4IzFfs9/X0QKzBnnjsT7usDI1FjFCIt3BZKgDMTGS4smQTvhPzA1sPZ7KZIiIiIi0HE/Kw2vfXEBm4V0AgFQChHVzxqxB3hjka8/lOYj+hslQBzMzpBM++TURcSkFuHJLiUA3a7FDIiIiIh3w8+UsLPjqLCpVatiaG2PGgE6YGdIJHrbmYodGpLP40EkH42xtijE9XAAAW2NTxA2GiIiIdMLu+AzM2xaPSpUa4d2dcfzNkXhjTAATIaL7YDLUAc0a5A0A2HMuE4VlleIGQ0RERKLacjwFUTvPQ6UWMLmvB9Y/2RdmJlwklagpmAx1QP28bBHoao3yKjV2nk4XOxwiIiISgSAI+CjmBpbuvQwAmD3IG/83pReMWG2WqMn4r6UDkkgkmDXICwCwNTYVqroV04iIiMggCIKA1QcS8MGh6wCAhSO7YumEQEi5VhBRszAZ6qAm9naHjbkxMu7cxa9Xc8QOh4iIiNqJSi3gzW8vYuPvNwEAi8cH4uVRfqwUR9QCTIY6KFNjGab38wQAbGEhBSIiIoNQWa3Gi1+fxY7T6ZBKgPen9MKzQ3zEDouow2Iy1IE9NdALEgnw+408JOWWiB0OERERtaGyymrM2Xoa+y/ehrFMgvVP9sW02i9GiahlmAx1YJ525hgZ4AwA+CI2VeRoiIiIqK0U3a1CxKY4/HY9F2bGMmya1R9je7qKHRZRh8dkqIOrK6Sw60wGSiqqRY6GiIiIWlteSQWe2HACp1PvwMrUCF/OGYBhfo5ih0WkF5gMdXBDujigs6MFSiqqsTs+Q+xwiIiIqBVlFt7FtOhYXLmthIOlCXY8F4pgLzuxwyLSG0yGOjiJRIJZod4AahZdEwSW2SYiItIHybklmPrpcSTnlcLdxgw7nw9FoJu12GER6RUmQ3rg8b7usDCRISm3FH8k5osdDhERET2gy7eKMDU6FreKytHZ0QLf/L9QdHa0FDssIr3TomRo/fr18Pb2hqmpKUJCQhAXF3fP9oWFhZg/fz5cXV0hl8vh5+eHAwcOaF5ftmwZJBKJ1hYQENCS0AySlakxpgR7AGCZbSIioo7udEoBZmw4gfzSSnR3s8bO50PhZmMmdlhEesmouQfs2LEDUVFRiI6ORkhICNatW4fw8HBcu3YNTk5O9dpXVlZi1KhRcHJywq5du+Du7o7U1FTY2NhotevevTsOHz78Z2BGzQ7NoD0d6o0tsamISchGekEZPO3MxQ6JiIiImuno9Vw8/8VplFep0d/bFptm94e1qbHYYRHprWaPDK1duxZz585FZGQkAgMDER0dDXNzc2zevLnB9ps3b0ZBQQH27NmDwYMHw9vbG8OHD0dQUJBWOyMjI7i4uGg2BweHll2RgeriZImhXR2gFoAvT7LMNhERUUdz4OJtzNlyCuVVagz3c8TWZ0KYCBG1sWYlQ5WVlThz5gzCwsL+PIFUirCwMMTGxjZ4zN69exEaGor58+fD2dkZPXr0wOrVq6FSqbTa3bhxA25ubujcuTNmzpyJtLS0RuOoqKiAUqnU2giIqC2ksONUOsqrVPduTESkZ5o7hXvdunXw9/eHmZkZPD098fLLL6O8vPyBzknUUjtPpWPBV/GoUgkY18sVGyP6wcxEJnZYRHqvWclQXl4eVCoVnJ2dtfY7OzsjKyurwWOSk5Oxa9cuqFQqHDhwAIsXL8YHH3yAlStXatqEhITg888/x8GDB/Hpp5/i5s2bGDp0KIqLixs855o1a6BQKDSbpydXXwaAhwOc4GFrhsKyKuw9d0vscIiI2k3dFO6lS5ciPj4eQUFBCA8PR05OToPtv/rqK7z55ptYunQpEhISsGnTJuzYsQNvvfVWi89J1FL//T0Zr397AWoBmNHfE/+Z0QcmRqxxRdQe2vxfmlqthpOTEzZs2IDg4GBMnz4db7/9NqKjozVtxo4di6lTp6JXr14IDw/HgQMHUFhYiJ07dzZ4zkWLFqGoqEizpaent/VldAgyqQRPD6xZhPVzltkmIgPS3Cncx48fx+DBg/Hkk0/C29sbo0ePxhNPPKE18tPccxI1lyAIWPvzNazcnwAAeG5YZ6x5vCdkUonIkREZjmYlQw4ODpDJZMjOztban52dDRcXlwaPcXV1hZ+fH2SyP4d6u3XrhqysLFRWVjZ4jI2NDfz8/JCYmNjg63K5HNbW1lob1ZjWzxNyIymu3FbiTOodscMhImpzLZnCPWjQIJw5c0aT/CQnJ+PAgQN45JFHWnxOTuGm5lCrBSz/4Qr+80vN7zqvhftj0dgASCRMhIjaU7OSIRMTEwQHByMmJkazT61WIyYmBqGhoQ0eM3jwYCQmJkKtVmv2Xb9+Ha6urjAxMWnwmJKSEiQlJcHV1bU54REAWwsTTOrtDgDYEstCCkSk/1oyhfvJJ5/Eu+++iyFDhsDY2Bi+vr4YMWKEZppcS87JKdzUVNUqNV7ddR6fH08BAKyY2B3zH+rCRIhIBM2eJhcVFYWNGzdiy5YtSEhIwLx581BaWorIyEgAQEREBBYtWqRpP2/ePBQUFGDhwoW4fv069u/fj9WrV2P+/PmaNq+++iqOHj2KlJQUHD9+HI899hhkMhmeeOKJVrhEwxMxqGaq3I8XbyNHWX6f1kREhufIkSNYvXo1PvnkE8THx2P37t3Yv38/VqxY0eJzcgo3NUV5lQr/2BaP3fGZkEkl+Pf0IDxdWwCJiNpfsxfzmT59OnJzc7FkyRJkZWWhd+/eOHjwoObbs7S0NEilf+ZYnp6e+Omnn/Dyyy+jV69ecHd3x8KFC/HGG29o2mRkZOCJJ55Afn4+HB0dMWTIEJw4cQKOjo6tcImGp7ubAv29bXEq5Q62nUzDy6P8xA6JiKjNtGQK9+LFi/H0009jzpw5AICePXuitLQUzz33HN5+++0WnVMul0Mul7fCFZG+Kq2oxnNfnMYfifkwMZJi/ZN9MSrQ+f4HElGbadHKpgsWLMCCBQsafO3IkSP19oWGhuLEiRONnm/79u0tCYPuISLUG6dS7uCruDTMf6gLq9IQkd766xTuSZMmAfhzCndjfVVZWZnWF3cANM+2CoLQonMS3UthWSVm/+8UzqUXwsJEho2z+mGQL9dUJBJbi5Ih0n1jerjAyUqOnOIK/HjpNibWPkdERKSPoqKiMGvWLPTr1w8DBgzAunXr6k3hdnd3x5o1awAAEyZMwNq1a9GnTx+EhIQgMTERixcvxoQJEzRJ0f3OSdRUOcpyPL0pDteyi2FjbozPIwegt6eN2GEREZgM6S1jmRQzQ7zw78PXsTU2lckQEem15k7hfueddyCRSPDOO+8gMzMTjo6OmDBhAlatWtXkcxI1RXpBGZ7adBKp+WVwspLji2dD4O9iJXZYRFRLIujBYjRKpRIKhQJFRUUss/0XOcXlGPzPX1ClErDvhSHo4a4QOyQi0kO8BzeMnwvdyC7GU5tOIltZAU87M2x7diA62ZuLHRaR3mvO/ZcPkugxJytTPNKzpjz5ltrynURERNT2LmQUYtpnschWVqCrkyV2/b9BTISIdBCTIT0XUVuu8/vzt3CntOFFbomIiKj1nEjOx5MbT+JOWRWCPBTY+XwonK1NxQ6LiBrAZEjP9e1kg57uClRWq7H9FNe8ICIiaku/XM3GrM1xKKmoRmhne2ybOxC2Fg0vMk9E4mMypOckEgkiQmsWYf3yRCpU6g7/iBgREZFO+v5cJp7begYV1WqEdXPC/yL7w1LOWlVEuozJkAGYEOQGW3NjZBbexeGE7PsfQERERM3y5YlUvLTjHKrVAib1dsOnTwXD1FgmdlhEdB9MhgyAqbEMMwZ0AgBsjU0RNxgiIiI989nRJLyz5xIEAXh6oBfWTusNYxl/xSLqCPgv1UDMDOkEqQT4IzEfiTnFYodDRESkFy5lFmHNj1cBAPMf8sW7E7tDKpWIHBURNRWTIQPhYWuOUYE1CwVuOZ4qcjRERET64cOYGwBqpqS/Fh4AiYSJEFFHwmTIgMyqLbP9bXwGlOVV4gZDRETUwV3KLMKhK9mQSoCFI7uKHQ4RtQCTIQMS6muPrk6WKKtU4dszGWKHQ0RE1KHVjQo9GuSGLk6WIkdDRC3BZMiASCQSRAzyBgB8EZsKNctsExERtchfR4UWPMxRIaKOismQgXm8jzus5EZIzivFscQ8scMhIiLqkDgqRKQfmAwZGAu5Eab08wAAbDmeIm4wREREHRBHhYj0B5MhA/T0QC8AwC/XcpCWXyZyNERERB0LR4WI9AeTIQPU2dESw/0cIQjAFydSxA6HiIiow+CoEJF+YTJkoGYNqhkd2nEqHXcrVSJHQ0RE1DFwVIhIvzAZMlDD/ZzQyc4cyvJq7DmXKXY4REREOu+vo0IvcF0hIr3AZMhAyaQSRITWjA5tOZ4CQWCZbSIionupGxWa2Nsdvo4cFSLSB0yGDNjUYE+YGctwNasYp1LuiB0OERGRztJ+VqiL2OEQUSthMmTAFObGmNTHHQDLbBMREd3LusMcFSLSR0yGDFzdVLmDl7OQVVQucjRERES651JmEQ4ncFSISB8xGTJw3VytMcDHDiq1gG0nU8UOh4iISOdwVIhIfzEZIswe5A0A+DouDRXVLLNNRERUh6NCRPqNyRBhVKAzXKxNkVdSiR8vZokdDhERkc7gqBCRfmMyRDCWSfHUwE4AgM9ZSIGIiAgAR4WIDAGTIQIAzBjQCSYyKc6lF+J8eqHY4RAREYmOo0JE+o/JEAEAHCzlGN/LFQCwJTZF3GCIiIhExlEhIsPAZIg0ImoLKew7fxv5JRXiBkNERCQijgoRGQYmQ6TR29MGQR4KVKrU2H4qXexwiIiIRMFRISLDwWSItMyqHR3adiIV1Sq1uMEQERGJgKNCRIaDyRBpGdfLFfYWJrhVVI7DCdlih0NERNSuOCpEZFiYDJEWuZEMTwxgmW0iIjJM6w5fB8BRISJDwWSI6nkypBNkUglOJBfgWlax2OEQERG1i4sZRTickMNRISIDwmSI6nGzMcPoQGcAwFaW2SYiIgPxYQxHhYgMDZMhalBdIYXd8ZkoulslbjBERERt7K+jQi9wVIjIYDAZogaF+NjB39kKd6tU2HUmQ+xwiIiI2lTdqNCk3u7ozFEhIoPBZIgaJJFINKNDX8SmQK0WxA2IiIiojfBZISLDxWSIGjWpjxusTY2Qkl+GozdyxQ6HiIioTXBUiMhwMRmiRpmbGGFaP08AwNcn00SOhoiIqPVxVIjIsDEZont6vK8HAOC3G7kor1KJHA0REVHr4qgQkWFjMkT31M3VCu42ZiivUuOPxDyxwyEiImo1HBUiIiZDdE8SiQRh3ZwAAIcTskWOhoiIqPVwVIiImAzRfYXVLsB6OCGHVeWIiEgvcFSIiAAmQ9QEIT72sJQbIbe4Ahcyi8QOh4iI6IFxVIiIACZD1AQmRlIM93cEABy+wqlyRETUsV3IKOSoEBEBYDJETTSqW91UOSZDRETUsX14+AYAjgoREZMhaqIR/o6QSSW4mlWM9IIyscMhIiJqkQsZhYi5ylEhIqrBZIiaxMbcBP29bQFwdIiIiDoujgoR0V8xGaImC+NUOSIi6sA4KkREf8dkiJpsVG2J7ZPJBSi6WyVyNERERM3DUSEi+jsmQ9RkXvYW6OpkiWq1gKPXc8UOh4hIy/r16+Ht7Q1TU1OEhIQgLi6u0bYjRoyARCKpt40bN07TZvbs2fVeHzNmTHtcCrUBjgoRUUOYDFGzaBZgZYltItIhO3bsQFRUFJYuXYr4+HgEBQUhPDwcOTk5DbbfvXs3bt++rdkuXboEmUyGqVOnarUbM2aMVruvv/66PS6H2gBHhYioIUyGqFnqnhv69VoOqlRqkaMhIqqxdu1azJ07F5GRkQgMDER0dDTMzc2xefPmBtvb2dnBxcVFsx06dAjm5ub1kiG5XK7VztbWtj0uh1rZX0eFXhjZVexwiEiHMBmiZuntaQMHSxMUl1fj1M0CscMhIkJlZSXOnDmDsLAwzT6pVIqwsDDExsY26RybNm3CjBkzYGFhobX/yJEjcHJygr+/P+bNm4f8/PxWjZ3ah2ZUqI87fBws7tOaiAwJkyFqFplUgocDnAAAh1hVjoh0QF5eHlQqFZydnbX2Ozs7Iysr677Hx8XF4dKlS5gzZ47W/jFjxmDr1q2IiYnBe++9h6NHj2Ls2LFQqVQNnqeiogJKpVJrI/FpjQo9zFEhItLGZIiarW6q3KEr2RAEQeRoiIgezKZNm9CzZ08MGDBAa/+MGTPw6KOPomfPnpg0aRL27duHU6dO4ciRIw2eZ82aNVAoFJrN09OzHaKn++GoEBHdC5MharYhXR0gN5Ii485dXMsuFjscIjJwDg4OkMlkyM7WHq3Ozs6Gi4vLPY8tLS3F9u3b8eyzz973fTp37gwHBwckJiY2+PqiRYtQVFSk2dLT05t+EdQmOCpERPfDZIiazdzECEO6OABgVTkiEp+JiQmCg4MRExOj2adWqxETE4PQ0NB7HvvNN9+goqICTz311H3fJyMjA/n5+XB1dW3wdblcDmtra62NxLWOo0JEdB9MhqhF6kpsH0pouGwtEVF7ioqKwsaNG7FlyxYkJCRg3rx5KC0tRWRkJAAgIiICixYtqnfcpk2bMGnSJNjb22vtLykpwWuvvYYTJ04gJSUFMTExmDhxIrp06YLw8PB2uSZ6MOfTC/ELR4WI6D6MxA6AOqaRtUUUzqcXIkdZDidrU5EjIiJDNn36dOTm5mLJkiXIyspC7969cfDgQU1RhbS0NEil2t//Xbt2DceOHcPPP/9c73wymQwXLlzAli1bUFhYCDc3N4wePRorVqyAXC5vl2uiB/NhDEeFiOj+JIIePAGvVCqhUChQVFTEaQntaOL6P3A+vRBrHu+JJwZ0EjscIhIJ78EN4+cinvPphZi4/g9IJUDMKyOYDBEZmObcf1s0TW79+vXw9vaGqakpQkJCEBcXd8/2hYWFmD9/PlxdXSGXy+Hn54cDBw480DlJfKO61YwO8bkhIiLSJRwVIqKmanYytGPHDkRFRWHp0qWIj49HUFAQwsPDkZPT8LMjlZWVGDVqFFJSUrBr1y5cu3YNGzduhLu7e4vPSbqh7rmhY4l5KKusFjkaIiIiPitERM3T7GRo7dq1mDt3LiIjIxEYGIjo6GiYm5tj8+bNDbbfvHkzCgoKsGfPHgwePBje3t4YPnw4goKCWnxO0g3+zlbwsDVDRbUax27kiR0OERERR4WIqFmalQxVVlbizJkzCAsL+/MEUinCwsIQGxvb4DF79+5FaGgo5s+fD2dnZ/To0QOrV6/WrODdknNylW/dIJFINAuwHk7gVDkiIhIXR4WIqLmalQzl5eVBpVJpqvPUcXZ2RlZWVoPHJCcnY9euXVCpVDhw4AAWL16MDz74ACtXrmzxObnKt+4YVTtVLiYhByp1h6/FQUREHRhHhYioudp8nSG1Wg0nJyds2LABwcHBmD59Ot5++21ER0e3+Jxc5Vt3DPCxg5WpEfJLK3EuvVDscIiIyEBxVIiIWqJZyZCDgwNkMhmys7WnRGVnZ8PFxaXBY1xdXeHn5weZTKbZ161bN2RlZaGysrJF5+Qq37rDWCbFCP/aqnKcKkdERCLhqBARtUSzkiETExMEBwcjJiZGs0+tViMmJgahoaENHjN48GAkJiZCrVZr9l2/fh2urq4wMTFp0TlJt4SxxDYREYmIo0JE1FLNniYXFRWFjRs3YsuWLUhISMC8efNQWlqKyMhIAEBERAQWLVqkaT9v3jwUFBRg4cKFuH79Ovbv34/Vq1dj/vz5TT4n6bYRfk4wkkpwI6cEKXmlYodDREQGhqNCRNRSRs09YPr06cjNzcWSJUuQlZWF3r174+DBg5oCCGlpaZBK/8yxPD098dNPP+Hll19Gr1694O7ujoULF+KNN95o8jlJtynMjTHAxw7Hk/JxOCEbc4Z2FjskIiIyEOdqR4VkUgle5KgQETWTRBCEDl8CTKlUQqFQoKioiM8PiWTzsZt4d98VDOxsh+3PcXojkSHhPbhh/FzaR+T/4vDrtVxM7uuBD6YF3f8AItJ7zbn/tnk1OTIMdesNnUq5g8KySpGjISIiQ3AuvRC/XsuFTCrBCw93ETscIuqAmAxRq+hkbw5/Zyuo1AKOXMsVOxwiIjIAHx6+DgCY1Nsd3nxWiIhagMkQtZqwwJqqcodYYpuIiNoYR4WIqDUwGaJWUzdV7ui1XFRWq+/TmoiIqOU4KkRErYHJELWaIA8bOFjKUVJRjZM388UOh4iI9BRHhYiotTAZolYjlUq4ACsREbW5Db8lAeCoEBE9OCZD1KrqpsodTsiBHlRtJyIiHVN0twqHE3IAAM8M8RY3GCLq8JgMUasa0tUBpsZSZBbeRcLtYrHDISIiPfPTpSxUVqvR1ckSga5cv4mIHgyTIWpVpsYyDO3qCAA4zKpyRETUyvacywQATOztBolEInI0RNTRMRmiVjdKM1WOyRAREbWerKJyxCbXFOiZ2Ntd5GiISB8wGaJW91CAEyQS4EJGEbKKysUOh4iI9MQP529BEIBgL1t42pmLHQ4R6QEmQ9TqHK3k6ONpAwCIucrRISIiah3fn6+ZIjept5vIkRCRvmAyRG0iLLB2qhxLbBMRUStIzCnBpUwljKQSjOvFZIiIWgeTIWoTdc8N/ZGUj9KKapGjISKiju772sIJQ7s6wM7CRORoiEhfMBmiNtHFyRJe9uaorFbj9xt5YodDREQdmCAI+P7cLQDApD4snEBErYfJELUJiUTylwVYOVWOiIha7mx6IdIKymBuIsOo2mnYREStgckQtZm6ZOiXqzlQqQWRoyEioo7q+7M1U+RGBzrD3MRI5GiISJ8wGaI208/bFgozYxSUVuJs2h2xwyEiog6oSqXGvgu3AQATOUWOiFoZkyFqM8YyKR7ydwQAHOJUOSIiaoFjiXnIL62EnYUJhnRxEDscItIzTIaoTbHENhERPYi6KXLje7nCWMZfW4iodfGuQm1qmJ8jjGUSJOWWIjm3ROxwiIioAymrrMbPtV+mTezNKXJE1PqYDFGbsjY1xsDO9gCAmIQckaMhIqKO5NCVbJRVqtDJzhx9O9mIHQ4R6SEmQ9Tm6qrK8bkhIiJqjrq1hSb2doNEIhE5GiLSR0yGqM2N7OYEADidUoA7pZUiR0NERB1BQWklfrueC6AmGSIiagtMhqjNediao5urNdQC8Os1TpUjIqL723/hFqrVArq7WaOLk5XY4RCRnmIyRO1iVO3o0GFOlSMioiaomyI3iYUTiKgNMRmidlFXYvvotVxUVKtEjoaIiHRZekEZTqfegUQCTAjiFDkiajtMhqhd9HBTwNlajtJKFU4kF4gdDhER6bC952tGhUI728NFYSpyNESkz5gMUbuQSiUY2Y0LsBIR0b0JgoA9tQutsnACEbU1JkPUbkbVJUMJ2RAEQeRoiIhIFyXcLsaNnBKYyKQY08NV7HCISM8xGaJ2E+prDzNjGW4XlePyLaXY4RARkQ76/lzNqNDDAU5QmBmLHA0R6TsmQ9RuTI1lGObnAIBV5YiIqD61WtA8LzSpD6fIEVHbYzJE7SrsL1PliIiI/urkzQLcLiqHlakRRvg7iR0OERkAJkPUrh4OcIJEAlzKVOJW4V2xwyEiIh1SN0XukR6uMDWWiRwNERkCJkPUruwt5QjuZAsAiOHoEBER1aqoVuHAxdsAWEWOiNoPkyFqd3ULsB5KyBE5EiIi0hVHruVCWV4NZ2s5Qjrbix0OERkIJkPU7uqeG4pNykNxeZXI0RARkS6omyL3aJAbZFKJyNEQkaFgMkTtztfRAj4OFqhSCfj9Rp7Y4RARkciU5VU4XDtbYGJvd5GjISJDwmSI2p1EIkFYt5oqQYev8LkhIiJDd/BSFiqr1ejiZInubtZih0NEBoTJEImibqrcL9dyUK1SixwNERGJae+5mrWFJga5QSLhFDkiaj9MhkgUwV62sDE3RmFZFc6k3hE7HCLSA+vXr4e3tzdMTU0REhKCuLi4RtuOGDECEomk3jZu3DhNG0EQsGTJEri6usLMzAxhYWG4ceNGe1yKQclRluN4Us2UaU6RI6L2xmSIRGEkk+Lh2gX1uAArET2oHTt2ICoqCkuXLkV8fDyCgoIQHh6OnJyGq1bu3r0bt2/f1myXLl2CTCbD1KlTNW3ef/99/Oc//0F0dDROnjwJCwsLhIeHo7y8vL0uyyDsPX8LagHo28kGnezNxQ6HiAwMkyESjabE9pVsCIIgcjRE1JGtXbsWc+fORWRkJAIDAxEdHQ1zc3Ns3ry5wfZ2dnZwcXHRbIcOHYK5ubkmGRIEAevWrcM777yDiRMnolevXti6dStu3bqFPXv2tOOV6b/va6fITerDUSEian9Mhkg0w/wcYSKTIiW/DEm5pWKHQ0QdVGVlJc6cOYOwsDDNPqlUirCwMMTGxjbpHJs2bcKMGTNgYWEBALh58yaysrK0zqlQKBASEtLkc9L9JeWW4GJmEWRSCcb1dBU7HCIyQEyGSDSWciMM9K1ZWI9T5YiopfLy8qBSqeDs7Ky139nZGVlZWfc9Pi4uDpcuXcKcOXM0++qOa845KyoqoFQqtTa6t7pRoaFdHWBvKRc5GiIyREyGSFSjWGKbiES2adMm9OzZEwMGDHig86xZswYKhUKzeXp6tlKE+kkQBM1Cq5NYOIGIRMJkiEQ1srbE9pm0O8gvqRA5GiLqiBwcHCCTyZCdrf2lSnZ2NlxcXO55bGlpKbZv345nn31Wa3/dcc0556JFi1BUVKTZ0tPTm3spBuVceiFS88tgZizDqEDn+x9ARNQGmAyRqNxszNDdzRqCAPxyteGqT0RE92JiYoLg4GDExMRo9qnVasTExCA0NPSex37zzTeoqKjAU089pbXfx8cHLi4uWudUKpU4efJko+eUy+WwtrbW2qhxdVPkRnd3hoXcSORoiMhQMRki0dUtwMrnhoiopaKiorBx40Zs2bIFCQkJmDdvHkpLSxEZGQkAiIiIwKJFi+odt2nTJkyaNAn29vZa+yUSCV566SWsXLkSe/fuxcWLFxEREQE3NzdMmjSpPS5Jr1Wr1Nh3obaKHKfIEZGI+FUMiW5UoDM+jLmB367nobxKBVNjmdghEVEHM336dOTm5mLJkiXIyspC7969cfDgQU0BhLS0NEil2t//Xbt2DceOHcPPP//c4Dlff/11lJaW4rnnnkNhYSGGDBmCgwcPwtTUtM2vR9/9kZSPvJJK2FmYYEhXB7HDISIDJhH0YIEXpVIJhUKBoqIiTkvogARBQOiaX5ClLMf/ZvfHQwFOYodERM3Ae3DD+Lk0LmrHOew+m4mnB3phxaQeYodDRHqmOfdfTpMj0UkkEoQF1iRAhzhVjohIr92tVOGnyzXlySf1cRM5GiIydEyGSCfUPTcUk5ANtbrDD1YSEVEjDiVko7RSBU87M/TtZCt2OERk4JgMkU4I9bWHhYkM2coKXLpVJHY4RETURvbWri00McgdEolE5GiIyNAxGSKdIDeSYbi/IwAuwEpEpK/ulFbiyLVcAMDE3pwiR0TiYzJEOqNuqtyhBK43RESkj/ZfvI1qtYBAV2t0dbYSOxwiIiZDpDse8neCVAIk3FYi406Z2OEQEVEr+752ihwLJxCRrmAyRDrD1sIE/bztAAAxHB0iItIrGXfKcCrlDiQS4NEgLrRKRLqByRDplFG1U+UOs8Q2EZFe2Xv+FgBgoI89XBRcuJaIdAOTIdIpYYE1ydCJ5Hwoy6tEjoaIiFrL92drkiEWTiAiXcJkiHSKj4MFfB0tUKUS8Nv1XLHDISKiVpBwW4lr2cUwkUkxtqer2OEQEWkwGSKdUzc6xBLbRET6YU9t4YSHAhyhMDMWORoioj8xGSKdU/fc0C9Xc1ClUoscDRERPQi1WsAP52qmyE3qzcIJRKRbmAyRzunTyRZ2FiZQllfjdModscMhIqIHcCqlALeKymFlaoSHApzEDoeISEuLkqH169fD29sbpqamCAkJQVxcXKNtP//8c0gkEq3N1FS7iszs2bPrtRkzZkxLQiM9IJNK8HBth8mqckREHdue2lGhsT1cYGosEzkaIiJtzU6GduzYgaioKCxduhTx8fEICgpCeHg4cnIaXxfG2toat2/f1mypqan12owZM0arzddff93c0EiPhP2lxLYgCCJHQ0RELVFZrcaBi7cBABM5RY6IdFCzk6G1a9di7ty5iIyMRGBgIKKjo2Fubo7Nmzc3eoxEIoGLi4tmc3Z2rtdGLpdrtbG1tW1uaKRHhnZ1gImRFKn5ZUjMKRE7HCIiaoEj13JQdLcKTlZyDOxsL3Y4RET1NCsZqqysxJkzZxAWFvbnCaRShIWFITY2ttHjSkpK4OXlBU9PT0ycOBGXL1+u1+bIkSNwcnKCv78/5s2bh/z8/OaERnrGQm6Ewb41HechTpUjIuqQvq9daPXRIDfIpBKRoyEiqq9ZyVBeXh5UKlW9kR1nZ2dkZWU1eIy/vz82b96M77//Hl9++SXUajUGDRqEjIwMTZsxY8Zg69atiImJwXvvvYejR49i7NixUKlUDZ6zoqICSqVSayP9wxLbREQdV3F5leb+PakPp8gRkW4yaus3CA0NRWhoqObPgwYNQrdu3fDZZ59hxYoVAIAZM2ZoXu/Zsyd69eoFX19fHDlyBCNHjqx3zjVr1mD58uVtHTqJbGSAM97GJZxNL0RucQUcreRih0RERE300+VsVFSr4etoge5u1mKHQ0TUoGaNDDk4OEAmkyE7W/ub+uzsbLi4uDTpHMbGxujTpw8SExMbbdO5c2c4ODg02mbRokUoKirSbOnp6U2/COowXBSm6OWhgCAAv15tvEAHERHpnu9rF1qd2NsdEgmnyBGRbmpWMmRiYoLg4GDExMRo9qnVasTExGiN/tyLSqXCxYsX4erq2mibjIwM5OfnN9pGLpfD2tpaayP9VFdVjs8NERF1HDnF5fgjMQ8AMLG3m8jREBE1rtnV5KKiorBx40Zs2bIFCQkJmDdvHkpLSxEZGQkAiIiIwKJFizTt3333Xfz8889ITk5GfHw8nnrqKaSmpmLOnDkAaoorvPbaazhx4gRSUlIQExODiRMnokuXLggPD2+ly6SOqi4Z+v1GLsqrGn6GjIiIdMu+87ehFoA+nWzgZW8hdjhERI1q9jND06dPR25uLpYsWYKsrCz07t0bBw8e1BRVSEtLg1T6Z451584dzJ07F1lZWbC1tUVwcDCOHz+OwMBAAIBMJsOFCxewZcsWFBYWws3NDaNHj8aKFSsgl/MZEUPXzdUK7jZmyCy8iz8S8zCyW/2y7EREpFvqpshN4tpCRKTjJIIerGipVCqhUChQVFTEKXN6aOn3l7AlNhVPDPDEmsd7iR0OEf0N78ENM9TP5WZeKR761xHIpBKcfGskHCz5xSYRta/m3H+bPU2OqL1pSmwn5ECt7vC5OxGRXttztmZUaEgXByZCRKTzmAyRzgvxsYel3Ai5xRW4kFkkdjhERNQIQRD+nCLXh4UTiEj3MRkinWdiJMVwf0cAXICViEiXXcgoQkp+GcyMZRgd2LQlN4iIxMRkiDqEUd3qpsoxGSIi0lV7akeFRgU6w0Le5uu6ExE9MCZD1CGM8HeETCrB1axipBeUiR0OERH9TbVKjR/O3wbAKXJE1HEwGaIOwcbcBP29bQFwdIiISBcdT8pHXkkFbM2NMbSro9jhEBE1CZMh6jDCOFWOiEhnfX/uFgBgXC9XGMv46wURdQy8W1GHMaq2xPbJ5AIU3a0SORoiIqpTXqXCT5ezAHChVSLqWJgMUYfhZW8BP2dLVKsFbD2eInY4RERU63BCNkoqquFha4ZgL1uxwyEiajImQ9ShzH+oCwDgkyNJyCoqFzkaIiICgD1na6bITeztBolEInI0RERNx2SIOpRHg9zQz8sWd6tUeP/gVbHDISIyeIVllTh6PQcAMJFT5Iiog2EyRB2KRCLBkgmBAIDdZzMRn3ZH5IiIiAzbgYtZqFIJ6OZqDT9nK7HDISJqFiZD1OH08rDB1GAPAMC7P1yBWi2IHBERkeGqW2h1Um+uLUREHQ+TIeqQXhvjDwsTGc6lF2o6YiIial+ZhXcRd7MAEgnwKJMhIuqAmAxRh+RkZYoFD3cFALx38CpKK6pFjoiIyPDsrV1bKMTHDq4KM5GjISJqPiZD1GE9M8QbnezMka2swKdHksQOh4jI4HxfOzLPwglE1FExGaIOS24kw1uPdAMAbPg9GekFZSJHRERkOK5mKXE1qxgmMike6eEqdjhERC3CZIg6tPDuzhjka4/KajXW/JggdjhERAbj+9opciP8HaEwNxY5GiKilmEyRB1aXaltqaSmvOuJ5HyxQyIi0ntqtaB5XmhSH06RI6KOi8kQdXgBLtZ4MqQTAGD5D1egYqltIqI2dTr1DjIL78JKboSHA5zEDoeIqMWYDJFeiBrlD2tTIyTcVmLn6XSxwyEi0mt1SxqE93CBqbFM5GiIiFqOyRDpBTsLE7wU5gcA+NdP16AsrxI5IiIi/VRZrcaBi7cBAJNYRY6IOjgmQ6Q3ng71gq+jBfJLK/FRzA2xwyEi0ku/Xc9FYVkVnKzkCPW1FzscIqIHwmSI9IaxTIrF4wMBAP/7IwXJuSUiR0REpH/qpshNCHKDTCoRORoiogfDZIj0ygh/Jzzk74hqtYBV+1lqm4ioNZVUVONwQjYATpEjIv3AZIj0zjvjA2EklSDmag6OXs8VOxwiIr3x8+UslFep0dnRAj3crcUOh4jogTEZIr3j62iJWYO8AQAr9l1BlUotbkBERHpiT+3aQhOD3CGRcIocEXV8TIZIL704sivsLEyQmFOCbSdSxQ6HiNrB+vXr4e3tDVNTU4SEhCAuLu6e7QsLCzF//ny4urpCLpfDz88PBw4c0Ly+bNkySCQSrS0gIKCtL0Nn5RZX4NiNmtH2ib3dRI6GiKh1MBkivaQwM8Yro2tKbf/78A3cKa0UOSIiaks7duxAVFQUli5divj4eAQFBSE8PBw5OTkNtq+srMSoUaOQkpKCXbt24dq1a9i4cSPc3bWfg+nevTtu376t2Y4dO9Yel6OT9l24BbUA9Pa0gbeDhdjhEBG1CiZDpLdm9O+EABcrFN2twr8PXxc7HCJqQ2vXrsXcuXMRGRmJwMBAREdHw9zcHJs3b26w/ebNm1FQUIA9e/Zg8ODB8Pb2xvDhwxEUFKTVzsjICC4uLprNwcGhPS5HJ9VNkZvEUSEi0iNMhkhvyaQSLJ3QHQDw5YlUXMsqFjkiImoLlZWVOHPmDMLCwjT7pFIpwsLCEBsb2+Axe/fuRWhoKObPnw9nZ2f06NEDq1evhkql0mp348YNuLm5oXPnzpg5cybS0tLa9Fp0VUpeKc6nF0ImlWBcLyZDRKQ/mAyRXgv1tceY7i5QC8C7+y5DEASxQyKiVpaXlweVSgVnZ2et/c7OzsjKymrwmOTkZOzatQsqlQoHDhzA4sWL8cEHH2DlypWaNiEhIfj8889x8OBBfPrpp7h58yaGDh2K4uKGv1ipqKiAUqnU2vTFvgs1o0KDfO3haCUXORoiotbDZIj03luPdIOJkRR/JObj0JVsscMhIh2gVqvh5OSEDRs2IDg4GNOnT8fbb7+N6OhoTZuxY8di6tSp6NWrF8LDw3HgwAEUFhZi586dDZ5zzZo1UCgUms3T07O9LqfN7btwG0DNQqtERPqEyRDpvU725pgzxAcAsOpAAiqqVfc5gog6EgcHB8hkMmRna3/ZkZ2dDRcXlwaPcXV1hZ+fH2QymWZft27dkJWVhcrKhguu2NjYwM/PD4mJiQ2+vmjRIhQVFWm29PT0Fl6RbknMKcbVrGIYyyQID2z48yQi6qiYDJFB+MdDXeBkJUdqfhn+90eK2OEQUSsyMTFBcHAwYmJiNPvUajViYmIQGhra4DGDBw9GYmIi1Oo/1yG7fv06XF1dYWJi0uAxJSUlSEpKgqura4Ovy+VyWFtba2364IfzNaNCQ7s6QmFuLHI0RESti8kQGQRLuRFeH1OzPsjHvyQip7hc5IiIqDVFRUVh48aN2LJlCxISEjBv3jyUlpYiMjISABAREYFFixZp2s+bNw8FBQVYuHAhrl+/jv3792P16tWYP3++ps2rr76Ko0ePIiUlBcePH8djjz0GmUyGJ554ot2vTyyCIGieFxrfq+EkkIioIzMSOwCi9vJ4H3d8EZuC8xlF+NdP1/D+lKD7H0REHcL06dORm5uLJUuWICsrC71798bBgwc1RRXS0tIglf75/Z+npyd++uknvPzyy+jVqxfc3d2xcOFCvPHGG5o2GRkZeOKJJ5Cfnw9HR0cMGTIEJ06cgKOjY7tfn1iuZhUjKbcUJkZSjAp0vv8BREQdjETQg/JaSqUSCoUCRUVFejMtgdrGmdQ7mPzpcUgkwN75Q9DTQyF2SEQdHu/BDdOHz+X/frqK9b8mYXSgMzZE9BM7HCKiJmnO/ZfT5MigBHvZYlJvNwgCsPwHltomImpMzRS5mueFxrOKHBHpKSZDZHDeGBsAM2MZTqfe0XT0RESk7fItJVLzy2BqLMXIACexwyEiahNMhsjguCrMMG+ELwBgzYEE3K1kqW0ior/7obZwwsgAZ1jI+YgxEeknJkNkkJ4b1hnuNma4VVSODb8lix0OEZFOEQQB++umyLGKHBHpMSZDZJBMjWVY9EhNqe1PjybiVuFdkSMiItId59ILkXHnLixMZHiIU+SISI8xGSKDNa6nKwZ426G8So33Dl4VOxwiIp1R9zxlWKAzTI1lIkdDRNR2mAyRwZJIJFgyIRASCfD9uVs4k1ogdkhERKJTq/86RY5V5IhIvzEZIoPWw12BacGeAIDlP1yBWs1S20Rk2M6k3UGWshxWpkYY5ucgdjhERG2KyRAZvFfD/WEpN8KFjCLsPpspdjhERKLad76mitzoQBfIjThFjoj0G5MhMniOVnK88HAXAMB7B6+ipKJa5IiIiMShUgs4cCkLAKvIEZFhYDJEBGD2YG942Zsjt7gCn/yaKHY4RESiOHkzH7nFFVCYGWNwF06RIyL9x2SICIDcSIZ3xgUCAP77+02k5ZeJHBERUfurqyI3prsLTIz4KwIR6T/e6YhqhXVzwpAuDqhUqbH6QILY4RARtatqlRoH66bIBXGKHBEZBiZDRLUkEgkWjw+ETCrBwctZOJ6UJ3ZIRETt5nhSPgpKK2FvYYLQzvZih0NE1C6YDBH9hb+LFWaGdAIAvPvDFVSr1CJHRETUPvZdqKkiN6aHC4xk/PWAiAwD73ZEf/NymB8UZsa4mlWM7afSxQ6HiKjNVVb/ZYocF1olIgPCZIjob2wtTPByWFcAwAc/X0NRWZXIERERta0/EvOgLK+Go5UcA3zsxA6HiKjdMBkiasDMgV7o6mSJO2VV+DDmhtjhEBG1qR9qp8iN6+kKmVQicjRERO2HyRBRA4xlUiweX1Nqe2tsChJzSkSOiIiobZRXqXDocjYALrRKRIaHyRBRI4b5OSKsmxOq1QJW7r8idjhERG3it+u5KK6ohqvCFH072YodDhFRu2IyRHQPb48LhLFMgiPXcvHr1RyxwyEianV1C62O6+kKKafIEZGBYTJEdA8+DhaIHOwDAFix/wqqWGqbiPTI3UoVDifUTpELYhU5IjI8TIaI7mPBw11gb2GC5NxSbI1NFTscIqJW8+u1HJRVquBha4YgD4XY4RARtTsmQ0T3YW1qjFfD/QEA6w5fR35JhcgRERG1jrqFVsf1coVEwilyRGR4mAwRNcG0fp4IdLVGcXk11h66LnY4REQPrLSiGr/UPgs5gQutEpGBYjJE1AQyqQRLJ9SU2v46Lg0Jt5UiR0RE9GAOJ2SjvEoNb3tzdHezFjscIiJRMBkiaqKQzvYY19MVagF494crEARB7JCIiFqsrorc+F5unCJHRAarRcnQ+vXr4e3tDVNTU4SEhCAuLq7Rtp9//jkkEonWZmpqqtVGEAQsWbIErq6uMDMzQ1hYGG7cuNGS0Ija1JtjA2BiJEVscj5+ql2kkIioo1GWV+HotVwAwPggLrRKRIar2cnQjh07EBUVhaVLlyI+Ph5BQUEIDw9HTk7ja7BYW1vj9u3bmi01Vbsi1/vvv4///Oc/iI6OxsmTJ2FhYYHw8HCUl5c3/4qI2pCnnTmeH9YZALDqwBWUV6lEjoiIqPkOXc5GpUqNLk6W8He2EjscIiLRNDsZWrt2LebOnYvIyEgEBgYiOjoa5ubm2Lx5c6PHSCQSuLi4aDZnZ2fNa4IgYN26dXjnnXcwceJE9OrVC1u3bsWtW7ewZ8+eFl0UUVv6f8N94WwtR3rBXWz+46bY4RARNVtdFbnxrCJHRAauWclQZWUlzpw5g7CwsD9PIJUiLCwMsbGxjR5XUlICLy8veHp6YuLEibh8+bLmtZs3byIrK0vrnAqFAiEhIY2es6KiAkqlUmsjai8WciO8OTYAAPDxL4nIUXIEk4g6jsKySvx+Iw9AzfNCRESGrFnJUF5eHlQqldbIDgA4OzsjKyurwWP8/f2xefNmfP/99/jyyy+hVqsxaNAgZGRkAIDmuOacc82aNVAoFJrN09OzOZdB9MAmBrmjt6cNyipVeGfPJVRUc7ocEXUMP1/ORrVaQICLFbo4WYodDhGRqNq8mlxoaCgiIiLQu3dvDB8+HLt374ajoyM+++yzFp9z0aJFKCoq0mzp6emtGDHR/UmlEix7tDtkUgl+vpKNJzacQDZHiIioA/ihdorchCCOChERNSsZcnBwgEwmQ3a2dhWt7OxsuLi4NOkcxsbG6NOnDxITEwFAc1xzzimXy2Ftba21EbW33p422Dy7P6xNjRCfVojxHx3DmdQ7YodFRNSo/JIKHE/KB1DzvBARkaFrVjJkYmKC4OBgxMTEaPap1WrExMQgNDS0SedQqVS4ePEiXF1rbsI+Pj5wcXHROqdSqcTJkyebfE4isQz3c8TeBUPg52yJ3OIKzNgQi6/j0sQOi4ioQQcvZ0GlFtDTXQEvewuxwyEiEl2zp8lFRUVh48aN2LJlCxISEjBv3jyUlpYiMjISABAREYFFixZp2r/77rv4+eefkZycjPj4eDz11FNITU3FnDlzANRUmnvppZewcuVK7N27FxcvXkRERATc3NwwadKk1rlKojbk7WCB7/4xGGN7uKBKJWDR7ot467uLqKxWix0aEZGWfefrFlrlqBAREQAYNfeA6dOnIzc3F0uWLEFWVhZ69+6NgwcPagogpKWlQSr9M8e6c+cO5s6di6ysLNja2iI4OBjHjx9HYGCgps3rr7+O0tJSPPfccygsLMSQIUNw8ODBeouzEukqC7kRPpnZF58cScK/fr6Gr06m4XpWMT55qi+crPj3mIjEl1NcjpM3a6bIjWMyREQEAJAIgiCIHcSDUiqVUCgUKCoq4vNDJLpfr+Xgxa/Pori8Gs7WckQ/FYw+nWzFDouozfAe3DBd+1y2HE/B0r2X0dvTBnvmDxY7HCKiNtOc+2+bV5MjMjQP+Tth74Ih6OpkiWxlBaZ/dgI7TvE5IiIS118XWiUiohpMhojagI+DBb6bPxjh3Z1RqVLjjW8vYvGeS3yOiIhEcbvoLk6l1FS75BQ5IqI/MRkiaiOWciN8OjMYr4zyg0QCfHEiFTP/ewK5xRVih0ZEBmb/hZrCCf29beGqMBM5GiIi3cFkiKgNSaUSvDCyK/4b0Q9WciOcSrmDCR8dw7n0QrFDIyIDsu9CXRU5LrRKRPRXTIaI2sHIbs7Ys2AwfB0tkKUsx7TPYrHzdLrYYRGRAUgvKMO59EJIJcDYnk1bIJ2IyFAwGSJqJ76OltgzfzBGBTqjslqN13ddwNLvL6FKxeeIiKjt7L9YMyoU4mPPUv9ERH/DZIioHVmZGuOzp4LxUlhXAMCW2FTM/O9J5JXwOSIiahuaKnJBLJxARPR3TIaI2plUKsFLYX7YGNEPlnIjxN0swKMfHcPFjCKxQyMiPZOSV4pLmUrIpBKM7cFkiIjo75gMEYlkVKAz9swfjM6OFrhVVI7J0cfx7ZkMscMiIj1SNyo0yNcedhYmIkdDRKR7mAwRiaiLU81zRCMDnFBZrcYr35zH8h8u8zkiImoVdVXkJrCKHBFRg5gMEYnM2tQYGyP64cWRNc8R/e+PFDy96STy+RwRET2AxJwSXM0qhrFMgvDurCJHRNQQJkNEOkAqlSBqlB8+ezoYFiYynEguwKMf/4FLmXyOiIhapm6K3NCujlCYG4scDRGRbmIyRKRDwru7YM/8wfBxsEBm4V1M/vQ49pzNFDssIupgBEH4y0KrLJxARNQYJkNEOqarsxX2zB+MhwOcUFGtxks7zmHFviuo5nNERPe0fv16eHt7w9TUFCEhIYiLi7tn+8LCQsyfPx+urq6Qy+Xw8/PDgQMHHuicuuJadjESc0pgIpMiLNBZ7HCIiHQWkyEiHaQwM8Z/I/rhhYe7AAA2HbuJiM1xKCitFDkyIt20Y8cOREVFYenSpYiPj0dQUBDCw8ORk5PTYPvKykqMGjUKKSkp2LVrF65du4aNGzfC3d29xefUJfvO14wKDfd3hLUpp8gRETWGyRCRjpJKJXhltD+in+oLcxMZjiflY8JHx3D5Fp8jIvq7tWvXYu7cuYiMjERgYCCio6Nhbm6OzZs3N9h+8+bNKCgowJ49ezB48GB4e3tj+PDhCAoKavE5dUXNFLnahVY5RY6I6J6YDBHpuDE9XPHdPwbDy95c8xzR9+f4HBFRncrKSpw5cwZhYWGafVKpFGFhYYiNjW3wmL179yI0NBTz58+Hs7MzevTogdWrV0OlUrX4nLri8i0lUvLLYGosRVg3TpEjIroXJkNEHYC/ixX2zh+C4X6OKK9SY+H2c1h9IIHPEREByMvLg0qlgrOz9i/+zs7OyMrKavCY5ORk7Nq1CyqVCgcOHMDixYvxwQcfYOXKlS0+Z0VFBZRKpdYmhh9qR4UeDnCChdxIlBiIiDoKJkNEHYTC3BibZ/fHP0b4AgA2/JaM2f87hTt8joio2dRqNZycnLBhwwYEBwdj+vTpePvttxEdHd3ic65ZswYKhUKzeXp6tmLETSMIAvZrqshxoVUiovthMkTUgcikErw+JgDrn+wLM2MZjiXm4dH1x3DlljjfQBPpAgcHB8hkMmRnZ2vtz87OhotLw4uNurq6ws/PDzKZTLOvW7duyMrKQmVlZYvOuWjRIhQVFWm29PT0B7yy5juXXoiMO3dhbiLDQ/5O7f7+REQdDZMhog5oXC9XfDd/EDrZmSO9oOY5oroHpokMjYmJCYKDgxETE6PZp1arERMTg9DQ0AaPGTx4MBITE6FW/znV9Pr163B1dYWJiUmLzimXy2Ftba21tbe6tYXCujnDzER2n9ZERMRkiKiDCnCxxt4FgzG0qwPuVqmw4KuzWPNjAsqrVGKHRtTuoqKisHHjRmzZsgUJCQmYN28eSktLERkZCQCIiIjAokWLNO3nzZuHgoICLFy4ENevX8f+/fuxevVqzJ8/v8nn1DVq9V+nyLGKHBFRU/DJSqIOzMbcBJ9HDsD7P13FZ0eT8dnRZHx1Mg2P9XHHtH6e6OGuEDtEonYxffp05ObmYsmSJcjKykLv3r1x8OBBTQGEtLQ0SKV/fv/n6emJn376CS+//DJ69eoFd3d3LFy4EG+88UaTz6lrzqTdQZayHFZyIwz3dxQ7HCKiDkEiCIIgdhAPSqlUQqFQoKioSJRpCUS6YP+F21h9IAGZhXc1+3q4W2N6/054NMgNCjMuvEhtg/fghrX357L0+0vYEpuKx/u6Y+203m3+fkREuqo591+ODBHpiXG9XDG2hwv+SMrD9lPpOHQ5G5cylbiUeQkr913BuJ6umN7fEwN87CCRSMQOl4hakUot4MClmpLfE1hFjoioyZgMEekRqVSCoV0dMbSrIwpKK/Hd2UzsOJWG69kl2H02E7vPZsLHwQLT+nlicrA7nKxMxQ6ZiFpB3M0C5BZXQGFmjMFdHMQOh4iow2AyRKSn7CxM8OwQHzwz2Btn0wux81Q69p6/hZt5pXjv4FX86+drGBnghBkDPDGsqyOMZKynQtRR1VWTHNPdBSZG/LdMRNRUTIaI9JxEIkHfTrbo28kW74wPxP4Lt7D9VDrOphXi5yvZ+PlKNpyt5Zga7Ilp/TzRyd5c7JCJqBmqVWocrJ0iN45V5IiImoXJEJEBsZQbYXr/TpjevxOuZxdjx6l07I7PQLayAh//moiPf03E4C72mNbPE+HdXWBqzHVKiHRdbHI+8ksrYWdhgkG+9mKHQ0TUoTAZIjJQfs5WWDw+EK+P8cfhKznYfioNxxLz8EdiPv5IzIfCzBiP9XHH9P6e6ObKCmFEumrf+Zq1hcb0cOF0VyKiZmIyRGTg5EYyjOvlinG9XJFeUIZvzmRg1+l03Coqx+fHU/D58RQEeSgwvX8nTAhyhZUpS3QT6YrKajUOXq6ZIseFVomImo/JEBFpeNqZI2qUHxaO7Irfb+Rix6l0HE7IxvmMIpzPuIgV+65gXC9XzOjviWAvW5boJhLZH4l5KLpbBUcrOUJ8OEWOiKi5mAwRUT0yqQQj/J0wwt8JeSUV+C4+E9tPpSEptxS7zmRg15kM+DpaYHp/Tzze1wMOlnKxQyYySD/UVpF7pIcLZFJ+OUFE1FxMhojonhws5Zg7rDPmDPVBfNodbI9Lx74Lt5GUW4rVB67i/YPXMCrQGdP7e2JoV0f+QkbUTsqrVDh0ORsAMD6IC60SEbUEkyEiahKJRIJgLzsEe9lhyYRA/HD+NnacTsf59EL8eCkLP17KgpvCFFP6eWJqsAc87Viim6gt/XY9F8UV1XCxNkVwJ1uxwyEi6pCYDBFRs1mZGuPJkE54MqQTEm4rseNUOvacy8StonL8J+YGPvrlBoZ0ccCEXm7o3ckGvo6WHDEiamX7LtRUkRvXyxVS/vsiImoRJkNE9EC6uVpj2aPd8ebYAPx8JRs7TqXhj8R8/H4jD7/fyAMAWJjI0MNdgSBPG/TyUCDIwwYetmYswEDUQncrVTicUDtFjlXkiIhajMkQEbUKU2MZHg1yw6NBbkjLL8Ou+AycSM7HpcwilFaqcPJmAU7eLNC0t7MwQc/aBCnIQ4FeHjZwtGIhBqKm+PVaDsoqVfCwNUNvTxuxwyEi6rCYDBFRq+tkX1OiGwBUagFJuSU4l16ICxmFuJBRhITbShSUVuLo9VwcvZ6rOc5NYYpeHjaaBKmHhwLWXNeIqJ59tVXkxvVy5QgrEdEDYDJERG1KJpXAz9kKfs5WmNbPEwBQUa3C1dvFOJ9RiPPpRbiQUYjE3BLcKirHraIszSKSANDZ0QK9PWqm1/XytEGgqzVMjWViXQ6R6EorqvHL1RwAwIRerCJHRPQgmAwRUbuTG8lqRn88bYDQmn0lFdW4lFmE8+k1o0fnMwqRcecuknNLkZxbit1nMwEARlIJ/F2stKbXdXWyhJFMKt4FEbWjmKs5KK9Sw9veHN3drMUOh4ioQ2MyREQ6wVJuhIGd7TGws71mX35JBS78JUG6kFGIvJJKXL6lxOVbSnx1sqadmbEM3d2stQo0eNmbc/oQ6aV95zlFjoiotTAZIiKdZW8px0P+TnjI3wkAIAgCbhWV40J6Ic5lFOJCehEuZhahpKIap1Pv4HTqHc2xCjNjTWLUy6OmUIOztalYl0LUKorLq3Ck9jm78ZwiR0T0wJgMEVGHIZFI4G5jBncbM4ztWVNOWK0WkJxXigsZhTifXojzGUW4cluJortVWuW9AcDHwQKDu9hjSBcHhHZ2gMKcxRmoYzl0JRuV1Wr4OlogwMVK7HCIiDo8JkNE1KFJpRJ0cbJEFydLPN7XAwBQWa3G9exirQp217OLcTOvFDfzSvHliTRIJUBPDxsM6WKPwV0cEOxlC7kRCzOQbqtbaHV8LzdOkSMiagVMhohI75gYSdHDXYEe7goAXgBqphedSC7AH4l5OJaYh8SckpqRpPRCrP81CabGUvT3tsOQLg4Y3MUBga7WkEr5yybpjqKyKvx+o2aK3IQgLrRKRNQamAwRkUGwMjXGqEBnjAp0BgBkFZXjj8Q8TXKUU1yhNa3OzsIEg3ztNcmRp525mOET4afLWahSCQhwsUIXJ06RIyJqDUyGiMgguShMMTnYA5ODPSAIAm7klODYjZrk6ERyPgpKK7Hvwm3NtCQve3MM7uKAIV0cMMjXHjbmJiJfARmaH2oXWh3fi6NCRESthckQERk8ieTPhWGfGeKDKpUa59MLcax25OhsWiFS88uQmp+Gr06mQSIBerorNMlRsJctF4KlNpVfUoHjSfkAWEWOiKg1MRkiIvobY5kU/bzt0M/bDi+F+aGkohonk/M1ydH17JLadY+K8OmRJMiNap836lqTHPF5I2ptBy9nQaUW0MPdGt4OFmKHQ0SkN5gMERHdh6XcCCO7OWNkt5rnjbKV5TieVPN80R+JechWVuBY7bNHAGBrboxBvg6akaNO9nzeiB7MvvN/VpEjIqLWw2SIiKiZnK1N8VgfDzzWp+Z5o6TcmueNjiXm40RyPu6UVWH/xdvYf7HmF1hPOzMM6eKAIV0cEeprDzsLPm9ETZdTXI6TN2umyI3ryeeFiIhaE5MhIqIHIJFI0MWpprrX7ME1zxtdyCjEsRv5+CMxD/Fpd5BecBdfx6Xj67h0SCRAdzdrhPjYI9DVGv4uVujiZMlnjqhRP17MgloAenvasKohEVErYzJERNSKjGVSBHvZIdjLDgvDuqK0ohpxNws0zxtdzSrGpUwlLmUqNcfIpBL4OFggwMWqdqtJkjxszbiwJmEfq8gREbUZJkNERG3IQm6EhwKc8FCAE4CaKU/HE/NxLr0QV7OUuJpVjMKyKiTmlCAxp0RTyhsArORG8NMkSFYIcLWGn7MVFGbGYl0OtbOsonKcSrkDABjHZIiIqNUxGSIiakdOVqaY1Mcdk/q4AwAEQUC2skKTGF3LKkbCbSWScktQXFGNM6l3cCb1jtY53G3M4F+bIPm7WKGbqzV8HCxgLJOKcUnUhuqeO+vnZQtXhZnI0RAR6R8mQ0REIpJIJHBRmMJFYYoR/k6a/VUqNZJzS7WSpKu3lbhVVI7MwrvILLyLX67maNqbyKTwdbLUjCLVJUlOVnJOtevAOEWOiKhtMRkiItJBxjIp/GuTmol/2V90twrXsopxLUuJhNok6VpWMUoqqpFwW4mE20qt89iYG2ueQ6pLkvxdrGBuwtu/rsu4U4azaYWQSIBHWEWOiKhNsDckIupAFGbGGOBjhwE+dpp9giAg487d2hGkP5Ok5NwSFJZV4URyAU4kF2jaSyRAJzvz2uTIGt1qn0fy4WKeOmV/7fNjIT52cLI2FTkaIiL9xGSIiKiDk0gk8LQzh6edOUYFOmv2l1epkJhTokmSrmYV42pWMXKLK5CaX4bU/DL8dDkbAODraIGYV0aIdAXUkLpiGlxolYio7TAZIiLSU6bGMvRwV6CHu0Jrf35JRU2hhr8kSV2drESKkhpSWa1GFydLpBWUYWwPF7HDISLSW0yGiIgMjL2lHIO6yDGoi4PYoVAjTIyk+Pf03qhSqVklkIioDfEOS0REpKOYCBERtS3eZYmIiIiIyCC1KBlav349vL29YWpqipCQEMTFxTXpuO3bt0MikWDSpEla+2fPng2JRKK1jRkzpiWhERERERERNUmzk6EdO3YgKioKS5cuRXx8PIKCghAeHo6cnJx7HpeSkoJXX30VQ4cObfD1MWPG4Pbt25rt66+/bm5oRERERERETdbsZGjt2rWYO3cuIiMjERgYiOjoaJibm2Pz5s2NHqNSqTBz5kwsX74cnTt3brCNXC6Hi4uLZrO1tW1uaERERERERE3WrGSosrISZ86cQVhY2J8nkEoRFhaG2NjYRo9799134eTkhGeffbbRNkeOHIGTkxP8/f0xb9485OfnNyc0IiIiIiKiZmlWae28vDyoVCo4Oztr7Xd2dsbVq1cbPObYsWPYtGkTzp071+h5x4wZg8cffxw+Pj5ISkrCW2+9hbFjxyI2NhYymaxe+4qKClRUVGj+rFQqm3MZREREREREbbvOUHFxMZ5++mls3LgRDg6Nr2cxY8YMzf/37NkTvXr1gq+vL44cOYKRI0fWa79mzRosX768TWImIiIiIiLD0Kxpcg4ODpDJZMjOztban52dDReX+itkJyUlISUlBRMmTICRkRGMjIywdetW7N27F0ZGRkhKSmrwfTp37gwHBwckJiY2+PqiRYtQVFSk2dLT05tzGURERERERM0bGTIxMUFwcDBiYmI05bHVajViYmKwYMGCeu0DAgJw8eJFrX3vvPMOiouL8eGHH8LT07PB98nIyEB+fj5cXV0bfF0ul0MulzcndCIiIiIiIi3NniYXFRWFWbNmoV+/fhgwYADWrVuH0tJSREZGAgAiIiLg7u6ONWvWwNTUFD169NA63sbGBgA0+0tKSrB8+XJMnjwZLi4uSEpKwuuvv44uXbogPDz8AS+PiIiIiIioYc1OhqZPn47c3FwsWbIEWVlZ6N27Nw4ePKgpqpCWlgaptOmz72QyGS5cuIAtW7agsLAQbm5uGD16NFasWMHRHyIiIiIiajMSQRAEsYN4UEqlEgqFAkVFRbC2thY7HCIig8J7cMP4uRARiaM5999mL7pKRERERESkD9q0tHZ7qRvc4npDRETtr+7eqwcTDVoV+yYiInE0p1/Si2SouLgYABqtTkdERG2vuLgYCoVC7DB0BvsmIiJxNaVf0otnhtRqNW7dugUrKytIJJJmH69UKuHp6Yn09HTO624AP5/G8bNpHD+bxunbZyMIAoqLi+Hm5tasAjr6jn1T2+Fnc2/8fBrHz6Zx+vTZNKdf0ouRIalUCg8Pjwc+j7W1dYf/4bclfj6N42fTOH42jdOnz4YjQvWxb2p7/GzujZ9P4/jZNE5fPpum9kv8Co+IiIiIiAwSkyEiIiIiIjJITIYAyOVyLF26lIu8NoKfT+P42TSOn03j+NlQU/DvSeP42dwbP5/G8bNpnKF+NnpRQIGIiIiIiKi5ODJEREREREQGickQEREREREZJCZDRERERERkkJgMERERERGRQWIyBGD9+vXw9vaGqakpQkJCEBcXJ3ZIoluzZg369+8PKysrODk5YdKkSbh27ZrYYemkf/7zn5BIJHjppZfEDkVnZGZm4qmnnoK9vT3MzMzQs2dPnD59WuywRKdSqbB48WL4+PjAzMwMvr6+WLFiBVjHhhrCvqk+9k1Nx75JG/ulhrFfYjKEHTt2ICoqCkuXLkV8fDyCgoIQHh6OnJwcsUMT1dGjRzF//nycOHEChw4dQlVVFUaPHo3S0lKxQ9Mpp06dwmeffYZevXqJHYrOuHPnDgYPHgxjY2P8+OOPuHLlCj744APY2tqKHZro3nvvPXz66af4+OOPkZCQgPfeew/vv/8+PvroI7FDIx3Dvqlh7Juahn2TNvZLjWO/xNLaCAkJQf/+/fHxxx8DANRqNTw9PfHCCy/gzTffFDk63ZGbmwsnJyccPXoUw4YNEzscnVBSUoK+ffvik08+wcqVK9G7d2+sW7dO7LBE9+abb+KPP/7A77//LnYoOmf8+PFwdnbGpk2bNPsmT54MMzMzfPnllyJGRrqGfVPTsG+qj31TfeyXGsd+ycBHhiorK3HmzBmEhYVp9kmlUoSFhSE2NlbEyHRPUVERAMDOzk7kSHTH/PnzMW7cOK2/PwTs3bsX/fr1w9SpU+Hk5IQ+ffpg48aNYoelEwYNGoSYmBhcv34dAHD+/HkcO3YMY8eOFTky0iXsm5qOfVN97JvqY7/UOPZLgJHYAYgpLy8PKpUKzs7OWvudnZ1x9epVkaLSPWq1Gi+99BIGDx6MHj16iB2OTti+fTvi4+Nx6tQpsUPROcnJyfj0008RFRWFt956C6dOncKLL74IExMTzJo1S+zwRPXmm29CqVQiICAAMpkMKpUKq1atwsyZM8UOjXQI+6amYd9UH/umhrFfahz7JQNPhqhp5s+fj0uXLuHYsWNih6IT0tPTsXDhQhw6dAimpqZih6Nz1Go1+vXrh9WrVwMA+vTpg0uXLiE6OtrgO52dO3di27Zt+Oqrr9C9e3ecO3cOL730Etzc3Az+syFqLvZN2tg3NY79UuPYLxl4MuTg4ACZTIbs7Gyt/dnZ2XBxcREpKt2yYMEC7Nu3D7/99hs8PDzEDkcnnDlzBjk5Oejbt69mn0qlwm+//YaPP/4YFRUVkMlkIkYoLldXVwQGBmrt69atG7799luRItIdr732Gt58803MmDEDANCzZ0+kpqZizZo1BtPp0P2xb7o/9k31sW9qHPulxrFfMvBnhkxMTBAcHIyYmBjNPrVajZiYGISGhooYmfgEQcCCBQvw3Xff4ZdffoGPj4/YIemMkSNH4uLFizh37pxm69evH2bOnIlz584ZbGdTZ/DgwfVK3V6/fh1eXl4iRaQ7ysrKIJVq33ZlMhnUarVIEZEuYt/UOPZNjWPf1Dj2S41jv2TgI0MAEBUVhVmzZqFfv34YMGAA1q1bh9LSUkRGRoodmqjmz5+Pr776Ct9//z2srKyQlZUFAFAoFDAzMxM5OnFZWVnVm59uYWEBe3t7zlsH8PLLL2PQoEFYvXo1pk2bhri4OGzYsAEbNmwQOzTRTZgwAatWrUKnTp3QvXt3nD17FmvXrsUzzzwjdmikY9g3NYx9U+PYNzWO/VLj2C8BEEj46KOPhE6dOgkmJibCgAEDhBMnTogdkugANLj973//Ezs0nTR8+HBh4cKFYoehM3744QehR48eglwuFwICAoQNGzaIHZJOUCqVwsKFC4VOnToJpqamQufOnYW3335bqKioEDs00kHsm+pj39Q87Jv+xH6pYeyXBMHg1xkiIiIiIiLDZNDPDBERERERkeFiMkRERERERAaJyRARERERERkkJkNERERERGSQmAwREREREZFBYjJEREREREQGickQEREREREZJCZDRERERERkkJgMERERERGRQWIyREREREREBonJEBERERERGSQmQ0REREREZJD+P83Pi4wQVM5QAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n", "axs[0].set_title(\"Loss value on test set\")\n", "axs[0].plot(eval_metrics_history[\"test_loss\"])\n", "axs[1].set_title(\"Accuracy on test set\")\n", "axs[1].plot(eval_metrics_history[\"test_accuracy\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, we acquired over 85% accuracy on a sentiment classification problem\n", "after 10 epochs. From the loss plot, we see that in the last few epochs the loss\n", "value didn't improve much which is an indication that our model has converged.\n", "\n", "JAX AI stack allowed us to build, train, and evaluate the model with JAX ecosystem\n", "libraries that interoperate smoothly in a multitude of computing environments.\n", "\n", "Let's now inspect a few predictions ourselves. First we're going to download\n", "\"indices to words\" dictionary to decode our samples. Then we will decode a few samples\n", "and print them together with a prediced and an actual label." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "response = requests.get(\"https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json\")\n", "response.raise_for_status()\n", "word_map = {v: k for k, v in json.loads(response.content).items()}" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def _label_to_str(label: int) -> str:\n", " return \"positive\" if label else \"negative\"\n", "\n", "def show_reviews(indices: list[int]) -> None:\n", " for idx in indices:\n", " x = x_test[idx][x_test[idx] != 0]\n", " y = y_test[idx]\n", " y_pred = model(x_test[idx][None, :]).argmax()\n", " review = \"\"\n", " for w_x in x:\n", " idx = w_x - index_from if w_x >= index_from else w_x\n", " review += f\"{word_map[idx]} \"\n", "\n", " print(\"Review:\")\n", " for line in textwrap.wrap(review):\n", " print(line)\n", " print(\"Predicted sentiment: \", _label_to_str(y_pred))\n", " print(\"Actual sentiment: \", _label_to_str(y), \"\\n\")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Review:\n", "please give this one a miss br br kristy swanson and the rest of the\n", "cast rendered terrible performances the show is flat flat flat br br i\n", "don't know how michael madison could have allowed this one on his\n", "plate he almost seemed to know this wasn't going to work out and his\n", "performance was quite lacklustre so all you madison fans give this a\n", "miss\n", "Predicted sentiment: negative\n", "Actual sentiment: negative \n", "\n", "Review:\n", "this is a funny movie the bob eddie show feel of it could lead to a\n", "sequel but i doubt it will make enough money br br deniro proves he\n", "can be a great straight man again with some hilarious and spontaneous\n", "moments eddie was fun to watch working with people instead of cgi\n", "animals and and rene russo well she's just fun to watch anyway and\n", "she's played her part excellent br br some wild and unusual stunts\n", "especially the garbage truck scene this was worth seeing in the\n", "theater we needed a good laugh and got many from the movie and the\n", "great out takes at the end do not leave at the start of the credits br\n", "br at least a 7\n", "Predicted sentiment: positive\n", "Actual sentiment: positive \n", "\n", "Review:\n", "terrible absolutely terrible long confusing and and after about three\n", "hours of this painful mess the ending truly is the final nail in the\n", "coffin not even the magnificent sexy beautiful goddess and and can\n", "save this poor adaptation of agatha and work the plot drags and drags\n", "and time goes by slowly and suddenly you realize that you don't even\n", "have any idea of what's going on anymore by the end even with the\n", "usual explanation by the villain there's still a lot that's left\n", "unexplained and and it's over a complete waste of time and without a\n", "doubt one of the worst to bear the name of agatha christie\n", "Predicted sentiment: negative\n", "Actual sentiment: negative \n", "\n", "Review:\n", "whoa this is one of the worst movies i have ever seen the packaging\n", "for the film is better than the film itself my girlfriend and i\n", "watched it this past weekend and we only continued to watch it in the\n", "hopes that it would get better it didn't br br the picture quality is\n", "poor it looks like it was shot on video and transferred to film the\n", "lighting is not great which makes it harder to read the actors' facial\n", "expressions the acting itself was cheesy but i guess it's acceptable\n", "for yet another teenage horror flick the sound was a huge problem\n", "sometimes you have to rewind the video because the sound is unclear\n", "and or and br br it holds no real merit of it's own trying to ride on\n", "the and of sleepy hollow don't bother with this one\n", "Predicted sentiment: negative\n", "Actual sentiment: negative \n", "\n", "Review:\n", "i am a big gone with the wind nut but i was disappointed that both\n", "gone with the wind the movie and scarlett the mini series are so\n", "different from the books gone with the wind left so many things out in\n", "the movie that were in the book and they did the same with scarlett\n", "both were good movies but i really liked both books better there were\n", "so many characters left out of scarlett and the ages of some\n", "characters didn't seem to match up with the book the time lines don't\n", "match up either scarlett realizes she is pregnant on the ship to\n", "ireland in the book but she realizes it when she is throwing up while\n", "in also sally is made out to be an ugly monkey like woman in the book\n", "and the movie casted jean smart to play her who is obviously not an\n", "ugly woman over all scarlett is a good movie and it helps anyone who\n", "was disappointed in the way gone with the wind ended to see what might\n", "have happened if margaret mitchell had lived to write a sequel herself\n", "Predicted sentiment: negative\n", "Actual sentiment: positive \n", "\n", "Review:\n", "if you liked the richard chamberlain version of the bourne identity\n", "then you will like this too aiden quinn does this one brilliantly you\n", "can't help but wonder if he is really out there i reckon he and the\n", "other main cast members probably had nightmares for weeks after doing\n", "this movie as it's so intense when i first saw it i was just and\n", "channels on the remote late one evening i got hooked within minutes\n", "look up www answers com for and and who is the character that carlos\n", "the is based on for both i remember reading about and arrest in the\n", "paper in 1997 it was front page for weeks through the trial after his\n", "arrest\n", "Predicted sentiment: positive\n", "Actual sentiment: positive \n", "\n" ] } ], "source": [ "show_reviews([0, 500, 600, 1000, 1800, 2000])" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "jax-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }