{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "import plotly.graph_objects as go\n", "import plotly.io as pio\n", "\n", "# Our 2-dimensional distribution will be over variables X and Y\n", "N = 60\n", "x = np.linspace(-3.5, 3.5, N)\n", "y = np.linspace(-3.5, 3.5, N)\n", "X, Y = np.meshgrid(x, y)\n", "\n", "\n", "# Mean vector and covariance matrix\n", "mu = np.array([0., 0.])\n", "Sigma = np.array([[ 2.5 , 1], [1, 2.5]])\n", "\n", "# Pack X and Y into a single 3-dimensional array\n", "pos = np.empty(X.shape + (2,))\n", "pos[:, :, 0] = X\n", "pos[:, :, 1] = Y\n", "\n", "def multivariate_gaussian(pos, mu, Sigma):\n", " \"\"\"Return the multivariate Gaussian distribution on array pos.\n", "\n", " pos is an array constructed by packing the meshed arrays of variables\n", " x_1, x_2, x_3, ..., x_k into its _last_ dimension.\n", "\n", " \"\"\"\n", "\n", " n = mu.shape[0]\n", " Sigma_det = np.linalg.det(Sigma)\n", " Sigma_inv = np.linalg.inv(Sigma)\n", " N = np.sqrt((2*np.pi)**n * Sigma_det)\n", " # This einsum call calculates (x-mu)T.Sigma-1.(x-mu) in a vectorized\n", " # way across all the input variables.\n", " fac = np.einsum('...k,kl,...l->...', pos-mu, Sigma_inv, pos-mu)\n", "\n", " return np.exp(-fac / 2) / N\n", "\n", "# The distribution on the variables X, Y packed into pos.\n", "Z = multivariate_gaussian(pos, mu, Sigma)\n", "\n", "fig = go.Figure(data=go.Surface(x=X, y=Y, z=Z))\n", "fig.update_traces(contours_z=dict(show=True, usecolormap=True, highlightcolor=\"limegreen\", project_z=True),\n", " showscale=False, opacity=0.6)\n", "fig.update_layout(scene=dict(\n", " xaxis=dict(visible=False),\n", " yaxis=dict(visible=False),\n", " zaxis=dict(visible=False, range=[0, 0.07]),\n", " domain=dict(y=[0.1, 1])),\n", " height=700,\n", " margin=dict(r=0, l=0, b=0, t=0, pad=0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pio.write_html(fig,\n", " file='../_includes/figures/figure.html',\n", " full_html=False,\n", " # include_mathjax='cdn',\n", " include_plotlyjs='cdn')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }