{
"cells": [
{
"cell_type": "markdown",
"id": "072f8ce2-7014-4f04-83a6-96953e9c8a79",
"metadata": {},
"source": [
"# Basic text classification with 1D CNN\n",
"\n",
"[](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_basic_text_classification.ipynb)\n",
"\n",
"In this tutorial we learn how to perform text classification from raw text data and train a basic 1D Convolutional Neural Network to perform sentiment analysis using JAX. This tutorial is originally inspired by [\"Text classification from scratch with Keras\"](https://keras.io/examples/nlp/text_classification_from_scratch/#build-a-model).\n",
"\n",
"We will use the IMDB movie review dataset to classify the review to \"positive\" and \"negative\" classes. We implement from scratch a simple model using Flax, train it and compute metrics on the test set."
]
},
{
"cell_type": "markdown",
"id": "ef7f5048-87d4-4578-a8ef-6fd8a9bad28e",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"We will be using the following packages in this tutorial:\n",
"- [Tiktoken](https://github.com/openai/tiktoken) to tokenize the raw text\n",
"- [Grain](https://github.com/google/grain) for efficient data loading and batching\n",
"- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "03e6ca8e-7a5e-4451-a1d6-699ddb1496eb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: grain in /opt/conda/lib/python3.11/site-packages (0.2.2)\n",
"Requirement already satisfied: tiktoken in /opt/conda/lib/python3.11/site-packages (0.8.0)\n",
"Requirement already satisfied: tqdm in /opt/conda/lib/python3.11/site-packages (4.66.4)\n",
"Requirement already satisfied: absl-py in /opt/conda/lib/python3.11/site-packages (from grain) (2.1.0)\n",
"Requirement already satisfied: array-record in /opt/conda/lib/python3.11/site-packages (from grain) (0.5.1)\n",
"Requirement already satisfied: cloudpickle in /opt/conda/lib/python3.11/site-packages (from grain) (3.1.0)\n",
"Requirement already satisfied: dm-tree in /opt/conda/lib/python3.11/site-packages (from grain) (0.1.8)\n",
"Requirement already satisfied: etils[epath,epy] in /opt/conda/lib/python3.11/site-packages (from grain) (1.9.4)\n",
"Requirement already satisfied: jaxtyping in /opt/conda/lib/python3.11/site-packages (from grain) (0.2.34)\n",
"Requirement already satisfied: more-itertools>=9.1.0 in /opt/conda/lib/python3.11/site-packages (from grain) (10.1.0)\n",
"Requirement already satisfied: numpy in /opt/conda/lib/python3.11/site-packages (from grain) (1.26.4)\n",
"Requirement already satisfied: regex>=2022.1.18 in /opt/conda/lib/python3.11/site-packages (from tiktoken) (2024.11.6)\n",
"Requirement already satisfied: requests>=2.26.0 in /opt/conda/lib/python3.11/site-packages (from tiktoken) (2.32.3)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2.0.4)\n",
"Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (3.7)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2.2.2)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests>=2.26.0->tiktoken) (2024.7.4)\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.11/site-packages (from etils[epath,epy]->grain) (2024.9.0)\n",
"Requirement already satisfied: importlib_resources in /opt/conda/lib/python3.11/site-packages (from etils[epath,epy]->grain) (6.4.5)\n",
"Requirement already satisfied: typing_extensions in /opt/conda/lib/python3.11/site-packages (from etils[epath,epy]->grain) (4.11.0)\n",
"Requirement already satisfied: zipp in /opt/conda/lib/python3.11/site-packages (from etils[epath,epy]->grain) (3.20.2)\n",
"Requirement already satisfied: typeguard==2.13.3 in /opt/conda/lib/python3.11/site-packages (from jaxtyping->grain) (2.13.3)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"!pip install grain tiktoken tqdm"
]
},
{
"cell_type": "markdown",
"id": "86b04d7d-2011-4c57-8976-c0a3746c9374",
"metadata": {},
"source": [
"### Load the data: IMDB movie review sentiment classification\n",
"\n",
"Let us download the dataset and briefly inspect the structure. We will be using only two classes: \"positive\" and \"negative\" for the sentiment analysis."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f211f467-9c07-45f6-89aa-6ebef42df27a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2024-11-18 16:58:00-- https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n",
"Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10\n",
"Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 84125825 (80M) [application/x-gzip]\n",
"Saving to: ‘/tmp/data/imdb/aclImdb_v1.tar.gz’\n",
"\n",
"/tmp/data/imdb/aclI 100%[===================>] 80.23M 17.8MB/s in 8.8s \n",
"\n",
"2024-11-18 16:58:09 (9.13 MB/s) - ‘/tmp/data/imdb/aclImdb_v1.tar.gz’ saved [84125825/84125825]\n",
"\n"
]
}
],
"source": [
"!rm -rf /tmp/data/imdb\n",
"!mkdir -p /tmp/data/imdb\n",
"!wget https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz -O /tmp/data/imdb/aclImdb_v1.tar.gz\n",
"!cd /tmp/data/imdb/ && tar -xf aclImdb_v1.tar.gz"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b0f91f70-9f10-43f0-a289-f17a66ba9906",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of positive samples in train set:\n",
"12500\n",
"Number of negative samples in train set:\n",
"12500\n",
"Number of positive samples in test set:\n",
"12500\n",
"Number of negative samples in test set:\n",
"12500\n",
"First 10 files with positive samples in train/test sets:\n",
"0_9.txt\n",
"10000_8.txt\n",
"10001_10.txt\n",
"10002_7.txt\n",
"10003_8.txt\n",
"10004_8.txt\n",
"10005_7.txt\n",
"10006_7.txt\n",
"10007_7.txt\n",
"10008_7.txt\n",
"ls: write error: Broken pipe\n",
"0_10.txt\n",
"10000_7.txt\n",
"10001_9.txt\n",
"10002_8.txt\n",
"10003_8.txt\n",
"10004_9.txt\n",
"10005_8.txt\n",
"10006_7.txt\n",
"10007_10.txt\n",
"10008_8.txt\n",
"ls: write error: Broken pipe\n",
"Display a single positive sample:\n",
"Being an Austrian myself this has been a straight knock in my face. Fortunately I don't live nowhere near the place where this movie takes place but unfortunately it portrays everything that the rest of Austria hates about Viennese people (or people close to that region). And it is very easy to read that this is exactly the directors intention: to let your head sink into your hands and say \"Oh my god, how can THAT be possible!\". No, not with me, the (in my opinion) totally exaggerated uncensored swinger club scene is not necessary, I watch porn, sure, but in this context I was rather disgusted than put in the right context.
This movie tells a story about how misled people who suffer from lack of education or bad company try to survive and live in a world of redundancy and boring horizons. A girl who is treated like a whore by her super-jealous boyfriend (and still keeps coming back), a female teacher who discovers her masochism by putting the life of her super-cruel \"lover\" on the line, an old couple who has an almost mathematical daily cycle (she is the \"official replacement\" of his ex wife), a couple that has just divorced and has the ex husband suffer under the acts of his former wife obviously having a relationship with her masseuse and finally a crazy hitchhiker who asks her drivers the most unusual questions and stretches their nerves by just being super-annoying.
After having seen it you feel almost nothing. You're not even shocked, sad, depressed or feel like doing anything... Maybe that's why I gave it 7 points, it made me react in a way I never reacted before. If that's good or bad is up to you!"
]
}
],
"source": [
"!echo \"Number of positive samples in train set:\"\n",
"!ls /tmp/data/imdb/aclImdb/train/pos | wc -l\n",
"!echo \"Number of negative samples in train set:\"\n",
"!ls /tmp/data/imdb/aclImdb/train/neg | wc -l\n",
"!echo \"Number of positive samples in test set:\"\n",
"!ls /tmp/data/imdb/aclImdb/test/pos | wc -l\n",
"!echo \"Number of negative samples in test set:\"\n",
"!ls /tmp/data/imdb/aclImdb/test/neg | wc -l\n",
"!echo \"First 10 files with positive samples in train/test sets:\"\n",
"!ls /tmp/data/imdb/aclImdb/train/pos | head\n",
"!ls /tmp/data/imdb/aclImdb/test/pos | head\n",
"!echo \"Display a single positive sample:\"\n",
"!cat /tmp/data/imdb/aclImdb/train/pos/6248_7.txt"
]
},
{
"cell_type": "markdown",
"id": "830d0e88-9d28-4c26-8cf1-842b65c8c85c",
"metadata": {},
"source": [
"Next, we will:\n",
"- create the dataset Python class to read samples from the disk\n",
"- use [Tiktoken](https://github.com/openai/tiktoken) to encode raw text into tokens and\n",
"- use [Grain](https://github.com/google/grain) for efficient data loading and batching."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4a1eceb5-4719-40da-ba81-9060920a7ef1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"- Number of samples in train and test sets: 25000 25000\n",
"- First train sample: {'text': \"Preston Waters, a 11 years old boy,has problems with his parents and brothers specially because of money issues. He is crazy to have his own house and his own rules,since his brothers always stole his saved money and his parents neglect his wishes. One awful day, Preston was riding his bicycle; It was the same day that the villain of the story,Quigley, was trying to scape from the Police and accidentally ran the car over Preston's bike. Needing to be far away from the police, Quigley gives in a hurry, a check to cover the damages of Preston's bike. The problem was: It was a blank check! Preston is a clever boy and decides to have a high price on that check: 1 million dollars! All that money gives Preston things that he always wished for, like a mansion with pool,lots of toys, and even a limousine! The problems start to begin when the FBI and Quigley wants to know where the money is, making Preston in a hard situation and facing many problems.
This movie was one of my favorites during my childhood. :)\", 'label': 0}\n",
"- First test sample: {'text': \"I think I was recommended this film by the lady in the shop I was hiring it from! For once she was bang on! What a superb film! First of all I was convinced James McAvoy & Romola Garai were Irish so convincing were their accents; and by half way through the film I was utterly convinced Steven Robertson was a disabled actor and pretty sure James McAvoy was also! When I watched the special features on the DVD and saw both actors in their 'normal' guise, to say I was blown away would be an understatement!!! I can remember all the acclaim Dustin Hoffmann got back in the 80's for his portrayal of autism in the film 'Rain Man' - quite frankly (in my opinion of course!)Steven Robertson's performance/portrayal blows Dustin Hoffmann's right out of the water - and he deserves recognition as such!! All in all one of the greatest portrayals of human friendship/love/relationships ever - and it was made in Britain/Ireland with home grown actors - stick that in yer pipe and smoke it Hollywood!\", 'label': 0}\n"
]
}
],
"source": [
"from pathlib import Path\n",
"\n",
"\n",
"class SentimentAnalysisDataset:\n",
" def __init__(self, path: str | Path):\n",
" self.path = Path(path)\n",
" assert self.path.exists()\n",
"\n",
" pos_texts = list((self.path / \"pos\").glob(\"*.txt\"))\n",
" neg_texts = list((self.path / \"neg\").glob(\"*.txt\"))\n",
" self.text_files = pos_texts + neg_texts\n",
" assert len(self.text_files) > 0\n",
" # Label 0 for Positive comments\n",
" # Label 1 for Negative comments\n",
" self.labels = [0] * len(pos_texts) + [1] * len(neg_texts)\n",
"\n",
" def __len__(self) -> int:\n",
" return len(self.text_files)\n",
"\n",
" def read_text_file(self, path: str | Path) -> str:\n",
" with open(path, \"r\") as handler:\n",
" lines = handler.readlines()\n",
" return \"\\n\".join(lines)\n",
"\n",
" def __getitem__(self, index: int) -> tuple[str, int]:\n",
" label = self.labels[index]\n",
" text = self.read_text_file(self.text_files[index])\n",
" return {\"text\": text, \"label\": label}\n",
"\n",
"\n",
"root_path = Path(\"/tmp/data/imdb/aclImdb/\")\n",
"train_dataset = SentimentAnalysisDataset(root_path / \"train\")\n",
"test_dataset = SentimentAnalysisDataset(root_path / \"test\")\n",
"\n",
"print(\"- Number of samples in train and test sets:\", len(train_dataset), len(test_dataset))\n",
"print(\"- First train sample:\", train_dataset[0])\n",
"print(\"- First test sample:\", test_dataset[0])"
]
},
{
"cell_type": "markdown",
"id": "a82f87d5-9bbf-4097-9baf-001e1f368561",
"metadata": {},
"source": [
"Now, we can create a string-to-tokens preprocessing transformation and set up data loaders. We are going to use the GPT-2 tokenizer via [Tiktoken](https://github.com/openai/tiktoken)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "75f77d72-b73f-41f0-a24c-2e37f3e463a8",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import tiktoken\n",
"import grain.python as grain\n",
"\n",
"\n",
"seed = 12\n",
"train_batch_size = 128\n",
"test_batch_size = 2 * train_batch_size\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"# max length of tokenized text\n",
"max_length = 500\n",
"vocab_size = tokenizer.n_vocab\n",
"\n",
"\n",
"class TextPreprocessing(grain.MapTransform):\n",
" def __init__(self, tokenizer, max_length: int = 256):\n",
" self.tokenizer = tokenizer\n",
" self.max_length = max_length\n",
"\n",
" def map(self, data):\n",
" text = data[\"text\"]\n",
" encoded = self.tokenizer.encode(text)\n",
" # Cut to max length\n",
" encoded = encoded[:self.max_length]\n",
" # Pad with zeros if needed\n",
" encoded = np.array(encoded + [0] * (self.max_length - len(encoded)))\n",
" return {\n",
" \"text\": encoded,\n",
" \"label\": data[\"label\"],\n",
" }\n",
"\n",
"\n",
"train_sampler = grain.IndexSampler(\n",
" len(train_dataset),\n",
" shuffle=True,\n",
" seed=seed,\n",
" shard_options=grain.NoSharding(), # No sharding since this is a single-device setup\n",
" num_epochs=1, # Iterate over the dataset for one epoch\n",
")\n",
"\n",
"test_sampler = grain.IndexSampler(\n",
" len(test_dataset),\n",
" shuffle=False,\n",
" seed=seed,\n",
" shard_options=grain.NoSharding(), # No sharding since this is a single-device setup\n",
" num_epochs=1, # Iterate over the dataset for one epoch\n",
")\n",
"\n",
"\n",
"train_loader = grain.DataLoader(\n",
" data_source=train_dataset,\n",
" sampler=train_sampler, # Sampler to determine how to access the data\n",
" worker_count=4, # Number of child processes launched to parallelize the transformations among\n",
" worker_buffer_size=2, # Count of output batches to produce in advance per worker\n",
" operations=[\n",
" TextPreprocessing(tokenizer, max_length=max_length),\n",
" grain.Batch(train_batch_size, drop_remainder=True),\n",
" ]\n",
")\n",
"\n",
"test_loader = grain.DataLoader(\n",
" data_source=test_dataset,\n",
" sampler=test_sampler, # Sampler to determine how to access the data\n",
" worker_count=4, # Number of child processes launched to parallelize the transformations among\n",
" worker_buffer_size=2, # Count of output batches to produce in advance per worker\n",
" operations=[\n",
" TextPreprocessing(tokenizer, max_length=max_length),\n",
" grain.Batch(test_batch_size),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7382398d-2304-436c-9e79-2487f6a4d21a",
"metadata": {},
"outputs": [],
"source": [
"train_batch = next(iter(train_loader))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0c3da4d2-70f5-45a7-a72c-ceca0eca65ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train encoded text batch info: (128, 500) int64\n",
"Train labels batch info: (128,) int64\n"
]
}
],
"source": [
"print(\"Train encoded text batch info:\", type(train_batch[\"text\"]), train_batch[\"text\"].shape, train_batch[\"text\"].dtype)\n",
"print(\"Train labels batch info:\", type(train_batch[\"label\"]), train_batch[\"label\"].shape, train_batch[\"label\"].dtype)"
]
},
{
"cell_type": "markdown",
"id": "0129084e-a28a-4612-bc54-28c8a4a84c9b",
"metadata": {},
"source": [
"Let's check few samples of the training batch. We expect to see integer tokens for the input text and integer value for the labels:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cd6606ff-eb64-4dc6-b7bc-21f6eb540aaf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train batch data: [[ 464 8258 2128 326 345 743 3285 618 345 16067 439 428]\n",
" [ 5297 11 428 3180 257 9961 43469 2646 290 340 373 257]] [1 0]\n"
]
}
],
"source": [
"print(\"Train batch data:\", train_batch[\"text\"][:2, :12], train_batch[\"label\"][:2])"
]
},
{
"cell_type": "markdown",
"id": "7e5c502c-7fd0-4f10-a0d5-007f9bc139a4",
"metadata": {},
"source": [
"## Model for text classification\n",
"\n",
"We choose a simple 1D convnet to classify the text. The first layer of the model transforms input tokens into float features using an embedding layer (`nnx.Embed`), then they are encoded further with convolutions. Finally, we classify encoded features using fully-connected layers."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1546b8d0-9c0c-4970-a8b6-67276fb08e2a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction shape (N, num_classes): (4, 2)\n"
]
}
],
"source": [
"from typing import Callable\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import nnx\n",
"\n",
"\n",
"class TextConvNet(nnx.Module):\n",
" def __init__(\n",
" self,\n",
" vocab_size: int,\n",
" num_classes: int = 2,\n",
" embed_dim: int = 256,\n",
" hidden_dim: int = 320,\n",
" dropout_rate: float = 0.5,\n",
" conv_ksize: int = 12,\n",
" activation_layer: Callable = nnx.relu,\n",
" rngs: nnx.Rngs = nnx.Rngs(0),\n",
" ):\n",
" self.activation_layer = activation_layer\n",
" self.token_embedding = nnx.Embed(\n",
" num_embeddings=vocab_size,\n",
" features=embed_dim,\n",
" rngs=rngs,\n",
" )\n",
" self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)\n",
" self.conv1 = nnx.Conv(\n",
" in_features=embed_dim,\n",
" out_features=hidden_dim,\n",
" kernel_size=conv_ksize,\n",
" strides=conv_ksize // 2,\n",
" rngs=rngs,\n",
" )\n",
" self.lnorm1 = nnx.LayerNorm(hidden_dim, rngs=rngs)\n",
" self.conv2 = nnx.Conv(\n",
" in_features=hidden_dim,\n",
" out_features=hidden_dim,\n",
" kernel_size=conv_ksize,\n",
" strides=conv_ksize // 2,\n",
" rngs=rngs,\n",
" )\n",
" self.lnorm2 = nnx.LayerNorm(hidden_dim, rngs=rngs)\n",
"\n",
" self.fc1 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)\n",
" self.fc2 = nnx.Linear(hidden_dim, num_classes, rngs=rngs)\n",
"\n",
" def __call__(self, x: jax.Array) -> jax.Array:\n",
" # x.shape: (N, max_length)\n",
" x = self.token_embedding(x)\n",
" x = self.dropout(x) # x.shape: (N, max_length, embed_dim)\n",
"\n",
" x = self.conv1(x)\n",
" x = self.lnorm1(x)\n",
" x = self.activation_layer(x)\n",
" x = self.conv2(x)\n",
" x = self.lnorm2(x)\n",
" x = self.activation_layer(x) # x.shape: (N, K, hidden_dim)\n",
"\n",
" x = nnx.max_pool(x, window_shape=(x.shape[1], )) # x.shape: (N, 1, hidden_dim)\n",
" x = x.reshape((-1, x.shape[-1])) # x.shape: (N, hidden_dim)\n",
"\n",
" x = self.fc1(x) # x.shape: (N, hidden_dim)\n",
" x = self.activation_layer(x)\n",
" x = self.dropout(x)\n",
" x = self.fc2(x) # x.shape: (N, 2)\n",
"\n",
" return x\n",
"\n",
"\n",
"# Let's check the model on a dummy input\n",
"x = jnp.ones((4, max_length), dtype=\"int32\")\n",
"module = TextConvNet(vocab_size)\n",
"y = module(x)\n",
"print(\"Prediction shape (N, num_classes): \", y.shape)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "39ff08da-dbcf-408d-a103-12b06229e834",
"metadata": {},
"outputs": [],
"source": [
"model = TextConvNet(\n",
" vocab_size,\n",
" num_classes=2,\n",
" embed_dim=128,\n",
" hidden_dim=128,\n",
" conv_ksize=7,\n",
" activation_layer=nnx.relu,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "00215ea2-4fa3-43e5-b64c-0bfb4b667d13",
"metadata": {},
"source": [
"## Train the model\n",
"\n",
"We can now train the model using training data loader and compute metrics: accuracy and loss on test data loader.\n",
"Below we set up the optimizer and define the loss function as Cross-Entropy.\n",
"Next, we define the train step where we compute the loss value and update the model parameters.\n",
"In the eval step we use the model to compute the metrics: accuracy and loss value."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "6d9f4756-4e64-49d1-81dd-20c0e0480dd0",
"metadata": {},
"outputs": [],
"source": [
"import optax\n",
"\n",
"\n",
"num_epochs = 10\n",
"learning_rate = 0.0005\n",
"momentum = 0.9\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "3fb1599e-c4cc-4d52-bf0a-1a9e7be043ee",
"metadata": {},
"outputs": [],
"source": [
"def compute_losses_and_logits(model: nnx.Module, batch_tokens: jax.Array, labels: jax.Array):\n",
" logits = model(batch_tokens)\n",
"\n",
" loss = optax.softmax_cross_entropy_with_integer_labels(\n",
" logits=logits, labels=labels\n",
" ).mean()\n",
" return loss, logits"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "f526aaf5-80c3-4b6a-b82a-cdd0d8c180c7",
"metadata": {},
"outputs": [],
"source": [
"@nnx.jit\n",
"def train_step(\n",
" model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, jax.Array]\n",
"):\n",
" # Convert numpy arrays to jax.Array on GPU\n",
" batch_tokens = jnp.array(batch[\"text\"])\n",
" labels = jnp.array(batch[\"label\"], dtype=jnp.int32)\n",
"\n",
" grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)\n",
" (loss, logits), grads = grad_fn(model, batch_tokens, labels)\n",
"\n",
" optimizer.update(grads) # In-place updates.\n",
"\n",
" return loss\n",
"\n",
"\n",
"@nnx.jit\n",
"def eval_step(\n",
" model: nnx.Module, batch: dict[str, jax.Array], eval_metrics: nnx.MultiMetric\n",
"):\n",
" # Convert numpy arrays to jax.Array on GPU\n",
" batch_tokens = jnp.array(batch[\"text\"])\n",
" labels = jnp.array(batch[\"label\"], dtype=jnp.int32)\n",
" loss, logits = compute_losses_and_logits(model, batch_tokens, labels)\n",
"\n",
" eval_metrics.update(\n",
" loss=loss,\n",
" logits=logits,\n",
" labels=labels,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "cf8c8601-5261-4c9a-8d74-192504bd3836",
"metadata": {},
"outputs": [],
"source": [
"eval_metrics = nnx.MultiMetric(\n",
" loss=nnx.metrics.Average('loss'),\n",
" accuracy=nnx.metrics.Accuracy(),\n",
")\n",
"\n",
"\n",
"train_metrics_history = {\n",
" \"train_loss\": [],\n",
"}\n",
"\n",
"eval_metrics_history = {\n",
" \"test_loss\": [],\n",
" \"test_accuracy\": [],\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "4fe942f3-13e6-4e8c-87f8-06c62ef25c82",
"metadata": {},
"outputs": [],
"source": [
"import tqdm\n",
"\n",
"\n",
"bar_format = \"{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]\"\n",
"train_total_steps = len(train_dataset) // train_batch_size\n",
"\n",
"\n",
"def train_one_epoch(epoch):\n",
" model.train() # Set model to the training mode: e.g. update batch statistics\n",
" with tqdm.tqdm(\n",
" desc=f\"[train] epoch: {epoch}/{num_epochs}, \",\n",
" total=train_total_steps,\n",
" bar_format=bar_format,\n",
" leave=True,\n",
" ) as pbar:\n",
" for batch in train_loader:\n",
" loss = train_step(model, optimizer, batch)\n",
" train_metrics_history[\"train_loss\"].append(loss.item())\n",
" pbar.set_postfix({\"loss\": loss.item()})\n",
" pbar.update(1)\n",
"\n",
"\n",
"def evaluate_model(epoch):\n",
" # Compute the metrics on the train and val sets after each training epoch.\n",
" model.eval() # Set model to evaluation model: e.g. use stored batch statistics\n",
"\n",
" eval_metrics.reset() # Reset the eval metrics\n",
" for test_batch in test_loader:\n",
" eval_step(model, test_batch, eval_metrics)\n",
"\n",
" for metric, value in eval_metrics.compute().items():\n",
" eval_metrics_history[f'test_{metric}'].append(value)\n",
"\n",
" print(f\"[test] epoch: {epoch + 1}/{num_epochs}\")\n",
" print(f\"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}\")\n",
" print(f\"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "21d76e88-a037-431a-80aa-fc43f79768c7",
"metadata": {},
"source": [
"Now, we can start the training."
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "7e37b8b4-9e11-4f10-874c-da66723b5ef3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 0/10, [192/195], loss=0.697 [00:05<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 1/10\n",
"- total loss: 0.6923\n",
"- Accuracy: 0.5106\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 1/10, [192/195], loss=0.691 [00:03<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 2/10\n",
"- total loss: 0.6922\n",
"- Accuracy: 0.5422\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 2/10, [192/195], loss=0.678 [00:03<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 3/10\n",
"- total loss: 0.6754\n",
"- Accuracy: 0.6263\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 3/10, [192/195], loss=0.339 [00:03<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 4/10\n",
"- total loss: 0.4050\n",
"- Accuracy: 0.8267\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 4/10, [192/195], loss=0.215 [00:03<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 5/10\n",
"- total loss: 0.3307\n",
"- Accuracy: 0.8664\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 5/10, [192/195], loss=0.167 [00:03<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 6/10\n",
"- total loss: 0.3100\n",
"- Accuracy: 0.8764\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 6/10, [192/195], loss=0.112 [00:03<00:00] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 7/10\n",
"- total loss: 0.3434\n",
"- Accuracy: 0.8692\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 7/10, [192/195], loss=0.0814 [00:03<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 8/10\n",
"- total loss: 0.3653\n",
"- Accuracy: 0.8760\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 8/10, [192/195], loss=0.0982 [00:03<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 9/10\n",
"- total loss: 0.4136\n",
"- Accuracy: 0.8664\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[train] epoch: 9/10, [192/195], loss=0.0731 [00:03<00:00]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[test] epoch: 10/10\n",
"- total loss: 0.4443\n",
"- Accuracy: 0.8664\n",
"CPU times: user 25.8 s, sys: 3.42 s, total: 29.3 s\n",
"Wall time: 1min 17s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"for epoch in range(num_epochs):\n",
" train_one_epoch(epoch)\n",
" evaluate_model(epoch)"
]
},
{
"cell_type": "markdown",
"id": "5f18cd48-fbc2-4ba2-80ba-3d124579844c",
"metadata": {},
"source": [
"Let's visualize the collected metrics:"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "2e9bce0d-406f-47dc-9963-ce09f93c6290",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjEAAAGdCAYAAADjWSL8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABlWElEQVR4nO3dd1zU9eMH8NdxbFkqyFAU90TciCNLSTQzbZr5zZFZllZmw8yVlWKl5tdylLm+3zStvmn9ykxDcZIDxS0qojiYIlvW3ef3B3Le5/aCu+Nezx48uvvcZ7w/HHIv3lMiCIIAIiIiIjvjZO0CEBEREZmCIYaIiIjsEkMMERER2SWGGCIiIrJLDDFERERklxhiiIiIyC4xxBAREZFdYoghIiIiu+Rs7QIYQi6X4/bt2/D29oZEIrF2cYiIiMgAgiCgsLAQISEhcHKyfL2JXYSY27dvIzQ01NrFICIiIhPcuHEDTZo0sfh57SLEeHt7A6j6Jvj4+Fi5NERERGSIgoIChIaGKj7HLc0uQkx1E5KPjw9DDBERkZ2pqa4g7NhLREREdokhhoiIiOwSQwwRERHZJbvoE0NEtk0QBFRWVkImk1m7KERUi6RSKZydna02/QlDDBGZpby8HOnp6SgpKbF2UYjICjw9PREcHAxXV9davzZDDBGZTC6XIzU1FVKpFCEhIXB1deWElEQOQhAElJeXIzs7G6mpqWjdunWNTGinC0MMEZmsvLwccrkcoaGh8PT0tHZxiKiWeXh4wMXFBdevX0d5eTnc3d1r9frs2EtEZqvtv76IyHZY898/f/MQERGRXWKIISKyc/Hx8ZBIJMjLy6v1az/88MOYNm2a2efZsGED/Pz8zD6PqSQSCbZv32616xvio48+QpcuXYw6JiwsDMuWLauR8tgChhgickjjx4/HyJEjrV0Mum/UqFG4dOlSjV/HlCBgKksHs3fffRdxcXFGHXPs2DG88sorFiuDrWHHXiIisqqKigp4eHjAw8PD2kWxivLycoOGJ3t5ecHLy8uocwcEBJhaLLvg0DUxlTI51h5MxfnbBdYuChHZmH379qFXr15wc3NDcHAwPvjgA1RWVipe//nnnxEeHg4PDw80bNgQ0dHRKC4uBlDVvNOrVy/Uq1cPfn5+6Nu3L65fv67xOn369MGMGTNE27Kzs+Hi4oL9+/cDAP773/+iR48e8Pb2RlBQEF544QVkZWVpLbum2oZly5YhLCxMtO27775D+/bt4e7ujnbt2mHlypU6vyfFxcUYO3YsvLy8EBwcjCVLlqjto6lZxs/PDxs2bAAAXLt2DRKJBFu3bsWAAQPg7u6OTZs2qdVaVN/Df//7X4SFhcHX1xfPP/88CgsLFfsUFhZizJgxqFevHoKDg/Hll1/qbN7asGED5s+fj1OnTkEikUAikSjKBQA5OTl48skn4enpidatW+O3334THX/27FkMHToUXl5eCAwMxIsvvoicnByN14qPj8eECROQn5+vuNZHH30EoKqJ55NPPsHYsWPh4+OjqCmZMWMG2rRpA09PT7Ro0QJz5sxBRUWF2vekWnVt4uLFixEcHIyGDRtiypQpomNUm5MkEgm+++47nff522+/oXXr1nB3d8cjjzyCjRs3Wq25Uh+HDjGbjqThk9/P47HlB6xdFKI6QxAElJRXWuVLEASL3MOtW7fw2GOPoWfPnjh16hRWrVqFtWvX4tNPPwUApKenY/To0XjppZdw4cIFxMfH46mnnlLMXDxy5EgMGDAAp0+fRkJCAl555RWt8+eMGTMGW7ZsEZV969atCAkJQf/+/QFU1VR88sknOHXqFLZv345r165h/PjxZt3jpk2bMHfuXCxYsAAXLlzAwoULMWfOHGzcuFHrMe+99x727duHX3/9Fbt27UJ8fDxOnDhh0vU/+OADvPXWW7hw4QJiYmI07pOSkoLt27fj999/x++//459+/Zh0aJFitenT5+OQ4cO4bfffsPu3btx4MABneUZNWoU3nnnHXTs2BHp6elIT0/HqFGjFK/Pnz8fzz33HE6fPo3HHnsMY8aMQW5uLgAgLy8PAwcORNeuXXH8+HHs3LkTmZmZeO655zReq0+fPli2bBl8fHwU13r33XcVry9evBgRERE4efIk5syZAwDw9vbGhg0bcP78efz73//GmjVr8OWXX+r8Pu7duxcpKSnYu3cvNm7ciA0bNoiCmSa67jM1NRXPPPMMRo4ciVOnTuHVV1/FrFmzdJ7Pmhy6OenMrXxrF4GozrlXIUOHuX9Z5drnP46Bp6v5v9ZWrlyJ0NBQfP3115BIJGjXrh1u376NGTNmYO7cuUhPT0dlZSWeeuopNGvWDAAQHh4OAMjNzUV+fj4ef/xxtGzZEgDQvn17rdd67rnnMG3aNBw8eFARWjZv3ozRo0crgs9LL72k2L9FixZYvnw5evbsiaKiIqObF6rNmzcPS5YswVNPPQUAaN68Oc6fP49vvvkG48aNU9u/qKgIa9euxffff49BgwYBADZu3IgmTZqYdP1p06Yprq2NXC7Hhg0b4O3tDQB48cUXERcXhwULFqCwsBAbN27E5s2bFeVZv349QkJCtJ7Pw8MDXl5ecHZ2RlBQkNrr48ePx+jRowEACxcuxPLly3H06FEMGTIEX3/9Nbp27YqFCxcq9l+3bh1CQ0Nx6dIltGnTRnQuV1dX+Pr6QiKRaLzWwIED8c4774i2zZ49W/E4LCwM7777LrZs2YL3339f6z3Vr18fX3/9NaRSKdq1a4dhw4YhLi4OkyZN0nqMrvv85ptv0LZtW3zxxRcAgLZt2+Ls2bNYsGCB1vNZk0PXxFjojzYiqmMuXLiAqKgoUe1J3759UVRUhJs3byIiIgKDBg1CeHg4nn32WaxZswZ3794FADRo0ADjx49HTEwMhg8fjn//+99IT0/Xeq2AgAAMHjwYmzZtAlD1l3BCQgLGjBmj2CcxMRHDhw9H06ZN4e3tjQEDBgAA0tLSTLq/4uJipKSkYOLEiYp+Fl5eXvj000+RkpKi8ZiUlBSUl5cjMjJSsa1BgwZo27atSWXo0aOH3n3CwsIUAQYAgoODFc1oV69eRUVFBXr16qV43dfX1+TyAEDnzp0Vj+vVqwcfHx/F9U6dOoW9e/eKvl/t2rUDAK3fM1003f/WrVvRt29fBAUFwcvLC7Nnz9b7Hnfs2BFSqVTxXPl7pI2u+0xOTkbPnj1F+yt/j22NQ9fECGCKIbI0Dxcpzn+suXmgNq5dG6RSKXbv3o3Dhw9j165d+OqrrzBr1iwcOXIEzZs3x/r16/Hmm29i586d2Lp1K2bPno3du3ejd+/eGs83ZswYvPnmm/jqq6+wefNmhIeHK2p2iouLERMTg5iYGGzatAkBAQFIS0tDTEwMysvLNZ7PyclJrWlNuZ9EUVERAGDNmjWiUFJ9b+aQSCQ6r12tXr16es/l4uKidm65XG5W+Uy9XlFREYYPH47PPvtM7bjg4GCjr6V6/9XBdf78+YiJiYGvry+2bNmisd+RoWW25DG2ijUxRGRREokEnq7OVvmy1LpN7du3R0JCgujD+NChQ/D29lY0n0gkEvTt2xfz58/HyZMn4erqim3btin279q1K2bOnInDhw+jU6dO2Lx5s9brjRgxAqWlpdi5cyc2b94sqoW5ePEi7ty5g0WLFqF///5o166d3r+0AwICkJGRISp/UlKS4nFgYCBCQkJw9epVtGrVSvTVvHlzjeds2bIlXFxccOTIEcW2u3fvqg2LDggIENU8Xb58uUYWB23RogVcXFxw7Ngxxbb8/Hy9w7RdXV1NWm29W7duOHfuHMLCwtS+Z9oCmTHXOnz4MJo1a4ZZs2ahR48eaN26tdbO4DWpbdu2OH78uGib8vfY1jh4iGGKIXJk+fn5SEpKEn3duHEDr7/+Om7cuIE33ngDFy9exK+//op58+Zh+vTpcHJywpEjR7Bw4UIcP34caWlp+OWXX5CdnY327dsjNTUVM2fOREJCAq5fv45du3bh8uXLOvvF1KtXDyNHjsScOXNw4cIFRX8FAGjatClcXV3x1Vdf4erVq/jtt9/wySef6Lyvhx9+GNnZ2fj888+RkpKCFStW4M8//xTtM3/+fMTGxmL58uW4dOkSzpw5g/Xr12Pp0qUaz+nl5YWJEyfivffew549e3D27FmMHz9ebcr5gQMH4uuvv8bJkydx/PhxTJ48We0vf0vw9vbGuHHj8N5772Hv3r04d+4cJk6cCCcnJ51hNiwsDKmpqUhKSkJOTg7KysoMut6UKVOQm5uL0aNH49ixY0hJScFff/2FCRMmaA0qYWFhKCoqQlxcHHJycnSGudatWyMtLQ1btmxBSkoKli9fLgrFteXVV1/FxYsXMWPGDFy6dAk//vijoqOwLS7u6tghxtoFICKrio+PR9euXUVf8+fPR+PGjbFjxw4cPXoUERERmDx5MiZOnKjoeOnj44P9+/fjscceQ5s2bTB79mwsWbIEQ4cOhaenJy5evIinn34abdq0wSuvvIIpU6bg1Vdf1VmWMWPG4NSpU+jfvz+aNm2q2B4QEIANGzbgp59+QocOHbBo0SIsXrxY57nat2+PlStXYsWKFYiIiMDRo0dFI2MA4OWXX8Z3332H9evXIzw8HAMGDMCGDRu01sQAwBdffIH+/ftj+PDhiI6ORr9+/dC9e3fRPkuWLEFoaCj69++PF154Ae+++26NLQ66dOlSREVF4fHHH0d0dDT69u2rGDKuzdNPP40hQ4bgkUceQUBAAH744QeDrhUSEoJDhw5BJpNh8ODBCA8Px7Rp0+Dn56d17aA+ffpg8uTJGDVqFAICAvD5559rPf8TTzyBt99+G1OnTkWXLl1w+PBhxail2tS8eXP8/PPP+OWXX9C5c2esWrVKMTrJzc2t1sujj0Swg+qIgoIC+Pr6Ij8/Hz4+PhY77xs/nMT/nboNALi2aJjFzkvkKEpLS5GamormzZvX+uq1RKqKi4vRuHFjLFmyBBMnTrR2ceqMBQsWYPXq1bhx44bG13X9Hqipz+9qjt2x1/bzGxERaXHy5ElcvHgRvXr1Qn5+Pj7++GMAVX2MyHQrV65Ez5490bBhQxw6dAhffPEFpk6dau1iaeTYIcbaBSAiIrMsXrwYycnJcHV1Rffu3XHgwAH4+/tbu1h27fLly/j000+Rm5uLpk2b4p133sHMmTOtXSyNHDvEsCaGiMhude3aFYmJidYuRp3z5Zdf6p0p2FY4dsdeZhgiIiK7xRBDREREdsmhQ4ycKYbIItg0S+S4rPnv36FDDBGZp3oSs5qYkZWI7EP1v/+amNRQH6M79u7fvx9ffPEFEhMTkZ6ejm3btmHkyJE6j4mPj8f06dNx7tw5hIaGYvbs2WYvI28J/NuRyDxSqRR+fn6KafA9PT1tclZPIrI8QRBQUlKCrKws+Pn5mb3ulimMDjHFxcWIiIjASy+9pHcZdaBqRdZhw4Zh8uTJ2LRpE+Li4vDyyy8jODgYMTHWWSSuGqvAicwXFBQEAHrX8yGiusnPz0/xe6C2GR1ihg4diqFDhxq8/+rVq9G8eXPFSpzt27fHwYMH8eWXX9pAiLHq5YnqBIlEguDgYDRq1EjjasVEVHe5uLhYpQamWo3PE5OQkIDo6GjRtpiYGEybNq2mL60XMwyR5UilUqv+MiMix1PjISYjIwOBgYGibYGBgSgoKMC9e/fg4eGhdkxZWZloZdGCgoIaKRtHJxEREdkvmxydFBsbC19fX8VXaGhojVyHGYaIiMh+1XiICQoKQmZmpmhbZmYmfHx8NNbCAMDMmTORn5+v+NK2cqa5mGGIiIjsV403J0VFRWHHjh2ibbt370ZUVJTWY9zc3ODm5lbTRePoJCIiIjtmdE1MUVERkpKSkJSUBKBqCHVSUhLS0tIAVNWijB07VrH/5MmTcfXqVbz//vu4ePEiVq5ciR9//BFvv/22Ze7ADMwwRERE9svoEHP8+HF07doVXbt2BQBMnz4dXbt2xdy5cwEA6enpikADAM2bN8cff/yB3bt3IyIiAkuWLMF3331n9eHVACCwQYmIiMhuGd2c9PDDD+tshtmwYYPGY06ePGnspWqcXG7tEhAREZGpbHJ0Um3hEGsiIiL7xRBDREREdsmhQ4xMzhBDRERkrxw6xDDDEBER2S8HDzFMMURERPbKoUMMm5OIiIjsF0MMERER2SWHDjFsTiIiIrJfDh1iWBNDRERkvxw6xDDDEBER2S+HDjGsiSEiIrJfDh1i2CeGiIjIfjl2iGFNDBERkd1y6BAjY00MERGR3XLsECO3dgmIiIjIVA4dYtgnhoiIyH45dIjp3aKBtYtAREREJnLoEPPBkPYAAE9XqZVLQkRERMZy6BAjkVT9n61KRERE9sehQ0w1AUwxRERE9sahQwxrYoiIiOyXg4eYqhTDDENERGR/HDvEVD9giiEiIrI7jh1iqpuTmGKIiIjsjmOHmPt1MewTQ0REZH8cO8QoamKIiIjI3jh2iLn/f4FVMURERHbHoUMMWBNDRERktxw6xLBPDBERkf1y7BAj0b8PERER2SbHDjFKj9kvhoiIyL44dohRqophhiEiIrIvjh1ilB7LmWKIiIjsimOHGKUUwwhDRERkXxw7xIDNSURERPbKoUMMRDUxTDFERET2xKFDjKg5iRmGiIjIrjh2iLF2AYiIiMhkjh1iOMSaiIjIbjl2iFF6zD4xRERE9sWxQwz7xBAREdktxw4xykOsrVgOIiIiMp5jhxhRTQxjDBERkT1x6BCjjBGGiIjIvjh0iGGfGCIiIvvl2CEGXDyJiIjIXjl0iHHisgNERER2y6FDDCe7IyIisl+OHWKUHjPDEBER2RfHDjEcYk1ERGS3HDzEcLI7IiIie+XQIUYZK2KIiIjsi8OHmOrKGI5OIiIisi8MMdUPmGGIiIjsikkhZsWKFQgLC4O7uzsiIyNx9OhRnfsvW7YMbdu2hYeHB0JDQ/H222+jtLTUpAJbWnW/GGYYIiIi+2J0iNm6dSumT5+OefPm4cSJE4iIiEBMTAyysrI07r9582Z88MEHmDdvHi5cuIC1a9di69at+PDDD80uvCVU18TI2SmGiIjIrhgdYpYuXYpJkyZhwoQJ6NChA1avXg1PT0+sW7dO4/6HDx9G37598cILLyAsLAyDBw/G6NGj9dbe1JZKeVV4Sc0utnJJiIiIyBhGhZjy8nIkJiYiOjr6wQmcnBAdHY2EhASNx/Tp0weJiYmK0HL16lXs2LEDjz32mNbrlJWVoaCgQPRV017aeKzGr0FERESW42zMzjk5OZDJZAgMDBRtDwwMxMWLFzUe88ILLyAnJwf9+vWDIAiorKzE5MmTdTYnxcbGYv78+cYUzWylFfJavR4RERGZp8ZHJ8XHx2PhwoVYuXIlTpw4gV9++QV//PEHPvnkE63HzJw5E/n5+YqvGzdu1HQxiYiIyM4YVRPj7+8PqVSKzMxM0fbMzEwEBQVpPGbOnDl48cUX8fLLLwMAwsPDUVxcjFdeeQWzZs2Ck5N6jnJzc4Obm5sxRSMiIiIHY1RNjKurK7p37464uDjFNrlcjri4OERFRWk8pqSkRC2oSKVSALa3XtGPx2+gQsZmJSIiIntgVE0MAEyfPh3jxo1Djx490KtXLyxbtgzFxcWYMGECAGDs2LFo3LgxYmNjAQDDhw/H0qVL0bVrV0RGRuLKlSuYM2cOhg8frggztuL9n08jt7gckwe0tHZRiIiISA+jQ8yoUaOQnZ2NuXPnIiMjA126dMHOnTsVnX3T0tJENS+zZ8+GRCLB7NmzcevWLQQEBGD48OFYsGCB5e7Cgg5dyWGIISIisgMSwdbadDQoKCiAr68v8vPz4ePjY9Fzh33wh+j5Q20C8J+Xeln0GkRERI6oJj+/Aa6dRERERHaKIUaFRP8uREREZAMYYlRImGKIiIjsAkOMCmYYIiIi+8AQo2JvcjaW7kq2djGIiIhID4YYDZbvuQKZ3OYHbRERETk0hhgt5LY/8pyIiMihMcRowRBDRERk2xhitJBzCSUiIiKbxhCjhYw1MURERDaNIUYLNicRERHZNoYYLeQcnURERGTTGGK0UM4wO8+mI/F6rvUKQ0RERGqcrV0AW1U9T8yVrEJM/v4EAODaomHWLBIREREpYU2MFsL9PjHXckqsXBIiIiLShCFGi+rRSZXsG0NERGSTGGK0qM4ulZwwhoiIyCYxxGgxbt1RVMjkojWUBA67tnuCIOBW3j2+l0REdQBDjBZXsorQetafOJmWp9jGliX7t2pfCvou2oOv91yxdlGIiMhMDDF6bDh8TfGYTUv27/OdyQCAJbsvWbkkRERkLoaY+1oE1NO7j0wuoKS8El/FXcalzMJaKJV9yC4sQ0Z+qbWLQUREDoYh5r7I5g307nPmZj6Wx13Bkt2XMPjL/QCA/Zey8dTKQ7icWYiM/FKH+zCXyQX0XPA3esfGobRCZu3iEBGRA+Fkd/c5O+nPc9O2JiG0gafi+Ytrj+DA5RwAwNTNJ5F8v3bm8oKhcJFaJh+WVsiQnFGI8Ma+cHKSWOScuhSUVsDH3cXg/a/fKVY8joqNw+xhHdA+2AcdQnxqonh6lZRXwtOVP9ZERI6ANTH3GRI60vNLcTT1wfID1QEGgCLAAEDrWX8i7IM/kHg9F7+duo3MAvXamcyCUqTdKcG9chkEQUCFTI4vd1/C8Wvi5Q2mbUnCiBWHsPlomt7yXckqwie/n0dWoWm1QcvjLqPzR7uw+3ymwcfcvHtP8fhuSQXe+ekUHlt+wKTrmyt2xwV0mPuX2veQiIjqJoaY+1yklq/leHpVAt784SQiF8YBAMor5ZDLBczadgaRC+Pw0Bd70X7uTrz/82m8sfkk/h13Gc+sTsDJtLsAqppqdp7LAACs3pcCAPj+n+sI++APJKTcAQBkFZbiSlYhrmYXIXrpPqw9mIrPdyajvNL4TshL73d2/XDbGZ37CYKAPRczce52PvLvVWjc5+DlHMT+eQE5RWVaz1NUVonDV3IMWmzz7K189F20B/9LvKnYll1YhqvZRfjx2A3I5AK+2X8VAPDZzouiY++Vy3Cv3LymLkEQ8PnOi9h5Nt2s8xARkeVIBDuYMKOgoAC+vr7Iz8+Hj49lmynCPvgDALD6X90UayTZAl8PF4T4eeBCeoFim7e7MwpLKxXPPxnZCXO2n9V6jon9mkMuCNh1LhMrx3TDn2cz0CKgHo6m5uLn+2Hgr2kPIbe4HMv+voQjSrVM85/oiM5NfLH5SBp+ur/v4Q8G4pcTN7F4l+Eje5rU98CB9x9Bwb1KSJwACYCtx24gu7BMETrei2mLf0U2w/7L2RjUvhHcnKU4fTMP+y5lY8ojreAidVK8T0DVGlYHLmfjxbVHNV6zV1gD/Dg5CkBVc1y/z/bAy80Z0e0D8d3BVMV+mpr95HIBMkFQ2777fCYm/ee44vpERKRfTX5+AwwxSM4oxLnb+Xiya2N0mvcXis38i51sw/g+YTh2LRcRoX7YfERzU5yfpwtaBXjhxahmeCIiBNfulOCplYfQ0MsNm1+OxLK4y3ihV1N0auyL/yZcw5xfzwFgiCEiMhRDDGr+m1DtXrkM7efurLHzk+2KatEQCVfvaHzt2qJh2Hj4Gub9xhBDRGSMmv78Zp8YJR6uUoP39fdyxdioZhjUrhGaKo1YMtfqf3XDGwNbaXytbaC3xa5TUzoEW2dUkrm0BRgAeGRxPDYduV6LpSEiIkNwLKoKF6kEFbKqyqlZj7XHgh0X1PZ5IiIEy0d3FW1T7rOhyeQBLTFjSFs0n7lD535DOgWjX+sAJN3IQ7OGnvj+nzS4OTth8oCWmBbdGkVllXCROuG3U7fx/T/XMW94BzT398K2k7dwI7cE/1y9g4sZhWgX5I1ne4SiaQNPRLVsiDtFZZBAgh+P30DP5g0woE0ASitk+HL3JUXfFGWerlKsH98To779R9+3TGTLq73h4+6CtQdT8cnv5wEAP74aBW93Z7QM8MLpm3l4ZnUCAKCFfz1MHtAS7//vtMZz/TQ5ChUyOV5Yc8SoMlhaak6x/p2IiKjWsTlJxZWsQkQvrZrI7tqiYaJw8v3ESOTfq8Aj7QLU5iJZsisZX+25gnquUkW/mon9mmPt/Y6k7w5ug6kDW+PMzXzE/nkBh1PuoLGfBxY9HY6uTevjnR+T0L91AP7Vu1mN3p8m+SUViPh4FwBgyyu90btFQ8Vr+y5lY9y6o3h1QAu8MbA1TqbdRWFpJd7emoQyDSOgLn4yBO4uVTVae5OzcCmjEK881AISyYPRXxUyOc7dLkCnEB84S52QeP0utp+8hae7N8Ezqw6jUi5g/hMdMa5PGAAgI78Ufp4uOJ9egK1Hb2DW4+3h4+6ieG+6N6uPz57ujOil+2rqWySybnwPDGwXiE1HriOroAxvP9qmVq5LRGRv2CcGtRtiAODPM+kI8HZDj7AGig/KT0Z0xItRYVqPkckFnLudj/bBPoi7kIXUnGK89nBLxfGbX45En1b+iv3zSsrh5eYMZwtNimeu0goZMgtK0ayh/uUXAKBSJkd8cjZe/s9xjI1qhtIKGTxcpJg/opNZ5cgvqUBppQyBPu569/1mXwqWx13G1lej0KmxL07fzMNbW5IwLbo1buSW6BxFtXNaf3i5OWNVfAr6tvKHm7MTJm48bnA5z3w0GOEfVQW/5aO74omIEIOPJSJyFAwxqP0Qo+yn4zdw4HIOFj8bAVdn4wPH+dsFuJxViBFdGtdA6awvq7AUAV5uopqW2lQpk+sNgtVB0svNGXMf74Di8kpM6Ntcbb/DKTkGN12pDslnZ18iInUMMbBuiCH71yc2DrfzSzGoXSOsHd9T576CIOC3U7ex+3wmSitk+PtClkHXYIghIlLH0UlEZvrhld6YPKAlYp8O17uvRCLBiC6N8fUL3dC5iV/NF46IiEzG0UlU5zVrWA8fDG1n9HGVMsOXbqiQyS226CcRERmGv3WJtHi6exOD99W2hhQREdUchhgiLVRHai1+NkLrvgwxRES1jyGGSId692dx/vqFrnhGR83MPzpm/CUioprBPjFEOux972FcyihC31YNde43a9tZjIms/YkKiYgcGWtiiHRo5O2Ofq39DZoHp09sHK7f4RIFRES1hSGGyEJu55diwBfxKCit6h/z+c6LePenU7CDqZiIiOwSQwyRhd3MvQcAWBmfgp8Tb+JSZpGVS0REVDcxxBAZYee0/nr3cXWWiGpfKoyYb4aIiAzHEENkhHZBPvC4v0q3NuWVAuRsQSIiqnEMMURGWju+h87Xy2VyyCyYYtinhohIM4YYIiP1aemv8/Vd5zIw/KuDiufmLPCdkl2Engv+xpr9V00/CRFRHcUQQ2RhK+NTkJxZqHgugekp5uP/O4+conIs2HHBEkUjIqpTGGKIbBgbkoiItGOIIaph5jQnmXEoEVGdxxBDZIJFT4Uj2Ncd/36+i959ncxJMUREpBXXTiIywfO9muL5Xk1xSanvS01g/iEi0o41MURmMCRj5JWUAwAuZRbi1I28Gi0PEZEjMSnErFixAmFhYXB3d0dkZCSOHj2qc/+8vDxMmTIFwcHBcHNzQ5s2bbBjxw6TCkxkSwyZDubV7xMBAIO/3I8RKw4ht7jc4POzIoaISDujm5O2bt2K6dOnY/Xq1YiMjMSyZcsQExOD5ORkNGrUSG3/8vJyPProo2jUqBF+/vlnNG7cGNevX4efn58lyk9kVXIDJqLLK6mAXCntZBaUokE9V4POb8jq2UREjsroELN06VJMmjQJEyZMAACsXr0af/zxB9atW4cPPvhAbf9169YhNzcXhw8fhouLCwAgLCzMvFIT2QhDJ9OVKe1oTC5hhCEi0s6o5qTy8nIkJiYiOjr6wQmcnBAdHY2EhASNx/z222+IiorClClTEBgYiE6dOmHhwoWQyWRar1NWVoaCggLRF5Et8nY37O8A5WUIjBmtxIoYIiLtjAoxOTk5kMlkCAwMFG0PDAxERkaGxmOuXr2Kn3/+GTKZDDt27MCcOXOwZMkSfPrpp1qvExsbC19fX8VXaGioMcUkqjWhDTwx9/EOevdTDjHMJUREllHjo5PkcjkaNWqEb7/9Ft27d8eoUaMwa9YsrF69WusxM2fORH5+vuLrxo0bNV1MIpO91K+53n0qlUOMUdUrjDxERNoY1SfG398fUqkUmZmZou2ZmZkICgrSeExwcDBcXFwglUoV29q3b4+MjAyUl5fD1VW9g6Obmxvc3NyMKRqRTVu9L0Xx2Im5hIjIIoyqiXF1dUX37t0RFxen2CaXyxEXF4eoqCiNx/Tt2xdXrlyBXC5XbLt06RKCg4M1BhiiumhVfIr+nTRgnxgiIu2Mbk6aPn061qxZg40bN+LChQt47bXXUFxcrBitNHbsWMycOVOx/2uvvYbc3Fy89dZbuHTpEv744w8sXLgQU6ZMsdxdENkRQ+aWqcYMQ0SkndFDrEeNGoXs7GzMnTsXGRkZ6NKlC3bu3Kno7JuWlgYnpwfZKDQ0FH/99RfefvttdO7cGY0bN8Zbb72FGTNmWO4uiKzsi2c6Y+YvZ0R9X7QRDB2XDdbEEBHpIhGM+Y1qJQUFBfD19UV+fj58fHysXRwijdLz7yEqdo/e/f6a9hDaBnkbdM7J/03EznNVI/+uLRpmVvmIiGpbTX9+c+0kIguRGlhtsu9SllFLDxARkWYMMUQWIjVw2NHCHRcxZNl+g/ZlcxIRkXYMMUQWYsz8L1mFZQae88Hjc7fzjS0SEVGdxhBDZCE10b1MojQ+adjygxY/PxGRPWOIIbIQD1ep/p2IiMhiGGKILMTT1Rmelg4y7BNDRKQVQwyRBT3StpG1i0BE5DAYYogsSIDh/WKyDejcy4oYIiLtGGKIrKTngr+xNzlL5z7GrXhNRORYGGKIrGjl3is6X2eEISLSjiGGyIaxIoaISDuGGKIacvGTIXr3OXbtLp5edRiFpRW1UCIiorqFIYbIgnzcXRSP3V0MG26deP0u1h+6pvE1VsQQEWnnbO0CENUl78a0RUp2EZ7v2dSo40rKZTVUIiKiuoshhsiC/L3c8NPkPkYfp21oNkcnERFpx+YkIhugbdml0zfzVPaz/PpMRET2iiGGyAbI5ZrDSUp2seg5MwwR0QMMMUQ24LuDqSgqq9S7n5wphohIgSGGqAatGtPN4H23n7yld5/beaUYu+4o/jqXYU6xiIjqBHbsJapBQ8ODLXq+h77YCwDYfykb1xYNs+i5iYjsDWtiiGyEpRuKBEHAuz+dwoI/zlv4zEREtoEhhshGzNl+1qLnu5pTjJ8Tb2LNgVSLnpeIyFYwxBDVUeWVcsVjDs0morqIIYaojlIeycQMQ0R1EUMMUR2lHFyYYYioLmKIIaphwyNCjNpfJhcs3vzD5iQiqosYYohq2PLnuxi8b3mlHI9+uQ/j1x8z+7qsiSGiuo7zxBDVMGMWcTxzKw9Xs4txNbvY7NoT5UUlOdMvEdVFrIkhsiEu0gf/JEsr5Dr2fKBSpnk/UU0MMwwR1UEMMUQ2xNnpwT/J4nL9ayldSC9Ah3l/4bF/H8Dd4nLRa8wtRFTXMcQQ2ZBK+YNalWI9C0JmF5Zh8V/JKK+U43x6AZ7/9h/R6wKHWBNRHccQQ1QLBrQJAABMi26tc7/pP55SPNa3qvWsbWfg7f6gW1tyZqHodeXcwj4xRFQXsWMvUS34blwPpOWWwEkiwbK/L2vd70pWkeKx8oy7mqRkF6FfK3+tr3N0EhHVdayJIaoFLlIntAzwguHjlIBKue7oIZFI4OvpqvV1cXMSYwwR1T0MMUS1yJgo8e5Pp/Tu4+fhYvHrEhHZC4YYolqkbTi0JtfvlOjdx81F+z9h5eAiGH5ZIiK7wRBDVIvK9PRzMYYEgK4WJ7nSiwLrYoioDmKIIapFypPZmUsigc6x06KaGGYYIqqDGGKIalGbQC+Mi2qGtwbpHmptKD19fxWYYYioLmKIIapFEokE80d0wpRHWpl/Lkh0jjpSnhuG88QQUV3EEENkBVInYwZba6crmnDtJCKq6xhiiKzAEhlGItHTsVdgx14iqtsYYoisQCKRoE/LhmafR1dzkkzOKXuJqG5jiCGykk0vR2LWY+1r7PzK+cbQDsBERPaEIYbISiQSiVl9YyQSic4Ou7banFQhk2P+/51D3IVMaxeFiOwcQwyRFbk4m/dPsKRcJnqu3IQkt9GOvT8ev4H1h65h4sbj1i4KEdk5hhgiK/J0kZp8bIVMrrYi9i8nbioey0Qz9tqO9LxSaxeBiOoIhhgiK/J0NT3EpOWqr6303s+nFY9tdRVrC40uJyJiiCGyJg8zQoy+6hVbbU6qWi+BiMh8ztYuAJEja+znYfKx2jrrXs0uQkZBKWSimhiTL2NxjDBEZCmsiSGyotaB3vh4REeTjq2QaU4mA5fswwtrjiA5o0Cx7bdTt1BuwRW0zeHEmhgishCGGCIrGxsVViPnvZheqHi8eNclrNh7pUauYyxmGCKyFIYYIgex+7xtzMvCjr1EZCkMMURUqySsiiEiCzEpxKxYsQJhYWFwd3dHZGQkjh49atBxW7ZsgUQiwciRI025LBEZQTUr2Ep2sJVyEJH9MzrEbN26FdOnT8e8efNw4sQJREREICYmBllZWTqPu3btGt59913079/f5MISkenO3S7AV3GXsXRXMuRyAXkl5Zj0n+PYeTajVssh4fgkIrIQo0PM0qVLMWnSJEyYMAEdOnTA6tWr4enpiXXr1mk9RiaTYcyYMZg/fz5atGhhVoGJyHRLdl/C8j1XsONsOpbuvoTd5zMx+ftEtf0++u0c3vjhZI1MkseaGCKyFKNCTHl5ORITExEdHf3gBE5OiI6ORkJCgtbjPv74YzRq1AgTJ0406DplZWUoKCgQfRGR5dzOu4ecojKtr284fA3/d+o2LmcVWfza7NhLRJZiVIjJycmBTCZDYGCgaHtgYCAyMjRXSR88eBBr167FmjVrDL5ObGwsfH19FV+hoaHGFJPIbg3uEIivRne1yLmU105SpauCRbn2pUJm+bll2JxERJZSo6OTCgsL8eKLL2LNmjXw9/c3+LiZM2ciPz9f8XXjxo0aLCWR9a3+V3dEtWiIj0d0wvCIELi7mP9P80q2abUoNT27r6Wbk3afz8SUzSdQUFph2RMTkc0zatkBf39/SKVSZGaK55vIzMxEUFCQ2v4pKSm4du0ahg8frtgml1f9Zefs7Izk5GS0bNlS7Tg3Nze4ubkZUzQiuzakUxCGdFL/N2SOG7n3tL6mK6fIbXS5Am0m/ec4ACDIxx39Wvvjdt49jIlsZuVSEVFtMOrPPVdXV3Tv3h1xcXGKbXK5HHFxcYiKilLbv127djhz5gySkpIUX0888QQeeeQRJCUlsZmIyEak5z8IPLpyy4X0AvReGIetx9JMvlZNLTuQWVCKCeuPYda2szh3O79GrkFEtsXoBSCnT5+OcePGoUePHujVqxeWLVuG4uJiTJgwAQAwduxYNG7cGLGxsXB3d0enTp1Ex/v5+QGA2nYiqj2qNSw7zmRgYr/mAMQ1Mare/ekUMgpKMeN/ZzCqZ1OTrl0bHXuzCsrQMaTmr0NE1mV0iBk1ahSys7Mxd+5cZGRkoEuXLti5c6eis29aWhqcnDgRMJE9EbQ0IanmmTILLCJZGzP2chg3kWMwOsQAwNSpUzF16lSNr8XHx+s8dsOGDaZcksih+Hq4oLRC+xBocwkqjUbKtS+iPjEQj1S6YoEh1zVVE6McjrhSNpFjYJUJkQ1aO66n2raX7zf3WMLGw9dEz5VrXLSNzD51I88yF6+FgMEQQ+QYGGKIbFCnxr5YN76H4vnS5yLw3pC2Fjt/ZkEZTt140PlVLmpCqtnRScrxoiZmBAY4oR6RozCpOYmIap6fp6vi8eOdQ+Dq/OBvDonE/ICRVViqeCxuTlLfd8/FTNwpKjfvgvcp15IIguUqZkSBiCGGyCEwxBDZqND6norHqjULEugeCm2ICplyjYvmxwBw7FouXtpwXOe5zt7KR2mFDD3CGqCwtALJGYXo3qy+xk68ypvkggAnCyUO5WKzOYnIMTDEENmoAG83fPNid7hIJXCWilt+JZaoilFSXC7DfxOuoU2gN0Z9+49iu4CquWE0Sc4oRNsgb8jlAh7/6iAA4NTcwXhm9WFczirC5890xnM91OeCUo4XOlZGMJryMgsMMUSOgSGGyIbFdNQ8i2/XUD8cv37XYtdZFZ+icbtcEODlpvnXxJztZ/Hj5ChUKoWH3JJyxaKR7/98Gs92b6JWG6NaE2Mpyudinxgix8COvUR2ZN97D2PN2B7o19rwtcjMIQjA3F/P6dxHV3hIyy1R21+i0ifGUpTLkVNUjiNX71ju5ERkkxhiiOxIs4b18GiHwFpbCVouCCgqq9T42s27JdibnCWqiVFtxtEUUpT3sWxNzIPHk79PxKhv/8GhKzkWOz8R2R6GGCI7pDpZXU2R6+i0cju/aq2i+OQsxTYnJ+1NR4ptyue3YIiRaSjrgcsMMUR1GUMMEWllSMfbf5SabVSbk8Z8dwR/nkkXbVNelaSmOvYSkWNgiCGyQ4ZWYLzyUAuzrmNITYmu7HDz7j0s3pWM/JIKnEi7C0EQRE1hlpzszpK1OkRkHxhiiOqwEV3MW8rZkGAg6JkoLyW7GAOXxOOplYcRn5ytMjqp6v9ZhaXILChVP9jCZSWiuoUhhqgOk5o51tiQJhq50sLW2mpW7hRXzfa763yGaPvwrw7ilxM30WtBHCIXxqG0QmZyWdmaROR4GGKI6jBzJ30zpHJDZsRaS6qv38q7h+k/nlI8zyupMKZ4Ipo6IddWB2hzbDt5E1M2nzArwBE5KoYYojrM3BBjWJ8Y40KMrn3MKa5My4nLK+VYsisZx67lmn7yGvT21lP443Q6vv/nurWLQmR3GGKI7JCh9QvaZts1lCHNScrZQV/oEe7/p405kUvbcPD/JFzDV3uu4NnVCUafs6ZW2dbkbollFtgkciQMMUR1SL9W/vD3clM89/dy1bG3foZ8sCoHHX0f+foywYHLOfhvwjX9BdNUDi0nT8kuMul8m4+kocenf+PsrXyTjjdWbU1gSFSXMMQQ1SEuUgmclTrzqi4caawZ/zujdx+5aHSScTU3qt756RTm/HoOideNb/pR7mBs6PV0+XDbGdwpLsf0H5NMOwER1TiGGCI7t/WV3nipb3MAwPRH25rVr8QURvWJMWAfoGp+GXPKYUmcRI/IdjHEENmhYeHBAIAW/vUQ2aIh5g7vgIufDEF4E1+1fUMbeNRoWXaceTBsWl8fEkNzhil5RFuIMTfbCACKyiprtX8MERmGIYbIDrUN8kbCzIH4c1p/xTZ3F6nGfbU1s9QEfR/zO86ko7RS/1BiU2pVaqrG5Gp2MTrN+0s0FJyIbANDDJGdCvb1gJuz5uBiLfrCx70KGeb/33kDzmP8tTWGGMFyc8VsO3nLIufRprabAYnqAoYYojquNptBDLlUeaX+qiFTamLYdYXI8TDEENUxqn/QvzO4ba1d22Kday3YJ4aI6i6GGKI67unuTeDmXDv/1E9cv2uR85gSSLQdYkq2MaS2yNLYmkRkPIYYIgcQ7OteK9eZ8+s5i5xHX9PQ13su4+s9l0XbLNWx98NtZ9Bm9p8WORcR1Szz5iQnIrtwz84WF9TVGTf/XgUW77oEABjbJ0yxXdOMvQKMb5nafCTNyCOIyFpYE0NUx0g0DHMpKbevEKOrUkW5xiWroOzBMfbes5fDk4iMxhBD5ADuqYSYfz/fxToFMZCuEVVKqypgy9EHtSba1k4ytE/MuoOpePOHk4btTEQ2gSGGqI5pG+Sttm1adGvR8xFdGtdWcUyirVZFEARsPHxd8bxSaT9DamKUQ4+qj38/j99O3TailERkbQwxRHXMZ093xuheTfH7G/0U215/uFWNLz9gSeduF2jcvjc5C1/+fUnxXLlpSeNcd4Ig6l8ze/tZyxXSwtiYRGQ8duwlqmMCvN0Q+1S4aJuTkwQdg31xI/fBwoquzk5WGUpsiJ8Sb0LqJMGipzuLtl/LKRE9r1RaU0Hr6CQ76SpjJ8UksimsiSFyEE4q/9p/m9rXOgUx0JZjN3C3uFy0zUmluqJC9uCjX1ufGGUMCkR1C0MMkYOY2K85AODRDoEAgHZBPohs3sCaRdIr9U6x6LmTSoqplD2oidFWq2TJ4JJfUmHBs4mxOYnIeAwxRA6ie7MGSJwdjW/+1d3aRTHYxA3HRM9Vh49X6unMa+mVCCI+3oWcojLRtvJKOVbvS8F5Lf14NDl3Ox8fbjuDrMJSyxZQD0EQUGpncwYR6cIQQ+RAGnq5iWozNH3GN6jnWnsF0uOuSs2HanNSpUx3SpFIgLJKzR/aucXlmPfrWZy9lW9UmQ6n3BE9X3swFYv+vIjHlh8w+BzDlh/E5iNpeO+n00Zd21xvb01Cuzk7cS2nWP/ORHaAIYaIFHw9XJA4O9raxdDKSaUmpkKmu2Ny3MUs7DiToXgukws4mXYXgiBgzq9nsTHhOh7/6qBRZZDeL8P/Em8idscFnLmVZ9TxyvZdylY8NnSuu9M38/D9P9dNWp18e1LVEPINh68ZfSyRLeLoJCJHpuFzUNOMv7ZCtWQn0nQvOHk1W73G4cmVh7H42QhcTDe8+UdZdW3QOz+dAgDU93Qx6TymeuLrQwAAfy9XDOkUXKvXJrI1rIkhcmC61iiyJSv2XsGjS/ch7564eUm1uclQv5yoGsJtCtWQV1haadJ5zHUps8jkY204pxIZhTUxRKRgShNFTeu9MA4ZBVUdYL87cNUi51Tt12IMJ4m4GctF6oRKufmdZSVGjk8yJ4cYey0iW8WaGCIHZoOZRU11gAGAnKJyHXvWjgqZIFpQ09XZOr9GWZtCxBBD5NBUM4w9hBprm7L5BErKrdOEZCkMQFRXMMQQkUKZntE+VEV5VfBKC33PjA0Wmjpg5xSVYcSKQ9h8pGqhy73JWegTG4fDV3LEx5pcSiLbwhBD5MBU+8BUz3rbIdjHGsWxG89/+4/icYWeuWpq07K/L+HUjTx8uO0MAGDC+mO4nV+KF747YuWSaXe3uByxOy7gcmahtYtCdoghhsiBafv4/W5cD0SE+imej+oRWivlsRZjOzRnFT6YtbdCbju1VyVlhnUwtqXmpA+3ncE3+6/i0S/3W7soZIcYYohITYifB+Y+3l7x3MNVasXS1LxtJ2+ZfKwx+edGbgl+TTL9WnrZUDgx1Ombxs2YbEkyuYAjV++guMy++zg5MoYYIgdm6Aew6ky5dc30H09Z9Hy7z2diyLL9uKAyoV7/z/firS1JFrmGprfEnKHTcj3rUNVF6w+lYtS3/2D8+qPWLgqZiCGGyIGpfmxNi26tcb86nmEsbtJ/juNiRiHe3ppk8jnySsp1NnNpCiyGzt+n2il47cFUdJ6/y+h1pOzdlmM3AADHrume+ZlsF0MMkSNT+ZB8a5DmEEOmMWbFaOVYEXchE10+3o2Pfjtn1PVyiw2bR0c163zy+3kUlVVi1v0OwY7C2cRZmwHj3luqOQwxRKQg/gtdabVrx2tpsAh3F9P6En2+MxkAsDHhOgpKK7DjTLrah6Zq7diWo2mIu5hl0vWqOVqLkrPUtBCTnFGIdnN2YuYvtbsKOaljiCEivWpjjSVv97q3CoqpHaKdlGoIXt54HK9vOoH5/3detI/qx++s7WdFz5fuvqR4rLZOlJbPbrmDpVVnJ9M+Ar/eewUA8MPRG5YsDpmAIYbIgfVu2VDrax2CfeAkARr7eWitifH3crNYWcyp2rcVSTfyRM/dnQ0PMco1K8rfi6OpuQCA/yXeFPWRUa2JkapsWB53WfFYPcNo/l5boybGmut1uZhYE2P/P6l1B0MMkQN7O7oNGnlrDiIerlKcmz8E8e89rPH1ecM7YObQdornIb7uZpVFauJfxbZk5IpDoufKt6Q6UkkXbSts6woZur59ho4uMzZQlJRXYv+lbMUkifbG9JXMLVwQMplJvzVWrFiBsLAwuLu7IzIyEkePah+etmbNGvTv3x/169dH/fr1ER0drXN/Iqo97i5STOjbXOvrHq5SuEidNH64tQjwEn1whvnXM6ssyrUPXZQm2jPEv3o3NevaNUUQqoJBaYUM/9Iza265TMCvSbeQX1Kh9cNVpiPF6Aoqqi9p29WY5qTXvk9Eh7l/Yey6o/h850WDj1Mvm+5EUCmTY+Yvp/G/xJuY9+tZfKVUw2SOHWfS8c/VXJOOZYaxHUaHmK1bt2L69OmYN28eTpw4gYiICMTExCArS3OHsvj4eIwePRp79+5FQkICQkNDMXjwYNy6VYMTPhGRwQz5Y1T5szP4fo2LsUFDH+UP7k6NjVv24MmuTSxaFnNM3XxC8VgQquagaTdnJ+7oGTm0PO4y3tqShAU7zqs1DVVTDhmqTULajgGgVlOibU9Dm5MEQcCfZzMUzzfdX6vJFPpqf34/nY4fjt7AOz+dwsaE61ii1NfHHK9vOqF/Jy30BS+qPUaHmKVLl2LSpEmYMGECOnTogNWrV8PT0xPr1q3TuP+mTZvw+uuvo0uXLmjXrh2+++47yOVyxMXFmV14IjKfIU0NQzsFAQCCfNyx771HcOajwfD1cFHbr7Gfh+nlUPptZGzfjHputjOj8O+n0xWPE67eMXo24B+P39TaNKRaU3I1u0gxSZ2TjjRqTDgxhOr5XJ0fFLi0QobNR9KQkV9q2EX1MHTYODkmo4YDlJeXIzExETNnzlRsc3JyQnR0NBISEgw6R0lJCSoqKtCgQQOt+5SVlaGs7MHaJAUFhrclE5Fxggzoy9KnlT92vNkfoQ084OrsJPrQUrbt9T7YdykbXZv64UJ6Ib74KxlpuSUGlUN5pIixfT2bNTCvKcuWSCRaRs1IxM1J3x28igU7LuBfvZvi05HhJvfvUK6lMTXsuEgflPfznclYdygVAd5uODYrWu+57LFWw/5KXHcZVROTk5MDmUyGwMBA0fbAwEBkZGRoOUpsxowZCAkJQXS09h/u2NhY+Pr6Kr5CQ+v24nNE1jQsPBivPtQC37zYXed+HUJ84O0urn0J8X1Q8yIIQCMfdzzbIxStGnljeESIUR0g3UTByLgUU5fWdhIE4OCVHI2vKa81mVlQ9Yfe9/+k4dxt/TPtKq/ZpPy+rIy/8uD8BqZH1b2U37v45KquBdlKi2TqPJcJo5OsOaIJAFOMDanV4QCLFi3Cli1bsG3bNri7a//rb+bMmcjPz1d83bjBsfhENcXJSYKZj7VHTMcgo4+NbKF9iLaxFj4Vjgb1XDH38Q6woYWhbYpMy4f3sOUH9Ta7aFuz6Q+l5i9dISa7sAxfxV1GZkGp2n7KIaY24oW1MowgCDh7Kx+VMseaT8eWGRVi/P39IZVKkZmZKdqemZmJoCDdvwAXL16MRYsWYdeuXejcubPOfd3c3ODj4yP6IiLbpnlBQsN1CPZB4uxovNSvucNNumYoXaOTdOnVXNx8r9wpuLTywUzAusLj65sSsWT3JYxbd1QtRGhrXqz2c+JNPP7VAdzOu6f2minNSXJBQP69CqOPM9f6Q9fw+FcH8dup27V+bdLMqBDj6uqK7t27izrlVnfSjYqK0nrc559/jk8++QQ7d+5Ejx49TC8tEdksTbnDmI9cF6mT4gONEUadBKbPqBvoI675Vs4N98ofJBddzTTViyRezChUe005xGg6x7s/ncLZWwVYsOMCAOB23j38mnQLlTK53qYhTRnnjR9OImL+LrXJBWva2oOptXo90s/oeb6nT5+OcePGoUePHujVqxeWLVuG4uJiTJgwAQAwduxYNG7cGLGxsQCAzz77DHPnzsXmzZsRFham6Dvj5eUFLy8vC94KEdkz5X6prInRzNSamP9TqTlQzgXKazIZenbV90e5Y6+uc5SUVQIABi3Zh3sVMtzV0wQmkwtqyy0AUAzvXnPgKla80M3AUlNdZHSfmFGjRmHx4sWYO3cuunTpgqSkJOzcuVPR2TctLQ3p6Q/aWFetWoXy8nI888wzCA4OVnwtXrzYcndBRDZp+qNtAADPdNc/j4uoWYEZRiNTQ4wu95RCjKbwmFlQilf+c1y0TXU3Q0dGVQ/nr77mwSs5ove9rFK8yOWBy9kGndeSrmYXYcOhVLWyADbQodhMt/PuoaS80trFsCiTVlybOnUqpk6dqvG1+Ph40fNr166ZcgkiqgNGdGmMXs0bIMjHHT8n3jT4OPv+qKg5FquhUgoOysFIU0b68JczaqtjF5eJPwiV1yDSVUS19Z6cJKJgMGz5QbzYuxnO3y5A7FPhKK2ovR7egiBAIpFg4JJ9AIDC0kq8Mah1rV1fn/x7FXBzdjJ5ZfTUnGI8sjge9T1dcHLuYAuXznrsf7ESIrIJ2la6Dvb1gEQiweZJkXh3cBvRa29q+ZBgc5Jmey5qnhndWBJUfSiqrvWkqabhlobOuL0WiicrNXTdK9VOvKrz4VzJKsK8385h6/Eb2HcpG6bE2Z1nM/D8twlIz1cvty6qAS4x7a7R164pBaUViJi/C5ELTZ8kdt/9oe93S2q/Q3RNYogholrRp6U/pg5sjfF9wgAAj3YI1Lr4ZG2tphzW0LN2LmQBZZVyjf1DTPX9P9fVOsaa2ly1/1I2ztysmqtGW5gF1Je4cNaxinTevXKThlJP/j4R/1zNxZztZ406rtKGx/Wfvf+9tcaILFvHEENEterDx9rjPy/1wvLnu2rdR7VG4L2YtvjxVe0jIIGqIdrGimrpjyEmzI9j78oq5bicqT7KSFOGMTRIvLnlpNq2WdvOiJ6rrvf0a9Jt3NayPEGlTFcc0q96MkBDqWYYTfHKavWDSoXZfvIWzt7KR1aBZZZ1sHcMMURUq1ydnfBQmwB4uEq1fig8EREiej7lkVZqc52o2jwp0uiyOEmAZ3vYzuKRtWX1vhRsT1Kf60QmF1Apq/o0FwRBsS6TIaqXL1AOPZuOpImaoyQS4Owt/bMLV5fFnFbFCpn2mpWJG46pbVOtibGl5RCUw9+0rUl4/KuD6LUwDley1IOo1nMo3Y8x76utM6ljLxGRKkMWklTVW0swebRDIHw9XIyqPvfzdDX6+hIJ1JZScGRFZZVoNetPfDe2B9YdSkVOUZnBTUy38u5h36Vs3Lwr7osiU5rd9s+zGaLVr3WpkAs6m6b0KdcSYiplcrWOykDNjPyqaf93Kh0SSTp6NW+APi39de6rXLv5vxM38WyPurGcD2tiiMgs84Z3QIC3Gz4e0cnoY1sHemPntP5InC1eS00ikaBjSM3M1P1yv+YPrgMJfDzU/5bT1uHYUbz8n+M4nHIHlzKLkJJdbPBx49YdVdtmaoWGTCY3qyZGeWFLZdpOWakSYmynHkb79/C3U7ex7O/LeGHNEaPOt+t8pv6d7ARDDBGZZULf5jj64SC0amTa5JXtgnzQ0Eu9g29NDVAa07uZ6HnrRt5o4S9eBXv6o21wcs6jNVMAMsjCHRcNroe5lXcPa/ZfxZ6LDz6ctTUnaRv5ploTYyutSfklFdh0JE3jazfvGrZCfF3G5iQiMpst9R/QRzSnCQRInSSIe2cAms/cIdqvfj3jm6dInak/GuUyOb47cNWgffsu2qO2rULLIo3awrFqTYyt1MV8/Pt5tRmXqzk7OaFCpj4pnyb29G/UGKyJISKbZN7YFO2UJwur/tzS9gu+SX0Pred5f0hbi5arrjLnw/P0Tf2dgLXNoqutOSmnSPOoJZlMf02MNaYvUq5dUqVriLqjYIghIptkygfGT5Oj8HzPUES1aKh1H+UQo/rBpWrHW/0Vj/9UegyY1pGZLE9bf1xt875M33pK43bV0Lz7fCYybWAYc30dHdaV16xyVPwOEJFNMuWP3p5hDbDo6c7w9dA+4shdacVl9SYEMQ+lwBPaQDwxHiOMYSp1DHW2BG2jirS9tUev5Wrcnpqj3oH5te8TTS5XtZt3S2psvSIX1sQwxBBR3aP6V7XyiCRnpb9eZXpmaXWROmHntP74/Y1+8HJzhrPSlLOsiTHMgC/ia/T82jrqapoL5V659v4j49cfwz9X74i2nUjLEz3X1MR56EoO9lzMRO+FcTh8JUf0Wkp2Efp9thf9P9ur9bp66fgxU122wRHxO0BEtsmM/geqn2uD2gdq3E+5JsbNWfOvw3ZBPujU2BcA8J+JvdCwnitWjelmM6NXHNkfp9Mx85czGl+rlAvIKSrD+z+fwon76yB9tvOizvNtO3FL9NyQxbnHfHcEL204joyCUrzwnXio897789HcKS7Hmv1XkZxh+OR0hnAWLbxpf/PcWAJDDBHVOaq/zptqWSNJ+a/4F+8PvY4I9dN63j4t/XF8djSGhgfX2dEe9mbbyVtaX+vx6d/48fhNPLXyMICqNZ502Xr8hui5h4krRmuyYMcFxCzbb7HzARDVDGobjVXXcYg1Edkkc0Ynqf5R2tjPA5tfjoSvp7ivTKXSL/73hrRFp8a+6Nda98yn1eHFkL/SDfXGwFZIyS7CjjOGzWZLpjF2dXQPV+NDzNqDqZh4v/nSEkFX1xmUO/ZmF5WhsZ/20XR1NXOzJoaIbJKhnze6OvEq69PKHx1DfEXblD/U3JylGNm1Mfw1TLyniSX7xLwzuK1a2cjyjI3FprzHn/z+YKXxms4NyiGm76I9+OXETYOOq0stTwwxRGTXdk7rr38nLfSNTtLl4bYBOl8f3yfMoPM0vD+p3rDwYJPLQoYxtibG1J+OixkFtbIWk2pt4OztZ2v8mraGIYaIbJKhHwHBvpqq0A072pwPmmYN6+l8/aMnOsJVzzwej7QNwPcvV62+HeZfD+882sbk8pB+egajqVHNPIZmoCHLDohqZHQ5dSMPy/6+hLJKzSOndDVJmfrTm3QjDwWlhi+uassYYojIJpkz2kL50FcfaqH2ev/7/V5e6NXU5GuoahGgHmr0/eX/2TOd0T74wUKXITr6NJA1mP4zuOHwNY39UCplciRez1XMKDxixSEs+/syvjuQanzpTAxZOUVlGLb8gNHXs0UMMURkk8ypjFc+duZj7dVeXzGmG3a9/RCGWqgJp1lDT/zf1H6K59UjnfSFGNV5Ply1DPMmyzA2GN8pLkfi9bs6p/7XRdMEegt3XMTTqxIwa5t4aPhFLcOvdfWrOXNL/7IM2tzIvYf0/HsmH28r+C+GiGySrpEW+uj7sPJxd0GbQG+Tz19t9b+6Y2C7Rtj2el/Uc3NGauxj+GvaQ/joiY5V5dBzvOraN9rmqiHLMLb1UBCAp1cdxksbjuNGbonRx/8n4bratnWHqmpcfkoUd8LV9jNrTN9iY0f0RcXuQUa+9ZdWMAeHWBORTZo3vCMEiJt8PFykuFdR1XfglYdawN9L87oynq6186ttSKcgDOkUpHgukUjQNuhBONL3h7+zSs/MIF93tX0e7xyM30+nm1dQAmDesP2DV3K0Lh5pCdUlyy+pwG+nb+Px8GCTV1KvkMnhJJFAqvTzpS0LJV6/i2Gd7bdTOUMMEdmkAG83rHihm2hb7xYNsDc5G15uzvhQQzNRtVnD2iMluwjjDBwhZC1SlRDTPtgHLfzrwcNVinO3CwAAwRqCDZnGnAFD8349Z7mCaHK/bDO3ncaOMxn4v1O38eOrUZAYMVC7tEKO0d/+g6PXctHCvx52Tx+geno1Mjsfb80QQ0R2Y/GzEfhm/1U81yNU534hfh7YOe2hWiqV6VxU+sS4SJ2we/oACIKAVrP+BCBedZtMV1haoXE9JUOZU4tjzPl3nq2a8PBoai7Grz+K5EzjlipIuL/+0+WsIlTK5KK1wjSJ3XEBT0SEmFBi28AGWCKyGw293PDhY+3RqpGXtYtisg+GtlM8dtIw7a/USSL64GlSX3/fIEP2cXThH+3CneJyk4+v6Wn9qytEGipNthifrHuZBH3KKh+MKddWn5POPjFERGQoXSspa1LPTfuv6X3vPYw7xeXILSrHy/85bm7RyIqqR7JZcr2m8ko56t3PRDfv2v9IJE1YE0NEVIsCvA1b1qCaaudfZe4uUnRrWt+um5xeHaA+j48jqomuKaWVMvxwNA3nbufjm/1XLX8BG8CaGCKiWvRM9ya4W1yOR9o1Mmh/qZP2vzWrOwarDtW2NSG+7ritpdmigadpI3DqmuoMk5ZbYrFzrj2Qiu8OGj+Jnj1hTQwRUS1yd5HijUGt0amxYQs+6uqXWV1L46InxIQ19AQAqw2l7RDio/U1U4cR1wXFZZWKx4IAszoea1LXAwzAEENEVGOe7tbE7HM4Oznhmxe7I6KJL/6ePkA07FxRE6OjtgYA/vdaH6wc0w1THm5ldnlMoatfTz0D5/SZf38CwbqkRNQ/SrD74c7WwBBDRFRDPh3ZCav/1Q0Lnuxk9LHPdG+CiCa+iGrZEDEdg/Dr1H5o1chL1HRUHV46NfbFU90aaz1XQy83PBYebFCz00ADm7mM4e4sxep/dVeb9weoqmna9HIkXurbXOc5jO1LZA8e/mKv4vHfF7KQrGXpgZp2/Y768gj2giGGiKiGeLhKMaRTMLx01ERos/jZCPw6tR9cVNqTpErz0FfXxEidJFj6XBe959S3lhMARDTxM6qchpBKJRjSKUhjc5ZEIkHfVv6YO7wDhhm4llVdWZ6hWGWk2lOrDlulHAO+iLfKdS2hbvwkEBE5COW1dHSNXNJELte/T00sQpmUlqd4/Oag1qLXlO9AX8hKnB2NP97sh73vPmy5wtmQ8koD3iASYYghIqphTsas4mfEuTRNlqeLIbPOagsxb0e3Mepayq4pNVeoFlm5RLpCjARVzWIdQ3xtfjSWPer/+R6cuWn6qtjWwhBDRFTDLJhhdJ7rs6fDEdWioWjbv3o/WEDTkH6j9Vw1zzlTv54LWgTUU9se3tgXn4wQd7ptrTKjcn2lYdRSlRvwVLqersE5nUP9FI+Vg9yut21/eQl7cCP3nkHNjbaGIYaIqIb1at4AAODjbv7UXLpGIo3q2RQ/vNJb8bxFQD18OjJc8bxlgO7lGtoEeqFjiOah352b+CGsoXqI+b83+uFZpbWsVv+ru2jhQUA8/b1q7ZFyfyFBy4foaw+3RGO/B0srKPeJCfHjkguWotr/yh5wsjsiohrWyNsdx2ZFm9TBV1VkiwboGOKjVtuhrFlDT1y/U4LhncUL+3m4SnH6o8EQBCBi/i4AwMR+zbH2YCqe7NoYS56NwJlb6k0KK17ohi6hfni5f3PsuZiFAW0C4OwkwcNtAwCIa4cCfapGEQX7uivW5SmrfNCBVbUmSfl7ItNSFdOsgafoube7C74cFQEJJBb5nlIVV2f7a6bju09EVAssNUTYReqE39/oB4mOdqX/vdYH/1y9g8EdgtRe83F3QaXsQc3IiC4hmPN4B8Xz5gH14O3uDBepE3LvL5jYr7U/AKBPS38cnTUI/vXcRDUqys071X/N//hqFPp/XjWE+I2BrTTuCwCNlRav1NacVKHhhSe7mj8HD4mxJoaIiGqcrgADAP5ebnhcpRZGma6Oxj7uLjj8wUC4SJ0QFRsHqZMTvJVqOxp5u+s8X/U6TqENPHFlwVBcyS5C20BvpX0fHHfg/UfgqTTZnbY+GTIZR+3UBnsMMfZXYiIiMotyhtGUG7zdXeDuIsWRD6Nx+IOBekdBSZ0keLJrYwxq1wgtlTr/Okud0C7IRxS6lANPqEozkbZ+pd7uLjqvb6ohHR/UVHVr6mex83qr9H3aMKEn3h/SVrStJiYVNBdDDBER2Tx9NTnVXJ2dDJ435stRXbB2fE+959ZVC6SpJmZElxA80UV7rZI5Vv3rwQzClhwG36yhOJw93LYRmtQXb/PzrJlgZg5XhhgiIiLtdFXqaAox/36+a43VEKjWEFWPgFIe9i1VKfDkAS3VzuPm7ITHwh/U6kggUYxIq6baEdvPw7yFL8dENtW/k5Fc7LBjL0MMEZEDeqhNwP0h1dpXmK4JupqmLLyIs9YFOJ/u1gQrx4jXcZJIgG1T+mDpcxE4MedRvDWoNTqG+OCfmYMU+/Rv7Y8PhrZTm4vn3PwYtY7GlSr9eNoH+2D56K6K5/XNqIl5sXczLHgyXOvrVxYMNem89lgTw469REQOaOOEnhAE42f9NZeu5iZt88To4+PujILSSvXzKc0H/MpDLfDt/qsAgEVPh6vV7jhJJGjk7Y6n7geftx9tg7cfrZql+NsXuyOnqBzP96yaD+e/E3uhy8e7UVRWiUfaBsBZ6iQqe8uAevBwleKE0nILADC4Q+CDMnuYHmLeGax79mRnE8OIaq2TPWCIISJyQBKJxKIzCRvqofvDtb01zO/ydLcmOHbtrtHn3DypNz794zxefaglissrMXXzSQDijsLKt6ppzakuOjr2Du4oHqruLHXCH2/2w/8Sb+KlflWrb3dWWjhz7vCOkDpJ4OHijCe7PlhdXCoalm7AjWnhez8APd8zFFuO3TD9RCoM7StlSxhiiIio1jRrWA8HZzwiWoqg2qieoWgd6I2U7CK8//Npg8/ZqbEvtrwSpXheHWJElD6flT+s/5r2EP6+kImJ98OIoZo1rIfpgx+MOArydce+9x6Gj7sL6terure5wzuIjjG28/Cbg1pjedxlte3V5f/oiY4I8nXHsr8f7OPvVXXtFyKbYvORNI3nlUgMW4LCHthfAxgREdm1JvU9UU9DTYxEIkH3ZvXxbPcmWPJsBHabuS6SchNP60beGvdpG+SNKY+0UsxvY45mDespAowmyrUvhjTjPd45WPE49in1PjDuLlKM6PKgpmfHm/1x4P2BAMQTDKrSVAtmr+rOnRARUZ0gkUjwdHfzZ+QVAPRt1RDJGUUYFh6Mrk39UM/Veh97EknVqKXMglI81ikYR1Nz8WvSbY37JswciGBfD5yaOxiVcrli9mS1cyo9bhFQTxHGPF2032eInwcKMgpNvg9bwhBDRER1ynM9muDH4zcxeUBLtAvyRqVcgIvUSe8CmLVh6yu9FeX5aHhHXEgvwKXMIrX9gn2rhnv73h/F1NDLDb+83geNVJavaNbQE9HtA+Hr4SKqTXJ31d7QojxHzbDwYPxxJt2se7ImhhgiIqpTPnu6M+Y/0Qke9+d7cZHaTodViUSiKE/9eq54pG0jtRBT3a9FVbem9TWe77txPdS26xouLVca/a1t0U17wT4xRERUp0gkEkWAsXWVGkLE/17rY/Z5JRIJpj6iuV9MpVKKeb5X1bDxrhZcdqE2sSaGiIjISlRrQsZGNUOzhvW07G2cd2PaYlTPUPwn4RrWHEgFUNV81MjHTTGHzcNtG+Hv6QPQRGk1cXvCEENERGQl4Y19FY8XPhkumlfGEkIbeGLWsA6YNawDUnOK0ayBJ+5VyODr4YLHwqtGP7VqZP2+QqZiiCEiIrKSJ7s2RoVMju7N6qN1oOZh4JbS3L+qhqeemzOmReue9ddemNQnZsWKFQgLC4O7uzsiIyNx9OhRnfv/9NNPaNeuHdzd3REeHo4dO3aYVFgiIqK6xMlJgud7Na3xAFNXGR1itm7diunTp2PevHk4ceIEIiIiEBMTg6ysLI37Hz58GKNHj8bEiRNx8uRJjBw5EiNHjsTZs2fNLjwRERE5Lolg5IpbkZGR6NmzJ77++msAgFwuR2hoKN544w188MEHavuPGjUKxcXF+P333xXbevfujS5dumD16tUGXbOgoAC+vr7Iz8+Hj0/trrhKREREpqnpz2+jamLKy8uRmJiI6OjoBydwckJ0dDQSEhI0HpOQkCDaHwBiYmK07g8AZWVlKCgoEH0RERERKTMqxOTk5EAmkyEwMFC0PTAwEBkZGRqPycjIMGp/AIiNjYWvr6/iKzQ01JhiEhERkQOwycnuZs6cifz8fMXXjRuWW2qciIiI6gajhlj7+/tDKpUiMzNTtD0zMxNBQUEajwkKCjJqfwBwc3ODm5ub1teJiIiIjKqJcXV1Rffu3REXF6fYJpfLERcXh6ioKI3HREVFifYHgN27d2vdn4iIiMgQRk92N336dIwbNw49evRAr169sGzZMhQXF2PChAkAgLFjx6Jx48aIjY0FALz11lsYMGAAlixZgmHDhmHLli04fvw4vv32W8veCRERETkUo0PMqFGjkJ2djblz5yIjIwNdunTBzp07FZ1309LS4OT0oIKnT58+2Lx5M2bPno0PP/wQrVu3xvbt29GpUyfL3QURERE5HKPnibEGzhNDRERkf2xqnhgiIiIiW8EQQ0RERHaJIYaIiIjsktEde62hutsOlx8gIiKyH9Wf2zXV/dYuQkxhYSEAcPkBIiIiO1RYWAhfX1+Ln9cuRifJ5XLcvn0b3t7ekEgkFjtvQUEBQkNDcePGjTo/6on3Wvc4yn0CvNe6ylHu1VHuE1C/V0EQUFhYiJCQENH0K5ZiFzUxTk5OaNKkSY2d38fHp87/YFXjvdY9jnKfAO+1rnKUe3WU+wTE91oTNTDV2LGXiIiI7BJDDBEREdklhw4xbm5umDdvnkOsmM17rXsc5T4B3mtd5Sj36ij3CdT+vdpFx14iIiIiVQ5dE0NERET2iyGGiIiI7BJDDBEREdklhhgiIiKySw4dYlasWIGwsDC4u7sjMjISR48etXaRjBIbG4uePXvC29sbjRo1wsiRI5GcnCza5+GHH4ZEIhF9TZ48WbRPWloahg0bBk9PTzRq1AjvvfceKisra/NW9Proo4/U7qNdu3aK10tLSzFlyhQ0bNgQXl5eePrpp5GZmSk6hz3cZ1hYmNp9SiQSTJkyBYB9v5/79+/H8OHDERISAolEgu3bt4teFwQBc+fORXBwMDw8PBAdHY3Lly+L9snNzcWYMWPg4+MDPz8/TJw4EUVFRaJ9Tp8+jf79+8Pd3R2hoaH4/PPPa/rW1Oi614qKCsyYMQPh4eGoV68eQkJCMHbsWNy+fVt0Dk0/C4sWLRLtY+v3CgDjx49Xu48hQ4aI9rGH91XffWr6dyuRSPDFF18o9rGX99SQzxZL/c6Nj49Ht27d4ObmhlatWmHDhg3GFVZwUFu2bBFcXV2FdevWCefOnRMmTZok+Pn5CZmZmdYumsFiYmKE9evXC2fPnhWSkpKExx57TGjatKlQVFSk2GfAgAHCpEmThPT0dMVXfn6+4vXKykqhU6dOQnR0tHDy5Elhx44dgr+/vzBz5kxr3JJW8+bNEzp27Ci6j+zsbMXrkydPFkJDQ4W4uDjh+PHjQu/evYU+ffooXreX+8zKyhLd4+7duwUAwt69ewVBsO/3c8eOHcKsWbOEX375RQAgbNu2TfT6okWLBF9fX2H79u3CqVOnhCeeeEJo3ry5cO/ePcU+Q4YMESIiIoR//vlHOHDggNCqVSth9OjRitfz8/OFwMBAYcyYMcLZs2eFH374QfDw8BC++eab2rpNQRB032teXp4QHR0tbN26Vbh48aKQkJAg9OrVS+jevbvoHM2aNRM+/vhj0Xut/G/bHu5VEARh3LhxwpAhQ0T3kZubK9rHHt5XffepfH/p6enCunXrBIlEIqSkpCj2sZf31JDPFkv8zr169arg6ekpTJ8+XTh//rzw1VdfCVKpVNi5c6fBZXXYENOrVy9hypQpiucymUwICQkRYmNjrVgq82RlZQkAhH379im2DRgwQHjrrbe0HrNjxw7ByclJyMjIUGxbtWqV4OPjI5SVldVkcY0yb948ISIiQuNreXl5gouLi/DTTz8ptl24cEEAICQkJAiCYD/3qeqtt94SWrZsKcjlckEQ6s77qfohIJfLhaCgIOGLL75QbMvLyxPc3NyEH374QRAEQTh//rwAQDh27Jhinz///FOQSCTCrVu3BEEQhJUrVwr169cX3euMGTOEtm3b1vAdaafpA0/V0aNHBQDC9evXFduaNWsmfPnll1qPsZd7HTdunDBixAitx9jj+2rIezpixAhh4MCBom32+J4Kgvpni6V+577//vtCx44dRdcaNWqUEBMTY3DZHLI5qby8HImJiYiOjlZsc3JyQnR0NBISEqxYMvPk5+cDABo0aCDavmnTJvj7+6NTp06YOXMmSkpKFK8lJCQgPDwcgYGBim0xMTEoKCjAuXPnaqfgBrp8+TJCQkLQokULjBkzBmlpaQCAxMREVFRUiN7Pdu3aoWnTpor3057us1p5eTm+//57vPTSS6KFT+vK+6ksNTUVGRkZovfQ19cXkZGRovfQz88PPXr0UOwTHR0NJycnHDlyRLHPQw89BFdXV8U+MTExSE5Oxt27d2vpboyXn58PiUQCPz8/0fZFixahYcOG6Nq1K7744gtRVbw93Wt8fDwaNWqEtm3b4rXXXsOdO3cUr9XF9zUzMxN//PEHJk6cqPaaPb6nqp8tlvqdm5CQIDpH9T7GfA7bxQKQlpaTkwOZTCb65gJAYGAgLl68aKVSmUcul2PatGno27cvOnXqpNj+wgsvoFmzZggJCcHp06cxY8YMJCcn45dffgEAZGRkaPw+VL9mKyIjI7Fhwwa0bdsW6enpmD9/Pvr374+zZ88iIyMDrq6uah8AgYGBinuwl/tUtn37duTl5WH8+PGKbXXl/VRVXTZNZVd+Dxs1aiR63dnZGQ0aNBDt07x5c7VzVL9Wv379Gim/OUpLSzFjxgyMHj1atDjgm2++iW7duqFBgwY4fPgwZs6cifT0dCxduhSA/dzrkCFD8NRTT6F58+ZISUnBhx9+iKFDhyIhIQFSqbROvq8bN26Et7c3nnrqKdF2e3xPNX22WOp3rrZ9CgoKcO/ePXh4eOgtn0OGmLpoypQpOHv2LA4ePCja/sorrygeh4eHIzg4GIMGDUJKSgpatmxZ28U02dChQxWPO3fujMjISDRr1gw//vijQT/o9mjt2rUYOnQoQkJCFNvqyvtJVSoqKvDcc89BEASsWrVK9Nr06dMVjzt37gxXV1e8+uqriI2Ntavp659//nnF4/DwcHTu3BktW7ZEfHw8Bg0aZMWS1Zx169ZhzJgxcHd3F223x/dU22eLrXDI5iR/f39IpVK1ntSZmZkICgqyUqlMN3XqVPz+++/Yu3cvmjRponPfyMhIAMCVK1cAAEFBQRq/D9Wv2So/Pz+0adMGV65cQVBQEMrLy5GXlyfaR/n9tLf7vH79Ov7++2+8/PLLOverK+9nddl0/ZsMCgpCVlaW6PXKykrk5uba5ftcHWCuX7+O3bt3i2phNImMjERlZSWuXbsGwL7uVVmLFi3g7+8v+pmtS+/rgQMHkJycrPffLmD776m2zxZL/c7Vto+Pj4/Bf5w6ZIhxdXVF9+7dERcXp9gml8sRFxeHqKgoK5bMOIIgYOrUqdi2bRv27NmjVg2pSVJSEgAgODgYABAVFYUzZ86IfolU/0Lt0KFDjZTbEoqKipCSkoLg4GB0794dLi4uovczOTkZaWlpivfT3u5z/fr1aNSoEYYNG6Zzv7ryfjZv3hxBQUGi97CgoABHjhwRvYd5eXlITExU7LNnzx7I5XJFmIuKisL+/ftRUVGh2Gf37t1o27atTTU5VAeYy5cv4++//0bDhg31HpOUlAQnJydF04u93Kuqmzdv4s6dO6Kf2bryvgJVNajdu3dHRESE3n1t9T3V99liqd+5UVFRonNU72PU57BpfZXt35YtWwQ3Nzdhw4YNwvnz54VXXnlF8PPzE/WktnWvvfaa4OvrK8THx4uG7JWUlAiCIAhXrlwRPv74Y+H48eNCamqq8OuvvwotWrQQHnroIcU5qofBDR48WEhKShJ27twpBAQE2MSQXGXvvPOOEB8fL6SmpgqHDh0SoqOjBX9/fyErK0sQhKrhfk2bNhX27NkjHD9+XIiKihKioqIUx9vLfQpC1Ui5pk2bCjNmzBBtt/f3s7CwUDh58qRw8uRJAYCwdOlS4eTJk4oROYsWLRL8/PyEX3/9VTh9+rQwYsQIjUOsu3btKhw5ckQ4ePCg0Lp1a9FQ3Ly8PCEwMFB48cUXhbNnzwpbtmwRPD09a32Iqq57LS8vF5544gmhSZMmQlJSkujfbvWojcOHDwtffvmlkJSUJKSkpAjff/+9EBAQIIwdO9au7rWwsFB49913hYSEBCE1NVX4+++/hW7dugmtW7cWSktLFeewh/dV38+vIFQNkfb09BRWrVqldrw9vaf6PlsEwTK/c6uHWL/33nvChQsXhBUrVnCItTG++uoroWnTpoKrq6vQq1cv4Z9//rF2kYwCQOPX+vXrBUEQhLS0NOGhhx4SGjRoILi5uQmtWrUS3nvvPdG8IoIgCNeuXROGDh0qeHh4CP7+/sI777wjVFRUWOGOtBs1apQQHBwsuLq6Co0bNxZGjRolXLlyRfH6vXv3hNdff12oX7++4OnpKTz55JNCenq66Bz2cJ+CIAh//fWXAEBITk4Wbbf393Pv3r0af17HjRsnCELVMOs5c+YIgYGBgpubmzBo0CC178GdO3eE0aNHC15eXoKPj48wYcIEobCwULTPqVOnhH79+glubm5C48aNhUWLFtXWLSroutfU1FSt/3ar5wNKTEwUIiMjBV9fX8Hd3V1o3769sHDhQtEHvz3ca0lJiTB48GAhICBAcHFxEZo1ayZMmjRJ7Y9Fe3hf9f38CoIgfPPNN4KHh4eQl5endrw9vaf6PlsEwXK/c/fu3St06dJFcHV1FVq0aCG6hiEk9wtMREREZFccsk8MERER2T+GGCIiIrJLDDFERERklxhiiIiIyC4xxBAREZFdYoghIiIiu8QQQ0RERHaJIYaIiIjsEkMMERER2SWGGCIiIrJLDDFERERklxhiiIiIyC79P6NM7KEyuJlaAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"plt.plot(train_metrics_history[\"train_loss\"], label=\"Loss value during the training\")\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "f3c2b7fd-965e-4440-8ad6-882e7d4ae104",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0MAAANECAYAAAByxfRXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACy0UlEQVR4nOzdd3yV5f3/8fc5J3ueQCYhEPaQKUhkOX5Go1UKTlBbBAVbCq1KrZVWUFy09ltKhzZqQbF14EBrK8URi4uloCKWvVdCAmSTdc79+yM5B2IC5GTdZ7yej55HzX3uc+dzApyT97mu63NZDMMwBAAAAAABxmp2AQAAAABgBsIQAAAAgIBEGAIAAAAQkAhDAAAAAAISYQgAAABAQCIMAQAAAAhIhCEAAAAAAYkwBAAAACAgEYYAAAAABCTCEHAGDz30kCwWi9llAAAAoI0QhnBOzz//vCwWi7744guzS0ErKi8v10MPPaRVq1a16fdZsWKFHnrooTb9Hmfy1FNP6fnnnzflewMw11NPPSWLxaKMjAyzS0EjVq9erYceekiFhYVt+n0ef/xxvfXWW236PRpz+PBhPfTQQ/rqq6/a/XvDM4QhIECVl5dr/vz57RKG5s+f36bf40wIQ0DgevHFF5Wenq7169dr586dZpeD71i9erXmz5/v12Fo/vz5hCEfQBgCAAB+Zc+ePVq9erUWLlyohIQEvfjii2aXdEZlZWVmlwAENMIQWs2XX36pq666SjExMYqKitJll12mtWvX1junurpa8+fPV69evRQWFqaOHTtqzJgxev/9993n5ObmaurUqercubNCQ0OVkpKi8ePHa+/evWf83v/3f/8ni8Wiffv2Nbhvzpw5CgkJ0YkTJyRJn3zyiW688UZ16dJFoaGhSktL0z333KOTJ0+e9fnt3btXFoul0ZEGi8XSYCrYoUOHdPvttyspKUmhoaE677zztGTJkrN+D5eamho98sgj6tGjh0JDQ5Wenq5f/epXqqysrHdeenq6rrnmGn366acaMWKEwsLC1L17d73wwgvnfC4JCQmSpPnz58tisTR4Dlu3btUNN9ygDh06KCwsTMOHD9fbb79d7zrn+vOcMmWKnnzySffPyHU7my+++EJZWVmKj49XeHi4unXrpttvv73eOU6nU4sWLdJ5552nsLAwJSUl6Uc/+pH7z9j1s/n222/10Ucfub/vJZdcctbvDcA/vPjii4qLi9PVV1+tG2644YxhqLCwUPfcc4/S09MVGhqqzp07a/LkySooKHCfU1FRoYceeki9e/dWWFiYUlJSdN1112nXrl2SpFWrVslisTQYZW/sPWPKlCmKiorSrl279L3vfU/R0dG69dZbJXn23rR161bddNNNSkhIUHh4uPr06aNf//rXkqT//ve/slgsevPNNxs87qWXXpLFYtGaNWvO+vPbvXu3brzxRnXo0EERERG68MIL9c4779Q7x/W8X331VT322GPq3LmzwsLCdNlll51zJO6hhx7SL37xC0lSt27d3K/Rp7/P/+Mf/9CwYcMUHh6uDh06aNKkSTpw4EC96+zYsUPXX3+9kpOTFRYWps6dO2vSpEkqKiqSVPu+U1ZWpqVLl7q/x5QpU85a25///Gedd955ioiIUFxcnIYPH66XXnqp3jnnen9ftWqVLrjgAknS1KlT3d+bmQreKcjsAuAfvv32W40dO1YxMTG67777FBwcrKefflqXXHKJPvroI/ec7YceekgLFizQtGnTNGLECBUXF+uLL77Qxo0bdfnll0uSrr/+en377bf66U9/qvT0dB09elTvv/++9u/fr/T09Ea//0033aT77rtPr776qvsF1uXVV1/VFVdcobi4OEnSa6+9pvLycs2YMUMdO3bU+vXr9ec//1kHDx7Ua6+91io/j7y8PF144YWyWCyaNWuWEhIS9J///Ed33HGHiouLdffdd5/18dOmTdPSpUt1ww036Oc//7nWrVunBQsWaMuWLQ3e4Hbu3KkbbrhBd9xxh2677TYtWbJEU6ZM0bBhw3Teeec1ev2EhAT99a9/1YwZM3TttdfquuuukyQNGjRIUu2f5+jRo5Wamqr7779fkZGRevXVVzVhwgS98cYbuvbaayWd+8/zRz/6kQ4fPqz3339ff//738/5czt69KiuuOIKJSQk6P7775fdbtfevXu1fPnyeuf96Ec/0vPPP6+pU6fqZz/7mfbs2aO//OUv+vLLL/XZZ58pODhYixYt0k9/+lNFRUW5f0lISko6Zw0AfN+LL76o6667TiEhIbr55pv117/+VZ9//rn7F1RJKi0t1dixY7VlyxbdfvvtOv/881VQUKC3335bBw8eVHx8vBwOh6655hrl5ORo0qRJuuuuu1RSUqL3339fmzdvVo8ePTyuraamRllZWRozZoz+7//+TxEREZKa/t60adMmjR07VsHBwbrzzjuVnp6uXbt26V//+pcee+wxXXLJJUpLS9OLL77ofq0+/efSo0cPjRw58oz15eXladSoUSovL9fPfvYzdezYUUuXLtX3v/99vf766w2u+Zvf/EZWq1X33nuvioqK9MQTT+jWW2/VunXrzvg9rrvuOm3fvl0vv/yy/vCHPyg+Pl6S3B/SPfbYY5o7d65uuukmTZs2Tfn5+frzn/+siy66SF9++aXsdruqqqqUlZWlyspK/fSnP1VycrIOHTqkf//73yosLFRsbKz+/ve/u9+f7rzzTkk665/Zs88+q5/97Ge64YYbdNddd6miokKbNm3SunXrdMstt7h/Pud6f+/Xr58efvhhzZs3T3feeafGjh0rSRo1atQZvzdMZADn8NxzzxmSjM8///yM50yYMMEICQkxdu3a5T52+PBhIzo62rjooovcxwYPHmxcffXVZ7zOiRMnDEnG7373O4/rHDlypDFs2LB6x9avX29IMl544QX3sfLy8gaPXbBggWGxWIx9+/a5jz344IPG6f9E9uzZY0gynnvuuQaPl2Q8+OCD7q/vuOMOIyUlxSgoKKh33qRJk4zY2NhGa3D56quvDEnGtGnT6h2/9957DUnGhx9+6D7WtWtXQ5Lx8ccfu48dPXrUCA0NNX7+85+f8XsYhmHk5+c3qNvlsssuMwYOHGhUVFS4jzmdTmPUqFFGr1693MfO9edpGIYxc+ZMo6kvNW+++eY5/6598sknhiTjxRdfrHd85cqVDY6fd955xsUXX9yk7w3AP3zxxReGJOP99983DKP2tatz587GXXfdVe+8efPmGZKM5cuXN7iG0+k0DMMwlixZYkgyFi5ceMZz/vvf/xqSjP/+97/17m/sPeO2224zJBn3339/g+s19b3poosuMqKjo+sdO70ewzCMOXPmGKGhoUZhYaH72NGjR42goKBGX/NPd/fddxuSjE8++cR9rKSkxOjWrZuRnp5uOByOes+7X79+RmVlpfvcP/7xj4Yk45tvvjnr9/nd735nSDL27NlT7/jevXsNm81mPPbYY/WOf/PNN0ZQUJD7+JdffmlIMl577bWzfp/IyEjjtttuO+s5LuPHjzfOO++8s57T1Pf3zz///Iy/M8C7ME0OLeZwOPTee+9pwoQJ6t69u/t4SkqKbrnlFn366acqLi6WJNntdn377bfasWNHo9cKDw9XSEiIVq1aVW/KU1NMnDhRGzZscE9dkKRly5YpNDRU48ePr/c9XMrKylRQUKBRo0bJMAx9+eWXHn3PxhiGoTfeeEPjxo2TYRgqKChw37KyslRUVKSNGzee8fErVqyQJM2ePbve8Z///OeS1GCqQv/+/d2fOkm1n6z16dNHu3fvblb9x48f14cffqibbrpJJSUl7tqPHTumrKws7dixQ4cOHZJ07j9PT9ntdknSv//9b1VXVzd6zmuvvabY2Fhdfvnl9X62w4YNU1RUlP773/+2Si0AfNOLL76opKQkXXrppZJqp0pNnDhRr7zyihwOh/u8N954Q4MHD24w0uF6jOuc+Ph4/fSnPz3jOc0xY8aMBsea8t6Un5+vjz/+WLfffru6dOlyxnomT56syspKvf766+5jy5YtU01NjX7wgx+ctbYVK1ZoxIgRGjNmjPtYVFSU7rzzTu3du1f/+9//6p0/depUhYSEuL92vR819z1o+fLlcjqduummm+q9xicnJ6tXr17u1/jY2FhJ0rvvvqvy8vJmfa/vstvtOnjwoD7//PNG72/p+zu8E2EILZafn6/y8nL16dOnwX39+vWT0+l0z/N9+OGHVVhYqN69e2vgwIH6xS9+oU2bNrnPDw0N1W9/+1v95z//UVJSki666CI98cQTys3NPWcdN954o6xWq5YtWyap9kXrtddec69jctm/f7+mTJmiDh06KCoqSgkJCbr44oslyT3PuCXy8/NVWFioZ555RgkJCfVuU6dOlVQ7HexM9u3bJ6vVqp49e9Y7npycLLvd3mBd1HffECUpLi7O4zDpsnPnThmGoblz5zao/8EHH6xX/7n+PD118cUX6/rrr9f8+fMVHx+v8ePH67nnnqu3VmrHjh0qKipSYmJig/pKS0vP+rMF4N8cDodeeeUVXXrppdqzZ4927typnTt3KiMjQ3l5ecrJyXGfu2vXLg0YMOCs19u1a5f69OmjoKDWW1UQFBSkzp07NzjelPcmV8A4V919+/bVBRdcUG+t1IsvvqgLL7ywwXvLd+3bt++M7+eu+0/33fcg15T05r4H7dixQ4ZhqFevXg1e47ds2eJ+je/WrZtmz56tv/3tb4qPj1dWVpaefPLJFr2P//KXv1RUVJRGjBihXr16aebMmfrss8/c97f0/R3eiTVDaFcXXXSRdu3apX/+859677339Le//U1/+MMflJ2drWnTpkmS7r77bo0bN05vvfWW3n33Xc2dO1cLFizQhx9+qKFDh57x2p06ddLYsWP16quv6le/+pXWrl2r/fv367e//a37HIfDocsvv1zHjx/XL3/5S/Xt21eRkZE6dOiQpkyZIqfTecbrn+lTwNM/aZTkvsYPfvAD3XbbbY0+xrU252ya+qmjzWZr9LhhGE16/He56r/33nuVlZXV6DmuN9Om/Hl6wmKx6PXXX9fatWv1r3/9S++++65uv/12/f73v9fatWsVFRUlp9OpxMTEMy6Ids05BxB4PvzwQx05ckSvvPKKXnnllQb3v/jii7riiita9Xs29b3BJTQ0VFartcG5zX1vOpPJkyfrrrvu0sGDB1VZWam1a9fqL3/5i8fXOZe2eA+yWCz6z3/+0+i1o6Ki3P/9+9//XlOmTHG/B/3sZz/TggULtHbt2kYD57n069dP27Zt07///W+tXLlSb7zxhp566inNmzdP8+fPb7X3d3gXwhBaLCEhQREREdq2bVuD+7Zu3Sqr1aq0tDT3sQ4dOmjq1KmaOnWqSktLddFFF+mhhx6q98tzjx499POf/1w///nPtWPHDg0ZMkS///3v9Y9//OOstUycOFE/+clPtG3bNi1btkwREREaN26c+/5vvvlG27dv19KlSzV58mT38dO72Z2J69Ou7+6J8N1PyRISEhQdHS2Hw6HMzMxzXve7unbtKqfTqR07drg/iZNqF20WFhaqa9euHl+zMWd6A3dNdQwODm5S/ef682zOVJILL7xQF154oR577DG99NJLuvXWW/XKK69o2rRp6tGjhz744AONHj263rSSxrRkGgsA3/Piiy8qMTHR3cXydMuXL9ebb76p7OxshYeHq0ePHtq8efNZr9ejRw+tW7dO1dXVCg4ObvScpr43nE1T35tcr8/nqluSJk2apNmzZ+vll1/WyZMnFRwcrIkTJ57zcV27dj3j+7nr/tZwptfnHj16yDAMdevWTb179z7ndQYOHKiBAwfqgQce0OrVqzV69GhlZ2fr0UcfPev3OZPIyEhNnDhREydOVFVVla677jo99thjmjNnjkfv77z/+A6myaHFbDabrrjiCv3zn/+s1xYzLy9PL730ksaMGeOepnbs2LF6j42KilLPnj3d06DKy8tVUVFR75wePXooOjq6QVvpxlx//fWy2Wx6+eWX9dprr+maa65RZGRkvVql+p9YGYahP/7xj+e8dkxMjOLj4/Xxxx/XO/7UU0/V+9pms+n666/XG2+80egbVn5+/lm/z/e+9z1J0qJFi+odX7hwoSTp6quvPmetTeHqYPTdN/DExERdcsklevrpp3XkyJEGjzu9/nP9eUpy//ybsrHeiRMnGnyaOGTIEElyX/Omm26Sw+HQI4880uDxNTU19b5PZGRkm2/oB8A7nDx5UsuXL9c111yjG264ocFt1qxZKikpcW8RcP311+vrr79utAW163Xo+uuvV0FBQaMjKq5zunbtKpvNds73hrNp6ntTQkKCLrroIi1ZskT79+9vtB6X+Ph4XXXVVfrHP/6hF198UVdeeaW7a9vZfO9739P69evrtd8uKyvTM888o/T0dPXv37/Jz+tszvTecN1118lms2n+/PkNnpNhGO73neLiYtXU1NS7f+DAgbJarQ3eg5r6PvDd97SQkBD1799fhmGourrao/d3T977YC5GhtBkS5Ys0cqVKxscv+uuu/Too4/q/fff15gxY/STn/xEQUFBevrpp1VZWaknnnjCfW7//v11ySWXaNiwYerQoYO++OILvf7665o1a5Ykafv27brssst00003qX///goKCtKbb76pvLw8TZo06Zw1JiYm6tJLL9XChQtVUlLS4FOwvn37qkePHrr33nt16NAhxcTE6I033mjy3OZp06bpN7/5jaZNm6bhw4fr448/1vbt2xuc95vf/Eb//e9/lZGRoenTp6t///46fvy4Nm7cqA8++EDHjx8/4/cYPHiwbrvtNj3zzDMqLCzUxRdfrPXr12vp0qWaMGGCe1FwS4WHh6t///5atmyZevfurQ4dOmjAgAEaMGCAnnzySY0ZM0YDBw7U9OnT1b17d+Xl5WnNmjU6ePCgvv76a0nn/vOUpGHDhkmSfvaznykrK0s2m+2Mf5ZLly7VU089pWuvvVY9evRQSUmJnn32WcXExLhD4sUXX6wf/ehHWrBggb766itdccUVCg4O1o4dO/Taa6/pj3/8o2644Qb39/7rX/+qRx99VD179lRiYqL+3//7f63y8wPgXd5++22VlJTo+9//fqP3X3jhhe4NWCdOnKhf/OIXev3113XjjTfq9ttv17Bhw3T8+HG9/fbbys7O1uDBgzV58mS98MILmj17ttavX6+xY8eqrKxMH3zwgX7yk59o/Pjxio2N1Y033qg///nPslgs6tGjh/797397tHbEk/emP/3pTxozZozOP/983XnnnerWrZv27t2rd955R1999VW9cydPnux+PWzsA6TG3H///Xr55Zd11VVX6Wc/+5k6dOigpUuXas+ePXrjjTcaTPFrLtd7w69//WtNmjRJwcHBGjdunHr06KFHH31Uc+bM0d69ezVhwgRFR0drz549evPNN3XnnXfq3nvv1YcffqhZs2bpxhtvVO/evVVTU6O///3v7sBy+vf54IMPtHDhQnXq1EndunVzb/fxXVdccYWSk5M1evRoJSUlacuWLfrLX/6iq6++WtHR0ZKa/v7eo0cP2e12ZWdnKzo6WpGRkcrIyFC3bt1a5eeHVtSOnevgo1yttc90O3DggGEYhrFx40YjKyvLiIqKMiIiIoxLL73UWL16db1rPfroo8aIESMMu91uhIeHG3379jUee+wxo6qqyjAMwygoKDBmzpxp9O3b14iMjDRiY2ONjIwM49VXX21yvc8++6whyYiOjjZOnjzZ4P7//e9/RmZmphEVFWXEx8cb06dPN77++usGLTC/21rbMGpbn95xxx1GbGysER0dbdx0003G0aNHG21RnZeXZ8ycOdNIS0szgoODjeTkZOOyyy4znnnmmXM+h+rqamP+/PlGt27djODgYCMtLc2YM2dOvVbXhlHbWrux1tYXX3xxk1pKr1692hg2bJgREhLS4Dns2rXLmDx5spGcnGwEBwcbqampxjXXXGO8/vrr7nPO9edpGIZRU1Nj/PSnPzUSEhIMi8Vy1jbbGzduNG6++WajS5cuRmhoqJGYmGhcc801xhdffNHg3GeeecYYNmyYER4ebkRHRxsDBw407rvvPuPw4cPuc3Jzc42rr77aiI6ONiTRZhvwY+PGjTPCwsKMsrKyM54zZcoUIzg42N0W+dixY8asWbOM1NRUIyQkxOjcubNx22231WubXF5ebvz61792vx4nJycbN9xwQ72tJPLz843rr7/eiIiIMOLi4owf/ehHxubNmxttrR0ZGdlobU19bzIMw9i8ebNx7bXXGna73QgLCzP69OljzJ07t8E1Kysrjbi4OCM2NrbR98Mz2bVrl3HDDTe4rz9ixAjj3//+d71zXK21v9va+mzbUHzXI488YqSmphpWq7VBm+033njDGDNmjBEZGWlERkYaffv2NWbOnGls27bNMAzD2L17t3H77bcbPXr0MMLCwowOHToYl156qfHBBx/U+x5bt241LrroIiM8PNyQdNY2208//bRx0UUXGR07djRCQ0ONHj16GL/4xS+MoqKieuc19f39n//8p9G/f38jKCiINttezGIYzVzhBgAAAK9VU1OjTp06ady4cVq8eLHZ5QBeiTVDAAAAfuitt95Sfn5+vaYMAOpjZAgAAMCPrFu3Tps2bdIjjzyi+Ph4NgIFzoKRIQAAAD/y17/+VTNmzFBiYqJeeOEFs8sBvBojQwAAAAACEiNDAAAAAAISYQgAAABAQPKLTVedTqcOHz6s6OhoWSwWs8sBgIBiGIZKSkrUqVOnVtuQ0R/w3gQA5vDkfckvwtDhw4eVlpZmdhkAENAOHDigzp07m12G1+C9CQDM1ZT3Jb8IQ9HR0ZJqn3BMTIzJ1QBAYCkuLlZaWpr7tRi1eG8CAHN48r7kF2HINf0gJiaGNxwAMAlTwerjvQkAzNWU9yUmdwMAAAAISIQhAAAAAAGJMAQAAAAgIBGGAAAAAAQkwhAAAACAgEQYAgAAABCQCEMAAAAAAhJhCAAAAEBAIgwBAAAACEiEIQAAAAABiTAEAAAAICARhgAAAAAEJMIQAAAAgIBEGAIAAAAQkJoVhp588kmlp6crLCxMGRkZWr9+/RnPveSSS2SxWBrcrr76avc5hmFo3rx5SklJUXh4uDIzM7Vjx47mlAYAAAAATeJxGFq2bJlmz56tBx98UBs3btTgwYOVlZWlo0ePNnr+8uXLdeTIEfdt8+bNstlsuvHGG93nPPHEE/rTn/6k7OxsrVu3TpGRkcrKylJFRUXznxkAAAAAnIXHYWjhwoWaPn26pk6dqv79+ys7O1sRERFasmRJo+d36NBBycnJ7tv777+viIgIdxgyDEOLFi3SAw88oPHjx2vQoEF64YUXdPjwYb311lstenIAAAAAcCYehaGqqipt2LBBmZmZpy5gtSozM1Nr1qxp0jUWL16sSZMmKTIyUpK0Z88e5ebm1rtmbGysMjIyznjNyspKFRcX17sBAAAAgCc8CkMFBQVyOBxKSkqqdzwpKUm5ubnnfPz69eu1efNmTZs2zX3M9ThPrrlgwQLFxsa6b2lpaZ48DQAAAABo325yixcv1sCBAzVixIgWXWfOnDkqKipy3w4cONBKFQIAAAAIFB6Fofj4eNlsNuXl5dU7npeXp+Tk5LM+tqysTK+88oruuOOOesddj/PkmqGhoYqJial3AwAAAABPeBSGQkJCNGzYMOXk5LiPOZ1O5eTkaOTIkWd97GuvvabKykr94Ac/qHe8W7duSk5OrnfN4uJirVu37pzXBAAAAIDmCvL0AbNnz9Ztt92m4cOHa8SIEVq0aJHKyso0depUSdLkyZOVmpqqBQsW1Hvc4sWLNWHCBHXs2LHecYvForvvvluPPvqoevXqpW7dumnu3Lnq1KmTJkyY0PxnBgAAAABn4XEYmjhxovLz8zVv3jzl5uZqyJAhWrlypbsBwv79+2W11h9w2rZtmz799FO99957jV7zvvvuU1lZme68804VFhZqzJgxWrlypcLCwprxlAAAAADg3CyGYRhmF9FSxcXFio2NVVFREeuHAKCd8RrcOH4uAGAOT15/27WbHAAAAAB4C8IQAAAAgIBEGAIAAAAQkAhDAAAAAAISYQgAAABAQCIMAQAAAAhIHu8z5I+e+2yP/r3piCySrBaLVPs/1f6npfb/LXX3qXaj2FP3135ttUhynXvaY63W2v/Xaee67rfW/bdO/z519/VLidHkkemy1V4YAADApxSUVurZj3crr7hCVw/qpEv7JCjIxufw8C6EIUn7j5drw74TZpfRwOpdx/THSUMUEcIfEwAA8A2llTX62ye79ezHu1VW5ZAkvfXVYSVEh+qGYZ110/A0dYuPNLlK8xSWV+lfm45o+caDOlJYoZuGd9a0i7orJizY7NICEpuuStpypFj7jpXJMCRDktMw3P/t+vHUfl173GnUHjck6bTj332s6s4xTjvf9djvXtN1ntMwVFZZo799ukdVNU4N6hyrv902XInRYa3wkwKA1sfmoo3j54JAU1Xj1Mvr9+tPOTt0rKxKkjSoc6yGd+2gt78+pILSKve5Gd06aNKINF01IEVhwTazSm43VTVOfbQ9X8s3HlTOlqOqcjjr3R8bHqw7L+quqaPT+RC8FXjy+ksY8lIb9h3XtKVf6ER5tVLt4Xp+6gXqlRRtdlkA0IA/vga3Bn4uCBROp6F/bTqs37+3XfuPl0uS0jtG6BdZffW9gcmyWCyqdjiVs+Woln2+Xx9tz5ez7rfP6LAgTRiSqokXpGlAaqyJz6L1GYahbw4VafnGQ3r768M6XnYqDPZPidF156cqITpUf/lwp3YcLZUkxUeF6CeX9NQtGV0CIiS2FcKQn9hbUKYpz63X3mPlig4L0tM/HKZRPeLNLgsA6vHX1+CW4ueCQPDJjnz95j9b9e3hYklSfFSo7srspUkXpCn4DOuDjhSd1OtfHNSyLw7o4ImT7uPndYrRpAvS9P0hqYoN990pY0eKTurNLw9p+cZD2lkXciQpITpUE4Z00nXnd1a/lFOvCQ6nobe/PqQ/vL/DHSZTYsP00//XSzcO73zGnyPOjDDkR46XVenOF77QF/tOKNhm0W+uG6Trh3U2uywAcPPn1+CW4OcCf7bpYKF+u3KrPtt5TJIUFRqkH13UXbeP6abI0KZN83I6Da3ZfUyvfH5A727OdU8dCw2y6nsDUzTxgjRldOsgi8X7m0mVVdbo3W9ztXzjIX22q0Cu365Dg6zKOi9Z152fqjE948/aQKLa4dTrGw7qTzk7dKSoQpLUpUOE7s7spfFDUmmq5QHCkJ+pqHbo3te+1r83HZEk3ZPZWz+7rKdPvDgA8H/+/hrcXPxc4I/2FpTp/97b5v6dJNhm0Q8vTNfMS3uoY1Ros697oqxKb355SMs+P6BteSXu493iI3Xj8M664fzOSozxrvXTDqehtbuP6Y2NB7Vyc67K65pFSNKIbh10/fmpumpgiseNESqqHXpp3X49tWqne51Vr8Qozb68t7LOS5aVUHROhCE/5HQa+t172/TXVbskSdef31kLrhuokCCGTgGYKxBeg5uDnwv8ydGSCv05Z6deXr9fNU5DFos0YUiqZl/eW2kdIlrt+xiGoa8PFmnZ5/v19leH3d3obFaLLu2TqEkXpOkSk1t07zxaouUbD+nNLw+5R3Ck2nVS153fWdcOTW2Vn0l5VY2eX71XT3+0W0UnqyXVTiW894o+uqRPAh+KnwVhyI+9tG6/5v5zsxxOQyO7d1T2D4f59LxaAL4vkF6DPcHPBf6gpKJaz368W3/7dI975OOSPgm6L6uv+ndq27/XZZU1euebI1r2+YF6W6AkRofqxuG1Lbq7dmyfFt3Hy6r0r68Pa/nGg/r6YJH7eExYkMYNrl0HdH4Xe5sElKKT1Vr8yW4t/nSPOxwO7xqnn1/RRyN7dGz17+cPCEN+btW2o5r54kaVVTnUKzFKS6Zc0KqfygCAJwLtNbip+LnAl1XW1E7V+vOHO91d0Aan2XX/lX1N+QV859ESLfv8gN7YeKheV7aR3Ttq4gVpunJAcqt3X6uscei/W/P1xsaD+u/Wo6qpa4EXZLXokj4Juv78zrq0b2K7dX07Vlqppz/eraWr96qypnZ91Zie8fr5Fb01tEtcu9TgKwhDAeB/h4t1+/OfK7e4QvFRoVoyZbgGdbabXRaAABSIr8FNwc8FvsjpNPT214f1f+9tc3d66x4fqfuu7KOs85JNn5pVVeNUzpY8vfL5AX28I9/dqCAmLEjXDk3VTRek6bxOzW/RbRiGvjpQqOUbD+lfmw6rsLzafd/A1Fhdd36qxg3upPgWrI9qqbziCv3lw5165fP9qnbU/gAy+yVq9uV92ny0zlcQhgLEkaKTmvrc59qaW6LwYJv+fPNQZfZPMrssAAEmUF+Dz4WfC3yJYRj6aHu+frtym7YcqW2TnRgdqrsze+um4Z1NXaNzJocKT+q1Lw7otS8O6lDhqRbdA1NjddMFaRo/pFOTmxccPFGut+raYe8uKHMfT4oJ1bVDO+u681PV28v2ezxwvFx/ytmhNzYedO/bdPWgFN2T2Vs9E6PMLc5khKEAUlJRrZkvfamPt+fLapHmXdNfU0Z3M7ssAAEkkF+Dz4afC3zFVwcK9Zv/bNHa3cclSdGhQfrxJT10++huCg/x/o0/HU5Dn+0s0LLPD+i9/+W6R0vCgmtbdE+6oIsuSI9rMKpVWlmj/3xzRMs3HtKa3cfcx8ODbbpyQG077FE94r2+pfWu/FIt+mCH/vX1YUmS1SJdd35n3XVZr4BdRkEYCjDVDqfm/XOzXl5/QJJ0++hu+vXV/bz+Hy8A/xDor8Fnws8F3m53fqn+771tWvFNriQpxGbV5JFdNfPSnoqLDDG5uuY5VlrpbtG947QNT7vHR+qmC9J07dBUbcst0fKNB7Xy21xVVNeuvbFYatcfXXd+Z105IFlRTdwryZtsOVKs37+3XR9syZNU2/Z84gVp+un/66UkL2tL3tYIQwHIMAxlf7Rbv125VZJ0Rf8k/XHSUJ/4RAeAb+M1uHH8XOCtjhZXaFHODi37/IAcdW2yrxvaWfdc3kud4/xjJMEwDH15oFDL1h/QvzYdrrcH0Om6J0Tq+vM7a8LQVKXaw9u5yrbx1YFC/f69bfpkR4Gk2o1fJ4/sqh9f3LK9oHwJYSiA/evrw/r5a1+rqsapwWl2/W3ycCVEB8ZffADm4DW4cfxc4G2KK6r1zEe1LZpPVteGg8v6JuoXV/ZR32T//TtaWlmjdzYd1iufH9CX+wtljwjW9+vaYQ/uHGt6U4i2snb3Mf3fu9v0RV1b8sgQm24f003Txnb3+21ZCEMB7vO9xzX9hS9UWF6tznHhen7qBeqZ6F2L/gD4D16DG8fPBd6issahv6/Zpyf/u1Mn6rqjnd/Frvuv6qcR3TqYXF37OlZaqeiw4IDZtN7VGOP3723XN4dq90eKCQvSjy7uoSmj0hXpg9MBm4IwBO0pKNOU59Zr37FyxYQF6ekfDmdjLgBtgtfgxvFzgdkcTkNvfXlIC9/f7u621iMhUvdd2VdX9E/y2xERNGQYht79Nk8L39+m7Xm1a6nio0I045KeujWjS7vtldReCEOQVPvpx/QXvtDG/YUKtln0xA2DdO3QzmaXBcDP8BrcOH4uMIthGFq1LV+/XblVW3NLJEnJMWG65/Jeuv5872yTjfbhcBr619eH9YcPtmvfsXJJtX83fnpZT900PE3BfvJ3gzAEt4pqh37+6td655sjkqTZl/fWT/9fTz4NAtBqeA1uHD8XmOFocYV++vKXWrentk12TFiQZlzSU1NGpdNUCW7VDqfe2HBQf8rZocNFFZKkLh0idFm/RFm96HfEi3on6OLeCR4/zpPXX/+cKAi3sLrNWDt3CNfTH+3Wwve368Dxcj1+3UC/Sf8AAKDWo+9s0bo9xxUSZNXUUemacUkP2SN8s0022k6wzapJI7ro2vNT9fK6/frLf3dp//FyPffZXrNLqycqNKhZYcgThKEAYLVaNOeqfkqLi9C8f27WaxsO6nDRST116zC/7yYCAECg2H+sXP/eVLvx5us/HqlBne3mFgSvFxpk05TR3XTTBWl6Y8NB9yiRt7ggve0bfBCGAsgPLuyqVHu4Zr60UZ/tPKYbs1dryZQL/GZPAQAAAtkzn+yS05Au7p1AEIJHIkKC9MOR6WaXYQrmSQWYS/sm6tUfjVRSTKi255Xq2qdW65uDRWaXBQAAWiC/pFKvfXFQkjTjkh4mVwP4DsJQABqQGqs3fzJafZOjlV9SqZueXqOcLXlmlwUAAJrp+dV7VFnj1NAudmUE2N5BQEsQhgJUJ3u4XvvxSI3tFa+T1Q5Nf+ELvbBmr9llAQAAD5VUVOuFNfskSTMu7kHHWMADhKEAFh0WrCVTLtCkC9LkNKR5//xWj/77f3I6fb7bOgAAAeOldftVUlGjnolRyuyXZHY5gE8hDAW4YJtVC64bqF9k9ZEk/e3TPfrJixt1ssphcmUAAOBcKmscWvzpHknSjy7qLquVUSHAE4QhyGKxaOalPfXHSUMUYrNq5be5uvnZtSoorTS7NAAAcBZvbjykoyWVSokN0/ghqWaXA/gcwhDcxg9J1T+mZcgeEayvDhTq2qc+0678UrPLAgAAjXA4DT398W5J0rSx3RUSxK91gKf4V4N6RnTroDdmjFKXDhE6cPykrntqtdbtPmZ2WQAA4Dve/TZXewrKFBserEkXpJldDuCTCENooEdClN78ySgN7WJX0clq/XDxen28Pd/ssgAAQB3DMJT90S5J0m2j0hUZGmRyRYBvIgyhUR2jQvXy9AuV2S9JVQ6nnl+91+ySAABAndW7jmnTwSKFBVs1ZVS62eUAPoswhDMKC7bppuGdJUnHy6pMrgYAALj8dVXtqNCkC7qoQ2SIydUAvoswhLOKq3uBLSwnDAEA4A02HSzUpzsLZLNaNG1sN7PLAXwaYQhnZQ8PliQVnqw2uRIAACDJvVZo/OBO6hwXYXI1gG8jDOGsYiNqw1DRyWo5nYbJ1QAAENj2FJTpP5tzJUk/uriHydUAvo8whLOyh9dOkzMMqbiC0SEAAMz0zMe7ZBhSZr9E9UmONrscwOcRhnBWIUFWRYbYJEmF5YQhAADMkldcoTc2HJIk/ZhRIaBVEIZwTvaI2tGhEzRRAADANEs+3aMqh1MXpMdpeHoHs8sB/AJhCOdkj6CJAgAAZio6Wa0X1+2XJM24hFEhoLUQhnBOrjBUxDQ5AABM8Y+1+1RaWaM+SdG6tE+i2eUAfoMwhHNimhwAAOapqHbouc/2SJJ+fEl3WSwWkysC/AdhCOfk3muIkSEAANrdaxsOqqC0Sqn2cF0zqJPZ5QB+hTCEc4qrGxkqZGQIAIB2VeNw6tmPd0uS7ryou4Jt/OoGtCb+ReGcaKAAAIA5VmzO1f7j5eoQGaKbhqeZXQ7gdwhDOKdYpskBANDuDMPQX1ftkiRNGZWu8Lp9/wC0HsIQzolpcgB8wZNPPqn09HSFhYUpIyND69evP+v5ixYtUp8+fRQeHq60tDTdc889qqiocN//0EMPyWKx1Lv17du3rZ8G4PbR9nxtOVKsiBCbJo/sanY5gF8KMrsAeD+myQHwdsuWLdPs2bOVnZ2tjIwMLVq0SFlZWdq2bZsSExu2IX7ppZd0//33a8mSJRo1apS2b9+uKVOmyGKxaOHChe7zzjvvPH3wwQfur4OCeNtE+8n+qHZU6JYRXdydXQG0LkaGcE7u1tpljAwB8E4LFy7U9OnTNXXqVPXv31/Z2dmKiIjQkiVLGj1/9erVGj16tG655Ralp6friiuu0M0339xgNCkoKEjJycnuW3x8fHs8HUAb95/Q2t3HFWyz6I6x3cwuB/BbhCGck2tkqLiiRg6nYXI1AFBfVVWVNmzYoMzMTPcxq9WqzMxMrVmzptHHjBo1Shs2bHCHn927d2vFihX63ve+V++8HTt2qFOnTurevbtuvfVW7d+/v+2eCHCa7Lq1QhOGpColNtzkagD/xXg/zsnVQEGSik9WKy6SoXoA3qOgoEAOh0NJSUn1jiclJWnr1q2NPuaWW25RQUGBxowZI8MwVFNTox//+Mf61a9+5T4nIyNDzz//vPr06aMjR45o/vz5Gjt2rDZv3qzo6OgG16ysrFRlZaX76+Li4lZ6hgg0O4+W6L3/5clikX50cXezywH8GiNDOKdgm1XRobW5+QRNFAD4gVWrVunxxx/XU089pY0bN2r58uV655139Mgjj7jPueqqq3TjjTdq0KBBysrK0ooVK1RYWKhXX3210WsuWLBAsbGx7ltaGm2Q0TxPf1S7r9AV/ZPUM7Fh8AbQeghDaJJYmigA8FLx8fGy2WzKy8urdzwvL0/JycmNPmbu3Ln64Q9/qGnTpmngwIG69tpr9fjjj2vBggVyOp2NPsZut6t3797auXNno/fPmTNHRUVF7tuBAwda9sQQkA4XntRbXx2SJP344h4mVwP4P8IQmoT22gC8VUhIiIYNG6acnBz3MafTqZycHI0cObLRx5SXl8tqrf8WaLPV7uFiGI2vjSwtLdWuXbuUkpLS6P2hoaGKiYmpdwM8tfjTPap2GLqwewcN7RJndjmA32PNEJrE3V6bjVcBeKHZs2frtttu0/DhwzVixAgtWrRIZWVlmjp1qiRp8uTJSk1N1YIFCyRJ48aN08KFCzV06FBlZGRo586dmjt3rsaNG+cORffee6/GjRunrl276vDhw3rwwQdls9l08803m/Y84d8Ky6v08vraJh0zLulpcjVAYCAMoUlcTRQIQwC80cSJE5Wfn6958+YpNzdXQ4YM0cqVK91NFfbv319vJOiBBx6QxWLRAw88oEOHDikhIUHjxo3TY4895j7n4MGDuvnmm3Xs2DElJCRozJgxWrt2rRISEtr9+SEwvLBmn8qrHOqfEqOLetHGHWgPFuNM8wF8SHFxsWJjY1VUVMS0hDYy963N+vvaffrZ/+up2Vf0MbscAF6E1+DG8XOBJ8qrajT6Nx/qRHm1/nzzUI0b3MnskgCf5cnrL2uG0CR2GigAANBmXv38gE6UV6tLhwhdNaDxxh8AWh9hCE1ir2ugcIJpcgAAtKpqh1PPfrJHknTnRd0VZOPXM6C98K8NTWJ3rxmimxwAAK3p35sO61DhScVHheqGYZ3NLgcIKIQhNIlrmlwR0+QAAGg1Tqehv67aJUm6fUy6woJtJlcEBBbCEJrk1DQ5RoYAAGgt/912VNvzShUVGqRbM7qaXQ4QcAhDaBL2GQIAoPW5RoVuvbCLexsLAO2HMIQmiasbGSqpqFGNw2lyNQAA+L7P9x7XF/tOKMRm1R2ju5ldDhCQCENokpiwU/vzsm4IAICWy64bFbp+WGclxoSZXA0QmAhDaJIgm1XRdYGIvYYAAGiZrbnFytl6VBZLbTttAOYgDKHJXFPlaK8NAEDLPP3RbknS9wakqFt8pMnVAIGLMIQmo4kCAAAtd/BEud7++rAk6ccX9zC5GiCwEYbQZKfaaxOGAABorr99skcOp6GxveI1sHOs2eUAAY0whCazh7tGhpgmBwBAcxwrrdQrn++XxKgQ4A0IQ2gy1zQ5uskBANA8S1fvVUW1U4M6x2pUj45mlwMEPMIQmuzUNDlGhgAA8FRZZY2WrtknSZpxcQ9ZLBaTKwJAGEKTnZomx8gQAACeenn9fhWdrFb3+EhdcV6y2eUAUDPD0JNPPqn09HSFhYUpIyND69evP+v5hYWFmjlzplJSUhQaGqrevXtrxYoV7vsfeughWSyWere+ffs2pzS0obhIwhAAAM1RVePU3z7ZI6l2XyGblVEhwBsEefqAZcuWafbs2crOzlZGRoYWLVqkrKwsbdu2TYmJiQ3Or6qq0uWXX67ExES9/vrrSk1N1b59+2S32+udd9555+mDDz44VViQx6WhjdnD6/YZOsk0OQAAPPHWV4eUW1yhxOhQXXt+qtnlAKjjceJYuHChpk+frqlTp0qSsrOz9c4772jJkiW6//77G5y/ZMkSHT9+XKtXr1ZwcO3IQnp6esNCgoKUnMyQsTeLZZ8hAAA85nQaevqjXZKkaWO7KTTIZnJFAFw8miZXVVWlDRs2KDMz89QFrFZlZmZqzZo1jT7m7bff1siRIzVz5kwlJSVpwIABevzxx+VwOOqdt2PHDnXq1Endu3fXrbfeqv379zfj6aAtxdU1UCAMAQDQdO9vydOu/DLFhAXp5hFdzC4HwGk8GhkqKCiQw+FQUlJSveNJSUnaunVro4/ZvXu3PvzwQ916661asWKFdu7cqZ/85Ceqrq7Wgw8+KEnKyMjQ888/rz59+ujIkSOaP3++xo4dq82bNys6OrrBNSsrK1VZWen+uri42JOngWZyNVAoraxRtcOpYBv9NwAAOBvDMPTUqtpRoR+O7KrosGCTKwJwujZfmON0OpWYmKhnnnlGNptNw4YN06FDh/S73/3OHYauuuoq9/mDBg1SRkaGunbtqldffVV33HFHg2suWLBA8+fPb+vS8R0x4cGyWCTDqB0dSogONbskAAC82trdx/X1gUKFBlk1ZVQ3s8sB8B0efbQfHx8vm82mvLy8esfz8vLOuN4nJSVFvXv3ls12an5sv379lJubq6qqxhfi2+129e7dWzt37mz0/jlz5qioqMh9O3DggCdPA81ks1oUE+baeJUmCgAAnEt23Vqhm4an8SEi4IU8CkMhISEaNmyYcnJy3MecTqdycnI0cuTIRh8zevRo7dy5U06n031s+/btSklJUUhISKOPKS0t1a5du5SSktLo/aGhoYqJial3Q/uw00QBAIAm+fZwkT7ani+b1aI7L+pudjkAGuHxoo/Zs2fr2Wef1dKlS7VlyxbNmDFDZWVl7u5ykydP1pw5c9znz5gxQ8ePH9ddd92l7du365133tHjjz+umTNnus+599579dFHH2nv3r1avXq1rr32WtlsNt18882t8BTRmux1TRROEIYAADir7I92S5KuHpiitA4RJlcDoDEerxmaOHGi8vPzNW/ePOXm5mrIkCFauXKlu6nC/v37ZbWeylhpaWl69913dc8992jQoEFKTU3VXXfdpV/+8pfucw4ePKibb75Zx44dU0JCgsaMGaO1a9cqISGhFZ4iWpOriUJhOdPkAAA4k33HyvTOpsOSpB9f3MPkagCcSbMaKMyaNUuzZs1q9L5Vq1Y1ODZy5EitXbv2jNd75ZVXmlMGTBDHNDkAAM7p2U92y2lIl/RJUP9OTOcHvBW9keER1zS5QhooAADQqPySSr36xUFJ0gxGhQCvRhiCR2LDGRkCAOBsnvtsj6pqnBraxa4R3TqYXQ6AsyAMwSNMkwMA4MxKKqr197X7JNWOClksFpMrAnA2hCF4hGlyAACc2Uvr9qukokY9E6OU2S/J7HIAnANhCB5x7TN0ooyRIQAATldR7dDfPt0jqbaDnNXKqBDg7QhD8IhrZKjoJGEIAIDTvfnlIeWXVColNkzfH9zJ7HIANAFhCB5hnyEAABpyOA09/dEuSdK0sd0VEsSvWIAv4F8qPBJXNzJUVuVQVY3T5GoAAPAO736bq73HymWPCNakC9LMLgdAExGG4JHosCC5pkDTRAEAAMkwDP11Ve2o0G0j0xUZ2qw97QGYgDAEj1itFvYaAgDgNJ/tPKZvDhUpLNiq20alm10OAA8QhuAxd3ttwhAAAPrrRzslSZMu6KIOkSEmVwPAE4QheCyWJgoAAEiSTpRV6bOdx2SxSNPGdjO7HAAeIgzBY3ERTJMDAECScosrJEkdI0PUOS7C5GoAeIowBI+5p8nRQAEAEODySyolSfFRoSZXAqA5CEPwmL1uZOgEI0MAgADnCkMJ0YQhwBcRhuAxezgNFAAAkKSC0rowxMgQ4JMIQ/CYa2SoiGlyAIAAx8gQ4NsIQ/CYe5pcGSNDAIDAll9KGAJ8GWEIHjvVQIEwBAAIbIwMAb6NMASPnWqtzTQ5AEBgc60Zopsc4JsIQ/AYDRQAAKjFyBDg2whD8Fhs3cjQyWqHKqodJlcDAIA5qmqc7m0m6CYH+CbCEDwWExYkm9UiSSpi3RAAIEAdK6sdFQq2WRQbHmxyNQCagzAEj1ksp170mSoHAAhUrilyHSNDZa37kBCAbyEMoVnc7bVpogAACFAFtNUGfB5hCM1iZ2QIABDgaJ4A+D7CEJrFtddQ0UlGhgAAgckdhmieAPgswhCa5dQ0OUaGAACByRWG4qNDTK4EQHMRhtAs7DUEAAh0BaW1syMYGQJ8F2EIzRIX4VozxDQ5AEBgOrVmKMzkSgA0F2EIzWKPoIECACCw5dNNDvB5hCE0S2xdA4VCGigAAAIU3eQA30cYQrPEMTIEAAhgJ6scKq2skSTFR9FAAfBVhCE0Cw0UAACBzLXhaliwVVGhQSZXA6C5CENollOttZkmBwAIPEdPmyJnsVhMrgZAcxGG0CyuMFRZ41RFtcPkagAAaF9suAr4B8IQmiUqNEhB1tpPwpgqBwAINK5pcvGEIcCnEYbQLBaLhalyAICARSc5wD8QhtBsseF0lAMABCb2GAL8A2EIzRbn2muIkSEAQIBhZAjwD4QhNJtrmlzhSUaGAACBhTVDgH8gDKHZYtlrCAAQoBgZAvwDYQjNFucaGWKaHAAggBiGQWttwE8QhtBs7mlyjAwBAAJISWWNKmuckhgZAnwdYQjNZq9roEBrbQBAICmoGxWKDg1SWLDN5GoAtARhCM1GAwUAQCBivRDgPwhDaDZ7XQOFIqbJAQACiGuPoXjCEODzCENoNtfIENPkAACBhJEhwH8QhtBsp0+TMwzD5GoAAGgfdJID/AdhCM0WV9dAoarGqZPVDpOrAQCgfbg2XGVkCPB9hCE0W0SITcE2iyTaawMAAgcjQ4D/IAyh2SwWi2LrmigQhgAAgSKfkSHAbxCG0CJx7o1XaaIAAAgMrpGheEaGAJ9HGEKLsNcQACCQOJ2GjpXWfgDIyBDg+whDaBHXNDnaawMAAkHhyWrVOGs7qHaMCjG5GgAtRRhCi5yaJsfIEADA/7mmyHWIDFGwjV+jAF/Hv2K0iGuaXBHT5AAAAeDUeiFGhQB/QBhCi9jr9ho6UcY0OQCA/2OPIcC/EIbQIjRQAOAtnnzySaWnpyssLEwZGRlav379Wc9ftGiR+vTpo/DwcKWlpemee+5RRUVFi64J/8ceQ4B/IQyhRezufYYYGQJgnmXLlmn27Nl68MEHtXHjRg0ePFhZWVk6evRoo+e/9NJLuv/++/Xggw9qy5YtWrx4sZYtW6Zf/epXzb4mAgN7DAH+hTCEFqGBAgBvsHDhQk2fPl1Tp05V//79lZ2drYiICC1ZsqTR81evXq3Ro0frlltuUXp6uq644grdfPPN9UZ+PL0mAoN7ZIgwBPgFwhBaJJZpcgBMVlVVpQ0bNigzM9N9zGq1KjMzU2vWrGn0MaNGjdKGDRvc4Wf37t1asWKFvve97zX7mggMrjVDbLgK+IcgswuAb4uLODVNzjAMWSwWkysCEGgKCgrkcDiUlJRU73hSUpK2bt3a6GNuueUWFRQUaMyYMTIMQzU1Nfrxj3/snibXnGtWVlaqsrLS/XVxcXFLnha8FCNDgH9hZAgt4mqgUO0wVF7lMLkaAGiaVatW6fHHH9dTTz2ljRs3avny5XrnnXf0yCOPNPuaCxYsUGxsrPuWlpbWihXDWxCGAP9CGEKLhAfbFFK36dwJmigAMEF8fLxsNpvy8vLqHc/Ly1NycnKjj5k7d65++MMfatq0aRo4cKCuvfZaPf7441qwYIGcTmezrjlnzhwVFRW5bwcOHGidJwivUeNw6njdex3d5AD/QBhCi1gsllPttWmiAMAEISEhGjZsmHJyctzHnE6ncnJyNHLkyEYfU15eLqu1/lugzWaTJBmG0axrhoaGKiYmpt4N/uV4WZUMQ7JZLe5p4gB8G2uG0GL2iGAdLalUEU0UAJhk9uzZuu222zR8+HCNGDFCixYtUllZmaZOnSpJmjx5slJTU7VgwQJJ0rhx47Rw4UINHTpUGRkZ2rlzp+bOnatx48a5Q9G5ronAc7RuilzHyBBZrayRBfwBYQgtZq/7dIxpcgDMMnHiROXn52vevHnKzc3VkCFDtHLlSncDhP3799cbCXrggQdksVj0wAMP6NChQ0pISNC4ceP02GOPNfmaCDzsMQT4H4thGIbZRbRUcXGxYmNjVVRUxLQEE9z5whd67395enTCAP3gwq5mlwOgnfEa3Dh+Lv7n1S8O6L7XN+mSPgl6fuoIs8sBcAaevP6yZggtdmrNECNDAAD/xR5DgP8hDKHFTu01xJohAID/oq024H8IQ2ixWNfIEA0UAAB+zB2GGBkC/AZhCC12amSIaXIAAP/FyBDgfwhDaDF7OPsMAQD8H2uGAP9DGEKLuabJ0VobAODPGBkC/A9hCC3mmibHpqsAAH9VUe1QcUWNJMIQ4E8IQ2ixU621q+UH21YBANCAa4pcSJBVMWHsWQ/4C8IQWsw1MlTjNFRaWWNyNQAAtL7TO8lZLBaTqwHQWghDaLGwYJtCg2r/KtFEAQDgjwpKa9fFxjNFDvArhCG0itOnygEA4G/YYwjwT4QhtAr3XkMn6SgHAPA/dJID/FOzwtCTTz6p9PR0hYWFKSMjQ+vXrz/r+YWFhZo5c6ZSUlIUGhqq3r17a8WKFS26JrxLLHsNAQD8WH5phSQpISrE5EoAtCaPw9CyZcs0e/ZsPfjgg9q4caMGDx6srKwsHT16tNHzq6qqdPnll2vv3r16/fXXtW3bNj377LNKTU1t9jXhfdwjQ+w1BADwQwUlte9vjAwB/sXjMLRw4UJNnz5dU6dOVf/+/ZWdna2IiAgtWbKk0fOXLFmi48eP66233tLo0aOVnp6uiy++WIMHD272NeF9WDMEAPBn+aVMkwP8kUdhqKqqShs2bFBmZuapC1ityszM1Jo1axp9zNtvv62RI0dq5syZSkpK0oABA/T444/L4XA0+5qVlZUqLi6ud4O5YuvC0AnCEADAD7FmCPBPHoWhgoICORwOJSUl1TuelJSk3NzcRh+ze/duvf7663I4HFqxYoXmzp2r3//+93r00Uebfc0FCxYoNjbWfUtLS/PkaaAN0EABAODPXGEonm5ygF9p825yTqdTiYmJeuaZZzRs2DBNnDhRv/71r5Wdnd3sa86ZM0dFRUXu24EDB1qxYjSHva6BQhEjQwAAP1NWWaOT1bUzWghDgH8J8uTk+Ph42Ww25eXl1Tuel5en5OTkRh+TkpKi4OBg2Ww297F+/fopNzdXVVVVzbpmaGioQkN5MfIm9rqRoRM0UAAA+BnXqFBkiE2RoR796gTAy3k0MhQSEqJhw4YpJyfHfczpdConJ0cjR45s9DGjR4/Wzp075XQ63ce2b9+ulJQUhYSENOua8D7uBgonGRkCAPgXmicA/svjaXKzZ8/Ws88+q6VLl2rLli2aMWOGysrKNHXqVEnS5MmTNWfOHPf5M2bM0PHjx3XXXXdp+/bteuedd/T4449r5syZTb4mvB/d5AAA/ormCYD/8nisd+LEicrPz9e8efOUm5urIUOGaOXKle4GCPv375fVeipjpaWl6d1339U999yjQYMGKTU1VXfddZd++ctfNvma8H6n7zPkdBqyWi0mVwQAQOsoKKV5AuCvmjXxddasWZo1a1aj961atarBsZEjR2rt2rXNvia8X2xdAwWnIZVW1SgmLNjkigAAaB2MDAH+q827ySEwhAXbFB5c2ySjsIypcgAA/+EOQ4wMAX6HMIRWc6qJAh3lAAD+g5EhwH8RhtBqXFPlTtBEAQDgR1gzBPgvwhBazelNFAAA8BeMDAH+izCEVuOaJlfEXkMAAD9hGAb7DAF+jDCEVmOvGxk6QQMFAICfKDpZrWqHIUnqGBVicjUAWhthCK2GBgoAAH/jWi8UGx6s0CCbydUAaG2EIbQae10DhUIaKAAA/MRR1gsBfo0whFZDAwUAgL9hjyHAvxGG0Gpi3dPkGBkCAPgHOskB/o0whFZzamSIMAQA8A/57DEE+DXCEFqNu4EC0+QAAH6ioKT2PY2RIcA/EYbQalwNFIpOVsvpNEyuBgCAlmOPIcC/EYbQalxrhpyGVFJRY3I1AAC0HGuGAP9GGEKrCQ2yKSKkdg8G9hoCAPgDVxiKZ8NVwC8RhtCqXE0UTtBEAQDg4xxOQ8fLGBkC/BlhCK0qNpwmCgAA/3C8rEpOQ7JapI6RhCHAHxGG0KpOdZRjZAgA4NtcU+Q6RIbKZrWYXA2AtkAYQqs6tdcQI0MAAN92ao8h1gsB/oowhFbl6ihXeJKRIQCAbyugkxzg9whDaFVxTJMDAPgJ9hgC/B9hCK3KHs40OQCAf2CPIcD/EYbQqlzT5GitDQDwde4wFEUYAvwVYQityt1AgTVDAAAfV8A0OcDvEYbQqlyttYuYJgcA8HGMDAH+jzCEVhXHNDkAgJ+ggQLg/whDaFWxdQ0Uiiuq5XAaJlcDAEDzVNU43Z1RCUOA/yIMoVXFhteODBmGVMy6IQCAjzpWVjsqFGyzuN/bAPgfwhBaVUiQVVGhQZJoogAA8F2u9ULxUaGyWCwmVwOgrRCG0Opcn6Cx1xAAwFexxxAQGAhDaHVxka4wxMgQAMA30UkOCAyEIbQ6e7hrryFGhgAAvsm1x1A8YQjwa4QhtLrYCEaGAAC+jWlyQGAgDKHVsdcQAMDXsccQEBgIQ2h1rmlyRTRQAAD4KEaGgMBAGEKrszMyBADwcae31gbgvwhDaHX2CFcDBcIQAMA3FZTWzm5gZAjwb4QhtDp73T5DTJMDAPii8qoalVbWSCIMAf6OMIRW59pniGlyAABfVFBS+2FeeLBNkSE2k6sB0JYIQ2h1sa59hhgZAgD4oPzSCklSfHSILBaLydUAaEuEIbQ6V2vt4ooa1TicJlcDAIBn8utGhhJongD4PcIQWl1s3ZohqTYQAQDgS9hjCAgchCG0uiCbVdGhQZKYKgcA8D3sMQQEDsIQ2oSdJgoAAB/FHkNA4CAMoU3Y65ooFJ1kZAgA4FsKmCYHBAzCENqEva6JwokyRoYAAL7FPU2OkSHA7xGG0CbsEXXttU8ShgAAvoU1Q0DgIAyhTdjrOsoV0UABAOBDDMNwd5NjzRDg/whDaBOuvYZooAAA8CUllTWqqqndI4+RIcD/EYbQJmKZJgcA8EGuKXLRYUEKC7aZXA2AtkYYQptwjQyxzxAAwJewXggILIQhtAm7OwwxMgQA8B10kgMCC2EIbSI23DVNjpEhAIDvcO0xFM/IEBAQCENoE+5pcuwzBADwIYwMAYGFMIQ24dpnqKSyRtUOp8nVAADQNKwZAgILYQhtIrZunyFJKqKjHADAR7j2GCIMAYGBMIQ2YbNaFBMWJIkmCgAA3+FaM8Q0OSAwEIbQZlxT5YpoogAA8BFMkwMCC2EIbcbVROEETRQAAD7A6TRUUFr7AR5hCAgMhCG0mdgIV3ttwhAAwPudKK+Sw2nIYpE6RIaYXQ6AdkAYQptxt9cuZ5ocAMD7uZonxEWEKNjGr0hAIOBfOtqMPdwVhhgZAgB4v4KSuilyNE8AAgZhCG3m1DQ5RoYAtL0nn3xS6enpCgsLU0ZGhtavX3/Gcy+55BJZLJYGt6uvvtp9zpQpUxrcf+WVV7bHU4FJ8ksrJLFeCAgkQWYXAP/lbqDAyBCANrZs2TLNnj1b2dnZysjI0KJFi5SVlaVt27YpMTGxwfnLly9XVdWpD2qOHTumwYMH68Ybb6x33pVXXqnnnnvO/XVoKL8k+zM6yQGBh5EhtBl7XRgqIgwBaGMLFy7U9OnTNXXqVPXv31/Z2dmKiIjQkiVLGj2/Q4cOSk5Odt/ef/99RURENAhDoaGh9c6Li4trj6cDk7jCUHwUzROAQEEYQptx7TN0ggYKANpQVVWVNmzYoMzMTPcxq9WqzMxMrVmzpknXWLx4sSZNmqTIyMh6x1etWqXExET16dNHM2bM0LFjx854jcrKShUXF9e7wbfQVhsIPIQhtBkaKABoDwUFBXI4HEpKSqp3PCkpSbm5ued8/Pr167V582ZNmzat3vErr7xSL7zwgnJycvTb3/5WH330ka666io5HI5Gr7NgwQLFxsa6b2lpac1/UjAF0+SAwMOaIbQZ18hQEfsMAfBiixcv1sCBAzVixIh6xydNmuT+74EDB2rQoEHq0aOHVq1apcsuu6zBdebMmaPZs2e7vy4uLiYQ+Rh3GIoKM7kSAO2FkSG0GVcDhdLKGlXVOE2uBoC/io+Pl81mU15eXr3jeXl5Sk5OPutjy8rK9Morr+iOO+445/fp3r274uPjtXPnzkbvDw0NVUxMTL0bfItrn6H4aNYMAYGCMIQ2Ex0WLIul9r8ZHQLQVkJCQjRs2DDl5OS4jzmdTuXk5GjkyJFnfexrr72myspK/eAHPzjn9zl48KCOHTumlJSUFtcM71PtcLrXuLLPEBA4CENoMzarRbHudUM0UQDQdmbPnq1nn31WS5cu1ZYtWzRjxgyVlZVp6tSpkqTJkydrzpw5DR63ePFiTZgwQR07dqx3vLS0VL/4xS+0du1a7d27Vzk5ORo/frx69uyprKysdnlOaF/Hy6pkGLXvXXERjAwBgYI1Q2hT9vBgFZZXq5CRIQBtaOLEicrPz9e8efOUm5urIUOGaOXKle6mCvv375fVWv/zv23btunTTz/Ve++91+B6NptNmzZt0tKlS1VYWKhOnTrpiiuu0COPPMJeQ37q9LbaVqvF5GoAtBfCENpUbESIdKycjnIA2tysWbM0a9asRu9btWpVg2N9+vSRYRiNnh8eHq533323NcuDlzsVhgi7QCBhmhzalKuJAnsNAQC8mat5Am21gcBCGEKbcu01VMTIEADAi51qq00YAgIJYQhtyrXXECNDAABvxoarQGAiDKFN2eumydFAAQDgzdx7DDEyBAQUwhDaFNPkAAC+oICRISAgEYbQpuIimSYHAPB+NFAAAhNhCG3q1KarjAwBALwXa4aAwNSsMPTkk08qPT1dYWFhysjI0Pr168947vPPPy+LxVLvFhYWVu+cKVOmNDjnyiuvbE5p8DKuXbwLGRkCAHipimqHSipqJBGGgEDj8aary5Yt0+zZs5Wdna2MjAwtWrRIWVlZ2rZtmxITExt9TExMjLZt2+b+2mJpuLPzlVdeqeeee879NTt8+wcaKAAAvF1B3RS5kCCrokPZjx4IJB6PDC1cuFDTp0/X1KlT1b9/f2VnZysiIkJLliw542MsFouSk5Pdt6SkpAbnhIaG1jsnLi7O09LghezhtSND5VUOVdY4TK4GAICGTt9jqLEPbAH4L4/CUFVVlTZs2KDMzMxTF7BalZmZqTVr1pzxcaWlperatavS0tI0fvx4ffvttw3OWbVqlRITE9WnTx/NmDFDx44d86Q0eKnosCBZ695X6CgHAPBGrBcCApdHYaigoEAOh6PByE5SUpJyc3MbfUyfPn20ZMkS/fOf/9Q//vEPOZ1OjRo1SgcPHnSfc+WVV+qFF15QTk6Ofvvb3+qjjz7SVVddJYej8ZGEyspKFRcX17vBO1mtllNNFJgqBwDwQnSSAwJXm0+MHTlypEaOHOn+etSoUerXr5+efvppPfLII5KkSZMmue8fOHCgBg0apB49emjVqlW67LLLGlxzwYIFmj9/fluXjlYSFxGiE+XVOlFGEwUAgPcpKKl9f2LDVSDweDQyFB8fL5vNpry8vHrH8/LylJyc3KRrBAcHa+jQodq5c+cZz+nevbvi4+PPeM6cOXNUVFTkvh04cKDpTwLtLpYmCgAAL5ZfWiGJkSEgEHkUhkJCQjRs2DDl5OS4jzmdTuXk5NQb/Tkbh8Ohb775RikpKWc85+DBgzp27NgZzwkNDVVMTEy9G7yXvW6aHGuGAADeiDVDQODyuJvc7Nmz9eyzz2rp0qXasmWLZsyYobKyMk2dOlWSNHnyZM2ZM8d9/sMPP6z33ntPu3fv1saNG/WDH/xA+/bt07Rp0yTVNlf4xS9+obVr12rv3r3KycnR+PHj1bNnT2VlZbXS04SZXHsNnWCvIQCAFzq9mxyAwOLxmqGJEycqPz9f8+bNU25uroYMGaKVK1e6myrs379fVuupjHXixAlNnz5dubm5iouL07Bhw7R69Wr1799fkmSz2bRp0yYtXbpUhYWF6tSpk6644go98sgj7DXkJ5gmBwDwZqcaKISYXAmA9tasBgqzZs3SrFmzGr1v1apV9b7+wx/+oD/84Q9nvFZ4eLjefffd5pQBH+EaGSpkZAgA4GUMw3A3UEiICjO5GgDtzeNpcoCn7K6RIdYMAQC8TFmVQyera7fyiGdkCAg4hCG0Ofc+Q4QhAICXca0XigoNUkRIm+84AsDLEIbQ5migAADwVq4wFB/FqBAQiAhDaHOuaXJFNFAAAHiZglLaagOBjDCENsfIEADAW7HHEBDYCENoc67W2hXVTlXULVIFAMAbsMcQENgIQ2hz0aFBslktkpgqBwDwLqfWDBGGgEBEGEKbs1gsstd1lGOqHADAm7BmCAhshCG0i1j2GgIAeKF8whAQ0AhDaBeuJgqFjAwBALwIDRSAwEYYQruws/EqAMDLGIbhnibHmiEgMBGG0C7c0+RooAAA8BJFJ6tV7TAkSR3ZdBUISIQhtAv2GgIAeBvXFDl7RLBCg2wmVwPADIQhtAvXNLkipskBALwEewwBIAyhXdgjGRkCAHiXfNYLAQGPMIR2QQMFAIC3oZMcAMIQ2oW9roFCEQ0UAABegj2GABCG0C5ooAAA8DaMDAEgDKFdxDJNDgDgZWigAIAwhHYRV9dAobLGqZNVDpOrAQBAKiitna0Qz8gQELAIQ2gXkSE2BVktkqTCk0yVAwCYj5EhAIQhtAuLxeJuosBUOQCA2RxOQ8fLWDMEBDrCENqNnSYKAAAvcaysUk5DslqkDnVTuQEEHsIQ2o1rr6EiRoYAACZzTZHrEBkqW900bgCBhzCEdnNqZIgwBAAwl6t5AlPkgMBGGEK7ca8ZooECAMBk7DEEQCIMoR0xTQ4A4C3oJAdAIgyhHbn2GqKBAgDAbK4wFB9N8wQgkBGG0G5iw2mtDQDwDgWljAwBIAyhHcXVNVAgDAEAzMaaIQASYQjtiAYKAABvkV9KGAJAGEI7YpocAMBb0EABgEQYQjtyNVAoLK+WYRgmVwMACFSVNQ4Vnaz9YI6RISCwEYbQblyttascTp2sdphcDQAgUB2r23A12GZxz1oAEJgIQ2g3ESE2hdhq/8qdYKocAMAkp0+Rs1gsJlcDwEyEIbQbi8WiWFcTBfYaAgCY5NQeQ0yRAwIdYQjtyjVVroiRIQCASdhjCIALYQjtyrXXENPkAABmYY8hAC6EIbSrWPYaAgCYjD2GALgQhtCu4iLYawgAYC73miGmyQEBjzCEdmWPcO01xMgQAMAcBYwMAahDGEK7cu3nwMgQAMAsrBkC4EIYQruigQIAwGyn7zMEILARhtCu7HVrhopooAAAMEF5VY3KqhyS2GcIAGEI7cwVhhgZAgCYoaCk9sO48GCbIkNsJlcDwGyEIbQre7irgQJhCADQ/vJLKyTVrheyWCwmVwPAbIQhtKvTp8kZhmFyNQCAQEPzBACnIwyhXbkaKFQ7DPecbQAA2gvNEwCcjjCEdhUWbFVIUO1fO/YaAgC0N/eGq9EhJlcCwBsQhtCuLBaL4iLYawgAYI780toP4hKiwkyuBIA3IAyh3dFEAQBgFtYMATgdYQjtLtY1MsReQwCAdpZfShgCcAphCO0ujr2GAAAmKXCtGYpizRAAwhBM4JomV0QDBQBAOzIMg5EhAPUQhtDu7JGMDAEA2l9xRY2qapySpHhaawMQYQgmoIECAMAMruYJMWFBCgu2mVwNAG9AGEK7s9etGSqigQIAoB2d2mOIUSEAtQhDaHc0UAAAmKHAtV6IKXIA6hCG0O5i3dPkGBkCALQf9hgC8F2EIbS7uLoGCqwZAgC0JzrJAfguwhDanbuBwslqGYZhcjUAgEDhXjPENDkAdQhDaHeuBgoOp6HSyhqTqwEABIoCRoYAfAdhCO0uLNimsODav3pMlQMAtBfWDAH4LsIQTMFeQwCA9uYOQ0yTA1CHMART2N3ttekoBwBoe06noWNlte85jAwBcCEMwRSuMFR4kpEhAEDbO1FeJYfTkMUidYgMMbscAF6CMARTuKbJFTEyBKCVPPnkk0pPT1dYWJgyMjK0fv36M557ySWXyGKxNLhdffXV7nMMw9C8efOUkpKi8PBwZWZmaseOHe3xVNAGXG21O0SEKNjGrz8AavFqAFO49ho6wZohAK1g2bJlmj17th588EFt3LhRgwcPVlZWlo4ePdro+cuXL9eRI0fct82bN8tms+nGG290n/PEE0/oT3/6k7Kzs7Vu3TpFRkYqKytLFRUV7fW00IpongCgMYQhmCKWBgoAWtHChQs1ffp0TZ06Vf3791d2drYiIiK0ZMmSRs/v0KGDkpOT3bf3339fERER7jBkGIYWLVqkBx54QOPHj9egQYP0wgsv6PDhw3rrrbfa8ZmhtbDHEIDGEIZgijjXmiGmyQFooaqqKm3YsEGZmZnuY1arVZmZmVqzZk2TrrF48WJNmjRJkZGRkqQ9e/YoNze33jVjY2OVkZHR5GvCu7DHEIDGBJldAAITDRQAtJaCggI5HA4lJSXVO56UlKStW7ee8/Hr16/X5s2btXjxYvex3Nxc9zW+e03Xfd9VWVmpyspK99fFxcVNfg5oe0yTA9AYRoZgilPT5BgZAmCuxYsXa+DAgRoxYkSLrrNgwQLFxsa6b2lpaa1UIVoDewwBaAxhCKY4NU2OkSEALRMfHy+bzaa8vLx6x/Py8pScnHzWx5aVlemVV17RHXfcUe+463GeXHPOnDkqKipy3w4cOODpU0EbcnWTi4+mrTaAUwhDMIU9om5kiGlyAFooJCREw4YNU05OjvuY0+lUTk6ORo4cedbHvvbaa6qsrNQPfvCDese7deum5OTketcsLi7WunXrznjN0NBQxcTE1LvBe5waGQozuRIA3oQ1QzDF6Q0UnE5DVqvF5IoA+LLZs2frtttu0/DhwzVixAgtWrRIZWVlmjp1qiRp8uTJSk1N1YIFC+o9bvHixZowYYI6duxY77jFYtHdd9+tRx99VL169VK3bt00d+5cderUSRMmTGivp4VWVFBaOy2bNUMATkcYgiliwmvDkNOQSiprFFv3NQA0x8SJE5Wfn6958+YpNzdXQ4YM0cqVK90NEPbv3y+rtf5kiG3btunTTz/Ve++91+g177vvPpWVlenOO+9UYWGhxowZo5UrVyosjJEFX1PtcOp4GWEIQEMWwzAMs4toqeLiYsXGxqqoqIhpCT6k39yVOlnt0Me/uFRdOkaYXQ6AZuI1uHH8XLxHblGFLlyQoyCrRdsfvYrZCICf8+T1lzVDMI1rqtwJOsoBANqQa71Qx6gQghCAeghDME0sTRQAAO2ADVcBnAlhCKY5vYkCAABthT2GAJwJYQimsbPXEACgHeQzMgTgDJoVhp588kmlp6crLCxMGRkZWr9+/RnPff7552WxWOrdvtuJxzAMzZs3TykpKQoPD1dmZqZ27NjRnNLgQ2LD66bJEYYAAG3INTIUz8gQgO/wOAwtW7ZMs2fP1oMPPqiNGzdq8ODBysrK0tGjR8/4mJiYGB05csR927dvX737n3jiCf3pT39Sdna21q1bp8jISGVlZamiosLzZwSfQQMFAEB7YGQIwJl4HIYWLlyo6dOna+rUqerfv7+ys7MVERGhJUuWnPExFotFycnJ7ptr3wepdlRo0aJFeuCBBzR+/HgNGjRIL7zwgg4fPqy33nqrWU8KvsE1Ta6IBgoAgDbkXjNEGALwHR6FoaqqKm3YsEGZmZmnLmC1KjMzU2vWrDnj40pLS9W1a1elpaVp/Pjx+vbbb9337dmzR7m5ufWuGRsbq4yMjDNes7KyUsXFxfVu8D32um5yjAwBANpSAQ0UAJyBR2GooKBADoej3siOJCUlJSk3N7fRx/Tp00dLlizRP//5T/3jH/+Q0+nUqFGjdPDgQUlyP86Tay5YsECxsbHuW1pamidPA17CHk4DBQBA23OvGWJkCMB3tHk3uZEjR2ry5MkaMmSILr74Yi1fvlwJCQl6+umnm33NOXPmqKioyH07cOBAK1aM9uIaGWKaHACgrVRUO1RSWSOJaXIAGvIoDMXHx8tmsykvL6/e8by8PCUnJzfpGsHBwRo6dKh27twpSe7HeXLN0NBQxcTE1LvB99BAAQDQ1lyjQqFBVkWHBplcDQBv41EYCgkJ0bBhw5STk+M+5nQ6lZOTo5EjRzbpGg6HQ998841SUlIkSd26dVNycnK9axYXF2vdunVNviZ8U+xpDRScTsPkagAA/uj0TnIWi8XkagB4G48/Ipk9e7Zuu+02DR8+XCNGjNCiRYtUVlamqVOnSpImT56s1NRULViwQJL08MMP68ILL1TPnj1VWFio3/3ud9q3b5+mTZsmqbbT3N13361HH31UvXr1Urdu3TR37lx16tRJEyZMaL1nCq9jr9tnyDCk4opq97Q5AABaC3sMATgbj8PQxIkTlZ+fr3nz5ik3N1dDhgzRypUr3Q0Q9u/fL6v11IDTiRMnNH36dOXm5iouLk7Dhg3T6tWr1b9/f/c59913n8rKynTnnXeqsLBQY8aM0cqVKxtszgr/EhJkVWSITWVVDhWWE4YAAK2vgD2GAJyFxTAMn5+fVFxcrNjYWBUVFbF+yMeM/s2HOlR4Um/NHK0haXazywHQDLwGN46fi3dY9MF2Lfpgh27J6KLHrx1odjkA2oEnr79t3k0OOBs7TRQAAG0onz2GAJwFYQimcoWhIvYaAgC0AfYYAnA2hCGYyrVOiJEhAEBbcK8ZYmQIQCMIQzCVPbx2ZKiQkSEAQBvIp4ECgLMgDMFU9tP2GgIAoDUZhuGeJpdIGALQCMIQTBXHNDkAQBspraxRRbVTEvsMAWgcYQimimWaHACgjRSU1n7QFhUapPAQm8nVAPBGhCGYyjUyVMjIEACglbnbajNFDsAZEIZgKteaoULWDAEAWhl7DAE4F8IQTOUOQ0yTAwC0svySCklSfHSIyZUA8FaEIZjKtc9QcUW1HE7D5GoAAP4knz2GAJwDYQimcjVQMAypmKlyAIBWVFBSux6VNUMAzoQwBFMF26yKDg2SRHttAEDrYsNVAOdCGILpYmmiAABoA3STA3AuhCGYztVEoYgmCgCAVuQKQ2y4CuBMCEMwnWuvIabJAQBai9Np6FgZI0MAzo4wBNO5mijQXhsA0FqKTlar2lHbpbRjJGEIQOMIQzCda2SokJEhAEArcTVPiIsIVkgQv+4AaByvDjCdnQYKAIBWxnohAE1BGILpmCYHAGhtBbTVBtAEhCGYjgYKAIDWRlttAE1BGILp3K21mSYHAGgl7jDENDkAZ0EYgunsjAwBAFqZe80QI0MAzoIwBNO5GyiwZggA0Epc3eQYGQJwNoQhmM5e10ChpKJGNQ6nydUAAPwBa4YANAVhCKZzdZOTWDcEAGgddJMD0BSEIZguyGZVdFiQJPYaAgC0XI3DqWNltetQ2WcIwNkQhuAVXO21C2miAABooePlVTIMyWqROkSGmF0OAC9GGIJXoIkCAKC1uNYLdYwKlc1qMbkaAN6MMASv4Fo3RBgCALQUewwBaCrCELxCHHsNAQBaCXsMAWgqwhC8gmuaHN3kAAAtVVBa+8EaI0MAzoUwBK9gZ2QIANBK2GMIQFMRhuAV7KwZAgC0knz2GALQRIQheAWmyQEAWkt+SYUkKT6KttoAzo4wBK9AAwUAQGthmhyApiIMwSvEss8QAKCVuBooJBKGAJwDYQhewTUyRBgCALREZY3DPeU6ISrM5GoAeDvCELyCq4FCaWWNqh1Ok6sBAPgq16hQiM2qmPAgk6sB4O0IQ/AKMeHBslhq/5smCgCA5nJvuBoVIovrjQUAzoAwBK9gs1oUE+ZaN0QTBQBA8xTQPAGABwhD8Bp2migAAFqIPYYAeIIwBK9hd7fXJgwBAJqHttoAPEEYgtdwNVFgmhwAoLlOrRkiDAE4N8IQvIZrmhwNFAAAzVXANDkAHiAMwWvEuafJMTIEAGge9zQ5RoYANAFhCF4jNpwGCgCAlqGBAgBPEIbgNeJc3eSYJgcAaCbWDAHwBGEIXsPVTY4GCgCA5iirrFF5lUMSI0MAmoYwBK8Ryz5DAIAWcDVPiAixKTI0yORqAPgCwhC8Rpx7ZIgwBADwHHsMAfAUYQheg32GAAAtwXohAJ4iDMFruEaGyqocqqpxmlwNAMDXuPcYIgwBaCLCELxGdFiQLJba/y48yegQAMAzTJMD4CnCELyG1Wpx7zVUxLohAICH2GMIgKcIQ/AqrqlyJwhDAAAPsWYIgKcIQ/AqsTRRAAA0U35p7XsHI0MAmoowBK8S59pr6CQjQwAAzxSwZgiAhwhD8Cp2915DjAwBAJrOMAwaKADwGGEIXuXUNDlGhgAATVd8skZVjtptGTpGhphcDQBfQRiCV6GBAgCgOfJLKyRJMWFBCgu2mVwNAF9BGIJXsdetGSpinyEAgAfyS2ieAMBzhCF4FVcYYpocAMAT7DEEoDkIQ/AqdqbJAQCagT2GADQHYQhexV7XQKGIbnIAAA/QSQ5AcxCG4FVooAAAaI4CpskBaAbCELxKbN2aoZPVDlVUO0yuBgDgK9wjQ0yTA+ABwhC8SkxYkGxWiySp+CSjQwCApmGaHIDmIAzBq1gsFvfGq0yVAwA0laubHA0UAHiCMASv42qiUEgTBQBAEzicho6X1b5nJDIyBMADhCF4HddeQ4wMAfDEk08+qfT0dIWFhSkjI0Pr168/6/mFhYWaOXOmUlJSFBoaqt69e2vFihXu+x966CFZLJZ6t759+7b100AznCivksNpyGKROkSGmF0OAB8SZHYBwHe59hoqOsnIEICmWbZsmWbPnq3s7GxlZGRo0aJFysrK0rZt25SYmNjg/KqqKl1++eVKTEzU66+/rtTUVO3bt092u73eeeedd54++OAD99dBQbxteiPXeqGOkSEKsvE5L4Cm41UdXsc1MlTIyBCAJlq4cKGmT5+uqVOnSpKys7P1zjvvaMmSJbr//vsbnL9kyRIdP35cq1evVnBw7WtOenp6g/OCgoKUnJzcprWj5dhwFUBz8fEJvI49nL2GADRdVVWVNmzYoMzMTPcxq9WqzMxMrVmzptHHvP322xo5cqRmzpyppKQkDRgwQI8//rgcjvot/Xfs2KFOnTqpe/fuuvXWW7V///42fS5oHvYYAtBcjAzB67hGhpgmB6ApCgoK5HA4lJSUVO94UlKStm7d2uhjdu/erQ8//FC33nqrVqxYoZ07d+onP/mJqqur9eCDD0qSMjIy9Pzzz6tPnz46cuSI5s+fr7Fjx2rz5s2Kjo5ucM3KykpVVla6vy4uLm7FZ4mzYY8hAM1FGILXiXM1UChjZAhA23A6nUpMTNQzzzwjm82mYcOG6dChQ/rd737nDkNXXXWV+/xBgwYpIyNDXbt21auvvqo77rijwTUXLFig+fPnt9tzwCnsMQSguZgmB68TW9dAoZCRIQBNEB8fL5vNpry8vHrH8/LyzrjeJyUlRb1795bNZnMf69evn3Jzc1VV1fhrj91uV+/evbVz585G758zZ46KiorctwMHDjTzGcFT7DEEoLkIQ/A6cTRQAOCBkJAQDRs2TDk5Oe5jTqdTOTk5GjlyZKOPGT16tHbu3Cmn0+k+tn37dqWkpCgkpPHWzKWlpdq1a5dSUlIavT80NFQxMTH1bmgfrBkC0FyEIXgdVwMFwhCAppo9e7aeffZZLV26VFu2bNGMGTNUVlbm7i43efJkzZkzx33+jBkzdPz4cd11113avn273nnnHT3++OOaOXOm+5x7771XH330kfbu3avVq1fr2muvlc1m080339zuzw9nxzQ5AM3FmiF4HXdrbabJAWiiiRMnKj8/X/PmzVNubq6GDBmilStXupsq7N+/X1brqc//0tLS9O677+qee+7RoEGDlJqaqrvuuku//OUv3eccPHhQN998s44dO6aEhASNGTNGa9euVUJCQrs/P5wdYQhAczVrZMjTXb5dXnnlFVksFk2YMKHe8SlTpjTY5fvKK69sTmnwA64wVFHtVEW14xxnA0CtWbNmad++faqsrNS6deuUkZHhvm/VqlV6/vnn650/cuRIrV27VhUVFdq1a5d+9atf1VtD9Morr+jw4cOqrKzUwYMH9corr6hHjx7t9XTQRNUOp3srBtYMAfCUx2HItcv3gw8+qI0bN2rw4MHKysrS0aNHz/q4vXv36t5779XYsWMbvf/KK6/UkSNH3LeXX37Z09LgJ6JCgxRktUhiqhwA4OyOldbOIgiyWmQPDza5GgC+xuMwdPou3/3791d2drYiIiK0ZMmSMz7G4XDo1ltv1fz589W9e/dGzwkNDVVycrL7FhcX52lp8BMWi4WpcgCAJnFNkYuPCpW17oM0AGgqj8JQc3b5lqSHH35YiYmJje7L4LJq1SolJiaqT58+mjFjho4dO+ZJafAzseHsNQQAOLf80gpJrBcC0DweNVBozi7fn376qRYvXqyvvvrqjNe98sordd1116lbt27uedtXXXWV1qxZU2/+tgu7fPs/e0SIpDIVMTIEADiLUyNDjbdEB4CzadNuciUlJfrhD3+oZ599VvHx8Wc8b9KkSe7/HjhwoAYNGqQePXpo1apVuuyyyxqczy7f/s+119AJ1gwBAM6CTnIAWsKjaXKe7vK9a9cu7d27V+PGjVNQUJCCgoL0wgsv6O2331ZQUJB27drV6Pfp3r274uPj2eU7gMWy1xAAoAkK6hooEIYANIdHI0On7/Ltao/t2uV71qxZDc7v27evvvnmm3rHHnjgAZWUlOiPf/yj0tLSGv0+Bw8e1LFjx866y3doKC96/iyOBgoAgCZwjwzRVhtAM3g8TW727Nm67bbbNHz4cI0YMUKLFi1qsMt3amqqFixYoLCwMA0YMKDe4+12uyS5j5eWlmr+/Pm6/vrrlZycrF27dum+++5Tz549lZWV1cKnB1/l7iZHAwUAwFm41wwxMgSgGTwOQ57u8n0uNptNmzZt0tKlS1VYWKhOnTrpiiuu0COPPMLoTwCLjaibJsfIEADgLPJLGRkC0HzNaqAwa9asRqfFSbUtss/muzuAh4eH6913321OGfBjNFAAADRFAQ0UALSAx5uuAu3BXtdAoYgwBAA4g5NVDpVU1kgiDAFoHsIQvJKdBgoAgHMoqJsiFxpkVVRom+4WAsBPEYbgleynTZMzDMPkagD/5nDybwy+6ehpU+QsFovJ1QDwRYQheCV7XQOFqhqnKqqdJlcD+K+i8mpdvvAj/X3tPjkJRfAxrpEhpsgBaC7CELxSZIhNwbbaT/lOlDNVDmgrv1m5RbsLyrR09V5VO/ngAb6FPYYAtBRhCF7JYrEotq6JQiFNFIA2sW73Mb28/oAkacF1AxUaZDO5IsAz+XSSA9BChCF4LZooAG2nssahOW9+I0m6eUQXXZDeweSKAM+59hiKZ2QIQDMRhuC1XHsNMTIEtL6/rtql3fllSogO1f1X9TW7HKBZ2GMIQEsRhuC1mCYHtI2dR0v01H93SZIeGneeYsODTa4IaJ58GigAaCHCELxWnLu9NtPkgNbidBr61fLNqnI4dVnfRH1vYLLZJQHNxpohAC1FGILXcq0ZKjrJyBDQWpZ9cUDr9x5XRIhND08YwN4s8FmGYdBNDkCLEYbgtVx7DRUyMgS0iqMlFXp8xRZJ0s+v6KNUe7jJFQHNV1pZo8qa2nbwNFAA0FyEIXgtu3uaHCNDQGt4+F//U0lFjQZ1jtWUUelmlwO0iGtUKDo0SOEhtIUH0DyEIXgte10DhSLCENBiH27N0783HZHNatHj1w6Uzcr0OPg21gsBaA2EIXgtGigAraOsskZz3/pWknTHmG4akBprckVAy7HHEIDWQBiC14p1b7rKyBDQEgvf365DhSfVOS5cd2f2MrscoFWwxxCA1kAYgtdyNVAoKq+WYRgmVwP4pm8OFum5z/ZIkh6dMEARIUEmVwS0DvYYAtAaCEPwWq5pclUOp8qrHCZXA/ieGodT9y/fJKchjR/SSZf0STS7JKDVsGYIQGsgDMFrhQfbFGKr/SvKVDnAc899tlffHi5WbHiw5l7T3+xygFblCkPxUSEmVwLAlxGG4LUsFsup9tplNFEAPHHgeLkWvr9dkvTr7/VjkTn8TkFp7fsCI0MAWoIwBK/mCkNFjAwBTWYYhh54a7NOVjt0YfcOunF4Z7NLAlqde5pcVJjJlQDwZYQheDXXXkOF7DUENNm/Nh3RR9vzFRJk1ePXDpTFwp5C8C9Op6ECGigAaAWEIXg1O3sNAR4pLK/Sw/+q3VNo1qU91T0hyuSKgNZXeLJaNc7aLqMdWTMEoAUIQ/BqTJMDPLNgxVYVlFapV2KUfnxxD7PLAdqEa4pcXESwgm38KgOg+XgFgVeLq9triAYKwLmt2XVMy744IElacN1AhQTxEg//xBQ5AK2Fd0p4tdi6kSFaawNnV1Ht0K/f/EaSdGtGFw1P72ByRUDbYY8hAK2FMASvRgMFoGmeWrVLuwvKlBAdqvuu7Gt2OUCbOrXHEGEIQMsQhuDV4lwjQzRQAM5oR16J/rpqpyRp/vfPU2x4sMkVAW0r3zVNjjAEoIUIQ/BqTJMDzs7pNDRn+TeqdhjK7JeoqwYkm10S0OaOFldIkuKZJgeghQhD8GquBgqMDAGNe/nz/fpi3wlFhtj08PgB7CmEgPC/I8WSpO7xkSZXAsDXEYbg1ezuaXLVMgzD5GoA73K0uEK/+c9WSdLPr+ijTvZwkysC2l5xRbV2HC2VJJ3fNc7kagD4OsIQvJqrgUKN01BZlcPkagDvMv9f/1NJRY0GdY7VbaPSzS4HaBdfHyiUYUhdOkTQQAFAixGG4NXCQ2wKrdsrhb2GgFM++F+e3vnmiGxWixZcN1A2K9PjEBg27iuUJA3tYje1DgD+gTAEr+eaKldEEwVAklRaWaN5/9wsSZo2tpvO6xRrckVA+9m4/4Qk6fwuTJED0HKEIXg9VxOFEzRRACRJv39vmw4XVSitQ7juvqy32eUA7cbpNPTVgUJJhCEArYMwBK/n2jOFjVeB2vUSS1fvlSQ9NmGgwkNs5hYEtKPdBWUqOlmtsGCr+qZEm10OAD9AGILXs7PXECBJqnY4df/yb+Q0pAlDOumi3glmlwS0qy/rpsgNSrUr2MavMABajlcSeD33XkM0UECAW/LpHm05Uix7RLAeuKa/2eUA7W7j/kJJ0tCudlPrAOA/CEPwerGMDAHaf6xcf/hguyTp19/rR0thBCTXyNDQNNYLAWgdhCF4PRooINAZhqFfv/WNKqqdGtm9o24Y1tnskoB2V1pZo215JZKk82mrDaCVEIbg9ex1DRSKaKCAAPX214f1yY4ChQRZ9di1A2SxsKcQAo9rs9VUe7gSY8LMLgeAnyAMwevRQAGB7ERZlR7+1/8kST/7fz3VPSHK5IoAc2zcV7e/UFemyAFoPYQheD070+QQwB5fsUXHyqrUOylKd17Uw+xyANN86d5fyG5qHQD8C2EIXs81MsQ0OQSa1bsK9NqGg5KkBdcNVEgQL9kITIZhnGqewGarAFoR76zweu7W2ierZRiGydUA7aOi2qFfv7lZkvSDC7toWNcOJlcEmGdPQZlOlFcrNMiq/ikxZpcDwI8QhuD1YusaKDichkoqa0yuBmgfT/53p/YUlCkxOlT3XdnX7HIAU31Zt7/QwNRYRkgBtCpeUeD1woJtCguu/avKVDkEgm25Jfrrql2SpIfHn6eYsGCTKwLMtdE9Rc5ubiEA/A5hCD6BvYYQKJxOQ3OWb1KN01BmvyRlnZdsdkmA6TbWjQydz3ohAK2MMASf4JoqV8jIEPzci+v3a+P+QkWG2PTw+PPYUwgBr6yyRttyiyXRVhtA6yMMwScwMoRAkFdcoSf+s1WS9IusPupkDze5IsB8Xx8slNOQOsWGKYnNVgG0MsIQfIK7vTYbr8KPPfT2tyqprNHgNLt+ODLd7HIAr+BqnjCUUSEAbYAwBJ/gCkNMk4O/eu/bXP1nc65sVot+c91A2axMjwMkufcXYr0QgLZAGIJPsDNNDn6stLJGD779rSRp+tju6sc+KoCk2s1WXc0T6CQHoC0QhuAT7HUNFGitDX/0f+9u05GiCnXpEKG7LutldjmA19h3rFzHy6oUYrPqvE58SACg9RGG4BNooAB/9dWBQi1ds1eS9Ni1AxQeYjO3IMCLfHmgdorcgNQYhQbxbwNA6yMMwSfEutYM0UABfqTa4dT9b2ySYUjXDk3V2F4JZpcEeJWN+wolSUNZLwSgjRCG4BOYJgd/9LdP9mhrbonsEcF64Op+ZpcDeJ2NNE8A0MYIQ/AJcZFMk4N/2XesTIs+2C5JeuDq/uoYFWpyRYB3Ka+q0dbcEknS+V3t5hYDwG8RhuAT3CNDJ6vldBomVwO0jGEYeuCtzaqscWpUj466/vxUs0sCvM6mg0VyOA0lx4QpJZYNiAG0DcIQfIJrzZDTkEoqakyuBmiZt746pE92FCg0yKrHrx0oi4U9hYDvck+RY1QIQBsiDMEnhAbZFFHXZavwJFPl4LuqHU795j9bJUk/u6yX0uMjTa4I8E5f1u0vxHohAG2JMASf4ZoqV0gTBfiwnC15yiuuVHxUiKaP7W52OYBXMgxDX9aNDLHZKoC2RBiCz7Cz1xD8wIvr9kuSbhyeppAgXoKBxhw4flIFpVUKtll0XqdYs8sB4Md4J4bPsEecaqIA+KL9x8r1yY4CSdLNF3QxuRrAe7k2Wz2vU6zCgtlsFUDbIQzBZ8S5RobKGBmCb3r589pRobG94tWlY4TJ1QDea+M+psgBaB+EIfgMV0e5QkaG4IOqapx67YsDkqRbMxgVAs5mI80TALQTwhB8Bg0U4Mve/1+eCkqrlBAdqsv6JZldDuC1TlY5tOVIsSRGhgC0PcIQfIZrmlwhDRTgg15av0+SNHF4moJtvPQCZ/LNoSLVOA0lRocq1c5mqwDaFu/I8BlMk4Ov2ltQps92HpPFIk0akWZ2OYBXc2+22iWODYkBtDnCEHyGu4EC0+TgY15eX9s44eLeCeocR+ME4GzYXwhAeyIMwWe4W2szTQ4+pLLGodc2HJQk3TKCxgnA2RiGcap5QleaJwBoe4Qh+Ax3AwWmycGHvPttno6XVSk5Jkz/r2+i2eUAXu3giZPKL6lUkNWigalstgqg7RGG4DPsddPkik5Wy+E0TK4GaJoX19Y2TrjpgjQF0TgBOKsvDxRKkvp3imGzVQDtgndm+IzYupEhw5BKKhgdgvfbebRU6/Ycl9UiTbqAxgnAubg2W2V/IQDthTAEnxESZFVUaJAkmijAN7gaJ1zaJ1GdaBEMnBPNEwC0N8IQfEqse+NVmijAu1VUO/TGxtrGCbdeSOOE9vDkk08qPT1dYWFhysjI0Pr16896fmFhoWbOnKmUlBSFhoaqd+/eWrFiRYuuiearqHbo28O1m60yMgSgvRCG4FMSokMlSfuPl5tcCXB2/9l8RIXl1Uq1h+vi3jROaGvLli3T7Nmz9eCDD2rjxo0aPHiwsrKydPTo0UbPr6qq0uWXX669e/fq9ddf17Zt2/Tss88qNTW12ddEy2yu22w1PipUneMYSQXQPghD8CkXpNd+Wrh65zGTKwHO7qV1tVPkJl6QJpuVjSPb2sKFCzV9+nRNnTpV/fv3V3Z2tiIiIrRkyZJGz1+yZImOHz+ut956S6NHj1Z6erouvvhiDR48uNnXRMt86Wqp3cXOZqsA2g1hCD5lVM94SdJnuwpMrgQ4s+15Jfp87wnZrBZNpHFCm6uqqtKGDRuUmZnpPma1WpWZmak1a9Y0+pi3335bI0eO1MyZM5WUlKQBAwbo8ccfl8PhaPY1KysrVVxcXO+GptvoXi/EFDkA7YcwBJ8yIr2DgqwWHTxxUvuPMVUO3sk1KnRZ30QlxYSZXI3/KygokMPhUFJSUr3jSUlJys3NbfQxu3fv1uuvvy6Hw6EVK1Zo7ty5+v3vf69HH3202ddcsGCBYmNj3be0NIJwU9VuturqJGc3txgAAYUwBJ8SGRqkIWl2SdJqRofghSqqHVpe1zjhlgwaJ3grp9OpxMREPfPMMxo2bJgmTpyoX//618rOzm72NefMmaOioiL37cCBA61YsX87UlShvOJK2awWDepsN7scAAGkWWGoud11XnnlFVksFk2YMKHeccMwNG/ePKWkpCg8PFyZmZnasWNHc0pDADg1VY51Q/A+/950RMUVNeocF66LeiWYXU5AiI+Pl81mU15eXr3jeXl5Sk5ObvQxKSkp6t27t2y2Uxt79uvXT7m5uaqqqmrWNUNDQxUTE1PvhqZxjQr1S4lWeAibrQJoPx6HoeZ219m7d6/uvfdejR07tsF9TzzxhP70pz8pOztb69atU2RkpLKyslRRUeFpeQgAo3t0lCSt2VUgwzBMrgao76V1+yRJN4/oIiuNE9pFSEiIhg0bppycHPcxp9OpnJwcjRw5stHHjB49Wjt37pTT6XQf2759u1JSUhQSEtKsa6L5Nu4rlERLbQDtz+Mw1JzuOg6HQ7feeqvmz5+v7t2717vPMAwtWrRIDzzwgMaPH69BgwbphRde0OHDh/XWW295/ITg/4Z2iVN4sE0FpVXalldidjmA29bcYm3cX6ggq0U3Du9sdjkBZfbs2Xr22We1dOlSbdmyRTNmzFBZWZmmTp0qSZo8ebLmzJnjPn/GjBk6fvy47rrrLm3fvl3vvPOOHn/8cc2cObPJ10Tr+fKAa70QYQhA+wry5GRXd53T31DO1V1Hkh5++GElJibqjjvu0CeffFLvvj179ig3N7dex57Y2FhlZGRozZo1mjRpUoPrVVZWqrKy0v01HXsCS0iQVRd066CPt+frs53H1DeZqSjwDq7GCZf3T1JiNI0T2tPEiROVn5+vefPmKTc3V0OGDNHKlSvdDRD2798vq/XU539paWl69913dc8992jQoEFKTU3VXXfdpV/+8pdNviZaR2WNQ98eqn0fH0rzBADtzKMwdLbuOlu3bm30MZ9++qkWL16sr776qtH7XV15PO3YM3/+fE9Kh58Z3aOjPt6er9U7C3THmG5mlwOovKpGb248JInGCWaZNWuWZs2a1eh9q1atanBs5MiRWrt2bbOvidax+VCxqhxOdYwMUZcOEWaXAyDAtGk3uZKSEv3whz/Us88+q/j4+Fa7Lh17MLquicK6PcdV43Ce42yg7f376yMqqaxRlw4RGt2j9V7vAH/35Wn7C7HZKoD25tHIkKfddXbt2qW9e/dq3Lhx7mOuxapBQUHatm2b+3F5eXlKSUmpd80hQ4Y0WkdoaKhCQ0M9KR1+pn9KjOwRwSosr9bXB4s0rCvzzGGuF9fXTpGjcQLgmS/3F0piihwAc3g0MuRpd52+ffvqm2++0VdffeW+ff/739ell16qr776SmlpaerWrZuSk5PrXbO4uFjr1q2jYw/OyGq1aGT32q5yq3ey3xDM9e3hIn19oFDBNhonAJ46tdkqH2oBaH8ejQxJtd11brvtNg0fPlwjRozQokWLGnTsSU1N1YIFCxQWFqYBAwbUe7zdbpekesfvvvtuPfroo+rVq5e6deumuXPnqlOnTg32IwJON6pnvP6zOVef7SrQTy/rZXY5CGCuxglXnJes+ChGrYGmOlJ0UkeKKmS1SIM6x5pdDoAA5HEY8rRjT1Pcd999Kisr05133qnCwkKNGTNGK1euVFgY3ZhwZq79hjbuK9TJKgcb9cEUZZU1+udXhyVJt46gcQLgCdcUub7JMYoM9fhXEgBosWa98njased0zz//fINjFotFDz/8sB5++OHmlIMA1S0+UimxYTpSVKEv9h3X2F4JZpeEAPT214dVWlmjbvGRGlkX0AE0zcZ9dVPkutrNLQRAwGrTbnJAW7JYLBpV17Xrs53HTK4Ggco1Re6WEV3ohAV46MsDhZKkoWmsFwJgDsIQfNronnVNFHbRRAHtb9PBQn1zqEghNquuH0bjBMATVTVOfXOoSJJ0Ph1BAZiEMASf5tpv6JtDRSoqrza5GgQa16jQVQOT1SEyxORqAN/y7eEiVdU4FRcRrPSObLYKwByEIfi0pJgw9UiIlGFIa3YzVQ7tp6SiWm9/Xds44RYaJwAeO7W/EJutAjAPYQg+zzU6xFQ5tKe3vjqs8iqHeiZGaUS3DmaXA/icU/sL2c0tBEBAIwzB551qokAYQvswDMM9Re5mGicAzeIaGWKzVQBmIgzB543s3lFWi7Qrv0y5RRVml4MA8NWBQm05UqyQIKuuPz/V7HIAn5NXXKFDhSdrN1tNs5tdDoAARhiCz4uNCNaA1Nqdy5kqh/bgGhW6ZmCK7BE0TgA89WXdFLneSdGKYrNVACYiDMEvsN8Q2kvRyWr9a1Nd44QMGicAzbHRNUWOltoATEYYgl8Y1ePUfkOGYZhcDfzZW18eUkW1U72TojSMX+SAZnGNDA1lihwAkxGG4BcuSO+gEJtVR4oqtPdYudnlwE+d3jjhFhonAM1SVePUpoNstgrAOxCG4BfCQ2waWteela5yaCsb95/QtrwShQVbde35nc0uB/BJW44Uq7LGqdjwYHWPjzS7HAABjjAEv8F+Q2hrL7oaJwzqpNjwYJOrAXyTe4pcFzujqwBMRxiC3xjds3bd0Jpdx+R0sm4IrauovFrvbDoiicYJQEtsZH8hAF6EMAS/MaizXZEhNp0or9b/jhSbXQ78zBsbD6qyxqm+ydEs+gZaYGPdyBBhCIA3IAzBbwTbrMrofqqrHNBaDMPQS+trp8jdmkHjBKC5jpZU6OCJk7JYpMFpsWaXAwCEIfgXV4tt9htCa/p87wntPFqq8GCbxg9NNbscwGd9WTdFrnditKLDWHcHwHyEIfgVVxOF9XuOq6rGaXI18BcvrdsnSRo/pJNi+AUOaDb3FLmudnMLAYA6hCH4lT5J0eoYGaKT1Q59daDQ7HLgB06UVWnF5lxJNE4AWso1MjQ0jfVCALwDYQh+xWq1aKR7qhzrhtByb2w8qKoapwakxmhQZ7vZ5QA+q9rh1KaDhZIYGQLgPQhD8DvsN4TWcnrjhFtGdDW5GsC3bT1Soopqp2LCgtQ9PsrscgBAEmEIfmh0j9ow9OX+QpVV1phcDXzZ2t3HtTu/TJEhNn1/SCezywF82pcHatcLDekSJ6uVjowAvANhCH6nS8cIdY4LV43T0Pq9x80uBz7sRVfjhKGpigoNMrkawLdt3OfaX8hubiEAcBrCEPySa3RoNeuG0EwFpZV699u6xgkjaJwAtNSXdU1thrLZKgAvQhiCXxrVk/2G0DKvbzioaoehwZ1jNSCVzSGBligordS+Y+WSpCFpdnOLAYDTEIbgl0bVjQz970ixjpdVmVwNfI3TaehlV+ME2mkDLeZqqd0rMUqx4ezVBcB7EIbglxKiQ9UnKVqStGYXo0PwzOpdx7TvWLmiQ4M0bjCNE4CW+rJus9WhrBcC4GUIQ/Bb7qlytNiGh15aX9s4YcLQVEWE0DgBaKmN+13NE1gvBMC7EIbgt2iigOY4WlKh977Nk8QUOaA11Dic+vpAkSTp/K6EIQDehTAEv5XRvYNsVov2HivXocKTZpcDH/HaFwdV4zQ0tItd/VJizC4H8Hnb8kp0stqh6NAg9Uxgs1UA3oUwBL8VHRasgXVdwD5jdAhN4HQaeuXzusYJtNMGWsXGuuYJQ7rY2WwVgNchDMGvja5bN8RUOTTFJzsLdOD4SUWHBemaQTROAFrDl/tczROYIgfA+xCG4Nfc64Z2HZNhGCZXA2/30rraxgnXn99Z4SE2k6sB/MOpzVbtptYBAI0hDMGvnd81TqFBVh0tqdSu/FKzy4EXyyuu0AdbjkqicQLQWo6XVWlPQZkk6fw0RoYAeB/CEPxaWLBNw9Nr34A/28l+QzizVz8/IIfT0AXpcepdt0cVgJZx7S/UIyFSsRFstgrA+xCG4PdG1U2Vo4kCzsThNPTK5wckMSoEtKYv65onsF4IgLciDMHvje5ZG4bW7j4mh5N1Q2jo4+35OlR4UvaIYF01IMXscgC/wWarALwdYQh+b2BqrKLDglRcUaPNh4rMLgde6MV1te20rz+/s8KCaZwAtAaH09DXdc0Tzu9qN7UWADgTwhD8ns1q0YXda1tsf7aLqXKo70jRSX24NU+SdDN7CwGtZnteicqqHIoKDVKvRNbhAfBOhCEEhNE9XPsN0UQB9S37/ICchpTRrYN6JkaZXQ7gN1xT5AanxcrGZqsAvBRhCAHBtW7o873HVVHtMLkaeIsah1PLaJwAtImN+wolsV4IgHcjDCEg9EyMUmJ0qCprnO5PK4FV2/J1pKhCHSJDdOWAZLPLAfzKlwdqX2vZbBWANyMMISBYLBaNYqocvuPFdfskSTcM66zQIBonAK2lsLxKu/NrN1sdymarALwYYQgBY1TdVDmaKECSDp4o16rt+ZJonAC0Ntf+Qt3jIxUXGWJuMQBwFoQhBAzXuqFNB4tUUlFtcjUw27LPD8j4/+3deXxU5dnG8WtmkslGFkJIQkKAkCAhCAYIQaAurVFUiti6YOsCWLFVabFp9cUVXOMuLS7ggmzVYlur4lqMRUWQsIMsYV8CJBAgK2SbOe8fWTSFaPYzy+/7+cwfmZw5c82BmSf3nOfcjyGNSOii+Iggs+MAHmVd7XTkFKbIAXBxFEPwGrFhAerVJVAOp6GVu4+bHQcmqqJxAtCu1taeGaJ5AgBXRzEEr8JUOUhS1tYjOlJSoYhOdl2STOMEoC05nIbW1y22SjEEwMVRDMGrjEyoKYZoouDd3szeL0m6ekic7D58DAJtaeeRUpVWVCvQbtNZUazdBcC18VcAvMrw2o5yOfklOlpSYXIamOHA8ZP6akdd44Q4k9MAnqd+sdXuYfKx8WcGANfGpxS8SniQXcndQiRJy5kq55Xeyt4vw5DO6xOhnl1onAC0tbX7WF8IgPugGILXGZnIekPeqsrh1NurcyVJv6adNtAu1nG9EAA3QjEEr0MTBe+1ZEu+Ckor1DXYT+nJUWbHATxO0ckq7TxSKokzQwDcA8UQvE5ar3D5WC3KPXFK+4+dNDsOOtCbK2saJ4xLjZMv1zIAbW7dgZopcj27BKpLJz+T0wDAj+OvAXidID8fpcSFSeLskDfZW1CmZTsLZLFI19E4AWgX61hfCICboRiCV6qbKrd8F9cNeYu3VtWcFbrgrK7q3jnQ5DSAZ6rrJDeYKXIA3ATFELzSyNoW2yt2FcgwDJPToL1VVjv1TxonAO3K+b3FVgdxZgiAm6AYglca1KOzAnxtKiitVE5+idlx0M7+seaAjpVVKjrEXz9LijQ7DuCRdh0tVUl5tQJ8bUqKDjY7DgA0CcUQvJLdx6qh8eGSpK9pse3Rvt5ZoOnvb5YkTRjZi0UggXZSN0VuYPdQ3mcA3AafVvBadVPllu+kiYKn2nq4WL9bsEZVDkM/H9hNt57X2+xIgMdau69QElPkALgXiiF4rZG1TRRW7jmuaofT5DRoawcLT2nCG9kqqajWsPhwPXvtObJaLWbHAjxWXVttmicAcCcUQ/Bayd1CFBboq9KKam3ILTI7DtpQ0ckqjZ+TrfziCp0V1Umv3JQqPx+b2bEAj1VcXqUd9YutcmYIgPugGILXslotGt6bqXKeprzKoUnzV2vnkVJFh/hr7sQ0hQb4mh0L8Gjr9xfKMKS48AB1DWaxVQDug2IIXq1uvSEWX/UMTqehjLfXK3vvcQX7+WjuzUMVExZgdizA47HYKgB3RTEEr1bXRGHtvkKdqnSYnAatYRiGHvlwiz7alCe7zarZNw1RUnSI2bEAr/DdYqsUQwDcC8UQvFp8RJC6hfqr0uHU6n3HzY6DVnjtqz164+u9kqRnrj1HIxIizA0EeImGi62GmZoFAJqLYghezWKx1P/RzHpD7uv9DYf02EdbJUn3Xd5PV5wTY3IiwHvsLihT0akq+fta1a8bZ2MBuBeKIXi9kYm1TRS4bsgtLd9VoD+9vV6SNHFkL91yXry5gQAvU7/YamyYfFlsFYCb4VMLXq9uvaFNB4tUdLLK5DRojm15xfrt/JpFVS8fEK0HRifLYmEtIaAj1TVPYIocAHdEMQSvFxXir4SuQTIMacVupsq5i0OFpzRhziqVVFQrrVe4nrs2hUVVAROsqz0zxPpCANwRxRCg784OMVXOPRSdqtKEN7KVV1yuPpGd9OpNqfL3ZVFVoKOVlFcpJ79EkjSYM0MA3BDFECB9r4kCxZCrq6h26Nb5q7U9v1RRIX6ae3OaQgNZVBUww8bcIhmGFBsWoMgQf7PjAECzUQwBkob37iKrRdp1tEx5ReVmx0EjahZV3aCVe46rk5+P3piQplgWVQVMs3Zf7fpCPZkiB8A9UQwBkkIDfXV2bKgkpsq5ssc/2qoPNx6Wr82i2TcOUXIMbXwBM9V1khsUF2ZuEABoIYohoNbwhJoW26w35Jpe+2q3Xlu2R5L0zDXn1F/nBcAchmFoXe1iq5wZAuCuKIaAWiMTvmuiYBiGyWnwfR9sPKRHP6xZVHXqZUkamxJrciIAewrKVHiySnYfq5JZbBWAm6IYAmoN7RUuu82qw0Xl2lNQZnYc1Ppm9zFlLNogSRo/vKd+e35vkxMBkKS1tesLDYgNld2HPycAuCc+vYBaAXZb/aKBy3cxVc4V5OSVaNL81ap0OHVp/2g9OKY/i6oCLqJufSFaagNwZxRDwPew3pDrOFx0ShPeyFZJebVSe3bWjOtSZGNRVcBl1J0ZGsxiqwDcGMUQ8D0jE2uaKKzYdUxOJ9cNmaXoVJUmzFmlw0XlSugapNfGs6gqftyLL76oXr16yd/fX8OGDVN2dnaj286dO1cWi6XBzd+/4To5EyZMOG2bSy+9tL1fhlsorahWTl6xJGkQxRAAN+ZjdgDAlQzsHqYgu00nTlZpy+Hi+nbb6DgV1Q79dsFq5eSXqGuwn+ZOTFNYoN3sWHBxixYtUkZGhmbNmqVhw4ZpxowZGjVqlHJychQZGXnGx4SEhCgnJ6f+5zNNwbz00kv1xhtv1P/s5+fX9uHd0MbcQjkNKSbUX9GhLLYKwH1xZgj4Hl+bVcN615wdYqpcx3M6Df35Hxv1ze6aRVXnThyquPBAs2PBDTz33HOaNGmSJk6cqOTkZM2aNUuBgYGaM2dOo4+xWCyKjo6uv0VFRZ22jZ+fX4NtOnfmLIgkraudIjeIltoA3FyLiqHmTEV45513lJqaqrCwMAUFBSklJUULFixosA1TEeBKRrDekGme+GSbFm84JB+rRS/fMFj9Yzgzhx9XWVmpNWvWKD09vf4+q9Wq9PR0rVixotHHlZaWqmfPnoqLi9PYsWO1efPm07ZZunSpIiMj1bdvX9122206dozPBUlau4/FVgF4hmZPk2vuVITw8HDdd999SkpKkt1u1wcffKCJEycqMjJSo0aNqt+OqQhwFXVNFLL3HFdltZOWsR1kzrI9euXL3ZKkp64eqPP6dDU5EdxFQUGBHA7HaWd2oqKitG3btjM+pm/fvpozZ44GDhyooqIiPfPMMxoxYoQ2b96s7t27S6oZl375y18qPj5eu3bt0r333qvLLrtMK1askM12+jVsFRUVqqioqP+5uLi4DV+l62CxVQCepNnF0PenIkjSrFmz9OGHH2rOnDmaOnXqadtfeOGFDX6eMmWK5s2bp2XLljUohuqmIgBm6xsVrC5Bdh0rq9T6A4VKiw83O5LH+2jTYT3y4RZJ0l2j+uqXg7ubnAiebvjw4Ro+fHj9zyNGjFC/fv00e/ZsPfLII5Kk6667rv73AwYM0MCBA5WQkKClS5fqoosuOm2fmZmZeuihh9o/vMn2HTup42WVstus6h/DYqsA3FuzvvJu6VSEOoZhKCsrSzk5OTr//PMb/K45UxEqKipUXFzc4Aa0FavVouH1U+W4bqi9Ze85rjsXrZdhSDee21O3X5hgdiS4mYiICNlsNuXn5ze4Pz8/v8lfsvn6+mrQoEHauXNno9v07t1bERERjW5zzz33qKioqP524MCBpr8IN7K2dn2h/rEh8vOhyyMA99asYuiHpiLk5eU1+riioiJ16tRJdrtdo0eP1syZM3XxxRfX//7SSy/V/PnzlZWVpSeffFJffPGFLrvsMjkcjjPuLzMzU6GhofW3uLi45rwM4Eex3lDH2JFfolvmrVJltVOXJEdp+hUsqorms9vtGjJkiLKysurvczqdysrKanD254c4HA5t2rRJ3bp1a3Sb3NxcHTt2rNFt/Pz8FBIS0uDmidaxvhAAD9IhrbWDg4O1fv16lZaWKisrSxkZGerdu3f9FLrmTkW45557lJGRUf9zcXExBRHa1MiEmmJo3f5ClVVUK8iPLvRtLa+oXOPnZKu4vFqDe4Tpr78axKKqaLGMjAyNHz9eqampSktL04wZM1RWVlY/pfumm25SbGysMjMzJUkPP/ywzj33XCUmJqqwsFBPP/209u3bp1tuuUVSTXOFhx56SFdddZWio6O1a9cu3X333UpMTGwwxdsb1Z0ZohgC4Ama9RdeS6ciWK1WJSYmSpJSUlK0detWZWZmnnY9UZ3vT0U4UzHk5+dHgwW0qx5dAtW9c4ByT5xS9t7j+mnfM69TgpYpLq/ShDeydaioXL0jgvT6+KEsqopWGTdunI4ePaoHH3xQeXl5SklJ0SeffFI/k2H//v2yWr+bDHHixAlNmjRJeXl56ty5s4YMGaLly5crOTlZkmSz2bRx40bNmzdPhYWFiomJ0SWXXKJHHnnEq8efk5XV2pZXIkka1CPM3DAA0AaaVQx9fyrClVdeKem7qQiTJ09u8n6cTmeDjjv/68emIgAdYWRChBatPqDlOwsohtpQZbVTv1uwRtvyShTRyU/zbk5T5yAWVUXrTZ48udGxaOnSpQ1+fv755/X88883uq+AgAB9+umnbRnPI2zMLZLDaSg6xF8xYQFmxwGAVmt2z+CMjAy9+uqrmjdvnrZu3arbbrvttKkI99xzT/32mZmZWrJkiXbv3q2tW7fq2Wef1YIFC3TDDTdIqpmKcNddd+mbb77R3r17lZWVpbFjxzIVAaYbkch6Q23N6TR01z83aPmuYwqy21hUFXAz9VPkeoaZGwQA2kizL4Ro7lSEsrIy3X777crNzVVAQICSkpK0cOFCjRs3ThJTEeC6RtReN7TlcLGOl1UqnLMXrfbkp9v03vqaRVVfumGIzo5lUVXAnazdVyhJGhTH9UIAPIPFMAzD7BCtVVxcrNDQUBUVFXls9x6YY9TzXyonv0Qv/nqwRg9k2mZrzP16j6YvrllL6OmrB+qaVJqeeAo+g8/M046LYRga+thnKiit1L9uG64hPVmDDYBras7nb7OnyQHepH6qHC22W+XjTYf10Ac1hdCfLzmLQghwQweOn1JBaaV8bRb1j+GsLgDPQDEE/IC6FtvLWXy1xVbtPa4ptYuq/npYD93x00SzIwFogbrrhZJjQun+CMBjUAwBPyCtd7isFmnvsZM6WHjK7DhuZ+eREt0yb7Uqq51K7xeph1lUFXBb6+rXFwozNwgAtCGKIeAHhPj7amD3MEnS15wdapb84nKNn7NKRaeqlBIXppm/GiwfGx85gLtau79QEoutAvAs/GUC/IiRtdcNMVWu6Sqrnbp57iodLDyl+IggvT4+VQF2ptUA7upUpUNbDxdLYrFVAJ6FYgj4EXXXDX2965g8oPlih5i/Yq82HypW50BfzZuYpi6daJMPuLOsbfmqdhqKDPZTLIutAvAgFEPAjxjcs7P8fKw6WlKhnUdKzY7j8o6WVOgvn+2QJP3fpUnq0YVFVQF3Vl7lUOZH2yRJ44bGcd0fAI9CMQT8CH9fm1J71cyRX77rmMlpXN8zn+aopKJaZ8eG0EIb8ACzv9itg4WnFBPqr9suTDA7DgC0KYohoAlG1E2V47qhH7Qpt0hvrzkgSZo+pr9sVr5BBtxZ7omTemnpTknSvaP7KdDuY3IiAGhbFENAE4xMrCmGvtl9TA4n1w2diWEYmr54swxDGpsSo9RerE4PuLvHP9qqimqnhsWHa/SAbmbHAYA2RzEENMGA2FAF+/uouLxa3x4sMjuOS3pv/SGt2XdCAb42Tb0syew4AFpp+c4CfbQpT1aLNJ01wgB4KIohoAlsVovO7V3TYvvrXUyV+19lFdXK/HirJOmOnyaoWyjdpgB3Vu1wavrizZKkG8/tqX7dQkxOBADtg2IIaKKRCXXrDdFE4X+9tHSn8osrFBceoFvO6212HACttOCbfdqeX6rOgb7648VnmR0HANoNxRDQRHXXDa3ae1zlVQ6T07iO/cdO6tWv9kiS7rs8Wf6+LK4KuLNjpRV6bsl2SdKfR/VVWKDd5EQA0H4ohoAmSozspMhgP1VUO7V2/wmz47iMRz/cospqp0YmdtGo/lFmxwHQSs/8J0cl5dXqHxOi64b2MDsOALQriiGgiSwWi0YwVa6BZTsK9J8t+bJZLZo2hgusAXe3KbdIf19V0x7/oStojw/A81EMAc0wonaqHE0UpCqHUw997wLrs6KCTU4EoDWcTkPT3v9WhiFdSXt8AF6CYghohrrrhjbmFqmkvMrkNOZa+M0+7ThSe4F1OhdYA+7u3fUHtXZ/oQLtNk29rJ/ZcQCgQ1AMAc0QGxagXl0C5XAaWrn7uNlxTHOstELP115g/adL+io00NfkRABao7SiWpkfb5Mk/f5nfRQd6m9yIgDoGBRDQDMxVU56dsl2FZdXq1+3EP0qjQusAXc3M2uHjpZUqFeXQN38k15mxwGADkMxBDTTyISaYshbmyhsPlSkt7L3S5Kmj0nmAmvAze06Wqo5X9e0x39wTLL8fGiPD8B7UAwBzTS8tqNcTn6JjpZUmJymYxmGoYcWb5FhSKMHdtOw3l3MjgSgFQzD0MOLt6jKYeinfbvqZ0m0xwfgXSiGgGYKD7IruVuIJGm5l02V+3DTYWXvOS5/X6vuvZwLrAF39/m2I/pi+1H52ix6cEx/s+MAQIejGAJawBvXGzpV6dDjH26VJP3uggTFhgWYnAhAa5RXOfTwB1skSb/5SW/FRwSZnAgAOh7FENACI72wicKsL3bpUFG5YsMC9NvzE8yOA6CVXl+2R/uOnVRksJ8m/yzR7DgAYAqKIaAF0uLD5WO1KPfEKe0/dtLsOO0u98RJzfpilyTp3sv7KcDOBdaAOztcdEovfL5TUs17upOfj8mJAMAcFENACwT5+SglLkySd5wdyvxomyqqnRoWH67LB0SbHQdAK2V+tE2nqhwa0rOzxqbEmB0HAExDMQS0UN16Q8t2enYxtGLXMX246bCsFmn6Ff1lsdBKG3Bn2XuO6/0Nh2SxSA/xngbg5SiGgBa64KyaYuijTYe1eMMhk9O0j2qHUw8t3ixJ+vWwHupX20UPgHtyOA1Ne7/mPf2rtB46OzbU5EQAYC6KIaCFBvforF8P6yHDkP64aL3+u+2I2ZHa3FurDmhbXolCA3z1p4v7mh0HQCu9mb1fWw8XK8TfR3++hPc0AFAMAS1ksVj0yNizNTYlRtVOQ79buEbf7PacVtuFJyv17H9yJEkZF5+lzkF2kxMBaI0TZd+9p/90SV+F854GAIohoDVsVoueueYcXZQUqYpqp26Zt1obcwvNjtUmnl+yXYUnq9Q3KljXD+thdhwArfRc7Xs6KZr3NADUoRgCWsnXZtWL1w/W8N5dVFpRrfFzsrUjv8TsWK2Sk1eihSv3S5IeHJMsHxsfFYA723KoWH9buU+SNG1Mf97TAFCLT0OgDfj72vTq+FSd0z1UJ05W6frXVrrt+kOGYeihxZvlcBoa1T+qfoFZAO7JMAxNf3+znIY0emA3DU/oYnYkAHAZFENAG+nk56O5E9PUNypYR0oqdMPrK5VfXG52rGb7dHOelu86JruPVfePTjY7DoBWWrzxsLL3Hpe/r1X3Xt7P7DgA4FIohoA21DnIrgW/SVPPLoHaf/ykbnhtpY6XVZodq8nKqxx69MOtkqRbz+utuPBAkxMBaI2TldV6vPY9fceFiYoNCzA5EQC4FoohoI1Fhvhr4W+GKTrEXzuOlGrCG9kqKa8yO1aTvPrlbuWeOKXoEH/d/tMEs+MAaKUX/7tTecXligsP0KTze5sdBwBcDsUQ0A7iwgO18JY0hQfZtTG3SL+Zt1rlVQ6zY/2gw0Wn9NLSXZKkey5PUqDdx+REAFpj37EyvfrlHknS/aOT5e9rMzkRALgeiiGgnSRGBmv+zWkK9vNR9p7jum3hGlVWO82O1agnPt6mU1UOpfbsrCvOiTE7DoBWeuSDrap0OHVenwhdkhxldhwAcEkUQ0A7Ojs2VK9PGCp/X6v+m3NUGW+vl8NpmB3rNKv3Htd76w/JYpGmX9FfFovF7EgAWmFpzhF9tjVfPlaLpo1J5j0NAI2gGALaWVp8uGbdMES+Nos+2HhY97+7SYbhOgWRw2lo+uLNkqRxqXE6OzbU5EQAWqOy2qmHF2+RJE0Y0UuJkcEmJwIA10UxBHSAC/tGasa4QbJapLeyDyjz420uUxD9Y/UBfXuwWMH+PvrzqL5mxwHQSnOX79HugjJFdPLTH9L7mB0HAFwaxRDQQUYP7KYnfjlQkvTKl7v1wuc7TU4kFZ2q0tOf5kiSplzURxGd/ExOBKA1jhSX6y+f7ZAk/d+lfRXi72tyIgBwbRRDQAe6dmicHvh5zUKmzy7Zrrlf7zE1z1+zduhYWaUSugZp/IhepmYB0HpPfLJNZZUOnRMXpqsGdzc7DgC4PIohoIP95ifxmnJRzdSV6Yu36J9rck3JsfNIieYt3ytJenBMf/na+DgA3NmafSf0ztqDkqSHrugvq5WmCQDwY/jrBzDBnel9dPPIeEnS3f/coE++Pdyhz28Yhh7+YKuqnYbS+0XqgrO6dujzA2hbTqeh6e/XNEK5Zkh3pcSFmRsIANwExRBgAovFovtH99M1Q7rLaUh/eGu9vtpxtMOeP2vrEX25/ajsNqvuH53cYc8LoH38Y80BbTpYpGA/H919aZLZcQDAbVAMASaxWi164qqBunxAtCodTt06f41W7z3e7s9bUe3QIx/WtN29+Sfx6hUR1O7PCaD9FJ2q0lOf1DZCSe+jrsE0QgGApqIYAkxks1r0/LgUnX9WV52qcmji3FXafKioXZ9zzrK92nfspCKD/TT5Z4nt+lwA2t+Mz7brWFmlEiM70QgFAJqJYggwmZ+PTbNvGKKhvTqrpLxaN72erV1HS9vluY4Ul+uFz+va7iapk59PuzwPgI6xPb9E81fskyRNG5NMIxQAaCY+NQEXEGC36fUJQ9U/JkTHyip142srdbDwVJs/T13b3ZS4MP1iUGyb7x9AxzGMmqYJDqehUf2jdF4fGqEAQHNRDAEuIsTfV/NvTlNC1yAdKirXDa+t1NGSijbb/7r937XdnU7bXcDtffJtnpbvOia7D41QAKClKIYAF9Klk58W3jJMsWEB2lNQphtfX6mik1Wt3u/32+5eTdtdwO2dqnTo0Q+3SpJ+d35vxYUHmpwIANwTxRDgYrqFBuhvtwxT12A/bcsr0YS52SqrqG7VPv+1NlcbcovUyc9Hd1/at42SAjDLrC926WDhKcWE+uu2C2mEAgAtRTEEuKBeEUFa8Js0hQb4at3+Qt26YLXKqxwt2ldJeZWerG27+/ufJSoy2L8towLoYAeOn9SsL3ZJku4bnawAu83kRADgviiGABeVFB2iuROHKtBu09c7j+n3b61TtcPZ7P288PlOFZRWKD4iSBNHxrdDUgAd6fGPtqqi2qnhvbvo8gHRZscBALdGMQS4sEE9Ouu18amy+1i1ZEu+7v7nRjmdRpMfv6egTHO+3iNJeuDn/WT34S0PuLOvdxbo42/zZLNaNO2KZFksNEIBgNbgLyPAxY1IiNBLvx4sm9Wid9Yd1PTFm2UYTSuIHv1gi6ochi7s21U/S4pq56QA2lOVw1nfCOXGc3sqKTrE5EQA4P4ohgA3kJ4cpeeuPUcWizR/xT4985+cH33M0pwjytp2RD5Wix74OW13AXe3YMU+7ThSqvAgu/6YfpbZcQDAI1AMAW5ibEqsHhl7tiTpxf/uqr+A+kwqq516+IMtkqQJI3opoWunDskIoH0UlFbo+c+2S5LuGtVXoYG+JicCAM9AMQS4kRvO7an/uzRJkvTEx9v0t5X7zrjd/BV7tftomSI62fWH9D4dGRFAO3j6kxyVlFfr7NgQXZsaZ3YcAPAYFEOAm7ntwgTdfmGCJOn+d7/Ve+sPNvj90ZIK/eWzHZJqvkEO8ecbZMCdbThQqLfXHJAkTR/TXzYrTRMAoK1QDAFu6K5RfXXjuT1lGFLG2xv02Zb8+t8982mOSiqqNSA2VNcM4RtkwJ05nUZt0xTpF4Nildor3OxIAOBRKIYAN2SxWPTQFf31i0GxcjgN3f7mWi3fVaBNuUXffYN8RbKsfIMMuLV/rzuodfsLFWS3aeplSWbHAQCP42N2AAAtY7Va9PTVA1VaUa0lW/I1ad5qxXYOkGFIV6bEaEhPvkEG3FlJeZUyP94mSfr9RX0UFeJvciIA8DycGQLcmI/Nqpm/GqSRiV1UVunQ9vxSBdptmnpZP7OjAWilmZ/vVEFpheIjgjRxZC+z4wCAR6IYAtycv69Nr9yYqkE9wiRJUy7qo+hQvkEG3FlJeZXeyt4vSXpwTLL8fGwmJwIAz8Q0OcADBPn5aNGtw7XjSIn6x4SaHQdAKwX7++rTO8/X4g2H9NO+kWbHAQCPxZkhwEPYfawUQoAHiQkL0G8vSDA7BgB4NIohAAAAAF6JYggAAACAV6IYAgAAAOCVKIYAAAAAeCWKIQAAAABeiWIIAAAAgFeiGAIAAADglSiGAAAAAHgliiEAAAAAXoliCAAAAIBXohgCAAAA4JVaVAy9+OKL6tWrl/z9/TVs2DBlZ2c3uu0777yj1NRUhYWFKSgoSCkpKVqwYEGDbQzD0IMPPqhu3bopICBA6enp2rFjR0uiAQAAAECTNLsYWrRokTIyMjRt2jStXbtW55xzjkaNGqUjR46ccfvw8HDdd999WrFihTZu3KiJEydq4sSJ+vTTT+u3eeqpp/TXv/5Vs2bN0sqVKxUUFKRRo0apvLy85a8MAAAAAH5As4uh5557TpMmTdLEiROVnJysWbNmKTAwUHPmzDnj9hdeeKF+8YtfqF+/fkpISNCUKVM0cOBALVu2TFLNWaEZM2bo/vvv19ixYzVw4EDNnz9fhw4d0rvvvtuqFwcA8B7NmbUwd+5cWSyWBjd/f/8G2zBrAQA8X7OKocrKSq1Zs0bp6enf7cBqVXp6ulasWPGjjzcMQ1lZWcrJydH5558vSdqzZ4/y8vIa7DM0NFTDhg1r0j4BAGjurAVJCgkJ0eHDh+tv+/bta/B7Zi0AgOdrVjFUUFAgh8OhqKioBvdHRUUpLy+v0ccVFRWpU6dOstvtGj16tGbOnKmLL75Ykuof15x9VlRUqLi4uMENAOC9mjtrQZIsFouio6Prb98fh5i1AADeoUO6yQUHB2v9+vVatWqVHnvsMWVkZGjp0qUt3l9mZqZCQ0Prb3FxcW0XFgDgVlo6a6G0tFQ9e/ZUXFycxo4dq82bN9f/riWzFviiDgDcT7OKoYiICNlsNuXn5ze4Pz8/X9HR0Y0/idWqxMREpaSk6E9/+pOuvvpqZWZmSlL945qzz3vuuUdFRUX1twMHDjTnZQAAPEhLZi307dtXc+bM0XvvvaeFCxfK6XRqxIgRys3NldSyWQt8UQcA7qdZxZDdbteQIUOUlZVVf5/T6VRWVpaGDx/e5P04nU5VVFRIkuLj4xUdHd1gn8XFxVq5cmWj+/Tz81NISEiDGwAATTV8+HDddNNNSklJ0QUXXKB33nlHXbt21ezZs1u8T76oAwD349PcB2RkZGj8+PFKTU1VWlqaZsyYobKyMk2cOFGSdNNNNyk2Nrb+zE9mZqZSU1OVkJCgiooKffTRR1qwYIFefvllSTVztu+88049+uij6tOnj+Lj4/XAAw8oJiZGV155Zdu9UgCAR2rprIXv8/X11aBBg7Rz505JDWctdOvWrcE+U1JSzrgPPz8/+fn5teAVAADM0uxiaNy4cTp69KgefPBB5eXlKSUlRZ988kn9VIL9+/fLav3uhFNZWZluv/125ebmKiAgQElJSVq4cKHGjRtXv83dd9+tsrIy3XrrrSosLNRPfvITffLJJ6e1OQUA4H99f9ZC3ZdodbMWJk+e3KR9OBwObdq0SZdffrmkhrMW6oqfulkLt912W3u8DACACSyGYRhmh2it4uJihYaGqqioiClzANDBXOEzeNGiRRo/frxmz55dP2vh7bff1rZt2xQVFXXarIWHH35Y5557rhITE1VYWKinn35a7777rtasWaPk5GRJ0pNPPqknnnhC8+bNq5+1sHHjRm3ZsqVJX9a5wnEBAG/UnM/fZp8ZAgDA1TR31sKJEyc0adIk5eXlqXPnzhoyZIiWL19eXwhJzFoAAG/AmSEAQKvwGXxmHBcAMEdzPn87ZJ0hAAAAAHA1FEMAAAAAvBLFEAAAAACvRDEEAAAAwCtRDAEAAADwShRDAAAAALwSxRAAAAAAr+QRi67WLZVUXFxschIA8D51n70esGxdm2JsAgBzNGdc8ohiqKSkRJIUFxdnchIA8F4lJSUKDQ01O4bLYGwCAHM1ZVyyGB7wVZ7T6dShQ4cUHBwsi8XS7McXFxcrLi5OBw4cYJXwM+D4NI5j0ziOTeM87dgYhqGSkhLFxMTIamX2dR3GpvbDsflhHJ/GcWwa50nHpjnjkkecGbJarerevXur9xMSEuL2//jtiePTOI5N4zg2jfOkY8MZodMxNrU/js0P4/g0jmPTOE85Nk0dl/gKDwAAAIBXohgCAAAA4JUohiT5+flp2rRp8vPzMzuKS+L4NI5j0ziOTeM4NmgK/p80jmPzwzg+jePYNM5bj41HNFAAAAAAgObizBAAAAAAr0QxBAAAAMArUQwBAAAA8EoUQwAAAAC8EsWQpBdffFG9evWSv7+/hg0bpuzsbLMjmS4zM1NDhw5VcHCwIiMjdeWVVyonJ8fsWC7piSeekMVi0Z133ml2FJdx8OBB3XDDDerSpYsCAgI0YMAArV692uxYpnM4HHrggQcUHx+vgIAAJSQk6JFHHhF9bHAmjE2nY2xqOsamhhiXzoxxiWJIixYtUkZGhqZNm6a1a9fqnHPO0ahRo3TkyBGzo5nqiy++0B133KFvvvlGS5YsUVVVlS655BKVlZWZHc2lrFq1SrNnz9bAgQPNjuIyTpw4oZEjR8rX11cff/yxtmzZomeffVadO3c2O5rpnnzySb388st64YUXtHXrVj355JN66qmnNHPmTLOjwcUwNp0ZY1PTMDY1xLjUOMYlWmtr2LBhGjp0qF544QVJktPpVFxcnH7/+99r6tSpJqdzHUePHlVkZKS++OILnX/++WbHcQmlpaUaPHiwXnrpJT366KNKSUnRjBkzzI5luqlTp+rrr7/WV199ZXYUl/Pzn/9cUVFRev311+vvu+qqqxQQEKCFCxeamAyuhrGpaRibTsfYdDrGpcYxLnn5maHKykqtWbNG6enp9fdZrValp6drxYoVJiZzPUVFRZKk8PBwk5O4jjvuuEOjR49u8P8H0vvvv6/U1FRdc801ioyM1KBBg/Tqq6+aHcsljBgxQllZWdq+fbskacOGDVq2bJkuu+wyk5PBlTA2NR1j0+kYm07HuNQ4xiXJx+wAZiooKJDD4VBUVFSD+6OiorRt2zaTUrkep9OpO++8UyNHjtTZZ59tdhyX8Pe//11r167VqlWrzI7icnbv3q2XX35ZGRkZuvfee7Vq1Sr94Q9/kN1u1/jx482OZ6qpU6equLhYSUlJstlscjgceuyxx3T99debHQ0uhLGpaRibTsfYdGaMS41jXPLyYghNc8cdd+jbb7/VsmXLzI7iEg4cOKApU6ZoyZIl8vf3NzuOy3E6nUpNTdXjjz8uSRo0aJC+/fZbzZo1y+sHnbffflt/+9vf9Oabb6p///5av3697rzzTsXExHj9sQGai7GpIcamxjEuNY5xycuLoYiICNlsNuXn5ze4Pz8/X9HR0Salci2TJ0/WBx98oC+//FLdu3c3O45LWLNmjY4cOaLBgwfX3+dwOPTll1/qhRdeUEVFhWw2m4kJzdWtWzclJyc3uK9fv37617/+ZVIi13HXXXdp6tSpuu666yRJAwYM0L59+5SZmek1gw5+HGPTj2NsOh1jU+MYlxrHuOTl1wzZ7XYNGTJEWVlZ9fc5nU5lZWVp+PDhJiYzn2EYmjx5sv7973/r888/V3x8vNmRXMZFF12kTZs2af369fW31NRUXX/99Vq/fr3XDjZ1Ro4ceVqr2+3bt6tnz54mJXIdJ0+elNXa8GPXZrPJ6XSalAiuiLGpcYxNjWNsahzjUuMYl7z8zJAkZWRkaPz48UpNTVVaWppmzJihsrIyTZw40exoprrjjjv05ptv6r333lNwcLDy8vIkSaGhoQoICDA5nbmCg4NPm58eFBSkLl26MG9d0h//+EeNGDFCjz/+uK699lplZ2frlVde0SuvvGJ2NNONGTNGjz32mHr06KH+/ftr3bp1eu6553TzzTebHQ0uhrHpzBibGsfY1DjGpcYxLkkyYMycOdPo0aOHYbfbjbS0NOObb74xO5LpJJ3x9sYbb5gdzSVdcMEFxpQpU8yO4TIWL15snH322Yafn5+RlJRkvPLKK2ZHcgnFxcXGlClTjB49ehj+/v5G7969jfvuu8+oqKgwOxpcEGPT6Ribmoex6TuMS2fGuGQYXr/OEAAAAADv5NXXDAEAAADwXhRDAAAAALwSxRAAAAAAr0QxBAAAAMArUQwBAAAA8EoUQwAAAAC8EsUQAAAAAK9EMQQAAADAK1EMAQAAAPBKFEMAAAAAvBLFEAAAAACvRDEEAAAAwCv9PybVjOb1B3HwAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, axs = plt.subplots(1, 2, figsize=(10, 10))\n",
"axs[0].set_title(\"Loss value on test set\")\n",
"axs[0].plot(eval_metrics_history[\"test_loss\"])\n",
"axs[1].set_title(\"Accuracy on test set\")\n",
"axs[1].plot(eval_metrics_history[\"test_accuracy\"])"
]
},