{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Trainable Distributions\n", "\n", "> In this post, we will take a look at how to make the parameters of distribution object trainable. This is the summary of lecture \"Probabilistic Deep Learning with Tensorflow 2\" from Imperial College London.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow_probability, ICL]\n", "- image: images/tfd_trainable_distribution.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "tfd = tfp.distributions\n", "\n", "plt.rcParams['figure.figsize'] = (10, 6)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow Version: 2.5.0\n", "Tensorflow Probability Version: 0.13.0\n" ] } ], "source": [ "print(\"Tensorflow Version: \", tf.__version__)\n", "print(\"Tensorflow Probability Version: \", tfp.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Previously, we just define the mean and standard deviation with floating type, but we can also use optimizer object to apply gradients obtained from a loss function and data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "normal = tfd.Normal(loc=tf.Variable(0., name='loc'), scale=1.)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "normal.trainable_variables" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now it has trainable variables. And this distribution is now trainable distribution.\n", "\n", "For example, we can use it like this, (for the case of negative log likelihood)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def nll(X_train):\n", " return -tf.reduce_mean(normal.log_prob(X_train))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function may return the tensor which have same shape of `X_train`. If we can assume that the training data is under IID (Independently and Indentically distributed) assumption, then the log probability of our data will be the sum of log probability of each data point." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the training loop," ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@tf.function\n", "def get_loss_and_grads(X_train):\n", " with tf.GradientTape() as tape:\n", " tape.watch(normal.trainable_variables)\n", " loss = nll(X_train)\n", " grads = tape.gradient(loss, normal.trainable_variables)\n", " return loss, grads" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that, we can speed up the computation of gradient with python decorator `@tf.function`. And it makes a computation graph out of the function.\n", "\n", "After that, we can make a loop for training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)\n", "\n", "for _ in range(num_steps):\n", " loss, grads = get_loss_and_grads(X_sample)\n", " optimizer.apply_gradients(zip(grads, normal.trainable_variables))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tutorial" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer\n", "from sklearn.naive_bayes import BernoulliNB\n", "from sklearn.metrics import f1_score" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Define an exponential distribution\n", "exponential = tfd.Exponential(rate=0.3, name='exp')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFlCAYAAADYqP0MAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWB0lEQVR4nO3df6xfZ30f8PenThMm2NrQWFWXH9hQV2s6qmS7DZrasaoDGhYpYRIUMyEFKZLXiaid0KS66xSYq0qGblUrLWvJSiRajbkpdJuluMqyQvdDXcAOBKiDMkzqEluMuIS2i6BJnXz2xz2wLzfX8de+38f31+slXd1znnPO9/u5R0fOO8/znHOquwMAwGJ923oXAACwFQlZAAADCFkAAAMIWQAAAwhZAAADCFkAAANctt4FrHTVVVf1rl271rsMAIDzevjhh/+ku3eutm2ukFVVNyf5lSQ7kvx6dx9csf0nk7wzyXNJnk6yr7sfnbb9bJI7pm0/1d0PvNh37dq1K8eOHZunLACAdVVVf3yubecdLqyqHUnuTvLGJNcneVtVXb9itw9196u7+4Yk70vyS9Ox1yfZm+QHktyc5N9OnwcAsKXNMyfrpiQnuvvx7n42yaEkt83u0N1/PrP60iTfeIz8bUkOdfcz3f1HSU5MnwcAsKXNM1x4dZInZtZPJXnNyp2q6p1J3pXk8iQ/NnPsQyuOvXqVY/cl2Zck11133Tx1AwBsaAu7u7C77+7uVyX5mST/4gKPvae7l7p7aefOVeeOAQBsKvOErNNJrp1Zv2ZqO5dDSd50kccCAGwJ84Sso0n2VNXuqro8yxPZD8/uUFV7ZlZvSfL5aflwkr1VdUVV7U6yJ8kn1l42AMDGdt45Wd19tqruTPJAlh/hcG93H6+qA0mOdffhJHdW1euS/GWSrya5fTr2eFXdl+TRJGeTvLO7nxv0twAAbBjV3eff6xJaWlpqz8kCADaDqnq4u5dW2+a1OgAAAwhZAAADCFkAAAMIWQAAAwhZAAADzPNanS1p1/77X9B28uAt61AJALAV6ckCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABjgsvUuYCPZtf/+F7SdPHjLOlQCAGx2erIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGmCtkVdXNVfVYVZ2oqv2rbH9XVT1aVZ+pqt+rqlfMbHuuqh6Zfg4vsngAgI3qsvPtUFU7ktyd5PVJTiU5WlWHu/vRmd0+lWSpu79WVf8kyfuSvHXa9vXuvmGxZQMAbGzz9GTdlOREdz/e3c8mOZTkttkduvtj3f21afWhJNcstkwAgM1lnpB1dZInZtZPTW3nckeS351Zf0lVHauqh6rqTRdeIgDA5nPe4cILUVVvT7KU5O/NNL+iu09X1SuTfLSqPtvdX1hx3L4k+5LkuuuuW2RJAADrYp6erNNJrp1Zv2Zq+xZV9bokP5fk1u5+5hvt3X16+v14kt9PcuPKY7v7nu5e6u6lnTt3XtAfAACwEc0Tso4m2VNVu6vq8iR7k3zLXYJVdWOS92c5YD05035lVV0xLV+V5IeTzE6YBwDYks47XNjdZ6vqziQPJNmR5N7uPl5VB5Ic6+7DSX4xycuS/HZVJckXu/vWJN+f5P1V9XyWA93BFXclAgBsSXPNyeruI0mOrGi7a2b5dec47g+SvHotBQIAbEae+A4AMICQBQAwgJAFADDAQp+TtRXt2n//C9pOHrxlHSoBADYTPVkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAANctt4FbEa79t//graTB29Zh0oAgI1KTxYAwABCFgDAAEIWAMAAc4Wsqrq5qh6rqhNVtX+V7e+qqker6jNV9XtV9YqZbbdX1eenn9sXWTwAwEZ13pBVVTuS3J3kjUmuT/K2qrp+xW6fSrLU3T+Y5MNJ3jcd+/Ik707ymiQ3JXl3VV25uPIBADameXqybkpyorsf7+5nkxxKctvsDt39se7+2rT6UJJrpuUfT/Jgdz/V3V9N8mCSmxdTOgDAxjVPyLo6yRMz66emtnO5I8nvXsixVbWvqo5V1bEzZ87MURIAwMa20InvVfX2JEtJfvFCjuvue7p7qbuXdu7cuciSAADWxTwh63SSa2fWr5navkVVvS7JzyW5tbufuZBjAQC2mnlC1tEke6pqd1VdnmRvksOzO1TVjUnen+WA9eTMpgeSvKGqrpwmvL9hagMA2NLO+1qd7j5bVXdmORztSHJvdx+vqgNJjnX34SwPD74syW9XVZJ8sbtv7e6nqurnsxzUkuRAdz815C/ZgLx+BwC2r7neXdjdR5IcWdF218zy617k2HuT3HuxBQIAbEae+A4AMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADDAZetdwFaxa//9F73fyYO3LLocAGCd6ckCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGOCy9S6AZNf++1/QdvLgLetQCQCwKHqyAAAGELIAAAYQsgAABhCyAAAGELIAAAaYK2RV1c1V9VhVnaiq/atsf21VfbKqzlbVm1dse66qHpl+Di+qcACAjey8j3Coqh1J7k7y+iSnkhytqsPd/ejMbl9M8o4k/2yVj/h6d9+w9lIBADaPeZ6TdVOSE939eJJU1aEktyX5Zsjq7pPTtucH1AgAsOnMM1x4dZInZtZPTW3zeklVHauqh6rqTavtUFX7pn2OnTlz5gI+GgBgY7oUE99f0d1LSf5Rkl+uqlet3KG77+nupe5e2rlz5yUoCQBgrHlC1ukk186sXzO1zaW7T0+/H0/y+0luvID6AAA2pXlC1tEke6pqd1VdnmRvkrnuEqyqK6vqimn5qiQ/nJm5XAAAW9V5Q1Z3n01yZ5IHknwuyX3dfbyqDlTVrUlSVT9UVaeSvCXJ+6vq+HT49yc5VlWfTvKxJAdX3JUIALAlzXN3Ybr7SJIjK9rumlk+muVhxJXH/UGSV6+xRgCATccT3wEABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGmOvdhVx6u/bf/4K2kwdvWYdKAICLoScLAGAAIQsAYAAhCwBgACELAGAAIQsAYAAhCwBgACELAGAAIQsAYAAhCwBgAE983wI8HR4ANh4haxNZLUwBABuT4UIAgAGELACAAQwXblHmaQHA+tKTBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMIBHOGwjHusAAJeOniwAgAGELACAAYQsAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAAeYKWVV1c1U9VlUnqmr/KttfW1WfrKqzVfXmFdtur6rPTz+3L6pwAICN7Lwhq6p2JLk7yRuTXJ/kbVV1/YrdvpjkHUk+tOLYlyd5d5LXJLkpybur6sq1lw0AsLHN05N1U5IT3f14dz+b5FCS22Z36O6T3f2ZJM+vOPbHkzzY3U9191eTPJjk5gXUDQCwoc0Tsq5O8sTM+qmpbR5rORYAYNPaEBPfq2pfVR2rqmNnzpxZ73IAANZsnpB1Osm1M+vXTG3zmOvY7r6nu5e6e2nnzp1zfjQAwMY1T8g6mmRPVe2uqsuT7E1yeM7PfyDJG6rqymnC+xumNgCALe28Iau7zya5M8vh6HNJ7uvu41V1oKpuTZKq+qGqOpXkLUneX1XHp2OfSvLzWQ5qR5McmNoAALa0y+bZqbuPJDmyou2umeWjWR4KXO3Ye5Pcu4YaAQA2nQ0x8R0AYKsRsgAABphruJCta9f++1/QdvLgLetQCQBsLXqyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAbwMFLm4qGlAHBhhCxeYLVABQBcGMOFAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADXLbeBbB57dp//wvaTh68ZR0qAYCNR08WAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwAAeRspwHloKwHakJwsAYAAhCwBgAMOFLNRqQ4MAsB3pyQIAGEDIAgAYwHAh68IdhwBsdXqyAAAGELIAAAYQsgAABpgrZFXVzVX1WFWdqKr9q2y/oqp+a9r+8araNbXvqqqvV9Uj08+vLbh+AIAN6bwT36tqR5K7k7w+yakkR6vqcHc/OrPbHUm+2t3fW1V7k7w3yVunbV/o7hsWWzYAwMY2T0/WTUlOdPfj3f1skkNJbluxz21JPjgtfzjJ36+qWlyZAACbyzwh6+okT8ysn5raVt2nu88m+bMk3zVt211Vn6qq/1ZVf3e1L6iqfVV1rKqOnTlz5oL+AACAjWj0xPcvJbmuu29M8q4kH6qqv7Zyp+6+p7uXuntp586dg0sCABhvnpB1Osm1M+vXTG2r7lNVlyX5jiRf6e5nuvsrSdLdDyf5QpLvW2vRAAAb3Twh62iSPVW1u6ouT7I3yeEV+xxOcvu0/OYkH+3urqqd08T5VNUrk+xJ8vhiSgcA2LjOe3dhd5+tqjuTPJBkR5J7u/t4VR1Icqy7Dyf5QJLfrKoTSZ7KchBLktcmOVBVf5nk+SQ/2d1PjfhD2Jq8fgeAzWqudxd295EkR1a03TWz/BdJ3rLKcR9J8pE11sg2sVqgAoDNyhPfAQAGELIAAAYQsgAABphrThZsJCbDA7AZ6MkCABhATxZbwrnuTNTDBcB60ZMFADCAniyIeV4ALJ6eLACAAfRksaXpoQJgvQhZcA4CGgBrYbgQAGAAIQsAYADDhWw753qmFgAskp4sAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAATzxHS6Al0YDMC89WQAAAwhZAAADGC6EAQwrAiBkwToSxgC2LiEL1mi1oAQAQhZsMHq3ALYGE98BAAYQsgAABjBcCJvUhQwrGoIEuPSELLhETJAH2F6ELNhCBDmAjUPIgk1gRHgyhAgwlonvAAAD6MkCvknvFsDiCFnAuhPugK1IyAKGEZ6A7UzIAi7YWsLTpbgDUrgDNgIhC9jWBDJgFCEL2JCEH2CzE7KAFzXv8N56DQMCbFRzhayqujnJryTZkeTXu/vgiu1XJPmNJH87yVeSvLW7T07bfjbJHUmeS/JT3f3AwqoHmJOABlxq5w1ZVbUjyd1JXp/kVJKjVXW4ux+d2e2OJF/t7u+tqr1J3pvkrVV1fZK9SX4gyV9P8l+r6vu6+7lF/yEAm8m8w6Ebbdh0o9UDG9k8PVk3JTnR3Y8nSVUdSnJbktmQdVuS90zLH07yb6qqpvZD3f1Mkj+qqhPT5/2vxZQPsHiL7vVadAiZt77NEH42a9hk49mI18g8IevqJE/MrJ9K8ppz7dPdZ6vqz5J819T+0Ipjr77oagE2ofWa17aWALPW77kUn3cpAtpa/7a1/Ef+UoSGEd+xEcPOetkQE9+ral+SfdPq01X12CX42quS/Mkl+J7tyLkdy/kdZ8uf23rvun798PM77993qc7Dor/nRT5vYed2xLlZr+tugd/7Yuf3Fec6aJ6QdTrJtTPr10xtq+1zqqouS/IdWZ4AP8+x6e57ktwzRy0LU1XHunvpUn7nduHcjuX8juPcjuX8juPcjnWx5/fb5tjnaJI9VbW7qi7P8kT2wyv2OZzk9mn5zUk+2t09te+tqiuqaneSPUk+caFFAgBsNuftyZrmWN2Z5IEsP8Lh3u4+XlUHkhzr7sNJPpDkN6eJ7U9lOYhl2u++LE+SP5vkne4sBAC2g7nmZHX3kSRHVrTdNbP8F0neco5jfyHJL6yhxlEu6fDkNuPcjuX8juPcjuX8juPcjnVR57eWR/UAAFikeeZkAQBwgbZdyKqqm6vqsao6UVX717ueraaqTlbVZ6vqkao6tt71bHZVdW9VPVlVfzjT9vKqerCqPj/9vnI9a9ysznFu31NVp6fr95Gq+gfrWeNmVVXXVtXHqurRqjpeVT89tbt2F+BFzq/rd42q6iVV9Ymq+vR0bv/l1L67qj4+ZYffmm4EPP/nbafhwukVQf87M68ISvK2Fa8IYg2q6mSSpe7e0s8aulSq6rVJnk7yG939N6e29yV5qrsPTv+jcGV3/8x61rkZnePcvifJ0939r9azts2uqr4nyfd09yer6q8meTjJm5K8I67dNXuR8/sTcf2uyfS2mpd299NV9e1J/meSn07yriS/092HqurXkny6u3/1fJ+33XqyvvmKoO5+Nsk3XhEEG1J3//cs37E767YkH5yWP5jlf1y5QOc4tyxAd3+puz85Lf/fJJ/L8ts+XLsL8CLnlzXqZU9Pq98+/XSSH8vyawOTC7h2t1vIWu0VQS7Mxeok/6WqHp6e5M/ifXd3f2la/j9Jvns9i9mC7qyqz0zDiYaz1qiqdiW5McnH49pduBXnN3H9rllV7aiqR5I8meTBJF9I8qfdfXbaZe7ssN1CFuP9SHf/rSRvTPLOaUiGQaaH/m6fMf/xfjXJq5LckORLSf71ulazyVXVy5J8JMk/7e4/n93m2l27Vc6v63cBuvu57r4hy2+puSnJ37jYz9puIWuu1/xw8br79PT7yST/McsXKIv15WlOxjfmZjy5zvVsGd395ekf2OeT/Lu4fi/aNJ/lI0n+fXf/ztTs2l2Q1c6v63exuvtPk3wsyd9J8p3TawOTC8gO2y1kzfOKIC5SVb10moSZqnppkjck+cMXP4qLMPsaq9uT/Od1rGVL+UYAmPzDuH4vyjR5+ANJPtfdvzSzybW7AOc6v67ftauqnVX1ndPyX8nyjXKfy3LYevO029zX7ra6uzBJpltafzn//xVBG/Fp9JtSVb0yy71XyfLbBD7k/K5NVf2HJD+a5TfAfznJu5P8pyT3JbkuyR8n+YnuNoH7Ap3j3P5olodaOsnJJP94Zg4Rc6qqH0nyP5J8NsnzU/M/z/K8IdfuGr3I+X1bXL9rUlU/mOWJ7Tuy3BF1X3cfmP77dijJy5N8Ksnbu/uZ837edgtZAACXwnYbLgQAuCSELACAAYQsAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAAf4fpone9R8nIi4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot\n", "plt.hist(exponential.sample(5000).numpy(), bins=100, density=True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define an exponential distribution with a trainable rate parameter\n", "exponential_train = tfd.Exponential(rate=tf.Variable(1., name='rate'), name='exp_train')\n", "exponential_train.trainable_variables" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Define a negative log likelihood\n", "def nll(X_train, distribution):\n", " return -tf.reduce_mean(distribution.log_prob(X_train))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Define a function to compute loss and gradients\n", "@tf.function\n", "def get_loss_and_grads(X_train, distribution):\n", " with tf.GradientTape() as tape:\n", " tape.watch(distribution.trainable_variables)\n", " loss = nll(X_train, distribution)\n", " grads = tape.gradient(loss, distribution.trainable_variables)\n", " return loss, grads" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Optimize\n", "def exponential_dist_optimization(data, distribution):\n", " # Keep results for plotting\n", " train_loss_results = []\n", " train_rate_results = []\n", " \n", " optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)\n", " \n", " num_steps = 10\n", " \n", " for i in range(num_steps):\n", " loss, grads = get_loss_and_grads(data, distribution)\n", " optimizer.apply_gradients(zip(grads, distribution.trainable_variables))\n", " \n", " rate_value = distribution.rate.value()\n", " train_loss_results.append(loss)\n", " train_rate_results.append(rate_value)\n", " \n", " print(\"Step {:03d}: Loss: {:.3f}: Rate: {:.3f}\".format(i, loss, rate_value))\n", " \n", " return train_loss_results, train_rate_results" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 000: Loss: 2.896: Rate: 0.669\n", "Step 001: Loss: 2.682: Rate: 0.573\n", "Step 002: Loss: 2.510: Rate: 0.490\n", "Step 003: Loss: 2.383: Rate: 0.422\n", "Step 004: Loss: 2.301: Rate: 0.370\n", "Step 005: Loss: 2.255: Rate: 0.335\n", "Step 006: Loss: 2.235: Rate: 0.314\n", "Step 007: Loss: 2.228: Rate: 0.303\n", "Step 008: Loss: 2.227: Rate: 0.297\n", "Step 009: Loss: 2.226: Rate: 0.295\n" ] } ], "source": [ "# Get some data and train\n", "sampled_data = exponential.sample(5000)\n", "train_loss_results, train_rate_results = exponential_dist_optimization(data=sampled_data, distribution=exponential_train)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Exact rate: 0.3\n", "Predicted rate: 0.29511935\n" ] } ], "source": [ "# Predicted value for the rate parameter\n", "\n", "pred_value = exponential_train.rate.numpy()\n", "exact_value = exponential.rate.numpy()\n", "\n", "print(\"Exact rate: \", exact_value)\n", "print(\"Predicted rate: \", pred_value)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmUAAAGiCAYAAACmirG2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABPPklEQVR4nO3dd3hVVfr28e+TXggJCYQOoUrvTZoKSrGOYxm76Aj2Mvo66ow6fZyfo45dx4KMvRdUrKA0RQiIgID00EkgJKSStt4/zkk4CRAIJDknyf25rkz22XvtfZ5wnHCz9tprmXMOEREREfGvIH8XICIiIiIKZSIiIiIBQaFMREREJAAolImIiIgEAIUyERERkQCgUCYiIiISABTKRERERAKAQpmI1Cgzu8TMks0s28x2mNlnZjbS33WJiAQahTIRqTFmdjvwKPBPoDnQDngaOMePZZUxsxB/1yAiUkqhTERqhJnFAn8FbnTOve+cy3HOFTrnPnbO3Wlm4Wb2qJlt9349ambh3nNPNrOtZnaHmaV6e9iu8h4bamY7zSzY573ONbNl3u0gM7vbzNab2R4ze9vM4r3HkszMmdlvzWwzMMvMgs3sYTPbbWYbzewmb5uQ0p/DzF701rDNzP5e+t5mNsnM5pnZQ2a213v+RJ+64s3sJe/Pt9fMPvQ5dqaZLTWzDDP7zsz61PRnIiKBTaFMRGrKiUAE8MFhjv8RGAb0A/oCQ4B7fY63AGKB1sBvgafMrIlz7gcgBxjj0/YS4HXv9s3Ar4CTgFbAXuCpCu99EtAdGA9MBiZ66xjgPdfXNKAI6Az0B8YB1/gcHwr8AjQFHgReNDPzHnsFiAJ6AonAfwDMrD8wFbgWSAD+C0wvDaUi0jCZ1r4UkZpgZpcCDzvnWhzm+HrgZufcDO/r8cB/nXNJZnYy8BkQ45wr8h5PBc52zi0ws78DrZxzV5tZDLAT6OGcSzGzVcBNzrmZ3vNaApuBSKANsBHo5Jzb4D0+C3jLOfdf7+tTga+AUDyBaTMQ55zL8x6/GJjinDvFzCYB9zrnOnuPReEJjC0BA7YBCc65vRV+9meA3c65+3z2/eK97uwq/2GLSL2g8RQiUlP2AE3NLKQ0WFXQCkjxeZ3i3Vd2foXzcoFG3u3Xge/M7Hrg18AS51zptdoDH5hZic+5xXjGtJXaUqGOLYc51h5PONtxoPOLoAptdpZuOOdyve0aAfFAesVA5nPdK83sZp99YZT/+UWkgdHtSxGpKd8D+zn4dmCp7XjCSal23n1H5JxbiSfETaT8rUvwBKaJzrk4n68I59w230v4bO/A04NWqm2Fa+0Hmvpcq7FzrudRlLkFiDezuMMc+0eFGqOcc28cxXVFpJ5SKBORGuGcywTuxzMW7FdmFmVmoWY20cweBN4A7jWzZmbW1Nv21Sq8xevArcBo4B2f/c8C/zCz9gDe61f2tOfbwK1m1toboO7y+Rl2AF8CD5tZY+9DBJ3M7KSj+Pl34LkF+7SZNfH+7KO9h58HrvM+tGBmFm1mZ3hvxYpIA6VQJiI1xjn3MHA7ngH8aXh6iG4CPgT+DiQDy4DlwBLvvqP1Bp4B+7Occ7t99j8GTAe+NLMsYAGewfiH8zye4LUM+BGYgWdgf7H3+BV4bi2uxPPQwLt4xowdjcuBQmA1kArcBuCcS8bzgMGT3muuAyYd5TVFpJ7SQH8RER/eKS2edc61P2JjEZFqpJ4yEWnQzCzSzE43sxAzaw38icNP4yEiUmPUUyYiDZp3GovZQDcgD/gUuNU5t8+vhYlIg6NQJiIiIhIAdPtSREREJAAolImIiIgEAIUyERERkQCgUCYiIiISABTKRERERAKAQpmIiIhIAFAoExEREQkACmUiIiIiAUChTERERCQAKJSJiIiIBACFMhEREZEAoFAmIiIiEgAUykREREQCgEKZiIiISABQKBMREREJAAplIiIiIgFAoUxEREQkACiUiYiIiAQAhTIRERGRAKBQJiIiIhIAFMpEREREAoBCmYiIiEgAUCgTERERCQAKZSIiIiIBQKFMREREJAAolImIiIgEAIUyERERkQCgUCYiIiISABTKRERERAJAiL8LOF5NmzZ1SUlJ/i5DRERE5IgWL1682znX7FDH6nwoS0pKIjk52d9liIiIiByRmaUc7phuX4qIiIgEAIUyERERkQCgUCYiIiISABTKjqCkxPHS/I3kFxb7uxQRERGpxxTKjmDhpnT+8vFKfvXUfDbuzvF3OSIiIlJPKZQdwbCOCbx01WB27svnrCfm8dnyHf4uSUREROohhbKjcMoJiXx6yyg6JTbi+teW8NePV1JQVOLvskRERKQeUSg7Sq3jInnn2hOZNDyJqfM3ctFz37M9I8/fZYmIiEg9oVBWBWEhQfz57J48eUl/ftmZxRmPz2X2mjR/lyUiIiL1gELZMTizTyum3zySxJgIJr20kP98tYbiEufvskRERKQOUyg7Rp2aNeLDG0fw6/5teGzmWq6cupA92fv9XZaIiIjUUQplxyEyLJiHLujD/53Xm4Wb0jnj8Xkkb0r3d1kiIiJSBymUHScz4zeD2/HBDcMJDw3ioucW8MLcDTin25kiIiJy9BTKqknPVrF8fPNIxnZP5O+fruK6VxezL7/Q32WJiIhIHaFQVo0aR4Ty7GUDufeM7sxclcpZT8zj5+2Z/i5LRERE6gCFsmpmZlwzqiNvThnG/sISzn36O95atFm3M0VERKRSCmU1ZFBSPJ/eMpKhHeK5673l/L93lpFXoEXNRURE5NAUympQQqNwpl01hFvHduH9H7dy7tPz2ZCW7e+yREREJAAplNWw4CDjd6d15X9XDWHXvnzOfnI+ny7TouYiIiJSnkJZLRndtRmf3jKKrs0bcePrS/jz9J+1qLmIiIiUqbVQZmZtzewbM1tpZj+b2a2HaBNrZh+b2U/eNlfVVn21oVVcJG9OOZGrR3Rg2nebuPC/37NNi5qLiIgItdtTVgTc4ZzrAQwDbjSzHhXa3AisdM71BU4GHjazsFqsscaFhQRx/1k9eObSAaxLzeaMx+fy7S+p/i5LRERE/KzWQplzbodzbol3OwtYBbSu2AyIMTMDGgHpeMJcvTOxd0s+vnkkLRpHcNW0RTz85S9a1FxERKQB88uYMjNLAvoDP1Q49CTQHdgOLAdudc7V24FXHZpG8+GNI7hgYBuemLWOK6b+wG4tai4iItIg1XooM7NGwHvAbc65fRUOjweWAq2AfsCTZtb4ENeYYmbJZpaclpZWwxXXrIjQYB48vy8Pnt+H5E17OePxuSzSouYiIiINTq2GMjMLxRPIXnPOvX+IJlcB7zuPdcBGoFvFRs6555xzg5xzg5o1a1azRdeSCwe15YMbRhAZGsxFzy3guTnrtQqAiIhIA1KbT18a8CKwyjn3yGGabQbGets3B04ANtROhf7Xo1Vjpt88knE9mvPPGau59pXFZOZpUXMREZGGoDZ7ykYAlwNjzGyp9+t0M7vOzK7ztvkbMNzMlgMzgbucc7trsUa/axwRytOXDuD+M3swa7VnUfMV27SouYiISH1ndf0W2aBBg1xycrK/y6gRi1P2ctPrS9iTU8Cfz+rJxUPa4ulwFBERkbrIzBY75wYd6phm9A9gA9s34dNbRjG0Qzx/+GA5d7z9E7kF9XKGEBERkQZPoSzAxUeHMe2qIfzu1K58sHQbv3pqPutStai5iIhIfaNQVgcEBxm3ntqFl68ewu7sAs5+ch7Tf9ru77JERESkGimU1SGjujTj01tG0r1lY25540fu/2gF+4uK/V2WiIiIVAOFsjqmZWwkb04ZxuRRHXj5+xQufPZ7tqTn+rssEREROU4KZXVQaHAQfzyjB89eNoANaTmc+cQ8Zq3e5e+yRERE5DgolNVhE3p5FjVvHRfJ1dOSefDz1RQV19ulQkVEROo1hbI6LqlpNO/fMJyLBrfl6W/Xc9mLP5Cale/vskRERKSKFMrqgYjQYP51Xh8euqAvS7dkcMbj81iwYY+/yxIREZEqUCirR84f2IYPbxxBTHgIlzy/gGe+XU9JSd1esUFERKShUCirZ7q1aMxHN41gYu+W/N/nq5n8cjKZuVrUXEREJNAplNVDMRGhPHlxf/58Vg/mrE3jjCfmsmxrhr/LEhERkUoolNVTZsakER14+9oTKSlxnP/M97yyIIW6vgC9iIhIfaVQVs/1b+dZ1PzETgnc9+EKbntrKTn7tai5iIhIoFEoawCaRIfx0qTB/L9xXfn4p+2c89R81u7K8ndZIiIi4kOhrIEICjJuGtOFV347lIzcAs5+cj4fLd3m77JERETES6GsgRnRuSmf3jKKXq0bc+ubS7n3w+Va1FxERCQAKJQ1QM0bR/D65GFcO7ojry7YzPnPfM/mPVrUXERExJ8Uyhqo0OAg7jm9O89dPpBNe3IY9+hsnv52HQVFWjtTRETEHxTKGrhxPVvw+W2jGd2lGQ9+/gunPz5XSzSJiIj4gUKZ0DoukueuGMSLVw4iv7CYi55bwO1vL2V39n5/lyYiItJgKJRJmbHdm/PV707ihpM78fFP2xn78Gxe+yFF62eKiIjUAoUyKScyLJjfT+jGZ7eOonvLGP74wQp+/cx3rNiW6e/SRERE6jWFMjmkzokxvDF5GI9c2Jct6bmc/eQ8/vLxz2Tla3FzERGRmqBQJodlZvx6QBtm3XEyFw9px7TvNjH24dl8smy71tAUERGpZgplckSxUaH849zefHDDCJrFhHPT6z9yxdSFbNqd4+/SRERE6g2FMjlq/drG8dGNI/jTWT34cXMG4x6dw6NfryG/UCsCiIiIHC+FMqmSkOAgrhrRgZl3nMS4Hs159Ou1THh0DnPXpvm7NBERkTpNoUyOSfPGETx5yQBe+e0QAC5/cSE3vb6EXfvy/VyZiIhI3aRQJsdlVJdmfH7baG47tQtfrtzF2Idn89L8jRQVa7kmERGRqlAok+MWERrMbad25cvbRtO/XRx/+Xgl5zw1n6VbMvxdmoiISJ1x3KHMzEKroxCp+5KaRvPy1UN48pL+pGXt59yn5/PHD5aTmau5zURERI6kSqHMzG4xs/N8Xr8I5JnZL2Z2whHObWtm35jZSjP72cxuPUy7k81sqbfN7KrUJ/5nZpzZpxUz7ziJScOTeGPhZsY+8i3vL9mquc1EREQqUdWesluANAAzGw1cCFwCLAUePsK5RcAdzrkewDDgRjPr4dvAzOKAp4GznXM9gQuqWJ8EiJiIUP50Vk+m3zSSNk2iuP3tn7j4+QWsS83yd2kiIiIBqaqhrDWw0bt9FvCOc+5t4M94gtZhOed2OOeWeLezgFXe6/m6BHjfObfZ2y61ivVJgOnVOpb3rx/OP87txcrt+5j42Fwe/Hw1eQWa20xERMRXVUPZPiDRu30aMNO7XQhEHO1FzCwJ6A/8UOFQV6CJmX1rZovN7IrDnD/FzJLNLDktTfNjBbqgIOPSoe2Z9f9O5qy+rXj62/Wc9p/ZzFy1y9+liYiIBIyqhrIvgefN7AWgM/CZd39PDvSgVcrMGgHvAbc55/ZVOBwCDATOAMYD95lZ14rXcM4955wb5Jwb1KxZsyr+COIvTRuF88iF/XhzyjAiQoP57f+SmfJyMtsy8vxdmoiIiN9VNZTdCMwHmgHnO+fSvfsHAG8c6WTvk5rvAa85594/RJOtwBfOuRzn3G5gDtC3ijVKgBvWMYEZt4zi9xNOYM7aNE57ZDb/nb2eQs1tJiIiDZjV1hNxZmbA/4B059xth2nTHXgSTy9ZGLAQuMg5t+Jw1x00aJBLTk6u/oKlVmxJz+UvH//M16tSOaF5DH8/txeDk+L9XZaIiEiNMLPFzrlBhzpW1SkxevhOfWFmp5nZq2Z2j5kFH+H0EcDlwBjvlBdLzex0M7vOzK4DcM6tAj4HluEJZC9UFsik7msbH8ULVw7mucsHkr2/iAue/Z473/mJ9JwCf5cmIiJSq6rUU2ZmC4BHnXNvmllb4BfgW6AP8Ipz7p4aqbIS6imrP3ILinhs5lpenLuRRhEh3D2hGxcOaktQkPm7NBERkWpRbT1lQDdgiXf7fOAH59zpeHrALj72EkUgKiyEeyZ259NbRtE1MYa731/OBf/9nlU7Kj4PIiIiUv9UNZQFA6X3lcYCM7zb64Hm1VWUNGwntIjhrWuH8e/z+7Bxdw5nPjGPv3+ykuz9Rf4uTUREpMZUNZStAK43s1F4Qtnn3v2tgd3VWZg0bGbGBYPaMvP2k7hwUBtemLeRUx+ezWfLd2i5JhERqZeqGsruAibjGUf2hnNuuXf/2XgG5otUqybRYTzw6z68d/1wmkSHcf1rS7h62iI278n1d2kiIiLVqspTYnifsmzsnNvrsy8JyPXHskga6N9wFBWXMO27TfznqzUUlThuHtOZyaM7Eh5ypAd/RUREAkN1DvTHOVcM5JlZLzPraWYRzrlNWqdSalpIcBDXjOrIzDtOZmz3RB76cg0TH5vLd+t051xEROq+qs5TFmJm/wb2Aj8By4G9Zvagd7Z+kRrXIjaCpy8dyLSrBlNU7LjkhR+47c0fSc3K93dpIiIix6yqPWUPApcB1+FZPLwLcD2eKTEeqN7SRCp38gmJfPm70dwypjMzlu9k7MOzefn7TRSX6EEAERGpe6o6eexO4Grn3IwK+8/AM/t+y2qu74g0pkwANqRlc99HK5i/bg992sTyj1/1pnebWH+XJSIiUk51jimLxTMnWUXrgbgqXkuk2nRs1ohXfzuUxy7qx47MfM55ah5/+mgF+/IL/V2aiIjIUalqKPsJuOUQ+2/1HhPxGzPjnH6tmXnHSVw+rD2vLEhh7MOz+WjpNs1tJiIiAa+qty9H45nFfxuwwLt7GNAKmOicm1ftFR6Bbl/K4SzfmskfP1zOsq2ZnNgxgZvHdObETgmYaS1NERHxj8puXx7LPGWtgBvxrIMJsApPULvNOXfh8RR6LBTKpDLFJY7XF27msa/XsDu7gJ6tGjN5VEfO6NOS0OAqzwgjIiJyXKo1lB3mDfoCS5xztT6Lp0KZHI38wmI+/HEbz8/dwPq0HFrGRnD1iA5cNKQtMRGazUVERGqHQpmIV0mJ45tfUnl+7gYWbEgnJjyEi4e2Y9LwJFrFRfq7PBERqecqC2UhtV2MiD8FBRljuzdnbPfmLNuawfNzN/LivI1MnbeRs/q24ppRHejZSlNpiIhI7VNPmTR4W9JzeWn+Jt5ctJncgmJGdE5g8qiOnNS1mR4KEBGRanXcty/NbPoRmjQGRimUSV2WmVvI6ws389L8jaRm7eeE5jFcM6oDZ/drpUXPRUSkWlRHKHvpaN7IOXdVFWs7bgplUt0KikqY/tN2Xpi7gdU7s0iMCWfSiCQuHdKe2Cg9FCAiIseuxgf6+5NCmdQU5xxz1+7m+bkbmLt2N1FhwVw4qC2/HdmBtvFR/i5PRETqIIUykeO0cvs+Xpi7gek/bafEOSb2bsmUUR3p2zbO36WJiEgdolAmUk12ZOYxbf4mXv9hM1n7ixjSIZ4pozoyplsiQUF6KEBERCqnUCZSzbLyC3lr0RamztvI9sx8OjaLZvKojpzbvzURoXooQEREDk2hTKSGFBaXMGP5Dp6fu4EV2/aREB3GFScmcfmJ7YmPDvN3eSIiEmAUykRqmHOO7zfs4fk5G/jmlzQiQoM4f2AbfjuyIx2aRvu7PBERCRCa0V+khpkZwzs1ZXinpqzZlcULczfw9qKtvPbDZsb1aM6U0R0Z2D7e32WKiEgAU0+ZSA1Jzcrn5e9SeGVBCpl5hfRvF8eUUR0Z17MFwXooQESkQdLtSxE/yi0o4p3krbwwbwNb0vNoFx/FNaM6cP7ANkSFqbNaRKQhUSgTCQDFJY4vft7Jc3M2sHRLBnFRoVw+rD1XnJhEs5hwf5cnIiK1QKFMJIA451icspfn5mzgq1W7CA0K4tcDWnPNqA50Tozxd3kiIlKDNNBfJICYGYOS4hmUFM+GtGxenLeRdxdv5c1FWxjTLZHJozoyrGM8Zhp3JiLSkNRaT5mZtQVeBpoDDnjOOffYYdoOBr4HLnLOvVvZddVTJvXBnuz9vLIghZe/TyE9p4DerWOZPLojp/dqQUhwkL/LExGRahIQty/NrCXQ0jm3xMxigMXAr5xzKyu0Cwa+AvKBqQpl0pDkFxbz3pKtvDB3Ixt359A6LpKrR3bgN4Pb0ihcHdsiInVdZaGs1v4J7pzb4Zxb4t3OAlYBrQ/R9GbgPSC1tmoTCRQRocFcOrQ9M28/ieevGETruEj+9slKTnxgJg98toqdmfn+LlFERGqIXwb6m1kSMAfo5Zzb57O/NfA6cAowFfjkUD1lZjYFmALQrl27gSkpKbVRtohf/Lh5Ly/M3chnK3YQZMbZ/VoxeVRHurds7O/SRESkigLi9qVPMY2A2cA/nHPvVzj2DvCwc26BmU3jMKHMl25fSkOxeU8uU+dv5K1FW8grLGZUl6ZMGd2RkZ2b6qEAEZE6ImBCmZmFAp8AXzjnHjnE8Y1A6d8uTYFcYIpz7sPDXVOhTBqajNwCXvthM9O+20Ra1n66tYhh8qiOnN67JZFhwf4uT0REKhEQocw8/5T/H5DunLvtKNpPQz1lIoe1v6iYj5Zu5/k5G1ibmk1YSBDDOyUwtlsip3RLpE2TKH+XKCIiFQRKKBsJzAWWAyXe3X8A2gE4556t0H4aCmUiR+Sc4/v1e/hq1S5mrU4lZU8uACc0j2FM90TGdkukf7smWm9TRCQABEQoqykKZSIHOOfYsDuHWatSmbl6F8mb9lJU4oiLCuXkrs0Y0705J3VpRmxUqL9LFRFpkBTKRBqozLxC5q5NY9aqVL5dk0Z6TgHBQcbA9k0Y2y2RMd0S6ZzYSA8KiIjUEoUyEaG4xLF0SwazVu9i1uo0Vu3wzEbTNj6Ssd2ac0q3RIZ2iCciVA8LiIjUFIUyETnI9ow8Zq1O5ZvVqcxbt5v9RSVEhQUzsnNTxngfFmjeOMLfZYqI1CsKZSJSqbyCYr7fsJtZq1OZtSqV7d6VA3q3juWUbp6HBXq3jiVIDwuIiBwXhTIROWrOOVbvzPIEtNWpLNm8F+egaaNwTjmhGWO7JzKySzOtxSkicgwUykTkmKXnFDB7TSozV6Uye00aWflFhAYbQzskMKZbImO7J9I+IdrfZYqI1AkKZSJSLQqLS1icspdZq1OZuWoX69NyAOjYLNr7NGdzBiU1ITQ4yM+ViogEJoUyEakRKXtyym5z/rAhnYLiEmIiQhjdtRljuyVy8gmJxEeH+btMEZGAoVAmIjUue38R89buZtbqXXzzSxppWfsxg/5t4xjbvTljuiXSrUWM5kQTkQZNoUxEalVJiWPF9kxmrkrlm19SWbY1E4CWsRFl49BO7NhUC6iLSIOjUCYifpW6L59vfvHc5py7dje5BcWEhwQxwjsn2phuibSKi/R3mSIiNU6hTEQCxv6iYn7YkO55WGD1Lrak5wHQrUUMY7t7Alq/tlpAXUTqJ4UyEQlIzjnWp2V7n+ZMJTllL8UljiZRoZxygmdVgdFdmxEbqQXURaR+UCgTkTohM7eQOWvTPMs//ZJKRm4hwUHG4KQmDOuYQJfEGLo0b0RSQjRhIZp2Q0TqHoUyEalzPAuo72XmKs9YtF92ZVH66yo4yGifEEXnZo3o0rwRXRJj6JzYiE7NGunhAREJaAplIlLn5RUUs2F3NutSs1m7y/s9NYuUPbkUlXh+j5lB67hIuiQ2onOiN6w192w3jtAtUBHxv8pCmRavE5E6ITIsmJ6tYunZKrbc/oKiElL25LA2tTSoeb7PX7+HgqKSsnbNG4eXBbVOiY3o4v1KaBRe2z+KiMghKZSJSJ0WFhJEl+YxdGkeU25/cYlj695c1u46ENTWpWbxTvIWcgqKy9o1iQotF9Q6J3puibZoHKGJbkWkVimUiUi95Bl3Fk37hGhO7dG8bL9zjh2Z+eWC2rrUbD5bsYM3cgvL2jUKDykf1Ly9bK2bRGq6DhGpEQplItKgmBmt4iJpFRfJSV2ble13zrEnp8A7Xi2r7FbonDVpvLt4a1m78JAgOjU7ENRKe9baJ0RrIXYROS4KZSIieMJa00bhNG0UzomdEsody8wtZF1a1oGHDNKyWZyyl+k/bS9rExJkJDWNLgtqpV+dmjUiIlRPhIrIkSmUiYgcQWxUKAPbxzOwfXy5/Tn7i9iQlsNan5611Tuz+OLnnXgfCMUM2sV7pu/o3LyRdxoPzxQejcL1K1hEDtBvBBGRYxQdHkLvNrH0blP+idD9RcVs3J1Trmdt3a5s5qxNo7D4wDRELWMjynrUOjaNJj46nLioUO9XGHGRoUSFBeuBA5EGQqFMRKSahYcE061FY7q1aFxuf1FxCZvTc30eMvDMtfbmwi3kFRYf8lphwUHERoUSF+kJa7GRYTTxCW6xkaE0iQrzHjuwP1phTqTOUSgTEaklIcFBdGzWiI7NGjG+54H9JSWO1Kz9ZOQVkJFb6P0qICPPs52ZV8DenEIy8grYujeXn7cXsje3gPzCksO+V2iwERvpCWueQOe77fs6rFzvnMKciP8olImI+FlQkNEiNoIWsRFVOi+/sJhMb3Dbm1tQFuAycgu9ge5AyNuWkcfP2zPJyC08bK8ceB5Y8L19Wto7FxcVSpOoUGJ99jfx9tTFRYXSKDxEYU7kOCmUiYjUURGhwUSEBtO8cdXD3L68Qvb69Mhl5np64vbmFpYLd9sy8lm5fR8ZeYXkFhw5zMVG+ga68r1zEaHBhIUEERrs+fJsG2HBvvuMsOBgQkPswD5vW80PJ/WdQpmISANTGuYSqxjm9hcVe8NbxVusB3rnMr29djsy81m9M4uM3IJyKygcjyCjLKSFlga6kAPBzRPirCzwle3ztg33CYSec8wnHHrale476BrB5j3u094nUJaGxiAzzPB8x/P0rXoQ5WgplImIyFEJDwkmsfExhrm8QvYXllBQXEJhcQmFRe7AdnEJBUXe78WOwrLt0v3uQDuftoVFrmxfoU/bguISsvcXlXufsnOKD7TxXRu1JpmB4QlqQd4XQT7B7cA+I8gb4kq/lx4vC3pG2bZv8PNtQ+l7BYHhcy3f9+AQ1/J9bzjEPs9r35+rbJtyL3y/edtaxcOHuMYR2h/m/Q5/jcrbl8/KnhfDOsZzTr/W+ItCmYiI1KjwkGASYwJvAl3nHEUlrnx4Ky4pFwpLA2FBke8xd3BA9LYFz4MbDihxzjNfnfe7w/O9xDlwB4670n1l5zjvPk+NzrctntfOHbiW8/4sJSUH3uNAm9JaKl7Lu10CxZQcuJbPtR2l1zxwrQN/dj5/jhX+TCvu4whtD97v294dvK/cxavhej5tYyNDOaffoa9fGxTKRESkQTKzstudhPm7GhGotYXazKytmX1jZivN7Gczu/UQbS41s2VmttzMvjOzvrVVn4iIiIg/1WZPWRFwh3NuiZnFAIvN7Cvn3EqfNhuBk5xze81sIvAcMLQWaxQRERHxi1oLZc65HcAO73aWma0CWgMrfdp853PKAqBNbdUnIiIi4k+1dvvSl5klAf2BHypp9lvgs8OcP8XMks0sOS0trQYqFBEREaldtR7KzKwR8B5wm3Nu32HanIInlN11qOPOueecc4Occ4OaNWtWc8WKiIiI1BJzh3u2tCbezCwU+AT4wjn3yGHa9AE+ACY659YcxTXTgJRqLfTQmgK7a+F9pGbo86v79BnWffoM6zZ9ftWjvXPukD1KtRbKzDMT3P+AdOfcbYdp0w6YBVxRYXyZ35lZsnNukL/rkGOjz6/u02dY9+kzrNv0+dW82nz6cgRwObDczJZ69/0BaAfgnHsWuB9IAJ72zuZbpP8AREREpCGozacv51F+BYRDtbkGuKZ2KhIREREJHH55+rKOes7fBchx0edX9+kzrPv0GdZt+vxqWK0O9BcRERGRQ1NPmYiIiEgAUCgTERERCQAKZSIiIiIBQKFMREREJAAolImIiIgEAIUyERERkQCgUCYiIiISABTKRERERAKAQpmIiIhIAFAoExEREQkACmUiIiIiAUChTERERCQAKJSJiIiIBACFMhEREZEAoFAmIiIiEgAUykREREQCgEKZiIiISABQKBMREREJAAplIiIiIgFAoUxEREQkACiUiYiIiAQAhTIRERGRAKBQJiIiIhIAFMpEREREAkCIvws4Xk2bNnVJSUn+LkNERETkiBYvXrzbOdfsUMfqfChLSkoiOTnZ32WIiIiIHJGZpRzumG5fioiIiAQAhTIRERGRAKBQJiIiIhIA6vyYsppWWFzCf75aw1UjOtAsJtzf5YiISA0rLCxk69at5Ofn+7sUqcMiIiJo06YNoaGhR32OQtkR/LQlg+fnbuCVBSn8fvwJXDK0PcFB5u+yRESkhmzdupWYmBiSkpIw0+97qTrnHHv27GHr1q106NDhqM/T7csjGJQUz+e3jaZ361ju++hnfv30fJZvzfR3WSIiUkPy8/NJSEhQIJNjZmYkJCRUubdVoewodGrWiNeuGcpjF/Vje2Y+5zw1jz99tIJ9+YX+Lk1ERGqAApkcr2P5b0ih7CiZGef0a83MO07i8mHteWVBCmMems1HS7fhnPN3eSIiUk9kZGTw9NNPH9O5p59+OhkZGcd07qRJk3j33XcP2v/tt99y5plnHtM1j+Sf//xnjVy3Mvfffz9ff/11pW1OPvnkQ86BOm3aNG666aaaKk2hrKoaR4Tyl3N68dGNI2kVF8Gtby7l0hd+YF1qtr9LExGReqCyUFZUVFTpuTNmzCAuLq4Gqjo2R6q3tkNZcXExf/3rXzn11FNr9X2PlkLZMerdJpYPbhjB337Vi+XbMpn42Bwe+uIX8guL/V2aiIjUYXfffTfr16+nX79+3HnnnXz77beMGjWKs88+mx49egDwq1/9ioEDB9KzZ0+ee+65snOTkpLYvXs3mzZtonv37kyePJmePXsybtw48vLyAHj++ecZPHgwffv25bzzziM3N7fs/K+//ppBgwbRtWtXPvnkk4Nqy8nJ4eqrr2bIkCH079+fjz766KA2R1vv3XffTV5eHv369ePSSy8F4NVXX2XIkCH069ePa6+9luLi8n+nfv7551xwwQXl3qu0F+/6669n0KBB9OzZkz/96U/l/kzuuusuBgwYwDvvvFOuR/Cvf/0rgwcPplevXkyZMqXcna9XXnmFfv360atXLxYuXHjQz5mWlsZ5553H4MGDGTx4MPPnzz/4w6wiPX15HIKDjMuHtWdCzxY8MGMVT36zjo9+2sZfz+7FKd0S/V2eiIgcp798/DMrt++r1mv2aNWYP53V87DH//Wvf7FixQqWLl0KeILHkiVLWLFiRdmTfFOnTiU+Pp68vDwGDx7MeeedR0JCQrnrrF27ljfeeIPnn3+eCy+8kPfee4/LLruMX//610yePBmAe++9lxdffJGbb74ZgE2bNrFw4ULWr1/PKaecwrp168pd8x//+Adjxoxh6tSpZGRkMGTIEE499VSio6PLtTuaev/1r3/x5JNPlv2cq1at4q233mL+/PmEhoZyww038Nprr3HFFVeUXffUU09lypQp5OTkEB0dzVtvvcVFF11UVlt8fDzFxcWMHTuWZcuW0adPHwASEhJYsmQJ4Al2pW666Sbuv/9+AC6//HI++eQTzjrrLAByc3NZunQpc+bM4eqrr2bFihXlfsZbb72V3/3ud4wcOZLNmzczfvx4Vq1addjP9WgolFWDZjHhPPKbflwwqC33fbSCq6YtYnzP5vzprJ60iov0d3kiIlLHDRkypNzUCo8//jgffPABAFu2bGHt2rUHhbIOHTrQr18/AAYOHMimTZsAWLFiBffeey8ZGRlkZ2czfvz4snMuvPBCgoKC6NKlCx07dmT16tXlrvnll18yffp0HnroIcDzpOrmzZvp3r37cdc7c+ZMFi9ezODBgwHIy8sjMbF8B0dISAgTJkzg448/5vzzz+fTTz/lwQcfBODtt9/mueeeo6ioiB07drBy5cqyUPab3/zmkH+u33zzDQ8++CC5ubmkp6fTs2fPslB28cUXAzB69Gj27dt30Fi9r7/+mpUrV5a93rdvH9nZ2TRq1OiQ73U0FMqq0YmdEphxyyien7uBJ2at5dRHZnPr2C5cPbIDocG6UywiUtdU1qNVm3x7or799lu+/vprvv/+e6Kiojj55JMPOfVCePiBCc+Dg4PLbl9OmjSJDz/8kL59+zJt2jS+/fbbsnYVnxis+No5x3vvvccJJ5xQ7fU657jyyit54IEHKr32RRddxJNPPkl8fDyDBg0iJiaGjRs38tBDD7Fo0SKaNGnCpEmTyr1HxZ488ATKG264geTkZNq2bcuf//zncucc6c+ipKSEBQsWEBERUWm9VaGkUM3CQoK48ZTOfPW7kxjeKYEHPlvNmY/PY9GmdH+XJiIidUBMTAxZWVmHPZ6ZmUmTJk2Iiopi9erVLFiwoErXz8rKomXLlhQWFvLaa6+VO/bOO+9QUlLC+vXr2bBhw0Hha/z48TzxxBNlY69+/PHHI75fZfWGhoZSWOiZXmrs2LG8++67pKamApCenk5KSspB1zvppJNYsmQJzz//fNmty3379hEdHU1sbCy7du3is88+O2JdpQGsadOmZGdnH/Tk6VtvvQXAvHnziI2NJTY2ttzxcePG8cQTT5S9Lr0NezwUympI2/goXrhyMM9dPpDs/UVc8Oz33PnOT+zJ3u/v0kREJIAlJCQwYsQIevXqxZ133nnQ8QkTJlBUVET37t25++67GTZsWJWu/7e//Y2hQ4cyYsQIunXrVu5Yu3btGDJkCBMnTuTZZ589qBfovvvuo7CwkD59+tCzZ0/uu+++I75fZfVOmTKFPn36cOmll9KjRw/+/ve/M27cOPr06cNpp53Gjh07DrpecHAwZ555Jp999lnZIP++ffvSv39/unXrxiWXXMKIESOOWFdcXByTJ0+mV69ejB8/vuy2aamIiAj69+/Pddddx4svvnjQ+Y8//jjJycn06dOHHj168Oyzzx7xPY/E6vocW4MGDXKHmkskkOQWFPH4zHW8MHcDjSJCuGtCN34zqC1BWq5JRCTgrFq16qAxUiLH4lD/LZnZYufcoEO1V09ZLYgKC+Huid2YcesoujaP4Z73l3Pes9/x83Yt1yQiIiIeCmW1qGvzGN6aMoyHL+jL5j25nPXEPP768Uqy91c+uZ6IiIjUfwpltczMOG9gG2becRIXDWnHS99tZOzD3/Lpsh1arklERKQBUyjzk7ioMP55bm/ev344TRuFc+PrS7jypUVs2p3j79JERETEDxTK/Kx/uyZ8dOMI/nRWD5ak7GXco3P4z1drtFyTiIhIA6NQFgBCgoO4akQHZt1xEuN7tuCxmWuZ8Ogc5qxJ83dpIiIiUksUygJIYuMInri4P6/8dghmxhVTF3Lj60vYmXnwzMciIiLH4sMPPyy3PNCxWLp0KTNmzKimiqSUQlkAGtWlGZ/dOorbT+vKVyt3ceojs3lx3kaKikv8XZqIiNRxRxvKiooOPzOAQlnNUCgLUBGhwdwytgtf/W40A9s34W+frOTsJ+ezZPNef5cmIiI17NVXX2XIkCH069ePa6+9luLiYhYtWkSfPn3Iz88nJyeHnj17smLFCrKzsxk7diwDBgygd+/efPTRR2XXefnll+nTpw99+/bl8ssv57vvvmP69Onceeed9OvXj/Xr15d730mTJnHdddcxdOhQfv/737Nw4UJOPPFE+vfvz/Dhw/nll18oKCjg/vvv56233qJfv3689dZb5OTkcPXVVzNkyBD69+9frgY5eprRvw5wzvH5ip385eOV7NyXz8VD2nLXhG7ERYX5uzQRkXqn3Czsn90NO5dX7xu06A0T/1Xp+//+97/n/fffJzQ0lBtuuIFhw4ZxxRVXcO+995Kfn09eXh5t2rThnnvuoaioiNzcXBo3bszu3bsZNmwYa9euZeXKlZx77rl89913NG3alPT0dOLj45k0aRJnnnkm559//kHvPWnSJHbv3s1HH31EcHAw+/btIyoqipCQEL7++mueeeYZ3nvvPaZNm0ZycjJPPvkkAH/4wx/o0aMHl112GRkZGQwZMoQff/zxkAuBNyRVndE/pFaqOlDIBOAxIBh4wTl30H+VZnYh8GfAAT855y6pzRoDkZkxsXdLRnVtxqNfreGl7zbxxc+7uGdiN84f2OagletFRKTumjlzJosXLy5bizEvL4/ExEQA7r//fgYPHkxERASPP/444PmH+x/+8AfmzJlDUFAQ27ZtY9euXcyaNYsLLriApk2bAhAfH39U73/BBRcQHBwMeBYTv/LKK1m7di1mVrZ4eEVffvkl06dP56GHHgI8i31v3rxZy1VVUa2FMjMLBp4CTgO2AovMbLpzbqVPmy7APcAI59xeM0usrfrqgkbhIdx7Zg/OG9iGP36wnDvfXcbbyVv4+696c0KLGH+XJyJS/1TSo1VTnHNceeWVPPDAAwcd27NnD9nZ2RQWFpKfn090dDSvvfYaaWlpLF68mNDQUJKSksjPP/YHxHx7t+677z5OOeUUPvjgAzZt2sTJJ5982Jrfe+89TjjhhGN+X6ndMWVDgHXOuQ3OuQLgTeCcCm0mA0855/YCOOdSa7G+OqN7y8a8e91w/u+83qxNzeaMx+fywIxV5Gi5JhGROm/s2LG8++67pKZ6/gpMT08nJSUFgGuvvZa//e1vXHrppdx1112ApzcrMTGR0NBQvvnmm7K2Y8aM4Z133mHPnj1l1wGIiYkhKyvrqGrJzMykdevWAEybNq1sf8VrjB8/nieeeKJsZZoff/zxWH/8Bq02Q1lrYIvP663efb66Al3NbL6ZLfDe7jyImU0xs2QzS05La5hzeQUFGb8Z3I5Zd5zMeQPa8N85Gzjtkdl8vmKnlmsSEanDevTowd///nfGjRtHnz59OO2009ixYwcvv/wyoaGhXHLJJdx9990sWrSIWbNmcemll5KcnEzv3r15+eWX6datGwA9e/bkj3/8IyeddBJ9+/bl9ttvB+Ciiy7i3//+N/379z9ooH9Fv//977nnnnvo379/uacxTznlFFauXFk20P++++6jsLCQPn360LNnT+67776a+wOqx2ptoL+ZnQ9McM5d4319OTDUOXeTT5tPgELgQqANMAfo7ZzLONx1G8JA/6ORvCmdez9cweqdWYzplshfzu5J2/gof5clIlLnHGpwtsixqOpA/9rsKdsGtPV53ca7z9dWYLpzrtA5txFYA3SppfrqtEFJ8Xx880juPaM7Czbs4dRHZvPkrLXsL9JyTSIiInVBbYayRUAXM+tgZmHARcD0Cm0+BE4GMLOmeG5nbqjFGuu00OAgrhnVkZl3nMSYbok89OUaJj42l+/W7fZ3aSIiInIEtRbKnHNFwE3AF8Aq4G3n3M9m9lczO9vb7Atgj5mtBL4B7nTO7amtGuuLlrGRPHPZQF66ajBFxY5LXviBW9/8kdQsLdckIiISqDR5bD2XX1jM09+s49nZGwgPDeLO8Sdw6dD2BAdpbjMRkUNZtWoV3bp10xyQclycc6xevTpgx5SJH0SEBnP7uBP4/LZR9GkTy/0f/cyvnprPsq0Z/i5NRCQgRUREsGfPHj3JLsfMOceePXuIiIio0nnqKWtAnHN8vGwHf/tkJbuz93PZ0Pb8v/EnEBsZ6u/SREQCRmFhIVu3bj2uCVhFIiIiaNOmDaGh5f+OraynTKGsAdqXX8gjX67h5e83ER8dxh/P6M6v+rVWV72IiEgN0+1LKadxRCh/Prsn028aSesmUfzurZ+44NnvWbQp3d+liYiINFgKZQ1Yr9axvH/9cP55bm9S0nO54Nnv+e20Razeuc/fpYmIiDQ4un0pAOQWFPHS/E08O3s92fuLOLdfa353WletCiAiIlKNNKZMjlpGbgHPzF7PtPmbKHGOS4e256YxnWnaKNzfpYmIiNR5CmVSZTsy83h85lreTt5KeIhnpYDJozoQE6EnNUVERI6VQpkcs/Vp2Tzy5Ro+Xb6D+OgwbjylM5cObUdEaLC/SxMREalzFMrkuP20JYN/f/EL89btpnVcJLed2oVfD2ijlQFERESqQFNiyHHr2zaOV68Zyqu/HUpCozDufHcZEx6dw5c/79Ss1yIiItVAoUyqZGSXpnx04wievnQAxSWOKa8s5tfPfMeCDVo3XkRE5HgolEmVmRmn927Jl78bzb9+3ZsdGflc9NwCrpy6kJ+3Z/q7PBERkTpJY8rkuOUXFvO/7zbx9Lfrycwr5Oy+rbj9tK4kNY32d2kiIiIBRQP9pVZk5hXy3Jz1vDhvI0XFjouGtOWWMV1IbBzh79JEREQCgkKZ1KrUffk8Pmstby7cQmhwEFePTGLK6E7ERmqOMxERadgUysQvNu3O4ZGv1jD9p+3ERoZyw8mduHJ4kuY4ExGRBkuhTPxqxbZM/v3FL8xek0aLxhHcdmoXzh/YhpBgPWciIiINi+YpE7/q1TqW/109hDcmD6NlXAR3v7+ccY/OYcbyHZrjTERExEuhTGrNiZ0SeP/64Tx3+UCCzbjhtSWc89R85q/b7e/SRERE/E6hTGqVmTGuZws+v200/z6/D7uz9nPpCz9w2Qs/sGxrhr/LExER8RuNKRO/yi8s5rUfNvPkrLXszS3kjN4tuX1cVzo1a+Tv0kRERKqdBvpLwMvKL+T5uRt5Ye4G9heVcOGgNtw6tistYjXHmYiI1B8KZVJn7M7ez5Oz1vHaDykEmTFpRBLXn9SJuKgwf5cmIiJy3BTKpM7Zkp7Lf75awwdLtxETHsJ1J3fiquEdiAzTHGciIlJ3VWsoM7MI4EygE/Bf51yGmXUC9jrn0o+72ipSKKvfVu/cx0Nf/MLXq1JJjAnnlrFd+M3gtoRqjjMREamDqi2UmVln4CsgBogDujrnNpjZQ0Ccc+6aaqi3ShTKGoZFm9L5v89Wk5yyl6SEKO4YdwJn9G5JUJD5uzQREZGjVp2Txz6KJ5Q1B/J89k8HTjmKQiaY2S9mts7M7j7E8UlmlmZmS71ftR7yJDANTornnetOZOqkQUSEBnPzGz9y1pPzmL0mTRPQiohIvRBSxfbDgWHOuWKzcj0Um4FWlZ1oZsHAU8BpwFZgkZlNd86trND0LefcTVWsSxoAM2NMt+ac1DWR6T9t4+Ev13Dl1IUM6xjPXRO60b9dE3+XKCIicsyOZWBO6CH2tQMyj3DeEGCdc26Dc64AeBM45xjeXxq44CDj3P5tmHXHyfzl7J6sS83m3Ke/49pXklmXmuXv8kRERI5JVUPZl8DtPq+dmTUG/gJ8eoRzWwNbfF5v9e6r6DwzW2Zm75pZ20NdyMymmFmymSWnpaVVoXypT8JCgrhyeBKz7zyF20/ryvx1exj3nznc+c5PbMvIO/IFREREAkhVB/q3Ar7xvuwI/Ah0BnYBo51zh01IZnY+MKH0YQAzuxwY6nur0swSgGzn3H4zuxb4jXNuTGU1aaC/lErPKeDpb9bx8vcpYHDFsPbccEpn4qM1x5mIiASG6p4SIxK4GBiAp6dtCfCac67SrgkzOxH4s3NuvPf1PQDOuQcO0z4YSHfOxVZ2XYUyqWhbRh6PfrWG95ZsJSoshCmjO/LbkR2IDq/qEEoREZHqVZ1TYowGvnPOFVXYHwIMd87NqeTcEGANMBbYBiwCLnHO/ezTpqVzbod3+1zgLufcsMpqUiiTw1m7K4uHvvyFL37eRXx0GJcNbcdlw9qT2FhLN4mIiH9UZygrBlo651Ir7E8AUp1zlU63bman45lWIxiY6pz7h5n9FUh2zk03sweAs4EiIB243jm3urJrKpTJkSzZvJenv1nHzNWphAQZZ/VpxdUjO9CrdaWdsCIiItWuOkNZCdC84tgxM+uKJ1g1Pq5Kj4FCmRytjbtz+N93m3g7eQu5BcUM6RDP1SM6cFqP5gRrEloREakFxx3KzGy6d/MM4Gtgv8/hYKAXsMo5N+E4a60yhTKpqsy8Qt5J3sJL8zexLSOPNk0imTQ8iQsHt6VxxKFmfBEREake1RHKXvJuXgm8TfnZ/AuATcDzzrndx1dq1SmUybEqKi7h61W7mDpvEws3pRMdFswFg9oyaXgSSU2j/V2eiIjUQ9V5+/JPwEPOuZzqKu54KZRJdVi+NZOX5m/k42XbKSpxjO3WnKtHJnFixwQqrF4hIiJyzKp1SoxAo1Am1Sl1Xz6vLkjh1R82k55TQLcWMVw9ogNn92tFRGilz7GIiIgcUXXPU3YVnnnK2gHlZuV0znU81iKPlUKZ1IT8wmKmL93O1PkbWb0zi4ToMC7VlBoiInKcKgtlVVpmyczuBB4GFgNJwIfACiAemHpcVYoEkIjQYC4c3JbPbh3F69cMpX+7OJ74Zh0j/m8Wt7+1lBXbjrTUq4iISNVUdUzZGuAPzrl3zSwL6Ouc22Bm9wHtnHOTa6rQw1FPmdQWTakhIiLHqzoH+ucC3Zxzm80sFRjnnFtqZp2Bhc65+Oop+egplElt05QaIiJyrKrt9iWwE2jq3U4BTvRudwbq9hMDIkcpNjKUa0Z1ZPadJ/PsZQNoFRvJ3z9dxYn/nMmfp//Mpt0B83CyiIjUIVVdoXkWnmWQlgAvAv8xswvxLE7+djXXJhLQQoKDmNCrJRN6tSybUuO1H1L43/ebGNstkatHdODETppSQ0REjk5Vb18GAUGlC5Kb2W+AEXgWGv/IObelRqqshG5fSiDRlBoiIlKZGp2nzMxaAPcBVzvnIo/rYsdAoUwCkabUEBGRQznuMWVmFmdmr5lZmpltN7NbzONPwHpgKHB1NdYsUqdpSg0REamqox1T9k9gNPA/YALwH+A0IBo43Tk3u2bKE6nbzIzhnZsyvHPTclNqvP/jNk2pISIi5RztguQpwG+dc1+bWUdgHfC4c+62Gq7viHT7UuoaTakhItJwHfeYMjMrBNo757Z7X+cCg51zP1drpcdAoUzqqqLiEr5etYup8zaxcFM60WHBXDCoLZOGJ5HUNNrf5YmISA2oLJQd7e3LIKDQ53UxkHu8hYk0ZJpSQ0REfB1tT1kJ8BWw37trIjCbCsHMOXd2dRd4JOopk/pEU2qIiNRv1XH78qWjeSPn3FVVrO24KZRJfaQpNURE6qcanafM3xTKpD5zzvH9+j1Mnb+RmatTCQkyzurTiqtHdqBX61h/lyciIlVUHWPKRMQPNKWGiEjDoZ4ykTrmcFNqnNu/NQmNwv1dnoiIVEK3L0XqoYpTagQZDOuYwOm9WzKhVwuaKqCJiAQchTKRem7Vjn18umwHM5bvYMPuHIIMhnZI4PQ+LRnfszmJMXo4QEQkECiUiTQQzjlW78xixvIdfLp8BxvScjCDIUnxnNHH04OmgCYi4j8KZSINkHOONbuy+XS5pwdtXWo2ZjA4KZ4zerdkYq8Wml5DRKSWKZSJCGt2ZZXd4lxbGtDax3N67xZM7N2S5gpoIiI1LmBCmZlNAB4DgoEXnHP/Oky784B38ayvWWniUigTqbq1u7LKetDW7PIEtIHtmnB675ac3rslLWIV0EREakJAhDIzCwbWAKcBW4FFwMXOuZUV2sUAnwJhwE0KZSI1a11qFjOW72TG8h2s3pkFwMD2pQGtBS1jI/1coYhI/REooexE4M/OufHe1/cAOOceqNDuUTzrbN4J/D+FMpHasz4tmxnLPA8JlAa0Ae3iynrQWsUpoImIHI9ACWXnAxOcc9d4X18ODHXO3eTTZgDwR+fceWb2LYcJZWY2BZgC0K5du4EpKSm18SOINCgb0rK9T3HuZNWOfQD0axvneUigdwvaNInyc4UiInVPnQhlZhYEzAImOec2VRbKfKmnTKTmbdydwwzvGLSft3sCWt+2cZzRuwUTe7WkbbwCmojI0QiUUFbp7UsziwXWA9neU1oA6cDZlQUzhTKR2rVpdw4zVngC2opt3oDWJrbsFqcCmojI4QVKKAvBM9B/LLANz0D/S5xzPx+m/beop0wkoKXsySl7SGD5tkwAerf2BLQzerekXYICmoiIr4AIZd5CTgcexTMlxlTn3D/M7K9AsnNueoW236JQJlJnbN6TW9aDtmyrJ6D1at24LKC1T4j2c4UiIv4XMKGsJiiUiQSeLem5fLbC85DAT1syAOjZ6kBAS2qqgCYiDZNCmYj4zda9uXy2fCefLt/BUm9A696yMWf0bsHpvVvSsVkj/xYoIlKLFMpEJCBsy8jjM+9i6T9uzgCgW4sYzujdktP7tKSTApqI1HMKZSIScLZn5PHZCs9DAotT9gKegFb6FGfnRAU0Eal/FMpEJKDtyMzjM+9TnMnegNa1eaOyMWhdmsf4uUIRkeqhUCYidcbOzHw+8z7FmZyyF+egS2IjxnZvzsD2TRjQLo6ERuH+LlNE5JgolIlInbRrXz6fr/A8JLAkZS9FJZ7fVx2aRtO/XRwD2zdhYPsmdEmMITjI/FytiMiRKZSJSJ2XX1jMsq2ZLE7Zy5LNe1mSspc9OQUANAoPoX+7OAa0a8KA9k3o1zaO2MhQP1csInKwykJZSG0XIyJyLCJCgxnSIZ4hHeIBcM6xOT2XxSl7vUEtgydmraXEgRl0TYxhQPsDQa1j02jM1JsmIoFLPWUiUm9k5Rfy05ZMlmzeW9ajlpVfBECTqNCygDagXRP6to0lKkz/LhWR2qWeMhFpEGIiQhnZpSkjuzQFoKTEsT4t26c3bS8zV6cCEBxk9GjZmAHt4sqCWpsmkepNExG/UU+ZiDQoe3MK+HHLXpakZLA4ZS9Lt2SQV1gMQGJMuPcJT0+PWq/WjQkPCfZzxSJSn6inTETEq0l0GGO6NWdMt+YAFBWXsHpnVrlbnp+t2AlAWHAQvVo3LnvKc0C7JiQ2jvBn+SJSj6mnTESkgtR9+Z4nPDd7etOWb82koLgEgDZNIssC2sD2TejWIoaQ4CA/VywidYWmxBAROQ77i4r5efs+lnjHpi1O2Utq1n4AIkOD6ds2tqw3rX/bJjSJDvNzxSISqBTKRESqkXOObRl5LE7Zy4/e3rSVO/ZR7J3ctmOzaAZ6x6UNbN+Ezs0aEaTJbUUEhTIRkRqXW1B0YHJb79i0vbmFAMREhNC/XRMGem959m0bS0yEJrcVaYg00F9EpIZFhYUwrGMCwzomAJ7etI27c8omtl2SspdHZ67BeSe3PaF5TLmxae0TojQdh0gDp54yEZFasi+/kKXe251LNu9l6eYMsvZ7JreNiwolKSGa9glRtIuPom18FO3jo2iXEEXzmAjd/hSpJ9RTJiISABpHhDK6azNGd20GQHGJY21qFotT9rJiWyYpezzLRn3803ZKfP69HBYSRNsmkbSLj6J9QjRt46O821G0bRJFZJjmUhOpDxTKRET8JDjI6NaiMd1aNC63v7C4hG1789icnnvga08uKem5LNyYTk5Bcbn2zWLCPb1qpT1s3t62dglRNGsUrtuiInWEQpmISIAJDQ4iqWk0SU2jDzrmnGNvbiEpe3LYnJ7LlvRcUvZ4gtuCDXv4YOk2fEelRIQGeQJafLT3e2RZb1ubJpFEhKqXTSRQKJSJiNQhZkZ8dBjx0WH0b9fkoOP5hcVsy8gr613z7Wmbv2532ZJSnmtBi8YRB8aveXvX2nm346PD1MsmUosUykRE6pGI0GA6NWtEp2aNDjrmnGN3dgGb03O8QS2PlPQctqTnMntNWtmEuKUahYd4x69Flh/LFh9Fq7hIwkK0koFIdVIoExFpIMyMZjHhNIsJZ2D7+IOO5xUUs3XvgduhpV/r03L45pc0CopKytoGGbSMjSz/tGhCaWiLJjZK87CJVJVCmYiIABAZFkyX5jF0aR5z0LGSEkdq1n6f26E5Zdtfr9rF7uyCcu0bR4TQLsET0Ep72JrFhBMXFUpcZCixUaHERoYSHqIxbSKlFMpEROSIgoKMFrERtIiNYEiHg3vZcvYXlYU034cPVu3Yx5crd1JYfOg5MaPCgr0hLYy4yFBPaIsKJTYyrCzAxUWF0jgylLjSfVGhRIYGa7yb1DsKZSIictyiw0Po3rIx3Vs2PuhYcYlj57589uYUkJFbSEae53tmXiEZuaX7CsnMLWR9WrbndW4hBcUlh3gnj7DgIGJ9QlvFEFcu5HmPxUaFEhMeojAnAatWQ5mZTQAeA4KBF5xz/6pw/DrgRqAYyAamOOdW1maNIiJSvYKDjNZxkbSOizzqc5xz5BeWlAU4T4g7EODKvc4tZFtGHiu3Z5KRV0huhXncKtYSG3ngFqonuIV59vm+9tmOi/T01AVrVQWpYbUWyswsGHgKOA3YCiwys+kVQtfrzrlnve3PBh4BJtRWjSIiEhjMjMiwYCLDImkZe/RhDmB/UTGZ3p630gCXkVvg7Zkr31O3O7uAdd7euaz8okqv2zgixBPSvOPh4nx642IjQ4mJCCEiNJjwkCDCvd8jKvkeEmTqtZNyarOnbAiwzjm3AcDM3gTOAcpCmXNun0/7aCAwFub87G7YudzfVYiIyFEIBxK9X5UKAhp5vwCHo6jEUVTsKCop8X4/sF1c4igqLqEoy1GU6dOu5PB/VRV4v7IOV4IZQVb63bNtQQfvKztmRtAxHjfzjA00wPO/cpAWvWHiv47crobUZihrDWzxeb0VGFqxkZndCNwOhAFjDnUhM5sCTAFo165dtRcqIiINj2GEBhmhQeAZZXN0HJ7AVlziKHFQ4pz3y7PtSg7eV+Icznmeaj3UOSXOc73CkpKDjjvnjrvHomKY84Q2IyjoMGHOu88MT6jzbnv+3A68Lo16vq8Nb1vfc8F73Oc6Pq+N0vYV3hOf69TDYBlwA/2dc08BT5nZJcC9wJWHaPMc8BzAoEGDar43zY+pWUREApvh+cu0Nv9CLSouYX9RCfmFxUf1ff9RtjvwvYT9RcXs934vfV1U4sot4+VPQeYZIxgcZIQEBXm/e3oKQ8r2+74OKtsf7HPcd/vkJolc5sefqTb/G9oGtPV53ca773DeBJ6p0YpERETqoJDgIEKCg4gOr/2+lZISR7G3J6+o5EAvYVFJifcWr6dXr/RY6a1fzzkl5V4XlTiKi32uc4g2vtcoKvFeu9jbrrSNzzWKShwlZbWVlKvxoJqLHfmFxWX7M/MKa/3P01dtfpqLgC5m1gFPGLsIuMS3gZl1cc6t9b48A1iLiIiIBIygICMIQ2vZV79aC2XOuSIzuwn4As/N+qnOuZ/N7K9AsnNuOnCTmZ0KFAJ7OcStSxEREZH6qFb7PZ1zM4AZFfbd77N9a23WIyIiIhIogvxdgIiIiIgolImIiIgEBIUyERERkQBgLlAmHDlGZpYGpNTCWzUFdtfC+0jN0OdX9+kzrPv0GdZt+vyqR3vnXLNDHajzoay2mFmyc26Qv+uQY6PPr+7TZ1j36TOs2/T51TzdvhQREREJAAplIiIiIgFAoezoPefvAuS46POr+/QZ1n36DOs2fX41TGPKRERERAKAespEREREAoBCmYiIiEgAUCg7AjObYGa/mNk6M7vb3/VI1ZhZWzP7xsxWmtnPZqb1VesgMws2sx/N7BN/1yJVZ2ZxZvauma02s1VmdqK/a5KqMbPfeX+HrjCzN8wswt811UcKZZUws2DgKWAi0AO42Mx6+LcqqaIi4A7nXA9gGHCjPsM66VZglb+LkGP2GPC5c64b0Bd9lnWKmbUGbgEGOed6AcHARf6tqn5SKKvcEGCdc26Dc64AeBM4x881SRU453Y455Z4t7Pw/GXQ2r9VSVWYWRvgDOAFf9ciVWdmscBo4EUA51yBcy7Dr0XJsQgBIs0sBIgCtvu5nnpJoaxyrYEtPq+3or/Q6ywzSwL6Az/4uRSpmkeB3wMlfq5Djk0HIA14yXsL+gUzi/Z3UXL0nHPbgIeAzcAOINM596V/q6qfFMqkQTCzRsB7wG3OuX3+rkeOjpmdCaQ65xb7uxY5ZiHAAOAZ51x/IAfQ+Nw6xMya4LlL1AFoBUSb2WX+rap+Uiir3Dagrc/rNt59UoeYWSieQPaac+59f9cjVTICONvMNuEZPjDGzF71b0lSRVuBrc650h7qd/GENKk7TgU2OufSnHOFwPvAcD/XVC8plFVuEdDFzDqYWRiegY3T/VyTVIGZGZ6xLKucc4/4ux6pGufcPc65Ns65JDz//5vlnNO/0OsQ59xOYIuZneDdNRZY6ceSpOo2A8PMLMr7O3UselijRoT4u4BA5pwrMrObgC/wPG0y1Tn3s5/LkqoZAVwOLDezpd59f3DOzfBfSSINzs3Aa95/3G4ArvJzPVIFzrkfzOxdYAmeJ9p/REsu1QgtsyQiIiISAHT7UkRERCQAKJSJiIiIBACFMhEREZEAoFAmIiIiEgAUykREREQCgEKZiEg1MTNnZuf7uw4RqZsUykSkXjCzad5QVPFrgb9rExE5Gpo8VkTqk6/xTBbsq8AfhYiIVJV6ykSkPtnvnNtZ4Ssdym4t3mRmn5pZrpmlVFxU2cx6m9nXZpZnZune3rfYCm2uNLPlZrbfzHaZ2f8q1BBvZu+YWY6ZbdDCzSJytBTKRKQh+Que9Wv74Vkm5mUzGwRgZtF4llTLBoYA5+JZdHlq6clmdi3wX+AloA9wOrCiwnvcD3wE9AXeAqaaWbsa+4lEpN7QMksiUi+Y2TTgMiC/wqGnnHN3mZkDXnDOTfY552tgp3PuMjObDDwEtHHOZXmPnwx8A3Rxzq0zs63Aq865uw9TgwP+5Zy7x/s6BNgHTHHOvVp9P62I1EcaUyYi9ckcYEqFfRk+299XOPY9cIZ3uzuwrDSQeX0HlAA9zGwf0BqYeYQalpVuOOeKzCwNSDyq6kWkQVMoE5H6JNc5t64GrluVWwqFhzhXQ0VE5Ij0i0JEGpJhh3i9yru9CuhtZjE+x4fj+T25yjmXCmwDxtZ4lSLSIKmnTETqk3Aza1FhX7FzLs27/WszWwR8C5yPJ2AN9R57Dc+DAC+b2f1AEzyD+t/36X37B/AfM9sFfApEAWOdcw/X1A8kIg2HQpmI1CenAjsq7NsGtPFu/xk4D3gcSAOucs4tAnDO5ZrZeOBRYCGeBwY+Am4tvZBz7hkzKwDuAP4PSAdm1NDPIiINjJ6+FJEGwftk5AXOuXf9XYuIyKFoTJmIiIhIAFAoExEREQkAun0pIiIiEgDUUyYiIiISABTKRERERAKAQpmIiIhIAFAoExEREQkACmUiIiIiAeD/A3ZIgfSfTk4iAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot to see the convergence of the estimated and true parameters\n", "tensor_exact_value = tf.constant(exact_value, shape=[len(train_rate_results)])\n", "\n", "fig, ax = plt.subplots(2, sharex=True)\n", "fig.suptitle('Convergence')\n", "\n", "ax[0].set_ylabel('Loss', fontsize=14)\n", "ax[0].plot(train_loss_results)\n", "\n", "ax[1].set_ylabel('Rate', fontsize=14)\n", "ax[1].set_xlabel('Epoch', fontsize=14)\n", "ax[1].plot(train_rate_results, label='trainable rate variable')\n", "ax[1].plot(tensor_exact_value, label='exact rate')\n", "ax[1].legend()\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# Making a function get_data which:\n", "# 1) Fetches the 20 newsgroup dataset\n", "# 2) Performs a word count on the articles and binarizes the result\n", "# 3) Returns the data as a numpy matrix with the labels\n", "\n", "def get_data(categories):\n", "\n", " newsgroups_train_data = fetch_20newsgroups(data_home='./dataset/20_Newsgroup_Data/',\n", " subset='train', categories=categories)\n", " newsgroups_test_data = fetch_20newsgroups(data_home='./dataset/20_Newsgroup_Data/',\n", " subset='test', categories=categories)\n", "\n", " n_documents = len(newsgroups_train_data['data'])\n", " count_vectorizer = CountVectorizer(input='content', binary=True,max_df=0.25, min_df=1.01/n_documents) \n", " train_binary_bag_of_words = count_vectorizer.fit_transform(newsgroups_train_data['data']) \n", " test_binary_bag_of_words = count_vectorizer.transform(newsgroups_test_data['data']) \n", "\n", " return (train_binary_bag_of_words.todense(), newsgroups_train_data['target']), (test_binary_bag_of_words.todense(), newsgroups_test_data['target'])" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# Defining a function to conduct laplace smoothing. This adds a base level of probability for a given feature\n", "# to occur in every class.\n", "\n", "def laplace_smoothing(labels, binary_data, n_classes):\n", " # Compute the parameter estimates (adjusted fraction of documents in class that contain word)\n", " n_words = binary_data.shape[1]\n", " alpha = 1 # parameters for Laplace smoothing\n", " theta = np.zeros([n_classes, n_words]) # stores parameter values - prob. word given class\n", " for c_k in range(n_classes): # 0, 1, ..., 19\n", " class_mask = (labels == c_k)\n", " N = class_mask.sum() # number of articles in class\n", " theta[c_k, :] = (binary_data[class_mask, :].sum(axis=0) + alpha)/(N + alpha*2)\n", "\n", " return theta" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# Now we will do a function that given the feature occurence counts returns a Bernoulli distribution of \n", "# batch_shape=number of classes and event_shape=number of features.\n", "\n", "def make_distributions(probs):\n", " batch_of_bernoullis = tfd.Bernoulli(probs=probs) # shape (n_classes, n_words)\n", " dist = tfd.Independent(batch_of_bernoullis, reinterpreted_batch_ndims=1)\n", " return dist" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# Function which computes the prior probability of every class based on frequency of occurence in \n", "# the dataset\n", "\n", "def class_priors(n_classes, labels):\n", " counts = np.zeros(n_classes)\n", " for c_k in range(n_classes):\n", " counts[c_k] = np.sum(np.where(labels==c_k, 1, 0))\n", " priors = counts / np.sum(counts)\n", " print('The class priors are {}'.format(priors))\n", " return priors" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "# The final function predict_sample which given the distribution, a test sample, and the class priors:\n", "# 1) Computes the class conditional probabilities given the sample\n", "# 2) Forms the joint likelihood\n", "# 3) Normalises the joint likelihood and returns the log prob\n", "\n", "def predict_sample(dist, sample, priors):\n", " cond_probs = dist.log_prob(sample)\n", " joint_likelihood = tf.add(np.log(priors), cond_probs)\n", " norm_factor = tf.math.reduce_logsumexp(joint_likelihood, axis=-1, keepdims=True)\n", " log_prob = joint_likelihood - norm_factor\n", "\n", " return log_prob" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "# Now we learn the distribution using gradient tape\n", "\n", "def make_distribution_withGT(data, labels, nb_classes):\n", "\n", " class_data = []\n", " train_vars = []\n", " distributions = []\n", " for c in range(nb_classes):\n", " train_vars.append(tf.Variable(initial_value=np.random.uniform(low=0.01, high =0.1, size=data.shape[-1])))\n", " distributions.append(tfd.Bernoulli(probs=train_vars[c]))\n", " class_mask = (labels == c)\n", " class_data.append(data[class_mask, :])\n", "\n", " for c_num in range(0,nb_classes):\n", " optimizer = tf.keras.optimizers.Adam()\n", " print('\\n%-------------------%')\n", " print('Class ', c_num)\n", " print('%-------------------%')\n", "\n", " for i in range(0, 100):\n", " loss, grads = get_loss_and_grads(class_data[c_num], distributions[c_num])\n", " if i % 10 == 0:\n", " print(\"iter: {}, Loss: {}\".format(i, loss))\n", " optimizer.apply_gradients(zip(grads, distributions[c_num].trainable_variables))\n", " eta = 1e-3\n", " clipped_probs = tf.clip_by_value(distributions[c_num].trainable_variables,\n", " clip_value_min=eta, clip_value_max=1)\n", " \n", " train_vars[c_num] = tf.squeeze(clipped_probs)\n", "\n", " dist = tfd.Bernoulli(probs=train_vars)\n", " dist = tfd.Independent(dist,reinterpreted_batch_ndims=1)\n", "\n", " print(dist)\n", "\n", " return dist\n" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The class priors are [0.2359882 0.28711898 0.29154376 0.18534907]\n" ] } ], "source": [ "# Make the same Naive Bayes classifier we did last tutorial\n", "\n", "categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']\n", "\n", "(train_data, train_labels), (test_data, test_labels) = get_data(categories)\n", "\n", "smoothed_counts = laplace_smoothing(labels=train_labels, binary_data=train_data, n_classes=len(categories))\n", "\n", "priors = class_priors(n_classes=len(categories), labels=train_labels)\n", "tf_dist = make_distributions(smoothed_counts)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "%-------------------%\n", "Class 0\n", "%-------------------%\n", "iter: 0, Loss: 0.07864662925861293\n", "iter: 10, Loss: 0.06923920529302693\n", "iter: 20, Loss: 0.060484430932711934\n", "iter: 30, Loss: 0.052378958091660745\n", "iter: 40, Loss: 0.044884874447401975\n", "iter: 50, Loss: 0.03795957675935009\n", "iter: 60, Loss: 0.03156166893228506\n", "iter: 70, Loss: 0.025648909539426928\n", "iter: 80, Loss: 0.020179556307287093\n", "iter: 90, Loss: 0.015095419046706705\n", "\n", "%-------------------%\n", "Class 1\n", "%-------------------%\n", "iter: 0, Loss: 0.07162433404608501\n", "iter: 10, Loss: 0.06226791554203955\n", "iter: 20, Loss: 0.053458417310592254\n", "iter: 30, Loss: 0.04525733720016119\n", "iter: 40, Loss: 0.03764243521857365\n", "iter: 50, Loss: 0.0305879746963904\n", "iter: 60, Loss: 0.02407123784997883\n", "iter: 70, Loss: 0.018063326103411242\n", "iter: 80, Loss: 0.012530662501078415\n", "iter: 90, Loss: 0.007417711392007358\n", "\n", "%-------------------%\n", "Class 2\n", "%-------------------%\n", "iter: 0, Loss: 0.07864916432960509\n", "iter: 10, Loss: 0.06954586738662134\n", "iter: 20, Loss: 0.061138999776087846\n", "iter: 30, Loss: 0.05346207920199955\n", "iter: 40, Loss: 0.046474524854562514\n", "iter: 50, Loss: 0.040144255228274424\n", "iter: 60, Loss: 0.03443733650612573\n", "iter: 70, Loss: 0.029299458896811317\n", "iter: 80, Loss: 0.024681602429558972\n", "iter: 90, Loss: 0.020525606754514734\n", "\n", "%-------------------%\n", "Class 3\n", "%-------------------%\n", "iter: 0, Loss: 0.07990305803193348\n", "iter: 10, Loss: 0.07064669667549849\n", "iter: 20, Loss: 0.062048971145070374\n", "iter: 30, Loss: 0.05407979822912391\n", "iter: 40, Loss: 0.04669298874331363\n", "iter: 50, Loss: 0.0398430034890397\n", "iter: 60, Loss: 0.033480511857230395\n", "iter: 70, Loss: 0.02756906082189042\n", "iter: 80, Loss: 0.022072121868228288\n", "iter: 90, Loss: 0.016940884899701955\n", "tfp.distributions.Independent(\"IndependentBernoulli\", batch_shape=[4], event_shape=[17495], dtype=int32)\n" ] } ], "source": [ "# Now train the distributions with gradient tape\n", "GT_dist = make_distribution_withGT(data=train_data, labels=train_labels, nb_classes=4)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "f1 0.8265056782070946\n", "f1 0.7848499112849504\n" ] } ], "source": [ "# Compare the two results\n", "\n", "for dist in [GT_dist,tf_dist]:\n", " probabilities = []\n", " for sample, label in zip(test_data, test_labels):\n", " probabilities.append(predict_sample(dist, sample, priors))\n", "\n", " probabilities = np.asarray(probabilities)\n", " predicted_classes = np.argmax(probabilities, axis =-1)\n", " print('f1 ', f1_score(test_labels, predicted_classes, average='macro'))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }