{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", " \n", "## [mlcourse.ai](https://mlcourse.ai) – Open Machine Learning Course \n", "###
Author: Kyriacos Kyriacou, kyr\n", " \n", "##
Tutorial\n", "###
Bring your plots to life with Matplotlib animations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Introduction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Early on in our journey at mlcourse.ai we were taught how to create plots describing our data so that we could gain insights on patterns and relationships that might exist. As the course progressed, we found ourselves constantly creating plots in order to assist us in all kinds of tasks: Whether it was for exploratory data analysis, visualizing results from machine learning (ML) algorithms such as decision trees, or for comparing results using validation curves, plots were prominent everywhere.\n", "\n", "In this tutorial, we'll take plotting a step further by learning how to animate them, using a library we're all familiar with, matplotlib." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Matplotlib animations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Luckily for us, matplotlib already provides excellent libraries for creating animations." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Matplotlib provides two classes for creating animations: [FuncAnimation](https://matplotlib.org/api/_as_gen/matplotlib.animation.FuncAnimation.html#matplotlib.animation.FuncAnimation) and [ArtistAnimation](https://matplotlib.org/api/_as_gen/matplotlib.animation.ArtistAnimation.html#matplotlib.animation.ArtistAnimation).\n", "\n", "With the former, an animation is created repeatedly by calling an update function, whereas with the latter frames are precomputed and then animated. We'll be taking a look at both in this tutorial.\n", "\n", "When using either of the animation classes, it is imperative to keep a reference to the instance object. The animation is advanced by a timer which the `Animation` object holds the only reference to. If you don't hold that reference, the object will be garbage collected which stops the animation. This is inconsistent with other plotting functions of `matplotlib`, where you can use `plt.plot()` without storing the result, so this is an important point to remember.\n", "\n", "Let's look at some code for animating the cosine function, and then we'll dive into explaining each function and the `FuncAnimation` API." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import seaborn as sns\n", "\n", "sns.set(rc={\"figure.figsize\": (10, 7)})\n", "\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# to display animation in notebook\n", "from IPython.display import HTML\n", "from matplotlib.animation import ArtistAnimation, FuncAnimation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "ax.set_xlim((0, 2))\n", "ax.set_ylim((-2, 2))\n", "ax.grid(True)\n", "# create a line which we'll animate\n", "(line,) = plt.plot([], [], lw=2)\n", "\n", "# function used to draw a clear frame\n", "def init():\n", " line.set_data([], [])\n", " return (line,)\n", "\n", "\n", "# called sequentialy on each frame\n", "def update(frame):\n", " x = np.linspace(0, 2, 1000)\n", " y = np.cos(2 * np.pi * (x - 0.01 * frame))\n", " line.set_data(x, y)\n", " return (line,)\n", "\n", "\n", "# hide static plot\n", "plt.close()\n", "\n", "ani = FuncAnimation(fig, update, frames=100, interval=20, init_func=init, blit=True)\n", "HTML(ani.to_html5_video())\n", "# ani.save('cosine_function.mp4', fps=30)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at important parameters of `FuncAnimation` API:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`matplotlib.animation.FuncAnimation(fig, func, frames=None, interval=None, init_func=None, blit=True)`\n", "\n", "- **fig**: The figure object that is used to draw \n", "- **func**: The function to call on each frame\n", "- **frames**: How many frames to animate. If None, passes itertools.count\n", "- **interval**: Delay between consecutive frames in milliseconds\n", "- **init_func**: Function used to draw a clear frame. Called once before the first frame\n", "- **blit**: Controls whether blitting is used. Blitting optimizes the animation by only redrawing parts that have changed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's go through each part of the code together:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "fig, ax = plt.subplots()\n", "ax.set_xlim((0, 2))\n", "ax.set_ylim((-2, 2))\n", "ax.grid(True)\n", "line, = plt.plot([], [], lw=2)\n", "```\n", "\n", "First, we create a figure and axes, and then set the limits of the y and x coordinates. Then, we create a line (with no points), which we'll be animating in our plot. If we wanted to animate another line, we'd also create it here." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "def init():\n", " line.set_data([], [])\n", " return line,\n", "```\n", "\n", "Here, we define the init function which is the function that is used to draw a clear frame. The function must return a touple of all objects that will be animated, in this case it's a single line so we return a touple of length 1." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "def update(frame):\n", " x = np.linspace(0, 2, 1000)\n", " y = np.cos(2 * np.pi * (x - 0.01 * frame))\n", " line.set_data(x, y)\n", " return line,\n", "```\n", "\n", "This is the update function, the guts of the animation. This function is called sequentially in each frame of the animation. It is in this function that we calculate the new coordinates of the object that we are animating.\n", "\n", "The `frame` parameter is the current frame number. The update function can receive additional parameters using the `fargs` parameter in `FuncAnimation` method." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "ani = FuncAnimation(fig, update, frames=100, interval=20, init_func=init, blit=True)\n", "HTML(ani.to_html5_video())\n", "```\n", "\n", "Finally, we call the `FuncAnimation` method using the parameters explained earlier. As you can see, we are saving a reference to the animation using the `ani` parameter. If we didn't, then no animation would take place.\n", "\n", "Using `IPython.display.HTML`, we display the animation in the jupyter notebook. \n", "\n", "Alternatively, we could have saved the animation to a file using:\n", "\n", "`ani.save('cosine_function.mp4', fps=30)`\n", "\n", "If you have problems saving the animation, ensure you have `ffmpeg` or `menencoder` installed." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From now on, instead of having to call `HTML` each time, we will set a parameter so that simply calling the animation object will display the animation in the notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.rcParams[\"animation.html\"] = \"html5\"\n", "ani" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Animating multiple objects" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we're familiar with animating a single object, let's go further by animating multiple objects. Let's add a sine line and display the frame number:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "ax.set_xlim((0, 2))\n", "ax.set_ylim((-2, 2))\n", "ax.grid(True)\n", "\n", "frame_text = ax.text(0.05, 1.9, \"\", bbox=dict(facecolor=\"white\", alpha=1))\n", "(cos_line,) = plt.plot([], [], lw=2, c=\"b\", label=\"cos\")\n", "(sin_line,) = plt.plot([], [], lw=2, c=\"r\", label=\"sin\")\n", "plt.legend()\n", "\n", "# function used to draw a clear frame\n", "def init():\n", " cos_line.set_data([], [])\n", " sin_line.set_data([], [])\n", " frame_text.set_text(\"\")\n", " return cos_line, sin_line, frame_text\n", "\n", "\n", "# called sequentialy on each frame\n", "def update(frame):\n", " x = np.linspace(0, 2, 1000)\n", " cos_x = np.cos(2 * np.pi * (x - 0.01 * frame))\n", " sin_x = np.sin(2 * np.pi * (x - 0.01 * frame))\n", " cos_line.set_data(x, cos_x)\n", " sin_line.set_data(x, sin_x)\n", " frame_text.set_text(\"frame = %d\" % frame)\n", " return cos_line, sin_line, frame_text\n", "\n", "\n", "# hide static plot\n", "plt.close()\n", "\n", "ani = FuncAnimation(fig, update, frames=100, interval=20, init_func=init, blit=True)\n", "ani" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Animating KMeans" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Even though I'm sure looking at a cos/sin animated plot is a lot of fun, let's create some animations which can be useful for us in our ML journey. \n", "\n", "In Topic 7, we learned about clustering algorithms, and in particular K-Means. However, we were only able to plot the final clusters after all iterations had completed. It would be more interesting to see how clusters form and change throughout each iteration. Sounds like an animation is perfect for this!\n", "\n", "We'll be using the iris dataset for this task. Since the iris dataset consists of three features, let's use PCA to reduce that down to 2, so we can plot it in 2 dimensions, just as we did in Topic 7:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn import datasets, decomposition\n", "from sklearn.cluster import KMeans\n", "\n", "iris = datasets.load_iris()\n", "X = iris.data\n", "y = iris.target\n", "\n", "pca = decomposition.PCA(n_components=2)\n", "X_centered = X - X.mean(axis=0)\n", "pca.fit(X_centered)\n", "X_pca = pca.transform(X_centered)\n", "X_pca.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's first take a look at our data with the real labels:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(X_pca[y == 0, 0], X_pca[y == 0, 1], \"bo\", label=\"Setosa\")\n", "plt.plot(X_pca[y == 1, 0], X_pca[y == 1, 1], \"go\", label=\"Versicolour\")\n", "plt.plot(X_pca[y == 2, 0], X_pca[y == 2, 1], \"yo\", label=\"Virginica\")\n", "plt.grid(False)\n", "plt.legend(loc=0);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Unfortunately, sklearn does not provide a way to view historical data of centroids - that is, we can only see the final centroid locations after n iterations, but we want to see the centrois at each iteration. There are two ways around this, we can either create our own implementation of KMeans which stores and returns centroid location after each iteration or with a little ingenuity we can use KMeans from sklearn to get the centrois on each iteration.\n", "\n", "I'll initialize the centroids myself, and then consecutively call KMeans with a fixed random random state and initial centroids, and at each iteration I'll increment the `max_iter` parameter by 1 and fit to the data. This way, on each iteration of our loop, we can look at the centroids. It'll make more sense when u read the code.\n", "\n", "For this animation we'll be using the `ArtistAnimation` class." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots()\n", "ax.set_xlim((-3.5, 4))\n", "ax.set_ylim((-1.5, 1.5))\n", "ax.grid(False)\n", "ims = []\n", "\n", "# initial centroids\n", "init = np.array([-1, 0.5, 3, 1, 3, -0.5]).reshape((3, 2))\n", "\n", "for i in range(10):\n", " txt = plt.text(\n", " -0.5, 1.4, \"Iteration = {}\".format(i + 1), bbox=dict(facecolor=\"white\", alpha=1)\n", " )\n", "\n", " kmeans = KMeans(n_clusters=3, init=init, max_iter=i + 1, random_state=6)\n", " y = kmeans.fit_predict(X_pca, y)\n", "\n", " centers = plt.scatter(\n", " kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], c=\"r\", marker=\"X\"\n", " )\n", " (c1,) = plt.plot(X_pca[y == 0, 0], X_pca[y == 0, 1], \"bo\")\n", " (c2,) = plt.plot(X_pca[y == 1, 0], X_pca[y == 1, 1], \"go\")\n", " (c3,) = plt.plot(X_pca[y == 2, 0], X_pca[y == 2, 1], \"yo\")\n", " ims.append([centers, c1, c2, c3, txt])\n", "\n", "# hide static plot\n", "plt.close()\n", "\n", "ani = ArtistAnimation(fig, ims, interval=1500, repeat_delay=1000, blit=False)\n", "ani" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "matplotlib.animation.ArtistAnimation(fig, artists, interval=200, blit=False, repeat_delay=None)\n", "\n", "- **fig**: The figure object that is used to draw \n", "- **artists**: Each list entry a collection of matplotlib artists that represent what needs to be enabled on each frame. More info on artists: https://matplotlib.org/users/artists.html\n", "- **interval**: Delay between consecutive frames in milliseconds\n", "- **blit**: Controls whether blitting is used. Blitting optimizes the animation by only redrawing parts that have changed\n", "- **repeat_delay**: If the animation in repeated, adds a delay in milliseconds before repeating the animation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's breakdown the code:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "fig, ax = plt.subplots()\n", "ax.set_xlim((-3.5, 4))\n", "ax.set_ylim((-1.5, 1.5))\n", "ax.grid(False)\n", "ims = []\n", "```\n", "\n", "Prepare figure and axes, and create ims array which will hold the artists for each iteration." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "init = np.array([-1, 0.5, 3, 1, 3, -0.5]).reshape((3,2))\n", "for i in range(10):\n", " txt = plt.text(-.5, 1.4, 'Iteration = {}'.format(i+1), bbox=dict(facecolor='white', alpha=1)) \n", " kmeans = KMeans(n_clusters=3, init=init, max_iter=i+1, random_state=6)\n", " y = kmeans.fit_predict(X_pca, y)\n", " centers = plt.scatter(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1], c='r', marker='X')\n", " c1, = plt.plot(X_pca[y == 0, 0], X_pca[y == 0, 1], 'bo')\n", " c2, = plt.plot(X_pca[y == 1, 0], X_pca[y == 1, 1], 'go')\n", " c3, = plt.plot(X_pca[y == 2, 0], X_pca[y == 2, 1], 'yo')\n", " ims.append([centers, c1, c2, c3, txt])\n", "```\n", "\n", "First, create initial centroids and place them in init array. Then, loop over the number of iterations we want to run KMeans for.\n", "\n", "In each iteration, we perform KMeans up to that iteration, and obtain the cluster predictions and the new centroids for that iteration. Update each new cluster using `plt.plot`, and draw new centroid locations with `plt.scatter`. All objects returned from `plt.text`, `plt.scatter` and `plt.plot` are artists which must be appended to the `ims` array in a list." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "plt.close() \n", "ani = ArtistAnimation(fig, ims, interval=1500, repeat_delay=1000, blit=False)\n", "ani\n", "```\n", "\n", "Finally, hide the static plot using `plt.close`, and display the animation using `ArtistAnimation`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Animating Decision Tree Boundaries" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's look at how the Decision Tree boundary changes while changing the `max_depth` parameter. We'll be able to see an animation of the tree fitting to the data and eventually overfitting!\n", "\n", "To generate and display points I'll be using code from Topic 3 of mlcourse.ai." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# generate points\n", "np.random.seed(6)\n", "\n", "class_1 = np.random.normal(size=(50, 2))\n", "class_1_lbl = np.zeros(50)\n", "class_2 = np.random.normal(size=(50, 2), loc=1.8)\n", "class_2_lbl = np.ones(50)\n", "\n", "train_data = np.r_[class_1, class_2]\n", "train_labels = np.r_[class_1_lbl, class_2_lbl]\n", "\n", "plt.scatter(\n", " train_data[:, 0],\n", " train_data[:, 1],\n", " c=train_labels,\n", " s=100,\n", " cmap=\"autumn\",\n", " edgecolors=\"black\",\n", " linewidth=1,\n", ");" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeClassifier\n", "\n", "\n", "def get_grid(data):\n", " x_min, x_max = data[:, 0].min() - 1, data[:, 0].max() + 1\n", " y_min, y_max = data[:, 1].min() - 1, data[:, 1].max() + 1\n", " return np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01))\n", "\n", "\n", "fig, ax = plt.subplots()\n", "ax.grid(False)\n", "ims = []\n", "xx, yy = get_grid(train_data)\n", "\n", "for i in range(8):\n", " txt = plt.text(\n", " 0, 4.5, \"DT max depth = {}\".format(i + 1), bbox=dict(facecolor=\"white\", alpha=1)\n", " )\n", "\n", " pred = (\n", " DecisionTreeClassifier(criterion=\"gini\", max_depth=i + 1, random_state=17)\n", " .fit(train_data, train_labels)\n", " .predict(np.c_[xx.ravel(), yy.ravel()])\n", " .reshape(xx.shape)\n", " )\n", " pcm = ax.pcolormesh(xx, yy, pred, cmap=\"autumn\")\n", "\n", " pts = plt.scatter(\n", " train_data[:, 0],\n", " train_data[:, 1],\n", " c=train_labels,\n", " s=100,\n", " cmap=\"autumn\",\n", " edgecolors=\"black\",\n", " linewidth=1,\n", " )\n", " ims.append([pcm, pts, txt])\n", "\n", "# hide static plot\n", "plt.close()\n", "\n", "ani = ArtistAnimation(fig, ims, interval=2000, repeat_delay=1000, blit=False)\n", "ani" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see above the three starts overfitting at max depth = 4. The code is more or less the same as the code for the KMeans visualization, and as such I will not explain each part." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. Further Examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Matplotlib has many examples of animations [here](https://matplotlib.org/api/animation_api.html#examples). I'll show one below which displays 3-D capabilities:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# taken from https://matplotlib.org/gallery/animation/random_walk.html\n", "\n", "import mpl_toolkits.mplot3d.axes3d as p3\n", "\n", "# Fixing random state for reproducibility\n", "np.random.seed(19680801)\n", "\n", "\n", "def Gen_RandLine(length, dims=2):\n", " \"\"\"\n", " Create a line using a random walk algorithm\n", "\n", " length is the number of points for the line.\n", " dims is the number of dimensions the line has.\n", " \"\"\"\n", " lineData = np.empty((dims, length))\n", " lineData[:, 0] = np.random.rand(dims)\n", " for index in range(1, length):\n", " # scaling the random numbers by 0.1 so\n", " # movement is small compared to position.\n", " # subtraction by 0.5 is to change the range to [-0.5, 0.5]\n", " # to allow a line to move backwards.\n", " step = (np.random.rand(dims) - 0.5) * 0.1\n", " lineData[:, index] = lineData[:, index - 1] + step\n", "\n", " return lineData\n", "\n", "\n", "def update_lines(num, dataLines, lines):\n", " for line, data in zip(lines, dataLines):\n", " # NOTE: there is no .set_data() for 3 dim data...\n", " line.set_data(data[0:2, :num])\n", " line.set_3d_properties(data[2, :num])\n", " return lines\n", "\n", "\n", "# Attaching 3D axis to the figure\n", "fig = plt.figure()\n", "ax = p3.Axes3D(fig)\n", "\n", "# Fifty lines of random 3-D lines\n", "data = [Gen_RandLine(25, 3) for index in range(50)]\n", "\n", "# Creating fifty line objects.\n", "# NOTE: Can't pass empty arrays into 3d version of plot()\n", "lines = [ax.plot(dat[0, 0:1], dat[1, 0:1], dat[2, 0:1])[0] for dat in data]\n", "\n", "# Setting the axes properties\n", "ax.set_xlim3d([0.0, 1.0])\n", "ax.set_xlabel(\"X\")\n", "\n", "ax.set_ylim3d([0.0, 1.0])\n", "ax.set_ylabel(\"Y\")\n", "\n", "ax.set_zlim3d([0.0, 1.0])\n", "ax.set_zlabel(\"Z\")\n", "\n", "ax.set_title(\"3D Test\")\n", "\n", "plt.close()\n", "\n", "# Creating the Animation object\n", "ani = FuncAnimation(fig, update_lines, 25, fargs=(data, lines), interval=50, blit=False)\n", "\n", "ani" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7. Conclusion" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Matplotlib provides great capability for animations straight out of the box! We saw that through usage of the `FuncAnimation` or the `ArtistAnimation` class, we have a lot of power in our hands for producing animations. We can plot mathematical functions, clusters and decision boundaries from our ML algorithms and even animate 3-D plots.\n", "\n", "Animations can be used to help us understand how our ML algorithms are workings, or they can be used to produce awesome animations to show off in your next blog post. Either way, the possibilites are endless.\n", "\n", "Have fun animating!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 8. References" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "1. https://matplotlib.org/api/animation_api.html\n", "2. https://matplotlib.org/api/_as_gen/matplotlib.animation.FuncAnimation.html#matplotlib.animation.FuncAnimation\n", "3. https://matplotlib.org/api/_as_gen/matplotlib.animation.ArtistAnimation.html#matplotlib.animation.ArtistAnimation\n", "4. https://jakevdp.github.io/blog/2012/08/18/matplotlib-animation-tutorial/\n", "5. https://mlcourse.ai/notebooks/blob/master/jupyter_english/topic03_decision_trees_kNN/topic3_decision_trees_kNN.ipynb?flush_cache=true\n", "6. https://mlcourse.ai/notebooks/blob/master/jupyter_english/topic07_unsupervised/topic7_pca_clustering.ipynb?flush_cache=true" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }