{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Relationships" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "[Click here to run this notebook on Colab](https://colab.research.google.com/github/AllenDowney/ElementsOfDataScience/blob/master/09_relationships.ipynb) or\n", "[click here to download it](https://github.com/AllenDowney/ElementsOfDataScience/raw/master/09_relationships.ipynb)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This chapter explores relationships between variables.\n", "\n", "* We will visualize relationships using scatter plots, box plots, and violin plots,\n", "\n", "* And we will quantify relationships using correlation and simple regression.\n", "\n", "The most important lesson in this chapter is that you should always visualize the relationship between variables before you try to quantify it; otherwise, you are likely to be misled." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [] }, "outputs": [], "source": [ "from os.path import basename, exists\n", "\n", "def download(url):\n", " filename = basename(url)\n", " if not exists(filename):\n", " from urllib.request import urlretrieve\n", " local, _ = urlretrieve(url, filename)\n", " print('Downloaded ' + local)\n", " \n", "download('https://github.com/AllenDowney/' +\n", " 'ElementsOfDataScience/raw/master/brfss.hdf5')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exploring relationships\n", "\n", "So far we have mostly considered one variable at a time. Now it's time to explore relationships between variables.\n", "As a first example, we'll look at the relationship between height and weight.\n", "\n", "We'll use data from the Behavioral Risk Factor Surveillance System (BRFSS), which is run by the Centers for Disease Control at . The survey includes more than 400,000 respondents, but to keep things manageable, we'll work with a random subsample of 100,000." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "brfss = pd.read_hdf('brfss.hdf5', 'brfss')\n", "brfss.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are the first few rows." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "brfss.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The BRFSS includes hundreds of variables. For the examples in this chapter, we'll work with just nine.\n", "The ones we'll start with are `HTM4`, which records each respondent's height in cm, and `WTKG3`, which records weight in kg." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "height = brfss['HTM4']\n", "weight = brfss['WTKG3']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To visualize the relationship between these variables, we'll make a **scatter plot**, which shows one marker for each pair of values.\n", "Scatter plots are common and readily understood, but they are surprisingly hard to get right.\n", "\n", "As a first attempt, we'll use `plot` with the style string `o`, which plots a circle for each data point." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.plot(height, weight, 'o')\n", "\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each marker represents the height and weight of one person.\n", "\n", "Based on the shape of the result, it looks like taller people are heavier, but there are a few things about this plot that make it hard to interpret.\n", "Most importantly, it is **overplotted**, which means that there are markers piled on top of each other so you can't tell where there are a lot of data points and where there is just one.\n", "When that happens, the results can be seriously misleading.\n", "\n", "One way to improve the plot is to use transparency, which we can do with the keyword argument `alpha`. The lower the value of alpha, the more transparent each data point is. \n", "\n", "Here's what it looks like with `alpha=0.02`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "plt.plot(height, weight, 'o', alpha=0.02)\n", "\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is better, but there are so many data points, the scatter plot is still overplotted. The next step is to make the markers smaller.\n", "With `markersize=1` and a low value of alpha, the scatter plot is less saturated.\n", "Here's what it looks like." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "plt.plot(height, weight, 'o', alpha=0.02, markersize=1)\n", "\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, this is better, but now we can see that the points fall in discrete columns. That's because most heights were reported in inches and converted to centimeters.\n", "We can break up the columns by adding some random noise to the values; in effect, we are filling in the values that got rounded off.\n", "Adding random noise like this is called **jittering**.\n", "\n", "We can use NumPy to add noise from a normal distribution with mean 0 and standard deviation 2." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "noise = np.random.normal(0, 2, size=len(brfss))\n", "height_jitter = height + noise" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's what the plot looks like with jittered heights." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "plt.plot(height_jitter, weight, 'o', \n", " alpha=0.02, markersize=1)\n", "\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The columns are gone, but now we can see that there are rows where people rounded off their weight. We can fix that by jittering weight, too." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "noise = np.random.normal(0, 2, size=len(brfss))\n", "weight_jitter = weight + noise" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "plt.plot(height_jitter, weight_jitter, 'o', \n", " alpha=0.02, markersize=1)\n", "\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let's zoom in on the area where most of the data points are.\n", "\n", "The functions `xlim` and `ylim` set the lower and upper bounds for the $x$ and $y$-axis; in this case, we plot heights from 140 to 200 centimeters and weights up to 160 kilograms.\n", "\n", "Here's what it looks like." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "plt.plot(height_jitter, weight_jitter, 'o', \n", " alpha=0.02, markersize=1)\n", "\n", "plt.xlim([140, 200])\n", "plt.ylim([0, 160])\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have a reliable picture of the relationship between height and weight.\n", "\n", "Below you can see the misleading plot we started with and the more reliable one we ended with. They are clearly different, and they suggest different relationships between these variables." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Set the figure size\n", "plt.figure(figsize=(8, 3))\n", "\n", "# Create subplots with 2 rows, 1 column, and start plot 1\n", "plt.subplot(1, 2, 1)\n", "plt.plot(height, weight, 'o')\n", "\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height')\n", "\n", "# Adjust the layout so the two plots don't overlap\n", "plt.tight_layout()\n", "\n", "# Start plot 2\n", "plt.subplot(1, 2, 2)\n", "\n", "plt.plot(height_jitter, weight_jitter, 'o', \n", " alpha=0.02, markersize=1)\n", "\n", "plt.xlim([140, 200])\n", "plt.ylim([0, 160])\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height')\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The point of this example is that it takes some effort to make an effective scatter plot." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** Do people tend to gain weight as they get older? We can answer this question by visualizing the relationship between weight and age. \n", "\n", "But before we make a scatter plot, it is a good idea to visualize distributions one variable at a time. So let's look at the distribution of age.\n", "\n", "The BRFSS dataset includes a column, `AGE`, which represents each respondent's age in years. To protect respondents' privacy, ages are rounded off into 5-year bins. `AGE` contains the midpoint of the bins.\n", "\n", "- Extract the variable `'AGE'` from the DataFrame `brfss` and assign it to `age`.\n", "\n", "- Plot the PMF of `age` as a bar chart, using `Pmf` from `empiricaldist`." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "tags": [] }, "outputs": [], "source": [ "try:\n", " import empiricaldist\n", "except ImportError:\n", " !pip install empiricaldist" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from empiricaldist import Pmf" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** Now let's look at the distribution of weight. The column that contains weight in kilograms is `WTKG3`. Because this column contains many unique values, displaying it as a PMF doesn't work very well." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "Pmf.from_seq(weight).bar()\n", "\n", "plt.xlabel('Weight in kg')\n", "plt.ylabel('PMF')\n", "plt.title('Distribution of weight');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get a better view of this distribution, try plotting the CDF.\n", "\n", "Compute the CDF of a normal distribution with the same mean and standard deviation, and compare it with the distribution of weight. Is the normal distribution a good model for this data? What about log-transformed weights?" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** Now let's make a scatter plot of `weight` versus `age`. Adjust `alpha` and `markersize` to avoid overplotting. Use `ylim` to limit the `y` axis from 0 to 200 kilograms." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** In the previous exercise, the ages fall in columns because they've been rounded into 5-year bins. If we jitter them, the scatter plot will show the relationship more clearly.\n", "\n", "- Add random noise to `age` with mean `0` and standard deviation `2.5`.\n", "- Make a scatter plot and adjust `alpha` and `markersize` again." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualizing relationships\n", "\n", "In the previous section we used scatter plots to visualize relationships between variables, and in the exercises, you explored the relationship between age and weight. In this section, we'll see other ways to visualize these relationships, including boxplots and violin plots.\n", "\n", "Let's start with a scatter plot of weight versus age." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "age = brfss['AGE']\n", "noise = np.random.normal(0, 1.0, size=len(brfss))\n", "age_jitter = age + noise\n", "\n", "plt.plot(age_jitter, weight_jitter, 'o', \n", " alpha=0.01, markersize=1)\n", "\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.ylim([0, 200])\n", "plt.title('Weight versus age');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this version of the scatter plot, the weights are jittered, but there's still space between the columns.\n", "That makes it possible to see the shape of the distribution in each age group, and the differences between groups.\n", "With this view, it looks like weight increases until age 40 or 50, and then starts to decrease.\n", "\n", "If we take this idea one step farther, we can use KDE to estimate the density function in each column and plot it.\n", "And there's a name for that; it's called a **violin plot**. Seaborn provides a function that makes violin plots, but before we can use it, we have to get rid of any rows with missing data.\n", "Here's how:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "data = brfss.dropna(subset=['AGE', 'WTKG3'])\n", "data.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`dropna()` creates a new DataFrame that drops the rows in `brfss` where `AGE` or `WTKG3` are `NaN`.\n", "Now we can call `violinplot`." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "import seaborn as sns\n", "\n", "sns.violinplot(x='AGE', y='WTKG3', data=data, inner=None)\n", "\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Weight versus age');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `x` and `y` arguments mean we want `AGE` on the x-axis and `WTKG3` on the y-axis. `data` is the DataFrame we just created, which contains the variables we're going to plot. The argument `inner=None` simplifies the plot a little.\n", "\n", "In the figure, each shape represents the distribution of weight in one age group. The width of these shapes is proportional to the estimated density, so it's like two vertical KDEs plotted back to back.\n", "\n", "Another, related way to look at data like this is called a **box plot**, which represents summary statistics for the values in each group. \n", "The code to generate a box plot is very similar. " ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "sns.boxplot(x='AGE', y='WTKG3', data=data, whis=10)\n", "\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Weight versus age');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The argument `whis=10` turns off a feature we don't need. If you are curious about it, you can [read the documentation](https://seaborn.pydata.org/generated/seaborn.boxplot.html).\n", "\n", "Each box represents the distribution of weight in an age group. The height of each box represents the range from the 25th to the 75th percentile. The line in the middle of each box is the median.\n", "The spines sticking out of the top and bottom show the minimum and maximum values.\n", "\n", "In my opinion, this plot gives us the best view of the relationship between weight and age.\n", "\n", "* Looking at the medians, it seems like people in their 40s are the heaviest; younger and older people are lighter.\n", "\n", "* Looking at the sizes of the boxes, it seems like people in their 40s have the most variability in weight, too.\n", "\n", "* These plots also show how skewed the distribution of weight is; that is, the heaviest people are much farther from the median than the lightest people.\n", "\n", "For data that skews toward higher values, it is sometimes useful to look at it on a logarithmic scale.\n", "We can do that with the Pyplot function `yscale`." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "sns.boxplot(x='AGE', y='WTKG3', data=data, whis=10)\n", "\n", "plt.yscale('log')\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg (log scale)')\n", "plt.title('Weight versus age');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On a log scale, the distributions are symmetric, so the spines are the same length, the boxes are close to the middle of the figure, and we can see the relationship between age and weight more clearly.\n", "\n", "In the following exercises, you'll have a chance to generate violin and box plots for other variables." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** Previously we looked at a scatter plot of height and weight, and saw that taller people tend to be heavier. Now let's take a closer look using a box plot. The `brfss` DataFrame contains a column named `_HTMG10` that represents height in centimeters, binned into 10 cm groups.\n", "\n", "- Make a boxplot that shows the distribution of weight in each height group.\n", "\n", "- Plot the y-axis on a logarithmic scale.\n", "\n", "Suggestion: If the labels on the `x` axis collide, you can rotate them like this:\n", "\n", "```\n", "plt.xticks(rotation='45')\n", "```" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** As a second example, let's look at the relationship between income and height. \n", "\n", "In the BRFSS, income is represented as a categorical variable; that is, respondents are assigned to one of 8 income categories. The column name is `INCOME2`. \n", "\n", "Before we connect income with anything else, let's look at the distribution by computing the PMF.\n", "\n", "* Extract `INCOME2` from `brfss` and assign it to `income`.\n", "\n", "* Plot the PMF of `income` as a bar chart.\n", "\n", "Note: You will see that about a third of the respondents are in the highest income group; ideally, it would be better if there were more groups at the high end, but we'll work with what we have." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** Generate a violin plot that shows the distribution of height in each income group. Can you see a relationship between these variables?" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Quantifying Correlation\n", "\n", "In the previous section, we visualized relationships between pairs of variables. Now we'll learn about the **coefficient of correlation**, which quantifies the strength of these relationships.\n", "\n", "When people say \"correlation\" casually, they might mean any relationship between two variables. In statistics, it usually means Pearson's correlation coefficient, which is a number between `-1` and `1` that quantifies the strength of a linear relationship between variables.\n", "\n", "To demonstrate, we'll select three columns from the BRFSS dataset:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "columns = ['HTM4', 'WTKG3', 'AGE']\n", "subset = brfss[columns]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is a DataFrame with just those columns.\n", "With this subset of the data, we can use the `corr` method, like this:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "subset.corr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is a **correlation matrix**. Reading across the first row, the correlation of `HTM4` with itself is `1`. That's expected; the correlation of anything with itself is `1`.\n", "\n", "The next entry is more interesting; the correlation of height and weight is about `0.47`. It's positive, which means taller people are heavier, and it is moderate in strength, which means it has some predictive value, but not much. If you know someone's height, you can make a somewhat better guess about their weight.\n", "\n", "The correlation between height and age is about `-0.09`. It's negative, which means that older people tend to be shorter, but it's weak, which means that knowing someone's age would not help much if you were trying to guess their height." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The correlation between age and weight is even smaller. It is tempting to conclude that there is no relationship between age and weight, but we have already seen that there is. So why is the correlation so low?\n", "Remember that the relationship between weight and age looks like this. " ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "tags": [] }, "outputs": [], "source": [ "data = brfss.dropna(subset=['AGE', 'WTKG3'])\n", "sns.boxplot(x='AGE', y='WTKG3', data=data, whis=10)\n", "\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Weight versus age');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "People in their forties are the heaviest; younger and older people are lighter. So this relationship is nonlinear.\n", "But correlation only measures linear relationships. If the relationship is nonlinear, correlation generally underestimates how strong it is.\n", "\n", "To demonstrate, I'll generate some fake data: `xs` contains equally-spaced points between `-1` and `1`. `ys` is `xs` squared plus some random noise." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "xs = np.linspace(-1, 1)\n", "ys = xs**2 + np.random.normal(0, 0.05, len(xs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's the scatter plot of `xs` and `ys`. " ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "plt.plot(xs, ys, 'o', alpha=0.5)\n", "plt.xlabel('x')\n", "plt.ylabel('y')\n", "plt.title('Scatter plot of a fake dataset');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's clear that this is a strong relationship; if you are given `x`, you can make a much better guess about `y`.\n", "But here's the correlation matrix:" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "np.corrcoef(xs, ys)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Even though there is a strong non-linear relationship, the computed correlation is close to `0`.\n", "\n", "In general, if correlation is high -- that is, close to `1` or `-1` -- you can conclude that there is a strong linear relationship.\n", "But if correlation is close to `0`, that doesn't mean there is no relationship; there might be a non-linear relationship.\n", "\n", "This is one of the reasons I think correlation is not such a great statistic.\n", "There's another reason to be careful with correlation; it doesn't mean what people take it to mean.\n", "Specifically, correlation says nothing about slope. If we say that two variables are correlated, that means we can use one to predict the other. But that might not be what we care about.\n", "\n", "For example, suppose we are concerned about the health effects of weight gain, so we plot weight versus age from 20 to 50 years old.\n", "I'll generate two fake datasets to demonstrate the point.\n", "In each dataset, `xs` represents age and `ys` represents weight." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "I use `np.random.seed` to initialize the random number generator so we get the same results every time we run." ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "tags": [] }, "outputs": [], "source": [ "np.random.seed(18)\n", "xs1 = np.linspace(20, 50)\n", "ys1 = 75 + 0.02 * xs1 + np.random.normal(0, 0.15, len(xs1))\n", "\n", "plt.plot(xs1, ys1, 'o', alpha=0.5)\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Fake dataset #1');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here's the second dataset:" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "tags": [] }, "outputs": [], "source": [ "np.random.seed(18)\n", "xs2 = np.linspace(20, 50)\n", "ys2 = 65 + 0.2 * xs2 + np.random.normal(0, 3, len(xs2))\n", "\n", "plt.plot(xs2, ys2, 'o', alpha=0.5)\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Fake dataset #2');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I constructed these examples so they look similar, but they have substantially different correlations: " ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "rho1 = np.corrcoef(xs1, ys1)[0][1]\n", "rho1" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "rho2 = np.corrcoef(xs2, ys2)[0][1]\n", "rho2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the first example, the correlation is strong, close to `0.75`.\n", "In the second example, the correlation is moderate, close to `0.5`.\n", "So we might think the first relationship is more important. But look more closely at the `y` axis in both figures.\n", "\n", "In the first example, the average weight gain over 30 years is less than 1 kilogram; in the second it is more than 5 kilograms!\n", "If we are concerned about the health effects of weight gain, the second relationship is probably more important, even though the correlation is lower. \n", "The statistic we really care about is the slope of the line, not the coefficient of correlation.\n", "\n", "In the next section, we'll see how to estimate that slope. But first, let's practice with correlation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** The purpose of the BRFSS is to explore health risk factors, so it includes questions about diet. The column `_VEGESU1` represents the number of servings of vegetables respondents reported eating per day.\n", "\n", "Let's see how this variable relates to age and income.\n", "\n", "- From the `brfss` DataFrame, select the columns `'AGE'`, `INCOME2`, and `_VEGESU1`.\n", "- Compute the correlation matrix for these variables." ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** In the previous exercise, the correlation between income and vegetable consumption is about `0.12`. The correlation between age and vegetable consumption is about `-0.01`.\n", "\n", "Which of the following are correct interpretations of these results?\n", "\n", "- *A*: People in this dataset with higher incomes eat more vegetables.\n", "- *B*: The relationship between income and vegetable consumption is linear.\n", "- *C*: Older people eat more vegetables.\n", "- *D*: There could be a strong non-linear relationship between age and vegetable consumption." ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** In general it is a good idea to visualize the relationship between variables *before* you compute a correlation. We didn't do that in the previous example, but it's not too late.\n", "\n", "Generate a visualization of the relationship between age and vegetables. How would you describe the relationship, if any?" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple Linear Regression\n", "\n", "In the previous section we saw that correlation does not always measure what we really want to know. In this section, we look at an alternative: simple linear regression.\n", "\n", "Let's look again at the relationship between weight and age. In the previous section, I generated two fake datasets to make a point:" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "tags": [] }, "outputs": [], "source": [ "plt.figure(figsize=(8, 3))\n", "\n", "plt.subplot(1, 2, 1)\n", "plt.plot(xs1, ys1, 'o', alpha=0.5)\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Fake dataset #1')\n", "plt.tight_layout()\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(xs2, ys2, 'o', alpha=0.5)\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Fake dataset #2')\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The one on the left has higher correlation, about 0.75 compared to 0.5.\n", "But in this context, the statistic we probably care about is the slope of the line, not the correlation coefficient.\n", "To estimate the slope, we can use `linregress` from the SciPy `stats` library." ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "from scipy.stats import linregress\n", "\n", "res1 = linregress(xs1, ys1)\n", "res1._asdict()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is a `LinregressResult` object that contains five values: `slope` is the slope of the line of best fit for the data; `intercept` is the intercept. We'll interpret some of the other values later.\n", "\n", "For Fake Dataset #1, the estimated slope is about 0.019 kilograms per year or about 0.56 kilograms over the 30-year range." ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "res1.slope * 30" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are the results for Fake Dataset #2. " ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "res2 = linregress(xs2, ys2)\n", "res2._asdict()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The estimated slope is almost 10 times higher: about 0.18 kilograms per year or about 5.3 kilograms per 30 years:" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "res2.slope * 30" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What's called `rvalue` here is correlation, which confirms what we saw before; the first example has higher correlation, about 0.75 compared to 0.5.\n", "But the strength of the effect, as measured by the slope of the line, is about 10 times higher in the second example.\n", "\n", "We can use the results from `linregress` to compute the line of best fit: first we get the minimum and maximum of the observed `xs`; then we multiply by the slope and add the intercept.\n", "Here's what that looks like for the first example. " ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "plt.plot(xs1, ys1, 'o', alpha=0.5)\n", "\n", "fx = np.array([xs1.min(), xs1.max()])\n", "fy = res1.intercept + res1.slope * fx\n", "plt.plot(fx, fy, '-')\n", "\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Fake Dataset #1');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here's what it looks like for the second example. " ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "plt.plot(xs2, ys2, 'o', alpha=0.5)\n", "\n", "fx = np.array([xs2.min(), xs2.max()])\n", "fy = res2.intercept + res2.slope * fx\n", "plt.plot(fx, fy, '-')\n", "\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Fake Dataset #2');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The visualization here might be misleading unless you look closely at the vertical scales; the slope in the second figure is almost 10 times higher." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Regression of Height and Weight\n", "\n", "Now let's look at an example of regression with real data.\n", "Here's the scatter plot of height and weight one more time." ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "plt.plot(height_jitter, weight_jitter, 'o', \n", " alpha=0.02, markersize=1)\n", "\n", "plt.xlim([140, 200])\n", "plt.ylim([0, 160])\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To compute the regression line, we'll use `linregress` again.\n", "But it can't handle `NaN` values, so we have to use `dropna` to remove rows that are missing the data we need." ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "subset = brfss.dropna(subset=['WTKG3', 'HTM4'])\n", "height_clean = subset['HTM4']\n", "weight_clean = subset['WTKG3']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can compute the linear regression." ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "res_hw = linregress(height_clean, weight_clean)\n", "res_hw._asdict()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The slope is about 0.92 kilograms per centimeter, which means that we expect a person one centimeter taller to be almost a kilogram heavier. That's quite a lot.\n", "\n", "As before, we can compute the line of best fit:" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "fx = np.array([height_clean.min(), height_clean.max()])\n", "fy = res_hw.intercept + res_hw.slope * fx" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here's what that looks like. " ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "plt.plot(height_jitter, weight_jitter, 'o', alpha=0.02, markersize=1)\n", "\n", "plt.plot(fx, fy, '-')\n", "\n", "plt.xlim([140, 200])\n", "plt.ylim([0, 160])\n", "plt.xlabel('Height in cm')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Scatter plot of weight versus height');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The slope of this line seems consistent with the scatter plot.\n", "\n", "Linear regression has the same problem as correlation; it only measures the strength of a linear relationship.\n", "Here's the scatter plot of weight versus age, which we saw earlier." ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "plt.plot(age_jitter, weight_jitter, 'o', \n", " alpha=0.01, markersize=1)\n", "\n", "plt.ylim([0, 160])\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Weight versus age');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "People in their 40s are the heaviest; younger and older people are lighter. So the relationship is nonlinear.\n", "\n", "If we don't look at the scatter plot and blindly compute the regression line, here's what we get." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "subset = brfss.dropna(subset=['WTKG3', 'AGE'])\n", "age_clean = subset['AGE']\n", "weight_clean = subset['WTKG3']\n", "\n", "res_aw = linregress(age_clean, weight_clean)\n", "res_aw._asdict()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The estimated slope is only 0.02 kilograms per year, or 0.6 kilograms in 30 years.\n", "And here's what the line of best fit looks like." ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "plt.plot(age_jitter, weight_jitter, 'o', \n", " alpha=0.01, markersize=1)\n", "\n", "fx = np.array([age_clean.min(), age_clean.max()])\n", "fy = res_aw.intercept + res_aw.slope * fx\n", "plt.plot(fx, fy, '-')\n", "\n", "plt.ylim([0, 160])\n", "plt.xlabel('Age in years')\n", "plt.ylabel('Weight in kg')\n", "plt.title('Weight versus age');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A straight line does not capture the relationship between these variables well.\n", "\n", "In the next chapter, you'll see how to use multiple regression to estimate non-linear relationships. But first, let's practice simple regression." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** Who do you think eats more vegetables, people with low income, or people with high income? Let's find out.\n", "\n", "As we've seen previously, the column `INCOME2` represents income level and `_VEGESU1` represents the number of vegetable servings respondents reported eating per day.\n", "\n", "Make a scatter plot with vegetable servings versus income, that is, with vegetable servings on the `y` axis and income group on the `x` axis.\n", "\n", "You might want to use `ylim` to zoom in on the bottom half of the `y` axis." ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** Now let's estimate the slope of the relationship between vegetable consumption and income.\n", "\n", "- Use `dropna` to select rows where `INCOME2` and `_VEGESU1` are not `NaN`.\n", "\n", "- Extract `INCOME2` and `_VEGESU1` and compute the simple linear regression of these variables.\n", "\n", "What is the slope of the regression line? What does this slope means in the context of the question we are exploring?" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Exercise:** Finally, plot the regression line on top of the scatter plot.\n" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "# Solution goes here" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "This chapter presents three ways to visualize the relationship between two variables: a scatter plot, violin plot, and box plot.\n", "A scatter plot is often a good choice when you are exploring a new data set, but it can take some attention to avoid overplotting.\n", "Violin and box plot are particularly useful when one of the variables only takes on a few discrete values.\n", "\n", "And we considered two ways to quantify the strength of a relationship: the coefficient of correlation and the slope of a regression line.\n", "These statistics capture different aspect of what we might mean by \"strength\".\n", "The coefficient of correlation indicates how well we can predict one variable, given the other.\n", "The slope of the regression line indicates how much difference we expect in one variable as we vary the other.\n", "One or the other might be more relevant, depending on the context." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "*Elements of Data Science*\n", "\n", "Copyright 2021 [Allen B. Downey](https://allendowney.com)\n", "\n", "License: [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 2 }