{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 9. Mixture Models and EM" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from sklearn.datasets import fetch_openml\n", "%matplotlib inline\n", "\n", "from prml.clustering import KMeans\n", "from prml.rv import (\n", " MultivariateGaussianMixture,\n", " BernoulliMixture\n", ")\n", "\n", "np.random.seed(2222)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9.1 K-means Clustering" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# training data\n", "x1 = np.random.normal(size=(100, 2))\n", "x1 += np.array([-5, -5])\n", "x2 = np.random.normal(size=(100, 2))\n", "x2 += np.array([5, -5])\n", "x3 = np.random.normal(size=(100, 2))\n", "x3 += np.array([0, 5])\n", "x_train = np.vstack((x1, x2, x3))\n", "\n", "x0, x1 = np.meshgrid(np.linspace(-10, 10, 100), np.linspace(-10, 10, 100))\n", "x = np.array([x0, x1]).reshape(2, -1).T" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "kmeans = KMeans(n_clusters=3)\n", "kmeans.fit(x_train)\n", "cluster = kmeans.predict(x_train)\n", "plt.scatter(x_train[:, 0], x_train[:, 1], c=cluster)\n", "plt.scatter(kmeans.centers[:, 0], kmeans.centers[:, 1], s=200, marker='X', lw=2, c=['purple', 'cyan', 'yellow'], edgecolor=\"white\")\n", "plt.contourf(x0, x1, kmeans.predict(x).reshape(100, 100), alpha=0.1)\n", "plt.xlim(-10, 10)\n", "plt.ylim(-10, 10)\n", "plt.gca().set_aspect('equal', adjustable='box')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## 9.2 Mixture of Gaussians" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "gmm = MultivariateGaussianMixture(n_components=3)\n", "gmm.fit(x_train)\n", "p = gmm.classify_proba(x_train)\n", "\n", "plt.scatter(x_train[:, 0], x_train[:, 1], c=p)\n", "plt.scatter(gmm.mu[:, 0], gmm.mu[:, 1], s=200, marker='X', lw=2, c=['red', 'green', 'blue'], edgecolor=\"white\")\n", "plt.xlim(-10, 10)\n", "plt.ylim(-10, 10)\n", "plt.gca().set_aspect(\"equal\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### 9.3.3 Mixtures of Bernoulli distributions" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/9s/lky4p_js2czgsr4_5962ffbw0000gn/T/ipykernel_10929/1003235212.py:6: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " x_train = (x_train > 127).astype(np.float)\n" ] } ], "source": [ "x, y = fetch_openml(\"mnist_784\", return_X_y=True, as_frame=False)\n", "x_train = []\n", "for i in [0, 1, 2, 3, 4]:\n", " x_train.append(x[np.random.choice(np.where(y == str(i))[0], 200)])\n", "x_train = np.concatenate(x_train, axis=0)\n", "x_train = (x_train > 127).astype(np.float)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "bmm = BernoulliMixture(n_components=5)\n", "bmm.fit(x_train)\n", "\n", "plt.figure(figsize=(20, 5))\n", "for i, mean in enumerate(bmm.mu):\n", " plt.subplot(1, 5, i + 1)\n", " plt.imshow(mean.reshape(28, 28), cmap=\"gray\")\n", " plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 1 }