{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Kzqlx7fpXRnJ" }, "source": [ "# Part 3: Train a diffusion model for image generation\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)\n", "\n", "This tutorial guides you through developing and training a simple diffusion model using the [U-Net architecture](https://en.wikipedia.org/wiki/U-Net) for image generation using JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io). This builds upon the previous tutorial, [Variational autoencoder (VAE) and debugging in JAX](https://jax-ai-stack.readthedocs.io/en/latest/digits_vae.html), which focus on training a simpler generative model called VAE.\n", "\n", "In this tutorial, you'll learn how to:\n", "\n", "- Load and preprocess the dataset\n", "- Define the diffusion model with Flax\n", "- Create the loss and training functions\n", "- Train the model (with Google Colab’s Cloud TPU v2)\n", "- Visualize and track the model’s progress\n", "\n", "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX." ] }, { "cell_type": "markdown", "metadata": { "id": "gwaaMmjXt7n7" }, "source": [ "## Setup\n", "\n", "JAX for AI (the stack) installation is covered [here](https://docs.jaxstack.ai/en/latest/install.html). And JAX (the library) installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site.\n", "\n", "Start with importing JAX, JAX NumPy, Flax NNX, Optax, matplotlib and scikit-learn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dVVACvmDuDCM" }, "outputs": [], "source": [ "import jax\n", "import optax\n", "from flax import nnx\n", "import jax.numpy as jnp\n", "from functools import partial\n", "import matplotlib.pyplot as plt\n", "from sklearn.datasets import load_digits\n", "from typing import Tuple, Callable, List, Optional\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": { "id": "tQ5KGMyrYG2H" }, "source": [ "**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.\n", "\n", "Check the available JAX devices, or [`jax.Device`s](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ldmtemzPBO5z", "outputId": "d21720a2-65cd-4a5c-ef86-3a0912e36c34" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check the available JAX devices.\n", "jax.devices()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading and preprocessing the data\n", "\n", "We'll use the small, self-contained [scikit-learn `digits` dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) for ease of experimentation to demonstrate diffusion model training. For simplicity, we'll focus on generating only the digit '1' (one).\n", "\n", "This involves several steps, such as:\n", "\n", "1. Loading the dataset\n", "2. Filtering the images of '1' (one)\n", "3. Normalizing pixel values\n", "4. Converting the data into `jax.Array`s\n", "5. Reshaping the data, and splitting it into training and test sets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 303 }, "id": "jNizSH6uuXY4", "outputId": "112723a1-fd36-46b2-946d-6d789f5a33ed" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training set size: 172\n", "Test set size: 10\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPwAAAD7CAYAAABOrvnfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJXklEQVR4nO1da3PiSLY8trHxu3s8M7sxu/v/f9nG3t2J6e6Zfti8uR86Up0keUoCIxBGGaEAYxDi6GSdZ1WdLZfLZfTo0eMkcH7oC+jRo8f+0BO+R48TQk/4Hj1OCD3he/Q4IfSE79HjhNATvkePE0JP+B49Tgg94Xv0OCH0hO/R44QwaPrGs7Oz2v+fnZ3FxcVFDAaDuLq6ipubm7i/v493797F09NT/PLLL/H09BTv37+Pu7u7uL6+jsvLy7i4uIiIiOVyGYvFIubzecxmsxiPxzEajeLr16/x+fPn+PjxY3z69Ck+ffoUf/31Vzw/P8doNIrJZBLz+Tzm83ksl8vYR/Pgtt9RJ8cmn1c5397exsPDQzw9PcXf//73+O233+K3336Lv//97/H09BSPj49xd3cXw+EwLi8vYzAYxPn597F+Pp/HdDqN5+fn+Pz5c3z48CH+97//xb///e/4z3/+E//3f/8Xf/zxR3z58iWen59jPB7HbDaL+Xwei8Xi1bLelxwht/Pz8+qADCHH4XAYd3d38fDwEO/fv4+ffvopfv755/jpp5/i8fExHh4e4u7uLm5vb+P29rZ6fnNzE1dXV5Vcl8tlzGazmEwmlQ6Px+Pq4NdHo1G8vLzEy8tLpc+j0Sien5/j27dv8eXLl/jrr7/i8+fP8fXr1+oeTKfTlXswn88byeEgFn6xWKy9hhuC5/zYFK8l06nD3YOIqAaH7DNdRxNj5XB+fl7JRN/j/s7ew6/za/p+lnPd922LVgm/6eiNEVhHYh6ReaR2N8MJ8tRQ+v2bKqrKuaSwxwT9Tap72f9xOJngvPrIn49YHUj0O9359Zpfg9YtPJMeLru+HrF+A5jkFxcXa+R3N+SYFXCXcLJVqLwgU/5f6ag7fxdR+n3OuOjhjFBJF/n/fF73XXo+vk5+7bVoHMNvAo2jXVzN8V+mUIivEGux4PFYUsZTnQiYKYezRJkXlSk3Dr6nxyDnzOIqwVnfzs/P13SQ/3bkjYhKPtBPHjQGg0HM5/O18+iB//F5d2HYWiG8AxQks/I8kkFYEd8HhuFwGFdXV3F1dbWSdFKhL5fL6vM45zEoYxtw8SP+ZouRhU5q2VS5F4tF9Qh0XdaZB8lEHgwGcXl5GZeXlzEcDqu/oYP4H/RQiQmdgw4PBoMqEQ2iLxaLWCwWcXl5GbPZLKbT6do1qG4jsc2/ZRvsnPDOsuMHwqrrgRvA2ePFYlH9yNlsFqPRqCI9MqKslHxuKKIOAKcO5w2x/JXgjgyz2WzFO1PZdk3WJcvuSA5SI2t/fX0d19fXMRwOq7+Z/EpS1jfoPntDPNgiu46qFD/nc7On0HkLzxaAia8uPRTs8vJypbRxdnYWi8UixuPxirCHw2GMRqPKRWLBLhaLFaGeOjQRpa55ZuFUoXHowK33smsyZ9KrFwNyQ6eY5NfX13FzcxM3NzdVKY7/B13kEIAH0ez7mbRcisYAgJK0ehH6+W2I3xrheVRTomsNF4SH4PHj5vN5XFxcxHK5jMlkEjc3N5WgcaO0HnkslmfXaJKYU4VnpWOLzorGVg8uKGTNMu/6AMuDHf9OJrwjOpMdNXcmO2TDOSU2Opy009ic38vcmM/nMZlMKoPGnlfnLbwqBsjJVh4/goV/fn5eER4WHiMsRuOrq6uVppvMrX9rcNYj+xuvZcmqiKi18mztLy8v1+4flLtrss6y8mzhmeyw7iA2N9eohedBENYd8oyISgfVo5jNZinh2aV3OSvN4ONxE7m3lqXnxwhv5dnCa8cTiA7ij0ajlVgKcTwOHkjgquLcbz2W5wQcv5aViljpkStx8bpae47jLy4uYj6fV+diGXdJ1q7Epd4MiMVW21l6DAYZ2ZXw8FxhuAaDQUwmkzWyz2azKnmHDj2cW5OnnbXwTHqN9dyAgBuBGzAYDKr/z2YzS3aXvJvP55V176ICvgZ8szWOU1K715js/BpbHVcaqitHcZ6my7Jmeejv4jievUiN6/E63HhXOnPGRuUfETZB50p0buDG+fmxCVp16V3G3g0AfBNA+svLy+rz0+l0xcXhTD1bHRaqZqLfEpzVxut4dM0cWQyvA0CpRKeewjZKt29kA5/zbDSu18w9u9lajmMZqMcDA4T3LBaLlXNk9+u1Fl2xlxhea+8utsbNYKsSEVW9Ussm/DfqmIiB2Mq/FZTccya4KqISVy0ND7iOGPrdSnz8H4/HMLjy73C1eP47q1TUudmQBcsXgH5qGBCxnvNyTWuvkfHeGm8ifsTxCpdQgSBdPOkSSoiD2LVXK38sCqnI3HZngZ1r7iwIJ4D4O/S5egdZSHEMcINa1nuQydXJsIkcMLiqpedKFief2XhpDwuQVaRK2CvhGe4iWRCaXMpKKbDwaFxwrj3O/dbIzqTX2rK6qK4kxHCKxN/rrqnutS5B5civuwExI7jmQbKwRmvx7m+uwXPibjKZxHQ6rfQapVC1/NvgINNjM7IDzhXVurA7slbHritjHUqW3WXVtZ7OGWWVDWPX7mNX4EIPzUs4r6gUU5fInn0ng4kOkut8eRCfBwQNhze9P3ux8Fks7eL4LJ5k5eaM6mQyqYShh3PrjwlqidQaaUYdpObkkpKdu7YAl1A9Nlk1RaZbzoKXiK7nrINLVmeE1wUzYOldfL8pDu7Sly5eLRqX4pj0cH+QvNNkio6yx6TMWTKtrm6eNc5ANhH1VZQswXpMcIMmh41ZHL9NnO7gZKxNNkp2Pti15xWdtr0ne8nS8/OSJXGxo1p4NEigNspkn0wmVSOExqrHaOEBTZqplXftsJpRVksW8cPz0nkIrl/C9VEwui5bjeFLMbsLe9zAUUKm99x4xtYdpH95eaksPMfzblmxbYjfWgzvyFx34H2MLHHn6qSs5C4bncVTxwCXsFMl1cOVkjR2Z7mXiO88smNz/UtZeiQ76+riTeBkxV2mnGCeTqfpendMdjfjVL+vCfa24k3247MfERFrCs5k1+mKOJxyZyN1l+GyyiW3ftMyUoQnO09u0h7v7H5tUx7aJ5To6sJzuMPPWdY4D9fNnS45r8jJEWTGI6x7KXHnZppuilbmw7tMr5st14T46tIr2bn/eDweV9NmtTECdVC9ri4jS9oxebPau2agszjUDcTqerpZjnXNIYdGllxj0msDl3qKmKrNpG+CzH2HRWed1RVscTDpOQld8oaboPXpschG4pEbC3jSvw4CEauTQpx1v7m5qYSH0RILZbDVw7nPz1dn0B1LXF+qWkB5XYedkl2ReVx8T9x9cveKz9klOOuelXjZU+Qyb1PXvhSrM9k5SYeYXZesxoH3awy/Lel3Svjl0jcXgOxuZhCPXohV+IfwTcJNAdHZImHRAJ7xNJlMKtJr3z5fb9eUlJG580x658o7dxQohVdMcm3+0PvlrD3O3zVoKORWtbm5uVmZPKOVjVI46Fx511DD8TrIjjXp+VGTd47wB7XwTHZVKLXqbJVdN5FaYbbwILwuxAAX6eXlZWXDBZ5Yg2tySasuw5HdEd7F6xmy+JKVVC2TdjQ6i9M1ONmxLoHoPAeeF7jIyrxZ4rOUiWei8/Ht27dq4wkcvCmFi+O3Tdq1PluOLbxzbZT0qkTswuImaTyD3VMwZ56bTlCqyzLOEcfn2mfNN0r8UsyusbqSnuNMLQ/xfdpFu2dbcN6RuvNMeLeiDbv0dZl61fdSnZ13mwHZsbMMrDwn8Disek3upPUY3sUyzsqrMvECGRE/4nhd1oqtu641pgs2aNmp6404alE4JndtoVmzCCecMndeLbyWj5xL77ab6oIsXWWjLvnLi13wMmpuxZm67HzJwnO8zu47XHi8xu585tJvg1az9E65HNHV4uP/l5eXK8IF6YfD4YpQmexZSykTHmTvqmV3CuVIzw0kEbGmkM664zdrXiVz691A7Cx7Vy28uvRnZ+urK7nyLq9Vpy59BhfKuuYabp/VhJ3W5DPPF9+3Kfay4k2WDGJBsOs4mUyqltnl8sf6dBGrSTxOvOhEGm3E4XnyiOezWKzLKBG/Sa8BD77s2pemZer/XBm1K7LLLDvH71ySc1n6bAkrLZNmcBbekV4P1OPbDJ9aj+Gz8kT24zHSYQEMrGIb8aMFNCuvuN5xJrzLKnfVyjNcPZmf1xHdWYbMlXdzskv9El2SnSO7kt51bOp685qhh/5t0sugFSnVcbXkeE8WOu3Kk2p1qykVgibtxuNxXF1dVbHL8/NztYX0+fn3xSuvr6/XJnzA2mu3mdtIwE08yFz7YyB/RD53ve7a1d1scugA6ZSuSzKrI7lbU0E9RF5DgJOj2yTt2INVVx5kdxa9FDq9Bq3E8BwnqjuPBJveiKurq/j69WvVTLJcLitPgF0rTbQx6bUxB4TnBKBzZyO6nanPyMU9C+49dfmUJrF6Rnr3vYdEnUV304jdJhTZunVNyZ4l7LSphq083uO8ql3nSFprvMncm8lkUr2fM84YfUG82Wy2UmbTm6DnyPrs3UYVGpt2mewOjnSaLc8+47LvrrmGZaYtnThfV6AhjvYqMNFxMMHZQOg2UnU1eMBZdyY7N9hwdh6k116UtpKhrbn0mg2GAOoy0Jx5R5lE9/iCu1X9CHHV8H6MnlpzhkIjc88lOlx/1+CsKx7V0jurrsm6rLFGV1iZz+f2XF2Ds/DZ7EF26dWyl9acLxEdz1nHuBynRM866bJ4fVcy31sdfjab1bpFZ2dnK7EPRsbr6+u4u7uztXltqtD1xUF4rduj+46VpCuKzJ6SAxMcz905+DmHNNoU4ursbVqZtqDVCzerUBcLcZNndJ340uSjiB/ydR2l2mwD4qOTzrXO1nlrr0Hru8dCAEqsTFnxOmfyUZ6LWCX5YDCoyJE1VSjhdVWcuhbUY4daCPVyeABkK7Orum/b0ApFkxg+WxnIrR3QZPIR63IWv3NTjWub1cRyW4Ps3lprmfgZ8D+1Qjxg8E3hfAHXozWehyVz2dfSyH0oZGEPoJ6Nm6NdUlCOzV1JLqu34xxdQ0Z613QDHXEr3GStyXVkZ/3Oys+w5uzKcwbfxe/8PbtCqy49rDmXvzgmVGjpDUoXERWRNYbnKa8uY4+4FMQfj8dpq2SX3HoGX59OfVVljfghR/5tLsmXlec4ZDg2OEuP17gdOWJ1r3ZN/OlrgJOj9i04smdHtqJNW2g9aYe/NTEWsSpYxDusoPg8LDvXTGHh2SvAd2SkB9nrXLYuEN9ZLe0Yw+/gedt1fd+aTc4aapxlP7RMmkItM2SncMlO9oD4UJ3QQdMlQLWxTBe40LnuR+/SA6WYPWJdufUHc0JOLTwe4Tnw+zlrDyuPPbe7THbAkV0TUNnKtNwowoNixPqSVlknHd7bFXlsg6YlNLbOvHAK55z0866/xC05zdNgdcELTpoerYUHnJAy0rv34b1aZwfhl8tlZekx0kasDhA4D26EW8mki5l6wJWZdPKHm0Og7qsbTDPCH1NmvoS6PIazzJqZn81mVZs3yzAjuyM6J+p0gQue637UFh5uEP+9WPxYZqrus9oFx4sWgKyLxSKurq4qd40Jf3l5uZKcwk1p0lThXLhDwGWcWQ7cYMTk59VXtQpRR3bn2h8L+etKmS65pm3e3NF5dnZW6RGfXwnvrDqX4niBi1JZLkva7RJ7XZe+SVyvigaFZ4sGsuJ9UG6cG//D37DwWRdVifSHgks+aW5Cuwp1lmCpWYSVS4kOUmTK1lXy63XpQJXF3OyKZysHsRz1PC5mZwuvlh1k1z56LoW2hdYJn5EGhALptV4fsdo2yyMnFBvEZIvGAwTHrlgRh0mhNxbXgsHj0IrtSkt17rzzXjhhpVZbm5JY4TRb7wh1DFAjUyIrD5aI4yE/PJaSdby8tM6OYzeey3Ft9s4r9pK0U+upB0NvCjeGaAZUrZirs+Kc0+nUWkGOdzVR1ZZb1URW7jUeADVph+eulsxgMmdWvlR3PyaSa/jCxGKyslvNcTwn8GCU+Nxa2uTmJT6fPmp3o2tyagutEV5JrfVj1zxS50q5GAxz3UFYnsrIrr7OhOIFCnEOtQR4fkjowKjLWnEij+vy2YAaESnJdQAoDXqHlksJmizW3+UmCrEhuby8XGs+0pmabkDR50622dyEfcmzFcJn8Sc/Zoeb3cSWOBN6xI/NAdmdx0Awm83WZkTxbDrOF9Rlqg+l7FA6brDRgbREdKeMWXedKu+xQAdtvo/8G9VrZPJnK80otNafea/Kg+z9+0DrFl5rx84audZGvJ9nymXLBmvszq23GFHn83l1Hp55N5lMqtzB+fl5Gk91CWp166oeeG/mimbddurCdk0OAAZrR3YND7mN2M0YrFtWyhHcGSye16FhF5eLdXBtuzq0c8KXSklcMnIxNHeLsWXnnWawIg6/P/sexF2w2li3HtNux+NxNSDg867N8RCxvAMrMa4Hz7Okj3u/kl2tm6sJdxVMdjzyPXdzBVw+SKcI60IgLIPMmLl18nSzC26y4d+Aa+ZBBf/bJXZK+MyF0bpxtuMrE11XJ9HliFCe01lP2o2Hsh0SdyiRoI8Zgh8MBivJFLiAEetznvcJ/V5NurHlKmXc9TOq+JxY0t7wrpMecEk69tp0sQ8tySF+v7q6stOEASX7YrGorDZCR172Wg0Mck2c9d9XTN+ahWcXmxf950MXn2QXSHveQWC49W56oy5qAIGiUwoZWaw0AmKfn5+vLQtcIn2bwHc0aZaBQuKRiaoDAz9Xd9ateqOTaLpu7dWdhw6C9M591/XmYFR4IRDNoHPJGN/H3w3Z8gYXd3d3awYGyBKnbVn51mJ4jmd4FZq7u7tqOx+3iKC6/ppx11jJLWLAW0dHNCM8Bhu2ckr6iP1a+Mw9x7WB7C7+VkVVd97FsW4+/DEk7ZiIPNCB8NAfR/qSS581w3BfA3QL362Ex3F7e1vli9hr4mrTbLa6KEsbcm/VwkPQsMQqBF78n5/r7q8MCFvP7Sw8boa69LgBGHFBeIz2KNOwteRY+BBQ685kV3dcXXLuSmQiu7hdie5c2i4OAHpNbCXh1nOLNZNb56aXVpBl/VMPEgTGvgocu+PAd0L+6CfRRHRbaK0s5xJ3SnwVBtav4xjcdYnx92irKRP/8vIyIqKK5fEdd3d3K4sOKOHZErjk3b6J7yx8lnBTEnPtVwcCDQNK8T+uo+vQnAWsvA5obN1dPO/cepe8g37qpCY1QvoIopdavNvAzgifdYhx7ditG8/k1507tVuOlRHg87o2Uy638cDCLj3OUSL8vl1bjgszl1xdU+eS8sKhGak1Dj3UwPZaZOU5zthnzTYay/NAgM8gQac9D8iua1mOy8xZj35EedLPrtFqHR5o0lyjLhDPXIr44S6xokLYSnhOAuJaMKqD8Lih7CWwApRq0vuCI3sp8aR/I+sMxSzVmB3YuzoWaDzvQiF27bnvnXVRZYr8juo14DpIMyIf0lNqtZfedRyViK/b/sDdiVhf9gpC40qAzoaDS48bjVo+CB0R1fey1ec4HgNNF1x6texwC3VXE97CCDJwhOffpCXVY4bmG1zCEwM+lqDi3g6uKrHF1/Za7fx016FhmAu5Mk+rDexl8oyDa15QS4+bwA0VHJNG/KgGgOic8cdnQVi22hHfw4HhcLiyv1fWeHGIbHUWu+vkIchvOByuTMl8eXmpZnzhkedeM9l1MNZ2XbwPj8cQz0d4K4+Bj2ezqSzd+gKuVAb5uGqKhhG8/XPd8lZtYS8r3mQ/oER6uPUcx2vmOGK1f57deu69dw0O5+fncXV1tUZ2jdtceWsf4HhULQTW/mNSYtC7vr6u5lxDfhzH6pLIGdmZ6LhXxwaWoVp4ZOeZ7FxZ0jBRZ2binBx2sp7od+mKOOxhZMnhNrDXBTCy0pZr0uHRVTP16orqgMG99BFhlRcW0SVnMgu/L1deFVUV6OLioioxQkkHg8HKRgeoeIDwg8GgKk1qa2dm1fdRJmoLdTLkmJ3X/2PjoWVil9PQtl7WF50Syx6Fzoffl2HZm0ufxVURnrTaIquxkrqjLjuqDQxYDgsjON8QdvW6krTD93F5SZsz8ByekS6DjBwIl+ZYubQzkr0izb8cMzLLCxlpCVlr6OxtqvwjfiSGI37sq8DlP90pVhewdA0+bRC/1TXtABaE/i+iPBlB3XqHOpd0uVxWngKOy8vLldJbqetM8wb7gpbK+DpVZtne4+zWQylRjizJz5H9WElfyoVgNSWN37UZh3cxZqOCTj6XZ1K94qy/a2N+kxY+y3RzvV5n02XulH5WLZMmWfj/OK+OxnjUTOqhYng8Z8LO5z/6wzlTzx4LHyhPIo7n38Ky0sPh2Ejv4njNibDbjX6MrK+hlMzl74jwOyg5o7IvskfsOYbPXtNYW8t23LigcFYoU1jtU4blw8H9zPP5vFoJR2/soRpSVGG5oYRDElVMeASc8My8LJWf5j+OCTxo4m8+HOldB6PrXHRJNpevcoOM9gXs06DsvSxX+kEaTyrxMyJrQi57D1s1ZOpZ6Pw5HhwOYeEBl3xyLqpTIo3XOaRSl54f9flbgBJTyaiDgJLbvVc9JZybv8/dl32TnHGwOnwGjR23dTf1PXxTmED6nXwTOCFzKMLrtZeIr2VLp5TZb1CZR6x3kh0rnPzwnAdHlR2/lsm06fe7+6TXoh5CGzgo4TPXXpXLkb2kgCULz48RP+JaVXh4F+oWdgElpXPuJX+mx3fUyW4bMut59Ln7bN17do2zZa8FPXqcDI5vdkSPHj22Rk/4Hj1OCD3he/Q4IfSE79HjhNATvkePE0JP+B49Tgg94Xv0OCH0hO/R44TQE75HjxNCT/gePU4IPeF79Dgh9ITv0eOE0BO+R48TQuPpsU2miPIUVrfsNO+keXd3F/f39/Hw8BA//fRT/PLLL/Hrr7/Gr7/+Gk9PT/H4+Bi3t7fVzinL5XJltxAs0vj8/Bxfv36Nv/76Kz58+BB//PFHfPjwIT58+BCfP3+Ob9++xfPz89oigroe+KaLVG47yXCXU211Ki/vpsubetzd3cXj42M8PT3F3/72t/jHP/4R//rXv+K3336LX3/9Nd6/fx/39/fVKre8MAhWgnl5eYmvX7/G58+f48OHD/G///0v/vvf/8bvv/8ef/zxR/z555/x5cuXatVcrK2n23e5eeWHlmN2bpYrNiu5v7+Pn376KX799df45z//Gf/617/iH//4R/ztb3+Lp6eneHh4WNkyjfdIxBz78Xgcz8/P8eXLl/j06VN8/PgxPn78GH/++Wd8/vw5vnz5Uh2fP3+uXv/69Wuly9BhnrvfBHu18KycbgmlbJELNz9e35+9N3vtLSJbxKJuDYHS66U1Ct4ySouwRKyuiKSfKy3Swufe5npee55WCO+WnFLF4SWRdSBwilsncJwju47Sa28BqpS8DiC8LX5tmzXnN/3MMcpaZci6yRZfF07VxVPrFm3R15oYrl2gNQvvyMor0vJSv27L3Gwl2ghPdqfwDm9lvY/st/MagCpn3dHUKSnOqctbOUvmFrhkAhwbnB6p3HiLZ5al09MSwfV79Rr0uft7G7R6V/jms3B4+Wld61utkBKfN0pwoy2+t3RNimMdBFQxlOy8Oy/vqJvJN2J9HbuStVIL5zyApvfl0GgiS7d8utv8xA2mpe/k73Vk3+Xg2ep20So8LPLP+3axMrJVUsFFRLWwpFqp17qqxwwlXbYNN2+bpMTPSM3n1+9wrq8bDI4BGdl1QxTeqJR3KNYBwXlR7nv4eUme/NnXYueEz1xMFZxTRN6p020GuVwu1zwBR/7MBX0ryJTEbXeM6giyxrojammnmWyBT1VKtfL6Xj6n/o6uILPsvKEky5C3oGLvCbkSPSJibSn0zINS6+4GjG3RioXnCytZHQhNN+3TEZNJv1gs0vjUhQBvDaowzrJrGRSlUCa+s/J8fmfVnSK7qguT3l1zl0Io1VU8qoHSwRNy5N1lnauvsgHps0HbyXWXln5nhHcXxooCskNQvFGfc5U0LgLR8YibkiliJpguKdumUOXU/AjvegrZgug42Cop4TNLn1kfJQk2ruTBtktWPAP/LmfdmewsR+edMuFZPqWeA34f/ubriliP47eV615ceggBgoFCqluvrpFLyM3n8+qGYFsofi+ugR8jjpvoEbklch4UlPT29rZqclILD6uUZez1u5rEmXyt2f+6hCxU4fhdvVHnlWpCVI1VNnDydfCAo8nokgHbFK0m7dQCsULq/tsqNP48j5DYsZOtvIuXnIDqSN/VQSFz/bLEkrNIapkg66aJzszCOxzTphfOa1ILPxwOV1x59UxZbzN95O9iZJa7Lp7fFq3F8E5JmfSuXOTqm0r4LJ50bpDbSWQXbZ37RCZHHkQ5sQTrzgQvWffMCjnZlLZHynZy6aqMnYV1HpMmmV0IigSoKysrQZ2XBNLr33ydu4rjW91qit0Urb1DmEp6587w+Vg4WTY0Yn3DwGMjekReEgNZOQnKcxR0vgJcexfDO6+I5eNk6DatdH87mXdB9pkVrfOa1Mq7uD2rUtRtWaaDQxNPahu0EsPzI36Iawhh8mcukRPCxcVFFb87V9RZntduCLhvOLJnAyYU8O7uLu7u7uLh4SHu7++rg+N4KKqL4Rk8ociRWbdUzvZRL5H9ELJXsqvBaJKw07yTaxxzlv0117sr7MWlZ0Fk3V9ZIwifk//nkkmAWiG1QF0meoQvvWXNNFBKnn34+PhYHff399VAwJaea/EuDMr2NgfRQW6efej2pmfZA4eSfZ1lV4PEZIfs2ItqkvzMwqPSNe6a5Iy9uPTq1m8Su8PyoH6p/1MrDyVl5Szty91V8mcxJbuWHKPf3t7G/f19PD4+xrt37+Ldu3crhNd6vE7dVBliyqUSHWSfTqcV2TEFVqfCMtkP7do7sjvL7noZMFjywJmFR84AveZ6d42dEj5LhOjoWXLpHYlxnqxpgaH7e6t1yqx8V0jvvCPE67Dq7KLzAesOwsO1hyuvJSWtikSElR1bcJCdiT4ej1fmvuuaA4f0qpxOZjkRR3bOhbC3pBUP12jDj6/ZfnqX2EunnUuIIE7i17L+4yweypIaTGrn1nfVumfWhxNImQLiNRD+8fExHh4eVtxPjt+zTruIH4MmD5LOsjvr7hYYyax827J0cs0MEXudGBDZldfQiL2kunD0Ndi1rFqfPBMRqesEy1Xqhc/IrkmmiFXr7mJ3zTgDhyZ9lkhSheSGGsTrrIgPDw8rMTz+B0vEk5dcWc5l4dmd50e29Ex8jef3TXQnTzxX4+PIjkGROxVdA1MWwzt9rfvdOiCqR7RL2bWWpdfXXAktq6ez4mfQJBALx5WQXNx+aKJH1JNdlZHjSrbyeM5uPGfmXTdYlp2P8Jl5Jj0fGrdr/M4DbJsyz5JymdekzWAYGLVZiUmuTTc8YcZ9dwZneFROqqe7CI1aT9o591wTb/q+EkpNNNuQ/JCkd9aHOwizWjBbHFh2rbm7ONPNPQCyUqYm7jSBByvuXPh9elJOx/TRWXZun+WYXfMj2XwE9lC14pHBkdzlTlSXd4FWCe+gsZW+pnDkVjddD4CfdxWlbLF2znHXHMftruyWdYO5sCliXc7OrdfkZ6ac+66ElNx29SZ1INV2ZPWc1Lq7qdyZTB2cp6PExiMWqYz4ocuv9ZT2TvhNRnzMjMuSb04Ju+Kq18G5m6qMzo13jTUumaQuJ8fr6k2pxckabdTCM9lLSdF9Jel04ORDk8Xswmu9HbkQzpFwpUMnerFlL7nz2aDqZKuDqnpg26JVwmcXVnLL9bM84rGAXJNHVnJzNwP/36RksiuoRXLTW9nyIEnnYvbMEtVNgXVQmWdkV1mz9TlEgs5l39llZ6LrIMo9DSA8Vzs4J+LmI2RzEkpw3qp2LmrzEhNfz7MJWrfwpVg6K5GpxcGjI7sjfR3Z8bp+1z7gEktcetPpmFxj5yy8WiBNKGXlIr4OpzzOnXdKmLnvh5AlnivZOZxxrbIsOx4wueIBGWtpk8OkJjko1W0eVFWXudLBeyfUGccm2KtL7+K8ugQbFFMtTR3hl8tlmjk9hFVXuKy8lt44ieSsOrvyjuyl2DJTmMzdZBfelTozj61tObssvJtF6JppNNnJVh6WHu91rjw37DTJztflSbTioa79LmS6c8K7i9nEwruRkN/vYh0+MoGodT+kW19nmbhHPpvX7spEbl21CO/6qYfj7oObBZe59Pu28CpD59LDjedyG5MdnhJXN9S9z9z4bZttHAdUrqzXGGgj1st226A1C88KoNY8OxzpQUYVEAuKs/MgPcNVBroAZ51cFpnj+tI8BK2pLxbriybyIMCyxfvV63LJ0ixXckg4K+/i96xt1tXadZELbZ8thYpsRJoYwSxZygMqf25btO7SZxbEWYfM0td5B02VT8neBdce18Fk5Hq8a0iK8K73xcVF9RxAaINjsVikk2XwHI9Nwq8uAb8HcnS1d+2u08VD3NyOTerqPJDWvT/TbafTTt7byL8VwmcjWmbV9W/+zCZ1Tf1sk+vqCiAH/Xu5XK4kzdDGig01odxwBa+urmI2m625nTqAqEegsmuijPpePU9XUadTrI+87gKgunZ29mOR1U2+P7sOTgKWrnUbg9V6WY7jQWeVNLNeZz2cK5XFUFwu4mtyzw8FRybO1o7H4xgMBjEej+Pl5WUlIYfBYDKZ2I094Ibi/TwIYDCFkjrvqvQ6/tcl6KDjEo9u9t9kMlmppzsvZz6fr4QIg8Fg5fcjccdeWhNoKOI8uiY1/qbYKeEdmZjMrMwuJiy55tmPbjISZmGCu+59gN0+VSooH2d++ZrZ8mC7bF0u2XWD6UCwWCxiMBisXZcqO//PPXYFTpZa6hoMBtUkn9FoVCXhQFIMBldXVzGZTOwCqyxDGLGrq6uIWF/WqgQmOicb9ciSgtuSfi+NN86Fz5o51GWM8M0VsFg6KrrJIO44NJxVn06ncXFxEePxuHITdVoqXPlv376trY+um3vwdFgdDFgOnMjjR77WroHDPR3ItXQ7mUzW4np20xeLRUyn07VVlPm51vSvr69XciXD4bBKkmYue50ea55Byd85C6/IboQjPA8K/HknIBeTKtmbJgP3DVVUtkjT6bT6n1qb0WgULy8v8fz8HDc3N/Hly5eV8p2bXMMDAsh/e3u7kr0/OzurvIk6t93Ji+/XIaDe0nw+j/Pz8yr2nkwm1XvdtS6Xy5hOpzEajVY8I03k8f+ur6+rchlbZ+giX5OD02Wuuuisxrrmnk0GgNaSdkoste7sdrnsrworI76Le9y1ZNnlQ5Gen7OlwGscw0MBnKXmZZRRVkJDCRN/MpnEzc3NilUvTZM9BrjQiD0jtuAaonBoNB6P1zZD0TCJm3em02l1zzg3MhgMqioID+wR+dTdrN+fiZ/F9TjfJjq8l6SdltBc44bL0kf4JpVSksN9f9fdepedn81m6QQQHgC0O+/u7i5dXw5KiM8im8/Z5S7IZRvwfZ3P51YfeFYf6yInPRHXK+HhKU0mk4rsFxcXcXV1VQ0ASOzhOkqkbBK/w7I3bY9ugr102kWsu9gQdlMy1mUtXdzkaplNrnUfULePKwqwEnq4AYAVEgTP5IfPDIfDtTXnVG7HBE00np2dVaQHtBIUEdVvR+jE8kS+BMRHHkXJDpmrEXNQHXWuvSbseL3BzsfwEeWVafi1TchYl5F3Qj907O7A1+RcUx7VM2sA68QtmPoeXY0VZC9VRboiozpoTkRJz+/jx4gfgwASorDWeD4cDlcms8BAnZ+fV4Os/k91XK2wM0yZ0dJq1C6w9/nw2yJLuNV5Btn/u6LQztJz7OlKkGwRVBHZxXcLS9YtH33sUNLz6/wYsbpYJ+TCAygSp9fX15W8EBLBvYfVV+uuMXwd3D0uWfNOluUiypvludcVLsGluQFWXIzq23gOh4L+RjfSR6wOAGiTxfth7UFsXmDSzb7ikCq7libI7u++oTJ0uRH3voioYm/kNLikNxwOV3QMg+r19XWMx+M1ub528HQ6Xfd7N8FeLby6KjqK1bnqbAV1RRA84v/uZncdau3x6AYA/k0oQymx3aKTXArFd2iSKUPp/4ciesS6NXVkzw7XQYcWZR0cI77/zqurqxiPx5X7vwuyl3JNu0Srq9byc0fuLOmmrgzfoBLZnVt1LGQHStfLis3Wnn+7ykSXomoqHxdHuv/rc36NY+t9QL8HpOf8iOoSW3TOdzivKOK7Hr+8vMTt7W1Fdjc9uxRe6rWWBiM2cu43borWLbyz5KVeYYa7Qdo6WZq2mSUMjw0lt05lA0UB2VVuKqPS+RmlZBM/f20WeVuodwTwwOjkhVmGOqvOLUAR8b1nHu48h0zsMdVVOXTAdfcH52DS78ID2HvSTsmeNRQoMuVWy+ZG2GMlegZWbv29THrIh5ubmjQ4cRuqG5S1u9ElmpyXtk+56GtMepQ98YjuPJDerTITEVVH3Gg0WtthR70n/c36+1mXNyH9a7EXwmfWvM5tBJxw2O0pKfFbJju/hkcX/mSDZJ1LnxHdtTdrq7M7dxdIj98GOejvcqEh3oPJN5oMdbKs+616n5wHy6TflW63urfcLt6bKU2di17nVr0FZHFgZi2y5BI/z0iedYLx9Fs8z8Iq/a425eLce/WMOAfChGein52dVZOatHNRB9Mmv09DKR6Y8ajnLXlmm8rzaOrwikNmhbsEZ9WZ4DpBqRT6gAwguq4Uo6vFoOaPujRbO5yPKwJ8zfuQi/4uvM6/08X3+Azc/WzSlyNi02vT+6X3Ktufr+631mEvK964Ecm5oQ58s1ysmMWNxzgZ5LUoWQyeqFRqA3Wtnrrss5u0o229sIwAJ7Wc9W1bLkCdu4/38MDJORAXIm1K9Ij1xTlwfzQRiO/fVZ1/78tUR/jFEkuuShb/ZxNoNF5768iIntXm1RJz+SrixwywxWJRLfSgJNeNHKbT6cp5XMUFMfMh8yo8IGWGST2liNV5Geqx6LmbXAN/hzZKadPUJrmXOuxlEUuN6RzJ6+KSUkxZtwjGWyV9JkN1C12SSZNCEeu1c5CeFRPTbHnTyvF4XN1TnrPN5+K4ma//UGDi82t1h3svn8/9pizs0nvErdDaEu0s/Dbya3W2HD9XojvXMnPvS1Y9WwxDiX+o+vC+oC6iKpK6iqxAAIdICgwkTPi7u7t4fn6OyWSyZt0hex4wNHvflftRdz3qmWb/b/pdmmfJCI/XtN9k0+9k7H22nGYdXSa5jvhQSjeNcJNy3zGDY2HnIqpl1/nxLoYvlU6Xy++LcuiqOvf39ysTSHA+jU2xhDbcepzz0FCvw/2fH/n1Ta+f5YMBFPcGi53gUOuelZ83vYa9bEThLLpmOvUHlWIVkN0NAJlb/5bgXGOWm8brvECD66nX0hnKU7yi7WKxqGJ13RFnNBpVg4sOOm4xxq6iSYjpXPttvkd54CY9uRmOr4nfI/aUtGOrreWHrHSE51w2AdSCd12R2oIbVKEUTHaX/dVyj7raWMqazz2fz9M928bjcbWwI6+8m22n3AXr3gTqzm9ThuPz6ODsMvQld35byw60uqYdflREHmPqXG3tX2b377Xxy1tFpkDOtVc3sRTP45ETpLxcs9vBZTKZrKyJr9WTY0HJ0jdJ6NWd2yXusiSrNuO48KIp9ramXVZ6yF6/vLys3HJYBB0MXhPLvCU4q+GSdpy8cwrFPeUAEm2uPq/722n1RPskugwOk1SvNM/UxPUvfY+e1w3SWUXlteFE6zG8y0gigYQVWYfDYZWowMFLJ3Pvs+sac7G/xqVvFVnSDq2gg8FgJQOM5BC7jDhA3Pl8btdScxWSrFoCz4A/r+jaIFBH9Cy5jM+685W+h8/FyVVH+iz8OnjSzgnN1RpBdsw+urq6qtYGhzsYEZXViYgV4TiXdBd1ymOBjvTOWvA2Six3Z/Vxj9xUZkZdE5TreOx61aSJdWe5ul73ukSfnlvPC71m4vP3Of3eBjsjPCd83AjJCjgajWx/Nu/BDVeSkz56LheLZvHUW3b7nTuv1YvLy8uVsg+XfjAY6G6pKKVlcstKeMcYszNKpFQvs6leuTwAt87qfncabmUWflPsxaXnBN1oNFrbWUMJz/E7GjbUrdeSk5vzfQpJPlgnVDQ40QnSwZtSsjPpkTdxiTYnv9e2LnctU6/XwvLEJBqtcDSx7tl34XNq4ZXwpUapTsTw/MPVwmNF0PF4vJb5RRw/Ho8rKw9F5ngSQsqSGqfm1jNxNG8BwiKeZ6Kraw9F47o5Blt8VwYm/bH3P7DegPSsV/AitwkjXbJPw1Q3/+GoLDyPYkz8y8vLlUeXtIPyIImE82bxuxL/FMA3nn8zx9aatOMYP1MwuPPH6pZvgyyWd56jepCbkDA7r7sPTPZSeNUUe2mtZZLix5TKc+w6nZ2dVdv3wJop0TNBvGXrrlArD9deZcU7qDjFyjylOvDAcMyDhEuGajy/rYzcd/E5M+OV5aW2wd5aa9kNch12THKtC2urrI54u3B1jhnq2mcJJ1UotTJZ+YcTsacCp78uAbypTDJ9xUCdeRK70u/WF8DIYhYdKdV1cck3nSzizn9qigkoMVVRtIrhyH7K8mNkrn1JNq+RmbtfTs93gb1kWDJhsTJGlPd0ryN66btPBZkFqhuAs/uya2U7NpTInenfrmTlzrmLc7dKeHeB/BorW8lib3reHt+RKWSdIvf4gbYHPh549/F9Z8v+DvfocTI47qJpjx49NkJP+B49Tgg94Xv0OCH0hO/R44TQE75HjxNCT/gePU4IPeF79Dgh9ITv0eOE0BO+R48TQk/4Hj1OCD3he/Q4IfSE79HjhNATvkePE0LjBTB2sWwRr1WOjQsGg0FcXV2tbE6InUlvbm7i4eEhHh8f4/379/Hzzz/HL7/8Ek9PT/H+/fu4v7+P6+vrlaWtASz4MJlM4uXlJb58+RJ//fVXfPjwIX7//ff4+PFjfPr0Kf7666/49u1bPD8/x8vLS7y8vMTz83OMRqNqnT3dzC9ifcHIXcuR16TjlX2xnxsOyO3+/r6S09PTUzw9PcW7d+/i4eEhbm5uYjgcriz5HbG6IOh4PK7k9Oeff1by+fTpU3z58iW+fv0aLy8vlVxUPttueLjtZM0mcnQyZJnd3d1VO+Di+bt37yoZ/vzzz/H09BSPj49xc3MTFxcXMZvN4tu3b/Hp06f4/fff4z//+U/897//jQ8fPsTHjx/X5OT2eN/F+vKKpuc6uIUvrWuuyyW79c71xmfLJzcl2rEtflna0aUkN/defu21y1B3GdmCH/x8EzJmOxZ3UXZ72T2W4RRLR2J+nm1j1MbGB8e2NEAmS/5b/8eKiVVpmwwMdd93LGBi8zpym6xa43bW4SXBN5HpvrE3wtcRHYduM6ybVpTIX/oOh2wZqOy1roIti3pGLCeW83K5rD6n/3fKyht7Hjuy5aOydeUAlYnqIe/0o5tp4vOHxl4tfEZ0xFe8Aw1iLd6LHLvSOPLzeXGTSgOCc+GOZacaF66o4vFGH3zgfbwwKDb7cBtEdklZXwO3lp8uEV1a+tyFm6WddFXekDnrKOvqvrAXwjvXUskOonPy7vr6eiWBh62oQHwVriN8prjZYo5dXvpaXUm1yDwAKtGhjLwdtO6oogrLFks9hC65qXVwa/rpism6caPb2kllwGRHQvXq6qpK0g0Gg7UNOlkPD6FXrRM+I7vbVw5Ev7u7i/v7+7i/v1/JomIQYMLz3uS8B11EpNYqItZuulv8/9BQBeHXWJYlbwlVEFQyXJY+4jv5NZzCudRlddfVdTjrzju96FZPmk1XV56rJnwMh8OVvRVwPmzDrUZp32iV8E0tO8pyt7e3VRnu/v4+Hh4eKuLjbxBfLT0Tmt11dasAKLuO5pl1b2Mp4qYokV6tOZOdN+jkTTrVwsPdxP5++KyGTvrd2SDUhcEyYnVHY/zNgzuTfTAYrO18pFuWscfjws/r6+uVEqVuHcVu/aHQGuHrYh51429vb+Pu7q6qJT8+PlYEx8GWnmvL2HiSbyzHppo5ZTKrO+dqpYcieOl/zrVUsmvNnneHjYhK8ZHEm8/nK4ODkl57HY7FrXfhG/dp4Lfx/nuw+Dxo8BbmLF+Q/fb2dsV4KOlh5XG+Q8TxrRA+I7tzObnp5uHhYaXR5vHxsWqI4Fie43mNSyNWd2GBQrMrGxHFTSizxohDWi51KbN40rmaUEoQmMGEn81ma/kRjemzRKm73q5Y+ohYIzsIiGYaeDZMdo3j2aNiwqNZTHMAvF9iVlk6+qRdHdlZKSE0xO1s4ZnwnMS7vr6urBUUGN8Tsdo4o9lndfmbZGm7kLgrxe9MQI7BXXyZWXgoHvaIB+ERAmhYxPG8XmMXwQZAXXouB+vGpkz4LIZnl143Qp1OpzEajVbyTIf2inZGeBdnOsvOiggCwx3iWJ1delh0jkfZ8mj2GgStS9i5zS3r3Pl9EV+VQuXrynDqQfHACrlr0m42m0XEd0vPZGcL76oBuBanwF2y7kx2V5KDpdets115TnMmnCe5vr5eCQewBbomlt8E4esskEskgehwh/DILjyew6prkk6TdRE/XDcdcNx7OIZnRSgl7PYNtaJqXZX0sEBKfFY8PT/i0qxmn5UAjw2O8Ofn52uxtpsXwKRnGWeZeoQIXBnpQnXj1YR32Vq17C6JxG48x+hMfvYAYNU1kaQjJm4ol0EYrg7r4vguuPKMjPhKRi2jueYbtvCoZLASa6iQlTdV9l136/GY1eHdkeVznNc6HA4r684GzuU9Ig7jCb2K8GrRS2Rnq84zvFB6e3h4WEvKcXa5RHbnUmajKVt2JX0pcdcFK5/JWePqElk1l+Hez/LlhKgb3EuJu64AvxPPXbae6/IlneDypHqv8EBBfvasuiKvrQnviA4rwYLgeB3xOCw5iP7+/fsqQ89ddZyccx1grIgRnpTOwnPZREtyuoc6n/fQ1t4NYiXy66DolE0bQnRAwXvZLe1CLLoNNHnHzzmsy6y9uvbOrR8OhzGZTNb01oWfh8BWhFc3rlR246QcknCop9/d3VUJOszd5o467hZzCu2uRV9zjTTqytVl6Q+F7Lfx4OoIXyJ6nUfkiK/f48KLrkNr3plr7wjPbdfaiFPqumPSZx5pyVi1gZ1aeCY7fjCTHaTmrjks3MBZecTscIs0QeesXERZaBnZXcNN10pyqig6wJbiazcIqkvvvicbJNTKu+s9tCfUBKwPLqeTxfjqnjPZUYqbTCYV6V0ClD/f+Rjeje7OsoPsSM5pQw2sPJfi0DvPSbqsNLSJZeFRve6GZv30XVBiWPWI9XkCbhBwZN/0+/CoJD8Gq+7ALnmJ9E4fuBU54sc9qLPwTfJO+8LWLr3GelnWkttmnfvOE2RAdo1/HNk3Rcml77KFj/DzsN2RWZEe3+GSd9lz5/JrDM9kR8LPNTqV9BgDz76s/U5ceiU7E54nxoDgIL5rm0Wyro7smypydhM1RusS0SN8ck7LaE1Ir+R3Ccm6g72eLshmF3C/Xz081QUOrThxx63J2uxUl2PZF+k3IryLnflHa2aeu+m4sYYPEJ1bZndJ9rqRHM+z2vuh43c8qkfFpUouWbquLpVXU7JD8blUVdeNeGwDQWlQ1Pfp48XFxZrHyO3JrvlJia9eR9vYOoZXi1MiOy9qoQeI3gbZM9RZMn7PoeDiZ9dBp/Xfpm2cdQTnpCa/lq2+egxEd/knp1/asu282bOzsypjD9mhPZlLdK48h/OzF8HZ/zZludUy1S6edJ10XJIrEV0XtFCyu2vYFipgp6xdUV6XJ+GBleXH8wx4YKiz8K4kBYJrXzgvSc2Dghssuwolrxou53rzPAUOnZbLH+sCLpff+zrQVquJO52noDPx9qWDW1l4p4Q6XZDjdiW6kt0tWbXrJgWnkDqqdsFSuUFOE0ScG9GDJxe5ujngYnMmM+aG8xxxPB+Px9Vkky7OP3AoWWsms/Ok1DLzugBsmCDLbCDmHJc28miOpK14fqsYHs81WcfuO8fnHKdzjV2tu8580+/keOetI0uKwrrrxh08kOpUWBwRfmpwnWUfjUbVAbJnmyscA1x1iTtEsyqIxuIRseL+w8qzQVPyj0ajNcLz0XY8v7WFV0V0STq16iU33sXtEasdUvx4CnDelK6y4sIkWJOsnZMti2tEYrLjYMK73VS6Vt1guHxIRnadWagDAP8PcTxPSMIqOurBujXv6srAbVj5V9Xhs9ZC3QqJCV7qjWeSZz/0VEnPysekd2TXdmQnW47fM9JPp9MVsquL77bg6ipKFQ+17CzzrLGJS6NYtgphkZJca/Max/Oh92rX2HpvOWfptRbPP65puYjjas2eus81Jb7L9MMt49e6MpA4xdQ+Bx1gNVNfyoWwVVGy8+KOugoME/3YsvVZHM8DoxI6O1xZbbFYrCXpXNIObbhYfCNLUreBV3faMVFYaBFhf4SzLjxCwk1y59Dzb2PtlexuIOga+JrVvdQmj2wlIIbL0meWPjuOjewKJS+sNf7HyEq3/F7nAbjwQI2gW7e+rYRdxJaEzxouNPutSsSWA6uFItHB1kgFwCPgYDCoyiF4bVvhZJ6C5gz2DXddbAlQBi15UZqhdy69Dr6axOPDEf0YyV6y2gz2fFhv2SrjfBxzq84q0TVvpRa+bTQmfGZJS24hx4Gc/EF9GDFPqSWUSyWDwffLxUgcERsLqnTD9zXKbnKtfM0R61ZelaZk3d0AraTNsvdugskxEH0bEmVGCoYKXihWs8HfavmdR+YsP+6X6v7By3IMR3II5PLyMkajkXU9YdVR63XlI42nuKsMgmaUSF8X+5c6+g5p5fk6nduYKVP2ewAmuSOxvubev68mkbaRue5cnmRDxfp7dXVVhZ5osXXLWkOPWYfd/XLhL3tlu8LGhHeuIJN9NBpZ8vCohdIF14vZpdcGCE5SOfepqUBcLsDFcl0guxvAnMuYeUZOmfn3ZJn6rEykrvuxufIOmSxg1UH2l5eXGA6HKyEosvEgbsRqbwPOzQlXrmZht5smyexdYiPCs9Ko24MlednaukUAQXZMg3WLBLDl4g4+XoqJFX2TxJ0jt/tbY7NDK7YmL7NyUYn4dck6vleuMUTP85bAvxveJ8j+/PxchaEYELCtFM+GWy6Xdj171mVXxco8s0649C7ew2g1Ho/XBgRWHgiLCc8WXruduIMP66ezC4tlhptYZOdx1FnHNlyqbdDU2mfXz9ByXHbgvfz3sVt0hywHBXf++fk5hsNhZcU1kYkJM0jicWIz4sdkHF0cRueOlKoqu8RWSTsnpOl0uvJeF9PgMxAmapKlFkZY9uVyuWLxN2n4KAnRhR34uwvKreTVHIf7G88VXEVh97OUpNPPZn8fI1y4wnMJELu/vLysJJpZ7/EcCWXot7r1roxamjDWFvG3svD44Wrlz87OYjKZVO/jJhoAVh4zilwbI/eMg+zn5+dV7INRNHM561AnzLZH2W1R6klwJTiGJt3Y+1IFPqYsfB003MtCFJYFMvLoMHx+fq7k7OSDRDJcfp5FiHtUakrLYvg2jM6r6vBsFc7Pz2M2m614ASXC66wuuPYc52A21tnZ9/o7WjpLa89ti30kTDZFRmC29nAZs1yEokmDTdZUc+zkV7iwRmcLooysCTv3OWTreX37iKjCVG6QQoWqzRmiDlsl7fBcSd+EMHgvsvQ6YUHbDyOi6h931n1X6LpSa47B5R1KWXourbn6su6t5kh/7HBZ+VL8zmSPiDXrzTKC3kbEigwjojJYvLotr1/PLr2W53adR9rawju3nrOLLFD+HEbDy8vLakse7UTiGB2jI7bxKfVwNxWOi02PJQvdhPgZtLLCZOeJMY70Ed0fFJuAdfP8/LzSW/ZQOXHMZGfLrzvF8pLqsPKce4IOs0fApWmdHt5Wpv5VMTxiFlwc98RrNp+FgI32OJZhd4ddorOzs2pwyKZkNrlmfdRrcwd/pgsoWfM6V57jVO6b4NVseKGLbeTcZaiRwr2H3rJFnUwmK1NeubKEhV3YI5pOp1XVCdl8yBr5p8vLy5XvRt4LDT2uPNcGXrW3HGd9+bWs1AMBsQuPLD2mdrIgI6JK1tXNwW56rRmpM4+kC3AxvCM793drokrdVo5RX15e0gUu2gqh9gkmO/52+SUd5NUTgrx41R88x4rLbKkBhKQ8KCM5+PLyshbLOyu/K2u/tUvPKBHe9ddz5xFIP51OV9x2uFwXFxdriy5kyTrNyJauO7PqbhA5pKLz79HYzrny+vtVyflesGVHV5m69c47OlZkhHcxvRoquN+IvTX04feUFs+ABecGNN54JesLqdPtptiJhY9YbejACMZx0sXFRZXJ1BIcJ4v4RiBrr9YmW2WliUD0xmbWXn/fMUNzLWq1nFuPe6Hlz2OFy/Eo6ZFlZ3npDE/nZepACm81IlYy8Pgb+YLxeLyyWjN3S7ZVk3/V5Bl1M5h4ID0uHKU7nf3GK39ExIpgspjyNfG7QhuD9HNdhZN79j6XYHWJOzew1p3/2JB5pzoYqDHh7Dxk6qa3cpiFZh0YLi6hwqN16zryOZED2JURerWFd7GFIz0nSEB416CAgxXQNYa4mNuhS3X1baAeh/5+Dp1c5YLPoZYrq8FvkyM5ZrhcDnupXIFSr1RXAHKTZ5CUxvOIKK7tiAOD0a56TSJeSfgI33KpMUc2APDoCmEwwZXor3G3NcbFY6mevc33tAUluBKV3UomPzwml5/QWBWWDO9nOPnsMpl0aGhuxE1Wcg1OLqPO7+cKFEKGiKgsfLbmHRvCXSaTX014vQjNhrJisKLwCIgBgBNyTPgs3t4USvLMHdMa6KGU2iWTNAHKYQ+veY68CayVekROgbLBtJQ4PDbSu6SnzjrURjBuBtO15rPNJnTpMUzAwT0onWMwGKyEuFmCehvsjPAAX4gjvxKKlTkiVqwNW6KmLnyG0g3Wg3sL9DftC87NVKLzSkJuRtam8ipVAEqJpGMguxu02ErzTE1uAHMHrxisey7ocuzsurOFx1LWOohgGeuIqBLeu+yF2DnhGRn5I6JyNzNXUz9f+g5XU2VkZNeFNjAC47oOZeVLZEf5hzeH0HkIvMACfl927UrsbPWhupDnGKx8Nogp2XWPBSYjH7wJCC8V7vYLAOGhq3DpdaCYTCZVWMWl0aMgPINjen7NHfgfPyrqSB6x3oCiN9nVSpFU7IqVZ9KjWYPXisfUTY4XUQ+GW+/IWPJ4MkJw5piTW4eST1NkOqB6ALKDiM568/95WzV+Da8r4aFXEc0Jv2lVqg57IzzgiJ9B66RNz68oKTdbeTx2zcpzoo5r56PRaKXRQzcDQd8Dl3XwOyJijeRKAnZ3mfQ8GHYh19EEJevOZMeCK7y1uW5pznv7ZVtKKZFxj0BmLstpy+5i8X03G547sivslfBMoE1GrF0oEcen6sazledVdJxS7xNq4dF/zVaeEz6j0aiyLtnkF8CRXGNancWV5TmOwcpH5JunsHfEZAbp7+/vK9Kre8/xNyfomMhorkFMHhH2fbDwy+WPxV7Ypd8F9m7hHTb5Mdl73ev6mip55rLiPexhHJL02qbsuuSgQG6jRwxi/Fs0EaeDnyM9D4jsNXTZsjs4K6+k501RQXi8pgk8EF2z8rpzLA+WumsQ8gE8JZxXzqkLcZuiE4TfFThcaOLas+ta6po6BPQmM+mRuGPSl7rluBwEaBnKLS+W5Tm0mUrLr10if10Zzv1ul4nHtudssZnMXEPnUhu/ju9fLBbV//g7xuNxzGbfd6LBrFEuS+8Cb4rwEetEcckqV3fVUg2Tfd9KrAOXZuu1NbZJiyzXc3XQ00VI9ICMkBNQK4+ja2QHXElO9UDjeI3FNY5Xq+7WdeA+eUd4dukRv0dENTuUk3bASVr4uh9dIj0TWUsyTSx820qdJTW5D0HbYrNVa3DopBD8Dq3duyYSXUcdCzay13HIPMcmqEvaaf1dS2xw67VZxg2SbilqR/ibm5u4v7+vXPnBYFBNvQXh31wMvy1YCJkL7Cx8RKy4qC4T3YXEnZYrtQORSV2y9PhtOGfE6kqqcB9dmycfOp/BlTC7aunVtXchjS7EAusOsjvCawik5+L/RUTlFc3n8xgOh1XcvlwuYzAYxPX1deXaa1/+UZbldg0lvT66ioCWndS6M9kZXcjWs2tfR3YcrsQIwoPImkTSpBO+A22fKB1xCfOQMmoCV5pjgrKFd/V0EF5ddSY1D6SaBI34cS91qfXz8/MYDodr1n2TJrQmOHrCA1ryy9px3c3OrPyhE3f4XXiEZdA43rn2fMC6cMcdfj8yxrA4Wm7CczTzqHWHnCK6TXZGlrBj6+5ceo7h4Rk5HVJdUndeE8tnZ9+Xcbu5udl6GbemeBOEV1K4UgYL18XwavVLibt9DAAuXKlL4JUOLsvxb0c8jnIfK7y69LwfGkKF2Wx28EGxDmrZ9R4z6TnMcY00KMtBJpmHyOVOHRQ1iXhxsb4qc0b2k0zaOTjBOHc+IlZuBI/AOhofWolLpEe5Rudil46IHwuMsJeD+FHLSM5tdf0Kh851ZHD3z5VmmZz6ezVrz7V3lok2LmWGg68hYjW0ylppewtfgIvpFdkoy//LRuR9I0tGRsTKo9bpdTELbReO+JFAcpOIMCGnadij6BLxAZe4Y+Jr552bIuuy79zL4bwJlhHidR10SguY7BJvjvARfl06rj+zwqu1iujmLjSZB6MLWKjbz94AnwPZddd84+JR5/3ooNlFuMQrP2bWWON8dfv1yPI+7hH3gj0i5Fd2naRTvDnCu44k5xrpjdDnXUSp+qC5C/d/VSItqWm5yikvD4p4dKTqmnUvQX+f5nI0NneGwoU2fO7s+yArbmICegtfg00THO6maHLrGFCX4OP3sEsPlNxQHgjwXiX9MSAb2LPYWgdANyjWufHuO7Jry+L1XZP+TRH+lKCeTEkxSrkMft7EGr0l8ODOz+sSbfyaPjYlO0jOj9n7domz5TH5Xj169HgVjs9/7dGjx9boCd+jxwmhJ3yPHieEnvA9epwQesL36HFC6Anfo8cJoSd8jx4nhJ7wPXqcEHrC9+hxQvh/s98sUacTB7MAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Load and preprocess the `digits` dataset.\n", "digits = load_digits()\n", "# Filter for digit '1' (one) images.\n", "images = digits.images[digits.target == 1]\n", "# Normalize pixel values into floating-point arrays in the `[0, 1]` interval.\n", "images = images / 16.0\n", "# Convert to `jax.Array`s.\n", "images = jnp.asarray(images)\n", "# Reshape to `(num_images, height, width, channels)` for convolutional layers.\n", "images = images.reshape(-1, 8, 8, 1)\n", "\n", "# Split the dataset into training and test sets (5% for testing).\n", "images_train, images_test = train_test_split(images, test_size=0.05, random_state=42)\n", "print(f\"Training set size: {images_train.shape[0]}\")\n", "print(f\"Test set size: {images_test.shape[0]}\")\n", "\n", "# Visualize sample images.\n", "fig, axes = plt.subplots(3, 3, figsize=(3, 3))\n", "for i, ax in enumerate(axes.flat):\n", " if i < len(images_train):\n", " ax.imshow(images_train[i, ..., 0], cmap='gray', interpolation='gaussian')\n", " ax.axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "exKxj9OcG0yk" }, "source": [ "## Defining the diffusion model with Flax\n", "\n", "In this section, we’ll develop various parts of the [diffusion model](https://en.wikipedia.org/wiki/Diffusion_model) and then put them all together.\n", "\n", "### The U-Net architecture\n", "\n", "For this example, we’ll use the [U-Net architecture](https://en.wikipedia.org/wiki/U-Net), a convolutional neural network architecture, as the backbone of the diffusion model. The U-Net consists of the following:\n", "\n", "- An [encoder](https://en.wikipedia.org/wiki/Autoencoder) path that [downsamples](https://en.wikipedia.org/wiki/Downsampling_(signal_processing)) the input image, extracting features.\n", "- A bridge with a (self-)[attention mechanism](https://en.wikipedia.org/wiki/Attention_(machine_learning) that connects the encoder with the decoder.\n", "- A [decoder](https://en.wikipedia.org/wiki/Autoencoder) path that [upsamples](https://en.wikipedia.org/wiki/Upsampling) the feature representations learned by the encoder, reconstructing the output image.\n", "- [Skip connections](https://en.wikipedia.org/wiki/Residual_neural_network#Residual_connection) between the encoder and the decoder.\n", "\n", "Let's define a class called `UNet` by subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) and using, among other things, [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) (linear or dense layers for time embedding and time projection layers, as well as the self-attention layers), [`flax.nnx.LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm) (layer normalization), and [`flax.nnx.Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv) (convolution layers for the output layer)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F4pxdITOuk79" }, "outputs": [], "source": [ "class UNet(nnx.Module):\n", " def __init__(self,\n", " in_channels: int,\n", " out_channels: int,\n", " features: int,\n", " time_emb_dim: int = 128,\n", " *,\n", " rngs: nnx.Rngs):\n", " \"\"\"\n", " Initialize the U-Net architecture with time embedding.\n", " \"\"\"\n", " self.features = features\n", "\n", " # Time embedding layers for diffusion timestep conditioning.\n", " self.time_mlp_1 = nnx.Linear(in_features=time_emb_dim, out_features=time_emb_dim, rngs=rngs)\n", " self.time_mlp_2 = nnx.Linear(in_features=time_emb_dim, out_features=time_emb_dim, rngs=rngs)\n", "\n", " # Time projection layers for different scales.\n", " self.time_proj1 = nnx.Linear(in_features=time_emb_dim, out_features=features, rngs=rngs)\n", " self.time_proj2 = nnx.Linear(in_features=time_emb_dim, out_features=features * 2, rngs=rngs)\n", " self.time_proj3 = nnx.Linear(in_features=time_emb_dim, out_features=features * 4, rngs=rngs)\n", " self.time_proj4 = nnx.Linear(in_features=time_emb_dim, out_features=features * 8, rngs=rngs)\n", "\n", " # The encoder path.\n", " self.down_conv1 = self._create_residual_block(in_channels, features, rngs)\n", " self.down_conv2 = self._create_residual_block(features, features * 2, rngs)\n", " self.down_conv3 = self._create_residual_block(features * 2, features * 4, rngs)\n", " self.down_conv4 = self._create_residual_block(features * 4, features * 8, rngs)\n", "\n", " # Multi-head self-attention blocks.\n", " self.attention1 = self._create_attention_block(features * 4, rngs)\n", " self.attention2 = self._create_attention_block(features * 8, rngs)\n", "\n", " # The bridge connecting the encoder and the decoder.\n", " self.bridge_down = self._create_residual_block(features * 8, features * 16, rngs)\n", " self.bridge_attention = self._create_attention_block(features * 16, rngs)\n", " self.bridge_up = self._create_residual_block(features * 16, features * 16, rngs)\n", "\n", " # Decoder path with skip connections.\n", " self.up_conv4 = self._create_residual_block(features * 24, features * 8, rngs)\n", " self.up_conv3 = self._create_residual_block(features * 12, features * 4, rngs)\n", " self.up_conv2 = self._create_residual_block(features * 6, features * 2, rngs)\n", " self.up_conv1 = self._create_residual_block(features * 3, features, rngs)\n", "\n", " # Output layers.\n", " self.final_norm = nnx.LayerNorm(features, rngs=rngs)\n", " self.final_conv = nnx.Conv(in_features=features,\n", " out_features=out_channels,\n", " kernel_size=(3, 3),\n", " strides=(1, 1),\n", " padding=((1, 1), (1, 1)),\n", " rngs=rngs)\n", "\n", " def _create_attention_block(self, channels: int, rngs: nnx.Rngs) -> Callable:\n", " \"\"\"Creates a self-attention block with learned query, key, value projections.\n", "\n", " Args:\n", " channels (int): The number of channels in the input feature maps.\n", " rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys.\n", "\n", " Returns:\n", " Callable: A function representing a forward pass through the attention block.\n", "\n", " \"\"\"\n", " query_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)\n", " key_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)\n", " value_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)\n", "\n", " def forward(x: jax.Array) -> jax.Array:\n", " \"\"\"Applies self-attention to the input.\n", "\n", " Args:\n", " x (jax.Array): The input tensor with the shape `[batch, height, width, channels]` (or `B, H, W, C`).\n", "\n", " Returns:\n", " jax.Array: The output tensor after applying self-attention.\n", " \"\"\"\n", "\n", " # Shape: batch, height, width, channels.\n", " B, H, W, C = x.shape\n", " scale = jnp.sqrt(C).astype(x.dtype)\n", "\n", " # Project the input into query, key, value projections.\n", " q = query_proj(x)\n", " k = key_proj(x)\n", " v = value_proj(x)\n", "\n", " # Reshape for the attention computation.\n", " q = q.reshape(B, H * W, C)\n", " k = k.reshape(B, H * W, C)\n", " v = v.reshape(B, H * W, C)\n", "\n", " # Compute the scaled dot-product attention.\n", " attention = jnp.einsum('bic,bjc->bij', q, k) / scale # Scaled dot-product.\n", " attention = jax.nn.softmax(attention, axis=-1) # Softmax.\n", "\n", " # The output tensor.\n", " out = jnp.einsum('bij,bjc->bic', attention, v)\n", " out = out.reshape(B, H, W, C)\n", "\n", " return x + out # A ResNet-style residual connection.\n", "\n", " return forward\n", "\n", " def _create_residual_block(self,\n", " in_channels: int,\n", " out_channels: int,\n", " rngs: nnx.Rngs) -> Callable:\n", " \"\"\"Creates a residual block with two convolutions and normalization.\n", "\n", " Args:\n", " in_channels (int): Number of input channels.\n", " out_channels (int): Number of output channels.\n", " rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX PRNG keys.\n", "\n", " Returns:\n", " Callable: A function that represents the forward pass through the residual block.\n", " \"\"\"\n", "\n", " # Convolutional layers with layer normalization.\n", " conv1 = nnx.Conv(in_features=in_channels,\n", " out_features=out_channels,\n", " kernel_size=(3, 3),\n", " strides=(1, 1),\n", " padding=((1, 1), (1, 1)),\n", " rngs=rngs)\n", " norm1 = nnx.LayerNorm(out_channels, rngs=rngs)\n", " conv2 = nnx.Conv(in_features=out_channels,\n", " out_features=out_channels,\n", " kernel_size=(3, 3),\n", " strides=(1, 1),\n", " padding=((1, 1), (1, 1)),\n", " rngs=rngs)\n", " norm2 = nnx.LayerNorm(out_channels, rngs=rngs)\n", "\n", " # Projection shortcut if dimensions change.\n", " shortcut = nnx.Conv(in_features=in_channels,\n", " out_features=out_channels,\n", " kernel_size=(1, 1),\n", " strides=(1, 1),\n", " rngs=rngs)\n", "\n", " # The forward pass through the residual block.\n", " def forward(x: jax.Array) -> jax.Array:\n", " identity = shortcut(x)\n", "\n", " x = conv1(x)\n", " x = norm1(x)\n", " x = nnx.gelu(x)\n", "\n", " x = conv2(x)\n", " x = norm2(x)\n", " x = nnx.gelu(x)\n", "\n", " return x + identity\n", "\n", " return forward\n", "\n", " def _pos_encoding(self, t: jax.Array, dim: int) -> jax.Array:\n", " \"\"\"Applies sinusoidal positional encoding for time embedding.\n", "\n", " Args:\n", " t (jax.Array): The time embedding, representing the timestep.\n", " dim (int): The dimension of the output positional encoding.\n", "\n", " Returns:\n", " jax.Array: The sinusoidal positional embedding per timestep.\n", "\n", " \"\"\"\n", " # Calculate half the embedding dimension.\n", " half_dim = dim // 2\n", " # Compute the logarithmic scaling factor for sinusoidal frequencies.\n", " emb = jnp.log(10000.0) / (half_dim - 1)\n", " # Generate a range of sinusoidal frequencies.\n", " emb = jnp.exp(jnp.arange(half_dim) * -emb)\n", " # Create the positional encoding by multiplying time embeddings with.\n", " emb = t[:, None] * emb[None, :]\n", " # Concatenate sine and cosine components for richer representation.\n", " emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1)\n", " return emb\n", "\n", " def _downsample(self, x: jax.Array) -> jax.Array:\n", " \"\"\"Downsamples the input feature map with max pooling.\"\"\"\n", " return nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')\n", "\n", " def _upsample(self, x: jax.Array, target_size: int) -> jax.Array:\n", " \"\"\"Upsamples the input feature map using nearest neighbor interpolation.\"\"\"\n", " return jax.image.resize(x,\n", " (x.shape[0], target_size, target_size, x.shape[3]),\n", " method='nearest')\n", "\n", " def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:\n", " \"\"\"Perform the forward pass through the U-Net using time embeddings.\"\"\"\n", "\n", " # Time embedding and projection.\n", " t_emb = self._pos_encoding(t, 128) # Sinusoidal positional encoding for time.\n", " t_emb = self.time_mlp_1(t_emb) # Project and activate the time embedding\n", " t_emb = nnx.gelu(t_emb) # Activation function: `flax.nnx.gelu` (GeLU).\n", " t_emb = self.time_mlp_2(t_emb)\n", "\n", " # Project time embeddings for each scale.\n", " # Project to the correct dimensions for each encoder block.\n", " t_emb1 = self.time_proj1(t_emb)[:, None, None, :]\n", " t_emb2 = self.time_proj2(t_emb)[:, None, None, :]\n", " t_emb3 = self.time_proj3(t_emb)[:, None, None, :]\n", " t_emb4 = self.time_proj4(t_emb)[:, None, None, :]\n", "\n", " # The encoder path with time injection.\n", " d1 = self.down_conv1(x)\n", " t_emb1 = jnp.broadcast_to(t_emb1, d1.shape) # Broadcast the time embedding to match feature map shape.\n", " d1 = d1 + t_emb1 # Add the time embedding to the feature map.\n", "\n", " d2 = self.down_conv2(self._downsample(d1))\n", " t_emb2 = jnp.broadcast_to(t_emb2, d2.shape)\n", " d2 = d2 + t_emb2\n", "\n", " d3 = self.down_conv3(self._downsample(d2))\n", " d3 = self.attention1(d3) # Apply self-attention.\n", " t_emb3 = jnp.broadcast_to(t_emb3, d3.shape)\n", " d3 = d3 + t_emb3\n", "\n", " d4 = self.down_conv4(self._downsample(d3))\n", " d4 = self.attention2(d4)\n", " t_emb4 = jnp.broadcast_to(t_emb4, d4.shape)\n", " d4 = d4 + t_emb4\n", "\n", " # The bridge.\n", " b = self._downsample(d4)\n", " b = self.bridge_down(b)\n", " b = self.bridge_attention(b)\n", " b = self.bridge_up(b)\n", "\n", " # The decoder path with skip connections.\n", " u4 = self.up_conv4(jnp.concatenate([self._upsample(b, d4.shape[1]), d4], axis=-1))\n", " u3 = self.up_conv3(jnp.concatenate([self._upsample(u4, d3.shape[1]), d3], axis=-1))\n", " u2 = self.up_conv2(jnp.concatenate([self._upsample(u3, d2.shape[1]), d2], axis=-1))\n", " u1 = self.up_conv1(jnp.concatenate([self._upsample(u2, d1.shape[1]), d1], axis=-1))\n", "\n", " # Final layers.\n", " x = self.final_norm(u1)\n", " x = nnx.gelu(x)\n", " return self.final_conv(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "XJaqiL07HD9D" }, "source": [ "### Defining the diffusion model\n", "\n", "Here, we will define the diffusion model that encapsulates the previously components, such as the `UNet` class, and include all the layers needed to perform the diffusion operations. The `DiffusionModel` class implements the diffusion process with:\n", "\n", "- Forward diffusion (adding noise)\n", "- Reverse diffusion (denoising)\n", "- Custom noise scheduling" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4ml8cKFUvCdv" }, "outputs": [], "source": [ "class DiffusionModel:\n", " def __init__(self,\n", " model: UNet,\n", " num_steps: int,\n", " beta_start: float,\n", " beta_end: float):\n", " \"\"\"Initialize diffusion process parameters.\n", "\n", " Args:\n", " model (UNet): The U-Net model for image generation.\n", " num_steps (int): The number of diffusion steps in the process.\n", " beta_start: The starting value for beta, controlling the noise level.\n", " beta_end: The end value for beta.\n", " \"\"\"\n", " self.model = model\n", " self.num_steps = num_steps\n", "\n", " # Noise schedule parameters.\n", " self.beta = self._cosine_beta_schedule(num_steps, beta_start, beta_end)\n", " self.alpha = 1 - self.beta\n", " self.alpha_cumulative = jnp.cumprod(self.alpha)\n", "\n", " self.sqrt_alpha_cumulative = jnp.sqrt(self.alpha_cumulative)\n", " self.sqrt_one_minus_alpha_cumulative = jnp.sqrt(1 - self.alpha_cumulative)\n", " self.sqrt_recip_alpha = jnp.sqrt(1 / self.alpha)\n", "\n", " self.posterior_variance = self.beta * (1 - self.alpha_cumulative) / (1 - self.alpha_cumulative + 1e-7)\n", "\n", " def _cosine_beta_schedule(self,\n", " num_steps: int,\n", " beta_start: float,\n", " beta_end: float) -> jax.Array:\n", " \"\"\"Cosine schedule for noise levels.\"\"\"\n", " steps = jnp.linspace(0, num_steps, num_steps + 1)\n", " x = steps / num_steps\n", " alphas = jnp.cos(((x + 0.008) / 1.008) * jnp.pi * 0.5) ** 2\n", " alphas = alphas / alphas[0]\n", " betas = 1 - (alphas[1:] / alphas[:-1])\n", " betas = jnp.clip(betas, beta_start, beta_end)\n", " return jnp.concatenate([betas[0:1], betas])\n", "\n", " def forward(self,\n", " x: jax.Array,\n", " t: jax.Array,\n", " key: jax.Array) -> Tuple[jax.Array, jax.Array]:\n", " \"\"\"Forward diffusion process - adds noise according to schedule.\n", "\n", " Args:\n", " x (jax.Array): The input image.\n", " t (jax.Array): The timestep(s) at which the noise is added.\n", " key (jax.Array): A JAX PRNG key for generating random noise.\n", "\n", " Returns:\n", " Tuple[jax.Array, jax.Array]\n", " \"\"\"\n", " noise = jax.random.normal(key, x.shape)\n", " noisy_x = (\n", " jnp.sqrt(self.alpha_cumulative[t])[:, None, None, None] * x +\n", " jnp.sqrt(1 - self.alpha_cumulative[t])[:, None, None, None] * noise\n", " )\n", " return noisy_x, noise\n", "\n", " def reverse(self, x: jax.Array, key: jax.Array) -> jax.Array:\n", " \"\"\"Performs the reverse diffusion process, denoising the input image gradually.\n", "\n", " Args:\n", " x (jax.Array): The noise image batch per timestep.\n", " key (jax.Array): A JAX PRNG key for the random noise.\n", " \"\"\"\n", " x_t = x\n", " for t in reversed(range(self.num_steps)):\n", " t_batch = jnp.array([t] * x.shape[0])\n", " predicted = self.model(x_t, t_batch) # Predicted noise using the U-Net.\n", "\n", " key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", " noise = jax.random.normal(subkey, x_t.shape) if t > 0 else 0 # Sample the noise for the current timestep.\n", "\n", " # The denoising step.\n", " x_t = (1 / jnp.sqrt(self.alpha[t])) * (\n", " x_t - ((1 - self.alpha[t]) / jnp.sqrt(1 - self.alpha_cumulative[t])) * predicted\n", " ) + jnp.sqrt(self.beta[t]) * noise\n", "\n", " return x_t # The final denoised image." ] }, { "cell_type": "markdown", "metadata": { "id": "wKnYRqMAI06f" }, "source": [ "## Defining the loss function and training step\n", "\n", "In this section, we’ll define the components for training our diffusion model, including:\n", "\n", "- The loss function (`loss_fn()`), which incorporates [SNR weighting](https://en.wikipedia.org/wiki/Signal-to-noise_ratio) and a gradient penalty; and\n", "- The training step (`train_step()`) with [gradient clipping](https://arxiv.org/pdf/1905.11881) for stability." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rq9Ic8WYCCJI" }, "outputs": [], "source": [ "def loss_fn(model: UNet,\n", " images: jax.Array,\n", " t: jax.Array,\n", " noise: jax.Array,\n", " sqrt_alpha_cumulative: jax.Array,\n", " sqrt_one_minus_alpha_cumulative: jax.Array) -> jax.Array:\n", " \"\"\"Computes the diffusion loss function with SNR weighting and adaptive noise scaling.\n", "\n", " Args:\n", " model(UNet): The U-Net model for image generation.\n", " images (jax.Array): A batch of images used for training.\n", " t (jax.Array): The timestep(s) at which the noise is added to each image.\n", " noise (jax.Array): The noise added to the images.\n", " sqrt_alpha_cumulative (jax.Array): Square root of cumulative alpha values.\n", " sqrt_one_minus_alpha_cumulative (jax.Array): Square root of (1 - cumulative alpha values).\n", "\n", " Returns:\n", " jax.Array: The total loss value.\n", " \"\"\"\n", "\n", " # Generate noisy images.\n", " noisy_images = (\n", " sqrt_alpha_cumulative[t][:, None, None, None] * images +\n", " sqrt_one_minus_alpha_cumulative[t][:, None, None, None] * noise\n", " )\n", "\n", " # Predict the noise using the U-Net.\n", " predicted = model(noisy_images, t)\n", "\n", " # Compute the SNR-weighted loss.\n", " snr = (sqrt_alpha_cumulative[t] / sqrt_one_minus_alpha_cumulative[t])[:, None, None, None]\n", " loss_weights = snr / (1 + snr)\n", "\n", " squared_error = (noise - predicted) ** 2\n", " main_loss = jnp.mean(loss_weights * squared_error)\n", "\n", " # Perform gradient penalty (regularization) with a reduced coefficient.\n", " grad = jax.grad(lambda x: model(x, t).mean())(noisy_images)\n", " grad_penalty = 0.02 * (jnp.square(grad).mean())\n", "\n", " # The total loss.\n", " return main_loss + grad_penalty\n", "\n", "# Flax NNX JIT-compilation for performance (`flax.nnx.jit`).\n", "@nnx.jit\n", "def train_step(model: UNet,\n", " optimizer: nnx.Optimizer,\n", " images: jax.Array,\n", " t: jax.Array,\n", " noise: jax.Array,\n", " sqrt_alpha_cumulative: jax.Array,\n", " sqrt_one_minus_alpha_cumulative: jax.Array) -> jax.Array:\n", " \"\"\"Performs a single training step with gradient clipping.\n", "\n", " Args:\n", " model(UNet): The U-Net model for image generation that is being trained.\n", " optimizer (flax.nnx.Optimizer): The Flax NNX optimizer for parameter updates.\n", " images (jax.Array): A batch of images used for training.\n", " t (jax.Array): The timestep(s) at which the noise is added to each image.\n", " noise (jax.Array): The noise added to the images during training.\n", " sqrt_alpha_cumulative (jax.Array): Square root of cumulative alpha values from the diffusion schedule.\n", " sqrt_one_minus_alpha_cumulative (jax.Array): Square root of (1 - cumulative alpha values) from the diffusion schedule.\n", "\n", " Returns:\n", " jax.Array: The loss value after a single training step.\n", " \"\"\"\n", " # The loss and gradients using `flax.nnx.value_and_grad`.\n", " loss, grads = nnx.value_and_grad(loss_fn)(\n", " model, images, t, noise,\n", " sqrt_alpha_cumulative, sqrt_one_minus_alpha_cumulative\n", " )\n", "\n", " # Apply conservative gradient clipping.\n", " clip_threshold = 0.3\n", " grads = jax.tree_util.tree_map(\n", " lambda g: jnp.clip(g, -clip_threshold, clip_threshold),\n", " grads\n", " )\n", " # Update the parameters using the optimizer.\n", " optimizer.update(grads)\n", " # Return the loss after a single training step.\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "id": "4slhkQ6vI5tZ" }, "source": [ "### Model training configuration\n", "\n", "Next, we’ll define the model configuration and the training loop implementation.\n", "\n", "We need to set up:\n", "\n", "- Model hyperparameters\n", "- An optimizer with the learning rate schedule" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w4CwR-6ivIjS" }, "outputs": [], "source": [ "# Set the model and training hyperparameters.\n", "key = jax.random.PRNGKey(42) # PRNG seed for reproducibility.\n", "in_channels = 1\n", "out_channels = 1\n", "features = 64 # Number of features in the U-Net.\n", "num_steps = 1000\n", "num_epochs = 5000\n", "batch_size = 64\n", "learning_rate = 1e-4\n", "beta_start = 1e-4 # The starting value for beta (noise level schedule).\n", "beta_end = 0.02 # The end value for beta (noise level schedule).\n", "\n", "# Initialize model components.\n", "key, subkey = jax.random.split(key) # Split the JAX PRNG key for initialization.\n", "model = UNet(in_channels, out_channels, features, rngs=nnx.Rngs(default=subkey)) # Instantiate the U-Net.\n", "\n", "diffusion = DiffusionModel(\n", " model=model,\n", " num_steps=num_steps,\n", " beta_start=beta_start,\n", " beta_end=beta_end\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yLjb_t026uy3", "outputId": "2cda0980-ac98-4fd7-ee3a-02728a64f1f7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape: (1, 8, 8, 1)\n", "Output shape: (1, 8, 8, 1)\n", "\n", "Model initialized successfully\n" ] } ], "source": [ "# Learning rate schedule configuration.\n", "# Start with the warmup, then cosine decay.\n", "warmup_steps = 1000 # Number of steps.\n", "total_steps = num_epochs # Total number of training steps.\n", "\n", "# Multiple schedules using `optax.join_schedules`:\n", "# Linear transition (`optax.linear_schedule`) (for the warmup) and\n", "# and cosine learning rate decay (`optax.cosine_decay_schedule`).\n", "schedule_fn = optax.join_schedules(\n", " schedules=[\n", " optax.linear_schedule(\n", " init_value=0.0,\n", " end_value=learning_rate,\n", " transition_steps=warmup_steps\n", " ),\n", " optax.cosine_decay_schedule(\n", " init_value=learning_rate,\n", " decay_steps=total_steps - warmup_steps,\n", " alpha=0.01\n", " )\n", " ],\n", " boundaries=[warmup_steps] # Where the schedule transitions from the warmup to cosine decay.\n", ")\n", "\n", "# Optimizer configuration (AdamW) with gradient clipping.\n", "optimizer = nnx.ModelAndOptimizer(model, optax.chain(\n", " optax.clip_by_global_norm(0.5), # Gradient clipping.\n", " optax.adamw(\n", " learning_rate=schedule_fn,\n", " weight_decay=2e-5,\n", " b1=0.9,\n", " b2=0.999,\n", " eps=1e-8\n", " )\n", "))\n", "\n", "# Model initialization with dummy input.\n", "dummy_input = jnp.ones((1, 8, 8, 1))\n", "dummy_t = jnp.zeros((1,), dtype=jnp.int32)\n", "output = model(dummy_input, dummy_t)\n", "\n", "print(\"Input shape:\", dummy_input.shape)\n", "print(\"Output shape:\", output.shape)\n", "print(\"\\nModel initialized successfully\")" ] }, { "cell_type": "markdown", "metadata": { "id": "LrzTfkDPJm2X" }, "source": [ "### Implementing the training loop\n", "\n", "Finally, we need to implement the main training loop for the diffusion model with:\n", "\n", "- The progressive timestep sampling strategy\n", "- [Exponential moving average (EMA)](https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average) loss tracking\n", "- Adaptive noise generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZnQqHCAoVfi1", "outputId": "a105e2de-ba88-44d0-bad5-3a9a69e54826" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0, Loss: 1.2441\n", "Epoch 100, Loss: 1.1178\n", "Epoch 200, Loss: 0.8737\n", "Epoch 300, Loss: 0.7176\n", "Epoch 400, Loss: 0.6327\n", "Epoch 500, Loss: 0.5682\n", "Epoch 600, Loss: 0.5024\n", "Epoch 700, Loss: 0.4417\n", "Epoch 800, Loss: 0.3805\n", "Epoch 900, Loss: 0.3254\n", "Epoch 1000, Loss: 0.2803\n", "Epoch 1100, Loss: 0.2534\n", "Epoch 1200, Loss: 0.2339\n", "Epoch 1300, Loss: 0.2221\n", "Epoch 1400, Loss: 0.2141\n", "Epoch 1500, Loss: 0.2085\n", "Epoch 1600, Loss: 0.2046\n", "Epoch 1700, Loss: 0.1991\n", "Epoch 1800, Loss: 0.1951\n", "Epoch 1900, Loss: 0.1923\n", "Epoch 2000, Loss: 0.1919\n", "Epoch 2100, Loss: 0.1913\n", "Epoch 2200, Loss: 0.1888\n", "Epoch 2300, Loss: 0.1858\n", "Epoch 2400, Loss: 0.1861\n", "Epoch 2500, Loss: 0.1867\n", "Epoch 2600, Loss: 0.1855\n", "Epoch 2700, Loss: 0.1832\n", "Epoch 2800, Loss: 0.1834\n", "Epoch 2900, Loss: 0.1839\n", "Epoch 3000, Loss: 0.1844\n", "Epoch 3100, Loss: 0.1838\n", "Epoch 3200, Loss: 0.1816\n", "Epoch 3300, Loss: 0.1824\n", "Epoch 3400, Loss: 0.1815\n", "Epoch 3500, Loss: 0.1823\n", "Epoch 3600, Loss: 0.1834\n", "Epoch 3700, Loss: 0.1823\n", "Epoch 3800, Loss: 0.1811\n", "Epoch 3900, Loss: 0.1806\n", "Epoch 4000, Loss: 0.1804\n", "Epoch 4100, Loss: 0.1814\n", "Epoch 4200, Loss: 0.1802\n", "Epoch 4300, Loss: 0.1813\n", "Epoch 4400, Loss: 0.1799\n", "Epoch 4500, Loss: 0.1811\n", "Epoch 4600, Loss: 0.1820\n", "Epoch 4700, Loss: 0.1829\n", "Epoch 4800, Loss: 0.1828\n", "Epoch 4900, Loss: 0.1832\n", "Epoch 5000, Loss: 0.1827\n", "\n", "Training completed.\n" ] } ], "source": [ "# Initialize training metrics.\n", "losses: List[float] = [] # Store the EMA loss history.\n", "moving_avg_loss: Optional[float] = None # The EMA of the loss value.\n", "beta: float = 0.99 # The EMA decay factor for loss smoothing.\n", "\n", "for epoch in range(num_epochs + 1):\n", " # Split the JAX PRNG key for independent random operations.\n", " key, subkey1 = jax.random.split(key)\n", " key, subkey2 = jax.random.split(key)\n", "\n", " # Progressive timestep sampling - weights early steps more heavily as training progresses.\n", " # This helps model focus on fine details in later epochs while maintaining stability.\n", " progress = epoch / num_epochs\n", " t_weights = jnp.linspace(1.0, 0.1 * (1.0 - progress), num_steps)\n", " t = jax.random.choice(\n", " subkey1,\n", " num_steps,\n", " shape=(images_train.shape[0],),\n", " p=t_weights/t_weights.sum()\n", " )\n", "\n", " # Generate the Gaussian noise for the current batch of images.\n", " noise = jax.random.normal(subkey2, images_train.shape)\n", "\n", " # Execute the training step with noise prediction and parameter updates.\n", " loss = train_step(\n", " model, optimizer, images_train, t, noise,\n", " diffusion.sqrt_alpha_cumulative, diffusion.sqrt_one_minus_alpha_cumulative\n", " )\n", "\n", " # Update the exponential moving average (EMA) of the loss for smoother tracking.\n", " if moving_avg_loss is None:\n", " moving_avg_loss = loss\n", " else:\n", " moving_avg_loss = beta * moving_avg_loss + (1 - beta) * loss\n", "\n", " losses.append(moving_avg_loss)\n", "\n", " # Log the training progress at regular intervals.\n", " if epoch % 100 == 0:\n", " print(f\"Epoch {epoch}, Loss: {moving_avg_loss:.4f}\")\n", "\n", "print(\"\\nTraining completed.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "s-iqch4HKlBV" }, "source": [ "### Training loss visualization\n", "\n", "To visualize the training loss, we can use a logarithmic scale to better display the exponential decay of the loss values over time. This representation helps identify both early rapid improvements and later fine-tuning phases of the training process.\n", "\n", "Based on the results, the model appears to perform well, as the training loss falls over time during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 487 }, "id": "1bjvWNCcbN24", "outputId": "457fd13f-377f-4021-ddc2-e36940b42550" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3EAAAHWCAYAAADZ8gAzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABn5ElEQVR4nO3dd3hUZcLG4WcmvRdCEkIJvQQkIF2kCKFERAF1UXBF3JVVg6sia12xi2UtqyLYKRZAFrBQlI4gSA0tEFroJCGU9D7n+wOZz5hQEgJnJvnd15VrnTPvOeeZ8G7gyWkWwzAMAQAAAACcgtXsAAAAAACAS0eJAwAAAAAnQokDAAAAACdCiQMAAAAAJ0KJAwAAAAAnQokDAAAAACdCiQMAAAAAJ0KJAwAAAAAnQokDAAAAACdCiQMAXFH33HOP6tevX6F1n3/+eVkslsoNBKfRs2dP9ezZ0+wYAOBwKHEAUE1ZLJZL+lq+fLnZUU1xzz33yNfX1+wYl8QwDE2bNk3du3dXYGCgvL29dc011+jFF19Udna22fHsDhw4cMnz7sCBA2bHBQCHZTEMwzA7BADg6vvyyy9LvJ46daoWLVqkadOmlVjep08fhYWFVXg/hYWFstls8vDwKPe6RUVFKioqkqenZ4X3X1H33HOPZs2apaysrKu+7/IoLi7WsGHDNHPmTHXr1k1DhgyRt7e3fvnlF3399deKiorS4sWLL+vPsLJkZ2drzpw5JZa99dZbOnLkiN55550SywcPHiw3NzdJkru7+1XLCADOgBIHAJAkjR49WhMmTNDF/lrIycmRt7f3VUplHmcpcePHj9fTTz+tsWPH6s033yzx3g8//KBBgwapb9++WrBgwVXNdanz5KabbtL27ds58gYA5cDplACA8+rZs6datWqljRs3qnv37vL29tbTTz8tSfruu+80YMAARUREyMPDQ40aNdJLL72k4uLiEtv48zVx506p+89//qOPP/5YjRo1koeHhzp06KD169eXWLesa+IsFotGjx6tuXPnqlWrVvLw8FDLli21cOHCUvmXL1+u9u3by9PTU40aNdJHH31U6dfZffvtt2rXrp28vLwUEhKiu+66S0ePHi0xJjk5WSNHjlSdOnXk4eGhWrVq6ZZbbilRXDZs2KB+/fopJCREXl5eatCgge69994L7js3N1dvvvmmmjZtqvHjx5d6f+DAgRoxYoQWLlyotWvXSjpbmho2bFjm9rp06aL27duXWPbll1/aP19wcLDuuOMOHT58uMSYC82Ty/Hna+KWL18ui8WimTNn6oUXXlDt2rXl5+en2267Tenp6crPz9cjjzyi0NBQ+fr6auTIkcrPzy+13Uv5TADgyFzNDgAAcGwnT55UbGys7rjjDt1111320/ImT54sX19fjRkzRr6+vlq6dKnGjRunjIyMUkeEyvL1118rMzNT//jHP2SxWPTGG29oyJAh2r9/v/00uvNZtWqVZs+erQcffFB+fn567733dOutt+rQoUOqUaOGJGnz5s3q37+/atWqpRdeeEHFxcV68cUXVbNmzcv/pvxu8uTJGjlypDp06KDx48crJSVF//3vf7V69Wpt3rxZgYGBkqRbb71VO3bs0EMPPaT69esrNTVVixYt0qFDh+yv+/btq5o1a+rJJ59UYGCgDhw4oNmzZ1/0+3D69Gk9/PDDcnUt+6/0u+++W1988YV+/PFHde7cWUOHDtXdd9+t9evXq0OHDvZxBw8e1Nq1a0v82b3yyit69tln9Ze//EV///vfdeLECb3//vvq3r17ic8nnX+eXAnjx4+Xl5eXnnzySe3du1fvv/++3NzcZLVadfr0aT3//PNau3atJk+erAYNGmjcuHEV+kwA4LAMAAAMw4iLizP+/NdCjx49DEnGpEmTSo3Pyckptewf//iH4e3tbeTl5dmXjRgxwoiMjLS/TkpKMiQZNWrUME6dOmVf/t133xmSjB9++MG+7LnnniuVSZLh7u5u7N27175sy5YthiTj/fffty8bOHCg4e3tbRw9etS+bM+ePYarq2upbZZlxIgRho+Pz3nfLygoMEJDQ41WrVoZubm59uU//vijIckYN26cYRiGcfr0aUOS8eabb553W3PmzDEkGevXr79orj969913DUnGnDlzzjvm1KlThiRjyJAhhmEYRnp6uuHh4WE89thjJca98cYbhsViMQ4ePGgYhmEcOHDAcHFxMV555ZUS47Zt22a4urqWWH6heXIxAwYMKDE//qhHjx5Gjx497K+XLVtmSDJatWplFBQU2JffeeedhsViMWJjY0us36VLlxLbLs9nAgBHxumUAIAL8vDw0MiRI0st9/Lysv93Zmam0tLS1K1bN+Xk5GjXrl0X3e7QoUMVFBRkf92tWzdJ0v79+y+6bkxMjBo1amR/3bp1a/n7+9vXLS4u1uLFizVo0CBFRETYxzVu3FixsbEX3f6l2LBhg1JTU/Xggw+WuPHKgAED1Lx5c82bN0/S2e+Tu7u7li9frtOnT5e5rXNHf3788UcVFhZecobMzExJkp+f33nHnHsvIyNDkuTv76/Y2FjNnDmzxPWPM2bMUOfOnVWvXj1J0uzZs2Wz2fSXv/xFaWlp9q/w8HA1adJEy5YtK7Gf882TK+Huu+8ucbS2U6dOMgyj1OmnnTp10uHDh1VUVCSp/J8JABwVJQ4AcEG1a9cu8+6AO3bs0ODBgxUQECB/f3/VrFlTd911lyQpPT39ots9VxbOOVfozld0LrTuufXPrZuamqrc3Fw1bty41LiyllXEwYMHJUnNmjUr9V7z5s3t73t4eOj111/XggULFBYWpu7du+uNN95QcnKyfXyPHj1066236oUXXlBISIhuueUWffHFF2Vez/VH5wrauTJXlrKK3tChQ3X48GGtWbNGkrRv3z5t3LhRQ4cOtY/Zs2ePDMNQkyZNVLNmzRJfO3fuVGpqaon9nG+eXAl//vMPCAiQJNWtW7fUcpvNZp+P5f1MAOCouCYOAHBBfzzids6ZM2fUo0cP+fv768UXX1SjRo3k6empTZs26YknnpDNZrvodl1cXMpcblzCTZMvZ10zPPLIIxo4cKDmzp2rn376Sc8++6zGjx+vpUuXqm3btrJYLJo1a5bWrl2rH374QT/99JPuvfdevfXWW1q7du15n1fXokULSdLWrVs1aNCgMsds3bpVkhQVFWVfNnDgQHl7e2vmzJm67rrrNHPmTFmtVt1+++32MTabTRaLRQsWLCjz+/3nTGXNkyvlfH/+F5sX5f1MAOCoKHEAgHJbvny5Tp48qdmzZ6t79+725UlJSSam+n+hoaHy9PTU3r17S71X1rKKiIyMlCQlJiaqV69eJd5LTEy0v39Oo0aN9Nhjj+mxxx7Tnj171KZNG7311lslntfXuXNnde7cWa+88oq+/vprDR8+XNOnT9ff//73MjNcf/31CgwM1Ndff61nnnmmzGIydepUSWfvSnmOj4+PbrrpJn377bd6++23NWPGDHXr1q3EqaeNGjWSYRhq0KCBmjZtWs7vjmOqip8JQPXE6ZQAgHI7Vxb+eOSroKBAH374oVmRSnBxcVFMTIzmzp2rY8eO2Zfv3bu30p6X1r59e4WGhmrSpEklTntcsGCBdu7cqQEDBkg6+7y0vLy8Eus2atRIfn5+9vVOnz5d6ihimzZtJOmCp1R6e3tr7NixSkxM1DPPPFPq/Xnz5mny5Mnq16+fOnfuXOK9oUOH6tixY/r000+1ZcuWEqdSStKQIUPk4uKiF154oVQ2wzB08uTJ8+ZyVFXxMwGonjgSBwAot+uuu05BQUEaMWKE/vnPf8pisWjatGkOdTrj888/r59//lldu3bVAw88oOLiYn3wwQdq1aqV4uPjL2kbhYWFevnll0stDw4O1oMPPqjXX39dI0eOVI8ePXTnnXfaHzFQv359Pfroo5Kk3bt3q3fv3vrLX/6iqKgoubq6as6cOUpJSdEdd9whSZoyZYo+/PBDDR48WI0aNVJmZqY++eQT+fv768Ybb7xgxieffFKbN2/W66+/rjVr1ujWW2+Vl5eXVq1apS+//FItWrTQlClTSq134403ys/PT2PHjpWLi4tuvfXWEu83atRIL7/8sp566ikdOHBAgwYNkp+fn5KSkjRnzhyNGjVKY8eOvaTvo6Ooip8JQPVEiQMAlFuNGjX0448/6rHHHtO///1vBQUF6a677lLv3r3Vr18/s+NJktq1a6cFCxZo7NixevbZZ1W3bl29+OKL2rlz5yXdPVM6e3Tx2WefLbW8UaNGevDBB3XPPffI29tbr732mp544gn5+Pho8ODBev311+13nKxbt67uvPNOLVmyRNOmTZOrq6uaN2+umTNn2otTjx49tG7dOk2fPl0pKSkKCAhQx44d9dVXX6lBgwYXzOji4qKZM2dq6tSp+vTTT/Xss8+qoKBAjRo10nPPPafHHntMPj4+pdbz9PTUzTffrK+++koxMTEKDQ0tNebJJ59U06ZN9c477+iFF16wf56+ffvq5ptvvqTvoaOpip8JQPVjMRzp16YAAFxhgwYN0o4dO7Rnzx6zowAAUCFcEwcAqLJyc3NLvN6zZ4/mz5+vnj17mhMIAIBKwJE4AECVVatWLd1zzz1q2LChDh48qIkTJyo/P1+bN29WkyZNzI4HAECFcE0cAKDK6t+/v7755hslJyfLw8NDXbp00auvvkqBAwA4NY7EAQAAAIAT4Zo4AAAAAHAilDgAAAAAcCJcE2cim82mY8eOyc/PTxaLxew4AAAAAExiGIYyMzMVEREhq/XCx9oocSY6duyY6tata3YMAAAAAA7i8OHDqlOnzgXHUOJM5OfnJ+nsH5S/v7+pWQoLC/Xzzz+rb9++cnNzMzULnANzBuXFnEF5MWdQXswZlJcjzZmMjAzVrVvX3hEuhBJnonOnUPr7+ztEifP29pa/v7/pExjOgTmD8mLOoLyYMygv5gzKyxHnzKVcZsWNTQAAAADAiVDiAAAAAMCJUOIAAAAAwIlQ4gAAAADAiVDiAAAAAMCJUOIAAAAAwIlQ4gAAAADAiVDiAAAAAMCJUOIAAAAAwIlQ4gAAAADAiVDiAAAAAMCJUOIAAAAAwIlQ4iBJMgxDOUVmpwAAAABwMZQ4SJKmrD2k17e4aOuRdLOjAAAAALgAShxUWGzTtxuO6kyBRXd+tl7fbjhsdiQAAAAA50GJg9xcrJp+X0e1CrKpoMimf83aque+267CYpvZ0QAAAAD8CSUOkiQ/T1f9rZlN/7yhkSRpypqDGv7JbzqRmW9yMgAAAAB/RImDndUiPdSrkT65u718PVy17sAp3fzBKh08mW12NAAAAAC/o8ShlD5RYZob11UNQ3x0PD1P905er/ScQrNjAQAAABAlDufRONRX34zqrFoBntp3IlsPfLWRa+QAAAAAB0CJw3mF+XvqsxEd5O3uol/3ndS477abHQkAAACo9ihxuKCoCH+9f2dbWS3SN+sOa8G242ZHAgAAAKo1ShwuqneLMD3Ys7Ek6dnvdnB9HAAAAGAiShwuyUO9G6thTR+lZeXrtYU7zY4DAAAAVFuUOFwSD1cXvTaktaSzp1Wu3X/S5EQAAABA9USJwyXr2CBYd3asJ0l6es427lYJAAAAmIASh3J5Mra5Qnzdtf9EtmZvOmJ2HAAAAKDaocShXAK83HR/j0aSpPeX7uVoHAAAAHCVUeJQbsM7RSrE10NHTudyNA4AAAC4yihxKDcvdxfd36OhJOnTX5JkGIbJiQAAAIDqgxKHChnaoa583F20JzVLv+7jTpUAAADA1UKJQ4X4ebppyLV1JElTfj1gbhgAAACgGqHEocL+2iVSkrR0V6rSsvJNTgMAAABUD5Q4VFjTMD9F1wlQkc3Q3M1HzY4DAAAAVAuUOFyW29rXlSTN2niEG5wAAAAAVwElDpfl5tYRcne1aldyprYfzTA7DgAAAFDlUeJwWQK83dSvZbgkadbGwyanAQAAAKo+Shwu2+3tzt6lcm78MeUXFZucBgAAAKjaKHG4bF0bhyjc31PpuYVannjC7DgAAABAlUaJw2VzsVo0MLqWJOn7+GMmpwEAAACqNkocKsXN0bUlSYt3pigrv8jkNAAAAEDVRYlDpWhV218NQ3yUX2TTkp0pZscBAAAAqixKHCqFxWJR7DVn71K5YFuyyWkAAACAqosSh0oT2+rsdXHLd6cqp4BTKgEAAIArgRKHStMywl91gryUV2jTCu5SCQAAAFwRlDhUGovFothWv59SuZ1TKgEAAIArgRKHStX/91Mql+5K5cHfAAAAwBVAiUOlals3UGH+HsrKL9KqPWlmxwEAAACqHErcZfrxxx/VrFkzNWnSRJ9++qnZcUxntVrUv+XZUyp/3sGjBgAAAIDKRom7DEVFRRozZoyWLl2qzZs3680339TJkyfNjmW6PlFnS9ySXamy2QyT0wAAAABVCyXuMqxbt04tW7ZU7dq15evrq9jYWP38889mxzJdxwbB8vNwVVpWvrYcOWN2HAAAAKBKqdYlbuXKlRo4cKAiIiJksVg0d+7cUmMmTJig+vXry9PTU506ddK6devs7x07dky1a9e2v65du7aOHj16NaI7NHdXq7o3qylJWrIz1eQ0AAAAQNVSrUtcdna2oqOjNWHChDLfnzFjhsaMGaPnnntOmzZtUnR0tPr166fUVIrJxcS0CJUkLd7JdXEAAABAZXI1O4CZYmNjFRsbe9733377bd13330aOXKkJGnSpEmaN2+ePv/8cz355JOKiIgoceTt6NGj6tix43m3l5+fr/z8fPvrjIwMSVJhYaEKCwsv9+NclnP7r6wc1zcMlovVol3Jmdqfmq66Qd6Vsl04jsqeM6j6mDMoL+YMyos5g/JypDlTngwWwzC484TOPqh6zpw5GjRokCSpoKBA3t7emjVrln2ZJI0YMUJnzpzRd999p6KiIrVo0ULLly9XQECA2rVrp19//VU1atQocx/PP/+8XnjhhVLLv/76a3l7V72S8/4OF+3NsGhI/WL1qMU0AwAAAM4nJydHw4YNU3p6uvz9/S84tlofibuQtLQ0FRcXKywsrMTysLAw7dq1S5Lk6uqqt956SzfccINsNpsef/zx8xY4SXrqqac0ZswY++uMjAzVrVtXffv2vegf1JVWWFioRYsWqU+fPnJzc6uUbaYEHtSrCxJ13FpTN97YvlK2CcdxJeYMqjbmDMqLOYPyYs6gvBxpzpw7S+9SUOIu080336ybb775ksZ6eHjIw8Oj1HI3NzfTJ805lZmlX6taenVBotYdOK2cQinA2zE+IyqXI81fOAfmDMqLOYPyYs6gvBxhzpRn/9X6xiYXEhISIhcXF6WklLwxR0pKisLDw01K5Vwia/ioaZivim2Glu/mZjAAAABAZaDEnYe7u7vatWunJUuW2JfZbDYtWbJEXbp0MTGZc+kTdfZ01J8TuEslAAAAUBmqdYnLyspSfHy84uPjJUlJSUmKj4/XoUOHJEljxozRJ598oilTpmjnzp164IEHlJ2dbb9bJS4upsXZErci8YTyi4pNTgMAAAA4v2p9TdyGDRt0ww032F+fu+nIiBEjNHnyZA0dOlQnTpzQuHHjlJycrDZt2mjhwoWlbnaC84uuE6hQPw+lZubrt/2n1L1pTbMjAQAAAE6tWpe4nj176mJPWBg9erRGjx59lRJVPVarRb1bhOmbdYe0KCGFEgcAAABcpmp9OiWujj5RoZKkxTtTLlqaAQAAAFwYJQ5X3HWNQuTl5qLj6XnacezSn38BAAAAoDRKHK44TzcXdW8aIom7VAIAAACXixKHq6JP1Nln6y2mxAEAAACXhRKHq6JX81BZLVLC8QwdOZ1jdhwAAADAaVHicFUE+7irfWSwJI7GAQAAAJeDEoerpk/U2efrLd6ZanISAAAAwHlR4nDVxPxe4tbuP6n03EKT0wAAAADOiRKHq6ZBiI8ah/qqyGZoeSJH4wAAAICKoMThquKUSgAAAODyUOJwVcW0OFvilu9KVUGRzeQ0AAAAgPOhxJlgwoQJioqKUocOHcyOctW1rRuoEF8PZeYXaV3SKbPjAAAAAE6HEmeCuLg4JSQkaP369WZHueqsVotiWoRKkhYlJJucBgAAAHA+lDhcdedOqVyUkCLDMExOAwAAADgXShyuuuubhMjTzapj6XlKOJ5hdhwAAADAqVDicNV5urmoW5Oaks4ejQMAAABw6ShxMMW5Rw1Q4gAAAIDyocTBFL2bh8pikXYcy9CxM7lmxwEAAACcBiUOpqjh66F29YIkcTQOAAAAKA9KHEzTr2W4JGnetuMmJwEAAACcByUOphnQupYkaf2BUzqezimVAAAAwKWgxME0EYFe6lA/SIYhzdvK0TgAAADgUlDiYKqB0RGSpB+2HDM5CQAAAOAcKHEwVWyrWrJapC1H0nXwZLbZcQAAAACHR4mDqWr6eahr4xBJ0o+cUgkAAABcFCUOphvY+uwpld/Hc0olAAAAcDGUOJiuX8twublYlJiSqcTkTLPjAAAAAA6NEgfTBXi7qUfTmpKkH7dyNA4AAAC4EEocHMK5u1R+v+WYDMMwOQ0AAADguChxcAh9osLk4+6igydztOHgabPjAAAAAA6LEgeH4O3uqhuvqSVJmrXhiMlpAAAAAMdFiYPDuK1dHUnSvG3HlVtQbHIaAAAAwDFR4kwwYcIERUVFqUOHDmZHcSgd6gerbrCXsvKL9NOOZLPjAAAAAA6JEmeCuLg4JSQkaP369WZHcShWq0W3Xnv2aNzMDYdNTgMAAAA4JkocHMrt7evKapF+3XdSe1OzzI4DAAAAOBxKHBxK7UAv9W4RJkn6+rdDJqcBAAAAHA8lDg7nzo51JUlz44+qoMhmchoAAADAsVDi4HC6N6mpUD8Pncou0OKdKWbHAQAAABwKJQ4Ox9XFan/cwDfrOKUSAAAA+CNKHBzS0A5nT6lctTdNh0/lmJwGAAAAcByUODikyBo+6tq4hgxD+nLtQbPjAAAAAA6DEgeHNfK6BpLO3qUyI6/Q5DQAAACAY6DEwWH1ah6qJqG+yswv4nEDAAAAwO8ocXBYVqtFo7o3lCR9vipJ+UXFJicCAAAAzEeJg0O7pU1thft7KjUzX3M3HzU7DgAAAGA6ShwcmrurVX+7/uy1cR+t3C+bzTA5EQAAAGAuShwc3p2d6snP01X7T2RrEQ//BgAAQDVHiYPD8/Vw1V87R0qSJq3YJ8PgaBwAAACqL0ocnMI9XevL3dWqzYfOaP2B02bHAQAAAExDiYNTCPXz1K3X1pEkTVy+1+Q0AAAAgHkocXAa/+jeUFaLtCzxhDYd4mgcAAAAqidKHJxG/RAfDfn9aNx/F+8xOQ0AAABgDkocnMo/ezWR1SKt2H1CCccyzI4DAAAAXHWUODiVejW8NaB1hCTplfkJ3KkSAAAA1Q4lDk7nX32bycPVqtV7T+rHrcfNjgMAAABcVZQ4OJ16Nbz1QM9GkqTXF+5SflGxyYkAAACAq4cSZ4IJEyYoKipKHTp0MDuK0xrVvaHC/D105HSupvx6wOw4AAAAwFVDiTNBXFycEhIStH79erOjOC1vd1c91reZJOn9pXt1OrvA5EQAAADA1UGJg9O69do6ah7up8y8Ir23lEcOAAAAoHqgxMFpuVgt+veAKEnStDUHlZSWbXIiAAAA4MqjxMGpXd8kRDc0q6kim6HXF+wyOw4AAABwxVHi4PSeurGFrBZp4Y5krUs6ZXYcAAAA4IqixMHpNQ3z0x0d60mSXpmXIJuNB4ADAACg6qLEoUp4NKapfNxdtOVIuv636YjZcQAAAIArhhKHKqGmn4dG92oiSRq/YJdSMvJMTgQAAABcGZQ4VBn3Xl9fLSP8dSq7QGNmxnNaJQAAAKokShyqDA9XF713Z1t5ublo9d6T+mjlfrMjAQAAAJWOEocqpVFNX71wc0tJ0ls/Jyr+8BlzAwEAAACVjBKHKuf29nU0oHUtFdkM/fObzcrMKzQ7EgAAAFBpKHGociwWi14dfI1qB3rp0KkcPTNnuwyD6+MAAABQNVDiUCUFeLnpvTvbyMVq0fdbjunTX5LMjgQAAABUCkocqqx2kcH694AWkqQ3f0pUYnKmyYkAAACAy0eJQ5V2z3X11bt5qAqKbRr77RYVFdvMjgQAAABcFkocqjSLxaJXh1wjf09XbTuazmMHAAAA4PQocajywvw9NW7g2ccOvL1ot7YfTTc5EQAAAFBxlDhUC7deW1uxrcJVbDP05OytyissNjsSAAAAUCGUOFQLFotFzw1sqSBvN20/mqF/z+WxAwAAAHBOlDhUG+EBnvpg2LWyWqRZG49o8q8HzI4EAAAAlBslrhINHjxYQUFBuu2228yOgvPo2jhET8Y2lyS9PG+n1iWdMjkRAAAAUD6UuEr08MMPa+rUqWbHwEXc162hBrWJULHN0KhpG7QrOcPsSAAAAMAlo8RVop49e8rPz8/sGLiIc48daBnhrzM5hXrwq03Kzi8yOxYAAABwSRyixB09elR33XWXatSoIS8vL11zzTXasGFDpW1/5cqVGjhwoCIiImSxWDR37twyx02YMEH169eXp6enOnXqpHXr1lVaBjgWb3dXTftbJ4X7e2r/iWyN+26H2ZEAAACAS2J6iTt9+rS6du0qNzc3LViwQAkJCXrrrbcUFBRU5vjVq1ersLCw1PKEhASlpKSUuU52draio6M1YcKE8+aYMWOGxowZo+eee06bNm1SdHS0+vXrp9TUVPuYNm3aqFWrVqW+jh07Vs5PDUcQ7OOu94e1lcUi/W/TEa3em2Z2JAAAAOCiXM0O8Prrr6tu3br64osv7MsaNGhQ5libzaa4uDg1adJE06dPl4uLiyQpMTFRvXr10pgxY/T444+XWi82NlaxsbEXzPH222/rvvvu08iRIyVJkyZN0rx58/T555/rySeflCTFx8dX5CPCgXWoH6y7OkVq2tqDuv/Ljfph9PWqH+JjdiwAAADgvEw/Evf999+rffv2uv322xUaGqq2bdvqk08+KXOs1WrV/PnztXnzZt19992y2Wzat2+fevXqpUGDBpVZ4C5FQUGBNm7cqJiYmBL7iomJ0Zo1ayq0zQuZMGGCoqKi1KFDh0rfNsrv8f7NdE3tAGXmFWnst1tUWGwzOxIAAABwXqaXuP3792vixIlq0qSJfvrpJz3wwAP65z//qSlTppQ5PiIiQkuXLtWqVas0bNgw9erVSzExMZo4cWKFM6Slpam4uFhhYWElloeFhSk5OfmStxMTE6Pbb79d8+fPV506dc5bAOPi4pSQkKD169dXODMqj5+nmybeda183F204eBpPTV7Gw8CBwAAgMMy/XRKm82m9u3b69VXX5UktW3bVtu3b9ekSZM0YsSIMtepV6+epk2bph49eqhhw4b67LPPZLFYrmbsMi1evNjsCKigOkHeeu/Otrpv6gbN2nhE7SODdEfHembHAgAAAEox/UhcrVq1FBUVVWJZixYtdOjQofOuk5KSolGjRmngwIHKycnRo48+elkZQkJC5OLiUurGKCkpKQoPD7+sbcN59G4Rpsf7n30Q+Is/JujYmVyTEwEAAAClmV7iunbtqsTExBLLdu/ercjIyDLHp6WlqXfv3mrRooVmz56tJUuWaMaMGRo7dmyFM7i7u6tdu3ZasmSJfZnNZtOSJUvUpUuXCm8Xzue+bg3VPjJIOQXFeuEHHjsAAAAAx2N6iXv00Ue1du1avfrqq9q7d6++/vprffzxx4qLiys11mazKTY2VpGRkZoxY4ZcXV0VFRWlRYsW6YsvvtA777xT5j6ysrIUHx9vv7tkUlKS4uPjSxztGzNmjD755BNNmTJFO3fu1AMPPKDs7Gz73SpRPbhYLXpl8DVysVr0044ULd1V9mMrAAAAALOYfk1chw4dNGfOHD311FN68cUX1aBBA7377rsaPnx4qbFWq1WvvvqqunXrJnd3d/vy6OhoLV68WDVr1ixzHxs2bNANN9xgfz1mzBhJ0ogRIzR58mRJ0tChQ3XixAmNGzdOycnJatOmjRYuXFjqZieo+pqF++lv1zfQxyv369m5O9ThkWD5ebqZHQsAAACQ5AAlTpJuuukm3XTTTZc0tk+fPmUub9u27XnX6dmz5yXdbXD06NEaPXr0JeVA1fZw7yZasP24Dp/K1Qs/JOg/t0ebHQkAAACQ5ACnUwKOyMfDVW/d3kYWizRr4xH9sueE2ZEAAAAASZQ44Lw6NgjWiC71JUn/nrtdeYXF5gYCAAAARIkDLmhsv2YK8/fQwZM5+njlfrPjAAAAAJQ44EJ8PVz19I0tJEnvL92jzYdOm5wIAAAA1R0lDriIm6MjFNsqXIXFhh6ftVWFxTazIwEAAKAao8QBF2GxWDR+yDWq4eOuPalZ+mxVktmRAAAAUI1R4oBLEOjtrqd+P63yv4v3KCUjz+REAAAAqK4occAluvXa2rq2XqByC4v17uI9ZscBAABANUWJAy6RxWKxH42bsf6QVu1JMzkRAAAAqiNKHFAOHeoH644OdWUzpCf+t1U5BUVmRwIAAEA1Q4kDymncwCjVCfLS0TO5mrR8n9lxAAAAUM1Q4oBy8nb//2fHfbh8nxKOZZicCAAAANUJJQ6ogNhW4erVPFRFNkPPzN0mm80wOxIAAACqCUocUAEWi0WvDG4lH3cXbT50RtPXHzY7EgAAAKoJShxQQbUCvDSmbzNJ0msLdiotK9/kRAAAAKgOKHHAZRjRJVJRtfyVkVekV+ftNDsOAAAAqgFKHHAZXF2semVwK1ks0uzNR/XrPp4dBwAAgCuLEgdcprb1gjS8Uz1J0r/nbldeYbHJiQAAAFCVUeKASvCvfs0V4uuh/Sey9fai3WbHAQAAQBVGiQMqQYCXm14bco0k6ZNf9mtd0imTEwEAAKCqosQBlSQmKky3t6sjw5DGfrtF2flFZkcCAABAFUSJAyrRuIFRqh3opUOncjRpxT6z4wAAAKAKosQBlcjP003P3hQlSZq4fJ/2pmaanAgAAABVDSUOqGT9Woapbb1AFdkMjZq6kbtVAgAAoFJVqMQdPnxYR44csb9et26dHnnkEX388ceVFgxwVhaLRe/d0Va+Hq7an5atCcv2mh0JAAAAVUiFStywYcO0bNkySVJycrL69OmjdevW6ZlnntGLL75YqQGrogkTJigqKkodOnQwOwqukLrB3nrjttaSpEkr9ungyWyTEwEAAKCqqFCJ2759uzp27ChJmjlzplq1aqVff/1VX331lSZPnlyZ+aqkuLg4JSQkaP369WZHwRUU2ypc3ZqEqLDY0GMzt6io2GZ2JAAAAFQBFSpxhYWF8vDwkCQtXrxYN998sySpefPmOn78eOWlA5yYxWLRy4Nayc/DVRsOntbMDUcuvhIAAABwERUqcS1bttSkSZP0yy+/aNGiRerfv78k6dixY6pRo0alBgScWWQNHz3ap6kk6e1FicrMKzQ5EQAAAJxdhUrc66+/ro8++kg9e/bUnXfeqejoaEnS999/bz/NEsBZd3WOVIMQH6VlFfDsOAAAAFw214qs1LNnT6WlpSkjI0NBQUH25aNGjZK3t3elhQOqAndXq56+sYXum7pBn/ySpNvb1VX9EB+zYwEAAMBJVehIXG5urvLz8+0F7uDBg3r33XeVmJio0NDQSg0IVAUxLULVtXENFRTZ9NA3m1XITU4AAABQQRUqcbfccoumTp0qSTpz5ow6deqkt956S4MGDdLEiRMrNSBQFVgsFv3n9mgFertp29F0Tfn1gNmRAAAA4KQqVOI2bdqkbt26SZJmzZqlsLAwHTx4UFOnTtV7771XqQGBqqJWgJee6N9ckvTfxXt0IjPf5EQAAABwRhUqcTk5OfLz85Mk/fzzzxoyZIisVqs6d+6sgwcPVmpAoCr5S/u6uqZ2gDLzi/TmT7vMjgMAAAAnVKES17hxY82dO1eHDx/WTz/9pL59+0qSUlNT5e/vX6kBgarExWrR8ze3lCTN3HBE8YfPmBsIAAAATqdCJW7cuHEaO3as6tevr44dO6pLly6Szh6Va9u2baUGBKqadpFBGnJtbUnSk//bqrzCYpMTAQAAwJlUqMTddtttOnTokDZs2KCffvrJvrx379565513Ki0cUFU9Gdtcgd5u2pWcqTcWJpodBwAAAE6kQiVOksLDw9W2bVsdO3ZMR44ckSR17NhRzZs3r7RwQFUV6uepN25tLUn64tckrdqTZnIiAAAAOIsKlTibzaYXX3xRAQEBioyMVGRkpAIDA/XSSy/JZuP5V8Cl6NsyXHd2rCfDkF6elyCbzTA7EgAAAJxAhUrcM888ow8++ECvvfaaNm/erM2bN+vVV1/V+++/r2effbayMwJV1hP9m8nPw1W7kjM1c8Nhs+MAAADACbhWZKUpU6bo008/1c0332xf1rp1a9WuXVsPPvigXnnllUoLCFRlgd7u+kePhvrPz7v1wg8Jur5JiOoEeZsdCwAAAA6sQkfiTp06Vea1b82bN9epU6cuOxRQnYzq3kiNQ32VW1is95fsNTsOAAAAHFyFSlx0dLQ++OCDUss/+OADtW7d+rJDAdWJu6tVr/9+k5NZm44oKS3b5EQAAABwZBU6nfKNN97QgAEDtHjxYvsz4tasWaPDhw9r/vz5lRoQqA7aRQapZ7OaWp54QmO/3aJZ93eRxWIxOxYAAAAcUIWOxPXo0UO7d+/W4MGDdebMGZ05c0ZDhgzRjh07NG3atMrOCFQLL93SSj7uLtp48LSWJaaaHQcAAAAOqkJH4iQpIiKi1A1MtmzZos8++0wff/zxZQcDqpu6wd4a3jlSH6/cr/8u3qMeTUPlYuVoHAAAAEqq8MO+AVS+v3drIF8PV205kq6paw6YHQcAAAAOiBIHOJBQP089EXv2zq//+SlRyel5JicCAACAo6HEAQ5meMd6alsvUNkFxXp5XoLZcQAAAOBgynVN3JAhQy74/pkzZy4nCwBJVqtFLw9qpYHvr9KPW49raIcT6takptmxAAAA4CDKdSQuICDggl+RkZG6++67r1RWoNpoGRGgEdfVlyQ9O3e78gqLzQ0EAAAAh1GuI3FffPHFlcoB4E/G9GmqeVuP68DJHL06f6devKWV2ZEAAADgALgmDnBQfp5ueu3WayRJ09YeVPzhM+YGAgAAgEOgxAEOrFfzMPWNCpNhSJ/+st/sOAAAAHAAlDjAwT0S01SStGB7so6n55qcBgAAAGajxAEOLirCX50aBKvYZmjamoNmxwEAAIDJKHGAExjZtYEk6avfDik9t9DkNAAAADATJQ5wAn2iwtQ41FfpuYWa8usBs+MAAADARJS4SjR48GAFBQXptttuMzsKqhgXq0UP9WosSfp45X6lZOSZnAgAAABmocRVoocfflhTp041OwaqqJtaRyi6bqCy8os0Ydles+MAAADAJJS4StSzZ0/5+fmZHQNVlIvVoif6N5MkTV9/mKNxAAAA1ZRDlbjXXntNFotFjzzySKVud+XKlRo4cKAiIiJksVg0d+7cMsdNmDBB9evXl6enpzp16qR169ZVag7gcnVpWEPtI4NUUGTTJyt5bhwAAEB15DAlbv369froo4/UunXrC45bvXq1CgtL350vISFBKSkpZa6TnZ2t6OhoTZgw4bzbnTFjhsaMGaPnnntOmzZtUnR0tPr166fU1FT7mDZt2qhVq1alvo4dO3aJnxK4PBaLRaN/vzbuq98O6WRWvsmJAAAAcLU5RInLysrS8OHD9cknnygoKOi842w2m+Li4jRs2DAVFxfblycmJqpXr16aMmVKmevFxsbq5Zdf1uDBg8+77bffflv33XefRo4cqaioKE2aNEne3t76/PPP7WPi4+O1ffv2Ul8REREV+NRAxfRoWlPX1A5QbmGxXvghwew4AAAAuMocosTFxcVpwIABiomJueA4q9Wq+fPna/Pmzbr77rtls9m0b98+9erVS4MGDdLjjz9eof0XFBRo48aNJfZvtVoVExOjNWvWVGibFzJhwgRFRUWpQ4cOlb5tVH0Wi0Uv3NJSVov0/ZZj+nVfmtmRAAAAcBWZXuKmT5+uTZs2afz48Zc0PiIiQkuXLtWqVas0bNgw9erVSzExMZo4cWKFM6Slpam4uFhhYWElloeFhSk5OfmStxMTE6Pbb79d8+fPV506dc5bAOPi4pSQkKD169dXODOqt2vrBWloh3qSpJd+3CnDMExOBAAAgKvF1cydHz58WA8//LAWLVokT0/PS16vXr16mjZtmnr06KGGDRvqs88+k8ViuYJJL83ixYvNjoBqZGzfppqz+Yh2Hs/Qb0mn1LlhDbMjAQAA4Cow9Ujcxo0blZqaqmuvvVaurq5ydXXVihUr9N5778nV1bXEdW9/lJKSolGjRmngwIHKycnRo48+elk5QkJC5OLiUurGKCkpKQoPD7+sbQNXSg1fDw25to4k6Y2Fu0xOAwAAgKvF1BLXu3dvbdu2TfHx8fav9u3ba/jw4YqPj5eLi0upddLS0tS7d2+1aNFCs2fP1pIlSzRjxgyNHTu2wjnc3d3Vrl07LVmyxL7MZrNpyZIl6tKlS4W3C1xpD/duIncXqzYdOqPdKZlmxwEAAMBVYOrplH5+fmrVqlWJZT4+PqpRo0ap5dLZYhUbG6vIyEjNmDFDrq6uioqK0qJFi9SrVy/Vrl27zKNyWVlZ2rt3r/11UlKS4uPjFRwcrHr1zl5XNGbMGI0YMULt27dXx44d9e677yo7O1sjR46s5E8NVJ4wf091bxqixTtT9cXqAxo/5BqzIwEAAOAKM7XElZfVatWrr76qbt26yd3d3b48OjpaixcvVs2aNctcb8OGDbrhhhvsr8eMGSNJGjFihCZPnixJGjp0qE6cOKFx48YpOTlZbdq00cKFC0vd7ARwNKO6N9Linan636YjGtu3qWr4epgdCQAAAFeQw5W45cuXX/D9Pn36lLm8bdu2512nZ8+el3T3vtGjR2v06NEXHQc4kg71g9S6ToC2HknXV78d0j97NzE7EgAAAK4g0x8xAODyWCwW/e36BpKkz1cnKTOv0OREAAAAuJIocUAVcFPrCDWs6aMzOYX6YvUBs+MAAADgCqLEAVWAi9WiR2KaSpI++WW/0nM4GgcAAFBVUeKAKuKma2qpaZivMvOK9Omq/WbHAQAAwBVCiQOqCKvVokd/Pxr3+aoknc4uMDkRAAAArgRKHFCF9GsZrqha/souKNZHKzkaBwAAUBVR4oAqxGq16NE+Z4/GTV1zQGdyOBoHAABQ1VDigCompkWoWtTyV05Bsb767ZDZcQAAAFDJKHFAFWOxWPTXzpGSpA+W7lVaVr7JiQAAAFCZKHFAFXR7+zpqHu6n3MJiTVtz0Ow4AAAAqESUOKAKcnOxanSvxpLOXhuXU1BkciIAAABUFkocUEX1bxmuesHeOp1TqP9tOmp2HAAAAFQSShxQRbm6WHXPdfUlSZ+s3K/CYpu5gQAAAFApKHFAFfaXDnUV4OWmQ6dy9MXqJLPjAAAAoBJQ4oAqzNfDVf/q10yS9O7iPTqRyZ0qAQAAnB0lDqjihneqp+i6gcopKNaEZXvNjgMAAIDLRIkDqjiLxaInfj8a9/Vvh3T0TK7JiQAAAHA5KHFANXBd4xB1aVhDBcU2nhsHAADg5ChxQDUx4rpISdKsjUeUV1hschoAAABUFCUOqCZ6NQ9TRICn0rLyNX3dIbPjAAAAoIIocUA14e5qVVyvxpKkD5fv42gcAACAk6LEAdXI7e3qKiLAU6mZHI0DAABwVpQ4oBrhaBwAAIDzo8QB1QxH4wAAAJwbJQ6oZjgaBwAA4NwocUA19MejcTw3DgAAwLlQ4oBqyN3Vqn/2biJJemfxbp3JKTA5EQAAAC4VJQ6opoZ2qKsWtfyVU1CsKb9yNA4AAMBZUOKAaspiseiBno0kSdPWHuDaOAAAACdBiQOqsdhW4YoI8FRaVoHmbD5qdhwAAABcAkocUI25uVh17/UNJEkfLN3L0TgAAAAnQIkDqrm7Okcq3N9TR8/k6su1XBsHAADg6ChxQDXn6eaiR/ucvVPlRJ4bBwAA4PAocQB067V1VDvQSyezCzR7E9fGAQAAODJKHAC5/uHauM9XJ8kwDJMTAQAA4HwocQAkSX9pX0c+7i7am5qllXvSzI4DAACA86DEAZAk+Xm66fb2dSVJT8/epuz8IpMTAQAAoCyUOAB2j8Y0VZ0gLx09k6tPf0kyOw4AAADKQImrRIMHD1ZQUJBuu+02s6MAFRLg7aYn+jeXJH2wbI82HjxtciIAAAD8GSWuEj388MOaOnWq2TGAy3JT61rqGxWmwmJDL/6YYHYcAAAA/AklrhL17NlTfn5+ZscALovFYtHLg1vJxWrRlsNnNHvTEbMjAQAA4A9ML3ETJ05U69at5e/vL39/f3Xp0kULFiyo1H2sXLlSAwcOVEREhCwWi+bOnVvmuAkTJqh+/fry9PRUp06dtG7dukrNATiLUD9P/bPX2QeAv/XzbuUX8QBwAAAAR2F6iatTp45ee+01bdy4URs2bFCvXr10yy23aMeOHWWOX716tQoLC0stT0hIUEpKSpnrZGdnKzo6WhMmTDhvjhkzZmjMmDF67rnntGnTJkVHR6tfv35KTU21j2nTpo1atWpV6uvYsWPl/NSA4/tHj4YK9fPQ0TO5mrmBo3EAAACOwvQSN3DgQN14441q0qSJmjZtqldeeUW+vr5au3ZtqbE2m01xcXEaNmyYiov//8hAYmKievXqpSlTppS5j9jYWL388ssaPHjweXO8/fbbuu+++zRy5EhFRUVp0qRJ8vb21ueff24fEx8fr+3bt5f6ioiIuIzvAOCYPN1c9EDPRpKkKb8e4AHgAAAADsL0EvdHxcXFmj59urKzs9WlS5dS71utVs2fP1+bN2/W3XffLZvNpn379qlXr14aNGiQHn/88Qrtt6CgQBs3blRMTEyJfcXExGjNmjUV/jznM2HCBEVFRalDhw6Vvm2gMt3Wro683M4+APy3pFNmxwEAAIAcpMRt27ZNvr6+8vDw0P333685c+YoKiqqzLERERFaunSpVq1apWHDhqlXr16KiYnRxIkTK7z/tLQ0FRcXKywsrMTysLAwJScnX/J2YmJidPvtt2v+/PmqU6fOeQtgXFycEhIStH79+gpnBq4GP083DWpbW5I0dc0Bc8MAAABAkuRqdgBJatasmeLj45Wenq5Zs2ZpxIgRWrFixXmLXL169TRt2jT16NFDDRs21GeffSaLxXKVU5e2ePFisyMAlW7EdZH6Zt0h/bQjRcfO5Coi0MvsSAAAANWaQxyJc3d3V+PGjdWuXTuNHz9e0dHR+u9//3ve8SkpKRo1apQGDhyonJwcPfroo5e1/5CQELm4uJS6MUpKSorCw8Mva9uAs2se7q/ODYNVbDM0dc1Bs+MAAABUew5R4v7MZrMpPz+/zPfS0tLUu3dvtWjRQrNnz9aSJUs0Y8YMjR07tsL7c3d3V7t27bRkyZISGZYsWVLmtXlAdfO36xtKOnuDk13JGSanAQAAqN5MP53yqaeeUmxsrOrVq6fMzEx9/fXXWr58uX766adSY202m2JjYxUZGakZM2bI1dVVUVFRWrRokXr16qXatWuXeVQuKytLe/futb9OSkpSfHy8goODVa9ePUnSmDFjNGLECLVv314dO3bUu+++q+zsbI0cOfLKfXjAScS0CFXXxjW0eu9JPT17m2Y/2NXsSAAAANWW6SUuNTVVd999t44fP66AgAC1bt1aP/30k/r06VNqrNVq1auvvqpu3brJ3d3dvjw6OlqLFy9WzZo1y9zHhg0bdMMNN9hfjxkzRpI0YsQITZ48WZI0dOhQnThxQuPGjVNycrLatGmjhQsXlrrZCVAdWSwWvfOXNury2lJtOnRGe1MzFRnkaXYsAACAasn0EvfZZ5+Va3xZ5U6S2rZte951evbseUnPuBo9erRGjx5drjxAdRHq76lezUO1KCFFb/28W+8NbW12JAAAgGrJIa+JA+CYHo1pKlerRQu2J+unHSkXXwEAAACVjhIH4JJFRfjrHz3O3uTk5fm7lF9sciAAAIBqiBIHoFwe6tVEtQO9lJyRrwWH+RECAABwtfEvMADl4unmopcGtZQkLT9u0f4T2SYnAgAAqF4ocQDKrVfzMN3QLESGLHrz591mxwEAAKhWKHEAKmRsnyayytDiXSe0KIGbnAAAAFwtlDgAFdI0zE/da519dMd/fkpUUbHN5EQAAADVAyUOQIX1rW1ToJebElMyNXXNQbPjAAAAVAuUOAAV5uMmPRrTWJL02oJd2puaaXIiAACAqo8SB+Cy3NG+jjrUD1JBsU0xb6/U/hNZZkcCAACo0ihxAC6L1WrRhOHX2l/3emuFktPzTEwEAABQtVHiAFy2UD9PffzXdvbX//hyo2w2w8REAAAAVRclDkCl6NsyXBOGnT0it+XwGU1Zc8DcQAAAAFUUJQ5ApRnQupZeuqWlJOnV+Tu1LumUyYkAAACqHkocgEp1V+dI9WsZpsJiQ3/5aI2OnM4xOxIAAECVQokDUKksFovevD1avh6ukqT3luwxOREAAEDVQokDUOn8Pd008a6z18fN3HBEvd9arsJim8mpAAAAqgZKHIAroluTmvr3gBaSpH0nsvXGwl0yDO5YCQAAcLkocQCumL93a6jRNzSWJH3yS5L+y6mVAAAAl40SB+CKeqxvUz3Ys5Ek6d3Fe/T2z4kmJwIAAHBulDgAV5TFYtG/+jXTHR3qSpLeW7pXv+0/aXIqAAAA50WJA3DFWSwWvXBLS7laLZKk+6Zu0MaDp01OBQAA4JwocQCuCg9XFy18pLskKSOvSMM+WavFCSkmpwIAAHA+lDgAV03jUF9tfraPujUJUX6RTaOmbdCK3SfMjgUAAOBUKHEArqogH3d9fk8HDWoTIZsh/fObzcrIKzQ7FgAAgNOgxAG46txcrHr9ttaqE+Sl9NxCtX7+Z63lZicAAACXhBIHwBQeri56aVAr++t3F+82MQ0AAIDzoMQBMM0NzUL186Nnb3aydv8pLdnJjU4AAAAuhhIHwFRNw/w0rFM9SdLfpmxQ3NebZBiGyakAAAAcFyUOgOkeiWli/+95W49r2Ce/yWajyAEAAJSFEgfAdKF+ntr1Un91aVhDkrRm/0m9NC+BI3IAAABloMQBcAiebi76ZlRnvTu0jSTpi9UHNHHFPnNDAQAAOCBKHACHMqhtbf17QAtJ0hsLEzVj/SGTEwEAADgWShwAh/P3bg11b9cGkqQn/rdNvf6zXJsOnTY5FQAAgGOgxAFwSE/f2Fw9mtaUJO1Py1bcV5uUnJ5ncioAAADzUeIAOCRXF6smj+yg/z3QRe4uVh1Pz1Pc15vMjgUAAGA6ShwAh2WxWNQuMljfjOosSdp48LRi3l6hYh4/AAAAqjFKHACH1y4ySLGtwiVJe1Oz1Ojp+Tp8KsfkVAAAAOagxAFwChPvaqfb2tWxv+72xjKt2pNmYiIAAABzUOIAOI03b2ute66rb39912e/adCE1TqZlW9eKAAAgKuMEgfAaVgsFj1/c0ute7q3fVn84TO66f1V2puaZWIyAACAq4cSB8DphPp7atdL/fVE/+bydDt758qYt1fopR8TuOkJAACo8ihxAJySp5uLHujZSEsf66kaPu6SpM9WJanR0/M1dc0BGQZlDgAAVE2UOABOLSLQS28PbVNi2bjvduieL9ZT5AAAQJVEiQPg9Ho0rakDrw3QB8Pa2pet2H1CDZ6ar1fmJejIaR5HAAAAqg5XswMAQGW5qXWEbmodoTcW7tKHy/dJkj75JUmf/JKkIdfWVlQtfxXbDP3t+gZydeF3WAAAwDlR4gBUOY/3b67rm4Ro2Ce/2ZfN3nRUs3VUkjR+wS6t+FdPRdbwMSsiAABAhfGraABV0nWNQnTgtQH6cPi1kqQQX/cS7/d4c7m++u2gGdEAAAAuC0fiAFRpN15TSwdeGyBJKiy26a5Pf9NvSackSc/M2a4pvx7Qp3d3UL0a3mbGBAAAuGQciQNQbbi5WDXjH12066X+ahLqK0nanZKl7m8u07S1HJUDAADOgRIHoNrxdHPR/Ie7lVj27Nztqv/kPP3r2y1Ky8o3KRkAAMDFcTolgGrJzcWqA68NUGGxTY9Mj9e8bcclSd9uPKJvNx6RJI3p01QjrquvAC83M6MCAACUQIkDUK25uVj1wbC2arc6SPO3HdfWI+kqKLZJkt5etFtvL9otSZoxqrOahfvJw9VFbi4WHlEAAABMQ4kDUO1ZLBbde30D3Xt9AyWlZev+aRuVmJJZYszQj9eWeP3mba11e/u6VzMmAACAJEocAJTQIMRHPz3aXZKUU1Ckr9Ye0ivzd5Ya969ZWxXo7a4+UWFXOyIAAKjmKHEAcB7e7q66r3tDjbiuvrYeOaOafh5KyyrQCz/s0NYj6bpv6gZJ0pBra+sf3RupaZivLBaLDMOQxWIxOT0AAKiqKHEAcBHurla1rx8sSYqs4aP/PXCdbvlgtRKOZ0iSZm86qtmbjpZYZ+q9HdW9ac2rnhUAAFR9lDgAKCc3F6teGtRKt0789bxj7v58nSSpaZiv2tcPVt+oMPVoWpMjdAAA4LJR4gCgAtpFBunAawPsr9fsO6nH/7dFmXlFOpNTaF++OyVLu1Oy9PVvhyRJ1zWqoVcHX6P6IT5XPTMAAKgaKHEAUAm6NKqhXx7vJUnKKyzWCz/s0DfrDpca9+u+k+r5n+VqHu6nT0e0167jmXp6zjalZuarQYiPvh/dVX6ePJcOAACcHyWuEg0ePFjLly9X7969NWvWLLPjADCJp5uLxg9prfFDWtuXfb/lmP75zWb7613Jmbr+9WUl1ktKy9Y1z/+s0Tc01t+7NZCL1aL4w2dUK8BTjUP9rlp+AADg2Chxlejhhx/WvffeqylTppgdBYCDuTk6QgNb19LpnEK9u3i3pq45eN6xHyzbqw+W7S2x7K7O9fTSLa24pg4AAFDiKlPPnj21fPlys2MAcFAWi0XBPu568ZZWGnFdfc3belyNavqqf6twuVgtSs3M0/BPftOe1KxS63659pC+XHtIY/s2VdwNjSlzAABUY1azA4wfP14dOnSQn5+fQkNDNWjQICUmJlbqPlauXKmBAwcqIiJCFotFc+fOLXPchAkTVL9+fXl6eqpTp05at25dpeYAgHMa1fTVP3s30YDWteRiPVvIQv08tWhMD+18sb/u69ZATcN8Nev+LhreqZ59vf/8vFsNnpqvT3/ZL8MwzIoPAABMZHqJW7FiheLi4rR27VotWrRIhYWF6tu3r7Kzs8scv3r1ahUWFpZanpCQoJSUlDLXyc7OVnR0tCZMmHDeHDNmzNCYMWP03HPPadOmTYqOjla/fv2UmppqH9OmTRu1atWq1NexY8fK+akB4Py83F30zIAo/fxoD7WvH6xXBl+jDf+OUfvIIPuYl+ftVIOn5qv+k/N0x8drdCCt7J+ZAACg6jH9dMqFCxeWeD158mSFhoZq48aN6t69e4n3bDab4uLi1KRJE02fPl0uLi6SpMTERPXq1UtjxozR448/XmofsbGxio2NvWCOt99+W/fdd59GjhwpSZo0aZLmzZunzz//XE8++aQkKT4+vqIfs4QJEyZowoQJKi4urpTtAaj6Qnw9NOuB67Tx4CmNmrpRJ7ML7O+t3X9KPf+zXJLUu3moBrWtLW93F3VtHCJPNxeTEgMAgCvF9BL3Z+np6ZKk4ODgUu9ZrVbNnz9f3bt31913361p06YpKSlJvXr10qBBg8oscJeioKBAGzdu1FNPPVViXzExMVqzZk3FPsgFxMXFKS4uThkZGQoICKj07QOoutpFBmvjs32Umpmnqb8eLHUDlCW7UrVkV2qJZROGXauezWrKx8PhfuQDAIAKcKi/0W02mx555BF17dpVrVq1KnNMRESEli5dqm7dumnYsGFas2aNYmJiNHHixArvNy0tTcXFxQoLCyuxPCwsTLt27brk7cTExGjLli3Kzs5WnTp19O2336pLly4VzgUA5xPq56mx/ZppbL9myiss1qQV+zR701EdOpVTamzc15skSfWCvfVITBMNubaODMPg5igAADgphypxcXFx2r59u1atWnXBcfXq1dO0adPUo0cPNWzYUJ999plD/GNk8eLFZkcAUA15urnokZimeiSmqWw2Q8kZeQrydtez323XrI1H7OMOncrRmJlbNGbmllLb8HZ30b1dG8jHw1W3tqutUD/Pq/kRAABAOZh+Y5NzRo8erR9//FHLli1TnTp1Ljg2JSVFo0aN0sCBA5WTk6NHH330svYdEhIiFxeXUjdGSUlJUXh4+GVtGwCuJqvVoohAL3m5u+g/t0frwGsDFD+uj3o3D73gejkFxfpg2V69vnCXOr6yRAu3J6uw2HaVUgMAgPIw/UicYRh66KGHNGfOHC1fvlwNGjS44Pi0tDT17t1bLVq00Lfffqvdu3erZ8+e8vDw0H/+858KZXB3d1e7du20ZMkSDRo0SNLZUzuXLFmi0aNHV2ibAOAoAr3d9dk9HZRXWKxx323X6ZxC7TuRpf0nshXg5ab03NJ3/L3/y42SpLrBXnqif3NNX3dYx87kKr/Ipm5NQvRUbAsFeLtddjabzdChUznafPi0pq87rJ3HM9S9aU2N6dNUdYK85e7qML9rBADAYZhe4uLi4vT111/ru+++k5+fn5KTkyVJAQEB8vLyKjHWZrMpNjZWkZGRmjFjhlxdXRUVFaVFixapV69eql27dplH5bKysrR37/9f/J+UlKT4+HgFBwerXr2zz18aM2aMRowYofbt26tjx4569913lZ2dbb9bJQA4O083F71xW/R538/KL9KKxBN64n9blZVfJEk6fCpXo7/eXGLc9PWHNX39Yfvr7+K6KrpuoAzDULHNUFpWgVIz81QrwEu+Hq7ycndRWla+3lu8W1PXuurhNT8rsoa3cguKlZqZXyrHj1uP68etx+2v/3Z9A93VOVIRgZ7ycOVumwAAmF7izt2QpGfPniWWf/HFF7rnnntKLLNarXr11VfVrVs3ubu725dHR0dr8eLFqlmzZpn72LBhg2644Qb76zFjxkiSRowYocmTJ0uShg4dqhMnTmjcuHFKTk5WmzZttHDhwlI3OwGAqsrXw1UDWtfSjdeEK6/Qpsm/HtDrCy9+c6dbJqwu974Onix9A5bz+WxVkj5blVRi2ZR7O6p7kxCHuB4aAICrzfQSZxhGucb36dOnzOVt27Y97zo9e/a8pP2MHj2a0ycBVHsWi0Ve7i56oGcj3dGhropshmr6eUiS0nMKlVdUrPeX7tGXaw+Ve9vhXoaScy0aGB2hdvUCtf7AaQ2MrqWYFmFysZ4tZEU2QwVFNn3120G9Or/sEjni83WSpP890EXX1gvS4VO5OpmdrzZ1Ayl2AIAqz/QSBwBwXEE+7iVeB3i7KUBuennQNXp50DVKzczTI9PjlZqZr3rB3moW7qe/do6UzTC0PPGEjp7J1crdJxTu76kHejTQ0a2rdeONN8rN7ez1dPd0LX0dtJuLRW4uVo3q3kijujfS/hNZ+mL1Afl5umrSin2y/eF3crdOLPksTz9PVzUM8dGYvs10IC1bp3MKNLJrAwV4Xf71ewAAOApKHACgwkL9PPX1fZ3LfO+uzpGSpCf6N5ckFRYW6ujW8u+jYU1fvTTo7LNDH+/fXOk5hfp8dZLWJZ3Smv0nS4zNzCvSliPp9iN1kvTu4j2adFc79W8Vruz8IuUWFivE18P+fl5hsbYeSVe4v6e8PVzk7e4ib3dz/no0DEM/J6Roy+EzWrP/pO7sWE83R0fI041rAQEA/48SBwBwKgHebnq0T1NJUlpWvr5ae0idGwbrx63HNW3twTLXOXe3zUtRN9hLM0Z1UUTg2ZtrbTl8Rln5RbquUY0LnqppsxnKKSyWr8el/dWakVeoPSmZysgr0ldrD2nxzpRSYzYfOqPHZ23Vk7HN1bVRiOqHeMvPk6OKAFDdUeIAAE4rxNdDD8c0kSR1alhDLw1qpWKbodM5Barh4678IptenpdQruv3Dp/KVb93Virz9zt0no+bi0UDW0co2Mddn/7hxisWi9SjaU3lF9rkYrUop6BIjWr6am3SScW0CNPDvZvox63H9e+52y+axdfDVVn5RXptwf9fGxhZw1tPxbbQmn1pCvR21+zNR9S9SU3Vr+GjVrUD9OHyvVq7/6Se6N9cf+/W8JI/9+U4kJatdQdOaUjb2nJ1uTqPhTiTU6DC4v+/XhMAqhNKHACgSnGxWuynS3q6uejlQdcouk6g/jVrq9rWC1STUF+F+nnqg2VnHz1zbb1A/atfc1kt0umcQn2wbI+2H8246H4Kiw3N3ny01HLDkJYnniixbNOhM5KkL1Yf0BerD5S5vTB/D6Vk5Ov2dnX07wFRCvB204nMfD30zSat3X/KPu7gyZxSRxa/+q10SX153k79uPW4Xrv1GoX6eSr49+sbC4pscrVaZP39RjKGYWjlnjR5uFrVqUGwpLM3l3FzsWr70XQ9PP3sIyZuaBaqm9tEqHWdQEnSqewCPT17mxbtTFHx7xcqPvG/rbq+cYgahvjI1cUqH3cXxfVqXOLREHtTM3X/l5vk4+GqyGBvdW1cQ3WDvTVrwxGdzinQntQsXVM7QMM7RernHcc1d6OLZqRuUPemodp48LQWJaTIy81FuYXF9m22rhOgNnUD1bFBsJLT83Qqu0BHTufq+y3H1LlhsN68LVq1fz+yeu5zV1WFxTb9fcoGrdh9Qv/o0VBj+jS97EdzGIbBDYMAB2Mxynt7SFSajIwMBQQEKD09Xf7+/qZmKSws1Pz580vccAC4EOYMystZ5kxBkU2f/LJfP+1I1raj6fr6753VJMxXP+1I1ofL9im3sFj+nq7ycHVRYkqmfb16wd6KrhuoH7YckyQ1DPFRem6havp5aFdyZpn7WvBwN0UEel30xisFRTb9ui9Nz3+/QwfK8XiG86kV4KmB0RH6eOX+cq876a52Ss3M07jvdlzyOu4uVhUU28q9ryvho7+2U7+W4WbHuCJyC4o1/NO19l8anNOhfpDWHzhtfz2kbW093r+5zuQWqHn4+f/9MX/bcT0xa6tC/T30w0PXm3ataHk4y88ZR5GZV6g3f0rUqO4NFeDlpj2pWVqXdEodGwTr2npBJcYahqEtR9JVv4a3ArzctC7plAqLDX2z/pDm/f5szwAvN0XXDdSt19bWooQU/ZyQojdvay1/Tze1jPBXsI/7RY/WZ+cXKT23UL/uO6nrG4coPMCzxC8S0nMKZbFKblar0rLyVTfYW5K0YvcJLd2ZolB/T0VF+GvkF+slSeNuipKL1aJuTUJUw9dDyxNTFVXLXw1r+srFalF+foEWLFigAQPMnzPl6QaUOBNR4uDMmDMor6o4Zw6ezNb38cd093X1L1jEioptslosKrIZ2nk8QzkFxWoQ4qPwAM9y7zOvsFibD51RxwbB9scynJOVXyQfdxdZLBb9sOWYxn23XadzCsu9j/JoGOKj525uqZT0PK1NOqkmoX76YnVSmQ9y/6MbrwnX6r0nlZ77//mGd6pX5lHFP4sI8NSzN0VpT2qW9p3I0u6ULO08fvGjp+eE+XvolUHXqHeLUFksFp3OLpC3h0uJI1aGYaiw2JC768VPD03PKdSmQ6cVFeGvhGMZ6tYkRDZDl7RueS1OSNEve04oNTNfdYK8dHN0bZ3JLdCSnama/OuBcm+vYU0ffXFPB83ZfFQrd59QcnqeXF2sOnSq9C8LmoX52X9x0S4ySH+7voGOncmVh5uLhnWsV2o+mqEq/pypbCkZeTqdU6DjZ/I0cvL68467s2NdPdm/hYpsNs3ZfFQvz9tZKftfNranGoT4lFhWWGzTzA2H9cyc0qeZ//nI+5UwoG6x3h0Va/qcocQ5CUocnBlzBuXFnLn6DMPQoVM5+nDZPs3YcFhS6aMykjSiS6RuaB6qD5buVQ1fd43q3kir96ZpXdIpjR9yjQxD8nSzauysrVq5++ypor4ervpgWFv1bBZaar+5BcUqNgwlHMvQTzuSVVBk09JdqfJwtSq3sFhT7+2oJmF+MgxDCccz5OnmooYhPvbftBcU2XQ8PVcR/u72OVNoWJRbUKwavmVfA2cYhmzG2dNpk9KylZyep7b1ArX1SLreX7pHv+47aT/t84/qBXuXKCw3R0cor7BYPyf8/41mRnVvqE4NgtWmbqCCfdztOXclZ+j2SWuUmVf29ZP1gr11Q7Oayi+yafr6s9//IG83ZecXq6DYppp+HoptFa4Wtfw1ff1hvXlba4X5ecrHw0VFNkOTVuzT7E1H1bNZTfl4uGri8n1l/0H/ydi+TTW6VxO9+dMufbPusE5lF1zSepWhd/NQBXi76bZ2ddSlYQ1J0pHTuTpyOleNQ321dFeKXKxW2QxDXRrWsB9FqSzn+zmTV1isMTPj5e/ppvjDZ+TuapWfp6sSk7NUw8ddiSmZem5glEZ0qV/mKbdbDp/R2v0n9e7iPWoa7qdJd12rWgFe9vcNw9A36w7r6Tnb1K1JiI6czlV+YbHG9G2mwmKbXp23s9R1tkPb19VTNzZXem6hgn3cS9206FJOY913Ikujpm7Qicx8ebi5KK+gWK4uFnVsEKycgmJtPHhaOQXFemVwK7WKCNCDX23S0TO5FfnWXtS5U8KdkavF0OIx3VW/prn/HqfEOQlKHJwZcwblxZxxHPlFxTp6OlcNa/qWa73i348kNg/3uyo3MKnsOWMYht5etFvvL91bCekcz/09Gml4p3oXLUbxh8/o6OlczY0/qkV/KKu1AjwV5O2uw6dyVDvIS4Pb1tao7g31yS/79dGK/WpY00dHTueqa+MQzdp4pFIy940K0z3X1dea/Sc1dc1BGYahjLwiNQ711ZmcQj3Uq7Fq+Lqrde1A1avx/5/rZFa+Jv96QNc1ClFULX8FeLvp4Mlsvf1zon5NPKYTeWfLT4ivu5qF+2nt/lNllvg/u6FZTX04vJ283F10OrtAH63cr0kryi7PT/Rvrt4tQjXi83U6np5XKd+PP3KxWtS1cYieim2uX/ac0PdbjmlI2zq6uU2EDqRlK7puoIZ/+pvWJZ26+MbOo1/LMI3oUl+bD59RSkae7uhQTw1r+mjCsr1l/v/khZtb6kRmvrzcXXRX50gZhiFfD9cSPw/Scwp1LD1XRcWGmoX7aXdKpiICvfTZqv2au/nYRUvkbe3qqFuTEF1TO0C93lphXx4R4KnkjDz9pX1dNQ711e6UTOUV2vT9lmNqHxmkF29ppagIfxUW2+RisciQlF1QpH2pWWpTN1CHT+Vq+7F0dW9aUwVFNn20cp8+WrFf/p6uGtEoT/+8w/y/myhxToISB2fGnEF5MWdQXldyzkxbe1Bv/5woX09XuVgsahbup/wim/adyNLhU2f/kXlN7QBtO5p+0W0Nuba2/to5Uhl5Rarp66Ff96XpdE6BZqw/opPZ+frjv7Tuua6+diVnlLhZTXm8MzRaA1tH6FR2gWr6eehYep6y8ooUWcO7Qs8TPHwqR7/sSdMtbSLkc4mPx5DOFvr5246rW5MQ7U7J0vLEVE359YCyC67caW99o8Lk4+GqOWXcUOhquZSjTX6eruc9OitJzcP9znudbEU1DPGRr6erktKyL7jvUD8PPdizke7uUl9pWfkK9T//Kd2/7Dmhr9Ye0sFTOfpnr8bq3yq8Um5wM3XNgTKvqX1pUCsN71jvqt98yJH+bipPN3D8K1QBAAAq2V87R+qvvz+Q/kJyC4q1YPtxRdbw0cz1h5VXVKwF25L1t24NFNMiVO0ig0utExVx9h9f/+p39kH3xTbjvNeLFdsMnczKVw1fDy3dlapVe06ooNhQnSAv1Q321s3REZLO3uzhjyXr3D++z911s6LqBntrWKd65V7PxWrRwN+zdWwQrI4NgvV4/+b29zccOKUpaw7qZFa+2kUGaXinSAX7uKvYZigtK18FxTZ9/dshffaHx3O4WC0XPFr2x1NcL1Wgt5vO5BTqpta19GRsc9UK8FJ2QZH8Pd3spyue+98lO1P0tykbytzOMze20PVNQtQ83E/LE0/oxR8TlJSWbX+/a+MaevO2aPvzJc/Jyi+Su4tVbi6WEgUoM69Qb/28Wz4eLhrctraW7Tqhj3/ZrxN/uJbU39NVGRcoZOc8NzBKI7s2sL8+d3zm3P4Mw9DSXakK8nEvcbOSCxU4SerWpKa6Nal50f2X191d6uvuLvUrfbvVDUfiTMSRODgz5gzKizmD8mLOVH2nsgs05dcDurlNhBrV9C1xbaN09tTfwmJDo6Zu0K/7TtrXG9y2tl4Z3EorEk/oyOlczd9+XG3rBumxmEZa/PNCxcbGyt3dvdx5ktPz9Nmq/SqyGWoVEaCb20TI7TynDu9JyZTFYlFNP4+L3mG2otJzCrX9WLoia3irdqCXliWmKjE5SyG+7lp/4JQGRkdckaJVnTjSzxmOxAEAAMDhBfu469E+Te2vLRaLXP5w0NLD1UUertLX93WWdPZmMmF+ngr6/bmHsdfUkiTd1/3sg+0LCwvt26mI8ABPPTMg6pLGNgnzq9A+yiPA201dG4fYX/dqHqZezcMkSbe3r3vF9w/HRYkDAACAU7jQc+2A6uTK31oKAAAAAFBpKHEAAAAA4EQocQAAAADgRChxAAAAAOBEKHEAAAAA4EQocQAAAADgRChxAAAAAOBEKHEAAAAA4EQocQAAAADgRChxAAAAAOBEKHEAAAAA4EQocQAAAADgRChxAAAAAOBEKHEAAAAA4ERczQ5QnRmGIUnKyMgwOYlUWFionJwcZWRkyM3Nzew4cALMGZQXcwblxZxBeTFnUF6ONGfOdYJzHeFCKHEmyszMlCTVrVvX5CQAAAAAHEFmZqYCAgIuOMZiXErVwxVhs9l07Ngx+fn5yWKxmJolIyNDdevW1eHDh+Xv729qFjgH5gzKizmD8mLOoLyYMygvR5ozhmEoMzNTERERslovfNUbR+JMZLVaVadOHbNjlODv72/6BIZzYc6gvJgzKC/mDMqLOYPycpQ5c7EjcOdwYxMAAAAAcCKUOAAAAABwIpQ4SJI8PDz03HPPycPDw+wocBLMGZQXcwblxZxBeTFnUF7OOme4sQkAAAAAOBGOxAEAAACAE6HEAQAAAIATocQBAAAAgBOhxAEAAACAE6HEQZI0YcIE1a9fX56enurUqZPWrVtndiRcBStXrtTAgQMVEREhi8WiuXPnlnjfMAyNGzdOtWrVkpeXl2JiYrRnz54SY06dOqXhw4fL399fgYGB+tvf/qasrKwSY7Zu3apu3brJ09NTdevW1RtvvHGlPxqukPHjx6tDhw7y8/NTaGioBg0apMTExBJj8vLyFBcXpxo1asjX11e33nqrUlJSSow5dOiQBgwYIG9vb4WGhupf//qXioqKSoxZvny5rr32Wnl4eKhx48aaPHnylf54uAImTpyo1q1b2x+k26VLFy1YsMD+PvMFF/Laa6/JYrHokUcesS9jzuCPnn/+eVkslhJfzZs3t79fZeeLgWpv+vTphru7u/H5558bO3bsMO677z4jMDDQSElJMTsarrD58+cbzzzzjDF79mxDkjFnzpwS77/22mtGQECAMXfuXGPLli3GzTffbDRo0MDIzc21j+nfv78RHR1trF271vjll1+Mxo0bG3feeaf9/fT0dCMsLMwYPny4sX37duObb74xvLy8jI8++uhqfUxUon79+hlffPGFsX37diM+Pt648cYbjXr16hlZWVn2Mffff79Rt25dY8mSJcaGDRuMzp07G9ddd539/aKiIqNVq1ZGTEyMsXnzZmP+/PlGSEiI8dRTT9nH7N+/3/D29jbGjBljJCQkGO+//77h4uJiLFy48Kp+Xly+77//3pg3b56xe/duIzEx0Xj66acNNzc3Y/v27YZhMF9wfuvWrTPq169vtG7d2nj44Yfty5kz+KPnnnvOaNmypXH8+HH714kTJ+zvV9X5QomD0bFjRyMuLs7+uri42IiIiDDGjx9vYipcbX8ucTabzQgPDzfefPNN+7IzZ84YHh4exjfffGMYhmEkJCQYkoz169fbxyxYsMCwWCzG0aNHDcMwjA8//NAICgoy8vPz7WOeeOIJo1mzZlf4E+FqSE1NNSQZK1asMAzj7Bxxc3Mzvv32W/uYnTt3GpKMNWvWGIZx9pcHVqvVSE5Oto+ZOHGi4e/vb58njz/+uNGyZcsS+xo6dKjRr1+/K/2RcBUEBQUZn376KfMF55WZmWk0adLEWLRokdGjRw97iWPO4M+ee+45Izo6usz3qvJ84XTKaq6goEAbN25UTEyMfZnValVMTIzWrFljYjKYLSkpScnJySXmRkBAgDp16mSfG2vWrFFgYKDat29vHxMTEyOr1arffvvNPqZ79+5yd3e3j+nXr58SExN1+vTpq/RpcKWkp6dLkoKDgyVJGzduVGFhYYl507x5c9WrV6/EvLnmmmsUFhZmH9OvXz9lZGRox44d9jF/3Ma5Mfxccm7FxcWaPn26srOz1aVLF+YLzisuLk4DBgwo9efKnEFZ9uzZo4iICDVs2FDDhw/XoUOHJFXt+UKJq+bS0tJUXFxcYuJKUlhYmJKTk01KBUdw7s//QnMjOTlZoaGhJd53dXVVcHBwiTFlbeOP+4BzstlseuSRR9S1a1e1atVK0tk/U3d3dwUGBpYY++d5c7E5cb4xGRkZys3NvRIfB1fQtm3b5OvrKw8PD91///2aM2eOoqKimC8o0/Tp07Vp0yaNHz++1HvMGfxZp06dNHnyZC1cuFATJ05UUlKSunXrpszMzCo9X1xN2SsAwOnFxcVp+/btWrVqldlR4OCaNWum+Ph4paena9asWRoxYoRWrFhhdiw4oMOHD+vhhx/WokWL5OnpaXYcOIHY2Fj7f7du3VqdOnVSZGSkZs6cKS8vLxOTXVkciavmQkJC5OLiUuouPSkpKQoPDzcpFRzBuT//C82N8PBwpaamlni/qKhIp06dKjGmrG38cR9wPqNHj9aPP/6oZcuWqU6dOvbl4eHhKigo0JkzZ0qM//O8udicON8Yf3//Kv2XclXl7u6uxo0bq127dho/fryio6P13//+l/mCUjZu3KjU1FRde+21cnV1laurq1asWKH33ntPrq6uCgsLY87gggIDA9W0aVPt3bu3Sv+MocRVc+7u7mrXrp2WLFliX2az2bRkyRJ16dLFxGQwW4MGDRQeHl5ibmRkZOi3336zz40uXbrozJkz2rhxo33M0qVLZbPZ1KlTJ/uYlStXqrCw0D5m0aJFatasmYKCgq7Sp0FlMQxDo0eP1pw5c7R06VI1aNCgxPvt2rWTm5tbiXmTmJioQ4cOlZg327ZtK/ELgEWLFsnf319RUVH2MX/cxrkx/FyqGmw2m/Lz85kvKKV3797atm2b4uPj7V/t27fX8OHD7f/NnMGFZGVlad++fapVq1bV/hlj2i1V4DCmT59ueHh4GJMnTzYSEhKMUaNGGYGBgSXu0oOqKTMz09i8ebOxefNmQ5Lx9ttvG5s3bzYOHjxoGMbZRwwEBgYa3333nbF161bjlltuKfMRA23btjV+++03Y9WqVUaTJk1KPGLgzJkzRlhYmPHXv/7V2L59uzF9+nTD29ubRww4qQceeMAICAgwli9fXuJ2zjk5OfYx999/v1GvXj1j6dKlxoYNG4wuXboYXbp0sb9/7nbOffv2NeLj442FCxcaNWvWLPN2zv/617+MnTt3GhMmTDD9ds6omCeffNJYsWKFkZSUZGzdutV48sknDYvFYvz888+GYTBfcHF/vDulYTBnUNJjjz1mLF++3EhKSjJWr15txMTEGCEhIUZqaqphGFV3vlDiYBiGYbz//vtGvXr1DHd3d6Njx47G2rVrzY6Eq2DZsmWGpFJfI0aMMAzj7GMGnn32WSMsLMzw8PAwevfubSQmJpbYxsmTJ40777zT8PX1Nfz9/Y2RI0camZmZJcZs2bLFuP766w0PDw+jdu3axmuvvXa1PiIqWVnzRZLxxRdf2Mfk5uYaDz74oBEUFGR4e3sbgwcPNo4fP15iOwcOHDBiY2MNLy8vIyQkxHjssceMwsLCEmOWLVtmtGnTxnB3dzcaNmxYYh9wHvfee68RGRlpuLu7GzVr1jR69+5tL3CGwXzBxf25xDFn8EdDhw41atWqZbi7uxu1a9c2hg4dauzdu9f+flWdLxbDMAxzjgECAAAAAMqLa+IAAAAAwIlQ4gAAAADAiVDiAAAAAMCJUOIAAAAAwIlQ4gAAAADAiVDiAAAAAMCJUOIAAAAAwIlQ4gAAAADAiVDiAABwUhaLRXPnzjU7BgDgKqPEAQBQAffcc48sFkupr/79+5sdDQBQxbmaHQAAAGfVv39/ffHFFyWWeXh4mJQGAFBdcCQOAIAK8vDwUHh4eImvoKAgSWdPdZw4caJiY2Pl5eWlhg0batasWSXW37Ztm3r16iUvLy/VqFFDo0aNUlZWVokxn3/+uVq2bCkPDw/VqlVLo0ePLvF+WlqaBg8eLG9vbzVp0kTff//9lf3QAADTUeIAALhCnn32Wd16663asmWLhg8frjvuuEM7d+6UJGVnZ6tfv34KCgrS+vXr9e2332rx4sUlStrEiRMVFxenUaNGadu2bfr+++/VuHHjEvt44YUX9Je//EVbt27VjTfeqOHDh+vUqVNX9XMCAK4ui2EYhtkhAABwNvfcc4++/PJLeXp6llj+9NNP6+mnn5bFYtH999+viRMn2t/r3Lmzrr32Wn344Yf65JNP9MQTT+jw4cPy8fGRJM2fP18DBw7UsWPHFBYWptq1a2vkyJF6+eWXy8xgsVj073//Wy+99JKks8XQ19dXCxYs4No8AKjCuCYOAIAKuuGGG0qUNEkKDg62/3eXLl1KvNelSxfFx8dLknbu3Kno6Gh7gZOkrl27ymazKTExURaLRceOHVPv3r0vmKF169b2//bx8ZG/v79SU1Mr+pEAAE6AEgcAQAX5+PiUOr2xsnh5eV3SODc3txKvLRaLbDbblYgEAHAQXBMHAMAVsnbt2lKvW7RoIUlq0aKFtmzZouzsbPv7q1evltVqVbNmzeTn56f69etryZIlVzUzAMDxcSQOAIAKys/PV3Jycollrq6uCgkJkSR9++23at++va6//np99dVXWrdunT777DNJ0vDhw/Xcc89pxIgRev7553XixAk99NBD+utf/6qwsDBJ0vPPP6/7779foaGhio2NVWZmplavXq2HHnro6n5QAIBDocQBAFBBCxcuVK1atUosa9asmXbt2iXp7J0jp0+frgcffFC1atXSN998o6ioKEmSt7e3fvrpJz388MPq0KGDvL29deutt+rtt9+2b2vEiBHKy8vTO++8o7FjxyokJES33Xbb1fuAAACHxN0pAQC4AiwWi+bMmaNBgwaZHQUAUMVwTRwAAAAAOBFKHAAAAAA4Ea6JAwDgCuBqBQDAlcKROAAAAABwIpQ4AAAAAHAilDgAAAAAcCKUOAAAAABwIpQ4AAAAAHAilDgAAAAAcCKUOAAAAABwIpQ4AAAAAHAi/wd3JINUGeRRXgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot the training loss history with logarithmic scaling.\n", "plt.figure(figsize=(10, 5)) # Create figure with wide aspect ratio for clarity\n", "plt.plot(losses) # losses: List[float] - historical EMA loss values.\n", "plt.title('Training Loss Over Time')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.yscale('log') # Use the log scale to better visualize exponential decay.\n", "plt.grid(True) # Add a grid for easier value reading.\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "M2ql0KwYJLqn" }, "source": [ "## Visualization functions\n", "\n", "Here, we can create several utilities for:\n", "\n", "- Sample generation;\n", "- Forward/reverse process visualization; and\n", "- Training progress tracking." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 212 }, "id": "thP6DDl56iXM", "outputId": "68b47408-bbc1-40e8-fb90-47d43230984e" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAykAAADDCAYAAACGXV4TAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAg4UlEQVR4nO3de1DVdf7H8TcXORggouANRBDwGqFSbqkIbJbbRitmul2WxG62baa75mb1s8x2crKamGmrtabExrU0SqPpajlitrbrpolipoJU3iBAUFAr4Pv7o/GshH7eyuHEB3k+Zpqp8zrfcz7nw/d8z/fd95zP28dxHEcAAAAAwBK+bT0AAAAAADgVRQoAAAAAq1CkAAAAALAKRQoAAAAAq1CkAAAAALAKRQoAAAAAq1CkAAAAALAKRQoAAAAAq1CkAAAAALBKhylS5s+fLz4+Pi3aNjc3V3x8fKS0tLR1B3WK0tJS8fHxkdzcXK89BwAAANAetIsipaioSP7whz9IZGSkuFwu6dOnj9x0001SVFTU1kMDAAAA0Mp8HMdx2noQJm+++abccMMN0q1bN7n11lslNjZWSktL5aWXXpLKykp57bXXZOLEierj1NfXS319vQQGBp7zGBoaGuTHH38Ul8vV4qsxmtLSUomNjZUlS5ZIdna2V54DAAAAaA/823oAJsXFxZKVlSX9+/eX9evXS0REhDubOXOmpKSkSFZWlhQWFkr//v1P+xh1dXUSFBQk/v7+4u/fspfr5+cnfn5+LdoWAAAAwLmx+uteTzzxhBw7dkxeeOGFJgWKiEh4eLgsXrxY6urqZNGiRSLyv9+d7NixQ2688UYJCwuTMWPGNMlOdfz4cbnnnnskPDxcQkJC5He/+53s379ffHx8ZP78+e77ne43KTExMZKRkSEbNmyQkSNHSmBgoPTv319eeeWVJs9RVVUl9957ryQmJkpwcLB06dJFrrrqKtm6dWsrzhQAAABw/rD6Ssrbb78tMTExkpKSctp87NixEhMTI++8806T2ydPniwJCQny2GOPienbbNnZ2bJy5UrJysqSSy+9VAoKCuTqq68+6/Ht2bNHrrvuOrn11ltl6tSp8vLLL0t2drYkJyfL0KFDRUSkpKREVq9eLZMnT5bY2FgpKyuTxYsXS2pqquzYsUP69Olz1s8HAAAAdATWFik1NTVy4MABmTBhgvF+F110keTn58vRo0fdtyUlJcny5cuN223evFlWrlwps2bNkqefflpERO666y6ZNm3aWV/l+Oqrr2T9+vXuImrKlCnSt29fWbJkiTz55JMiIpKYmCi7du0SX9//XbTKysqSQYMGyUsvvSTz5s07q+cCAAAAOgprv+51sugICQkx3u9kfuTIEfdtd955p/r477//voj8VJicasaMGWc9xiFDhjS5yhMRESEDBw6UkpIS920ul8tdoDQ0NEhlZaUEBwfLwIEDZfPmzWf9XAAAAEBHYW2RcrL4OPUKyemcrpiJjY1VH//rr78WX1/fZveNj48/6zFGR0c3uy0sLEwOHz7s/u/GxkZ5+umnJSEhQVwul4SHh0tERIQUFhZKTU3NWT8XAAAA0FFYW6SEhoZK7969pbCw0Hi/wsJCiYyMlC5durhv69y5s7eHJyJyxhW/Tv0dzGOPPSZ/+ctfZOzYsbJs2TL54IMPZM2aNTJ06FBpbGz8RcYJAAAAtCfW/iZFRCQjI0NefPFF2bBhg3uVrlN98sknUlpaKtOnTz/nx+7Xr580NjbK3r17JSEhwX37nj17PBrzz+Xl5Ul6erq89NJLTW6vrq6W8PDwVn0uAAAA4Hxg7ZUUEZE5c+ZI586dZfr06VJZWdkkq6qqkjvvvFMuuOACmTNnzjk/9vjx40VE5Lnnnmty+zPPPNPyAZ+Gn59fsxXGXn/9ddm/f3+rPg8AAABwvrD6SkpCQoIsXbpUbrrpJklMTGzWcb6iokJeffVViYuLO+fHTk5OlkmTJklOTo5UVla6lyDetWuXiEirdZbPyMiQBQsWyLRp02TUqFGybds2+ec//3nG5pMAAABAR2d1kSLyU8+TQYMGycKFC92FSffu3SU9PV0eeOABufDCC1v82K+88or06tVLXn31VVm1apWMGzdOVqxYIQMHDpTAwMBWGf8DDzwgdXV1snz5clmxYoWMGDFC3nnnHZk7d26rPD4AAABwvvFxTN0OO6AvvvhChg8fLsuWLZObbrqprYcDAAAAdDhW/ybF244fP97stpycHPH19ZWxY8e2wYgAAAAAWP91L29atGiRfP7555Keni7+/v7y3nvvyXvvvSd33HGH9O3bt62HBwAAAHRIHfrrXmvWrJFHHnlEduzYIbW1tRIdHS1ZWVny4IMPir9/h67fAAAAgDbToYsUAAAAAPbp0L9JAQAAAGAfihQAAAAAVqFIAQAAAGAVa38d3lod31sqMzPTmOfm5hrzWbNmebS9N38q5Onczp8/35iXlpYac+21e5vNc6vR5rZfv37GvKamxpivW7fOmGvvC2//xI3jAvvumdi877b1frt69Wpj3rVrV2Oelpbm0fOfz3OrvaeffvppYz5t2jRj3pbHBBHvny8MGzbMmGvvW2+zed/tCOdiXEkBAAAAYBWKFAAAAABWoUgBAAAAYBWKFAAAAABWoUgBAAAAYBWKFAAAAABWsXYJYm/TllTUlmYLDQ015trSb+2Z9tq0ZfHaetk7m8XExBhzbZlWjbbfVldXe/T47R3HhZZj32072jKu2n7dkedOo82dtsSw5nw+Jojo+9aECROMubZvf/HFF+c2oPNIRzgX40oKAAAAAKtQpAAAAACwCkUKAAAAAKtQpAAAAACwCkUKAAAAAKtQpAAAAACwCkUKAAAAAKu02z4p2trZ2vrQ2trcS5cuNeZTp0415h2Z1g+Bdc/PTOs14W3tYd10E44LbYd9t+1o+7XWY2bdunWtNxjLaPtlTk6OMdf6pHhKO2a197+Npz14Zs2aZcyzs7M9evzz2flwLsaVFAAAAABWoUgBAAAAYBWKFAAAAABWoUgBAAAAYBWKFAAAAABWoUgBAAAAYBWKFAAAAABWabd9Urp27erR9sOHDzfm2vrQHbkfgqfrxmvrmmvrop/PtH4HniooKDDm7X1Nfo4LbYd913u0Xh5afx+Np70sbOZpb6StW7cac0+PCTb0ovAmT88XMjMzjbl2zD+f9+2OcC7GlRQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVmm3fVK0NfE9XTPf034L5zNt7Wxt7pcsWWLMtZ4ApaWlxhxnlpub29ZD8CqOC+ev83nf1fodzJw505hrPWRSU1ONeUxMjDFvz7TPK+3zRutjov3ttD4p5/sxRevFofWx2bt3rzHvyH1SOsK5GFdSAAAAAFiFIgUAAACAVShSAAAAAFiFIgUAAACAVShSAAAAAFiFIgUAAACAVShSAAAAAFil3fZJ8TZP140/n9fm1l6b1s9AW9s7MzPTmGtrd7dn53O/gvMBx4UzY99tOW2/WLp0qTHXjolbtmzx6PnbM+21aX1QNMOGDfNo+47+vtF6bdTU1Bhz7XxCy9uzjnAuxpUUAAAAAFahSAEAAABgFYoUAAAAAFahSAEAAABgFYoUAAAAAFahSAEAAABgFYoUAAAAAFahT8oZeLp2edeuXVtlHDZKS0vzKNfmRlu7W9t+3bp1xtxm2pr9/fr1M+bamvLteW5swHHhzNh3W06bu+zsbK8+v6e9Pjoyjglm2r6r5aGhoR5t7+0+OW2pI5yLcSUFAAAAgFUoUgAAAABYhSIFAAAAgFUoUgAAAABYhSIFAAAAgFUoUgAAAABYhSIFAAAAgFXok3IG2vrP2pr+57PMzExjPnPmTK8+/8MPP2zMtfHZTFvzvbS01KPH19aMhxnHhTNj37XXW2+91dZDwBl4+r6wnfb6PO3Ro/VR0c4X2rOOcC7GlRQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVvFxHMdp60EAAAAAwElcSQEAAABgFYoUAAAAAFahSAEAAABgFYoUAAAAAFahSOmgSktLxcfHR3Jzc9t6KAAAAEAT51yk5Obmio+Pj/sff39/iYyMlOzsbNm/f783xthmnnvuuTY/ibdhDAAAAMAvyb+lGy5YsEBiY2PlxIkT8tlnn0lubq5s2LBBtm/fLoGBga05xjbz3HPPSXh4uGRnZ3foMQAAAAC/pBYXKVdddZVcfPHFIiJy2223SXh4uDz++OOSn58vU6ZMabUBthd1dXUSFBTU1sMAAAAA2r1W+01KSkqKiIgUFxe7b9u5c6dcd9110q1bNwkMDJSLL75Y8vPzm21bXV0tf/7znyUmJkZcLpdERUXJzTffLBUVFe77lJeXy6233io9e/aUwMBASUpKkqVLlzZ5nJO/s3jyySflhRdekLi4OHG5XHLJJZfIpk2bmtz30KFDMm3aNImKihKXyyW9e/eWCRMmSGlpqYiIxMTESFFRkRQUFLi/2paWliYi//vKW0FBgdx1113So0cPiYqKEhGR7OxsiYmJafYa58+fLz4+Ps1uX7ZsmYwcOVIuuOACCQsLk7Fjx8qHH36ojuHkvM2aNUv69u0rLpdL4uPj5fHHH5fGxsZm85udnS2hoaHStWtXmTp1qlRXVzcbCwAAAGCDFl9J+bmTJ/dhYWEiIlJUVCSjR4+WyMhImTt3rgQFBcnKlSslMzNT3njjDZk4caKIiNTW1kpKSop8+eWXcsstt8iIESOkoqJC8vPzZd++fRIeHi7Hjx+XtLQ02bNnj9x9990SGxsrr7/+umRnZ0t1dbXMnDmzyViWL18uR48elenTp4uPj48sWrRIrr32WikpKZFOnTqJiMikSZOkqKhIZsyYITExMVJeXi5r1qyRb775RmJiYiQnJ0dmzJghwcHB8uCDD4qISM+ePZs8z1133SURERHy0EMPSV1d3TnP2SOPPCLz58+XUaNGyYIFCyQgIED+/e9/y9q1a+XKK680juHYsWOSmpoq+/fvl+nTp0t0dLT861//kvvvv18OHjwoOTk5IiLiOI5MmDBBNmzYIHfeeacMHjxYVq1aJVOnTj3n8QIAAAC/COccLVmyxBER56OPPnK+++4759tvv3Xy8vKciIgIx+VyOd9++63jOI5z+eWXO4mJic6JEyfc2zY2NjqjRo1yEhIS3Lc99NBDjog4b775ZrPnamxsdBzHcXJychwRcZYtW+bOfvjhB+eyyy5zgoODnSNHjjiO4zh79+51RMTp3r27U1VV5b7vW2+95YiI8/bbbzuO4ziHDx92RMR54oknjK916NChTmpq6hnnYMyYMU59fX2TbOrUqU6/fv2abfPwww87p0737t27HV9fX2fixIlOQ0PDaV+3aQyPPvqoExQU5OzatavJ7XPnznX8/Pycb775xnEcx1m9erUjIs6iRYvc96mvr3dSUlIcEXGWLFlyppcPAAAAtIkWf91r3LhxEhERIX379pXrrrtOgoKCJD8/X6KioqSqqkrWrl0rU6ZMkaNHj0pFRYVUVFRIZWWljB8/Xnbv3u1eCeyNN96QpKQk95WVU538etS7774rvXr1khtuuMGdderUSe655x6pra2VgoKCJtv9/ve/d1/REfnfV9FKSkpERKRz584SEBAg69atk8OHD7d0CuT2228XPz+/Fm27evVqaWxslIceekh8fZv+GU73tbCfe/311yUlJUXCwsLc81tRUSHjxo2ThoYGWb9+vYj8NHf+/v7yxz/+0b2tn5+fzJgxo0XjBgAAALytxV/3evbZZ2XAgAFSU1MjL7/8sqxfv15cLpeIiOzZs0ccx5F58+bJvHnzTrt9eXm5REZGSnFxsUyaNMn4XF9//bUkJCQ0O5kfPHiwOz9VdHR0k/8+WbCcLEhcLpc8/vjjMnv2bOnZs6dceumlkpGRITfffLP06tXrLGdAJDY29qzv+3PFxcXi6+srQ4YMadH2u3fvlsLCQomIiDhtXl5eLiI/zU3v3r0lODi4ST5w4MAWPS8AAADgbS0uUkaOHOle3SszM1PGjBkjN954o3z11VfuH27fe++9Mn78+NNuHx8f39KnVp3p6objOO5/nzVrllxzzTWyevVq+eCDD2TevHmycOFCWbt2rQwfPvysnqdz587NbjvTVZCGhoazesyz1djYKFdccYX89a9/PW0+YMCAVn0+AAAA4JfSKj+c9/Pzk4ULF0p6err8/e9/l1tuuUVEfvpK1rhx44zbxsXFyfbt24336devnxQWFkpjY2OTqyk7d+505y0RFxcns2fPltmzZ8vu3btl2LBh8tRTT8myZctE5Oy+dvVzYWFhp1056+dXe+Li4qSxsVF27Nghw4YNO+PjnWkMcXFxUltbq85vv3795OOPP5ba2tomV1O++uor43YAAABAW2m1JYjT0tJk5MiRkpOTI126dJG0tDRZvHixHDx4sNl9v/vuO/e/T5o0SbZu3SqrVq1qdr+TVz5++9vfyqFDh2TFihXurL6+Xp555hkJDg6W1NTUcxrrsWPH5MSJE01ui4uLk5CQEPn+++/dtwUFBZ3zUr1xcXFSU1MjhYWF7tsOHjzY7PVlZmaKr6+vLFiwoNmSwade8TnTGKZMmSIbN26UDz74oFlWXV0t9fX1IvLT3NXX18vzzz/vzhsaGuSZZ545p9cFAAAA/FJabQliEZE5c+bI5MmTJTc3V5599lkZM2aMJCYmyu233y79+/eXsrIy2bhxo+zbt0+2bt3q3iYvL08mT54st9xyiyQnJ0tVVZXk5+fLP/7xD0lKSpI77rhDFi9eLNnZ2fL5559LTEyM5OXlyaeffio5OTkSEhJyTuPctWuXXH755TJlyhQZMmSI+Pv7y6pVq6SsrEyuv/569/2Sk5Pl+eefl7/97W8SHx8vPXr0kF//+tfGx77++uvlvvvuk4kTJ8o999wjx44dk+eff14GDBggmzdvdt8vPj5eHnzwQXn00UclJSVFrr32WnG5XLJp0ybp06ePLFy40DiGOXPmSH5+vmRkZEh2drYkJydLXV2dbNu2TfLy8qS0tFTCw8PlmmuukdGjR8vcuXOltLRUhgwZIm+++abU1NSc05wBAAAAv5hzXQ7s5PK7mzZtapY1NDQ4cXFxTlxcnFNfX+8UFxc7N998s9OrVy+nU6dOTmRkpJORkeHk5eU12a6ystK5++67ncjISCcgIMCJiopypk6d6lRUVLjvU1ZW5kybNs0JDw93AgICnMTExGbL555cgvh0SwuLiPPwww87juM4FRUVzp/+9Cdn0KBBTlBQkBMaGur86le/clauXNlkm0OHDjlXX321ExIS4oiIeylg0xw4juN8+OGHzoUXXugEBAQ4AwcOdJYtW9ZsCeKTXn75ZWf48OGOy+VywsLCnNTUVGfNmjXqGBzHcY4ePercf//9Tnx8vBMQEOCEh4c7o0aNcp588knnhx9+aDK/WVlZTpcuXZzQ0FAnKyvL2bJlC0sQAwAAwEo+jnPKd4sAAAAAoI212m9SAAAAAKA1UKQAAAAAsApFCgAAAACrUKQAAAAAsApFCgAAAACrUKQAAAAAsEqrNnNsTX5+fsY8ISHBmB8+fNiYp6WlefT4n3zyiTHftm2bMQ8ODjbm33zzjTH3REZGhjEPCAgw5trcXXTRRcY8KCjImK9cudKYb9myxZj7+ppr7w8//NCYe0JrLOpyuYx5aGioMZ81a5Yxf/fdd435sWPHjHlJSYkxr66uNuZHjx415p7q2bOnMdfeVz/++KMxHzRokDEvKysz5unp6cZ8+fLlxlw77h08eNCYe8LHx8eYd+vWzZjX1tYa80mTJhlz7ZirvW/Ly8uN+YEDB4z58ePHjbkntLnt0aOHMR88eLAxHzJkiDGvq6sz5qc2Gj4d7X195MgRY15VVWXMPREVFWXMw8LCjLl2TLzsssuMeVJSkjHX5vazzz4z5toxYc+ePcbcU9HR0cY8MjLSmGvnA9oxTdu3u3fvbszfeustY659pm3fvt2YeyI2NtaYax1CtLnV9v3i4mJjXllZacy1uevatasx//LLL425CFdSAAAAAFiGIgUAAACAVShSAAAAAFiFIgUAAACAVShSAAAAAFiFIgUAAACAVShSAAAAAFjF2j4p2vrQhw4dMuZaP4SUlBRjrvVZ0Z5f6xmgrZvvTdq68I2NjcZcW9e9f//+xlxbO1vrFaL1cdFenzf5+5vfUloflQEDBhhzrRfEyJEjjfnevXuN+aZNm4y51uPG27QeOFquve+0NffDw8M92l77+2vHjbak7dvaax82bJgxLy0tNebacalLly7GXOtx05a0sWufhyNGjDDmO3bsMOYnTpzwKNf6E3mT9p7Xerj07t3bmMfFxRlzbb/XellofTi0cxlv0/q0aMe0iIgIY671ltJ6cWjnC0OHDjXmWh8bb2poaDDm2vt+9+7dxlyb286dOxtz7b2hHRe0842zwZUUAAAAAFahSAEAAABgFYoUAAAAAFahSAEAAABgFYoUAAAAAFahSAEAAABgFYoUAAAAAFZpt31S6urqjPmgQYOMeZ8+fYx5SUmJMdfWXg8MDDTm9fX1xtybtD4jWo+YCy64wJgfPXrUo+fX1r3X1hbXxudNWh8ObV3x6OhoYz5+/HhjrvWy+PTTT415Xl6eMe/UqZMx9zZtzX7tfdWtWzdj3rdvX2N+2WWXGfPvv//emGt/f62HkDfFxMQYc23uJ06c6FH+4osvGnOtl0hRUZExd7lcxtybtF4RVVVVxlzbL/bt22fMtc+j+Ph4Y75z505jrh13vEl7T2ufR9rnhZZrn0da3y5tv9TOhbxN69OinS9UVlYac6331uTJk4251kdl0aJFxrwtaec62vsqOTnZmM+ZM8eYv/baa8Z848aNxlzrEaSd650NrqQAAAAAsApFCgAAAACrUKQAAAAAsApFCgAAAACrUKQAAAAAsApFCgAAAACrUKQAAAAAsIq1fVK0PicXXXSRMU9KSjLmpaWlxvziiy825itWrDDmP/zwgzHXeg54U+/evY251o8gPDzcmGvrovfq1cuYa3Oj9aJoy34IISEhxtzTuRkzZowx37ZtmzHX5jYsLMyYl5WVGXNv09bs97Q/kbbmfufOnY15QkKCMa+oqDDmbdnjp7a21phrvTp69uxpzLXjhnZcOnDggDHfunWrMT9+/Lgx9yat30FoaKgxb2xsNOabN2825sHBwcZc6++k7ZcHDx405t6k7bdar4b09HRjru3X2jFV66OifZ5pr8/b9u/fb8y1Y6rWA0jrSacd87X51d732vx7k3a+oM1tXFycMS8vL/fo+T2d+9bYd7mSAgAAAMAqFCkAAAAArEKRAgAAAMAqFCkAAAAArEKRAgAAAMAqFCkAAAAArEKRAgAAAMAq1vZJ0da2LiwsNOZff/21MU9MTDTm/fv3N+bR0dHGvK6uzphrfVS8SVvbWlubW+thk5ycbMy1tbO//fZbY66t+X/ixAlj7k1aHxRtTX2tF8WhQ4eMudYv4bPPPjPmWp+UtlxTXkRft1173xUXFxtzbd/6v//7P2Ou9VEZNWqUMS8oKDDm3qT9bbV+B1qPGm3fGjZsmDHXjgudOnUy5lqvEm8KCgoy5trnneM4xvyOO+4w5tpxZ/ny5cZc+9trr8+btPe01uNFm9vRo0cb827duhlz7VzE19f8/4q1HjbeFhMTY8x79OhhzLXeU9rn9bXXXmvMFyxYYMy1zwztM9ebtNeujX3jxo3GXDtm3n333cZc+zx86qmnjLnWo+hscCUFAAAAgFUoUgAAAABYhSIFAAAAgFUoUgAAAABYhSIFAAAAgFUoUgAAAABYhSIFAAAAgFWs7ZOira2trT2urd2t9fro0qWLMT927Jgxr6mpMeYul8uYe5PWD0Fbu1vbXltbW9tem3ttbrV18b1J69PRt29fY67151m1apUx1+a2qKjImGv9ELQ1/b1N6/eg9RTQ5j8yMtKYb9myxZhrPRO2bt3q0fN7k9a7Sevl8Z///MeY33fffcZ89uzZxnz+/PnGXNv3Gxsbjbk3HT161Jhr/XV69uxpzNPT04259rfT+qQMHTrUmGv9obwpNDTUmMfGxhrzvXv3GvNXXnnFmD/66KPGXPs81M4ltHMZb9N6XWifCVp/oksvvdSYf/TRR8Zc+/tp42vL44LWM047T9Rem/Z5p+2bWg+h7du3G/OJEyca87PBlRQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVrG2T4pGW1de6zfRtWtXY75//36PHl/r4xISEmLMvSk/P9+Yjx492phrr03LMzIyjLm2drenfVi8aeDAgcZc6+PxxRdfGHOth43WY0Zbs17rQZOYmGjMvS0wMNCYa397bU3+2267zZj/5je/MeYlJSXGvE+fPsZcW5fem7S51dbkf//99415amqqMS8rKzPmWo8c7b3RlnN75MgRY3748GFjPnnyZGN+6NAhYz527FhjnpKSYszfffddY669Pm/SPg+0vllanxKtF0RBQYExT05ONuZaLwutx463afNXUVFhzMPDw415v379jLnW0y4sLMyYa72/tP3Hm7TPY62HS3FxsTHXzpPnzJljzK+44gpjftVVVxlz7Tz6bHAlBQAAAIBVKFIAAAAAWIUiBQAAAIBVKFIAAAAAWIUiBQAAAIBVKFIAAAAAWIUiBQAAAIBVrO2Toq1p39DQYMy19ZnLy8uNeVZWljHfsWOHMf/vf/9rzLX1sb1J6yeg9SuIjo425nV1dcb8yy+/NOY7d+405tra39ra4t40YsQIY75x40Zj3r17d2Ou9eHQ9utevXoZ8+rqamOurUnvbVrPAK3ngKf7/po1a4y5Nr7jx48bc+3v50319fXGXOsn8OOPPxpzrTeU1gdF62GjHXe0Y7I3ab0mtPflgQMHjPkll1xizJcsWWLMk5KSjLnWA+e7774z5t4UFxdnzLVeElpfryuvvNKYf/LJJ8Z88ODBxjwhIcGYr1271ph7m3YupfX+crlcxnzfvn3G/OOPPzbm3bp1M+ZRUVHGXPv7eJPWeyogIMCYe9r76b333jPmWn8mrd/g5s2bz3VIzXAlBQAAAIBVKFIAAAAAWIUiBQAAAIBVKFIAAAAAWIUiBQAAAIBVKFIAAAAAWIUiBQAAAIBVfBxPF1oGAAAAgFbElRQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVqFIAQAAAGAVihQAAAAAVvl/dR+Mj1zyZycAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "@partial(nnx.jit, static_argnums=(3,))\n", "def reverse_diffusion_batch(model: UNet,\n", " x: jax.Array,\n", " key: jax.Array,\n", " num_steps: int) -> jax.Array:\n", " \"\"\"Efficiently generates samples from the trained diffusion model using batched reverse diffusion (with `jax.lax.scan`).\n", "\n", " Args:\n", " model (UNet): The trained U-Net model for image generation.\n", " x (jax.Array): Noisy image (or pure noise).\n", " key (jax.Array): A JAX PRNG key for generating random noise.\n", " num_steps (int): Number of denoising steps in the reverse diffusion process.\n", "\n", " Returns:\n", " jax.Array: The denoised image after `num_steps` iterations.\n", " \"\"\"\n", " # Define the schedule for beta (noise level) and alpha (signal strength).\n", " beta = jnp.linspace(1e-4, 0.02, num_steps)\n", " alpha = 1 - beta\n", " alpha_cumulative = jnp.cumprod(alpha)\n", "\n", " def scan_step(carry: Tuple[jax.Array, jax.Array],\n", " step: int) -> Tuple[Tuple[jax.Array, jax.Array], jax.Array]:\n", " \"\"\"Applied a single denoising step.\"\"\"\n", " # Carry-over information.\n", " x, key = carry\n", "\n", " # Create a batch of timesteps for the current iteration.\n", " t_batch = jnp.full((x.shape[0],), step)\n", "\n", " # Predict the noise using the U-Net model.\n", " predicted = model(x, t_batch)\n", "\n", " # Generate noise for the current timestep (after the first \"pure noise\" step).\n", " key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", " noise = jnp.where(step > 0, jax.random.normal(subkey, x.shape), 0)\n", "\n", " # Update the image using denoising.\n", " x_new = 1 / jnp.sqrt(alpha[step]) * (\n", " x - (1 - alpha[step]) / jnp.sqrt(1 - alpha_cumulative[step]) * predicted\n", " ) + jnp.sqrt(beta[step]) * noise\n", "\n", " # Return the updated image and carry-over information.\n", " return (x_new, key), x_new\n", "\n", " steps = jnp.arange(num_steps - 1, -1, -1)\n", " (final_x, _), _ = jax.lax.scan(scan_step, (x, key), steps)\n", " return final_x\n", "\n", "def plot_samples(model: UNet,\n", " diffusion: DiffusionModel,\n", " images: jax.Array,\n", " key: jax.Array,\n", " num_samples: int = 9) -> None:\n", " \"\"\"Visualize original vs reconstructed images.\"\"\"\n", " indices = jax.random.randint(key, (num_samples,), 0, len(images))\n", " samples = images[indices]\n", "\n", " key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", " noisy = diffusion.forward(samples, jnp.full((num_samples,), diffusion.num_steps-1), subkey)[0]\n", "\n", " key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", " reconstructed = reverse_diffusion_batch(model, noisy, subkey, diffusion.num_steps)\n", "\n", " fig, axes = plt.subplots(2, num_samples, figsize=(8, 2))\n", "\n", " for i in range(num_samples):\n", " axes[0, i].imshow(samples[i, ..., 0], cmap='gray')\n", " axes[0, i].axis('off')\n", " axes[1, i].imshow(reconstructed[i, ..., 0], cmap='gray')\n", " axes[1, i].axis('off')\n", "\n", " axes[0, 0].set_title('Original')\n", " axes[1, 0].set_title('Reconstructed')\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Create a plot of original vs reconstructed images.\n", "key, subkey = jax.random.split(key) # Split the JAX PRNG key.\n", "plot_samples(model, diffusion, images_test, subkey)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 277 }, "id": "iqfjpn8havnI", "outputId": "756595c5-5380-46fd-f625-e21b4581381e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Full Forward and Reverse Process:\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAwEAAADhCAYAAACQs7kKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADHhUlEQVR4nOx9eZRcVdX9rh6rqrszkYExISFMCQIC8iNMYYyCDAKBgMgggwwSBEWUeVZYzDJpUBEjyIyCAhImQeAD/EQlGD6EEEQIEJL0PHff3x9Z+/V+p25VujuVpB65e61a1V316g3nnXvO2eece1/KOecQEBAQEBAQEBAQELDGoGx1n0BAQEBAQEBAQEBAwKpFIAEBAQEBAQEBAQEBaxgCCQgICAgICAgICAhYwxBIQEBAQEBAQEBAQMAahkACAgICAgICAgICAtYwBBIQEBAQEBAQEBAQsIYhkICAgICAgICAgICANQyBBAQEBAQEBAQEBASsYQgkICAgICAgICAgIGANQyABAQEBqxzPPfccUqkUnnvuudV9KhE23HBDHHvssav7NNZI7Lbbbthtt91in33yySeYPn061lprLaRSKdxwww0AgH//+9+YNm0ahg4dilQqhd/97ndFPZcFCxYglUrhV7/6VVH3GxAQEFBqCCQgIGA14le/+hVSqZT39cMf/nB1n16AB/Y+DRkyBFOnTsUf//jH1X1qJYFjjz02Jp/a2lpMmDAB06dPx4MPPoje3t5+7efMM8/En/70J5xzzjmYPXs2vvKVrwAAjjnmGLzxxhu44oorMHv2bGy33XYr83JWKazshgwZgq222grXXnstOjo6VvfpBQQEfM5QsbpPICAgALj00ksxfvz42GdbbLHFajqbgOVh7733xtFHHw3nHN5//33cdttt2H///fH444/jy1/+8uo+vdWO6upq/PznPwcAtLW14f3338ejjz6K6dOnY7fddsPvf/97DBkyJNr+ySefzNnHM888gwMPPBBnnXVW9FlbWxtefvllnHfeeTjttNNWyrmPGzcObW1tqKysXCn7Xx5UdvX19XjwwQdx1lln4bXXXsM999yzWs4pICDg84lAAgICSgD77LPPSslotrS0oKampuj7XR6cc2hvb0cmk1nlx14V2GSTTfCNb3wj+v+QQw7BpEmTcOONN5YkCVjV96OioiImHwC4/PLLceWVV+Kcc87BiSeeiHvvvTf6rqqqKmcfn376KYYNGxb7bNGiRQCQ83kxkUqlkE6nV9r+lwcru1NPPRX/7//9P9x777247rrrsO666+b85vM+3gICAlYOQjtQQEAC8Mwzz2CXXXZBTU0Nhg0bhgMPPBDz5s2LbXPxxRcjlUrhX//6F77+9a9j+PDh2HnnnfHII48glUrhn//8Z7Ttgw8+iFQqhYMPPji2j8033xwzZsyI/r/jjjuwxx57YPTo0aiursakSZNw22235ZzfhhtuiP322w9/+tOfsN122yGTyeBnP/sZAOC///0vvva1r6GmpgajR4/GmWee2e/Whvfffx+nnnoqNt10U2QyGay11lo49NBDsWDBgth2bKt68cUX8d3vfhejRo1CTU0NDjrooChwJJxzuPzyy7H++usjm81i9913x5tvvtmv88mHzTffHCNHjsS7774b+7yjowMXXXQRJk6ciOrqamywwQY4++yzY9e/xRZbYPfdd8/ZZ29vL9Zbbz1Mnz499tkNN9yAyZMnI51OY8yYMTjppJOwdOnS2G8L3Y85c+Zg5513xrBhw1BbW4tNN90U55577oDPezD44Q9/iGnTpuH+++/H22+/HX2ucwJ4L51zuOWWW6LWmIsvvhjjxo0DAHz/+99HKpXChhtuCGBZGw3/VnBMKJZ3/fnmBAxkDL7zzjs49thjMWzYMAwdOhTf/OY30draOiiZlZWVRbKh3he6v/Pnz8ehhx6KESNGIJvNYocddvC2qrW3t+Piiy/GJptsgnQ6jXXWWQcHH3xwTIf7q29//etf8eUvfxkjR45EJpPB+PHjcdxxx8W2ueeee7Dtttuirq4OQ4YMwRe+8AXceOONg5JJQEBAcRAqAQEBJYCGhgZ89tlnsc9GjhwJAHjqqaewzz77YMKECbj44ovR1taGm266CTvttBP+9re/5QQ/hx56KDbeeGP86Ec/gnMOO++8M1KpFJ5//nlsueWWAIAXXngBZWVl+Mtf/hL9btGiRXjrrbdibRa33XYbJk+ejAMOOAAVFRV49NFHceqpp6K3txff/va3Y8f9v//7PxxxxBE46aSTcOKJJ2LTTTdFW1sb9txzT/znP//B6aefjnXXXRezZ8/GM8880y+5vPbaa3jppZdw+OGHY/3118eCBQtw2223YbfddsO//vUvZLPZ2PYzZ87E8OHDcdFFF2HBggW44YYbcNppp8WyzhdeeCEuv/xy7Lvvvth3333xt7/9DdOmTUNnZ2e/zsmHhoYGLF26FBtttFH0WW9vLw444AD85S9/wbe+9S1svvnmeOONN3D99dfj7bffjia0zpgxAxdffDE+/vhjrL322tHv//KXv+Cjjz7C4YcfHn120kkn4Ve/+hW++c1v4vTTT8d7772Hm2++Ga+//jpefPHFWAuL7368+eab2G+//bDlllvi0ksvRXV1Nd555x28+OKLAz7vweKoo47Ck08+iTlz5mCTTTbJ+X7XXXfF7NmzcdRRR0VtVwCw5ZZbYtiwYTjzzDNxxBFHYN9990Vtbe2Ajt2f6/dhoGPwsMMOw/jx4/HjH/8Yf/vb3/Dzn/8co0ePxlVXXTWg8yUYmK+11lrRZ777+8knn2DHHXdEa2srTj/9dKy11lq48847ccABB+CBBx7AQQcdBADo6enBfvvth6effhqHH344vvOd76CpqQlz5szB3LlzIz3uj759+umnmDZtGkaNGoUf/vCHGDZsGBYsWICHHnooOtc5c+bgiCOOwJ577hnJYN68eXjxxRfxne98Z1AyCQgIKAJcQEDAasMdd9zhAHhfxNZbb+1Gjx7tFi9eHH32j3/8w5WVlbmjjz46+uyiiy5yANwRRxyRc5zJkye7ww47LPp/m222cYceeqgD4ObNm+ecc+6hhx5yANw//vGPaLvW1tacfX35y192EyZMiH02btw4B8A98cQTsc9vuOEGB8Ddd9990WctLS1u4sSJDoB79tlnC8rHd/yXX37ZAXC//vWvo88ox7322sv19vZGn5955pmuvLzc1dfXO+ec+/TTT11VVZX76le/Gtvu3HPPdQDcMcccU/B8nHMOgDv++OPdokWL3Keffur++te/uq985SsOgLv66quj7WbPnu3KysrcCy+8EPv9T3/6UwfAvfjii8455/7v//7PAXA33XRTbLtTTz3V1dbWRjJ44YUXHAB31113xbZ74okncj7Pdz+uv/56B8AtWrQo7/X197zz4ZhjjnE1NTV5v3/99dcdAHfmmWdGn02dOtVNnTo1th0A9+1vfzv22XvvvZcjZx5z3LhxOcfimCD6c/08xh133BF9NtAxeNxxx8X2edBBB7m11lor7zH1OmpqatyiRYvcokWL3DvvvON+9KMfuVQq5bbccstou3z394wzznAAYveuqanJjR8/3m244Yaup6fHOefcL3/5SwfAXXfddTnnwHHRX317+OGHHQD32muv5b2u73znO27IkCGuu7t7uTIICAhYdQjtQAEBJYBbbrkFc+bMib0AYOHChfj73/+OY489FiNGjIi233LLLbH33nvjsccey9nXySefnPPZLrvsghdeeAEA0NTUhH/84x/41re+hZEjR0afv/DCCxg2bFhsQrL2GLNaMXXqVMyfPx8NDQ2xY4wfPz6nH/6xxx7DOuusE2tpyWaz+Na3vtUvuejxu7q6sHjxYkycOBHDhg3D3/72t5ztv/Wtb8XaP3bZZRf09PTg/fffB7Aso9vZ2YmZM2fGtjvjjDP6dT7EL37xC4waNQqjR4/Gdttth6effhpnn302vvvd70bb3H///dh8882x2Wab4bPPPotee+yxBwDg2WefBbBsfsHWW28dq1b09PTggQcewP777x/J4P7778fQoUOx9957x/a37bbbora2Ntof4bsf7KX//e9/n3eVnv6e92DB7H1TU9MK7Wcw6M/1WxRjDO6yyy5YvHgxGhsbl3u8lpYWjBo1CqNGjcLEiRNx7rnnYsqUKXj44Ydj2+Ubb9tvvz123nnn6LPa2lp861vfwoIFC/Cvf/0LwLJ2wJEjR2LmzJk5x+e46K++UaZ/+MMf0NXV5b2mYcOGoaWlJbJrAQEBpYFAAgICSgDbb7899tprr9gLQBS8brrppjm/2XzzzfHZZ5+hpaUl9rldZQhYFoQsXLgQ77zzDl566SWkUilMmTIlRg5eeOEF7LTTTigr6zMLL774Ivbaa6+oD3rUqFFR/7SPBFi8//77mDhxYk5ftu96fGhra8OFF16IDTbYANXV1Rg5ciRGjRqF+vr6nOMDwNixY2P/Dx8+HACiHmbKc+ONN45tN2rUqGjb/uDAAw/EnDlz8Mc//jHqA29tbY3J7t///jfefPPNKKDjiy0wn376abTtjBkz8OKLL+LDDz8EsOw5Cp9++mlsfsa///1vNDQ0YPTo0Tn7bG5uju0P8N+PGTNmYKeddsIJJ5yAMWPG4PDDD8d9990XC4gHct6DQXNzMwCgrq5uhfYzGPTn+i0GMwaXp4eFkE6no0TA888/jw8++AAvvvgiJkyYENsu33jLd556Le+++y423XRTVFTk7wjur75NnToVhxxyCC655BKMHDkSBx54IO64447Y/JFTTz0Vm2yyCfbZZx+sv/76OO644/DEE08sVxYBAQErF2FOQEDA5wy+FUKYGXz++ecxf/58bLPNNqipqcEuu+yCn/zkJ2hubsbrr7+OK664IvrNu+++iz333BObbbYZrrvuOmywwQaoqqrCY489huuvvz4ncFoZK5PMnDkTd9xxB8444wxMmTIlekDU4Ycf7g3cysvLvftxzhX1vNZff/2IqO27774YOXIkTjvtNOy+++7RZOve3l584QtfwHXXXefdxwYbbBD9PWPGDJxzzjm4//77ccYZZ+C+++7D0KFDo7Xxub/Ro0fjrrvu8u5v1KhRsf999yOTyeD555/Hs88+iz/+8Y944okncO+992KPPfbAk08+ifLy8gGd92Awd+5cAMDEiRNXaD8KSzKJnp6e2P/9uf5iYEX0sLy8PNKtQljZKwH1V99SqRQeeOAB/M///A8effRR/OlPf8Jxxx2Ha6+9Fv/zP/+D2tpajB49Gn//+9/xpz/9CY8//jgef/xx3HHHHTj66KNx5513rtTrCAgIyI9AAgICShhcDeX//u//cr576623MHLkyH4tATp27FiMHTsWL7zwAubPn49ddtkFwLJJmN/97ndx//33o6enB7vuumv0m0cffRQdHR145JFHYpnNgbSDjBs3DnPnzoVzLhao+a7HhwceeADHHHMMrr322uiz9vZ21NfX9/sc7PkAy7KcmlldtGhRv7K0+XDSSSfh+uuvx/nnn4+DDjoIqVQKG220Ef7xj39gzz33zBukEuPHj8f222+Pe++9F6eddhoeeughfO1rX0N1dXW0zUYbbYSnnnoKO+200woFgGVlZdhzzz2x55574rrrrsOPfvQjnHfeeXj22Wex1157Dei8B4PZs2cjlUph7733Lto+hw8f7tUJZr4Vy7t+i2KNwVWBcePG5T1Pfg8s06VXXnkFXV1deZ+HMFB922GHHbDDDjvgiiuuwN13340jjzwS99xzD0444QQAy5aB3X///bH//vujt7cXp556Kn72s5/hggsuKCohDAgI6D9CO1BAQAljnXXWwdZbb40777wzFuTMnTsXTz75JPbdd99+72uXXXbBM888g1dffTUiAVtvvTXq6upw5ZVXIpPJYNttt422ZzZTs5cNDQ244447+n3MfffdFx999BEeeOCB6LPW1lbMmjWrX78vLy/PyZ7edNNNORne/mKvvfZCZWUlbrrppth+b7jhhkHtj6ioqMD3vvc9zJs3D7///e8BLFsh5sMPP8Ttt9+es31bW1tOC8mMGTPwP//zP/jlL3+Jzz77LNYKxP319PTgsssuy9lfd3d3v4jRkiVLcj7beuutASBq3xjoeQ8EV155JZ588knMmDEjpyVrRbDRRhuhoaEhtgzuwoULc/ro+3P9FsUcgysb++67L1599VW8/PLL0WctLS2YNWsWNtxwQ0yaNAnAsudafPbZZ7j55ptz9sFx0V99W7p0ac4YtTJdvHhx7PuysrJopbLwJOSAgNWHUAkICChxXH311dhnn30wZcoUHH/88dHyhEOHDsXFF1/c7/3ssssuuOuuu5BKpaL2oPLycuy4447405/+hN122y320KZp06ZF2buTTjoJzc3NuP322zF69GgsXLiwX8c88cQTcfPNN+Poo4/G//7v/2KdddbB7Nmzc5b2zIf99tsPs2fPxtChQzFp0iS8/PLLeOqpp2JLJQ4Eo0aNwllnnYUf//jH2G+//bDvvvvi9ddfx+OPPx4tyTpYHHvssbjwwgtx1VVX4Wtf+xqOOuoo3HfffTj55JPx7LPPYqeddkJPTw/eeust3HfffdEa78Rhhx2Gs846C2eddRZGjBiRk5WeOnUqTjrpJPz4xz/G3//+d0ybNg2VlZX497//jfvvvx833nhjbAK2D5deeimef/55fPWrX8W4cePw6aef4tZbb8X6668f6cRAz9uH7u5u/OY3vwGwrHLz/vvv45FHHsE///lP7L777v0mgf3F4Ycfjh/84Ac46KCDcPrpp6O1tRW33XYbNtlkk9gE8v5cvw/FGoMrGz/84Q/x29/+Fvvssw9OP/10jBgxAnfeeSfee+89PPjgg9GclaOPPhq//vWv8d3vfjdKCrS0tOCpp57CqaeeigMPPLDf+nbnnXfi1ltvxUEHHYSNNtoITU1NuP322zFkyJCIIJ1wwglYsmQJ9thjD6y//vp4//33cdNNN2HrrbeO5isEBASsBqy+hYkCAgK4tGWh5fWcc+6pp55yO+20k8tkMm7IkCFu//33d//6179i23B5wnzLH7755psOgNt8881jn19++eUOgLvgggtyfvPII4+4Lbfc0qXTabfhhhu6q666Klpe8L333ou2GzdunPvqV7/qPe7777/vDjjgAJfNZt3IkSPdd77znWiZweUtEbp06VL3zW9+040cOdLV1ta6L3/5y+6tt95y48aNiy3nmU+Ozz77bM5xenp63CWXXOLWWWcdl8lk3G677ebmzp2bs898gGfpSuLiiy+OHa+zs9NdddVVbvLkya66utoNHz7cbbvttu6SSy5xDQ0NOb/faaedHAB3wgkn5D3+rFmz3LbbbusymYyrq6tzX/jCF9zZZ5/tPvroo2ibfPfj6aefdgceeKBbd911XVVVlVt33XXdEUcc4d5+++3YdgM9b8UxxxwTW+o2m826DTfc0B1yyCHugQceiJapVKzoEqHOOffkk0+6LbbYwlVVVblNN93U/eY3v8lZIrQ/1+9bItS5FRuD1E8dMz4sb3lVotB4e/fdd9306dPdsGHDXDqddttvv737wx/+kLNda2urO++889z48eNdZWWlW3vttd306dPdu+++G9tuefr2t7/9zR1xxBFu7Nixrrq62o0ePdrtt99+7q9//Wu0jwceeMBNmzbNjR492lVVVbmxY8e6k046yS1cuHC51xoQELDykHKuyDPmAgICAgICAgICAgJKGmFOQEBAQEBAQEBAQMAahkACAgICAgICAgICAtYwBBIQEBAQEBAQEBAQsIYhkICAgICAgICAgICANQyBBAQEBAQEBAQEBASsYQgkICAgICAgICAgIGANQyABAQEBAQEBAQEBAWsYAgkICAgICAgICAgIWMMQSEBAQEBAQEBAQEDAGoZAAgICAgICAgICAgLWMAQSEBAQEBAQEBAQELCGIZCAgICAgICAgICAgDUMgQQEBAQEBAQEBAQErGEIJCAgICAgICAgICBgDUMgAQEBAQEBAQEBAQFrGAIJCAgICAgICAgICFjDEEhAQEBAQEBAQEBAwBqGQAICAgICAgICAgIC1jAEEhAQEBAQEBAQEBCwhiGQgICAgICAgICAgIA1DIEEBAQEBAQEBAQEBKxhCCQgICAgICAgICAgYA1DIAEBAQEBAQEBAQEBaxgCCQgICAgICAgICAhYwxBIQEBAQEBAQEBAQMAahkACAgICAgICAgICAtYwBBIQEBAQEBAQEBAQsIYhkICAgICAgICAgICANQyBBAQEBAQEBAQEBASsYQgkICAgICAgICAgIGANQyABAQEBAQEBAQEBAWsYAgkICAgICAgICAgIWMMQSEBAQEBAQEBAQEDAGoZAAgICAgICAgICAgLWMAQSEBAQEBAQEBAQELCGIZCAgICAgICAgICAgDUMgQQEBAQEBAQEBAQErGEIJCAgICAgICAgICBgDUMgAQEBAQEBAQEBAQFrGAIJCAgICAgICAgICFjDEEhAQEBAQEBAQEBAwBqGQAICAgICAgICAgIC1jAEEhAQEBAQEBAQEBCwhiGQgICAgICAgICAgIA1DIEEBAQEBAQEBAQEBKxhCCQgICAgICAgICAgYA1DIAEBAQEBAQEBAQEBaxgCCQgICAgICAgICAhYw7DSSMBLL72Eiy++GPX19UXb5yOPPIJtttkG6XQaY8eOxUUXXYTu7u6i7X91o9gyu/fee/GNb3wDG2+8MVKpFHbbbTfvdq+99hpOO+00TJ48GTU1NRg7diwOO+wwvP322wX339XVhUmTJiGVSuGaa64pyjmvLBRTtosXL8bVV1+NXXfdFaNGjcKwYcOwww474N57713ub6+44gqkUilsscUWOd91dXXhkksuwYQJE1BdXY0JEybg8ssvL3kdL7bennnmmdhmm20wYsQIZLNZbL755rj44ovR3Nwc2+7YY49FKpXK+/rwww+9+6+vr8fo0aORSqXwwAMPFOWcVxZWhh0l3n33XaTTaaRSKfz1r3/N+X7OnDnYeeedkc1mMXz4cEyfPh0LFizI2a65uRlnnHEG1l9/fVRXV2PzzTfHbbfdVvTzXRlYGfJtamrC2WefjfHjx6O6uhrrrbcepk+fjtbW1mibp59+Gscddxw22WQTZLNZTJgwASeccAIWLlxYcN9rqu4+99xzBcf6FVdckfe3J554IlKpFPbbb7+c75Kqu8XW2/b2dvz4xz/GpEmTkM1msd566+HQQw/Fm2++Gdtut912y3sPKisr8+5/ebamlFBs2Q5Ex1aLzXUrCVdffbUD4N57772i7O+xxx5zqVTK7b777m7WrFlu5syZrqyszJ188slF2X8poNgymzp1qqutrXW77767Gz58uJs6dap3u0MOOcStvfbabubMme722293l112mRszZoyrqalxb7zxRt79X3vtta6mpsYBcFdffXVRznlloZiyffTRR11lZaU78MAD3Q033OBuvvlmt/vuuzsA7sILL8z7uw8++MBls1lXU1PjJk+enPP9YYcd5lKplDv++OPdbbfd5o455hgHwJ144okrfM4rE8XW25122smdfvrp7ic/+YmbNWuWO+WUU1x1dbXbaaedXE9PT7TdSy+95GbPnh17/frXv3bZbNZNmjQp7/5nzpwZ6e39999flHNeWSi2bBX7779/JIfXXnst9t2jjz7qysrK3HbbbeduvPFGd9lll7mRI0e69dZbz3366afRdt3d3W7HHXd0VVVV7swzz3S33nqrO/DAAx0Ad8UVVxT9nIuNYsu3vr7ebbXVVm6ttdZy55xzjvvFL37hrrzySvfVr37VLVmyJNpu2223dePHj3dnn322u/32290555zj6urq3JgxY9zChQvz7n9N1d2PP/44Z6zPnj3bTZs2zQFwr776qvd3r732mquoqHDpdNp99atfjX2XZN0ttt4efPDBrqKiwp1yyinu9ttvd5dccokbPXq0q6urcwsWLIi2e/LJJ3PuwU9/+lMHwO27775591/I1pQaiinbgejY6rK5iSEBkyZNcltttZXr6uqKPjvvvPNcKpVy8+bNK8oxVjeKLbP//Oc/UdA0efLkvCTgxRdfdB0dHbHP3n77bVddXe2OPPJI728++eQTN3ToUHfppZeucSRg/vz5McPonHO9vb1ujz32cNXV1a65udn7uxkzZrg99tjDTZ06NYcEvPrqqw6Au+CCC2Kff+9733OpVMr94x//WOHzXllYmYEqcc011zgA7uWXXy643QsvvFDQGL7xxhuuoqIi0ts1KZBSPPHEE66qqsqdf/75Xsc8adIkN3HixJhd+Pvf/+7Kysrcd7/73eiz++67zwFwv/jFL2K/P+SQQ1w6nXaffPJJUc+72Ci2fE855RQ3bNgwN3/+/ILb/fnPf44RWn4GwJ133nne3wTdzcXEiRPdxhtv7P2ut7fXTZkyxR133HFu3LhxOSQgybpbTNn+97//dQDcWWedFfv8mWeecQDcddddV/D3s2fPdgDcXXfd5f1+ebam1FBM2Q5Ex1aXzV0pJOCiiy5yAHJegxXqm2++6QC4W265Jfb5hx9+6AC4yy67rAhnvXpRbJlZFCIB+bDNNtu4bbbZxvvdN7/5Tbf99tu7+fPnlzwJWNmyJX7yk584AO6f//xnznd//vOfXXl5ufvnP//pJQHXXnutA+DefPPN2OevvfaaA+DOPffcop5rsbCqZPvAAw84AO7xxx8vuN0pp5ziUqlU3uPvscce7tBDD3XPPvtsyQdSK0u2nZ2dbtNNN3Xf//733R133JHjmBcvXuwAuO9///s5v508ebJbd911o/9nzpzpALiWlpbYdvfff78D4GbNmrVC57oyUWz5Ll261KXTaXf22Wc755zr6Ohw7e3tA9rHiBEj3MEHH+z9LuhuHK+88ooD4C6++GLv93feeaerq6tzCxcu9JKApOpusWU7b948rw/n57fddlvB3++zzz6upqbGm/xanq0pNRRbtv3VsdVpcyuwEnDwwQfj7bffxm9/+1tcf/31GDlyJABg1KhRaGhoQFdX13L3kU6nUVtbCwB4/fXXAQDbbbddbJt1110X66+/fvR9klFsma0onHP45JNPMHny5JzvXn31Vdx55534y1/+glQqVZTjrUysKtl+/PHHABDtn+jp6cHMmTNxwgkn4Atf+IL3tx0dHQCATCYT+zybzQIA/vd//3e557g6sLJk293djfr6enR2dmLu3Lk4//zzUVdXh+233z7vfrq6unDfffdhxx13xIYbbpjz/f3334+XXnoJ8+bN8/ZZlhpWlmxvuOEGLF26FOeffz4eeuihnN/k00VgmT6++eab+Pjjj7H22mujo6MD5eXlqKqqytkOWKa3J554Yv8ueBWj2PL9y1/+gvb2dkycOBHTp0/H7373O/T29mLKlCm45ZZbsPXWWxfcV3NzM5qbm3PsBxB014e77roLAHDkkUfmfNfU1IQf/OAHOPfcc7H22mt7f59U3S22bDfaaCOsv/76uPbaa7Hpppvii1/8Ij766KNoXsvhhx+edz+LFi3CnDlzMGPGDNTU1OR8vzxbU2ootmz7q2Or1eYOiDIMAPlKKlOnTvUyLfs65phjcvb1n//8J+c4X/rSl9wOO+ywsi5jlaKYMrMYaCWAJT5bcurt7XXbb7+9O+KII5xzzr333nslXwlwbuXK1rllTH706NFul112yfnu5ptvdkOHDo36+nyVgAcffNABcLNnz459zn7LLbbYYuAXvYqwMmT78ssvx7bZdNNN3bPPPlvwPB599FEHwN16660537W2trqxY8e6c845xznnEpFNda74sl24cKGrq6tzP/vZz5xzzpud6+npccOGDXN77rln7LefffZZ1Nf717/+1TnXV8F64YUXYtv+8Ic/dADcfvvtVyRJrBwUU77XXXedA+DWWmstt/3227u77rrL3XrrrW7MmDFu+PDh7qOPPip4LpdddpkD4J5++unY50F3c9Hd3e3GjBnjtt9+e+/3Z511lhs/fnxUifFVApKsu8WW7SuvvOI22mij2Dbbbrttwfkpzjl30003OQDusccey/muP7amFFFM2fZXx1anzV0plYBCuPbaa7F06dLlbrfuuutGf7e1tQEAqqurc7ZLp9NobGws3gmWIAYjsxXBW2+9hW9/+9uYMmUKjjnmmNh3v/rVr/DGG2+U/MoU/UUxZNvb24sjjzwS9fX1uOmmm2LfLV68GBdeeCEuuOACjBo1Ku8+9t13X4wbNw5nnXUWstkstt12W7zyyis477zzUFFREY2BJGFFZDtp0iTMmTMHLS0teOmll/DUU0/lrA5kcffdd6OyshKHHXZYzndXXnklurq6cO655/b/AkoYg5XtD37wg2glmnwoKyvDSSedhKuuugrnnHMOjjvuODQ2NuLss89GZ2cngD6b/PWvfx2XXnopjjvuONxyyy3YeOON8eSTT+LWW2+NbZc0DEa+1M9UKoWnn346ygZ+8YtfjKoBl19+uXc/zz//PC655BIcdthh2GOPPWLfBd3NxdNPP41PPvnEK5O3334bN954I3772996Ywbi86i7g5Xt8OHDsfXWW+PQQw/FDjvsgHfeeQc//vGPceihh2LOnDlIp9Pe/dx9990YNWoU9t5775zv+mNrkoTByLa/OrZabe6AKMMAUMzJFWt6JaAY6G8lYOHChW7ChAlugw02cB9++GHsu4aGBjdmzJjYCjhJrwQUA6eeeqoD4H7961/nfHfyySfnTPbxVQKcc27u3Llu0qRJUUahurra3XjjjW706NFuq622Kvp5FwurYgLgXXfd5crKytzf//537/dNTU0um816syDvvfeey2Qy7pe//GX0WdKzqYPByy+/7FKplHvmmWeiz/Jl5zo6Otzxxx/vysrKIn2cNm2aO/nkkx0A9/rrr0fb/vnPf3Zjx46NthsyZIi78847HQB34IEHrvB5r0ysDD/1zW9+M+e78ePHu9133937u3nz5rkRI0a4rbfe2jU2Nsa+C7rrx9FHH+3Ky8vdxx9/nPPdV77ylRxf56sEOJdc3S2mbOvr692YMWPcNddcE/v8ueeey1tZdc65d9991wFwp512Ws53A7E1pYZi621/dWx12dxVXglYsmRJxGwKIZPJYOjQoQCAddZZBwCwcOFCbLDBBrHtFi5cWLBP+POAwchsMGhoaMA+++yD+vp6vPDCCznZgmuuuQadnZ2YMWNG1Jf63//+FwCwdOlSLFiwAOuuu25Or1opY0Vle8kll+DWW2/FlVdeiaOOOir23b///W/MmjULN9xwAz766KPo8/b2dnR1dWHBggUYMmQIRowYAQCYPHky5s6di3/9619YunQpJk2ahEwmgzPPPBNTp05dwStd9Sim3h588ME46qijcM8992CrrbbK+f53v/sdWltbvf3BF154IdZbbz3stttukd5y/saiRYuwYMECjB07FmVlyXl24mBke/bZZ2OXXXbB+PHjIzl89tlnAJbZ0f/85z8YO3YsAKCqqgo///nPccUVV+Dtt9/GmDFjsMkmm+DrX/86ysrKMHHixOgYu+66K+bPn4833ngDLS0t2GqrrSJ932STTYp52asMg5Ev7eWYMWNyths9erQ3i/jBBx9g2rRpGDp0KB577DHU1dXFvg+6m2sX2tra8PDDD2OvvfbKkfUzzzyDJ554Ag899FBs7kR3dzfa2tqwYMECjBgxAkOGDAHw+dPdwcj2wQcfxCeffIIDDjggts3UqVMxZMgQvPjiizjllFNy9nH33XcD8M/JGIitSQoGq7f91bHVZnOLQnU84JJ+xeirmjt3rgPyrw506aWXrqzLWKUopswsllcJaGtrc7vssovLZrPupZde8m7DdesLvZStlhJWhmxvvvlmB8CdccYZ3mMyY1fo9Z3vfKfgef/xj390AKK+ylLEytRbor6+3gFwp5xyivf7r3zlK662tjZnxYT+nsfSpUsHceUrH8WU7bhx4wpuO3To0ILn0t3d7dZZZx03ZcqU5Z73Lbfc4gC4P/3pT4O46lWHYsr3rbfecgDcUUcdlXOcDTbYwO29996xzz777DO32WabudGjR7u3337be35Bd3Nxzz33OMBfeWW2udDr+uuvL3jeSdDdYsr2Rz/6kQOQs9R6b2+vq6mpcTNmzPCew+abb+422mgj73cramtWJ1aFP+uvjq0Km7vSKgGcKW6fujaYvqrJkydjs802w6xZs3DSSSehvLwcAHDbbbchlUph+vTpxTvx1Yhiymwg6OnpwYwZM/Dyyy/j97//PaZMmeLd7vTTT8fXvva12GeffvopTjrpJBx77LE48MADMX78+EGdw8pGsWV777334vTTT8eRRx6J6667zvubLbbYAg8//HDO5+effz6amppw4403YqONNsp7zLa2NlxwwQVYZ511cMQRRyz3HFcXiinb+vp61NTU5Dx98uc//zmA3BXCgGUZ0aeeegpHHHFEtEKC4vLLL4+yUMTcuXNxwQUX4Oyzz8aUKVO8K1uUAoop21mzZsWeWgssy5zedNNNuOaaa7DZZpsV3Nc111yDhQsX5sx7sVi0aBGuuuoqbLnllthrr72We46rE8WU76abboqtttoKv//97/HZZ59FK4s8+eST+OCDDzBz5sxo25aWFuy777748MMP8eyzz2LjjTf27j/obi7uvvtuZLNZHHTQQTnf7bHHHl6b+61vfQvjxo3Deeedl3eFNiA5ultM2TJzfM899+Diiy+OPn/kkUfQ0tKCL37xizm/f/311zFv3jxccMEF3v2vqK1ZnVjZcdhAdGxV2NyVRgK23XZbAMB5552Hww8/HJWVldh///2jzweKq6++GgcccACmTZuGww8/HHPnzsXNN9+ME044AZtvvnkxT321odgye/755/H8888DWKYkLS0t0cS0XXfdFbvuuisA4Hvf+x4eeeQR7L///liyZAl+85vfxPbzjW98AwCwzTbbYJtttol9x1Lf5MmTcwhCKaGYsn311Vdx9NFHY6211sKee+4ZLVVH7LjjjpgwYQJGjhzplckNN9wAADnfHXbYYVh33XUxadIkNDY24pe//CXmz5+PP/7xjzltAqWEYsr2ueeew+mnn47p06dj4403RmdnJ1544QU89NBD2G677SJdVNx7773o7u72lqUBYOedd875bNiwYQCAL33pS2uM3k6bNi3nMzq6qVOnxgjWb37zGzz44IPYddddUVtbi6eeegr33XcfTjjhBBxyyCGxfUydOhVTpkzBxIkT8fHHH2PWrFlobm7GH/7wh5JvUym2zb3++uux9957Y+edd8ZJJ52EhoYGXHfdddhkk01iLRVHHnkkXn31VRx33HGYN28e5s2bF31XW1sb6WTQ3TiWLFmCxx9/HIcccoh3+dCxY8d620zOOOMMjBkzJkdeSdXdYsp2//33x+TJk3HppZfi/fffjyYG33zzzVhnnXVw/PHH5/ym0PKswMBsTamh2HrbXx1bbTZ3QHWDAeKyyy5z6623XjTRYUUnWjz88MNu6623dtXV1W799dd3559/vuvs7CzOyZYIiimzfA++AOAuuuiiaLvllbkKISkTg50rnmyXV3K+4447Cv4+38Tgq666ym222WYunU674cOHuwMOOKBk26ssiiXbd955xx199NFuwoQJLpPJuHQ67SZPnuwuuuiivE9i3mGHHdzo0aNdd3d3v4+TlMmVzhXfjiryTdZ75ZVX3K677uqGDx/u0um022qrrdxPf/pT19vbm7OPM888002YMMFVV1e7UaNGua9//evu3XffLdo5rmwUW75z5sxxO+ywg0un027EiBHuqKOOyllqsVC7xLhx4wruf03WXS6Z/Mgjjwzod/kmBidZd4sp2yVLlrgzzzzTbbLJJq66utqNHDnSHX744d4nX/f09Lj11lsv74NE8yEpE4OdK65s+6tjq8vmppxzbmC0ISAgICAgICAgICAgySjdeldAQEBAQEBAQEBAwEpBIAEBAQEBAQEBAQEBaxgCCQgICAgICAgICAhYwxBIQEBAQEBAQEBAQMAahkACAgICAgICAgICAtYwBBIQEBAQEBAQEBAQsIYhkICAgICAgICAgICANQyBBAQEBAQEBAQEBASsYagY6A+6u7v7tZ19Bhn/d87BOYeenh50d3ejo6MDTU1N+Oyzz/DRRx/hgw8+wEcffYRFixahpaUF7e3t6O7uRm9vL8rKylBVVYW6ujqstdZa2GCDDbDhhhtigw02wMiRI1FXV4d0Oo3y8vK8j06uqqoa6CWvUjz55JNIpVJIpVIoLy9HRUUFKisrUVVVhcrKSlRUVETXp9fY29uLzs5OtLe3o6mpCY2NjWhqakJzc3P0amhoQFNTE5qamtDa2or29nZ0dnaio6MDPT096O3tRSqVQnV1NWprazF8+HCMHDkSa621FoYOHYq6ujrU1taipqYGdXV1qKmpQXV1dXROqVQKG2+88WqUXmH84x//yNG/zs7O6NXV1YXu7m50d3ejp6cnelF3KXfeh4qKCpSVlSGVSkXvwDId7+3tzdk/98vvibKyMpSXl6OqqgrV1dVIp9OoqqpCVVUVKioqIn3YbbfdVofY+oV0Oh3pTiaTwbBhwzBy5EiMHj0ao0ePxpAhQ5DJZCJdARCTQ29vb3RvqPvZbBZDhgzBiBEjsNZaa2HIkCGorKxEW1sbPvnkE8yfPx9vvfUW5s+fj8WLF6Orqwu1tbUYM2YMxo0bhw022ACjRo3CiBEjIn3NZrORfHk/U6kURo4cuTrFVxDPPvtszB5QxplMBul0GmVlZejp6UFraysaGxuxZMkS1NfXo6mpCW1tbejo6EBXVxe6uroi/aMec18cxwDQ09ODrq6uyD7wtxwbOj6I8vLymF3iWEilUrj00ktXi9z6gyOOOCIavz7ZVlZWIpVKRTLp6OiIZEJZ8lp57XzRRtBO0F6nUqkcf0h7QdvEz/gbypQv4kc/+tEqltjAcOyxx0b2rbKyEjU1NaitrcXQoUNRW1sbkzGvvaurK/JLLS0taGhoQGNjI5qbm9HR0YHe3t5Id+vq6pDNZpHNZlFdXR3ZzcrKysjOdHV1oa2tDQ0NDaivr8fSpUvR2NiIlpYWNDc3R/6ypaUFbW1t0b3t7e2N6Xip4cEHH4zOsbOzE62trbFr6enpQVVVFYYOHYoxY8ZgnXXWwahRo5DNZtHV1YXFixfj/fffx7///W8sWLAAixYtQnt7OyorK1FXV4cRI0ZgxIgRGDJkCGprayM5015Qvt3d3ZGdaWtri2RHe6W+jHFCKpXCoYceupolmB+zZs0CgGhc0h7SJjIm1RiN19fb24ve3t7IXvB3vD8NDQ2RDra3t0f6nMlkUFtbG8VZdXV1kf8bMWIEhg8fjtraWmQyGVRWVubEg2pzd955535f64BJgO8Bw2qU+rsPNX72RSNIZeILQOxvNaSfF6ijIFRONEqUH7dXGaiz0ABVB6C+ysvLo9+rE7dOTR0Qz6mnpyemfKUM1TsLn1x8v/ftL5VKxYJ6AFGQxPtmdb7Qsek0fUFVqaKysjI6bxomvgrJlNDAiKB+0Qh3dXWht7c3CmppiHUfAHJsCwNXBhdlZWVwzsUIVimDY4y6psF4V1cXysrKIlkpafXJ3Wd7+Ttum2+MKKzcVXeTorNA35hWn5Lv2nldtJdWVvZ6feNd76OCdrinpyfST/UFKl/fsUoV6kvUFhSCXifJFG0h96F/A3FZ9/T05Hxn74HaVn2Vl5eju7s79vtShfoV/ZvwjXsd75SJ2m2SJxtYqt2wfpS6qrGE9Wk8Nu9dEmI2nzx9cZX9TnWXctUXg/eqqqrIXtvv1IdaGard5uf2XAaCAZOAfMLqz/f5nI++1MHpyypxoYAuydCbSWXSQJsGzgZZvsGvTtkX1KtD031TQVURrWNncMVz5AAvZeQzYr7BzevxBad2n+pgfMGnZk9prHlcvuv9UMPB/0sdzEhXVlZGmXYaNKuDqieUNZCbIWU2hRUuyqOtrS2W4fYRAc3g8Pc6nkha8lUMSwldXV05zpO6yWwUyRGJER1xPvLO79Xucr+0sb7AQvWXsJlv6msSiADHOWXo8yv8HlhGdoFl16Z+yQZFTGKpPuZLwhAaPGiA1p8ERamCWX4bzFvd8NlCDZxYJQUQVUg0UKfOaiCvfk3HjQ24bFa1srIS3d3dJS9njaM0LgLiPk0rTV1dXRHZ1EoWZeycyyEB1nbwpbZTiRmA6F7oNlrlKnXYpIBNkjJuop76fJwv3uA9yGazkX2gbUin09GL3R/cN8+F91Hvq71PA0VRSEB/YQkAs3tsl2CpVd9Z/qPAVYFtlvXzADoZhXMukhUdkjVklpEqSWBwZg0rj8Xf0GBXV1fHXhrQKRlhAJaUjKrPyWvgz4Hty8JZWKfPfat+awZbW4G4X5vZ4v2hAaDckxBM1dbWxkhATU1NRAaskVRd1WBUs6vqtFpbW1FeXh5lvdvb29HQ0ICWlpaodG8DJNqVtra2SM+7u7sjMqGtXKWO9vZ2AH2ZJW3NqaqqisgBS9YM6G1W3ldN5HZAX0BMZ22JQKEMueqtvd+lDAYrwLLr8JEfm0CprKyMkSebXVXnzHGvWX7NavO4hJWx6rQGvaUuV6K6ujryRxps+3RDs6hqA9LpdDTOKUf1VwBiwRH/p3w1QNaWOvowbdtUglHqMUVXV1dsrNpgVf0Tr5G2hC1E9N/pdBrZbDa6N9oyyX3YSoKSDBIJ7lsTO7Q9+SpmpQjbpgr06aeNt3wJO03maTu36mcqlYqqAfSb2Ww2avnJZDKxGID3kUkCmyAYLIrSDkRhLO93qkTa98d+qfb2drS1tcXeOzs7I6dF4afTaW9mtb/nUspgwKKsmQ5fyRB7IjVjZMt7fGn5icpaUVERDV7+DyAykOz9Y/+fOneek5ahkiBz20rlC8g1s5rPEagu838NCnxVLWsYNdNieyc5L4B92mp0ShUjRoyI9I36QyLAErMGUoS9F+rMurq6APTNd+HY6OjoiPqE29vbY05Ig9i2trYoe9jZ2ektsQKlby+am5tjgaAtKVOPlHhq0kSDR9peBkQqd8062UqtjgPf2KHdsMSv1MFxSRlWVFTEkktAnz1VndFqH4OrVCoVyVbbVoE4kWAw7Mse+l6+Klqp6yyRzWZjSSa+fO02ahctSaC+sh9bK11An93gPdH+cyW+DHh1v7Q13IaVhFJHZ2dnrGKq7ZE20cIESEtLS7Qdf19RUYFMJhP5Kk0E+mwC26WAeLsXZU45atXA18ZdytBYwcZVAKJ3m3C11T6tCnZ0dKC6ujrS/6qqqmhOAIkX512wIsDtad+ZeM2XOBiMbFeYBGipKF/GlO86SEkAWNonEWhpaUFraytaW1ujkj/3QcNJR6eBcRIUqz9gSY5MkUZPCVFZ2bIJ0prxs1l9KqyvH43sk7/V8h6DuEwm41VGBg7McCdJ7jYDqoPbQnvuOJgVGrzqJGOdgGnb3AgNVtWA8t6k0+nYxMQk9KeOGjUqlvGjQeM1aHbYyluzLkB8vgkXD1BH1NnZGU1050Q01XvNinOCXFtbm/deJkF/m5qaYo5Fs0yFgibdVskRHYkSB20zVJtuibMFbQb1lvdcg7BSBoMeytFHfDTbp1lsZlNJANra2qJxrkGpHedAbkBh2y41uLJVtFKXqaKmpiY6X7Vx+UiAVgL02unjrH1VeWuAyoyp7o9EgtUJJRDMigOI7Hip24aOjg4AiOmsTTRZEsBKonMuNsm6pqYmiq34W12YwlYBmKBRvdVttfVFK5eWHJcqfG1LSnQIHaM6fu14JlFjdp9klCSA8qbfZ7ylNof74P3TZMPyqrWFsMLtQGoo85V7NOjSgIlVABIBrQLwb5IA7p9KbNuBPi9ghl6rHAyE6GRYRtLypvahAYX7zNnzCPQZEM2E2RUytBKgRkWNTxKQL7ujDgHoC0K1JcgXOOr90fY1bf/R+6jH8WUSfNWApJCAESNGAEDsOkggdSUeX8CqWW4lBLbNgo6EgRerhwxoNUCiY29tbY1NntUMLpAMEtDc3AzAPwFXW0tIwLSyp8GjElYd/8x++9q2fASY5wLEKxNqN/IFeqUG7dvP1w5EmWiLHjPMGohpvzU/0+wsdZB2Vm2CkjpL7pQQcD9JgVYC7HVqkG6rHgBygqqqqqqcFmK+NGOr8wNUbkr2+K5JSd4rblfqtkGDQeqa9f/qr5lM4djn9VZWViKbzUZtbvR9KicgbpP13vFvrVpzGyZkKOOkxAsaK2hVRVuf1d9Ywm7nmnCfXFWJMRxXZqR+W7+vY526qp0HamNWGQnQg+Qz8L6qgCqrTtZjwM/Mvwb/ujxoRUVFLPD/vAX/hC33qswYYHJgMzPS2dmZk8nyKYQNHrQaoJkqdXZ2aVJ7D3WpvFK/Hzr5ke++Aa1tTnQUhGXcti1AlxC0eqrZGQ0MfEGAvpIQTA0dOhRAPGtql4bzZTX5G6BP1/k/5arkSpdzpeMG+gIGGmltO9SgTLNSGjiUMlpbWwH49dYSAJbxCZslsskTfXV1dcXulbas8bjaGgMgp8LI49s5RKUKW6n26YQSdW1noSzpzK28NHtKmwIg+q31pXovNZOowWvSqgHV1dU5waJNXFm91sDHkgdmlzUjS92190+Thz5CwHunwRrjDGv3SxF2wZR8ffckAbSlDPL5zi6CysrKWOsU4UskaMusxg4kyIwVWGHRqoAvy15qoAx8SRc7zoE+XfPZZNpBnTPB/TMhqyRAW4B0e02i05dRnnaRgoFglU0MVkXVbF5bW1u0Pq8SADp4Onkb9PfHCObLYpUy9Gbbl+1t1LYoLn2oA47vWprWcmE+NmsV35ZVlZSQySZhYLOE6RvYKgfNhBQiukC8DK3zXOw8AJUhjYHK2xccr0iJb1Ujk8nkyJROwbas2HcNkHzBGHWaLXHMKNHx20DB1wbAQEHJa1JaCdluQmjmWh2OJatMCtj5U3rdSpBYEdDAwBIOfsb961yjQgQiqVje+Oc2PltpyYANGqzuKSHxHcNns0sd2n9fKJiyNk99Eu0zyZMmRtRO+AJ3X2WL+9V3bpskGduqqS8xanWO79ZukhT5bISOfxsj8Bj2nmoyR+MFndNRymD12I5tuza/xlWqe3YuBHVTq4u2ZcgmXVWf+Vsbb2iMN9hWq0GRgP4MDnuTKSwGrAz+m5ubo3eW91VZyMqBuNPzBU6+QZ40tLW1RbKyzJmGkNDWKn7PNghWV9hGZLPTqpD6vrweM3sPmaFNAgno6OiI9EJbcGyWKB8B0Oy/zV5rdUSXauT+bHVMDatdLUOJHBBvWShV0NmrwfQ5ev5tHbklYvycxhhAjnPS46VSqdiqSraKwmNomTcpJIstJ9bpUsbW+TPzRyeixAlAzCEpEbAEgHrLv4G+NjY9PrNX3EaTFaWut77qlNVFALExrrZWV/2yFQMSJRs0KbRaoNlWZlmZdaX+FtpXKSIfmfTJXf2P9elMnLACbucYaQsmt/OBOs9tlSDbZFmp2wWgf+2MvuSM2lQG7VrRz2dn6a8s2bf3k2A80tHREcUjgw1WVyW0LYqgbdQJ0+qrfZVWTSjSFjNuUn31xQb6ua3aaiuytgkNRmcHTAIGY3z0xHXCHp+expU+9Em2dklFVWBdHcNnZJNiIH1obm6OESYOThpCXpv2mTFDCixTXgb+JFT61DqtsFgGqxk+KrbdxpIMEowkkABODlXmDcQNHP+3jlZLy3zn52oItB1IAwP2AAP+zIkGUAx8NQgudZ0mCVC52QqAkh7b88iWP22FAPqCL05qU8Km9y+VSsV6KnViNZ2c7o+/S4Kz10qAdeaabQP6CIAGVrpClSYSNKtEqG5rj6+vTUXvJe0xHaElvaUKvT5t9VFdo//ShRmAPrKjLWespgDxydTaa879anaP2/Me2ftcaBJ4KWN5JECDHA3E+RufvaCt0DY0XSjEBmF80Z/ZlQnVZvMeJyFbTXlSt2jnfNUA6jeTJJo80EnUKjdFPn3Mp5fUZV2qmTGJBr+lCl2OndCVE3XFHiZhqX9q/3ScUvcYk1HX1C7oRGDrn3wVQh8GqrcrnQQoAdAMMlf3qK+vR319PRoaGiISoI+5Z/BrHU6hRyYnGY2NjQByH6wGIJbZ5KDXjLMqmW2t4qu9vT1n9RoqGpU6lUrFes10Gw12eR+TUuJTEsCAhX8DuWV9wpZR1UjaKoDKGIhnTpVsaRbFBsOcDE8CoOdSqqiqqor9bzN5lhyoY6cjY2+kDcB6enpQXV2d09/P41BGurKSTlAF4uuGl5WVxXS21PVWKwGUnf5P6CQxwK+3WkkF+loKdNlFvrPVh8dhhUWXrtWqLO8jdT8J8JEcn1/RpJTNlKo9YLAFxFtYqa92fg8DOCVvhAYFNhBOip9Toqp6onZNM6i2sqJVUrUXtAcc80oCbBst4w4lAEyKsRVZE2TaKljKULlafVC7piSAVTtWl1Tn1F5o4kYrZEoEfHNX+HuVu678SN9Y6pUArUBrOxpX8OEKU5SfVu21i0VJKGVi27SZ0FKCxSSg2hrC3neeG2O3lU4C+gvNgGigpMFjY2NjRAQaGxujliAOVM0mWeZvSUA+wzgYoaxOcCUQIDd74mOWOlvcLr2qPdTas25XrtFyNCdGWSVWcmBXdmpvb09EJYAZVQ4YIJ65A3INqnXyBAcp74GSI5b6NEtDo6vHsARAs6haEk8CGJj6YAMbm1HiGKVD0evu7u6OTVTXLD73zX3pikRKAkhu9Tck0EkgAUyGqC2zZWqgr4TNgEeJkmaalJBqhlRbUjTg5THY+kMnqM5K7XwSyv0E9U8DbtviQDnQwdvEia126bNe+OK9s8GvzfRZQqG6zdagJFUDlKTmS7IQtrWC16cVP12mmvGAtgKpb1Ofx3GhJMC2H9ve6lLXYSXshK+VR22kLtGqMZqSAV/FRl9KBjQJoMfXe6kxSWtrayxRUapgQoo+HOgj+LSBqVTfEu5aXbWrVzE+sCSVoH53dnZG41y/twRA5c372N8qgQ8rXAmwB7VtEloB4OBraWmJ2oHYBkQCwAFpewNppG05SoOGJBjF5YFP9NMgUQedzeCpsWMQyixHW1tbzjq9PgLgUzJl9WoQldGq8U3Cuspk6wxYSQas0bRVGB2UdhvrbHQ+i6/M5ws6bQbFtmgkAfkqJ3z3jU1foOXTeZuh5XbchwZK+pRrbQfi/eB95e9K3RkBuVlk1R8bDFJndd1zbqMBrtoP1T+f3mvAq8EE92erijZrXsrwJVqsDmqlShMp2sJmH5Bm26e0ZYPgeNf/1dZ2d3fH2gi5TVKqg0Cfn7J+y35mdY7Xqf5J7xNJgO6LctGAlmMdQGyFHN8DSZXU2XtVirCZYvVjthoI5JIw2yJp57HlIwW+pGS+6jhjP41BklAJ4Ni0JNZWP9hK5UvOMt7Vp1Fbcsn900YUismo97wPOk8G6EtMrvJKgDolWx71lYT4am5ujr1YllNGrsegYbYPULDBf9KJgDWalmVT3how2lYUsm7NSisB0NYdrbRo36CWXtVZ2bkaqqClPrDpiHXSpA1eAMTkpXID4pN8nXOxB4SRALDvUTNXfPkImGZklGwlCTZQsZ9R9nrdVga2ZKoBKNBnhG2FUBMDSgaoy2pAuU8g/qC8Uke+8jy/IzTY1MBdg3dLAnR9cACx+6Qk126nsPcv33alBk006flqsGTloUFjb2+vd01vgnqquqbVE3suvp54O160wljqYCub1d987SO+tgdrHwhrK1T31HbTPtskmV10g0/ftsFVqUIrmTa5p/5K/ZzNSDNu0OSrkv18sLrH46kNUNuahISAhR3POgZ9ts3aZ/5G74neF7Xjdq6Vj2QpIfGda76uheWh6O1A2iJh20Y4CVgnAzc0NMRIAJecJJu3q7fYvjQfEbCw2bNShhpH/VuDQ5uBtpl+BqZ2fXXN2BG6vjcfDFZdXY2amhqk0+lY8KTvvDe6CkZSYEuXWslgYE/Z6TKo/K1mU+0cDJ3YzrYKIHe9azsvg3IGctd2TwJ8lRK+6/ik0WOLiwY1WjFU56SZI5uRUgLACoDOG+IELusMlUSXOvQabD+u9uWrjNjWp+OUL8rCjm1+zmMxSCBsRVadotqhJLUF6XNXKEPNCGsihE5YZaa2UG2gOn7N0vIY1i4AuVlyW+W2wUESSEBLSwuA3KSW7xooT82Osl2Luqj2gnbbtgNp1UYXImG3gS5oQTutz9YgSl2+7BqgftGnqL1UIsTFFehvuKgIfZa2YNNWaELQBqoai2mrmp4Lz4M2iMcudbtAO6nknrLlKoNMKOr8P1YFua1tYbXzuZh45UM1tdWd29JW20qO3ge1RyudBCgzUegJ6ARgtv5wDsDixYuxePFiLFmyBEuXLkVjYyMaGxtjlQAqox5THZ5v7fH+nHcSAiqfMQIQC17oXH0Za9uaYucFaMlPJ6PU1NSgrq4uCv41kwogZlxZiuXkmKqqqpwSVinCsmwgXrrjNepqBkpMuQ/tY7fbsu+xq6srCnZ9uqwPveJg12BAA4ckQINFINc+2AALiJfq1cDa1Tq0nGoz4NqOYecL0ZFxWwYO6iALZVdKBRxndAK+hRF4HXzgj80m+wgD9Vh1UuXT09PjrSQCy+4dJxgC8WDD1z5Xqmhvb4+1Rmg2WhNN/D9fYoZVRJWd7VPXsj3vkRIvoI/Aqc/Te23PodRRX18f+99mS32VKgZgqt89PT2xoFfnYFkiwPuhNqW1tRVNTU1R67E+adw3zzAJJKu1tTWWOFLirRl4PkOIZIr+hsSAxIi6T3vKpIpdCMASAL1/6hc5JnS+VllZWSIWDtAFEZjw6O7ujhYXsc9loi9j/Ebd4e9t1Ullx3kGdlU7m3RgyxDtq87f1JhhoBh0JcAXVCsJ6OjoQEtLS5Ttb2xsRH19PZYsWYJFixZhyZIlsVWBWKLTgItC4M1Q5cuXFSnkdEp9UAPLllrkNdiMpQ5yzTTZVh+twmhQpc8KSKVSkbJVVlaipqYGI0aMQF1dHTKZTKwkqkZVs4uceGnbWkoVdklQoG9iJAcnAx9tXbM6aY0dK13a6kY56cNyeOyKigq0t7dHZMtHAjSIKnW5ArkkwIKyVb21hlTb2ZS02hI3x7slANoSZFevYX81Dafe71JHOp2OkQB7fTZrpfqdr7Tc29sb276iosK7iACdii5pByyzTel0OjYW8rW5lTJIAngddNY28QQgIvaWCGgl0bZN2odeArltbXYOjK30WCLQn+p3qWDJkiV5v+N1MEDk0osMyjUTahNdmuRia4+1FXzRrmglgIEo7W9NTU0sAabJhlJFS0tLjAQAuZPRLSmlvjPZxSqJrmjHBF82m0Umk4ky2qqbWiXT72h3eJ9IKqqrq+GcizLjpW4bGMxrJYAkgL7ExhG8Th2vbW1t0VOBaQM0g0+5sQqgMYF2LFCWqt+MXVY0BitaO5DtCWUg1dDQEAX8JAF86bKgukQXS0ZKArQUYqsASTCG/YVd0lDX/7dLT2lVwGaefGsha1ZPlYxGcNiwYRg+fHgUdDBToGsnk0BQaUt9MCu0FGcrAZQndVcz+zarodkBDVxJBnR1IM3I0rFUVlbGiJmPBCStJahQdkfb29R4cXyrzfBN1uM+1A5oFsW2xthASVsQ+FvNxJa6fJUEaNCojte2/OicHiC+agjthTpz/R2zpM65KLhiBk+z3wwObH83UepyBfoeIMgVU+z8H8pZSZNeq16jtlhpS0q+SgDvAZBbsbFLBuo9T5LP00qAbRUkka+uro6qctlsNvqMK7fpij763BsmXmzsoJVy/T3bgPjAKiaFqqqqIv9XW1sbkdtSb3Flq1UhaKuOPnOGyS6SI/osYFmyLJPJRCSKgaZ9hgbfWc1WwqyEmIkaAHkXxyg1ZDKZ6G+ONeoXiZT6IE1CsVNC21R9La0qS8qXHRhKPli1tS2GPT09MVK12kiADUSpNBygfCDY0qVLUV9fj6VLl6KhoQEtLS3eJwRrWdV3QTbwT4ox7C+oRAwgbabU9trxxlP2djKUVgS0EkAHr+W/bDYbGUHnXJQBZ3uWlsnVCCSlR1UdOM9XZUcSoFUTvpgd1LY3DZKsw+f29r7oSgm6fxoGIHeZt1I3mEDhVcIIXiMdvhpXZqlIAnRxAFt25r50n5oU0EDXBsFKsJICZiaB3Kd/WmLE/lK7hKfVQ4XdH/Wa+kuCR3uk+9WEQlLsgIK6qFUqoM+u5rsmS6q0tQ1Aji3Wcay2U6sBfM+3BDaPq++ljtbW1uhvbVfgdVdUVEQkM5vNxnSMgQ9tgVZdtfWSmX1tRVMiwN9qIoznwHuRTqejlthMJhOR21JGvuqr2graQEtebdsrl/lmZl8rrUpsNXDVKqISVx6HPlVtEwl1qUPbsm18y3Gv9pZtmqwo2mqetmED/moAKwJawdbjMxbUpISv/WugGDQJ0ABFnYz28bI1gi1B+TL/vlJGvoux2yYpUOoP8snVBprWGfjmBORbq5YDk/AprU5GsS+bbUxKdipfL60G9nag+3TS3iMlR1olsDLxjRXeL0sCkqbX1mGq/GzLjS21a3+vOmoaNqurmr23wRj3xW2om8wI6jhQ41nKyDdPSEmpBq2aTabjAHJXnNL7oKSYWVINAqyMOzs7o98yU5XE6qy1bUBuYgtAjh1WW6r7oqxsC4uOATpvbaUC+iZPsspqV2fRLGBSoNetY1XnlrACY6vcPrtsW19t+4/9Tueycb/aTqiBl62klToJ8AV+Pp+hvkV/a5e01lZfG2PxnfLWqqTqpcpNyRjPQ99LGSSk2u6nXRYalDOTr/ZX7SD1W5NPtkrt63DR/Svsb1fUjw2KBFgHTOHoMpV8QhwrAZwX0NzcHLVYqDLpBVGYejxlPLb39PME7S212WLNTiu4vWb+7epA+SamKNtUI6tZG91WJ7JoiTBfgF1KGIhR14GpA0xJj81asYwKIFZaZatBvl5ikgYNjJOW8ctmswDic0hssO1zBqrrWiWxrUBsr+I7P+d9Up1l5k9lni+bkwT56tOY8xFQjk29Nm6vgZePACkZ4HZcJYy2RtthdDvKlxUBYGDjbHVDl4W0/b+0oWx50Kq1tlYC8QSDtd82ANWVP0hWaWv0YWy0Pepb7dgodeiD0+jvqTe2kuKrnqi/12vWFjjul7aAQZq2SQCI2r74ty6AoYRWzy8JsPbNxkZaCbA6agNXrUTZldYsSQb6yDHHkdoSG7MlKfFSU1MD51yMHFG21FHKg61WtuVPSYDORVObDOQ+n4Wf+5I2PKbaYZtAHCgGTALsiaux06fCcVIw5wJwEjCrABQslS6VWjbjmsLW42mGT7N5VuGTYhgLoa2tLeeatZVHKwGaCVBmb+cB6KofmsnXrDWdHI9Px04HRSPA0ilfHPxKKEoVyzs/DiYOWDoY6xA0YAKWGUL2UAKI9en5+gYZ/NK5A31BVqnLMB/q6uqiMcj+fuqQLdUTGghotsWC+9WVlHSs0wYUat3wvScFrATYrKi2r/D6aU+pZ5yHoQkCW2Wx7QLaRgEgsh8MYElq9ZwYiBBJqQZkMplYAK59uFzjHliWnNGJpbZlTR0x74tOsqZsaAvUefO4nBhYU1MTyZi/7e7ujlXQk5IAY0bVJgypg9rCRz3lUpY6J0Dlp4E8iQDlTPko2G5EQkIfWFVVFbX+0Fax9ZXV2VKGZuw1HvM9a8k3eRdALOjnPquqqlBbWxtNliYh1aCYsYEmwzRBBsQfZKrnlIRYra6uLpZQAhBVmPUzAFFMRNnYrD71WPenJAyIz5PRJJdWe+2Yt8m0wWLQJEAvSh0MlwVtbm6OnghMEtDU1BRlpqk4DCJpDGh4NcDVkp9te7FKVerKtTxwso9Ptjajb42rtlPoaktUWBo/LTszEONEbq4SQkelgSwNBldy4OpA2hOcNKjTZ2Cjg1ZLnoQGS8zkUe/ouLVVwJaftceSzkYzMtYQJAFDhw6NZMYsvHMumoRnJ+5ZqB4DcQMI9AW5DAiU+PN+cTtbhaBsNTPD+54E+XKODuWjmT/b7sBgR6+P45t2Qck9g1KgL7taXV0dqxLYgMKOd7VDK+qQVjVqamqiYJxtOEDf8omUlba36rNAVJfVltgqLrdjggHoSybwt5yXVVNTE1UUdaI8/aolIKWM2traWMbfJvMAxMgT+9RpL/m5TRqStGlMoF0ItlrIvxkM6ypEJNnd3d1Rl0ISqlkkAZpd7+npifRF4yxeL+cKaZBKndcedxKAbDYbJWkZYzD+sHMraIdpZ32JTNsiVKogCWBCq7e3N5oQbAP4ysrK2CIfOl9CuwaU0GpVAUBsf0qs1MbyPR8RGCwG3Q5ky3j2wWA0mFybt6mpCS0tLVF5WUtzDHZ58ZqB8VUcfG0GSWCX/QEnUqmTtxNO9Zp1QCrj1n4/NaA0CNqiwiChpaUlyojYB4SogchkMhEJ0JUUSj2g8gV+Wl7z6Y8aN0s2KX8NLtlC4ZvrYtuw1AgwgNMWi3znXIqgs+/p6VuHmpU9JbT9bWfwzY/QQN8SYTWqSoZ1+Ved85KUFjYA0fJ6dCKagaMtpLx0gh7QtyKTVgdpg20ShroJ9CVhtNRtWwz4nq/ykgSwjU3tnAY8tJ9dXV3RYhZcNljbMm0lgD5R571plZsBHICYTWYAlslkYqSPlQA+U6ejoyMRcq6trQUQT2p1dHSgvLw8ugbqHGXG5wFo1pry0xVoNFFIEsCAVRMKvKccG3bde9orXQAiCXaBOsSxqX3rXMqSyS2f/LQNTudHMMlXW1sbVUoARBUovpMgK9Gg/9KYTdvEfa1dpYiamppYko6JUV87pLZra3u0vgN98+TU52tCRmViJwcvD6uUBOiJ2vK0b0Ua63x6enpiSseL16UwqTy+42hJKl/wTyeZRNAAKrP3OWDNuKkhoHxUNpoJpWJqXyqVnQSEPcHMjGmAoP3v2jOYhIBKM29AbumNg1GzQFoJsDJl4M7tKFdbMrX3TokAnY4eS4OyUs+YEJwYpZlT20OuGWvrCGxbir12JfwqP9vzzyCDDx/U/k1dvzkpLWwAonWgKQ9ej5IAQit51G+tKDIoZQVFyZH2seoqZZxrwKDKV5a24yYp4HM8dFwCfTYR6Htomn1R7qqHGoDaJBWQu1oW0BeMMSvNF5MzSkrYasvsZKmDvdWUJ3UO6PNb2j7C8QsgZ5yqnAHk7EdbVjQTzgCY8tUquNoaJQ5JgFaU1Y7ZuAzoyzQr8bR+nO+sRpGMaqVVk7VaLdc4Re2ErVjalbJKFWwFBhDNd9LA3FZBtTtFg387T8r3W99cAiUZhbL9/f2sEFaoEmArAj7jZwMgzSgzG63ZAO339wXzloT4jGuSocvxAXGHq5l77fHv7u6OBVw2kATiKwBp5p7KRqaumRcNAvJlwjU4KfWASkmAtocoqeKA1s9pwOzEHl/GSOcU2O3zHRvoI652XCUlK2XbQPS8td0PiE/0V93m/dEA3zoVlUu+0iiPoQbWOiTqtf1dKYIEi3JgtQPIneCnzogyJCHVuQTcDzNZlKfVR5s1tKSWKFQZKGXohEcNNIF45k7lSkKlds8uk0qd5/4pTw307ZNqNcGiS2cqsdCKbxJIAElWb29v9HwUDVDVzzFIVCLGIMqunAL4bY62uHHcMHHgIwFAcpcPVl8BxJ9Sy3ZelYdWQbW6b58OrNV+VvptckdjByUBWknJl+xJgt7mO3eSdRsjKPlU/dQ4y8aq+WIPABEBUJuysmxrUR8WptAMk87i1xUQMplMFGByLWzt9VeF8rGhJDmb/kLbpfhus5YcfOoUAERy7OrqQnV1day6oo7KlrR0FQzuRx2cvmzZnEFVEioB+Ry+6hcdj04Mo5HToNQXxHNbX1uFxUAcTRIqWyzta8AJIDb+tVrl6w3VzJztabdy1t9Q3mpIGXyUl/c9AEodfZKQTqdjZIiBkjpa6qe2RWg1yTo0rcwA8YnpGlxphgrI7UG1SYGkJWVsX75vXGuQaDOeSpSUvAKIAjHqm237YTulrrRme4qTDrZV2MoKJ1mqrlEfNRFm9U+hPlDbDWkvlIBR9mwHslVKJdlJyFQDiLXjKMnkdbCCyGCf5JPz+LTarxUobfXNZrNRJRBY1qngq+ZQ/ppY4W8Ye/A+JIFo8YGJTJCqDnG+DuMr+h7Vb+vrKDM7H4ufW78I5PpDtS1A/nkCA8Wg2oF8B9WgihfH2fc1NTXo6Vn2mPny8vLoM/absYeNQtInrtr90uj6mFGpK1Z/oNfNm2/bGNSBs/WKDF9LgcAyQ6ATg3V1FVYQlN3zO9v6YrMlNOq+5S9LFfkeuqQOl3JVeWhmg47fBqaUPbenUdb9K9HQQMJWCXzBSKmDj7BXYqoZOC2zq8PlO+Wvfar8XrdX0BZwfNDxa2+rrqql82V8si9VkAQwYFcSAMRlpAER9c5XieKYZuBkyb4dG7S/NgkAxAOBQm2apQhbDdLxR9jEiRICJQCazbctLLalUoMvLglqe46VsPraMbUNrFSRzWYjWdH30MdbIqvEQJNQmj1WndfxrHMFmZQiseJ94H3Rh1pZ+5oEnSUsWeI1l5WVRXNKgD791UnRnGBtCQBX/dMYTR92pf6ex7ZVKh6Tx+exlEyUOhobG2NxFq+LbXpVVVWxZHVvb2+kh2VlZbEVHXUVRa1kc7xr+7DuD/A/XZjQ5A63H4z+FqUSkC9oJZOsq6uLAnz2nNXV1aGmpiZSsKamJvT29kaz+8m+9OLt30kMlpYHBpnqOBjQcEITWTZ7fKuqqmJGlaAy6opCDMxUwdl3CvQFqNrKogGGlqP1OQFJuAfa/5tvcHEwq3PiwNLMhgWNMDOyAGIOWx08B7t9AjSJQxJ1urGxEUA8w8wgiQsAqIO3bW/afqITd3XVGxsMUL7MaFG+dn6SfXaJyl1XcSpVKAlgW4Vd6k8dMV/aYmXbe3yBu2a2VW+VZFHeOvlYA4B8k+JLFflaBG3GTb/ztWHYoErJqJ2IrjbHZ9uBeDUWiPdw63KOpQ6d4MwgUHUHiGcydbI10CdbW8ViAoxzf5i5BfoIE5M+tj1WiZ4v28rjljo0uccXM/laubPXxf/pv21FQIkA98UxrsSKUAJHG626zLEAJIdk1dfXR39rpYVjW+dZ8anLOk+Hc7NoE+2qTLQB1GUlFL4FLawN0aqBzj3k+Q4ERSMBlgBQmWpra6Olk3p6eqI1aIcMGRKRAH7f2dmJxsbGiDVSkfJlbwsx+SS0UPigrSZAfNkuu1wXqwC+iX38jEvZ2ecraLCWSqVi69X72oGU8VOxNRhJgqx1fWnVV0sCNJjUF8mTBZm9Zq0of3XyGjTpxFldLk9hWzBKGc3NzTmfkRjx2kkWdfEAtgNSdjbQUdKgwZnNquoxbAZHV77i/zTYSShNs6xPJ6SVPw2mgNxWQSCuP3ToHNOUhy45qXpL2XCca4LHJiTUTgD5W+FKCb5gSe0Zx7zt+edvSIy0zce+lAho9YR2WldrA/p8gCZjdFsuaZkEEsDMqbayWd3VyrOFEkqCOs5x3djYiPb29sjesNWYcrWEmdBEkJ5TUhIw2o5i+/t11T7A/7RpXRGIL106VZ9crf5e5aSZbbU53L8GvUlKbjU2NuYkC5kA0eCdz1Xi0ql81+oJZa8y5nin3OibuPyvdsPw3tqVsejjfAvBDAQrbEXyVQFY5qypqYkqAOxZVxLAdhTnHJqbm2NG01emXhNgbySdMkt2GmhyIFo5aTZEnyDMzKhm+n09/cpAbaCvQZbNoJQyfCV6vT6gb3BZUkWZWxmr7C20DKsPZNGWInXylnjx7ySAE1V98mWgROOpFSzqoF23W9drZ/aQGQ/Vb+u0fBkqyp2l2ba2tsgeJYEEUDdtRlNbTOw1aAuf1SUSAa0EciUloG8te5UxoTIvRNSSQACAvgexqWx0/HFlEG2D0vYgrZpoFpVJG06u1OypBr2aqQbiywbaHncN9NRelDLURy0veaetItxG5cTfaysbK3xtbW2xKhnbWCh7/Z1ttdB2W9tyUcqgHwH6bANbsJkwZJJFA0atYOnDPjVY9yVbfffMdgtoMMr92AeVJcGntbW1Aegb34xJdfxRlvr8AG3pUTlSrppMoew5xtmRofcKiC+IoS10dk7MYCuvgyYBvoCTjlrXmuWEQTp1VgLq6uqiUhOXEdXeKRt0quH09auXuiMfCFS26nx0dR8SKyqBBjPqMHS5Vh2Edp1pW4HgsmC6T0IzO/mMeinCnuPygm3VOxssqu77nLuv118dmxoHJQK+/SZBttaxAvH+frb2sd1KMx0kAXRedGAAYk9X1CyTrqiiJIvHVmfPLIwGZwBiDrKUodU56qBmp7TvWQNUHzHg/ihTJgeYgQL6ni6qlQSrk9wP96s22VYTSxnUBcJeH/VTA0VWZkhONYNqiax+zkBBHbgeiy0uvIf8n/ZBA5KkkACtOlud8Nk1bbnk+NT/Wf3gfdHnE3Fb7UsnObOyJ3Qehy/hlST4EiO0vST1TABwe/6GUNtA/WMlwK4wZv2hEgdbOdT185NAsrS6QRmqDlJHtGWS16fdBHqtKm+9P4y1tMKrMQXvnZ1P4YvPOC4GghWqBCgB0HKJLVOw1QdA9FTE2tra2AoefPosBWqFoD2vdF6WeSa9FQiIrzBRaMBo6VqzdcoyM5lMzMnbp/tSflq+Zy+x9ldqoGV705JCAIB4oErYc9d5D74+Z/2dtgepg7HtLWz94e9saVYDLM0M6v0vdRnz+nzGUScEAn1BLVuFSAJqampQW1sbkQA6ecpMjatvfWvaDdoMtQk8Vm1tLVKpFLLZbMyGlDJ0WUX79FkGpFo6thlNq6Pah8r5Em1tbbF5GpQ/0NduoNlqngdlbSdnJkGuwLInXRNKYjTLRh9GIkmnzM9JLm2Zn2NW7aq2a9Gpq13SymxPT0/0BFttyyKhSIJ82abDgNLOGbFJFAAx+dKXazDJ/y0RoH3Q8WFbrWzrkW0HSlJlm+fp88e+6p8N+Lmd6rzaBCaoSEibm5uj3nfNPtuKirXLlgQkIWbwVVttYs7GX/oE61QqFZtsra1A+rnqeWdnJ9LpdNTCzRjNRwh4jvo+WAyKBKggVAi+thROBNZSMyed2F7XfBPe6HS0pUUDtRXphyo1qCHyMUItF1tDRuPI7dVBt7W1oampKbZCAGWpWUbNoGomhStY2MGQJOjyq5Sxtv4AywJVPtyuvb09egS7Lu3lI0CaPdFgl8fjvm3ZUEv9WsGxRrPUwbYKzQRpGRXom29BuXGFBf5tn1LJ7AiDJ2ZkWfa3yyvaIIs2gzJmn6wuW5iEQKq5uTmmX5w0xrYHDXw0KaO6w2u1E/yp562trZHj0ayTzZCyDYPfW0Jhnwpd6vZ41KhRAOItjjpPggTU5+DZ+6ur/KgPo9/SYFf1TjN7WrVloKtVM2YkmexJQhsbADQ1NeVUnfQpykDu8z0ARJVuIO6HaE/sktYkSmxlA3JbOai/qpeagLEBaqnbXeoJr1PHuq002etVaFLKVgl5jO7u7mj+BRdu8VWorE3WRFCSCFYmk4kl5TQutZV+Jr/Z3k5/o/Mz9PkL+gwGEq1UKhWzodyvxhOExizazjVYvR0wCbBZVGWXFJiWqbPZbDToeWFkRWSYPT09UZCpiqIMVTMJ+uRLm81OOjjo7IBRpq7BjpaptK+Pf9Npt7a2RkG83reysrJYsKSypJJZpbWBRVKgqyn4WD4DGiUAdFrWsWhZmr/VTArQN6mNL10VwydrZqt1QlZSelRZ1bMkQMkMjRe/y2QyETFiEEUSQGfOpYPV4dgMrLYCUea8j7QRDCS4GIGOg1IHAykNVBnwMAPtaz0D4plPDT5pQ7VdkBUr7k9JqQa2anMpbyUXSSEAADBmzJiYXLUywr5gmz2urKyMrd6hlSgd47bXVwNQJVHWb2pAp1Vvjiut1JY6uMqK6h3HpU0aqnxIgoD44hh8/g2ThkB8AQddjthmpzULDuQ+eDRp0MSLtkqRFAGI+WtCK9jcnuNdk62tra0xXW5vb0dTUxNaW1sjYsr9aXKXCV7eN1+FpdSJgJIArbYwntIMfllZWeT/WEGlv7HLsLIdlfEuk1r0cboaHhMD2o0A5M41tAsPrHQSQFgHqoGlZkfsZBQNUgFEF6nMUbNYNkOTjwR8nioBKktbUlbjmK+sSdnTmXBCpDpxbRFg8KsyZMClE9744BD+zlYQSh1aCVAioJPXGDzqJGpWS6jvNLa2HUjbIpi5U2dv11i250L919UdktKjqhMslQTYoEaDHS4fqCSgpqYmliFhQKoZEspWS6y2V1grhTwuszU6WTMJstWVl1TfbMuPvjTwtw5fK4oMynzLpmpwqhPS1CnZHuEk2QNgWSVACXx7e3u0klS+tjIGmhrEWsLuI176OZCbneV+aJcUSjiSIlugb+lg24am7WQAYkEVgyggd/Iu50MwTtD7pLqrLZlKrmz7kd6jpMURNmFIeZAw6gp0+YJwnw/nvBRb5WZyrKWlJfYcGN4PfQgs448kZf8VfCAYkNtvr3ZBA3ZWCIH4ggG+NiC+67w0XSqU+7DxghIBEmMmwjSuHggGXQmgUdISmmYBdWDZLJWWRm0mRVdisIqoDswOXoUt9yUJtqxjjZxuowZTs3YcjKy2sKzNikAmk4nK3fwO6JuUpSRB96mtHWpsk2I4taymhpG6TCdlgxvrnGw7kdVR2xPNlXM0E67H5yBXw8ESY1JIgFYDfe1SlC+rWAx4+FtWnPikSrakMLvHV77ed3sfVTd1bGSzWQwZMgTZbDbnyaGlivb29uhvla/N/NighmVm1Vvf9vybuqzZWB5TbUS+/m6i1O2AYujQoTECX1FREWX7mHG2JED9G5D7pGBfxUV9l8LeF2sXbGXXtqyUOpTAqlwKtePZa9NgnkGTLzawRFj10MqWv+M2NrGWBB1m9hlAzrXxemhzAeTVG42trE3g2CD51+Qrj6sT5rUH3rZ2Jwm6Ilqh5Iq2tGlLvPXlvvYojXHtogJ6L7S9yyaJ7XzNwcQKK1QJUPBm0xlbRdLf2cBWJ/1YQSmjL3SBSRm4ywOvLx8B0uu32RLNsjLo4fa2XG3LhJoRsAGwHkOz4bw3anBKGXot9m/N6qtDptOhMdUXgNhABnINqgazPuKgmS415Pb/Unf6Kk87/tVhqJP1lVv1OvONad82uh3vG0kDs7fMwOgrCSTArtJjyZQlWipX7QsmCeXL2lraYiVVqs88F0uOCbVdSQHnh+jEXVYEWMn2TeJXgsnPKX9C24v0gY3cHogv8GAJnrZ4+OxCEqAtVXz3JfHUfvBdr9H6JStL3juFtqhpO7Jup0kb21pc6v6MD0OzsJ9Z/wbEn+gLxO+PHfOUi01E8l1bBu1qQOoX9T2JsL7dV0EC/C3zNvmiNkf36Yvl1D5zn5bQrQh5HVQlwGf41QHxQu3v9G8GQnqh2uOUyWSishSdnjoty3yTlB3pLywDtZk9zYjkUzYgt+Snf2sZVfunfeV+7R0Ekuf0fXrLv/UayNw5MOk4tM2FmQJ9UuLydNA6OF+Wj/ux977UdZvVDiB3+WAbqOp2lClXAVK91knauvIM127X//UY7E9lYKAkwLeCUKnLViscQDzL59M9JVVsq9Igk7/VTLUSAF1Ske1x3K/aADvuC93rUgXbaxiwcIlrVj51NSaCcqBc+RmQuzQtn8Hg64O3iQRLAHhOvC92knYS5EsSAPgDTf1cZaLLLpKYcVlrnbjNe0eyxHfqdnt7e2S/fXOsNNOtwV0SoC0rCl8FxMIm92zspkkErVTT7/EYdjlcW53Nd16lDtVNG4Np8K+6w9/Q7mo1wQbutK2Uveq03Y8mxAmr57q/VVIJsJm4Qmw0n0PQjBZfdNTs26Wzp7PTSZM6GcI3+WV551eq4Plqlt2W9wltB9Cbz+3ZQ63zKFjO0/4znQxEtplOp6PeePYM62oNlowkERqEqsNVI1ZZGX9IleobjZqu7e0r5dvMnr70M9vjSgORBJLV1NSU9zubRbUEiKSL+qkrUnBVq5aWltij2bW6QJkxOGKyQI+pFbKenp5o1ackJA9sn6eSdyCXVKou6fYMau269dRtBlW0rQxiScjU3mg1zEf4kkIE6F9oa+0kPtvC48v4aaClwSTlx4dZaR+1ZvWVsFF/9Zx4DI4T3t8kQEkA4M8I+/RXs8m9vcsmV3N1u66uLrS0tETBD0k+ZcrglbJ3zkVkwFdZzUdKSh11dXU5iYF8bVb5gkMdt1ph1Ey3Zq6ZqOX+2Zdu9dImh3XcJEG+nFgN+OeOcJzbZcQ5btnOaxOpjGm56p36e/v0dk2U059xPPC8GM/5qgP9RVGeGOwrg+j39l2314vUh4zV1tZGEwMpWJ0IwdUBtI2gEBFIClQ2loGytUS/B5Bj2EgA6Iy4RKiudkNl09VBGGDRaOqLD7IAEOv1S1JpGvBPCtYgHECkdwyaNOBSg6mVAiUHNhOrpVKtJOjvdPKsBrl6zFIGe3+1wmQdiQbk+gRGdfr8m46cE9Gam5ujpenY1uKrPOpqDJo11XHEfQy2fLqqwUDVZplt9olBogaTGpzSkXMhBg2KuGQwg33aXT6EScmDr69VbZDNjpcytI3EVqS1Nxfoy95pYFSIALGK0traiubm5iiTTXmrfuqxeTw7f8aSgCTIV+ez6Fiz2X8lRDrhnz6vvb09SoJxbhufQcA5a/wtEwitra1IpVLRXA8bK1iymgR5Kmpra3Oq/XaCdL7qN+GrZBF2ngT3q22A5eXlEWn2tRrbc0iKzVW9JZTM6NwoVvsYkHN8p9PpiATQ7loiylgi3/0D+iowHAuWVNh5iwPFCi0RagPWfNvaAWYzKOrAuMxUNpuN2LwyIl0SUFuDkjaA88EOQq0G0Ahqf58Navk7VgGYRVESQGevDFbbAnp7e2Nr5fOdwZkGvtqSVOqwxt9mnmyriA4yWxakjHn9Wklgq4oGRzb4t6V97gOIl6iTklFtbm6OOQ3bYmazGr61pOlESDapo21tbWhtbY3Wsqf+adma+2YAwYeO2aoCV7bgA1lsX3spgnpBGVKmHL82G6pPolQSpM7LRwK0d53HU1vBSbLcr03E9CchU2pQ2diJdpxToiV/jm2fHVE9BPpK9S0tLWhpafGSAN4rXdpSbb6epxKBJJIAQsebys72lVOf1R+xdUL/10QLdVDbh3QVML1PtjLpqxKUMmpqanIShbS5rHj6qgKEkvZC41Z9ni4Iwn1oK5CvBSiJlRZtgbTQyggTrKxUs4LAxVlsJcVXhfLtX0F7wXtKu6IPdtPYbaDyXeEnBucjAv0JXvIFYwwQ0ul0lBHUEos+HMgqcFIGcD7YG2izqdomQtjrpjPh9toKZHurdSlFfbCSfS4DAw5f5i8pgaolor4WCg3EfT2AtsrC8p4NBGy7i7a9+VqA1KnrQC51mRKtra0A4g9NU93RbCvn/DDbqg7cVkSot6wCcIKfTlCjTlZVVUXHSKfTqK2tjSbPMaPNtgx9CFypt7Nptlodqo5dEietnNKeavWKOq16rhknDfppO7QCyePoHC2tPiSZBORrSdHAylf51EoAZa8lfuobM9d05rxXfGf7AL9XG+CrWiahAstkFKF6TD3R9hKOYyWpGnzSttg5AdyH3gdW/LgCntVNJa+2UsBzKWXQ3gHxp1JrLKTzCH12TnWrUByltsNXgdRkrI0F1cZo5aKUocuJ893GWKqLJAL6jCCgr7pIYq8VEuvnLSnVzwDExjztFkkA/dkqJwE8+eVVBOz2QG5PoDV0VVVVUfaZF67B1Oct+FeoPGnYqDTqtLiNTwa2/O9b+lIDXAZDNBpaxipUekoKASCsM7DVKHUi6rA18CJ8s/V95WbNuDDgV/3Vwa3HThJ0DX8G70oCqFca0BKUq2ZZ9XMGpro0HY+llRo6FgYFzMYAy8YU+5N9gUQpQ0m/ZtP0+oE4WbBEVGWqFUMG+ul0OiYHzfjxnurkNNvW5Stz8zxKGRq4WyespF0DJf6O776xTqgNZpKF94Pkin9rkGWDtnznVOrQ8WoDQb1ufqY2USsqDObtfDZbiSUscbD3y95nyluD4VIPVEkCeI38m+PUZze4TSF5APFJ/pp44Da8Z/SblkDpsfiej4iUIuxy4jbGsqRIH7xIMBnFNiDuSxOKOhZ0bOc7LtBXEeY5aEt3ocpPPqwwCfDBKoLPEVgn4VNGTk4j1MkkLeM0GOhgHejvbAXBZrO1PUANhX5nf6ulvVIPnPoDq0e2HKzZQQ54NYa+oJ/79b18xFWPp/faVxEqZWjLFIMdrQbQEXNiH8c2PyfxVJ23jt9OhrWVGg3mNFMNIKpiAX293TaILlX4bIAd4/yMfy9Pxxgs2FYKbWVTe5FKpWK2QvU5ye1AgD+gL9Qm0l+CYwMF1XfqrS5Nmi+Lp8QjH9koVdiAxJf9tP5Es9n6O18ySvejpEj9m5IAa+91InGSyCvQR/rVf6tf4hhe3rVY/2WrCUA8FsmXwM13nFL3XT74ErH2e9Uv+iomw0jwbbJVbba2yfriYCtn7teOC9oRVh4HipVCAgYLn5FTYfiE9XnD8gL/gQwoG7T7ynJ2Ow328+1jMOeSFPj0yqeDPoaeb3++d3u8fIFeUpCPOCp5KkRGNdDUTLV1cno8S0wJa0fsZ7q/JBFaX5ZNZee7Dp9++Yhvvsqqz24UCgIGMjZKCfnO1Ufw+7uffPbXZkUHqoNJClSB3HGrf/v0Sf+2eqe2RGF/p8fS/diKJO0NW+SSZg8YpNoxqNvo+/L25xvHvn2tCXEYQRkziQTkPvnbFzdZ32e/09/YilihMcH/1df6zqG/SLmkaHxAQEBAQEBAQEBAQFFQ+k2FAQEBAQEBAQEBAQFFRSABAQEBAQEBAQEBAWsYAgkICAgICAgICAgIWMMwIBLw0ksv4eKLL0Z9fX1RDt7c3IwzzjgD66+/Pqqrq7H55pvjtttu8247Z84c7Lzzzshmsxg+fDimT5+OBQsWrNA+Sw3Flm97ezt+/OMfY9KkSchms1hvvfVw6KGH4s0334xtt9tuu3kn9aVSqWh1FR/efffd6FkOf/3rX4tyzisLxZTtc889l1deqVQKV1xxRd7fnnjiiUilUthvv/1yvkuq7hZbbwGgqakJZ599NsaPH4/q6mqst956mD59evQ8AgB4+umncdxxx2GTTTZBNpvFhAkTcMIJJ2DhwoUF911fX4/Ro0cjlUrhgQceKNo5rwysDNkSyxu/weYOHGeeeSa22WYbjBgxAtlsFptvvjkuvvji6InaxLHHHlvQhnz44Yfe/a+purt48WJcffXV2HXXXTFq1CgMGzYMO+ywA+69997l/vaKK65AKpXCFltskfNdV1cXLrnkEkyYMAHV1dWYMGECLr/88tgSkaWIYuvtvffei2984xvYeOONkUqlsNtuu3m3e+2113Daaadh8uTJqKmpwdixY3HYYYfh7bffLrj/rq4uTJo0CalUCtdcc01RzrkUsDLs8yOPPIJtttkG6XQaY8eOxUUXXbRy9dENAFdffbUD4N57772B/MyL7u5ut+OOO7qqqip35plnultvvdUdeOCBDoC74oorYts++uijrqyszG233XbuxhtvdJdddpkbOXKkW2+99dynn346qH2WIoopX+ecO/jgg11FRYU75ZRT3O233+4uueQSN3r0aFdXV+cWLFgQbffkk0+62bNnx14//elPHQC377775t3//vvv72pqahwA99prrxXlnFcWiinbjz/+OEdes2fPdtOmTXMA3Kuvvur93WuvveYqKipcOp12X/3qV2PfJVl3i6239fX1bquttnJrrbWWO+ecc9wvfvELd+WVV7qvfvWrbsmSJdF22267rRs/frw7++yz3e233+7OOeccV1dX58aMGeMWLlyYd/8zZ86M9Pb+++8vyjmvLBRbtopC4zfY3MFhp512cqeffrr7yU9+4mbNmuVOOeUUV11d7XbaaSfX09MTbffSSy/l2I9f//rXLpvNukmTJuXd/5qqu48++qirrKx0Bx54oLvhhhvczTff7HbffXcHwF144YV5f/fBBx+4bDbrampq3OTJk3O+P+yww1wqlXLHH3+8u+2229wxxxzjALgTTzxxhc95ZaLYejt16lRXW1vrdt99dzd8+HA3depU73aHHHKIW3vttd3MmTPd7bff7i677DI3ZswYV1NT49544428+7/22msjvb366quLcs6lgGLfh8cee8ylUim3++67u1mzZrmZM2e6srIyd/LJJxdl/z6sNhJw3333OQDuF7/4RezzQw45xKXTaffJJ59En02aNMlNnDjRdXR0RJ/9/e9/d2VlZe673/3uoPZZiiimfP/73/86AO6ss86Kff7MM884AO66664r+PvZs2c7AO6uu+7yfv/EE0+4qqoqd/75569xJCAfJk6c6DbeeGPvd729vW7KlCnuuOOOc+PGjcshAUnW3WLL9pRTTnHDhg1z8+fPL7jdn//851hgxc8AuPPOO8/7mzfeeMNVVFS4Sy+9dI0LpBTLG7/B5hYP11xzjQPgXn755YLbvfDCCwXJ05qsu/Pnz48lrpxbZlP32GMPV11d7Zqbm72/mzFjhttjjz3c1KlTc0jAq6++6gC4Cy64IPb59773PZdKpdw//vGPFT7vlYVi6+1//vOfyJZOnjw5Lwl48cUXYzbBOefefvttV11d7Y488kjvbz755BM3dOjQSG8DCciPSZMmua222sp1dXVFn5133nkulUq5efPmFeUYFv0mARdddJEDkPMa7MXPnDnTAXAtLS2xz++//34HwM2aNcs559zixYsdAPf9738/Zx+TJ09266677oD3WYootnznzZvnHXD8/Lbbbiv4+3322cfV1NR4jWtnZ6fbdNNN3fe//313xx13lDwJKLZsfXjllVccAHfxxRd7v7/zzjtdXV2dW7hwoZcEJFV3iy3bpUuXunQ67c4++2znnHMdHR2uvb19QPsYMWKEO/jgg73f7bHHHu7QQw91zz77bMkHUitLb5c3foPNfa+ox3nggQccAPf4448X3O6UU05xqVQq7/GD7ubiJz/5iQPg/vnPf+Z89+c//9mVl5e7f/7zn14ScO211zoA7s0334x9/tprrzkA7txzzy3quRYLK1u2hUhAPmyzzTZum2228X73zW9+022//fZu/vz5nysSUOz78OabbzoA7pZbbol9/uGHHzoA7rLLLivCWeei3w8LO/jgg/H222/jt7/9La6//nqMHDkSADBq1Cg0NDT060ll6XQatbW1AICOjo7osfOKbDYLAPjf//1fnHjiidFjmDOZTM7+stks3nzzTXz88cdYe+21+73PUkSx5bvRRhth/fXXx7XXXotNN90UX/ziF/HRRx9FfdaHH3543v0sWrQIc+bMwYwZM1BTU5Pz/Q033IClS5fi/PPPx0MPPTTIK151KLZsfbjrrrsAAEceeWTOd01NTfjBD36Ac889F2uvvbb390nV3WLL9i9/+Qva29sxceJETJ8+Hb/73e/Q29uLKVOm4JZbbsHWW29dcF/Nzc1obm6OzkNx//3346WXXsK8efO8ve2lhpWlt8sbv8Hmrph8u7u7UV9fj87OTsydOxfnn38+6urqsP322+fdT1dXF+677z7suOOO2HDDDXO+D7rrx8cffwwAOeO9p6cHM2fOxAknnIAvfOEL3t/m03PV3VLEqpJtf+GcwyeffILJkyfnfPfqq6/izjvvxF/+8pfP3UPFin0fXn/9dQDAdtttF9tm3XXXxfrrrx99X3QMhDHkK31MnTrVy4js65hjjol+Qxb+wgsvxPb1wx/+0AFw++23n3POuZ6eHjds2DC35557xrb77LPPoh6zv/71rwPaZ6mimPJ1bll2eqONNopts+222xbsl3bOuZtuuskBcI899ljOdwsXLnR1dXXuZz/7mXPOJaIS4FzxZavo7u52Y8aMcdtvv733+7POOsuNHz8+ymj7KgFJ1t1iyva6665zANxaa63ltt9+e3fXXXe5W2+91Y0ZM8YNHz7cffTRRwXP5bLLLnMA3NNPPx37vLW11Y0dO9adc845zjmXiGyqc8XX2/6M32BzV8wuvPzyy7FtNt10U/fss88WPI9HH33UAXC33nprzndBd/1YvHixGz16tNtll11yvrv55pvd0KFDo/krvkrAgw8+6AC42bNnxz7nfLgttthi4Be9irAyZTvQSgBbh21LYG9vr9t+++3dEUcc4Zxz7r333vtcVQKcK+594L7+85//5BznS1/6ktthhx1WyjX0uxJQCNdeey2WLl263O3WXXfd6O+vf/3ruPTSS3HcccfhlltuwcYbb4wnn3wSt956KwCgra0NwLJH3J900km46qqrcM455+C4445DY2Mjzj77bHR2dsa27e8+k4bByBcAhg8fjq233hqHHnoodthhB7zzzjv48Y9/jEMPPRRz5sxBOp327ufuu+/GqFGjsPfee+d894Mf/CBaheXzgMHKVvH000/jk08+wbnnnpvz3dtvv40bb7wRv/3tb1FdXZ13H59H3R2MbLmKSiqVwtNPPx1lSb74xS9G1YDLL7/cu5/nn38el1xyCQ477DDssccese+uvPJKdHV1ee9REjFYve3P+A02d8XswqRJkzBnzhy0tLTgpZdewlNPPZWzOpDF3XffjcrKShx22GE53wXdzUVvby+OPPJI1NfX46abbop9t3jxYlx44YW44IILMGrUqLz72HfffTFu3DicddZZyGaz2HbbbfHKK6/gvPPOQ0VFRSJ1txiyHQjeeustfPvb38aUKVNwzDHHxL771a9+hTfeeKPkV7FaGRjMfaC++eKEdDqNxsbG4p2gYiCModiTIP785z+7sWPHRqxoyJAh7s4773QA3IEHHhht19HR4Y4//nhXVlYWbTtt2jR38sknOwDu9ddfH/A+SxHFlG99fb0bM2aMu+aaa2KfP/fcc3kzTs459+677zoA7rTTTsv57uWXX3apVMo988wz0WdJrwQUA0cffbQrLy93H3/8cc53X/nKV3KyKr5KgHPJ1d1iypb7+uY3v5nz3fjx493uu+/u/d28efPciBEj3NZbb+0aGxtj37333nsuk8m4X/7yl9FnSc+mDgYDGb/B5hYPd911lysrK3N///vfvd83NTW5bDbrrZoE3fXj1FNPdQDcr3/965zvTj755JxJ7b5KgHPOzZ07102aNCnS3erqanfjjTe60aNHu6222qro510srEzZ9rcSsHDhQjdhwgS3wQYbuA8//DD2XUNDgxszZkxs5aY1qRKwIvtKZCVgyZIlUYaoEDKZDIYOHRr9v+uuu2L+/Pl444030NLSgq222gofffQRAGCTTTaJtquqqsLPf/5zXHHFFXj77bcxZswYbLLJJvj617+OsrIyTJw4ccD7TBIGI98HH3wQn3zyCQ444IDYNlOnTsWQIUPw4osv4pRTTsnZx9133w3A39t+9tlnY5dddsH48eOjvtTPPvsMALBw4UL85z//wdixYwd0basbg9Vdoq2tDQ8//DD22msvjBkzJvbdM888gyeeeAIPPfRQrI+3u7sbbW1tWLBgAUaMGIEhQ4YA+Pzp7mBky8yIlSUAjB492ptd+eCDDzBt2jQMHToUjz32GOrq6mLfX3jhhVhvvfWw2267RfeBvcSLFi3CggULMHbsWJSVJefZiYOR7UDGb7C5K2YXFAcffDCOOuoo3HPPPdhqq61yvv/d736H1tZWr80Nupsr20suuQS33norrrzyShx11FGx7/79739j1qxZuOGGGyIdBJY9M6erqwsLFizAkCFDMGLECADA5MmTMXfuXPzrX//C0qVLMWnSJGQyGZx55pmYOnXqCl7pqkcx9bYQGhoasM8++6C+vh4vvPBCTmXhmmuuQWdnJ2bMmBHp7X//+18AwNKlS7FgwQKsu+66OXOJPi8YzH1YZ511ACyzxRtssEFsu4ULFxacU7RCGAhj4FJnK6vHzznnbrnlFgfA/elPfyq4XXd3t1tnnXXclClTirbP1Y1iyvdHP/qRA5CzrFRvb6+rqalxM2bM8J7D5ptv7jbaaCPvd+PGjSt4/KFDh67I5a9UrCzdveeee/JmpJhlLfS6/vrrC553EnS3mLJ96623HAB31FFH5Rxngw02cHvvvXfss88++8xtttlmbvTo0e7tt9/2nl9/zmPp0qUrKoaVgmLKdkXHb7C5g/Np9fX1DoA75ZRTvN9/5StfcbW1tTkrLPX3PNYE3SVuvvlmB8CdccYZ3mOySlLo9Z3vfKfgef/xj390AKJ5M6WIlam3y6sEtLW1uV122cVls1n30ksvebfh8xYKvbSamFQU8z7MnTvXAflXB7r00ktXyjUMqBLAlWLs09GK1Ye2aNEiXHXVVdhyyy2x1157Fdz2mmuuwcKFC3P6AVdkn6sbxZQvM3D33HMPLr744ujzRx55BC0tLfjiF7+Y8/vXX38d8+bNwwUXXODd/6xZs2JPbAWWZbtvuukmXHPNNdhss82We46rCytLd++++25ks1kcdNBBOd/tscceePjhh3M+/9a3voVx48bhvPPOy7tyBZAc3S2mbDfddFNstdVW+P3vf4/PPvssWnHhySefxAcffICZM2dG27a0tGDffffFhx9+iGeffRYbb7yxd/+XX355lPEm5s6diwsuuABnn302pkyZ4l0FqxRQTNmu6PgNNjcXKt/6+nrU1NTkPGX95z//OYDcVT+AZbJ66qmncMQRR0Sr0iiC7vbh3nvvxemnn44jjzwS1113nfc3W2yxhdfmnn/++WhqasKNN96IjTbaKO8x29racMEFF2CdddbBEUccsdxzXF1Y2bFYPvT09GDGjBl4+eWX8fvf/x5Tpkzxbnf66afja1/7WuyzTz/9FCeddBKOPfZYHHjggRg/fvygzqGUUMz7MHnyZGy22WaYNWsWTjrpJJSXlwMAbrvtNqRSKUyfPr14Jy4YEAnYdtttAQDnnXceDj/8cFRWVmL//fePPh8opk6diilTpmDixIn4+OOPMWvWLDQ3N+MPf/hDrLz5m9/8Bg8++CB23XVX1NbW4qmnnsJ9992HE044AYcccsig9lmKKKZ8999/f0yePBmXXnop3n///Whi8M0334x11lkHxx9/fM5vCi1zCQDTpk3L+YzKP3XqVK+TKxUUW3eBZSW/xx9/HIcccoh3ubWxY8d626POOOMMjBkzJsdIJlV3iy3b66+/HnvvvTd23nlnnHTSSWhoaMB1112HTTbZJNbCduSRR+LVV1/Fcccdh3nz5mHevHnRd7W1tZF8d95555xjDBs2DADwpS99Kec+lBKKKduBjN9gcwcu3+eeew6nn346pk+fjo033hidnZ144YUX8NBDD2G77bbDN77xjZzf3Hvvveju7s5rc4PuLsOrr76Ko48+GmuttRb23HPPyFcRO+64IyZMmICRI0d6ZXLDDTcAQM53hx12GNZdd11MmjQJjY2N+OUvf4n58+fjj3/8Y05bYSmh2Db3+eefx/PPPw9gGTFtaWmJFmDYddddseuuuwIAvve97+GRRx7B/vvvjyVLluA3v/lNbD/U8W222QbbbLNN7Du2BU2ePLmk9XYgKPZ9uPrqq3HAAQdg2rRpOPzwwzF37lzcfPPNOOGEE7D55psX89T7MNDSwWWXXebWW2+9aMLYikyIOPPMM92ECRNcdXW1GzVqlPv617/u3n333ZztXnnlFbfrrru64cOHu3Q67bbaaiv305/+1PX29g56n6WKYsp3yZIl7swzz3SbbLKJq66udiNHjnSHH36490msPT09br311sv7wI98SMrEYOeKK1vn+paSe+SRRwb0u3wTg5Osu8WW7Zw5c9wOO+zg0um0GzFihDvqqKNylrYt1N4ybty4gvtPyuRK54ovW0W+8Rts7sDxzjvvuKOPPtpNmDDBZTIZl06n3eTJk91FF12U94m2O+ywgxs9erTr7u7u93HWRN1dXmvlHXfcUfD3+SYGX3XVVW6zzTZz6XTaDR8+3B1wwAGJaVMppl3I9+ArAO6iiy6Ktltem0shfB4nBjtXfPv88MMPu6233tpVV1e79ddf351//vmus7OzOCfrQco551YOvQgICAgICAgICAgIKEWUdq02ICAgICAgICAgIKDoCCQgICAgICAgICAgYA1DIAEBAQEBAQEBAQEBaxgCCQgICAgICAgICAhYwxBIQEBAQEBAQEBAQMAahkACAgICAgICAgICAtYwBBIQEBAQEBAQEBAQsIYhkICAgICAgICAgICANQwVA/3BrFmzAAC9vb3o7u5GV1cX2tvb0d7ejs7OTnR3d6O3txfl5eWoqKhAZWUlKioqUF5ejt7eXvT29qKrqwsdHR3R71pbW9HU1ISGhgYsXboUjY2NaG9vR29vLyorK5HJZFBbW4u6urrofciQIRgxYgRGjBiB4cOHo7a2FplMBpWVlbFjlpWVIZVKAQBSqZT3MeylhDvuuAPOOfT29qKnpwednZ2RrDo6OtDd3Q3nHFKpVEy+ZWVlcM6hu7sbHR0daGtrQ1tbG9rb29HW1hbJt6GhAY2Njejo6IBzDpWVlchms6itrcWQIUMiGQ8bNgzDhg3DyJEjY/Ktqqryypcy3n777VezBPPjV7/6FYA+3e3s7ERnZydaW1vR1dWFrq6umGwp17Kysug31FfKt7W1Fc3NzWhsbERDQwNaWlrQ2dkJAKisrERNTU0kV76GDh3q1V3KlmPHyna77bZbXaJbLn7961/H9FbHOPW2p6cHZWVlKC8vj3QHAHp6eqJtW1pa0NLSgtbWVrS2tqKlpQXNzc1obm5GS0sLuru7kUqlUF1djZqaGgwdOhRDhw6N9Hfo0KEYPnw4RowYgWHDhuXV2/Ly8phsv/jFL65O8RXE3Xffjd7eXgCIbG5HRwc6OjpiNhdYZuMoV9qD9vZ2NDU1obm5GU1NTWhpaYl0t7W1NbLD1P10Oo0hQ4Zg2LBhGDp0aKS/Q4cOxbBhwzB8+HAMGzYMNTU1SKfTkWypu1a2W2yxxeoRXD/w2GOPAUDkm9QudHV1RbLVZ2pyW9oO2tbGxsZIntxHT08Pent7I5uSyWQiWVJv+RoyZAiGDBmC2tpaZLNZVFdXRzqr9khlu+GGG64OsfUbr732Gpxz0Yt2oLu7O/qbNoPvGle0tLREcUFDQwNaW1vR2dkZybSsrCzSvaqqKtTU1ETxwZAhQ1BTU4NsNotMJoNsNhv9bXVW/RnHTyqVwrBhw1avAAvgvffeA4BIN2l/+aJMKU+NJdra2tDY2IilS5di8eLFkWx7enpQXl4eyau2tjaSYU1NTSwOy2QykY7auIuy5N++WKGiYsDh5yrDp59+CgCx2FFfzrlYfMaYgP6qvr4+km19fT3a29sBAJlMBsOHD8fIkSOx1lprYdiwYairq0NNTU30on6Wl5cDiN9XjiPqvr70fAci20HdBZ4IYQVkA2/9HECkHGrcNKCtqqpCT08PnHM53/mcuE/5+bk9l1KHytU+zFkHkMrZfq4EjAGZDtKqqqpIobidDmA6G8qY91vly/PjPQUQ+7tUYXUXQCRHBvsqV/2O79Q/q7tVVVURuQIQfV5eXp7jbFR3NXi28tR7W8rI9+Bx1U+9bhus8jsrW6u3wDK52ASDEjaVmdVb/TwpsmWAn88e6HWojaUuU150xFa23d3dkc0tLy+PHJDqq02o+GxCEmUL9MmXyKejlD8TLj5HTPtAOXN7ADn+S+WqsuXx1KdRxkmTrc8u5NNblanPv/liDJWf6ng+ufKcVH85VlR/eZ5JhMrIJ1OfXJVI6LuVgY017DYqWytn+/tSR6Frt59Z+ardUPuo5MwX3NuXT3d98G3bXwyYBKhDsoaRg4/BoXXMfPddLNlqNpuNCay8vBzpdDp6MdjivnkuZLxKCpIURBFWrgBiyqXytYrHoN45h6qqqpwAs6OjIyZfbl9dXR3JV5m9yrenpyciFPw/n5EtVVjdBXJlq85FyaZWt5REUe+ov8wOsMqSTqejjAlly/3yXJj9IpIoWzuegbiBpN7q/0p4SJQ0INWsYVdXV7Qds3/UV75Tvj4DTL3VwDgpsgXy21vKklCiRTkoqaquro7GsjqgsrIydHd3o6KiImZvq6urY1UUteHcD3WXdpznAJR+IOWzCUCfXbBBkf6tNkIDfNpgIB7c0nepTDUbrUEoZWs/0/uaFPhkaHWUBJSoqKiIdMkmZcrLy6PtbTJGiaslXPkSA/pZkmKGfARLiaL1b77MvN1fvvvkS475zoeypN9MYnJAzzHfOftkbH2LxgfAMr1mnKCVsHxkAPDHhIU+HygGTQK0REohVFRUxIy/zXwAiA1cGsXKysqY0FKpVBQI0NmzNJXJZGLlPG2D8WVyk6BwCqsIQDy7r/JlMKQBlVY/1ECqQysrK0N1dXVEAjKZTFTqYxmQARUHMsvkDKiSVmEB8svWp7tqNHWQ+7632fuOjg4AQFVVVVT+11JfOp2Ojsf2OOou9ThpuuszYKqDNjOt8vQFYlb+DKK6urqQSi1rB1LZZrPZWGsK98vEAI2uBh5JkW9/bC7tq14ft6WdTafTsWyyBgdssSgrK/PaA8qWx1N5MiCzDjAJ8rWOFuizt5STVgpUt3t7e2MVlerq6sjZ894wyKT+klyxLUV1lrJT4qoVwiT6NF+QwmuwhJW+XxNUtmqlekr/R3KrLyVZWiHgOdHu2kxt0myDDzx39UuUJ2OqfDKhbPl7G+Da6pVt/bNVNR0vVraZTGblCqIIyBf8811lovEu4y3aScYEqVQqas0iCdCkjP6vY8GOI838r3ISYJkz0OewAcSckWWeasR4UcxQa5a0qqoqmhNQVlYWGU06I81Y0wnRYFrF9/1dyqCTtiU0ypVZJt+gVDDjx+BdM/zV1dVR20p5eTmqq6ujPlQ6JjonGmeffIHkyBXw6y7Jqi+Q8rVAMEPlq54w0NJKAAMqylYdvwb+K1LOKwXkk61WPGzwqXK1PaTMmOqrvb09CrI434L91Cpb2hFWZZhRTCpUtpYE8FoJdciaEFD5s5pFHe7o6IiyUwyo2FtdV1cX2V3OueL+rE1Iov7mI6BaVfUFNlqt7urqQjqdjuYP0EbQlnOfJAHUWSWv1p9xX4VkmwR5WwIF9NkFrbZQzhy3lK3t29cqC+2EVgVpB2gLlBBwnGiyRTPVSZBnPtjqq453X8beVvv5uWakuT9bKbd/q73hPrQakLTEAOCvBPg+oyzYet3d3R2TKwkA41kA0VxNzjvi3CO+WxIA+Nu08lUHBopBkQB7IGY9rDPKpzz6P40os/ssR1NoHOg6qHV+gO6DzN4qczFKJqsK2hbC89bMPmEZug42W47q6OiIGUQ6fiUBSrA0g0L5dnZ25pA3X+mqlEG2rVDZaiDukyWh2aOuri60tbVFgSonGTOIUALra63QzKzqrgZ++l6q8NkFBpxWbwHEghtes1YHSV5Vdzs6OiK50CZwwhoDWtoFBhFWb61NsPpQiuA1A316wKDSJ3Pd1lfBYqDPyau64IDaA1ZZdAKgVgI0gFK55pvDUIrQgIeg3vqSSGr/KIfq6urIgQOI+SKCekkSwMmV2uKqQSrb3/SYSbEFinznav2ZtX0cv7SXjA10wiP1mD6N9kBJAOWr7cO8575xVeicSw2Fxr6SK+vLaBftwhcMPvk/92njNk2OqX0BELMJ9px8/5cq8p2nvTbKw85f0w4VLiZCvUun02hra4vIQTqdjuwHX0qSFRorAIgSBSuSRBwwCdAgVdkdnYNlTTZYtcLiPqurq6ML5wRLnahGJ2/bXygIOjEtp+jKDkkZ2CyxEzYQtaxaX7Y0ReXIZDJob2+PlU47OjpilRbN/uuENp2Qpn2UVNYkOX2ru76yNL+z7z7d5iDX0jMDAqDPSWkLm7bGAXHdVdlSh5Pi+FW2QF910Cc3C45PEga2AnZ1daGqqiqSL9vRAESOX4NUX8+6Bm229KptNqUMDVZ0rOeTrQaN1onQVnOxANpVykpJAFvYSK5Ub61sNYhIks3tj2wJDVRpG7Wvv7KyMvqOBM1mC6mzrLzSp9nEjbW5tBN6HklAvqSLzx5o6xOAaI6KnfBLkkCZa0XW1xakizHwOJawJjFhqFlhIp/eKgHQQNUSACUBSgRUj32tQDYwteeYFJn6UIjIMH5gEkt1lWSeqzVyTGez2WhFNiYMdDUyJQG+ao6PCOj/A8Wg24FUMSxT9J2wDVI1q8Sgk0oIIJokaFsDdP6A7aVi8ETB8rMkDW6VrxpMVS4rX0LvhRo+zTBx/8x08V5ks9ko4KIR5rZ2kiZLV2wH8GXTShE+3VVZ+UqbamhV321Gm9umUikvCaD+ajbKBqYqW500lETZAshrF/Jl5Z1zUSDlCxLYZsEgSzN/NpOaT7a6NKEGgKUMK1vbOmUTIrYSSidlJ61yXwysAEQkQGVLggXEJ61qosXKNik2Vyffqr21Law2U025aoKLv9FsKPep7YI2UFWHr0E/x7/qa5IIFoCcQJJ/qx3Va0+lUrFEWL7Ega1uMcGY76V2V8eGyjmJRACI94fnS1YBuYkZIG4vOH5TqZR3DKtM8pENvvtedh+ljEKBv+8afDKnTFlt5dw/Lu3Md9sWpGQMyF11UclBvvMZCAZMAnRSqBpAO9nEGjRCP9OMilVUNbA6gVhbVLg/dUwaRJEI2HMoZajDB+LLqdr+X83C6WDj9dqgi/ulwWT21ZZctZSo8lP5KtFKimOigVNipNUp64ytcwDiTk1lzc+4P6BviVCfbIG+bKoaYJVtkghsIdlqi45er9VZXxBkHQ/tgs/Rs6ULgFdvVb6+e1uqUJtg51yRtGv1w84r8hEhDeRtosYnXx7Xzv3gvfQ5sSRAxy7gT6QAcVur/+eTrTppvWc2ccZ3PRcrW91/UnSW0Cwl/bZtKwEQs69qF+z1sxqo+/P5H5uMsJlTa+MtydJzKlXY87Mky2aHVYfz6S4XXrBj2lan7bHzEQBfbJJkWHnqWPUlSxngU2+1BbMQEdCqosZuvvNZkSoAMMhKAE+OYKCuJXl1DjaYpOC0HNXZ2Rn1SVGQ9sJtNoFC8CmdZqiSEEQRnZ2d3mwzqyFWvpyEqopYVlYWy6ZQvjoZhfK1BtSW9oBcZddjJSlbravDaLaPJWW7Yg+QmzkilCTpg1h4P3gc/Z2VbaEAON9xSxVWtkCu3jLDxLGpTsj3tyX1bIfgcSxJUKiMfU4/SXZBV5wBEGubYiYZWHatWp2jfeDY54PbVJ4qA9pxTnLLp3vL01slt6UO2kGrt0yO2ACdtpMOmw9e0gev0XawAqCtQVyhZXmypW3JR4yToruUm/Yt0/bqimx2bNKe6oMvOaGS1UBWcZRUsYPAEja++4JVX3VAf1eq0ESfL05S2WqQSrnSHljb0NvbG60a2N7eHnUIcPUrS0Y1EC30SoI9yAcf6eEYpY+ycuWiFNqhUl5eHntoG1+co6X+r6KiInaPgXgVwBKAfFWz5WFQlQDNSAF9bQ/ZbDa2wgEDInX6aggIBl36FEwOcJv5U+JgGfvyBm2pD2oAkYHjC8iVL4N8zptQp64GzCoqgwEqmmYS6fjtai75DGISZGmhuptKpSInz/YH1V2+K5lV52JL9SQC1EttwbDZUeuMLJIqW5vds618zDAxUCI51WyJZkRsSw9tB3WUS4ayd1idHpEvAFCUurxVbxlMkWCxpxxAFNTTLjCxwhefFEwZ28wpK1YAokms1F1fVZH/+1DqMiVIXtWhavVZ561x/OrTbFtaWtDY2IjGxkY0Nzejra0t1g6oKzFRjlVVVTGHb1sK+beiPzIvRfCalWgqCWD1jjKjn9Knr/LFp13T1qTT6ZxEiS4kon5Oq9tEIdubBBn7AnAgTgL4OeMFlSv1l3aBuss27JaWltgiLNreyvlaahsKwWeTkwaVse2MYDJAn3ivtlb9mJKA1tbW2OR2fQK8VgLyyVi/G6xcB90OBPQ9mEYHnjr7VKqvP5oXr4FUvnK9lvvs8mA6mG02xdcXpxOJkqB8lBtZIACvfClXJVo2gPJlWJUoUDaaAdTytE9eWtLmO7ctdfla3eW7ypbb9fT0RCsocZAzSOWLwb9mXgk+UExlaycEF5KvyjkJsqXeapWQsuV8E26nfadqQPne1tYWqyKqsysrK4vuE9utdH12IHc+B5D7YBftgx9sBmVVobOzMwqWVLYM1Clb6iWAiGy1trZGTp9BFBMBKlPKMp1OR/u3S1fy/uYLOFS+QDJsAnVRJ45SHrx2vVb6McqWcq2vr0djY2OkuwBiMuUKICQBdPg6Kdjqq7a9+mxCEmCz9UC8bZJ6TZ9v9bapqQlNTU0R0WppaYnGg3YNaMzAhS80oPXBN+5L3RYoNEucbzxyO5IAtQfUXVayWlpaohiBSQadWK0T220cZ/vWLawuJ0V/CV/lSCtWVl9JrFRHGevyN5Q7k2SZTCYnRrb+L9+5rYjeDmp1IBpMDSbpjDOZDFKpVM763Jot5Ut70bQ9gGCGq7OzM2Kd+r3PyauTXF7GtRTBAajy1XYrypfVAF6vZZhajrJtO0qwgL6HCdmMqs8p2XYa7isJ8qUcrO7S4afT6ej6SBiUYLHcz8FNImDlylI1EM9W24wqkOvcGaCyitDfTMvqhuqtXht1iw+G6ejoiD2EjhVAGlG+KDMdA7xXmlFlVsoGU1px1Mqaj7yWOnxVT3XS6XQ6umbO9aGtbW9vj5x9Q0MDmpqaIkIL9D28kdlq6p2Wp7U6aUvPvqSAZstKXXd1fGnvvs7nYZBqdVazqgxUuUQwgJwWCuf6niKusuU9yEdcVfYMQpISSKkOaBCoq9jlky1brGxFgDKjTAnGIdRv2xakgShhiVWSWlbUh/l0R7ezcmUlQCuFKrPy8nK0trZG1XJdVdBmt3UulkU+PS11u6DwEQDbBqT6ype2V2l8q4kExhP6vBYfCejPOQ4WAyYBNIjKrnVQs/SmQQGh7Ekzq1oVIDRTm68/2g5gndTGcwXia+aWOrT9xmc0ScC0NUq3p1LqElQaTHF/2qO9vD5TlTHvB2VM450E48mMhT1XdUj83waUdvCSCFgnw2CVQaoOZo4bHRP5gn/N0iZBd20bGoDYeNTxrESIOqtESzMolAPvD++hzQCyh5LQrC5tA+VLYpck2dpz1YqGknGVLfXV9q77WlYoZy4fqr2s+YJTnQCudpefJU22hC/4tnLVpf1ob7WKxX3baku+BRVsUkCP6yNgpS5XYnny1Yw1ZcvgSucF2OQhf1NRUYGOjo4oSLWthIUqrrZixf+TQF4VhZKdtLG+uWsqV+3GABBr0bStmtquosfXe6lQWwyU/pLM+WDJgE9XGXcpUbKyURtSaHUge19tAqZYGDAJAHKXLFKG5LvBvhIVBWErAGoYLLHQTJ7uJ1+5j/vTG1fqyDeQtPSmZXn+xicXla0GZr7ss92HGtlCsPMzShmavVRokK+BP7e1pVVbEiS0QqJEQj+398lC5UjZJkFvAb9xsnbBR+StM1YDC8QztFamKlubiQZyn1+g+0mS3lqdsVkpIK6bhNpq/T5f0KDHy2cblofPg2xVvioz3zX5Ai9foK7BkE0E8OXbv1ZofcmwJCCffHlNGrDnS/rZ/QC5vebL02sS5kKtNEmpbOeDL3hUmfrIp62Qqr8r9Ftry31/0x5QvwvFa6WGfONd5av+Sskn0FfB1mQY5aDkzLYM6XHynUOxyMCASYD2jWsrCjPQHEhkSEC85MxtOREV6OvBBOJZb5ZO9YFB+tRAZhnUYemL7RRJIgFKfLSVij19AKJ2FWZKKT9t69ESNAehykW314eq6EPcGBzznmi/oGankuKU6Eh1opgyciC/bKuqqtDV1RVlm1jGtplm7VPXB9ao7tII+FpU8uluqcPaBTVw1FugL8Pk01vqIr/XDLcuE6zPXdAHD+qcIdV3kjWVrxrbUodNggB9K35x3gqAKEsKxFtaVBdZYdHKDL/Thy5ZnVV7b+1JkvVWyaMGPqxM8TrsXDW7/C9fuoQ2Za/f65NvKVe19UCc/FOutDUDaREoBVC+NpGi89MsmaUe6eRstQ2UibUB1ndavVRCbCsrPuKRBBn7kn8a69hrUqKjctXJ64zV7GIslrxyf7pfH9lTUpwEmQKFJ+LrdfqIj1Zp9SGiJPPa0cH92wROPsJl5VsMIjBgElBVVRUL1J1bVlpva2tDKpWKlrikIaUh42Q+njR/b0smWuLWJ1ryMeA6wZJBHMv8ZGEM0Hj8pBAAANEDpShjAF756qRUDlpbxuTkKV9ZmgbAPm6dPYCpVComS8qXD2yiE7RlwVKG1V0gLlsGsmyjoGyrq6tzDBh1mkGXkgBO2Eyn05FcGVzZ9fIpV5Jim01ICsFSEqCybW9vRyqVihIE1FvnXFTC13Yp6qlNEujTrq3OMhgo9IwL6jETEL5sWKmCK3hRvkDfKiqcKAkgRmbV8dhSP5dRBBALBDjnqKamJiZbBq88tuql2gbfKlilDs5P0aQGxz+ASFdJXoG+3nPqYU1NTdRSQfvIfau+Wp3V4NXOxdLxoDJOWpDKhJIlWZQlgEi+9CM6RyWTyaC2tjbqWWeCkTYim83GbCyThpbAWhLA42rSIEk2AYhXSAmblPNlpTUwVR0GEC1jqQ+200SLEiofAfARkuVVaJICXweFJZm+BGsmk4nsYllZWSyOVTlaefkqkErm+P+KEoEBkwAG8zqoGEgxOLQZKwZSmmlua2uLnJEGUhpI0EGpQnL/FAAzttZQcgmrpGVOODlVSUBXV1eOfJVZU74c4Co7Oicg3rtrFdX3CHvNjqicbZCaFBLAlU9UP63ualmegar+hnpfWVmZI1slAcsjWLZEa1vjbHm81MGJv/lkq5kPXnN5eXl0T9TAVlRUoL29PZKtypVBqS9YtZUA+1KypXJNqt5yEiplq4kXrvNNUJ/phOwKNloNqKmpiWSrlQG1u6qbVneTFKhyZSWtBvT09ER6ayuGACLZWj1LpVIR8aIe074yIVBTUxNLCijBssEA58BQtjbAK3XZAn0kQFtN6J/thGDKUAksK4msytCn0QZrgpCy1UqWrQzky1QntQoAICfwVrtmSQB1MpPJoKOjA9lsFnV1dZGNZqWF2/BlSavNhmtw7JMx/9b3UobvHFWHGXfRN2nwr08Hp/9nrJpKpaI4y5Iq7tfnu2zlSoN/SwwGigGTAHUsPCgzQO3t7TmZfGXkNGzq0Dn4Vcg6aDmYqYQqMM3wqaG0zj5pJACIl5YoX8s8qXT6tyoky9PaW22DVdsuYEkWZasOSQlBkuRL2RKpVCpyRtRdJUrqQGxWiav+WBLA7zRrzZetYvkGuwZVSXJKTA4QZWVlMbuQL3ui8uXnNKBKAtSesCVIHRRJsBINn4yTVmEB+mTrswnagqmOmYkBW8YvLy/PIVjaaqXBqm0NUrsLFNZhyrbUCSxJANAnX00kEWrndJ6U2kglAUCcYDGospUAtSvLC1CTltAC+lb4U+gcNeqHVpC0QqhVLJIA3hdLXpVcKQGwdkfhy74mRb75eutVT/KRAGaos9lstIBIeXl5DsHytWP7WoMKVQP03f6dNNCPMaFi4ylfFYDtsJStPkBTk7pA7lwvm2Rg940lAvwtz3EgGHQlAIg/iU4zQFou1VI+syPq+O3KNb5qABXXGktlScws6nuSstQEM0w6iH2TTWgANbMPICewYrAA5K5Eo7LVzImSAF/WWuWblL5qoI8E2IFmM/rK5JVUaf8wsybU+XxVLF9WijLVVjaffNUhlXowxeSAylZ7fqlLqm++/kigL4DVHmxLBKxs7aQr2gRbGfDJttSheqvZd129hw7C9pvrWKZcfG1slKO2BuWbE0BHVChITYp82eoH+HVXr4HOmjYCiD/pFoi3Wvn0VeVqg1Sf89bsv5VrEuRrCROvx7bq8l3ly+10BRu2awLxpcnzBaqEDU59hDZJcgX8C7TYcanLqNL++ogAs9Ta6WGTgz57Yv/mcWzPfFJkmg92bOZrA9JqgCYBGRczXuM2duzrfdMWIr7o03znRAxU1gMmASxt0Pjpg6kYFOnJVlVVxZwUL5j/W2Pr67uyhtIKg7C/VRKQFGiZWfv+mfWjLPnQsHzyBfqysT752qBVZWsDMiqiT74MCJIAJVi6BCJ11zoZrhGu106UlZXFJg4C8aUTNbi1stWsrO++WKcPlL4RzWQy0Xlr9k7nV7DUD/S1+HAMU4+VLGiAq+TLTlizxlD11AYXSZWtT29VtgCiaqs6CX5mM61Wtj7yoDIkbPaP+8gn21KHzvexxNVH8pVsatsfgy3aSx8xUz9GqLxsgAogZmOTpLMEKy3q09RO5Kt4arCkfdWMGSgvJlc0+M9XDfRlT5UMDLadYnVBKwGqw3rN+pn6J5UtH2QHIDbvRTsNbKCqNt7GBipjfU+KzlpYm8f3fLGqkn/Kl7YB6JvzorEXkLtELh8uaGMHK8cV1dsBk4CamprIGbGnj4OZSqFBjk78BXKXAFWHpApLoViFBnKzoj6Gr86J/ycBlC8dPEv/qiC81oqKimjFGms8gb4A3vaX++Srf+cLrLh//tYqZKnLWGXLSauUrc5NoePVTKsGAfxfW6Rsdln/tiU9fq/QMqN+X+oyJfLJVtdLBvpW9GJ2lLKk4VTnrkECYStT3d3dUdYQQI6z0t/x+6TJNpvNAoivEka9VL2lPdbKFdDndJQs5EsOALmVMt2vzajyt0B8KcCkwCZd6M9Ul3W822SUVr3p6PNVt1V3NYji5wwIVJcJe3+SAiUBlC/1VjOeFpSvEi0maVS+miRU+er67So3Bl42o2oTikkA7SWQuwqQjZt0nGsyRUlsKpWKJsQDyCGvPI7KtpDcaGstIbB/lyLyxT/2MyUAGiPY+azO9c3xUdlrFVznwGiCTI/lSxTkO7/+YsAkoK6uLtZGASCaTKKfAX2Pnteeds2G2IlPzKwqafA5KWVENrBVgSSRfdbV1UVGkg9PIjvXyWlA39NouWSlzexr8KnLhqoRteQMQMxA+sqldlAnRcZDhgyJVVUARGybsuW1cPlFLmurxlMz1XzZDK2vBQvIP5tfZaiBalJAvVXZMjh1zsWyqgycdMIwZQr0kVe7so0mC/R+6NNvfVUbIl8mpdQxZMiQmE1IpVKRHWDGWoNIrbJoMMX2QbUDlK/aibKyssjJk0jQednqloJEIUkgeaUcmXRJpeITrTWg0eukLmulS2WrQZhOvrSy1bZDLf37UOoBlEI7B1SufFcZ2UCHcmZGlYGUlS2h1fPOzs7IntAuMTurWVkg97kNPHapg1VVJZa8Vv08HxHwESwuMQ4gFtgCfQTAPvXdl6whbCCbFOTzH/q3rxpgKwFaZdEKlrYGqXz5YEfam3yJGiA34TVY+Q6aBFAZent7YyetATxXUOGAVKXS/l0tvWpVAYhPIuL2vmyfKjm31fekgIEqs6lWvhqoUr5c9jCffDnAbQYGiD/KHejLltgBXagElRQZq+4yo9/R0RGTjZ2cxglTNvsHxJ0OgBgJAPzLefkqLcszOEmA6q2VrQZTQF9mWle2od6yOqAZJ5udBXLJKwMESwR8WeukydbaXMpWxzflrgSLTkaz+EqwGCxRttRnIL5aTm9vb94+diDujJIm29ra2pje9vb2RkRL7aVm4Kx+sX+a8iYJYFZPA1YmzKx9pnyVEFhCm8SAipUWayc14eRr+yNYNeQ9YKWFtpb+UMcBCZbGF6yEWb9G+1QoeVCq0BZKrbRq8sUSASDetqqTsLVzw26viVrGcloptPY4X7DK/ZU6BnKOPhKg8y40ltUKltoBte8aS2t8pokdm9BapZWAmpqaWNaks7MztnKEBp36aHXtLdN3oK93TwNTLcNqAOpbSaEQkqBwCitfDfCBvkCSstWXLUtRvhrkAnH56gDmd7YHsFCWJEnyZVZK+x+ViSsBVbn6JgUDfZOw+VsdzLxPCjp4XyBhZZ0kuQLL9Fbb1azeqnPSAFQrK2rkSEZJAPLJ1gYTOh9jeWQgKVCbQLuqiwFoFt9WBi0BYH+5ZmXp/DV4AHJlqxPfbMURiC9ZmBRwLgtJE5MvGohS/5SsAojZAaCvFZOBqCYWtOVFA10NpHSlIq3sKBlbXkWx1MAMvo5huwId9Zdy4bVp1lTJrNpmTYyRwPmy1zZQVVCHOTaSIFegbw6QyotyBnJXmLKJUq0GULe0DdZ2DNC+tLW1AYjrLvfnix1scjYJMs53fr4qsq8aoKsE0a4y0a32Um0sbQ/vI+MR+kc7P0OTAysi10E9MZgXweyJBklWATUrYsvVvAAVpP7WxzCVZBQKmJKaXS0kX1sBseVPXyXAKgp/q/K1zNJmRpYXpCZBrgBifdH6RETVS9tC5asEMOtMqFMplHnRQFXvkdXlJBIBXf5UV5WxJWWgz+Bppl+NqAZJPp237QPcJwNVJQK2jz2J8s1nEzRAJCgHBrVAPKDUSaZ2UprqPhAnAXzuCseQTpr3tVPYQKtUwew7AC+JBPoCF1st0bY0BpK0C+qjfBUs/dxmUxU8rhKBJIH6xr/VHlhiTvnaz/hb/czX9kKdZnugbZMBcis53K6srO+pzEkIUoG+5J7VkUIBrBIB1V0SCt2WyNcVoHqbSqVi1cJCVSy7/ySiUOeJkgGd9wr0JRqVXGmCQHWbY0GXwtXf2/tvCVd/MWASYAMcIM4omWHSoFJ7WSkoNZDclw2a9HstqRTKpCYdarTUyWhvpD6aXgNVlYFmAe39IqyDAhALUjUg+zzI2WcElbUrMU2l+uYJqK4qrP7yM3UuWga3QYZmaAvJNwkyt4aNeqNr+wPxwFIzpipHZv3y9ZnSuWvmWjPV+vKV+pMgT4WPVNLmUncJtQn6W60cAPnXR1fZqn3Q+Vs62dNXcVGUuqx941cTKSQ73EZb29QXKZlSf6YBqtVnzaTatlf+Pum+zmdz1afZrL3aEP29+kS7X8pRSawmcjRQtVVC6jC/twQ4Ccjngyw5twFmPvuq2+rCA2wl5O9tprqzszOaQ6cVCe4viZVChY1T+beVP+FLbgHxZYXVV/rmFGrrLO8FbRJln+/4/cWASQD71HUiL5091wrXiXrOuVgQpBdP50LHxX1p+U+3tWUUG6gS1lAkCSpfXbaSy6QBfY+yZ3ZE5auTpvgAJX5HpdEAVFm9lpttRsxnUPQ9CWhra4vprjJtbQtQg0VZM7DiQKRs+Z1mmihzNaQaMFC/NWttyYO+JwFWb61saQfUcLEfXcv7XEedmVXbF6nZD50n4CvD5iufJkmuwDK9VULpky31mdemtkPXpdZxr7YC6HPOOkeARI1Onj2urAbYCmXSwHltbCWhLNgrrVliX+Dka4HgfbKT2fk7HScVFRXo6OiIPRwPiM8fyiffJMhcW3Xsqj5MHAJ9vkt1UltYNA5Qm+Fb7KKsrCz6TicSA7lLEyuZSxrB0vZeW8GzMRKAmPzUn1k5ahu3tl0x2GfLkK76pqvh2HYv/jZJiRifj/ARUEsClMT6iKrKV3WWsmUyUlcbY+JXW2F9lQDKeaAYMAlobGyMDWoGoJxcQtaiAT+VqaysLHoEeFdXV/Q0Ol6MOjdlRsqQrKLb0i1vhGZf9AaWOhoaGgAgZuyAZfKtq6uLsUJuw8nXzFxRvjpfQ50KZ/arbHWFEJWt7dUGEDPIQHLIltVdVqfo7LnakjojrhKkk6jS6XRsghQDKg5kADn6q6xfJw5Zp6RGI0mlf59dSKWWPceitrY20kt1OkwiUG76ADCdMMVASlsL1Jn5Agv7pFvNmOQz3KWKpqYmr95WVVWhpqYmCtDVaTO41aqBPqCN+9OFBmz11tpaOiPaYe5PyVXSCGxzc3PMSdPesu1Jbaq+bIuJEnlNNGhQa7PZ3F59JuFc30R3i6TIFgBaW1sBxNdAB/pIFu0hbYEmBDRhqO/aAst4Qm0ssOyetLe3o7KyEh0dHbFsthLhJJNYbXvSpAhlqg+0Uz+lAb5959/t7e3o6OiIBavU18rKSrS3t0f2gLpaXV3tnSOqvwVKnwAA+UmA9c+2sqL/a1ZfFwqgbNWOWN+lPo/xhsYl7PBQWfq6FfqDAZOA+vr6HKEwkOLa3lSktrY2tLW1RQ4JQOSwNPNn+7I1w6qEgkppgylb3uNNsJMIk2A888k3nU5HmT9VppaWlkhBAESZpc7OTmQymZwnLXMQs7RHY0n5aiZcmakaTlt+VedVyqivr/dWiXyybW1tRWtra6SrNHRVVVVRRlSDIL4Y0GsPICf80EhoaxcDXpuZspmwUsfSpUujv/PJls6ltbU1Cma52g3JZjqdjlUD1MgxIOJn3AdtBG2CLktse+M1GCv14J+wNoHnTdtJJ0F7oAESAwDK165NbZ00ZcUgVquO7e3tUbaccqT+W8eYFL1l0oXgOVP/GKCzGkMfpIGlJk10PzYBAORWa2kz0ul0zKmz0mKTW0kjWUoCAMRsIKupthuA9lLlZKsBHPuaVNBtNUbQZZ71Sa08DyW/SYIux65jjuOdn/PaLYHyvbhCjQarNtHCe8btKVdruwHEbEsSgn/C6oKOO1uh0s99BEDtcVtbG1pbW9HW1hbTcfV/XFaU+pnNZnMSCpqwJQYr30FVAmxfH1mgBu9qNBlg8mI169zT0xN75Lc6KSov105V46tsVzMm3K+WCJPklBobGwEgFlSq82a2pLW1Fc45tLW1RUaTDFGJE6st2putLUIAIseuhoAVA51Aa7OEeh+TINuGhoZooOikHZ9se3t7o2CVsqVM6Hg0iLdlUMqG94KGQLOLGkhp8KWvpASrDKZ0noNm+wBE1ZPu7u5ItjqeOaeIk1Bty5T2X1MmvD90iPqQMeqq2gatKCZFb9Xm2koSgxlWAnUZRjoZoG89ew2C7IR33hsGDCQW1H0uSUiQvPFvTbwkxd62tLTk9O6qrGgTAMSe0UB/pOTTzssCcqsj1EENrih3ADFdJ4ED4k+/TYJcCa4kA8TbeNX+MqupzxdRX28r+xy/lL9mSDUI4zFttpo2G0DsniWNYNHmaYZdbQTQt6qPbVWzlRStJCoJoA1Rwst7pkuLMhHBfWslN4mVlkIkQAmAJj/s/6qbmvRqa2tDS0tLlLChDOn/tApQXV0dq8hY/6XVx8ESrQGTAA5qdUTah6uMXte31wvQXjUqCwMhdW40AFwTW4N7IG4YbfuAZlyS4pCAZfJVp8wBpAEVjRrbVChfvtTQ6gRhzYyQZGk2jwZCMwdAvD9VMzFJC1StbH26y0HZ1tYWTYZUmWgJ2c5HUQdOx0ZDrbKlHDXr5yMBSQpU29vbY7IFEAs6qcO9vb3Rmurazgb4DS91VvercyyAvuyqtlDxN7xfQHzlnCQFqlZvAcT6/IG+SXu0ndqSQhnoNXM/PlvMIFX3oQ6d7THclvBVWUrdLtC32KSSJpbYA60TsGkXeH3UR80s+4Izm5GlLmtFQB8kZifNJjFQtQTLVvxpd7Xa7AvuNfjSYN8GX5oA5FgoLy+PtSKzPVntsCW4pQ6OPZWv1T0bhKv8NFOtLZokAioryphjXquyDFRtywoTMPxd0ogA4ese8JGCfMRA24G0i0O7ZHQFId5Lbe32VQGsD1tl7UBkHFQEddQatNJBabuOBjta5lQnpNUABg18jLJl7TS8th1FjaUa4qQMbB3AWjJW+TKrqkZVAyOFGmBWXZg1ZKCq2X4g3meojkjPi/sGkrEkoJbPVD9sX75WTFQPfYZMAwgydw1wOQbs79QQq27aoD8ppWrbygDEV7FR2erqPZSdLW2q49DWNN3eNzHVOjkG/JShtQtJQD69tWV/G8DaqgmhuqTVXO5Xq1NWrhpkacDr268SsFIFZcvzBfyyVZ1V8m+JJANO7sfaRw3KfEGtJnN8la8k6S0QX+LUyldlYit9Pn9iAy7uS/epuu6zBZqM0YSkz64nFXo9Kk9LRhU2gNXxrn8DfdUTO2fTN7/QEi0gebLlWPXZOX6fj+RoHMV3O86VYFVUVMTmJSoB8LW86Tn4/u4PBkwC7EC1TtiyfvY+8yLIHDnBUluB9HMOaJa60+l0VJqio/IRAp6jvicJDIzUOfsGsGaz2U+qMlfZ8mXly8yW9rjT+ZDkMXNlj6+PLE8KeD2qtypbIE4K9JHqdOzay0/CSrlyDgYzh2wl4HfaWmAzfDy2dYBJka+VrZUv4JctM4Dch+qrypdzC1ghpCy1T50VQitbdfQM5JIEnXdC2eo8EpWzla06MF47yZPKWqssfGdmj0GAklnNxOq9thnsUodO5FdfZm2C9We6IpsSVh3DVv+10sKVgfS+2uSPJs3s3KOkQPXB96LslOhz3NOnq0/S8axVb/X/tjKt94AxhQZTen5JSmrxHK3e+WIFjSeYJGAbmq+yrXZGg1QeT6HEgfaCiQRu7/O1pQzrgzURZ2Wb738LH4m139l2YPo1rbBa+RKDle2AvWEmk4kZLruyCU8c6FsBoKamJpbtU8fO90wmg0wmg2w2i3Q6HWW2U6lUbPY696ttQYRmFFgWV+OaBOVLp9MA4hk6DWqoCJQ/5UvSxACdgakGqNlsFtlsNpq/Qfmy5E9Dqo+6J9TR2UAkKQM7nU7nOHRtFdGBbmVLh0NHRSJg5UsSoFUqbSfSZQgpN3Vs2iebJNly+dpCsvXZBWaUGGRqCxHlTDvBMU3Z0kAyWNL+SrVJNgj2EZRSBvWWY8+3+gazrVa2uhQuv+fY9ZEArfBqb6rOibH2SM/LLkJQ6mCLpZWP2gKgrxdaCT3nB9nMtE00UA46P4XbsorNe8Lj2swrZWt9WqmDS4BqgOQ7f9UhJrUA5KzYpgSJeqoVKvVLdp4FZW+z3JSpLtKQBPna54MUelG2mpS1QadtYeH8KtpoOx6sf7ItMepLtfKbJIIF5G8H0uu3xEnlQ1iioAk/3da2E9lqIceGr7K2ykkAlYuMnn12OrgY1HJSGR2JLzudzWaRyWSiVS+09UidvfZK0UBQkHb/SXNK2WwWQPzGUr50GOqcrXw1G6gvJVlsWdGeSStfu0wbYeXLoCAJhjObzXp1F0Be2TrnYqtLqDHVTDXla0mAypZ9qQxWbZsGZcsVFnSSfBJkC+RmpWz2AkBMtnQ0AGKG1MpXKwE6r0Lvl65cY/WRAR4JhTqlpMjWZuqAvtYrK1tmiuzKHjYraEmADa6AvqQL0BesKZh40NVBfC1wpQiSVyAuHyDe/gjEZct2St/iCDb7ysBe51/xeLQFSrAsAWGAyjGgdqvUoQsD2MBUs6Ka1KKucQUau2yobx4AgyTNbDNwpWyVBPiO63u2SClDSQAQT3JaXeR4VwIAICY72w7EbbS33yZV1PcrCVCbo/Y8Kf5MdYXvPhJgk3Z8V9moHaCsKA/GB5YUaELdEgAbO9jq2kAxKBJA6OC2vU9q2FKpZS1AFK5OxrRtQHxn5YBC0nWUbTkcQIwI0CGxJUazAKUOBvWEKoWW5ilfTmjVB9v4AlXKQ2VCRWLpXwNdTjrWNZg1W6b7svM1ShX5dFcHlZIhnUCtDkN12OqvTszk/nQCoXVstmVFs40MVpPgkFS2QNwh+fSWwQHXY+dv1Ihq1UWXntMKlt4vribiy84qwWLFxrcMaSmCsvWdp+01ZUADIPb0TkKdhdphbVWxY4IBrw2cCM2SM8mgDr+UYe2twkdeudxteXl5LDtdSM5KWNWeU44aVFk/pTZBCWwSZAsg0kULmzUG+uRL26BtK7S7thqg98i2pqlt0H3w+ITqL+WbhIy1koB8wSoQb8NkbATECYD2pvtIrcZjqru+uUPW9mo8khR/ZisB+UiABvkal9qXBvmq17QJ6vtUNrYa4Ls3SvQGQ7AGTAKs4tnBrO88SS61ZpWRmShVENsKYTOCvlUCtH9d2ZkucZkExQP6yqd2UNPBWMfE4IbyVRLADJJdHlRLVhpo2YfgWFZv5ct7khSn5JOtlacNfvTdN4jtBHjVXdVB6q41mDYLYLPgSQlUC9kFjlO1C5otIbRC45OvnbxK+bCqkEota38hUbCrLdiSeFICVeotkPswOStfoC+Y0qqsr5Svk13VgfT29kYVAs2m6r6szlqbkBTZ2rlNPtn6AlX+TUduS/W2RE+Z0dZoZdxmTfkbn12w1ddSh/Yt22DK5194naqvBHXJ1yphq9F6P5lssLIFkKO/tDlJsLm+bLWtLhMae7HVhySW1009ZheG2g2Ns5i08lUD9Hx4XI0ZkhKL6fnZKoetaPnsYCFCoLKg7Qbi8+qsfHxEhOdmbc1KJwE+aEDuYyzKWvTk7Unztyw/6T5VmVSIVF4ViGVSVnClDF9JzyfXfIZQ95FPvpZIWfla2TI7aBU0n2KWKqxslWD5gipbduM+fLoL5C6RaJ1bvoBBnZh17EmULRAPVvOVNX378Bkz3Q8DfCVq2kJkMyTcJp9h1XFTqvA5V0uwfHrl24fqr90XCQCPo/YAQIxYAYg5OZ/zoV0pZfjsrS/JpCSLv1OZFtKhfJlZ23aluqw6W0i+pY5CwYyNEXTs2ooVx762pgG5S9BqIoH7tMkbq7P2pfspZfh0V/+2MrVytS0sGl/Z7DX3oYksTYb5Mt/Le5Uy8slWv7djU6vYvsqAja2o1xqvaRVKk7X55GXPYTAEa1BLhAK5g9k3sUTfeZGaSVFhqqPjagDOuWhdVW2p8LGp6ILEYXGNbO6v1BUP6Ovt9WX8bDBl+yNtJsWnPCRZZLc66drn/LXdAOjLPnDCK4Bof6Uu33yytbrK8qh1/pRJPmKZT7bcryVa/IwZFp9s8wWwpQYNHvMFqHzZ5dGAeGDU29ubk0GkbPk/ZauZahpRADnOi3/39vbGFhhIgmw18NbKlZbx+bfOk1Kyb52E2m+dsKr7sfZWgyvuS6sIvb29secHJEG21p8V0lvtTe/Put0AYveAOmx9ogYRmo227Sv0YySupS5bwP+wNN/LZyNUvlaulqTp3+rH+K7VRP1bx4PaGNrmUoYN9m1MtrwkncZelhzlqzRRH9Ve2+y3r+Ki58Xff17gqwLYYJ/zLO13GltxP6z2FaooWEKVL4nYHwyYBPgeVe0b0NYh0XGw906DLG5fUVER9aJzYNL4kghoQGVLs2rQ9ZHWSQhQCRIXID/Rorz0Cb+UCUvy6XQa3d3dkULxfii7pGPRx4groWCgyuwrP2cQpsqbBPn6ZOvTW8rEPjKdA1RbHrg9H/DDQai6q4/8ZsDKfaneqmx967CXMvi0TyD3eRI2kOKTaFUmtv2HxIiLDaiO2TFgA1Y6d20fYlKBpCxJsrU21xc8qSy4bKqu6KHORK/dZrzsOPAFrNyfOj/qLd+TYhM0+dEfW6sPUdIAn+DfWqHNRy40WKW+agCgcwS6u/ueAK3BbalDdbAQ0VK7axf98BEu/t4SMvokXxWWL22RpXx7e3ujBzclLakF5Mq2UALRV3nxVZzyEVuVqV3NzddiqASL/q/UZQv4H16ZT3a+FlZt6+3s7Iz+ZqzLZK2OZ+qnLryihMDXXrS8iuHyMGAS0N7e7hUWX2o0+YhkOiQycgaoShZsb5kqj3V26th0smC+TEI+JlyKUPla40nHQfnqU+doEChfNaB2xQPLzm02RUkWJ27q5M6enh60t7fHjIyeb6mikGytw29ra4s9Nh3oW9oyk8lEBEvLd76ASvVXs7LUd3VCOn50+1KXK9BHAgoR2O7u7uhpia2trZFsSQLsCjO2p9yX6bKOjbIFcku6dPSakUqybH1Jl7a2tki2JEeUiV2hg/uygYNmp9TZadBvK1TOLXuCOeWbFCh5tfplCRH9GWWriRF9WZ0vlNFWufpWtWOAQBKShMBfYQksEG9v9SVefE+f9fl3m7zSaqTaZA3OdOU1JQFa8U4CAQDykwBeh61SqX/36aGvK8P6eFvNLkQE1M5QrhwvpQ61tXz3kQAg3jlBGSsBYGDf1dUVmwur4wDo6zTQZ+Tkm3NoY2UlAgPFgEkAjabvRirjYyDV0tISEQEAURCliulr1/E5EvsZSUV5eXlkHJhFYJCs7RhJcE7qlCw0u0yH1NzcjNbW1ligWlNTE1MwzcwNJLCkojnnotYgK1+t9JS6fH0EllBdZKDa3NyMlpaW6PqUwHKwkwQAhTMnfNdBy2BV7yszfppt1KxhqaKQbHn+qrctLS1oaWmJMrEkWLxeVkkA/1wWwmcE1VgCfW1a1N9CmdxShA1U9W/VHY7JlpYWtLa2RsER7SSXQSQBtckVDSqA+JPcNaDSfQCIydbqbKnLVoNUIHeSpQ04SWAZpALxdgAlAXZpSw1QuS1Xz9NAlc5fz0kTNUkisNrCp+++CokSAa2eqv/WZBXvCZNg3LcvU2rlq0RLzzNJsETdElj1R/kqBFaH1JfxnfJWvfX1uNsFSGgzuI+kVFiA5ZMAfqb+J181QCsBdoEWdhNookY7DVRPS6YSwH5aG9gQ6lTU4dPYarDKAIrZKgrWMnrb86SfAYg5fC37M5OrbS6lDsrJJ19fVrW1tTVySkDfSiJa9ucABuKTV7n/5SkWf5dKpWIlf1YhlAiUMny6y3fNwqtseY0AovKd7d8H4o5as9I2QPANZm5LckEiYFtmShmULZArX5WNrWKRPOgkKf7O9vFra4rqrRpJJVi2yqLtMtpeWOrQQNUaeJWtBkWtra1RcMTskgZgJAEa4KpsGaCqTG0wRd0FEBs3GsCVOlRvgT5bAPTJVuXEBIFWX3WSH1ujLOHUllhuw55gADmyVd/I39tlhZMAG1z7Albqir1OvVYbzGonAX+j8vWtmOebE6BBatKqrzam8WWrC7UH5UtW2f35SKetCto5ARqT6bkmgQAA/mSL/Yyw9tEnE62OUD+1FdHXtpavAmCrNuozVwkJ0EFtT4jC0QyT9lJGB62oiFpWdCKfj6Fq4ESn7jsuEH/qa29vb3Rc9volwXCq4/TdZJu9o+OlfHt7l02qpBG1EyxtZlmVTyf96PGV0dPxKxHRakspo5Bs1eBZ2TJQ7e3tjc0L0LGggRjQlyVQR0/58vgc2Dy2EgstdSeBBDCbAeTqj3VASgQYfDKoZKaEgbxzLuboNVDVyX10ONYgE0r0dE5CEjKAvmBaCZKVLcc+dYfkiiuraOCjPdja3sIMtQb6PK46O44Xnqedv1Xq0AQJofYA8NtcfQo4t1Nbq/bAkgANVnk8vtTX6b713iapsu3zCbYawM9sssDXuqJJQiUD1GeOC70XQP5kl55T0oiAyrZQxtpWqG2WG8g/UVd/o8kZX3JWq4a2VZDnmrR2NqsDPr3wyULlYP+2MtIq6/J+v7zgfzAka8AkQINHHXRWSFrGoyEElimB9vb5mCo/912kdUpWMQmbMbAZn1KFBjnOOa98uZ0aSy1NW8NIZ6Xy0AwUj6cBqX6mCq5BnTq6JASqVrZa6VASYK+P10bd1TI1dY5y0KBIq1NKoHgO+ao9KlcGHaUuW6s3OhZ9mSkdn/ydzj/RpQC1EsDPNENtAyIfwVMSre0ESdJboO/abHBl7ac+zRZYlvGmfGlbtbXCTtLWrLa9rz6HlC+IS5JsAeT4FxtMWbsA9PnCfHrusyEkuJZgqYy5b56HzyaVOqx8ldgA/uw19cjXsmLvB9+t3e1PoGZJls2cl7p88wWovoy1JQn5YO12IVg7qzGYfSdKPVGYD/3RBV+AbmVi/b52AuQjWb79qlx9SYyB4P+3ay9LCgIxAABn//+jd0+pSsWEFb0wNd0nBUGI80gGPy4C4stidTTU6n6q5LuKu+uEdZKfqte6Wl4Tuh069Vqv95cTmEk3gNVJKc6dJ6kQDbH+lvk6uuvsCrgn69rONGDWZLU+pZpWVbq2etXu6mSfr6XG+MnuXF+Nb048u9jlyT621yS19u9usJz6wS7tNict02Td3Vsu7mN7LX7zvprUdqb41rFnh6cBWY7N1JfzPeZFkS4B62Kbz1dNK3o1Qd5lPqu68Ta/7+arq2PvjLWdabFit/hOsbx73NXnpjzkne3v5DFP1cXov/4br0NdpM77u4T+bqw+fQqw1lo/vzu1dAAA4Gt7/UELAAD4miIAAAAOowgAAIDDKAIAAOAwigAAADiMIgAAAA6jCAAAgMMoAgAA4DCKAAAAOMwf17P62xpOeH0AAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "@partial(nnx.jit, static_argnums=(3,))\n", "def compute_forward_sequence(model: UNet,\n", " image: jax.Array,\n", " key: jax.Array,\n", " num_vis_steps: int) -> jax.Array:\n", " \"\"\"Computes the forward diffusion sequence efficiently.\"\"\"\n", " # Prepare image sequence and noise parameters.\n", " image_repeated = jnp.repeat(image[None], num_vis_steps, axis=0)\n", " timesteps = jnp.linspace(0, 999, num_vis_steps).astype(jnp.int32) # Assuming 1000 steps\n", " beta = jnp.linspace(1e-4, 0.02, 1000)\n", " alpha = 1 - beta\n", " alpha_cumulative = jnp.cumprod(alpha)\n", "\n", " # Generate and apply noise progressively.\n", " noise = jax.random.normal(key, image_repeated.shape)\n", " noisy_images = (\n", " jnp.sqrt(alpha_cumulative[timesteps])[:, None, None, None] * image_repeated +\n", " jnp.sqrt(1 - alpha_cumulative[timesteps])[:, None, None, None] * noise\n", " )\n", " return noisy_images\n", "\n", "@partial(nnx.jit, static_argnums=(3,))\n", "def compute_reverse_sequence(model: UNet,\n", " noisy_image: jax.Array,\n", " key: jax.Array,\n", " num_vis_steps: int) -> jax.Array:\n", " \"\"\"Compute reverse diffusion sequence efficiently.\"\"\"\n", " # Denoise completely and create interpolation sequence.\n", " final_image = reverse_diffusion_batch(model, noisy_image[None], key, 1000)[0]\n", " alphas = jnp.linspace(0, 1, num_vis_steps)\n", " reverse_sequence = (\n", " (1 - alphas)[:, None, None, None] * noisy_image +\n", " alphas[:, None, None, None] * final_image\n", " )\n", " return reverse_sequence\n", "\n", "def plot_forward_and_reverse(model: UNet,\n", " diffusion: DiffusionModel,\n", " image: jax.Array,\n", " key: jax.Array,\n", " num_steps: int = 9) -> None:\n", " \"\"\"Plot both forward and reverse diffusion processes with optimized computation.\"\"\"\n", " # Compute the forward/reverse transformations\n", " key1, key2 = jax.random.split(key)\n", " forward_sequence = compute_forward_sequence(model, image, key1, num_steps)\n", " reverse_sequence = compute_reverse_sequence(model, forward_sequence[-1], key2, num_steps)\n", "\n", " # Plot the grid.\n", " fig, (ax1, ax2) = plt.subplots(2, num_steps, figsize=(8, 2))\n", " fig.suptitle('Forward and reverse diffusion process', y=1.1)\n", "\n", " timesteps = jnp.linspace(0, diffusion.num_steps-1, num_steps).astype(jnp.int32)\n", "\n", " # Visualize forward diffusion.\n", " for i in range(num_steps):\n", " ax1[i].imshow(forward_sequence[i, ..., 0], cmap='binary', interpolation='gaussian')\n", " ax1[i].axis('off')\n", " ax1[i].set_title(f't={timesteps[i]}')\n", " ax1[0].set_ylabel('Forward', rotation=90, labelpad=10)\n", "\n", " # Visualize reverse diffusion.\n", " for i in range(num_steps):\n", " ax2[i].imshow(reverse_sequence[i, ..., 0], cmap='binary', interpolation='gaussian')\n", " ax2[i].axis('off')\n", " ax2[i].set_title(f't={timesteps[num_steps-1-i]}')\n", " ax2[0].set_ylabel('Reverse', rotation=90, labelpad=10)\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Create a plot.\n", "key, subkey = jax.random.split(key)\n", "print(\"\\nFull forward and reverse diffusion processes:\")\n", "plot_forward_and_reverse(model, diffusion, images_test[0], subkey)" ] }, { "cell_type": "markdown", "metadata": { "id": "o43bRWpiM6Mt" }, "source": [ "## Summary\n", "\n", "In this tutorial, we implemented a simple diffusion model using JAX and Flax NNX, and trained it with Optax and Flax NNX. The model consisted of the U-Net model architecture with attention mechanisms, the training used Flax’s NNX JIT compilation (`flax.nnx.jit`), and we also learned how to visualize the diffusion process." ] } ], "metadata": { "accelerator": "TPU", "colab": { "gpuType": "V28", "machine_shape": "hm", "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 0 }