{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Trainable Distributions\n", "\n", "> In this post, we will take a look at how to make the parameters of distribution object trainable. This is the summary of lecture \"Probabilistic Deep Learning with Tensorflow 2\" from Imperial College London.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow_probability, ICL]\n", "- image: images/tfd_trainable_distribution.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "tfd = tfp.distributions\n", "\n", "plt.rcParams['figure.figsize'] = (10, 6)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow Version: 2.5.0\n", "Tensorflow Probability Version: 0.13.0\n" ] } ], "source": [ "print(\"Tensorflow Version: \", tf.__version__)\n", "print(\"Tensorflow Probability Version: \", tfp.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Previously, we just define the mean and standard deviation with floating type, but we can also use optimizer object to apply gradients obtained from a loss function and data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "normal = tfd.Normal(loc=tf.Variable(0., name='loc'), scale=1.)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "normal.trainable_variables" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now it has trainable variables. And this distribution is now trainable distribution.\n", "\n", "For example, we can use it like this, (for the case of negative log likelihood)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def nll(X_train):\n", " return -tf.reduce_mean(normal.log_prob(X_train))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function may return the tensor which have same shape of `X_train`. If we can assume that the training data is under IID (Independently and Indentically distributed) assumption, then the log probability of our data will be the sum of log probability of each data point." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the training loop," ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@tf.function\n", "def get_loss_and_grads(X_train):\n", " with tf.GradientTape() as tape:\n", " tape.watch(normal.trainable_variables)\n", " loss = nll(X_train)\n", " grads = tape.gradient(loss, normal.trainable_variables)\n", " return loss, grads" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that, we can speed up the computation of gradient with python decorator `@tf.function`. And it makes a computation graph out of the function.\n", "\n", "After that, we can make a loop for training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)\n", "\n", "for _ in range(num_steps):\n", " loss, grads = get_loss_and_grads(X_sample)\n", " optimizer.apply_gradients(zip(grads, normal.trainable_variables))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tutorial" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import CountVectorizer\n", "from sklearn.naive_bayes import BernoulliNB\n", "from sklearn.metrics import f1_score" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Define an exponential distribution\n", "exponential = tfd.Exponential(rate=0.3, name='exp')" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFlCAYAAADYqP0MAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWB0lEQVR4nO3df6xfZ30f8PenThMm2NrQWFWXH9hQV2s6qmS7DZrasaoDGhYpYRIUMyEFKZLXiaid0KS66xSYq0qGblUrLWvJSiRajbkpdJuluMqyQvdDXcAOBKiDMkzqEluMuIS2i6BJnXz2xz2wLzfX8de+38f31+slXd1znnPO9/u5R0fOO8/znHOquwMAwGJ923oXAACwFQlZAAADCFkAAAMIWQAAAwhZAAADCFkAAANctt4FrHTVVVf1rl271rsMAIDzevjhh/+ku3eutm2ukFVVNyf5lSQ7kvx6dx9csf0nk7wzyXNJnk6yr7sfnbb9bJI7pm0/1d0PvNh37dq1K8eOHZunLACAdVVVf3yubecdLqyqHUnuTvLGJNcneVtVXb9itw9196u7+4Yk70vyS9Ox1yfZm+QHktyc5N9OnwcAsKXNMyfrpiQnuvvx7n42yaEkt83u0N1/PrP60iTfeIz8bUkOdfcz3f1HSU5MnwcAsKXNM1x4dZInZtZPJXnNyp2q6p1J3pXk8iQ/NnPsQyuOvXqVY/cl2Zck11133Tx1AwBsaAu7u7C77+7uVyX5mST/4gKPvae7l7p7aefOVeeOAQBsKvOErNNJrp1Zv2ZqO5dDSd50kccCAGwJ84Sso0n2VNXuqro8yxPZD8/uUFV7ZlZvSfL5aflwkr1VdUVV7U6yJ8kn1l42AMDGdt45Wd19tqruTPJAlh/hcG93H6+qA0mOdffhJHdW1euS/GWSrya5fTr2eFXdl+TRJGeTvLO7nxv0twAAbBjV3eff6xJaWlpqz8kCADaDqnq4u5dW2+a1OgAAAwhZAAADCFkAAAMIWQAAAwhZAAADzPNanS1p1/77X9B28uAt61AJALAV6ckCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABjgsvUuYCPZtf/+F7SdPHjLOlQCAGx2erIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGmCtkVdXNVfVYVZ2oqv2rbH9XVT1aVZ+pqt+rqlfMbHuuqh6Zfg4vsngAgI3qsvPtUFU7ktyd5PVJTiU5WlWHu/vRmd0+lWSpu79WVf8kyfuSvHXa9vXuvmGxZQMAbGzz9GTdlOREdz/e3c8mOZTkttkduvtj3f21afWhJNcstkwAgM1lnpB1dZInZtZPTW3nckeS351Zf0lVHauqh6rqTRdeIgDA5nPe4cILUVVvT7KU5O/NNL+iu09X1SuTfLSqPtvdX1hx3L4k+5LkuuuuW2RJAADrYp6erNNJrp1Zv2Zq+xZV9bokP5fk1u5+5hvt3X16+v14kt9PcuPKY7v7nu5e6u6lnTt3XtAfAACwEc0Tso4m2VNVu6vq8iR7k3zLXYJVdWOS92c5YD05035lVV0xLV+V5IeTzE6YBwDYks47XNjdZ6vqziQPJNmR5N7uPl5VB5Ic6+7DSX4xycuS/HZVJckXu/vWJN+f5P1V9XyWA93BFXclAgBsSXPNyeruI0mOrGi7a2b5dec47g+SvHotBQIAbEae+A4AMICQBQAwgJAFADDAQp+TtRXt2n//C9pOHrxlHSoBADYTPVkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAANctt4FbEa79t//graTB29Zh0oAgI1KTxYAwABCFgDAAEIWAMAAc4Wsqrq5qh6rqhNVtX+V7e+qqker6jNV9XtV9YqZbbdX1eenn9sXWTwAwEZ13pBVVTuS3J3kjUmuT/K2qrp+xW6fSrLU3T+Y5MNJ3jcd+/Ik707ymiQ3JXl3VV25uPIBADameXqybkpyorsf7+5nkxxKctvsDt39se7+2rT6UJJrpuUfT/Jgdz/V3V9N8mCSmxdTOgDAxjVPyLo6yRMz66emtnO5I8nvXsixVbWvqo5V1bEzZ87MURIAwMa20InvVfX2JEtJfvFCjuvue7p7qbuXdu7cuciSAADWxTwh63SSa2fWr5navkVVvS7JzyW5tbufuZBjAQC2mnlC1tEke6pqd1VdnmRvksOzO1TVjUnen+WA9eTMpgeSvKGqrpwmvL9hagMA2NLO+1qd7j5bVXdmORztSHJvdx+vqgNJjnX34SwPD74syW9XVZJ8sbtv7e6nqurnsxzUkuRAdz815C/ZgLx+BwC2r7neXdjdR5IcWdF218zy617k2HuT3HuxBQIAbEae+A4AMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMICQBQAwgJAFADDAZetdwFaxa//9F73fyYO3LLocAGCd6ckCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGEDIAgAYQMgCABhAyAIAGOCy9S6AZNf++1/QdvLgLetQCQCwKHqyAAAGELIAAAYQsgAABhCyAAAGELIAAAaYK2RV1c1V9VhVnaiq/atsf21VfbKqzlbVm1dse66qHpl+Di+qcACAjey8j3Coqh1J7k7y+iSnkhytqsPd/ejMbl9M8o4k/2yVj/h6d9+w9lIBADaPeZ6TdVOSE939eJJU1aEktyX5Zsjq7pPTtucH1AgAsOnMM1x4dZInZtZPTW3zeklVHauqh6rqTavtUFX7pn2OnTlz5gI+GgBgY7oUE99f0d1LSf5Rkl+uqlet3KG77+nupe5e2rlz5yUoCQBgrHlC1ukk186sXzO1zaW7T0+/H0/y+0luvID6AAA2pXlC1tEke6pqd1VdnmRvkrnuEqyqK6vqimn5qiQ/nJm5XAAAW9V5Q1Z3n01yZ5IHknwuyX3dfbyqDlTVrUlSVT9UVaeSvCXJ+6vq+HT49yc5VlWfTvKxJAdX3JUIALAlzXN3Ybr7SJIjK9rumlk+muVhxJXH/UGSV6+xRgCATccT3wEABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGmOvdhVx6u/bf/4K2kwdvWYdKAICLoScLAGAAIQsAYAAhCwBgACELAGAAIQsAYAAhCwBgACELAGAAIQsAYAAhCwBgAE983wI8HR4ANh4haxNZLUwBABuT4UIAgAGELACAAQwXblHmaQHA+tKTBQAwgJAFADCAkAUAMICQBQAwgJAFADCAkAUAMIBHOGwjHusAAJeOniwAgAGELACAAYQsAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAAeYKWVV1c1U9VlUnqmr/KttfW1WfrKqzVfXmFdtur6rPTz+3L6pwAICN7Lwhq6p2JLk7yRuTXJ/kbVV1/YrdvpjkHUk+tOLYlyd5d5LXJLkpybur6sq1lw0AsLHN05N1U5IT3f14dz+b5FCS22Z36O6T3f2ZJM+vOPbHkzzY3U9191eTPJjk5gXUDQCwoc0Tsq5O8sTM+qmpbR5rORYAYNPaEBPfq2pfVR2rqmNnzpxZ73IAANZsnpB1Osm1M+vXTG3zmOvY7r6nu5e6e2nnzp1zfjQAwMY1T8g6mmRPVe2uqsuT7E1yeM7PfyDJG6rqymnC+xumNgCALe28Iau7zya5M8vh6HNJ7uvu41V1oKpuTZKq+qGqOpXkLUneX1XHp2OfSvLzWQ5qR5McmNoAALa0y+bZqbuPJDmyou2umeWjWR4KXO3Ye5Pcu4YaAQA2nQ0x8R0AYKsRsgAABphruJCta9f++1/QdvLgLetQCQBsLXqyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAbwMFLm4qGlAHBhhCxeYLVABQBcGMOFAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADXLbeBbB57dp//wvaTh68ZR0qAYCNR08WAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwAAeRspwHloKwHakJwsAYAAhCwBgAMOFLNRqQ4MAsB3pyQIAGEDIAgAYwHAh68IdhwBsdXqyAAAGELIAAAYQsgAABpgrZFXVzVX1WFWdqKr9q2y/oqp+a9r+8araNbXvqqqvV9Uj08+vLbh+AIAN6bwT36tqR5K7k7w+yakkR6vqcHc/OrPbHUm+2t3fW1V7k7w3yVunbV/o7hsWWzYAwMY2T0/WTUlOdPfj3f1skkNJbluxz21JPjgtfzjJ36+qWlyZAACbyzwh6+okT8ysn5raVt2nu88m+bMk3zVt211Vn6qq/1ZVf3e1L6iqfVV1rKqOnTlz5oL+AACAjWj0xPcvJbmuu29M8q4kH6qqv7Zyp+6+p7uXuntp586dg0sCABhvnpB1Osm1M+vXTG2r7lNVlyX5jiRf6e5nuvsrSdLdDyf5QpLvW2vRAAAb3Twh62iSPVW1u6ouT7I3yeEV+xxOcvu0/OYkH+3urqqd08T5VNUrk+xJ8vhiSgcA2LjOe3dhd5+tqjuTPJBkR5J7u/t4VR1Icqy7Dyf5QJLfrKoTSZ7KchBLktcmOVBVf5nk+SQ/2d1PjfhD2Jq8fgeAzWqudxd295EkR1a03TWz/BdJ3rLKcR9J8pE11sg2sVqgAoDNyhPfAQAGELIAAAYQsgAABphrThZsJCbDA7AZ6MkCABhATxZbwrnuTNTDBcB60ZMFADCAniyIeV4ALJ6eLACAAfRksaXpoQJgvQhZcA4CGgBrYbgQAGAAIQsAYADDhWw753qmFgAskp4sAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAATzxHS6Al0YDMC89WQAAAwhZAAADGC6EAQwrAiBkwToSxgC2LiEL1mi1oAQAQhZsMHq3ALYGE98BAAYQsgAABjBcCJvUhQwrGoIEuPSELLhETJAH2F6ELNhCBDmAjUPIgk1gRHgyhAgwlonvAAAD6MkCvknvFsDiCFnAuhPugK1IyAKGEZ6A7UzIAi7YWsLTpbgDUrgDNgIhC9jWBDJgFCEL2JCEH2CzE7KAFzXv8N56DQMCbFRzhayqujnJryTZkeTXu/vgiu1XJPmNJH87yVeSvLW7T07bfjbJHUmeS/JT3f3AwqoHmJOABlxq5w1ZVbUjyd1JXp/kVJKjVXW4ux+d2e2OJF/t7u+tqr1J3pvkrVV1fZK9SX4gyV9P8l+r6vu6+7lF/yEAm8m8w6Ebbdh0o9UDG9k8PVk3JTnR3Y8nSVUdSnJbktmQdVuS90zLH07yb6qqpvZD3f1Mkj+qqhPT5/2vxZQPsHiL7vVadAiZt77NEH42a9hk49mI18g8IevqJE/MrJ9K8ppz7dPdZ6vqz5J819T+0Ipjr77oagE2ofWa17aWALPW77kUn3cpAtpa/7a1/Ef+UoSGEd+xEcPOetkQE9+ral+SfdPq01X12CX42quS/Mkl+J7tyLkdy/kdZ8uf23rvun798PM77993qc7Dor/nRT5vYed2xLlZr+tugd/7Yuf3Fec6aJ6QdTrJtTPr10xtq+1zqqouS/IdWZ4AP8+x6e57ktwzRy0LU1XHunvpUn7nduHcjuX8juPcjuX8juPcjnWx5/fb5tjnaJI9VbW7qi7P8kT2wyv2OZzk9mn5zUk+2t09te+tqiuqaneSPUk+caFFAgBsNuftyZrmWN2Z5IEsP8Lh3u4+XlUHkhzr7sNJPpDkN6eJ7U9lOYhl2u++LE+SP5vkne4sBAC2g7nmZHX3kSRHVrTdNbP8F0neco5jfyHJL6yhxlEu6fDkNuPcjuX8juPcjuX8juPcjnVR57eWR/UAAFikeeZkAQBwgbZdyKqqm6vqsao6UVX717ueraaqTlbVZ6vqkao6tt71bHZVdW9VPVlVfzjT9vKqerCqPj/9vnI9a9ysznFu31NVp6fr95Gq+gfrWeNmVVXXVtXHqurRqjpeVT89tbt2F+BFzq/rd42q6iVV9Ymq+vR0bv/l1L67qj4+ZYffmm4EPP/nbafhwukVQf87M68ISvK2Fa8IYg2q6mSSpe7e0s8aulSq6rVJnk7yG939N6e29yV5qrsPTv+jcGV3/8x61rkZnePcvifJ0939r9azts2uqr4nyfd09yer6q8meTjJm5K8I67dNXuR8/sTcf2uyfS2mpd299NV9e1J/meSn07yriS/092HqurXkny6u3/1fJ+33XqyvvmKoO5+Nsk3XhEEG1J3//cs37E767YkH5yWP5jlf1y5QOc4tyxAd3+puz85Lf/fJJ/L8ts+XLsL8CLnlzXqZU9Pq98+/XSSH8vyawOTC7h2t1vIWu0VQS7Mxeok/6WqHp6e5M/ifXd3f2la/j9Jvns9i9mC7qyqz0zDiYaz1qiqdiW5McnH49pduBXnN3H9rllV7aiqR5I8meTBJF9I8qfdfXbaZe7ssN1CFuP9SHf/rSRvTPLOaUiGQaaH/m6fMf/xfjXJq5LckORLSf71ulazyVXVy5J8JMk/7e4/n93m2l27Vc6v63cBuvu57r4hy2+puSnJ37jYz9puIWuu1/xw8br79PT7yST/McsXKIv15WlOxjfmZjy5zvVsGd395ekf2OeT/Lu4fi/aNJ/lI0n+fXf/ztTs2l2Q1c6v63exuvtPk3wsyd9J8p3TawOTC8gO2y1kzfOKIC5SVb10moSZqnppkjck+cMXP4qLMPsaq9uT/Od1rGVL+UYAmPzDuH4vyjR5+ANJPtfdvzSzybW7AOc6v67ftauqnVX1ndPyX8nyjXKfy3LYevO029zX7ra6uzBJpltafzn//xVBG/Fp9JtSVb0yy71XyfLbBD7k/K5NVf2HJD+a5TfAfznJu5P8pyT3JbkuyR8n+YnuNoH7Ap3j3P5olodaOsnJJP94Zg4Rc6qqH0nyP5J8NsnzU/M/z/K8IdfuGr3I+X1bXL9rUlU/mOWJ7Tuy3BF1X3cfmP77dijJy5N8Ksnbu/uZ837edgtZAACXwnYbLgQAuCSELACAAYQsAIABhCwAgAGELACAAYQsAIABhCwAgAGELACAAf4fpone9R8nIi4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot\n", "plt.hist(exponential.sample(5000).numpy(), bins=100, density=True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(,)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define an exponential distribution with a trainable rate parameter\n", "exponential_train = tfd.Exponential(rate=tf.Variable(1., name='rate'), name='exp_train')\n", "exponential_train.trainable_variables" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Define a negative log likelihood\n", "def nll(X_train, distribution):\n", " return -tf.reduce_mean(distribution.log_prob(X_train))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Define a function to compute loss and gradients\n", "@tf.function\n", "def get_loss_and_grads(X_train, distribution):\n", " with tf.GradientTape() as tape:\n", " tape.watch(distribution.trainable_variables)\n", " loss = nll(X_train, distribution)\n", " grads = tape.gradient(loss, distribution.trainable_variables)\n", " return loss, grads" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Optimize\n", "def exponential_dist_optimization(data, distribution):\n", " # Keep results for plotting\n", " train_loss_results = []\n", " train_rate_results = []\n", " \n", " optimizer = tf.keras.optimizers.SGD(learning_rate=0.05)\n", " \n", " num_steps = 10\n", " \n", " for i in range(num_steps):\n", " loss, grads = get_loss_and_grads(data, distribution)\n", " optimizer.apply_gradients(zip(grads, distribution.trainable_variables))\n", " \n", " rate_value = distribution.rate.value()\n", " train_loss_results.append(loss)\n", " train_rate_results.append(rate_value)\n", " \n", " print(\"Step {:03d}: Loss: {:.3f}: Rate: {:.3f}\".format(i, loss, rate_value))\n", " \n", " return train_loss_results, train_rate_results" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 000: Loss: 2.896: Rate: 0.669\n", "Step 001: Loss: 2.682: Rate: 0.573\n", "Step 002: Loss: 2.510: Rate: 0.490\n", "Step 003: Loss: 2.383: Rate: 0.422\n", "Step 004: Loss: 2.301: Rate: 0.370\n", "Step 005: Loss: 2.255: Rate: 0.335\n", "Step 006: Loss: 2.235: Rate: 0.314\n", "Step 007: Loss: 2.228: Rate: 0.303\n", "Step 008: Loss: 2.227: Rate: 0.297\n", "Step 009: Loss: 2.226: Rate: 0.295\n" ] } ], "source": [ "# Get some data and train\n", "sampled_data = exponential.sample(5000)\n", "train_loss_results, train_rate_results = exponential_dist_optimization(data=sampled_data, distribution=exponential_train)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Exact rate: 0.3\n", "Predicted rate: 0.29511935\n" ] } ], "source": [ "# Predicted value for the rate parameter\n", "\n", "pred_value = exponential_train.rate.numpy()\n", "exact_value = exponential.rate.numpy()\n", "\n", "print(\"Exact rate: \", exact_value)\n", "print(\"Predicted rate: \", pred_value)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot to see the convergence of the estimated and true parameters\n", "tensor_exact_value = tf.constant(exact_value, shape=[len(train_rate_results)])\n", "\n", "fig, ax = plt.subplots(2, sharex=True)\n", "fig.suptitle('Convergence')\n", "\n", "ax[0].set_ylabel('Loss', fontsize=14)\n", "ax[0].plot(train_loss_results)\n", "\n", "ax[1].set_ylabel('Rate', fontsize=14)\n", "ax[1].set_xlabel('Epoch', fontsize=14)\n", "ax[1].plot(train_rate_results, label='trainable rate variable')\n", "ax[1].plot(tensor_exact_value, label='exact rate')\n", "ax[1].legend()\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# Making a function get_data which:\n", "# 1) Fetches the 20 newsgroup dataset\n", "# 2) Performs a word count on the articles and binarizes the result\n", "# 3) Returns the data as a numpy matrix with the labels\n", "\n", "def get_data(categories):\n", "\n", " newsgroups_train_data = fetch_20newsgroups(data_home='./dataset/20_Newsgroup_Data/',\n", " subset='train', categories=categories)\n", " newsgroups_test_data = fetch_20newsgroups(data_home='./dataset/20_Newsgroup_Data/',\n", " subset='test', categories=categories)\n", "\n", " n_documents = len(newsgroups_train_data['data'])\n", " count_vectorizer = CountVectorizer(input='content', binary=True,max_df=0.25, min_df=1.01/n_documents) \n", " train_binary_bag_of_words = count_vectorizer.fit_transform(newsgroups_train_data['data']) \n", " test_binary_bag_of_words = count_vectorizer.transform(newsgroups_test_data['data']) \n", "\n", " return (train_binary_bag_of_words.todense(), newsgroups_train_data['target']), (test_binary_bag_of_words.todense(), newsgroups_test_data['target'])" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# Defining a function to conduct laplace smoothing. This adds a base level of probability for a given feature\n", "# to occur in every class.\n", "\n", "def laplace_smoothing(labels, binary_data, n_classes):\n", " # Compute the parameter estimates (adjusted fraction of documents in class that contain word)\n", " n_words = binary_data.shape[1]\n", " alpha = 1 # parameters for Laplace smoothing\n", " theta = np.zeros([n_classes, n_words]) # stores parameter values - prob. word given class\n", " for c_k in range(n_classes): # 0, 1, ..., 19\n", " class_mask = (labels == c_k)\n", " N = class_mask.sum() # number of articles in class\n", " theta[c_k, :] = (binary_data[class_mask, :].sum(axis=0) + alpha)/(N + alpha*2)\n", "\n", " return theta" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# Now we will do a function that given the feature occurence counts returns a Bernoulli distribution of \n", "# batch_shape=number of classes and event_shape=number of features.\n", "\n", "def make_distributions(probs):\n", " batch_of_bernoullis = tfd.Bernoulli(probs=probs) # shape (n_classes, n_words)\n", " dist = tfd.Independent(batch_of_bernoullis, reinterpreted_batch_ndims=1)\n", " return dist" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# Function which computes the prior probability of every class based on frequency of occurence in \n", "# the dataset\n", "\n", "def class_priors(n_classes, labels):\n", " counts = np.zeros(n_classes)\n", " for c_k in range(n_classes):\n", " counts[c_k] = np.sum(np.where(labels==c_k, 1, 0))\n", " priors = counts / np.sum(counts)\n", " print('The class priors are {}'.format(priors))\n", " return priors" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "# The final function predict_sample which given the distribution, a test sample, and the class priors:\n", "# 1) Computes the class conditional probabilities given the sample\n", "# 2) Forms the joint likelihood\n", "# 3) Normalises the joint likelihood and returns the log prob\n", "\n", "def predict_sample(dist, sample, priors):\n", " cond_probs = dist.log_prob(sample)\n", " joint_likelihood = tf.add(np.log(priors), cond_probs)\n", " norm_factor = tf.math.reduce_logsumexp(joint_likelihood, axis=-1, keepdims=True)\n", " log_prob = joint_likelihood - norm_factor\n", "\n", " return log_prob" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "# Now we learn the distribution using gradient tape\n", "\n", "def make_distribution_withGT(data, labels, nb_classes):\n", "\n", " class_data = []\n", " train_vars = []\n", " distributions = []\n", " for c in range(nb_classes):\n", " train_vars.append(tf.Variable(initial_value=np.random.uniform(low=0.01, high =0.1, size=data.shape[-1])))\n", " distributions.append(tfd.Bernoulli(probs=train_vars[c]))\n", " class_mask = (labels == c)\n", " class_data.append(data[class_mask, :])\n", "\n", " for c_num in range(0,nb_classes):\n", " optimizer = tf.keras.optimizers.Adam()\n", " print('\\n%-------------------%')\n", " print('Class ', c_num)\n", " print('%-------------------%')\n", "\n", " for i in range(0, 100):\n", " loss, grads = get_loss_and_grads(class_data[c_num], distributions[c_num])\n", " if i % 10 == 0:\n", " print(\"iter: {}, Loss: {}\".format(i, loss))\n", " optimizer.apply_gradients(zip(grads, distributions[c_num].trainable_variables))\n", " eta = 1e-3\n", " clipped_probs = tf.clip_by_value(distributions[c_num].trainable_variables,\n", " clip_value_min=eta, clip_value_max=1)\n", " \n", " train_vars[c_num] = tf.squeeze(clipped_probs)\n", "\n", " dist = tfd.Bernoulli(probs=train_vars)\n", " dist = tfd.Independent(dist,reinterpreted_batch_ndims=1)\n", "\n", " print(dist)\n", "\n", " return dist\n" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The class priors are [0.2359882 0.28711898 0.29154376 0.18534907]\n" ] } ], "source": [ "# Make the same Naive Bayes classifier we did last tutorial\n", "\n", "categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']\n", "\n", "(train_data, train_labels), (test_data, test_labels) = get_data(categories)\n", "\n", "smoothed_counts = laplace_smoothing(labels=train_labels, binary_data=train_data, n_classes=len(categories))\n", "\n", "priors = class_priors(n_classes=len(categories), labels=train_labels)\n", "tf_dist = make_distributions(smoothed_counts)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "%-------------------%\n", "Class 0\n", "%-------------------%\n", "iter: 0, Loss: 0.07864662925861293\n", "iter: 10, Loss: 0.06923920529302693\n", "iter: 20, Loss: 0.060484430932711934\n", "iter: 30, Loss: 0.052378958091660745\n", "iter: 40, Loss: 0.044884874447401975\n", "iter: 50, Loss: 0.03795957675935009\n", "iter: 60, Loss: 0.03156166893228506\n", "iter: 70, Loss: 0.025648909539426928\n", "iter: 80, Loss: 0.020179556307287093\n", "iter: 90, Loss: 0.015095419046706705\n", "\n", "%-------------------%\n", "Class 1\n", "%-------------------%\n", "iter: 0, Loss: 0.07162433404608501\n", "iter: 10, Loss: 0.06226791554203955\n", "iter: 20, Loss: 0.053458417310592254\n", "iter: 30, Loss: 0.04525733720016119\n", "iter: 40, Loss: 0.03764243521857365\n", "iter: 50, Loss: 0.0305879746963904\n", "iter: 60, Loss: 0.02407123784997883\n", "iter: 70, Loss: 0.018063326103411242\n", "iter: 80, Loss: 0.012530662501078415\n", "iter: 90, Loss: 0.007417711392007358\n", "\n", "%-------------------%\n", "Class 2\n", "%-------------------%\n", "iter: 0, Loss: 0.07864916432960509\n", "iter: 10, Loss: 0.06954586738662134\n", "iter: 20, Loss: 0.061138999776087846\n", "iter: 30, Loss: 0.05346207920199955\n", "iter: 40, Loss: 0.046474524854562514\n", "iter: 50, Loss: 0.040144255228274424\n", "iter: 60, Loss: 0.03443733650612573\n", "iter: 70, Loss: 0.029299458896811317\n", "iter: 80, Loss: 0.024681602429558972\n", "iter: 90, Loss: 0.020525606754514734\n", "\n", "%-------------------%\n", "Class 3\n", "%-------------------%\n", "iter: 0, Loss: 0.07990305803193348\n", "iter: 10, Loss: 0.07064669667549849\n", "iter: 20, Loss: 0.062048971145070374\n", "iter: 30, Loss: 0.05407979822912391\n", "iter: 40, Loss: 0.04669298874331363\n", "iter: 50, Loss: 0.0398430034890397\n", "iter: 60, Loss: 0.033480511857230395\n", "iter: 70, Loss: 0.02756906082189042\n", "iter: 80, Loss: 0.022072121868228288\n", "iter: 90, Loss: 0.016940884899701955\n", "tfp.distributions.Independent(\"IndependentBernoulli\", batch_shape=[4], event_shape=[17495], dtype=int32)\n" ] } ], "source": [ "# Now train the distributions with gradient tape\n", "GT_dist = make_distribution_withGT(data=train_data, labels=train_labels, nb_classes=4)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "f1 0.8265056782070946\n", "f1 0.7848499112849504\n" ] } ], "source": [ "# Compare the two results\n", "\n", "for dist in [GT_dist,tf_dist]:\n", " probabilities = []\n", " for sample, label in zip(test_data, test_labels):\n", " probabilities.append(predict_sample(dist, sample, priors))\n", "\n", " probabilities = np.asarray(probabilities)\n", " predicted_classes = np.argmax(probabilities, axis =-1)\n", " print('f1 ', f1_score(test_labels, predicted_classes, average='macro'))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }