{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Function optimization in SciPy\n", "\n", "> \"What's new?\" is an interesting and broadening eternal question, but one\n", "> which, if pursued exclusively, results only in an endless parade of trivia\n", "> and fashion, the silt of tomorrow. I would like, instead, to be concerned\n", "> with the question \"What is best?\", a question which cuts deeply rather than\n", "> broadly, a question whose answers tend to move the silt downstream.\n", ">\n", "> — Robert M Pirsig, *Zen and the Art of Motorcycle Maintenance*\n", "\n", "When hanging a picture on the wall, it is sometimes difficult to get it\n", "straight. You make an adjustment, step back, evaluate the picture's\n", "horizontality, and repeat. This is a process of *optimization*: we're\n", "changing the orientation of the portrait until it satisfies our\n", "demand—that it makes a zero angle with the horizon.\n", "\n", "In mathematics, our demand is called a \"cost function\", and the\n", "orientation of the portrait the \"parameter\". In a typical\n", "optimization problem, we vary the parameters until the cost function\n", "is minimized.\n", "\n", "Consider, for example, the shifted parabola, $f(x) = (x - 3)^2$.\n", "We'd like to find the value of x that minimizes this cost function. We\n", "know that this function, with parameter $x$, has a minimum at 3,\n", "because we can calculate the derivative, set it to zero, and see that $2 (x - 3) = 0$, i.e. $x = 3$.\n", "\n", "But, if this function were much more complicated (for example if the\n", "expression had many terms, multiple points of zero derivative,\n", "contained non-linearities, or was dependent on more variables), using\n", "a hand calculation would become arduous.\n", "\n", "You can think of the cost function as representing a landscape, where we\n", "are trying to find the lowest point. That analogy immediately\n", "highlights one of the hard parts of this problem: if you are standing\n", "in any valley, with mountains surrounding you, how do you know whether\n", "you are in the lowest valley, or whether this valley just seems low because it is\n", "surrounded by particularly tall mountains? In optimization parlance: how\n", "do you know whether you are trapped in a *local\n", "minimum*? Most optimization algorithms make some\n", "attempt to address the issue[^line_search].\n", "\n", "[^line_search]: Optimization algorithms handle this issue in various\n", " ways, but two common approaches are line searches and\n", " trust regions. With a *line search*, you try to find\n", " the cost function minimum along a specific dimension,\n", " and then successively attempt the same along the other\n", " dimensions. With *trust regions*, we move our guess\n", " for the minimum in the direction we expect it to be;\n", " if we see that we are indeed approaching the minimum\n", " as expected, we repeat the procedure with increased\n", " confidence. If not, we lower our confidence and\n", " search a wider area.\n", "\n", "\n", "\n", "\n", "There are many different optimization algorithms to choose from (see\n", "figure). You get to choose whether your cost function takes a scalar\n", "or a vector as input (i.e., do you have one or multiple parameters to\n", "optimize?). There are those that require the cost function gradient\n", "to be given and those that automatically estimate it. Some only\n", "search for parameters in a given area (*constrained optimization*),\n", "and others examine the entire parameter space.\n", "\n", "## Optimization in SciPy: `scipy.optimize`\n", "\n", "In the rest of this chapter, we are going to use SciPy's `optimize` module to\n", "align two images. Applications of image alignment, or *registration*, include\n", "panorama stitching, combination of different brain scans, super-resolution\n", "imaging, and, in astronomy, object denoising (noise reduction) through the\n", "combination of multiple exposures.\n", "\n", "We start, as usual, by setting up our plotting environment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Make plots appear inline, set custom plotting style\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "plt.style.use('style/elegant.mplstyle')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start with the simplest version of the problem: we have two\n", "images, one shifted relative to the other. We wish to recover the\n", "shift that will best align our images.\n", "\n", "Our optimization function will \"jiggle\" one of the images, and see\n", "whether jiggling it in one direction or another reduces their\n", "dissimilarity. By doing this repeatedly, we can try to find the\n", "correct alignment.\n", "\n", "### An example: computing optimal image shift\n", "\n", "You'll remember our astronaut — Eileen Collins — from chapter 3.\n", "We will be shifting this image by 50 pixels to the right then comparing it back\n", "to the original until we\n", "find the shift that best matches. Obviously this is a silly thing to do, as we\n", "know the original position, but this way we know the truth, and we can check\n", "how our algorithm is doing. Here's the original and shifted image." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skimage import data, color\n", "from scipy import ndimage as ndi\n", "\n", "astronaut = color.rgb2gray(data.astronaut())\n", "shifted = ndi.shift(astronaut, (0, 50))\n", "\n", "fig, axes = plt.subplots(nrows=1, ncols=2)\n", "axes[0].imshow(astronaut)\n", "axes[0].set_title('Original')\n", "axes[1].imshow(shifted)\n", "axes[1].set_title('Shifted');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "For the optimization algorithm to do its work, we need some way of\n", "defining \"dissimilarity\"—i.e., the cost function. The easiest way to do this is to\n", "simply calculate the average of the squared differences, often called the\n", "*mean squared error*, or MSE." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def mse(arr1, arr2):\n", " \"\"\"Compute the mean squared error between two arrays.\"\"\"\n", " return np.mean((arr1 - arr2)**2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will return 0 when the images are perfectly aligned, and a higher\n", "value otherwise. With this cost function, we can check whether two images are aligned:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ncol = astronaut.shape[1]\n", "\n", "# Cover a distance of 90% of the length in columns,\n", "# with one value per percentage point\n", "shifts = np.linspace(-0.9 * ncol, 0.9 * ncol, 181)\n", "mse_costs = []\n", "\n", "for shift in shifts:\n", " shifted_back = ndi.shift(shifted, (0, shift))\n", " mse_costs.append(mse(astronaut, shifted_back))\n", "\n", "fig, ax = plt.subplots()\n", "ax.plot(shifts, mse_costs)\n", "ax.set_xlabel('Shift')\n", "ax.set_ylabel('MSE');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "With the cost function defined, we can ask `scipy.optimize.minimize`\n", "to search for optimal parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy import optimize\n", "\n", "def astronaut_shift_error(shift, image):\n", " corrected = ndi.shift(image, (0, shift))\n", " return mse(astronaut, corrected)\n", "\n", "res = optimize.minimize(astronaut_shift_error, 0, args=(shifted,),\n", " method='Powell')\n", "\n", "print(f'The optimal shift for correction is: {res.x}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It worked! We shifted it by +50 pixels, and, thanks to our MSE measure, SciPy's\n", "`optimize.minimize` function has given us the correct amount of shift (-50) to\n", "get it back to its original state.\n", "\n", "It turns out, however, that this was a particularly easy optimization problem,\n", "which brings us to the principal difficulty of this kind of\n", "alignment: sometimes, the MSE has to get worse before it gets better.\n", "\n", "Let's look again at shifting images, starting with the unmodified image:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ncol = astronaut.shape[1]\n", "\n", "# Cover a distance of 90% of the length in columns,\n", "# with one value per percentage point\n", "shifts = np.linspace(-0.9 * ncol, 0.9 * ncol, 181)\n", "mse_costs = []\n", "\n", "for shift in shifts:\n", " shifted1 = ndi.shift(astronaut, (0, shift))\n", " mse_costs.append(mse(astronaut, shifted1))\n", "\n", "fig, ax = plt.subplots()\n", "ax.plot(shifts, mse_costs)\n", "ax.set_xlabel('Shift')\n", "ax.set_ylabel('MSE');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "Starting at zero shift, have a look at the MSE value as the shift becomes\n", "increasingly negative: it increases consistently until around -300\n", "pixels of shift, where it starts to decrease again! Only slightly, but it\n", "decreases nonetheless. The MSE bottoms out at around -400, before it\n", "increases again. This is called a *local minimum*.\n", "Because optimization methods only have access to \"nearby\"\n", "values of the cost function, if the function improves by moving in the \"wrong\"\n", "direction, the `minimize` process will move that way regardless. So, if we\n", "start by an image shifted by -340 pixels:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "shifted2 = ndi.shift(astronaut, (0, -340))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`minimize` will shift it by a further 40 pixels or so,\n", "instead of recovering the original image:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res = optimize.minimize(astronaut_shift_error, 0, args=(shifted2,),\n", " method='Powell')\n", "\n", "print(f'The optimal shift for correction is {res.x}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The common solution to this problem is to smooth or downscale the images, which\n", "has the dual result of smoothing the objective function. Have a look at the\n", "same plot, after having smoothed the images with a Gaussian filter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skimage import filters\n", "\n", "astronaut_smooth = filters.gaussian(astronaut, sigma=20)\n", "\n", "mse_costs_smooth = []\n", "shifts = np.linspace(-0.9 * ncol, 0.9 * ncol, 181)\n", "for shift in shifts:\n", " shifted3 = ndi.shift(astronaut_smooth, (0, shift))\n", " mse_costs_smooth.append(mse(astronaut_smooth, shifted3))\n", "\n", "fig, ax = plt.subplots()\n", "ax.plot(shifts, mse_costs, label='original')\n", "ax.plot(shifts, mse_costs_smooth, label='smoothed')\n", "ax.legend(loc='lower right')\n", "ax.set_xlabel('Shift')\n", "ax.set_ylabel('MSE');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "As you can see, with some rather extreme smoothing, the \"funnel\" of\n", "the error function becomes wider, and less bumpy. Rather than smoothing the\n", "function itself we can get a similar effect by blurring the images before\n", "comparing them. Therefore, modern alignment\n", "software uses what's called a *Gaussian pyramid*, which is a set of\n", "progressively lower resolution versions of the same image. We align\n", "the the lower resolution (blurrier) images first, to get an\n", "approximate alignment, and then progressively refine the alignment\n", "with sharper images." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def downsample2x(image):\n", " offsets = [((s + 1) % 2) / 2 for s in image.shape]\n", " slices = [slice(offset, end, 2)\n", " for offset, end in zip(offsets, image.shape)]\n", " coords = np.mgrid[slices]\n", " return ndi.map_coordinates(image, coords, order=1)\n", "\n", "\n", "def gaussian_pyramid(image, levels=6):\n", " \"\"\"Make a Gaussian image pyramid.\n", "\n", " Parameters\n", " ----------\n", " image : array of float\n", " The input image.\n", " max_layer : int, optional\n", " The number of levels in the pyramid.\n", "\n", " Returns\n", " -------\n", " pyramid : iterator of array of float\n", " An iterator of Gaussian pyramid levels, starting with the top\n", " (lowest resolution) level.\n", " \"\"\"\n", " pyramid = [image]\n", "\n", " for level in range(levels - 1):\n", " blurred = ndi.gaussian_filter(image, sigma=2/3)\n", " image = downsample2x(image)\n", " pyramid.append(image)\n", "\n", " return reversed(pyramid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see how the 1D alignment looks along that pyramid:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "shifts = np.linspace(-0.9 * ncol, 0.9 * ncol, 181)\n", "nlevels = 8\n", "costs = np.empty((nlevels, len(shifts)), dtype=float)\n", "astronaut_pyramid = list(gaussian_pyramid(astronaut, levels=nlevels))\n", "for col, shift in enumerate(shifts):\n", " shifted = ndi.shift(astronaut, (0, shift))\n", " shifted_pyramid = gaussian_pyramid(shifted, levels=nlevels)\n", " for row, image in enumerate(shifted_pyramid):\n", " costs[row, col] = mse(astronaut_pyramid[row], image)\n", "\n", "fig, ax = plt.subplots()\n", "for level, cost in enumerate(costs):\n", " ax.plot(shifts, cost, label='Level %d' % (nlevels - level))\n", "ax.legend(loc='lower right', frameon=True, framealpha=0.9)\n", "ax.set_xlabel('Shift')\n", "ax.set_ylabel('MSE');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "As you can see, at the highest level of the pyramid, that bump at a shift of\n", "about -325 disappears. We can therefore get an approximate alignment at that\n", "level, then pop down to the lower levels to refine that alignment.\n", "\n", "## Image registration with `optimize`\n", "\n", "Let's automate that, and try with a \"real\" alignment, with three parameters:\n", "rotation, translation in the row dimension, and translation in the\n", "column dimension. This is called a \"*rigid* registration\" because there are no\n", "deformations of any kind (scaling, skew, or other stretching). The object is\n", "considered solid and moved around (including rotation) until a match is found.\n", "\n", "To simplify the code, we'll use the scikit-image *transform* module to compute\n", "the shift and rotation of the image. SciPy's `optimize` requires a vector of\n", "parameters as input. We first make a\n", "function that will take such a vector and produce a rigid transformation with\n", "the right parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skimage import transform\n", "\n", "def make_rigid_transform(param):\n", " r, tc, tr = param\n", " return transform.SimilarityTransform(rotation=r,\n", " translation=(tc, tr))\n", "\n", "rotated = transform.rotate(astronaut, 45)\n", "\n", "fig, axes = plt.subplots(nrows=1, ncols=2)\n", "axes[0].imshow(astronaut)\n", "axes[0].set_title('Original')\n", "axes[1].imshow(rotated)\n", "axes[1].set_title('Rotated');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "Next, we need a cost function. This is just MSE, but SciPy requires a specific\n", "format: the first argument needs to be the *parameter vector*, which it is\n", "optimizing. Subsequent arguments can be passed through the `args` keyword as a\n", "tuple, but must remain fixed: only the parameter vector can be optimized. In\n", "our case, this is just the rotation angle and the two translation parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def cost_mse(param, reference_image, target_image):\n", " transformation = make_rigid_transform(param)\n", " transformed = transform.warp(target_image, transformation, order=3)\n", " return mse(reference_image, transformed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we write our alignment function, which optimizes our cost function\n", "*at each level of the Gaussian pyramid*, using the result of the previous\n", "level as a starting point for the next one:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def align(reference, target, cost=cost_mse):\n", " nlevels = 7\n", " pyramid_ref = gaussian_pyramid(reference, levels=nlevels)\n", " pyramid_tgt = gaussian_pyramid(target, levels=nlevels)\n", "\n", " levels = range(nlevels, 0, -1)\n", " image_pairs = zip(pyramid_ref, pyramid_tgt)\n", "\n", " p = np.zeros(3)\n", "\n", " for n, (ref, tgt) in zip(levels, image_pairs):\n", " p[1:] *= 2\n", "\n", " res = optimize.minimize(cost, p, args=(ref, tgt), method='Powell')\n", " p = res.x\n", "\n", " # print current level, overwriting each time (like a progress bar)\n", " print(f'Level: {n}, Angle: {np.rad2deg(res.x[0]) :.3}, '\n", " f'Offset: ({res.x[1] * 2**n :.3}, {res.x[2] * 2**n :.3}), '\n", " f'Cost: {res.fun :.3}', end='\\r')\n", "\n", " print('') # newline when alignment complete\n", " return make_rigid_transform(p)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try it with our astronaut image. We rotate it by 60 degrees and add some\n", "noise to it. Can SciPy recover the correct transform?" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skimage import util\n", "\n", "theta = 60\n", "rotated = transform.rotate(astronaut, theta)\n", "rotated = util.random_noise(rotated, mode='gaussian',\n", " seed=0, mean=0, var=1e-3)\n", "\n", "tf = align(astronaut, rotated)\n", "corrected = transform.warp(rotated, tf, order=3)\n", "\n", "f, (ax0, ax1, ax2) = plt.subplots(1, 3)\n", "ax0.imshow(astronaut)\n", "ax0.set_title('Original')\n", "ax1.imshow(rotated)\n", "ax1.set_title('Rotated')\n", "ax2.imshow(corrected)\n", "ax2.set_title('Registered')\n", "for ax in (ax0, ax1, ax2):\n", " ax.axis('off')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "We're feeling pretty good now. But our choice of parameters actually masked\n", "the difficulty of optimization: Let's see what happens with a rotation of\n", "50 degrees, which is *closer* to the original image:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "theta = 50\n", "rotated = transform.rotate(astronaut, theta)\n", "rotated = util.random_noise(rotated, mode='gaussian',\n", " seed=0, mean=0, var=1e-3)\n", "\n", "tf = align(astronaut, rotated)\n", "corrected = transform.warp(rotated, tf, order=3)\n", "\n", "f, (ax0, ax1, ax2) = plt.subplots(1, 3)\n", "ax0.imshow(astronaut)\n", "ax0.set_title('Original')\n", "ax1.imshow(rotated)\n", "ax1.set_title('Rotated')\n", "ax2.imshow(corrected)\n", "ax2.set_title('Registered')\n", "for ax in (ax0, ax1, ax2):\n", " ax.axis('off')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "Even though we started closer to the original image, we failed to\n", "recover the correct rotation. This is because optimization techniques can get\n", "stuck in local minima, little bumps on the road to success, as we saw above\n", "with the shift-only alignment. They can therefore be quite sensitive to the\n", "starting parameters.\n", "\n", "## Avoiding local minima with basin hopping\n", "\n", "A 1997 algorithm devised by David Wales and Jonathan Doyle [^basinhop], called\n", "*basin-hopping*, attempts to avoid local minima by trying an optimization from\n", "some initial parameters, then moving away from the found local minimum in a\n", "random direction, and optimizing again. By choosing an appropriate step size\n", "for these random moves, the algorithm can avoid falling into the same local\n", "minimum twice, and thus explore a much larger area of the parameter space than\n", "simple gradient-based optimization methods.\n", "\n", "We leave it as an exercise to incorporate SciPy's implementation of basin-hopping\n", "into our alignment function. You'll need it for later parts of the chapter, so\n", "feel free to peek at the solution at the end of the book if you're stuck.\n", "\n", "[^basinhop]: David J. Wales and Jonathan P.K. Doyle (1997). Global Optimization\n", " by Basin-Hopping and the Lowest Energy Structures of Lennard-Jones\n", " Clusters Containing up to 110 Atoms.\n", " **Journal of Physical Chemistry 101(28):5111–5116**\n", " DOI: 10.1021/jp970984n\n", "\n", "\n", "\n", "**Exercise:** Try modifying the `align` function to use\n", "`scipy.optimize.basinhopping`, which has explicit strategies to avoid local minima.\n", "\n", "*Hint:* limit using basin-hopping to just the top levels of the pyramid, as it is\n", "a slower optimization approach, and could take rather long to run at full image\n", "resolution.\n", "\n", "\n", "\n", "**Solution:** We use basin-hopping at the higher levels of the pyramid, but use\n", "Powell's method for the lower levels, because basin-hopping is too\n", "computationally expensive to run at full resolution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def align(reference, target, cost=cost_mse, nlevels=7, method='Powell'):\n", " pyramid_ref = gaussian_pyramid(reference, levels=nlevels)\n", " pyramid_tgt = gaussian_pyramid(target, levels=nlevels)\n", "\n", " levels = range(nlevels, 0, -1)\n", " image_pairs = zip(pyramid_ref, pyramid_tgt)\n", "\n", " p = np.zeros(3)\n", "\n", " for n, (ref, tgt) in zip(levels, image_pairs):\n", " p[1:] *= 2\n", " if method.upper() == 'BH':\n", " res = optimize.basinhopping(cost, p,\n", " minimizer_kwargs={'args': (ref, tgt)})\n", " if n <= 4: # avoid basin-hopping in lower levels\n", " method = 'Powell'\n", " else:\n", " res = optimize.minimize(cost, p, args=(ref, tgt), method='Powell')\n", " p = res.x\n", " # print current level, overwriting each time (like a progress bar)\n", " print(f'Level: {n}, Angle: {np.rad2deg(res.x[0]) :.3}, '\n", " f'Offset: ({res.x[1] * 2**n :.3}, {res.x[2] * 2**n :.3}), '\n", " f'Cost: {res.fun :.3}', end='\\r')\n", "\n", " print('') # newline when alignment complete\n", " return make_rigid_transform(p)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's try that alignment:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skimage import util\n", "\n", "theta = 50\n", "rotated = transform.rotate(astronaut, theta)\n", "rotated = util.random_noise(rotated, mode='gaussian',\n", " seed=0, mean=0, var=1e-3)\n", "\n", "tf = align(astronaut, rotated, nlevels=8, method='BH')\n", "corrected = transform.warp(rotated, tf, order=3)\n", "\n", "f, (ax0, ax1, ax2) = plt.subplots(1, 3)\n", "ax0.imshow(astronaut)\n", "ax0.set_title('Original')\n", "ax1.imshow(rotated)\n", "ax1.set_title('Rotated')\n", "ax2.imshow(corrected)\n", "ax2.set_title('Registered')\n", "for ax in (ax0, ax1, ax2):\n", " ax.axis('off')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "Success! Basin-hopping was able to recover the correct alignment, even in the\n", "problematic case in which the `minimize` function got stuck.\n", "\n", "\n", "\n", "\n", "\n", "## \"What is best?\": Choosing the right objective function\n", "\n", "At this point, we have a working registration approach, which is most\n", "excellent. But it turns out that we've only solved the easiest of registration\n", "problems: aligning images of the same *modality*. This means that we expect\n", "bright pixels in the reference image to match up to bright pixels in the test\n", "image.\n", "\n", "We now move on to aligning different color channels of the same image,\n", "where we can no longer rely on the channels having the same modality. This task\n", "has historical significance: between 1909 and 1915, the photographer Sergei\n", "Mikhailovich Prokudin-Gorskii produced color photographs of the Russian empire\n", "before color photography had been invented. He did this by taking three\n", "different monochrome pictures of a scene, each with a different color filter\n", "placed in front of the lens.\n", "\n", "Aligning bright pixels together, as the MSE implicitly does, won't work in\n", "this case. Take, for example, these three pictures of a stained glass window\n", "in the Church of Saint John the Theologian, taken from the [Library of Congress\n", "Prokudin-Gorskii Collection](http://www.loc.gov/pictures/item/prk2000000263/):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from skimage import io\n", "stained_glass = io.imread('data/00998v.jpg') / 255 # use float image in [0, 1]\n", "fig, ax = plt.subplots(figsize=(4.8, 7))\n", "ax.imshow(stained_glass)\n", "ax.axis('off');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "Take a look at St John's robes: they look pitch black in one image, gray in\n", "another, and bright white in the third! This would result in a terrible MSE\n", "score, even with perfect alignment.\n", "\n", "Let's see what we can do with this. We start by splitting the plate into its\n", "component channels:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nrows = stained_glass.shape[0]\n", "step = nrows // 3\n", "channels = (stained_glass[:step],\n", " stained_glass[step:2*step],\n", " stained_glass[2*step:3*step])\n", "channel_names = ['blue', 'green', 'red']\n", "fig, axes = plt.subplots(1, 3)\n", "for ax, image, name in zip(axes, channels, channel_names):\n", " ax.imshow(image)\n", " ax.axis('off')\n", " ax.set_title(name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "First, we overlay all three images to verify that the alignment indeed needs to\n", "be fine-tuned between the three channels:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "blue, green, red = channels\n", "original = np.dstack((red, green, blue))\n", "fig, ax = plt.subplots(figsize=(4.8, 4.8), tight_layout=True)\n", "ax.imshow(original)\n", "ax.axis('off');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "You can see by the color \"halos\" around objects in the image that the colors\n", "are close to alignment, but not quite. Let's try to align them in the same\n", "way that we aligned the astronaut image above, using the MSE. We use one color\n", "channel, green, as the reference image, and align the blue and red channels to\n", "that." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('*** Aligning blue to green ***')\n", "tf = align(green, blue)\n", "cblue = transform.warp(blue, tf, order=3)\n", "\n", "print('** Aligning red to green ***')\n", "tf = align(green, red)\n", "cred = transform.warp(red, tf, order=3)\n", "\n", "corrected = np.dstack((cred, green, cblue))\n", "f, (ax0, ax1) = plt.subplots(1, 2)\n", "ax0.imshow(original)\n", "ax0.set_title('Original')\n", "ax1.imshow(corrected)\n", "ax1.set_title('Corrected')\n", "for ax in (ax0, ax1):\n", " ax.axis('off')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "The alignment is a little bit better than with the raw images, because the red\n", "and the green channels are correctly aligned, probably thanks to the giant\n", "yellow patch of sky. However, the blue channel is still off, because the bright\n", "spots of blue don't coincide with the green channel. That means that the MSE\n", "will be lower when the channels are *mis*-aligned so that blue patches overlap\n", "with some bright green spots.\n", "\n", "We turn instead to a measure called *normalized mutual information* (NMI),\n", "which measures correlations between the different brightness bands of the\n", "different images. When the images are perfectly aligned, any object of uniform\n", "color will create a large correlation between the shades of the different\n", "component channels, and a correspondingly large NMI value. In a sense, NMI\n", "measures how easy it would be to predict a pixel value of one image given the\n", "value of the corresponding pixel in the other. It was defined in the paper\n", "\"Studholme, C., Hill, D.L.G., Hawkes, D.J., *An Overlap Invariant Entropy Measure\n", "of 3D Medical Image Alignment*, Patt. Rec. 32, 71–86 (1999)\":\n", "\n", "$$I(X, Y) = \\frac{H(X) + H(Y)}{H(X, Y)},$$\n", "\n", "where $H(X)$ is the *entropy* of $X$, and $H(X, Y)$ is the joint\n", "entropy of $X$ and $Y$. The numerator describes the entropy of the\n", "two images, seen separately, and the denominator the total entropy if\n", "they are observed together. Values can vary between 1 (maximally\n", "aligned) and 2 (minimally aligned)[^mi_calc]. See Chapter 5 for a\n", "more in-depth discussion of entropy.\n", "\n", "[^mi_calc]: A quick handwavy explanation is that entropy is calculated\n", " from the histogram of the quantity under consideration.\n", " If $X = Y$, then the joint histogram $(X, Y)$ is diagonal,\n", " and that diagonal is the same as that of either $X$ or\n", " $Y$. Thus, $H(X) = H(Y) = H(X, Y)$ and $I(X, Y) = 2$.\n", "\n", "In Python code, this becomes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy.stats import entropy\n", "\n", "def normalized_mutual_information(A, B):\n", " \"\"\"Compute the normalized mutual information.\n", "\n", " The normalized mutual information is given by:\n", "\n", " H(A) + H(B)\n", " Y(A, B) = -----------\n", " H(A, B)\n", "\n", " where H(X) is the entropy ``- sum(x log x) for x in X``.\n", "\n", " Parameters\n", " ----------\n", " A, B : ndarray\n", " Images to be registered.\n", "\n", " Returns\n", " -------\n", " nmi : float\n", " The normalized mutual information between the two arrays, computed at a\n", " granularity of 100 bins per axis (10,000 bins total).\n", " \"\"\"\n", " hist, bin_edges = np.histogramdd([np.ravel(A), np.ravel(B)], bins=100)\n", " hist /= np.sum(hist)\n", "\n", " H_A = entropy(np.sum(hist, axis=0))\n", " H_B = entropy(np.sum(hist, axis=1))\n", " H_AB = entropy(np.ravel(hist))\n", "\n", " return (H_A + H_B) / H_AB" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we define a *cost function* to optimize, as we defined `cost_mse` above:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def cost_nmi(param, reference_image, target_image):\n", " transformation = make_rigid_transform(param)\n", " transformed = transform.warp(target_image, transformation, order=3)\n", " return -normalized_mutual_information(reference_image, transformed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we use this with our basinhopping-optimizing aligner:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print('*** Aligning blue to green ***')\n", "tf = align(green, blue, cost=cost_nmi)\n", "cblue = transform.warp(blue, tf, order=3)\n", "\n", "print('** Aligning red to green ***')\n", "tf = align(green, red, cost=cost_nmi)\n", "cred = transform.warp(red, tf, order=3)\n", "\n", "corrected = np.dstack((cred, green, cblue))\n", "fig, ax = plt.subplots(figsize=(4.8, 4.8), tight_layout=True)\n", "ax.imshow(corrected)\n", "ax.axis('off')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "What a glorious image! Realize that this artifact was created before color\n", "photography existed! Notice God's pearly white robes, John's white beard,\n", "and the white pages of the book held by Prochorus, his scribe — all of which\n", "were missing from the MSE-based alignment, but look wonderfully clear using NMI.\n", "Notice also the realistic gold of the candlesticks in the foreground.\n", "\n", "We've illustrated the two key concepts in function optimization in this\n", "chapter: understanding local minima and how to avoid them, and choosing the\n", "right function to optimize to achieve a particular objective. Solving these\n", "allows you to apply optimization to a wide array of scientific problems!" ] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 4 }