
{
"cell_type": "markdown",
"id": "5dee1e2b-7dae-4af7-a4ad-bd81b9076cc9",
"metadata": {},
"source": [
"We can observe that the model starts overfitting after the 5-th epoch and the best accuracy it could achieve is around 0.87. Let us also check few model's predictions on the test data:"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "959a55dc-2adb-4324-9b9e-114f773c5484",
"metadata": {},
"outputs": [],
"source": [
"data = test_dataset[10]"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "6540737e-2237-49da-a9a6-87468100f061",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"- Text:\n",
" Let me say first off that I am a huge fan of the original series Lonesome Dove and the book it was based from. I have put off watching this sequel for the better part of 10 years due to the bad reviews I'd heard about it. If Tommy Lee Jones wasn't playing Capt. Call I didn't see the point. If Larry McMurtry wasn't involved why should I care? How wrong I was.
This is in so many ways a worthy sequel to Lonesome Dove, maybe even more so than the dark mood of Streets Of Laredo. The story, acting, production, cinematography are all top-notch. Of course the script isn't as colorful as Lonesome Dove but it has it's moments. And, much to my surprise, there are bits of Lonesome Done in this series; the relationship between July and Clara, completely dismissed in the prequel, is brought up here almost identical to the book, a most welcome surprise. The story isn't all roses, it has it's surprises too. By far the biggest surprise is Jon Voight's interpretation of Capt. Call. While not a direct copy of Tommy Lee Jones' his is both faithful and unique to Voight's credit. The cast is fantastic all across the board, and I don't think Rick Schroeder has done a better job of acting than in this series. Oliver Reed practically steals the show here, he is superb in a role that makes you care for his character as equally as you hate him.
It is worth it to watch this if you haven't due to bad criticisms, especially that the DVD is so affordable (I got the 2-disc set for $10.99, you can probably find it cheaper). It is in no way the disappointment that Dead Man's Walk turned out (well, it was for me). And MCMurtry was involved with that one!\n",
"\n",
"- Expected review sentiment: positive\n",
"- Predicted review sentiment: positive, confidence: 0.897\n"
]
}
],
"source": [
"text_processing = TextPreprocessing(tokenizer, max_length=max_length)\n",
"processed_data = text_processing.map(data)\n",
"model.eval()\n",
"preds = model(processed_data[\"text\"][None, :])\n",
"pred_label = preds.argmax(axis=-1).item()\n",
"confidence = nnx.softmax(preds, axis=-1)\n",
"\n",
"print(\"- Text:\\n\", data[\"text\"])\n",
"print(\"\")\n",
"print(f\"- Expected review sentiment: {'positive' if data['label'] == 0 else 'negative'}\")\n",
"print(f\"- Predicted review sentiment: {'positive' if pred_label == 0 else 'negative'}, confidence: {confidence[0, pred_label]:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "a8cd4b5d-6361-406c-87db-2d35358e3199",
"metadata": {},
"outputs": [],
"source": [
"data = test_dataset[20]"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "4897165d-e5ef-4528-b3ea-43845dde6b3a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"- Text:\n",
" One of the best TV shows out there, if not the best one. Why? Simple: it has guts to show us real life in prison, without any clichés and predictable twists. This is not Prison Break or any other show, actually comparing to Oz the show Sopranos look like story for children's. Profanity, cursing, shots of explicit violence and using drugs, disgusting scenes of male sexual organs and rapes... all this and more in Oz. But this is not the best part of Oz; the characters are the strongest point of this show; they're all excellent and not annoying, despite the fact we are looking at brutal criminals. The actors are excellent, my favorite are the actors who are playing Ryan O'Reilly and Tobias Beecher, because they're so unique and changing their behavior completely. And most of all... the don't have no remorse for their actions. Overall... Oz is amazing show, the best one out there. Forget about CSI and shows about stupid doctors... this is the deal... OZ!\n",
"\n",
"- Expected review sentiment: positive\n",
"- Predicted review sentiment: negative, confidence: 0.610\n"
]
}
],
"source": [
"text_processing = TextPreprocessing(tokenizer, max_length=max_length)\n",
"processed_data = text_processing.map(data)\n",
"model.eval()\n",
"preds = model(processed_data[\"text\"][None, :])\n",
"pred_label = preds.argmax(axis=-1).item()\n",
"confidence = nnx.softmax(preds, axis=-1)\n",
"\n",
"print(\"- Text:\\n\", data[\"text\"])\n",
"print(\"\")\n",
"print(f\"- Expected review sentiment: {'positive' if data['label'] == 0 else 'negative'}\")\n",
"print(f\"- Predicted review sentiment: {'positive' if pred_label == 0 else 'negative'}, confidence: {confidence[0, pred_label]:.3f}\")"
]
},
{
"cell_type": "markdown",
"id": "c633f0db-057a-42d7-9c73-59b09712d160",
"metadata": {},
"source": [
"## Further reading\n",
"\n",
"In this tutorial we implemented from scratch a simple convolutional neural network and trained it on a text classification dataset. Trained model shows 87% classification accuracy due to its convolutional nature. Next steps to improve the metrics could be to change the model to a transformer-based architecture.\n",
"\n",
"- Model checkpointing and exporting using [Orbax](https://orbax.readthedocs.io/en/latest/).\n",
"- Other NLP tutorials in [jax-ai-stack](https://jax-ai-stack.readthedocs.io/en/latest/getting_started.html)."
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}