{ "cells": [ { "cell_type": "markdown", "id": "ee3e1116-f6cd-497e-b617-1d89d5d1f744", "metadata": {}, "source": [ "# Machine Translation with encoder-decoder transformer model\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_machine_translation.ipynb)" ] }, { "cell_type": "markdown", "id": "50f0bd58-dcc6-41f4-9dc4-3a08c8ef751b", "metadata": {}, "source": [ "This tutorial is adapted from [Keras' documentation on English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/), which is itself an adaptation from the book [Deep Learning with Python, Second Edition by François Chollet](https://www.manning.com/books/deep-learning-with-python-second-edition)\n", "\n", "We step through an encoder-decoder transformer in JAX and train a model for English->Spanish translation." ] }, { "cell_type": "code", "execution_count": 1, "id": "dd506ffa-3b91-44f1-92d1-a08ed933e78e", "metadata": {}, "outputs": [], "source": [ "import pathlib\n", "import random\n", "import string\n", "import re\n", "import numpy as np\n", "\n", "import jax.numpy as jnp\n", "import optax\n", "\n", "from flax import nnx\n", "\n", "import tiktoken\n", "import grain.python as grain\n", "import tqdm" ] }, { "cell_type": "markdown", "id": "e1f324b0-140a-48fa-9fcb-d6308f098343", "metadata": {}, "source": [ "## Pull down data to temp and extract into memory\n", "\n", "There are lots of ways to get this done, but for simplicity and clear visibility into what's happening this is downloaded to a temporary directory, extracted there, and read into a python object with processing." ] }, { "cell_type": "code", "execution_count": 2, "id": "102943a5-8724-48e0-8d6a-f56069f03426", "metadata": {}, "outputs": [], "source": [ "import requests\n", "import zipfile\n", "import tempfile\n", "\n", "url = \"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\"\n", "\n", "with tempfile.TemporaryDirectory() as temp_dir:\n", " temp_path = pathlib.Path(temp_dir)\n", " zip_file_path = temp_path / \"spa-eng.zip\"\n", "\n", " response = requests.get(url)\n", " zip_file_path.write_bytes(response.content)\n", "\n", " with zipfile.ZipFile(zip_file_path, \"r\") as zip_ref:\n", " zip_ref.extractall(temp_path)\n", "\n", " text_file = temp_path / \"spa-eng\" / \"spa.txt\"\n", "\n", " with open(text_file) as f:\n", " lines = f.read().split(\"\\n\")[:-1]\n", " text_pairs = []\n", " for line in lines:\n", " eng, spa = line.split(\"\\t\")\n", " spa = \"[start] \" + spa + \" [end]\"\n", " text_pairs.append((eng, spa))" ] }, { "cell_type": "markdown", "id": "9524904b-fa17-493f-bcfa-335963cb7c45", "metadata": {}, "source": [ "## Build train/validate/test pair sets\n", "We'll stay close to the original tutorial so it's clear how to follow what's the same vs what's different; one early difference is the choice to go with an off-the-shelf encoder/tokenizer in tiktoken. Specifically \"cl100k_base\" - it has a wide range of language understanding and it's fast." ] }, { "cell_type": "code", "execution_count": 3, "id": "bee9f1b0-5f74-47dc-a7e1-a4ea3be1ef7f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "118964 total pairs\n", "83276 training pairs\n", "17844 validation pairs\n", "17844 test pairs\n" ] } ], "source": [ "random.shuffle(text_pairs)\n", "num_val_samples = int(0.15 * len(text_pairs))\n", "num_train_samples = len(text_pairs) - 2 * num_val_samples\n", "train_pairs = text_pairs[:num_train_samples]\n", "val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]\n", "test_pairs = text_pairs[num_train_samples + num_val_samples :]\n", "\n", "print(f\"{len(text_pairs)} total pairs\")\n", "print(f\"{len(train_pairs)} training pairs\")\n", "print(f\"{len(val_pairs)} validation pairs\")\n", "print(f\"{len(test_pairs)} test pairs\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "fd1f61fe-e4b7-479d-917e-f609ebe482e9", "metadata": {}, "outputs": [], "source": [ "tokenizer = tiktoken.get_encoding(\"cl100k_base\")" ] }, { "cell_type": "markdown", "id": "a714c4ea-9ff6-4dab-ae9c-1a884d4857e7", "metadata": {}, "source": [ "We strip out punctuation to keep things simple and in line with the original tutorial - the `[` `]` are kept in so that our `[start]` and `[end]` formatting is preserved." ] }, { "cell_type": "code", "execution_count": 5, "id": "07e054d3-a20c-4aed-8f8a-fb5158df8e5b", "metadata": {}, "outputs": [], "source": [ "strip_chars = string.punctuation + \"¿\"\n", "strip_chars = strip_chars.replace(\"[\", \"\")\n", "strip_chars = strip_chars.replace(\"]\", \"\")\n", "\n", "vocab_size = tokenizer.n_vocab\n", "sequence_length = 20" ] }, { "cell_type": "code", "execution_count": 6, "id": "e2b3e5b3-8466-4c81-99da-0559c88b25ef", "metadata": {}, "outputs": [], "source": [ "def custom_standardization(input_string):\n", " lowercase = input_string.lower()\n", " return re.sub(f\"[{re.escape(strip_chars)}]\", \"\", lowercase)" ] }, { "cell_type": "code", "execution_count": 7, "id": "5bdc0673-9723-45b5-8a42-2152295df69b", "metadata": {}, "outputs": [], "source": [ "def tokenize_and_pad(text, tokenizer, max_length):\n", " tokens = tokenizer.encode(text)[:max_length]\n", " padded = tokens + [0] * (max_length - len(tokens)) if len(tokens) < max_length else tokens ##assumes list-like - (https://github.com/openai/tiktoken/blob/main/tiktoken/core.py#L81 current tiktoken out)\n", " return padded" ] }, { "cell_type": "code", "execution_count": 8, "id": "235b1221-e72d-4793-addd-7bb870bd8e75", "metadata": {}, "outputs": [], "source": [ "def format_dataset(eng, spa, tokenizer, sequence_length):\n", " eng = custom_standardization(eng)\n", " spa = custom_standardization(spa)\n", " eng = tokenize_and_pad(eng, tokenizer, sequence_length)\n", " spa = tokenize_and_pad(spa, tokenizer, sequence_length)\n", " return {\n", " \"encoder_inputs\": eng,\n", " \"decoder_inputs\": spa[:-1],\n", " \"target_output\": spa[1:],\n", " }" ] }, { "cell_type": "code", "execution_count": 9, "id": "ca013d07-1504-42cc-906f-2fcacc757008", "metadata": {}, "outputs": [], "source": [ "train_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in train_pairs]\n", "val_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in val_pairs]\n", "test_data = [format_dataset(eng, spa, tokenizer, sequence_length) for eng, spa in test_pairs]" ] }, { "cell_type": "markdown", "id": "90bbae98-48dd-4ae4-99bb-92336d7c0a1c", "metadata": {}, "source": [ "At this point we've extracted the data, applied formatting, and tokenized the phrases with padding. The data is kept in train/validate/test sets that each have dictionary entries, which look like the following:" ] }, { "cell_type": "code", "execution_count": 10, "id": "dcbfa780-553f-41f6-8b3e-55955db78b2a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'encoder_inputs': [72, 1390, 311, 617, 264, 3137, 449, 1461, 922, 856, 3938, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 92820, 7669, 277, 390, 33013, 1645, 78993, 409, 9686, 65744, 510, 408, 60, 0, 0, 0, 0], 'target_output': [60, 92820, 7669, 277, 390, 33013, 1645, 78993, 409, 9686, 65744, 510, 408, 60, 0, 0, 0, 0, 0]}\n" ] } ], "source": [ "## data selection example\n", "print(train_data[135])" ] }, { "cell_type": "markdown", "id": "24c6271b-e359-4aba-a583-f18c40eddba9", "metadata": {}, "source": [ "The output should look something like\n", "\n", "{'encoder_inputs': [9514, 265, 3339, 264, 2466, 16930, 1618, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_inputs': [29563, 60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0], 'target_output': [60, 1826, 7206, 71086, 37116, 653, 16109, 1493, 54189, 510, 408, 60, 0, 0, 0, 0, 0, 0, 0]}" ] }, { "cell_type": "markdown", "id": "7a906a05-bd17-4a47-afe0-4422d2ea0f50", "metadata": {}, "source": [ "## Define Transformer components: Encoder, Decoder, Positional Embed\n", "\n", "In many ways this is very similar to the original source, with `ops` changing to `jnp` and `keras` or `layers` becoming `nnx`. Certain module-specific arguments come and go, like the rngs attached to most things in the updated version, and decode=False in the MultiHeadAttention call." ] }, { "cell_type": "code", "execution_count": 11, "id": "121bf138-34b3-4be9-a0fc-3bbac81f469a", "metadata": {}, "outputs": [], "source": [ "class TransformerEncoder(nnx.Module):\n", " def __init__(self, embed_dim: int, dense_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):\n", " self.embed_dim = embed_dim\n", " self.dense_dim = dense_dim\n", " self.num_heads = num_heads\n", "\n", " self.attention = nnx.MultiHeadAttention(num_heads=num_heads,\n", " in_features=embed_dim,\n", " decode=False,\n", " rngs=rngs)\n", " self.dense_proj = nnx.Sequential(\n", " nnx.Linear(embed_dim, dense_dim, rngs=rngs),\n", " nnx.relu,\n", " nnx.Linear(dense_dim, embed_dim, rngs=rngs),\n", " )\n", "\n", " self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", " self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", "\n", " def __call__(self, inputs, mask=None):\n", " if mask is not None:\n", " padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n", " else:\n", " padding_mask = None\n", "\n", " attention_output = self.attention(\n", " inputs_q = inputs, inputs_k = inputs, inputs_v = inputs, mask=padding_mask, decode = False\n", " )\n", " proj_input = self.layernorm_1(inputs + attention_output)\n", " proj_output = self.dense_proj(proj_input)\n", " return self.layernorm_2(proj_input + proj_output)\n", "\n", "\n", "class PositionalEmbedding(nnx.Module):\n", " def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, rngs: nnx.Rngs, **kwargs):\n", " self.token_embeddings = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)\n", " self.position_embeddings = nnx.Embed(num_embeddings=sequence_length, features=embed_dim, rngs=rngs)\n", " self.sequence_length = sequence_length\n", " self.vocab_size = vocab_size\n", " self.embed_dim = embed_dim\n", "\n", " def __call__(self, inputs):\n", " length = inputs.shape[1]\n", " positions = jnp.arange(0, length)[None, :]\n", " embedded_tokens = self.token_embeddings(inputs)\n", " embedded_positions = self.position_embeddings(positions)\n", " return embedded_tokens + embedded_positions\n", "\n", " def compute_mask(self, inputs, mask=None):\n", " if mask is None:\n", " return None\n", " else:\n", " return jnp.not_equal(inputs, 0)\n", "\n", "class TransformerDecoder(nnx.Module):\n", " def __init__(self, embed_dim: int, latent_dim: int, num_heads: int, rngs: nnx.Rngs, **kwargs):\n", " self.embed_dim = embed_dim\n", " self.latent_dim = latent_dim\n", " self.num_heads = num_heads\n", " self.attention_1 = nnx.MultiHeadAttention(num_heads=num_heads,\n", " in_features=embed_dim,\n", " decode=False,\n", " rngs=rngs)\n", " self.attention_2 = nnx.MultiHeadAttention(num_heads=num_heads,\n", " in_features=embed_dim,\n", " decode=False,\n", " rngs=rngs)\n", "\n", " self.dense_proj = nnx.Sequential(\n", " nnx.Linear(embed_dim, latent_dim, rngs=rngs),\n", " nnx.relu,\n", " nnx.Linear(latent_dim, embed_dim, rngs=rngs),\n", " )\n", " self.layernorm_1 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", " self.layernorm_2 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", " self.layernorm_3 = nnx.LayerNorm(embed_dim, rngs=rngs)\n", "\n", " def __call__(self, inputs, encoder_outputs, mask=None):\n", " causal_mask = self.get_causal_attention_mask(inputs.shape[1])\n", " if mask is not None:\n", " padding_mask = jnp.expand_dims(mask, axis=1).astype(jnp.int32)\n", " padding_mask = jnp.minimum(padding_mask, causal_mask)\n", " else:\n", " padding_mask = None\n", " attention_output_1 = self.attention_1(\n", " inputs_q=inputs, inputs_v=inputs, inputs_k=inputs, mask=causal_mask\n", " )\n", " out_1 = self.layernorm_1(inputs + attention_output_1)\n", "\n", " attention_output_2 = self.attention_2( ## https://github.com/google/flax/blob/main/flax/nnx/nn/attention.py#L403-L405\n", " inputs_q=out_1,\n", " inputs_v=encoder_outputs,\n", " inputs_k=encoder_outputs,\n", " mask=padding_mask,\n", " )\n", " out_2 = self.layernorm_2(out_1 + attention_output_2)\n", "\n", " proj_output = self.dense_proj(out_2)\n", " return self.layernorm_3(out_2 + proj_output)\n", "\n", " def get_causal_attention_mask(self, sequence_length):\n", " i = jnp.arange(sequence_length)[:, None]\n", " j = jnp.arange(sequence_length)\n", " mask = (i >= j).astype(jnp.int32)\n", " mask = jnp.reshape(mask, (1, 1, sequence_length, sequence_length))\n", " return mask" ] }, { "cell_type": "markdown", "id": "d033ae31-cc43-4e61-8d7f-cdc6d55b8bf9", "metadata": {}, "source": [ "Here we finally use our earlier encoder, decoder, and positional embed classes to construct the Model that we'll train and later use for inference." ] }, { "cell_type": "code", "execution_count": 12, "id": "c5dcfaf6-f5cd-40f4-bbf0-2754c0193327", "metadata": {}, "outputs": [], "source": [ "class TransformerModel(nnx.Module):\n", " def __init__(self, sequence_length: int, vocab_size: int, embed_dim: int, latent_dim: int, num_heads: int, dropout_rate: float, rngs: nnx.Rngs):\n", " self.sequence_length = sequence_length\n", " self.vocab_size = vocab_size\n", " self.embed_dim = embed_dim\n", " self.latent_dim = latent_dim\n", " self.num_heads = num_heads\n", " self.dropout_rate = dropout_rate\n", "\n", " self.encoder = TransformerEncoder(embed_dim, latent_dim, num_heads, rngs=rngs)\n", " self.positional_embedding = PositionalEmbedding(sequence_length, vocab_size, embed_dim, rngs=rngs)\n", " self.decoder = TransformerDecoder(embed_dim, latent_dim, num_heads, rngs=rngs)\n", " self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)\n", " self.dense = nnx.Linear(embed_dim, vocab_size, rngs=rngs)\n", "\n", " def __call__(self, encoder_inputs: jnp.array, decoder_inputs: jnp.array, mask: jnp.array = None, deterministic: bool = False):\n", " x = self.positional_embedding(encoder_inputs)\n", " encoder_outputs = self.encoder(x, mask=mask)\n", "\n", " x = self.positional_embedding(decoder_inputs)\n", " decoder_outputs = self.decoder(x, encoder_outputs, mask=mask)\n", " # per nnx.Dropout - disable (deterministic=True) for eval, keep (False) for training\n", " decoder_outputs = self.dropout(decoder_outputs, deterministic=deterministic)\n", "\n", " logits = self.dense(decoder_outputs)\n", " return logits" ] }, { "cell_type": "markdown", "id": "1744cd95-afcc-4a82-9a00-18fef4f6f7df", "metadata": {}, "source": [ "## Build out Data Loader and Training Definitions\n", "It can be more computationally efficient to use pygrain for the data load stage, but this way it's abundandtly clear what's happening: data pairs go in and sets of jnp arrays come out, in step with our original dictionaries. 'Encoder_inputs', 'decoder_inputs' and 'target_output'." ] }, { "cell_type": "code", "execution_count": 13, "id": "1fb8cb44-9012-4802-9286-1efc19dd2ba1", "metadata": {}, "outputs": [], "source": [ "batch_size = 512 #set here for the loader and model train later on\n", "\n", "class CustomPreprocessing(grain.MapTransform):\n", " def __init__(self):\n", " pass\n", "\n", " def map(self, data):\n", " return {\n", " \"encoder_inputs\": np.array(data[\"encoder_inputs\"]),\n", " \"decoder_inputs\": np.array(data[\"decoder_inputs\"]),\n", " \"target_output\": np.array(data[\"target_output\"]),\n", " }\n", "\n", "train_sampler = grain.IndexSampler(\n", " len(train_data),\n", " shuffle=True,\n", " seed=12, # Seed for reproducibility\n", " shard_options=grain.NoSharding(), # No sharding since it's a single-device setup\n", " num_epochs=1, # Iterate over the dataset for one epoch\n", ")\n", "\n", "val_sampler = grain.IndexSampler(\n", " len(val_data),\n", " shuffle=False,\n", " seed=12,\n", " shard_options=grain.NoSharding(),\n", " num_epochs=1,\n", ")\n", "\n", "train_loader = grain.DataLoader(\n", " data_source=train_data,\n", " sampler=train_sampler, # Sampler to determine how to access the data\n", " worker_count=4, # Number of child processes launched to parallelize the transformations\n", " worker_buffer_size=2, # Count of output batches to produce in advance per worker\n", " operations=[\n", " CustomPreprocessing(),\n", " grain.Batch(batch_size=batch_size, drop_remainder=True),\n", " ]\n", ")\n", "\n", "val_loader = grain.DataLoader(\n", " data_source=val_data,\n", " sampler=val_sampler,\n", " worker_count=4,\n", " worker_buffer_size=2,\n", " operations=[\n", " CustomPreprocessing(),\n", " grain.Batch(batch_size=batch_size),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "id": "40d9707d-a73c-47f5-8c12-1f336e526e61", "metadata": {}, "source": [ "Optax doesn't have the identical loss function that the source tutorial uses, but this softmax cross entropy works well here - you can one_hot_encode if you don't use the `_with_integer_labels` version of the loss." ] }, { "cell_type": "code", "execution_count": 14, "id": "d2f8e06f-1126-41cc-b8d8-de6bd7a5255a", "metadata": {}, "outputs": [], "source": [ "def compute_loss(logits, labels):\n", " loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)\n", " return jnp.mean(loss)" ] }, { "cell_type": "markdown", "id": "0a1b625a-d9e7-4028-bc98-521ce1632450", "metadata": {}, "source": [ "While in the original tutorial most of the model and training details happen inside keras, we make them explicit here in our step functions, which are later used in `train_one_epoch` and `eval_model`." ] }, { "cell_type": "code", "execution_count": 15, "id": "279d991f-f129-48b3-9b7e-d143019c18a8", "metadata": {}, "outputs": [], "source": [ "@nnx.jit\n", "def train_step(model, optimizer, batch):\n", " def loss_fn(model, train_encoder_input, train_decoder_input, train_target_input):\n", " logits = model(train_encoder_input, train_decoder_input)\n", " loss = compute_loss(logits, train_target_input)\n", " return loss\n", "\n", " grad_fn = nnx.value_and_grad(loss_fn)\n", " loss, grads = grad_fn(model, jnp.array(batch[\"encoder_inputs\"]), jnp.array(batch[\"decoder_inputs\"]), jnp.array(batch[\"target_output\"]))\n", " optimizer.update(grads)\n", " return loss\n", "\n", "@nnx.jit\n", "def eval_step(model, batch, eval_metrics):\n", " logits = model(jnp.array(batch[\"encoder_inputs\"]), jnp.array(batch[\"decoder_inputs\"]))\n", " loss = compute_loss(logits, jnp.array(batch[\"target_output\"]))\n", " labels = jnp.array(batch[\"target_output\"])\n", "\n", " eval_metrics.update(\n", " loss=loss,\n", " logits=logits,\n", " labels=labels,\n", " )" ] }, { "cell_type": "markdown", "id": "04e53ee9-6da1-431c-8b3f-f619d3fee68f", "metadata": {}, "source": [ "Here, `nnx.MultiMetric` helps us keep track of general training statistics, while we make our own dictionaries to hold historical values" ] }, { "cell_type": "code", "execution_count": 16, "id": "32a17edc-33d0-41bc-a516-8b8ce45c3ad7", "metadata": {}, "outputs": [], "source": [ "eval_metrics = nnx.MultiMetric(\n", " loss=nnx.metrics.Average('loss'),\n", " accuracy=nnx.metrics.Accuracy(),\n", ")\n", "\n", "train_metrics_history = {\n", " \"train_loss\": [],\n", "}\n", "\n", "eval_metrics_history = {\n", " \"test_loss\": [],\n", " \"test_accuracy\": [],\n", "}" ] }, { "cell_type": "code", "execution_count": 17, "id": "1189a6a6-2cc6-4c87-9f87-b4b800a1513d", "metadata": {}, "outputs": [], "source": [ "## Hyperparameters\n", "rng = nnx.Rngs(0)\n", "embed_dim = 256\n", "latent_dim = 2048\n", "num_heads = 8\n", "dropout_rate = 0.5\n", "vocab_size = tokenizer.n_vocab\n", "sequence_length = 20\n", "learning_rate = 1.5e-3\n", "num_epochs = 10" ] }, { "cell_type": "code", "execution_count": 18, "id": "fbeb6101-be11-4a33-9650-a3efd3656855", "metadata": {}, "outputs": [], "source": [ "bar_format = \"{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]\"\n", "train_total_steps = len(train_data) // batch_size\n", "\n", "def train_one_epoch(epoch):\n", " model.train() # Set model to the training mode: e.g. update batch statistics\n", " with tqdm.tqdm(\n", " desc=f\"[train] epoch: {epoch}/{num_epochs}, \",\n", " total=train_total_steps,\n", " bar_format=bar_format,\n", " leave=True,\n", " ) as pbar:\n", " for batch in train_loader:\n", " loss = train_step(model, optimizer, batch)\n", " train_metrics_history[\"train_loss\"].append(loss.item())\n", " pbar.set_postfix({\"loss\": loss.item()})\n", " pbar.update(1)\n", "\n", "\n", "def evaluate_model(epoch):\n", " # Compute the metrics on the train and val sets after each training epoch.\n", " model.eval() # Set model to evaluation model: e.g. use stored batch statistics\n", "\n", " eval_metrics.reset() # Reset the eval metrics\n", " for val_batch in val_loader:\n", " eval_step(model, val_batch, eval_metrics)\n", "\n", " for metric, value in eval_metrics.compute().items():\n", " eval_metrics_history[f'test_{metric}'].append(value)\n", "\n", " print(f\"[test] epoch: {epoch + 1}/{num_epochs}\")\n", " print(f\"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}\")\n", " print(f\"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "49a1d33a-c2e4-4d48-821b-519f5c0192c7", "metadata": {}, "outputs": [], "source": [ "model = TransformerModel(sequence_length, vocab_size, embed_dim, latent_dim, num_heads, dropout_rate, rngs=rng)\n", "optimizer = nnx.Optimizer(model, optax.adamw(learning_rate))" ] }, { "cell_type": "markdown", "id": "fa7d5601-60c1-4131-a40c-c670f055ce68", "metadata": {}, "source": [ "## Start the Training!\n", "With our data loaders set and the model, optimizer, and epoch train/eval functions set up - time to finally press go - on a 3090, this is roughly 19GB VRAM and each epoch is roughly 18 seconds with batch_size set to 512." ] }, { "cell_type": "code", "execution_count": 20, "id": "c764510c-4d98-46ad-b877-8cfc2fa5a9ea", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 0/10, [160/162], loss=1.98 [00:27<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 1/10\n", "- total loss: 1.9655\n", "- Accuracy: 0.6774\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 1/10, [160/162], loss=1.16 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 2/10\n", "- total loss: 1.1961\n", "- Accuracy: 0.7903\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 2/10, [160/162], loss=0.846 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 3/10\n", "- total loss: 1.0054\n", "- Accuracy: 0.8167\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 3/10, [160/162], loss=0.695 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 4/10\n", "- total loss: 0.9351\n", "- Accuracy: 0.8289\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 4/10, [160/162], loss=0.593 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 5/10\n", "- total loss: 0.8976\n", "- Accuracy: 0.8369\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 5/10, [160/162], loss=0.511 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 6/10\n", "- total loss: 0.8876\n", "- Accuracy: 0.8396\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 6/10, [160/162], loss=0.454 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 7/10\n", "- total loss: 0.8857\n", "- Accuracy: 0.8426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 7/10, [160/162], loss=0.421 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 8/10\n", "- total loss: 0.8959\n", "- Accuracy: 0.8427\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 8/10, [160/162], loss=0.371 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 9/10\n", "- total loss: 0.9128\n", "- Accuracy: 0.8434\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[train] epoch: 9/10, [160/162], loss=0.341 [00:18<00:00]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] epoch: 10/10\n", "- total loss: 0.9227\n", "- Accuracy: 0.8452\n" ] } ], "source": [ "for epoch in range(num_epochs):\n", " train_one_epoch(epoch)\n", " evaluate_model(epoch)" ] }, { "cell_type": "markdown", "id": "f922eac4-8338-4a0d-bc6d-1f5880079bde", "metadata": {}, "source": [ "We can then plot the loss over training time. That log-plot comes in handy here, or it's hard to appreciate the progress after 1000 steps or so." ] }, { "cell_type": "code", "execution_count": 21, "id": "a79ecfa5-d74a-4956-9ee2-cbed86d5a82f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAAGdCAYAAADJ6dNTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABWqklEQVR4nO3deVxUVeMG8GcWhn1RdhDEXVFEXHDL1CSXTNN6y8zMzBZLf2WWmZmalUu7qaSVpVaatpia+76jKCqGuCDiziqyyzr39wdwmWFmWGRWeL6fj5935tw7957jS8zjuWeRCIIggIiIiMhMSE1dASIiIiJVDCdERERkVhhOiIiIyKwwnBAREZFZYTghIiIis8JwQkRERGaF4YSIiIjMCsMJERERmRW5qStQW0qlEnfu3IGjoyMkEompq0NEREQ1IAgCsrOz4ePjA6m06r4Riwsnd+7cgZ+fn6mrQURERA/g5s2baNKkSZXnWFw4cXR0BFDaOCcnJxPXhoiIiGoiKysLfn5+4vd4VSwunJQ/ynFycmI4ISIisjA1GZLBAbFERERkVhhOiIiIyKwwnBAREZFZsbgxJ0Rk/kpKSlBUVGTqahCREclkMsjlcr0s88FwQkR6lZOTg1u3bkEQBFNXhYiMzM7ODt7e3lAoFHW6DsMJEelNSUkJbt26BTs7O7i7u3OhRKIGQhAEFBYWIjU1FQkJCWjVqlW1C61VheGEiPSmqKgIgiDA3d0dtra2pq4OERmRra0trKyscP36dRQWFsLGxuaBr8UBsUSkd+wxIWqY6tJbonYdvVyFiIiISE8YToiI6pEDBw5AIpEgIyPD6Pfu168fpkyZUufrrFq1Ci4uLnW+zoOSSCTYuHGjye5fEx999BE6depUq88EBARg0aJFBqmPvjGcEFGD9+KLL2LEiBGmrgaVGTVqFC5fvmzw+zzIF/yD0nfgevfdd7F3795afebkyZN49dVX9VYHQ+KAWCIiMhtFRUWwtbVtsAOqCwsLazQN18HBAQ4ODrW6tru7+4NWy+jYc1Lm0OVUfLT5PDZH3zF1VYjIzBw8eBChoaGwtraGt7c33n//fRQXF4vH//rrLwQFBcHW1haurq4ICwtDbm4ugNLHLKGhobC3t4eLiwt69+6N69eva71Pr169MH36dLWy1NRUWFlZ4dChQwCAX3/9FV27doWjoyO8vLzw3HPPISUlRWfdtfUOLFq0CAEBAWplK1asQLt27WBjY4O2bdviu+++q/LvJDc3Fy+88AIcHBzg7e2Nr776SuMcbY9HXFxcsGrVKgDAtWvXIJFIsH79evTt2xc2NjZYs2aNRi9DeRt+/fVXBAQEwNnZGc8++yyys7PFc7KzszFmzBjY29vD29sb33zzTZWPmVatWoW5c+ciOjoaEokEEolErBcApKWlYeTIkbCzs0OrVq2wefNmtc/HxMRgyJAhcHBwgKenJ8aOHYu0tDSt9zpw4ADGjx+PzMxM8V4fffQRgNJHLZ988gleeOEFODk5iT0b06dPR+vWrWFnZ4fmzZtj1qxZagsbVv7/tbz378svv4S3tzdcXV0xadIktc9UfqwjkUiwYsWKKtu5efNmtGrVCjY2Nujfvz9Wr15tlMeGDCdlzt3KwKpj1xARr/2Hi4hqTxAE5BUWm+SPvhaBu337Nh577DF069YN0dHRWLZsGX766Sd8+umnAIDExESMHj0aL730Ei5cuIADBw7gySefhCAIKC4uxogRI9C3b1+cO3cOERERePXVV3XOZhozZgzWrVunVvf169fDx8cHffr0AVDas/DJJ58gOjoaGzduxLVr1/Diiy/WqY1r1qzB7NmzMW/ePFy4cAHz58/HrFmzsHr1ap2fmTZtGg4ePIhNmzZh165dOHDgAE6fPv1A93///ffx1ltv4cKFCxg0aJDWc+Lj47Fx40Zs2bIFW7ZswcGDB7Fw4ULx+NSpU3H06FFs3rwZu3fvxuHDh6usz6hRo/DOO++gffv2SExMRGJiIkaNGiUenzt3Lp555hmcO3cOjz32GMaMGYP09HQAQEZGBh555BGEhITg1KlT2LFjB5KTk/HMM89ovVevXr2waNEiODk5ifd69913xeNffvklgoODcebMGcyaNQsA4OjoiFWrViE2NhbffvstfvzxR3zzzTdV/j3u378f8fHx2L9/P1avXo1Vq1apBS5tqmpnQkIC/ve//2HEiBGIjo7Ga6+9hpkzZ1Z5PX3hY50y5b8slEoTV4SoHrlfVILA2TtNcu/YjwfBTlH3X3Hfffcd/Pz8sHTpUkgkErRt2xZ37tzB9OnTMXv2bCQmJqK4uBhPPvkkmjZtCgAICgoCAKSnpyMzMxOPP/44WrRoAQBo166dzns988wzmDJlCo4cOSKGkbVr12L06NHi76iXXnpJPL958+ZYvHgxunXrhpycnFp385ebM2cOvvrqKzz55JMAgGbNmiE2Nhbff/89xo0bp3F+Tk4OfvrpJ/z2228YMGAAAGD16tVo0qTJA91/ypQp4r11USqVWLVqFRwdHQEAY8eOxd69ezFv3jxkZ2dj9erVWLt2rViflStXwsfHR+f1bG1t4eDgALlcDi8vL43jL774IkaPHg0AmD9/PhYvXozIyEgMHjwYS5cuRUhICObPny+e//PPP8PPzw+XL19G69at1a6lUCjg7OwMiUSi9V6PPPII3nnnHbWyDz/8UHwdEBCAd999F+vWrcN7772ns02NGjXC0qVLIZPJ0LZtWwwdOhR79+7FK6+8ovMzVbXz+++/R5s2bfDFF18AANq0aYOYmBjMmzdP5/X0hT0nZcr/IaPkkttEpOLChQvo2bOnWm9H7969xWX6g4ODMWDAAAQFBeHpp5/Gjz/+iHv37gEAGjdujBdffBGDBg3CsGHD8O233yIxMVHnvdzd3TFw4ECsWbMGQOm/XCMiIjBmzBjxnKioKAwbNgz+/v5wdHRE3759AQA3btx4oPbl5uYiPj4eEyZMEMcxODg44NNPP0V8fLzWz8THx6OwsBDdu3cXyxo3bow2bdo8UB26du1a7TkBAQFiMAEAb29v8XHW1atXUVRUhNDQUPG4s7PzA9cHADp27Ci+tre3h5OTk3i/6Oho7N+/X+3vq23btgCg8++sKtrav379evTu3RteXl5wcHDAhx9+WO3/x+3bt4dMJhPfq/4d6VJVOy9duoRu3bqpna/6d2xI7DkpIy37xcNoQqQ/tlYyxH6svZveGPc2BplMht27d+PYsWPYtWsXlixZgpkzZ+LEiRNo1qwZVq5ciTfffBM7duzA+vXr8eGHH2L37t3o0aOH1uuNGTMGb775JpYsWYK1a9ciKChI7InJzc3FoEGDMGjQIKxZswbu7u64ceMGBg0ahMLCQq3Xk0qlGo+4VMch5OTkAAB+/PFHtbBR3ra6kEgkVd67nL29fbXXsrKy0ri20oBd3VXdLycnB8OGDcNnn32m8Tlvb+9a36ty+8sD6dy5czFo0CA4Oztj3bp1Wsf11LTO+vyMMbDnpEz5v4nYc0KkPxKJBHYKuUn+6GuV2nbt2iEiIkLtS/bo0aNwdHQUH2NIJBL07t0bc+fOxZkzZ6BQKPDPP/+I54eEhGDGjBk4duwYOnTogLVr1+q83xNPPIH8/Hzs2LEDa9euVes1uXjxIu7evYuFCxeiT58+aNu2bbX/MnZ3d0dSUpJa/c+ePSu+9vT0hI+PD65evYqWLVuq/WnWrJnWa7Zo0QJWVlY4ceKEWHbv3j2N6b/u7u5qPUVxcXHIy8ursr4Ponnz5rCyssLJkyfFsszMzGqnIysUCpSUlNT6fp07d8b58+cREBCg8XemK2jV5l7Hjh1D06ZNMXPmTHTt2hWtWrXSOYjakNq0aYNTp06plan+HRsSw0kZseeE2YSoQcrMzMTZs2fV/ty8eRNvvPEGbt68if/7v//DxYsXsWnTJsyZMwdTp06FVCrFiRMnMH/+fJw6dQo3btzAhg0bkJqainbt2iEhIQEzZsxAREQErl+/jl27diEuLq7KcSf29vYYMWIEZs2ahQsXLojjAQDA398fCoUCS5YswdWrV7F582Z88sknVbarX79+SE1Nxeeff474+HiEh4dj+/btaufMnTsXCxYswOLFi3H58mX8999/WLlyJb7++mut13RwcMCECRMwbdo07Nu3DzExMXjxxRc1li5/5JFHsHTpUpw5cwanTp3CxIkTNf6lrg+Ojo4YN24cpk2bhv379+P8+fOYMGECpFJplSE1ICAACQkJOHv2LNLS0lBQUFCj+02aNAnp6ekYPXo0Tp48ifj4eOzcuRPjx4/XGUACAgKQk5ODvXv3Ii0trcqQ1qpVK9y4cQPr1q1DfHw8Fi9erBZ2jeW1117DxYsXMX36dFy+fBl//PGHOMDW0FtUMJyUKf975jbvRA3TgQMHEBISovZn7ty58PX1xbZt2xAZGYng4GBMnDgREyZMEAcsOjk54dChQ3jsscfQunVrfPjhh/jqq68wZMgQ2NnZ4eLFi3jqqafQunVrvPrqq5g0aRJee+21KusyZswYREdHo0+fPvD39xfL3d3dsWrVKvz5558IDAzEwoUL8eWXX1Z5rXbt2uG7775DeHg4goODERkZqTZTBABefvllrFixAitXrkRQUBD69u2LVatW6ew5AYAvvvgCffr0wbBhwxAWFoaHHnoIXbp0UTvnq6++gp+fH/r06YPnnnsO7777Luzs7Kqs74P6+uuv0bNnTzz++OMICwtD7969xanRujz11FMYPHgw+vfvD3d3d/z+++81upePjw+OHj2KkpISDBw4EEFBQZgyZQpcXFx07i3Tq1cvTJw4EaNGjYK7uzs+//xzndcfPnw43n77bUyePBmdOnXCsWPHxFk8xtSsWTP89ddf2LBhAzp27Ihly5aJs3Wsra0Nem+JYGHfxllZWXB2dkZmZiacnJz0dt2fjiTgky2xGB7sg8WjQ/R2XaKGJD8/HwkJCWjWrFmddiQlqqvc3Fz4+vriq6++woQJE0xdnXpj3rx5WL58OW7evKn1eFW/A2rz/c0BsWWk5T0npq0GERE9gDNnzuDixYsIDQ1FZmYmPv74YwClY3jowX333Xfo1q0bXF1dcfToUXzxxReYPHmywe/LcFKmfMwJB8QSEVmmL7/8EpcuXYJCoUCXLl1w+PBhuLm5mbpaFi0uLg6ffvop0tPT4e/vj3feeQczZsww+H0ZTspwzAkRkeUKCQlBVFSUqatR73zzzTfVrkxrCBwQW0bC2TpERERmgeGkDNc5ISIiMg8MJ2W4zgmR/vDxKFHDpK//9hlOykjFvXVMWw8iS1a+3LmupdSJqH4rX1yurovtmWRA7MiRI3HgwAEMGDAAf/31lymqoIEDYonqTi6Xw87ODqmpqbCystK5IBUR1S+CICAvLw8pKSlwcXGp875MJgknb731Fl566SWsXr3aFLfXSsKN/4jqTCKRwNvbGwkJCSbZC4SITMvFxQVeXl51vo5Jwkm/fv1w4MABU9xaJ65zQqQfCoUCrVq14qMdogbGysqqzj0m5WodTg4dOoQvvvgCUVFRSExMxD///IMRI0aonRMeHo4vvvgCSUlJCA4OxpIlSxAaGqqXChtKxWwdk1aDqF6QSqVcvp6IHlitHwjn5uYiODgY4eHhWo+vX78eU6dOxZw5c3D69GkEBwdj0KBB1W7rbWrlj8Y55oSIiMi0at1zMmTIEAwZMkTn8a+//hqvvPIKxo8fDwBYvnw5tm7dip9//hnvv/9+rStYUFCgto11VlZWra9RExJwKjEREZE50OtQ+sLCQkRFRSEsLKziBlIpwsLCEBER8UDXXLBgAZydncU/fn5++qquGnG2DofEEhERmZRew0laWhpKSkrg6empVu7p6YmkpCTxfVhYGJ5++mls27YNTZo0qTK4zJgxA5mZmeIfXds015U4IFZpkMsTERFRDZlkts6ePXtqfK61tTWsra0NWJtSEnERNvacEBERmZJee07c3Nwgk8mQnJysVp6cnKyXec+GJOU6J0RERGZBr+FEoVCgS5cu2Lt3r1imVCqxd+9e9OzZU5+30rvyqcScrUNERGRatX6sk5OTgytXrojvExIScPbsWTRu3Bj+/v6YOnUqxo0bh65duyI0NBSLFi1Cbm6uOHvHXEnERdhMXBEiIqIGrtbh5NSpU+jfv7/4furUqQCAcePGYdWqVRg1ahRSU1Mxe/ZsJCUloVOnTtixY4fGIFlzI+XeOkRERGah1uGkX79+1X6BT548GZMnT37gSpkCe06IiIjMA7cMLSP2nJi2GkRERA0ew0kZcbYOH+sQERGZlMWEk/DwcAQGBqJbt26GuQHXOSEiIjILFhNOJk2ahNjYWJw8edIg16/oOTHI5YmIiKiGLCacGFr5OiccEEtERGRaDCdlOOaEiIjIPDCclKlY58S09SAiImroGE7KcUAsERGRWWA4KcON/4iIiMwDw0kZqbhCLOMJERGRKTGclJFwzAkREZFZYDgpw43/iIiIzIPFhBODrxALbvxHRERkDiwmnBh+hdjS/xU4JJaIiMikLCacGJo4IFZp4ooQERE1cAwnZSQcc0JERGQWGE7KyMqe6xRx0AkREZFJMZyUsbGSAQAKikpMXBMiIqKGjeGkTHk4yS/moBMiIiJTYjgpYy0v/asoLFZCyUc7REREJsNwUqa85wQACkvYe0JERGQqDCdlbOQVfxX5HHdCRERkMgwnZeQyKeRlM3byi9hzQkREZCoMJyrKx52w54SIiMh0GE5UiNOJOWOHiIjIZCwmnBh+4z/ASlb611HEAbFEREQmYzHhxNAb/wGAlbx0zAln6xAREZmOxYQTY1CU95zwsQ4REZHJMJyoqHisw0XYiIiITIXhRIVCzjEnREREpsZwoqK854RjToiIiEyH4USFlaxsQCzHnBAREZkMw4kKTiUmIiIyPYYTFQqGEyIiIpNjOFFRMeaEs3WIiIhMheFEhThbh2NOiIiITIbhREV5OMkv5sZ/REREpsJwosJOUbrx3/1ChhMiIiJTYThRYVsWTvIYToiIiEzGYsKJMXYltrOSA2A4ISIiMiWLCSfG2JW44rFOscHuQURERFWzmHBiDHysQ0REZHoMJyrEnpMihhMiIiJTYThRYafgmBMiIiJTYzhRYcfHOkRERCbHcKKCA2KJiIhMj+FEBQfEEhERmR7DiQqOOSEiIjI9hhMVFWNOiiEI3JmYiIjIFBhOVJQ/1lEKQAF3JiYiIjIJhhMVdlYy8TU3/yMiIjINhhMVcpkUClnpX0keF2IjIiIyCYaTSmw5nZiIiMikGE4q4UJsREREpsVwUgnXOiEiIjIthpNKKlaJZTghIiIyBYsJJ+Hh4QgMDES3bt0Meh87Ky7ERkREZEoWE04mTZqE2NhYnDx50qD3sVVZiI2IiIiMz2LCibGIj3U4lZiIiMgkGE4q4YBYIiIi02I4qYRTiYmIiEyL4aSS8p2JuQgbERGRaTCcVGJrxZ4TIiIiU2I4qcTemuucEBERmRLDSSW2Cq5zQkREZEoMJ5XYlT3WyeWYEyIiIpNgOKmEy9cTERGZFsNJJVznhIiIyLQYTiqxE8ec8LEOERGRKTCcVOLhaA0ASMrKhyAIJq4NERFRw8NwUomPiy2kEiC/SInUnAJTV4eIiKjBYTipRCGXwr2s9yQ5k+GEiIjI2BhOtCgPJ6k5+SauCRERUcPDcKKFu0NZOMlmzwkREZGxMZxoIfacMJwQEREZHcOJFh6ONgAYToiIiEyB4USL8p6TFIYTIiIio7OYcBIeHo7AwEB069bN4PdysbMCAGTeLzL4vYiIiEidxYSTSZMmITY2FidPnjT4vRxtSleJzc7nKrFERETGZjHhxJgcrEt7TnIKGE6IiIiMjeFEC/acEBERmQ7DiRYO1qXhJCu/iPvrEBERGRnDiRaeTjawsZKisFiJKyk5pq4OERFRg8JwooVCLkU7bycAQHwqwwkREZExMZzo4OVUuhBbchbXOiEiIjImhhMdPMVwws3/iIiIjInhRAcPp9JVYtlzQkREZFwMJzp4Olb0nHDGDhERkfEwnOjg7VIaTo5cScMLP0eauDZEREQNB8OJDl2aNhJfH45LQ0FxiQlrQ0RE1HAwnOhgLZepvU/h2BMiIiKjYDipwuzHA8XXfT7fjx8OxXP8CRERkYExnFThpYeaYXL/luL7+dsu4r/bmSasERERUf3HcFKNAe081N6/80c0e0+IiIgMiOGkGkG+zmrv41JyEDx3F47Fp5moRkRERPUbw0k15DIpjs8YgNBmjcWyrPxiPPfjCRPWioiIqP5iOKkBL2cbfPV0sEb52ZsZWHPiOpRKPuYhIiLSF7mpK2ApXOysNMpGhB8FANjIZXiqSxNjV4mIiKheYs9JDTlYyzE82EfrsehbGcatDBERUT3GcFJDEokEi0eH4NrCoRrHZFKJCWpERERUPzGcPIDlz3dWey9nOCEiItIbhpMHMKi9l6mrQEREVG8xnDwAiUS9p+TMjQzTVISIiKgesphwEh4ejsDAQHTr1s3UVdFw6vo9TFh1Eieu3jV1VYiIiCyeRLCwtdizsrLg7OyMzMxMODk5maweKVn5ePaH47ialqtWrm3ALBERUUNXm+9vi+k5MTceTjbYMeVhU1eDiIio3mE4qQOFXPOv72Z6nglqQkREVH8wnNTR3nf6qr0fv+ok8otKTFQbIiIiy8dwUkct3B3U3l9JycGArw6isFhpohoRERFZNoYTPXitb3O197cz7uN2xn0T1YaIiMiyMZzowYwh7TTKRn0fgYOXU/Hr8euwsAlRREREJsVwoieH3+uv9j4luwDjfo7ErI0xOHszwzSVIiIiskAMJ3ri19hO57Hs/GIj1oSIiMiyMZwYgVzGjQGJiIhqiuFEj7StewKAU4uJiIhqgeFEj9a/2gM9mjfGlv97SK38pVWn8HfULRPVioiIyLIwnOhRiH8jrHu1Jzr4Omv0orzzZ7T4Wqnk7B0iIiJdGE4M5NF2nhplBcUlmPvveTT/YBvGrDjOKcZERERaMJwYyLyRHTTK2s7agZVHrwEAjl65iwuJ2UauFRERkfljODEQFzuFRlnljpKcAk4xJiIiqozhxIBaezpUefyZ7yOMVBMiIiLLwXBiQKvGh1Z7zn+3MlFUwk0CiYiIyjGcGJCPiy2+fDoYM4a01XnOsKVH0GnuLgiCgPTcQiPWjoiIyDwxnBjY/7o0wWt9W6CVR+kjHjuFTOOc3MISjPzuGDp/shtbzyUau4pERERmheHESJaP7YIhHbzwx2s9tR4v3xxw9qYYI9aKiIjI/DCcGEkLdwcse74LOvg6I9DbSed5d3MLOYuHiIgaNIYTE1g1vht8XWx1Hu8wZyeGLTmCW/fyjFgrIiIi88BwYgIeTjY4Mr1/lef8dzsT3x+8aqQaERERmQ+GExORSCSI/GBAledcTcsxUm2IiIjMB8OJCXk42WDDG73Qr407xvVsqnH8Zvp9E9SKiIjItBhOTKyzfyOsGh8Kf1d7jWNpOQUoLlFi1/kkZORxDRQiImoY5KauAJUqKC7RKMsrLEHLmdsBAH1aueHXCd2NXS0iIiKjY8+JmXimqx96tXDFa32bY+3LmiHkcFwa/rfsGHaeTzJB7YiIiIyHPSdmws3BGmtf6SG+d7SWI7vSeienrt/Df7+fQezHgyGTSoxdRSIiIqNgz4mZUg0qqgqKlUhIyzVybYiIiIyH4cRMBTVxxu86AsqFxCwj14aIiMh4GE7MWM8Wrtj65kMa5f/3+xmkZheYoEZERESGx3Bi5gK9neDf2A4AENbOQyzvNm8Ppv5xFkmZ+aaqGhERkUEwnJg5iUSCXW8/jKXPheDz/wWrHdtw+jYmrD5popoREREZBsOJBbCxkuHxjj5obK/Aa32bqx07fycLgiCYqGZERET6x3BiYUL8XDTKIuLvGr8iREREBmIx4SQ8PByBgYHo1q2bqatiUto6Sf6MuoW9F5KNXxkiIiIDkAgW9kwgKysLzs7OyMzMhJOTk6mrY3R7YpPx8i+ntB5b/2oPBPu5wMZKZuRaERERVa02398W03NCpfq1ccfg9l54b3AbuDlYqx0b9cNxtJ21A0qlReVNIiIiNew5sWAxtzPx+JIjGuVBvs4Y3MELk/q3NEGtiIiINLHnpIHo4Oustfy/25n4Yucl3Mm4b+QaERER1R3DiYXTtcQ9AOQVFus8RkREZK4YTixczxau+OeNXlqPZecznBARkeVhOKnHGE6IiMgSMZzUA229nGAtl6JJI1u18lmbYjD1j7M4HJeKR748gD2xXAuFiIjMH2fr1BO5BcWQyyQYsugwrqbl6jwv8oMBOBSXhmHB3rCWcz0UIiIyDs7WaYDsreWwlsvw3fOdqzxv2NIjePfPaCzeG2ekmhEREdUOw0k909bLCZsm9dZ5PDmrAACw6zwf8RARkXliOKmHgv1ccPT9R0xdDSIiogfCcFJP+brY4uzsR3Uet6iBRkRE1KAwnNRjTjZWOo9Z2DhoIiJqQBhO6jGpVKLzGKMJERGZK7mpK0CGte+dvkjKzEcjewWGfHtYLGfHCRERmSuGk3quubsDmrs7aJQnpOVi1dEEtPV2Qo/mriaoGRERkXYMJw3YR//GAgDWvNwd5+9k4pU+zSGR6H4UREREZAwcc9KArHihq9byMStOYP62izhyJQ0/H0nA4EWHkJpdYOTaERERlWI4aUAeaeuBaYPa6DyeklWAj7fE4mJSNn44FG/EmhEREVVgOGlApFIJJvVvqfO4naJirx0+3iEiIlNhOGmAdC1vfze3UHxtr+BwJCIiMg2GkwYo2M8Ff7zWU6P8w40x4msHG4YTIiIyDYaTBiq0WWNU9eTGWs4fDSIiMg1+AzVgL/RoqvPYhxtj8P1BDoolIiLjYzhpwEZ2blLl8QXbLxqpJkRERBUYThqwTn4u1Z6TU1CMUd9H4MudlxBzO9PwlSIiogaP4aSB83SyrvL40n1XcCIhHUv3X8HjS44woBARkcExnDRwElSMinVzUGgcX15p3ElE/F219wJ3ECQiIj1jOGng5j7RHgAwqX8LtPN2qvZ8qbQizBy9koaOc3dhc/Qdg9WPiIgaHolgYf/0zcrKgrOzMzIzM+HkVP2XKVUvI68QzrZW+CvqFqb9da7a8395KRRtvR3Ra8E+FCtLf3yuLRxq6GoSEZEFq833N3tOCC52CkgkEvyvS9Wzd8q98HMkHvv2iBhMiIiI9InhhEQSiURtf52qpOVw12IiIjIMhhNSE+Lv8kCfu3UvD79EXEN+UYl+K0RERA0Owwmp+fqZTghwtcMLPZvWaB2UciPCj2L2pvNYsi/OcJUjIqIGgeGE1Hg62eDAtP74+IkOUB1R0tzdvsrPpeWU7mh88HKqAWtHREQNAcMJ6TT78XYAgP97pCXsFTXbpVgm5Y8UERHVTc2+cahB6tK0MS58PBi2ChnsreX4rwarw0bfzDB8xYiIqF5jOKEq2ZbN3nmlT3MoBQG9Wrhh/MpI3Msr0vkZQRAgkUh0HiciIqoKwwnViEwqwRv9Woqvq5JfpBRDjVIpqK0qS0REVB2GE6q16tZeW3PiOhrZKRDUxBlT/zgLANg06aFqQw0RERHAcEIPoKSadPLp1gsaZXdzC+DhaGOoKhERUT3CqRVUawp57X9ssu7rHqNCRESkiuGEam35853h5mCNxaNDMDzYp0afyahiAC0REZEqPtahWuvStDFOzhwAiUSC4cE+GNcrAMev3sWhy6k4kZCu9TNXU3PR2b8RjlxJQ5CvMxrZK4xcayIishQMJ/RAVKcKd2naCF2aNkJqdoHOcPLe3+ew+0Iydscmo7m7Pfa9089INSUiIkvDxzqkN/6N7ao8vjs2GUBpLwoREZEuDCekN2N6+GNUVz8sG9NZLGvn7VTlZ67fzcXdnAJDV42IiCwIH+uQ3ljLZfjsfx0BAI939MaWc4l4a0BLWMmkmLD6lNq52/9LRFNXezy2+DAAYO87fdHC3cHodSYiIvMjEQShmiW1zEtWVhacnZ2RmZkJJ6eq/1VOppOVX4T4lByE+DfCyWvpeHp5RLWf2fBGL3T2b2SE2hERkbHV5vubj3XIIJxsrBBSFjRsrWQ1+syOmCS193+euomVRxOQlJmPJXvjkJrNxz9ERA0BH+uQwdnUMJx4OlWsIHssPg3T/joHAJj7bywA4MDlVPz9ei/9V5CIiMwKe07I4Mo3AQSAlh66x5VsOXdHfL0nNkXjeNT1e/qtGBERmSWGEzI4ucqGf39N7KnzvDM3MrD1XCI+2RKL9Fw+wiEiaqj4WIcMzsPRGn1aucFaLoOzrRUa2yuQnluo9dxJa08buXZERGRu2HNCBieRSPDrhO5YMa4rJBIJ7K0rHvN8Xjb1uC4OXk7Fcz8ex427eXW+FhERmR7DCRldBx9n8fUzXf3Q3M2+Ttcb93MkjsXfxYebYupaNSIiMgMmCSdbtmxBmzZt0KpVK6xYscIUVSAT+mREBwwL9sHvr/QAAAzq4FXrayw7EI9eC/bi1r2K3pLM+9z5mIioPjD6mJPi4mJMnToV+/fvh7OzM7p06YKRI0fC1dXV2FUhE3FzsMaS0SHi+7cGtEIbT0fYWElx5Eoafjt+Q+dnP9z4H6KuZ+BCYhYAYOm+K+IxJ5ua/Thn3i+Ck41cbfNCIiIyH0bvOYmMjET79u3h6+sLBwcHDBkyBLt27TJ2NciM2FjJMCLEF4M7eOPTEUGY8FAzOFjL8flTmuNRfjt+QwwmAKCaLxysqw8nx6/eRfDcXfhwIx8BERGZq1qHk0OHDmHYsGHw8fGBRCLBxo0bNc4JDw9HQEAAbGxs0L17d0RGRorH7ty5A19fX/G9r68vbt++/WC1p3pp1uOBODv7UTzTza/ac3+PvCm+1rYSbYlSfXeGr3ddBgCsOaG7d4aIiEyr1uEkNzcXwcHBCA8P13p8/fr1mDp1KubMmYPTp08jODgYgwYNQkqK5qJaRLrIZaU/mt8+26nGn9lw5jb2xCYDAK6m5iDg/a1o8cE2XEnJFs8RYFFbSRERNUi1DidDhgzBp59+ipEjR2o9/vXXX+OVV17B+PHjERgYiOXLl8POzg4///wzAMDHx0etp+T27dvw8fHReb+CggJkZWWp/aGG44lOvnitb/Man//yL6egVAp45KuDYtnC7RcNUTUiIjIQvY45KSwsRFRUFMLCwipuIJUiLCwMERGlu9KGhoYiJiYGt2/fRk5ODrZv345BgwbpvOaCBQvg7Ows/vHzq76rn+qXGUPaIaydR43P//v0rUolEpRvvq26B/ete3lYuP0ikjLz9VBLIiLSF73O1klLS0NJSQk8PT3Vyj09PXHxYum/XuVyOb766iv0798fSqUS7733XpUzdWbMmIGpU6eK77OyshhQGqClz3VGfGoOlu67gu2Vdi+urHzDwHJ5hcVoNmMbAMDd0Vosf3HlSVxJycGJhLv4543e+q80ERE9EJMsXz98+HAMHz68RudaW1vD2tq6+hOpXrOxkqG9jzOGdvSuNpxUdiz+rvg6Nbtiz54rKTkASvf0ISIi86HXxzpubm6QyWRITk5WK09OToaXV+0X2iIiIqKGR6/hRKFQoEuXLti7d69YplQqsXfvXvTsqXs3WqKaeqilm6mrQEREBlbrxzo5OTm4cqViVc6EhAScPXsWjRs3hr+/P6ZOnYpx48aha9euCA0NxaJFi5Cbm4vx48frteLUMLnYKRDWzhN7LiRXfzIREVmkWoeTU6dOoX///uL78sGq48aNw6pVqzBq1CikpqZi9uzZSEpKQqdOnbBjxw6NQbJED8pOUbHYWvScgQieW/cVhvMKi2GnMMkQLCIiqqTWj3X69esHQRA0/qxatUo8Z/Lkybh+/ToKCgpw4sQJdO/eXZ91pgZuaEdvAIBfY1s421phYKD24NunVc0fASVl5kMQBIxfGYmA97di/rYL4vRjIiIyLpPsSkxUFwMDPfH3672w5f/6AAAKS5Raz+vs36jG18zOL8bBy6nYfykVAPDDoauITeSCf0REpsB+bLI4EokEXZpWBI9Gdgrx9bmPBmLCqpN4LMgbHo42Nb7mE+FHNcouJWWjvY+zRvme2GQUlSgxJMi7ljUnIqKasJhwEh4ejvDwcJSUlJi6KmRmZgxpi7ScAozt0RRONlb4c2IvAKWb/j3VuQkCfZxw6lp6rddHmfpHNNwdrdGnlbtYVlisxMu/nAIAfPV0MIZ38oGVTL0DUhAESFS3SyYiolqRCBb2YD0rKwvOzs7IzMyEk5OTqatDFuK1X09h5/kHn+ET/lxnDO3ojcy8IgR/XDEAd+qjrfHmgFYAgPyiErSdtQMAsO3NPgj04c8nEVG52nx/c8wJNQg6hqVo5WxrpVE2ae1pLN0XpxZMAOCXiGvi67+iKvb0+WjzefwVdQsp2dy3h4iothhOqEGQ6njK8mKvAI0yB2vtTzu/3HVZo0x1+vH9wopHjpHX0vHun9F4ZnlE7SpKREQMJ9QwzHisHdwcFJg+uC3eDmuNFu72ODPrUbRwt9c4VyaVoLO/S42ua68SZOy1hJprd/O0fi47vwgbz9xGdn5RtffILSjGsz9EYOXRhBrViYjI0lnMgFiiumjmZo+TM8PEgapvhZWOE3HXMqPny6eD0aSRLXot3FftdS8kZiHg/a3o38YdnfxqPnV56h/R2B2bjKEdvRH+XOcqz/0l4jqOX03H8avpGN+7WY3vQURkqdhzQg2Gthk0/du6I6Ssl+Th1u64Ov8xhDZrDFsrmca5Vdl/KRXf7NF87KPL7tjSwblbzyVWe25OQfW9K0RE9Ql7TqhBs5bL8M8bvTXLrfSf25VKAYlZ+fB1sRXLrGTVTzm2rPl0RER1x3BCpIW1vKLnJLiJM6JvZYrvA1ztdI4l0ebYlTTEpeTgUnI21p64gW+f7SQec7TRnBlUmZLhhIgaGD7WIdJCpjK956WHmiG4ScVKsQem9cfZ2Y/CzcG6Rtd6bsUJzNl8HmtP3AAAvLXurHjM3lr98VHU9XRMWHUS19JyxTIBTCdE1LAwnBBVo7WnI0KbNVYrc7FTYM/UhzGovfqmg7XZbBAAbqbfxw+H4gEAyVn5eGpZBPZeTMGb685UnKSSTS4lZWPgNwex/b/qx6oQEVkqiwkn4eHhCAwMRLdu3UxdFWog1r3aA98+2wntvJ3wSp/mcHOwxmsPNxePu9gp8PlTweL7n1/silXjQ2t9n/nbLiK/qATd5+8Vy26k5yGnoBh/R91CRl7FgNi31p3B5eQcvL7m9AO2iojI/HH5eqIa0rVnzvGrd6GQS8VdkP+OuoV3/oyu070a2yuQnltYZfm1hUPrdA8iImPi8vVEBqBrM78ezV3FYAIAw4J96nwvbcGkqnIiovqE4YRIzxRyKd4b3Mao95z773l8s7vm66wQEZkzTiUmMoA3+rXEG/1aIju/CEEf7ar+Aw/gXm4h7KxlSM4swMqj1wAAE/o0g1MNpicTEZkzhhMiA3K0sdI5fqSu/j13B+H7ryA5q0Asu5qai05+LmrnXUzKQolSQHsfZxARWQI+1iEysN8mdDfIdWdvOq8WTAAgKfO+2vuC4hIMXnQYQxcfQV5hsc5r5RToPkZEZGwMJ0QG1tbLUaNs0ahOBrnXxN9O49fj18X3qj02gbN3ou8X+5GSla/2mdXHrqHDnJ3YdPa2QepERFRbDCdEBiaVSvDJiA5qZQXFJVrPXfhkUJ3vN2tjDALe34qt5xJxN0f9cdL1u3n4dOsFtbI5m88DUF+5lojIlBhOiIxAZTV8dPJzQUGxUnzftWkjsTyoif7GhUxaexqxd7I0yi8mVZRl51cs8GanqN1OzEREhsIBsURG0Kelu/j691d64Nfj18T3f73eC/lFJbCSSdVCjD689/c5jbLLyTnignIXErPF8rzCEmw8cxsh/i5o6mqv8bmb6XmY/vc5vPJwc/Rv41GreuQWFMNOIdO5VgwRkSr2nBAZgb+rHY5M74/zcwfBViFDQZFS7biNlQwyqQQSiQSRHwzQ+HzlvX0qs5JJxB6Ymoi4ehe37uXh2t1ctfIp68+i7xcHxPeFxUocv3oXBcUleOfPaByLv4vxK09We31BEPDpllj8evw6LiZlof2cnXjvL82gRESkjcWEE+6tQ5auSSM72FuXdlZ2rDTdV5WHkw2iZw9UW8ht7ctVz/gpKhHw8/ia/7cxac1pPPTZ/moDw7ytsXj2h+OYs+k8btzN0ziuVArYHZusMcj29I0MrDiSgFkbY7DsQOnGhn9G3apx/YioYbOYcDJp0iTExsbi5Mnq/9VGZO4ebuWG5c93wYF3+2k97mxnBZnKIxC5TIpxPZsCAF7r2xzvPNpa7fw3B7Sq1eJr91Q2E6zK6ojSmT/rTt5EUqUAAgB/nb6FV345hYGLDqmVZ+RVDMQtUVrU9l1EZAY45oTIBCQSCQZ38KrynB7NXdXezxwaiOGdfNGxiTOsZFIEuNkj5k4mngj2RRst05XrqiZ7gu67kAIAajsnA+qBhOGEiGrLYnpOiBqaYD8XbJzUG5EzS8egKORSdGnaCFay0v9shwX7YMaQdgj0cYJMy0jaNp4VgeXfyQ/V6t5zNsWg2YxtOo/fzSnA75E31KZE/3HyJoDSBd2KVQJJYbFS4/NERFVhzwmRGau8FH113BwUSMspxPLnO0MpAG+sOQ0AaOPliE+eaI9Zm87X6Drlj3N06fLpHo2y9/4+h7zCYnz0byxauFfM9snXsaYLEZEu7Dkhqkc2vN4bP77QFYM7eEO1L8VKJkFYoCeA0hVr/5rYE8uf76L3+3/0bywAID61YhZQfpH2npMle+Pw3l/RNXp8REQNC3tOiOoRf1c7+LvaAQBUlxSRSCTwdrbFmVmPwsFGLj4aMoao6/e0ln+1+zIA4KnOTdC90viacoXFSljJJFwfhaiBYc8JUT3VwVdztdlG9gq1YPLLS6EY0sELK2sxDbmuzt7MwJZzd8T3f0bdwr/RdzTOi0xIR+sPt2Pib1FGqxsRmQeJYGF9qllZWXB2dkZmZiacnJxMXR0is3Y5ORsutlbwcLKp9tzjV+/isx0XIZNIcEpHb0ddnZwZhm7zNMerAEDkBwPU6hnw/lbx9bWFQ9XOzcgrxA+HruKpLk3Qwt3BIHUlIv2qzfc3H+sQ1WOtPWs+xbhHc1f880ZvXE7OxsBvDlX/gQfw/cF4ncfyCms+cLbTx7sBAOtP3kTUrEfrXC8iMi98rENEaryctfeyBHrXvadyxZEEncfuF5Ugv6g0oLy17ozaMdUNDFWnJt/NVd91mYjqB4YTIlLjZGOFn8Z11Sj/aHh7g953yLeH0X3+XhQUl2DTWfUxKI8tPiy+zikoVju263yS2k7LRGT5GE6ISMOAdp7Y+qb6wm3aFnrr29odPjp6Wh5E5v0iHLyUqvWYIAjYei4RZ26oj4d59dcoDF5UGl7yCoux7b9EMcAolQIuJmWhqIQLwRFZEo45ISKtpJWm78pVwklYOw+8N7itOKbl2JU0PLfihF7u++qv2mfnLNoTh2/3xlX52Q83xmDD6dsAgFf6NIOrgzUWbr+IoR29Ef5c52rvffRKGgpLlOjfxgM5BcWwLdstmoiMy2J6TrgrMZFxVQ4nql/Sn44IUhts6+pgrfH57W/1ga+LbZX3sLWS1bg+1QUTQRDEYAIAPx5OwMLtFwEAW88lVnv9ohIlxqw4gfErT+JKSjY6zNmJMSuO17h+RKQ/FhNOuCsxkXG1cLdHM7fSZehXvxQK1UUHKocKbb0Lbb0cq139dc0r3ete0TL3i6qe7TPk28OIuZ0JQPumhgUqA21/jyzdJ+j41XS91Y+Ias5iwgkRGZdcJsXutx/G1fmPoW9rdxSqjNuwUaj/6nBzUGh8XiKRoEQlBDjbWqGDb8WMHw9Ha3T2b6S3+uYWlMBarvtX2oXELCzacxmZeUXotXAfZmz4TzyWnluI+ypTmasLOkRkWBxzQkQ6yVVWk1UdVKqotPy9i50C/7zRCzZWMrjaK2Bd1rMS3MQFu2KTIZNKcOKDAVDIpPj33B18sfMSvh+r3719/jh1U633QzsJvjtwBYmZ+fg98gYWPBmES0nZGLToEIKbVKyom1tpRhARGRfDCRHViLfKrBxte92EaOkFmf9kEHxcbPFsqB9sygLLE5188UQn32rvt/LFbvhq9yXE3K7ZNOEvdl6q9pyU7HzsuZAsvk/NLsB3B64AAKJvZYrlmfeLtH7+Znoerqblom9r9xrViYgeDMMJEdVIU1d7/DSuK9y0DH7Vxc3Butr1UQa198TO88l4tpsfujdvjLfXRwMA2vs4Ycv/9cGe2GRYW0mRmJmPfm3cETpv7wO34ZxKAAGAQYsOIV3LQm6qPSdKpQBp2ZiaPp/vBwCsfaU7erVw03oPQRC4USFRHTGcEFGNDWjnqfdrLhoVgu0xiXiolRscrOUASsOJo40VACAsUP2ew4N9sDn6DuwUslotea+NtmACqI85KSxRwloiVXtkdDz+rkY42X8xBS+tPgkbuQzrX+uBjk1c6lQ3ooaMG/8RkVk5eS0dggCENmus9fj9whKcvJYOGysZnvk+wuD1+fbZTjh4OVVtmrKjjRwPt3bHpH4tkZydj/5tPNQ2KgQ0Nyskaui48R8RWaxuAdpDSTlbhQwPl435ePXh5vjh0FWD1uetdWc1yrLzi7H1XKK4fsq/kx/SOKfy453TN+7hm92X8ULPADwaqNkDdeteHtwdrVGiFLAlOhH923rA3bHmj9CI6hOGEyKyWB881g6v9GkOFzsrtJq5XeP4i70CsOrYNYPXIzYxU6Ps4y2xmP14IBbtiUOTRraY9tc5AEB8So5GOIm+mYEnwo8i2M8FwU2c8UvEdbTzdsL2t/rUqh430/OgkEvh6aS/LQWITIHhhIgsWlW9C+N0hJPQZo0RmaC/Bda0bd2z8ug1DOngrbGy7Z3MfBSVKCGVSMTF69advAGgNKTcvncfQOm6LLWRlV8kDthNWPAYB+WSReMibERUb+laHn9czwC93kfX1GNtY2Ia2ysw6JtDGL70iLhSreoCcArZg4WKy0nZ4utCbnRIFo7hhIjqhciZAzTKFJVWjA0NaIzVL4WihYe9Xu999Epajc9Nzy3E1bRcnL+ThfTcQsQlZ2Pj2Tvi8TuZ+eLrw3Had2hWlVtQjMTM+0jLKRDLCqtdjI7IvDGcEFG94OFog+XPq686q7rnz+bJvbH+tR7o29odAa72sHrAHoro2QPR2d9FrexILcKJqjsZ+fh4S6zO42N/isQ/Z24BAJYfjMcfp26Kx8pX7O25YC96LtiHS0k54jHVcCIIAle8JYvDcEJE9cbgDl74a2JPtPVyxLyRHdQCiL21XByHYWMlQ/ScgeKxcT2bolcLV6x9uTsCXO3E8rVaNiZ0trOCs62VXup7ISkLSSo9Jdq8vT4a/1t2DAu3X8R7f53DzfQ8LNpzGR0/2oVNZ28jK780eBy5UtHLovpY55MtF9B+zk5E38zQS52JjIEDYomoXuka0Bg7pjwMAMhXWUzNSqr+bzE7hRyONnJk5xfj1b4t4OtiCwAI9nPBtbt5AKCx0Jq23Zfr4nBcGuJScqo979T1e+Lr8kGvgPo051tlA2kB9Z6Tn48mAAC+3HUJv05QD1u3M+5j38UUjAzxxUsrT6KFhwMWPBlU63YQ6RvDCRHVWzZWMgwM9EReYQn8GttqHN/19sPIKywRgwkATBvUBieupmNsz6YAgBUvdMWyg/Ho3dINjwV5AQDS87QPgK2tA5dS9HIdAEhU6YGZv+0C3hrQGoE+FQtdFRQpcTvjvlpbn1h6BGk5hfgr6haib2Yg8lq6WjhRKgVk5RfBxU5z12kiQ7KYFWLDw8MRHh6OkpISXL58mSvEEpHBVLc/zkOf7VPrqSg3vncA5FIJfjycoFbe3N0eV1Nz9V7P6lxbOFRj5dpNk3oj2M8FADSOAaXTkC8lZ0MQgJVHE7DxzB2sGNdVXPhO1clr6XBzsEYzN/0OMKb6qV6uEDtp0iRMmjRJbBwRkaFUt0aIai/Flv97CI8vOQIAaOZmjxd6BqCDrzPScgox4aFmAEp3Q1667wqsZFL8dCRB6zUNYU9sskbZP2dui+FEm/tFJRi86LBa2cyN/+Hwe4+olcWn5uDp5aVTpa8tHIrkrHxIJKUDk6uz7EA81p+8gT9e6wkPLhhHWnBALBFRLX00LBAA8NlTQWjl6SCWy8vGtTzRyVcMJkDpF/bHT3RAa5VzjeHlX05plDVVGfCrzeVkzTEw9wtLkJSZjyspFWupxNzOVDveff5ehM7bixKl7s748P1XsHRfHD7bcRHX7uZh6f4rNWkGNUAW03NCRGQuxvYMwOMdfdDIXgHVJ+PVTU+uydpofVq54XDcg01NromNZ25j1/lkLHu+s9bjI8KPai3vsWAvgNL1ZDwcbVBUUtHudrN3iK//Pn0LT3VuojF4OCu/CF/svKRWVlxFkKGGjT0nREQPoJF96SBR1UdAVrKqf6UqVYLM8z38tZ4z9dHWeqidbtG3MhFx9a7GsvpVScspFF9fTMxGUYkSn++4qPXc9/46h5n//KdRnq+yCm45PU9+onqE4YSISE+0zQhSFdaudMO/QG8nfDoiCIPbe2mcoyvgfPG/jnWvoIqVR6890OeKlUqsP3kTKdkFOs9Zd/ImRv9wHK/+cgojvzuK4hIlCrSsWiuVSLD2xA18WalHpZxSKeD6XeMPJCbT42MdIqI6WvliNySk5aJL08ZVnuflbIPo2QNhZ12650/4mM5o8cE2tXPkWh4N7Xr7YXg524g7G9eWXCrR2yOUwmIlbqbnVXtexNW74usLidla2yWVSPBBWS/L0I7eaOetPoNjxob/sP7UTcwfGYTnumvvaaL6iT0nRER11L+tB15SGQBbFWc7K7F3RCaVwMFa/d+IrTwc1d43aWSL1p6OUFTzyKgqvVq6VX9SDU387TRsdGyoqEthiRKzN8VolP96/Lr4Oq/SY5+8wmKsL1uu/+vdl9WOZeQVYsG2C7ikstkh1S8MJ0REJrRjSh/x9fTBbdUGkk4b1AZ7pvYFAK3hxMPRGgen9dMoq+zlGganmopLqV0oeGNNFE5eu6dRrmtmz8qjCQicvVN8X7npczafx/eHrmLQokManxUEATG3M9VWBybLw3BCRGRCTRrZ4aXezeDrYovnQksfXTiW9aYMD/YReymklUaPvt6vBSJmDEBTV3ssGtVJLI+cGaZ23uhQfzzc2h3dm1X9yKk27qoMkK2J5Czd41PKLdpzGVFly/TP/Vd9M0R5pa0HTmkJOuUOx6Xh8SVH8MJPkbWqI5kXjjkhIjKx2cMCMevxduLMn4gPBiAjrxBNGqmvSbLgySDM2FA6RsNaLhV7WYZ29Mb+SynoFqAZQD4d0QEA0Ni++iXoL3w8GPGpOVi0Jw6Xk7NxQ8fYkhMJ6TVvXA0djkvD4bg0vNgrQOOYVFo6OLY8oBUrNQfXbo6+g7jkbGSXbYQYeS0dxSVKyOvwOIxMh/+vERGZAdUpyQ7Wco1gApT2gnRp2ggAMDLEVyy3kknx7bMheL5H6X5Abw5oBQB4o18LMcCM7dkUcqkEc4YF4oWyfYNULX++M2wVMnTwdcaKcV3R0sO4C8aVW3XsmkbZzfT76Dh3F6Ku30N+UYlaT0zm/dJ9jt78/QyW7LuCv0/fEo8dv5qOXyOuqa1Fk5ZTgLQc9Z6cqOvpSMmuenfo+4UlOH3jHpRcm8UoLGZvnXK1WZufiKi+KS5RIju/WFxnRRulUkBcSg5aeTioPQ4qKlHCSiZF5v0iBM/dBQD4c2JPrT0ux66k4bkVJ9DC3R7xJtgXSJv2Pk4YGOiFb/ZUDJDt0rQROvm5VLktwHdjOuOxIG8UFivR+sPtAIC4eUNgJZPi5LV0PL08Agq5FJc/HaLzGs+vOIEjV9Iwd3h7jNPSu0PVq5d76xARESCXSasMJkDp+JQ2Xo4a5eWzhJxtrRAzdxDuZNxHa0/N84DSGT5H338E7g7WGPjNQVy7W/30YUOzkkmx7KD6kvdR1++JY1V0OXcrE48Feav1mOQVlOD8nXS8+2c0gNIp0gCw4vBVHL2ShuVju8BaLsP3B+Nx/k4WjlwpXbX3t+PXGU6MgOGEiKgBcrCW6wwm5XxdSheVOzCtP3IKimElk6DNhzu0njumuz/+OHUTmyY9hJVHE/Bn1C2t59WFjZUUNlYy5BfVYB8AFeUzdwpVFoJLyy3AcytOaJz76dYLAIDNZ++gvY8zFmxXXwm38rL8ZBgcc0JERNVysJbDWq59fZMXewXg0xEdcHb2QAT6OGFivxZo0shWHIyrL8evpiMjr6jWnysoLg0nO84niWUjtewhpDr9uKBYiclrT2ucI5VIcL+wBImZ92tdD2OLTEjH4EWHEGmAAcyGxnBCREQP5KungzG2R1PMfjwQEokE9mVToFu4O+DI9EfEAbqqDr/XH8OCfYxaz98jbyIuORsLVXpBsspm9aiqHEYStCydL5UCQ5ccRp/P9uN2RkVAybxfhBWHryI5S3NgbVZ+Ee7l1m76tT48830ELiZl45nvI4x+77piOCEiohoLcK2YRfRUlyb4ZEQHjTVYquLX2A7uDpoLxRnadwfiqz1nz4UU8fWHG2Pgp2XGVHJWAa6m5qJYKSAuuWIxupn//IdPt17A2J/UHxUJgoCun+xByCe7kVeoGYgAIDW7ANP/OocnvzuKY1cMtyO1JWE4ISKiGtsztS9e6t0M34/tUqPzXeysNMpsFepfPYHeThor236u540O/42+U+vPaFvnJVVlw8OXV5/C17tKNy3cFZsMALicnKMWQgqKlSgsKR3rcuhyGhZsu4Cnlh0THzUBwILtF7D+1E2cvpGhdRwMgAY3hZkDYomIqMbkMilmDwus8fmH3uuP3gv3ITu/GEODvAEAdoqKr561L3dHl4BGsJbLIAgCJv9+BnkFxXi6SxPkFhRrrBZbmUImFb/8hwZ5Y+t/iVrP09fGh5WvuXjfFTzfo6naYNv2c3bi3YFtIAgCBqrsPD3xtyjx9Z7YFAT6OCH6ZgZ2xiShKpPWnkbM7UzsnPKw1n2N/jh5E2tOXMePL3SFh5ONHlpmegwnRERkME42Vtj2Zh9s/S8RT3QqHWsS4u8iHlfdlFAikSD8uc7ie1uVL+JH2npg38WKxy7lTnwwACGf7AYA9GzhqjOcGNLcLeoBShCAL3aW9qiobm5YWf8vD2gtn7UxBtOHtIWDtRxFJUpsPVfapvnbLqCDrzP2xCZj8egQMai893fpbtWf7biEr54JrmtzzILFhJPw8HCEh4ejpISbORERWRK/xnaY2LeF+L5XCzd8N6YzWrhXvQrt8E4+OH3jHkZ08kUje4VGODk4rR9sFRUB5snOvlDIpfB1scUYHY9HDKE8PGija18hGyvdoyp+PX4d94tK8OXTwTh4KVUs/yWiIuj8dvw6xvduprZ2S0ZeIUqUAqb+cRaN7KrfrsCccYVYIiKyCB9tPo9Vx64h/LnO6NG8MVzLBtb+dysTABDUxBlA6dRhXeuxmIuwdh5qA3Ars1fIEDXrUbSdpb0doc0a425OgdrqvQ+3dsfLDzXDCz9rbnp4beFQcYVgU+EKsUREVO/MGRaItwa00lghtzyUlFNYwGZ/VQUTAMgtLMHdKqYfa1u7pLhEiSQtU5kB4Ns9cfj+UDw2vNELrvbWeP23KDzX3R8jQ3zV9nUyF+b//yARERFKx6RUt3R/+Xnldr39MNpqWcpflY2VFOte7YFAbyf8/XqvOtdTX3ov3Fer84/F31Wb3qzqmz2XkVdYgnlbL+DLnZdw6vo9TP0jGl0+3YOF2y9i6h9nceLqXX1UWy/4WIeIiOqdzLwi5BeXwNPJBjfT89Dn8/1qx9v7OOH8nSwAgKu9AlGzHhWPBby/9YHuOW9kB8z8J+bBK20ED7V0g51CJk59ruzawqEGu3dtvr/Zc0JERPWOs50VPMum1TrZaK61EuBmL77W1/TbEL9GermOIQkQYAk9EgwnRERUrznayNGkkS08nSoWeuvs3whLnwtBKw8HLH62k9r5Ux9trfU6nzzRHiM6+eDDoe20Hm9sr8Anet5PSN+irt/Dbh29JuaEA2KJiKhek0ol2P9uPygFAdfv5uHolTSM7dEUcpkUj3fU3OfnzQGt8PXuyxrlo0P9MbZnAADAViHDzH9ioJBLxQXYbKykcLat6KXxdrZBYqbmANXFo0PQqYkLXv7lJC4n5+iplTVT3Y7Oi/ZcxuMdfdDSo+pp3obGcEJERPVe+RTa1p6OaO1Z9QBZAOjTyg2H49Iwprs/+rXxgLVcCrnKLKDnQv3RpWkjXE7OwZu/nwEA2FjJkKOyoWD5RoiV9WzuCndHa+x6uy9SswvQbd6eujRNrxbticP3B6/iwieDTVoPhhMiIqJKlo7ujINxqRgY6Kl1yXiJRIK2Xk64klLR82Etl0IhrwgwHw9vL+6V4+Nsg/919UNBUQncVfYRcne0xubJvWGnkCPs64MAgIl9W2D5weo3KjSU+0WmX+yU4YSIiKgSZzsrDA/WfORTmeqWPRKJBMOCvbHrfBL6tHJDr5ZuWPtyd/x4+CqmD2mLtl7aZ6h0bOICANgz9WFcSsrB0I7euJ1xv9rNCte+3F0MP442cmTna9/12BIxnBARET2gyqtxWMtl+OGFruL7Xi3d1PYPqkpLD0e09Ch95NTBx0lrOLFXyJBbWNqz0bOFK/7vkZZo7m6Pwe290enjXSgornpMSU0VlyjVHmMZG2frEBERPaBeLUqDR1NXO71et3tzV42y9j5Oao+NJBIJ3hnYBiNDmsBWIcODLPTqqmNRu83V9NoYGsMJERHRA3J3tMbZ2Y9i99t99Xrd9j6aj4AWjw7B8ue7oLWnA35/pYfGcdVOHGdbK7w5oJXGOR881hbx8x8T33cLaKz1/jfS8x6g1vrDxzpERER14GKAHYC1bdDXwt0BLdwdsEtHEFINJ9FzBgIAfom4hoy8IrH81YdbqH3GyVZ7DCjU0+OhB8VwQkREZIbWv9oDayNv4HBcGj5+on2152tb+/X4jAFISMvFvospGNPdX+O4TKr9AUqbavYjMjSGEyIiIjPUvbkrujd3hSAINdo5uIOvM87cyIC9omLqs42VDO28ndDOW/0x0aOBntgdm4zne/ijb2t3TPwtCgDw9TPBZrFTMcMJERGRGatpUFj6XGcs2RuH8b2bVXvu9893Qcb9IjS2V6CFe8VqsE1d7U0eTADuSkxERNTgnb+TifO3s/B01yYGCye1+f5mzwkREVED197HGe19nE1dDRGnEhMREZFZYTghIiIis8JwQkRERGaF4YSIiIjMCsMJERERmRWGEyIiIjIrFhNOwsPDERgYiG7dupm6KkRERGRAXISNiIiIDK42398W03NCREREDQPDCREREZkVhhMiIiIyKwwnREREZFYYToiIiMisWNyuxOWTi7KyskxcEyIiIqqp8u/tmkwStrhwkp2dDQDw8/MzcU2IiIiotrKzs+Hs7FzlORa3zolSqcSdO3fg6OgIiUSi12tnZWXBz88PN2/erJdrqLB9lq++t7G+tw+o/21k+yyfodooCAKys7Ph4+MDqbTqUSUW13MilUrRpEkTg97Dycmp3v7QAWxffVDf21jf2wfU/zayfZbPEG2srsekHAfEEhERkVlhOCEiIiKzwnCiwtraGnPmzIG1tbWpq2IQbJ/lq+9trO/tA+p/G9k+y2cObbS4AbFERERUv7HnhIiIiMwKwwkRERGZFYYTIiIiMisMJ0RERGRWGE7KhIeHIyAgADY2NujevTsiIyNNXaUaWbBgAbp16wZHR0d4eHhgxIgRuHTpkto5+fn5mDRpElxdXeHg4ICnnnoKycnJaufcuHEDQ4cOhZ2dHTw8PDBt2jQUFxcbsyk1snDhQkgkEkyZMkUsqw/tu337Np5//nm4urrC1tYWQUFBOHXqlHhcEATMnj0b3t7esLW1RVhYGOLi4tSukZ6ejjFjxsDJyQkuLi6YMGECcnJyjN0UDSUlJZg1axaaNWsGW1tbtGjRAp988ona/hqW1r5Dhw5h2LBh8PHxgUQiwcaNG9WO66s9586dQ58+fWBjYwM/Pz98/vnnhm4agKrbV1RUhOnTpyMoKAj29vbw8fHBCy+8gDt37qhdw1LbV9nEiRMhkUiwaNEitXJzbh9QszZeuHABw4cPh7OzM+zt7dGtWzfcuHFDPG7S360CCevWrRMUCoXw888/C+fPnxdeeeUVwcXFRUhOTjZ11ao1aNAgYeXKlUJMTIxw9uxZ4bHHHhP8/f2FnJwc8ZyJEycKfn5+wt69e4VTp04JPXr0EHr16iUeLy4uFjp06CCEhYUJZ86cEbZt2ya4ubkJM2bMMEWTdIqMjBQCAgKEjh07Cm+99ZZYbuntS09PF5o2bSq8+OKLwokTJ4SrV68KO3fuFK5cuSKes3DhQsHZ2VnYuHGjEB0dLQwfPlxo1qyZcP/+ffGcwYMHC8HBwcLx48eFw4cPCy1bthRGjx5tiiapmTdvnuDq6ips2bJFSEhIEP7880/BwcFB+Pbbb8VzLK1927ZtE2bOnCls2LBBACD8888/asf10Z7MzEzB09NTGDNmjBATEyP8/vvvgq2trfD999+btH0ZGRlCWFiYsH79euHixYtCRESEEBoaKnTp0kXtGpbaPlUbNmwQgoODBR8fH+Gbb75RO2bO7ROE6tt45coVoXHjxsK0adOE06dPC1euXBE2bdqk9r1nyt+tDCeCIISGhgqTJk0S35eUlAg+Pj7CggULTFirB5OSkiIAEA4ePCgIQukvEisrK+HPP/8Uz7lw4YIAQIiIiBAEofSHWCqVCklJSeI5y5YtE5ycnISCggLjNkCH7OxsoVWrVsLu3buFvn37iuGkPrRv+vTpwkMPPaTzuFKpFLy8vIQvvvhCLMvIyBCsra2F33//XRAEQYiNjRUACCdPnhTP2b59uyCRSITbt28brvI1MHToUOGll15SK3vyySeFMWPGCIJg+e2r/ItfX+357rvvhEaNGqn9jE6fPl1o06aNgVukrqov73KRkZECAOH69euCINSP9t26dUvw9fUVYmJihKZNm6qFE0tqnyBob+OoUaOE559/XudnTP27tcE/1iksLERUVBTCwsLEMqlUirCwMERERJiwZg8mMzMTANC4cWMAQFRUFIqKitTa17ZtW/j7+4vti4iIQFBQEDw9PcVzBg0ahKysLJw/f96Itddt0qRJGDp0qFo7gPrRvs2bN6Nr1654+umn4eHhgZCQEPz444/i8YSEBCQlJam10dnZGd27d1dro4uLC7p27SqeExYWBqlUihMnThivMVr06tULe/fuxeXLlwEA0dHROHLkCIYMGQLA8ttXmb7aExERgYcffhgKhUI8Z9CgQbh06RLu3btnpNbUTGZmJiQSCVxcXABYfvuUSiXGjh2LadOmoX379hrH60P7tm7ditatW2PQoEHw8PBA9+7d1R79mPp3a4MPJ2lpaSgpKVH7ywUAT09PJCUlmahWD0apVGLKlCno3bs3OnToAABISkqCQqEQf2mUU21fUlKS1vaXHzO1devW4fTp01iwYIHGsfrQvqtXr2LZsmVo1aoVdu7ciddffx1vvvkmVq9eDaCijlX9jCYlJcHDw0PtuFwuR+PGjU3exvfffx/PPvss2rZtCysrK4SEhGDKlCkYM2YMAMtvX2X6ao+5/9yWy8/Px/Tp0zF69GhxkzhLb99nn30GuVyON998U+txS29fSkoKcnJysHDhQgwePBi7du3CyJEj8eSTT+LgwYNiHU35u9XidiUm3SZNmoSYmBgcOXLE1FXRm5s3b+Ktt97C7t27YWNjY+rqGIRSqUTXrl0xf/58AEBISAhiYmKwfPlyjBs3zsS1q7s//vgDa9aswdq1a9G+fXucPXsWU6ZMgY+PT71oX0NWVFSEZ555BoIgYNmyZaaujl5ERUXh22+/xenTpyGRSExdHYNQKpUAgCeeeAJvv/02AKBTp044duwYli9fjr59+5qyegDYcwI3NzfIZDKNEcjJycnw8vIyUa1qb/LkydiyZQv279+PJk2aiOVeXl4oLCxERkaG2vmq7fPy8tLa/vJjphQVFYWUlBR07twZcrkccrkcBw8exOLFiyGXy+Hp6WnR7QMAb29vBAYGqpW1a9dOHDVfXseqfka9vLyQkpKidry4uBjp6ekmb+O0adPE3pOgoCCMHTsWb7/9ttgTZuntq0xf7TH3n9vyYHL9+nXs3r1b7DUBLLt9hw8fRkpKCvz9/cXfOdevX8c777yDgIAAsX6W2j6g9HtPLpdX+3vHlL9bG3w4USgU6NKlC/bu3SuWKZVK7N27Fz179jRhzWpGEARMnjwZ//zzD/bt24dmzZqpHe/SpQusrKzU2nfp0iXcuHFDbF/Pnj3x33//qf3HVv7LpvIPr7ENGDAA//33H86ePSv+6dq1K8aMGSO+tuT2AUDv3r01pn9fvnwZTZs2BQA0a9YMXl5eam3MysrCiRMn1NqYkZGBqKgo8Zx9+/ZBqVSie/fuRmiFbnl5eZBK1X/VyGQy8V9vlt6+yvTVnp49e+LQoUMoKioSz9m9ezfatGmDRo0aGak12pUHk7i4OOzZsweurq5qxy25fWPHjsW5c+fUfuf4+Phg2rRp2LlzJwDLbh9Q+r3XrVu3Kn/vmPy7o07DaeuJdevWCdbW1sKqVauE2NhY4dVXXxVcXFzURiCbq9dff11wdnYWDhw4ICQmJop/8vLyxHMmTpwo+Pv7C/v27RNOnTol9OzZU+jZs6d4vHw62MCBA4WzZ88KO3bsENzd3c1mqm1lqrN1BMHy2xcZGSnI5XJh3rx5QlxcnLBmzRrBzs5O+O2338RzFi5cKLi4uAibNm0Szp07JzzxxBNap6aGhIQIJ06cEI4cOSK0atXKLKYSjxs3TvD19RWnEm/YsEFwc3MT3nvvPfEcS2tfdna2cObMGeHMmTMCAOHrr78Wzpw5I85W0Ud7MjIyBE9PT2Hs2LFCTEyMsG7dOsHOzs4oU1Gral9hYaEwfPhwoUmTJsLZs2fVfu+oztCw1PZpU3m2jiCYd/sEofo2btiwQbCyshJ++OEHIS4uTliyZIkgk8mEw4cPi9cw5e9WhpMyS5YsEfz9/QWFQiGEhoYKx48fN3WVagSA1j8rV64Uz7l//77wxhtvCI0aNRLs7OyEkSNHComJiWrXuXbtmjBkyBDB1tZWcHNzE9555x2hqKjIyK2pmcrhpD60799//xU6dOggWFtbC23bthV++OEHteNKpVKYNWuW4OnpKVhbWwsDBgwQLl26pHbO3bt3hdGjRwsODg6Ck5OTMH78eCE7O9uYzdAqKytLeOuttwR/f3/BxsZGaN68uTBz5ky1LzJLa9/+/fu1/nc3btw4QRD0157o6GjhoYceEqytrQVfX19h4cKFJm9fQkKCzt87+/fvt/j2aaMtnJhz+wShZm386aefhJYtWwo2NjZCcHCwsHHjRrVrmPJ3q0QQVJZpJCIiIjKxBj/mhIiIiMwLwwkRERGZFYYTIiIiMisMJ0RERGRWGE6IiIjIrDCcEBERkVlhOCEiIiKzwnBCREREZoXhhIiIiMwKwwkRERGZFYYTIiIiMisMJ0RERGRW/h9cfOiJ0cWvQQAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(train_metrics_history[\"train_loss\"], label=\"Loss value during the training\")\n", "plt.yscale('log')\n", "plt.legend()" ] }, { "cell_type": "markdown", "id": "66250bf2-3d88-40ad-87e3-7d2b906fd860", "metadata": {}, "source": [ "And eval set Loss and Accuracy - Accuracy does continue to rise, though it's hard-earned progress after about the 5th epoch. Based on the training statistics, it's fair to say the process starts overfitting after roughly that 5th epoch." ] }, { "cell_type": "code", "execution_count": 22, "id": "64d54051-358b-4de8-b5b3-04bebf18018f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAzoAAANECAYAAAB4mVoFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACfuklEQVR4nOzdeXhU5fnG8XtmsoesLAmbBlDZZBMEEUQQFFGpVETciiLgUtEK2lasiksraitaW5UuIPorlsVSd1FEEBeECoKyyqYgkEBIMoHsmTm/P5IzYcy+TM4s3891zdXkzJmZd1LMyT3v8z6vzTAMQwAAAAAQROxWDwAAAAAAmhpBBwAAAEDQIegAAAAACDoEHQAAAABBh6ADAAAAIOgQdAAAAAAEHYIOAAAAgKBD0AEAAAAQdAg6AAAAAIIOQQch55FHHpHNZrN6GEFp+PDhGj58uNXDAACEmO+//142m00LFy60eijwIwSdELZw4ULZbDZ99dVXVg8FaLAnnnhCb7zxhtXDANBMXnzxRdlsNg0aNMjqoSDEbd++XY888oi+//57q4eCahB0AAQ0gg4QWhYtWqS0tDRt2LBBe/bssXo4CGHbt2/Xo48+StDxYwQdAAAQEPbv368vvvhCc+fOVevWrbVo0SKrh1StvLw8q4cAhDyCDmr19ddfa8yYMYqPj1eLFi00cuRIffnll17nlJSU6NFHH9WZZ56pqKgotWzZUkOHDtXKlSs956Snp2vy5Mnq0KGDIiMj1bZtW1155ZU1fhLypz/9STabTT/88EOl+2bNmqWIiAhlZ2dLkj799FNNmDBBp512miIjI9WxY0fNmDFDBQUFNb6/mup6bTabHnnkEa9jhw4d0i233KKUlBRFRkaqZ8+eWrBgQY2vYSotLdXjjz+uLl26KDIyUmlpaXrggQdUVFTkdV5aWpquuOIKffbZZxo4cKCioqLUuXNnvfrqq3V6Hbfbreeee049e/ZUVFSUUlJSdNttt3l+VpJ0xRVXqHPnzlU+fvDgwRowYIDn+5dfflkXXXSR2rRpo8jISPXo0UMvvfRSncZSlZUrV2ro0KFKTExUixYt1LVrVz3wwANe5xQVFWn27Nk644wzPP9//uY3v/H6WdlsNuXl5emVV16RzWaTzWbTzTff3OBxAfBvixYtUlJSki6//HJdffXV1QadnJwczZgxQ2lpaYqMjFSHDh00adIkZWZmes4pLCzUI488orPOOktRUVFq27atrrrqKu3du1eStGbNGtlsNq1Zs8bruau6Ztx8881q0aKF9u7dq8suu0xxcXG64YYbJNXv2rRz505dc801at26taKjo9W1a1f97ne/kyStXr1aNptN//3vfys97rXXXpPNZtO6detq/Pnt27dPEyZMUHJysmJiYnTeeefp3Xff9TrHfN9Lly7VH/7wB3Xo0EFRUVEaOXJknWfQartOZmRkKCwsTI8++milx+7atUs2m01//etfJUlZWVm677771KtXL7Vo0ULx8fEaM2aMtmzZUqex/FRd/l6Ryv6/uPrqq5WcnKyoqCgNGDBAb731luf+hQsXasKECZKkESNGeK5BP/33AmuFWT0A+Ldt27bpggsuUHx8vH7zm98oPDxcf/vb3zR8+HB98sknnhrpRx55RHPmzNHUqVM1cOBA5ebm6quvvtKmTZt08cUXS5LGjx+vbdu26a677lJaWpqOHj2qlStX6sCBA0pLS6vy9a+55hr95je/0dKlS/XrX//a676lS5fqkksuUVJSkiRp2bJlys/P1x133KGWLVtqw4YN+stf/qIff/xRy5Yta5KfR0ZGhs477zzZbDZNnz5drVu31vvvv68pU6YoNzdX99xzT42Pnzp1ql555RVdffXVuvfee7V+/XrNmTNHO3bsqHTx2rNnj66++mpNmTJFN910kxYsWKCbb75Z/fv3V8+ePWt8ndtuu00LFy7U5MmTdffdd2v//v3661//qq+//lqff/65wsPDNXHiRE2aNEn/+9//dO6553oe+8MPP+jLL7/UH//4R8+xl156ST179tTPfvYzhYWF6e2339Yvf/lLud1u3XnnnfX6GW7btk1XXHGFevfurccee0yRkZHas2ePPv/8c885brdbP/vZz/TZZ5/p1ltvVffu3fXtt9/q2Wef1XfffecpVfu///s/z7+5W2+9VZLUpUuXeo0HQOBYtGiRrrrqKkVEROi6667TSy+9VOl32MmTJ3XBBRdox44duuWWW3TOOecoMzNTb731ln788Ue1atVKLpdLV1xxhVatWqVrr71Wv/rVr3TixAmtXLlSW7dubdDvkdLSUo0ePVpDhw7Vn/70J8XExEiq+7Xpm2++0QUXXKDw8HDdeuutSktL0969e/X222/rD3/4g4YPH66OHTtq0aJF+vnPf17p59KlSxcNHjy42vFlZGTo/PPPV35+vu6++261bNlSr7zyin72s5/p9ddfr/ScTz75pOx2u+677z45nU49/fTTuuGGG7R+/foafw51uU6mpKTowgsv1NKlSzV79myvxy9ZskQOh8MTIvbt26c33nhDEyZMUKdOnZSRkaG//e1vuvDCC7V9+3a1a9eu9v9zTlGXv1e2bdumIUOGqH379rr//vsVGxurpUuXaty4cfrPf/6jn//85xo2bJjuvvtuPf/883rggQfUvXt3SfL8L/yEgZD18ssvG5KM//3vf9WeM27cOCMiIsLYu3ev59jhw4eNuLg4Y9iwYZ5jffr0MS6//PJqnyc7O9uQZPzxj3+s9zgHDx5s9O/f3+vYhg0bDEnGq6++6jmWn59f6bFz5swxbDab8cMPP3iOzZ492zj1n/7+/fsNScbLL79c6fGSjNmzZ3u+nzJlitG2bVsjMzPT67xrr73WSEhIqHIMps2bNxuSjKlTp3odv++++wxJxscff+w5dvrppxuSjLVr13qOHT161IiMjDTuvffeal/DMAzj008/NSQZixYt8jq+YsUKr+NOp7PK53v66acr/cyqel+jR482Onfu7HXswgsvNC688MIax/fss88akoxjx45Ve87//d//GXa73fj000+9js+bN8+QZHz++eeeY7GxscZNN91U42sCCHxfffWVIclYuXKlYRiG4Xa7jQ4dOhi/+tWvvM57+OGHDUnG8uXLKz2H2+02DMMwFixYYEgy5s6dW+05q1evNiQZq1ev9rq/qmvGTTfdZEgy7r///krPV9dr07Bhw4y4uDivY6eOxzAMY9asWUZkZKSRk5PjOXb06FEjLCzM61pVlXvuuceQ5PV79cSJE0anTp2MtLQ0w+Vyeb3v7t27G0VFRZ5z//znPxuSjG+//bbG16nrdfJvf/tblc/Xo0cP46KLLvJ8X1hY6Bmbaf/+/UZkZKTx2GOPeR2r7lp+qtr+XjEMwxg5cqTRq1cvo7Cw0HPM7XYb559/vnHmmWd6ji1btqzKfyPwH5SuoVoul0sffvihxo0b51Xi1LZtW11//fX67LPPlJubK0lKTEzUtm3btHv37iqfKzo6WhEREVqzZo1X+VRdTJw4URs3bvSUE0hln/hERkbqyiuv9HoNU15enjIzM3X++efLMAx9/fXX9XrNqhiGof/85z8aO3asDMNQZmam5zZ69Gg5nU5t2rSp2se/9957kqSZM2d6Hb/33nslqVL5QI8ePXTBBRd4vm/durW6du2qffv21TjOZcuWKSEhQRdffLHXGPv3768WLVpo9erVkuSZ/l+6dKkMw/A8fsmSJTrvvPN02mmneY6d+rN1Op3KzMzUhRdeqH379snpdNY4np9KTEyUJL355ptyu93Vvofu3burW7duXu/hoosukiTPewAQOhYtWqSUlBSNGDFCUlnp6sSJE7V48WK5XC7Pef/5z3/Up0+fSjMU5mPMc1q1aqW77rqr2nMa4o477qh0rC7XpmPHjmnt2rW65ZZbvH73/nQ8kyZNUlFRkV5//XXPsSVLlqi0tFQ33nhjjWN77733NHDgQA0dOtRzrEWLFrr11lv1/fffa/v27V7nT548WREREZ7vzetRTdeg+lwnr7rqKoWFhWnJkiWex2/dulXbt2/XxIkTPcciIyNlt5f9uepyuXT8+HFPyXNN19zq1Pb3SlZWlj7++GNdc801OnHihGf8x48f1+jRo7V7924dOnSo3q8LaxB0UK1jx44pPz9fXbt2rXRf9+7d5Xa7dfDgQUnSY489ppycHJ111lnq1auXfv3rX+ubb77xnB8ZGamnnnpK77//vlJSUjRs2DA9/fTTSk9Pr3UcEyZMkN1u9/wyNAxDy5Yt86wbMh04cEA333yzkpOT1aJFC7Vu3VoXXnihJNX7j/GqHDt2TDk5Ofr73/+u1q1be90mT54sSTp69Gi1j//hhx9kt9t1xhlneB1PTU1VYmJipXVIP73YSVJSUlKtQXH37t1yOp1q06ZNpXGePHnSa4wTJ07UwYMHPXXde/fu1caNG70uMpL0+eefa9SoUYqNjVViYqJat27tWVNT35/txIkTNWTIEE2dOlUpKSm69tprtXTpUq/Qs3v3bm3btq3S+M866yxJNf+cAQQfl8ulxYsXa8SIEdq/f7/27NmjPXv2aNCgQcrIyNCqVas85+7du1dnn312jc+3d+9ede3aVWFhTVfBHxYWpg4dOlQ6Xpdrkxkeaht3t27ddO6553qtTVq0aJHOO++8SteWn/rhhx+qvZ6b95/qp9cgs0y8pmtQfa6TrVq10siRI7V06VLP45csWaKwsDBdddVVnmNut1vPPvuszjzzTEVGRqpVq1Zq3bq1vvnmmwZd22v7e2XPnj0yDEMPPfRQpfdgltlxDQocrNFBkxg2bJj27t2rN998Ux9++KH++c9/6tlnn9W8efM0depUSdI999yjsWPH6o033tAHH3yghx56SHPmzNHHH3+sfv36Vfvc7dq10wUXXKClS5fqgQce0JdffqkDBw7oqaee8pzjcrl08cUXKysrS7/97W/VrVs3xcbG6tChQ7r55purnTmQqv/07tRPCCV5nuPGG2/UTTfdVOVjevfuXe3r1PZ6P+VwOKo8fursS1XcbrfatGlT7SLd1q1be74eO3asYmJitHTpUp1//vlaunSp7Ha7pzZaKvuDYOTIkerWrZvmzp2rjh07KiIiQu+9956effbZGn+2VYmOjtbatWu1evVqvfvuu1qxYoWWLFmiiy66SB9++KEcDofcbrd69eqluXPnVvkcHTt2rNdrAghsH3/8sY4cOaLFixdr8eLFle5ftGiRLrnkkiZ9zbpeG0ynzjycem5Dr03VmTRpkn71q1/pxx9/VFFRkb788kvPwv2m1JBrUH2vk9dee60mT56szZs3q2/fvlq6dKlGjhypVq1aec554okn9NBDD+mWW27R448/ruTkZNntdt1zzz0N+vnV9veK+Zz33XefRo8eXeVz1BYq4T8IOqhW69atFRMTo127dlW6b+fOnbLb7V5/cCYnJ2vy5MmaPHmyTp48qWHDhumRRx7xBB2pbKH4vffeq3vvvVe7d+9W37599cwzz+hf//pXjWOZOHGifvnLX2rXrl1asmSJYmJiNHbsWM/93377rb777ju98sormjRpkuf4T7uoVMX8lConJ8fr+E8/3WrdurXi4uLkcrk0atSoWp/3p04//XS53W7t3r3ba7FiRkaGcnJydPrpp9f7OavSpUsXffTRRxoyZIhXyURVYmNjdcUVV2jZsmWaO3eulixZogsuuMBrcefbb7+toqIivfXWW16f8DWmfMxut2vkyJEaOXKk5s6dqyeeeEK/+93vtHr1ao0aNUpdunTRli1bNHLkyFqDYWPKTAAEhkWLFqlNmzZ64YUXKt23fPly/fe//9W8efMUHR2tLl26aOvWrTU+X5cuXbR+/XqVlJQoPDy8ynPqem2oSV2vTWZ5eG3jlsrCwcyZM/Xvf/9bBQUFnuYytTn99NOrvZ6b9zdWfa+T48aN02233eap2Pjuu+80a9Ysr3Nef/11jRgxQvPnz/c6npOT4xWI6qOmv1fM/y/Cw8NrfQ9cf/wfpWuolsPh0CWXXKI333zTqwV0RkaGXnvtNQ0dOtRTOnb8+HGvx7Zo0UJnnHGGpxVwfn6+CgsLvc7p0qWL4uLiKrVWrsr48ePlcDj073//W8uWLdMVV1yh2NhYr7FK3p80GYahP//5z7U+d3x8vFq1aqW1a9d6HX/xxRe9vnc4HBo/frz+85//VHkxOnbsWI2vc9lll0mSnnvuOa/j5qzF5ZdfXutY6+Kaa66Ry+XS448/Xum+0tLSShftiRMn6vDhw/rnP/+pLVu2VLpgVvWzdTqdevnllxs0vqysrErH+vbtK0mefwvXXHONDh06pH/84x+Vzi0oKPDanyI2NrbSewIQPAoKCrR8+XJdccUVuvrqqyvdpk+frhMnTnha/44fP15btmypsg2z+Xts/PjxyszMrHImxDzn9NNPl8PhqPXaUJO6Xptat26tYcOGacGCBTpw4ECV4zG1atVKY8aM0b/+9S8tWrRIl156aZ3+4L/sssu0YcMGrxbUeXl5+vvf/660tDT16NGjzu+rOvW9TiYmJmr06NFaunSpFi9erIiICI0bN67Sc/70Z7Bs2bIGr5Op7e+VNm3aaPjw4frb3/6mI0eO1PgezL9DuAb5L2Z0oAULFmjFihWVjv/qV7/S73//e8+eJ7/85S8VFhamv/3tbyoqKtLTTz/tObdHjx4aPny4+vfvr+TkZH311Vd6/fXXNX36dElln9KMHDlS11xzjXr06KGwsDD997//VUZGhq699tpax9imTRuNGDFCc+fO1YkTJyr9Md6tWzd16dJF9913nw4dOqT4+Hj95z//qXPjg6lTp+rJJ5/U1KlTNWDAAK1du1bfffddpfOefPJJrV69WoMGDdK0adPUo0cPZWVladOmTfroo4+q/CPe1KdPH9100036+9//rpycHF144YXasGGDXnnlFY0bN86zwLaxLrzwQt12222aM2eONm/erEsuuUTh4eHavXu3li1bpj//+c+6+uqrPeebez7cd999novUqS655BJFRERo7Nixuu2223Ty5En94x//UJs2baq8CNTmscce09q1a3X55Zfr9NNP19GjR/Xiiy+qQ4cOnkWyv/jFL7R06VLdfvvtWr16tYYMGSKXy6WdO3dq6dKl+uCDDzz7/PTv318fffSR5s6dq3bt2qlTp06etucAAt9bb72lEydO6Gc/+1mV95933nmezUMnTpyoX//613r99dc1YcIE3XLLLerfv7+ysrL01ltvad68eerTp48mTZqkV199VTNnztSGDRt0wQUXKC8vTx999JF++ctf6sorr1RCQoImTJigv/zlL7LZbOrSpYveeeedeq3PqM+16fnnn9fQoUN1zjnn6NZbb1WnTp30/fff691339XmzZu9zp00aZLn93hVH2pV5f7779e///1vjRkzRnfffbeSk5P1yiuvaP/+/frPf/5Tqeyuoep7nZw4caJuvPFGvfjiixo9erSnYY3piiuu0GOPPabJkyfr/PPP17fffqtFixZVuw9cbWr7e0WSXnjhBQ0dOlS9evXStGnT1LlzZ2VkZGjdunX68ccfPXv49O3bVw6HQ0899ZScTqciIyM9e87BTzRvkzf4E7O9dHW3gwcPGoZhGJs2bTJGjx5ttGjRwoiJiTFGjBhhfPHFF17P9fvf/94YOHCgkZiYaERHRxvdunUz/vCHPxjFxcWGYRhGZmamceeddxrdunUzYmNjjYSEBGPQoEHG0qVL6zzef/zjH4YkIy4uzigoKKh0//bt241Ro0YZLVq0MFq1amVMmzbN2LJlS6V2kz9tL20YZe0/p0yZYiQkJBhxcXHGNddcYxw9erRSe2nDMIyMjAzjzjvvNDp27GiEh4cbqampxsiRI42///3vtb6HkpIS49FHHzU6depkhIeHGx07djRmzZrl1cLSMMraS1fV/rIu7ZtNf//7343+/fsb0dHRRlxcnNGrVy/jN7/5jXH48OFK595www2GJGPUqFFVPtdbb71l9O7d24iKijLS0tKMp556ytOedf/+/fUa36pVq4wrr7zSaNeunREREWG0a9fOuO6664zvvvvO67zi4mLjqaeeMnr27GlERkYaSUlJRv/+/Y1HH33UcDqdnvN27txpDBs2zIiOjjYk0WoaCDJjx441oqKijLy8vGrPufnmm43w8HBPS+Pjx48b06dPN9q3b29EREQYHTp0MG666Savlsf5+fnG7373O8/v49TUVOPqq6/22k7h2LFjxvjx442YmBgjKSnJuO2224ytW7dW2V46Nja2yrHV9dpkGIaxdetW4+c//7mRmJhoREVFGV27djUeeuihSs9ZVFRkJCUlGQkJCVVeD6uzd+9e4+qrr/Y8/8CBA4133nnH6xyzvfSyZcu8jte1fbNh1O86mZub6/n9/a9//avS/YWFhca9995rtG3b1oiOjjaGDBlirFu3rtL1pq7jq+3vFdPevXuNSZMmGampqUZ4eLjRvn1744orrjBef/11r/P+8Y9/GJ07dzYcDgetpv2QzTBqWdkMAAAAv1FaWqp27dpp7NixldauAKjAGh0AAIAA8sYbb+jYsWNeDQ4AVMaMDgAAQABYv369vvnmGz3++ONq1apVgzbMBEIJMzoAAAAB4KWXXtIdd9yhNm3a6NVXX7V6OIDfY0YHAAAAQNBhRgcAAABA0CHoAAAAAAg6AbFhqNvt1uHDhxUXFyebzWb1cAAgZBiGoRMnTqhdu3ZNtqFgsODaBADWqOu1KSCCzuHDh9WxY0erhwEAIevgwYPq0KGD1cPwK1ybAMBatV2bAiLoxMXFSSp7M/Hx8RaPBgBCR25urjp27Oj5PYwKXJsAwBp1vTYFRNAxSwLi4+O5mACABSjNqoxrEwBYq7ZrEwXXAAAAAIIOQQcAAABA0CHoAAAAAAg6BB0AAAAAQYegAwAAACDo1CvozJkzR+eee67i4uLUpk0bjRs3Trt27ar1ccuWLVO3bt0UFRWlXr166b333mvwgAEAAACgNvUKOp988onuvPNOffnll1q5cqVKSkp0ySWXKC8vr9rHfPHFF7ruuus0ZcoUff311xo3bpzGjRunrVu3NnrwAAAAAFAVm2EYRkMffOzYMbVp00affPKJhg0bVuU5EydOVF5ent555x3PsfPOO099+/bVvHnz6vQ6ubm5SkhIkNPpZK8CAGhG/P6tHj8bALBGXX//NmqNjtPplCQlJydXe866des0atQor2OjR4/WunXrGvPSAIAg8cILLygtLU1RUVEaNGiQNmzYUOP5zz33nLp27aro6Gh17NhRM2bMUGFhoef+Rx55RDabzevWrVs3r+coLCzUnXfeqZYtW6pFixYaP368MjIyfPL+AADWaHDQcbvduueeezRkyBCdffbZ1Z6Xnp6ulJQUr2MpKSlKT0+v9jFFRUXKzc31ugEAgs+SJUs0c+ZMzZ49W5s2bVKfPn00evRoHT16tMrzX3vtNd1///2aPXu2duzYofnz52vJkiV64IEHvM7r2bOnjhw54rl99tlnXvfPmDFDb7/9tpYtW6ZPPvlEhw8f1lVXXeWz9wkAaH5hDX3gnXfeqa1bt1a6eDSFOXPm6NFHH23y5wUA+Je5c+dq2rRpmjx5siRp3rx5evfdd7VgwQLdf//9lc7/4osvNGTIEF1//fWSpLS0NF133XVav36913lhYWFKTU2t8jWdTqfmz5+v1157TRdddJEk6eWXX1b37t315Zdf6rzzzmvKtwgAsEiDZnSmT5+ud955R6tXr1aHDh1qPDc1NbVSOUBGRka1FyBJmjVrlpxOp+d28ODBhgwTAODHiouLtXHjRq/yZrvdrlGjRlVb3nz++edr48aNnvK2ffv26b333tNll13mdd7u3bvVrl07de7cWTfccIMOHDjguW/jxo0qKSnxet1u3brptNNOq7GsmmoDAAgs9Qo6hmFo+vTp+u9//6uPP/5YnTp1qvUxgwcP1qpVq7yOrVy5UoMHD672MZGRkYqPj/e6AQCCS2ZmplwuV73Km6+//no99thjGjp0qMLDw9WlSxcNHz7cq3Rt0KBBWrhwoVasWKGXXnpJ+/fv1wUXXKATJ05IKiupjoiIUGJiYp1fVyqrNkhISPDcOnbs2MB3DgBoDvUKOnfeeaf+9a9/6bXXXlNcXJzS09OVnp6ugoICzzmTJk3SrFmzPN//6le/0ooVK/TMM89o586deuSRR/TVV19p+vTpTfcuAAAhYc2aNXriiSf04osvatOmTVq+fLneffddPf74455zxowZowkTJqh3794aPXq03nvvPeXk5Gjp0qWNem2qDQAgsNRrjc5LL70kSRo+fLjX8Zdfflk333yzJOnAgQOy2yvy0/nnn6/XXntNDz74oB544AGdeeaZeuONN2psYAAACH6tWrWSw+GoV3nzQw89pF/84heaOnWqJKlXr17Ky8vTrbfeqt/97nde1x9TYmKizjrrLO3Zs0dSWUl1cXGxcnJyvGZ1aiurjoyMVGRkZH3fJgDAIvUuXavqZoYcqezTtoULF3o9bsKECdq1a5eKioq0devWSrXUAIDQExERof79+3uVN7vdbq1atara8ub8/PxKYcbhcEgqu0ZV5eTJk9q7d6/atm0rSerfv7/Cw8O9XnfXrl06cOBAjWXVAIDA0uCuawAANNbMmTN10003acCAARo4cKCee+455eXlebqwTZo0Se3bt9ecOXMkSWPHjtXcuXPVr18/DRo0SHv27NFDDz2ksWPHegLPfffdp7Fjx+r000/X4cOHNXv2bDkcDl133XWSpISEBE2ZMkUzZ85UcnKy4uPjddddd2nw4MF0XAOAIELQAQBYZuLEiTp27Jgefvhhpaenq2/fvlqxYoWnQcFPy6EffPBB2Ww2Pfjggzp06JBat26tsWPH6g9/+IPnnB9//FHXXXedjh8/rtatW2vo0KH68ssv1bp1a885zz77rOx2u8aPH6+ioiKNHj1aL774YvO9cQCAz9mM6ub6/Uhubq4SEhLkdDrpwAYAzYjfv9XjZwMA1qjr798G7aMDAAAAAP6MoAMAAAAg6BB0AAAAAAQdgg4AAACAoEPQAQAAABB0CDoAAAAAgg5BBwAAAEDQIegAAAAACDoEHQAAAABBh6ADAAAAIOgQdAAAAAAEHYIOAAAAgKATZvUAfG13xgk9/OY2xUaG6Z83DbB6OAAAAEDIyTxZpG9+zNHmAzna/KNTLSIdevGG/j59zaAPOm5DWrfvuJJjI6weCgAAABD0Copd2nrYqS0Hc/T1wRxtOZijH7MLvM6JiwyT223Ibrf5bBxBH3QSY8IlSTn5xT7/YQIAAAChxOU2tPvoCW05mKPNB53afDBH32WckMttVDr3jDYt1KdDovp2TFCfjok+H1vQB52E6LKg4zakE0Wlnu8BAAAA1J1hGDriLNTm8lmazQdz9O0hp/KLXZXObRMXqb4dE9WnY6L6dkxUrw4Jio9q3r/Dgz7oRIU7FB3uUEGJS878EoIOAAAAUAfOghJ9+6NTW37M0dcHcrTlxxwdO1FU6bzYCId6dUhQ345Jntma1Pgo2WzWVlIFfdCRysrXCpwu5RQU6zTFWD0cAAAAwK8Ul7q140iutvxYNlOz+WCO9h3Lq3Sew25Tt9Q4z0xN346J6tK6hRx+uDwkJIJOQnS4jjgLlZNfYvVQAAAAAEsZhqHvj+dr88FsbSlfV7P9cK6KXe5K556WHKM+HRPVp0OC+p2WqB5tExQd4bBg1PUXEkEnKaas41p2frHFIwEAAACaR4nLrZz8EmXnF+tgVr5npuabH51yFlSeAEiKCS8PNWUzNb07JKhli0gLRt40QiLomJ3Xqvo/FAAAAPB3pS63svNLlJNfrKy8YmXnFysrryzEZOcVKyu/WDn5JafcV6wThaXVPl9EmF1nt4tX345J6tMxQX07Juq05BjL19U0pZAKOpSuAQAAwGqlLrdyCkrKAkpesbLLZ12y8spCy6nfm8Emt4bQUhObTUqMDlebuKjyhgFlszVdU+MU7rA38TvzLyERdBKiy0rXCDoAAAD4KZfbUInLXX4zVOpyq7iKr089p6TUrVK3W8XVfV3qlrOgpGLWxZx5aWRoSYgOV3JMhJJiI5QUE6GkmHAlx5Z9nxwTocRTvk+KiVBCdLhfNgpoDiERdJJO2TQUAAAAwcMwDH3zo1PvfHNY+47llYcSt0rLg0lxeVgpqSqslH9dxd6WzSIxJtw7rJwSYJJjw3/yfWiHloYIiaDjKV1jjQ4AAEBQ2Jmeq7e3HNbbW47oQFZ+kz63w25TmN2mCIdd4WF2hdltCnfYFXHK1+FhdoXX8nXCKUEmKTaiPMyEe2ZawoK8dMxqIRF0KkrXmNEBAAAIVPsz8/TOlsN6+5vD+i7jpOd4dLhDo3qkaHDnlooKt5cFDodd4Q5bzV//JKCYIYZZk+AQEkGHGR0AAIDAdCinQO9+UzZz8+0hp+d4hMOu4V1ba2yfdhrZvY1iIkLiz1rUQ0j8izD30aEZAQAAgP87eqJQ73+brre3HNZXP2R7jjvsNg05o5XG9m6rS3qmKiE63MJRwt+FRNBJPKUZgdttyM50JAAAgF/JyS/Wiq3pevubw1q397inQYDNJg1MS9bYPu005uzUgN7AEs0rJIKOmfbdhnSyuFTxUaR/AAAAq50sKtXK7el6e8sRrf3umEpPaX/Wt2OixvZpp8t7tVVqQpSFo0SgComgExXuUHS4QwUlLjnzSwg6AAAAFikscWn1zqN6a8thfbzzqIpK3Z77ureN19g+bXVFr3Y6rWWMhaNEMAiJoCOVla8VOF3Kzi9Wx2T+wwEAAGguxaVufbbnmN7eckQfbktXXrHLc1/nVrG6ok87/axPW53RJs7CUSLYhEzQSYgO1xFnIQ0JAAAAmoHLbejLfcf19pbDen9rupyndL9tnxitK/q01dje7dSzXbxsNtZPo+mFTNChxTQAAIBvud2GNh3I1ttbDuvdb9OVebLIc1/ruEhd3qutxvZpp3NOSyTcwOdCJuhUtJhm01AAAICmYhiGth3O1VtbDuudLYd12FnouS8xJlxjzm6rsX3aalCnlmzEiWYVMkGnosU0MzoAAACNYRiGdhw5oRVbj+jtb45of2ae574WkWG6pEeKxvZpp6FntlK4w27hSBHKQiboJESzaSgAAEBDudyGvvo+Sx9sy9CH29P1Y3aB577IMLtGdU/R2D5tNbxrG0WFOywcKVAmZIJOxRodStcAAADqorDEpc92Z+rD7en6aMdRZeVV/B0VGWbXsLNa64rebTWye4paRIbMn5UIECHzLzKJ0jUAAIBaOQtKtHrnUX2wLV2ffHdM+ae0gk6IDtfI7m10SY9UDTurlWIiQuZPSQSgkPnXWVG6xowOAADAqdKdhVq5PV0fbs/Qur3HVeo2PPe1TYjSJT1SNLpnqs7tlMyaGwSMkAk6tJcGAACosOfoSX24PV0fbMvQloM5Xved2aaFRvdM1SU9U9SrfQKtoBGQQi7oOCldAwAAIcjtNvTNIac+2JauD7ela++xPK/7zzktUaN7puriHinq3LqFRaMEmk7IBB3PPjoFJTIMg08mAABA0CtxufXlvuP6cFuGVm7PUHpuxR434Q6bzu/SSpf0TNHF3VPUJj7KwpECTS9kgk5CdNmMjstt6ERRqeKjwi0eEQAAQNPLKyrV2u+O6YNt6fp451HlFpZ67ouNcGh4tzYa3TNVw7u25u8hBLWQCTpR4Q5FhdtVWOKWM7+E/7ABAEDQOH6ySKt2HNWH29O1dnemikvdnvtatYjQxT1SdEmPVJ1/RktFhrHHDUJDyAQdSUqMjlB6SaFy8kvUMdnq0QAAADTcwax8fbg9Qx9sS9dX32fplEZpOr1lTFkzgR4p6ndakhx2SvYRekIr6MSEKz23UNm0mAYAAAFo37GTemvLYX24LUPbj+R63Xd2+3hd0iNVo3um6qyUFqxHRsgLuaAj0WIaAAAEnne/OaJ7lnytElfZ1I3dJg3slOzplNYhKcbiEQL+JbSCTvmmoU5mdAAAQABZ+r+Dun/5N3Ib0nmdkzX+nA4a2T1FybERVg8N8FshFXSSYstndNhLBwAABIh/frpPv393hyTpuoGn6ffjzmbNDVAHIRV0EspndLIJOgAAwM8ZhqFnP9qt51ftliTdNqyz7h/TjbU3QB2FVNCpWKND6RoAAPBfbrehx97ZroVffC9J+vXorvrl8C6EHKAeQivolG8a6mRGBwAA+KlSl1v3L/9Wr2/8UZL0+JU99YvBadYOCghAoRV0YszSNWZ0AACA/ykqdelX/96sFdvS5bDb9KcJvfXzfh2sHhYQkEIs6NBeGgAA+Kf84lLd9n8b9enuTEU47Prr9f10Sc9Uq4cFBKyQDDqUrgEAAH/iLCjRLQv/p40/ZCsmwqF/TBqgIWe0snpYQEALraBT3nUtp6BEhmGwoA8AAFju2IkiTVqwQTuO5Co+KkwLbxmoc05LsnpYQMALraBTPqPjchs6UVSq+Khwi0cEAABC2aGcAt34z/Xan5mnVi0i9X9TBqp723irhwUEBbvVA2hOUeEORYWXvWXK1wAAgJX2HTupCS99of2ZeWqfGK3Xbx9MyAGaUEgFHemU8jWCDgAAsMi2w05d87d1OuwsVJfWsXr9jsFKaxVr9bCAoBJ6QYdNQwEAgIU2/pCla//+pTJPFqtnu3gtvW2w2iZEWz0sIOiE1BodqSLoZDOjAwAAmtmnu4/p1lc3qqDEpXPTkjT/5nNZMwz4SOgFnfLSNSebhgIAgGa0YusR3f3vzSp2uXXhWa0178b+io5wWD0sIGiFXtAxS9eY0QEAAM1k2VcH9dv/fCO3IV3eq62endhXEWEht4IAaFYhF3QSPGt0CDoAAMD3Fny2X4+9s12SNHFARz1xVS857OzlB/hayAWdpJiy0rVsStcAAIAPGYah51ft0bMffSdJmjq0k353eXc2LAeaScgFncToshkd9tEBAAC+YhiGfv/uDs3/bL8kaebFZ+mui84g5ADNKPSCDqVrAADAh1xuQw8s/1ZLvjooSZo9tocmD+lk8aiA0BOCQcfcMJTSNQAA0LSKS92asWSz3v32iOw26emr++jq/h2sHhYQkkIw6NB1DQAANL2CYpdu/9dGffLdMUU47Hr+ur669Oy2Vg8LCFmhF3TK99HJKSiRYRjUygIAgEbLLSzRlIX/0/++z1Z0uEN/n9RfF5zZ2uphASEt9IJO+YyOy23oZFGp4tiNGAAANELmySLdtGCDth3OVXxUmF6efK76n55s9bCAkBdyO1VFhTsUFV72tilfAwAAjXE4p0DX/G2dth3OVasWEVp862BCDuAnQi7oSKeUrxF0AABAA+3PzNOEeeu071ie2iVEaeltg9WjXbzVwwJQLjSDjqfFNJ3XAABA/e04kqsJ89bpUE6BOreK1bI7zlfn1i2sHhaAU4TcGh1JSoim8xoAAGiYTQeydfOCDcotLFX3tvF69ZaBah0XafWwAPxESAadJPbSAQAADfDZ7kzd+n9fKb/Ypf6nJ2nBzed6PkAF4F9CMuiwlw4AAKivD7al667Xvlaxy60Lzmylv/2iv2IiQvJPKSAghOR/nQmeNToEHQAAULvlm37Ur1//Ri63oUt7purP1/VVZJjD6mEBqEFIBh26rgEAgLp65YvvNfutbZKkq/t30JNX9VKYIyT7OQEBJSSDTpKndI01OgAAoHovf75fj769XZJ08/lpeviKHrLbbRaPCkBdhGTQSaR0DQAA1GLrIaf+8O4OSdLdF52hGRefJZuNkAMEipCcd02IpusaAACoXmGJS/cs2azS8jU5hBwg8IRk0DFndJzM6AAAgCo8+f5O7Tl6Uq3jIvXEVb0IOUAACsmgU7GPTokMw7B4NAAAwJ98uvuYFn7xvSTpj1f3VnJshLUDAtAgIRl0zBmdUrehk0WlFo8GAAD4i5z8Yt23bIsk6Rfnna7hXdtYPCIADRWSQScq3KHIsLK3TotpAAAgSYZh6Hf/3aqM3CJ1bh2rBy7rbvWQADRCSAYdqaJ8jXU6AABAkt7YfEjvfntEYXabnpvYV9ERbAgKBLKQDTpm+Vo2ndcAAAh5P2bn6+E3yjYFvXvkmerdIdHaAQFotJANOgnR5qahzOgAABDKXG5D9y7dohNFpep3WqJ+ObyL1UMC0ARCNuiwaSgAAJCkf366T+v3ZykmwqHnJvZVmCNk/zwCgkrI/pfsaTGdR+kaAAChavvhXP3pw12SpIev6KHTW8ZaPCIATSVkg04CMzoAAIS0whKXZizZrBKXoVHdUzTx3I5WDwlAEwrZoJMYXbFpKADAOi+88ILS0tIUFRWlQYMGacOGDTWe/9xzz6lr166Kjo5Wx44dNWPGDBUWFnrunzNnjs4991zFxcWpTZs2GjdunHbt2uX1HMOHD5fNZvO63X777T55f/Bff/pgl3ZlnFCrFhF6cnwv2Ww2q4cEoAmFbtApn9FxFlC6BgBWWbJkiWbOnKnZs2dr06ZN6tOnj0aPHq2jR49Wef5rr72m+++/X7Nnz9aOHTs0f/58LVmyRA888IDnnE8++UR33nmnvvzyS61cuVIlJSW65JJLlJeX5/Vc06ZN05EjRzy3p59+2qfvFf7liz2Z+udn+yVJT43vrVYtIi0eEYCmFmb1AKyS5GkvzYwOAFhl7ty5mjZtmiZPnixJmjdvnt59910tWLBA999/f6Xzv/jiCw0ZMkTXX3+9JCktLU3XXXed1q9f7zlnxYoVXo9ZuHCh2rRpo40bN2rYsGGe4zExMUpNTfXF24KfcxaU6N5lWyRJ1w08TSO7p1g8IgC+ELIzOgme0jVmdADACsXFxdq4caNGjRrlOWa32zVq1CitW7euysecf/752rhxo6e8bd++fXrvvfd02WWXVfs6TqdTkpScnOx1fNGiRWrVqpXOPvtszZo1S/n5+Y19SwgQD7+5VUechUprGaMHL+9u9XAA+EjIzuhUlK4xowMAVsjMzJTL5VJKiven6SkpKdq5c2eVj7n++uuVmZmpoUOHyjAMlZaW6vbbb/cqXTuV2+3WPffcoyFDhujss8/2ep7TTz9d7dq10zfffKPf/va32rVrl5YvX17teIuKilRUVOT5Pjc3tz5vF37irS2H9ebmw3LYbZo7sa9iI0P2TyEg6IXsf92efXTyS2QYBgsQASAArFmzRk888YRefPFFDRo0SHv27NGvfvUrPf7443rooYcqnX/nnXdq69at+uyzz7yO33rrrZ6ve/XqpbZt22rkyJHau3evunSperPIOXPm6NFHH23aN4RmdcRZoAf/+60k6c4RZ+ic05IsHhEAXwrZ0jVzH51St6GTRaUWjwYAQk+rVq3kcDiUkZHhdTwjI6PatTMPPfSQfvGLX2jq1Knq1auXfv7zn+uJJ57QnDlz5Ha7vc6dPn263nnnHa1evVodOnSocSyDBg2SJO3Zs6fac2bNmiWn0+m5HTx4sC5vE37C7TZ037Ityi0sVZ8OCbrrojOsHhIAHwvZoBMV7lBkWNnbp8U0ADS/iIgI9e/fX6tWrfIcc7vdWrVqlQYPHlzlY/Lz82W3e1+6HA6HJMkwDM//Tp8+Xf/973/18ccfq1OnTrWOZfPmzZKktm3bVntOZGSk4uPjvW4IHC9/8b0+33NcUeF2zZ3YV+GOkP0TCAgZIVu6JpWVr2XkFslZUCK2CAOA5jdz5kzddNNNGjBggAYOHKjnnntOeXl5ni5skyZNUvv27TVnzhxJ0tixYzV37lz169fPU7r20EMPaezYsZ7Ac+edd+q1117Tm2++qbi4OKWnp0uSEhISFB0drb179+q1117TZZddppYtW+qbb77RjBkzNGzYMPXu3duaHwR8alf6CT21omzd1+8u76EurVtYPCIAzSG0g050hDJyi5jRAQCLTJw4UceOHdPDDz+s9PR09e3bVytWrPA0KDhw4IDXDM6DDz4om82mBx98UIcOHVLr1q01duxY/eEPf/Cc89JLL0kq2xT0VC+//LJuvvlmRURE6KOPPvKEqo4dO2r8+PF68MEHff+G0eyKSl26Z8lmFZe6NaJra9046DSrhwSgmdgMc67fj+Xm5iohIUFOp7NJSwUm/m2d1u/P0l+u66exfdo12fMCQLDw1e/fYMDPJjA8+f5Ozftkr5JiwvXBjGFqExdl9ZAANFJdf/+GdIGqp/MaLaYBAAg66/cd19/W7pUkzbmqNyEHCDGhHXTKNw11smkoAABBJbewRDOXbpFhSBP6d9ClZ1fdyQ9A8ArtoBNbsZcOAAAIHo++tV2HcgrUMTlas3/W0+rhALBAaAed8hmdbIIOAABB4/1vj+g/m36U3SY9e01ftYgM6d5LQMgK7aBTvkbHWUDpGgAAwSAjt1Cz/vutJOn2C7toQFqyxSMCYJXQDjrRlK4BABAsDMPQr1//Rjn5JTq7fbzuGXWW1UMCYKHQDjoxZukaMzoAAAS6//vyB6397pgiw+x69pq+iggL6T9zgJAX0r8BKkrXmNEBACCQ7Tl6Un94d4ck6f4x3XRmSpzFIwJgNYKOykrXAmDfVAAAUIXiUrdmLNmsolK3LjizlW4anGb1kAD4gdAOOuVd10rdhvKKXRaPBgAANMRfPt6tbw85lRAdrj9e3Ud2u83qIQHwAyEddKIjHIosr9/NzmOdDgAAgWbjD1l6YfUeSdITP++l1IQoi0cEwF+EdNCRWKcDAECgOllUqhlLtshtSD/v116X925r9ZAA+JF6B521a9dq7NixateunWw2m954441aH7No0SL16dNHMTExatu2rW655RYdP368IeNtcmb5Gi2mAQAILI+/vV0HsvLVPjFaj17Z0+rhAPAz9Q46eXl56tOnj1544YU6nf/5559r0qRJmjJlirZt26Zly5Zpw4YNmjZtWr0H6wsJZkMCNg0FACBgfLgtXUu+OiibTXrmmj6Kjwq3ekgA/ExYfR8wZswYjRkzps7nr1u3Tmlpabr77rslSZ06ddJtt92mp556qr4v7RNJ5UEnmxkdAAACwrETRZq1/FtJ0q0XdNZ5nVtaPCIA/sjna3QGDx6sgwcP6r333pNhGMrIyNDrr7+uyy67zNcvXSdm6ZqTTUMBAPB7hmHot//5RsfzitUtNU4zLznL6iEB8FM+DzpDhgzRokWLNHHiREVERCg1NVUJCQk1lr4VFRUpNzfX6+Yrp+6lAwAA/NtrGw7o451HFeGw67lr+yoyzGH1kAD4KZ8Hne3bt+tXv/qVHn74YW3cuFErVqzQ999/r9tvv73ax8yZM0cJCQmeW8eOHX02voo1OgQdAAD82f7MPP3+nR2SpN9c2lXdUuMtHhEAf+bzoDNnzhwNGTJEv/71r9W7d2+NHj1aL774ohYsWKAjR45U+ZhZs2bJ6XR6bgcPHvTZ+JJizK5rlK4BAOCvSl1uzViyWQUlLg3u3FK3DOlk9ZAA+Ll6NyOor/z8fIWFeb+Mw1E2zWwYRpWPiYyMVGRkpK+HJklKjKZ0DQAAf/fX1Xu0+WCO4qLC9Mw1fWS326weEgA/V+8ZnZMnT2rz5s3avHmzJGn//v3avHmzDhw4IKlsNmbSpEme88eOHavly5frpZde0r59+/T555/r7rvv1sCBA9WuXbumeReNQOkaAAD+bfPBHP3l4z2SpN+PO1vtEqMtHhGAQFDvGZ2vvvpKI0aM8Hw/c+ZMSdJNN92khQsX6siRI57QI0k333yzTpw4ob/+9a+69957lZiYqIsuushv2kuzYSgAAP4rv7hUM5ZslsttaGyfdrqyb3urhwQgQNQ76AwfPrzakjNJWrhwYaVjd911l+666676vlSzSIo1S9eKZRiGbDamwgEA8Bd/eHeH9mfmKTU+Sr+/8myrhwMggPi8GYG/M2d0St2G8opdFo8GAACYPt6ZoUXry6pEnrmmj6fcHADqIuSDTlS4XRFhZT8GOq8BAOAfjp8s0m9e/1aSdMuQThpyRiuLRwQg0IR80LHZbEpi01AAAPyGYRiatfxbZZ4s0pltWug3l3a1ekgAAlDIBx2JhgQAAPiTZV/9qA+3ZyjcYdNz1/ZVVLjD6iEBCEAEHZ3aYprSNQAArJTuLNSjb2+TJM28uKt6tkuweEQAAhVBR2waCgCAv/jv14eUV+xSnw4JunVYZ6uHAyCAEXQkJcWYpWvM6AAAYKUV29IlSRMGdJTDzpYPABqOoCMpkWYEAABY7nBOgbYczJHNJl3SM8Xq4QAIcAQdnbpGh6ADAIBVPiifzTn39GS1iYuyeDQAAh1BR3RdAwDAH7y/tSzojD471eKRAAgGBB3plH10WKMDAIAVjp0o0v++z5IkXUrQAdAECDqidA0AAKut3J4hw5B6d0hQ+8Roq4cDIAgQdETpGgAAVjO7rTGbA6CpEHRU0XXNWVAswzAsHg0AAKHFmV+iL/ZkSpIu7UnQAdA0CDqq2EenxGUor9hl8WgAAAgtH+3IUKnbUNeUOHVu3cLq4QAIEgQdSVHhdkWElf0oaEgAAEDzMsvW6LYGoCkRdCTZbDYlRrNpKAAAzS2vqFRrvzsmSRpD0AHQhAg65SrW6RB0AABoLmt2HVNRqVtpLWPULTXO6uEACCIEnXKJ5et0sildAwCg2by/9YiksrI1m81m8WgABBOCTjlK1wAAaF6FJS6t3nlUkjTm7LYWjwZAsCHolKN0DQCA5vXZ7kzlFbvUNiFKvdsnWD0cAEGGoFPObDGdnUfpGgAAzeH9reXd1nqmym6nbA1A0yLolEson9HJYUYHAACfK3G59dGODEnSpXRbA+ADBJ1yidFlMzqs0QEAwPe+3HdczoIStWoRoXPTkq0eDoAgRNApV7FGh9I1AAB8bUV52drFPVLloGwNgA8QdMqZQSebGR0AAHzK5Tb0wTbK1gD4FkGnHKVrAAA0j00HspV5skjxUWEa3Lml1cMBEKQIOuVOLV0zDMPi0QAAELze/7asbG1U9xRFhPGnCADf4LdLOTPolLgM5Re7LB4NAADByTAMfbCtLOhQtgbAlwg65aLDHZ5PlbLzaUgAAIAvfHvIqUM5BYqJcGjYWa2tHg6AIEbQKWez2ZQYXb6XDut0AADwCXOT0BFd2ygq3GHxaAAEM4LOKSrW6RB0AABoaoZheNpKU7YGwNcIOqeg8xoAAL7zXcZJ7c/MU0SYXSO6tbF6OACCHEHnFBV76bBGBwCApmbO5gw7s5VaRIZZPBoAwY6gcwpK1wAA8J33tx6RJI3uSdkaAN8j6JwiMcYsXWNGBwCApvR9Zp52pp+Qw27TxT1SrB4OgBBA0DlFAl3XAADwiRXle+cM7tzS88EiAPgSQecUSeW/eLMJOgAANKn36bYGoJkRdE5RsUaH0jUAAJrK4ZwCbTmYI5tNuqQnZWsAmgdB5xRsGAoAQNP7oLxsbcDpSWoTF2XxaACECoLOKRIpXQMAoMlVbBLa1uKRAAglBJ1TnFq6ZhiGxaMBACDwZZ4s0v++z5IkjaZsDUAzIuicwgw6JS5D+cUui0cDAEDgW7k9Q25D6t0hQR2SYqweDoAQQtA5RXS4QxGOsh9JDpuGAgDQaGa3NTYJBdDcCDqnsNlsnlmd7Dw6rwEA0BjOghJ9sSdTkjSGttIAmhlB5ycq1ukwowMAQGOs2pGhUrehs1JaqHPrFlYPB0CIIej8RGJ0Wec1WkwDANA479NtDYCFCDo/kVA+o5PDpqEAADRYXlGp1n53TBJlawCsQdD5iaQYNg0FAKCx1uw6pqJSt05vGaNuqXFWDwdACCLo/IS5aWhOPjM6AAA01IptZtlaqmw2m8WjARCKCDo/kRDNjA4AAI1RWOLSxzsyJEmX0lYagEUIOj+R6FmjQ9ABAKAhPt+Tqbxil9omRKlPh0SrhwMgRBF0fiKJ0jUAABrl1E1C7XbK1gBYg6DzE4mUrgEA0GAlLrdWbi8vW6PbGgALEXR+IoHSNQAAGmz9viw5C0rUMjZC56YlWz0cACGMoPMTZtc1Z36JDMOweDQAAASW97cekSRd0jNFDsrWAFiIoPMT5j46xS638otdFo8GAIDA4XIb+mCbWbbW1uLRAAh1BJ2fiA53KMJR9mOhfA0AgLrbdCBbmSeLFBcVpsGdW1o9HAAhjqDzEzabrWKdDp3XAACosxXl3dYu7p6iiDD+xABgLX4LVSEphs5rAADUh2EYnqAzmm5rAPwAQacKidHmXjoEHQAA6mLroVwdyilQdLhDF57V2urhAABBpyoVLaYpXQMAoC7MbmsjurVWVLjD4tEAAEGnSmwaCgBA3Z1atka3NQD+gqBThaRYs3SNGR0AAGqz++hJ7cvMU4TDrou6tbF6OAAgiaBTpQRmdAAAqLP3vy2bzbngzFZqERlm8WgAoAxBpwqJnjU6BB0AAGqzYptZtka3NQD+g6BTBbPrmpMZHQAAavTD8TztOJIrh92mUd1TrB4OAHgQdKpg7qOTzRodAABqZDYhGNy5pWeNKwD4A4JOFRIoXQMAoE7eZ5NQAH6KoFOFxJiK0jXDMCweDQAA/umIs0CbD+bIZpNG96RsDYB/IehUwdxHp9jlVkGJy+LRAEBwe+GFF5SWlqaoqCgNGjRIGzZsqPH85557Tl27dlV0dLQ6duyoGTNmqLCwsF7PWVhYqDvvvFMtW7ZUixYtNH78eGVkZDT5ewt2H5TP5gw4PUlt4qIsHg0AeCPoVCEmwqEIR9mPJpuGBADgM0uWLNHMmTM1e/Zsbdq0SX369NHo0aN19OjRKs9/7bXXdP/992v27NnasWOH5s+fryVLluiBBx6o13POmDFDb7/9tpYtW6ZPPvlEhw8f1lVXXeXz9xtsPGVrPSlbA+B/CDpVsNlsFet0aEgAAD4zd+5cTZs2TZMnT1aPHj00b948xcTEaMGCBVWe/8UXX2jIkCG6/vrrlZaWpksuuUTXXXed14xNbc/pdDo1f/58zZ07VxdddJH69++vl19+WV988YW+/PLLZnnfwSDzZJH+932WJNpKA/BPBJ1qmOVrtJgGAN8oLi7Wxo0bNWrUKM8xu92uUaNGad26dVU+5vzzz9fGjRs9wWbfvn167733dNlll9X5OTdu3KiSkhKvc7p166bTTjut2tdFZSu3Z8htSL3aJ6hDUozVwwGASti+uBpsGgoAvpWZmSmXy6WUFO9F7CkpKdq5c2eVj7n++uuVmZmpoUOHyjAMlZaW6vbbb/eUrtXlOdPT0xUREaHExMRK56Snp1c73qKiIhUVFXm+z83NrfN7DUZmW2lmcwD4K2Z0qmF2XmMvHQDwH2vWrNETTzyhF198UZs2bdLy5cv17rvv6vHHH/f5a8+ZM0cJCQmeW8eOHX3+mv7KWVCiL/ZmSiLoAPBfBJ1qmKVrOZSuAYBPtGrVSg6Ho1K3s4yMDKWmVv3H80MPPaRf/OIXmjp1qnr16qWf//zneuKJJzRnzhy53e46PWdqaqqKi4uVk5NT59eVpFmzZsnpdHpuBw8ebMC7Dg6rdmSoxGXorJQW6tK6hdXDAYAqEXSqYZauOSldAwCfiIiIUP/+/bVq1SrPMbfbrVWrVmnw4MFVPiY/P192u/ely+FwSJIMw6jTc/bv31/h4eFe5+zatUsHDhyo9nUlKTIyUvHx8V63UFVRttbW4pEAQPVYo1MNs3SNrmsA4DszZ87UTTfdpAEDBmjgwIF67rnnlJeXp8mTJ0uSJk2apPbt22vOnDmSpLFjx2ru3Lnq16+fBg0apD179uihhx7S2LFjPYGntudMSEjQlClTNHPmTCUnJys+Pl533XWXBg8erPPOO8+aH0QAySsq1SffHZMkXUpbaQB+jKBTDXNGh310AMB3Jk6cqGPHjunhhx9Wenq6+vbtqxUrVniaCRw4cMBrBufBBx+UzWbTgw8+qEOHDql169YaO3as/vCHP9T5OSXp2Wefld1u1/jx41VUVKTRo0frxRdfbL43HsA++e6YikrdOr1ljLq3jbN6OABQLZthGIbVg6hNbm6uEhIS5HQ6m61U4N1vjujO1zZpYFqylt5efSkDAAQzK37/BopQ/dnc9e+v9faWw7ptWGfNuqy71cMBEILq+vuXNTrVqGgvTekaAACSVFji0sc7yho90G0NgL8j6FSD0jUAALx9vidTecUupcZHqU+HRKuHAwA1IuhUw2xG4MwvUQBU9wEA4HOnbhJqt9ssHg0A1IygUw1zH51il1sFJS6LRwMAgLVKXG6tLC9bG023NQABgKBTjZgIh8IdZZ9WsWkoACDUrd+XpZz8ErWMjdDATslWDwcAakXQqYbNZvOUr2Wzlw4AIMSt2HZEknRJzxQ5KFsDEAAIOjUwy9eczOgAAEKY223og22UrQEILASdGlS0mCboAABC16YD2Tp2okhxUWE6v0srq4cDAHVC0KlBQnRZ6RprdAAAoez98m5ro7qnKCKMPx0ABAZ+W9UgybOXDmt0AAChyTAMr7bSABAoCDo1MEvXnJSuAQBC1NZDuTqUU6DocIeGndna6uEAQJ0RdGpgdl3LYUYHABCi3t9a1m1tRLfWio5wWDwaAKg7gk4NEsq7rrFGBwAQirzL1tpaPBoAqB+CTg2SYmhGAAAIXbuPntS+zDxFOOwa0ZWyNQCBhaBTg4r20pSuAQBCjzmbc8GZrRQXFW7xaACgfgg6NaB0DQAQysy20qPptgYgABF0apAUW166VlAiwzAsHg0AAM3nh+N52nEkVw67TRd3T7F6OABQbwSdGiSWz+gUl7pVUOKyeDQAADQfs2ztvM7Jng/+ACCQEHRqEBPhULjDJonyNQBAaFmxjW5rAAIbQacGNptNCdF0XgMAhJYjzgJ9fSBHNps0ugdlawACE0GnFklm5zU2DQUAhIgPysvW+p+WpDbxURaPBgAahqBTi4oW08zoAABCQ0XZGt3WAAQugk4tKF0DAISS4yeLtGF/liRpdE+CDoDARdCpBZuGAgBCycrtGXIbUq/2CeqYHGP1cACgwQg6tahYo8OMDgAg+JmbhFK2BiDQEXRqkRhjlq4xowMACG7OghJ9sTdTEkEHQOAj6NQiIZoZHQBAaPh4Z4ZKXIbObNNCXVq3sHo4ANAoBJ1a0HUNABAq3v+2rGxtDLM5AIIAQacWSZSuAQBCQH5xqT757pgk6dKz21o8GgBoPIJOLShdAwCEgjW7jqmo1K3TkmPUvW2c1cMBgEYj6NTi1NI1wzAsHg0AAL6xYmtF2ZrNZrN4NADQeASdWphd14pL3SoscVs8GgAAml5RqUsf7zwqSRrN+hwAQYKgU4vYCIfCHWWfbGWzTgcAEIQ+35Opk0WlSo2PUt8OiVYPBwCaBEGnFjabTQnRZkMC1ukAAIKP2W1tdM8U2e2UrQEIDgSdOqhYp8OMDgAg+Hy5/7gk6eIelK0BCB4EnTpIKg86TmZ0AABB6PjJsg/yOiZHWzwSAGg6BJ06MEvXsgk6AIAgU1jiUn6xS5KUFBth8WgAoOkQdOqA0jUAQLAy15+G2W2KiwyzeDQA0HTqHXTWrl2rsWPHql27drLZbHrjjTdqfUxRUZF+97vf6fTTT1dkZKTS0tK0YMGChozXEonRlK4BAIKT2VE0MSac/XMABJV6f3STl5enPn366JZbbtFVV11Vp8dcc801ysjI0Pz583XGGWfoyJEjcrsDZ08acyqf9tIAgGCTnVd2bUuKoWwNQHCpd9AZM2aMxowZU+fzV6xYoU8++UT79u1TcnKyJCktLa2+L2uphPIZHdpLAwCCjbn+lKADINj4fI3OW2+9pQEDBujpp59W+/btddZZZ+m+++5TQUFBtY8pKipSbm6u181KFWt0CDoAgOCSVV6tkBQbbvFIAKBp+XzV4b59+/TZZ58pKipK//3vf5WZmalf/vKXOn78uF5++eUqHzNnzhw9+uijvh5anSWWd11jjQ4AINjkULoGIEj5fEbH7XbLZrNp0aJFGjhwoC677DLNnTtXr7zySrWzOrNmzZLT6fTcDh486Oth1sic0WGNDgAg2Jila4kEHQBBxuczOm3btlX79u2VkJDgOda9e3cZhqEff/xRZ555ZqXHREZGKjIy0tdDq7NTS9cMw6ArDQAgaJgf4iVTugYgyPh8RmfIkCE6fPiwTp486Tn23XffyW63q0OHDr5++SZhfspVXOpWYUngdIsDAKA2Fe2lmdEBEFzqHXROnjypzZs3a/PmzZKk/fv3a/PmzTpw4ICksrKzSZMmec6//vrr1bJlS02ePFnbt2/X2rVr9etf/1q33HKLoqOjm+Zd+FhshENh9rJZHDYNBQAEE7quAQhW9Q46X331lfr166d+/fpJkmbOnKl+/frp4YcfliQdOXLEE3okqUWLFlq5cqVycnI0YMAA3XDDDRo7dqyef/75JnoLvmez2TyfdGXn0ZAAABA8zH10KF0DEGzqvUZn+PDhMgyj2vsXLlxY6Vi3bt20cuXK+r6UX0mMCVfmySJmdAAAQYXSNQDByudrdIJFYvmmobSYBgAEixKXWycKSyVRugYg+BB06ohNQwEAwSan/MM7m01KiKZ0DUBwIejUkWeNDnvpAACCRE75NS0hOlwOO1snAAguBJ06onQNABBs6LgGIJgRdOrIU7pG0AEABIms8o5rSTGUrQEIPgSdOjJL1+i6BgAIFmbpGjM6AIIRQaeOzBmdbGZ0AABBwrym0VoaQDAi6NRRYnTZRYA1OgCAYGE22GGzUADBiKBTRxXtpSldAwAEh+w8NgsFELwIOnV0aumaYRgWjwYAgMaj6xqAYEbQqSPz067iUrcKS9wWjwYAgMajdA1AMCPo1FFshENh5ZupUb4GAAgGZtChdA1AMCLo1JHNZmMvHQBAUMnOo700gOBF0KkH8xMv8xMwAAACldttyFlQvkaH0jUAQYigUw+J0WUXAlpMAwACXW5hidzlvXXMLRQAIJgQdOqhosU0QQcAENiyysvW4iLDFBHGnwMAgg+/2eohofwTL9boAAACndlaOpGyNQBBiqBTD0meZgSs0QEABDbzWkYjAgDBiqBTD3RdAwAEiyw6rgEIcgSdekgovxiwjw4AINCZH9qZ1QoAEGwIOvVgdl1jRgcAEOjYLBRAsCPo1IM5vU/QAQAEOjPoJMcSdAAEJ4JOPVS0l6Z0DQAQ2LLzKF0DENwIOvWQQOkaACBIULoGINgRdOrBnNEpKnWrsMRl8WgAAGg4StcABDuCTj20iAxTmN0mqeICAQBAIPJsGErpGoAgRdCpB5vNxl46AICAZxgGG4YCCHoEnXpinQ4AINCdLCpVicuQRNABELwIOvVU0WKa0jUAQGAyP6yLCrcrOsJh8WgAwDcIOvVU0WKaGR0AQGDKpmwNQAgg6NRTQjSbhgIAAltWHkEHQPAj6NQTm4YCAAKd+WFdUiwd1wAEL4JOPZk7SOfkMaMDAAhM5owOm4UCCGYEnXpKMJsRMKMDAAhQZkOdZIIOgCBG0KmnRNpLAwACnLlZaBKbhQIIYgSdejLX6DjpugYACFBZ+ZSuAQh+BJ16MjvUZLOPDgAgQHlK12IJOgCCF0GnnhIoXQMABLjs8oY6iZSuAQhiBJ16Mi8KRaVuFZa4LB4NAAD1l82MDoAQQNCppxaRYQqz2yQxqwMACExm0GHDUADBjKBTTzabzTOrwzodAECgKSh2qbDELYnSNQDBjaDTAKzTAQAEKvNDunCHTS0iwyweDQD4DkGnAcx2nE42DQUABJjsU1pL22w2i0cDAL5D0GkANg0FAASqHDYLBRAiCDoNkOjZS4egAwAILFl5NCIAEBoIOg1gLt7MoXQNABrthRdeUFpamqKiojRo0CBt2LCh2nOHDx8um81W6Xb55Zd7zqnqfpvNpj/+8Y+ec9LS0ird/+STT/r0ffqLHDquAQgRrEJsALN0zcmMDgA0ypIlSzRz5kzNmzdPgwYN0nPPPafRo0dr165datOmTaXzly9fruLiig+Zjh8/rj59+mjChAmeY0eOHPF6zPvvv68pU6Zo/PjxXscfe+wxTZs2zfN9XFxcU70tv2ZWIyTFUroGILgRdBogMdYsXWNGBwAaY+7cuZo2bZomT54sSZo3b57effddLViwQPfff3+l85OTk72+X7x4sWJiYryCTmpqqtc5b775pkaMGKHOnTt7HY+Li6t0biigdA1AqKB0rQFoRgAAjVdcXKyNGzdq1KhRnmN2u12jRo3SunXr6vQc8+fP17XXXqvY2Ngq78/IyNC7776rKVOmVLrvySefVMuWLdWvXz/98Y9/VGlpaY2vVVRUpNzcXK9bIKJ0DUCoYEanAcw1Os4Cgg4ANFRmZqZcLpdSUlK8jqekpGjnzp21Pn7Dhg3aunWr5s+fX+05r7zyiuLi4nTVVVd5Hb/77rt1zjnnKDk5WV988YVmzZqlI0eOaO7cudU+15w5c/Too4/WOi5/Z5ausVkogGBH0GmAxOiyT8GY0QEA68yfP1+9evXSwIEDqz1nwYIFuuGGGxQVFeV1fObMmZ6ve/furYiICN12222aM2eOIiMjq3yuWbNmeT0uNzdXHTt2bOS7aH5m2XVyLDM6AIIbpWsNYH4KxhodAGi4Vq1ayeFwKCMjw+t4RkZGrWtn8vLytHjx4ipL0kyffvqpdu3apalTp9Y6lkGDBqm0tFTff/99tedERkYqPj7e6xaITt0wFACCGUGnAcygU1TqVmGJy+LRAEBgioiIUP/+/bVq1SrPMbfbrVWrVmnw4ME1PnbZsmUqKirSjTfeWO058+fPV//+/dWnT59ax7J582bZ7fYqO70Fm5w8NgwFEBooXWuAFpFhcthtcrkN5eSXKDXBYfWQACAgzZw5UzfddJMGDBiggQMH6rnnnlNeXp6nC9ukSZPUvn17zZkzx+tx8+fP17hx49SyZcsqnzc3N1fLli3TM888U+m+devWaf369RoxYoTi4uK0bt06zZgxQzfeeKOSkpKa/k36keJSt04UlTVdoHQNQLAj6DSAzWZTYnS4jucVK6egWKkJUbU/CABQycSJE3Xs2DE9/PDDSk9PV9++fbVixQpPg4IDBw7IbvcuPti1a5c+++wzffjhh9U+7+LFi2UYhq677rpK90VGRmrx4sV65JFHVFRUpE6dOmnGjBle62+ClbnRtd0mxUcxowMguNkMwzCsHkRtcnNzlZCQIKfT6Tc10SOfWaO9x/L072nnaXCXqj9RBIBA54+/f/1FIP5sdqWf0Ojn1iopJlxfP3yJ1cMBgAap6+9f1ug0kLmI01lAQwIAQGAwGxEkUbYGIAQQdBqITUMBAIGGzUIBhBKCTgMllHeryWHTUABAgMii4xqAEELQaSDz0zD20gEABIpsZnQAhBCCTgOZpWtOStcAAAEihzU6AEIIQaeBzE1DWaMDAAgUFaVrBB0AwY+g00AJ5ReJHLquAQACREUzAtboAAh+BJ0GSmJGBwAQYMw1OonM6AAIAQSdBkqMLp/RIegAAAJEdvk1K5k1OgBCAEGngTxrdChdAwAEiGxK1wCEEIJOA5lBp7DErcISl8WjAQCgZi63IWf53m+UrgEIBQSdBmoRGSaH3SaJ8jUAgP9zFpTIMMq+TmRGB0AIIOg0kM1m8+ylQ/kaAMDfmWVrcVFhCndw+QcQ/PhN1wgJdF4DAASIitbSlK0BCA0EnUYwLxbmxQMAAH/l2SyUjmsAQgRBpxE8pWvM6AAA/Bwd1wCEGoJOI3hK1woIOgAA/0bpGoBQQ9BpBDYNBQAECk/pGkEHQIgg6DRCkqcZAWt0AAD+LYfSNQAhhqDTCIl0XQMABIisvLKgk0gzAgAhgqDTCAlm1zX20QEA+DnzQ7lkStcAhAiCTiPQdQ0AECjougYg1BB0GqFiHx2CDgDAv5lBJ5EZHQAhgqDTCJ41OpSuAQD8mGEYFaVrrNEBECIIOo1g7qNTWOJWYYnL4tEAAFC1E0WlKnUbkio+pAOAYEfQaYS4yDA57DZJkpNNQwEAfiq7vONadLhDUeEOi0cDAM2DoNMINpvN05Agm710AAB+KpuyNQAhiKDTSAnspQMA8HMVjQgoWwMQOgg6jUSLaQCAvzNL15jRARBKCDqNVNFimtI1AIB/MkvXaC0NIJQQdBrJU7pGMwIAgJ/KYbNQACGIoNNIidFsGgoA8G9ZeWbQYUYHQOgg6DSSubDTyaahAAA/ZX4Yx4wOgFBC0Gkk86KRnceMDgDAP5ld15JoRgAghBB0GinBbEbAjA4AwE9RugYgFBF0Gon20gAAf1dRukbQARA6CDqNVLFGh6ADAPA/hmGwYSiAkETQaSTz07Fs9tEBAPihghKXikrdktgwFEBoIeg0krmPTmGJW4UlLotHAwCAN3Oz0AiHXTERDotHAwDNh6DTSHGRYXLYbZIoXwMA+J/svIqyNZvNZvFoAKD5EHQayWazKYGGBAAAP2WWVlO2BiDUEHSagLm4k3U6AAB/Y5au0YgAQKgh6DQBWkwDAPxVNnvoAAhRBJ0mkFh+8XCyaSgAwM+Y1QZJlK4BCDEEnSbAjA4AwF9VbBZK6RqA0ELQaQKJnr10CDoAAP+SRekagBBF0GkC5gJPStcAAP7GU7pG0AEQYgg6TcAMOpSuAQD8jad0LZbSNQChhaDTBMzSNYIOAMDfZHk2DGVGB0BoIeg0AbMZAfvoAAD8TY65YShBB0CIIeg0gYo1OszoAAD8R1GpS3nFLkms0QEQegg6TSAxmtI1AID/Ma9LDrtNcVFhFo8GAJoXQacJJJYv8CwocamwxGXxaAAAKGOWVCdGh8tut1k8GgBoXgSdJhAXGSZH+QWE8jUAgL/Iziu7JiWyWSiAEETQaQI2m00J0bSYBgD4F3NGJzmW9TkAQg9Bp4kkeoIOndcAAP7BU7pGIwIAIYig00TMsoBsZnQAAH7Cs1kopWsAQhBBp4mYn5Y5C5jRAQD4B3Oz0CRK1wCEIIJOE0lkjQ4AwM+YpWvsoQMgFBF0mkhCeVlADl3XAAB+gtI1AKGMoNNEzE/LaEYAAPAXntI1ZnQAhCCCThMxmxFQugYA8Bfmh2+s0QEQiuoddNauXauxY8eqXbt2stlseuONN+r82M8//1xhYWHq27dvfV/W77GPDgDA31TM6FC6BiD01Dvo5OXlqU+fPnrhhRfq9bicnBxNmjRJI0eOrO9LBgSz6xprdAAA/qDU5VZuYakkStcAhKaw+j5gzJgxGjNmTL1f6Pbbb9f1118vh8NRr1mgQJEUw4ahAAD/4Tzlgzez6gAAQkmzrNF5+eWXtW/fPs2ePbs5Xs4SidFmMwJmdAAA1jNbS8dHhSnMwZJcAKGn3jM69bV7927df//9+vTTTxUWVreXKyoqUlFRkef73NxcXw2vyZjtpQtKXCoscSkq3GHxiAAAoSy7/IO3ZBoRAAhRPv2Ix+Vy6frrr9ejjz6qs846q86PmzNnjhISEjy3jh07+nCUTSM+KkwOu02SlMs6HQCAxbLLGxEksj4HQIjyadA5ceKEvvrqK02fPl1hYWEKCwvTY489pi1btigsLEwff/xxlY+bNWuWnE6n53bw4EFfDrNJ2Gw2Tw10NuVrAACLmaVrdFwDEKp8WroWHx+vb7/91uvYiy++qI8//livv/66OnXqVOXjIiMjFRkZ6cuh+URidLiy8oppSAAAsJz5oRt76AAIVfUOOidPntSePXs83+/fv1+bN29WcnKyTjvtNM2aNUuHDh3Sq6++KrvdrrPPPtvr8W3atFFUVFSl48HAXKdDi2kAgNUqZnQIOgBCU72DzldffaURI0Z4vp85c6Yk6aabbtLChQt15MgRHThwoOlGGEDMiwkzOgAAq2WzWSiAEFfvoDN8+HAZhlHt/QsXLqzx8Y888ogeeeSR+r5sQEiMNvfSYUYHAGAtStcAhDoa6zchStcAAP4ih9I1ACGOoNOE2DQUAOAvsjztpSldAxCaCDpNKCnWLF1jjQ4AwFo5bBgKIMQRdJpQAmt0AAB+wO02PGXUlK4BCFUEnSZk7j7NGh0AgJVOFJbK5S5rHETpGoBQRdBpQmbXNSelawAAC5l76MRGOBQZ5rB4NABgDYJOEzLLA7IpXQMAWMgMOomUrQEIYQSdJmS2ly4ocamwxGXxaAAAocoMOjQiABDKCDpNKC4yTHZb2de5rNMBAFgkO6/sGsT6HAChjKDThOx2W0XnNYIOAMAi2WwWCgAEnabmWaeTR0MCAIA1KF0DAIJOkzPX6TCjAwCwitkUh9I1AKGMoNPEKlpME3QAANYwqwooXQMQygg6TSzJs2kopWsAAGt41uhQugYghBF0mphZusZeOgAAq+SUX4OSKF0DEMIIOk0sMbp8RoegAwCwSBalawBA0Glq5sJPJ6VrAAALGIZRMaND6RqAEEbQaWJm0DE3awMAoDnlF7tU7HJLonQNQGgj6DSxRE8zAoIOAKD5mWVrEWF2RYc7LB4NAFiHoNPEKtpLU7oGAGh+ZtlackyEbDabxaMBAOsQdJpYIhuGAkC9vPDCC0pLS1NUVJQGDRqkDRs2VHvu8OHDZbPZKt0uv/xyzzk333xzpfsvvfRSr+fJysrSDTfcoPj4eCUmJmrKlCk6efKkz95jczJbS7NZKIBQR9BpYmbpWn6xS0WlLotHAwD+bcmSJZo5c6Zmz56tTZs2qU+fPho9erSOHj1a5fnLly/XkSNHPLetW7fK4XBowoQJXuddeumlXuf9+9//9rr/hhtu0LZt27Ry5Uq98847Wrt2rW699Vafvc/m5NlDh45rAEIcQaeJxUWGyV5eKeCkxTQA1Gju3LmaNm2aJk+erB49emjevHmKiYnRggULqjw/OTlZqampntvKlSsVExNTKehERkZ6nZeUlOS5b8eOHVqxYoX++c9/atCgQRo6dKj+8pe/aPHixTp8+LBP329zyC5fo5NMxzUAIY6g08TsdpsSoilfA4DaFBcXa+PGjRo1apTnmN1u16hRo7Ru3bo6Pcf8+fN17bXXKjY21uv4mjVr1KZNG3Xt2lV33HGHjh8/7rlv3bp1SkxM1IABAzzHRo0aJbvdrvXr11f7WkVFRcrNzfW6+SNzw2pK1wCEOoKOD3g6rzGjAwDVyszMlMvlUkpKitfxlJQUpaen1/r4DRs2aOvWrZo6darX8UsvvVSvvvqqVq1apaeeekqffPKJxowZI5errJw4PT1dbdq08XpMWFiYkpOTa3zdOXPmKCEhwXPr2LFjXd9qs6J0DQDKhFk9gGDk2UuHzmsA4DPz589Xr169NHDgQK/j1157refrXr16qXfv3urSpYvWrFmjkSNHNvj1Zs2apZkzZ3q+z83N9cuwk81moQAgiRkdn6hoMc2MDgBUp1WrVnI4HMrIyPA6npGRodTU1Bofm5eXp8WLF2vKlCm1vk7nzp3VqlUr7dmzR5KUmppaqdlBaWmpsrKyanzdyMhIxcfHe938UY5nRofSNQChjaDjAxWbhjKjAwDViYiIUP/+/bVq1SrPMbfbrVWrVmnw4ME1PnbZsmUqKirSjTfeWOvr/Pjjjzp+/Ljatm0rSRo8eLBycnK0ceNGzzkff/yx3G63Bg0a1MB34z/MDUOZ0QEQ6gg6PuBpRsCMDgDUaObMmfrHP/6hV155RTt27NAdd9yhvLw8TZ48WZI0adIkzZo1q9Lj5s+fr3Hjxqlly5Zex0+ePKlf//rX+vLLL/X9999r1apVuvLKK3XGGWdo9OjRkqTu3bvr0ksv1bRp07RhwwZ9/vnnmj59uq699lq1a9fO92/ax8xrD2t0AIQ61uj4gHlxySboAECNJk6cqGPHjunhhx9Wenq6+vbtqxUrVngaFBw4cEB2u/dncrt27dJnn32mDz/8sNLzORwOffPNN3rllVeUk5Ojdu3a6ZJLLtHjjz+uyMhIz3mLFi3S9OnTNXLkSNntdo0fP17PP/+8b99sM8mmdA0AJBF0fMJsRuCkdA0AajV9+nRNnz69yvvWrFlT6VjXrl1lGEaV50dHR+uDDz6o9TWTk5P12muv1WucgaCwxKX84rLucpSuAQh1lK75gBl0KF0DADQn87oTZrcpLpLPMgGENoKOD7CPDgDACmYjgsSYcNlsNotHAwDWIuj4QKKnGQGlawCA5pPDZqEA4EHQ8QFP6VoBMzoAgOaTTcc1APAg6PhAYnTZBSa/2KWiUpfFowEAhIqs/IrSNQAIdQQdH4iLCpO9vDTayTodAEAzySlfo5NMxzUAIOj4gt1uq9g0lPI1AEAzMUvXEildAwCCjq/QeQ0A0NzYLBQAKhB0fCSBzmsAgGbmCTqUrgEAQcdXktg0FADQzOi6BgAVCDo+4ildK2BGBwDQPLLzKF0DABNBx0cqSteY0QEANA9K1wCgAkHHR9g0FADQnEpcbp0oLJVE6RoASAQdn0nydF2jdA0A4HtmBYHNVlFVAAChjKDjI4k0IwAANCPzg7WE6HA5zF2rASCEEXR8hDU6AIDmRMc1APBG0PERs+uakzU6AIBmkFXecS2RjmsAIImg4zNma89s1ugAAJqBWbqWzIwOAEgi6PhMYnTZhSa/2KWiUpfFowEABDuzdC2RoAMAkgg6PhMXFSZzLSjlawAAXzMrCJJjKV0DAImg4zN2u83TkMBJQwIAgI9le9boMKMDABJBx6fMi002QQcA4GPmjA5d1wCgDEHHhypaTNOQAADgW+aHapSuAUAZgo4PeTYNZY0OAMDHzBkdStcAoAxBx4fM8gFmdAAAvmau0aF0DQDKEHR8qKJ0jRkdAIDvuN2Gp8NnEqVrACCJoONTlK4BAJpDbmGJ3EbZ1+Y+bgAQ6gg6PpRIe2kAQDPIKi9baxEZpogwLu0AIBF0fCop1mwvzRodAIDvmB3XKFsDgAoEHR9ijQ4AoDnksIcOAFRC0PEhs8WnkzU6AAAfMkvXaC0NABUIOj6UyIahAIBmYFYOJMdQugYAJoKOD5klBHnFLhWXui0eDQAgWLFZKABURtDxobioMNlsZV/nFDCrAwDwjWzW6ABAJQQdH7LbbZ6GBLSYBgD4SnZeeekaXdcAwIOg42OedTo0JAAA+AilawBQGUHHx8yLTnYepWsAAN+gdA0AKiPo+FhiDDM6AADfYsNQAKiMoONjiazRAQD4kGEYbBgKAFUg6PiYWbpG1zUAgC+cLCpVicuQRNABgFMRdHzMLF3LZkYHAOAD5mahUeF2RUc4LB4NAPgPgo6PUboGAPClrDzK1gCgKgQdH6N0DQDgS3RcA4CqEXR8zFO6lseMDgCg6eXQcQ0AqkTQ8TFzRsdJe2kAgA+YpWtsFgoA3gg6Pmau0TFbfwIA0JTM60syQQcAvBB0fMwsXcsrdqm41G3xaAAAwcazWWgMpWsAcCqCjo/FR4XLZiv7moYEAICmlpVP6RoAVIWg42N2u00JtJgGAPiIp3QtlqADAKci6DQDzzodGhIAAJqY2dUzkdI1APBC0GkGCeZeOszoAACaGPvoAEDVCDrNwFwgmk3nNQBAE8umdA0AqkTQaQaJrNEBAPhAQbFLhSVlHT0pXQMAbwSdZmB2wqHrGgCgKZmzOWF2m1pEhlk8GgDwLwSdZpDg2TSUGR0AQNPxrM+JjZDN3MsAACCJoNMszDU6BB0AQFPKYbNQAKgWQacZULoGAPCFrDw2CwWA6hB0mkECMzoAAB/wbBZK0AGASgg6zSCRNToAAB/INkvXYildA4CfIug0gyTPhqGUrgEAmg6lawBQPYJOMzD3Nsgrdqm41G3xaAAAwYLSNQCoHkGnGcRFhcvs+uksoHwNANA0zNI1NgsFgMoIOs3AYbedspcO5WsAgKbh2UeHGR0AqISg00w8DQmY0QEANJFTNwwFAHgj6DSTBE9DAoIOAKBpZOexYSgAVIeg00wSKV0DADSh4lK3ThaVSpKSmdEBgEoIOs0kiU1DAQBNKKeg7IMzu02Kj2JGBwB+iqDTTMw9DswLEwAAjWGWrSVEh8tut1k8GgDwPwSdZlLRdY0ZHQBA49GIAABqRtBpJuYeB3RdAwA0hRxaSwNAjQg6zSTJ03WN0jUAQONl0XENAGpE0GkmCTQjAAA0ITYLBYCaEXSaSSJrdAAATSiHNToAUCOCTjMxu645WaMDAGgCZulaIqVrAFAlgk4zMWuoTxaVqrjUbfFoAACBzpzRSaZ0DQCqRNBpJnFR4bKVb3PArA4AoLHMNTqJBB0AqBJBp5k47DbPztVONg0FADRSdj5d1wCgJgSdZpRI5zUAQBMxZ3SSaUYAAFUi6DQjs7wgm6ADAGgEl9vwlEFTugYAVSPoNKOKFtOUrgEAGs5ZUCLDKPuarmsAUDWCTjMyL0Y0IwAANIZZthYXFaZwB5dyAKgKvx2bUZKndI0ZHQBAw3k2C6VsDQCqRdBpRgnRNCMAADSeuVkoHdcAoHoEnWbk6bpG6RoAoBHMyoAkOq4BQLUIOs3Is0aHGR0AQCNk51G6BgC1Ieg0o0TW6AAAmoC5TQEd1wCgegSdZpTIGh0AQBMwmxEkM6MDANUi6DQjc0aH9tIAgMbIKi9dS2SNDgBUq95BZ+3atRo7dqzatWsnm82mN954o8bzly9frosvvlitW7dWfHy8Bg8erA8++KCh4w1o5ozOyaJSlbjcFo8GABCozMoAZnQAoHr1Djp5eXnq06ePXnjhhTqdv3btWl188cV67733tHHjRo0YMUJjx47V119/Xe/BBrr46HDZbGVfU74GAGgoT9c11ugAQLXC6vuAMWPGaMyYMXU+/7nnnvP6/oknntCbb76pt99+W/369avvywc0h92m+KhwOQtK5CwoVuu4SKuHBAAIQGbQSWRGBwCq1exrdNxut06cOKHk5OTmfmm/4NlLhxkdAEADGIZRUbrGGh0AqFa9Z3Qa609/+pNOnjypa665ptpzioqKVFRU5Pk+Nze3OYbWLBKjw/WDCDoAgIY5UVSqUrchifbSAFCTZp3Ree211/Too49q6dKlatOmTbXnzZkzRwkJCZ5bx44dm3GUvsVeOgDg7YUXXlBaWpqioqI0aNAgbdiwodpzhw8fLpvNVul2+eWXS5JKSkr029/+Vr169VJsbKzatWunSZMm6fDhw17Pk5aWVuk5nnzySZ++z6ZibhYaHe5QVLjD4tEAgP9qtqCzePFiTZ06VUuXLtWoUaNqPHfWrFlyOp2e28GDB5tplL5H6RoAVFiyZIlmzpyp2bNna9OmTerTp49Gjx6to0ePVnn+8uXLdeTIEc9t69atcjgcmjBhgiQpPz9fmzZt0kMPPaRNmzZp+fLl2rVrl372s59Veq7HHnvM67nuuusun77XppJN2RoA1EmzlK79+9//1i233KLFixd7PnWrSWRkpCIjg3Oh/unJMZKkbw45LR4JAFhv7ty5mjZtmiZPnixJmjdvnt59910tWLBA999/f6Xzf7q+c/HixYqJifEEnYSEBK1cudLrnL/+9a8aOHCgDhw4oNNOO81zPC4uTqmpqU39lnyuohEBZWsAUJN6z+icPHlSmzdv1ubNmyVJ+/fv1+bNm3XgwAFJZbMxkyZN8pz/2muvadKkSXrmmWc0aNAgpaenKz09XU5naP6hf2HX1pKktd8dUyl76QAIYcXFxdq4caPXLL/dbteoUaO0bt26Oj3H/Pnzde211yo2Nrbac5xOp2w2mxITE72OP/nkk2rZsqX69eunP/7xjyotLW3Q+2huZulaEh3XAKBG9Z7R+eqrrzRixAjP9zNnzpQk3XTTTVq4cKGOHDniCT2S9Pe//12lpaW68847deedd3qOm+eHmr4dk5QQXdZievPBHA1IC83ucwCQmZkpl8ullJQUr+MpKSnauXNnrY/fsGGDtm7dqvnz51d7TmFhoX7729/quuuuU3x8vOf43XffrXPOOUfJycn64osvNGvWLB05ckRz586t9rn8pVGOWbqWROkaANSo3kFn+PDhMgyj2vt/Gl7WrFlT35cIag67TcPOaq23txzWml3HCDoA0EDz589Xr169NHDgwCrvLykp0TXXXCPDMPTSSy953Wd+SCdJvXv3VkREhG677TbNmTOn2tLpOXPm6NFHH226N9BAOWwWCgB10uz76EAaUV6+tnpX1YttASAUtGrVSg6HQxkZGV7HMzIyal07k5eXp8WLF2vKlClV3m+GnB9++EErV670ms2pyqBBg1RaWqrvv/++2nP8pVFOVh6bhQJAXRB0LDDsrLKgs+1wro7mFlo8GgCwRkREhPr3769Vq1Z5jrndbq1atUqDBw+u8bHLli1TUVGRbrzxxkr3mSFn9+7d+uijj9SyZctax7J582bZ7fYatz6IjIxUfHy8180Kns1CmdEBgBo1+4ahkFq1iFSfDgna8qNTa747pmsGBM8+QQBQHzNnztRNN92kAQMGaODAgXruueeUl5fn6cI2adIktW/fXnPmzPF63Pz58zVu3LhKIaakpERXX321Nm3apHfeeUcul0vp6emSyjq2RUREaN26dVq/fr1GjBihuLg4rVu3TjNmzNCNN96opKSk5nnjjWB2XWONDgDUjKBjkeFd25QFnV1HCToAQtbEiRN17NgxPfzww0pPT1ffvn21YsUKT4OCAwcOyG73Lj7YtWuXPvvsM3344YeVnu/QoUN66623JEl9+/b1um/16tUaPny4IiMjtXjxYj3yyCMqKipSp06dNGPGDK91O/6M0jUAqBuCjkWGd22tP6/arU+/y1SJy61wB1WEAELT9OnTNX369Crvq6qhTdeuXattipOWllZjwxxJOuecc/Tll1/We5z+oqJ0jaADADXhr2uL9O6QqOTYCJ0oKtWmH7KtHg4AIAAYhqEsNgwFgDoh6FjEYbfpwrPM7mvHLB4NACAQFJS4VFxattk0a3QAoGYEHQsNL28zvYY20wCAOjA3C41w2BUb4bB4NADg3wg6Fhp2ZmvZbNLO9BM64iywejgAAD+XnVdRtmaz2SweDQD4N4KOhZJiI9SvY6IkaQ3lawCAWnhaS9OIAABqRdCx2PCuZZvTrd5J+RoAoGZm6VpSLI0IAKA2BB2LjSgPOp/vyfQsMAUAoCpm6RozOgBQO4KOxXq2i1erFpHKK3bpq++zrB4OAMCPeUrX6LgGALUi6FjM7tVmmvI1AED1zM1Ck9hDBwBqRdDxAyO6mW2maUgAAKheFqVrAFBnBB0/cMEZreWw27T76EkdzMq3ejgAAD9F1zUAqDuCjh9IiAnXOaclSpLWfMesDgCgajl0XQOAOiPo+AmzzfQnrNMBAFQjy7NhKDM6AFAbgo6fqGgzfVyFJS6LRwMA8Ec55aVryQQdAKgVQcdPdG8bp5T4SBWUuLRhP22mAQDeikpdyisu+yCMNToAUDuCjp+w2WwaflbZrA7d1wAAP2Wuz7HbpLioMItHAwD+j6DjRyraTLNOBwDg7dSOa3a7zeLRAID/I+j4kSFntFKY3aZ9mXn64Xie1cMBAPiR7LyyGZ1ENgsFgDoh6PiRuKhwDUhLkkT5GgDAG3voAED9EHT8jNl9bTXlawCAU3iCTixBBwDqgqDjZ8z9dNbtpc00AKCCZ7NQStcAoE4IOn7mrJQWapcQpaJSt9btO271cAAAfsLcLJTSNQCoG4KOn7HZbLqwfFZnzU7K1wAAZShdA4D6Iej4oRFdy9pMr951TIZhWDwaAIA/yPbM6FC6BgB1QdDxQ0POaKVwh00HsvK1P5M20wAAKTvfbC/NjA4A1AVBxw/FRoZpYKdkSWWzOgAA5JSXriVTugYAdULQ8VNmm+k1tJkGAOjUZgSUrgFAXRB0/JTZZnr9vizlF5daPBoAgJVKXW7lFpZdCyhdA4C6Iej4qS6tY9UhKVrFLre+2EObaQAIZc6CEs/XidHM6ABAXRB0/JTNZvOUr62mfA0AQprZWjo+KkxhDi7dAFAX/Lb0YyO6lbWZXkObaQAIaWbHNRoRAEDdEXT82ODOrRQRZtehnALtOXrS6uEAACxi7qHD+hwAqDuCjh+LjnDovM4tJVG+BgChzCxdo+MaANQdQcfPjehaUb4GAAhNZulaEqVrAFBnBB0/Z7aZ/t/3WTpZRJtpAAhFFTM6BB0AqCuCjp/r1CpWaS1jVOIy9PmeTKuHAwCwQDabhQJAvRF0AoA5q7OGdToAEJIoXQOA+iPoBIDh5et0Vu+kzTQAhKIcStcAoN4IOgHgvM4tFRVuV3puoXZlnLB6OACAZpblaS9N6RoA1BVBJwBEhTt0fpdWkspmdQAAoSWHDUMBoN4IOgHCU77GOh0ACClut6GcgvI1OpSuAUCdEXQCxPCzyhoSbPwhW7mFJRaPBgDQXE4UlsrlLlufSekaANQdQSdAnNYyRl1ax8rlNvTZbtpMA0CoMPfQiY1wKDLMYfFoACBwEHQCiNlmevVOytcAIFRk5ZuNCChbA4D6IOgEkBHmfjrf0WYaAEKFp7V0LGVrAFAfBJ0Acm6nJMVEOHTsRJG2Hc61ejgAgGaQnUcjAgBoCIJOAIkMq2gzvYbuawAQErLZLBQAGoSgE2BGdCtrM71mF/vpAEAoqAg6lK4BQH0QdAKM2ZBg04FsT902ACB4ZZdvFprEZqEAUC8EnQDTPjFaZ6W0kNuQ1tJmGgCCXnYepWsA0BAEnQDk6b7GOh0ACHrZnvbSlK4BQH0QdAKQWb72ya5jcrtpMw0AwSynvHQtmdI1AKgXgk4AGpCWpBaRYTqeV6xvDzmtHg4AwIeyKF0DgAYh6ASgcIddQ88w20zTfQ0AgpVhGJ4ZHZoRAED9EHQClNlmejXrdAAgaOUXu1TsckuivTQA1BdBJ0BdeFbZOp0tP+bo+Mkii0cDAPAFs2wtIsyu6HCHxaMBgMBC0AlQqQlR6t42XoYhfUqbaQAISp5GBDERstlsFo8GAAILQSeADe9K+RoABDNaSwNAwxF0Api5n84n3x2TizbTABB0zKBDxzUAqD+CTgA757RExUWFKSe/RFt+zLF6OACAJpZdvkaHPXQAoP4IOgEszGHXsDPLytfW7KR8DQCCTXb5Gh1K1wCg/gg6Aa5inQ776QBAsKF0DQAajqAT4C4sDzrfHnLq2AnaTANAMMlms1AAaDCCToBrExels9vHSyprSgAACB7mGh02CwWA+iPoBAGz+xptpgEguFC6BgANR9AJAsPLg86n3x1Tqctt8WgAAE0lh9I1AGgwgk4Q6NsxUYkx4cotLNXXB3OsHg4AoIlkUboGAA1G0AkCDrvN02Z6NW2mASAoFJa4VFDikiQlUroGAPVG0AkSI7qV76dDm2kACApm2ZrDblN8VJjFowGAwEPQCRLDzmwtm03afiRX6c5Cq4cDAGikU8vWbDabxaMBgMBD0AkSLVtEqneHREnSJ99RvgYAgS6nvOMaZWsA0DAEnSAyoivlawAQLMzNQpMJOgDQIASdIOJpM707UyW0mQaAgJblmdGh4xoANARBJ4j0bp+glrEROllUqq++z7Z6OACARsjJY7NQAGgMgk4QsdttuvCs8vI11ukAQEDLZrNQAGgUgk6QudBcp7OTdToAEMiy89ksFAAag6ATZIad2Vp2m7Qr44QO5xRYPRwAQAN5gg4zOgDQIASdIJMUG6F+pyVJovsaAAQyT+kaa3QAoEEIOkFoePk6ndW7WKcDAIEqO4/SNQBoDIJOEBrRrazN9Od7MlVU6rJ4NACAhqB0DQAah6AThHq0jVfruEjlF7toMw0AAajE5daJwlJJlK4BQEMRdILQqW2mV++kfA0AAk1O+focm01KiKZ0DQAagqATpEZ0LStfY50OAASenPKytYTocDnsNotHAwCBiaATpIae2UoOu017j+XpYFa+1cMBANQDHdcAoPEIOkEqITpc/T1tppnVAYBAklXecS2RjmsA0GAEnSA2vJvZZpr9dAAgkJila8nM6ABAgxF0gpi5TueLvZkqLKHNNAAEiqx8c0aHoAMADUXQCWLdUuOUGh+lwhK31u/Psno4AIA6yvGs0aF0DQAaiqATxGw2m4Z3pc00AASa7Dw2CwWAxiLoBLnh5eVrn3zHOh0ACBTZ5aVrdF0DgIYj6AS5IWe0VJjdpv2ZedqfmWf1cAAAdZBN6RoANBpBJ8jFRYXr3LRkSbSZBoBA4ZnRoXQNABqMoBMCzHU6a2gzDQABwbNGh9I1AGgwgk4IGNGtbJ3Oun3HVVBMm2kA8GdutyFnAaVrANBYBJ0QcGabFmqfGK3iUrfW7cu0ejgA4OWFF15QWlqaoqKiNGjQIG3YsKHac4cPHy6bzVbpdvnll3vOMQxDDz/8sNq2bavo6GiNGjVKu3fv9nqerKws3XDDDYqPj1diYqKmTJmikydP+uw91kduYYncRtnX7KMDAA1H0AkBNptNF1K+BsAPLVmyRDNnztTs2bO1adMm9enTR6NHj9bRo1WvKVy+fLmOHDniuW3dulUOh0MTJkzwnPP000/r+eef17x587R+/XrFxsZq9OjRKiws9Jxzww03aNu2bVq5cqXeeecdrV27VrfeeqvP329dZJWXrbWIDFNEGJdpAGgofoOGiBHlbaY/3nlUhmFYPBoAKDN37lxNmzZNkydPVo8ePTRv3jzFxMRowYIFVZ6fnJys1NRUz23lypWKiYnxBB3DMPTcc8/pwQcf1JVXXqnevXvr1Vdf1eHDh/XGG29Iknbs2KEVK1bon//8pwYNGqShQ4fqL3/5ixYvXqzDhw8311uvltlxLZGyNQBoFIJOiDi/S0tFOOz6MbtAe4/RZhqA9YqLi7Vx40aNGjXKc8xut2vUqFFat25dnZ5j/vz5uvbaaxUbGytJ2r9/v9LT072eMyEhQYMGDfI857p165SYmKgBAwZ4zhk1apTsdrvWr19f7WsVFRUpNzfX6+YLOeUd15LpuAYAjULQCRGxkWEa2Ik20wD8R2Zmplwul1JSUryOp6SkKD09vdbHb9iwQVu3btXUqVM9x8zH1fSc6enpatOmjdf9YWFhSk5OrvF158yZo4SEBM+tY8eOtY6xIczSNdbnAEDjEHRCCG2mAQST+fPnq1evXho4cGCzvN6sWbPkdDo9t4MHD/rkdXLKS9eSKV0DgEYh6IQQs830hv1ZyisqtXg0AEJdq1at5HA4lJGR4XU8IyNDqampNT42Ly9Pixcv1pQpU7yOm4+r6TlTU1MrNTsoLS1VVlZWja8bGRmp+Ph4r5svmJuFMqMDAI1D0AkhnVvFqmNytIpdbn2x97jVwwEQ4iIiItS/f3+tWrXKc8ztdmvVqlUaPHhwjY9dtmyZioqKdOONN3od79Spk1JTU72eMzc3V+vXr/c85+DBg5WTk6ONGzd6zvn444/ldrs1aNCgpnhrjWIGHTYLBYDGIeiEEJvN5um+tpp1OgD8wMyZM/WPf/xDr7zyinbs2KE77rhDeXl5mjx5siRp0qRJmjVrVqXHzZ8/X+PGjVPLli29jttsNt1zzz36/e9/r7feekvffvutJk2apHbt2mncuHGSpO7du+vSSy/VtGnTtGHDBn3++eeaPn26rr32WrVr187n77k22XnlpWuxlK4BQGOEWT0ANK8RXdvo1XU/aE15m2mbzWb1kACEsIkTJ+rYsWN6+OGHlZ6err59+2rFihWeZgIHDhyQ3e79mdyuXbv02Wef6cMPP6zyOX/zm98oLy9Pt956q3JycjR06FCtWLFCUVFRnnMWLVqk6dOna+TIkbLb7Ro/fryef/55373ReqB0DQCahs0IgE1VcnNzlZCQIKfT6bOa6FBRUOxSn8c+VHGpWx/OGKazUuKsHhIAP8bv3+r56mdzybOf6LuMk/rXlEEaemarJnteAAgWdf39S+laiImOcGhw57JSj9U7KV8DAH9jbhiaROkaADQKQScEjShvM806HQDwL4ZhKDuPZgQA0BQIOiFoeHlDgq++z9aJwhKLRwMAMJ0sKlWpu6yinKADAI1D0AlBaa1i1alVrErdhj7fk2n1cAAA5czNQqPC7YqOcFg8GgAIbASdEDXcLF/beczikQAATFmUrQFAkyHohCizfG3Nd2VtpgEA1qO1NAA0HYJOiBrUKVlR4XZl5BZpx5ETVg8HAKCK0jU2CwWAxiPohKiocIeGdCnbn4HuawDgH8zSNWZ0AKDxCDohzFyn88ku1ukAgD/IyTfX6DCjAwCNVe+gs3btWo0dO1bt2rWTzWbTG2+8Uetj1qxZo3POOUeRkZE644wztHDhwgYMFU3N02b6hyxtOZhj7WAAAJ7NQpOZ0QGARqt30MnLy1OfPn30wgsv1On8/fv36/LLL9eIESO0efNm3XPPPZo6dao++OCDeg8WTatjcozGnJ0qtyHd8a+NnpIJAIA1smhGAABNJqy+DxgzZozGjBlT5/PnzZunTp066ZlnnpEkde/eXZ999pmeffZZjR49ur4vjyb21NW9tTP9hPZn5unuf3+tV24ZKIfdZvWwACAkeUrXaEYAAI3m8zU669at06hRo7yOjR49WuvWrav2MUVFRcrNzfW6wTfio8L10o3nKDrcoc/2ZOrZld9ZPSQACFnZeWWla+yjAwCN5/Ogk56erpSUFK9jKSkpys3NVUFBQZWPmTNnjhISEjy3jh07+nqYIa1baryeHN9LkvTX1Xu0cnuGxSMCgNCUnc+GoQDQVPyy69qsWbPkdDo9t4MHD1o9pKB3Zd/2uvn8NEnSzCWbtT8zz9oBAUAIIugAQNPxedBJTU1VRob3DEFGRobi4+MVHR1d5WMiIyMVHx/vdYPvPXBZdw04PUknikp1x782Kr+41OohAUDIKCh2qbDELYk1OgDQFHwedAYPHqxVq1Z5HVu5cqUGDx7s65dGPUWE2fXCDeeoVYtI7Uw/oVnLv5VhGFYPCwBCgjmbE2a3qUVkvXsFAQB+ot5B5+TJk9q8ebM2b94sqax99ObNm3XgwAFJZWVnkyZN8px/++23a9++ffrNb36jnTt36sUXX9TSpUs1Y8aMpnkHaFIp8VF64fp+cthtenPzYb267gerhwQAIcFTthYbIZuN7pcA0Fj1DjpfffWV+vXrp379+kmSZs6cqX79+unhhx+WJB05csQTeiSpU6dOevfdd7Vy5Ur16dNHzzzzjP75z3/SWtqPDercUrPGdJMkPf7Odm38IcviEQFA8MvJNzuuUbYGAE2h3nPjw4cPr7GcaeHChVU+5uuvv67vS8FCU4Z20tcHc/TuN0f0y0Wb9M5dF6h1XKTVwwKAoGVu2sxmoQDQNPyy6xqsZ7PZ9NT43jqjTQtl5BZp+mubVOpyWz0sAAha5mahyQQdAGgSBB1Uq0VkmObd2F+xEQ6t35+lpz/YZfWQACBoZZula3RcA4AmQdBBjc5o00J/mtBHkvT3tfv03rdHLB4RAAQnStcAoGkRdFCrMb3a6rZhnSVJv162RXuOnrB4RAAQfChdA4CmRdBBnfx6dFed1zlZecUu3fZ/G3WyiM1EAaApZZWXriXSdQ0AmgRBB3US5rDrL9edo9T4KO09lqffvL6FzUQBoAmZMzpJzOgAQJMg6KDOWsdF6oUbzlG4w6b3vk3XPz/db/WQACBonLphKACg8Qg6qJf+pyfpoSt6SJKeXLFTX+47bvGIACA4ZOexYSgANCWCDurtF+edrp/3ay+X29D01zYp3Vlo9ZAAIKAVl7o9ax8pXQOApkHQQb3ZbDY98fNe6pYap8yTxbrztU0qLmUzUQBoqJyCsrI1u02Kj2ZGBwCaAkEHDRId4dC8G/srLipMG3/I1hPv7bB6SAAQsMyytYTocDnsNotHAwDBgaCDBktrFatnr+krSVr4xfd64+tD1g4IAAJUNh3XAKDJEXTQKKN6pOiui86QJN2//BvtOJJr8YgAIPDk0HENAJocQQeNds+os3TBma1UWOLWHf/aKGdBidVDAoCAkkXHNQBocgQdNJrDbtPz1/ZT+8RofX88X/cu3Sy3m81EAaCuzNK1RErXAKDJEHTQJJJiI/TSjecowmHXRzuO6qVP9lo9JAAIGGbpWjKlawDQZAg6aDK9OyTqsSt7SpKe+XCXPt19zOIRAUBgMEvXEildA4AmQ9BBk7p24GmaOKCj3IZ097+/1qGcAquHBAB+L4euawDQ5Ag6aHKPXtlTvdonKDu/RHf8a6MKS1xWDwkA/BrtpQGg6RF00OSiwh168YZzlBgTrm9+dOrRt7dbPSQA8GvZ+XRdA4CmRtCBT3RMjtGfr+0nm03694YDWvq/g1YPCQD8VjbNCACgyRF04DMXntVaM0edJUl68M2t2nrIafGIAMD/uNyGZ/8x2ksDQNMh6MCn7hxxhkZ2a6PiUrdu/9dGZecVWz0kAPArzoISGeVbj9F1DQCaDkEHPmW32zR3Yl+dlhyjH7MLdM+SzXKxmSgAeJhla3FRYQp3cFkGgKbCb1T4XEJ0uObd2F+RYXZ98t0xPb9qt9VDAgC/Yc5003ENAJoWQQfNoke7/2/v3uObru89jr+T9EpJobQkTaHlLm25X4RBFXWA6EHP3EHcjkyZnIdOVx3Ysx1ApzycE0QnYxOV4VEe28RNN284nUrLEURlsCJyaUsFZmHQhktLU1p6S3L+aBOobbDOhl9IXs/Ho4/SX38tn/wo+eTd7+WXqKXfHiFJ+lXBZ9pY4jS4IgAIDey4BgDBQdDBBTNrXF/d8o1+kqQFf9ypQyfrDK4IAIznv4cOO64BQJci6OCCeuC6bI3J6ClXfbN+8EKhzjRyM1EAkY2pawAQHAQdXFAxUWY9PWeskhNiVFzu0v2v75bXy+YEACKXb+oaO64BQNci6OCCc/SI15M3j5HZJL2644jW/e2Q0SUBgGFO+W4WyogOAHQpgg4MMXlQihZekylJeujNvfrkUJXBFQGAMSpbp671ZI0OAHQpgg4Mc8eUgbpmWKqa3F79cN0OnTzdYHRJAHDBnWLXNQAICoIODGMymfT47JEamJKg8up63fOHT9Ts9hhdFgBcUFVMXQOAoCDowFDWuGitvmWcusVY9NGBk3piQ6nRJQHABeULOj0JOgDQpQg6MNwldquWzxopSXrm/QP6w7ZD7MQGICJ4vd6zU9cSmLoGAF2JoIOQcP2oNP3XZQMkSYtf3a27XmDNDoDwV9PQrGZPyy92uI8OAHQtgg5Cxn3/lqX/nn6JoswmvbO3QjNWbtZ7eyuMLgsAgsZ3s9D4aIvioi0GVwMA4YWgg5BhMZt0z9Qhej03R0PtVp043ag7fl+ovJd3qvpMk9HlAUCXq2LHNQAIGoIOQs7wPj20/p4c/eCKgTK13lT0mpWbteWzE0aXBgBdyrcRQRL30AGALkfQQUiKjbJo8bVZ+tMPJqlfcjeVV9fre8/9TQ++sUd1jc1GlwcAXcI3dY31OQDQ9Qg6CGnj+/fS2z+6XLd8o58k6Xcfl+nffvWBCsuqDK4MAL4+39S1nkxdA4AuR9BByEuIjdLDNwzX7+ZNUGpinD4/WafZqz/S8ndK1NDsNro8APiXnfLdLJSpawDQ5Qg6uGhMuaS33r13iv5jTB95vC333PnWqg+192i10aUBIcPj8erQyTq9u7dCK/NLdefvC3XVL97XmUZ+KRCKKmu5WSgABEuU0QUAX0WP+Git+M5oXT3Mrvtf26OSihrd8NSHmj91iO68YpCiLGR3RI7ahmaVVNSouNylkgqXistrtK+iRqcb2q9jK3XWaFR6zwtfJM7Ld7PQXkxdA4AuR9DBRema4Q6N799L9726W+8VOfWL90qVX3xMT9w0SoN6dze6PKBLeTxe/bPqjIorXC2hprxGxRUulZ2s6/D8GItZg23dleVIVJbDqixHogbb+H8RinwjOuy6BgBdj6CDi1ZK91j95pZxeu2TI1qyfq92Hj6lmb/+QAuvydTcSf1lNpuMLhH4ynyjNCXnhJqSAKM0kmSzxirLkahMh1XZjkRlpiZqYO8ERTO6eVHwbS/N1DUA6HoEHVzUTCaT/mNsX31jYLL+58+7tGX/CT30ZpHe2+vU47NHqm9SN6NLBDrk9baM0hT5Rmhap5+VVdbJ621/fkejNJmpViV3j73wxaPLnJ26RtABgK5G0EFYSOsZr9/Nm6B1fyvT0rdL9PHBk7pm5Qd68LpszR7fVyYTozswTm1Ds/Y5a85OOyt3fekoTaYv0KQmKsvBKE048nq9qvSP6LBGBwC6GkEHYcNsNumWSf112ZDe+vGfPlVhWZX+55VdendvhZbNGiGbNc7oEhEBauqbtPVgpYqOujo9SuObdsYoTWQ50+RWY7NHEmt0ACAYCDoIOwNSEvTyDyZpzeaD+uWGUhWUHNOMX27Wz28YoZkjHUaXhzB09NQZ5Rc7taHIqa0HT6rJ3T7V9G5dS8MoDXx8NwuNsZiVEGMxuBoACD8EHYQli9mku64cpKsye+velz5VcblLuS/u0Lt70/Szbw1j4S++Fq/Xq+LyGm0ocmpDcYX2HHG1+fzAlASNzujp3xwg02FVCqM0+IKq2rPT1pheCwBdj6CDsJaZmqg3cnP064LP9PT7+7X+06PaevCklt84UlcNtRldHi4iTW6Ptv2jsiXcFDl15NQZ/+dMJmlsRpKmZ9s1PdvOFufoFN+Oa0n84gUAgoKgg7AXE2XWj2cM1dQsm/77T5/q4PFa3bZ2u/5zQrrun5mt7rH8N0DHauqbtKn0uDYUOfV/Jcfkqj+7eUBctFmXDe6tq7PtuirTpt5WRmzw1fimriUlsBEBAAQDr/AQMcZkJOmtey7XY++WaO2Hn+sP2w5ry/4T+sWNozRxYLLR5SFElFefUX6RU+91sN4mOSFG38y0aXq2XZcP6a141lXga/BNXWNEBwCCg6CDiBIfY9GS64dperZdP/nTLh2uPKPvPrtV/5UzQD+eMVRx0bxwjTS+9Ta+zQR2H6lu8/mBKQn+KWljMpJk4Ua06CLcLBQAgougg4g0eVCK3llwuR7+S5Fe/vs/9b9b/qH3S49rxU2jNLJvT6PLQ5A1uT3a/o9KvVfkVH6xU/+sar/eZlpWS7gZbGO9DYLDf7NQpq4BQFAQdBCxrHHReuzGUbo6O1WLXt2t/cdO69tPf6S7rxqsu785mG1/w8zphmZt2ndcG4oqtPEL621io8y6fEiKpmfb9c1MO+ttcEFUMnUNAIKKoIOINy3brg39kvTTN/borV3l+lXBZ9pYckwrbhqlIXar0eXha6iortcG3/1tDpxUo9vj/1yvhBhNbV1vc9mQFHWL4ekQFxZT1wAguOjsgFruSv7UzWM1Y9hRPfD6Hu0+Uq2ZT25R7pWDddmQFGWmWpXA7mwhz+v1qqSiRvlFTm0odmrXP9uutxlwznqbsay3gcGYugYAwcUrN+Ac/z4qTRMH9NLCV3bp/X3H9cv8Uv0yv1SS1C+5m7Jab/6YmZqobEei+ibFy8yLZUMdq6lX0VGXNpUeV36xU4cr2663GZPeU9OzU1lvg5BTWcuIDgAEE0EH+AJ7YpzWfv9SvbLjiNZ/elQl5S4dq2lQ2ck6lZ2s0zt7K/zndo+N0tBUqzJTrcpyJCrLYdXQ1ETuzRMEzW6PDp6oVXG5S0VHXSoqd6m4vEYnTje0OY/1NrhYnOKGoQAQVLwaAzpgMpl047i+unFcX0nSydMNKqmoUXHri+uSCpc+c57W6YZmFZZVqbCsqs3XZ/Tq1ib8ZDkSlZ7UjdGfTqo+09R6rV3+a77PWaPGZk+7c02mlilpYzOSWu9vw3obhL6GZrdqG92SpF4EHQAICl4NAJ2Q3D1WOYNjlTM4xX+sye3RP1pHGHzhp7jcJaerQYcq63Sosk7vFTn95yfEWFpGfxyJLQEo1aqhqVZZ4yJ3fr7H49XhqrqWUZryGhUdbbmGR06d6fD87rFR/gCZndZyHYfardy4Excd3/ocs0myxtGKASAYeHYF/kXRFrMusVt1id2qb40+e7yytlEl5S4Vt44AlVS4VOo8rdpGt3YcOqUdh061+T7pveKVmXo2/GQ5EpXRK/xGf840urXPeTbMtFybGp1uaO7w/D494/2BJptRMYSZc3dc42caAIKDoAN0sV4JMZo8OEWTzxn9afaN/vjCT+uL/PLqeh2uPKPDlWe04ZzRn26+0Z/Ulhf5mY5EDU21KvEiGP3xer06VtNwzjqalvefn6iVx9v+/Jgos4barf4pftmORGU6EtUjPvQfK/CvqqptGdFJ6sbPOQAEC0EHuACiLGYNsVs1xG7Vv49K8x+vqm30r/1pmfpWo1Jnjeoa3frk0Cl98oXRn9gos+KiLYqLbn0fZVFstPns+2iL4qItree1HPedHxvV+r7NORbFRX3x685+/9gos0ymwL9tbnJ7tP/Yaf8IjW+DAN9uUl+U0j1WWQ5r6yhNyyjWwJQERXFzVkSYKjYiAICgI+gABkpKiNGkQcmaNCjZf6zZ7dHnJ2vPWfdTo5Jyl45W16uh2aOGZo+qO17CEhQxUeYOw1CT26MDx0+ryd1+mMZiNmlgSoJ/HY1vUwabNe7CFQ6EMH/QSSDoAECwEHSAEBNlMWuwzarBNquuP2f0x1XfpOq6JjU0e1Tf5FZDs1v1Tb4/t7z3fVzf7FZDk+fs+yZ3+/Nav77hC+fUN3vkPmeOWWOzR43NHrnqO15LY42L8k85843SDLF3V1w0GwQAgVTV+kZ0mLoGAMFC0AEuEolx0RdsjU6z26P65o4DUkNrkDLJpMG27uqbFH/e6W0A2rt+VJoG26xK7cEoJwAEC0EHQDtRFrO6W8zc+BQIkn7JCeqXnGB0GQAQ1lgBDAAAACDsEHQAAAAAhB2CDgAAAICwQ9ABABjqqaeeUv/+/RUXF6eJEydq27Zt5z3/1KlTys3NlcPhUGxsrC655BK9/fbb/s/3799fJpOp3Vtubq7/nCuvvLLd5++8886gPUYAwIXHSmMAgGFeeukl5eXlafXq1Zo4caJWrlypGTNmaN++fbLZbO3Ob2xs1PTp02Wz2fTnP/9Zffr0UVlZmXr27Ok/Z/v27XK73f6P9+zZo+nTp2v27Nltvtftt9+un/3sZ/6Pu3Xr1vUPEABgGIIOAMAwK1as0O23367bbrtNkrR69Wq99dZbev7557Vo0aJ25z///POqrKzURx99pOjolu3W+/fv3+ac3r17t/n40Ucf1aBBg3TFFVe0Od6tWzelpqZ24aMBAIQSpq4BAAzR2NiowsJCTZs2zX/MbDZr2rRp+vjjjzv8mvXr12vSpEnKzc2V3W7X8OHDtXTp0jYjOF/8O1544QXNmzev3f2e1q1bp5SUFA0fPlyLFy9WXV3deettaGiQy+Vq8wYACF2M6AAADHHixAm53W7Z7fY2x+12u0pKSjr8moMHD2rjxo2aM2eO3n77be3fv18//OEP1dTUpCVLlrQ7//XXX9epU6f0/e9/v83xm2++Wf369VNaWpp27dqlhQsXat++fXr11VcD1rts2TI99NBDX/2BAgAMQdABAFw0PB6PbDab1qxZI4vFonHjxunIkSN6/PHHOww6zz33nK699lqlpaW1OX7HHXf4/zxixAg5HA5NnTpVBw4c0KBBgzr8uxcvXqy8vDz/xy6XS+np6V30yAAAXY2gAwAwREpKiiwWi5xOZ5vjTqcz4NoZh8Oh6OhoWSwW/7GsrCxVVFSosbFRMTEx/uNlZWXKz88/7yiNz8SJEyVJ+/fvDxh0YmNjFRsb+6XfCwAQGlijAwAwRExMjMaNG6eCggL/MY/Ho4KCAk2aNKnDr8nJydH+/fvl8Xj8x0pLS+VwONqEHElau3atbDabZs6c+aW17Ny5U1JLkAIAhAeCDgDAMHl5eXr22Wf129/+VsXFxbrrrrtUW1vr34Xt1ltv1eLFi/3n33XXXaqsrNT8+fNVWlqqt956S0uXLm1zjxypJTCtXbtWc+fOVVRU28kLBw4c0MMPP6zCwkJ9/vnnWr9+vW699VZNmTJFI0eODP6DBgBcEExdAwAY5jvf+Y6OHz+uBx98UBUVFRo9erTeeecd/wYFhw4dktl89ndy6enpevfdd3Xvvfdq5MiR6tOnj+bPn6+FCxe2+b75+fk6dOiQ5s2b1+7vjImJUX5+vlauXKna2lqlp6dr1qxZ+ulPfxrcBwsAuKBMXq/Xa3QRX8blcqlHjx6qrq5WYmKi0eUAQMTg+Tcwrg0AGKOzz79MXQMAAAAQdgg6AAAAAMIOQQcAAABA2CHoAAAAAAg7BB0AAAAAYYegAwAAACDsEHQAAAAAhB2CDgAAAICwQ9ABAAAAEHYIOgAAAADCDkEHAAAAQNgh6AAAAAAIOwQdAAAAAGGHoAMAAAAg7BB0AAAAAIQdgg4AAACAsEPQAQAAABB2oowuoDO8Xq8kyeVyGVwJAEQW3/Ou73kYZ9GbAMAYne1NF0XQqampkSSlp6cbXAkARKaamhr16NHD6DJCCr0JAIz1Zb3J5L0Ifk3n8Xh09OhRWa1WmUymr/z1LpdL6enpOnz4sBITE4NQ4cWLaxMY1yYwrs35hdP18Xq9qqmpUVpamsxmZjufi94UPFybwLg2gXFtAgu3a9PZ3nRRjOiYzWb17dv3a3+fxMTEsPjHDQauTWBcm8C4NucXLteHkZyO0ZuCj2sTGNcmMK5NYOF0bTrTm/j1HAAAAICwQ9ABAAAAEHYiIujExsZqyZIlio2NNbqUkMO1CYxrExjX5vy4PugMfk4C49oExrUJjGsTWKRem4tiMwIAAAAA+CoiYkQHAAAAQGQh6AAAAAAIOwQdAAAAAGGHoAMAAAAg7IR90HnqqafUv39/xcXFaeLEidq2bZvRJYWEZcuW6dJLL5XVapXNZtMNN9ygffv2GV1WSHr00UdlMpm0YMECo0sJCUeOHNH3vvc9JScnKz4+XiNGjNDf//53o8synNvt1gMPPKABAwYoPj5egwYN0sMPPyz2e0FH6E3t0Zc6j77UHr2pY5Hem8I66Lz00kvKy8vTkiVLtGPHDo0aNUozZszQsWPHjC7NcJs2bVJubq62bt2qDRs2qKmpSVdffbVqa2uNLi2kbN++Xb/5zW80cuRIo0sJCVVVVcrJyVF0dLT++te/qqioSE888YSSkpKMLs1wy5cv1zPPPKNVq1apuLhYy5cv12OPPaYnn3zS6NIQYuhNHaMvdQ59qT16U2CR3pvCenvpiRMn6tJLL9WqVaskSR6PR+np6brnnnu0aNEig6sLLcePH5fNZtOmTZs0ZcoUo8sJCadPn9bYsWP19NNP6+c//7lGjx6tlStXGl2WoRYtWqQPP/xQH3zwgdGlhJzrrrtOdrtdzz33nP/YrFmzFB8frxdeeMHAyhBq6E2dQ19qj77UMXpTYJHem8J2RKexsVGFhYWaNm2a/5jZbNa0adP08ccfG1hZaKqurpYk9erVy+BKQkdubq5mzpzZ5mco0q1fv17jx4/X7NmzZbPZNGbMGD377LNGlxUSJk+erIKCApWWlkqSPv30U23ZskXXXnutwZUhlNCbOo++1B59qWP0psAivTdFGV1AsJw4cUJut1t2u73NcbvdrpKSEoOqCk0ej0cLFixQTk6Ohg8fbnQ5IeGPf/yjduzYoe3btxtdSkg5ePCgnnnmGeXl5em+++7T9u3b9aMf/UgxMTGaO3eu0eUZatGiRXK5XMrMzJTFYpHb7dYjjzyiOXPmGF0aQgi9qXPoS+3RlwKjNwUW6b0pbIMOOi83N1d79uzRli1bjC4lJBw+fFjz58/Xhg0bFBcXZ3Q5IcXj8Wj8+PFaunSpJGnMmDHas2ePVq9eHfHN5OWXX9a6dev04osvatiwYdq5c6cWLFigtLS0iL82wFdFX2qLvnR+9KbAIr03hW3QSUlJkcVikdPpbHPc6XQqNTXVoKpCz913362//OUv2rx5s/r27Wt0OSGhsLBQx44d09ixY/3H3G63Nm/erFWrVqmhoUEWi8XACo3jcDiUnZ3d5lhWVpZeeeUVgyoKHT/5yU+0aNEiffe735UkjRgxQmVlZVq2bFlENBN0Dr3py9GX2qMvnR+9KbBI701hu0YnJiZG48aNU0FBgf+Yx+NRQUGBJk2aZGBlocHr9eruu+/Wa6+9po0bN2rAgAFGlxQypk6dqt27d2vnzp3+t/Hjx2vOnDnauXNnRDeTnJycdtu9lpaWql+/fgZVFDrq6upkNrd9SrVYLPJ4PAZVhFBEbwqMvhQYfen86E2BRXpvCtsRHUnKy8vT3LlzNX78eE2YMEErV65UbW2tbrvtNqNLM1xubq5efPFFvfHGG7JaraqoqJAk9ejRQ/Hx8QZXZyyr1dpuTnhCQoKSk5Mjfq74vffeq8mTJ2vp0qW66aabtG3bNq1Zs0Zr1qwxujTDXX/99XrkkUeUkZGhYcOG6ZNPPtGKFSs0b948o0tDiKE3dYy+FBh96fzoTYFFfG/yhrknn3zSm5GR4Y2JifFOmDDBu3XrVqNLCgmSOnxbu3at0aWFpCuuuMI7f/58o8sICW+++aZ3+PDh3tjYWG9mZqZ3zZo1RpcUElwul3f+/PnejIwMb1xcnHfgwIHe+++/39vQ0GB0aQhB9Kb26EtfDX2pLXpTxyK9N4X1fXQAAAAARKawXaMDAAAAIHIRdAAAAACEHYIOAAAAgLBD0AEAAAAQdgg6AAAAAMIOQQcAAABA2CHoAAAAAAg7BB0AAAAAYYegAwAAACDsEHQAAAAAhB2CDgAAAICwQ9ABAAAAEHb+HxgWBv6zVDQZAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, axs = plt.subplots(1, 2, figsize=(10, 10))\n", "axs[0].set_title(\"Loss value on eval set\")\n", "axs[0].plot(eval_metrics_history[\"test_loss\"])\n", "axs[1].set_title(\"Accuracy on eval set\")\n", "axs[1].plot(eval_metrics_history[\"test_accuracy\"])" ] }, { "cell_type": "markdown", "id": "a3f7b0ad-ddfa-4ab3-b56f-6ea99385ff6a", "metadata": {}, "source": [ "## Use Model for Inference\n", "After all that, the product of what we were working for: a trained model we can save and load for inference. For people using LLMs recently, this pattern may look rather familiar: an input sentence tokenized into an array and computed 'next' token-by-token. While many recent LLMs are decoder-only, this was an encoder/decoder architecture with the very specific english-to-spanish pattern baked in.\n", "\n", "We've changed a couple things from the source 'use' function, here - because of the tokenizer used, things like `[start]` and `[end]` are no longer single tokens - instead `[start]` is `[29563, 60] = \"[start\" + \"]\"` and `[end]` is `[58308, 60] = \"[end\" + \"]\"` - thus we start with only a single token `[start` and can't only test on `last_token = \"[end\"]`. Otherwise, the main change here is that the input is assumed a single sentence, rather than batch inference." ] }, { "cell_type": "code", "execution_count": 23, "id": "e4589706-cfd6-4efb-9975-bfa0df75d4f0", "metadata": {}, "outputs": [], "source": [ "def decode_sequence(input_sentence):\n", "\n", " input_sentence = custom_standardization(input_sentence)\n", " tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)\n", "\n", " decoded_sentence = \"[start\"\n", " for i in range(sequence_length):\n", " tokenized_target_sentence = tokenize_and_pad(decoded_sentence, tokenizer, sequence_length)[:-1]\n", " predictions = model(jnp.array([tokenized_input_sentence]), jnp.array([tokenized_target_sentence]))\n", "\n", " sampled_token_index = np.argmax(predictions[0,i, :]).item(0)\n", " sampled_token = tokenizer.decode([sampled_token_index])\n", " decoded_sentence += \"\" + sampled_token\n", "\n", " if decoded_sentence[-5:] == \"[end]\":\n", " break\n", " return decoded_sentence" ] }, { "cell_type": "code", "execution_count": 24, "id": "554c2f72-0bd3-4ed1-804b-5f1a4cc13851", "metadata": {}, "outputs": [], "source": [ "test_eng_texts = [pair[0] for pair in test_pairs]" ] }, { "cell_type": "code", "execution_count": 25, "id": "c1d6edbb-af89-42c9-90c3-d61612b75da3", "metadata": {}, "outputs": [], "source": [ "test_result_pairs = []\n", "for _ in range(10):\n", " input_sentence = random.choice(test_eng_texts)\n", " translated = decode_sequence(input_sentence)\n", "\n", " test_result_pairs.append(f\"[Input]: {input_sentence} [Translation]: {translated}\")" ] }, { "cell_type": "markdown", "id": "258c2172-5a0f-4dee-9b21-f65433183c62", "metadata": {}, "source": [ "## Test Results\n", "For the model and the data, not too shabby - It's definitely spanish-ish. Though when 'making' friends, please don't confuse 'hacer' (to make) with 'comer' (to eat)." ] }, { "cell_type": "code", "execution_count": 26, "id": "4f0ae018-b7cd-4849-b245-c5c647ad1a95", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Input]: We're both way too busy to help you right now. [Translation]: [start] los dos estamos demasiado para ayudar esta mañana [end]\n", "[Input]: Have you eaten dinner? [Translation]: [start] has comido la cena [end]\n", "[Input]: That is the poet I met in Paris. [Translation]: [start] ese es el poeta que conocí en parís [end]\n", "[Input]: It doesn't make sense to me. [Translation]: [start] no me hace falta sentido [end]\n", "[Input]: We're happy. [Translation]: [start] estamos felices [end]\n", "[Input]: What about me? [Translation]: [start] de qué me [end]\n", "[Input]: Make a decision and make it with the confidence that you are right. [Translation]: [start] haz una decisión y tomará la confianza en el confian [end]\n", "[Input]: Put some salt on your meat. [Translation]: [start] ponte algo de sal [end]\n", "[Input]: Tom's deaf. [Translation]: [start] tom es sordo [end]\n", "[Input]: How old are your brothers and sisters? [Translation]: [start] qué edad son tus hermanos [end]\n" ] } ], "source": [ "for i in test_result_pairs:\n", " print(i)" ] }, { "cell_type": "markdown", "id": "5ca18d4c-b3c0-4abb-b5fc-fc96a2264b53", "metadata": {}, "source": [ "Example output from the above cell:\n", "\n", " [Input]: We're going to have a baby. [Translation]: [start] nosotros vamos a tener un bebé [end]\n", " [Input]: You drive too fast. [Translation]: [start] conducís demasiado rápido [end]\n", " [Input]: Let me know if there's anything I can do. [Translation]: [start] déjame saber si hay cualquier cosa que yo pueda hacer [end]\n", " [Input]: Let's go to the kitchen. [Translation]: [start] vayamos a la cocina [end]\n", " [Input]: Tom gasped. [Translation]: [start] tom se quedó sin aliento [end]\n", " [Input]: I was just hanging out with some of my friends. [Translation]: [start] estaba escquieto con algunos de mi amigos [end]\n", " [Input]: Tom is in the bathroom. [Translation]: [start] tom está en el cuarto de baño [end]\n", " [Input]: I feel safe here. [Translation]: [start] me siento segura [end]\n", " [Input]: I'm going to need you later. [Translation]: [start] me voy a necesitar después [end]\n", " [Input]: A party is a good place to make friends with other people. [Translation]: [start] una fiesta es un buen lugar de comer amigos con otras personas [end]" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }