{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# The TransformedDistribution class\n", "\n", "> In this post, we are going to take a look at transform distribution objects as a module. This is the summary of lecture \"Probabilistic Deep Learning with Tensorflow 2\" from Imperial College London.\n", "\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- author: Chanseok Kang\n", "- categories: [Python, Coursera, Tensorflow_probability, ICL]\n", "- image: images/transformed_dist.png" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import tensorflow_probability as tfp\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "tfd = tfp.distributions\n", "tfpl = tfp.layers\n", "tfb = tfp.bijectors\n", "\n", "plt.rcParams['figure.figsize'] = (10, 6)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensorflow Version: 2.5.0\n", "Tensorflow Probability Version: 0.13.0\n" ] } ], "source": [ "print(\"Tensorflow Version: \", tf.__version__)\n", "print(\"Tensorflow Probability Version: \", tfp.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n", "\n", "The transformedDistribution is sort of distribution that can be defined by another base distribution and a bijector object. Tensorflow Probability offers [transformed distribution](https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/TransformedDistribution) object with consistent API that can use same methods and properties of other distribution." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "normal = tfd.Normal(loc=0., scale=1.)\n", "z = normal.sample(3)\n", "z" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scale_and_shift = tfb.Chain([tfb.Shift(1.), tfb.Scale(2.)])\n", "x = scale_and_shift.forward(z)\n", "x" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_prob_z = normal.log_prob(z)\n", "log_prob_z" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_prob_x = (log_prob_z - scale_and_shift.forward_log_det_jacobian(z, event_ndims=0))\n", "log_prob_x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that, the `event_ndims` argument means the number of rightmost dimensions of z make up the event shape. So in the above case, the log of the jacobian determinant is calculated for each element of the tensor z." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or we express it with the inverse of the bijective transformation." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_prob_x = (log_prob_z + scale_and_shift.inverse_log_det_jacobian(x, event_ndims=0))\n", "log_prob_x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is the same as while using inverse of x." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_prob_x = (normal.log_prob(scale_and_shift.inverse(x)) + scale_and_shift.inverse_log_det_jacobian(x, event_ndims=0))\n", "log_prob_x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You may notice that log probability of x can be calculated with only using z or x. In practice, most of cases uses second expression. The reason is that the z is from base distriubtion. So in terms of analysis, it is the latent variable. But x is from the data distribution, and it is the output from transformed distribution. While using mentioned approach, we can express transform object with bijector or invertible, it can be learned with best parameters for maximum likelihood." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "# Base distribution Transformation Data distribution \n", "# z ~ P0 <=> x = f(z) <=> x ~ P1\n", "\n", "log_prob_x = (base_dist.log_prob(bijector.inverse(x)) + bijector.inverse_log_det_jacobian(x, event_ndims=0))\n", "\n", "### Training \n", "\n", "x_sample = bijector.forward(base_dist.sample())\n", "```" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "normal = tfd.Normal(loc=0., scale=1.)\n", "z = normal.sample(3)\n", "z" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp = tfb.Exp()\n", "x = exp.forward(z)\n", "x" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_normal = tfd.TransformedDistribution(normal, exp)\n", "log_normal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Above expression is same with like this," ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_normal = exp(normal)\n", "log_normal" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_normal.sample()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_normal.log_prob(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also define specific `event_shape` and `batch_shape` for transformedDistribtion." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "normal = tfd.Normal(loc=0., scale=1.)\n", "scale_tril = [[1., 0.], [1., 1.]]\n", "scale = tfb.ScaleMatvecTriL(scale_tril=scale_tril)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Multivariate Normal distribution\n", "mvn = tfd.TransformedDistribution(tfd.Sample(normal, sample_shape=[2]), scale)\n", "mvn" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scale_tril = [[[1., 0.], [1., 1.]], [[0.5, 0.], [-1., 0.5]]]\n", "scale = tfb.ScaleMatvecTriL(scale_tril=scale_tril)\n", "\n", "mvn = tfd.TransformedDistribution(tfd.Sample(tfd.Normal(loc=[0., 0.], scale=1.), sample_shape=[2], ), scale)\n", "mvn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tutorial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TransformedDistribution" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# Parameters\n", "\n", "n = 10000\n", "loc = 0\n", "scale = 0.5" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# Normal distribution\n", "\n", "normal = tfd.Normal(loc=loc, scale=scale)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "batch shape: ()\n", "event shape: ()\n" ] } ], "source": [ "# Display event and batch shape\n", "\n", "print('batch shape: ', normal.batch_shape)\n", "print('event shape: ', normal.event_shape)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "# Exponential bijector\n", "exp = tfb.Exp()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# log normal transformed distribution using exp bijector and normal distribution\n", "\n", "log_normal_td = exp(normal)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "batch shape: ()\n", "event shape: ()\n" ] } ], "source": [ "# Display event and batch shape\n", "\n", "print('batch shape: ', log_normal_td.batch_shape)\n", "print('event shape: ', log_normal_td.event_shape)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# Base distribution\n", "\n", "z = normal.sample(n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plots" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# Plot z density" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlUAAAFlCAYAAADClB2CAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAXDUlEQVR4nO3dcayd530X8O9vzrJKW9kKudtK4tQBPDEzpnZc0k5FULYWnFqyC22HMw0a1GKGlq1sFcJlU6jCP24RmwRko1lXtUy0aQhseKpHKG2nAVor33ahnROymeAtDoV4bemYtjXz+uOPe1JOb8/1Pbafe++5934+kpXzvufxeX959N7r73ne533e6u4AAHB9vmK7CwAA2A2EKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABbtiuA99000194MCB7To8AMDcPvaxj/1mdy9dqc22haoDBw5kZWVluw4PADC3qvr1jdq4/AcAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADDAXKGqqg5X1eNVdb6qTs54/9aq+nBV/XJVfaKqXjm+VACAxbVhqKqqfUnuS3JHkkNJ7qyqQ2ua/UiSB7v7RUmOJ/nx0YUCACyyeUaqbk9yvruf6O5nkjyQ5NiaNp3kD01ef22S/zmuRACAxTdPqLo5yZNT2xcn+6a9Jcn3VNXFJGeSfP+sD6qqE1W1UlUrly5duoZyAQAW0w2DPufOJO/q7n9SVd+e5Ker6lu6+wvTjbr7/iT3J8ny8nIPOjbAFx04+f4v23fh1JFtqATYa+YJVU8l2T+1fctk37TXJzmcJN39S1X1nCQ3JXl6RJHA7iUEAbvFPKHqbJKDVXVbVsPU8STfvabNbyT5ziTvqqpvTvKcJK7vAcMIX8Ci2zBUdfflqro7ycNJ9iV5Z3efq6p7k6x09+kkb0ryk1X1g1mdtH5Xd7u8B1yTWQEKYNHNNaequ89kdQL69L57pl4/muSlY0sDANg5rKgOADCAUAUAMMCoJRUAdjyT4YHrYaQKAGAAoQoAYAChCgBgAHOqgC1j/SlgNzNSBQAwgFAFADCAUAUAMIA5VQBXab25Yda0gr3NSBUAwABCFQDAAC7/AXvy8SyWdwBGM1IFADCAUAUAMIBQBQAwgDlVwI5lXhSwSIxUAQAMYKQK2BRGkYC9xkgVAMAAQhUAwABCFQDAAOZUATPtxVXWAa6HkSoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAdz9B1w3q6cDzDlSVVWHq+rxqjpfVSdnvP9jVfXI5M+vVtX/GV4pAMAC23Ckqqr2JbkvySuSXExytqpOd/ejz7bp7h+cav/9SV60CbUCACyseUaqbk9yvruf6O5nkjyQ5NgV2t+Z5L0jigMA2CnmCVU3J3lyavviZN+XqaoXJLktyYfWef9EVa1U1cqlS5eutlYAgIU1+u6/40ke6u4/mPVmd9/f3cvdvby0tDT40AAA22eeUPVUkv1T27dM9s1yPC79AQB70Dyh6mySg1V1W1XdmNXgdHpto6r6k0mel+SXxpYIALD4NgxV3X05yd1JHk7yWJIHu/tcVd1bVUenmh5P8kB39+aUCgCwuOZa/LO7zyQ5s2bfPWu23zKuLACAncVjagAABvCYGmBuHkcDsD4jVQAAAwhVAAADCFUAAAMIVQAAAwhVAAADCFUAAANYUgF2iVnLHVw4dWQbKlk8loIAtoKRKgCAAYxUAQxitBD2NiNVAAADCFUAAAMIVQAAAwhVAAADCFUAAAMIVQAAA1hSAWATWWYB9g4jVQAAAwhVAAADCFUAAAOYUwVwBR7GDMzLSBUAwABCFQDAAEIVAMAAQhUAwABCFQDAAEIVAMAAQhUAwABCFQDAAHOFqqo6XFWPV9X5qjq5TpvvqqpHq+pcVb1nbJkAAIttwxXVq2pfkvuSvCLJxSRnq+p0dz861eZgkjcneWl3f7aqvn6zCgbY6Wat0n7h1JFtqAQYaZ6RqtuTnO/uJ7r7mSQPJDm2ps3fSnJfd382Sbr76bFlAgAstnlC1c1JnpzavjjZN+2bknxTVf2XqvpIVR2e9UFVdaKqVqpq5dKlS9dWMQDAAho1Uf2GJAeTvCzJnUl+sqq+bm2j7r6/u5e7e3lpaWnQoQEAtt+Gc6qSPJVk/9T2LZN90y4m+Wh3/36S/1FVv5rVkHV2SJXANTF3B2DrzDNSdTbJwaq6rapuTHI8yek1bX42q6NUqaqbsno58IlxZQIALLYNQ1V3X05yd5KHkzyW5MHuPldV91bV0Umzh5N8uqoeTfLhJH+vuz+9WUUDACya6u5tOfDy8nKvrKxsy7FhN5p1qY+dzaVaWBxV9bHuXr5SGyuqAwAMMM9EdWAbmWwOsDMYqQIAGECoAgAYQKgCABhAqAIAGECoAgAYQKgCABhAqAIAGECoAgAYQKgCABhAqAIAGECoAgAYwLP/YAea9TxAALaXkSoAgAGEKgCAAVz+A1hQsy7zXjh1ZBsqAeZhpAoAYAChCgBgAKEKAGAAoQoAYAChCgBgAKEKAGAAoQoAYADrVMEC8fgZgJ1LqALYQSwICovL5T8AgAGEKgCAAYQqAIABhCoAgAHmClVVdbiqHq+q81V1csb7d1XVpap6ZPLnDeNLBQBYXBve/VdV+5Lcl+QVSS4mOVtVp7v70TVN39fdd29CjQAAC2+ekarbk5zv7ie6+5kkDyQ5trllAQDsLPOEqpuTPDm1fXGyb61XV9Unquqhqto/64Oq6kRVrVTVyqVLl66hXACAxTRqovrPJTnQ3d+a5ANJ3j2rUXff393L3b28tLQ06NAAANtvnlD1VJLpkadbJvu+qLs/3d2fn2y+I8mfGVMeAMDOME+oOpvkYFXdVlU3Jjme5PR0g6p6/tTm0SSPjSsRAGDxbXj3X3dfrqq7kzycZF+Sd3b3uaq6N8lKd59O8gNVdTTJ5SSfSXLXJtYMwBTPA4TFMNcDlbv7TJIza/bdM/X6zUnePLY0AICdw4rqAAADzDVSBYw365INADuXkSoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABbtjuAmAvOHDy/dtdAgCbzEgVAMAARqoAdqFZo6MXTh3Zhkpg7zBSBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIAlFWAwC30C7E1zjVRV1eGqeryqzlfVySu0e3VVdVUtjysRAGDxbRiqqmpfkvuS3JHkUJI7q+rQjHbPTfLGJB8dXSQAwKKbZ6Tq9iTnu/uJ7n4myQNJjs1o94+SvDXJ7w2sDwBgR5gnVN2c5Mmp7YuTfV9UVd+WZH93X3EySVWdqKqVqlq5dOnSVRcLALCorvvuv6r6iiQ/muRNG7Xt7vu7e7m7l5eWlq730AAAC2OeUPVUkv1T27dM9j3ruUm+JckvVNWFJC9JctpkdQBgL5knVJ1NcrCqbquqG5McT3L62Te7+3PdfVN3H+juA0k+kuRod69sSsUAAAtow1DV3ZeT3J3k4SSPJXmwu89V1b1VdXSzCwQA2AnmWvyzu88kObNm3z3rtH3Z9ZcFALCzWFEdgC8x66kAF04d2YZKYGfx7D8AgAGEKgCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAASz+Cddh1iKJsKgs6gmby0gVAMAAQhUAwABCFQDAAEIVAMAAQhUAwABCFQDAAEIVAMAAQhUAwABCFQDAAFZUhzXWWyXdytPsRp4KAOMYqQIAGMBIFczJN3oArsRIFQDAAEIVAMAAQhUAwABCFQDAAEIVAMAAQhUAwABCFQDAAEIVAMAAc4WqqjpcVY9X1fmqOjnj/e+tqk9W1SNV9Z+r6tD4UgEAFteGoaqq9iW5L8kdSQ4luXNGaHpPd//p7n5hkrcl+dHRhQIALLJ5RqpuT3K+u5/o7meSPJDk2HSD7v6tqc2vTtLjSgQAWHzzPPvv5iRPTm1fTPLitY2q6vuS/FCSG5N8x6wPqqoTSU4kya233nq1tQIALKxhE9W7+77u/uNJ/n6SH1mnzf3dvdzdy0tLS6MODQCw7eYJVU8l2T+1fctk33oeSPKq66gJAGDHmSdUnU1ysKpuq6obkxxPcnq6QVUdnNo8kuTXxpUIALD4NpxT1d2Xq+ruJA8n2Zfknd19rqruTbLS3aeT3F1VL0/y+0k+m+R1m1k0AMCimWeierr7TJIza/bdM/X6jYPrAgDYUayoDgAwwFwjVQAwjwMn3z9z/4VTR7a4Eth6RqoAAAYQqgAABhCqAAAGMKcKgA2tN1cK+P+MVAEADCBUAQAMIFQBAAwgVAEADCBUAQAM4O4/9gx3LwGwmYxUAQAMIFQBAAwgVAEADCBUAQAMIFQBAAzg7j8ANt2su28vnDqyDZXA5jFSBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADCAUAUAMIBQBQAwgFAFADDAXKGqqg5X1eNVdb6qTs54/4eq6tGq+kRVfbCqXjC+VACAxbVhqKqqfUnuS3JHkkNJ7qyqQ2ua/XKS5e7+1iQPJXnb6EIBABbZPCNVtyc5391PdPczSR5Icmy6QXd/uLt/Z7L5kSS3jC0TAGCx3TBHm5uTPDm1fTHJi6/Q/vVJfv56igJg9ztw8v1ftu/CqSPbUAmMMU+omltVfU+S5SR/YZ33TyQ5kSS33nrryEMDAGyreS7/PZVk/9T2LZN9X6KqXp7kh5Mc7e7Pz/qg7r6/u5e7e3lpaela6gUAWEjzjFSdTXKwqm7Lapg6nuS7pxtU1YuSvD3J4e5+eniVcAWzLiEAwFbbcKSquy8nuTvJw0keS/Jgd5+rqnur6uik2T9O8jVJ/nVVPVJVpzetYgCABTTXnKruPpPkzJp990y9fvngugAAdhQrqgMADCBUAQAMMHRJBRjJGjYA7CRGqgAABhCqAAAGEKoAAAYQqgAABhCqAAAGEKoAAAYQqgAABhCqAAAGsPgnO8qsBUEBYBEIVSwEYQmAnc7lPwCAAYQqAIABhCoAgAGEKgCAAYQqAIABhCoAgAGEKgCAAaxTBcDCmLVm3YVTR7ahErh6RqoAAAYQqgAABhCqAAAGMKcKgIVmnhU7hVAFwI4z70PYhS+2kst/AAADGKliy837DRPgerl0yFYyUgUAMIBQBQAwwFyhqqoOV9XjVXW+qk7OeP/PV9XHq+pyVb1mfJkAAIttw1BVVfuS3JfkjiSHktxZVYfWNPuNJHclec/oAgEAdoJ5JqrfnuR8dz+RJFX1QJJjSR59tkF3X5i894VNqBEAYOHNE6puTvLk1PbFJC/enHLYbdzpB8BesaUT1avqRFWtVNXKpUuXtvLQAACbap5Q9VSS/VPbt0z2XbXuvr+7l7t7eWlp6Vo+AgBgIc0Tqs4mOVhVt1XVjUmOJzm9uWUBAOwsG4aq7r6c5O4kDyd5LMmD3X2uqu6tqqNJUlV/tqouJnltkrdX1bnNLBoAYNHM9Zia7j6T5MyaffdMvT6b1cuCAAB7khXVAQAGEKoAAAaY6/IfzMOaVADsZUaqAAAGMFIFwJ4y76j6hVNHrvnz5v277C5GqgAABhCqAAAGcPmPDRnaBljlhhyuRKgCgBkEKK6Wy38AAAMYqeKa+AYHAF/KSBUAwABCFQDAAEIVAMAA5lTtUevNibJUAgBcGyNVAAADCFUAAAMIVQAAAwhVAAADCFUAAAO4+w8AtoC7rnc/oWqXmfVDezU/sB4/AwDXRqjaAwQlgMU17+/oWV+Qr/eLNGOZUwUAMIBQBQAwgFAFADCAOVULaN5r5OZKASymzfj97Hf+4hOqttD1/ED4YQKAxSZUbRIhCAD2FnOqAAAGmCtUVdXhqnq8qs5X1ckZ739VVb1v8v5Hq+rA8EoBABbYhpf/qmpfkvuSvCLJxSRnq+p0dz861ez1ST7b3X+iqo4neWuSv7YZBV+Nq7kEdz2LpbnUB8BOM3rhUAuRzjen6vYk57v7iSSpqgeSHEsyHaqOJXnL5PVDSf55VVV398Bat5wTBICdZiu+6M97jM34d3SR/22eJ1TdnOTJqe2LSV68XpvuvlxVn0vyR5L85ogit8L1nCAAsFv5d29+W3r3X1WdSHJisvnbVfX4Vh5/G9yUHRQst5i+mU2/rE/frE/frE/fzDazX+qtm3/gzTjG4M9c75x5wUZ/cZ5Q9VSS/VPbt0z2zWpzsapuSPK1ST699oO6+/4k989xzF2hqla6e3m761hE+mY2/bI+fbM+fbM+fTObflnf9fTNPHf/nU1ysKpuq6obkxxPcnpNm9NJXjd5/ZokH9rp86kAAK7GhiNVkzlSdyd5OMm+JO/s7nNVdW+Sle4+neSnkvx0VZ1P8pmsBi8AgD1jrjlV3X0myZk1++6Zev17SV47trRdYc9c6rwG+mY2/bI+fbM+fbM+fTObflnfNfdNuUoHAHD9PKYGAGAAoWqgqnptVZ2rqi9U1bp3DlTVhar6ZFU9UlUrW1njdrmKvrniI5F2m6r6w1X1gar6tcl/n7dOuz+YnC+PVNXaG0V2FY/FWt8cfXNXVV2aOlfesB11brWqemdVPV1Vv7LO+1VV/3TSb5+oqm/b6hq3yxx987Kq+tzUOXPPrHa7TVXtr6oPV9Wjk3+b3jijzVWfN0LVWL+S5K8m+cU52v7F7n7hHrqldcO+mXok0h1JDiW5s6oObU152+Zkkg9298EkH5xsz/K7k/Plhd19dOvK21pzngNffCxWkh/L6mOxdr2r+Pl439S58o4tLXL7vCvJ4Su8f0eSg5M/J5L8xBbUtCjelSv3TZL8p6lz5t4tqGkRXE7ypu4+lOQlSb5vxs/TVZ83QtVA3f1Yd+/2BU2vyZx988VHInX3M0mefSTSbnYsybsnr9+d5FXbV8pCmOccmO6zh5J8Z1XVFta4Xfbiz8dcuvsXs3rn+XqOJfmXveojSb6uqp6/NdVtrzn6Zk/q7k9198cnr/9vksey+nSYaVd93ghV26OT/Ieq+thklXlWzXok0tqTfLf5hu7+1OT1/0ryDeu0e05VrVTVR6rqVVtT2raY5xz4ksdiJXn2sVi73bw/H6+eXKp4qKr2z3h/L9qLv1uuxrdX1X+tqp+vqj+13cVstckUghcl+eiat676vNnSx9TsBlX1H5N844y3fri7/92cH/Pnuvupqvr6JB+oqv82+Taxow3qm13nSv0yvdHdXVXr3Y77gsk588eSfKiqPtnd/310rex4P5fkvd39+ar621kd0fuOba6JxfbxrP5++e2qemWSn83q5a49oaq+Jsm/SfJ3u/u3rvfzhKqr1N0vH/AZT03++3RV/UxWh/V3fKga0DfzPBJpx7lSv1TV/66q53f3pybDyk+v8xnPnjNPVNUvZPVb1W4MVcMei7ULbdg33T3dD+9I8rYtqGsn2JW/W0aYDhLdfaaqfryqburuXf+8xKr6yqwGqn/V3f92RpOrPm9c/ttiVfXVVfXcZ18n+UtZncTNfI9E2m2mH/H0uiRfNqJXVc+rqq+avL4pyUuTPLplFW4tj8Va34Z9s2a+x9GszhNhtZ/+xuRurpck+dzUZfc9raq+8dk5iVV1e1Zzwa7/kjL5f/6pJI9194+u0+yqzxsjVQNV1V9J8s+SLCV5f1U90t1/uar+aJJ3dPcrszpn5mcm5/ANSd7T3f9+24reIvP0zXqPRNrGsrfCqSQPVtXrk/x6ku9KklpdduJ7u/sNSb45ydur6gtZ/YV3qrt3ZajyWKz1zdk3P1BVR7N6Z9Nnkty1bQVvoap6b5KXJbmpqi4m+YdJvjJJuvtfZPWJIK9Mcj7J7yT5m9tT6dabo29ek+TvVNXlJL+b5Pge+ZLy0iR/Pcknq+qRyb5/kOTW5NrPGyuqAwAM4PIfAMAAQhUAwABCFQDAAEIVAMAAQhUAwABCFQDAAEIVAMAAQhUAwAD/D3tMF0zwLRdBAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.hist(z.numpy(), bins=100, density=True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# Transformed distribution\n", "\n", "x = log_normal_td.sample(n)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlcAAAFlCAYAAADGYc2/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAST0lEQVR4nO3df6xf933X8dd7MWF0Ky0j3oA47GYjHYRqo8ULhQrYaIvSZoqR+LFEFHWiWiRESmHVkMtQNAUJuSsaVFoYhK5kjK5RyMqwiEdWdQUktJa47dotCdmszDTOOuKVUn5MWxbx5o/7Db61r+tvk7d97o/HQ7Ly/Z57/P2+9ZF98/T5nntOdXcAAJjxFUsPAACwl4grAIBB4goAYJC4AgAYJK4AAAaJKwCAQQeWeuNrrrmmNzY2lnp7AIC1ffzjH/+17j64zr6LxdXGxkZOnjy51NsDAKytqv7ruvv6WBAAYJC4AgAYJK4AAAaJKwCAQeIKAGCQuAIAGCSuAAAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYNCBpQfYbTaOPnTBttPHbllgEgBgJ3LkCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYJK4AAAa5ztWXsN01rQAAvhRHrgAABokrAIBB4goAYJC4AgAYJK4AAAaJKwCAQeIKAGCQuAIAGCSuAAAGiSsAgEHiCgBgkHsLDtjuHoSnj92ywCQAwNIcuQIAGCSuAAAGiSsAgEFrxVVV3VxVT1TVqao6us3Xf39VfaSqPllVn66qN82PCgCw810yrqrqqiT3JHljkhuT3F5VN563299N8kB3vyrJbUn+8fSgAAC7wTpHrm5Kcqq7n+zuZ5Pcn+TIeft0kt+5evyyJL8yNyIAwO6xTlxdm+SpLc/PrLZt9f1J3lxVZ5KcSPK27V6oqu6oqpNVdfLs2bMvYFwAgJ1t6oT225Pc192HkrwpyY9V1QWv3d33dvfh7j588ODBobcGANg51omrp5Nct+X5odW2rd6a5IEk6e6fTfKVSa6ZGBAAYDdZJ64eSXJDVV1fVVdn84T14+ft85kkr0uSqvpD2Ywrn/sBAPvOJeOqu59LcmeSh5M8ns2fCny0qu6uqltXu70jyXdX1aeSfCDJd3V3X66hAQB2qrXuLdjdJ7J5ovrWbXdtefxYktfOjgYAsPu4QjsAwCBxBQAwSFwBAAwSVwAAg8QVAMAgcQUAMEhcAQAMElcAAIPEFQDAIHEFADBIXAEADBJXAACDxBUAwCBxBQAwSFwBAAwSVwAAg8QVAMAgcQUAMEhcAQAMElcAAIPEFQDAIHEFADBIXAEADBJXAACDxBUAwKADSw+wV20cfeiCbaeP3bLAJADAleTIFQDAIHEFADBIXAEADBJXAACDxBUAwCBxBQAwSFwBAAwSVwAAg8QVAMAgcQUAMEhcAQAMElcAAIPEFQDAIHEFADBIXAEADBJXAACDxBUAwCBxBQAw6MDSA+wUG0cfWnoEAGAPcOQKAGCQuAIAGCSuAAAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYJK4AAAaJKwCAQeIKAGCQuAIAGCSuAAAGiSsAgEHiCgBg0FpxVVU3V9UTVXWqqo5eZJ+/VFWPVdWjVfXjs2MCAOwOBy61Q1VdleSeJG9IcibJI1V1vLsf27LPDUnemeS13f35qvrayzUwAMBOts6Rq5uSnOruJ7v72ST3Jzly3j7fneSe7v58knT3M7NjAgDsDuvE1bVJntry/Mxq21avSPKKqvpPVfXRqrp5uxeqqjuq6mRVnTx79uwLmxgAYAebOqH9QJIbknxbktuT/LOqevn5O3X3vd19uLsPHzx4cOitAQB2jnXi6ukk1215fmi1baszSY5392919y8n+cVsxhYAwL6yTlw9kuSGqrq+qq5OcluS4+ft85PZPGqVqrommx8TPjk3JgDA7nDJuOru55LcmeThJI8neaC7H62qu6vq1tVuDyf5XFU9luQjSb63uz93uYYGANipLnkphiTp7hNJTpy37a4tjzvJ96x+AQDsW67QDgAwSFwBAAwSVwAAg8QVAMAgcQUAMEhcAQAMElcAAIPWus4VMzaOPnTBttPHbllgEgDgcnHkCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYJK4AAAaJKwCAQeIKAGCQuAIAGCSuAAAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYJK4AAAaJKwCAQeIKAGCQuAIAGCSuAAAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYJK4AAAaJKwCAQeIKAGCQuAIAGCSuAAAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYJK4AAAaJKwCAQeIKAGCQuAIAGCSuAAAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYdGCdnarq5iTvSXJVkvd297GL7PfnkzyY5Fu7++TYlHvYxtGHLth2+tgtC0wCAEy45JGrqroqyT1J3pjkxiS3V9WN2+z30iRvT/Kx6SEBAHaLdT4WvCnJqe5+srufTXJ/kiPb7Pf3krwryW8MzgcAsKusE1fXJnlqy/Mzq23/X1W9Osl13X3hZ1xfvN8dVXWyqk6ePXv2yx4WAGCne9EntFfVVyT5wSTvuNS+3X1vdx/u7sMHDx58sW8NALDjrBNXTye5bsvzQ6ttz3tpklcm+fdVdTrJa5Icr6rDU0MCAOwW68TVI0luqKrrq+rqJLclOf78F7v7C919TXdvdPdGko8mudVPCwIA+9El46q7n0tyZ5KHkzye5IHufrSq7q6qWy/3gAAAu8la17nq7hNJTpy37a6L7PttL34sAIDdyRXaAQAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYJK4AAAaJKwCAQeIKAGCQuAIAGLTWjZu5sjaOPnTBttPHbllgEgDgy+XIFQDAIHEFADBIXAEADBJXAACDxBUAwCBxBQAwSFwBAAwSVwAAg8QVAMAgcQUAMEhcAQAMElcAAIPEFQDAIHEFADBIXAEADBJXAACDxBUAwCBxBQAw6MDSAyxh4+hDS48AAOxRjlwBAAwSVwAAg8QVAMAgcQUAMEhcAQAMElcAAIPEFQDAIHEFADBIXAEADBJXAACDxBUAwCBxBQAwSFwBAAwSVwAAg8QVAMAgcQUAMEhcAQAMElcAAIPEFQDAIHEFADDowNIDsJ6Now9dsO30sVsWmAQA+FIcuQIAGCSuAAAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYtFZcVdXNVfVEVZ2qqqPbfP17quqxqvp0VX24qr5+flQAgJ3vknFVVVcluSfJG5PcmOT2qrrxvN0+meRwd39zkgeT/MD0oAAAu8E6R65uSnKqu5/s7meT3J/kyNYduvsj3f3rq6cfTXJodkwAgN1hnbi6NslTW56fWW27mLcm+antvlBVd1TVyao6efbs2fWnBADYJUZPaK+qNyc5nOTd2329u+/t7sPdffjgwYOTbw0AsCMcWGOfp5Nct+X5odW2L1JVr0/yfUn+dHf/5sx4AAC7yzpHrh5JckNVXV9VVye5LcnxrTtU1auS/NMkt3b3M/NjAgDsDpeMq+5+LsmdSR5O8niSB7r70aq6u6puXe327iRfneRfVdXPVdXxi7wcAMCets7HgunuE0lOnLftri2PXz88FwDAruQK7QAAg8QVAMAgcQUAMEhcAQAMElcAAIPEFQDAoLUuxcDOtHH0oQu2nT52ywKTAADPc+QKAGCQuAIAGCSuAAAGiSsAgEHiCgBgkLgCABgkrgAABokrAIBB4goAYJC4AgAYJK4AAAa5t+Ae436DALAsR64AAAaJKwCAQeIKAGCQuAIAGOSE9n3ASe4AcOU4cgUAMEhcAQAMElcAAIPEFQDAIHEFADBIXAEADBJXAACDxBUAwCBxBQAwSFwBAAwSVwAAg8QVAMAgN27ep7a7mXPihs4A8GI5cgUAMEhcAQAMElcAAIPEFQDAIHEFADBIXAEADBJXAACDXOeKL7Ld9a9c+woA1ufIFQDAIHEFADBIXAEADNrT51xd7P55fHmchwUA63PkCgBgkLgCABgkrgAABu3pc664fJyHBQDbE1eMEVwA4GNBAIBR4goAYJC4AgAYJK4AAAY5oZ3Lat2r5DvxHYC9wpErAIBB4goAYNBaHwtW1c1J3pPkqiTv7e5j5339tyf5F0n+aJLPJfnO7j49Oyp7mWtkAbBXXDKuquqqJPckeUOSM0keqarj3f3Ylt3emuTz3f0Hquq2JO9K8p2XY2D2NxEGwE63zpGrm5Kc6u4nk6Sq7k9yJMnWuDqS5PtXjx9M8kNVVd3dg7Oyz6x7Mvy6+61LrAHwYqwTV9cmeWrL8zNJ/tjF9unu56rqC0l+d5JfmxgSrqQXG2vbxdm6R9wcmQPY/a7opRiq6o4kd6ye/u+qemL4La6JoEusw1ZXfC3qXcvstwZ/Ls6xFudYi3OsxTnWYtPWdfj6dX/TOnH1dJLrtjw/tNq23T5nqupAkpdl88T2L9Ld9ya5d93hvlxVdbK7D1+u198trMM51uIca3GOtTjHWpxjLc6xFpte6DqscymGR5LcUFXXV9XVSW5Lcvy8fY4necvq8V9I8jPOtwIA9qNLHrlanUN1Z5KHs3kphvd196NVdXeSk919PMmPJPmxqjqV5L9nM8AAAPadtc656u4TSU6ct+2uLY9/I8lfnB3tBblsHznuMtbhHGtxjrU4x1qcYy3OsRbnWItNL2gdyqd3AABz3P4GAGDQnoirqrq5qp6oqlNVdXTpeZZSVddV1Ueq6rGqerSq3r70TEurqquq6pNV9W+XnmVJVfXyqnqwqv5LVT1eVX986ZmWUlV/a/X34xeq6gNV9ZVLz3SlVNX7quqZqvqFLdu+pqo+VFW/tPrv71pyxivhIuvw7tXfj09X1b+uqpcvOOIVs91abPnaO6qqq+qaJWa70i62FlX1ttWfjUer6gfWea1dH1dbbs/zxiQ3Jrm9qm5cdqrFPJfkHd19Y5LXJPnr+3gtnvf2JI8vPcQO8J4k/667/2CSb8k+XZOqujbJ30hyuLtfmc0f0tlPP4BzX5Kbz9t2NMmHu/uGJB9ePd/r7suF6/ChJK/s7m9O8otJ3nmlh1rIfblwLVJV1yX5s0k+c6UHWtB9OW8tqurbs3kXmm/p7j+c5B+s80K7Pq6y5fY83f1skudvz7PvdPdnu/sTq8f/K5v/A7122amWU1WHktyS5L1Lz7KkqnpZkj+VzZ/qTXc/293/Y9GhlnUgye9YXZPvJUl+ZeF5rpju/o/Z/InurY4k+dHV4x9N8ueu5ExL2G4duvunu/u51dOPZvOajnveRf5MJMk/TPK3k+ybE7MvshZ/Lcmx7v7N1T7PrPNaeyGutrs9z74NiudV1UaSVyX52MKjLOkfZfObw/9deI6lXZ/kbJJ/vvqI9L1V9VVLD7WE7n46m//y/EySzyb5Qnf/9LJTLe7ruvuzq8e/muTrlhxmh/irSX5q6SGWUlVHkjzd3Z9aepYd4BVJ/mRVfayq/kNVfes6v2kvxBXnqaqvTvITSf5md//PpedZQlV9R5JnuvvjS8+yAxxI8uokP9zdr0ryf7I/Pvq5wOp8oiPZDM7fl+SrqurNy061c6wu/rxvjlRsp6q+L5unWLx/6VmWUFUvSfJ3ktx1qX33iQNJviabp9p8b5IHqqou9Zv2Qlytc3uefaOqfls2w+r93f3BpedZ0GuT3FpVp7P5UfGfqap/uexIizmT5Ex3P38U88FsxtZ+9Pokv9zdZ7v7t5J8MMmfWHimpf23qvq9SbL671ofe+xFVfVdSb4jyV/ex3cZ+cZs/uPjU6vvn4eSfKKqfs+iUy3nTJIP9qb/nM1PQi55gv9eiKt1bs+zL6xq+keSPN7dP7j0PEvq7nd296Hu3sjmn4mf6e59eYSiu381yVNV9U2rTa9L8tiCIy3pM0leU1UvWf19eV326cn9W2y9fdlbkvybBWdZTFXdnM3TCG7t7l9fep6ldPfPd/fXdvfG6vvnmSSvXn0f2Y9+Msm3J0lVvSLJ1Vnjhta7Pq5WJyA+f3uex5M80N2PLjvVYl6b5K9k8yjNz61+vWnpodgR3pbk/VX16SR/JMnfX3acZayO3j2Y5BNJfj6b3wP3zZWoq+oDSX42yTdV1ZmqemuSY0neUFW/lM0je8eWnPFKuMg6/FCSlyb50Op75z9ZdMgr5CJrsS9dZC3el+QbVpdnuD/JW9Y5qukK7QAAg3b9kSsAgJ1EXAEADBJXAACDxBUAwCBxBQAwSFwBAAwSVwAAg8QVAMCg/wesW8643YZm/gAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot x density\n", "plt.hist(x.numpy(), bins=100, density=True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "# Define log normal distribution\n", "\n", "log_normal = tfd.LogNormal(loc=loc, scale=scale)\n", "l = log_normal.sample(n)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlMAAAFlCAYAAADPim3FAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAARkElEQVR4nO3dcayld17X8c93Z1jRZdlN6Gg2neptYiFuCHHXSdGsgY3smq5DWhJFWgMBXOg/lqxZohnUVKz/zEJC3MRqrF1kQaCpi5CJM1hMWAMaFjqFJdCWJZM62KmaDuuKrgZr9esf9yy53Lmzc3a+Z+acc+/rlUx6znOf3vvtM037nt/znOep7g4AADfnDeseAABgm4kpAIABMQUAMCCmAAAGxBQAwICYAgAYOL6uH3zHHXf0zs7Oun48AMDSnnvuud/u7hMHfW1tMbWzs5OLFy+u68cDACytqn7rel9zmg8AYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAgJgCABgQUwAAA2IKAGBATAEADIgpAIABMQUAMCCmAAAGjq97AG7ezpnz12y7fPb0GiYBgKPLyhQAwICYAgAYEFMAAANiCgBgQEwBAAyIKQCAATEFADAgpgAABsQUAMCAmAIAGBBTAAADYgoAYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAgJgCABg4vu4BuNbOmfPXbLt89vQaJgEAbkRMbYmDAgsAWD+n+QAABqxMHTJOEQLA7WVlCgBgQEwBAAyIKQCAAddMrZlP6QHAdrMyBQAwIKYAAAbEFADAgGumjgD3ngKAW8fKFADAgJgCABgQUwAAA2IKAGBATAEADIgpAIABMQUAMCCmAAAGxBQAwICYAgAYEFMAAANiCgBgQEwBAAyIKQCAgePL7FRV9yX5cJJjSZ7s7rP7vv5Hk3w0yVsX+5zp7gurHZVV2jlz/sDtl8+evs2TAMB2u+HKVFUdS/J4kvcleXuSh6rq7ft2+ztJnu7udyR5MMk/WvWgAACbaJnTfPcmudTdL3X3a0meSvLAvn06yZcuXr8lyX9a3YgAAJtrmdN8dyZ5ec/7K0m+et8+35vkZ6rqu5K8Kcl7VjIdAMCGW9UF6A8l+aHuPpnkLyT5kaq65ntX1cNVdbGqLl69enVFPxoAYH2WialXkty15/3Jxba93p/k6STp7l9I8sVJ7tj/jbr7ie4+1d2nTpw4cXMTAwBskGVO8z2b5J6quju7EfVgkr+yb5//mOTrkvxQVf2J7MaUpad9rvcJOgBge91wZaq7X0/ySJJnkryY3U/tPV9Vj1XV/YvdvjvJd1bVryb58STf1t19q4YGANgUS91nanHPqAv7tj265/ULSd612tEAADafO6ADAAyIKQCAATEFADAgpgAABsQUAMCAmAIAGBBTAAADYgoAYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAgJgCABgQUwAAA2IKAGBATAEADIgpAIABMQUAMHB83QOwWXbOnL9m2+Wzp9cwCQBsBytTAAADYgoAYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAgJgCABhw085b5KCbXwIAh4+VKQCAATEFADDgNB83xTP8AGCXlSkAgAExBQAwIKYAAAbEFADAgJgCABgQUwAAA2IKAGBATAEADIgpAIABMQUAMCCmAAAGxBQAwICYAgAYEFMAAANiCgBgQEwBAAwcX/cAbL6dM+fXPQIAbCwrUwAAA2IKAGBATAEADIgpAIABMQUAMCCmAAAGxBQAwICYAgAYWCqmquq+qvpUVV2qqjPX2ecvV9ULVfV8Vf3YascEANhMN7wDelUdS/J4kvcmuZLk2ao6190v7NnnniTfk+Rd3f2ZqvrDt2pgAIBNsszK1L1JLnX3S939WpKnkjywb5/vTPJ4d38mSbr71dWOCQCwmZaJqTuTvLzn/ZXFtr2+PMmXV9W/r6pPVNV9B32jqnq4qi5W1cWrV6/e3MQAABtkVRegH09yT5J3J3koyT+tqrfu36m7n+juU9196sSJEyv60QAA67NMTL2S5K49708utu11Jcm57v4/3f0fkvxmduMKAOBQWyamnk1yT1XdXVVvTPJgknP79vmp7K5KparuyO5pv5dWNyYAwGa6YUx19+tJHknyTJIXkzzd3c9X1WNVdf9it2eSfLqqXkjy8SR/o7s/fauGBgDYFDe8NUKSdPeFJBf2bXt0z+tO8sHFLwCAI8Md0AEABsQUAMCAmAIAGBBTAAADYgoAYEBMAQAMLHVrBD6/nTPn1z0CALAmVqYAAAbEFADAgJgCABgQUwAAA2IKAGBATAEADLg1Aitz0C0iLp89vYZJAOD2sTIFADAgpgAABsQUAMCAmAIAGBBTAAADYgoAYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAgJgCABgQUwAAA2IKAGBATAEADBxf9wAcbjtnzl+z7fLZ02uYBABuDStTAAADYgoAYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAgJgCABgQUwAAA2IKAGBATAEADIgpAIABMQUAMHB83QNw9OycOX/NtstnT69hEgCYszIFADAgpgAABsQUAMCAmAIAGHABOhvBRekAbCsrUwAAA2IKAGBATAEADLhm6gt00LU9AMDRZWUKAGBATAEADIgpAICBpWKqqu6rqk9V1aWqOvN59vuLVdVVdWp1IwIAbK4bXoBeVceSPJ7kvUmuJHm2qs519wv79ntzkg8k+cVbMShHjxt5ArANllmZujfJpe5+qbtfS/JUkgcO2O/vJ/lQkt9d4XwAABttmZi6M8nLe95fWWz7PVX1ziR3dbf7BgAAR8r4AvSqekOSH0jy3Uvs+3BVXayqi1evXp3+aACAtVsmpl5Jctee9ycX2z7nzUm+Msm/rarLSf50knMHXYTe3U9096nuPnXixImbnxoAYEMsE1PPJrmnqu6uqjcmeTDJuc99sbt/p7vv6O6d7t5J8okk93f3xVsyMQDABrlhTHX360keSfJMkheTPN3dz1fVY1V1/60eEABgky31bL7uvpDkwr5tj15n33fPxwIA2A7ugA4AMCCmAAAGxBQAwICYAgAYEFMAAANiCgBgQEwBAAyIKQCAATEFADAgpgAABsQUAMCAmAIAGBBTAAADYgoAYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAgJgCABgQUwAAA2IKAGDg+LoH2GQ7Z86vewSWcNDv0+Wzp9cwCQBHkZUpAIABMQUAMCCmAAAGxBQAwICYAgAYEFMAAANujcBWcbsKADaNlSkAgAErUxxKbuQJwO1iZQoAYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAgJgCABgQUwAAA2IKAGBATAEADIgpAIABMQUAMCCmAAAGxBQAwICYAgAYOL7uAWCdds6cv2bb5bOn1zAJANvKyhQAwICYAgAYEFMAAANiCgBgQEwBAAyIKQCAATEFADAgpgAABty0E/Y56EaeiZt5AnCwpWKqqu5L8uEkx5I82d1n9339g0m+I8nrSa4m+avd/VsrnhVGrhdJADBxw9N8VXUsyeNJ3pfk7Ukeqqq379vtV5Kc6u6vSvKxJN+36kEBADbRMtdM3ZvkUne/1N2vJXkqyQN7d+juj3f3/1q8/USSk6sdEwBgMy0TU3cmeXnP+yuLbdfz/iQ/fdAXqurhqrpYVRevXr26/JQAABtqpZ/mq6pvTnIqyfcf9PXufqK7T3X3qRMnTqzyRwMArMUyF6C/kuSuPe9PLrb9PlX1niR/O8nXdvf/Xs14AACbbZmVqWeT3FNVd1fVG5M8mOTc3h2q6h1J/kmS+7v71dWPCQCwmW4YU939epJHkjyT5MUkT3f381X1WFXdv9jt+5N8SZJ/UVWfrKpz1/l2AACHylL3meruC0ku7Nv26J7X71nxXAAAW8HjZAAABsQUAMCAmAIAGBBTAAADYgoAYEBMAQAMiCkAgAExBQAwIKYAAAbEFADAwFKPkzkKds6cX/cIAMAWElOwpIOC+/LZ02uYBIBN4jQfAMCAlSkYsFoFgJUpAIABMQUAMOA0H6yYU38AR4uVKQCAATEFADAgpgAABsQUAMCAmAIAGBBTAAADbo0At8GyD9J2CwWA7WNlCgBgQEwBAAyIKQCAATEFADAgpgAABsQUAMCAmAIAGHCfKdhwB92jyv2oADaHlSkAgAErU7CFrFYBbA4xBRtk2cfOALA5nOYDABgQUwAAA07zwSHhOiqA9bAyBQAwYGUKsKoFMCCm4BBbdSRd79OGwgs4yo5kTPn4OQCwKq6ZAgAYOJIrU3CU3a6VWddhAUeFlSkAgAExBQAwIKYAAAZcMwWM+YQscJSJKeBAAglgOWIKWCuf+gO2nWumAAAGrEwBt41Th8BhZGUKAGDgUK9M+VMwbCfXUQHb5FDHFMCtcL0/qAk+OJrEFLAVJqtVk1XqLySQrKjB0bRUTFXVfUk+nORYkie7++y+r/+BJD+c5E8l+XSSb+ruy6sdFYAJsQe3xg1jqqqOJXk8yXuTXEnybFWd6+4X9uz2/iSf6e4/XlUPJvlQkm+6FQMDfM7tiIPbde2l0IHttczK1L1JLnX3S0lSVU8leSDJ3ph6IMn3Ll5/LMk/rKrq7l7hrAA3tGkfPJnMs0mBtUmzXM82zMjhtExM3Znk5T3vryT56uvt092vV9XvJPmyJL+9iiEBuL51XU92u8L1oH+WZX+2wDp8NvH39LZegF5VDyd5ePH2s1X1qS/gb78j4mzVHNPVc0xXzzHdoz60kv226pgu+8+8ru+3sFXHdEssfUxv0e/pfn/sel9YJqZeSXLXnvcnF9sO2udKVR1P8pbsXoj++3T3E0meWOJnXqOqLnb3qZv5ezmYY7p6junqOaar55iunmO6ett0TJe5A/qzSe6pqrur6o1JHkxybt8+55J86+L1X0rys66XAgCOghuuTC2ugXokyTPZvTXCD3b381X1WJKL3X0uyUeS/EhVXUryX7MbXAAAh95S10x194UkF/Zte3TP699N8o2rHe0aN3V6kM/LMV09x3T1HNPVc0xXzzFdva05puVsHADAzVvmmikAAK5jK2Kqqu6rqk9V1aWqOrPuebZdVf1gVb1aVb++7lkOi6q6q6o+XlUvVNXzVfWBdc+07arqi6vql6rqVxfH9O+te6bDoKqOVdWvVNW/Wvcsh0FVXa6qX6uqT1bVxXXPcxhU1Vur6mNV9RtV9WJV/Zl1z3QjG3+ab/E4m9/MnsfZJHlo3+Ns+AJU1dck+WySH+7ur1z3PIdBVb0tydu6+5er6s1JnkvyDf49vXlVVUne1N2fraovSvLvknyguz+x5tG2WlV9MMmpJF/a3V+/7nm2XVVdTnKqu91jakWq6qNJfr67n1zcReAPdfd/W/NYn9c2rEz93uNsuvu1JJ97nA03qbt/LrufumRFuvs/d/cvL17/jyQvZvfJANyk3vXZxdsvWvza7D/9bbiqOpnkdJIn1z0LHKSq3pLka7J7l4B092ubHlLJdsTUQY+z8T8pNlZV7SR5R5JfXPMoW29xSuqTSV5N8m+62zGd+QdJ/maS/7fmOQ6TTvIzVfXc4ikfzNyd5GqSf7Y4Hf1kVb1p3UPdyDbEFGyNqvqSJD+R5K93939f9zzbrrv/b3f/yew+eeHeqnJa+iZV1dcnebW7n1v3LIfMn+3udyZ5X5K/triMgpt3PMk7k/zj7n5Hkv+ZZOOvld6GmFrmcTawdovren4iyY92979c9zyHyWKZ/+NJ7lvzKNvsXUnuX1zj81SSP1dV/3y9I22/7n5l8ddXk/xkdi9N4eZdSXJlzyr0x7IbVxttG2JqmcfZwFotLpb+SJIXu/sH1j3PYVBVJ6rqrYvXfzC7H0L5jbUOtcW6+3u6+2R372T3v6M/293fvOaxtlpVvWnxgZMsTkX9+SQ+JT3Q3f8lyctV9RWLTV+XZOM/yLPUHdDX6XqPs1nzWFutqn48ybuT3FFVV5L83e7+yHqn2nrvSvItSX5tcY1PkvytxdMDuDlvS/LRxSd635Dk6e72cX42yR9J8pO7f5bK8SQ/1t3/er0jHQrfleRHFwsoLyX59jXPc0Mbf2sEAIBNtg2n+QAANpaYAgAYEFMAAANiCgBgQEwBAAyIKQCAATEFADAgpgAABv4/W37FqTp0RPsAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.hist(l.numpy(), bins=100, density=True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Log probability" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "# Log prob of LogNormal\n", "\n", "log_prob = log_normal.log_prob(x)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "# Log prob of log normal transformed distribution\n", "\n", "log_prob_td = log_normal_td.log_prob(x)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check log probs\n", "\n", "tf.norm(log_prob - log_prob_td)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Event shape and batch shape" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "# Set a scaling lower triangular matrix\n", "\n", "tril = tf.random.normal((2, 4, 4))\n", "scale_low_tri = tf.linalg.LinearOperatorLowerTriangular(tril)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# view of scale_low_tri\n", "\n", "scale_low_tri.to_dense()" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "# Define scale linear operator\n", "\n", "scale_lin_op = tfb.ScaleMatvecLinearOperator(scale_low_tri)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "# Define scale linear operator transformed distribution with a batch and event shape\n", "\n", "mvn = tfd.TransformedDistribution(tfd.Sample(tfd.Normal(loc=[0., 0.], scale=1.), sample_shape=[4]), scale_lin_op)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "batch shape: (2,)\n", "event shape: (4,)\n" ] } ], "source": [ "# Display event and batch shape\n", "\n", "print('batch shape: ', mvn.batch_shape)\n", "print('event shape: ', mvn.event_shape)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10000, 2, 4)\n" ] } ], "source": [ "# Sample\n", "\n", "y1 = mvn.sample(sample_shape=(n,))\n", "print(y1.shape)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Define a MultivariateNormalLinearOperator distribution\n", "\n", "mvn2 = tfd.MultivariateNormalLinearOperator(loc=0, scale=scale_low_tri)\n", "mvn2" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TensorShape([10000, 2, 4])" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# sample\n", "y2 = mvn2.sample(sample_shape=(n, ))\n", "y2.shape" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check\n", "\n", "xn = normal.sample((n, 2, 4))\n", "tf.norm(mvn.log_prob(xn) - mvn2.log_prob(xn)) / tf.norm(mvn.log_prob(xn))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }