{ "metadata": { "name": "11_photoz_regression" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "heading", "level": 1, "metadata": {}, "source": [ "Regression Example: Photometric Redshifts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we will explore a regression problem important in the field of astronomy:\n", "Data in Astronomy is most often in one of two forms: spectral data, and photometric\n", "data. Spectra are high-resolution measurements of the energy of a source as a\n", "function of wavelength. Photometry, usually measured in a logarithmic scale\n", "called *magnitudes*, can be thought of as the integral of the spectrum through\n", "a broad filter.\n", "\n", "Run the following code to see an example of a stellar spectrum (in particular,\n", "the spectrum of the star Vega) along with the five filters used in the\n", "Sloan Digital Sky Survey:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "%pylab inline" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "from figures import plot_sdss_filters\n", "plot_sdss_filters()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Photometric Redshifts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One interesting regression problem which appears often in Astronomy is the photometric\n", "determination of the galaxy redshift.\n", "\n", "In the current standard cosmological model, the universe began nearly 14 billion years ago,\n", "in an explosive event commonly known as the Big Bang. Since then, the very fabric\n", "of space has been expanding, so that distant galaxies appear to be moving away from\n", "us at very high speeds. The uniformity of this expansion means that there is a\n", "relationship between the distance to a galaxy, and the speed that it appears to be\n", "receeding from us (this relationship is known as Hubble\u2019s Law, named after Edwin Hubble).\n", "This recession speed leads to a shift in the frequency of photons, very similar to the\n", "more familiar doppler shift that causes the pitch of a siren to change as an emergency\n", "vehicle passes by. If a galaxy or star were moving toward us, its light would be shifted\n", "to higher frequencies, or blue-shifted. Because the universe is expanding away from us,\n", "distant galaxies appear to be red-shifted: their photons are shifted to lower frequencies.\n", "\n", "In cosmology, the redshift is measured with the parameter $z$, defined in terms of the\n", "observed wavelength $\\lambda_{obs}$ and the emitted wavelength $\\lambda_{em}$:\n", "\n", "$\\lambda_{obs} = (1 + z)\\lambda_{em}$\n", "\n", "When a spectrum can be obtained, determining the redshift is rather straight-forward:\n", "if you can localize the spectral fingerprint of a common element, such as hydrogen,\n", "then the redshift can be computed using simple arithmetic. But the task becomes much\n", "more difficult when only photometric observations are available.\n", "\n", "Because of the spectrum shift, an identical source at different redshifts will have a\n", "different color through each pair of filters. See the following figure:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from figures import plot_redshifts\n", "plot_redshifts()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This again shows the spectrum of the star Vega ($\\alpha$-Lyr), but\n", "at three different redshifts. The SDSS ugriz filters are shown in gray for reference.\n", "\n", "At redshift z=0.0, the spectrum is bright in the u and g filters,\n", "but dim in the i and z filters. At redshift z=0.8, the opposite\n", "is the case. This suggests the possibility of determining redshift\n", "from photometry alone. The situation is complicated by the fact that\n", "each individual source has unique spectral characteristics, but\n", "nevertheless, these photometric redshifts are often used in astronomical applications." ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Motivation: Dark Energy, Dark Matter, and the Fate of the Universe" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The photometric redshift problem is very important. Future astronomical\n", "surveys hope to image trillions of very faint galaxies, and use this data \n", "to inform our view of the universe as a whole: its history, its geometry, \n", "and its fate. Obtaining an accurate estimate of the redshift to each of \n", "these galaxies is a pivotal part of this task. Because these surveys will \n", "image so many extremely faint galaxies, there is no possibility of obtaining \n", "a spectrum for each one. Thus sophisticated photometric redshift codes will \n", "be required to advance our understanding of the Universe, including more \n", "precisely understanding the nature of the dark energy that is currently \n", "accelerating the cosmic expansion." ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Decision Tree Regression" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we will address the photometric redshift problem using 50,000 observations\n", "from the Sloan Digital Sky Survey. This example draws from examples available\n", "through the ``astroML`` python package, which can be found here:\n", "\n", "- http://astroml.github.com/book_figures/chapter9/fig_photoz_tree.html\n", "- http://astroml.github.com/book_figures/chapter9/fig_photoz_forest.html\n", "\n", "We'll start by downloading the data. This fetch function actually generates\n", "an SQL query and downloads the data from the SDSS database. The results\n", "will be cached locally so that subsequent calls to the function don't result\n", "in another download." ] }, { "cell_type": "code", "collapsed": false, "input": [ "import numpy as np\n", "from datasets import fetch_sdss_galaxy_mags" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "# This will download a ~3MB file the first time you call the function\n", "data = fetch_sdss_galaxy_mags()\n", "\n", "print data.shape\n", "print data.dtype" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we'll extract the data. Because the relative magnitudes are easier to\n", "calibrate than the absolute magnitude, we'll work with what astronomers call\n", "*colors*, the difference of two magnitudes. Because the magnitudes are related\n", "to the logarithm of the flux, the colors can be thought of as normalized magnitudes." ] }, { "cell_type": "code", "collapsed": false, "input": [ "redshift = data['redshift']\n", "mags = np.vstack([data[f] for f in 'ugriz']).transpose()\n", "colors = mags[:, :-1] - mags[:, 1:]\n", "print colors.shape" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll split the data into training and validation sets, and do a decision tree\n", "fit to the data:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn import cross_validation\n", "ctrain, ctest, ztrain, ztest = cross_validation.train_test_split(colors, redshift)\n", "\n", "from sklearn.tree import DecisionTreeRegressor\n", "clf = DecisionTreeRegressor()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by running a 4-fold cross validation and see how we're doing. The\n", "cross validation here prints the r2 score, which lies between 0 and 1. The\n", "closer the score is to 1, the better:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "print cross_validation.cross_val_score(clf, colors, redshift, cv=4)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another way we can visualize the results is to scatter-plot the input versus the output:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "# We'll use this function several times below\n", "def plot_redshifts(ztrue, zpred):\n", " \"\"\"scatter-plot the true vs predicted redshifts\"\"\"\n", " fig, ax = plt.subplots(figsize=(8, 8))\n", " ax.plot(ztrue, zpred, '.k')\n", " \n", " # plot trend lines, +/- 0.1 in z\n", " ax.plot([0, 3], [0, 3], '--k')\n", " ax.plot([0, 3], [0.2, 3.2], ':k')\n", " ax.plot([0.2, 3.2], [0, 3], ':k')\n", " \n", " ax.text(2.9, 0.1,\n", " \"RMS = %.2g\" % np.sqrt(np.mean((ztrue - zpred) ** 2)),\n", " ha='right', va='bottom')\n", "\n", " ax.set_xlim(0, 3)\n", " ax.set_ylim(0, 3)\n", " \n", " ax.set_xlabel('True redshift')\n", " ax.set_ylabel('Predicted redshift')" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "clf = DecisionTreeRegressor()\n", "clf.fit(ctrain, ztrain)\n", "zpred = clf.predict(ctest)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "plot_redshifts(ztest, zpred)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see several things here: first, there are some regions of redshift space where\n", "the results are not very precise: they can vary by $\\pm 0.2$ or so.\n", "\n", "Second, there are many *catastrophic outliers*: values where the prediction is\n", "completely wrong. Both these sources of error are important to minimize, or\n", "at the very least, statistically characterize. We'll explore some ways to do\n", "this below." ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Optimizing the Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The question of how to improve the model goes back to the discussion of\n", "Learning Curves from earlier. We'll start by attempting to optimize the\n", "model itself through cross-validation.\n", "\n", "The decision tree regressors have a few hyperparameters, but one of the\n", "more important is the *depth*. This tells how many times the data set\n", "is split in the process of computing a fit. Here we'll plot the validation\n", "curve for the maximum depth:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "# we'll explore results for max_depth from 1 to 20\n", "max_depth_array = np.arange(1, 21)\n", "train_error = np.zeros(len(max_depth_array))\n", "test_error = np.zeros(len(max_depth_array))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "for i, max_depth in enumerate(max_depth_array):\n", " clf = DecisionTreeRegressor(max_depth=max_depth)\n", " clf.fit(ctrain, ztrain)\n", "\n", " ztrain_pred = clf.predict(ctrain)\n", " ztest_pred = clf.predict(ctest)\n", "\n", " train_error[i] = np.sqrt(np.mean((ztrain_pred - ztrain) ** 2))\n", " test_error[i] = np.sqrt(np.mean((ztest_pred - ztest) ** 2))" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "plt.plot(max_depth_array, train_error, label='training')\n", "plt.plot(max_depth_array, test_error, label='validation')\n", "plt.grid()\n", "plt.legend(loc=3)\n", "plt.xlabel('max_depth')\n", "plt.ylabel('error')" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see a very clean curve which looks much like what we'd expect: the\n", "training error decreases consistently as the model over-fits it more\n", "and more, while the validation error turns over at some optimal value.\n", "\n", "Note that scikit-learn has a set of functions which automate this sort\n", "of calculation: it's in the ``grid_search`` model. As of scikit-learn version 0.13,\n", "the interface for grid search is still evolving, so the following code\n", "may have to be adjusted when newer scikit-learn versions are released:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "from sklearn.grid_search import GridSearchCV\n", "clf = DecisionTreeRegressor()\n", "grid = GridSearchCV(clf, param_grid=dict(max_depth=max_depth_array))\n", "grid.fit(colors, redshift)\n", "print grid.best_params_" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The grid search uses a full cross-validation rather than a single validation set:\n", "this is what leads to the larger optimal value of `max_depth`." ] }, { "cell_type": "heading", "level": 3, "metadata": {}, "source": [ "Plotting the optimal model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can read from the above plot that the optimal RMS is obtained with a\n", "max depth of about 7. Let's see what this looks like:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "clf = DecisionTreeRegressor(max_depth=7)\n", "clf.fit(ctrain, ztrain)\n", "zpred = clf.predict(ctest)\n", "plot_redshifts(ztest, zpred)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ugh... not pretty.\n", "\n", "Even though *by eye* this looks like a much worse fit, than we had above, it actually\n", "does have a better RMS residual (0.21 vs. 0.27). This is a good illustration that\n", "**the form of the loss or score function is very important**." ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "An Alternative Loss Function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We said above that along with being concerned about RMS error, we're also\n", "concerned about the level of outliers in the data. With that in mind, we\n", "could propose another loss function particular to this problem: the fraction\n", "of points with an absolute deviation greater than 0.2 (that is, outside the\n", "dotted lines in the scatter-plot).\n", "\n", "We'll define the function to compute this, and then create the validation\n", "curve for this metric:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "def outlier_fraction(y_pred, y_true, cutoff=0.2):\n", " return np.sum((abs(y_pred - y_true) > cutoff)) * 1. / len(y_pred)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "train_outfrac = np.zeros(len(max_depth_array))\n", "test_outfrac = np.zeros(len(max_depth_array))\n", "\n", "for i, max_depth in enumerate(max_depth_array):\n", " clf = DecisionTreeRegressor(max_depth=max_depth)\n", " clf.fit(ctrain, ztrain)\n", "\n", " ztrain_pred = clf.predict(ctrain)\n", " ztest_pred = clf.predict(ctest)\n", "\n", " train_outfrac[i] = outlier_fraction(ztrain_pred, ztrain)\n", " test_outfrac[i] = outlier_fraction(ztest_pred, ztest)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "code", "collapsed": false, "input": [ "plt.plot(max_depth_array, train_outfrac)\n", "plt.plot(max_depth_array, test_outfrac)\n", "plt.grid()" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This outlier-based loss function settles on a much deeper tree. Let's\n", "see what the result looks like with a max depth of 19:" ] }, { "cell_type": "code", "collapsed": false, "input": [ "clf = DecisionTreeRegressor(max_depth=20)\n", "clf.fit(ctrain, ztrain)\n", "zpred = clf.predict(ctest)\n", "plot_redshifts(ztest, zpred)" ], "language": "python", "metadata": {}, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Unfortunately, this is about as far as we can go with a simple decision tree\n", "trained on this data. There are two more possibilities we can pursue, though:\n", "\n", "- Optimize the data: observe more samples, or observe more features of each sample\n", " (using learning curves to determine which is better)\n", "- Optimize the model: use a more sophisticated estimator \n", "\n", "These approaches will be exercises below." ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Exercise: Optimizing the Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Earlier we showed how *learning curves* can be used to determine the best course\n", "of action when a model is under-performing: should we gather more samples? Gather\n", "more features? Seek a more sophisticated model?\n", "\n", "The goal of this exercise is to use learning curves to answer this question:\n", "**how should astronomers spend their resources when trying to improve photometric redshifts?**" ] }, { "cell_type": "heading", "level": 2, "metadata": {}, "source": [ "Exercise: Better Results Through Ensemble Methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One way to improve upon Decision Trees is to use *Ensemble Methods*.\n", "Ensemble methods (also known as *boosting* and *bagging*) use ensembles\n", "of randomized estimators which can prevent over-fitting and lead to very\n", "powerful learning algorithms.\n", "\n", "It is interesting to see how small an RMS you can attain on the photometric\n", "redshift problem using a more sophisticated method. Try one of the following:\n", "\n", "- ``sklearn.ensemble.RandomForestRegressor``\n", "- ``sklearn.ensemble.GradientBoostingRegressor``\n", "- ``sklearn.ensemble.ExtraTreesRegressor``\n", "\n", "You can read more about each of these methods in the scikit-learn documentation:\n", "\n", "- http://scikit-learn.org/stable/modules/ensemble.html\n", "\n", "Each method has hyperparameters which need to be determined using cross-validation\n", "steps like those above. Can you use ensemble methods to reduce the rms error for\n", "the test set down below 0.1?\n", "\n", "Here you can adjust several hyperparameters, but the important ones will be\n", "the number of estimators ``n_estimators`` as well as the maximum depth\n", "``max_depth`` that we saw above." ] } ], "metadata": {} } ] }