{ "cells": [ { "cell_type": "markdown", "id": "7c4c44c8", "metadata": {}, "source": [ "## MOSAIKS feature extraction\n", "\n", "This tutorial demonstrates the **MOSAIKS** method for extracting _feature vectors_ from satellite imagery patches for use in downstream modeling tasks. It will show:\n", "- How to extract 1km$^2$ patches of Sentinel 2 multispectral imagery for a list of latitude, longitude points\n", "- How to extract summary features from each of these imagery patches\n", "- How to use the summary features in a linear model of the population density at each point\n", "\n", "### Background\n", "\n", "Consider the case where you have a dataset of latitude and longitude points associated with some dependent variable (for example: population density, weather, housing prices, biodiversity) and, potentially, other independent variables. You would like to model the dependent variable as a function of the independent variables, but instead of including latitude and longitude directly in this model, you would like to include some high dimensional representation of what the Earth looks like at that point (that hopefully explains some of the variance in the dependent variable!). From the computer vision literature, there are various [representation learning techniques](https://en.wikipedia.org/wiki/Feature_learning) that can be used to do this, i.e. extract _features vectors_ from imagery. This notebook gives an implementation of the technique described in [Rolf et al. 2021](https://www.nature.com/articles/s41467-021-24638-z), \"A generalizable and accessible approach to machine learning with global satellite imagery\" called Multi-task Observation using Satellite Imagery & Kitchen Sinks (**MOSAIKS**). For more information about **MOSAIKS** see the [project's webpage](http://www.globalpolicy.science/mosaiks).\n", "\n", "\n", "**Notes**:\n", "- This example uses [Sentinel-2 Level-2A data](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a). The techniques used here apply equally well to other remote-sensing datasets.\n", "- If you're running this on the [Planetary Computer Hub](http://planetarycomputer.microsoft.com/compute), make sure to choose the **GPU - PyTorch** profile when presented with the form to choose your environment." ] }, { "cell_type": "code", "execution_count": 1, "id": "ab74cb2f", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "import time\n", "import os\n", "\n", "RASTERIO_BEST_PRACTICES = dict( # See https://github.com/pangeo-data/cog-best-practices\n", " CURL_CA_BUNDLE=\"/etc/ssl/certs/ca-certificates.crt\",\n", " GDAL_DISABLE_READDIR_ON_OPEN=\"EMPTY_DIR\",\n", " AWS_NO_SIGN_REQUEST=\"YES\",\n", " GDAL_MAX_RAW_BLOCK_CACHE_SIZE=\"200000000\",\n", " GDAL_SWATH_SIZE=\"200000000\",\n", " VSI_CURL_CACHE_SIZE=\"200000000\",\n", ")\n", "os.environ.update(RASTERIO_BEST_PRACTICES)\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "import rasterio\n", "import rasterio.warp\n", "import rasterio.mask\n", "import shapely.geometry\n", "import geopandas\n", "import dask_geopandas\n", "from sklearn.linear_model import RidgeCV\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import r2_score\n", "from scipy.stats import spearmanr\n", "from scipy.linalg import LinAlgWarning\n", "from dask.distributed import Client\n", "\n", "\n", "warnings.filterwarnings(action=\"ignore\", category=LinAlgWarning, module=\"sklearn\")\n", "\n", "import pystac_client\n", "import planetary_computer as pc" ] }, { "cell_type": "markdown", "id": "e372135a", "metadata": {}, "source": [ "First we define the pytorch model that we will use to extract the features and a helper method. The **MOSAIKS** methodology describes several ways to do this and we use the simplest." ] }, { "cell_type": "code", "execution_count": 2, "id": "e13c154a", "metadata": {}, "outputs": [], "source": [ "def featurize(input_img, model, device):\n", " \"\"\"Helper method for running an image patch through the model.\n", "\n", " Args:\n", " input_img (np.ndarray): Image in (C x H x W) format with a dtype of uint8.\n", " model (torch.nn.Module): Feature extractor network\n", " \"\"\"\n", " assert len(input_img.shape) == 3\n", " input_img = torch.from_numpy(input_img / 255.0).float()\n", " input_img = input_img.to(device)\n", " with torch.no_grad():\n", " feats = model(input_img.unsqueeze(0)).cpu().numpy()\n", " return feats\n", "\n", "\n", "class RCF(nn.Module):\n", " \"\"\"A model for extracting Random Convolution Features (RCF) from input imagery.\"\"\"\n", "\n", " def __init__(self, num_features=16, kernel_size=3, num_input_channels=3):\n", " super(RCF, self).__init__()\n", "\n", " # We create `num_features / 2` filters so require `num_features` to be divisible by 2\n", " assert num_features % 2 == 0\n", "\n", " self.conv1 = nn.Conv2d(\n", " num_input_channels,\n", " num_features // 2,\n", " kernel_size=kernel_size,\n", " stride=1,\n", " padding=0,\n", " dilation=1,\n", " bias=True,\n", " )\n", "\n", " nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)\n", " nn.init.constant_(self.conv1.bias, -1.0)\n", "\n", " def forward(self, x):\n", " x1a = F.relu(self.conv1(x), inplace=True)\n", " x1b = F.relu(-self.conv1(x), inplace=True)\n", "\n", " x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()\n", " x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()\n", "\n", " if len(x1a.shape) == 1: # case where we passed a single input\n", " return torch.cat((x1a, x1b), dim=0)\n", " elif len(x1a.shape) == 2: # case where we passed a batch of > 1 inputs\n", " return torch.cat((x1a, x1b), dim=1)" ] }, { "cell_type": "markdown", "id": "42f303c0", "metadata": {}, "source": [ "Next, we initialize the model and pytorch components" ] }, { "cell_type": "code", "execution_count": 3, "id": "60b896be", "metadata": {}, "outputs": [], "source": [ "num_features = 1024\n", "\n", "device = torch.device(\"cuda\")\n", "model = RCF(num_features).eval().to(device)" ] }, { "cell_type": "markdown", "id": "6c804888", "metadata": { "tags": [] }, "source": [ "### Read dataset of (lat, lon) points and corresponding labels" ] }, { "cell_type": "markdown", "id": "6462a61e", "metadata": {}, "source": [ "We read a CSV of 100,000 randomly sampled (lat, lon) points over the US and the corresponding population living roughly within 1km$^2$ of the points from the [Gridded Population of the World](https://sedac.ciesin.columbia.edu/downloads/data/gpw-v4/gpw-v4-population-density-rev10/gpw-v4-population-density-rev10_2015_30_sec_tif.zip) dataset. This data comes from the [Code Ocean capsule](https://codeocean.com/capsule/6456296/tree/v2) that accompanies the Rolf et al. 2021 paper." ] }, { "cell_type": "code", "execution_count": 4, "id": "92ef8030", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | V1 | \n", "V1.1 | \n", "ID | \n", "lon | \n", "lat | \n", "population | \n", "geometry | \n", "
---|---|---|---|---|---|---|---|
1 | \n", "1 | \n", "1 | \n", "1225,1595 | \n", "-103.046735 | \n", "37.932314 | \n", "0.085924 | \n", "POINT (-103.04673 37.93231) | \n", "
2 | \n", "2 | \n", "2 | \n", "1521,2455 | \n", "-91.202438 | \n", "34.647579 | \n", "0.808222 | \n", "POINT (-91.20244 34.64758) | \n", "
4 | \n", "4 | \n", "4 | \n", "828,3849 | \n", "-72.003660 | \n", "42.116711 | \n", "101.286320 | \n", "POINT (-72.00366 42.11671) | \n", "
5 | \n", "5 | \n", "5 | \n", "1530,2831 | \n", "-86.024002 | \n", "34.545552 | \n", "28.181724 | \n", "POINT (-86.02400 34.54555) | \n", "
6 | \n", "6 | \n", "6 | \n", "1097,2696 | \n", "-87.883281 | \n", "39.309455 | \n", "11.923701 | \n", "POINT (-87.88328 39.30945) | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
99995 | \n", "99995 | \n", "99995 | \n", "1164,1275 | \n", "-107.453915 | \n", "38.591906 | \n", "0.006152 | \n", "POINT (-107.45391 38.59191) | \n", "
99996 | \n", "99996 | \n", "99996 | \n", "343,2466 | \n", "-91.050941 | \n", "46.876806 | \n", "1.782686 | \n", "POINT (-91.05094 46.87681) | \n", "
99997 | \n", "99997 | \n", "99997 | \n", "634,3922 | \n", "-70.998272 | \n", "44.067430 | \n", "8.383357 | \n", "POINT (-70.99827 44.06743) | \n", "
99998 | \n", "99998 | \n", "99998 | \n", "357,2280 | \n", "-93.612615 | \n", "46.744854 | \n", "8.110552 | \n", "POINT (-93.61261 46.74485) | \n", "
100000 | \n", "100000 | \n", "100000 | \n", "1449,2691 | \n", "-87.952143 | \n", "35.459263 | \n", "1.972914 | \n", "POINT (-87.95214 35.45926) | \n", "
67968 rows × 7 columns
\n", "\n", " | V1 | \n", "V1.1 | \n", "ID | \n", "lon | \n", "lat | \n", "population | \n", "geometry | \n", "hd | \n", "stac_item | \n", "
---|---|---|---|---|---|---|---|---|---|
58227 | \n", "58227 | \n", "58227 | \n", "2253,2012 | \n", "-97.303628 | \n", "25.958163 | \n", "0.193628 | \n", "POINT (-97.30363 25.95816) | \n", "360793917 | \n", "{'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP... | \n", "
27993 | \n", "27993 | \n", "27993 | \n", "2245,1999 | \n", "-97.482670 | \n", "26.057180 | \n", "244.428020 | \n", "POINT (-97.48267 26.05718) | \n", "360931126 | \n", "{'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP... | \n", "
61808 | \n", "61808 | \n", "61808 | \n", "2246,1986 | \n", "-97.661712 | \n", "26.044807 | \n", "525.789124 | \n", "POINT (-97.66171 26.04481) | \n", "360950765 | \n", "{'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP... | \n", "
97710 | \n", "97710 | \n", "97710 | \n", "2240,1992 | \n", "-97.579077 | \n", "26.119023 | \n", "70.362312 | \n", "POINT (-97.57908 26.11902) | \n", "360966149 | \n", "{'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP... | \n", "
55442 | \n", "55442 | \n", "55442 | \n", "2231,1994 | \n", "-97.551532 | \n", "26.230258 | \n", "44.150997 | \n", "POINT (-97.55153 26.23026) | \n", "361043545 | \n", "{'id': 'S2B_MSIL2A_20181213T170709_R069_T14RPQ... | \n", "
RidgeCV(alphas=array([1.e-08, 1.e-07, 1.e-06, 1.e-05, 1.e-04, 1.e-03, 1.e-02, 1.e-01,\n", " 1.e+00, 1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07,\n", " 1.e+08]),\n", " cv=5)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RidgeCV(alphas=array([1.e-08, 1.e-07, 1.e-06, 1.e-05, 1.e-04, 1.e-03, 1.e-02, 1.e-01,\n", " 1.e+00, 1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07,\n", " 1.e+08]),\n", " cv=5)