{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "UNAM_GuestLecture_ModelSelectionExamples.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "44xR8l6o2NH7" }, "source": [ "# Introduction\n", "This notebook is intended to demonstrate some simple applications of Bayesian Model Selection in astronomy using python, and will also set out some problems for attendees to solve. Please bear in mind that that these examples and problems all implicitly or explicitly assume Bayesian approaches, so I recommend you remind yourselves of the basic principles first. It also assumes some knowledge from the previous lecture/examples on MCMC.\n", "\n", "The text and graphical elements of this work is licensed under Attribution 4.0 International. To view a copy of this license, visit http://creativecommons.org/licenses/by/4.0/. The code elements of this notebook are licensed under the MIT license (text at end).\n", "\n", "\n", "# This example sheet\n", "This will cover a number of examples, including\n", "1. Calculating different test statistics with python\n", " 1. LRT\n", " 2. AIC\n", " 3. BIC\n", " 4. Bayes' factors\n", "2. A short example of Nested Sampling for Bayes' factors\n", "3. Mixture model introduction\n", " 1. Example 1: Outlier rejection\n", " 2. Example 2: Gaussian mixture model for classifying by colour\n", "\n", "\n", "These examples for Nested sampling are setup in *dynesty* and the Gaussian mixture model uses scikit-learn. However, I strongly encourage you to try out other packages and approaches to doing this. In particular, Ultranest may be faster for many problems than dynesty, while scikit-learn fails to account for uncertain input data." ] }, { "cell_type": "code", "metadata": { "id": "ddyMgZmt2D9o" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "42m14j9ABYaS" }, "source": [ "#Calculating test statistics in python\n", "\n", "There are a lot of codes available to calculate useful test statistics in python. However, for completeness here we include some functions to calculate the Likelihood ratio, Akaike information criterion and Bayesian information criterion." ] }, { "cell_type": "code", "metadata": { "id": "ou9gq_WxD2HB" }, "source": [ "from scipy.stats.distributions import chi2\n", "def likelihood_ratio(mll_null, mll_alt, n_null = None, n_alt = None):\n", " lr = (2*(mll_null-mll_alt))\n", " if (n_null is not None) and (n_alt is not None):\n", " p = chi2.sf(lr, np.abs(n_null - n_alt))\n", " return lr, p\n", " return lr\n", "\n", "def aic(mll, npars):\n", " return (2 * npars) - (2 * mll)\n", "\n", "def bic(mll, npars, npoints):\n", " return (npars*np.log(npoints)) - (2 * mll)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Jmk5WFMBG9Tf" }, "source": [ "However, the Bayes Factor is more difficult to calculate immediately, as it involves calculating the evidences of the models." ] }, { "cell_type": "code", "metadata": { "id": "twTAIdqhHnQD" }, "source": [ "def bayesfac(evidence_null, evidence_alt, log=True):\n", " if log:\n", " return evidence_null - evidence_alt\n", " return evidence_null / evidence_alt" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "900gbWRkHuVF" }, "source": [ "# Using Nested Sampling to calculate the evidence\n", "\n", "Calculating the evidence is hard, because it involves integrating the likelihood over the entire parameter space. If the entire likelihood can be evaluated and integrated analytically, this is feasible, but a likelihood function doesn't have to get very complex before this is unrealistic or even impossible; this is common in typical astronomical applications. \n", "\n", "For these cases, we turn to special integration routines, and Nested Sampling is a very clever algorithm designed with this specific case in mind. The basic idea is to draw Monte Carlo samples from the prior and evaluate the likelihood at those points, then gradually shrink the sampling volume by moving to higher-likelihood locations. By doing this it integrates the function, naturally producing the evidence. The first step to calculating Bayes factors is therefore to fit the data with Nested Sampling using each model being tested.\n", "\n", "There are several python packages available for Nested Sampling, that mostly operate along the same lines as *emcee* treats MCMC. In the following cell we use the package dynesty to fit two models to some data (this will be familiar for everyone who has worked on the problems for the MCMC lecture)." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1ilRjy8mN23C", "outputId": "35ebcd62-c1ef-4a37-bf2c-a4f03f0360dc" }, "source": [ "import numpy as np\n", "from scipy.stats import norm\n", "\n", "x = np.array([ 0. , 0.1010101 , 0.2020202 , 0.3030303 , 0.4040404 ,\n", " 0.50505051, 0.60606061, 0.70707071, 0.80808081, 0.90909091,\n", " 1.01010101, 1.11111111, 1.21212121, 1.31313131, 1.41414141,\n", " 1.51515152, 1.61616162, 1.71717172, 1.81818182, 1.91919192,\n", " 2.02020202, 2.12121212, 2.22222222, 2.32323232, 2.42424242,\n", " 2.52525253, 2.62626263, 2.72727273, 2.82828283, 2.92929293,\n", " 3.03030303, 3.13131313, 3.23232323, 3.33333333, 3.43434343,\n", " 3.53535354, 3.63636364, 3.73737374, 3.83838384, 3.93939394,\n", " 4.04040404, 4.14141414, 4.24242424, 4.34343434, 4.44444444,\n", " 4.54545455, 4.64646465, 4.74747475, 4.84848485, 4.94949495,\n", " 5.05050505, 5.15151515, 5.25252525, 5.35353535, 5.45454545,\n", " 5.55555556, 5.65656566, 5.75757576, 5.85858586, 5.95959596,\n", " 6.06060606, 6.16161616, 6.26262626, 6.36363636, 6.46464646,\n", " 6.56565657, 6.66666667, 6.76767677, 6.86868687, 6.96969697,\n", " 7.07070707, 7.17171717, 7.27272727, 7.37373737, 7.47474747,\n", " 7.57575758, 7.67676768, 7.77777778, 7.87878788, 7.97979798,\n", " 8.08080808, 8.18181818, 8.28282828, 8.38383838, 8.48484848,\n", " 8.58585859, 8.68686869, 8.78787879, 8.88888889, 8.98989899,\n", " 9.09090909, 9.19191919, 9.29292929, 9.39393939, 9.49494949,\n", " 9.5959596 , 9.6969697 , 9.7979798 , 9.8989899 , 10. ])\n", "y = np.array([0.60614598, 0.6153662 , 0.47419432, 0.64142881, 0.5404495 ,\n", " 0.65958241, 0.5388796 , 0.72837078, 0.56604171, 0.86397498,\n", " 0.4751028 , 0.4119352 , 0.53598582, 1.04707218, 0.83268879,\n", " 0.65939936, 0.71741511, 0.79977691, 0.76297566, 0.56573409,\n", " 0.68881211, 0.93441395, 1.08102487, 1.06692672, 0.91861211,\n", " 1.14857326, 0.98065209, 0.99522437, 0.98278417, 1.11751265,\n", " 0.79828341, 0.18980406, 1.11795621, 1.16912414, 1.08937796,\n", " 1.05315588, 0.95460981, 1.35636458, 1.5210485 , 1.70122012,\n", " 1.8312029 , 2.50786172, 2.80173261, 3.10481785, 2.62385979,\n", " 2.20492704, 1.6121115 , 1.64443925, 1.69162613, 0.96340852,\n", " 0.86319356, 0.61203334, 0.87827908, 0.73654364, 0.99289407,\n", " 0.74860198, 0.5520727 , 0.74797704, 0.60877654, 0.48734763,\n", " 0.41548919, 0.47134843, 0.65060695, 0.55424859, 0.53872972,\n", " 0.42004391, 0.50261679, 0.41800777, 0.14338189, 0.39985653,\n", " 0.34840721, 0.37024026, 0.57578898, 0.16391488, 0.36135355,\n", " 0.36556374, 0.72052741, 0.49379491, 0.58735107, 0.24443858,\n", " 0.44170883, 0.498692 , 0.51745757, 0.45058975, 0.25297367,\n", " 0.57483966, 0.29939929, 0.5941698 , 0.35649874, 0.56844186,\n", " 0.45872986, 0.36447514, 0.52903476, 0.51461255, 0.70671988,\n", " 0.36093269, 0.63303675, 0.76049668, 1.0671897 , 0.58650049])\n", "yerr = np.array([0.13571468, 0.17649925, 0.2427158 , 0.19674616, 0.1374289 ,\n", " 0.13734713, 0.23344185, 0.13494697, 0.16865474, 0.14877157,\n", " 0.1523842 , 0.21276768, 0.21248585, 0.23104686, 0.2186303 ,\n", " 0.16633297, 0.12466968, 0.19250511, 0.23896049, 0.24204692,\n", " 0.11889055, 0.20328089, 0.19548629, 0.1582987 , 0.18590566,\n", " 0.24759034, 0.23386036, 0.14216423, 0.18593009, 0.2357429 ,\n", " 0.21747414, 0.24222248, 0.17103163, 0.1992606 , 0.15450062,\n", " 0.22667307, 0.23436873, 0.11204839, 0.11259136, 0.11087712,\n", " 0.23634978, 0.17765576, 0.24790074, 0.15669243, 0.1050639 ,\n", " 0.11204431, 0.13651244, 0.14456857, 0.24083602, 0.21794055,\n", " 0.18097375, 0.13599461, 0.23872987, 0.10324854, 0.24533545,\n", " 0.12294229, 0.16149936, 0.10591347, 0.21025653, 0.18675912,\n", " 0.16242686, 0.16716389, 0.18856679, 0.11728497, 0.11803215,\n", " 0.15240304, 0.12481872, 0.17082965, 0.24074905, 0.13743163,\n", " 0.23039399, 0.21941084, 0.24978922, 0.19685753, 0.16923487,\n", " 0.22468505, 0.18220194, 0.10126559, 0.16738149, 0.13034307,\n", " 0.16072441, 0.12061447, 0.15581466, 0.18634505, 0.21928294,\n", " 0.12790317, 0.21768046, 0.15576875, 0.19392913, 0.10682687,\n", " 0.18502616, 0.17874781, 0.1710912 , 0.12937686, 0.20394657,\n", " 0.13313035, 0.10067853, 0.24100547, 0.18669291, 0.19211422])\n", "\n", "\n", "def model1(x, centre, width, int_intens, baseline):\n", " #A model for a spectral line described by a gaussian, normalised by integrated intensity\n", " return int_intens*norm.pdf(x, centre, width) + baseline\n", "\n", "def model2(x, centre, width, int_intens, baseline):\n", " #A model for a spectral line described by a gaussian, normalised by integrated intensity\n", " #however, now the baseline is more complex\n", " bl = baseline[0] + baseline[1]*np.sin(baseline[2]*x)\n", " mod = int_intens*norm.pdf(x, centre, width) + bl\n", " return mod\n", "\n", "def ptform1(u):\n", " theta = u.copy()\n", " # now we go over each parameter and transform from U[0, 1) to our desired prior\n", " lo = -10\n", " hi = 10\n", " theta[0] = u[0] * (hi - lo) + lo\n", "\n", " lo = 0\n", " hi = 100\n", " theta[1] = u[1] * (hi - lo) + lo\n", "\n", " lo = 0\n", " hi = 3\n", " theta[2] = u[2] * (hi - lo) + lo\n", "\n", " lo = -10\n", " hi = 10\n", " theta[3] = u[3] * (hi - lo) + lo\n", "\n", " return theta\n", "\n", "def ptform2(u):\n", " theta = u.copy()\n", " # now we go over each parameter and transform from U[0, 1) to our desired prior\n", " lo = -10\n", " hi = 10\n", " theta[0] = u[0] * (hi - lo) + lo\n", "\n", " lo = 0\n", " hi = 100\n", " theta[1] = u[1] * (hi - lo) + lo\n", "\n", " lo = 0\n", " hi = 3\n", " theta[2] = u[2] * (hi - lo) + lo\n", "\n", " lo = -2\n", " hi = 2\n", " theta[3] = u[3] * (hi - lo) + lo\n", "\n", " lo = -1\n", " hi = 1\n", " theta[4] = u[4] * (hi - lo) + lo\n", "\n", " lo = -2\n", " hi = 2\n", " theta[5] = u[5] * (hi - lo) + lo\n", " return theta\n", "\n", "def lnlike1(theta, x, y, yerr):\n", " centre, width, int_intens, baseline = theta\n", " y_model = model1(x, centre, width, int_intens, baseline)\n", " like = -0.5 * (((y_model - y)/yerr)**2).sum()\n", " return like\n", "\n", "\n", "def lnlike2(theta, x, y, yerr):\n", " centre = theta[0]\n", " width = theta[1]\n", " int_intens = theta[2] \n", " baseline = theta[3:]\n", " y_model = model2(x, centre, width, int_intens, baseline)\n", " like = -0.5 * (((y_model - y)/yerr)**2).sum()\n", " return like\n", "\n", "try:\n", " import dynesty\n", "except ImportError:\n", " !pip install dynesty\n", " import dynesty\n", "\n", "sampler1 = dynesty.NestedSampler(lnlike1, ptform1, \n", " ndim = 4, nlive=1500, logl_args=(x,y,yerr))\n", "sampler1.run_nested(dlogz = 0.01) #Model1 is very simple, so this is extremely fast\n", "\n", "results1 = sampler1.results" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Collecting dynesty\n", "\u001b[?25l Downloading https://files.pythonhosted.org/packages/94/e6/dc4369009259a0a113b3f91e223be9229f71d8350aca6ec8fe978982d2f9/dynesty-1.0.1-py2.py3-none-any.whl (86kB)\n", "\r\u001b[K |███▉ | 10kB 14.0MB/s eta 0:00:01\r\u001b[K |███████▋ | 20kB 10.2MB/s eta 0:00:01\r\u001b[K |███████████▍ | 30kB 7.3MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 40kB 7.2MB/s eta 0:00:01\r\u001b[K |███████████████████ | 51kB 4.3MB/s eta 0:00:01\r\u001b[K |██████████████████████▊ | 61kB 4.4MB/s eta 0:00:01\r\u001b[K |██████████████████████████▌ | 71kB 4.9MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 81kB 5.0MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 92kB 2.9MB/s \n", "\u001b[?25hRequirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from dynesty) (1.4.1)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from dynesty) (1.19.5)\n", "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from dynesty) (1.15.0)\n", "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from dynesty) (3.2.2)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->dynesty) (1.3.1)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->dynesty) (0.10.0)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->dynesty) (2.4.7)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->dynesty) (2.8.1)\n", "Installing collected packages: dynesty\n", "Successfully installed dynesty-1.0.1\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "9899it [00:24, 235.54it/s, bound: 5 | nc: 2 | ncall: 63414 | eff(%): 15.610 | loglstar: -inf < -657.975 < inf | logz: -664.493 +/- 0.089 | dlogz: 404.221 > 0.010]" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "id": "45iomo51gruZ" }, "source": [ "sampler2 = dynesty.NestedSampler(lnlike2, ptform2, \n", " ndim = 6, nlive=1500, logl_args=(x,y,yerr))\n", "sampler2.run_nested(dlogz = 0.01) #However, model 2 is rather more complex, so this takes ~10 minutes.\n", "\n", "results2 = sampler2.results" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "YfZCv2H1b8Kt" }, "source": [ "These nested sampling runs have given us samples from the posterior and likelihood, estimates of the evidence. We can use these to estimate some of the test statistics coded above. First, we have to extract the maximum likelihood and evidence from the results objects." ] }, { "cell_type": "code", "metadata": { "id": "CmcPPqS4hY0q" }, "source": [ "mll1 = np.max(results1['logl'])\n", "ev1 = results1['logz'][-1]\n", "\n", "mll2 = np.max(results2['logl'])\n", "ev2 = results2['logz'][-1]\n", "\n", "aic1 = aic(mll1, 4)\n", "bic1 = bic(mll1, 4, len(x))\n", "aic2 = aic(mll2, 6)\n", "bic2 = bic(mll2, 6, len(x))\n", "\n", "print(\"AIC - Model 1: \",aic1,\"; Model 2: \", aic2,\" Delta AIC: \", aic1 - aic2) # \"; odds of model 2: \", np.exp((aic1 - aic2)/2))\n", "print(\"AIC - Model 1: \", bic1,\"; Model 2: \", bic2,\" Delta BIC: \", bic1 - bic2) # , \"; odds of model 2: \", np.exp((aic1 - aic2)/2))\n", "\n", "lr = likelihood_ratio(mll1, mll2, n_null = 4, n_alt = 6)\n", "print(\"Likelihood ratio: \",np.exp(lr[0]), \"; p-value (assumes N is large!): \", lr[1])\n", "\n", "bf = bayesfac(ev1, ev2)\n", "print(\"Bayes factor: \",np.exp(bf))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "oVKb8ikncs4J" }, "source": [ "All these tests show that model 2 is strongly preferred. Assuming the prior odds are even, the posterior odds favour model 2 by 14 orders of magnitude! However, it is worth noting that the Bayes factor is much more conservative than the Likelihood ratio. This is because it is accounting for the differences in **model complexity** and **prior volumes**. However, the information criteria do not give us odds of one model vs the other, just which one is better given the penalties for complexity (there are rules of thumb for this, however).\n", "\n", "It's worth pointing out that *neither of these models is the one which actually produced the data!* That was the following model" ] }, { "cell_type": "code", "metadata": { "id": "ZMo9_ad5nQGN" }, "source": [ "from scipy.stats import cauchy\n", "\n", "def true_model(x, centre, width, int_intens, baseline):\n", " bl = baseline[0] + baseline[1]*np.sin(baseline[2]*x)\n", " mod = int_intens*cauchy.pdf(x, centre, width) + bl\n", " return mod" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "mNkwVnrMjJqP" }, "source": [ "Where the line is actually a Cauchy profile, rather than a Gaussian. Test this out for yourself by calculating the same statistics for the true model. Bear in mind, the likelihood ratio will no longer be well defined because the models are no longer nested!" ] }, { "cell_type": "markdown", "metadata": { "id": "iiFvtkyRyZFW" }, "source": [ "#Introduction to mixture models\n", "\n", "Mixture models are a powerful technique; although they do not technically do model selection, I have included them here because they do classification, which is often what we are interested in doing when we talk about model selection. The basic principle is to marginalise out some *latent variables* which parametrise the classifications of the data. By treating the likelihood as the linear combination of the likelihoods under the assumption of each class individually, we can infer which class each point belongs. Alternatively, we can marginalise out the classes and see what impact this has on the fitted parameters.\n", "\n", "\n", "## Simple mixture model for outlier rejection\n", "A very common application for mixture models in analysing a dataset which has bad values or outliers. In this case, we don't know ahead of time which points are good and bad, but need to account for this when fitting." ] }, { "cell_type": "code", "metadata": { "id": "IcAzOPPfsIJ0" }, "source": [ "#Need to install tex so matplotlib will cooperate\n", "! sudo apt-get install texlive-latex-recommended \n", "! sudo apt install texlive-latex-extra\n", "! sudo apt install dvipng" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "HSS5N3LknQOH" }, "source": [ "# Author: Jake VanderPlas (adapted to PyMC3 by Brigitta Sipocz)\n", "# License: BSD\n", "# The figure produced by this code is published in the textbook\n", "# \"Statistics, Data Mining, and Machine Learning in Astronomy\" (2013)\n", "# For more information, see http://astroML.github.com\n", "# To report a bug or issue, use the following forum:\n", "# https://groups.google.com/forum/#!forum/astroml-general\n", "import numpy as np\n", "\n", "import pymc3 as pm\n", "\n", "from matplotlib import pyplot as plt\n", "from matplotlib import rc\n", "rc('text', usetex=False) \n", "from theano import shared as tshared\n", "import theano.tensor as tt\n", "\n", "try:\n", " from astroML.datasets import fetch_hogg2010test\n", " from astroML.plotting.mcmc import convert_to_stdev\n", "except ImportError:\n", " !pip install astroML\n", " from astroML.datasets import fetch_hogg2010test\n", " from astroML.plotting.mcmc import convert_to_stdev\n", "\n", "\n", "# ----------------------------------------------------------------------\n", "# This function adjusts matplotlib settings for a uniform feel in the textbook.\n", "# Note that with usetex=True, fonts are rendered with LaTeX. This may\n", "# result in an error if LaTeX is not installed on your system. In that case,\n", "# you can set usetex to False.\n", "#if \"setup_text_plots\" not in globals():\n", "# from astroML.plotting import setup_text_plots\n", "#setup_text_plots(fontsize=8, usetex=True)\n", "\n", "np.random.seed(0)\n", "\n", "# ------------------------------------------------------------\n", "# Get data: this includes outliers. We need to convert them to Theano variables\n", "data = fetch_hogg2010test()\n", "xi = tshared(data['x'])\n", "yi = tshared(data['y'])\n", "dyi = tshared(data['sigma_y'])\n", "size = len(data)\n", "\n", "\n", "# ----------------------------------------------------------------------\n", "# Define basic linear model\n", "\n", "def model(xi, theta, intercept):\n", " slope = np.tan(theta)\n", " return slope * xi + intercept\n", "\n", "\n", "# ----------------------------------------------------------------------\n", "# First model: no outlier correction\n", "with pm.Model():\n", " # set priors on model gradient and y-intercept\n", " inter = pm.Uniform('inter', -1000, 1000)\n", " theta = pm.Uniform('theta', -np.pi / 2, np.pi / 2)\n", "\n", " y = pm.Normal('y', mu=model(xi, theta, inter), sd=dyi, observed=yi)\n", "\n", " trace0 = pm.sample(draws=10000, tune=2000)\n", "\n", "\n", "# ----------------------------------------------------------------------\n", "# Second model: nuisance variables correcting for outliers\n", "# This is the mixture model given in equation 17 in Hogg et al\n", "def mixture_likelihood(yi, xi):\n", " \"\"\"Equation 17 of Hogg 2010\"\"\"\n", "\n", " sigmab = tt.exp(log_sigmab)\n", " mu = model(xi, theta, inter)\n", "\n", " Vi = dyi ** 2\n", " Vb = sigmab ** 2\n", "\n", " root2pi = np.sqrt(2 * np.pi)\n", "\n", " L_in = (1. / root2pi / dyi * np.exp(-0.5 * (yi - mu) ** 2 / Vi))\n", "\n", " L_out = (1. / root2pi / np.sqrt(Vi + Vb)\n", " * np.exp(-0.5 * (yi - Yb) ** 2 / (Vi + Vb)))\n", "\n", " return tt.sum(tt.log((1 - Pb) * L_in + Pb * L_out))\n", "\n", "\n", "with pm.Model():\n", " # uniform prior on Pb, the fraction of bad points\n", " Pb = pm.Uniform('Pb', 0, 1.0, testval=0.1)\n", "\n", " # uniform prior on Yb, the centroid of the outlier distribution\n", " Yb = pm.Uniform('Yb', -10000, 10000, testval=0)\n", "\n", " # uniform prior on log(sigmab), the spread of the outlier distribution\n", " log_sigmab = pm.Uniform('log_sigmab', -10, 10, testval=5)\n", "\n", " inter = pm.Uniform('inter', -200, 400)\n", " theta = pm.Uniform('theta', -np.pi / 2, np.pi / 2, testval=np.pi / 4)\n", "\n", " y_mixture = pm.DensityDist('mixturenormal', logp=mixture_likelihood,\n", " observed={'yi': yi, 'xi': xi})\n", "\n", " trace1 = pm.sample(draws=10000, tune=2000)\n", "\n", "\n", "# ----------------------------------------------------------------------\n", "# Third model: marginalizes over the probability that each point is an outlier.\n", "# define priors on beta = (slope, intercept)\n", "\n", "def outlier_likelihood(yi, xi):\n", " \"\"\"likelihood for full outlier posterior\"\"\"\n", "\n", " sigmab = tt.exp(log_sigmab)\n", " mu = model(xi, theta, inter)\n", "\n", " Vi = dyi ** 2\n", " Vb = sigmab ** 2\n", "\n", " logL_in = -0.5 * tt.sum(qi * (np.log(2 * np.pi * Vi)\n", " + (yi - mu) ** 2 / Vi))\n", "\n", " logL_out = -0.5 * tt.sum((1 - qi) * (np.log(2 * np.pi * (Vi + Vb))\n", " + (yi - Yb) ** 2 / (Vi + Vb)))\n", "\n", " return logL_out + logL_in\n", "\n", "\n", "with pm.Model():\n", " # uniform prior on Pb, the fraction of bad points\n", " Pb = pm.Uniform('Pb', 0, 1.0, testval=0.1)\n", "\n", " # uniform prior on Yb, the centroid of the outlier distribution\n", " Yb = pm.Uniform('Yb', -10000, 10000, testval=0)\n", "\n", " # uniform prior on log(sigmab), the spread of the outlier distribution\n", " log_sigmab = pm.Uniform('log_sigmab', -10, 10, testval=5)\n", "\n", " inter = pm.Uniform('inter', -1000, 1000)\n", " theta = pm.Uniform('theta', -np.pi / 2, np.pi / 2)\n", "\n", " # qi is bernoulli distributed\n", " qi = pm.Bernoulli('qi', p=1 - Pb, shape=size)\n", "\n", " y_outlier = pm.DensityDist('outliernormal', logp=outlier_likelihood,\n", " observed={'yi': yi, 'xi': xi})\n", "\n", " trace2 = pm.sample(draws=10000, tune=2000)\n", "\n", "\n", "# ------------------------------------------------------------\n", "# plot the data\n", "fig = plt.figure(figsize=(5, 5))\n", "fig.subplots_adjust(left=0.1, right=0.95, wspace=0.25,\n", " bottom=0.1, top=0.95, hspace=0.2)\n", "\n", "# first axes: plot the data\n", "ax1 = fig.add_subplot(221)\n", "ax1.errorbar(data['x'], data['y'], data['sigma_y'], fmt='.k', ecolor='gray', lw=1)\n", "ax1.set_xlabel('x')\n", "ax1.set_ylabel('y')\n", "\n", "#------------------------------------------------------------\n", "# Go through models; compute and plot likelihoods\n", "linestyles = [':', '--', '-']\n", "labels = ['no outlier correction\\n(dotted fit)',\n", " 'mixture model\\n(dashed fit)',\n", " 'outlier rejection\\n(solid fit)']\n", "\n", "x = np.linspace(0, 350, 10)\n", "\n", "bins = [(np.linspace(140, 300, 51), np.linspace(0.6, 1.6, 51)),\n", " (np.linspace(-40, 120, 51), np.linspace(1.8, 2.8, 51)),\n", " (np.linspace(-40, 120, 51), np.linspace(1.8, 2.8, 51))]\n", "\n", "for i, trace in enumerate([trace0, trace1, trace2]):\n", " H2D, bins1, bins2 = np.histogram2d(np.tan(trace['theta']),\n", " trace['inter'], bins=50)\n", " w = np.where(H2D == H2D.max())\n", "\n", " # choose the maximum posterior slope and intercept\n", " slope_best = bins1[w[0][0]]\n", " intercept_best = bins2[w[1][0]]\n", "\n", " # plot the best-fit line\n", " ax1.plot(x, intercept_best + slope_best * x, linestyles[i], c='k')\n", "\n", " # For the model which identifies bad points,\n", " # plot circles around points identified as outliers.\n", " if i == 2:\n", " Pi = trace['qi'].mean(0)\n", " outlier_x = data['x'][Pi < 0.32]\n", " outlier_y = data['y'][Pi < 0.32]\n", " ax1.scatter(outlier_x, outlier_y, lw=1, s=400, alpha=0.5,\n", " facecolors='none', edgecolors='red')\n", "\n", " # plot the likelihood contours\n", " ax = plt.subplot(222 + i)\n", "\n", " H, xbins, ybins = np.histogram2d(trace['inter'],\n", " np.tan(trace['theta']), bins=bins[i])\n", " H[H == 0] = 1E-16\n", " Nsigma = convert_to_stdev(np.log(H))\n", "\n", " ax.contour(0.5 * (xbins[1:] + xbins[:-1]),\n", " 0.5 * (ybins[1:] + ybins[:-1]),\n", " Nsigma.T, levels=[0.683, 0.955], colors='black')\n", "\n", " ax.set_xlabel('intercept')\n", " ax.set_ylabel('slope')\n", " ax.grid(color='gray')\n", " ax.xaxis.set_major_locator(plt.MultipleLocator(40))\n", " ax.yaxis.set_major_locator(plt.MultipleLocator(0.2))\n", "\n", " ax.text(0.96, 0.96, labels[i], ha='right', va='top',\n", " bbox=dict(fc='w', ec='none', alpha=0.5),\n", " transform=ax.transAxes)\n", " ax.set_xlim(bins[i][0][0], bins[i][0][-1])\n", " ax.set_ylim(bins[i][1][0], bins[i][1][-1])\n", "\n", "ax1.set_xlim(0, 350)\n", "ax1.set_ylim(100, 700)\n", "\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "sIpyOth4ykBD" }, "source": [ "## Gaussian mixture model to classify sources\n", "\n", "Here we will put together an example of using a Gaussian mixture model to decompose observed colour-colour diagrams into different classes. To start with, we will grab some data from by extracting a subset of 1234 sources selected from Gaia DR2 with crossmatches in 2MASS." ] }, { "cell_type": "code", "metadata": { "id": "ZgK3i5umPbgT" }, "source": [ "from astropy.table import Table\n", "\n", "url = 'https://raw.githubusercontent.com/sundarjhu/DAWGI_Lectures_2021/main/'\n", "file = 'Gaia_2MASS_phot_10000sample_data.vot'\n", "table = Table.read(url + file, format = 'votable')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "dueR6DzYPbyS" }, "source": [ "Now that we have this dataset and have extracted some colours, we can decompose this into classes. For starters, we use scikit-learn's implementation of Gaussian mixture models." ] }, { "cell_type": "code", "metadata": { "id": "KgLkj2kPzEE7" }, "source": [ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "\n", "from sklearn.mixture import GaussianMixture as GM\n", "\n", "\n", "\n", "table.pprint()\n", "print(table.colnames)\n", "\n", "data1 = table['bp_rp'].data\n", "d1_unc = np.sqrt(table['phot_bp_mean_mag_err'].data**2 + table['phot_rp_mean_mag_err'].data**2)\n", "data2 = table['tmass_j_mag'].data - table['tmass_ks_mag'].data\n", "d2_unc = np.sqrt(table['tmass_ks_mag_err'].data**2 + table['tmass_j_mag_err'].data**2)\n", "\n", "print(np.nanmin(data1), np.nanmax(data1))\n", "print(np.nanmin(data2), np.nanmax(data2))\n", "data = np.vstack((data1, data2)).T\n", "print(data.shape)\n", "mask =np.logical_and( np.isfinite(data1), np.isfinite(data2))\n", "gm = GM(n_components=2).fit(data[mask, :])\n", "\n", "from scipy import linalg\n", "import matplotlib as mpl\n", "import itertools\n", "color_iter = itertools.cycle(['navy', 'c', 'cornflowerblue', 'gold',\n", " 'darkorange'])\n", "\n", "def plot_results(X, Y_, means, covariances, index, title):\n", " splot = plt.subplot(1, 1, 1)\n", " for i, (mean, covar, color) in enumerate(zip(\n", " means, covariances, color_iter)):\n", " v, w = linalg.eigh(covar)\n", " v = 2. * np.sqrt(2.) * np.sqrt(v)\n", " u = w[0] / linalg.norm(w[0])\n", " # as the DP will not use every component it has access to\n", " # unless it needs it, we shouldn't plot the redundant\n", " # components.\n", " if not np.any(Y_ == i):\n", " continue\n", " plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], .8, color=color)\n", "\n", " # Plot an ellipse to show the Gaussian component\n", " angle = np.arctan(u[1] / u[0])\n", " angle = 180. * angle / np.pi # convert to degrees\n", " ell = mpl.patches.Ellipse(mean, v[0], v[1], 180. + angle, color=color)\n", " ell.set_clip_box(splot.bbox)\n", " ell.set_alpha(0.5)\n", " splot.add_artist(ell)\n", "\n", " plt.xlim(0., 4.)\n", " plt.ylim(-1., 3.)\n", " plt.xlabel(\"BP - RP\")\n", " plt.ylabel(\"J - K_s\")\n", " #plt.xticks(())\n", " #plt.yticks(())\n", " plt.title(title)\n", "\n", "plot_results(data[mask], gm.predict(data[mask]), gm.means_, gm.covariances_, 0,\n", " 'Gaussian Mixture')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "WzdIycdxzGXm" }, "source": [ "We only tried to identify two classes, and it seems to reproduce the structure in this plot reasonably well. We have only compared two colours - [Bp - Rp] and [J - Ks]. However, there's nothing stopping us from turning this into an N-colour decomposition of the sources, using all 5 colours in this cross-match.\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "xSnpkVTL7DNK", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "d9465dd5-ba73-4850-c2b9-f5e76753ea8c" }, "source": [ "data1 = table['bp_rp'].data\n", "d1_unc = np.sqrt(table['phot_bp_mean_mag_err'].data**2 + table['phot_rp_mean_mag_err'].data**2)\n", "data2 = table['phot_g_mean_mag'].data - table['phot_rp_mean_mag'].data\n", "d2_unc = np.sqrt(table['phot_g_mean_mag_err'].data**2 + table['phot_rp_mean_mag_err'].data**2)\n", "data3 = table['phot_rp_mean_mag'].data - table['tmass_j_mag'].data\n", "d3_unc = np.sqrt(table['phot_rp_mean_mag_err'].data**2 + table['tmass_j_mag_err'].data**2)\n", "data4 = table['tmass_j_mag'].data - table['tmass_h_mag'].data\n", "d4_unc = np.sqrt(table['tmass_h_mag_err'].data**2 + table['tmass_j_mag_err'].data**2)\n", "data5 = table['tmass_h_mag'].data - table['tmass_ks_mag'].data\n", "d5_unc = np.sqrt(table['tmass_ks_mag_err'].data**2 + table['tmass_h_mag_err'].data**2)\n", "\n", "print(np.nanmin(data1), np.nanmax(data1))\n", "print(np.nanmin(data2), np.nanmax(data2))\n", "data = np.vstack((data1, data2, data3, data4, data5)).T\n", "print(data.shape)\n", "\n", "mask=np.logical_and(np.isfinite(data5), np.logical_and(np.isfinite(data4), np.logical_and(np.isfinite(data3), np.logical_and(np.isfinite(data1), np.isfinite(data2)))))\n", "gm = GM(n_components=2).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))\n", "gm = GM(n_components=3).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))\n", "gm = GM(n_components=4).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))\n", "gm = GM(n_components=5).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))\n", "gm = GM(n_components=6).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))\n", "gm = GM(n_components=7).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))\n", "gm = GM(n_components=8).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))\n", "gm = GM(n_components=9).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))\n", "gm = GM(n_components=10).fit(data[mask, :])\n", "print(gm.bic(data[mask,:]))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "0.04345131 4.3446026\n", "0.07635212 1.6862373\n", "(1234, 5)\n", "-11121.982069612812\n", "-11851.11774461043\n", "-12248.949883009822\n", "-12243.698226889865\n", "-12318.92551412302\n", "-12344.832753915784\n", "-12604.106463966884\n", "-12446.384596799428\n", "-12443.159585647332\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 284 }, "id": "27I8ASKo-o1X", "outputId": "679448a9-ffab-46d1-ab16-8fa61f5df9bf" }, "source": [ "gm = GM(n_components=8).fit(data[mask, :])\n", "plot_results(data[mask], gm.predict(data[mask]), gm.means_, gm.covariances_, 0,\n", " 'Gaussian Mixture')" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAELCAYAAADdriHjAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxcdb3/8dcnkz1t0iUpbekiXSlt6TZtKdByEShShQsPAasoilcjigriVXG5V+RxRS8uF64Ll+hPFB7VCqKI7Mq+Q8rSpCwtBVu60CTd0jTTySTz+f0xk3ROyNKk6WTSvp+PRx/JnO+cM5+ctud9vt/vOSfm7oiIiLTK6u8CREQksygYREQkQMEgIiIBCgYREQlQMIiISICCQUREAhQMIiISoGCQAcnMLjGzL5vZv5nZNWb2+T7e/gfM7ORernuOmTWb2aKUZcvMbKuZLTGzkWb2xS7Wf7Q3nyvSV7L7uwCRnjKzU4Dp7v7vKcvmJb/+B/AOMAP4AfA+4CZ3D5vZfwJZ7n61mf0UWAOcCFwKfBvYBMwGfgicDDQDT5rZdUA1cHzyfacC1wM3AGcC17r78621uPtdZvZ74AdmdiqJ/2fHAbvd/XEzOx0438zuBH4F/A04GtgAvA6MMrPPAC8At7n7VDO7GHh/staVQC2wGYgD24EIUAqsd/c7+mA3yxFMPQYZiOaTOFBjZoVm9ingkmTbP4E8YAhwhruvAhqSbY+nbONYEgfs60gEwASgBLgR2Ab8I+W9G4EcYBIwx93vT27z/4BfAmd0UONbwDPAx4CLgd+1Nrj7P5JfNwEfBc4HNrl7hbs/Dmx191+7+yvA1tTa3X0fcCdQ5e5XAz8BLiMREHXAgq52nMiBUDDIQPQCMAvA3Rvd/bfA6WY2DPicu99E4qCc3269wpTvPw2sBX5DIhT+HXgC+DFwQuubzGwOsMjd/x9QlbLNBk88T6aJRGh05PvA5YkyvbaT9zQCLwMnpSxzS8gCYmaW3a52gPqU7/e5+2/d/TfAg518jsgB01CSDDju/piZTTSzK0icuR8F/B7YBdSY2WXAYqDRzH4HPGdmXwLKgNnJAPk+ibPwl0kMxVwNvEQiLDaTCI7jgZuBIWZWDiwCWsxsD4nhnrnAEmCumRW7ez2AmZ0FzAWKgIuAfyaXlZjZSUBBcv1TSPR0fg78R3J46+vATuBaEqF1F4nhqx3AZDMbn/xMzOxud68zs9+Y2ddIDCe92qc7W45IpofoiYhIKg0liYhIgIJBREQCFAwiIhKQ1slnMxsJXEBicm0Z8AN3r0q2zUm21QLb3f2WdNYmIiIJaZ98NrM8EtddzwE+4+7R5PJ7gYuTV1k8C5zk7i1pLU5ERNJ/uWoyCH5qZt8DPgS03qU5zt3rkt/vJXEX57bW9fLz8z0UCrVtp7S0lLKysvQU3QO1tbUZWVd7qrPvDIQaQXX2tUyus7a2lrq6xOG0sbEx6u7t7+npUrqHkk4GXk8GwGZgbErzBjMrTbYVkbiLs00oFGLv3r3pK7aXwuEwlZWV/V1Gt1Rn3xkINYLq7GsDpU4z6/HIS7p7DLnAd82sksTNQleb2V3ufg7wHeBKM6sDbtQwkohI/0hrMLj7w8DDyZetz445J9n2Eok7T0VEpB8NmMtVS0tL+7uEA1JeXt7fJRwQ1dl3BkKNoDr72kCpk8SVnj0yYB6JEQ6HfSCM54mIZBIzW+Xu4Z6sM2B6DCIikh4KBhERCVAwiIhIgIJBREQCFAwiIhKgYBARkQAFg4iIBCgYREQkQMEgIiIBCgYREQlQMIiISICCQUREAhQMIiISoGAQEZEABYOIiAQoGEREJEDBICIiAQoGEREJUDCIiEhAdro+yMzOAk4G3gYWApe5e1Oy7avA5ORb/+7ud6SrLhERCUpbMABVJA76zWZ2JjAeWJdsiwHPA4OA1WmsSURE2klbMLj7JgAzmwK86e7rUppvcfddZlYIPAQsar9+bW0t4XC47XV5eTnl5eWHuGoRkYGnoqKCioqK1pelPV3f3L1vK+rqw8wWATOAlUCBu9ckl4fdvTL5/Rp3n95+3XA47JWVlWmrVUTkcGBmq9w93P0790v3HMN1wFPA6cBfzWy5u58DnG9mYWA48JN01SQiIu+VzqGk+4D72i3+fbLtqnTVISIiXdPlqiIiEqBgEBGRAAWDiIgEKBhERCRAwSAiIgEKBhERCVAwiIhIgIJBREQCFAwiIhKgYBARkQAFg4iIBCgYREQkQMEgIiIBCgYREQlQMIiISICCQUREAhQMIiISoGAQEZEABYOIiAQoGEREJEDBICIiAQoGEREJyE7XB5nZWcDJwNvAQuAyd29Kts0BLgBqge3ufku66hIRkaC0BQNQBfzd3ZvN7ExgPLAu2fZ94GJ3rzOzZ81shbu3pK5cW1tLOBxue11eXk55eXm6ahcRGTAqKiqoqKhofVna0/XN3fu2ou4+0GwKcIm7fzNlWbW7z0h+/xDwMXfflrpeOBz2ysrKtNYqIjLQmdkqdw93/8790jrHYGaLgFOAa81sRErTBjNrTbUioC6ddYmIyH5pC4bkHEMFMA/4NXC6md2VbP4OcKWZXQnc2H4YSURE0idtcwzufh9wX7vFv0+2vQS8lK5aRESkc7pcVUREAhQMIiISoGAQEZEABYOIiAQoGEREJEDBICIiAQoGEREJUDCIiEiAgkFERAIUDCIiEqBgEBGRAAWDiIgEKBhERCRAwSAiIgEKBhERCVAwiIhIgIJBREQCFAwiIhKgYBARkQAFg4iIBKQ9GMzsU2a2Md2fKyLpE2lyHl+zj0iT93cp0gvZ6fwwM8sBngAaO2j7KjA5+fLv7n5HOmsTkb7z1OtR/vhkI00tcPrx+f1djvRQWoPB3WPAejPrqDkGPA8MAla3b6ytrSUcDre9Li8vp7y8/BBVKiIHxT34VdKqoqKCioqK1pelPV3fvB/+4szsdXc/tt2yIe6+y8wKgYfcfVFqezgc9srKyrTWKSK9E2lyXlgXZf7kPApy33si2F279B0zW+Xu4e7fuV8mTT5PAnD3RqC4n2sRkYNQkGssmZ7f6UH/hXVRbn2skRfWRbvdVn19lIqKVdTXd/9e6RtpHUoCMLOPAiVm9mlgJbDS3c8BzjezMDAc+Em66xKR9Jk/OS/wtSsrV1bzuc/dDUB5+bxDWpck9MtQUm9oKEkkPTJtmKe+PsrKldUsXz6D4uLug6Q7mfbzHWoDfShJRDJAT4Z5eqOnQ0PFxXmUl8/rdSjU10e54YbnuOGGZ6mvj3b482m4KijtQ0kiknlSz6IPZJjnYM7iW4eGIpEYBQU5fdYT6OrzrrjifgAKCnL4xKfmAsGfT8NVQQoGEWk7iwZYMj2fJdO7vvfgYA6ky5fPACASaU7LwXj58hlEIs2As3z5jLaJ8Y5qav3a18NXvdVvw17uPiD+zJs3z0Xk0GiMxv2x6oi/Wxvxm26q9N2793X5/t279x30+w50G52t9847u7tcv7fbd3e/6aZKh6v9ppsqe7xuX3qsOuKf+cV2f6w60uttAJXew+Otegwi/awvzgoPdhutZ9EVFau6PYtPPZsG+PkvK5k4bzJzJ+bx1z+vaTvLrq+P8oUv3MOKFVVEIjEA9u1rBiA/P4dLLpnd9hkdnaF3dtbe2lu56KKZrFhR1WmtN9/8MldccT+RSDOXX76wR/ujfQ+iv/Tk6q2+pGAQ6SPuTjQGjU3OvibHcbLMyDISf7LAUl7nZBv5Oe8dxumNvtgGvPeAuHdvE6+/Xscf/7iGGTNGUFKSx9NPv8N11z3dts6XLruHkz96KseMyObWG/4OJA7UK1dWs2JFFRdeeBxPP72J225bE6z5hc388pcfBGgLkNZ1YX8APP74hrb3rVxZzbJlk7nppg+xbNlkliwZ38XB29t9PXCtE979raNhr3RQMIh0o7Oz8cZonJrdcTZvb6amPk59Y5y4Q+s7jOQhyZLLWp8SkWxwIJQFhblZLJycQ1G+samumZKiLIryjayOHx3ToY7OLHsyTt763mXLJtPYGOO55zbx8svbeO21Wt54o46nn97ESSeNZerU4UQizSxePI4lS8YzevRgmpodH30MletjfPGq04hEmqmvj6bMJcS44ooHuPDC45g7dxQAjzzyNitWVDF58nBGjRrEihVVLFhwNMuWTW6rafnyGTz++AZWrKhiyZLxge386lfndHvwvuSSOW2T29IzCgaRbqSejZ88LY936ppZvSFGXX0cs8TBPS/HKCkwsrJ6NowTjztNLU5BrrF6QwwjMeQSyoLhg0OMLQ1x9PBshhYZnTxjDOj4zLInE8St7124cDTPPbeFCROGcPrpExgzppijjiqioCCHrVv3UFKSz9ixeWRlGX/4QxVf/eqJfPELYSJNzqx1UVY/ls+XLruHgoJsysvnUV4+j/r6aOAAvXJlNdu27QXgzjtfY/nymcybN4rnn9/MHXe8RkFBdluY/fKXH2zrFdx888sA3Hbbq5x22oRuf6ZMOesfiBQMIt1oPQufOCqHO5+PsKshTmEeDBvU9cG6M7EW2LKjmdHDsskJGflZkJ+zfzvRaAtVVTX41DJqdrVQ+WYThfnGpJE5jC0NUTo464ACaNmyyVx00czAWXhq7ye2r4mVK6tZunQiDQ1NLFo0hmOOGUJDQ4w1a2rZsmUPo0YNBuC11+rYsmUPu3ZF+fCHj2Xjxt2sWFHFsGEFbQf9JdPzmVAyhWef3sjixeOoqFj1nrP11gC6+upTyM3NZtmySVx11T8AuOiimYAHwiz14H7JJbNJ9LNMvYBDTMEg0o2CXGPq0Tk88PI+srOc0uKDuy90y45mVm9ITMKOLwv+F4xGW7jn3nVUrd7GhwzmzUsMvURjTtWGJlZvgNxsY/rYHCaPyqYov/Na7r13XdswTOvB9dGX6rn2hlf41uWzeOOZV/nKVx5kwYLR5OaGWLDgaNav38HSpRMZM6aYWCxONNpMdXUNW7bsoaw0xLCiddx/H/zox2cxblwJu3ZFueKKB4hEYlx++QncccerrFhRRSzWwm23vdpWS+vBPnUO47vf/Rfq66MMGZIPWPLAT6fDP8XFeVx++Qm93OvSEwoGOShHwuMFWuLOo9X7yA1BYd7BPyxg9LDswNdU1dU1VK3exszjj2LGjBFty/NyjLxkryLW7Lz0dhPPvxGh5u13+eR54xg7siCwnfr6KJFIjOuv/0DgILt+1Tqe/MMjVETfZcPbOwHYtSvK2rXb2b07SlVVDTNn7mbEiCIeeuhtVq/exnnnHcsHPziJIYUbGJq7ga21dfz61yO5++51zJzZWqMFvs6dO4rTTpsQ+OzW4aHU4Z2ODvYa/ul/XQaDmf0b8GfgB8CL7l7R1fvlyNNXV8Nksp0NcSJNzrBBffMEmZzQe3sKrVrDYMaMEeTlhTpeP9sYPsh45rmtPPHMNnY357JsSRmvP7uWiz46neLivOTdvg9w/fVncvPNL7Nv3/7LRWfNOoo7//wa2aFmPnR6hMkz5jBlyjDGTRrKthKoWlXHoJwYuYUtjByylof+AUOHlbD45AUUxDbyuzuHMnFyPQBVVTUsWHA0H/7wNCAx3JM6R9BKB/uBpbseQyEwA3gbODz/18tB6a/rrNMpNzt9PaG8vBDz5o0iGm1h1aqtbQERjbZQXV0TCAwDIrv3su2d7fztkTjPPLWLhx6/n/nHDweM668/E7C2x0G0d/riRuZNfZEX6/IZN+kkHt1UR83EAsgaSsO4IuYN3sz0ITWQV8ottzUwdmwJt98+innzRnHjjR/kxz9+hlismb/85Q3uvXfde+YEZODqLhhGAJcCVwLnHfpyZKDpr+us02lQvlGYZzRGncK89IREdXUNd9+9FkjMM6S+njx5GA8+uJ4hQwsYNXowa9bUMmrnPnCnxobzo1+sZvs7tVz7X2HmTHyM8z49nTeXjuf0t2Dd05t5440drFu3A/KO48XdBTwQGsfwlzaydUsDTB0EGyMU1sS44DNLOf/94yl93/mcdNo77NwZ4fbbX2XVqq289NK7/PGP51NfH+UDH6jWZPBhprtguB8YCsSByKEvRyTzZGUZS47L554XI+RmQ3ao43Do6Ky+veAVSZ1/ZuqQEiTCYPr0EexpaGLlH9ewdcseAHJzmznlhN0882IzQ0YMZfugEOGzw7xy18PMG/1dls5exa6hiyjP/zb37YuyuGAlj26ewYknTqK0bChN8SGM2bSD48uGUJ9fSHG4jEhoB7/4yQeYOHEYcBoA5eUjqK+Pkp+fTepVQeohHJ66DAZ3fyrl5e8AzGyZu997SKsSyTAjh4ZYNCWXZ95ooriAtongVO3P8jvS3RVJ1dU1TJ1WxtadLbT4/uX3PfQWr8ci8FQdAGUnlDAt521CxY0sed+bjCwr5YnCpbTk5zLOipgx51hOW5i47n/5+Gd4ePfDkLWHimvu5iMfmcH8Mz7OoEG5NDQ0EYu1UFycR05XSYWuCjqS9OaqpMPz0hORbhxzVA5rN8fYsdeJNEFJIYH7GNqf5b+HRxlTVE3zqCls27CTkcXBnkV1dQ0P/P1V9hYYz+waybrBezn1nRpOH/QUuybPYe7QPWyYPZoZ+VuJ5e7m1LzV7I4nrkY6dlwd66JvsyU2lS05MdYXTuXmjVex9k/PcsZ5s2gaewGzWiJESmZx2rmfgFBiTqg/nxwqmas3wTAwfuWbSB97YV2Ue16Mcv6iArIMNta1UJDrFOUlbnRrnTjuVHQNoca72bHxDO57MAbezIK5Q2jZex9byKF22kzyx7fwRG6UhqO3URYPMal5NUdZLR8f8RhFWVFeGTKBWTlv8Y5PAmCf51FChObQCEqL5jA9BqeUlLBuVxPhad8gvmAtCxfP4IziPDge4NS07CsZ2A44GMwsz92jqMcgR6jUK7Dyc2DLzhZeXN9EbX2ceEszG9+sY+bMTuYXPAoeI1qwlMiYJt5/+ga2jy7gyT2Pssby2QvsjK2nJDfCsBa40F9mQnwPj8VmsTN7F0OzGnglNoFX4gspCh3D2MEziTSt5ZW9JfxLbhW5RUtZFErcpRxrdo4blM200UXM1vi/9EKnwWBmZ7v731IWfRn4kbvfc+jLEsk8BbnG7LHGrb99keXLZzA4G7a8+BpLzz6OFXf+k0ef3kkTMPO4oRT5amK0UB0tYkbRcLL33s6mWB5VLdOJ59Xw8tRjyM95lxYrIZs4+d7C4vgmQk1FvLxvJm9m5/FuKMRrLZPIqp3Fwgk1xPMm8pGioeRlTUkUlD2fpYUAU9pqbGp2dkecJdPyOpwHETkQXfUYvmRmD7h7k5nNA84AfnSwH2hmnwKucfdxB7stkc4cqjuy2x5MF49A5GkuvTLCTY0P8eVz4xy1628sXngir+49gRe3tTDUGgiRx66mZxmc3UBNTi5v5u2gISsfYw95xBgVb6Q4lk28aTD5WWPZlnUcs4qL2NU8gpOGDmVxdjaMBXgfc7qoq7klEQihLHj/zDyOGZHTZz+zHHm6CoZPApea2Z+A/wHOPdgPM7Mc4AmgsYO2OcAFQC2w3d1vOdjPk8NbVwf/g74ju6Ue9qyEwcsT39ddBcO+zUWn/56jfn8Mp55YCbX/wbzbhzB7Wg2haCGXnNtIdtbTtOSO4+zh29gWmcCG+tncs/cUHgvNZ6flUexNjI01MDheyKysBuKM4onYJCYUDWFhcTF5WQd2d3U87kSbE/sAEk9jnXtMLlPH5AQeyCfSG10Fw3TgYeBPJEJiLvCPg/kwd48B6zt5IuX3gYvdvc7MnjWzFe7ecjCfJ4e31IP//Ml5gZDo9I7s1AN+qLjz5XtWwrufS7Q1Pg71KyD6OkVNq/jXuUW0RHLxwibmzahJbqCR7OQxfXrORgDGDHqN0sK1DG9+gqOjC2jKOpXjQkN4uS7E6PwTiMcKiDucmyxxbyPsJQ4kr/Dw/Vd6GJD638YMhhZlMXV0DqOHhxg+OItQDx/5LdKZroLhFmBX8j13AcXA0YewlnHuXpf8fi9QCmxrbaytrSUcDre9uby8nPLy8kNYjvS37oaDUg/+7XsInd6RnXrAH5L899NSD9u+kDj4ty7PngwtQ4i8tYLq/AJmZRUQir5MCMD3EmIvdHBy/8y++az2afxrwZMU0MgD+ddzdvEWLsvKgpKLIVTMmcn3ujuRJmdv1Ik1Q9yduIPHIe77X8fjiSDIzzUKc42C5J+e/u4HOXJUVFRQUdH2aLvSnq5vid8V3UGD2UmpN7i1f30wzOx1dz+23bJ7gE+29hiAk1J7DOFw2CsrK/vi42WAeHzNPm59rJFPnFLYdpDvLCwOeE6hfc+gaRO8cybEXgUrg5zjiOzcRiTvHYbYXrK6vucLgGZCeGgC78QuZNSEr1GQX3KwP7pInzGzVe4e7v6d+3XaY2gfAn0YCh8FSszs08BKYKW7nwN8B7jSzOqAGzWMJB0NB3U0fDRjfC7VG5oCofCeoPBm8JZEMNT/Dbb/HIrOIF53C1nxOsgFvBaaHqOgCAreU03izD1GCLKGUZ19IccNHklBVpTsIV+GnDImHOodIpImnfYYMo16DALBA35rSCycnMtz65r29yxa6ln/xvXs3PEIOwblc2rxTnJaNiVCoTECO5thOLARmECnp0dxYF98EM/ET+GUnKfILv4Y5E2BkkuC8xMiGaxPewwpGy129/relyVyaLT2JGaMz2Xa6Ebmj7gJahtp3nU/ExseZ/eQAl6PjyOn6Y39KxUm/wBMBFKGiuJxaLsoKOdYssb+ncLcMcnHyIkcOQ7kzuf/InFzm0i/C0wyH9vECSNuZtPGWk7wvxBa9SoUQnYeUAQl2RFmN6+jeTdkFwKpl/bHSYRCHCg6DQafTVbJJYm2jq5aEjmCHEgw7DrkVYik6uiS0uSy+cecxUi7i7HF0+DX55DbsDcxtj+W5LOAgvKy49AEvJJszwlBXgGULAfbDvkLYejngyEwRFe7yZGt22Bw9/9MRyEibZKXlK7dEmPspC9QENoDmz8J/7yTghdhSkeXJWwlMWPcenwvAnYCJTnASfDxW6FwTJp+AJGBrTdPVxU5eJ3daAZQuIza6DnE195K9j/+HQbvS5z1lwDDSNwb314zsJbEvQXjh8OIPDjuTiicf4h/EJHDT1cP0fskcF+7xRF333NoS5Ijwu6boeaKxDOHSi6BnTdCw3M0rR1MqOpWysY4ZXEglvxjya87OthWVjZMXQhzz4ayz2tuQOQgddVj+CaJR2Gk3i2UY2ZV7v75Q1uWHBEiwD3XQs1XaH34Q25r21Ygj8S/vgJgO9B6SpJdCNMWw8m/1vCQyCHQVTBc6u6Ptl9oZl88dOXIYa2lHrbeCM/cD5ufghaAmk7ey/5HLe4dDcfPhRNuVBCIpEFXdz4/2n6ZmQ0C3eApB6h1HqFhMtx9EUR3QHP0AFbMgpxCOO2zcOzVGhoSSbMeTT67ewNw5SGqRQ4X0Xp46X/g1Rtg584DW6fwaJj1WZj3FchTEIj0J12VJD3T2dVE0Xp45ltQ/WtozoKWSDcbMph+McydA8P1iAmRTKJgkJ5pfWx1PA6hD8DuN+HPZ0PLvi5WMggVwbELYfYZsOsqGPl/upFMJEMpGOTAtNTD9pvh1bfg1WKIfCHZ0MFDGEMFkF0ALU0w6gQ462YYPGb/dvKHJnocIpKRFAzStWg9PP9D4q/8jKxoQ9fvzcqH2Z+HE6/ufJ4gVKyegkiGUzDIe0Xr4ZUbYdOjULcK9tR29MvKEiwX5lzWdRiIyICiYJD9tr8G938KRp2UuKqonTgQH30K2RPPgmkXwdv3wtTlCgSRw4yC4Ui2ZxM8+lXILYa1f0pMKDfXQ131/vcMLks8h8hCZJ17F1mjUp49dLyGhEQORwqGI0W0Ht5YmTjDh8RQ0Us/g4bNwfdlF0HzXhgxD6ZckJgzUI9A5IiiYDgSROvhoS/Aaytg0+Nw1Hx44qr97SUTILIDRs6D9/8MNj+hISKRI5iC4XDUvnfQGgojFyS/zofFP4TNT4LlwGn/u/9yUoDh0/qnbhHJCAqGw0ikYTcbH/kVJbUPM3LnfUTX3c87+Scy6fUVicnixT9sN2H8jf4uWUQyUFqDwcxygP8GVgOzga+5eyzZdi2JX8MCsLKjh/hJxyINu9n82K8YtONZptbc0bY8759/4YXs+cSn/5wpp34iEQaaMBaRbqS7x3AesMXdf2tmXwfOBW5PtjUCrwNDk1+lK9F6mtb8gRdC51G64XdMXfd1ALaUfZiGwqmU7K1m0IRFjCv+NGOnlUGudbNBEZGEdAfDBGBT8vutBB/h/XN332VmY4DfAMtSV6ytrSUcDre9Li8vp7z8yDr7jTQ5L6yLMn9yHgVvrCT3kUtZn9PI8AmJg/7zWf9KbNYvOWnWiLZ1TuqvYkWk31RUVFBRUdH6srSn66c7GNYD45PfjwLeSmmbBFQCdcBR7VcsKyujsrLykBeYyV5YF+XWxxK/vWbJ1OU0NTsTQ+cxfmIeTaMHEwudx9xpZf1cpYj0t9QTZzOr6+n65t7BQ9AOkQ7mGH4BfN3dP2tmNwHPAxOBp9z9ntR1w+GwHwnBUFffwpOvRTlzTgHA/h5CrgV7DBoaEpEDYGar3D3c/Tv3S2uPITnR3P4X/Xw22fa5dNaSqR6p2seDr0TJz4HCvKz9PYTp+RTkGkum5/dzhSJyuNPlqhkgtSewcGoea7c0M2FkDmNLE3898yfn9XOFInIkUTBkgNS5g8XH5fHJ9xcxoiREbrZ6CCKSfgqGDNDaI5g/OQ8zY8xw/bWISP/RESgDaO5ARDJJp79/RUREjkwKhj4WaXIeX7OPSFP6LgMWEelLCoY+1jqR/MK6aH+XIiLSK5pj6GOpE8kiIgORgqGPaSJZRAY6DSWJiEiAgkFERAIUDCIiEqBgEBGRgCM2GHS/gYhIx47YYND9BiIiHTtiL1fV/QYiIh07YoNB9xuIiHTsiB1KEhGRjh2WwaCJZRGR3jssg0ETyyIivXdYzjFoYllEpPfSGgxmlgP8N7AamA18zd1jybalwAISvZhn3f3B3n6OJpZFRHov3UNJ5wFb3P23wBbg3JS27wE/AK4Fri33usIAAAdTSURBVElzXSIikpTuoaQJwKbk91uTr1sNdvcWADMb3H7F2tpawuFw2+vy8nLKy8sPYakiIgNTRUUFFRUVrS9Le7p+uoNhPTA++f0o4K2Utj1mFgIMqG+/YllZGZWVlYe+QhGRAS71xNnM6nq6frqHku4ERpvZp4DRwMtm9qtk23eBbwLfAq5Oc10iIpKU1h5DcqL5ynaLP5tsexDo9YSziIj0jcPyPgYREek9BYOIiAQoGEREJEDBICIiAQoGEREJUDCIiEiAgkFERAIUDCIiEqBgEBGRAAWDiIgEKBhERCRAwSAiIgEKBhERCVAwiIhIgIJBREQCFAwiIhKgYBARkQAFg4iIBCgYREQkQMEgIiIBCgYREQnITtcHmdl44EvAOmCwu/+4XfufgLrky2vcfUu6ahMRkf3S2WP4BrDS3W8CTjGzke3aa4AngdXsDwgREUmzPu8xmNlRwIp2i98FSpNfIREC41NeA3zL3XeZ2ceBK4DrUjdQW1tLOBxue11eXk55eXkfVy8iMvBVVFRQUVHR+rK0p+ubu/dtRZ19kNkvgJvdvdLM/gZ81t3fTbYVAyPdfa2ZLQXOcvevpK4fDoe9srIyLbWKiBwuzGyVu4e7f+d+aZtjINED+JKZhYHH3P1dM/soUAjcBXzLzB4GTgB+3MV2RETkEEpbj+FgqccgItJzvekx6HJVEREJUDCIiEiAgkFERAIUDCIiEqBgEBGRAAWDiIgEKBhERCRAwSAiIgEKBhERCVAwiIhIgIJBREQCFAwiIhKgYBARkQAFg4iIBCgYREQkQMEgIiIBCgYREQlQMIiISICCQUREAhQMIiISMGCCoba2tr9LOCAVFRX9XcIBUZ19ZyDUCKqzrw2UOoHSnq6QtmAws0Fm9j0zu7c369fV1fV1SYfEQPnHojr7zkCoEVRnXxsodQJlPV0h+1BU0YkS4E/Awo4azexiYBhwFHCbu7+UxtpERCQpbcHg7pvNLKejNjMLAZe6+4lmNgK4Gfhg6nsaGxujZtaSsqgWyMRuRKmZZWJd7anOvjMQagTV2dcyuc5S9vcUQj1duc+DwcyOAla0W/yuu3+8i9XKgEYAd68xs/Ht3+Du+X1XpYiIdKbPg8HdtwGn93C1WqAQINlj2NDXdYmIyIFJ5+RzMfARYIyZnZVcdpKZXe3uLcD/mdkVwJXAd9JVl4iIBJm793cNcoiZ2aeAa9x9XH/X0pWBUqdIJjOzHwJD3P3S3m4jnVclHZDkBPV/A6uB2cDX3D2WbFsKLCDR03nW3R/M0DqvJXGFFcBKd3+0X4qkrc4nSM7htGubA1xAYihvu7vfkubyUmvpqs6vApOTL//u7neks7Z2tZwFnAy8TeIKu8vcvSnZlkn7s6s6M2J/mtlIEvtrJ7AM+IG7VyXbMmlfdlVnRuzLVmZ2LDAGaGi3vEf7M+OCATgP2OLuvzWzrwPnArcn275H4h+7AU8C/RYMdF1nI/A6MDT5td8kw2q9mXXU/H3gYnevM7NnzWxFclgv7bqpMwY8DwwiEcT9qYrEAaDZzM4ExgPrkm0Zsz+7qTMj9qe7v2tmFcBlQAuwNqU5Y/ZlN3VmxL5M8W/Ab4AL2y3v0f7MxGCYAGxKfr81+brV4NYfxswGp7uwdrqq8+fuvsvMxpD4S1qW7uIO0Dh3b73cbi+JS9y29WM9nbkluT8LgYeARf1ViLtvAjCzKcCb7r4upTlj9mc3dWbS/owCPzWz7wEfAlrPuDNmX0KXdWbMvjSzC4C/0vFxvUf7MxMfibEeGJn8fhTwVkrbHjMLmVk2UJ/2yoK6qnNS8msdiRv2MtUGM2u9Xb6IzLwvBJL7090bgeJ+rgUzWwScAlybvIquVUbtzy7qzIj9aWYnp+yvzcDYlOaM2Zfd1JkR+zJpCol6PgBMNbPTUtp6tD8zbvK5g7H7XwBfd/fPtptjeM7dH8jQOm8i0b2cCDzl7vf0V50AZvZR4KfAt4GVJOY9zkkZd6wjMe74u34ss6s6fwj8ExgObHX33/RjjWcB1wFPkRgq/CuwPNP2Zzd1ZsT+NLP3kxiSrQT+Bbga+FkG7suu6syIfdnKzIYB3wTmA98Fvtqb/ZlxwSAiIv0rE4eSRESkHykYREQkQMEgIiIBCgYREQlQMIiISICCQUREAhQMIrTdxLTVzD5jZt80s58ll5+VsvwaM/vyAWzrrHbbuq632xLpD7qPQSTJzF5392Mt8cCmV919WrvlhcAT7j6vh9uqcvcZvd2WSLpl4rOSRPpLKPno7ynAvSnLS5LLjwLKWxeaWRbwv0CLu1/eblslZvZpEr+i9svtlr9nWyKZRD0GkaTWs/nk9zcAf3P3f6Qu7+m2zGw+8BV3/1j7zxDJVJpjECExx0DibP4zZlZO4vdprE4+c6gk+fVAt9W6zqnu/gIQNbNv9GZbIv1BPQYREQlQj0FERAIUDCIiEqBgEBGRAAWDiIgEKBhERCRAwSAiIgEKBhERCVAwiIhIwP8H+DKASYJKR/4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "tags": [], "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "6N-h6NFP7DaE" }, "source": [ "So it looks like **8** components are required to optimally explain this dataset\n", "\n", "However, although sklearn is simple, it doesn't support uncertainty properly. We can try the same things with pyGMMis to use uncertainty on the colours. This is left as an exercise for the reader" ] }, { "cell_type": "code", "metadata": { "id": "-bN15QQBOrdG" }, "source": [ "try:\n", " import pygmmis\n", "except ImportError:\n", " !pip install pygmmis\n", " import pygmmis\n", "\n", "gmm = pygmmis.GMM(K=2, D=5) # K components, D dimensions" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "d_yss1XomaeY" }, "source": [ "\"\"\" Copyright 2021 Peter Scicluna\n", "\n", "Permission is hereby granted, free of charge, to any person obtaining a copy of \n", "this software and associated documentation files (the \"Software\"), to deal in \n", "the Software without restriction, including without limitation the rights to \n", "use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of \n", "the Software, and to permit persons to whom the Software is furnished to do so, \n", "subject to the following conditions:\n", "\n", "The above copyright notice and this permission notice shall be included in all \n", "copies or substantial portions of the Software.\n", "\n", "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR \n", "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS \n", "FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR \n", "COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER \n", "IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN \n", "CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\"\"\"" ], "execution_count": null, "outputs": [] } ] }