{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Relationships\n",
"\n",
"Elements of Data Science\n",
"\n",
"by [Allen Downey](https://allendowney.com)\n",
"\n",
"[MIT License](https://opensource.org/licenses/MIT)\n",
"\n",
"### Goals\n",
"\n",
"This notebook explores relationships between variables.\n",
"\n",
"* We will visualize relationships using scatter plots, box plots, and violin plots,\n",
"\n",
"* And we will quantify relationships using correlation and simple regression.\n",
"\n",
"The most important lesson in this notebook is that you should always visualize the relationship between variables before you try to quantify it; otherwise, you are likely to be misled."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Get the data file\n",
"\n",
"import os\n",
"\n",
"if not os.path.exists('brfss.hdf5'):\n",
" !wget https://github.com/AllenDowney/ElementsOfDataScience/raw/master/brfss.hdf5"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# If we're running on Colab, install empiricaldist\n",
"# https://pypi.org/project/empiricaldist/\n",
"\n",
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"if IN_COLAB:\n",
" !pip install empiricaldist"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exploring relationships\n",
"\n",
"So far we have only looked at one variable at a time. Now it's time to explore relationships between variables.\n",
"\n",
"As a first example, we'll look at the relationship between height and weight.\n",
"\n",
"I'll use data from the Behavioral Risk Factor Surveillance Survey, or BRFSS, which is run by the Centers for Disease Control. The survey includes more than 400,000 respondents, but to keep things manageable, we'll use a random subsample of 100,000."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"brfss = pd.read_hdf('brfss.hdf5', 'brfss')\n",
"brfss.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Like the NSFG, the BRFSS deliberately oversamples some groups, so each respondent has a sampling weight, stored in the `_LLCPWT` column. I used these weights to resample the data, so the subset we just loaded is representative of adult residents of the U.S.\n",
"\n",
"The BRFSS includes hundreds of variables. For the examples in this notebook, I chose just nine. I'll explain what the columns mean as we go along."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"brfss.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A common way to visualize the relationship between two variables is a \"scatter plot\".\n",
"\n",
"Scatter plots are common and readily understood, but they are surprisingly hard to get right.\n",
"\n",
"To demonstrate, we'll explore the relationship between height and weight. First I'll extract the columns for height in centimeters and weight in kilograms."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"height = brfss['HTM4']\n",
"weight = brfss['WTKG3']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To make a scatter plot we'll use `plot()` with the style string `o`, which plots a circle for each data point.\n",
"\n",
"Here's how:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.plot(height, weight, 'o')\n",
"\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In general, it looks like taller people are heavier, but there are a few things about this scatter plot that make it hard to interpret.\n",
"\n",
"Most importantly, it is \"overplotted\", which means that there are data points piled on top of each other so you can't tell where there are a lot of points and where there is just one.\n",
"\n",
"When that happens, the results can be really misleading.\n",
"\n",
"One way to improve the plot is to use transparency, which we can do with the keyword argument `alpha`. The lower the value of alpha, the more transparent each data point is. \n",
"\n",
"Here's what it looks like with `alpha=0.02`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(height, weight, 'o', alpha=0.02)\n",
"\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is better, but there are so many data points, the scatter plot is still overplotted. The next step is to make the markers smaller.\n",
"\n",
"With `markersize=1` and a low value of alpha, the scatter plot is less saturated.\n",
"\n",
"Here's what it looks like."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(height, weight, 'o', alpha=0.02, markersize=1)\n",
"\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Again, this is better, but now we can see that the points fall in discrete columns. That's because most heights were reported in inches and converted to centimeters.\n",
"\n",
"We can break up the columns by adding some random noise to the values; in effect, we are filling in the values that got rounded off.\n",
"\n",
"Adding random noise like this is called \"jittering\".\n",
"\n",
"In this example, I added noise with mean 0 and standard deviation 2.\n",
"\n",
"Here's what the plot looks like when we jitter height."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"height_jitter = height + np.random.normal(0, 2, size=len(brfss))\n",
"\n",
"plt.plot(height_jitter, weight, 'o', alpha=0.02, markersize=1)\n",
"\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The columns are gone, but now we can see that there are rows where people rounded off their weight. We can fix that by jittering weight, too."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"weight_jitter = weight + np.random.normal(0, 2, size=len(brfss))\n",
"\n",
"plt.plot(height_jitter, weight_jitter, 'o', alpha=0.02, markersize=1)\n",
"\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, let's zoom in on the area where most of the data points are.\n",
"\n",
"The Pyplot functions `xlim()` and `ylim()` set the lower and upper bounds for the x- and y-axis; in this case, we plot heights from 140 to 200 centimeters and weights up to 160 kilograms.\n",
"\n",
"Here's what it looks like."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"weight_jitter = weight + np.random.normal(0, 2, size=len(brfss))\n",
"\n",
"plt.plot(height_jitter, weight_jitter, 'o', alpha=0.02, markersize=1)\n",
"\n",
"plt.xlim([140, 200])\n",
"plt.ylim([0, 160])\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have a reliable picture of the relationship between height and weight.\n",
"\n",
"Below you can see the misleading plot we started with and the more reliable one we ended with. They are clearly different, and they suggest different stories about the relationship between these variables."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Set the figure size\n",
"plt.figure(figsize=(6, 8))\n",
"\n",
"# Create subplots with 2 rows, 1 column, and start plot 1\n",
"plt.subplot(2, 1, 1)\n",
"plt.plot(height, weight, 'o')\n",
"\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height')\n",
"\n",
"# Adjust the layout so the two plots don't overlap\n",
"plt.tight_layout()\n",
"\n",
"# Start plot 2\n",
"plt.subplot(2, 1, 2)\n",
"\n",
"weight_jitter = weight + np.random.normal(0, 2, size=len(brfss))\n",
"\n",
"plt.plot(height_jitter, weight_jitter, 'o', alpha=0.02, markersize=1)\n",
"\n",
"plt.xlim([140, 200])\n",
"plt.ylim([0, 160])\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height')\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The point of this example is that it takes some effort to make an effective scatter plot."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Do people tend to gain weight as they get older? We can answer this question by visualizing the relationship between weight and age. \n",
"\n",
"But before we make a scatter plot, it is a good idea to visualize distributions one variable at a time. So let's look at the distribution of age.\n",
"\n",
"The BRFSS dataset includes a column, `AGE`, which represents each respondent's age in years. To protect respondents' privacy, ages are rounded off into 5-year bins. `AGE` contains the midpoint of the bins.\n",
"\n",
"- Extract the variable `'AGE'` from the DataFrame `brfss` and assign it to `age`.\n",
"\n",
"- Plot the PMF of `age` as a bar chart."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from empiricaldist import Pmf"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Now let's look at the distribution of weight. The column that contains weight in kilograms is `WTKG3`. This column contains many unique values, if we display it using a PMF, it doesn't work very well."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To get a better view of this distribution, try plotting the CDF.\n",
"\n",
"Also try plotting it on a log-x scale."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"from empiricaldist import Cdf"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Optional Exercise:** Compute the CDF of a normal distribution and compare it with the distribution of weight. Is the normal distribution a good model for this data? What about log-transformed weights?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Now let's make a scatter plot of `weight` versus `age`. Adjust `alpha` and `markersize` to avoid overplotting. Use `ylim` to limit the `y` axis from 0 to 200 kilograms."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** In the previous exercise, the ages fall in columns because they've been rounded into 5-year bins. If we jitter them, the scatter plot will show the relationship more clearly.\n",
"\n",
"- Add random noise to `age` with mean `0` and standard deviation `2.5`.\n",
"- Make a scatter plot and adjust `alpha` and `markersize` again."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing relationships\n",
"\n",
"In the previous section we used scatter plots to visualize relationships between variables, and in the exercises, you explored the relationship between age and weight. In this section, we'll see other ways to visualize these relationships, including boxplots and violin plots.\n",
"\n",
"In a previous exercise, you made a scatter plot of weight versus age that might have looked like this:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# Add jittering to age\n",
"age_jitter = age + np.random.normal(0, 2.5, size=len(brfss))\n",
"\n",
"# Make a scatter plot\n",
"plt.plot(age_jitter, weight, 'o', alpha=0.01, markersize=1)\n",
"\n",
"# Decorate the axes\n",
"plt.ylim([0, 200])\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Weight versus age');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It looks like older people might be heavier, but it is hard to see clearly.\n",
"\n",
"Here's another version of the same plot:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# Jitter weight\n",
"weight_jitter = brfss['WTKG3'] + np.random.normal(0, 2, size=len(brfss))\n",
"\n",
"# Jitter age\n",
"age_jitter = age + np.random.normal(0, 0.75, size=len(brfss))\n",
"\n",
"# Make a scatter plot\n",
"plt.plot(age_jitter, weight_jitter, 'o', alpha=0.01, markersize=1)\n",
"\n",
"# Decorate the axes\n",
"plt.ylim([0, 200])\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Weight versus age');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what I changed:\n",
"\n",
"* First, I jittered the weights, so the horizontal rows are not visible.\n",
"\n",
"* Second, I adjusted the jittering of the weights so there's still space between the columns.\n",
"\n",
"That makes it possible to see the shape of the distribution in each age group, and the differences between groups.\n",
"\n",
"With this view, it looks like weight increases until age 40 or 50, and then starts to decrease."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we take this idea one step farther, we can use KDE to estimate the density function in each column and plot it.\n",
"\n",
"And there's a name for that; it's called a violin plot. Seaborn provides a function that makes violin plots, but before we can use it, we have to get rid of any rows with missing data.\n",
"\n",
"Here's how:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"brfss.shape"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"data = brfss.dropna(subset=['AGE', 'WTKG3'])\n",
"data.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`dropna()` creates a new DataFrame that contains the rows from `brfss` where `AGE` and `WTKG3` are not NaN.\n",
"\n",
"Now we can call `violinplot()`."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import seaborn as sns\n",
"\n",
"sns.violinplot(x='AGE', y='WTKG3', data=data, inner=None)\n",
"\n",
"# Decorate the axes\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Weight versus age');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `x` and `y` arguments mean we want `AGE` on the x-axis and `WTKG3` on the y-axis. `data` is the DataFrame we just created, which contains the variables we're going to plot. The argument `inner=None` simplifies the plot a little.\n",
"\n",
"In the figure, each shape represents the distribution of weight in one age group. The width of these shapes is proportional to the estimated density, so it's like two vertical KDEs plotted back to back (and filled in with nice colors).\n",
"\n",
"There's yet another way to look at data like this, called a \"box plot\". The code to generate a box plot is very similar. "
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"sns.boxplot(x='AGE', y='WTKG3', data=data, whis=10)\n",
"\n",
"# Decorate the axes\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Weight versus age');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I included the argument `whis=10` to turn off a feature we don't need. If you are curious about it, you can read the documentation.\n",
"\n",
"Each box represents the distribution of weight in an age group. The height of each box represents the range from the 25th to the 75th percentile. The line in the middle of each box is the median.\n",
"\n",
"The spines sticking out of the top and bottom show the minimum and maximum values.\n",
"\n",
"In my opinion, this plot gives us the best view of the relationship between weight and age.\n",
"\n",
"Looking at the medians, it seems like people in their 40s are the heaviest; younger and older people are lighter.\n",
"\n",
"Looking at the sizes of the boxes, it seems like people in their 40s have the most variability in weight, too.\n",
"\n",
"These plots also show how skewed the distribution of weight is; that is, the heaviest people are much farther from the median than the lightest people.\n",
"\n",
"For data that skews toward higher values, it is sometimes useful to look at it on a logarithmic scale.\n",
"\n",
"We can do that with the Pyplot function `yscale()`."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"sns.boxplot(x='AGE', y='WTKG3', data=data, whis=10)\n",
"\n",
"# Decorate the axes\n",
"plt.yscale('log')\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg (log scale)')\n",
"plt.title('Weight versus age');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's what it looks like. To show the relationship between age and weight most clearly, this is probably the figure I would use.\n",
"\n",
"In the exercises, you'll have a chance to generate violin and box plots."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Previously we looked at a scatter plot of height and weight, and saw that taller people tend to be heavier. Now let's take a closer look using a box plot. The `brfss` DataFrame contains a column named `_HTMG10` that represents height in centimeters, binned into 10 cm groups.\n",
"\n",
"- Make a boxplot that shows the distribution of weight in each height group.\n",
"\n",
"- Plot the y-axis on a logarithmic scale.\n",
"\n",
"Suggestion: If the labels on the `x` axis collide, you can rotate them like this:\n",
"\n",
"```\n",
"plt.xticks(rotation='45')\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Now let's look at relationships between income and other variables. \n",
"\n",
"In the BRFSS, income is represented as a categorical variable; that is, respondents are assigned to one of 8 income categories. The column name is `INCOME2`. \n",
"\n",
"Before we connect income with anything else, let's look at the distribution by computing the PMF.\n",
"\n",
"- Extract `INCOME2` from the `brfss` DataFrame and assign it to `income`.\n",
"- Plot the PMF of `income` as a bar chart.\n",
"\n",
"Note: You will see that about a third of the respondents are in the highest income group; ideally, it would be better if there were more groups at the high end, but we'll work with what we have."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Generate a violin plot that shows the distribution of height in each income group. Can you see a relationship between these variables?"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Correlation\n",
"\n",
"In the previous lesson, we visualized relationships between pairs of variables. In this lesson we'll learn about the coefficient of correlation, which quantifies the strength of these relationships.\n",
"\n",
"When people say \"correlation\" casually, they might mean any relationship between two variables. In statistics, it usually means Pearson's correlation coefficient, which is a number between -1 and 1 that quantifies the strength of a linear relationship between variables.\n",
"\n",
"To demonstrate, I'll select three columns from the BRFSS dataset, like this. The result is a DataFrame with just those columns."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"columns = ['HTM4', 'WTKG3', 'AGE']\n",
"subset = brfss[columns]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use the `.corr()` method, like this."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"subset.corr()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result is a \"correlation matrix\". Reading across the first row, the correlation of `HTM4` with itself is 1. That's expected; the correlation of anything with itself is 1.\n",
"\n",
"The next entry is more interesting; the correlation of height and weight is about 0.47. It's positive, which means taller people are heavier, and it is moderate in strength, which means it has some predictive value. If you know someone's height, you can make a better guess about their weight, and vice versa.\n",
"\n",
"The correlation between height and age is about -0.09. It's negative, which means that older people tend to be shorter, but it's weak, which means that knowing someone's age would not help much if you were trying to guess their height.\n",
"\n",
"The correlation between age and weight is even smaller. It is tempting to conclude that there is no relationship between age and weight, but we have already seen that there is. So why is the correlation so low?\n",
"\n",
"Remember that the relationship between weight and age looks like this. "
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"data = brfss.dropna(subset=['AGE', 'WTKG3'])\n",
"sns.boxplot(x='AGE', y='WTKG3', data=data, whis=10)\n",
"\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Weight versus age');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" \n",
"\n",
"People in their 40s are the heaviest; younger and older people are lighter. So this relationship is nonlinear.\n",
"\n",
"But correlation only measures linear relationships. If the relationship is nonlinear, correlation generally underestimates how strong it is.\n",
"\n",
"To demonstrate, I'll generate some fake data: `xs` contains equally-spaced points between -1 and 1. `ys` is `xs` squared plus some random noise."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(18)\n",
"xs = np.linspace(-1, 1)\n",
"ys = xs**2 + np.random.normal(0, 0.05, len(xs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's the scatter plot of `xs` and `ys`. "
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(xs, ys, 'o', alpha=0.5)\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"plt.title('Scatter plot of a fake dataset');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's clear that this is a strong relationship; if you are given `x`, you can make a much better guess about `y`.\n",
"\n",
"But here's the correlation matrix:"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"np.corrcoef(xs, ys)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Even though there is a strong non-linear relationship, the computed correlation is close to 0.\n",
"\n",
"In general, if correlation is high -- that is, close to 1 or -1, you can conclude that there is a strong linear relationship.\n",
"\n",
"But if correlation is close to 0, that doesn't mean there is no relationship; there might be a non-linear relationship.\n",
"\n",
"This is one of the reasons I think correlation is not such a great statistic.\n",
"\n",
"There's another reason to be careful with correlation; it doesn't mean what people take it to mean.\n",
"\n",
"Specifically, correlation says nothing about slope. If we say that two variables are correlated, that means we can use one to predict the other. But that might not be what we care about.\n",
"\n",
"For example, suppose we are concerned about the health effects of weight gain, so we plot weight versus age, from 20 to 50 years old.\n",
"\n",
"I'll generate two fake datasets to demonstrate the point.\n",
"\n",
"In each dataset, `xs` represents age and `ys` represents weight.\n",
"\n",
"I use `np.random.seed` to initialize the random number generator so we get the same results every time we run."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(18)\n",
"xs1 = np.linspace(20, 50)\n",
"ys1 = 75 + 0.02 * xs1 + np.random.normal(0, 0.15, len(xs1))\n",
"\n",
"plt.plot(xs1, ys1, 'o', alpha=0.5)\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Fake dataset #1');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And here's the second dataset:"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(18)\n",
"xs2 = np.linspace(20, 50)\n",
"ys2 = 65 + 0.2 * xs2 + np.random.normal(0, 3, len(xs2))\n",
"\n",
"plt.plot(xs2, ys2, 'o', alpha=0.5)\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Fake dataset #2');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I constructed these examples so they look similar, but they have substantially different correlations: "
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"rho1 = np.corrcoef(xs1, ys1)[0][1]\n",
"rho1"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"rho2 = np.corrcoef(xs2, ys2)[0][1]\n",
"rho2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the first example, the correlation is strong, close to 0.75.\n",
"\n",
"In the second example, the correlation is moderate, close to 0.5.\n",
"\n",
"So we might think the first relationship is more important. But look more closely at the `y` axis in both figures.\n",
"\n",
"In the first example, the average weight gain over 30 years is less than 1 kilogram; in the second it is more than 5 kilograms!\n",
"\n",
"If we are concerned about the health effects of weight gain, the second relationship is probably more important, even though the correlation is lower. \n",
"\n",
"The statistic we really care about is the slope of the line, not the coefficient of correlation.\n",
"\n",
"In the next lesson, you'll learn how to estimate that slope. But first, let's practice with correlation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** The purpose of the BRFSS is to explore health risk factors, so it includes questions about diet. The column `_VEGESU1` represents the number of servings of vegetables respondents reported eating per day.\n",
"\n",
"Let's see how this variable relates to age and income.\n",
"\n",
"- From the `brfss` DataFrame, select the columns `'AGE'`, `INCOME2`, and `_VEGESU1`.\n",
"- Compute the correlation matrix for these variables."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** In the previous exercise, the correlation between income and vegetable consumption is about `0.12`. The correlation between age and vegetable consumption is about `-0.01`.\n",
"\n",
"Which of the following are correct interpretations of these results?\n",
"\n",
"- *A*: People in this dataset with higher incomes eat more vegetables.\n",
"- *B*: The relationship between income and vegetable consumption is linear.\n",
"- *C*: Older people eat more vegetables.\n",
"- *D*: There could be a strong non-linear relationship between age and vegetable consumption."
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** In general it is a good idea to visualize the relationship between variables *before* you compute a correlation. We didn't do that in the previous example, but it's not too late.\n",
"\n",
"Generate a visualization of the relationship between age and vegetables. How would you describe the relationship, if any?"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple regression\n",
"\n",
"In the previous section we saw that correlation does not always measure what we really want to know. In this section, we look at an alternative: simple linear regression.\n",
"\n",
"Let's look again at the relationship between weight and age. In the previous section, I generated two fake datasets to make a point:"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(9, 4))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(xs1, ys1, 'o', alpha=0.5)\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Fake dataset #1')\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(xs2, ys2, 'o', alpha=0.5)\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Fake dataset #2');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"They look similar, and one on the left has higher correlation, about 0.75 compared to 0.5.\n",
"\n",
"But in this context, the statistic we probably care about is the slope of the line, not the correlation coefficient.\n",
"\n",
"To estimate the slope, we can use `linregress()` from the SciPy `stats` module."
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"from scipy.stats import linregress\n",
"\n",
"# Fake dataset 1\n",
"res1 = linregress(xs1, ys1)\n",
"res1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result is a `LinregressResult` that contains five values: `slope` is the slope of the line of best fit for the data; `intercept` is the intercept.\n",
"\n",
"For Fake dataset #1, the estimated slope is about 0.019 kilograms per year or about 0.56 kilograms over the 30-year range."
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"res1.slope * 30"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here are the results for Fake dataset #2. "
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"res2 = linregress(xs2, ys2)\n",
"res2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The estimated slope is almost 10 times higher: about 0.18 kilograms per year or about 5.3 kilograms per 30 years:"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"res2.slope * 30"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What's called `rvalue` here is correlation, which confirms what we saw before; the first example has higher correlation, about 0.75 compared to 0.5.\n",
"\n",
"But the strength of the effect, as measured by the slope of the line, is about 10 times higher in the second example.\n",
"\n",
"We can use the results from `linregress()` to compute the line of best fit: first we get the min and max of the observed xs; then we multiply by the slope and add the intercept. And plot the line.\n",
"\n",
"Here's what that looks like for the first example. "
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(xs1, ys1, 'o', alpha=0.5)\n",
"\n",
"fx = np.array([xs1.min(), xs1.max()])\n",
"fy = res1.intercept + res1.slope * fx\n",
"plt.plot(fx, fy, '-')\n",
"\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Fake dataset #1');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And the same thing for the second example. "
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(xs2, ys2, 'o', alpha=0.5)\n",
"\n",
"fx = np.array([xs2.min(), xs2.max()])\n",
"fy = res2.intercept + res2.slope * fx\n",
"plt.plot(fx, fy, '-')\n",
"\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Fake dataset #2');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The visualization here might be misleading unless you look closely at the vertical scales; the slope in the second figure is almost 10 times higher."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Height and weight\n",
"\n",
"Now let's look at an example with real data.\n",
"\n",
"Here's the scatter plot of height and weight one more time."
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(height_jitter, weight_jitter, 'o', alpha=0.02, markersize=1)\n",
"\n",
"plt.xlim([140, 200])\n",
"plt.ylim([0, 160])\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can compute the regression line. `linregress()` can't handle NaNs, so we have to use `dropna()` to remove rows that are missing the data we need."
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"subset = brfss.dropna(subset=['WTKG3', 'HTM4'])\n",
"\n",
"height_clean = subset['HTM4']\n",
"weight_clean = subset['WTKG3']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can compute the linear regression."
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"res_hw = linregress(height_clean, weight_clean)\n",
"res_hw"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And here are the results. The slope is about 0.9 kilograms per centimeter, which means that we expect a person one centimeter taller to be almost a kilogram heavier. That's quite a lot.\n",
"\n",
"As before, we can compute the line of best fit:"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"fx = np.array([height_clean.min(), height_clean.max()])\n",
"fy = res_hw.intercept + res_hw.slope * fx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And here's what that looks like. "
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(height_jitter, weight_jitter, 'o', alpha=0.02, markersize=1)\n",
"\n",
"plt.plot(fx, fy, '-')\n",
"\n",
"plt.xlim([140, 200])\n",
"plt.ylim([0, 160])\n",
"plt.xlabel('Height in cm')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Scatter plot of weight versus height');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The slope of this line seems consistent with the scatter plot.\n",
"\n",
"However, linear regression has the same problem as correlation; it only measures the strength of a linear relationship.\n",
"\n",
"Here's the scatter plot of weight versus age, which you saw in a previous exercise."
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"# Add jittering to age\n",
"age_jitter = age + np.random.normal(0, 2.5, size=len(brfss))\n",
"\n",
"# Make a scatter plot\n",
"plt.plot(age_jitter, weight, 'o', alpha=0.01, markersize=1)\n",
"\n",
"plt.ylim([0, 160])\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Weight versus age');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"People in their 40s are the heaviest; younger and older people are lighter. So the relationship is nonlinear.\n",
"\n",
"If we don't look at the scatter plot and blindly compute the regression line, here's what we get."
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"subset = brfss.dropna(subset=['WTKG3', 'AGE'])\n",
"age_clean = subset['AGE']\n",
"weight_clean = subset['WTKG3']\n",
"\n",
"res_aw = linregress(age_clean, weight_clean)\n",
"res_aw"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The estimated slope is only 0.02 kilograms per year, or 0.6 kilograms in 30 years.\n",
"\n",
"And here's what the line of best fit looks like."
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"# Make a scatter plot\n",
"plt.plot(age_jitter, weight, 'o', alpha=0.01, markersize=1)\n",
"\n",
"# Plot the line of best fit\n",
"fx = np.array([age_clean.min(), age_clean.max()])\n",
"fy = res_aw.intercept + res_aw.slope * fx\n",
"plt.plot(fx, fy, '-')\n",
"\n",
"plt.ylim([0, 160])\n",
"plt.xlabel('Age in years')\n",
"plt.ylabel('Weight in kg')\n",
"plt.title('Weight versus age');"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A straight line does not capture the relationship between these variables well.\n",
"\n",
"In the next notebook, you'll see how to use multiple regression to estimate non-linear relationships. But first, let's practice simple regression."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Who do you think eats more vegetables, people with low income, or people with high income? Let's find out.\n",
"\n",
"As we've seen previously, the column `INCOME2` represents income level and `_VEGESU1` represents the number of vegetable servings respondents reported eating per day.\n",
"\n",
"Make a scatter plot with vegetable servings versus income, that is, with vegetable servings on the `y` axis and income group on the `x` axis.\n",
"\n",
"You might want to use `ylim` to zoom in on the bottom half of the `y` axis."
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Now let's estimate the slope of the relationship between vegetable consumption and income.\n",
"\n",
"- Use `dropna` to select rows where `INCOME2` and `_VEGESU1` are not NaN.\n",
"\n",
"- Extract `INCOME2` and `_VEGESU1` and compute the simple linear regression of these variables.\n",
"\n",
"What is the slope of the regression\n",
"line? Write a sentence that explains what this slope means in the context of the question we are exploring."
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"# The estimated slope is 0.07, which means that\n",
"# people in higher income groups eat slightly more vegetables\n",
"# on average. Between the lowest and the highest income group\n",
"# the difference is about half a vegetable serving per day."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Exercise:** Finally, plot the regression line on top of the scatter plot.\n"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"# Solution goes here"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}